diff --git a/.changeset/hot-trees-sing.md b/.changeset/hot-trees-sing.md new file mode 100644 index 000000000..74df1dd00 --- /dev/null +++ b/.changeset/hot-trees-sing.md @@ -0,0 +1,12 @@ +--- +'@modelcontextprotocol/express': patch +'@modelcontextprotocol/hono': patch +'@modelcontextprotocol/node': patch +'@modelcontextprotocol/eslint-config': patch +'@modelcontextprotocol/test-integration': patch +'@modelcontextprotocol/client': patch +'@modelcontextprotocol/server': patch +'@modelcontextprotocol/core': patch +--- + +add context API to tool, prompt, resource callbacks, linting diff --git a/CLAUDE.md b/CLAUDE.md index 0f6eaeece..68302a22b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -136,7 +136,7 @@ When a request arrives from the remote side: 2. **`Protocol.connect()`** routes to `_onrequest()`, `_onresponse()`, or `_onnotification()` 3. **`Protocol._onrequest()`**: - Looks up handler in `_requestHandlers` map (keyed by method name) - - Creates `RequestHandlerExtra` with `signal`, `sessionId`, `sendNotification`, `sendRequest` + - Creates a context object (`ServerContext` or `ClientContext`) via `createRequestContext()` - Invokes handler, sends JSON-RPC response back via transport 4. **Handler** was registered via `setRequestHandler(Schema, handler)` @@ -144,29 +144,51 @@ When a request arrives from the remote side: ```typescript // In Client (for serverβ†’client requests like sampling, elicitation) -client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { +client.setRequestHandler(CreateMessageRequestSchema, async (request, ctx) => { // Handle sampling request from server return { role: "assistant", content: {...}, model: "..." }; }); // In Server (for clientβ†’server requests like tools/call) -server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { +server.setRequestHandler(CallToolRequestSchema, async (request, ctx) => { // Handle tool call from client return { content: [...] }; }); ``` -### Request Handler Extra +### Request Handler Context -The `extra` parameter in handlers (`RequestHandlerExtra`) provides: +The `ctx` parameter in handlers provides a structured context with three layers: -- `signal`: AbortSignal for cancellation +**`ctx.mcpCtx`** - MCP-level context: + +- `requestId`: JSON-RPC message ID +- `method`: The method being called +- `_meta`: Request metadata - `sessionId`: Transport session identifier + +**`ctx.requestCtx`** - Request-level context: + +- `signal`: AbortSignal for cancellation - `authInfo`: Validated auth token info (if authenticated) -- `requestId`: JSON-RPC message ID -- `sendNotification(notification)`: Send related notification back -- `sendRequest(request, schema)`: Send related request (for bidirectional flows) -- `taskStore`: Task storage interface (if tasks enabled) +- For server: `uri`, `headers`, `stream` (HTTP details) + +**`ctx.taskCtx`** - Task context (when tasks are enabled): + +- `id`: Current task ID (updates after `store.createTask()`) +- `store`: Request-scoped task store (`RequestTaskStore`) +- `requestedTtl`: Requested TTL for the task + +**Context methods**: + +- `ctx.sendNotification(notification)`: Send notification back +- `ctx.sendRequest(request, schema)`: Send request (for bidirectional flows) + +For server contexts, additional helpers: + +- `ctx.loggingNotification(level, data, logger)`: Send logging notification +- `ctx.requestSampling(params)`: Request sampling from client +- `ctx.elicitInput(params)`: Request user input from client ### Capability Checking @@ -197,7 +219,7 @@ const result = await server.createMessage({ }); // Client must have registered handler: -client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { +client.setRequestHandler(CreateMessageRequestSchema, async (request, ctx) => { // Client-side LLM call return { role: "assistant", content: {...} }; }); @@ -208,8 +230,8 @@ client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { ### Request Handler Registration (Low-Level Server) ```typescript -server.setRequestHandler(SomeRequestSchema, async (request, extra) => { - // extra contains sessionId, authInfo, sendNotification, etc. +server.setRequestHandler(SomeRequestSchema, async (request, ctx) => { + // ctx provides mcpCtx, requestCtx, taskCtx, sendNotification, sendRequest return { /* result */ }; @@ -219,7 +241,7 @@ server.setRequestHandler(SomeRequestSchema, async (request, extra) => { ### Tool Registration (High-Level McpServer) ```typescript -mcpServer.tool('tool-name', { param: z.string() }, async ({ param }, extra) => { +mcpServer.tool('tool-name', { param: z.string() }, async ({ param }, ctx) => { return { content: [{ type: 'text', text: 'result' }] }; }); ``` diff --git a/examples/client/src/elicitationUrlExample.ts b/examples/client/src/elicitationUrlExample.ts index 4ac59aa6a..a61ae8535 100644 --- a/examples/client/src/elicitationUrlExample.ts +++ b/examples/client/src/elicitationUrlExample.ts @@ -26,7 +26,7 @@ import { ErrorCode, getDisplayName, ListToolsResultSchema, - McpError, + ProtocolError, StreamableHTTPClientTransport, UnauthorizedError, UrlElicitationRequiredError @@ -339,7 +339,7 @@ async function handleElicitationRequest(request: ElicitRequest): Promise { console.log(`\n[${id}] Tool result:`); if (Array.isArray(result.content)) { for (const item of result.content) { - if (item.type === 'text' && item.text) { + if (isTextContent(item)) { console.log(` ${item.text}`); } else { console.log(` ${item.type} content:`, item); diff --git a/examples/client/src/parallelToolCallsClient.ts b/examples/client/src/parallelToolCallsClient.ts index fbd3910de..5114c6a22 100644 --- a/examples/client/src/parallelToolCallsClient.ts +++ b/examples/client/src/parallelToolCallsClient.ts @@ -2,6 +2,7 @@ import type { CallToolResult, ListToolsRequest } from '@modelcontextprotocol/cli import { CallToolResultSchema, Client, + isTextContent, ListToolsResultSchema, LoggingMessageNotificationSchema, StreamableHTTPClientTransport @@ -60,7 +61,7 @@ async function main(): Promise { for (const [caller, result] of Object.entries(toolResults)) { console.log(`\n=== Tool result for ${caller} ===`); for (const item of result.content) { - if (item.type === 'text') { + if (isTextContent(item)) { console.log(` ${item.text}`); } else { console.log(` ${item.type} content:`, item); diff --git a/examples/client/src/simpleOAuthClient.ts b/examples/client/src/simpleOAuthClient.ts index 23ea05a99..84791acb9 100644 --- a/examples/client/src/simpleOAuthClient.ts +++ b/examples/client/src/simpleOAuthClient.ts @@ -9,6 +9,7 @@ import type { CallToolRequest, ListToolsRequest, OAuthClientMetadata } from '@mo import { CallToolResultSchema, Client, + isTextContent, ListToolsResultSchema, StreamableHTTPClientTransport, UnauthorizedError @@ -315,7 +316,7 @@ class InteractiveOAuthClient { console.log(`\nπŸ”§ Tool '${toolName}' result:`); if (result.content) { for (const content of result.content) { - if (content.type === 'text') { + if (isTextContent(content)) { console.log(content.text); } else { console.log(content); @@ -396,7 +397,7 @@ class InteractiveOAuthClient { case 'result': { console.log('βœ“ Completed!'); for (const content of message.result.content) { - if (content.type === 'text') { + if (isTextContent(content)) { console.log(content.text); } else { console.log(content); diff --git a/examples/client/src/simpleStreamableHttp.ts b/examples/client/src/simpleStreamableHttp.ts index ced687027..d969f0b92 100644 --- a/examples/client/src/simpleStreamableHttp.ts +++ b/examples/client/src/simpleStreamableHttp.ts @@ -16,11 +16,12 @@ import { ErrorCode, getDisplayName, GetPromptResultSchema, + isTextContent, ListPromptsResultSchema, ListResourcesResultSchema, ListToolsResultSchema, LoggingMessageNotificationSchema, - McpError, + ProtocolError, ReadResourceResultSchema, RELATED_TASK_META_KEY, ResourceListChangedNotificationSchema, @@ -273,7 +274,7 @@ async function connect(url?: string): Promise { // Set up elicitation request handler with proper validation client.setRequestHandler(ElicitRequestSchema, async request => { if (request.params.mode !== 'form') { - throw new McpError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); + throw new ProtocolError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); } console.log('\nπŸ”” Elicitation (form) Request Received:'); console.log(`Message: ${request.params.message}`); @@ -737,7 +738,7 @@ async function runNotificationsToolWithResumability(interval: number, count: num console.log('Tool result:'); for (const item of result.content) { - if (item.type === 'text') { + if (isTextContent(item)) { console.log(` ${item.text}`); } else { console.log(` ${item.type} content:`, item); @@ -791,7 +792,7 @@ async function getPrompt(name: string, args: Record): Promise): Promis console.log('Task completed!'); console.log('Tool result:'); for (const item of message.result.content) { - if (item.type === 'text') { + if (isTextContent(item)) { console.log(` ${item.text}`); } } diff --git a/examples/client/src/simpleStreamableHttpBuilder.ts b/examples/client/src/simpleStreamableHttpBuilder.ts new file mode 100644 index 000000000..32d8ba8b1 --- /dev/null +++ b/examples/client/src/simpleStreamableHttpBuilder.ts @@ -0,0 +1,821 @@ +/* eslint-disable unicorn/no-process-exit */ +/** + * Simple Streamable HTTP Client Example using Builder Pattern + * + * This example demonstrates using the Client.builder() fluent API + * to create and configure an MCP client with: + * - Builder pattern configuration + * - Universal middleware (logging) + * - Outgoing middleware (retry logic) + * - Tool call middleware (instrumentation) + * - Sampling request handler + * - Elicitation request handler + * - Roots list handler + * - Error handlers (onError, onProtocolError) + * + * Run with: npx tsx src/simpleStreamableHttpBuilder.ts + */ + +import { createInterface } from 'node:readline'; + +import type { + CallToolRequest, + ClientMiddleware, + GetPromptRequest, + ListPromptsRequest, + ListResourcesRequest, + ListToolsRequest, + OutgoingMiddleware, + ReadResourceRequest, + ToolCallMiddleware +} from '@modelcontextprotocol/client'; +import { + CallToolResultSchema, + Client, + getDisplayName, + GetPromptResultSchema, + isTextContent, + ListPromptsResultSchema, + ListResourcesResultSchema, + ListToolsResultSchema, + LoggingMessageNotificationSchema, + ReadResourceResultSchema, + StreamableHTTPClientTransport +} from '@modelcontextprotocol/client'; + +// Create readline interface for user input +const readline = createInterface({ + input: process.stdin, + output: process.stdout +}); + +// Track received notifications +let notificationCount = 0; + +// Global client and transport +let client: Client | null = null; +let transport: StreamableHTTPClientTransport | null = null; +let serverUrl = 'http://localhost:3000/mcp'; +let sessionId: string | undefined; + +// ═══════════════════════════════════════════════════════════════════════════ +// Custom Middleware Examples +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Options for MCP client logging middleware. + */ +export interface ClientLoggingMiddlewareOptions { + /** Log level */ + level?: 'debug' | 'info' | 'warn' | 'error'; + /** Custom logger function */ + logger?: (level: string, message: string, data?: unknown) => void; +} + +/** + * Creates a logging middleware for MCP client operations. + * + * @example + * ```typescript + * client.useMiddleware(createClientLoggingMiddleware({ level: 'debug' })); + * ``` + */ +export function createClientLoggingMiddleware(options: ClientLoggingMiddlewareOptions = {}): ClientMiddleware { + const { level = 'info', logger = console.log } = options; + + return async (ctx, next) => { + logger(level, `${ctx.direction} ${ctx.type}: ${ctx.method}`, { + direction: ctx.direction, + type: ctx.type, + method: ctx.method, + requestId: ctx.requestId + }); + + const start = Date.now(); + + try { + const result = await next(); + const duration = Date.now() - start; + logger(level, `← ${ctx.type}: ${ctx.method} (${duration}ms)`, { + direction: ctx.direction, + type: ctx.type, + method: ctx.method, + requestId: ctx.requestId, + duration + }); + return result; + } catch (error) { + const duration = Date.now() - start; + logger('error', `βœ— ${ctx.type}: ${ctx.method} (${duration}ms)`, { + direction: ctx.direction, + type: ctx.type, + method: ctx.method, + requestId: ctx.requestId, + duration, + error + }); + throw error; + } + }; +} + +/** + * Options for retry middleware. + */ +interface RetryMiddlewareOptions { + /** Maximum number of retries */ + maxRetries?: number; + /** Base delay between retries in ms */ + baseDelay?: number; + /** Function to determine if an error is retryable */ + isRetryable?: (error: unknown) => boolean; +} +/** + * Creates a retry middleware for outgoing MCP requests. + * + * @example + * ```typescript + * client.useOutgoingMiddleware(createRetryMiddleware({ + * maxRetries: 3, + * baseDelay: 100, + * })); + * ``` + */ +export function createRetryMiddleware(options: RetryMiddlewareOptions = {}): OutgoingMiddleware { + const { maxRetries = 3, baseDelay = 100, isRetryable = () => true } = options; + + return async (ctx, next) => { + let lastError: unknown; + + for (let attempt = 1; attempt <= maxRetries + 1; attempt++) { + try { + return await next(); + } catch (error) { + lastError = error; + + if (attempt > maxRetries || !isRetryable(error)) { + throw error; + } + + // Exponential backoff + const delay = baseDelay * Math.pow(2, attempt - 1); + await new Promise(resolve => setTimeout(resolve, delay)); + } + } + + throw lastError; + }; +} + +/** + * Custom tool call instrumentation middleware. + * Logs tool calls with timing information. + */ +const toolCallInstrumentationMiddleware: ToolCallMiddleware = async (ctx, next) => { + console.log(`\n[TOOL CALL] Starting: ${ctx.params.name}`); + console.log(`[TOOL CALL] Arguments: ${JSON.stringify(ctx.params.arguments || {})}`); + + const start = performance.now(); + try { + const result = await next(); + const duration = (performance.now() - start).toFixed(2); + console.log(`[TOOL CALL] Completed: ${ctx.params.name} (${duration}ms)`); + return result; + } catch (error) { + const duration = (performance.now() - start).toFixed(2); + console.log(`[TOOL CALL] Failed: ${ctx.params.name} (${duration}ms) - ${error}`); + throw error; + } +}; + +/** + * Custom request timing middleware. + * Tracks timing for all outgoing requests. + */ +const requestTimingMiddleware: ClientMiddleware = async (ctx, next) => { + const start = performance.now(); + try { + const result = await next(); + const duration = (performance.now() - start).toFixed(2); + console.log(`[TIMING] ${ctx.direction} ${ctx.method} completed in ${duration}ms`); + return result; + } catch (error) { + const duration = (performance.now() - start).toFixed(2); + console.log(`[TIMING] ${ctx.direction} ${ctx.method} failed in ${duration}ms`); + throw error; + } +}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Main Function +// ═══════════════════════════════════════════════════════════════════════════ + +async function main(): Promise { + console.log('═══════════════════════════════════════════════════════════════'); + console.log('MCP Interactive Client (Builder Pattern Example)'); + console.log('═══════════════════════════════════════════════════════════════'); + console.log(''); + console.log('Features demonstrated:'); + console.log(' - Builder pattern for client configuration'); + console.log(' - Universal middleware (logging, timing)'); + console.log(' - Outgoing middleware (retry logic)'); + console.log(' - Tool call middleware (instrumentation)'); + console.log(' - Sampling request handler'); + console.log(' - Elicitation request handler'); + console.log(' - Roots list handler'); + console.log(' - Error handlers (onError, onProtocolError)'); + console.log('═══════════════════════════════════════════════════════════════'); + + // Connect to server immediately + await connect(); + + // Print help and start the command loop + printHelp(); + commandLoop(); +} + +function printHelp(): void { + console.log('\nAvailable commands:'); + console.log(' connect [url] - Connect to MCP server (default: http://localhost:3000/mcp)'); + console.log(' disconnect - Disconnect from server'); + console.log(' reconnect - Reconnect to the server'); + console.log(' list-tools - List available tools'); + console.log(' call-tool [args] - Call a tool with optional JSON arguments'); + console.log(' greet [name] - Call the greet tool'); + console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); + console.log(' context-demo [msg] - Call the context-demo tool'); + console.log(' admin-action - Call admin-action (no auth)'); + console.log(' admin-action-auth - Call admin-action with auth token'); + console.log(' error-test - Test error handling (application/validation)'); + console.log(' list-prompts - List available prompts'); + console.log(' get-prompt [args] - Get a prompt with optional JSON arguments'); + console.log(' list-resources - List available resources'); + console.log(' read-resource - Read a specific resource by URI'); + console.log(' session-info - Read session info resource'); + console.log(' help - Show this help'); + console.log(' quit - Exit the program'); +} + +function commandLoop(): void { + readline.question('\n> ', async input => { + const args = input.trim().split(/\s+/); + const command = args[0]?.toLowerCase(); + + try { + switch (command) { + case 'connect': { + await connect(args[1]); + break; + } + + case 'disconnect': { + await disconnect(); + break; + } + + case 'reconnect': { + await reconnect(); + break; + } + + case 'list-tools': { + await listTools(); + break; + } + + case 'call-tool': { + if (args.length < 2) { + console.log('Usage: call-tool [args]'); + } else { + const toolName = args[1]!; + let toolArgs = {}; + if (args.length > 2) { + try { + toolArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await callTool(toolName, toolArgs); + } + break; + } + + case 'greet': { + await callTool('greet', { name: args[1] || 'World' }); + break; + } + + case 'multi-greet': { + console.log('Calling multi-greet tool (watch for notifications)...'); + await callTool('multi-greet', { name: args[1] || 'World' }); + break; + } + + case 'context-demo': { + await callTool('context-demo', { message: args.slice(1).join(' ') || 'Hello from client!' }); + break; + } + + case 'admin-action': { + if (args.length < 2) { + console.log('Usage: admin-action '); + } else { + // Call without requiresAdmin flag - should work + await callTool('admin-action', { action: args[1] }); + } + break; + } + + case 'admin-action-auth': { + if (args.length < 2) { + console.log('Usage: admin-action-auth '); + } else { + // Call with requiresAdmin but provide token + await callTool('admin-action', { + action: args[1], + requiresAdmin: true, + adminToken: 'demo-token-123' + }); + } + break; + } + + case 'error-test': { + if (args.length < 2) { + console.log('Usage: error-test '); + } else { + await callTool('error-test', { errorType: args[1] }); + } + break; + } + + case 'list-prompts': { + await listPrompts(); + break; + } + + case 'get-prompt': { + if (args.length < 2) { + console.log('Usage: get-prompt [args]'); + } else { + const promptName = args[1]!; + let promptArgs = {}; + if (args.length > 2) { + try { + promptArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await getPrompt(promptName, promptArgs); + } + break; + } + + case 'list-resources': { + await listResources(); + break; + } + + case 'read-resource': { + if (args.length < 2) { + console.log('Usage: read-resource '); + } else { + await readResource(args[1]!); + } + break; + } + + case 'session-info': { + await readResource('https://example.com/session/info'); + break; + } + + case 'help': { + printHelp(); + break; + } + + case 'quit': + case 'exit': { + await cleanup(); + return; + } + + default: { + if (command) { + console.log(`Unknown command: ${command}`); + } + break; + } + } + } catch (error) { + console.error(`Error executing command: ${error}`); + } + + // Continue the command loop + commandLoop(); + }); +} + +/** + * Connect to the MCP server using the builder pattern. + * + * The builder provides a fluent API for configuring the client: + * - .name() and .version() set client info + * - .capabilities() configures client capabilities + * - .useMiddleware() adds universal middleware + * - .useOutgoingMiddleware() adds outgoing-only middleware + * - .useToolCallMiddleware() adds tool call specific middleware + * - .onSamplingRequest() handles sampling requests from server + * - .onElicitation() handles elicitation requests from server + * - .onRootsList() handles roots list requests from server + * - .onError() handles application errors + * - .onProtocolError() handles protocol errors + * - .build() creates the configured Client instance + */ +async function connect(url?: string): Promise { + if (client) { + console.log('Already connected. Disconnect first.'); + return; + } + + if (url) { + serverUrl = url; + } + + console.log(`\nConnecting to ${serverUrl}...`); + + try { + // Create a new client using the builder pattern + client = Client.builder() + .name('builder-example-client') + .version('1.0.0') + + // ─── Capabilities ─── + // Enable sampling, elicitation, and roots capabilities + .capabilities({ + sampling: {}, + elicitation: { form: {} }, + roots: { listChanged: true } + }) + + // ─── Universal Middleware ─── + // Logging middleware for all requests + .useMiddleware( + createClientLoggingMiddleware({ + level: 'debug', + logger: (level, message, data) => { + const timestamp = new Date().toISOString(); + console.log(`[${timestamp}] [CLIENT ${level.toUpperCase()}] ${message}`); + if (data) { + console.log(`[${timestamp}] [CLIENT ${level.toUpperCase()}] Data:`, JSON.stringify(data, null, 2)); + } + } + }) + ) + + // Custom timing middleware + .useMiddleware(requestTimingMiddleware) + + // ─── Outgoing Middleware ─── + // Retry middleware for transient failures + .useOutgoingMiddleware( + createRetryMiddleware({ + maxRetries: 3, + baseDelay: 100, + isRetryable: error => { + // Retry on network errors + const message = error instanceof Error ? error.message : String(error); + return message.includes('ECONNREFUSED') || message.includes('ETIMEDOUT') || message.includes('network'); + } + }) + ) + + // ─── Tool Call Middleware ─── + .useToolCallMiddleware(toolCallInstrumentationMiddleware) + + // ─── Request Handlers ─── + + // Sampling request handler (when server requests LLM completion) + .onSamplingRequest(async params => { + console.log('\n[SAMPLING] Received sampling request from server'); + console.log('[SAMPLING] Messages:', JSON.stringify(params, null, 2)); + + // In a real implementation, this would call an LLM + // For demo, return a simulated response + return { + role: 'assistant', + content: { + type: 'text', + text: 'This is a simulated sampling response from the client.' + }, + model: 'simulated-model-v1' + }; + }) + + // Elicitation handler (when server requests user input) + .onElicitation(async params => { + const elicitParams = params as { mode?: string; message?: string; requestedSchema?: unknown }; + console.log('\n[ELICITATION] Received elicitation request from server'); + console.log('[ELICITATION] Mode:', elicitParams.mode); + console.log('[ELICITATION] Message:', elicitParams.message); + + if (elicitParams.mode === 'form') { + // For demo, auto-accept with sample data + console.log('[ELICITATION] Auto-accepting form with sample data'); + return { + action: 'accept', + content: { + name: 'Demo User', + email: 'demo@example.com', + confirmed: true + } + }; + } + + // Decline other modes + console.log('[ELICITATION] Declining non-form elicitation'); + return { action: 'decline' }; + }) + + // Roots list handler (when server requests filesystem roots) + .onRootsList(async () => { + console.log('\n[ROOTS] Received roots list request from server'); + return { + roots: [ + { uri: 'file:///workspace', name: 'Workspace' }, + { uri: 'file:///home/user', name: 'Home Directory' }, + { uri: 'file:///tmp', name: 'Temporary Files' } + ] + }; + }) + + // ─── Error Handlers ─── + .onError((error, ctx) => { + console.error(`\n[CLIENT ERROR] ${ctx.type}: ${error.message}`); + console.error(`[CLIENT ERROR] Request ID: ${ctx.requestId}`); + // Return the original error (could also transform it) + return error; + }) + .onProtocolError((error, ctx) => { + console.error(`\n[PROTOCOL ERROR] ${ctx.method}: ${error.message}`); + console.error(`[PROTOCOL ERROR] Request ID: ${ctx.requestId}`); + }) + + .build(); + + // Set up client error handler + client.onerror = error => { + console.error('\n[CLIENT] Error event:', error); + }; + + // Create transport with optional session ID for reconnection + transport = new StreamableHTTPClientTransport(new URL(serverUrl), { + sessionId: sessionId + }); + + // Set up notification handler for logging messages + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + notificationCount++; + console.log(`\n[NOTIFICATION #${notificationCount}] ${notification.params.level}: ${notification.params.data}`); + process.stdout.write('> '); + }); + + // Connect the client + await client.connect(transport); + sessionId = transport.sessionId; + console.log('Connected to MCP server'); + console.log('Session ID:', sessionId); + } catch (error) { + console.error('Failed to connect:', error); + client = null; + transport = null; + } +} + +async function disconnect(): Promise { + if (!client || !transport) { + console.log('Not connected.'); + return; + } + + try { + await transport.close(); + console.log('Disconnected from MCP server'); + client = null; + transport = null; + } catch (error) { + console.error('Error disconnecting:', error); + } +} + +async function reconnect(): Promise { + if (client) { + await disconnect(); + } + await connect(); +} + +async function listTools(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ListToolsRequest = { + method: 'tools/list', + params: {} + }; + const result = await client.request(request, ListToolsResultSchema); + + console.log('\nAvailable tools:'); + if (result.tools.length === 0) { + console.log(' No tools available'); + } else { + for (const tool of result.tools) { + console.log(` - ${tool.name}: ${getDisplayName(tool)} - ${tool.description}`); + } + } + } catch (error) { + console.log(`Tools not supported by this server (${error})`); + } +} + +async function callTool(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: CallToolRequest = { + method: 'tools/call', + params: { + name, + arguments: args + } + }; + + const result = await client.request(request, CallToolResultSchema); + + console.log('\nTool result:'); + for (const item of result.content) { + if (isTextContent(item)) { + console.log(` ${item.text}`); + } else { + console.log(` [${item.type}]:`, item); + } + } + } catch (error) { + console.log(`Error calling tool ${name}: ${error}`); + } +} + +async function listPrompts(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ListPromptsRequest = { + method: 'prompts/list', + params: {} + }; + const result = await client.request(request, ListPromptsResultSchema); + + console.log('\nAvailable prompts:'); + if (result.prompts.length === 0) { + console.log(' No prompts available'); + } else { + for (const prompt of result.prompts) { + console.log(` - ${prompt.name}: ${getDisplayName(prompt)} - ${prompt.description}`); + } + } + } catch (error) { + console.log(`Prompts not supported by this server (${error})`); + } +} + +async function getPrompt(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: GetPromptRequest = { + method: 'prompts/get', + params: { + name, + arguments: args as Record + } + }; + + const result = await client.request(request, GetPromptResultSchema); + console.log('\nPrompt template:'); + for (const [index, msg] of result.messages.entries()) { + console.log(` [${index + 1}] ${msg.role}: ${isTextContent(msg.content) ? msg.content.text : JSON.stringify(msg.content)}`); + } + } catch (error) { + console.log(`Error getting prompt ${name}: ${error}`); + } +} + +async function listResources(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ListResourcesRequest = { + method: 'resources/list', + params: {} + }; + const result = await client.request(request, ListResourcesResultSchema); + + console.log('\nAvailable resources:'); + if (result.resources.length === 0) { + console.log(' No resources available'); + } else { + for (const resource of result.resources) { + console.log(` - ${resource.name}: ${getDisplayName(resource)} - ${resource.uri}`); + } + } + } catch (error) { + console.log(`Resources not supported by this server (${error})`); + } +} + +async function readResource(uri: string): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ReadResourceRequest = { + method: 'resources/read', + params: { uri } + }; + + console.log(`\nReading resource: ${uri}`); + const result = await client.request(request, ReadResourceResultSchema); + + console.log('Resource contents:'); + for (const content of result.contents) { + console.log(` URI: ${content.uri}`); + if (content.mimeType) { + console.log(` Type: ${content.mimeType}`); + } + + if ('text' in content && typeof content.text === 'string') { + console.log(' Content:'); + console.log(' ---'); + console.log( + content.text + .split('\n') + .map((line: string) => ' ' + line) + .join('\n') + ); + console.log(' ---'); + } else if ('blob' in content && typeof content.blob === 'string') { + console.log(` [Binary data: ${content.blob.length} bytes]`); + } + } + } catch (error) { + console.log(`Error reading resource ${uri}: ${error}`); + } +} + +async function cleanup(): Promise { + if (client && transport) { + try { + await transport.close(); + } catch (error) { + console.error('Error closing transport:', error); + } + } + + readline.close(); + console.log('\nGoodbye!'); + process.exit(0); +} + +// Handle Ctrl+C +process.on('SIGINT', async () => { + console.log('\nReceived SIGINT. Cleaning up...'); + await cleanup(); +}); + +// Start the interactive client +try { + await main(); +} catch (error) { + console.error('Error running MCP client:', error); + process.exit(1); +} diff --git a/examples/client/src/simpleTaskInteractiveClient.ts b/examples/client/src/simpleTaskInteractiveClient.ts index 2a4d47043..31b645127 100644 --- a/examples/client/src/simpleTaskInteractiveClient.ts +++ b/examples/client/src/simpleTaskInteractiveClient.ts @@ -9,14 +9,15 @@ import { createInterface } from 'node:readline'; -import type { CreateMessageRequest, CreateMessageResult, TextContent } from '@modelcontextprotocol/client'; +import type { ContentBlock, CreateMessageRequest, CreateMessageResult } from '@modelcontextprotocol/client'; import { CallToolResultSchema, Client, CreateMessageRequestSchema, ElicitRequestSchema, ErrorCode, - McpError, + isTextContent, + ProtocolError, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; @@ -34,9 +35,9 @@ function question(prompt: string): Promise { }); } -function getTextContent(result: { content: Array<{ type: string; text?: string }> }): string { - const textContent = result.content.find((c): c is TextContent => c.type === 'text'); - return textContent?.text ?? '(no text)'; +function getTextContent(result: { content: ContentBlock[] }): string | undefined { + const textContent = result.content.find(element => isTextContent(element)); + return textContent?.text; } async function elicitationCallback(params: { @@ -104,7 +105,7 @@ async function run(url: string): Promise { // Set up elicitation request handler client.setRequestHandler(ElicitRequestSchema, async request => { if (request.params.mode && request.params.mode !== 'form') { - throw new McpError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); + throw new ProtocolError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); } return elicitationCallback(request.params); }); diff --git a/examples/client/src/streamableHttpWithSseFallbackClient.ts b/examples/client/src/streamableHttpWithSseFallbackClient.ts index 90fee9270..bda71d3f7 100644 --- a/examples/client/src/streamableHttpWithSseFallbackClient.ts +++ b/examples/client/src/streamableHttpWithSseFallbackClient.ts @@ -2,6 +2,7 @@ import type { CallToolRequest, ListToolsRequest } from '@modelcontextprotocol/cl import { CallToolResultSchema, Client, + isTextContent, ListToolsResultSchema, LoggingMessageNotificationSchema, SSEClientTransport, @@ -173,7 +174,7 @@ async function startNotificationTool(client: Client): Promise { console.log('Tool result:'); for (const item of result.content) { - if (item.type === 'text') { + if (isTextContent(item)) { console.log(` ${item.text}`); } else { console.log(` ${item.type} content:`, item); diff --git a/examples/server/src/elicitationFormExample.ts b/examples/server/src/elicitationFormExample.ts index 70ff8ecb5..2e8ad1b25 100644 --- a/examples/server/src/elicitationFormExample.ts +++ b/examples/server/src/elicitationFormExample.ts @@ -11,7 +11,7 @@ import { randomUUID } from 'node:crypto'; import { createMcpExpressApp } from '@modelcontextprotocol/express'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; -import { isInitializeRequest, McpServer } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, text } from '@modelcontextprotocol/server'; import type { Request, Response } from 'express'; // Create MCP server - it will automatically use AjvJsonSchemaValidator with sensible defaults @@ -86,39 +86,21 @@ mcpServer.registerTool( return { content: [ - { - type: 'text', - text: `Registration successful!\n\nUsername: ${username}\nEmail: ${email}\nNewsletter: ${newsletter ? 'Yes' : 'No'}` - } + text(`Registration successful!\n\nUsername: ${username}\nEmail: ${email}\nNewsletter: ${newsletter ? 'Yes' : 'No'}`) ] }; } else if (result.action === 'decline') { return { - content: [ - { - type: 'text', - text: 'Registration cancelled by user.' - } - ] + content: [text('Registration cancelled by user.')] }; } else { return { - content: [ - { - type: 'text', - text: 'Registration was cancelled.' - } - ] + content: [text('Registration was cancelled.')] }; } } catch (error) { return { - content: [ - { - type: 'text', - text: `Registration failed: ${error instanceof Error ? error.message : String(error)}` - } - ], + content: [text(`Registration failed: ${error instanceof Error ? error.message : String(error)}`)], isError: true }; } @@ -162,7 +144,7 @@ mcpServer.registerTool( if (basicInfo.action !== 'accept' || !basicInfo.content) { return { - content: [{ type: 'text', text: 'Event creation cancelled.' }] + content: [text('Event creation cancelled.')] }; } @@ -198,7 +180,7 @@ mcpServer.registerTool( if (dateTime.action !== 'accept' || !dateTime.content) { return { - content: [{ type: 'text', text: 'Event creation cancelled.' }] + content: [text('Event creation cancelled.')] }; } @@ -209,21 +191,11 @@ mcpServer.registerTool( }; return { - content: [ - { - type: 'text', - text: `Event created successfully!\n\n${JSON.stringify(event, null, 2)}` - } - ] + content: [text(`Event created successfully!\n\n${JSON.stringify(event, null, 2)}`)] }; } catch (error) { return { - content: [ - { - type: 'text', - text: `Event creation failed: ${error instanceof Error ? error.message : String(error)}` - } - ], + content: [text(`Event creation failed: ${error instanceof Error ? error.message : String(error)}`)], isError: true }; } @@ -287,30 +259,20 @@ mcpServer.registerTool( if (result.action === 'accept' && result.content) { return { - content: [ - { - type: 'text', - text: `Address updated successfully!\n\n${JSON.stringify(result.content, null, 2)}` - } - ] + content: [text(`Address updated successfully!\n\n${JSON.stringify(result.content, null, 2)}`)] }; } else if (result.action === 'decline') { return { - content: [{ type: 'text', text: 'Address update cancelled by user.' }] + content: [text('Address update cancelled by user.')] }; } else { return { - content: [{ type: 'text', text: 'Address update was cancelled.' }] + content: [text('Address update was cancelled.')] }; } } catch (error) { return { - content: [ - { - type: 'text', - text: `Address update failed: ${error instanceof Error ? error.message : String(error)}` - } - ], + content: [text(`Address update failed: ${error instanceof Error ? error.message : String(error)}`)], isError: true }; } diff --git a/examples/server/src/elicitationUrlExample.ts b/examples/server/src/elicitationUrlExample.ts index e72845f1a..3348bb9df 100644 --- a/examples/server/src/elicitationUrlExample.ts +++ b/examples/server/src/elicitationUrlExample.ts @@ -46,12 +46,12 @@ const getServer = () => { cartId: z.string().describe('The ID of the cart to confirm') } }, - async ({ cartId }, extra): Promise => { + async ({ cartId }, ctx): Promise => { /* In a real world scenario, there would be some logic here to check if the user has the provided cartId. For the purposes of this example, we'll throw an error (-> elicits the client to open a URL to confirm payment) */ - const sessionId = extra.sessionId; + const sessionId = ctx.mcpCtx.sessionId; if (!sessionId) { throw new Error('Expected a Session ID'); } @@ -79,15 +79,15 @@ const getServer = () => { param1: z.string().describe('First parameter') } }, - async (_, extra): Promise => { + async (_, ctx): Promise => { /* In a real world scenario, there would be some logic here to check if we already have a valid access token for the user. - Auth info (with a subject or `sub` claim) can be typically be found in `extra.authInfo`. + Auth info (with a subject or `sub` claim) can be typically be found in `ctx.requestCtx.authInfo`. If we do, we can just return the result of the tool call. If we don't, we can throw an ElicitationRequiredError to request the user to authenticate. For the purposes of this example, we'll throw an error (-> elicits the client to open a URL to authenticate). */ - const sessionId = extra.sessionId; + const sessionId = ctx.mcpCtx.sessionId; if (!sessionId) { throw new Error('Expected a Session ID'); } diff --git a/examples/server/src/jsonResponseStreamableHttp.ts b/examples/server/src/jsonResponseStreamableHttp.ts index fa0e3a300..481b47716 100644 --- a/examples/server/src/jsonResponseStreamableHttp.ts +++ b/examples/server/src/jsonResponseStreamableHttp.ts @@ -51,7 +51,7 @@ const getServer = () => { name: z.string().describe('Name to greet') } }, - async ({ name }, extra): Promise => { + async ({ name }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); await server.sendLoggingMessage( @@ -59,7 +59,7 @@ const getServer = () => { level: 'debug', data: `Starting multi-greet for ${name}` }, - extra.sessionId + ctx.mcpCtx.sessionId ); await sleep(1000); // Wait 1 second before first greeting @@ -69,7 +69,7 @@ const getServer = () => { level: 'info', data: `Sending first greeting to ${name}` }, - extra.sessionId + ctx.mcpCtx.sessionId ); await sleep(1000); // Wait another second before second greeting @@ -79,7 +79,7 @@ const getServer = () => { level: 'info', data: `Sending second greeting to ${name}` }, - extra.sessionId + ctx.mcpCtx.sessionId ); return { diff --git a/examples/server/src/simpleStatelessStreamableHttp.ts b/examples/server/src/simpleStatelessStreamableHttp.ts index 15669131a..0c7a791b9 100644 --- a/examples/server/src/simpleStatelessStreamableHttp.ts +++ b/examples/server/src/simpleStatelessStreamableHttp.ts @@ -49,7 +49,7 @@ const getServer = () => { count: z.number().describe('Number of notifications to send (0 for 100)').default(10) } }, - async ({ interval, count }, extra): Promise => { + async ({ interval, count }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); let counter = 0; @@ -61,7 +61,7 @@ const getServer = () => { level: 'info', data: `Periodic notification #${counter} at ${new Date().toISOString()}` }, - extra.sessionId + ctx.mcpCtx.sessionId ); } catch (error) { console.error('Error sending notification:', error); diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 75308beff..4157b878f 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -20,7 +20,8 @@ import { InMemoryTaskMessageQueue, InMemoryTaskStore, isInitializeRequest, - McpServer + McpServer, + TaskPlugin } from '@modelcontextprotocol/server'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -44,12 +45,18 @@ const getServer = () => { websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' }, { - capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } }, - taskStore, // Enable task support - taskMessageQueue: new InMemoryTaskMessageQueue() + capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } } } ); + // Enable task support via TaskPlugin + server.usePlugin( + new TaskPlugin({ + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + }) + ); + // Register a simple tool that returns a greeting server.registerTool( 'greet', @@ -86,7 +93,7 @@ const getServer = () => { openWorldHint: false } }, - async ({ name }, extra): Promise => { + async ({ name }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); await server.sendLoggingMessage( @@ -94,7 +101,7 @@ const getServer = () => { level: 'debug', data: `Starting multi-greet for ${name}` }, - extra.sessionId + ctx.mcpCtx.sessionId ); await sleep(1000); // Wait 1 second before first greeting @@ -104,7 +111,7 @@ const getServer = () => { level: 'info', data: `Sending first greeting to ${name}` }, - extra.sessionId + ctx.mcpCtx.sessionId ); await sleep(1000); // Wait another second before second greeting @@ -114,7 +121,7 @@ const getServer = () => { level: 'info', data: `Sending second greeting to ${name}` }, - extra.sessionId + ctx.mcpCtx.sessionId ); return { @@ -137,7 +144,7 @@ const getServer = () => { infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect') } }, - async ({ infoType }, extra): Promise => { + async ({ infoType }, ctx): Promise => { let message: string; let requestedSchema: { type: 'object'; @@ -236,8 +243,8 @@ const getServer = () => { } try { - // Use sendRequest through the extra parameter to elicit input - const result = await extra.sendRequest( + // Use sendRequest through the ctx parameter to elicit input + const result = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -325,7 +332,7 @@ const getServer = () => { count: z.number().describe('Number of notifications to send (0 for 100)').default(50) } }, - async ({ interval, count }, extra): Promise => { + async ({ interval, count }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); let counter = 0; @@ -337,7 +344,7 @@ const getServer = () => { level: 'info', data: `Periodic notification #${counter} at ${new Date().toISOString()}` }, - extra.sessionId + ctx.mcpCtx.sessionId ); } catch (error) { console.error('Error sending notification:', error); @@ -482,10 +489,12 @@ const getServer = () => { } }, { - async createTask({ duration }, { taskStore, taskRequestedTtl }) { + async createTask({ duration }, ctx) { // Create the task + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const taskStore = ctx.taskCtx.store; const task = await taskStore.createTask({ - ttl: taskRequestedTtl + ttl: ctx.taskCtx.requestedTtl }); // Simulate out-of-band work @@ -506,11 +515,13 @@ const getServer = () => { task }; }, - async getTask(_args, { taskId, taskStore }) { - return await taskStore.getTask(taskId); + async getTask(_args, ctx) { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + return await ctx.taskCtx.store.getTask(ctx.taskCtx.id!); }, - async getTaskResult(_args, { taskId, taskStore }) { - const result = await taskStore.getTaskResult(taskId); + async getTaskResult(_args, ctx) { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const result = await ctx.taskCtx.store.getTaskResult(ctx.taskCtx.id!); return result as CallToolResult; } } diff --git a/examples/server/src/simpleStreamableHttpBuilder.ts b/examples/server/src/simpleStreamableHttpBuilder.ts new file mode 100644 index 000000000..d1c64f8a4 --- /dev/null +++ b/examples/server/src/simpleStreamableHttpBuilder.ts @@ -0,0 +1,550 @@ +/** + * Simple Streamable HTTP Server Example using Builder Pattern + * + * This example demonstrates using the McpServer.builder() fluent API + * to create and configure an MCP server with: + * - Tools, resources, and prompts registration + * - Middleware (logging, custom metrics) + * - Per-tool middleware (authorization) + * - Error handlers (onError, onProtocolError) + * - Context helpers (logging, notifications) + * + * Run with: npx tsx src/simpleStreamableHttpBuilder.ts + */ + +import { randomUUID } from 'node:crypto'; + +import { createMcpExpressApp } from '@modelcontextprotocol/express'; +import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; +import type { CallToolResult, GetPromptResult, ReadResourceResult, ToolMiddleware } from '@modelcontextprotocol/server'; +import { createLoggingMiddleware, isInitializeRequest, McpServer, text } from '@modelcontextprotocol/server'; +import type { Request, Response } from 'express'; +import * as z from 'zod/v4'; + +import { InMemoryEventStore } from './inMemoryEventStore.js'; + +// ═══════════════════════════════════════════════════════════════════════════ +// Custom Middleware Examples +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Custom metrics middleware that tracks tool execution time. + * Demonstrates how to create custom middleware. + */ +const metricsMiddleware: ToolMiddleware = async (ctx, next) => { + const start = performance.now(); + try { + const result = await next(); + const duration = (performance.now() - start).toFixed(2); + console.log(`[METRICS] Tool '${ctx.name}' completed in ${duration}ms`); + return result; + } catch (error) { + const duration = (performance.now() - start).toFixed(2); + console.log(`[METRICS] Tool '${ctx.name}' failed in ${duration}ms`); + throw error; + } +}; + +/** + * Per-tool authorization middleware example. + * This is passed directly to a specific tool registration. + */ +const adminAuthMiddleware: ToolMiddleware = async (ctx, next) => { + // In a real app, check ctx.authInfo for admin scope + // For demo purposes, we'll check for a special argument + const args = ctx.args as Record; + if (args.requiresAdmin && !args.adminToken) { + throw new Error('Admin authorization required. Provide adminToken argument.'); + } + console.log(`[AUTH] Admin action authorized for tool '${ctx.name}'`); + return next(); +}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Session Management +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Session data type - stores transport for each session. + */ +interface SessionData { + transport: NodeStreamableHTTPServerTransport; + createdAt: Date; +} + +/** + * Simple Map-based session storage. + */ +const sessions = new Map(); + +// ═══════════════════════════════════════════════════════════════════════════ +// Server Factory +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Creates an MCP server using the builder pattern. + * + * The builder provides a fluent API for configuring the server: + * - .name() and .version() set server info + * - .options() configures capabilities + * - .useMiddleware() adds universal middleware + * - .useToolMiddleware() adds tool-specific middleware + * - .tool() registers tools inline (with optional per-tool middleware) + * - .resource() registers resources inline + * - .prompt() registers prompts inline + * - .onError() handles application errors + * - .onProtocolError() handles protocol errors + * - .build() creates the configured McpServer instance + */ +const getServer = () => { + const server = McpServer.builder() + .name('builder-example-server') + .version('1.0.0') + .options({ + capabilities: { logging: {} } + }) + + // ─── Universal Middleware ─── + // Runs for all request types (tools, resources, prompts) + .useMiddleware( + createLoggingMiddleware({ + level: 'info', + logger: (level, message, data) => { + const timestamp = new Date().toISOString(); + console.log(`[${timestamp}] [${level.toUpperCase()}] ${message}`, data ? JSON.stringify(data) : ''); + } + }) + ) + + // ─── Tool-Specific Middleware ─── + .useToolMiddleware(async (ctx, next) => { + console.log(`Tool '${ctx.name}' called`); + return next(); + }) + + // Custom metrics middleware + .useToolMiddleware(metricsMiddleware) + + // ─── Error Handlers ─── + .onError((error, ctx) => { + console.error(`[APP ERROR] ${ctx.type}/${ctx.name || ctx.method}: ${error.message}`); + // Return custom error response with additional context + return { + code: -32_000, + message: `Error in ${ctx.name || ctx.method}: ${error.message}`, + data: { type: ctx.type, requestId: ctx.requestId } + }; + }) + .onProtocolError((error, ctx) => { + console.error(`[PROTOCOL ERROR] ${ctx.method}: ${error.message}`); + // Protocol errors preserve error code, can customize message/data + return { + message: `Protocol error: ${error.message}`, + data: { requestId: ctx.requestId } + }; + }) + + // ─── Tool Registrations ─── + + // Simple greeting tool + .tool( + 'greet', + { + title: 'Greeting Tool', + description: 'A simple greeting tool that returns a personalized greeting', + inputSchema: { + name: z.string().describe('Name to greet') + } + }, + async ({ name }): Promise => { + return { + content: [text(`Hello, ${name}!`)] + }; + } + ) + + // Tool with notifications demonstrating context usage + .tool( + 'multi-greet', + { + title: 'Multiple Greeting Tool', + description: 'A tool that sends different greetings with delays and notifications', + inputSchema: { + name: z.string().describe('Name to greet') + } + }, + async function ({ name }, ctx): Promise { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + // Use context logging helper + await ctx.loggingNotification.debug(`Starting multi-greet for ${name}`); + + await sleep(1000); + + // Use sendNotification directly + await ctx.sendNotification({ + method: 'notifications/message', + params: { + level: 'info', + data: `Sending first greeting to ${name}` + } + }); + + await sleep(1000); + + await ctx.loggingNotification.info(`Sending second greeting to ${name}`); + + return { + content: [text(`Good morning, ${name}!`)] + }; + } + ) + + // Context demo tool - shows all context features + .tool( + 'context-demo', + { + title: 'Context Demo', + description: 'Demonstrates all context helper methods and properties', + inputSchema: { + message: z.string().describe('A message to echo back') + } + }, + async ({ message }, ctx): Promise => { + // Access MCP context + const mcpInfo = { + requestId: ctx.mcpCtx.requestId, + sessionId: ctx.mcpCtx.sessionId, + method: ctx.mcpCtx.method + }; + + // Access request context + const requestInfo = { + signalAborted: ctx.requestCtx.signal.aborted, + hasAuthInfo: !!ctx.requestCtx.authInfo + }; + + // Use logging helpers at different levels + await ctx.loggingNotification.debug('Debug: Processing context-demo'); + await ctx.loggingNotification.info('Info: Context inspection complete'); + + // Send custom notification + await ctx.sendNotification({ + method: 'notifications/message', + params: { + level: 'debug', + data: `Echo: ${message}` + } + }); + + return { + content: [ + text('Context Demo Results:'), + text(`MCP Context: ${JSON.stringify(mcpInfo, null, 2)}`), + text(`Request Context: ${JSON.stringify(requestInfo, null, 2)}`), + text(`Your message: ${message}`) + ] + }; + } + ) + + // Tool with per-tool middleware (authorization) + .tool( + 'admin-action', + { + title: 'Admin Action', + description: 'An admin-only tool demonstrating per-tool middleware', + inputSchema: { + action: z.string().describe('Admin action to perform'), + requiresAdmin: z.boolean().optional().describe('Whether this action requires admin auth'), + adminToken: z.string().optional().describe('Admin token for authorization') + }, + middleware: adminAuthMiddleware // Per-tool middleware + }, + async ({ action }): Promise => { + return { + content: [text(`Admin action '${action}' executed successfully`)] + }; + } + ) + + // Tool that intentionally throws an error (for testing error handlers) + .tool( + 'error-test', + { + title: 'Error Test', + description: 'A tool that throws errors to test error handlers', + inputSchema: { + errorType: z.enum(['application', 'validation']).describe('Type of error to throw') + } + }, + async ({ errorType }): Promise => { + const error = + errorType === 'application' + ? new Error('This is a test application error') + : new Error('Validation failed: invalid input format'); + throw error; + } + ) + + // ─── Resource Registration ─── + .resource( + 'greeting-resource', + 'https://example.com/greetings/default', + { + title: 'Default Greeting', + description: 'A simple greeting resource' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'https://example.com/greetings/default', + mimeType: 'text/plain', + text: 'Hello, world!' + } + ] + }; + } + ) + + // Resource demonstrating server info + .resource( + 'server-info', + 'https://example.com/server/info', + { + title: 'Server Information', + description: 'Returns current server statistics' + }, + async (): Promise => { + const stats = { + activeSessions: sessions.size, + uptime: process.uptime() + }; + return { + contents: [ + { + uri: 'https://example.com/server/info', + mimeType: 'application/json', + text: JSON.stringify(stats, null, 2) + } + ] + }; + } + ) + + // ─── Prompt Registration ─── + .prompt( + 'greeting-template', + { + title: 'Greeting Template', + description: 'A simple greeting prompt template', + argsSchema: { + name: z.string().describe('Name to include in greeting') + } + }, + async ({ name }): Promise => { + return { + messages: [ + { + role: 'user', + content: text(`Please greet ${name} in a friendly manner.`) + } + ] + }; + } + ) + + .build(); + + return server; +}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Express App Setup +// ═══════════════════════════════════════════════════════════════════════════ + +const PORT = process.env.PORT ? Number.parseInt(process.env.PORT, 10) : 3000; + +const app = createMcpExpressApp(); + +// ═══════════════════════════════════════════════════════════════════════════ +// MCP Request Handlers +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * MCP POST endpoint handler. + * Uses a simple Map for session management. + */ +const mcpPostHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + + try { + // Check for existing session + const session = sessionId ? sessions.get(sessionId) : undefined; + + if (session) { + // Reuse existing transport + console.log(`[REQUEST] Using existing session: ${sessionId}`); + await session.transport.handleRequest(req, res, req.body); + return; + } + + if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request - create session + console.log('[REQUEST] New initialization request'); + + const eventStore = new InMemoryEventStore(); + const transport = new NodeStreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + eventStore, + onsessioninitialized: sid => { + // Store session + sessions.set(sid, { + transport, + createdAt: new Date() + }); + } + }); + + // Clean up session when transport closes + transport.onclose = () => { + const sid = transport.sessionId; + if (sid) { + sessions.delete(sid); + } + }; + + // Connect the transport to the MCP server + const server = getServer(); + await server.connect(transport); + + await transport.handleRequest(req, res, req.body); + return; + } + + // Invalid request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32_000, + message: 'Bad Request: No valid session ID provided' + }, + id: null + }); + } catch (error) { + console.error('[ERROR] Handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32_603, + message: 'Internal server error' + }, + id: null + }); + } + } +}; + +app.post('/mcp', mcpPostHandler); + +/** + * MCP GET endpoint handler for SSE streams. + */ +const mcpGetHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + const session = sessionId ? sessions.get(sessionId) : undefined; + + if (!session) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + const lastEventId = req.headers['last-event-id'] as string | undefined; + if (lastEventId) { + console.log(`[SSE] Client reconnecting with Last-Event-ID: ${lastEventId}`); + } else { + console.log(`[SSE] Establishing new stream for session ${sessionId}`); + } + + await session.transport.handleRequest(req, res); +}; + +app.get('/mcp', mcpGetHandler); + +/** + * MCP DELETE endpoint handler for session termination. + */ +const mcpDeleteHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + const session = sessionId ? sessions.get(sessionId) : undefined; + + if (!session) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + console.log(`[SESSION] Termination request for session ${sessionId}`); + + try { + await session.transport.handleRequest(req, res); + } catch (error) { + console.error('[ERROR] Session termination:', error); + if (!res.headersSent) { + res.status(500).send('Error processing session termination'); + } + } +}; + +app.delete('/mcp', mcpDeleteHandler); + +// ═══════════════════════════════════════════════════════════════════════════ +// Server Startup +// ═══════════════════════════════════════════════════════════════════════════ + +app.listen(PORT, error => { + if (error) { + console.error('Failed to start server:', error); + // eslint-disable-next-line unicorn/no-process-exit + process.exit(1); + } + console.log('═══════════════════════════════════════════════════════════════'); + console.log('MCP Builder Example Server'); + console.log('═══════════════════════════════════════════════════════════════'); + console.log(`Listening on port ${PORT}`); + console.log(`MCP endpoint: http://localhost:${PORT}/mcp`); + console.log(''); + console.log('Features demonstrated:'); + console.log(' - Builder pattern for server configuration'); + console.log(' - Universal middleware (logging)'); + console.log(' - Tool-specific middleware (metrics)'); + console.log(' - Per-tool middleware (authorization)'); + console.log(' - Error handlers (onError, onProtocolError)'); + console.log(' - Context helpers (logging, notifications)'); + console.log('═══════════════════════════════════════════════════════════════'); +}); + +// ═══════════════════════════════════════════════════════════════════════════ +// Graceful Shutdown +// ═══════════════════════════════════════════════════════════════════════════ + +process.on('SIGINT', async () => { + console.log('\n[SHUTDOWN] Received SIGINT, shutting down...'); + + // Close all sessions + for (const [sid, session] of sessions) { + try { + console.log(`[SHUTDOWN] Closing session ${sid}`); + await session.transport.close(); + } catch (error) { + console.error(`[SHUTDOWN] Error closing session ${sid}:`, error); + } + } + + // Clear the sessions map + sessions.clear(); + + console.log('[SHUTDOWN] Complete'); + process.exit(0); +}); diff --git a/examples/server/src/simpleTaskInteractive.ts b/examples/server/src/simpleTaskInteractive.ts index 1b8532525..d02f9bf0b 100644 --- a/examples/server/src/simpleTaskInteractive.ts +++ b/examples/server/src/simpleTaskInteractive.ts @@ -516,7 +516,7 @@ const createServer = (): Server => { }); // Handle tool calls - server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { + server.setRequestHandler(CallToolRequestSchema, async (request, ctx): Promise => { const { name, arguments: args } = request.params; const taskParams = (request.params._meta?.task || request.params.task) as { ttl?: number; pollInterval?: number } | undefined; @@ -531,7 +531,7 @@ const createServer = (): Server => { pollInterval: taskParams.pollInterval ?? 1000 }; - const task = await taskStore.createTask(taskOptions, extra.requestId, request, extra.sessionId); + const task = await taskStore.createTask(taskOptions, ctx.mcpCtx.requestId, request, ctx.mcpCtx.sessionId); console.log(`\n[Server] ${name} called, task created: ${task.taskId}`); @@ -609,7 +609,7 @@ const createServer = (): Server => { activeTaskExecutions.set(task.taskId, { promise: taskExecution, server, - sessionId: extra.sessionId ?? '' + sessionId: ctx.mcpCtx.sessionId ?? '' }); return { task }; @@ -626,10 +626,10 @@ const createServer = (): Server => { }); // Handle tasks/result - server.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra): Promise => { + server.setRequestHandler(GetTaskPayloadRequestSchema, async (request, ctx): Promise => { const { taskId } = request.params; console.log(`[Server] tasks/result called for task ${taskId}`); - return taskResultHandler.handle(taskId, server, extra.sessionId ?? ''); + return taskResultHandler.handle(taskId, server, ctx.mcpCtx.sessionId ?? ''); }); return server; diff --git a/examples/server/src/ssePollingExample.ts b/examples/server/src/ssePollingExample.ts index 2416d2ec3..f4ac271f0 100644 --- a/examples/server/src/ssePollingExample.ts +++ b/examples/server/src/ssePollingExample.ts @@ -7,7 +7,7 @@ * Key features: * - Configures `retryInterval` to tell clients how long to wait before reconnecting * - Uses `eventStore` to persist events for replay after reconnection - * - Uses `extra.closeSSEStream()` callback to gracefully disconnect clients mid-operation + * - Uses `ctx.requestCtx.stream.closeSSEStream()` callback to gracefully disconnect clients mid-operation * * Run with: pnpm tsx src/ssePollingExample.ts * Test with: curl or the MCP Inspector @@ -16,7 +16,7 @@ import { randomUUID } from 'node:crypto'; import { createMcpExpressApp } from '@modelcontextprotocol/express'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; -import type { CallToolResult } from '@modelcontextprotocol/server'; +import type { CallToolResult, ServerRequestContext } from '@modelcontextprotocol/server'; import { McpServer } from '@modelcontextprotocol/server'; import cors from 'cors'; import type { Request, Response } from 'express'; @@ -40,10 +40,12 @@ server.registerTool( { description: 'A long-running task that sends progress updates. Server will disconnect mid-task to demonstrate polling.' }, - async (extra): Promise => { + async (ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + const sessionId = ctx.mcpCtx.sessionId; + const requestCtx = ctx.requestCtx as ServerRequestContext; - console.log(`[${extra.sessionId}] Starting long-task...`); + console.log(`[${sessionId}] Starting long-task...`); // Send first progress notification await server.sendLoggingMessage( @@ -51,7 +53,7 @@ server.registerTool( level: 'info', data: 'Progress: 25% - Starting work...' }, - extra.sessionId + sessionId ); await sleep(1000); @@ -61,16 +63,16 @@ server.registerTool( level: 'info', data: 'Progress: 50% - Halfway there...' }, - extra.sessionId + sessionId ); await sleep(1000); // Server decides to disconnect the client to free resources // Client will reconnect via GET with Last-Event-ID after the transport's retryInterval - // Use extra.closeSSEStream callback - available when eventStore is configured - if (extra.closeSSEStream) { - console.log(`[${extra.sessionId}] Closing SSE stream to trigger client polling...`); - extra.closeSSEStream(); + // Use requestCtx.stream.closeSSEStream callback - available when eventStore is configured + if (requestCtx.stream.closeSSEStream) { + console.log(`[${sessionId}] Closing SSE stream to trigger client polling...`); + requestCtx.stream.closeSSEStream(); } // Continue processing while client is disconnected @@ -81,7 +83,7 @@ server.registerTool( level: 'info', data: 'Progress: 75% - Almost done (sent while client disconnected)...' }, - extra.sessionId + sessionId ); await sleep(500); @@ -90,10 +92,10 @@ server.registerTool( level: 'info', data: 'Progress: 100% - Complete!' }, - extra.sessionId + sessionId ); - console.log(`[${extra.sessionId}] Task complete`); + console.log(`[${sessionId}] Task complete`); return { content: [ diff --git a/examples/server/src/toolWithSampleServer.ts b/examples/server/src/toolWithSampleServer.ts index 9881830b5..e3a529b98 100644 --- a/examples/server/src/toolWithSampleServer.ts +++ b/examples/server/src/toolWithSampleServer.ts @@ -1,6 +1,6 @@ // Run with: pnpm tsx src/toolWithSampleServer.ts -import { McpServer, StdioServerTransport } from '@modelcontextprotocol/server'; +import { isTextContent, McpServer, StdioServerTransport } from '@modelcontextprotocol/server'; import * as z from 'zod/v4'; const mcpServer = new McpServer({ @@ -37,7 +37,7 @@ mcpServer.registerTool( content: [ { type: 'text', - text: response.content.type === 'text' ? response.content.text : 'Unable to generate summary' + text: isTextContent(response.content) ? response.content.text : 'Unable to generate summary' } ] }; diff --git a/packages/client/src/client/builder.ts b/packages/client/src/client/builder.ts new file mode 100644 index 000000000..2724097f9 --- /dev/null +++ b/packages/client/src/client/builder.ts @@ -0,0 +1,449 @@ +/** + * Client Builder + * + * Provides a fluent API for configuring and creating Client instances. + * The builder is an additive convenience layer - the existing constructor + * API remains available for users who prefer it. + * + * @example + * ```typescript + * const client = Client.builder() + * .name('my-client') + * .version('1.0.0') + * .capabilities({ sampling: {} }) + * .useMiddleware(loggingMiddleware) + * .onSamplingRequest(samplingHandler) + * .build(); + * ``` + */ + +import type { + ClientCapabilities, + CreateMessageRequest, + CreateMessageResult, + CreateMessageResultWithTools, + CreateTaskResult, + ElicitRequest, + ElicitResult, + jsonSchemaValidator, + ListChangedHandlers, + ListRootsRequest, + ListRootsResult +} from '@modelcontextprotocol/core'; + +import type { Client } from './client.js'; +import type { ClientContextInterface } from './context.js'; +import type { + ClientMiddleware, + ElicitationMiddleware, + IncomingMiddleware, + OutgoingMiddleware, + ResourceReadMiddleware, + SamplingMiddleware, + ToolCallMiddleware +} from './middleware.js'; + +/** + * Handler for sampling requests from the server. + * Receives the full CreateMessageRequest and returns the sampling result. + * When task creation is requested via params.task, returns CreateTaskResult instead. + */ +export type SamplingRequestHandler = ( + request: CreateMessageRequest, + ctx: ClientContextInterface +) => + | CreateMessageResult + | CreateMessageResultWithTools + | CreateTaskResult + | Promise; + +/** + * Handler for elicitation requests from the server. + * Receives the full ElicitRequest and returns the elicitation result. + * When task creation is requested via params.task, returns CreateTaskResult instead. + */ +export type ElicitationRequestHandler = ( + request: ElicitRequest, + ctx: ClientContextInterface +) => ElicitResult | CreateTaskResult | Promise; + +/** + * Handler for roots list requests from the server. + * Receives the full ListRootsRequest and returns the list of roots. + */ +export type RootsListHandler = (request: ListRootsRequest, ctx: ClientContextInterface) => ListRootsResult | Promise; + +/** + * Error handler type for application errors + */ +export type OnErrorHandler = (error: Error, ctx: ErrorContext) => OnErrorReturn | void | Promise; + +/** + * Error handler type for protocol errors + */ +export type OnProtocolErrorHandler = ( + error: Error, + ctx: ErrorContext +) => OnProtocolErrorReturn | void | Promise; + +/** + * Return type for onError handler + */ +export type OnErrorReturn = string | { code?: number; message?: string; data?: unknown } | Error; + +/** + * Return type for onProtocolError handler (code cannot be changed) + */ +export type OnProtocolErrorReturn = string | { message?: string; data?: unknown }; + +/** + * Context provided to error handlers + */ +export interface ErrorContext { + type: 'sampling' | 'elicitation' | 'rootsList' | 'protocol'; + method: string; + requestId: string; +} + +/** + * Options for client configuration + */ +export interface ClientBuilderOptions { + /** Enforce strict capability checking */ + enforceStrictCapabilities?: boolean; +} + +/** + * Fluent builder for Client instances. + * + * Provides a declarative, chainable API for configuring clients. + * All configuration is collected and applied when build() is called. + */ +export class ClientBuilder { + private _name?: string; + private _version?: string; + private _capabilities?: ClientCapabilities; + private _options: ClientBuilderOptions = {}; + private _jsonSchemaValidator?: jsonSchemaValidator; + private _listChanged?: ListChangedHandlers; + + // Middleware + private _universalMiddleware: ClientMiddleware[] = []; + private _outgoingMiddleware: OutgoingMiddleware[] = []; + private _incomingMiddleware: IncomingMiddleware[] = []; + private _toolCallMiddleware: ToolCallMiddleware[] = []; + private _resourceReadMiddleware: ResourceReadMiddleware[] = []; + private _samplingMiddleware: SamplingMiddleware[] = []; + private _elicitationMiddleware: ElicitationMiddleware[] = []; + + // Handlers + private _samplingHandler?: SamplingRequestHandler; + private _elicitationHandler?: ElicitationRequestHandler; + private _rootsListHandler?: RootsListHandler; + + // Error handlers + private _onError?: OnErrorHandler; + private _onProtocolError?: OnProtocolErrorHandler; + + /** + * Sets the client name. + */ + name(name: string): this { + this._name = name; + return this; + } + + /** + * Sets the client version. + */ + version(version: string): this { + this._version = version; + return this; + } + + /** + * Sets the client capabilities. + * + * @example + * ```typescript + * .capabilities({ + * sampling: {}, + * roots: { listChanged: true } + * }) + * ``` + */ + capabilities(capabilities: ClientCapabilities): this { + this._capabilities = { ...this._capabilities, ...capabilities }; + return this; + } + + /** + * Sets client options. + */ + options(options: ClientBuilderOptions): this { + this._options = { ...this._options, ...options }; + return this; + } + + /** + * Sets the JSON Schema validator for tool output validation. + * + * @example + * ```typescript + * .jsonSchemaValidator(new AjvJsonSchemaValidator()) + * ``` + */ + jsonSchemaValidator(validator: jsonSchemaValidator): this { + this._jsonSchemaValidator = validator; + return this; + } + + /** + * Configures handlers for list changed notifications (tools, prompts, resources). + * + * @example + * ```typescript + * .onListChanged({ + * tools: { + * onChanged: (error, tools) => console.log('Tools updated:', tools) + * }, + * prompts: { + * onChanged: (error, prompts) => console.log('Prompts updated:', prompts) + * } + * }) + * ``` + */ + onListChanged(handlers: ListChangedHandlers): this { + this._listChanged = { ...this._listChanged, ...handlers }; + return this; + } + + /** + * Adds universal middleware that runs for all requests. + */ + useMiddleware(middleware: ClientMiddleware): this { + this._universalMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware for outgoing requests only. + */ + useOutgoingMiddleware(middleware: OutgoingMiddleware): this { + this._outgoingMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware for incoming requests only. + */ + useIncomingMiddleware(middleware: IncomingMiddleware): this { + this._incomingMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for tool calls. + */ + useToolCallMiddleware(middleware: ToolCallMiddleware): this { + this._toolCallMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for resource reads. + */ + useResourceReadMiddleware(middleware: ResourceReadMiddleware): this { + this._resourceReadMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for sampling requests. + */ + useSamplingMiddleware(middleware: SamplingMiddleware): this { + this._samplingMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for elicitation requests. + */ + useElicitationMiddleware(middleware: ElicitationMiddleware): this { + this._elicitationMiddleware.push(middleware); + return this; + } + + /** + * Sets the handler for sampling requests from the server. + * + * @example + * ```typescript + * .onSamplingRequest(async (params, ctx) => { + * const result = await llm.complete(params.messages); + * return { role: 'assistant', content: result }; + * }) + * ``` + */ + onSamplingRequest(handler: SamplingRequestHandler): this { + this._samplingHandler = handler; + return this; + } + + /** + * Sets the handler for elicitation requests from the server. + */ + onElicitation(handler: ElicitationRequestHandler): this { + this._elicitationHandler = handler; + return this; + } + + /** + * Sets the handler for roots list requests from the server. + */ + onRootsList(handler: RootsListHandler): this { + this._rootsListHandler = handler; + return this; + } + + /** + * Sets the application error handler. + * Called when a handler throws an error. + */ + onError(handler: OnErrorHandler): this { + this._onError = handler; + return this; + } + + /** + * Sets the protocol error handler. + * Called for protocol-level errors. + */ + onProtocolError(handler: OnProtocolErrorHandler): this { + this._onProtocolError = handler; + return this; + } + + /** + * Gets the collected configuration (for debugging/testing). + */ + getConfig(): { + name?: string; + version?: string; + capabilities?: ClientCapabilities; + options: ClientBuilderOptions; + middlewareCount: number; + hasHandlers: boolean; + } { + return { + name: this._name, + version: this._version, + capabilities: this._capabilities, + options: this._options, + middlewareCount: + this._universalMiddleware.length + + this._outgoingMiddleware.length + + this._incomingMiddleware.length + + this._toolCallMiddleware.length + + this._resourceReadMiddleware.length + + this._samplingMiddleware.length + + this._elicitationMiddleware.length, + hasHandlers: !!this._samplingHandler || !!this._elicitationHandler || !!this._rootsListHandler + }; + } + + /** + * Builds and returns the configured Client instance. + */ + build(): Client { + if (!this._name) { + throw new Error('Client name is required. Use .name() to set it.'); + } + if (!this._version) { + throw new Error('Client version is required. Use .version() to set it.'); + } + + const result: ClientBuilderResult = { + clientInfo: { + name: this._name, + version: this._version + }, + capabilities: this._capabilities, + options: this._options, + jsonSchemaValidator: this._jsonSchemaValidator, + listChanged: this._listChanged, + middleware: { + universal: this._universalMiddleware, + outgoing: this._outgoingMiddleware, + incoming: this._incomingMiddleware, + toolCall: this._toolCallMiddleware, + resourceRead: this._resourceReadMiddleware, + sampling: this._samplingMiddleware, + elicitation: this._elicitationMiddleware + }, + handlers: { + sampling: this._samplingHandler, + elicitation: this._elicitationHandler, + rootsList: this._rootsListHandler + }, + errorHandlers: { + onError: this._onError, + onProtocolError: this._onProtocolError + } + }; + + // Dynamically import Client to create the instance + // eslint-disable-next-line @typescript-eslint/no-require-imports + const { Client: ClientClass } = require('./client.js'); + return ClientClass.fromBuilderResult(result); + } +} + +/** + * Result of building the client configuration. + * Used to create the actual Client instance. + */ +export interface ClientBuilderResult { + clientInfo: { + name: string; + version: string; + }; + capabilities?: ClientCapabilities; + options: ClientBuilderOptions; + jsonSchemaValidator?: jsonSchemaValidator; + listChanged?: ListChangedHandlers; + middleware: { + universal: ClientMiddleware[]; + outgoing: OutgoingMiddleware[]; + incoming: IncomingMiddleware[]; + toolCall: ToolCallMiddleware[]; + resourceRead: ResourceReadMiddleware[]; + sampling: SamplingMiddleware[]; + elicitation: ElicitationMiddleware[]; + }; + handlers: { + sampling?: SamplingRequestHandler; + elicitation?: ElicitationRequestHandler; + rootsList?: RootsListHandler; + }; + errorHandlers: { + onError?: OnErrorHandler; + onProtocolError?: OnProtocolErrorHandler; + }; +} + +/** + * Creates a new ClientBuilder instance. + * + * @example + * ```typescript + * const client = createClientBuilder() + * .name('my-client') + * .version('1.0.0') + * .capabilities({ sampling: {} }) + * .build(); + * ``` + */ +export function createClientBuilder(): ClientBuilder { + return new ClientBuilder(); +} diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 754d77277..9bb81146e 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -1,5 +1,6 @@ import type { AnyObjectSchema, + BaseRequestContext, CallToolRequest, ClientCapabilities, ClientNotification, @@ -7,8 +8,12 @@ import type { ClientResult, CompatibilityCallToolResultSchema, CompleteRequest, + ContextInterface, + ErrorInterceptionContext, + ErrorInterceptionResult, GetPromptRequest, Implementation, + JSONRPCRequest, JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator, @@ -17,13 +22,15 @@ import type { ListPromptsRequest, ListResourcesRequest, ListResourceTemplatesRequest, + ListRootsResult, ListToolsRequest, LoggingLevel, + McpContext, + MessageExtraInfo, Notification, ProtocolOptions, ReadResourceRequest, Request, - RequestHandlerExtra, RequestOptions, Result, SchemaOutput, @@ -40,6 +47,7 @@ import { assertClientRequestTaskCapability, assertToolsCallTaskCapability, CallToolResultSchema, + CapabilityError, CompleteResultSchema, CreateMessageRequestSchema, CreateMessageResultSchema, @@ -48,29 +56,52 @@ import { ElicitRequestSchema, ElicitResultSchema, EmptyResultSchema, - ErrorCode, getObjectShape, GetPromptResultSchema, InitializeResultSchema, + isProtocolError, isZ4Schema, LATEST_PROTOCOL_VERSION, ListChangedOptionsBaseSchema, ListPromptsResultSchema, ListResourcesResultSchema, ListResourceTemplatesResultSchema, + ListRootsRequestSchema, ListToolsResultSchema, - McpError, mergeCapabilities, PromptListChangedNotificationSchema, Protocol, + ProtocolError, ReadResourceResultSchema, ResourceListChangedNotificationSchema, safeParse, + StateError, SUPPORTED_PROTOCOL_VERSIONS, ToolListChangedNotificationSchema } from '@modelcontextprotocol/core'; import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; +import type { + ClientBuilderResult, + ErrorContext, + OnErrorHandler, + OnErrorReturn, + OnProtocolErrorHandler, + OnProtocolErrorReturn +} from './builder.js'; +import { ClientBuilder } from './builder.js'; +import type { ClientRequestContext } from './context.js'; +import { ClientContext } from './context.js'; +import type { + ClientMiddleware, + ElicitationMiddleware, + IncomingMiddleware, + OutgoingMiddleware, + ResourceReadMiddleware, + SamplingMiddleware, + ToolCallMiddleware +} from './middleware.js'; +import { ClientMiddlewareManager } from './middleware.js'; /** * Elicitation default application helper. Applies defaults to the data based on the schema. @@ -253,6 +284,11 @@ export class Client< private _experimental?: { tasks: ExperimentalClientTasks }; private _listChangedDebounceTimers: Map> = new Map(); private _pendingListChangedConfig?: ListChangedHandlers; + private readonly _middleware: ClientMiddlewareManager; + + // Error handlers (single callback pattern, matching McpServer) + private _onErrorHandler?: OnErrorHandler; + private _onProtocolErrorHandler?: OnProtocolErrorHandler; private _enforceStrictCapabilities: boolean; /** @@ -265,6 +301,7 @@ export class Client< super(options); this._capabilities = options?.capabilities ?? {}; this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new AjvJsonSchemaValidator(); + this._middleware = new ClientMiddlewareManager(); this._enforceStrictCapabilities = options?.enforceStrictCapabilities ?? false; // Store list changed config for setup after connection (when we know server capabilities) @@ -273,6 +310,184 @@ export class Client< } } + /** + * Gets the middleware manager for advanced middleware configuration. + */ + get middleware(): ClientMiddlewareManager { + return this._middleware; + } + + /** + * Registers universal middleware that runs for all request types. + * + * @param middleware - The middleware function to register + * @returns This Client instance for chaining + */ + useMiddleware(middleware: ClientMiddleware): this { + this._middleware.useMiddleware(middleware); + return this; + } + + /** + * Registers middleware for outgoing requests only. + * + * @param middleware - The outgoing middleware function to register + * @returns This Client instance for chaining + */ + useOutgoingMiddleware(middleware: OutgoingMiddleware): this { + this._middleware.useOutgoingMiddleware(middleware); + return this; + } + + /** + * Registers middleware for incoming requests only. + * + * @param middleware - The incoming middleware function to register + * @returns This Client instance for chaining + */ + useIncomingMiddleware(middleware: IncomingMiddleware): this { + this._middleware.useIncomingMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for tool calls. + * + * @param middleware - The tool call middleware function to register + * @returns This Client instance for chaining + */ + useToolCallMiddleware(middleware: ToolCallMiddleware): this { + this._middleware.useToolCallMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for resource reads. + * + * @param middleware - The resource read middleware function to register + * @returns This Client instance for chaining + */ + useResourceReadMiddleware(middleware: ResourceReadMiddleware): this { + this._middleware.useResourceReadMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for sampling requests. + * + * @param middleware - The sampling middleware function to register + * @returns This Client instance for chaining + */ + useSamplingMiddleware(middleware: SamplingMiddleware): this { + this._middleware.useSamplingMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for elicitation requests. + * + * @param middleware - The elicitation middleware function to register + * @returns This Client instance for chaining + */ + useElicitationMiddleware(middleware: ElicitationMiddleware): this { + this._middleware.useElicitationMiddleware(middleware); + return this; + } + + /** + * Creates a new ClientBuilder for fluent configuration. + * + * @example + * ```typescript + * const client = Client.builder() + * .name('my-client') + * .version('1.0.0') + * .capabilities({ sampling: {} }) + * .onSamplingRequest(async (params) => { + * // Handle sampling request from server + * return { role: 'assistant', content: { type: 'text', text: '...' } }; + * }) + * .build(); + * ``` + */ + static builder(): ClientBuilder { + return new ClientBuilder(); + } + + /** + * Creates a Client from a ClientBuilderResult configuration. + * + * @param result - The result from ClientBuilder.build() + * @returns A configured Client instance + */ + static fromBuilderResult(result: ClientBuilderResult): Client { + const client = new Client(result.clientInfo, { + capabilities: result.capabilities, + enforceStrictCapabilities: result.options.enforceStrictCapabilities, + jsonSchemaValidator: result.jsonSchemaValidator, + listChanged: result.listChanged + }); + + // Register handlers + if (result.handlers.sampling) { + client.setRequestHandler( + CreateMessageRequestSchema, + result.handlers.sampling as Parameters[1] + ); + } + + if (result.handlers.elicitation) { + client.setRequestHandler(ElicitRequestSchema, result.handlers.elicitation as Parameters[1]); + } + + if (result.handlers.rootsList) { + client.setRequestHandler(ListRootsRequestSchema, result.handlers.rootsList as Parameters[1]); + } + + // Wire up error handlers to Protocol events + if (result.errorHandlers.onError || result.errorHandlers.onProtocolError) { + client.events.on('error', ({ error, context }) => { + const errorContext = { + type: (context as 'sampling' | 'elicitation' | 'rootsList' | 'protocol') || 'protocol', + method: context || 'unknown', + requestId: 'unknown' + }; + + // Call the appropriate error handler based on context + if (context === 'protocol' && result.errorHandlers.onProtocolError) { + (result.errorHandlers.onProtocolError as (error: Error, ctx: typeof errorContext) => void)(error, errorContext); + } else if (result.errorHandlers.onError) { + (result.errorHandlers.onError as (error: Error, ctx: typeof errorContext) => void)(error, errorContext); + } + }); + } + + // Apply middleware from builder + for (const middleware of result.middleware.universal) { + client.useMiddleware(middleware); + } + for (const middleware of result.middleware.outgoing) { + client.useOutgoingMiddleware(middleware); + } + for (const middleware of result.middleware.incoming) { + client.useIncomingMiddleware(middleware); + } + for (const middleware of result.middleware.toolCall) { + client.useToolCallMiddleware(middleware); + } + for (const middleware of result.middleware.resourceRead) { + client.useResourceReadMiddleware(middleware); + } + for (const middleware of result.middleware.sampling) { + client.useSamplingMiddleware(middleware); + } + for (const middleware of result.middleware.elicitation) { + client.useElicitationMiddleware(middleware); + } + + return client; + } + /** * Set up handlers for list changed notifications based on config and server capabilities. * This should only be called after initialization when server capabilities are known. @@ -325,7 +540,7 @@ export class Client< */ public registerCapabilities(capabilities: ClientCapabilities): void { if (this.transport) { - throw new Error('Cannot register capabilities after connecting to transport'); + throw StateError.registrationAfterConnect('capabilities'); } this._capabilities = mergeCapabilities(this._capabilities, capabilities); @@ -338,7 +553,7 @@ export class Client< requestSchema: T, handler: ( request: SchemaOutput, - extra: RequestHandlerExtra + extra: ContextInterface ) => ClientResult | ResultT | Promise ): void { const shape = getObjectShape(requestSchema); @@ -366,14 +581,14 @@ export class Client< if (method === 'elicitation/create') { const wrappedHandler = async ( request: SchemaOutput, - extra: RequestHandlerExtra + ctx: ContextInterface ): Promise => { const validatedRequest = safeParse(ElicitRequestSchema, request); if (!validatedRequest.success) { // Type guard: if success is false, error is guaranteed to exist const errorMessage = validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid elicitation request: ${errorMessage}`); } const { params } = validatedRequest.data; @@ -381,14 +596,14 @@ export class Client< const { supportsFormMode, supportsUrlMode } = getSupportedElicitationModes(this._capabilities.elicitation); if (params.mode === 'form' && !supportsFormMode) { - throw new McpError(ErrorCode.InvalidParams, 'Client does not support form-mode elicitation requests'); + throw ProtocolError.invalidParams('Client does not support form-mode elicitation requests'); } if (params.mode === 'url' && !supportsUrlMode) { - throw new McpError(ErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); + throw ProtocolError.invalidParams('Client does not support URL-mode elicitation requests'); } - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handler(request, ctx)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -398,7 +613,7 @@ export class Client< taskValidationResult.error instanceof Error ? taskValidationResult.error.message : String(taskValidationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid task creation result: ${errorMessage}`); } return taskValidationResult.data; } @@ -409,7 +624,7 @@ export class Client< // Type guard: if success is false, error is guaranteed to exist const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid elicitation result: ${errorMessage}`); } const validatedResult = validationResult.data; @@ -439,18 +654,18 @@ export class Client< if (method === 'sampling/createMessage') { const wrappedHandler = async ( request: SchemaOutput, - extra: RequestHandlerExtra + ctx: ContextInterface ): Promise => { const validatedRequest = safeParse(CreateMessageRequestSchema, request); if (!validatedRequest.success) { const errorMessage = validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid sampling request: ${errorMessage}`); } const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handler(request, ctx)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -460,7 +675,7 @@ export class Client< taskValidationResult.error instanceof Error ? taskValidationResult.error.message : String(taskValidationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid task creation result: ${errorMessage}`); } return taskValidationResult.data; } @@ -472,7 +687,7 @@ export class Client< if (!validationResult.success) { const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid sampling result: ${errorMessage}`); } return validationResult.data; @@ -486,9 +701,36 @@ export class Client< return super.setRequestHandler(requestSchema, handler); } + protected createRequestContext(args: { + request: JSONRPCRequest; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): ContextInterface { + const { request, abortController, capturedTransport, extra } = args; + const sessionId = capturedTransport?.sessionId; + + // Build the MCP context using the helper from Protocol + const mcpContext: McpContext = this.buildMcpContext({ request, sessionId }); + + // Build the client request context (minimal, no HTTP details - client-specific) + const requestCtx: ClientRequestContext = { + signal: abortController.signal, + authInfo: extra?.authInfo + }; + + // Return a ClientContext instance (task context is added by plugins if needed) + return new ClientContext({ + client: this, + request, + mcpContext, + requestCtx + }); + } + protected assertCapability(capability: keyof ServerCapabilities, method: string): void { if (!this._serverCapabilities?.[capability]) { - throw new Error(`Server does not support ${capability} (required for ${method})`); + throw CapabilityError.serverDoesNotSupport(capability, method); } } @@ -571,7 +813,7 @@ export class Client< switch (method as ClientRequest['method']) { case 'logging/setLevel': { if (!this._serverCapabilities?.logging) { - throw new Error(`Server does not support logging (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('logging', method); } break; } @@ -579,7 +821,7 @@ export class Client< case 'prompts/get': case 'prompts/list': { if (!this._serverCapabilities?.prompts) { - throw new Error(`Server does not support prompts (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('prompts', method); } break; } @@ -590,11 +832,11 @@ export class Client< case 'resources/subscribe': case 'resources/unsubscribe': { if (!this._serverCapabilities?.resources) { - throw new Error(`Server does not support resources (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('resources', method); } if (method === 'resources/subscribe' && !this._serverCapabilities.resources.subscribe) { - throw new Error(`Server does not support resource subscriptions (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('resources.subscribe', method); } break; @@ -603,14 +845,14 @@ export class Client< case 'tools/call': case 'tools/list': { if (!this._serverCapabilities?.tools) { - throw new Error(`Server does not support tools (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('tools', method); } break; } case 'completion/complete': { if (!this._serverCapabilities?.completions) { - throw new Error(`Server does not support completions (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('completions', method); } break; } @@ -631,7 +873,7 @@ export class Client< switch (method as ClientNotification['method']) { case 'notifications/roots/list_changed': { if (!this._capabilities.roots?.listChanged) { - throw new Error(`Client does not support roots list changed notifications (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('roots.listChanged', method); } break; } @@ -663,21 +905,21 @@ export class Client< switch (method) { case 'sampling/createMessage': { if (!this._capabilities.sampling) { - throw new Error(`Client does not support sampling capability (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('sampling', method); } break; } case 'elicitation/create': { if (!this._capabilities.elicitation) { - throw new Error(`Client does not support elicitation capability (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('elicitation', method); } break; } case 'roots/list': { if (!this._capabilities.roots) { - throw new Error(`Client does not support roots capability (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('roots', method); } break; } @@ -687,7 +929,7 @@ export class Client< case 'tasks/result': case 'tasks/cancel': { if (!this._capabilities.tasks) { - throw new Error(`Client does not support tasks capability (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('tasks', method); } break; } @@ -782,8 +1024,7 @@ export class Client< ) { // Guard: required-task tools need experimental API if (this.isToolTaskRequired(params.name)) { - throw new McpError( - ErrorCode.InvalidRequest, + throw ProtocolError.invalidRequest( `Tool "${params.name}" requires task-based execution. Use client.experimental.tasks.callToolStream() instead.` ); } @@ -795,10 +1036,7 @@ export class Client< if (validator) { // If tool has outputSchema, it MUST return structuredContent (unless it's an error) if (!result.structuredContent && !result.isError) { - throw new McpError( - ErrorCode.InvalidRequest, - `Tool ${params.name} has an output schema but did not return structured content` - ); + throw ProtocolError.invalidRequest(`Tool ${params.name} has an output schema but did not return structured content`); } // Only validate structured content if present (not when there's an error) @@ -808,17 +1046,15 @@ export class Client< const validationResult = validator(result.structuredContent); if (!validationResult.valid) { - throw new McpError( - ErrorCode.InvalidParams, + throw ProtocolError.invalidParams( `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` ); } } catch (error) { - if (error instanceof McpError) { + if (isProtocolError(error)) { throw error; } - throw new McpError( - ErrorCode.InvalidParams, + throw ProtocolError.invalidParams( `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` ); } @@ -955,4 +1191,150 @@ export class Client< async sendRootsListChanged() { return this.notification({ method: 'notifications/roots/list_changed' }); } + + /** + * Registers a handler for roots/list requests from the server. + * + * @param handler - Handler function that returns the list of roots + * @returns This Client instance for chaining + * + * @example + * ```typescript + * client.onRootsList(async () => ({ + * roots: [ + * { uri: 'file:///workspace', name: 'Workspace' } + * ] + * })); + * ``` + */ + onRootsList( + handler: ( + ctx: ContextInterface + ) => ListRootsResult | Promise + ): this { + this.setRequestHandler(ListRootsRequestSchema, (_request, ctx) => handler(ctx)); + return this; + } + + /** + * Updates the error interceptor based on current handlers. + * This combines both onError and onProtocolError handlers into a single interceptor. + */ + private _updateErrorInterceptor(): void { + if (!this._onErrorHandler && !this._onProtocolErrorHandler) { + // No handlers, clear the interceptor + this.setErrorInterceptor(undefined); + return; + } + + this.setErrorInterceptor(async (error: Error, ctx: ErrorInterceptionContext): Promise => { + const errorContext: ErrorContext = { + type: ctx.type === 'protocol' ? 'protocol' : (ctx.method as ErrorContext['type']) || 'sampling', + method: ctx.method, + requestId: typeof ctx.requestId === 'string' ? ctx.requestId : String(ctx.requestId) + }; + + let result: OnErrorReturn | OnProtocolErrorReturn | void = undefined; + + if (ctx.type === 'protocol' && this._onProtocolErrorHandler) { + // Protocol error - use onProtocolError handler + result = await this._onProtocolErrorHandler(error, errorContext); + } else if (this._onErrorHandler) { + // Application error (or protocol error without specific handler) - use onError handler + result = await this._onErrorHandler(error, errorContext); + } + + if (result === undefined || result === null) { + return undefined; + } + + // Convert the handler result to ErrorInterceptionResult + if (typeof result === 'string') { + return { message: result }; + } else if (result instanceof Error) { + const errorWithCode = result as Error & { code?: number; data?: unknown }; + return { + message: result.message, + code: ctx.type === 'application' ? errorWithCode.code : undefined, + data: errorWithCode.data + }; + } else { + // Object with code/message/data + return { + message: result.message, + code: ctx.type === 'application' ? (result as OnErrorReturn & { code?: number }).code : undefined, + data: result.data + }; + } + }); + } + + private _clearOnErrorHandler(): void { + this._onErrorHandler = undefined; + this._updateErrorInterceptor(); + } + + private _clearOnProtocolErrorHandler(): void { + this._onProtocolErrorHandler = undefined; + this._updateErrorInterceptor(); + } + + /** + * Registers an error handler for application errors in sampling/elicitation/rootsList handlers. + * + * The handler receives the error and a context object with information about where + * the error occurred. It can optionally return a custom error response that will + * modify the error sent to the server. + * + * Note: This is a single-handler pattern. Setting a new handler replaces any previous one. + * The handler is awaited, so async handlers are fully supported. + * + * @param handler - Error handler function + * @returns Unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = client.onError(async (error, ctx) => { + * console.error(`Error in ${ctx.type}/${ctx.method}: ${error.message}`); + * // Optionally return a custom error response + * return { + * code: -32000, + * message: `Application error: ${error.message}`, + * data: { type: ctx.type } + * }; + * }); + * ``` + */ + onError(handler: OnErrorHandler): () => void { + this._onErrorHandler = handler; + this._updateErrorInterceptor(); + return this._clearOnErrorHandler.bind(this); + } + + /** + * Registers an error handler for protocol errors (method not found, parse error, etc.). + * + * The handler receives the error and a context object. It can optionally return + * a custom error response. Note that the error code cannot be changed for protocol + * errors as they have fixed codes per the MCP specification. + * + * Note: This is a single-handler pattern. Setting a new handler replaces any previous one. + * The handler is awaited, so async handlers are fully supported. + * + * @param handler - Error handler function + * @returns Unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = client.onProtocolError(async (error, ctx) => { + * console.error(`Protocol error in ${ctx.method}: ${error.message}`); + * return { message: `Protocol error: ${error.message}` }; + * }); + * ``` + */ + onProtocolError(handler: OnProtocolErrorHandler): () => void { + this._onProtocolErrorHandler = handler; + this._updateErrorInterceptor(); + return this._clearOnProtocolErrorHandler.bind(this); + } } diff --git a/packages/client/src/client/context.ts b/packages/client/src/client/context.ts new file mode 100644 index 000000000..108a89e0f --- /dev/null +++ b/packages/client/src/client/context.ts @@ -0,0 +1,74 @@ +import type { + BaseRequestContext, + ClientNotification, + ClientRequest, + ClientResult, + ContextInterface, + JSONRPCRequest, + McpContext, + Notification, + Request, + Result, + TaskContext +} from '@modelcontextprotocol/core'; +import { BaseContext } from '@modelcontextprotocol/core'; + +import type { Client } from './client.js'; + +/** + * Client-specific request context. + * Clients don't receive HTTP requests, so this is minimal. + * Extends BaseRequestContext with any client-specific fields. + */ +export type ClientRequestContext = BaseRequestContext & { + // Client doesn't receive HTTP requests, just JSON-RPC messages over transport. + // Additional client-specific fields can be added here if needed. +}; + +/** + * Type alias for client-side request handler context. + * Extends the base ContextInterface with ClientRequestContext. + * The generic parameters match the Client's combined types. + */ +export type ClientContextInterface< + RequestT extends Request = Request, + NotificationT extends Notification = Notification +> = ContextInterface; + +/** + * A context object that is passed to client-side request handlers. + * Used when the client handles requests from the server (e.g., sampling, elicitation). + */ +export class ClientContext< + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + ResultT extends Result = Result + > + extends BaseContext + implements ClientContextInterface +{ + private readonly client: Client; + + constructor(args: { + client: Client; + request: JSONRPCRequest; + mcpContext: McpContext; + requestCtx: ClientRequestContext; + task?: TaskContext; + }) { + super({ + request: args.request, + mcpContext: args.mcpContext, + requestCtx: args.requestCtx, + task: args.task + }); + this.client = args.client; + } + + /** + * Returns the client instance for sending notifications and requests. + */ + protected getProtocol(): Client { + return this.client; + } +} diff --git a/packages/client/src/client/middleware.ts b/packages/client/src/client/middleware.ts index 3fd52e41a..1fc1dcd91 100644 --- a/packages/client/src/client/middleware.ts +++ b/packages/client/src/client/middleware.ts @@ -1,8 +1,28 @@ -import type { FetchLike } from '@modelcontextprotocol/core'; +/** + * Client Middleware System + * + * This module provides two distinct middleware systems: + * + * 1. Fetch Middleware - For HTTP/fetch level operations (OAuth, logging, etc.) + * 2. MCP Client Middleware - For MCP protocol level operations (tool calls, sampling, etc.) + */ + +import type { + AuthInfo, + CallToolResult, + CreateMessageResult, + ElicitResult, + FetchLike, + ReadResourceResult +} from '@modelcontextprotocol/core'; import type { OAuthClientProvider } from './auth.js'; import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +// ═══════════════════════════════════════════════════════════════════════════ +// Fetch Middleware (HTTP Level) +// ═══════════════════════════════════════════════════════════════════════════ + /** * Middleware function that wraps and enhances fetch functionality. * Takes a fetch handler and returns an enhanced fetch handler. @@ -320,3 +340,391 @@ export const applyMiddlewares = (...middleware: Middleware[]): Middleware => { export const createMiddleware = (handler: (next: FetchLike, input: string | URL, init?: RequestInit) => Promise): Middleware => { return next => (input, init) => handler(next, input as string | URL, init); }; + +// ═══════════════════════════════════════════════════════════════════════════ +// MCP Client Middleware (Protocol Level) +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Base context shared by all MCP client middleware + */ +interface BaseClientContext { + /** The request ID */ + requestId: string; + /** Abort signal for cancellation */ + signal: AbortSignal; +} + +/** + * Context for outgoing requests (client β†’ server) + */ +export interface OutgoingContext extends BaseClientContext { + direction: 'outgoing'; + /** The type of outgoing request */ + type: + | 'callTool' + | 'readResource' + | 'getPrompt' + | 'listTools' + | 'listResources' + | 'listPrompts' + | 'ping' + | 'complete' + | 'initialize' + | 'other'; + /** The JSON-RPC method name */ + method: string; + /** The request parameters */ + params: unknown; +} + +/** + * Context for incoming requests (server β†’ client) + */ +export interface IncomingContext extends BaseClientContext { + direction: 'incoming'; + /** The type of incoming request */ + type: 'sampling' | 'elicitation' | 'rootsList' | 'other'; + /** The JSON-RPC method name */ + method: string; + /** The request parameters */ + params: unknown; + /** Authentication info if available */ + authInfo?: AuthInfo; +} + +/** + * Union type for all client contexts + */ +export type ClientContext = OutgoingContext | IncomingContext; + +// ═══════════════════════════════════════════════════════════════════════════ +// Type-Specific Contexts +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Context for tool call requests + */ +export interface ToolCallContext extends OutgoingContext { + type: 'callTool'; + params: { + name: string; + arguments?: unknown; + }; +} + +/** + * Context for resource read requests + */ +export interface ResourceReadContext extends OutgoingContext { + type: 'readResource'; + params: { + uri: string; + }; +} + +/** + * Context for sampling requests (server β†’ client) + */ +export interface SamplingContext extends IncomingContext { + type: 'sampling'; + params: { + messages: unknown[]; + maxTokens?: number; + [key: string]: unknown; + }; +} + +/** + * Context for elicitation requests (server β†’ client) + */ +export interface ElicitationContext extends IncomingContext { + type: 'elicitation'; + params: { + message?: string; + [key: string]: unknown; + }; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// MCP Middleware Types +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Next function for MCP client middleware + */ +export type ClientNextFn = (modifiedParams?: unknown) => Promise; + +/** + * Universal middleware for all MCP client requests + */ +export type ClientMiddleware = (ctx: ClientContext, next: ClientNextFn) => Promise; + +/** + * Middleware for outgoing requests only + */ +export type OutgoingMiddleware = (ctx: OutgoingContext, next: ClientNextFn) => Promise; + +/** + * Middleware for incoming requests only + */ +export type IncomingMiddleware = (ctx: IncomingContext, next: ClientNextFn) => Promise; + +/** + * Middleware specifically for tool calls + */ +export type ToolCallMiddleware = (ctx: ToolCallContext, next: ClientNextFn) => Promise; + +/** + * Middleware specifically for resource reads + */ +export type ResourceReadMiddleware = (ctx: ResourceReadContext, next: ClientNextFn) => Promise; + +/** + * Middleware specifically for sampling requests + */ +export type SamplingMiddleware = (ctx: SamplingContext, next: ClientNextFn) => Promise; + +/** + * Middleware specifically for elicitation requests + */ +export type ElicitationMiddleware = (ctx: ElicitationContext, next: ClientNextFn) => Promise; + +// ═══════════════════════════════════════════════════════════════════════════ +// MCP Middleware Manager +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Manages MCP middleware registration and execution for Client. + */ +export class ClientMiddlewareManager { + private _universalMiddleware: ClientMiddleware[] = []; + private _outgoingMiddleware: OutgoingMiddleware[] = []; + private _incomingMiddleware: IncomingMiddleware[] = []; + private _toolCallMiddleware: ToolCallMiddleware[] = []; + private _resourceReadMiddleware: ResourceReadMiddleware[] = []; + private _samplingMiddleware: SamplingMiddleware[] = []; + private _elicitationMiddleware: ElicitationMiddleware[] = []; + + /** + * Registers universal middleware that runs for all requests. + */ + useMiddleware(middleware: ClientMiddleware): this { + this._universalMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware for outgoing requests only. + */ + useOutgoingMiddleware(middleware: OutgoingMiddleware): this { + this._outgoingMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware for incoming requests only. + */ + useIncomingMiddleware(middleware: IncomingMiddleware): this { + this._incomingMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for tool calls. + */ + useToolCallMiddleware(middleware: ToolCallMiddleware): this { + this._toolCallMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for resource reads. + */ + useResourceReadMiddleware(middleware: ResourceReadMiddleware): this { + this._resourceReadMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for sampling requests. + */ + useSamplingMiddleware(middleware: SamplingMiddleware): this { + this._samplingMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for elicitation requests. + */ + useElicitationMiddleware(middleware: ElicitationMiddleware): this { + this._elicitationMiddleware.push(middleware); + return this; + } + + /** + * Executes the middleware chain for an outgoing tool call. + */ + async executeToolCall(ctx: ToolCallContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._outgoingMiddleware as unknown as ClientMiddleware[]), + ...this._toolCallMiddleware + ], + handler + ); + } + + /** + * Executes the middleware chain for an outgoing resource read. + */ + async executeResourceRead( + ctx: ResourceReadContext, + handler: (params?: unknown) => Promise + ): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._outgoingMiddleware as unknown as ClientMiddleware[]), + ...this._resourceReadMiddleware + ], + handler + ); + } + + /** + * Executes the middleware chain for an incoming sampling request. + */ + async executeSampling(ctx: SamplingContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._incomingMiddleware as unknown as ClientMiddleware[]), + ...this._samplingMiddleware + ], + handler + ); + } + + /** + * Executes the middleware chain for an incoming elicitation request. + */ + async executeElicitation(ctx: ElicitationContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._incomingMiddleware as unknown as ClientMiddleware[]), + ...this._elicitationMiddleware + ], + handler + ); + } + + /** + * Executes the middleware chain for a generic outgoing request. + */ + async executeOutgoing(ctx: OutgoingContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._outgoingMiddleware as unknown as ClientMiddleware[]) + ], + handler + ); + } + + /** + * Executes the middleware chain for a generic incoming request. + */ + async executeIncoming(ctx: IncomingContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._incomingMiddleware as unknown as ClientMiddleware[]) + ], + handler + ); + } + + /** + * Checks if any middleware is registered. + */ + hasMiddleware(): boolean { + return ( + this._universalMiddleware.length > 0 || + this._outgoingMiddleware.length > 0 || + this._incomingMiddleware.length > 0 || + this._toolCallMiddleware.length > 0 || + this._resourceReadMiddleware.length > 0 || + this._samplingMiddleware.length > 0 || + this._elicitationMiddleware.length > 0 + ); + } + + /** + * Clears all registered middleware. + */ + clear(): void { + this._universalMiddleware = []; + this._outgoingMiddleware = []; + this._incomingMiddleware = []; + this._toolCallMiddleware = []; + this._resourceReadMiddleware = []; + this._samplingMiddleware = []; + this._elicitationMiddleware = []; + } + + /** + * Adapts generic middleware to a typed middleware. + */ + private _adaptToTyped( + middlewares: ClientMiddleware[] + ): Array<(ctx: TCtx, next: ClientNextFn) => Promise> { + return middlewares.map(mw => { + return async (ctx: TCtx, next: ClientNextFn): Promise => { + return (await mw(ctx, next as ClientNextFn)) as TResult; + }; + }); + } + + /** + * Executes a chain of middleware. + */ + private async _executeChain( + ctx: TCtx, + middlewares: Array<(ctx: TCtx, next: ClientNextFn) => Promise>, + handler: (params?: unknown) => Promise + ): Promise { + let index = -1; + let currentParams: unknown = ctx.params; + + const dispatch = async (i: number, params?: unknown): Promise => { + if (i <= index) { + throw new Error('next() called multiple times'); + } + index = i; + if (params !== undefined) { + currentParams = params; + } + + if (i >= middlewares.length) { + return handler(currentParams); + } + + const middleware = middlewares[i]; + if (!middleware) { + return handler(currentParams); + } + return middleware(ctx, (modifiedParams?: unknown) => dispatch(i + 1, modifiedParams)); + }; + + return dispatch(0); + } +} diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 2141cb12d..253f84a0e 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -1,5 +1,5 @@ import type { FetchLike, JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; -import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders } from '@modelcontextprotocol/core'; +import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders, StateError } from '@modelcontextprotocol/core'; import type { ErrorEvent, EventSourceInit } from 'eventsource'; import { EventSource } from 'eventsource'; @@ -211,7 +211,7 @@ export class SSEClientTransport implements Transport { async start() { if (this._eventSource) { - throw new Error('SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.'); + throw StateError.alreadyConnected(); } return await this._startOrAuth(); @@ -245,7 +245,7 @@ export class SSEClientTransport implements Transport { async send(message: JSONRPCMessage): Promise { if (!this._endpoint) { - throw new Error('Not connected'); + throw StateError.notConnected('send message'); } try { diff --git a/packages/client/src/client/stdio.ts b/packages/client/src/client/stdio.ts index 47df59e3b..178d979c4 100644 --- a/packages/client/src/client/stdio.ts +++ b/packages/client/src/client/stdio.ts @@ -4,7 +4,7 @@ import type { Stream } from 'node:stream'; import { PassThrough } from 'node:stream'; import type { JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; -import { ReadBuffer, serializeMessage } from '@modelcontextprotocol/core'; +import { ReadBuffer, serializeMessage, StateError } from '@modelcontextprotocol/core'; import spawn from 'cross-spawn'; export type StdioServerParameters = { @@ -112,9 +112,7 @@ export class StdioClientTransport implements Transport { */ async start(): Promise { if (this._process) { - throw new Error( - 'StdioClientTransport already started! If using Client class, note that connect() calls start() automatically.' - ); + throw StateError.alreadyConnected(); } return new Promise((resolve, reject) => { @@ -246,7 +244,7 @@ export class StdioClientTransport implements Transport { send(message: JSONRPCMessage): Promise { return new Promise(resolve => { if (!this._process?.stdin) { - throw new Error('Not connected'); + throw StateError.notConnected('send message'); } const json = serializeMessage(message); diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index dbee90f31..82645d30e 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -7,7 +7,8 @@ import { isJSONRPCRequest, isJSONRPCResultResponse, JSONRPCMessageSchema, - normalizeHeaders + normalizeHeaders, + StateError } from '@modelcontextprotocol/core'; import { EventSourceParserStream } from 'eventsource-parser/stream'; @@ -422,9 +423,7 @@ export class StreamableHTTPClientTransport implements Transport { async start() { if (this._abortController) { - throw new Error( - 'StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.' - ); + throw StateError.alreadyConnected(); } this._abortController = new AbortController(); diff --git a/packages/client/src/client/websocket.ts b/packages/client/src/client/websocket.ts index cb0c34687..6b6eda667 100644 --- a/packages/client/src/client/websocket.ts +++ b/packages/client/src/client/websocket.ts @@ -1,5 +1,5 @@ import type { JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; -import { JSONRPCMessageSchema } from '@modelcontextprotocol/core'; +import { JSONRPCMessageSchema, StateError } from '@modelcontextprotocol/core'; const SUBPROTOCOL = 'mcp'; @@ -20,9 +20,7 @@ export class WebSocketClientTransport implements Transport { start(): Promise { if (this._socket) { - throw new Error( - 'WebSocketClientTransport already started! If using Client class, note that connect() calls start() automatically.' - ); + throw StateError.alreadyConnected(); } return new Promise((resolve, reject) => { @@ -63,7 +61,7 @@ export class WebSocketClientTransport implements Transport { send(message: JSONRPCMessage): Promise { return new Promise((resolve, reject) => { if (!this._socket) { - reject(new Error('Not connected')); + reject(StateError.notConnected('send message')); return; } diff --git a/packages/client/src/experimental/tasks/client.ts b/packages/client/src/experimental/tasks/client.ts index df57e91a4..f0a1ec01c 100644 --- a/packages/client/src/experimental/tasks/client.ts +++ b/packages/client/src/experimental/tasks/client.ts @@ -20,7 +20,7 @@ import type { Result, SchemaOutput } from '@modelcontextprotocol/core'; -import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/core'; +import { CallToolResultSchema, ErrorCode, ProtocolError, TaskClientPlugin } from '@modelcontextprotocol/core'; import type { Client } from '../../client/client.js'; @@ -56,6 +56,20 @@ export class ExperimentalClientTasks< > { constructor(private readonly _client: Client) {} + /** + * Gets the TaskClientPlugin, throwing if not installed. + */ + private _getTaskClient(): TaskClientPlugin { + const plugin = this._client.getPlugin(TaskClientPlugin); + if (!plugin) { + throw new ProtocolError( + ErrorCode.InternalError, + 'TaskClientPlugin not installed. Use client.usePlugin(new TaskClientPlugin()) first.' + ); + } + return plugin; + } + /** * Calls a tool and returns an AsyncGenerator that yields response messages. * The generator is guaranteed to end with either a 'result' or 'error' message. @@ -123,7 +137,7 @@ export class ExperimentalClientTasks< if (!result.structuredContent && !result.isError) { yield { type: 'error', - error: new McpError( + error: new ProtocolError( ErrorCode.InvalidRequest, `Tool ${params.name} has an output schema but did not return structured content` ) @@ -140,7 +154,7 @@ export class ExperimentalClientTasks< if (!validationResult.valid) { yield { type: 'error', - error: new McpError( + error: new ProtocolError( ErrorCode.InvalidParams, `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` ) @@ -148,13 +162,13 @@ export class ExperimentalClientTasks< return; } } catch (error) { - if (error instanceof McpError) { + if (error instanceof ProtocolError) { yield { type: 'error', error }; return; } yield { type: 'error', - error: new McpError( + error: new ProtocolError( ErrorCode.InvalidParams, `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` ) @@ -179,9 +193,7 @@ export class ExperimentalClientTasks< * @experimental */ async getTask(taskId: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - type ClientWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise }; - return (this._client as unknown as ClientWithGetTask).getTask({ taskId }, options); + return this._getTaskClient().getTask({ taskId }, options); } /** @@ -195,16 +207,10 @@ export class ExperimentalClientTasks< * @experimental */ async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - getTaskResult: ( - params: { taskId: string }, - resultSchema?: U, - options?: RequestOptions - ) => Promise>; - } - ).getTaskResult({ taskId }, resultSchema, options); + if (!resultSchema) { + throw new ProtocolError(ErrorCode.InvalidParams, 'resultSchema is required'); + } + return this._getTaskClient().getTaskResult({ taskId }, resultSchema, options); } /** @@ -217,12 +223,7 @@ export class ExperimentalClientTasks< * @experimental */ async listTasks(cursor?: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise; - } - ).listTasks(cursor ? { cursor } : undefined, options); + return this._getTaskClient().listTasks(cursor ? { cursor } : undefined, options); } /** @@ -234,12 +235,7 @@ export class ExperimentalClientTasks< * @experimental */ async cancelTask(taskId: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise; - } - ).cancelTask({ taskId }, options); + return this._getTaskClient().cancelTask({ taskId }, options); } /** diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index 787cfd2f0..71674f898 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -1,5 +1,6 @@ export * from './client/auth.js'; export * from './client/authExtensions.js'; +export * from './client/builder.js'; export * from './client/client.js'; export * from './client/middleware.js'; export * from './client/sse.js'; diff --git a/packages/core/src/errors.ts b/packages/core/src/errors.ts new file mode 100644 index 000000000..7b5dd81c5 --- /dev/null +++ b/packages/core/src/errors.ts @@ -0,0 +1,405 @@ +/** + * MCP SDK Error Hierarchy + * + * This module defines a comprehensive error hierarchy for the MCP SDK: + * + * 1. Protocol Errors - Errors that cross the wire as JSON-RPC errors + * - ProtocolError: Protocol-level errors with locked codes + * - Users can throw ProtocolError for intentional locked-code errors + * - Other errors thrown by users are customizable via onError handler + * + * 2. SDK Errors (SdkError subclasses) - Local errors that don't cross the wire + * - StateError: Wrong SDK state (not connected, already connected, etc.) + * - CapabilityError: Missing required capability + * - TransportError: Network/connection issues + * - ValidationError: Local schema validation issues + * + * 3. OAuth Errors - Kept in auth/errors.ts (unchanged) + */ + +import type { ElicitRequestURLParams } from './types/types.js'; +import { ErrorCode } from './types/types.js'; + +// ═══════════════════════════════════════════════════════════════════════════ +// SDK Error Codes (for local errors that don't cross the wire) +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Error codes for local SDK errors (not transmitted over JSON-RPC) + */ +export enum SdkErrorCode { + // State errors + NOT_CONNECTED = 'NOT_CONNECTED', + ALREADY_CONNECTED = 'ALREADY_CONNECTED', + INVALID_STATE = 'INVALID_STATE', + REGISTRATION_AFTER_CONNECT = 'REGISTRATION_AFTER_CONNECT', + + // Capability errors + CAPABILITY_NOT_SUPPORTED = 'CAPABILITY_NOT_SUPPORTED', + + // Transport errors + CONNECTION_FAILED = 'CONNECTION_FAILED', + CONNECTION_LOST = 'CONNECTION_LOST', + CONNECTION_TIMEOUT = 'CONNECTION_TIMEOUT', + SEND_FAILED = 'SEND_FAILED', + + // Validation errors + INVALID_SCHEMA = 'INVALID_SCHEMA', + INVALID_REQUEST = 'INVALID_REQUEST', + INVALID_RESPONSE = 'INVALID_RESPONSE' +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Protocol Errors (cross the wire as JSON-RPC errors) +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Protocol-level errors that cross the wire as JSON-RPC errors. + * The error code is LOCKED and cannot be changed in onProtocolError handlers. + * + * Use this when you want a specific error code that should not be customized: + * - SDK uses this for spec-mandated errors (parse error, method not found, etc.) + * - Users can throw this for intentional locked-code errors + * + * For errors where you want the onError handler to customize the response, + * throw a plain Error instead. + */ +export class ProtocolError extends Error { + /** + * Indicates this is a protocol-level error with a locked code + */ + readonly isProtocolLevel = true as const; + + constructor( + public readonly code: number, + message: string, + public readonly data?: unknown + ) { + super(`MCP error ${code}: ${message}`); + this.name = 'ProtocolError'; + } + + /** + * Creates a parse error (-32700) + */ + static parseError(message: string = 'Parse error', data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.ParseError, message, data); + } + + /** + * Creates an invalid request error (-32600) + */ + static invalidRequest(message: string = 'Invalid request', data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.InvalidRequest, message, data); + } + + /** + * Creates a method not found error (-32601) + */ + static methodNotFound(method: string, data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.MethodNotFound, `Method not found: ${method}`, data); + } + + /** + * Creates an invalid params error (-32602) + */ + static invalidParams(message: string = 'Invalid params', data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.InvalidParams, message, data); + } + + /** + * Creates an internal error (-32603) + */ + static internalError(message: string = 'Internal error', data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.InternalError, message, data); + } + + /** + * Factory method to create the appropriate error type based on the error code and data + */ + static fromError(code: number, message: string, data?: unknown): ProtocolError { + // Check for specific error types + if (code === ErrorCode.UrlElicitationRequired && data) { + const errorData = data as { elicitations?: unknown[] }; + if (errorData.elicitations) { + return new UrlElicitationRequiredError(errorData.elicitations as ElicitRequestURLParams[], message); + } + } + + // Default to generic ProtocolError + return new ProtocolError(code, message, data); + } +} + +/** + * Specialized error type when a tool requires a URL mode elicitation. + * This makes it nicer for the client to handle since there is specific data to work with. + */ +export class UrlElicitationRequiredError extends ProtocolError { + constructor(elicitations: ElicitRequestURLParams[], message: string = `URL elicitation${elicitations.length > 1 ? 's' : ''} required`) { + super(ErrorCode.UrlElicitationRequired, message, { + elicitations: elicitations + }); + this.name = 'UrlElicitationRequiredError'; + } + + get elicitations(): ElicitRequestURLParams[] { + return (this.data as { elicitations: ElicitRequestURLParams[] })?.elicitations ?? []; + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// SDK Error Hierarchy (local errors - don't cross the wire) +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Base class for local SDK errors that don't cross the wire. + * These are thrown locally and should be caught by the SDK user. + */ +export abstract class SdkError extends Error { + /** + * The SDK error code for programmatic handling + */ + abstract readonly code: SdkErrorCode; + + /** + * Whether this error is potentially recoverable + */ + readonly recoverable: boolean = false; + + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Errors related to incorrect SDK state. + * Examples: "Not connected", "Already connected", "Cannot register after connecting" + */ +export class StateError extends SdkError { + readonly code: SdkErrorCode; + + constructor( + code: + | SdkErrorCode.NOT_CONNECTED + | SdkErrorCode.ALREADY_CONNECTED + | SdkErrorCode.INVALID_STATE + | SdkErrorCode.REGISTRATION_AFTER_CONNECT, + message: string + ) { + super(message); + this.code = code; + } + + /** + * Creates a "not connected" error + */ + static notConnected(operation: string = 'perform this operation'): StateError { + return new StateError(SdkErrorCode.NOT_CONNECTED, `Cannot ${operation}: not connected`); + } + + /** + * Creates an "already connected" error + */ + static alreadyConnected(): StateError { + return new StateError(SdkErrorCode.ALREADY_CONNECTED, 'Already connected'); + } + + /** + * Creates an "invalid state" error + */ + static invalidState(message: string): StateError { + return new StateError(SdkErrorCode.INVALID_STATE, message); + } + + /** + * Creates a "registration after connect" error + */ + static registrationAfterConnect(type: string): StateError { + return new StateError(SdkErrorCode.REGISTRATION_AFTER_CONNECT, `Cannot register ${type} after connecting`); + } +} + +/** + * Errors related to missing or unsupported capabilities. + * Example: "Server does not support X (required for Y)" + */ +export class CapabilityError extends SdkError { + readonly code = SdkErrorCode.CAPABILITY_NOT_SUPPORTED as const; + + constructor( + public readonly capability: string, + public readonly requiredFor?: string + ) { + const message = requiredFor + ? `Capability '${capability}' is not supported (required for ${requiredFor})` + : `Capability '${capability}' is not supported`; + super(message); + } + + /** + * Creates a capability error for a missing server capability + */ + static serverDoesNotSupport(capability: string, requiredFor?: string): CapabilityError { + return new CapabilityError(capability, requiredFor); + } + + /** + * Creates a capability error for a missing client capability + */ + static clientDoesNotSupport(capability: string, requiredFor?: string): CapabilityError { + return new CapabilityError(capability, requiredFor); + } +} + +/** + * Errors related to transport/network issues. + * Examples: Connection failed, timeout, connection lost + */ +export class TransportError extends SdkError { + readonly code: SdkErrorCode; + override readonly recoverable: boolean; + + constructor( + code: SdkErrorCode.CONNECTION_FAILED | SdkErrorCode.CONNECTION_LOST | SdkErrorCode.CONNECTION_TIMEOUT | SdkErrorCode.SEND_FAILED, + message: string, + public override readonly cause?: Error + ) { + super(message); + this.code = code; + // Connection lost and timeout are potentially recoverable via retry + this.recoverable = code === SdkErrorCode.CONNECTION_LOST || code === SdkErrorCode.CONNECTION_TIMEOUT; + } + + /** + * Creates a connection failed error + */ + static connectionFailed(message: string = 'Connection failed', cause?: Error): TransportError { + return new TransportError(SdkErrorCode.CONNECTION_FAILED, message, cause); + } + + /** + * Creates a connection lost error + */ + static connectionLost(message: string = 'Connection lost', cause?: Error): TransportError { + const error = new TransportError(SdkErrorCode.CONNECTION_LOST, message, cause); + return error; + } + + /** + * Creates a connection timeout error + */ + static connectionTimeout(timeoutMs: number, cause?: Error): TransportError { + return new TransportError(SdkErrorCode.CONNECTION_TIMEOUT, `Connection timed out after ${timeoutMs}ms`, cause); + } + + /** + * Creates a request timeout error (request sent but no response received in time) + */ + static requestTimeout( + message: string = 'Request timed out', + details?: { timeout?: number; maxTotalTimeout?: number; totalElapsed?: number } + ): TransportError { + const detailsStr = details ? ` (${JSON.stringify(details)})` : ''; + return new TransportError(SdkErrorCode.CONNECTION_TIMEOUT, `${message}${detailsStr}`); + } + + /** + * Creates a send failed error + */ + static sendFailed(message: string = 'Failed to send message', cause?: Error): TransportError { + return new TransportError(SdkErrorCode.SEND_FAILED, message, cause); + } + + /** + * Creates a connection closed error + */ + static connectionClosed(message: string = 'Connection closed'): TransportError { + return new TransportError(SdkErrorCode.CONNECTION_LOST, message); + } +} + +/** + * Errors related to local schema/validation issues (before sending). + * Examples: "Schema is missing a method literal", "Invalid request format" + */ +export class ValidationError extends SdkError { + readonly code: SdkErrorCode; + + constructor( + code: SdkErrorCode.INVALID_SCHEMA | SdkErrorCode.INVALID_REQUEST | SdkErrorCode.INVALID_RESPONSE, + message: string, + public readonly details?: unknown + ) { + super(message); + this.code = code; + } + + /** + * Creates an invalid schema error + */ + static invalidSchema(message: string, details?: unknown): ValidationError { + return new ValidationError(SdkErrorCode.INVALID_SCHEMA, message, details); + } + + /** + * Creates an invalid request error (local validation) + */ + static invalidRequest(message: string, details?: unknown): ValidationError { + return new ValidationError(SdkErrorCode.INVALID_REQUEST, message, details); + } + + /** + * Creates an invalid response error (local validation) + */ + static invalidResponse(message: string, details?: unknown): ValidationError { + return new ValidationError(SdkErrorCode.INVALID_RESPONSE, message, details); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Type Guards +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Type guard to check if an error is a ProtocolError + */ +export function isProtocolError(error: unknown): error is ProtocolError { + return error instanceof ProtocolError; +} + +/** + * Type guard to check if an error is an SdkError + */ +export function isSdkError(error: unknown): error is SdkError { + return error instanceof SdkError; +} + +/** + * Type guard to check if an error is a StateError + */ +export function isStateError(error: unknown): error is StateError { + return error instanceof StateError; +} + +/** + * Type guard to check if an error is a CapabilityError + */ +export function isCapabilityError(error: unknown): error is CapabilityError { + return error instanceof CapabilityError; +} + +/** + * Type guard to check if an error is a TransportError + */ +export function isTransportError(error: unknown): error is TransportError { + return error instanceof TransportError; +} + +/** + * Type guard to check if an error is a ValidationError + */ +export function isValidationError(error: unknown): error is ValidationError { + return error instanceof ValidationError; +} diff --git a/packages/core/src/experimental/index.ts b/packages/core/src/experimental/index.ts index ea39eb79f..b59a388b4 100644 --- a/packages/core/src/experimental/index.ts +++ b/packages/core/src/experimental/index.ts @@ -1,3 +1,4 @@ +export * from './requestTaskStore.js'; export * from './tasks/helpers.js'; export * from './tasks/interfaces.js'; export * from './tasks/stores/inMemory.js'; diff --git a/packages/core/src/experimental/requestTaskStore.ts b/packages/core/src/experimental/requestTaskStore.ts new file mode 100644 index 000000000..f3e9887af --- /dev/null +++ b/packages/core/src/experimental/requestTaskStore.ts @@ -0,0 +1,122 @@ +import type { JSONRPCRequest, RequestId, Result, Task } from '../types/types.js'; +import type { CreateTaskOptions, TaskStore } from './tasks/interfaces.js'; + +/** + * Request-scoped TaskStore interface. + */ +export interface RequestTaskStoreInterface { + /** + * Creates a new task with the given creation parameters. + * The implementation generates a unique taskId and createdAt timestamp. + * + * @param taskParams - The task creation parameters from the request + * @returns The created task object + */ + createTask(taskParams: CreateTaskOptions): Promise; + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @returns The task object + * @throws If the task does not exist + */ + getTask(taskId: string): Promise; + + /** + * Stores the result of a task and sets its final status. + * + * @param taskId - The task identifier + * @param status - The final status: 'completed' for success, 'failed' for errors + * @param result - The result to store + */ + storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; + + /** + * Retrieves the stored result of a task. + * + * @param taskId - The task identifier + * @returns The stored result + */ + getTaskResult(taskId: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + * + * @param taskId - The task identifier + * @param status - The new status + * @param statusMessage - Optional diagnostic message for failed tasks or other status information + */ + updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + * + * @param cursor - Optional cursor for pagination + * @returns An object containing the tasks array and an optional nextCursor + */ + listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; +} + +/** + * Request-scoped task store implementation that wraps a TaskStore with session binding + * and provides a mutable task ID that updates after task creation. + */ +export class RequestTaskStore implements RequestTaskStoreInterface { + private readonly _taskStore: TaskStore; + private readonly _requestId: RequestId; + private readonly _request: JSONRPCRequest; + private readonly _sessionId: string | undefined; + private readonly _taskIdHolder: { id: string }; + + constructor(args: { + taskStore: TaskStore; + requestId: RequestId; + request: JSONRPCRequest; + sessionId: string | undefined; + initialTaskId: string; + }) { + this._taskStore = args.taskStore; + this._requestId = args.requestId; + this._request = args.request; + this._sessionId = args.sessionId; + this._taskIdHolder = { id: args.initialTaskId }; + } + + /** + * Gets the current task ID. This may be updated after createTask is called. + */ + get currentTaskId(): string { + return this._taskIdHolder.id; + } + + async createTask(taskParams: CreateTaskOptions): Promise { + const task = await this._taskStore.createTask(taskParams, this._requestId, this._request, this._sessionId); + // Update the task ID so subsequent sendRequest/sendNotification calls + // will use the correct task ID for message routing + this._taskIdHolder.id = task.taskId; + return task; + } + + async getTask(taskId: string): Promise { + const task = await this._taskStore.getTask(taskId, this._sessionId); + if (!task) throw new Error(`Task not found: ${taskId}`); + return task; + } + + async storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise { + return this._taskStore.storeTaskResult(taskId, status, result, this._sessionId); + } + + async getTaskResult(taskId: string): Promise { + return this._taskStore.getTaskResult(taskId, this._sessionId); + } + + async updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise { + return this._taskStore.updateTaskStatus(taskId, status, statusMessage, this._sessionId); + } + + async listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }> { + return this._taskStore.listTasks(this._sessionId, cursor); + } +} diff --git a/packages/core/src/experimental/tasks/interfaces.ts b/packages/core/src/experimental/tasks/interfaces.ts index c1901d70a..4bf11942c 100644 --- a/packages/core/src/experimental/tasks/interfaces.ts +++ b/packages/core/src/experimental/tasks/interfaces.ts @@ -3,7 +3,6 @@ * WARNING: These APIs are experimental and may change without notice. */ -import type { RequestHandlerExtra, RequestTaskStore } from '../../shared/protocol.js'; import type { JSONRPCErrorResponse, JSONRPCNotification, @@ -12,8 +11,6 @@ import type { Request, RequestId, Result, - ServerNotification, - ServerRequest, Task, ToolExecution } from '../../types/types.js'; @@ -22,23 +19,6 @@ import type { // Task Handler Types (for registerToolTask) // ============================================================================ -/** - * Extended handler extra with task store for task creation. - * @experimental - */ -export interface CreateTaskRequestHandlerExtra extends RequestHandlerExtra { - taskStore: RequestTaskStore; -} - -/** - * Extended handler extra with task ID and store for task operations. - * @experimental - */ -export interface TaskRequestHandlerExtra extends RequestHandlerExtra { - taskId: string; - taskStore: RequestTaskStore; -} - /** * Task-specific execution configuration. * taskSupport cannot be 'forbidden' for task-based tools. diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index b7980fadb..e111b9956 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,14 +1,25 @@ export * from './auth/errors.js'; +export * from './errors.js'; export * from './shared/auth.js'; export * from './shared/authUtils.js'; +export * from './shared/context.js'; +export * from './shared/events.js'; +export * from './shared/handlerRegistry.js'; export * from './shared/metadataUtils.js'; +export * from './shared/plugin.js'; +export * from './shared/pluginContext.js'; +export * from './shared/progressManager.js'; export * from './shared/protocol.js'; export * from './shared/responseMessage.js'; export * from './shared/stdio.js'; +export * from './shared/taskClientPlugin.js'; +export * from './shared/taskPlugin.js'; +export * from './shared/timeoutManager.js'; export * from './shared/toolNameValidation.js'; export * from './shared/transport.js'; export * from './shared/uriTemplate.js'; export * from './types/types.js'; +export * from './util/content.js'; export * from './util/inMemory.js'; export * from './util/zodCompat.js'; export * from './util/zodJsonSchemaCompat.js'; diff --git a/packages/core/src/shared/context.ts b/packages/core/src/shared/context.ts new file mode 100644 index 000000000..df05d906b --- /dev/null +++ b/packages/core/src/shared/context.ts @@ -0,0 +1,230 @@ +import type { RequestTaskStoreInterface } from '../experimental/requestTaskStore.js'; +import type { + AuthInfo, + JSONRPCRequest, + Notification, + RelatedTaskMetadata, + Request, + RequestId, + RequestMeta, + Result +} from '../types/types.js'; +import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; +import type { NotificationOptions, Protocol, RequestOptions } from './protocol.js'; + +/** + * Internal type for options that may include task-related fields. + * Used by context methods that need to set relatedTask. + */ +type OptionsWithTask = { relatedTask?: RelatedTaskMetadata }; + +/** + * MCP-level context for a request being handled. + * Contains information about the JSON-RPC request and session. + */ +export type McpContext = { + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + */ + requestId: RequestId; + /** + * The method of the request. + */ + method: string; + /** + * The metadata of the request. + */ + _meta?: RequestMeta; + /** + * The session ID of the request. + */ + sessionId?: string; +}; + +/** + * Base request context with fields common to both client and server. + */ +export type BaseRequestContext = { + /** + * An abort signal used to communicate if the request was cancelled. + */ + signal: AbortSignal; + /** + * The authentication information, if available. + */ + authInfo?: AuthInfo; +}; + +/** + * Task-related context for task-augmented requests. + */ +export type TaskContext = { + /** + * The ID of the task. + */ + id: string; + /** + * The task store for managing task state. + */ + store: RequestTaskStoreInterface; + /** + * The requested TTL for the task, or null if not specified. + */ + requestedTtl: number | null; +}; + +/** + * Base context interface for request handlers. + * Generic over request type, notification type, and request context type. + * + * @typeParam RequestT - The type of requests that can be sent from this context + * @typeParam NotificationT - The type of notifications that can be sent from this context + * @typeParam RequestContextT - The type of request context (server or client specific) + */ +export interface ContextInterface< + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + RequestContextT extends BaseRequestContext = BaseRequestContext +> { + /** + * MCP-level context containing request ID, method, metadata, and session info. + */ + mcpCtx: McpContext; + /** + * Request-specific context (transport/HTTP details). + */ + requestCtx: RequestContextT; + /** + * Task context if this is a task-augmented request, undefined otherwise. + */ + taskCtx: TaskContext | undefined; + /** + * Sends a notification that relates to the current request being handled. + * This is used by certain transports to correctly associate related messages. + */ + sendNotification: (notification: NotificationT) => Promise; + /** + * Sends a request that relates to the current request being handled. + * This is used by certain transports to correctly associate related messages. + */ + sendRequest: (request: RequestT, resultSchema: U, options?: RequestOptions) => Promise>; +} + +/** + * Arguments for constructing a BaseContext. + */ +export interface BaseContextArgs { + /** + * The JSON-RPC request being handled. + */ + request: JSONRPCRequest; + /** + * The MCP context for the request. + */ + mcpContext: McpContext; + /** + * The request-specific context (transport/HTTP details). + */ + requestCtx: RequestContextT; + /** + * The task context, if the request is task-augmented. + * Will be added by plugins if task support is enabled. + */ + task?: TaskContext; +} + +/** + * Abstract base class for context objects passed to request handlers. + * Provides shared implementation for sendNotification and sendRequest. + * + * @typeParam RequestT - The type of requests that can be sent from this context + * @typeParam NotificationT - The type of notifications that can be sent from this context + * @typeParam RequestContextT - The type of request context (server or client specific) + */ +export abstract class BaseContext< + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + RequestContextT extends BaseRequestContext = BaseRequestContext, + ResultT extends Result = Result +> implements ContextInterface +{ + /** + * The MCP context - Contains information about the current MCP request and session. + */ + public readonly mcpCtx: McpContext; + + /** + * The request context with transport-specific fields. + */ + public readonly requestCtx: RequestContextT; + + /** + * The task context, if the request is task-augmented. + * This property can be set by plugins via onBuildHandlerContext. + */ + public taskCtx: TaskContext | undefined; + + /** + * Returns the protocol instance for sending notifications and requests. + * Subclasses must implement this to provide the appropriate Client or Server instance. + */ + protected abstract getProtocol(): Protocol; + + constructor(args: BaseContextArgs) { + this.mcpCtx = { + requestId: args.request.id, + method: args.mcpContext.method, + _meta: args.mcpContext._meta, + sessionId: args.mcpContext.sessionId + }; + this.requestCtx = args.requestCtx; + // Use the task object directly instead of copying to preserve any getters + // (e.g., the id getter that updates after createTask is called) + this.taskCtx = args.task; + } + + /** + * Sends a notification that relates to the current request being handled. + * This is used by certain transports to correctly associate related messages. + * Note: This is an arrow function to preserve 'this' binding when destructured. + */ + public sendNotification = async (notification: NotificationT): Promise => { + const notificationOptions: NotificationOptions & OptionsWithTask = { relatedRequestId: this.mcpCtx.requestId }; + + // Only set relatedTask if there's a valid (non-empty) task ID + // Empty task ID means no task has been created yet or task queuing isn't applicable + if (this.taskCtx && this.taskCtx.id) { + notificationOptions.relatedTask = { taskId: this.taskCtx.id }; + } + + return this.getProtocol().notification(notification, notificationOptions); + }; + + /** + * Sends a request that relates to the current request being handled. + * This is used by certain transports to correctly associate related messages. + * Note: This is an arrow function to preserve 'this' binding when destructured. + */ + public sendRequest = async ( + request: RequestT, + resultSchema: U, + options?: RequestOptions + ): Promise> => { + const requestOptions: RequestOptions & OptionsWithTask = { ...options, relatedRequestId: this.mcpCtx.requestId }; + + // Only set relatedTask if there's a valid (non-empty) task ID + // Empty task ID means no task has been created yet or task queuing isn't applicable + const taskId = this.taskCtx?.id; + if (taskId) { + requestOptions.relatedTask = { taskId }; + + // Set task status to input_required when sending a request within a task context + if (this.taskCtx?.store) { + await this.taskCtx.store.updateTaskStatus(taskId, 'input_required'); + } + } + + return await this.getProtocol().request(request, resultSchema, requestOptions); + }; +} diff --git a/packages/core/src/shared/events.ts b/packages/core/src/shared/events.ts new file mode 100644 index 000000000..60adc039d --- /dev/null +++ b/packages/core/src/shared/events.ts @@ -0,0 +1,274 @@ +/** + * Event Emitter System + * + * A lightweight, type-safe event emitter for SDK observability. + * + * Design decisions: + * - Custom implementation instead of Node's EventEmitter for cross-platform compatibility + * - Works in Node.js, browsers, and edge runtimes + * - Type-safe event names and payloads + * - Modern API with unsubscribe function returned from `on()` + */ + +/** + * Type-safe event emitter interface. + * Events is a record mapping event names to their payload types. + */ +export interface McpEventEmitter> { + /** + * Subscribe to an event. + * @param event - The event name + * @param listener - The callback to invoke when the event is emitted + * @returns An unsubscribe function + */ + on(event: K, listener: (data: Events[K]) => void): () => void; + + /** + * Subscribe to an event for a single occurrence. + * @param event - The event name + * @param listener - The callback to invoke when the event is emitted + * @returns An unsubscribe function + */ + once(event: K, listener: (data: Events[K]) => void): () => void; + + /** + * Unsubscribe from an event. + * @param event - The event name + * @param listener - The callback to remove + */ + off(event: K, listener: (data: Events[K]) => void): void; + + /** + * Emit an event with data. + * @param event - The event name + * @param data - The event payload + */ + emit(event: K, data: Events[K]): void; +} + +/** + * Type-safe event emitter implementation. + * Provides a minimal, cross-platform event system. + */ +export class TypedEventEmitter> implements McpEventEmitter { + private _listeners = new Map void>>(); + + /** + * Subscribe to an event. + * + * @param event - The event name + * @param listener - The callback to invoke when the event is emitted + * @returns An unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = emitter.on('connection:opened', ({ sessionId }) => { + * console.log(`Connected: ${sessionId}`); + * }); + * + * // Later, to unsubscribe: + * unsubscribe(); + * ``` + */ + on(event: K, listener: (data: Events[K]) => void): () => void { + if (!this._listeners.has(event)) { + this._listeners.set(event, new Set()); + } + const listeners = this._listeners.get(event)!; + listeners.add(listener as (data: unknown) => void); + + // Return unsubscribe function + return () => this.off(event, listener); + } + + /** + * Subscribe to an event for a single occurrence. + * The listener is automatically removed after the first invocation. + * + * @param event - The event name + * @param listener - The callback to invoke when the event is emitted + * @returns An unsubscribe function + */ + once(event: K, listener: (data: Events[K]) => void): () => void { + const wrapper = (data: Events[K]): void => { + this.off(event, wrapper); + listener(data); + }; + return this.on(event, wrapper); + } + + /** + * Unsubscribe from an event. + * + * @param event - The event name + * @param listener - The callback to remove + */ + off(event: K, listener: (data: Events[K]) => void): void { + const listeners = this._listeners.get(event); + if (listeners) { + listeners.delete(listener as (data: unknown) => void); + if (listeners.size === 0) { + this._listeners.delete(event); + } + } + } + + /** + * Emit an event with data. + * All registered listeners for the event will be invoked synchronously. + * + * @param event - The event name + * @param data - The event payload + */ + emit(event: K, data: Events[K]): void { + const listeners = this._listeners.get(event); + if (listeners) { + // Create a copy to allow listeners to unsubscribe during iteration + for (const listener of listeners) { + try { + listener(data); + } catch { + // Silently ignore listener errors to prevent one listener + // from breaking others. Errors should be handled by the listener. + } + } + } + } + + /** + * Check if any listeners are registered for an event. + * + * @param event - The event name + * @returns true if there are listeners for the event + */ + hasListeners(event: K): boolean { + const listeners = this._listeners.get(event); + return listeners !== undefined && listeners.size > 0; + } + + /** + * Get the number of listeners for an event. + * + * @param event - The event name + * @returns The number of listeners + */ + listenerCount(event: K): number { + const listeners = this._listeners.get(event); + return listeners?.size ?? 0; + } + + /** + * Remove all listeners for a specific event, or all events if no event is specified. + * + * @param event - Optional event name. If not provided, removes all listeners. + */ + removeAllListeners(event?: K): void { + if (event === undefined) { + this._listeners.clear(); + } else { + this._listeners.delete(event); + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Pre-defined Event Maps for SDK Components +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Events emitted by McpServer. + */ +export interface McpServerEvents { + [key: string]: unknown; + + /** + * Emitted when a tool is registered. + */ + 'tool:registered': { name: string; tool: unknown }; + + /** + * Emitted when a tool is removed. + */ + 'tool:removed': { name: string }; + + /** + * Emitted when a resource is registered. + */ + 'resource:registered': { uri: string; resource: unknown }; + + /** + * Emitted when a resource is removed. + */ + 'resource:removed': { uri: string }; + + /** + * Emitted when a prompt is registered. + */ + 'prompt:registered': { name: string; prompt: unknown }; + + /** + * Emitted when a prompt is removed. + */ + 'prompt:removed': { name: string }; + + /** + * Emitted when a connection is opened. + */ + 'connection:opened': { sessionId: string }; + + /** + * Emitted when a connection is closed. + */ + 'connection:closed': { sessionId: string; reason?: string }; + + /** + * Emitted when an error occurs. + */ + error: { error: Error; context?: string }; +} + +/** + * Events emitted by Client. + */ +export interface McpClientEvents { + [key: string]: unknown; + + /** + * Emitted when a connection is opened. + */ + 'connection:opened': { sessionId: string }; + + /** + * Emitted when a connection is closed. + */ + 'connection:closed': { sessionId: string; reason?: string }; + + /** + * Emitted when a tool call is made. + */ + 'tool:called': { name: string; args: unknown }; + + /** + * Emitted when a tool call returns a result. + */ + 'tool:result': { name: string; result: unknown }; + + /** + * Emitted when an error occurs. + */ + error: { error: Error; context?: string }; +} + +/** + * Creates a new typed event emitter for McpServer events. + */ +export function createServerEventEmitter(): TypedEventEmitter { + return new TypedEventEmitter(); +} + +/** + * Creates a new typed event emitter for Client events. + */ +export function createClientEventEmitter(): TypedEventEmitter { + return new TypedEventEmitter(); +} diff --git a/packages/core/src/shared/handlerRegistry.ts b/packages/core/src/shared/handlerRegistry.ts new file mode 100644 index 000000000..12cb0a0d8 --- /dev/null +++ b/packages/core/src/shared/handlerRegistry.ts @@ -0,0 +1,184 @@ +/** + * Handler Registry + * + * Manages request and notification handlers for the Protocol class. + * Extracted from Protocol to follow Single Responsibility Principle. + * + * This registry is focused on storage and management - it does NOT handle: + * - Schema parsing (handled by Protocol) + * - Capability assertions (handled by Protocol) + */ + +import type { JSONRPCNotification, JSONRPCRequest, Notification, Request, RequestId, Result } from '../types/types.js'; +import type { BaseRequestContext, ContextInterface } from './context.js'; + +/** + * Internal handler type for request handlers (after parsing by Protocol) + */ +export type InternalRequestHandler = ( + request: JSONRPCRequest, + extra: ContextInterface +) => Promise; + +/** + * Internal notification handler type (after parsing by Protocol) + */ +export type InternalNotificationHandler = (notification: JSONRPCNotification) => Promise; + +/** + * Manages request and notification handlers for the Protocol. + * Focused on storage, retrieval, and abort controller management. + */ +export class HandlerRegistry { + private _requestHandlers = new Map>(); + private _notificationHandlers = new Map(); + private _requestHandlerAbortControllers = new Map(); + + /** + * A handler to invoke for any request types that do not have their own handler installed. + */ + fallbackRequestHandler?: InternalRequestHandler; + + /** + * A handler to invoke for any notification types that do not have their own handler installed. + */ + fallbackNotificationHandler?: (notification: Notification) => Promise; + + // ═══════════════════════════════════════════════════════════════════════════ + // Request Handler Management + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Sets a request handler for a method. + * The handler should already be wrapped to handle JSONRPCRequest. + */ + setRequestHandler(method: string, handler: InternalRequestHandler): void { + this._requestHandlers.set(method, handler); + } + + /** + * Gets a request handler for a method, or the fallback handler if none exists. + */ + getRequestHandler(method: string): InternalRequestHandler | undefined { + return this._requestHandlers.get(method) ?? this.fallbackRequestHandler; + } + + /** + * Checks if a request handler exists for a method. + */ + hasRequestHandler(method: string): boolean { + return this._requestHandlers.has(method); + } + + /** + * Removes a request handler for a method. + */ + removeRequestHandler(method: string): void { + this._requestHandlers.delete(method); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Notification Handler Management + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Sets a notification handler for a method. + * The handler should already be wrapped to handle JSONRPCNotification. + */ + setNotificationHandler(method: string, handler: InternalNotificationHandler): void { + this._notificationHandlers.set(method, handler); + } + + /** + * Gets a notification handler for a method, or the fallback handler if none exists. + */ + getNotificationHandler(method: string): InternalNotificationHandler | undefined { + const handler = this._notificationHandlers.get(method); + if (handler) return handler; + // Wrap fallback to match InternalNotificationHandler signature + if (this.fallbackNotificationHandler) { + return async (notification: JSONRPCNotification) => { + await this.fallbackNotificationHandler!(notification as Notification); + }; + } + return undefined; + } + + /** + * Checks if a notification handler exists for a method. + */ + hasNotificationHandler(method: string): boolean { + return this._notificationHandlers.has(method); + } + + /** + * Removes a notification handler for a method. + */ + removeNotificationHandler(method: string): void { + this._notificationHandlers.delete(method); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Abort Controller Management + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Creates an AbortController for a request and stores it. + */ + createAbortController(requestId: RequestId): AbortController { + const controller = new AbortController(); + this._requestHandlerAbortControllers.set(requestId, controller); + return controller; + } + + /** + * Gets the AbortController for a request. + */ + getAbortController(requestId: RequestId): AbortController | undefined { + return this._requestHandlerAbortControllers.get(requestId); + } + + /** + * Removes the AbortController for a request. + */ + removeAbortController(requestId: RequestId): void { + this._requestHandlerAbortControllers.delete(requestId); + } + + /** + * Aborts all pending request handlers. + */ + abortAllPendingRequests(reason?: string): void { + for (const controller of this._requestHandlerAbortControllers.values()) { + controller.abort(reason); + } + this._requestHandlerAbortControllers.clear(); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Utility Methods + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Gets all registered request handler methods. + */ + getRequestMethods(): string[] { + return [...this._requestHandlers.keys()]; + } + + /** + * Gets all registered notification handler methods. + */ + getNotificationMethods(): string[] { + return [...this._notificationHandlers.keys()]; + } + + /** + * Clears all handlers and abort controllers. + */ + clear(): void { + this._requestHandlers.clear(); + this._notificationHandlers.clear(); + this.abortAllPendingRequests('Registry cleared'); + } +} diff --git a/packages/core/src/shared/plugin.ts b/packages/core/src/shared/plugin.ts new file mode 100644 index 000000000..0fa58f2d1 --- /dev/null +++ b/packages/core/src/shared/plugin.ts @@ -0,0 +1,481 @@ +/** + * Protocol Plugin System + * + * This module defines the plugin interface for extending Protocol functionality. + * Plugins are INTERNAL to the SDK - they are used for decomposing the Protocol class + * into focused components. They are not exposed as a public API for SDK users. + * + * For application-level extensibility (logging, auth, metrics), SDK users should + * use McpServer Middleware (see server/middleware.ts) or Client Middleware. + */ + +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + RequestId, + Result +} from '../types/types.js'; +import type { AnyObjectSchema, SchemaOutput } from '../util/zodCompat.js'; +import type { ProgressManagerInterface } from './progressManager.js'; +import type { Transport, TransportSendOptions } from './transport.js'; + +// ═══════════════════════════════════════════════════════════════════════════ +// Sub-Component Interfaces +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Interface for transport-related operations accessible to plugins. + */ +export interface PluginTransportInterface { + /** + * Get the current transport (may be undefined if not connected) + */ + getTransport(): Transport | undefined; + + /** + * Get the session ID (if available) + */ + getSessionId(): string | undefined; + + /** + * Send a message through the transport + */ + send( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise; +} + +/** + * Interface for making outbound requests from plugins. + */ +export interface PluginRequestsInterface { + /** + * Send a request through the protocol and wait for a response. + * + * @param request - The request to send + * @param resultSchema - Schema to validate the response + * @param options - Optional request options (timeout, signal, etc.) + * @returns The validated response + */ + sendRequest( + request: JSONRPCRequest, + resultSchema: T, + options?: PluginRequestOptions + ): Promise>; +} + +/** + * Interface for registering and managing handlers. + */ +export interface PluginHandlersInterface { + /** + * Register a request handler for a specific method. + * Handler returns SendResultT to ensure type safety with the Protocol. + */ + setRequestHandler( + schema: T, + handler: (request: SchemaOutput, extra: PluginHandlerExtra) => SendResultT | Promise + ): void; + + /** + * Register a notification handler for a specific method + */ + setNotificationHandler(schema: T, handler: (notification: SchemaOutput) => void | Promise): void; + + /** + * Remove a request handler + */ + removeRequestHandler(method: string): void; + + /** + * Remove a notification handler + */ + removeNotificationHandler(method: string): void; +} + +/** + * Interface for managing request resolvers. + * Used by TaskPlugin for routing queued responses back to their original callers. + */ +export interface PluginResolversInterface { + /** + * Register a resolver for a pending request. + */ + register(id: RequestId, resolver: (response: JSONRPCResultResponse | Error) => void): void; + + /** + * Get a resolver for a pending request. + */ + get(id: RequestId): ((response: JSONRPCResultResponse | Error) => void) | undefined; + + /** + * Remove a resolver for a pending request. + */ + remove(id: RequestId): void; +} + +/** + * Options for plugin requests. + */ +export interface PluginRequestOptions { + /** + * Timeout in milliseconds for the request + */ + timeout?: number; + + /** + * Abort signal for cancelling the request + */ + signal?: AbortSignal; + + /** Allow additional options */ + [key: string]: unknown; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Plugin Context +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Context provided to plugins during installation. + * Composed of focused sub-components for different concerns. + */ +export interface PluginContext { + /** + * Transport operations (get transport, send messages) + */ + readonly transport: PluginTransportInterface; + + /** + * Outbound request operations + */ + readonly requests: PluginRequestsInterface; + + /** + * Handler registration and management + */ + readonly handlers: PluginHandlersInterface; + + /** + * Request resolver management (for task response routing) + */ + readonly resolvers: PluginResolversInterface; + + /** + * Progress handler management + */ + readonly progress: ProgressManagerInterface; + + /** + * Report an error through the protocol's error handling + */ + reportError(error: Error): void; +} + +/** + * Extra context passed to plugin request handlers. + */ +export interface PluginHandlerExtra { + /** + * MCP context with request metadata + */ + readonly mcpCtx: { + readonly requestId: RequestId; + readonly sessionId?: string; + }; + + /** + * Request context with abort signal + */ + readonly requestCtx: { + readonly signal: AbortSignal; + }; +} + +/** + * Context provided to plugin hooks during request processing. + */ +export interface RequestContext { + /** + * The session ID for this request + */ + readonly sessionId?: string; + + /** + * The request ID from the JSON-RPC message + */ + readonly requestId: number | string; + + /** + * The method being called + */ + readonly method: string; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Protocol Plugin Interface +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Plugin interface for extending Protocol functionality. + * + * Plugins are internal SDK components for decomposing the Protocol class. + * They can: + * - Register handlers during installation + * - Hook into request/response lifecycle + * - Route messages (e.g., for task queueing) + * + * Note: Plugins are NOT a public API for SDK users. For application-level + * extensibility, use McpServer/Client middleware instead. + */ +export interface ProtocolPlugin { + /** + * Unique name for this plugin (for debugging and identification) + */ + readonly name: string; + + /** + * Priority determines execution order. Higher priority = runs first. + * Default: 0 + */ + readonly priority?: number; + + // ─── LIFECYCLE HOOKS ─── + + /** + * Called when the plugin is installed on a Protocol instance. + * Use this to register handlers, set up state, etc. + */ + install?(ctx: PluginContext): void | Promise; + + /** + * Called when a transport is connected. + */ + onConnect?(transport: Transport): void | Promise; + + /** + * Called when the connection is closed. + */ + onClose?(): void | Promise; + + // ─── MESSAGE ROUTING ─── + + /** + * Determines if this plugin should route the message instead of the default transport. + * Used by TaskPlugin to queue messages for task-related responses. + */ + shouldRouteMessage?( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): boolean; + + /** + * Routes the message. Only called if shouldRouteMessage returned true. + */ + routeMessage?( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise; + + // ─── REQUEST/RESPONSE HOOKS ─── + + /** + * Called before a request is processed. + * Can modify the request or return void to pass through unchanged. + */ + onRequest?(request: JSONRPCRequest, ctx: RequestContext): JSONRPCRequest | void | Promise; + + /** + * Called after a request is successfully processed. + * Can modify the result or return void to pass through unchanged. + */ + onRequestResult?(request: JSONRPCRequest, result: Result, ctx: RequestContext): Result | void | Promise; + + /** + * Called when a request handler throws an error. + * Can modify the error or return void to pass through unchanged. + */ + onRequestError?(request: JSONRPCRequest, error: Error, ctx: RequestContext): Error | void | Promise; + + /** + * Called when a response is received (for outgoing requests). + * Plugins can use this to manage progress handlers or other state. + * @param response - The response received + * @param messageId - The message ID (progress token) for this request + */ + onResponse?(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): void | Promise; + + // ─── NOTIFICATION HOOKS ─── + + /** + * Called before a notification is processed. + * Can modify the notification or return void to pass through unchanged. + */ + onNotification?(notification: JSONRPCNotification): JSONRPCNotification | void | Promise; + + // ─── OUTGOING MESSAGE HOOKS ─── + + /** + * Called before sending an outgoing request. + * Plugins can augment request params (e.g., add task metadata) or register response resolvers. + * @param request - The request being sent (can be mutated) + * @param options - The request options (can be mutated) + * @returns Modified request, or void to use original + */ + onBeforeSendRequest?(request: JSONRPCRequest, options: OutgoingRequestContext): JSONRPCRequest | void | Promise; + + /** + * Called before sending an outgoing notification. + * Plugins can augment notification params (e.g., add task metadata). + * @param notification - The notification being sent (can be mutated) + * @param options - The notification options (can be mutated) + * @returns Modified notification, or void to use original + */ + onBeforeSendNotification?( + notification: JSONRPCNotification, + options: OutgoingNotificationContext + ): JSONRPCNotification | void | Promise; + + // ─── HANDLER CONTEXT HOOKS ─── + + /** + * Called when building context for an incoming request handler. + * Plugins can contribute additional context (e.g., task context). + * @param request - The incoming request + * @param baseContext - Base context with session info + * @returns Additional context fields to merge, or void + */ + onBuildHandlerContext?( + request: JSONRPCRequest, + baseContext: HandlerContextBase + ): Record | void | Promise | void>; +} + +/** + * Context passed to onBeforeSendRequest hook. + */ +export interface OutgoingRequestContext { + /** Message ID for this request */ + readonly messageId: number; + /** Session ID if available */ + readonly sessionId?: string; + /** Original request options (plugins can read task, relatedTask, etc.) */ + readonly requestOptions?: Record; + /** Register a resolver to handle the response */ + registerResolver(resolver: (response: JSONRPCResultResponse | Error) => void): void; +} + +/** + * Context passed to onBeforeSendNotification hook. + */ +export interface OutgoingNotificationContext { + /** Session ID if available */ + readonly sessionId?: string; + /** Related request ID if this notification is in response to a request */ + readonly relatedRequestId?: RequestId; + /** Original notification options (plugins can read relatedTask, etc.) */ + readonly notificationOptions?: Record; +} + +/** + * Base context passed to onBuildHandlerContext hook. + */ +export interface HandlerContextBase { + /** Session ID if available */ + readonly sessionId?: string; + /** The incoming request */ + readonly request: JSONRPCRequest; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Base Plugin Class +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Abstract base class for plugins. + * Provides default no-op implementations for all hooks. + * Plugins only need to override the methods they care about. + */ +export abstract class BasePlugin implements ProtocolPlugin { + abstract readonly name: string; + readonly priority?: number; + + // Default no-op implementations + install?(_ctx: PluginContext): void | Promise { + // Override in subclass + } + + onConnect?(_transport: Transport): void | Promise { + // Override in subclass + } + + onClose?(): void | Promise { + // Override in subclass + } + + shouldRouteMessage?( + _message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + _options?: TransportSendOptions + ): boolean { + return false; + } + + routeMessage?( + _message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + _options?: TransportSendOptions + ): Promise { + return Promise.resolve(); + } + + onRequest?(_request: JSONRPCRequest, _ctx: RequestContext): JSONRPCRequest | void | Promise { + // Override in subclass + } + + onRequestResult?(_request: JSONRPCRequest, _result: Result, _ctx: RequestContext): Result | void | Promise { + // Override in subclass + } + + onRequestError?(_request: JSONRPCRequest, _error: Error, _ctx: RequestContext): Error | void | Promise { + // Override in subclass + } + + onResponse?(_response: JSONRPCResponse | JSONRPCErrorResponse, _messageId: number): void | Promise { + // Override in subclass + } + + onNotification?(_notification: JSONRPCNotification): JSONRPCNotification | void | Promise { + // Override in subclass + } + + onBeforeSendRequest?( + _request: JSONRPCRequest, + _options: OutgoingRequestContext + ): JSONRPCRequest | void | Promise { + // Override in subclass + } + + onBeforeSendNotification?( + _notification: JSONRPCNotification, + _options: OutgoingNotificationContext + ): JSONRPCNotification | void | Promise { + // Override in subclass + } + + onBuildHandlerContext?( + _request: JSONRPCRequest, + _baseContext: HandlerContextBase + ): Record | void | Promise | void> { + // Override in subclass + } +} + +/** + * Helper function to sort plugins by priority (higher priority first) + */ +export function sortPluginsByPriority

(plugins: P[]): P[] { + return plugins.toSorted((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); +} diff --git a/packages/core/src/shared/pluginContext.ts b/packages/core/src/shared/pluginContext.ts new file mode 100644 index 000000000..e6f04d3ca --- /dev/null +++ b/packages/core/src/shared/pluginContext.ts @@ -0,0 +1,192 @@ +/** + * Plugin Context Implementation + * + * This module provides the concrete implementations of the plugin context interfaces. + * These are internal to the SDK and are created by Protocol for plugin installation. + */ + +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + RequestId, + Result +} from '../types/types.js'; +import type { AnyObjectSchema, SchemaOutput } from '../util/zodCompat.js'; +import type { + PluginContext, + PluginHandlerExtra, + PluginHandlersInterface, + PluginRequestOptions, + PluginRequestsInterface, + PluginResolversInterface, + PluginTransportInterface +} from './plugin.js'; +import type { ProgressManagerInterface } from './progressManager.js'; +import type { Transport, TransportSendOptions } from './transport.js'; + +/** + * Protocol interface for plugin context creation. + * This avoids circular dependency with Protocol. + */ +export interface PluginHostProtocol { + readonly transport?: Transport; + request(request: JSONRPCRequest, resultSchema: T, options?: PluginRequestOptions): Promise>; + setRequestHandler( + schema: T, + handler: ( + request: SchemaOutput, + ctx: { mcpCtx: { requestId: RequestId; sessionId?: string }; requestCtx: { signal: AbortSignal } } + ) => SendResultT | Promise + ): void; + setNotificationHandler(schema: T, handler: (notification: SchemaOutput) => void | Promise): void; + removeRequestHandler(method: string): void; + removeNotificationHandler(method: string): void; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Transport Access Implementation +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Implementation of PluginTransportInterface. + * Provides transport-related operations to plugins. + */ +export class PluginTransport implements PluginTransportInterface { + constructor(private readonly getTransportFn: () => Transport | undefined) {} + + getTransport(): Transport | undefined { + return this.getTransportFn(); + } + + getSessionId(): string | undefined { + return this.getTransportFn()?.sessionId; + } + + async send( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise { + await this.getTransportFn()?.send(message, options); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Requests Access Implementation +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Implementation of PluginRequestsInterface. + * Allows plugins to make outbound requests. + */ +export class PluginRequests implements PluginRequestsInterface { + constructor(private readonly protocol: PluginHostProtocol) {} + + async sendRequest( + request: JSONRPCRequest, + resultSchema: T, + options?: PluginRequestOptions + ): Promise> { + return this.protocol.request(request, resultSchema, options); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Handler Registry Implementation +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Implementation of PluginHandlersInterface. + * Allows plugins to register request and notification handlers. + */ +export class PluginHandlers implements PluginHandlersInterface { + constructor(private readonly protocol: PluginHostProtocol) {} + + setRequestHandler( + schema: T, + handler: (request: SchemaOutput, extra: PluginHandlerExtra) => SendResultT | Promise + ): void { + this.protocol.setRequestHandler(schema, (parsedRequest, ctx) => { + const pluginExtra: PluginHandlerExtra = { + mcpCtx: { + requestId: ctx.mcpCtx.requestId, + sessionId: ctx.mcpCtx.sessionId + }, + requestCtx: { + signal: ctx.requestCtx.signal + } + }; + return handler(parsedRequest, pluginExtra); + }); + } + + setNotificationHandler(schema: T, handler: (notification: SchemaOutput) => void | Promise): void { + this.protocol.setNotificationHandler(schema, handler); + } + + removeRequestHandler(method: string): void { + this.protocol.removeRequestHandler(method); + } + + removeNotificationHandler(method: string): void { + this.protocol.removeNotificationHandler(method); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Request Resolver Implementation +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Implementation of PluginResolversInterface. + * Manages request resolvers for routing queued responses. + */ +export class PluginResolvers implements PluginResolversInterface { + constructor(private readonly resolvers: Map void>) {} + + register(id: RequestId, resolver: (response: JSONRPCResultResponse | Error) => void): void { + this.resolvers.set(id, resolver); + } + + get(id: RequestId): ((response: JSONRPCResultResponse | Error) => void) | undefined { + return this.resolvers.get(id); + } + + remove(id: RequestId): void { + this.resolvers.delete(id); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Factory Function +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Configuration for creating a PluginContext. + */ +export interface PluginContextConfig { + protocol: PluginHostProtocol; + getTransport: () => Transport | undefined; + resolvers: Map void>; + progressManager: ProgressManagerInterface; + reportError: (error: Error) => void; +} + +/** + * Creates a PluginContext from the given configuration. + * This is called once by Protocol and cached for reuse. + */ +export function createPluginContext( + config: PluginContextConfig +): PluginContext { + return { + transport: new PluginTransport(config.getTransport), + requests: new PluginRequests(config.protocol), + handlers: new PluginHandlers(config.protocol), + resolvers: new PluginResolvers(config.resolvers), + progress: config.progressManager, + reportError: config.reportError + }; +} diff --git a/packages/core/src/shared/progressManager.ts b/packages/core/src/shared/progressManager.ts new file mode 100644 index 000000000..a20070cc2 --- /dev/null +++ b/packages/core/src/shared/progressManager.ts @@ -0,0 +1,126 @@ +/** + * Progress Manager + * + * Manages progress tracking for the Protocol class. + * Extracted from Protocol to follow Single Responsibility Principle. + */ + +import type { Progress, ProgressNotification } from '../types/types.js'; + +/** + * Callback for progress notifications. + */ +export type ProgressCallback = (progress: Progress) => void; + +/** + * Interface for progress management. + * Plugins use this interface to register and manage progress handlers. + */ +export interface ProgressManagerInterface { + /** + * Registers a progress callback for a message. + * @param messageId - The message ID (used as progress token) + * @param callback - The callback to invoke when progress is received + */ + registerHandler(messageId: number, callback: ProgressCallback): void; + + /** + * Gets the progress callback for a message. + * @param messageId - The message ID + * @returns The progress callback or undefined + */ + getHandler(messageId: number): ProgressCallback | undefined; + + /** + * Removes the progress callback for a message. + * @param messageId - The message ID + */ + removeHandler(messageId: number): void; + + /** + * Handles an incoming progress notification. + * @param notification - The progress notification + * @returns true if handled, false if no handler was found + */ + handleProgress(notification: ProgressNotification): boolean; +} + +/** + * Manages progress tracking for requests. + */ +export class ProgressManager implements ProgressManagerInterface { + /** + * Maps message IDs to progress callbacks + */ + private _progressHandlers: Map = new Map(); + + /** + * Registers a progress callback for a message. + * + * @param messageId - The message ID (used as progress token) + * @param callback - The callback to invoke when progress is received + */ + registerHandler(messageId: number, callback: ProgressCallback): void { + this._progressHandlers.set(messageId, callback); + } + + /** + * Gets the progress callback for a message. + * + * @param messageId - The message ID + * @returns The progress callback or undefined + */ + getHandler(messageId: number): ProgressCallback | undefined { + return this._progressHandlers.get(messageId); + } + + /** + * Removes the progress callback for a message. + * + * @param messageId - The message ID + */ + removeHandler(messageId: number): void { + this._progressHandlers.delete(messageId); + } + + /** + * Handles an incoming progress notification. + * Returns true if the progress was handled, false if no handler was found. + * + * @param notification - The progress notification + * @returns true if handled, false otherwise + */ + handleProgress(notification: ProgressNotification): boolean { + const token = notification.params.progressToken; + if (typeof token !== 'number') { + // Token must be a number for our internal tracking + return false; + } + + const callback = this._progressHandlers.get(token); + if (callback) { + callback({ + progress: notification.params.progress, + total: notification.params.total, + message: notification.params.message + }); + return true; + } + + return false; + } + + /** + * Clears all progress handlers. + */ + clear(): void { + this._progressHandlers.clear(); + } + + /** + * Gets the number of active progress handlers. + */ + get handlerCount(): number { + return this._progressHandlers.size; + } +} diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index def841832..5401c9d8f 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1,12 +1,7 @@ -import type { CreateTaskOptions, QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; -import { isTerminal } from '../experimental/tasks/interfaces.js'; +import { isProtocolError, ProtocolError, StateError, TransportError } from '../errors.js'; import type { - AuthInfo, CancelledNotification, ClientCapabilities, - GetTaskPayloadRequest, - GetTaskRequest, - GetTaskResult, JSONRPCErrorResponse, JSONRPCNotification, JSONRPCRequest, @@ -14,52 +9,46 @@ import type { JSONRPCResultResponse, MessageExtraInfo, Notification, - Progress, ProgressNotification, - RelatedTaskMetadata, Request, RequestId, - RequestInfo, - RequestMeta, Result, - ServerCapabilities, - Task, - TaskCreationParams, - TaskStatusNotification + ServerCapabilities } from '../types/types.js'; import { CancelledNotificationSchema, - CancelTaskRequestSchema, - CancelTaskResultSchema, - CreateTaskResultSchema, ErrorCode, - GetTaskPayloadRequestSchema, - GetTaskRequestSchema, - GetTaskResultSchema, isJSONRPCErrorResponse, isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResultResponse, - isTaskAugmentedRequestParams, - ListTasksRequestSchema, - ListTasksResultSchema, - McpError, PingRequestSchema, - ProgressNotificationSchema, - RELATED_TASK_META_KEY, - TaskStatusNotificationSchema + ProgressNotificationSchema } from '../types/types.js'; import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/zodCompat.js'; import { safeParse } from '../util/zodCompat.js'; import { getMethodLiteral, parseWithCompat } from '../util/zodJsonSchemaCompat.js'; +import type { BaseRequestContext, ContextInterface } from './context.js'; +import type { McpEventEmitter } from './events.js'; +import { TypedEventEmitter } from './events.js'; +import { HandlerRegistry } from './handlerRegistry.js'; +import type { + HandlerContextBase, + OutgoingNotificationContext, + OutgoingRequestContext, + PluginContext, + PluginRequestOptions, + ProtocolPlugin, + RequestContext +} from './plugin.js'; +import { sortPluginsByPriority } from './plugin.js'; +import { createPluginContext } from './pluginContext.js'; +import type { ProgressCallback } from './progressManager.js'; +import { ProgressManager } from './progressManager.js'; import type { ResponseMessage } from './responseMessage.js'; +import { TimeoutManager } from './timeoutManager.js'; import type { Transport, TransportSendOptions } from './transport.js'; -/** - * Callback for progress notifications. - */ -export type ProgressCallback = (progress: Progress) => void; - /** * Additional initialization options. */ @@ -79,29 +68,6 @@ export type ProtocolOptions = { * e.g., ['notifications/tools/list_changed'] */ debouncedNotificationMethods?: string[]; - /** - * Optional task storage implementation. If provided, enables task-related request handlers - * and provides task storage capabilities to request handlers. - */ - taskStore?: TaskStore; - /** - * Optional task message queue implementation for managing server-initiated messages - * that will be delivered through the tasks/result response stream. - */ - taskMessageQueue?: TaskMessageQueue; - /** - * Default polling interval (in milliseconds) for task status checks when no pollInterval - * is provided by the server. Defaults to 5000ms if not specified. - */ - defaultTaskPollInterval?: number; - /** - * Maximum number of messages that can be queued per task for side-channel delivery. - * If undefined, the queue size is unbounded. - * When the limit is exceeded, the TaskMessageQueue implementation's enqueue() method - * will throw an error. It's the implementation's responsibility to handle overflow - * appropriately (e.g., by failing the task, dropping messages, etc.). - */ - maxTaskQueueSize?: number; }; /** @@ -111,12 +77,30 @@ export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60_000; /** * Options that can be given per request. + * + * ## Plugin Extension Pattern + * + * Plugins can define their own typed options by creating intersection types. + * For type safety at call sites, use the plugin-specific type with `satisfies`: + * + * @example + * ```typescript + * import type { TaskRequestOptions } from '@modelcontextprotocol/core'; + * + * // Type-safe task options + * await ctx.sendRequest(req, schema, { + * task: { ttl: 60000 }, + * relatedTask: { taskId: 'parent-123' } + * } satisfies TaskRequestOptions); + * ``` + * + * The index signature allows plugins to read their options from the + * `requestOptions` field in their `onBeforeSendRequest` hooks. */ export type RequestOptions = { /** - * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. - * - * For task-augmented requests: progress notifications continue after CreateTaskResult is returned and stop automatically when the task reaches a terminal status. + * If set, requests progress notifications from the remote end (if supported). + * When progress notifications are received, this callback will be invoked. */ onprogress?: ProgressCallback; @@ -126,7 +110,7 @@ export type RequestOptions = { signal?: AbortSignal; /** - * A timeout (in milliseconds) for this request. If exceeded, an McpError with code `RequestTimeout` will be raised from request(). + * A timeout (in milliseconds) for this request. If exceeded, a TransportError will be raised from request(). * * If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout. */ @@ -141,24 +125,35 @@ export type RequestOptions = { /** * Maximum total time (in milliseconds) to wait for a response. - * If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications. + * If exceeded, a TransportError will be raised, regardless of progress notifications. * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; - /** - * If provided, augments the request with task creation parameters to enable call-now, fetch-later execution patterns. - */ - task?: TaskCreationParams; - - /** - * If provided, associates this request with a related task. - */ - relatedTask?: RelatedTaskMetadata; + /** Allow plugin-specific options via index signature */ + [key: string]: unknown; } & TransportSendOptions; /** * Options that can be given per notification. + * + * ## Plugin Extension Pattern + * + * Plugins can define their own typed options by creating intersection types. + * For type safety at call sites, use the plugin-specific type with `satisfies`: + * + * @example + * ```typescript + * import type { TaskNotificationOptions } from '@modelcontextprotocol/core'; + * + * // Type-safe task options + * await ctx.sendNotification(notification, { + * relatedTask: { taskId: 'parent-123' } + * } satisfies TaskNotificationOptions); + * ``` + * + * The index signature allows plugins to read their options from the + * `notificationOptions` field in their `onBeforeSendNotification` hooks. */ export type NotificationOptions = { /** @@ -166,156 +161,112 @@ export type NotificationOptions = { */ relatedRequestId?: RequestId; - /** - * If provided, associates this notification with a related task. - */ - relatedTask?: RelatedTaskMetadata; + /** Allow plugin-specific options via index signature */ + [key: string]: unknown; }; -/** - * Options that can be given per request. - */ -// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. -export type TaskRequestOptions = Omit; +// ═══════════════════════════════════════════════════════════════════════════ +// Error Interception +// ═══════════════════════════════════════════════════════════════════════════ /** - * Request-scoped TaskStore interface. + * Context provided to error interceptors. */ -export interface RequestTaskStore { +export interface ErrorInterceptionContext { /** - * Creates a new task with the given creation parameters. - * The implementation generates a unique taskId and createdAt timestamp. - * - * @param taskParams - The task creation parameters from the request - * @returns The created task object + * The type of error: + * - 'protocol': Protocol-level errors (method not found, parse error, etc.) + * - 'application': Application errors (handler threw an exception) */ - createTask(taskParams: CreateTaskOptions): Promise; + type: 'protocol' | 'application'; /** - * Gets the current status of a task. - * - * @param taskId - The task identifier - * @returns The task object - * @throws If the task does not exist + * The method that was being called when the error occurred. */ - getTask(taskId: string): Promise; + method: string; /** - * Stores the result of a task and sets its final status. - * - * @param taskId - The task identifier - * @param status - The final status: 'completed' for success, 'failed' for errors - * @param result - The result to store + * The request ID from the JSON-RPC message. */ - storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; - - /** - * Retrieves the stored result of a task. - * - * @param taskId - The task identifier - * @returns The stored result - */ - getTaskResult(taskId: string): Promise; - - /** - * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). - * - * @param taskId - The task identifier - * @param status - The new status - * @param statusMessage - Optional diagnostic message for failed tasks or other status information - */ - updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; + requestId: RequestId; /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @param cursor - Optional cursor for pagination - * @returns An object containing the tasks array and an optional nextCursor + * For protocol errors, the fixed error code that cannot be changed. + * For application errors, the error code that will be used (can be modified via returned Error). */ - listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; + errorCode: number; } /** - * Extra data given to request handlers. + * Result from an error interceptor that can modify the error response. */ -export type RequestHandlerExtra = { - /** - * An abort signal used to communicate if the request was cancelled from the sender's side. - */ - signal: AbortSignal; - - /** - * Information about a validated access token, provided to request handlers. - */ - authInfo?: AuthInfo; - +export interface ErrorInterceptionResult { /** - * The session ID from the transport, if available. + * Override the error message. If not provided, the original error message is used. */ - sessionId?: string; + message?: string; /** - * Metadata from the original request. + * Additional data to include in the error response. */ - _meta?: RequestMeta; + data?: unknown; /** - * The JSON-RPC ID of the request being handled. - * This can be useful for tracking or logging purposes. + * For application errors only: override the error code. + * Ignored for protocol errors (they have fixed codes per MCP spec). */ - requestId: RequestId; - - taskId?: string; + code?: number; +} - taskStore?: RequestTaskStore; +/** + * Error interceptor function type. + * Called before sending error responses, allows customizing the error. + * + * @param error - The original error + * @param context - Context about where the error occurred + * @returns Optional modifications to the error response, or void to use defaults + */ +export type ErrorInterceptor = ( + error: Error, + context: ErrorInterceptionContext +) => ErrorInterceptionResult | void | Promise; - taskRequestedTtl?: number | null; +// ═══════════════════════════════════════════════════════════════════════════ +// Protocol Events +// ═══════════════════════════════════════════════════════════════════════════ - /** - * The original HTTP request. - */ - requestInfo?: RequestInfo; +/** + * Events emitted by the Protocol class. + * + * @example + * ```typescript + * const unsubscribe = protocol.events.on('connection:opened', ({ sessionId }) => { + * console.log(`Connected with session: ${sessionId}`); + * }); + * + * protocol.events.on('error', ({ error, context }) => { + * console.error(`Protocol error in ${context}:`, error); + * }); + * ``` + */ +export interface ProtocolEvents { + [key: string]: unknown; /** - * Sends a notification that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. + * Emitted when a connection is successfully established. */ - sendNotification: (notification: SendNotificationT) => Promise; + 'connection:opened': { sessionId?: string }; /** - * Sends a request that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. + * Emitted when the connection is closed. */ - sendRequest: (request: SendRequestT, resultSchema: U, options?: TaskRequestOptions) => Promise>; + 'connection:closed': { sessionId?: string; reason?: string }; /** - * Closes the SSE stream for this request, triggering client reconnection. - * Only available when using a StreamableHTTPServerTransport with eventStore configured. - * Use this to implement polling behavior during long-running operations. + * Emitted when an error occurs during protocol operations. */ - closeSSEStream?: () => void; - - /** - * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using a StreamableHTTPServerTransport with eventStore configured. - * Use this to implement polling behavior for server-initiated notifications. - */ - closeStandaloneSSEStream?: () => void; -}; - -/** - * Information about a request's timeout state - */ -type TimeoutInfo = { - timeoutId: ReturnType; - startTime: number; - timeout: number; - maxTotalTimeout?: number; - resetTimeoutOnProgress: boolean; - onTimeout: () => void; -}; + error: { error: Error; context?: string }; +} /** * Implements MCP protocol framing on top of a pluggable transport, including @@ -324,25 +275,25 @@ type TimeoutInfo = { export abstract class Protocol { private _transport?: Transport; private _requestMessageId = 0; - private _requestHandlers: Map< - string, - (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise - > = new Map(); - private _requestHandlerAbortControllers: Map = new Map(); - private _notificationHandlers: Map Promise> = new Map(); private _responseHandlers: Map void> = new Map(); - private _progressHandlers: Map = new Map(); - private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); - // Maps task IDs to progress tokens to keep handlers alive after CreateTaskResult - private _taskProgressTokens: Map = new Map(); + // Extracted managers + private _timeoutManager = new TimeoutManager(); + private _progressManager = new ProgressManager(); + private _handlerRegistry = new HandlerRegistry(); - private _taskStore?: TaskStore; - private _taskMessageQueue?: TaskMessageQueue; + // Plugin system + private _plugins: ProtocolPlugin[] = []; private _requestResolvers: Map void> = new Map(); + // Event emitter for observability + private _events = new TypedEventEmitter(); + + // Error interception callback + private _errorInterceptor?: ErrorInterceptor; + /** * Callback for when the connection is closed for any reason. * @@ -357,15 +308,80 @@ export abstract class Protocol void; + /** + * Event emitter for observability and monitoring. + * + * Subscribe to events like connection lifecycle, errors, etc. + * + * @example + * ```typescript + * protocol.events.on('connection:opened', ({ sessionId }) => { + * console.log(`Connected: ${sessionId}`); + * }); + * + * protocol.events.on('error', ({ error }) => { + * console.error('Protocol error:', error); + * }); + * ``` + */ + get events(): McpEventEmitter { + return this._events; + } + + /** + * Sets an error interceptor that can customize error responses before they are sent. + * + * The interceptor is called for both protocol errors (method not found, etc.) and + * application errors (when a handler throws). It can modify the error message and data, + * and for application errors, can also change the error code. + * + * @param interceptor - The error interceptor function, or undefined to clear + * + * @example + * ```typescript + * server.setErrorInterceptor(async (error, ctx) => { + * console.error(`Error in ${ctx.method}: ${error.message}`); + * return { + * message: 'An error occurred', + * data: { originalMessage: error.message } + * }; + * }); + * ``` + */ + protected setErrorInterceptor(interceptor: ErrorInterceptor | undefined): void { + this._errorInterceptor = interceptor; + } + /** * A handler to invoke for any request types that do not have their own handler installed. */ - fallbackRequestHandler?: (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise; + get fallbackRequestHandler(): + | ((request: JSONRPCRequest, extra: ContextInterface) => Promise) + | undefined { + return this._handlerRegistry.fallbackRequestHandler; + } + + set fallbackRequestHandler( + handler: + | (( + request: JSONRPCRequest, + extra: ContextInterface + ) => Promise) + | undefined + ) { + this._handlerRegistry.fallbackRequestHandler = handler; + } /** * A handler to invoke for any notification types that do not have their own handler installed. */ - fallbackNotificationHandler?: (notification: Notification) => Promise; + get fallbackNotificationHandler(): ((notification: Notification) => Promise) | undefined { + return this._handlerRegistry.fallbackNotificationHandler; + } + + set fallbackNotificationHandler(handler: ((notification: Notification) => Promise) | undefined) { + this._handlerRegistry.fallbackNotificationHandler = handler; + } constructor(private _options?: ProtocolOptions) { this.setNotificationHandler(CancelledNotificationSchema, notification => { @@ -381,182 +397,258 @@ export abstract class Protocol ({}) as SendResultT ); + } - // Install task handlers if TaskStore is provided - this._taskStore = _options?.taskStore; - this._taskMessageQueue = _options?.taskMessageQueue; - if (this._taskStore) { - this.setRequestHandler(GetTaskRequestSchema, async (request, extra) => { - const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - - // Per spec: tasks/get responses SHALL NOT include related-task metadata - // as the taskId parameter is the source of truth - // @ts-expect-error SendResultT cannot contain GetTaskResult, but we include it in our derived types everywhere else - return { - ...task - } as SendResultT; - }); + // ═══════════════════════════════════════════════════════════════════════════ + // Plugin System + // ═══════════════════════════════════════════════════════════════════════════ - this.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra) => { - const handleTaskResult = async (): Promise => { - const taskId = request.params.taskId; - - // Deliver queued messages - if (this._taskMessageQueue) { - let queuedMessage: QueuedMessage | undefined; - while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, extra.sessionId))) { - // Handle response and error messages by routing them to the appropriate resolver - if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { - const message = queuedMessage.message; - const requestId = message.id; - - // Lookup resolver in _requestResolvers map - const resolver = this._requestResolvers.get(requestId as RequestId); - - if (resolver) { - // Remove resolver from map after invocation - this._requestResolvers.delete(requestId as RequestId); - - // Invoke resolver with response or error - if (queuedMessage.type === 'response') { - resolver(message as JSONRPCResultResponse); - } else { - // Convert JSONRPCError to McpError - const errorMessage = message as JSONRPCErrorResponse; - const error = new McpError( - errorMessage.error.code, - errorMessage.error.message, - errorMessage.error.data - ); - resolver(error); - } - } else { - // Handle missing resolver gracefully with error logging - const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; - this._onerror(new Error(`${messageType} handler missing for request ${requestId}`)); - } - - // Continue to next message - continue; - } + /** + * Registers a plugin with the protocol. + * Plugins are installed immediately and sorted by priority. + * + * @param plugin - The plugin to register + * @returns this for chaining + */ + usePlugin(plugin: ProtocolPlugin): this { + this._plugins.push(plugin); + this._plugins = sortPluginsByPriority(this._plugins); - // Send the message on the response stream by passing the relatedRequestId - // This tells the transport to write the message to the tasks/result response stream - await this._transport?.send(queuedMessage.message, { relatedRequestId: extra.requestId }); - } - } + // Install the plugin immediately + const ctx = this._getPluginContext(); + plugin.install?.(ctx); - // Now check task status - const task = await this._taskStore!.getTask(taskId, extra.sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); - } + return this; + } - // Block if task is not terminal (we've already delivered all queued messages above) - if (!isTerminal(task.status)) { - // Wait for status change or new messages - await this._waitForTaskUpdate(taskId, extra.signal); + /** + * Retrieves a registered plugin by its class. + * Returns undefined if the plugin is not registered. + * + * @param PluginClass - The plugin class to find + * @returns The plugin instance or undefined + * + * @example + * ```typescript + * const taskPlugin = server.getPlugin(TaskPlugin); + * if (taskPlugin) { + * // Access plugin-specific methods + * } + * ``` + */ + getPlugin>(PluginClass: abstract new (...args: unknown[]) => T): T | undefined { + return this._plugins.find((p): p is T => p instanceof PluginClass); + } - // After waking up, recursively call to deliver any new messages or result - return await handleTaskResult(); - } + /** + * Cached plugin context, created once and reused for all plugins. + */ + private _pluginContext?: PluginContext; - // If task is terminal, return the result - if (isTerminal(task.status)) { - const result = await this._taskStore!.getTaskResult(taskId, extra.sessionId); + /** + * Gets or creates the plugin context for plugin installation. + * The context is created once and cached for reuse. + */ + private _getPluginContext(): PluginContext { + if (!this._pluginContext) { + this._pluginContext = createPluginContext({ + protocol: this._createPluginHostProtocol(), + getTransport: () => this._transport, + resolvers: this._requestResolvers, + progressManager: this._progressManager, + reportError: error => this._onerror(error, 'plugin') + }); + } + return this._pluginContext; + } - this._clearTaskQueue(taskId); + /** + * Creates the protocol interface for plugin context. + * This provides a typed view of Protocol for the plugin system. + */ + private _createPluginHostProtocol() { + return { + transport: this._transport, + request: (request: JSONRPCRequest, resultSchema: T, options?: PluginRequestOptions) => + this.request(request as SendRequestT, resultSchema, options), + setRequestHandler: ( + schema: T, + handler: ( + request: SchemaOutput, + ctx: { mcpCtx: { requestId: RequestId; sessionId?: string }; requestCtx: { signal: AbortSignal } } + ) => SendResultT | Promise + ) => this.setRequestHandler(schema, handler), + setNotificationHandler: ( + schema: T, + handler: (notification: SchemaOutput) => void | Promise + ) => this.setNotificationHandler(schema, handler), + removeRequestHandler: (method: string) => this.removeRequestHandler(method), + removeNotificationHandler: (method: string) => this.removeNotificationHandler(method) + }; + } - return { - ...result, - _meta: { - ...result._meta, - [RELATED_TASK_META_KEY]: { - taskId: taskId - } - } - } as SendResultT; - } + /** + * Calls onConnect on all plugins. + */ + private async _notifyPluginsConnect(transport: Transport): Promise { + for (const plugin of this._plugins) { + await plugin.onConnect?.(transport); + } + } - return await handleTaskResult(); - }; + /** + * Calls onClose on all plugins. + */ + private async _notifyPluginsClose(): Promise { + for (const plugin of this._plugins) { + await plugin.onClose?.(); + } + } - return await handleTaskResult(); - }); + /** + * Checks if any plugin wants to route a message instead of the default transport. + */ + private _findMessageRouter( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): ProtocolPlugin | undefined { + return this._plugins.find(p => p.shouldRouteMessage?.(message, options)); + } - this.setRequestHandler(ListTasksRequestSchema, async (request, extra) => { - try { - const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.sessionId); - // @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else - return { - tasks, - nextCursor, - _meta: {} - } as SendResultT; - } catch (error) { - throw new McpError( - ErrorCode.InvalidParams, - `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` - ); - } - }); + /** + * Calls onRequest on all plugins, allowing them to modify the request. + */ + private async _runPluginOnRequest(request: JSONRPCRequest, ctx: RequestContext): Promise { + let current = request; + for (const plugin of this._plugins) { + const modified = await plugin.onRequest?.(current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - this.setRequestHandler(CancelTaskRequestSchema, async (request, extra) => { - try { - // Get the current task to check if it's in a terminal state, in case the implementation is not atomic - const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); + /** + * Calls onRequestResult on all plugins, allowing them to modify the result. + */ + private async _runPluginOnRequestResult(request: JSONRPCRequest, result: Result, ctx: RequestContext): Promise { + let current = result; + for (const plugin of this._plugins) { + const modified = await plugin.onRequestResult?.(request, current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); - } + /** + * Calls onRequestError on all plugins, allowing them to modify the error. + */ + private async _runPluginOnRequestError(request: JSONRPCRequest, error: Error, ctx: RequestContext): Promise { + let current = error; + for (const plugin of this._plugins) { + const modified = await plugin.onRequestError?.(request, current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - // Reject cancellation of terminal tasks - if (isTerminal(task.status)) { - throw new McpError(ErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); - } + /** + * Calls onNotification on all plugins, allowing them to modify the notification. + */ + private async _runPluginOnNotification(notification: JSONRPCNotification): Promise { + let current = notification; + for (const plugin of this._plugins) { + const modified = await plugin.onNotification?.(current); + if (modified) { + current = modified; + } + } + return current; + } - await this._taskStore!.updateTaskStatus( - request.params.taskId, - 'cancelled', - 'Client cancelled task execution.', - extra.sessionId - ); + /** + * Calls onBeforeSendRequest on all plugins, allowing them to augment the request. + */ + private async _runPluginOnBeforeSendRequest(request: JSONRPCRequest, ctx: OutgoingRequestContext): Promise { + let current = request; + for (const plugin of this._plugins) { + const modified = await plugin.onBeforeSendRequest?.(current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - this._clearTaskQueue(request.params.taskId); + /** + * Calls onBeforeSendNotification on all plugins, allowing them to augment the notification. + */ + private async _runPluginOnBeforeSendNotification( + notification: JSONRPCNotification, + ctx: OutgoingNotificationContext + ): Promise { + let current = notification; + for (const plugin of this._plugins) { + const modified = await plugin.onBeforeSendNotification?.(current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - const cancelledTask = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); - if (!cancelledTask) { - // Task was deleted during cancellation (e.g., cleanup happened) - throw new McpError(ErrorCode.InvalidParams, `Task not found after cancellation: ${request.params.taskId}`); - } + /** + * Calls onBuildHandlerContext on all plugins, merging additional context. + */ + private async _runPluginOnBuildHandlerContext( + request: JSONRPCRequest, + baseContext: HandlerContextBase + ): Promise> { + const additions: Record = {}; + for (const plugin of this._plugins) { + const pluginContext = await plugin.onBuildHandlerContext?.(request, baseContext); + if (pluginContext) { + Object.assign(additions, pluginContext); + } + } + return additions; + } - return { - _meta: {}, - ...cancelledTask - } as unknown as SendResultT; - } catch (error) { - // Re-throw McpError as-is - if (error instanceof McpError) { - throw error; - } - throw new McpError( - ErrorCode.InvalidRequest, - `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` - ); - } - }); + /** + * Routes a message through plugins or transport. + * Plugins can intercept messages (e.g., for task queueing) via shouldRouteMessage/routeMessage. + */ + private async _routeMessage( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise { + // Check if any plugin wants to route this message + for (const plugin of this._plugins) { + if (plugin.shouldRouteMessage?.(message, options)) { + await plugin.routeMessage?.(message, options); + return; + } } + + // No plugin routing - send via transport + await this._transport?.send(message, options); } + // ═══════════════════════════════════════════════════════════════════════════ + // Internal Handlers + // ═══════════════════════════════════════════════════════════════════════════ + private async _oncancel(notification: CancelledNotification): Promise { if (!notification.params.requestId) { return; } // Handle request cancellation - const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); + const controller = this._handlerRegistry.getAbortController(notification.params.requestId); controller?.abort(notification.params.reason); } @@ -567,9 +659,7 @@ export abstract class Protocol void, resetTimeoutOnProgress: boolean = false ) { - this._timeoutInfo.set(messageId, { - timeoutId: setTimeout(onTimeout, timeout), - startTime: Date.now(), + this._timeoutManager.setup(messageId, { timeout, maxTotalTimeout, resetTimeoutOnProgress, @@ -578,29 +668,26 @@ export abstract class Protocol= info.maxTotalTimeout) { - this._timeoutInfo.delete(messageId); - throw McpError.fromError(ErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { - maxTotalTimeout: info.maxTotalTimeout, - totalElapsed - }); + // Check max total timeout before delegating to manager + if (info.maxTotalTimeout) { + const totalElapsed = Date.now() - info.startTime; + if (totalElapsed >= info.maxTotalTimeout) { + this._timeoutManager.cleanup(messageId); + throw TransportError.requestTimeout('Maximum total timeout exceeded', { + maxTotalTimeout: info.maxTotalTimeout, + totalElapsed + }); + } } - clearTimeout(info.timeoutId); - info.timeoutId = setTimeout(info.onTimeout, info.timeout); - return true; + return this._timeoutManager.reset(messageId); } private _cleanupTimeout(messageId: number) { - const info = this._timeoutInfo.get(messageId); - if (info) { - clearTimeout(info.timeoutId); - this._timeoutInfo.delete(messageId); - } + this._timeoutManager.cleanup(messageId); } /** @@ -632,36 +719,96 @@ export abstract class Protocol this._onerror(error_, 'plugin-close')); + for (const handler of responseHandlers.values()) { handler(error); } } - private _onerror(error: Error): void { + private _onerror(error: Error, context?: string): void { this.onerror?.(error); + this._events.emit('error', { error, context }); + } + + /** + * Sends a protocol-level error response (e.g., method not found, parse error). + * Protocol errors have fixed error codes per MCP spec - the interceptor can only + * modify the message and data, not the code. + */ + private _sendProtocolError(request: JSONRPCRequest, errorCode: number, defaultMessage: string, sessionId: string | undefined): void { + const error = new ProtocolError(errorCode, defaultMessage); + + // Call error interceptor if set (async, fire-and-forget for the interception result usage) + Promise.resolve() + .then(async () => { + let message = defaultMessage; + let data: unknown; + + if (this._errorInterceptor) { + const ctx: ErrorInterceptionContext = { + type: 'protocol', + method: request.method, + requestId: request.id, + errorCode + }; + const result = await this._errorInterceptor(error, ctx); + if (result) { + message = result.message ?? message; + data = result.data; + // Note: result.code is ignored for protocol errors (fixed codes per MCP spec) + } + } + + const errorResponse: JSONRPCErrorResponse = { + jsonrpc: '2.0', + id: request.id, + error: { + code: errorCode, + message, + ...(data !== undefined && { data }) + } + }; + + // Route error response through plugins + await this._routeMessage(errorResponse, { sessionId }); + }) + .catch(error_ => this._onerror(new Error(`Failed to send error response: ${error_}`), 'send-error-response')); } private _onnotification(notification: JSONRPCNotification): void { - const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; + const handler = this._handlerRegistry.getNotificationHandler(notification.method); // Ignore notifications not being subscribed to. if (handler === undefined) { @@ -670,102 +817,59 @@ export abstract class Protocol handler(notification)) - .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); + .then(async () => { + // Let plugins modify the notification + const modifiedNotification = await this._runPluginOnNotification(notification); + return handler(modifiedNotification); + }) + .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`), 'notification-handler')); } private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { - const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; + const handler = this._handlerRegistry.getRequestHandler(request.method); // Capture the current transport at request time to ensure responses go to the correct client const capturedTransport = this._transport; - // Extract taskId from request metadata if present (needed early for method not found case) - const relatedTaskId = request.params?._meta?.[RELATED_TASK_META_KEY]?.taskId; - if (handler === undefined) { - const errorResponse: JSONRPCErrorResponse = { - jsonrpc: '2.0', - id: request.id, - error: { - code: ErrorCode.MethodNotFound, - message: 'Method not found' - } - }; - - // Queue or send the error response based on whether this is a task-related request - if (relatedTaskId && this._taskMessageQueue) { - this._enqueueTaskMessage( - relatedTaskId, - { - type: 'error', - message: errorResponse, - timestamp: Date.now() - }, - capturedTransport?.sessionId - ).catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); - } else { - capturedTransport - ?.send(errorResponse) - .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); - } + // Handle method not found - this is a protocol error + this._sendProtocolError(request, ErrorCode.MethodNotFound, 'Method not found', capturedTransport?.sessionId); return; } - const abortController = new AbortController(); - this._requestHandlerAbortControllers.set(request.id, abortController); - - const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; - const taskStore = this._taskStore ? this.requestTaskStore(request, capturedTransport?.sessionId) : undefined; - - const fullExtra: RequestHandlerExtra = { - signal: abortController.signal, - sessionId: capturedTransport?.sessionId, - _meta: request.params?._meta, - sendNotification: async notification => { - // Include related-task metadata if this request is part of a task - const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; - if (relatedTaskId) { - notificationOptions.relatedTask = { taskId: relatedTaskId }; - } - await this.notification(notification, notificationOptions); - }, - sendRequest: async (r, resultSchema, options?) => { - // Include related-task metadata if this request is part of a task - const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } + const abortController = this._handlerRegistry.createAbortController(request.id); + const sessionId = capturedTransport?.sessionId; - // Set task status to input_required when sending a request within a task context - // Use the taskId from options (explicit) or fall back to relatedTaskId (inherited) - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } + const baseExtra: ContextInterface = this.createRequestContext({ + request, + abortController, + capturedTransport, + extra + }); - return await this.request(r, resultSchema, requestOptions); - }, - authInfo: extra?.authInfo, + // Build plugin request context + const pluginReqCtx: RequestContext = { requestId: request.id, - requestInfo: extra?.requestInfo, - taskId: relatedTaskId, - taskStore: taskStore, - taskRequestedTtl: taskCreationParams?.ttl, - closeSSEStream: extra?.closeSSEStream, - closeStandaloneSSEStream: extra?.closeStandaloneSSEStream + method: request.method, + sessionId }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() - .then(() => { - // If this request asked for task creation, check capability first - if (taskCreationParams) { - // Check if the request method supports task creation - this.assertTaskHandlerCapability(request.method); + // Let plugins modify the request + .then(() => this._runPluginOnRequest(request, pluginReqCtx)) + .then(async modifiedRequest => { + // Let plugins contribute additional context (e.g., task context) + const additionalContext = await this._runPluginOnBuildHandlerContext(request, { sessionId, request }); + + // Assign additional context properties to the existing context object + // This preserves the prototype chain (instanceof checks work) + if (additionalContext) { + Object.assign(baseExtra, additionalContext); } + + return handler(modifiedRequest, baseExtra); }) - .then(() => handler(request, fullExtra)) .then( async result => { if (abortController.signal.aborted) { @@ -773,24 +877,19 @@ export abstract class Protocol { if (abortController.signal.aborted) { @@ -798,48 +897,127 @@ export abstract class Protocol this._onerror(new Error(`Failed to send response: ${error}`))) + .catch(async error => { + // Last resort: try to send an error response even if something went wrong above + // This prevents the client from hanging indefinitely + try { + const errorCode = isProtocolError(error) ? error.code : ErrorCode.InternalError; + const errorResponse: JSONRPCErrorResponse = { + jsonrpc: '2.0', + id: request.id, + error: { + code: errorCode, + message: error?.message ?? 'Internal error' + } + }; + await capturedTransport?.send(errorResponse); + } catch { + // Truly give up - can't even send error response + } + this._onerror(new Error(`Failed to send response: ${error}`), 'send-response'); + }) .finally(() => { - this._requestHandlerAbortControllers.delete(request.id); + this._handlerRegistry.removeAbortController(request.id); }); } + /** + * Builds the common MCP context from a request. + * This is used by subclass implementations of createRequestContext. + */ + protected buildMcpContext(args: { request: JSONRPCRequest; sessionId: string | undefined }): { + requestId: RequestId; + method: string; + _meta: Record | undefined; + sessionId: string | undefined; + } { + return { + requestId: args.request.id, + method: args.request.method, + _meta: args.request.params?._meta, + sessionId: args.sessionId + }; + } + + /** + * Creates the context object passed to request handlers. + * Subclasses must implement this to provide the appropriate context type + * (ClientContext for Client, ServerContext for Server). + */ + protected abstract createRequestContext(args: { + request: JSONRPCRequest; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): ContextInterface; + private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); - const handler = this._progressHandlers.get(messageId); + const handler = this._progressManager.getHandler(messageId); if (!handler) { - this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); + this._onerror( + new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`), + 'progress-notification' + ); return; } const responseHandler = this._responseHandlers.get(messageId); - const timeoutInfo = this._timeoutInfo.get(messageId); + const timeoutInfo = this._timeoutManager.get(messageId); if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { try { @@ -847,7 +1025,7 @@ export abstract class Protocol; - if (result.task && typeof result.task === 'object') { - const task = result.task as Record; - if (typeof task.taskId === 'string') { - isTaskResponse = true; - this._taskProgressTokens.set(task.taskId, messageId); - } - } - } + // Let plugins process the response (e.g., for task progress management) + // Plugins can inspect the response and manage progress handlers via getProgressManager() + this._runPluginOnOutboundResponse(response, messageId); - if (!isTaskResponse) { - this._progressHandlers.delete(messageId); - } + // Default: remove progress handler + // Plugins that need to keep progress handlers active should re-register them in their onResponse hook + this._progressManager.removeHandler(messageId); if (isJSONRPCResultResponse(response)) { handler(response); } else { - const error = McpError.fromError(response.error.code, response.error.message, response.error.data); + const error = ProtocolError.fromError(response.error.code, response.error.message, response.error.data); handler(error); } } + /** + * Calls onResponse on all plugins for outbound response processing. + */ + private _runPluginOnOutboundResponse(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): void { + for (const plugin of this._plugins) { + plugin.onResponse?.(response, messageId); + } + } + get transport(): Transport | undefined { return this._transport; } @@ -939,20 +1117,6 @@ export abstract class Protocol( request: SendRequestT, resultSchema: T, options?: RequestOptions ): AsyncGenerator>, void, void> { - const { task } = options ?? {}; - - // For non-task requests, just yield the result - if (!task) { - try { - const result = await this.request(request, resultSchema, options); - yield { type: 'result', result }; - } catch (error) { - yield { - type: 'error', - error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) - }; - } - return; - } - - // For task-augmented requests, we need to poll for status - // First, make the request to create the task - let taskId: string | undefined; try { - // Send the request and get the CreateTaskResult - const createResult = await this.request(request, CreateTaskResultSchema, options); - - // Extract taskId from the result - if (createResult.task) { - taskId = createResult.task.taskId; - yield { type: 'taskCreated', task: createResult.task }; - } else { - throw new McpError(ErrorCode.InternalError, 'Task creation did not return a task'); - } - - // Poll for task completion - while (true) { - // Get current task status - const task = await this.getTask({ taskId }, options); - yield { type: 'taskStatus', task }; - - // Check if task is terminal - if (isTerminal(task.status)) { - switch (task.status) { - case 'completed': { - // Get the final result - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - - break; - } - case 'failed': { - yield { - type: 'error', - error: new McpError(ErrorCode.InternalError, `Task ${taskId} failed`) - }; - - break; - } - case 'cancelled': { - yield { - type: 'error', - error: new McpError(ErrorCode.InternalError, `Task ${taskId} was cancelled`) - }; - - break; - } - // No default - } - return; - } - - // When input_required, call tasks/result to deliver queued messages - // (elicitation, sampling) via SSE and block until terminal - if (task.status === 'input_required') { - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - return; - } - - // Wait before polling again - const pollInterval = task.pollInterval ?? this._options?.defaultTaskPollInterval ?? 1000; - await new Promise(resolve => setTimeout(resolve, pollInterval)); - - // Check if cancelled - options?.signal?.throwIfAborted(); - } + const result = await this.request(request, resultSchema, options); + yield { type: 'result', result }; } catch (error) { yield { type: 'error', - error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) + error: isProtocolError(error) ? error : new ProtocolError(ErrorCode.InternalError, String(error)) }; } } @@ -1082,7 +1158,7 @@ export abstract class Protocol(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; + const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; // Send the request return new Promise>((resolve, reject) => { @@ -1098,11 +1174,6 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); - this._progressHandlers.delete(messageId); + this._progressManager.removeHandler(messageId); this._cleanupTimeout(messageId); this._transport @@ -1165,10 +1217,10 @@ export abstract class Protocol this._onerror(new Error(`Failed to send cancellation: ${error}`))); + .catch(error => this._onerror(new Error(`Failed to send cancellation: ${error}`), 'send-cancellation')); - // Wrap the reason in an McpError if it isn't already - const error = reason instanceof McpError ? reason : new McpError(ErrorCode.RequestTimeout, String(reason)); + // Wrap the reason in a TransportError if it isn't already an error we recognize + const error = reason instanceof Error ? reason : TransportError.requestTimeout(String(reason)); reject(error); }; @@ -1199,132 +1251,73 @@ export abstract class Protocol cancel(McpError.fromError(ErrorCode.RequestTimeout, 'Request timed out', { timeout })); + const timeoutHandler = () => cancel(TransportError.requestTimeout('Request timed out', { timeout })); this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - // Queue request if related to a task - const relatedTaskId = relatedTask?.taskId; - if (relatedTaskId) { - // Store the response resolver for this request so responses can be routed back - const responseResolver = (response: JSONRPCResultResponse | Error) => { - const handler = this._responseHandlers.get(messageId); - if (handler) { - handler(response); - } else { - // Log error when resolver is missing, but don't fail - this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); - } - }; - this._requestResolvers.set(messageId, responseResolver); - - this._enqueueTaskMessage(relatedTaskId, { - type: 'request', - message: jsonrpcRequest, - timestamp: Date.now() - }).catch(error => { - this._cleanupTimeout(messageId); - reject(error); - }); + // Create plugin context for outgoing request + const outgoingCtx: OutgoingRequestContext = { + messageId, + sessionId: this._transport?.sessionId, + requestOptions: options as Record, + registerResolver: () => { + // Register resolver so responses can be routed back (used by task plugin) + const responseResolver = (response: JSONRPCResultResponse | Error) => { + const handler = this._responseHandlers.get(messageId); + if (handler) { + handler(response); + } else { + this._onerror( + new Error(`Response handler missing for side-channeled request ${messageId}`), + 'side-channel-routing' + ); + } + }; + this._requestResolvers.set(messageId, responseResolver); + } + }; - // Don't send through transport - queued messages are delivered via tasks/result only - // This prevents duplicate delivery for bidirectional transports - } else { - // No related task - send through transport normally - this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + // Let plugins augment the request (e.g., add task metadata) + this._runPluginOnBeforeSendRequest(jsonrpcRequest, outgoingCtx) + .then(modifiedRequest => { + jsonrpcRequest = modifiedRequest; + + // Route message through plugins or transport + return this._routeMessage(jsonrpcRequest, { + relatedRequestId, + sessionId: this._transport?.sessionId, + resumptionToken, + onresumptiontoken + }); + }) + .catch(error => { this._cleanupTimeout(messageId); reject(error); }); - } }); } - /** - * Gets the current status of a task. - * - * @experimental Use `client.experimental.tasks.getTask()` to access this method. - */ - protected async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { - // @ts-expect-error SendRequestT cannot directly contain GetTaskRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); - } - - /** - * Retrieves the result of a completed task. - * - * @experimental Use `client.experimental.tasks.getTaskResult()` to access this method. - */ - protected async getTaskResult( - params: GetTaskPayloadRequest['params'], - resultSchema: T, - options?: RequestOptions - ): Promise> { - // @ts-expect-error SendRequestT cannot directly contain GetTaskPayloadRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/result', params }, resultSchema, options); - } - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @experimental Use `client.experimental.tasks.listTasks()` to access this method. - */ - protected async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { - // @ts-expect-error SendRequestT cannot directly contain ListTasksRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); - } - - /** - * Cancels a specific task. - * - * @experimental Use `client.experimental.tasks.cancelTask()` to access this method. - */ - protected async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { - // @ts-expect-error SendRequestT cannot directly contain CancelTaskRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); - } - /** * Emits a notification, which is a one-way message that does not expect a response. */ async notification(notification: SendNotificationT, options?: NotificationOptions): Promise { if (!this._transport) { - throw new Error('Not connected'); + throw StateError.notConnected('send notification'); } this.assertNotificationCapability(notification.method); - // Queue notification if related to a task - const relatedTaskId = options?.relatedTask?.taskId; - if (relatedTaskId) { - // Build the JSONRPC notification with metadata - const jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0', - params: { - ...notification.params, - _meta: { - ...notification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; - - await this._enqueueTaskMessage(relatedTaskId, { - type: 'notification', - message: jsonrpcNotification, - timestamp: Date.now() - }); - - // Don't send through transport - queued messages are delivered via tasks/result only - // This prevents duplicate delivery for bidirectional transports - return; - } - const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; // A notification can only be debounced if it's in the list AND it's "simple" - // (i.e., has no parameters and no related request ID or related task that could be lost). - const canDebounce = - debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId && !options?.relatedTask; + // (i.e., has no parameters and no related request ID). + const canDebounce = debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId; + + // Create plugin context for outgoing notification + const outgoingCtx: OutgoingNotificationContext = { + sessionId: this._transport?.sessionId, + relatedRequestId: options?.relatedRequestId, + notificationOptions: options as Record + }; if (canDebounce) { // If a notification of this type is already scheduled, do nothing. @@ -1337,7 +1330,7 @@ export abstract class Protocol { + Promise.resolve().then(async () => { // Un-mark the notification so the next one can be scheduled. this._pendingDebouncedNotifications.delete(notification.method); @@ -1351,23 +1344,14 @@ export abstract class Protocol this._onerror(error)); + // Route notification through plugins + this._routeMessage(jsonrpcNotification, { + ...options, + sessionId: this._transport?.sessionId + }).catch(error => this._onerror(error, 'send-notification')); }); // Return immediately. @@ -1379,21 +1363,14 @@ export abstract class Protocol, - extra: RequestHandlerExtra + ctx: ContextInterface ) => SendResultT | Promise ): void { const method = getMethodLiteral(requestSchema); this.assertRequestHandlerCapability(method); - this._requestHandlers.set(method, (request, extra) => { + // Wrap handler to parse the request and delegate to registry + this._handlerRegistry.setRequestHandler(method, (request, ctx) => { const parsed = parseWithCompat(requestSchema, request) as SchemaOutput; - return Promise.resolve(handler(parsed, extra)); + return Promise.resolve(handler(parsed, ctx)); }); } @@ -1421,15 +1399,15 @@ export abstract class Protocol) => void | Promise ): void { const method = getMethodLiteral(notificationSchema); - this._notificationHandlers.set(method, notification => { + // Wrap handler to parse the notification and delegate to registry + this._handlerRegistry.setNotificationHandler(method, notification => { const parsed = parseWithCompat(notificationSchema, notification) as SchemaOutput; return Promise.resolve(handler(parsed)); }); @@ -1453,195 +1432,7 @@ export abstract class Protocol { - // Task message queues are only used when taskStore is configured - if (!this._taskStore || !this._taskMessageQueue) { - throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); - } - - const maxQueueSize = this._options?.maxTaskQueueSize; - await this._taskMessageQueue.enqueue(taskId, message, sessionId, maxQueueSize); - } - - /** - * Clears the message queue for a task and rejects any pending request resolvers. - * @param taskId The task ID whose queue should be cleared - * @param sessionId Optional session ID for binding the operation to a specific session - */ - private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { - if (this._taskMessageQueue) { - // Reject any pending request resolvers - const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); - for (const message of messages) { - if (message.type === 'request' && isJSONRPCRequest(message.message)) { - // Extract request ID from the message - const requestId = message.message.id as RequestId; - const resolver = this._requestResolvers.get(requestId); - if (resolver) { - resolver(new McpError(ErrorCode.InternalError, 'Task cancelled or completed')); - this._requestResolvers.delete(requestId); - } else { - // Log error when resolver is missing during cleanup for better observability - this._onerror(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); - } - } - } - } - } - - /** - * Waits for a task update (new messages or status change) with abort signal support. - * Uses polling to check for updates at the task's configured poll interval. - * @param taskId The task ID to wait for - * @param signal Abort signal to cancel the wait - * @returns Promise that resolves when an update occurs or rejects if aborted - */ - private async _waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { - // Get the task's poll interval, falling back to default - let interval = this._options?.defaultTaskPollInterval ?? 1000; - try { - const task = await this._taskStore?.getTask(taskId); - if (task?.pollInterval) { - interval = task.pollInterval; - } - } catch { - // Use default interval if task lookup fails - } - - return new Promise((resolve, reject) => { - if (signal.aborted) { - reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); - return; - } - - // Wait for the poll interval, then resolve so caller can check for updates - const timeoutId = setTimeout(resolve, interval); - - // Clean up timeout and reject if aborted - signal.addEventListener( - 'abort', - () => { - clearTimeout(timeoutId); - reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); - }, - { once: true } - ); - }); - } - - private requestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { - const taskStore = this._taskStore; - if (!taskStore) { - throw new Error('No task store configured'); - } - - return { - createTask: async taskParams => { - if (!request) { - throw new Error('No request provided'); - } - - return await taskStore.createTask( - taskParams, - request.id, - { - method: request.method, - params: request.params - }, - sessionId - ); - }, - getTask: async taskId => { - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - - return task; - }, - storeTaskResult: async (taskId, status, result) => { - await taskStore.storeTaskResult(taskId, status, result, sessionId); - - // Get updated task state and send notification - const task = await taskStore.getTask(taskId, sessionId); - if (task) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: task - }); - await this.notification(notification as SendNotificationT); - - if (isTerminal(task.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } - } - }, - getTaskResult: taskId => { - return taskStore.getTaskResult(taskId, sessionId); - }, - updateTaskStatus: async (taskId, status, statusMessage) => { - // Check if task exists - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task "${taskId}" not found - it may have been cleaned up`); - } - - // Don't allow transitions from terminal states - if (isTerminal(task.status)) { - throw new McpError( - ErrorCode.InvalidParams, - `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` - ); - } - - await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); - - // Get updated task state and send notification - const updatedTask = await taskStore.getTask(taskId, sessionId); - if (updatedTask) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: updatedTask - }); - await this.notification(notification as SendNotificationT); - - if (isTerminal(updatedTask.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } - } - }, - listTasks: cursor => { - return taskStore.listTasks(cursor, sessionId); - } - }; + this._handlerRegistry.removeNotificationHandler(method); } } diff --git a/packages/core/src/shared/responseMessage.ts b/packages/core/src/shared/responseMessage.ts index 8a0dcc2c2..b67f52362 100644 --- a/packages/core/src/shared/responseMessage.ts +++ b/packages/core/src/shared/responseMessage.ts @@ -1,4 +1,5 @@ -import type { McpError, Result, Task } from '../types/types.js'; +import type { ProtocolError } from '../errors.js'; +import type { Result, Task } from '../types/types.js'; /** * Base message type @@ -36,7 +37,7 @@ export interface ResultMessage extends BaseResponseMessage { */ export interface ErrorMessage extends BaseResponseMessage { type: 'error'; - error: McpError; + error: ProtocolError; } /** diff --git a/packages/core/src/shared/taskClientPlugin.ts b/packages/core/src/shared/taskClientPlugin.ts new file mode 100644 index 000000000..e29211133 --- /dev/null +++ b/packages/core/src/shared/taskClientPlugin.ts @@ -0,0 +1,446 @@ +/** + * Task Client Plugin + * + * This plugin provides client-side methods for calling task APIs on a remote server. + * It also manages task-related progress handlers. + * + * Usage: + * ```typescript + * const taskClient = client.getPlugin(TaskClientPlugin); + * const task = await taskClient?.getTask({ taskId: 'task-123' }); + * ``` + */ + +import { ProtocolError } from '../errors.js'; +import { isTerminal } from '../experimental/tasks/interfaces.js'; +import type { + CancelTaskResult, + GetTaskPayloadRequest, + GetTaskRequest, + GetTaskResult, + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListTasksResult, + RelatedTaskMetadata, + Request, + Result, + TaskCreationParams +} from '../types/types.js'; +import { + CancelTaskResultSchema, + CreateTaskResultSchema, + ErrorCode, + GetTaskResultSchema, + isJSONRPCResultResponse, + ListTasksResultSchema, + RELATED_TASK_META_KEY +} from '../types/types.js'; +import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; +import type { OutgoingNotificationContext, OutgoingRequestContext, PluginContext, PluginRequestOptions, ProtocolPlugin } from './plugin.js'; +import type { ProgressCallback, ProgressManagerInterface } from './progressManager.js'; +import type { RequestOptions } from './protocol.js'; +import type { ResponseMessage } from './responseMessage.js'; + +// ═══════════════════════════════════════════════════════════════════════════ +// Task-Specific Option Types +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Extended request options for task-augmented requests. + * + * Use these options when sending requests that should create or relate to tasks. + * For type safety at call sites, use `satisfies TaskRequestOptions`: + * + * @example + * ```typescript + * import type { TaskRequestOptions } from '@modelcontextprotocol/core'; + * + * // Create a task with the request + * await client.request(callToolRequest, CallToolResultSchema, { + * task: { ttl: 60000 } + * } satisfies TaskRequestOptions); + * + * // Inside a handler, associate with a parent task + * await ctx.sendRequest(req, schema, { + * relatedTask: { taskId: ctx.taskCtx?.id } + * } satisfies TaskRequestOptions); + * ``` + */ +export type TaskRequestOptions = RequestOptions & { + /** + * If provided, augments the request with task creation parameters + * to enable call-now, fetch-later execution patterns. + */ + task?: TaskCreationParams; + + /** + * If provided, associates this request with a related task. + * This is typically set internally by the SDK when handling task-augmented requests. + */ + relatedTask?: RelatedTaskMetadata; +}; + +/** + * Extended notification options for task-related notifications. + * + * Use these options when sending notifications that should be associated with a task. + * For type safety at call sites, use `satisfies TaskNotificationOptions`: + * + * @example + * ```typescript + * import type { TaskNotificationOptions } from '@modelcontextprotocol/core'; + * + * // Inside a handler, associate notification with a parent task + * await ctx.sendNotification(progressNotification, { + * relatedTask: { taskId: ctx.taskCtx?.id } + * } satisfies TaskNotificationOptions); + * ``` + */ +export type TaskNotificationOptions = { + /** + * If provided, associates this notification with a related task. + * This is typically set internally by the SDK when handling task-augmented requests. + */ + relatedTask?: RelatedTaskMetadata; +}; + +/** + * Plugin that provides client-side task API methods. + * Clients access this via getPlugin(TaskClientPlugin) to call task APIs on remote servers. + */ +export class TaskClientPlugin implements ProtocolPlugin { + readonly name = 'TaskClientPlugin'; + readonly priority = 50; // Standard priority + + private ctx?: PluginContext; + private progressManager?: ProgressManagerInterface; + + /** + * Maps task IDs to their associated progress token (message ID) and handler. + * This allows progress to continue after CreateTaskResult is returned. + */ + private readonly taskProgressHandlers = new Map(); + + /** + * Install the plugin. + */ + install(ctx: PluginContext): void { + this.ctx = ctx; + this.progressManager = ctx.progress; + } + + /** + * Called when a response is received for an outgoing request. + * Detects task creation responses and preserves progress handlers. + */ + onResponse(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): void { + if (!this.progressManager) return; + + // Check if this is a CreateTaskResult response + if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') { + const result = response.result as Record; + if (result.task && typeof result.task === 'object') { + const task = result.task as Record; + if (typeof task.taskId === 'string') { + const taskId = task.taskId; + + // Get the current progress handler before Protocol removes it + const handler = this.progressManager.getHandler(messageId); + if (handler) { + // Store the handler for this task + this.taskProgressHandlers.set(taskId, { messageId, handler }); + + // Re-register the handler so it stays active + // This is called before Protocol.removeHandler, so we need to + // re-register after Protocol removes it. We do this by + // scheduling it on next tick. + queueMicrotask(() => { + this.progressManager?.registerHandler(messageId, handler); + }); + } + } + } + } + } + + /** + * Clears the progress handler for a completed task. + * Call this when a task reaches terminal state. + * + * @param taskId - The task ID whose progress handler should be removed + */ + clearTaskProgress(taskId: string): void { + const entry = this.taskProgressHandlers.get(taskId); + if (entry) { + this.progressManager?.removeHandler(entry.messageId); + this.taskProgressHandlers.delete(taskId); + } + } + + /** + * Checks if a task has an active progress handler. + */ + hasTaskProgress(taskId: string): boolean { + return this.taskProgressHandlers.has(taskId); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Outgoing Message Hooks + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Augments outgoing requests with task metadata. + * - Adds task creation params if `task` option is provided + * - Adds related task metadata if `relatedTask` option is provided + * - Registers response resolver for task-related requests + */ + onBeforeSendRequest(request: JSONRPCRequest, ctx: OutgoingRequestContext): JSONRPCRequest | void { + // Read task-specific options from the raw options object + const options = ctx.requestOptions as TaskRequestOptions | undefined; + if (!options) return; + + let modified = request; + const { task, relatedTask } = options; + + // Augment with task creation parameters if provided + if (task) { + modified = { + ...modified, + params: { + ...modified.params, + task + } + }; + } + + // Augment with related task metadata if provided + if (relatedTask) { + const existingParams = (modified.params ?? {}) as Record; + const existingMeta = (existingParams._meta ?? {}) as Record; + modified = { + ...modified, + params: { + ...existingParams, + _meta: { + ...existingMeta, + [RELATED_TASK_META_KEY]: relatedTask + } + } + }; + + // Register resolver for task-related requests so responses route back + ctx.registerResolver(() => { + // The resolver is registered automatically by Protocol + }); + } + + // Return modified request if changes were made + if (modified === request) { + return undefined; + } + return modified; + } + + /** + * Augments outgoing notifications with task metadata. + * Adds related task metadata if `relatedTask` option is provided. + */ + onBeforeSendNotification(notification: JSONRPCNotification, ctx: OutgoingNotificationContext): JSONRPCNotification | void { + // Read task-specific options from the raw options object + const options = ctx.notificationOptions as TaskNotificationOptions | undefined; + if (!options?.relatedTask) return; + + const existingParams = (notification.params ?? {}) as Record; + const existingMeta = (existingParams._meta ?? {}) as Record; + const modified = { + ...notification, + params: { + ...existingParams, + _meta: { + ...existingMeta, + [RELATED_TASK_META_KEY]: options.relatedTask + } + } + }; + + return modified; + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Task API Methods + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Gets the current status of a task. + */ + async getTask(params: GetTaskRequest['params'], options?: PluginRequestOptions): Promise { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + return this.ctx.requests.sendRequest({ jsonrpc: '2.0', id: 0, method: 'tasks/get', params }, GetTaskResultSchema, options); + } + + /** + * Retrieves the result of a completed task. + * Uses long-polling to wait for task completion. + */ + async getTaskResult( + params: GetTaskPayloadRequest['params'], + resultSchema: T, + options?: PluginRequestOptions + ): Promise> { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + const result = await this.ctx.requests.sendRequest( + { jsonrpc: '2.0', id: 0, method: 'tasks/result', params }, + resultSchema, + options + ); + + // Clear progress handler when task result is retrieved + this.clearTaskProgress(params.taskId); + + return result; + } + + /** + * Lists all tasks, optionally with pagination. + */ + async listTasks(params?: { cursor?: string }, options?: PluginRequestOptions): Promise { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + return this.ctx.requests.sendRequest({ jsonrpc: '2.0', id: 0, method: 'tasks/list', params }, ListTasksResultSchema, options); + } + + /** + * Cancels a running task. + */ + async cancelTask(params: { taskId: string }, options?: PluginRequestOptions): Promise { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + const result = await this.ctx.requests.sendRequest( + { jsonrpc: '2.0', id: 0, method: 'tasks/cancel', params }, + CancelTaskResultSchema, + options + ); + + // Clear progress handler when task is cancelled + this.clearTaskProgress(params.taskId); + + return result; + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Task Streaming + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Sends a task-augmented request and streams status updates until completion. + * This handles the full task lifecycle: creation, polling, and result retrieval. + * + * @param request - The request to send (method and params) + * @param resultSchema - Schema to validate the final result + * @param options - Options including task creation params + * @yields ResponseMessage events for task creation, status updates, and final result/error + */ + async *requestStream( + request: Request, + resultSchema: T, + options: TaskClientRequestStreamOptions + ): AsyncGenerator>, void, void> { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + + let taskId: string | undefined; + try { + // Send the request and get the CreateTaskResult + // Convert Request to JSONRPCRequest format for sendRequest + const jsonRpcRequest = { jsonrpc: '2.0' as const, id: 0, ...request }; + const createResult = await this.ctx.requests.sendRequest(jsonRpcRequest, CreateTaskResultSchema, options); + + // Extract taskId from the result + if (createResult.task) { + taskId = createResult.task.taskId; + yield { type: 'taskCreated', task: createResult.task }; + } else { + throw new ProtocolError(ErrorCode.InternalError, 'Task creation did not return a task'); + } + + // Poll for task completion + while (true) { + // Get current task status + const task = await this.getTask({ taskId }, options); + yield { type: 'taskStatus', task }; + + // Check if task is terminal + if (isTerminal(task.status)) { + switch (task.status) { + case 'completed': { + // Get the final result + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + break; + } + case 'failed': { + yield { + type: 'error', + error: new ProtocolError(ErrorCode.InternalError, `Task ${taskId} failed`) + }; + break; + } + case 'cancelled': { + yield { + type: 'error', + error: new ProtocolError(ErrorCode.InternalError, `Task ${taskId} was cancelled`) + }; + break; + } + // No default + } + return; + } + + // When input_required, call tasks/result to deliver queued messages + // (elicitation, sampling) via SSE and block until terminal + if (task.status === 'input_required') { + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + return; + } + + // Wait before polling again + const pollInterval = task.pollInterval ?? options.defaultPollInterval ?? 1000; + await new Promise(resolve => setTimeout(resolve, pollInterval)); + + // Check if cancelled + options.signal?.throwIfAborted(); + } + } catch (error) { + yield { + type: 'error', + error: error instanceof ProtocolError ? error : new ProtocolError(ErrorCode.InternalError, String(error)) + }; + } + } +} + +/** + * Options for TaskClientPlugin.requestStream. + */ +export interface TaskClientRequestStreamOptions extends PluginRequestOptions { + task: TaskCreationParams; + defaultPollInterval?: number; +} + +/** + * Factory function to create a TaskClientPlugin. + */ +export function createTaskClientPlugin(): TaskClientPlugin { + return new TaskClientPlugin(); +} diff --git a/packages/core/src/shared/taskPlugin.ts b/packages/core/src/shared/taskPlugin.ts new file mode 100644 index 000000000..6d6e86589 --- /dev/null +++ b/packages/core/src/shared/taskPlugin.ts @@ -0,0 +1,492 @@ +/** + * Task Plugin + * + * This plugin completely abstracts all task-related functionality from the Protocol class: + * - Message routing for task-related messages (queue instead of send) + * - Task API handlers (tasks/get, tasks/result, tasks/list, tasks/cancel) + * - Task message queue management + * + * The plugin is internal to the SDK and not exposed as a public API. + */ + +import { ProtocolError } from '../errors.js'; +import { RequestTaskStore } from '../experimental/requestTaskStore.js'; +import type { QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; +import { isTerminal } from '../experimental/tasks/interfaces.js'; +import type { + CancelTaskResult, + GetTaskPayloadRequest, + GetTaskRequest, + GetTaskResult, + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + ListTasksResult, + RequestId, + Result +} from '../types/types.js'; +import { + CancelTaskRequestSchema, + ErrorCode, + GetTaskPayloadRequestSchema, + GetTaskRequestSchema, + isJSONRPCRequest, + isTaskAugmentedRequestParams, + ListTasksRequestSchema, + RELATED_TASK_META_KEY +} from '../types/types.js'; +import type { HandlerContextBase, PluginContext, PluginHandlerExtra, ProtocolPlugin } from './plugin.js'; +import type { Transport, TransportSendOptions } from './transport.js'; + +/** + * Configuration for the TaskPlugin. + */ +export interface TaskPluginConfig { + /** + * The task store implementation for persisting task state. + */ + readonly taskStore: TaskStore; + + /** + * Optional message queue for async message delivery during task execution. + */ + readonly taskMessageQueue?: TaskMessageQueue; + + /** + * Default polling interval (in milliseconds) for task status checks. + * Defaults to 1000ms if not specified. + */ + readonly defaultTaskPollInterval?: number; + + /** + * Maximum number of messages that can be queued per task. + * If undefined, the queue size is unbounded. + */ + readonly maxTaskQueueSize?: number; +} + +/** + * Plugin that handles all task-related MCP operations. + * This completely abstracts task functionality from the Protocol class. + */ +export class TaskPlugin implements ProtocolPlugin { + readonly name = 'TaskPlugin'; + readonly priority = 100; // High priority to run before other plugins + + private ctx?: PluginContext; + private transport?: Transport; + + constructor(private readonly config: TaskPluginConfig) {} + + // ═══════════════════════════════════════════════════════════════════════════ + // Plugin Lifecycle + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Install the plugin by registering task request handlers. + */ + install(ctx: PluginContext): void { + this.ctx = ctx; + + // Register tasks/get handler + ctx.handlers.setRequestHandler(GetTaskRequestSchema, async (request, extra) => { + return this.handleGetTask(request, extra); + }); + + // Register tasks/result handler + ctx.handlers.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra) => { + return this.handleGetTaskPayload(request, extra); + }); + + // Register tasks/list handler + ctx.handlers.setRequestHandler(ListTasksRequestSchema, async (request, extra) => { + return this.handleListTasks(request.params, extra); + }); + + // Register tasks/cancel handler + ctx.handlers.setRequestHandler(CancelTaskRequestSchema, async (request, extra) => { + return this.handleCancelTask(request.params, extra); + }); + } + + /** + * Called when transport connects. + */ + onConnect(transport: Transport): void { + this.transport = transport; + } + + /** + * Called when connection closes. + */ + onClose(): void { + this.transport = undefined; + } + + /** + * Called before a request is processed. + * Checks if task creation is supported for the request method. + */ + onRequest(request: JSONRPCRequest): JSONRPCRequest | void { + // If this request asks for task creation, check capability + const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; + if (taskCreationParams) { + // Check if this method supports task creation + // For now, we support tasks for tools/call and sampling/createMessage + const taskCapableMethods = ['tools/call', 'sampling/createMessage']; + if (!taskCapableMethods.includes(request.method)) { + throw new ProtocolError(ErrorCode.InvalidRequest, `Task creation is not supported for method: ${request.method}`); + } + } + // Return void to pass through unchanged + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Message Routing + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Extracts the relatedTaskId from a message's _meta field. + */ + private extractRelatedTaskId( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse + ): string | undefined { + // For requests/notifications, check params._meta + if ('method' in message && 'params' in message && message.params) { + const params = message.params as Record; + const meta = params._meta as Record | undefined; + const taskMeta = meta?.[RELATED_TASK_META_KEY] as { taskId?: string } | undefined; + return taskMeta?.taskId; + } + return undefined; + } + + /** + * Determines if this plugin should route the message (queue for task delivery). + * Returns true if the message has a relatedTaskId in its metadata and task queue is configured. + */ + shouldRouteMessage( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + _options?: TransportSendOptions + ): boolean { + // Route if there's a related task ID in the message and we have a message queue + const relatedTaskId = this.extractRelatedTaskId(message); + return Boolean(relatedTaskId && this.config.taskMessageQueue); + } + + /** + * Routes the message by queueing it for task delivery. + */ + async routeMessage( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise { + const relatedTaskId = this.extractRelatedTaskId(message); + const sessionId = options?.sessionId; + if (!relatedTaskId || !this.config.taskMessageQueue) { + throw new Error('Cannot route message: relatedTaskId or taskMessageQueue not available'); + } + + const timestamp = Date.now(); + + // Create properly typed QueuedMessage based on message structure + let queuedMessage: QueuedMessage; + if ('method' in message && 'id' in message) { + queuedMessage = { type: 'request', message: message as JSONRPCRequest, timestamp }; + } else if ('method' in message && !('id' in message)) { + queuedMessage = { type: 'notification', message: message as JSONRPCNotification, timestamp }; + } else if ('result' in message) { + queuedMessage = { type: 'response', message: message as JSONRPCResultResponse, timestamp }; + } else { + queuedMessage = { type: 'error', message: message as JSONRPCErrorResponse, timestamp }; + } + + await this.enqueueTaskMessage(relatedTaskId, queuedMessage, sessionId); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Handler Context Hook + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Builds task context for incoming request handlers. + * Extracts task creation params and related task metadata from the request, + * creates a RequestTaskStore, and returns the task context. + */ + onBuildHandlerContext(request: JSONRPCRequest, baseContext: HandlerContextBase): Record | undefined { + // Only build task context if we have a task store configured + if (!this.config.taskStore) { + return undefined; + } + + // Extract task metadata from request + const relatedTaskId = this.extractRelatedTaskId(request); + const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; + + // Create the RequestTaskStore + const requestTaskStore = new RequestTaskStore({ + taskStore: this.config.taskStore, + requestId: request.id, + request, + sessionId: baseContext.sessionId, + initialTaskId: relatedTaskId ?? '' + }); + + // Return task context that will be merged into the handler context + return { + taskCtx: { + get id() { + return requestTaskStore.currentTaskId; + }, + store: requestTaskStore, + requestedTtl: taskCreationParams?.ttl ?? null + } + }; + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Task Message Queue Management + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Enqueues a message for task delivery. + */ + private async enqueueTaskMessage(taskId: string, message: QueuedMessage, sessionId?: string): Promise { + if (!this.config.taskMessageQueue) { + throw new Error('Cannot enqueue task message: taskMessageQueue is not configured'); + } + + await this.config.taskMessageQueue.enqueue(taskId, message, sessionId, this.config.maxTaskQueueSize); + } + + /** + * Clears the message queue for a task and rejects any pending request resolvers. + */ + private async clearTaskQueue(taskId: string, sessionId?: string): Promise { + if (!this.config.taskMessageQueue || !this.ctx) { + return; + } + + // Dequeue all messages and reject pending request resolvers + const messages = await this.config.taskMessageQueue.dequeueAll(taskId, sessionId); + for (const message of messages) { + if (message.type === 'request' && isJSONRPCRequest(message.message)) { + const requestId = message.message.id as RequestId; + const resolver = this.ctx.resolvers.get(requestId); + if (resolver) { + resolver(new ProtocolError(ErrorCode.InternalError, 'Task cancelled or completed')); + this.ctx.resolvers.remove(requestId); + } else { + this.ctx.reportError(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); + } + } + } + } + + /** + * Waits for a task update (new messages or status change) with abort signal support. + */ + private async waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { + // Get the task's poll interval, falling back to default + let interval = this.config.defaultTaskPollInterval ?? 1000; + try { + const task = await this.config.taskStore.getTask(taskId); + if (task?.pollInterval) { + interval = task.pollInterval; + } + } catch { + // Use default interval if task lookup fails + } + + return new Promise((resolve, reject) => { + if (signal.aborted) { + reject(new ProtocolError(ErrorCode.InvalidRequest, 'Request cancelled')); + return; + } + + // Wait for the poll interval, then resolve so caller can check for updates + const timeoutId = setTimeout(resolve, interval); + + // Clean up timeout and reject if aborted + signal.addEventListener( + 'abort', + () => { + clearTimeout(timeoutId); + reject(new ProtocolError(ErrorCode.InvalidRequest, 'Request cancelled')); + }, + { once: true } + ); + }); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Task API Handlers + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Handler for tasks/get - retrieves task metadata. + */ + private async handleGetTask(request: GetTaskRequest, extra: PluginHandlerExtra): Promise { + const task = await this.config.taskStore.getTask(request.params.taskId, extra.mcpCtx.sessionId); + if (!task) { + throw new ProtocolError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + // Per spec: tasks/get responses SHALL NOT include related-task metadata + return { ...task }; + } + + /** + * Handler for tasks/result - delivers task results and queued messages. + * Implements long-polling pattern for task updates. + */ + private async handleGetTaskPayload(request: GetTaskPayloadRequest, extra: PluginHandlerExtra): Promise { + const taskId = request.params.taskId; + + const poll = async (): Promise => { + // Deliver any queued messages first + await this.deliverQueuedMessages(taskId, extra); + + // Check task status + const task = await this.config.taskStore.getTask(taskId, extra.mcpCtx.sessionId); + if (!task) { + throw new ProtocolError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); + } + + // If task is not terminal, wait for updates and poll again + if (!isTerminal(task.status)) { + await this.waitForTaskUpdate(taskId, extra.requestCtx.signal); + return poll(); + } + + // Task is terminal - return the result + const result = await this.config.taskStore.getTaskResult(taskId, extra.mcpCtx.sessionId); + await this.clearTaskQueue(taskId, extra.mcpCtx.sessionId); + + return { + ...result, + _meta: { + ...result._meta, + [RELATED_TASK_META_KEY]: { taskId } + } + }; + }; + + return poll(); + } + + /** + * Delivers queued messages for a task. + */ + private async deliverQueuedMessages(taskId: string, extra: PluginHandlerExtra): Promise { + const { taskMessageQueue } = this.config; + if (!taskMessageQueue || !this.ctx) { + return; + } + + let queuedMessage: QueuedMessage | undefined; + while ((queuedMessage = await taskMessageQueue.dequeue(taskId, extra.mcpCtx.sessionId))) { + // Handle response and error messages by routing to original resolver + if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { + await this.routeQueuedResponse(queuedMessage); + continue; + } + + // Send other messages (notifications, requests) on the response stream + const transport = this.ctx.transport.getTransport(); + await transport?.send(queuedMessage.message, { relatedRequestId: extra.mcpCtx.requestId }); + } + } + + /** + * Routes a queued response/error back to its original request resolver. + */ + private async routeQueuedResponse(queuedMessage: QueuedMessage): Promise { + if (!this.ctx) return; + + const message = queuedMessage.message as JSONRPCResultResponse | JSONRPCErrorResponse; + const requestId = message.id as RequestId; + + const resolver = this.ctx.resolvers.get(requestId); + if (!resolver) { + const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; + this.ctx.reportError(new Error(`${messageType} handler missing for request ${requestId}`)); + return; + } + + this.ctx.resolvers.remove(requestId); + + if (queuedMessage.type === 'response') { + resolver(message as JSONRPCResultResponse); + } else { + const errorMessage = message as JSONRPCErrorResponse; + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + + /** + * Handler for tasks/list - lists all tasks. + */ + private async handleListTasks(params: { cursor?: string } | undefined, extra: PluginHandlerExtra): Promise { + try { + const { tasks, nextCursor } = await this.config.taskStore.listTasks(params?.cursor, extra.mcpCtx.sessionId); + return { tasks, nextCursor, _meta: {} }; + } catch (error) { + throw new ProtocolError( + ErrorCode.InvalidParams, + `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + /** + * Handler for tasks/cancel - cancels a running task. + */ + private async handleCancelTask(params: { taskId: string }, extra: PluginHandlerExtra): Promise { + try { + const task = await this.config.taskStore.getTask(params.taskId, extra.mcpCtx.sessionId); + + if (!task) { + throw new ProtocolError(ErrorCode.InvalidParams, `Task not found: ${params.taskId}`); + } + + if (isTerminal(task.status)) { + throw new ProtocolError(ErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); + } + + await this.config.taskStore.updateTaskStatus( + params.taskId, + 'cancelled', + 'Client cancelled task execution.', + extra.mcpCtx.sessionId + ); + + await this.clearTaskQueue(params.taskId, extra.mcpCtx.sessionId); + + const cancelledTask = await this.config.taskStore.getTask(params.taskId, extra.mcpCtx.sessionId); + if (!cancelledTask) { + throw new ProtocolError(ErrorCode.InvalidParams, `Task not found after cancellation: ${params.taskId}`); + } + + return { _meta: {}, ...cancelledTask }; + } catch (error) { + if (error instanceof ProtocolError) { + throw error; + } + throw new ProtocolError( + ErrorCode.InvalidRequest, + `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` + ); + } + } +} + +/** + * Factory function to create a TaskPlugin. + */ +export function createTaskPlugin(config: TaskPluginConfig): TaskPlugin { + return new TaskPlugin(config); +} diff --git a/packages/core/src/shared/timeoutManager.ts b/packages/core/src/shared/timeoutManager.ts new file mode 100644 index 000000000..e87da7367 --- /dev/null +++ b/packages/core/src/shared/timeoutManager.ts @@ -0,0 +1,172 @@ +/** + * Timeout Manager + * + * Manages request timeouts for the Protocol class. + * Extracted from Protocol to follow Single Responsibility Principle. + */ + +/** + * Information about a request's timeout state + */ +export interface TimeoutInfo { + timeoutId: ReturnType; + startTime: number; + timeout: number; + maxTotalTimeout?: number; + resetTimeoutOnProgress: boolean; + onTimeout: () => void; +} + +/** + * Options for setting up a timeout + */ +export interface TimeoutOptions { + /** + * The timeout duration in milliseconds + */ + timeout: number; + + /** + * Maximum total time allowed (optional) + */ + maxTotalTimeout?: number; + + /** + * Whether to reset the timeout when progress is received + */ + resetTimeoutOnProgress?: boolean; + + /** + * Callback to invoke when the timeout expires + */ + onTimeout: () => void; +} + +/** + * Manages request timeouts for outgoing requests. + */ +export class TimeoutManager { + private _timeoutInfo: Map = new Map(); + + /** + * Sets up a timeout for a message. + * + * @param messageId - The unique identifier for the message + * @param options - Timeout configuration options + */ + setup(messageId: number, options: TimeoutOptions): void { + const { timeout, maxTotalTimeout, resetTimeoutOnProgress, onTimeout } = options; + + this._timeoutInfo.set(messageId, { + timeoutId: setTimeout(onTimeout, timeout), + startTime: Date.now(), + timeout, + maxTotalTimeout, + resetTimeoutOnProgress: resetTimeoutOnProgress ?? false, + onTimeout + }); + } + + /** + * Resets the timeout for a message (e.g., when progress is received). + * Returns true if the timeout was reset, false if it wasn't found or + * if the max total timeout would be exceeded. + * + * @param messageId - The message ID whose timeout should be reset + * @returns true if reset succeeded, false otherwise + */ + reset(messageId: number): boolean { + const info = this._timeoutInfo.get(messageId); + if (!info || !info.resetTimeoutOnProgress) { + return false; + } + + const elapsed = Date.now() - info.startTime; + + // Check if max total timeout would be exceeded + if (info.maxTotalTimeout === undefined) { + // No max total timeout, just reset with original timeout + clearTimeout(info.timeoutId); + info.timeoutId = setTimeout(info.onTimeout, info.timeout); + } else { + const remainingTotal = info.maxTotalTimeout - elapsed; + if (remainingTotal <= 0) { + // Don't reset, let the timeout fire + return false; + } + + // Clear old timeout and set new one with the smaller of: + // - original timeout + // - remaining total time + clearTimeout(info.timeoutId); + const newTimeout = Math.min(info.timeout, remainingTotal); + info.timeoutId = setTimeout(info.onTimeout, newTimeout); + } + + return true; + } + + /** + * Cleans up the timeout for a message (e.g., when a response is received). + * + * @param messageId - The message ID whose timeout should be cleaned up + */ + cleanup(messageId: number): void { + const info = this._timeoutInfo.get(messageId); + if (info) { + clearTimeout(info.timeoutId); + this._timeoutInfo.delete(messageId); + } + } + + /** + * Gets the timeout info for a message. + * + * @param messageId - The message ID + * @returns The timeout info or undefined if not found + */ + get(messageId: number): TimeoutInfo | undefined { + return this._timeoutInfo.get(messageId); + } + + /** + * Checks if a timeout exists for a message. + * + * @param messageId - The message ID + * @returns true if a timeout exists + */ + has(messageId: number): boolean { + return this._timeoutInfo.has(messageId); + } + + /** + * Gets the elapsed time for a message's timeout. + * + * @param messageId - The message ID + * @returns The elapsed time in milliseconds, or undefined if not found + */ + getElapsed(messageId: number): number | undefined { + const info = this._timeoutInfo.get(messageId); + if (!info) { + return undefined; + } + return Date.now() - info.startTime; + } + + /** + * Clears all timeouts. + */ + clearAll(): void { + for (const info of this._timeoutInfo.values()) { + clearTimeout(info.timeoutId); + } + this._timeoutInfo.clear(); + } + + /** + * Gets the number of active timeouts. + */ + get size(): number { + return this._timeoutInfo.size; + } +} diff --git a/packages/core/src/shared/transport.ts b/packages/core/src/shared/transport.ts index 87608f124..844445c7b 100644 --- a/packages/core/src/shared/transport.ts +++ b/packages/core/src/shared/transport.ts @@ -1,5 +1,29 @@ import type { JSONRPCMessage, MessageExtraInfo, RequestId } from '../types/types.js'; +// ═══════════════════════════════════════════════════════════════════════════ +// Connection State +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Represents the current state of a transport connection. + * + * State transitions: + * - disconnected β†’ connecting β†’ connected + * - connected β†’ reconnecting β†’ connected + * - * β†’ error (from any state on unrecoverable error) + * - * β†’ disconnected (on close) + */ +export type ConnectionState = 'disconnected' | 'connecting' | 'authenticating' | 'connected' | 'reconnecting' | 'error'; + +/** + * Callback for connection state changes + */ +export type ConnectionStateChangeCallback = (state: ConnectionState, previousState: ConnectionState) => void; + +// ═══════════════════════════════════════════════════════════════════════════ +// Fetch Utilities +// ═══════════════════════════════════════════════════════════════════════════ + export type FetchLike = (url: string | URL, init?: RequestInit) => Promise; /** @@ -54,6 +78,11 @@ export type TransportSendOptions = { */ relatedRequestId?: RequestId; + /** + * Optional session ID for the message routing context. + */ + sessionId?: string; + /** * The resumption token used to continue long-running requests that were interrupted. * @@ -125,4 +154,23 @@ export interface Transport { * Sets the protocol version used for the connection (called when the initialize response is received). */ setProtocolVersion?: (version: string) => void; + + // ─── Connection State (optional, for transports that support it) ─── + + /** + * The current connection state. + * Optional - not all transports track state. + */ + readonly state?: ConnectionState; + + /** + * Whether the transport is currently connected. + * This is a convenience property equivalent to `state === 'connected'`. + */ + readonly isConnected?: boolean; + + /** + * Callback for when the connection state changes. + */ + onStateChange?: ConnectionStateChangeCallback; } diff --git a/packages/core/src/types/types.ts b/packages/core/src/types/types.ts index d3e404c58..a8c08f4f8 100644 --- a/packages/core/src/types/types.ts +++ b/packages/core/src/types/types.ts @@ -2302,48 +2302,9 @@ export const ServerResultSchema = z.union([ CreateTaskResultSchema ]); -export class McpError extends Error { - constructor( - public readonly code: number, - message: string, - public readonly data?: unknown - ) { - super(`MCP error ${code}: ${message}`); - this.name = 'McpError'; - } - - /** - * Factory method to create the appropriate error type based on the error code and data - */ - static fromError(code: number, message: string, data?: unknown): McpError { - // Check for specific error types - if (code === ErrorCode.UrlElicitationRequired && data) { - const errorData = data as { elicitations?: unknown[] }; - if (errorData.elicitations) { - return new UrlElicitationRequiredError(errorData.elicitations as ElicitRequestURLParams[], message); - } - } - - // Default to generic McpError - return new McpError(code, message, data); - } -} - -/** - * Specialized error type when a tool requires a URL mode elicitation. - * This makes it nicer for the client to handle since there is specific data to work with instead of just a code to check against. - */ -export class UrlElicitationRequiredError extends McpError { - constructor(elicitations: ElicitRequestURLParams[], message: string = `URL elicitation${elicitations.length > 1 ? 's' : ''} required`) { - super(ErrorCode.UrlElicitationRequired, message, { - elicitations: elicitations - }); - } - - get elicitations(): ElicitRequestURLParams[] { - return (this.data as { elicitations: ElicitRequestURLParams[] })?.elicitations ?? []; - } -} +// Note: McpError has been removed. Use ProtocolError from '../errors.js' instead. +// ProtocolError is for errors with locked codes (SDK-generated or user-intentional). +// For customizable errors, throw a plain Error and use the onError handler. type Primitive = string | number | boolean | bigint | null | undefined; type Flatten = T extends Primitive diff --git a/packages/core/src/util/content.ts b/packages/core/src/util/content.ts new file mode 100644 index 000000000..e1ac5fa30 --- /dev/null +++ b/packages/core/src/util/content.ts @@ -0,0 +1,211 @@ +/** + * Content Formatting Helpers + * + * Utilities for working with tool call results and content types. + * Reduces boilerplate when processing mixed content types in results. + */ + +import type { + AudioContent, + BlobResourceContents, + ContentBlock, + EmbeddedResource, + ImageContent, + ResourceLink, + TextContent, + TextResourceContents +} from '../types/types.js'; + +/** + * Type guard to check if content is TextContent + */ +export function isTextContent(item: ContentBlock): item is TextContent { + return item.type === 'text'; +} + +/** + * Type guard to check if content is ImageContent + */ +export function isImageContent(item: ContentBlock): item is ImageContent { + return item.type === 'image'; +} + +/** + * Type guard to check if content is AudioContent + */ +export function isAudioContent(item: ContentBlock): item is AudioContent { + return item.type === 'audio'; +} + +/** + * Type guard to check if content is EmbeddedResource + */ +export function isEmbeddedResource(item: ContentBlock): item is EmbeddedResource { + return item.type === 'resource'; +} + +/** + * Type guard to check if content is ResourceLink + */ +export function isResourceLink(item: ContentBlock): item is ResourceLink { + return item.type === 'resource_link'; +} + +/** + * Extracts all text content from a tool result content array. + * + * @example + * ```typescript + * const result = await client.callTool('search', { query: 'hello' }); + * const texts = extractTextContent(result.content); + * console.log(texts.join('\n')); + * ``` + */ +export function extractTextContent(content: ContentBlock[]): string[] { + return content.filter(item => isTextContent(item)).map(item => item.text); +} + +/** + * Formats all text content from a tool result as a single string. + * + * @param content - The content array from a tool result + * @param separator - Separator between text items (default: newline) + * @returns Concatenated text content + * + * @example + * ```typescript + * const result = await client.callTool('search', { query: 'hello' }); + * const text = formatTextContent(result.content); + * ``` + */ +export function formatTextContent(content: ContentBlock[], separator: string = '\n'): string { + return extractTextContent(content).join(separator); +} + +/** + * Extracts all image content from a tool result content array. + */ +export function extractImageContent(content: ContentBlock[]): ImageContent[] { + return content.filter(item => isImageContent(item)); +} + +/** + * Extracts all audio content from a tool result content array. + */ +export function extractAudioContent(content: ContentBlock[]): AudioContent[] { + return content.filter(item => isAudioContent(item)); +} + +/** + * Extracts all embedded resources from a tool result content array. + */ +export function extractEmbeddedResources(content: ContentBlock[]): EmbeddedResource[] { + return content.filter(item => isEmbeddedResource(item)); +} + +/** + * Extracts all resource links from a tool result content array. + */ +export function extractResourceLinks(content: ContentBlock[]): ResourceLink[] { + return content.filter(item => isResourceLink(item)); +} + +/** + * Creates a text content item. + * + * @example + * ```typescript + * return { content: [text('Hello, world!')] }; + * ``` + */ +export function text(content: string, annotations?: TextContent['annotations']): TextContent { + return { + type: 'text', + text: content, + annotations + }; +} + +/** + * Creates an image content item from base64 data. + * + * @example + * ```typescript + * return { content: [image(base64Data, 'image/png')] }; + * ``` + */ +export function image(data: string, mimeType: string, annotations?: ImageContent['annotations']): ImageContent { + return { + type: 'image', + data, + mimeType, + annotations + }; +} + +/** + * Creates an audio content item from base64 data. + * + * @example + * ```typescript + * return { content: [audio(base64Data, 'audio/wav')] }; + * ``` + */ +export function audio(data: string, mimeType: string, annotations?: AudioContent['annotations']): AudioContent { + return { + type: 'audio', + data, + mimeType, + annotations + }; +} + +/** + * Creates an embedded resource content item. + * + * @example + * ```typescript + * return { + * content: [ + * embeddedResource({ + * uri: 'file:///path/to/file.txt', + * mimeType: 'text/plain', + * text: 'File contents' + * }) + * ] + * }; + * ``` + */ +export function embeddedResource( + resource: TextResourceContents | BlobResourceContents, + annotations?: EmbeddedResource['annotations'] +): EmbeddedResource { + return { + type: 'resource', + resource, + annotations + }; +} + +/** + * Creates a resource link content item. + * + * @example + * ```typescript + * return { + * content: [ + * resourceLink({ + * uri: 'file:///path/to/file.txt', + * mimeType: 'text/plain', + * name: 'file.txt' + * }) + * ] + * }; + * ``` + */ +export function resourceLink(link: Omit): ResourceLink { + return { + type: 'resource_link', + ...link + }; +} diff --git a/packages/core/src/util/inMemory.ts b/packages/core/src/util/inMemory.ts index 3f832b06b..9e541b053 100644 --- a/packages/core/src/util/inMemory.ts +++ b/packages/core/src/util/inMemory.ts @@ -1,3 +1,4 @@ +import { StateError } from '../errors.js'; import type { Transport } from '../shared/transport.js'; import type { AuthInfo, JSONRPCMessage, RequestId } from '../types/types.js'; @@ -50,7 +51,7 @@ export class InMemoryTransport implements Transport { */ async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId; authInfo?: AuthInfo }): Promise { if (!this._otherTransport) { - throw new Error('Not connected'); + throw StateError.notConnected('send message'); } if (this._otherTransport.onmessage) { diff --git a/packages/core/src/util/zodJsonSchemaCompat.ts b/packages/core/src/util/zodJsonSchemaCompat.ts index 12e5e88c4..d144a6941 100644 --- a/packages/core/src/util/zodJsonSchemaCompat.ts +++ b/packages/core/src/util/zodJsonSchemaCompat.ts @@ -9,6 +9,7 @@ import type * as z4c from 'zod/v4/core'; import * as z4mini from 'zod/v4-mini'; import { zodToJsonSchema } from 'zod-to-json-schema'; +import { ValidationError } from '../errors.js'; import type { AnyObjectSchema, AnySchema } from './zodCompat.js'; import { getLiteralValue, getObjectShape, isZ4Schema, safeParse } from './zodCompat.js'; @@ -48,7 +49,7 @@ export function getMethodLiteral(schema: AnyObjectSchema): string { const shape = getObjectShape(schema); const methodSchema = shape?.method as AnySchema | undefined; if (!methodSchema) { - throw new Error('Schema is missing a method literal'); + throw ValidationError.invalidSchema('Schema is missing a method literal'); } const value = getLiteralValue(methodSchema); diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index ab657f0dc..4c76d8fee 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -11,6 +11,7 @@ import type { TaskStore } from '../../src/experimental/tasks/interfaces.js'; import { InMemoryTaskMessageQueue } from '../../src/experimental/tasks/stores/inMemory.js'; +import type { BaseRequestContext, ContextInterface } from '../../src/shared/context.js'; import { mergeCapabilities, Protocol } from '../../src/shared/protocol.js'; import type { ErrorMessage, ResponseMessage } from '../../src/shared/responseMessage.js'; import { toArrayAsync } from '../../src/shared/responseMessage.js'; @@ -21,6 +22,7 @@ import type { JSONRPCMessage, JSONRPCRequest, JSONRPCResultResponse, + MessageExtraInfo, Notification, Request, RequestId, @@ -29,7 +31,8 @@ import type { Task, TaskCreationParams } from '../../src/types/types.js'; -import { CallToolRequestSchema, ErrorCode, McpError, RELATED_TASK_META_KEY } from '../../src/types/types.js'; +import { CallToolRequestSchema, ErrorCode, RELATED_TASK_META_KEY } from '../../src/types/types.js'; +import { ProtocolError } from '../../src/errors.js'; // Type helper for accessing private/protected Protocol properties in tests interface TestProtocol { @@ -38,7 +41,7 @@ interface TestProtocol { _responseHandlers: Map void>; _taskProgressTokens: Map; _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; - requestTaskStore: (request: Request, authInfo: unknown) => TaskStore; + _cleanupTaskProgressHandler: (taskId: string) => void; // Protected task methods (exposed for testing) listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; cancelTask: (params: { taskId: string }) => Promise; @@ -66,7 +69,7 @@ function createMockTaskStore(options?: { return { createTask: vi.fn((taskParams: TaskCreationParams, _1: RequestId, _2: Request) => { // Generate a unique task ID - const taskId = `test-task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + const taskId = `test-task-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`; const createdAt = new Date().toISOString(); const task = (tasks[taskId] = { taskId, @@ -147,6 +150,93 @@ function assertQueuedRequest(o?: QueuedMessage): asserts o is QueuedRequest { expect(o?.type).toBe('request'); } +/** + * Creates a mock ContextInterface for testing. + * This provides a minimal implementation of the context interface. + */ +function createMockContext(args: { + request: JSONRPCRequest; + abortController: AbortController; + sessionId?: string; +}): ContextInterface { + return { + mcpCtx: { + requestId: args.request.id, + method: args.request.method, + _meta: args.request.params?._meta, + sessionId: args.sessionId + }, + requestCtx: { + signal: args.abortController.signal, + authInfo: undefined + }, + taskCtx: undefined, + sendNotification: async () => {}, + sendRequest: async () => ({}) as never + }; +} + +/** + * Creates a mock Protocol class for testing with all abstract methods implemented. + */ +function createTestProtocolClass(options?: { + taskStore?: TaskStore; + taskMessageQueue?: TaskMessageQueue; + debouncedNotificationMethods?: string[]; + maxTaskQueueSize?: number; + defaultTaskPollInterval?: number; +}) { + return class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + protected createRequestContext(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): ContextInterface { + // Create a context that properly delegates to the protocol + const mcpCtx = { + requestId: args.request.id, + method: args.request.method, + _meta: args.request.params?._meta, + sessionId: args.capturedTransport?.sessionId + }; + + return { + mcpCtx, + requestCtx: { + signal: args.abortController.signal, + authInfo: undefined + }, + taskCtx: undefined, + sendNotification: async (notification: Notification) => { + // Properly delegate to the protocol's notification method with relatedTask metadata + const notificationOptions: { relatedRequestId?: RequestId; relatedTask?: { taskId: string } } = { + relatedRequestId: mcpCtx.requestId + }; + // Extract relatedTask from the original request's _meta if present + if (args.relatedTaskId) { + notificationOptions.relatedTask = { taskId: args.relatedTaskId }; + } + await this.notification(notification, notificationOptions); + }, + sendRequest: async () => ({}) as never + }; + } + + constructor() { + super(options); + } + }; +} + describe('protocol tests', () => { let protocol: Protocol; let transport: MockTransport; @@ -155,13 +245,7 @@ describe('protocol tests', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + protocol = new (createTestProtocolClass())(); }); test('should throw a timeout error if the request exceeds the timeout', async () => { @@ -175,8 +259,8 @@ describe('protocol tests', () => { timeout: 0 }); } catch (error) { - expect(error).toBeInstanceOf(McpError); - if (error instanceof McpError) { + expect(error).toBeInstanceOf(ProtocolError); + if (error instanceof ProtocolError) { expect(error.code).toBe(ErrorCode.RequestTimeout); } } @@ -629,13 +713,7 @@ describe('protocol tests', () => { it('should NOT debounce a notification that has parameters', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced_with_params'] }))(); await protocol.connect(transport); // ACT @@ -652,13 +730,7 @@ describe('protocol tests', () => { it('should NOT debounce a notification that has a relatedRequestId', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced_with_options'] }))(); await protocol.connect(transport); // ACT @@ -673,13 +745,7 @@ describe('protocol tests', () => { it('should clear pending debounced notifications on connection close', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); await protocol.connect(transport); // ACT @@ -699,13 +765,7 @@ describe('protocol tests', () => { it('should debounce multiple synchronous calls when params property is omitted', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); await protocol.connect(transport); // ACT @@ -728,13 +788,7 @@ describe('protocol tests', () => { it('should debounce calls when params is explicitly undefined', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); await protocol.connect(transport); // ACT @@ -755,13 +809,7 @@ describe('protocol tests', () => { it('should send non-debounced notifications immediately and multiple times', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); // Configure for a different method await protocol.connect(transport); // ACT @@ -790,13 +838,7 @@ describe('protocol tests', () => { it('should handle sequential batches of debounced notifications correctly', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new (createTestProtocolClass({ debouncedNotificationMethods: ['test/debounced'] }))(); await protocol.connect(transport); // ACT (Batch 1) @@ -1009,13 +1051,7 @@ describe('Task-based execution', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = new (createTestProtocolClass({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }))(); }); describe('request with task metadata', () => { @@ -1034,7 +1070,7 @@ describe('Task-based execution', () => { void protocol .request(request, resultSchema, { task: { - ttl: 30000, + ttl: 30_000, pollInterval: 1000 } }) @@ -1048,7 +1084,7 @@ describe('Task-based execution', () => { params: { name: 'test-tool', task: { - ttl: 30000, + ttl: 30_000, pollInterval: 1000 } } @@ -1077,7 +1113,7 @@ describe('Task-based execution', () => { void protocol .request(request, resultSchema, { task: { - ttl: 60000 + ttl: 60_000 } }) .catch(() => { @@ -1092,7 +1128,7 @@ describe('Task-based execution', () => { customField: 'customValue' }, task: { - ttl: 60000 + ttl: 60_000 } } }), @@ -1114,7 +1150,7 @@ describe('Task-based execution', () => { const resultPromise = protocol.request(request, resultSchema, { task: { - ttl: 30000 + ttl: 30_000 } }); @@ -1204,7 +1240,7 @@ describe('Task-based execution', () => { void protocol .request(request, resultSchema, { task: { - ttl: 60000, + ttl: 60_000, pollInterval: 1000 }, relatedTask: { @@ -1232,7 +1268,7 @@ describe('Task-based execution', () => { expect(queuedMessage.message.params).toMatchObject({ name: 'test-tool', task: { - ttl: 60000, + ttl: 60_000, pollInterval: 1000 }, _meta: { @@ -1256,20 +1292,14 @@ describe('Task-based execution', () => { // rather than in _meta, and that task management is handled by tool implementors const mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); protocol.setRequestHandler(CallToolRequestSchema, async request => { // Tool implementor can access task creation parameters from request.params.task expect(request.params.task).toEqual({ - ttl: 60000, + ttl: 60_000, pollInterval: 1000 }); return { result: 'success' }; @@ -1283,7 +1313,7 @@ describe('Task-based execution', () => { name: 'test', arguments: {}, task: { - ttl: 60000, + ttl: 60_000, pollInterval: 1000 } } @@ -1315,7 +1345,7 @@ describe('Task-based execution', () => { const task2 = await mockTaskStore.createTask( { - ttl: 60000, + ttl: 60_000, pollInterval: 1000 }, 2, @@ -1325,13 +1355,7 @@ describe('Task-based execution', () => { } ); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); @@ -1361,7 +1385,7 @@ describe('Task-based execution', () => { { taskId: task2.taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: expect.any(String), lastUpdatedAt: expect.any(String), pollInterval: 1000 @@ -1386,13 +1410,7 @@ describe('Task-based execution', () => { } ); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); @@ -1432,13 +1450,7 @@ describe('Task-based execution', () => { onList: () => listedTasks.releaseLatch() }); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); @@ -1465,13 +1477,7 @@ describe('Task-based execution', () => { const mockTaskStore = createMockTaskStore(); mockTaskStore.listTasks.mockRejectedValue(new Error('Invalid cursor: bad-cursor')); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); @@ -1492,7 +1498,7 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(4); expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.code).toBe(-32_602); // InvalidParams error code expect(sentMessage.error.message).toContain('Failed to list tasks'); expect(sentMessage.error.message).toContain('Invalid cursor'); }); @@ -1552,7 +1558,7 @@ describe('Task-based execution', () => { { taskId: 'task-11', status: 'working', - ttl: 30000, + ttl: 30_000, createdAt: '2024-01-01T00:00:00Z', lastUpdatedAt: '2024-01-01T00:00:00Z', pollInterval: 1000 @@ -1599,13 +1605,7 @@ describe('Task-based execution', () => { throw new Error('Task not found'); }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1641,13 +1641,7 @@ describe('Task-based execution', () => { mockTaskStore.getTask.mockResolvedValue(null); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1671,7 +1665,7 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(6); expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.code).toBe(-32_602); // InvalidParams error code expect(sentMessage.error.message).toContain('Task not found'); }); @@ -1689,13 +1683,7 @@ describe('Task-based execution', () => { mockTaskStore.updateTaskStatus.mockClear(); mockTaskStore.getTask.mockResolvedValue(completedTask); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1719,7 +1707,7 @@ describe('Task-based execution', () => { expect(sentMessage.jsonrpc).toBe('2.0'); expect(sentMessage.id).toBe(7); expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.code).toBe(-32_602); // InvalidParams error code expect(sentMessage.error.message).toContain('Cannot cancel task in terminal status'); }); @@ -1737,7 +1725,7 @@ describe('Task-based execution', () => { _meta: {}, taskId: 'task-to-delete', status: 'cancelled', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString() } @@ -1771,13 +1759,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); await serverProtocol.connect(serverTransport); @@ -1806,7 +1788,7 @@ describe('Task-based execution', () => { // Verify that getTask was called after updateTaskStatus // This is done by the RequestTaskStore wrapper to get the updated task for the notification const getTaskCalls = mockTaskStore.getTask.mock.calls; - const lastGetTaskCall = getTaskCalls[getTaskCalls.length - 1]; + const lastGetTaskCall = getTaskCalls.at(-1); expect(lastGetTaskCall?.[0]).toBe(task.taskId); }); }); @@ -1821,13 +1803,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1870,13 +1846,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1907,13 +1877,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1952,13 +1916,7 @@ describe('Task-based execution', () => { await mockTaskStore.storeTaskResult(task.taskId, 'completed', testResult); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1995,13 +1953,10 @@ describe('Task-based execution', () => { it('should propagate related-task metadata to handler sendRequest and sendNotification', async () => { const mockTaskStore = createMockTaskStore(); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const serverProtocol = new (createTestProtocolClass({ + taskStore: mockTaskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + }))(); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2009,9 +1964,9 @@ describe('Task-based execution', () => { await serverProtocol.connect(serverTransport); // Set up a handler that uses sendRequest and sendNotification - serverProtocol.setRequestHandler(CallToolRequestSchema, async (_request, extra) => { - // Send a notification using the extra.sendNotification - await extra.sendNotification({ + serverProtocol.setRequestHandler(CallToolRequestSchema, async (_request, ctx) => { + // Send a notification using the ctx.sendNotification + await ctx.sendNotification({ method: 'notifications/message', params: { level: 'info', data: 'test' } }); @@ -2078,13 +2033,7 @@ describe('Request Cancellation vs Task Cancellation', () => { beforeEach(() => { transport = new MockTransport(); taskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + protocol = new (createTestProtocolClass({ taskStore }))(); }); describe('notifications/cancelled behavior', () => { @@ -2097,10 +2046,10 @@ describe('Request Cancellation vs Task Cancellation', () => { method: z.literal('test/longRunning'), params: z.optional(z.record(z.string(), z.unknown())) }); - protocol.setRequestHandler(TestRequestSchema, async (_request, extra) => { + protocol.setRequestHandler(TestRequestSchema, async (_request, ctx) => { // Simulate a long-running operation await new Promise(resolve => setTimeout(resolve, 100)); - wasAborted = extra.signal.aborted; + wasAborted = ctx.requestCtx.signal.aborted; return { _meta: {} } as Result; }); @@ -2141,7 +2090,7 @@ describe('Request Cancellation vs Task Cancellation', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2173,7 +2122,7 @@ describe('Request Cancellation vs Task Cancellation', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2207,7 +2156,7 @@ describe('Request Cancellation vs Task Cancellation', () => { const sendSpy = vi.spyOn(transport, 'send'); // Create a task and mark it as completed - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2279,7 +2228,7 @@ describe('Request Cancellation vs Task Cancellation', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2319,7 +2268,7 @@ describe('Request Cancellation vs Task Cancellation', () => { }); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + const task = await taskStore.createTask({ ttl: 60_000 }, 'req-1', { method: 'test/method', params: {} }); @@ -2371,24 +2320,12 @@ describe('Progress notification support for tasks', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + protocol = new (createTestProtocolClass())(); }); it('should maintain progress token association after CreateTaskResult is returned', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2412,7 +2349,7 @@ describe('Progress notification support for tasks', () => { // Start a task-augmented request with progress callback void protocol .request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }) .catch(() => { @@ -2439,7 +2376,7 @@ describe('Progress notification support for tasks', () => { task: { taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString() } } @@ -2475,26 +2412,20 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task reaches terminal status (completed)', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); await protocol.connect(transport); // Set up a request handler that will complete the task - protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskStore) { - const task = await extra.taskStore.createTask({ ttl: 60000 }); + protocol.setRequestHandler(CallToolRequestSchema, async (request, ctx) => { + if (ctx.taskCtx?.store) { + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000 }); // Simulate async work then complete the task setTimeout(async () => { - await extra.taskStore!.storeTaskResult(task.taskId, 'completed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'Done' }] }); }, 50); @@ -2522,7 +2453,7 @@ describe('Progress notification support for tasks', () => { // Start a task-augmented request with progress callback void protocol .request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }) .catch(() => { @@ -2537,7 +2468,7 @@ describe('Progress notification support for tasks', () => { const progressToken = sentRequest.params._meta.progressToken; // Create a task in the mock store first so it exists when we try to get it later - const createdTask = await taskStore.createTask({ ttl: 60000 }, messageId, request); + const createdTask = await taskStore.createTask({ ttl: 60_000 }, messageId, request); const taskId = createdTask.taskId; // Simulate CreateTaskResult response @@ -2576,11 +2507,12 @@ describe('Progress notification support for tasks', () => { expect(taskProgressTokens.has(taskId)).toBe(true); expect(taskProgressTokens.get(taskId)).toBe(progressToken); - // Simulate task completion by calling through the protocol's task store - // This will trigger the cleanup logic - const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; - const requestTaskStore = (protocol as unknown as TestProtocol).requestTaskStore(mockRequest, undefined); - await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); + // Simulate task completion by updating task store and triggering cleanup + // First update the task status in the store + await taskStore.storeTaskResult(taskId, 'completed', { content: [] }); + + // Then manually trigger the cleanup (in real usage, this happens via tasks/result polling) + (protocol as unknown as TestProtocol)._cleanupTaskProgressHandler(taskId); // Wait for all async operations including notification sending to complete await new Promise(resolve => setTimeout(resolve, 50)); @@ -2610,13 +2542,7 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task reaches terminal status (failed)', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2638,7 +2564,7 @@ describe('Progress notification support for tasks', () => { }); void protocol.request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }); @@ -2656,7 +2582,7 @@ describe('Progress notification support for tasks', () => { task: { taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString() } } @@ -2679,7 +2605,7 @@ describe('Progress notification support for tasks', () => { params: { taskId, status: 'failed', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString(), statusMessage: 'Task failed' @@ -2708,13 +2634,7 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task is cancelled', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2736,7 +2656,7 @@ describe('Progress notification support for tasks', () => { }); void protocol.request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }); @@ -2754,7 +2674,7 @@ describe('Progress notification support for tasks', () => { task: { taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString() } } @@ -2774,7 +2694,7 @@ describe('Progress notification support for tasks', () => { params: { taskId, status: 'cancelled', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString(), statusMessage: 'User cancelled' @@ -2803,13 +2723,7 @@ describe('Progress notification support for tasks', () => { it('should use the same progressToken throughout task lifetime', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = new (createTestProtocolClass({ taskStore }))(); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2831,7 +2745,7 @@ describe('Progress notification support for tasks', () => { }); void protocol.request(request, resultSchema, { - task: { ttl: 60000 }, + task: { ttl: 60_000 }, onprogress: progressCallback }); @@ -2849,7 +2763,7 @@ describe('Progress notification support for tasks', () => { task: { taskId, status: 'working', - ttl: 60000, + ttl: 60_000, createdAt: new Date().toISOString() } } @@ -2903,7 +2817,7 @@ describe('Progress notification support for tasks', () => { void protocol.request(request, resultSchema, { task: { - ttl: 60000 + ttl: 60_000 }, onprogress: onProgressMock }); @@ -2928,7 +2842,7 @@ describe('Progress notification support for tasks', () => { void protocol.request(request, resultSchema, { task: { - ttl: 30000 + ttl: 30_000 }, onprogress: onProgressMock }); @@ -2978,7 +2892,7 @@ describe('Progress notification support for tasks', () => { void protocol.request(request, resultSchema, { task: { - ttl: 30000 + ttl: 30_000 }, onprogress: onProgressMock }); @@ -2995,7 +2909,7 @@ describe('Progress notification support for tasks', () => { task: { taskId: 'task-123', status: 'working', - ttl: 30000, + ttl: 30_000, createdAt: new Date().toISOString() } } @@ -3072,18 +2986,12 @@ describe('Message interception for task-related notifications', () => { it('should queue notifications with io.modelcontextprotocol/related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); // Create a task first - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a notification with related task metadata await server.notification( @@ -3109,13 +3017,7 @@ describe('Message interception for task-related notifications', () => { it('should not queue notifications without related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); @@ -3136,18 +3038,16 @@ describe('Message interception for task-related notifications', () => { it('should propagate queue overflow errors without failing the task', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + const server = new (createTestProtocolClass({ + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue(), + maxTaskQueueSize: 100 + }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Fill the queue to max capacity (100 messages) for (let i = 0; i < 100; i++) { @@ -3183,13 +3083,7 @@ describe('Message interception for task-related notifications', () => { it('should extract task ID correctly from metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); @@ -3216,18 +3110,12 @@ describe('Message interception for task-related notifications', () => { it('should preserve message order when queuing multiple notifications', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send multiple notifications for (let i = 0; i < 5; i++) { @@ -3258,18 +3146,12 @@ describe('Message interception for task-related requests', () => { it('should queue requests with io.modelcontextprotocol/related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); // Create a task first - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a request with related task metadata (don't await - we're testing queuing) const requestPromise = server.request( @@ -3310,13 +3192,7 @@ describe('Message interception for task-related requests', () => { it('should not queue requests without related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); @@ -3349,18 +3225,12 @@ describe('Message interception for task-related requests', () => { it('should store request resolver for response routing', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a request with related task metadata const requestPromise = server.request( @@ -3402,18 +3272,12 @@ describe('Message interception for task-related requests', () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); const queue = new InMemoryTaskMessageQueue(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: queue }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: queue }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a request with related task metadata const requestPromise = server.request( @@ -3467,13 +3331,7 @@ describe('Message interception for task-related requests', () => { it('should log error when resolver is missing for side-channeled request', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = new (createTestProtocolClass({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); const errors: Error[] = []; server.onerror = (error: Error) => { @@ -3483,7 +3341,7 @@ describe('Message interception for task-related requests', () => { await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Send a request with related task metadata void server.request( @@ -3543,18 +3401,16 @@ describe('Message interception for task-related requests', () => { it('should propagate queue overflow errors for requests without failing the task', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + const server = new (createTestProtocolClass({ + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue(), + maxTaskQueueSize: 100 + }))(); await server.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 'test-request-1', { method: 'tools/call', params: {} }); // Fill the queue to max capacity (100 messages) const promises: Promise[] = []; @@ -3604,13 +3460,7 @@ describe('Message Interception', () => { beforeEach(() => { transport = new MockTransport(); mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); }); describe('messages with relatedTask metadata are queued', () => { @@ -3743,7 +3593,7 @@ describe('Message Interception', () => { }); protocol.setRequestHandler(TestRequestSchema, async () => { - throw new McpError(ErrorCode.InternalError, 'Test error message'); + throw new ProtocolError(ErrorCode.InternalError, 'Test error message'); }); // Simulate an incoming request with relatedTask metadata @@ -4145,13 +3995,7 @@ describe('Queue lifecycle management', () => { beforeEach(() => { transport = new MockTransport(); mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = new (createTestProtocolClass({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }))(); }); describe('queue cleanup on task completion', () => { @@ -4269,7 +4113,7 @@ describe('Queue lifecycle management', () => { .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); // Verify request is queued const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; @@ -4291,7 +4135,7 @@ describe('Queue lifecycle management', () => { // Verify the request promise is rejected const result = await requestPromise; - expect(result).toBeInstanceOf(McpError); + expect(result).toBeInstanceOf(ProtocolError); expect(result.message).toContain('Task cancelled or completed'); // Verify queue is cleared (no messages available) @@ -4342,7 +4186,7 @@ describe('Queue lifecycle management', () => { .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); // Verify request is queued const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; @@ -4353,7 +4197,7 @@ describe('Queue lifecycle management', () => { // Verify the request promise is rejected const result = await requestPromise; - expect(result).toBeInstanceOf(McpError); + expect(result).toBeInstanceOf(ProtocolError); expect(result.message).toContain('Task cancelled or completed'); // Verify queue is cleared (no messages available) @@ -4375,19 +4219,19 @@ describe('Queue lifecycle management', () => { .request({ method: 'test/request1', params: { data: 'test1' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); const request2Promise = protocol .request({ method: 'test/request2', params: { data: 'test2' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); const request3Promise = protocol .request({ method: 'test/request3', params: { data: 'test3' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); // Verify requests are queued const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; @@ -4401,11 +4245,11 @@ describe('Queue lifecycle management', () => { const result2 = await request2Promise; const result3 = await request3Promise; - expect(result1).toBeInstanceOf(McpError); + expect(result1).toBeInstanceOf(ProtocolError); expect(result1.message).toContain('Task cancelled or completed'); - expect(result2).toBeInstanceOf(McpError); + expect(result2).toBeInstanceOf(ProtocolError); expect(result2.message).toContain('Task cancelled or completed'); - expect(result3).toBeInstanceOf(McpError); + expect(result3).toBeInstanceOf(ProtocolError); expect(result3.message).toContain('Task cancelled or completed'); // Verify queue is cleared (no messages available) @@ -4425,7 +4269,7 @@ describe('Queue lifecycle management', () => { .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { relatedTask: { taskId } }) - .catch(err => err); + .catch(error => error); // Get the request ID that was sent const requestResolvers = (protocol as unknown as TestProtocol)._requestResolvers; @@ -4441,7 +4285,7 @@ describe('Queue lifecycle management', () => { // Verify request promise is rejected const result = await requestPromise; - expect(result).toBeInstanceOf(McpError); + expect(result).toBeInstanceOf(ProtocolError); expect(result.message).toContain('Task cancelled or completed'); // Verify resolver mapping is cleaned up @@ -4459,13 +4303,7 @@ describe('requestStream() method', () => { test('should yield result immediately for non-task requests', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); // Start the request stream @@ -4502,13 +4340,7 @@ describe('requestStream() method', () => { test('should yield error message on request failure', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); // Start the request stream @@ -4548,13 +4380,7 @@ describe('requestStream() method', () => { test('should handle cancellation via AbortSignal', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const abortController = new AbortController(); @@ -4586,13 +4412,7 @@ describe('requestStream() method', () => { describe('Error responses', () => { test('should yield error as terminal message for server error response', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4618,7 +4438,7 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); expect(lastMessage.error.message).toContain('Server error'); @@ -4628,13 +4448,7 @@ describe('requestStream() method', () => { vi.useFakeTimers(); try { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4655,7 +4469,7 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); expect(lastMessage.error.code).toBe(ErrorCode.RequestTimeout); @@ -4666,13 +4480,7 @@ describe('requestStream() method', () => { test('should yield error as terminal message for cancellation', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const abortController = new AbortController(); @@ -4691,7 +4499,7 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); expect(lastMessage.error.message).toContain('cancelled'); @@ -4699,13 +4507,7 @@ describe('requestStream() method', () => { test('should not yield any messages after error message', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4751,13 +4553,7 @@ describe('requestStream() method', () => { test('should yield error as terminal message for task failure', async () => { const transport = new MockTransport(); const mockTaskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const protocol = new (createTestProtocolClass({ taskStore: mockTaskStore }))(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4804,20 +4600,14 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); }); test('should yield error as terminal message for network error', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); // Override send to simulate network error @@ -4832,20 +4622,14 @@ describe('requestStream() method', () => { // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); assertErrorResponse(lastMessage!); expect(lastMessage.error).toBeDefined(); }); test('should ensure error is always the final message', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = new (createTestProtocolClass())(); await protocol.connect(transport); const messagesPromise = toArrayAsync( @@ -4871,7 +4655,7 @@ describe('requestStream() method', () => { // Verify error is the last message expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; + const lastMessage = messages.at(-1); expect(lastMessage?.type).toBe('error'); // Verify all messages before the last are not terminal @@ -4895,17 +4679,11 @@ describe('Error handling for missing resolvers', () => { taskMessageQueue = new InMemoryTaskMessageQueue(); errorHandler = vi.fn(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected assertTaskHandlerCapability(_method: string): void {} - })({ + protocol = new (createTestProtocolClass({ taskStore, taskMessageQueue, defaultTaskPollInterval: 100 - }); + }))(); // @ts-expect-error deliberately overriding error handler with mock protocol.onerror = errorHandler; @@ -4917,7 +4695,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); // Enqueue a response message without a corresponding resolver await taskMessageQueue.enqueue(task.taskId, { @@ -4962,7 +4740,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); // Enqueue a response with missing resolver, then a valid notification await taskMessageQueue.enqueue(task.taskId, { @@ -5003,7 +4781,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); // Enqueue a request without storing a resolver await taskMessageQueue.enqueue(task.taskId, { @@ -5033,7 +4811,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const requestId = 42; const resolverMock = vi.fn(); @@ -5058,7 +4836,7 @@ describe('Error handling for missing resolvers', () => { await testProtocol._clearTaskQueue(task.taskId); // Verify resolver was called with cancellation error - expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); // Verify the error has the correct properties const calledError = resolverMock.mock.calls[0]![0]; @@ -5073,7 +4851,7 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const testProtocol = protocol as unknown as TestProtocol; @@ -5118,7 +4896,7 @@ describe('Error handling for missing resolvers', () => { await testProtocol._clearTaskQueue(task.taskId); // Verify resolver was called for first request - expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); // Verify the error has the correct properties const calledError = resolverMock.mock.calls[0]![0]; @@ -5177,7 +4955,7 @@ describe('Error handling for missing resolvers', () => { it('should not throw when processing response with missing resolver', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); await taskMessageQueue.enqueue(task.taskId, { type: 'response', @@ -5209,7 +4987,7 @@ describe('Error handling for missing resolvers', () => { it('should not throw during task cleanup with missing resolvers', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); await taskMessageQueue.enqueue(task.taskId, { type: 'request', @@ -5233,7 +5011,7 @@ describe('Error handling for missing resolvers', () => { it('should route error messages to resolvers correctly', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const requestId = 42; const resolverMock = vi.fn(); @@ -5268,13 +5046,13 @@ describe('Error handling for missing resolvers', () => { if (resolver) { testProtocol._requestResolvers.delete(reqId); - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } - // Verify resolver was called with McpError - expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + // Verify resolver was called with ProtocolError + expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); const calledError = resolverMock.mock.calls[0]![0]; expect(calledError.code).toBe(ErrorCode.InvalidRequest); expect(calledError.message).toContain('Invalid request parameters'); @@ -5286,7 +5064,7 @@ describe('Error handling for missing resolvers', () => { it('should log error for unknown request ID in error messages', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); // Enqueue an error message without a corresponding resolver await taskMessageQueue.enqueue(task.taskId, { @@ -5330,7 +5108,7 @@ describe('Error handling for missing resolvers', () => { it('should handle error messages with data field', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const requestId = 42; const resolverMock = vi.fn(); @@ -5363,13 +5141,13 @@ describe('Error handling for missing resolvers', () => { if (resolver) { testProtocol._requestResolvers.delete(reqId); - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } - // Verify resolver was called with McpError including data - expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + // Verify resolver was called with ProtocolError including data + expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); const calledError = resolverMock.mock.calls[0]![0]; expect(calledError.code).toBe(ErrorCode.InvalidParams); expect(calledError.message).toContain('Validation failed'); @@ -5379,7 +5157,7 @@ describe('Error handling for missing resolvers', () => { it('should not throw when processing error with missing resolver', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); await taskMessageQueue.enqueue(task.taskId, { type: 'error', @@ -5416,7 +5194,7 @@ describe('Error handling for missing resolvers', () => { it('should handle mixed response and error messages in queue', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const testProtocol = protocol as unknown as TestProtocol; // Set up resolvers for multiple requests @@ -5479,7 +5257,7 @@ describe('Error handling for missing resolvers', () => { const resolver = testProtocol._requestResolvers.get(requestId); if (resolver) { testProtocol._requestResolvers.delete(requestId); - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } @@ -5487,7 +5265,7 @@ describe('Error handling for missing resolvers', () => { // Verify all resolvers were called correctly expect(resolver1).toHaveBeenCalledWith(expect.objectContaining({ id: 1 })); - expect(resolver2).toHaveBeenCalledWith(expect.any(McpError)); + expect(resolver2).toHaveBeenCalledWith(expect.any(ProtocolError)); expect(resolver3).toHaveBeenCalledWith(expect.objectContaining({ id: 3 })); // Verify error has correct properties @@ -5502,7 +5280,7 @@ describe('Error handling for missing resolvers', () => { it('should maintain FIFO order when processing responses and errors', async () => { await protocol.connect(transport); - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const task = await taskStore.createTask({ ttl: 60_000 }, 1, { method: 'test', params: {} }); const testProtocol = protocol as unknown as TestProtocol; const callOrder: number[] = []; @@ -5526,7 +5304,7 @@ describe('Error handling for missing resolvers', () => { message: { jsonrpc: '2.0', id: 2, - error: { code: -32600, message: 'Error' } + error: { code: -32_600, message: 'Error' } }, timestamp: 2000 }); @@ -5554,7 +5332,7 @@ describe('Error handling for missing resolvers', () => { const resolver = testProtocol._requestResolvers.get(requestId); if (resolver) { testProtocol._requestResolvers.delete(requestId); - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } diff --git a/packages/core/test/shared/protocolTransportHandling.test.ts b/packages/core/test/shared/protocolTransportHandling.test.ts index 0e1b9b5c9..69b092a59 100644 --- a/packages/core/test/shared/protocolTransportHandling.test.ts +++ b/packages/core/test/shared/protocolTransportHandling.test.ts @@ -1,9 +1,19 @@ import { beforeEach, describe, expect, test } from 'vitest'; import * as z from 'zod/v4'; +import type { TaskStore } from '../../src/experimental/tasks/interfaces.js'; +import type { BaseRequestContext, ContextInterface } from '../../src/shared/context.js'; import { Protocol } from '../../src/shared/protocol.js'; import type { Transport } from '../../src/shared/transport.js'; -import type { JSONRPCMessage, Notification, Request, Result } from '../../src/types/types.js'; +import type { + JSONRPCMessage, + JSONRPCRequest, + MessageExtraInfo, + Notification, + Request, + Result, + TaskCreationParams +} from '../../src/types/types.js'; // Mock Transport class class MockTransport implements Transport { @@ -28,19 +38,66 @@ class MockTransport implements Transport { } } +/** + * Creates a mock ContextInterface for testing. + */ +function createMockContext(args: { + request: JSONRPCRequest; + abortController: AbortController; + sessionId?: string; +}): ContextInterface { + return { + mcpCtx: { + requestId: args.request.id, + method: args.request.method, + _meta: args.request.params?._meta, + sessionId: args.sessionId + }, + requestCtx: { + signal: args.abortController.signal, + authInfo: undefined + }, + taskCtx: undefined, + sendNotification: async () => {}, + sendRequest: async () => ({}) as never + }; +} + +/** + * Creates a test Protocol class with all abstract methods implemented. + */ +function createTestProtocolClass() { + return class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + protected createRequestContext(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): ContextInterface { + return createMockContext({ + request: args.request, + abortController: args.abortController, + sessionId: args.capturedTransport?.sessionId + }); + } + }; +} + describe('Protocol transport handling bug', () => { let protocol: Protocol; let transportA: MockTransport; let transportB: MockTransport; beforeEach(() => { - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - })(); + protocol = new (createTestProtocolClass())(); transportA = new MockTransport('A'); transportB = new MockTransport('B'); @@ -138,14 +195,14 @@ describe('Protocol transport handling bug', () => { }); // Set up handler with variable delay - protocol.setRequestHandler(DelayedRequestSchema, async (request, extra) => { + protocol.setRequestHandler(DelayedRequestSchema, async (request, ctx) => { const delay = request.params?.delay || 0; delays.push(delay); await new Promise(resolve => setTimeout(resolve, delay)); return { - processedBy: `handler-${extra.requestId}`, + processedBy: `handler-${ctx.mcpCtx.requestId}`, delay: delay } as Result; }); diff --git a/packages/middleware/express/package.json b/packages/middleware/express/package.json index 408cf446a..844b3e3ec 100644 --- a/packages/middleware/express/package.json +++ b/packages/middleware/express/package.json @@ -37,8 +37,8 @@ "build": "tsdown", "build:watch": "tsdown --watch", "prepack": "npm run build", - "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", - "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "lint": "eslint src/ && prettier --ignore-path ../../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../../.prettierignore --write .", "check": "npm run typecheck && npm run lint", "test": "vitest run", "test:watch": "vitest" diff --git a/packages/middleware/hono/package.json b/packages/middleware/hono/package.json index 3377c5fb4..afef9b02e 100644 --- a/packages/middleware/hono/package.json +++ b/packages/middleware/hono/package.json @@ -37,8 +37,8 @@ "build": "tsdown", "build:watch": "tsdown --watch", "prepack": "npm run build", - "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", - "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "lint": "eslint src/ && prettier --ignore-path ../../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../../.prettierignore --write .", "check": "npm run typecheck && npm run lint", "test": "vitest run", "test:watch": "vitest" diff --git a/packages/middleware/node/package.json b/packages/middleware/node/package.json index 766346613..5024fa55f 100644 --- a/packages/middleware/node/package.json +++ b/packages/middleware/node/package.json @@ -36,8 +36,8 @@ "build": "tsdown", "build:watch": "tsdown --watch", "prepack": "npm run build", - "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", - "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "lint": "eslint src/ && prettier --ignore-path ../../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../../.prettierignore --write .", "check": "npm run typecheck && npm run lint", "test": "vitest run", "test:watch": "vitest", diff --git a/packages/middleware/node/test/streamableHttp.test.ts b/packages/middleware/node/test/streamableHttp.test.ts index ca7728d88..f4dee8866 100644 --- a/packages/middleware/node/test/streamableHttp.test.ts +++ b/packages/middleware/node/test/streamableHttp.test.ts @@ -12,7 +12,7 @@ import type { JSONRPCResultResponse, RequestId } from '@modelcontextprotocol/core'; -import type { EventId, EventStore, StreamId } from '@modelcontextprotocol/server'; +import type { EventId, EventStore, ServerRequestContext, StreamId } from '@modelcontextprotocol/server'; import { McpServer } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -214,8 +214,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { description: 'A user profile data tool', inputSchema: { active: z.boolean().describe('Profile status') } }, - async ({ active }, { authInfo }): Promise => { - return { content: [{ type: 'text', text: `${active ? 'Active' : 'Inactive'} profile from token: ${authInfo?.token}!` }] }; + async ({ active }, { requestCtx }): Promise => { + return { + content: [ + { type: 'text', text: `${active ? 'Active' : 'Inactive'} profile from token: ${requestCtx.authInfo?.token}!` } + ] + }; } ); @@ -404,11 +408,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { description: 'A simple test tool with request info', inputSchema: { name: z.string().describe('Name to greet') } }, - async ({ name }, { requestInfo }): Promise => { + async ({ name }, ctx): Promise => { // Convert Headers object to plain object for JSON serialization // Headers is a Web API class that doesn't serialize with JSON.stringify + const requestCtx = ctx.requestCtx as ServerRequestContext; const serializedRequestInfo = { - headers: Object.fromEntries(requestInfo?.headers ?? new Headers()) + headers: Object.fromEntries(requestCtx.headers ?? new Headers()) }; return { content: [ @@ -1851,9 +1856,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); // Register a tool that closes its own SSE stream via extra callback - mcpServer.registerTool('close-stream-tool', { description: 'Closes its own stream' }, async extra => { + mcpServer.registerTool('close-stream-tool', { description: 'Closes its own stream' }, async ctx => { // Close the SSE stream for this request - extra.closeSSEStream?.(); + const requestCtx = ctx.requestCtx as ServerRequestContext; + requestCtx.stream.closeSSEStream?.(); streamCloseCalled = true; // Wait before returning so we can observe the stream closure @@ -1918,9 +1924,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Track whether closeSSEStream callback was provided let receivedCloseSSEStream: (() => void) | undefined; - // Register a tool that captures the extra.closeSSEStream callback - mcpServer.registerTool('test-callback-tool', { description: 'Test tool' }, async extra => { - receivedCloseSSEStream = extra.closeSSEStream; + // Register a tool that captures the ctx.requestCtx.stream.closeSSEStream callback + mcpServer.registerTool('test-callback-tool', { description: 'Test tool' }, async ctx => { + const requestCtx = ctx.requestCtx as ServerRequestContext; + receivedCloseSSEStream = requestCtx.stream.closeSSEStream; return { content: [{ type: 'text', text: 'Done' }] }; }); @@ -1977,10 +1984,11 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { let receivedCloseSSEStream: (() => void) | undefined; let receivedCloseStandaloneSSEStream: (() => void) | undefined; - // Register a tool that captures the extra.closeSSEStream callback - mcpServer.registerTool('test-old-version-tool', { description: 'Test tool' }, async extra => { - receivedCloseSSEStream = extra.closeSSEStream; - receivedCloseStandaloneSSEStream = extra.closeStandaloneSSEStream; + // Register a tool that captures the ctx.requestCtx.stream callbacks + mcpServer.registerTool('test-old-version-tool', { description: 'Test tool' }, async ctx => { + const requestCtx = ctx.requestCtx as ServerRequestContext; + receivedCloseSSEStream = requestCtx.stream.closeSSEStream; + receivedCloseStandaloneSSEStream = requestCtx.stream.closeStandaloneSSEStream; return { content: [{ type: 'text', text: 'Done' }] }; }); @@ -2036,9 +2044,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Track whether closeSSEStream callback was provided let receivedCloseSSEStream: (() => void) | undefined; - // Register a tool that captures the extra.closeSSEStream callback - mcpServer.registerTool('test-no-callback-tool', { description: 'Test tool' }, async extra => { - receivedCloseSSEStream = extra.closeSSEStream; + // Register a tool that captures the ctx.requestCtx.stream.closeSSEStream callback + mcpServer.registerTool('test-no-callback-tool', { description: 'Test tool' }, async ctx => { + const requestCtx = ctx.requestCtx as ServerRequestContext; + receivedCloseSSEStream = requestCtx.stream.closeSSEStream; return { content: [{ type: 'text', text: 'Done' }] }; }); @@ -2093,9 +2102,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Track whether closeStandaloneSSEStream callback was provided let receivedCloseStandaloneSSEStream: (() => void) | undefined; - // Register a tool that captures the extra.closeStandaloneSSEStream callback - mcpServer.registerTool('test-standalone-callback-tool', { description: 'Test tool' }, async extra => { - receivedCloseStandaloneSSEStream = extra.closeStandaloneSSEStream; + // Register a tool that captures the ctx.requestCtx.stream.closeStandaloneSSEStream callback + mcpServer.registerTool('test-standalone-callback-tool', { description: 'Test tool' }, async ctx => { + const requestCtx = ctx.requestCtx as ServerRequestContext; + receivedCloseStandaloneSSEStream = requestCtx.stream.closeStandaloneSSEStream; return { content: [{ type: 'text', text: 'Done' }] }; }); @@ -2148,9 +2158,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { baseUrl = result.baseUrl; mcpServer = result.mcpServer; - // Register a tool that closes the standalone SSE stream via extra callback - mcpServer.registerTool('close-standalone-stream-tool', { description: 'Closes standalone stream' }, async extra => { - extra.closeStandaloneSSEStream?.(); + // Register a tool that closes the standalone SSE stream via ctx callback + mcpServer.registerTool('close-standalone-stream-tool', { description: 'Closes standalone stream' }, async ctx => { + const requestCtx = ctx.requestCtx as ServerRequestContext; + requestCtx.stream.closeStandaloneSSEStream?.(); return { content: [{ type: 'text', text: 'Stream closed' }] }; }); @@ -2230,8 +2241,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { mcpServer = result.mcpServer; // Register a tool that closes the standalone SSE stream - mcpServer.registerTool('close-standalone-for-reconnect', { description: 'Closes standalone stream' }, async extra => { - extra.closeStandaloneSSEStream?.(); + mcpServer.registerTool('close-standalone-for-reconnect', { description: 'Closes standalone stream' }, async ctx => { + const requestCtx = ctx.requestCtx as ServerRequestContext; + requestCtx.stream.closeStandaloneSSEStream?.(); return { content: [{ type: 'text', text: 'Stream closed' }] }; }); diff --git a/packages/server/src/experimental/tasks/interfaces.ts b/packages/server/src/experimental/tasks/interfaces.ts index 0b32be213..391b79e49 100644 --- a/packages/server/src/experimental/tasks/interfaces.ts +++ b/packages/server/src/experimental/tasks/interfaces.ts @@ -6,14 +6,15 @@ import type { AnySchema, CallToolResult, - CreateTaskRequestHandlerExtra, CreateTaskResult, GetTaskResult, Result, - TaskRequestHandlerExtra, + ServerNotification, + ServerRequest, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import type { ServerContextInterface } from '../../server/context.js'; import type { BaseToolCallback } from '../../server/mcp.js'; // ============================================================================ @@ -27,7 +28,7 @@ import type { BaseToolCallback } from '../../server/mcp.js'; export type CreateTaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; +> = BaseToolCallback, Args>; /** * Handler for task operations (get, getResult). @@ -36,7 +37,7 @@ export type CreateTaskRequestHandler< export type TaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; +> = BaseToolCallback, Args>; /** * Interface for task-based tool handlers. diff --git a/packages/server/src/experimental/tasks/mcpServer.ts b/packages/server/src/experimental/tasks/mcpServer.ts index 6fd5a6cc5..a8a526d1e 100644 --- a/packages/server/src/experimental/tasks/mcpServer.ts +++ b/packages/server/src/experimental/tasks/mcpServer.ts @@ -7,7 +7,8 @@ import type { AnySchema, TaskToolExecution, ToolAnnotations, ToolExecution, ZodRawShapeCompat } from '@modelcontextprotocol/core'; -import type { AnyToolHandler, McpServer, RegisteredTool } from '../../server/mcp.js'; +import type { AnyToolHandler, McpServer } from '../../server/mcp.js'; +import type { RegisteredTool } from '../../server/registries/toolRegistry.js'; import type { ToolTaskHandler } from './interfaces.js'; /** @@ -55,16 +56,16 @@ export class ExperimentalMcpServerTasks { * inputSchema: { input: z.string() }, * execution: { taskSupport: 'required' } * }, { - * createTask: async (args, extra) => { - * const task = await extra.taskStore.createTask({ ttl: 300000 }); + * createTask: async (args, ctx) => { + * const task = await ctx.taskCtx!.store.createTask({ ttl: 300000 }); * startBackgroundWork(task.taskId, args); * return { task }; * }, - * getTask: async (args, extra) => { - * return extra.taskStore.getTask(extra.taskId); + * getTask: async (args, ctx) => { + * return ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); * }, - * getTaskResult: async (args, extra) => { - * return extra.taskStore.getTaskResult(extra.taskId); + * getTaskResult: async (args, ctx) => { + * return ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); * } * }); * ``` diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 1a8dbf143..df415bf24 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,6 +1,10 @@ +export * from './server/builder.js'; export * from './server/completable.js'; +export * from './server/context.js'; export * from './server/mcp.js'; +export * from './server/middleware.js'; export * from './server/middleware/hostHeaderValidation.js'; +export * from './server/registries/index.js'; export * from './server/server.js'; export * from './server/stdio.js'; export * from './server/streamableHttp.js'; diff --git a/packages/server/src/server/builder.ts b/packages/server/src/server/builder.ts new file mode 100644 index 000000000..978abe4ec --- /dev/null +++ b/packages/server/src/server/builder.ts @@ -0,0 +1,428 @@ +/** + * McpServer Builder + * + * Provides a fluent API for configuring and creating McpServer instances. + * The builder is an additive convenience layer - the existing constructor + * API remains available for users who prefer it. + * + * @example + * ```typescript + * const server = McpServer.builder() + * .name('my-server') + * .version('1.0.0') + * .useMiddleware(loggingMiddleware) + * .tool('greet', { inputSchema: { name: z.string() } }, handler) + * .build(); + * ``` + */ + +import type { ToolAnnotations, ToolExecution, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import { objectFromShape } from '@modelcontextprotocol/core'; + +import type { PromptCallback, ReadResourceCallback } from '../types/types.js'; +import type { McpServer, ResourceMetadata, ToolCallback } from './mcp.js'; +import type { PromptMiddleware, ResourceMiddleware, ToolMiddleware, UniversalMiddleware } from './middleware.js'; +import { PromptRegistry } from './registries/promptRegistry.js'; +import { ResourceRegistry } from './registries/resourceRegistry.js'; +import { ToolRegistry } from './registries/toolRegistry.js'; +import type { ServerOptions as BaseServerOptions } from './server.js'; + +// ZodRawShape for backward compatibility +type ZodRawShape = ZodRawShapeCompat; + +/** + * Extended server options including builder-specific options + */ +export interface McpServerBuilderOptions extends BaseServerOptions { + /** Server name */ + name?: string; + /** Server version */ + version?: string; +} + +/** + * Error handler type for application errors + */ +export type OnErrorHandler = (error: Error, ctx: ErrorContext) => OnErrorReturn | void | Promise; + +/** + * Error handler type for protocol errors + */ +export type OnProtocolErrorHandler = ( + error: Error, + ctx: ErrorContext +) => OnProtocolErrorReturn | void | Promise; + +/** + * Return type for onError handler + */ +export type OnErrorReturn = string | { code?: number; message?: string; data?: unknown } | Error; + +/** + * Return type for onProtocolError handler (code cannot be changed) + */ +export type OnProtocolErrorReturn = string | { message?: string; data?: unknown }; + +/** + * Context provided to error handlers + */ +export interface ErrorContext { + type: 'tool' | 'resource' | 'prompt' | 'protocol'; + name?: string; + method: string; + requestId: string; +} + +/** + * Fluent builder for McpServer instances. + * + * Provides a declarative, chainable API for configuring servers. + * All configuration is collected and applied when build() is called. + */ +export class McpServerBuilder { + private _name?: string; + private _version?: string; + private _options: McpServerBuilderOptions = {}; + + // Global middleware + private _universalMiddleware: UniversalMiddleware[] = []; + private _toolMiddleware: ToolMiddleware[] = []; + private _resourceMiddleware: ResourceMiddleware[] = []; + private _promptMiddleware: PromptMiddleware[] = []; + + // Registries (created without callbacks - McpServer will bind them later) + private _toolRegistry = new ToolRegistry(); + private _resourceRegistry = new ResourceRegistry(); + private _promptRegistry = new PromptRegistry(); + + // Per-item middleware (keyed by name/uri) + private _perToolMiddleware = new Map(); + private _perResourceMiddleware = new Map(); + private _perPromptMiddleware = new Map(); + + // Error handlers + private _onError?: OnErrorHandler; + private _onProtocolError?: OnProtocolErrorHandler; + + /** + * Sets the server name. + */ + name(name: string): this { + this._name = name; + return this; + } + + /** + * Sets the server version. + */ + version(version: string): this { + this._version = version; + return this; + } + + /** + * Sets server options. + */ + options(options: McpServerBuilderOptions): this { + this._options = { ...this._options, ...options }; + return this; + } + + /** + * Adds universal middleware that runs for all request types. + */ + useMiddleware(middleware: UniversalMiddleware): this { + this._universalMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for tool calls. + */ + useToolMiddleware(middleware: ToolMiddleware): this { + this._toolMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for resource reads. + */ + useResourceMiddleware(middleware: ResourceMiddleware): this { + this._resourceMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for prompt requests. + */ + usePromptMiddleware(middleware: PromptMiddleware): this { + this._promptMiddleware.push(middleware); + return this; + } + + /** + * Registers a tool with the server. + * + * @example + * ```typescript + * .tool('greet', { + * description: 'Greet a user', + * inputSchema: { name: z.string() } + * }, async ({ name }) => { + * return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + * }) + * ``` + */ + tool( + name: string, + config: { + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: ZodRawShape; + middleware?: ToolMiddleware; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + }, + handler: ToolCallback + ): this { + this._toolRegistry.register({ + name, + title: config.title, + description: config.description, + inputSchema: config.inputSchema ? objectFromShape(config.inputSchema) : undefined, + outputSchema: config.outputSchema ? objectFromShape(config.outputSchema) : undefined, + annotations: config.annotations, + execution: config.execution, + _meta: config._meta, + handler: handler as ToolCallback + }); + + // Store per-tool middleware if provided + if (config.middleware) { + this._perToolMiddleware.set(name, config.middleware); + } + + return this; + } + + /** + * Registers a resource with the server. + * + * @example + * ```typescript + * .resource('config', 'file:///config', { + * description: 'Configuration file' + * }, async (uri) => { + * return { contents: [{ uri, mimeType: 'application/json', text: '{}' }] }; + * }) + * ``` + */ + resource( + name: string, + uri: string, + config: { + title?: string; + description?: string; + mimeType?: string; + metadata?: ResourceMetadata; + middleware?: ResourceMiddleware; + }, + readCallback: ReadResourceCallback + ): this { + this._resourceRegistry.register({ + name, + uri, + title: config.title, + description: config.description, + mimeType: config.mimeType, + metadata: config.metadata, + readCallback + }); + + // Store per-resource middleware if provided + if (config.middleware) { + this._perResourceMiddleware.set(uri, config.middleware); + } + + return this; + } + + /** + * Registers a prompt with the server. + * + * @example + * ```typescript + * .prompt('summarize', { + * description: 'Summarize text', + * argsSchema: { text: z.string() } + * }, async ({ text }) => { + * return { messages: [{ role: 'user', content: { type: 'text', text } }] }; + * }) + * ``` + */ + prompt( + name: string, + config: { + title?: string; + description?: string; + argsSchema?: Args; + middleware?: PromptMiddleware; + }, + callback: PromptCallback + ): this { + this._promptRegistry.register({ + name, + title: config.title, + description: config.description, + argsSchema: config.argsSchema, + callback: callback as PromptCallback + }); + + // Store per-prompt middleware if provided + if (config.middleware) { + this._perPromptMiddleware.set(name, config.middleware); + } + + return this; + } + + /** + * Sets the application error handler. + * Called when a handler throws an error. + */ + onError(handler: OnErrorHandler): this { + this._onError = handler; + return this; + } + + /** + * Sets the protocol error handler. + * Called for protocol-level errors (parse, method not found, etc.) + */ + onProtocolError(handler: OnProtocolErrorHandler): this { + this._onProtocolError = handler; + return this; + } + + /** + * Gets the collected configuration (for debugging/testing). + */ + getConfig(): { + name?: string; + version?: string; + options: McpServerBuilderOptions; + toolCount: number; + resourceCount: number; + promptCount: number; + middlewareCount: number; + } { + return { + name: this._name, + version: this._version, + options: this._options, + toolCount: this._toolRegistry.size, + resourceCount: this._resourceRegistry.size, + promptCount: this._promptRegistry.size, + middlewareCount: + this._universalMiddleware.length + + this._toolMiddleware.length + + this._resourceMiddleware.length + + this._promptMiddleware.length + }; + } + + /** + * Builds and returns the configured McpServer instance. + */ + build(): McpServer { + if (!this._name) { + throw new Error('Server name is required. Use .name() to set it.'); + } + if (!this._version) { + throw new Error('Server version is required. Use .version() to set it.'); + } + + const result: BuilderResult = { + serverInfo: { + name: this._name, + version: this._version + }, + options: this._options, + middleware: { + universal: this._universalMiddleware, + tool: this._toolMiddleware, + resource: this._resourceMiddleware, + prompt: this._promptMiddleware + }, + registries: { + tools: this._toolRegistry, + resources: this._resourceRegistry, + prompts: this._promptRegistry + }, + perItemMiddleware: { + tools: this._perToolMiddleware, + resources: this._perResourceMiddleware, + prompts: this._perPromptMiddleware + }, + errorHandlers: { + onError: this._onError, + onProtocolError: this._onProtocolError + } + }; + + // Dynamically import McpServer to create the instance + // eslint-disable-next-line @typescript-eslint/no-require-imports + const { McpServer: McpServerClass } = require('./mcp.js'); + return McpServerClass.fromBuilderResult(result); + } +} + +/** + * Result of building the server configuration. + * Used to create the actual McpServer instance. + */ +export interface BuilderResult { + serverInfo: { + name: string; + version: string; + }; + options: McpServerBuilderOptions; + middleware: { + universal: UniversalMiddleware[]; + tool: ToolMiddleware[]; + resource: ResourceMiddleware[]; + prompt: PromptMiddleware[]; + }; + registries: { + tools: ToolRegistry; + resources: ResourceRegistry; + prompts: PromptRegistry; + }; + perItemMiddleware: { + tools: Map; + resources: Map; + prompts: Map; + }; + errorHandlers: { + onError?: OnErrorHandler; + onProtocolError?: OnProtocolErrorHandler; + }; +} + +/** + * Creates a new McpServerBuilder instance. + * + * @example + * ```typescript + * const server = createServerBuilder() + * .name('my-server') + * .version('1.0.0') + * .tool('greet', { inputSchema: { name: z.string() } }, handler) + * .build(); + * ``` + */ +export function createServerBuilder(): McpServerBuilder { + return new McpServerBuilder(); +} diff --git a/packages/server/src/server/context.ts b/packages/server/src/server/context.ts new file mode 100644 index 000000000..5913cff30 --- /dev/null +++ b/packages/server/src/server/context.ts @@ -0,0 +1,242 @@ +import type { + BaseRequestContext, + ContextInterface, + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + JSONRPCRequest, + LoggingMessageNotification, + McpContext, + Notification, + Request, + RequestOptions, + Result, + ServerNotification, + ServerRequest, + ServerResult, + TaskContext +} from '@modelcontextprotocol/core'; +import { BaseContext, ElicitResultSchema } from '@modelcontextprotocol/core'; + +import type { Server } from './server.js'; + +/** + * Server-specific request context with HTTP request details. + * Extends BaseRequestContext with fields only available on the server side. + */ +export type ServerRequestContext = BaseRequestContext & { + /** + * The URI of the incoming HTTP request. + */ + uri: URL; + /** + * The headers of the incoming HTTP request. + */ + headers: Headers; + /** + * Stream control methods for SSE connections. + */ + stream: { + /** + * Closes the SSE stream for this request, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior during long-running operations. + */ + closeSSEStream: (() => void) | undefined; + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior for server-initiated notifications. + */ + closeStandaloneSSEStream: (() => void) | undefined; + }; +}; + +/** + * Interface for sending logging messages to the client via {@link LoggingMessageNotification}. + */ +export interface LoggingMessageNotificationSenderInterface { + /** + * Sends a logging message to the client. + */ + log(params: LoggingMessageNotification['params'], sessionId?: string): Promise; + /** + * Sends a debug log message to the client. + */ + debug(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends an info log message to the client. + */ + info(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends a warning log message to the client. + */ + warning(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends an error log message to the client. + */ + error(message: string, extraLogData?: Record, sessionId?: string): Promise; +} + +export class ServerLogger implements LoggingMessageNotificationSenderInterface { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + constructor(private readonly server: Server) {} + + /** + * Sends a logging message. + */ + public async log(params: LoggingMessageNotification['params'], sessionId?: string) { + await this.server.sendLoggingMessage(params, sessionId); + } + + /** + * Sends a debug log message. + */ + public async debug(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'debug', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an info log message. + */ + public async info(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'info', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends a warning log message. + */ + public async warning(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'warning', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an error log message. + */ + public async error(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'error', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } +} + +/** + * Server-specific context interface extending the base ContextInterface. + * Includes server-specific methods for logging, elicitation, and sampling. + */ +export interface ServerContextInterface + extends ContextInterface { + /** + * Logger for sending logging messages to the client. + */ + loggingNotification: LoggingMessageNotificationSenderInterface; + /** + * Sends an elicitation request to the client. + */ + elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise; + /** + * Sends a sampling request to the client. + */ + requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; +} + +/** + * A context object that is passed to server-side request handlers. + * Provides access to MCP context, request context, task context, and server-specific methods. + */ +export class ServerContext< + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + ResultT extends Result = Result + > + extends BaseContext + implements ServerContextInterface +{ + private readonly server: Server; + + /** + * Logger for sending logging messages to the client. + */ + public readonly loggingNotification: LoggingMessageNotificationSenderInterface; + + constructor(args: { + server: Server; + request: JSONRPCRequest; + mcpContext: McpContext; + requestCtx: ServerRequestContext; + task?: TaskContext; + }) { + super({ + request: args.request, + mcpContext: args.mcpContext, + requestCtx: args.requestCtx, + task: args.task + }); + this.server = args.server; + this.loggingNotification = new ServerLogger(args.server); + } + + /** + * Returns the server instance for sending notifications and requests. + */ + protected getProtocol(): Server { + return this.server; + } + + /** + * Sends a sampling request to the client. + */ + public requestSampling(params: CreateMessageRequest['params'], options?: RequestOptions) { + return this.server.createMessage(params, options); + } + + /** + * Sends an elicitation request to the client. + */ + public async elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise { + const request: ElicitRequest = { + method: 'elicitation/create', + params + }; + return await this.server.request(request, ElicitResultSchema, { ...options, relatedRequestId: this.mcpCtx.requestId }); + } +} diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 975cca257..bf53e7dbe 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -8,29 +8,27 @@ import type { CompleteRequestResourceTemplate, CompleteResult, CreateTaskResult, + ErrorInterceptionContext, + ErrorInterceptionResult, GetPromptResult, Implementation, ListPromptsResult, ListResourcesResult, ListToolsResult, LoggingMessageNotification, - Prompt, - PromptArgument, PromptReference, - ReadResourceResult, - RequestHandlerExtra, + ProtocolPlugin, Resource, ResourceTemplateReference, Result, SchemaOutput, ServerNotification, ServerRequest, + ServerResult, ShapeOutput, - Tool, ToolAnnotations, ToolExecution, Transport, - Variables, ZodRawShapeCompat } from '@modelcontextprotocol/core'; import { @@ -43,29 +41,66 @@ import { getObjectShape, getParseErrorMessage, GetPromptRequestSchema, - getSchemaDescription, - isSchemaOptional, + isProtocolError, ListPromptsRequestSchema, ListResourcesRequestSchema, ListResourceTemplatesRequestSchema, ListToolsRequestSchema, - McpError, normalizeObjectSchema, objectFromShape, + ProtocolError, ReadResourceRequestSchema, safeParseAsync, - toJsonSchemaCompat, - UriTemplate, - validateAndWarnToolName + UriTemplate } from '@modelcontextprotocol/core'; import { ZodOptional } from 'zod'; import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; +import type { PromptArgsRawShape, PromptCallback, ReadResourceCallback, ReadResourceTemplateCallback } from '../types/types.js'; +import type { + BuilderResult, + ErrorContext, + OnErrorHandler, + OnErrorReturn, + OnProtocolErrorHandler, + OnProtocolErrorReturn +} from './builder.js'; +import { McpServerBuilder } from './builder.js'; import { getCompleter, isCompletable } from './completable.js'; +import type { ServerContextInterface } from './context.js'; +import type { + PromptContext, + PromptMiddleware, + ResourceContext, + ResourceMiddleware, + ToolContext, + ToolMiddleware, + UniversalMiddleware +} from './middleware.js'; +import { MiddlewareManager } from './middleware.js'; +import type { RegisteredPrompt } from './registries/promptRegistry.js'; +import { PromptRegistry } from './registries/promptRegistry.js'; +import type { RegisteredResourceEntity, RegisteredResourceTemplateEntity } from './registries/resourceRegistry.js'; +import { ResourceRegistry, ResourceTemplateRegistry } from './registries/resourceRegistry.js'; +import type { RegisteredTool } from './registries/toolRegistry.js'; +import { ToolRegistry } from './registries/toolRegistry.js'; import type { ServerOptions } from './server.js'; import { Server } from './server.js'; +/** + * Internal options for McpServer that can include pre-created registries. + * Used by fromBuilderResult to pass registries from the builder. + */ +interface InternalMcpServerOptions extends ServerOptions { + /** Pre-created tool registry (callbacks will be bound by McpServer) */ + _toolRegistry?: ToolRegistry; + /** Pre-created resource registry (callbacks will be bound by McpServer) */ + _resourceRegistry?: ResourceRegistry; + /** Pre-created prompt registry (callbacks will be bound by McpServer) */ + _promptRegistry?: PromptRegistry; +} + /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. * For advanced usage (like sending notifications or setting custom request handlers), use the underlying @@ -77,16 +112,200 @@ export class McpServer { */ public readonly server: Server; - private _registeredResources: { [uri: string]: RegisteredResource } = {}; - private _registeredResourceTemplates: { - [name: string]: RegisteredResourceTemplate; - } = {}; - private _registeredTools: { [name: string]: RegisteredTool } = {}; - private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + private readonly _toolRegistry: ToolRegistry; + private readonly _resourceRegistry: ResourceRegistry; + private readonly _resourceTemplateRegistry: ResourceTemplateRegistry; + private readonly _promptRegistry: PromptRegistry; + private readonly _middleware: MiddlewareManager; private _experimental?: { tasks: ExperimentalMcpServerTasks }; + // Error handlers (single callback pattern, not event-based) + private _onErrorHandler?: OnErrorHandler; + private _onProtocolErrorHandler?: OnProtocolErrorHandler; + constructor(serverInfo: Implementation, options?: ServerOptions) { + const internalOptions = options as InternalMcpServerOptions | undefined; this.server = new Server(serverInfo, options); + + // Use pre-created registries if provided, otherwise create new ones + // Either way, bind the notification callbacks to this server instance + this._toolRegistry = internalOptions?._toolRegistry ?? new ToolRegistry(); + this._toolRegistry.setNotifyCallback(() => this.sendToolListChanged()); + + this._resourceRegistry = internalOptions?._resourceRegistry ?? new ResourceRegistry(); + this._resourceRegistry.setNotifyCallback(() => this.sendResourceListChanged()); + + // Resource template registry is always created fresh (not passed from builder) + this._resourceTemplateRegistry = new ResourceTemplateRegistry(); + this._resourceTemplateRegistry.setNotifyCallback(() => this.sendResourceListChanged()); + + this._promptRegistry = internalOptions?._promptRegistry ?? new PromptRegistry(); + this._promptRegistry.setNotifyCallback(() => this.sendPromptListChanged()); + + // Initialize middleware manager + this._middleware = new MiddlewareManager(); + + // If registries were pre-populated, set up request handlers + if (this._toolRegistry.size > 0) { + this.setToolRequestHandlers(); + } + if (this._resourceRegistry.size > 0) { + this.setResourceRequestHandlers(); + } + if (this._promptRegistry.size > 0) { + this.setPromptRequestHandlers(); + } + } + + /** + * Gets the middleware manager for advanced middleware configuration. + */ + get middleware(): MiddlewareManager { + return this._middleware; + } + + /** + * Registers universal middleware that runs for all request types (tools, resources, prompts). + * + * @param middleware - The middleware function to register + * @returns This McpServer instance for chaining + */ + useMiddleware(middleware: UniversalMiddleware): this { + this._middleware.useMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for tool calls. + * + * @param middleware - The tool middleware function to register + * @returns This McpServer instance for chaining + */ + useToolMiddleware(middleware: ToolMiddleware): this { + this._middleware.useToolMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for resource reads. + * + * @param middleware - The resource middleware function to register + * @returns This McpServer instance for chaining + */ + useResourceMiddleware(middleware: ResourceMiddleware): this { + this._middleware.useResourceMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for prompt requests. + * + * @param middleware - The prompt middleware function to register + * @returns This McpServer instance for chaining + */ + usePromptMiddleware(middleware: PromptMiddleware): this { + this._middleware.usePromptMiddleware(middleware); + return this; + } + + /** + * Gets the tool registry for advanced tool management. + */ + get tools(): ToolRegistry { + return this._toolRegistry; + } + + /** + * Gets the resource registry for advanced resource management. + */ + get resources(): ResourceRegistry { + return this._resourceRegistry; + } + + /** + * Gets the resource template registry for advanced template management. + */ + get resourceTemplates(): ResourceTemplateRegistry { + return this._resourceTemplateRegistry; + } + + /** + * Gets the prompt registry for advanced prompt management. + */ + get prompts(): PromptRegistry { + return this._promptRegistry; + } + + /** + * Creates a new McpServerBuilder for fluent configuration. + * + * @example + * ```typescript + * const server = McpServer.builder() + * .name('my-server') + * .version('1.0.0') + * .tool('greet', { name: z.string() }, async ({ name }) => ({ + * content: [{ type: 'text', text: `Hello, ${name}!` }] + * })) + * .build(); + * ``` + */ + static builder(): McpServerBuilder { + return new McpServerBuilder(); + } + + /** + * Creates an McpServer from a BuilderResult configuration. + * + * @param result - The result from McpServerBuilder.build() + * @returns A configured McpServer instance + */ + static fromBuilderResult(result: BuilderResult): McpServer { + // Create server with pre-populated registries from the builder + // The constructor will bind notification callbacks to the registries + const internalOptions: InternalMcpServerOptions = { + ...result.options, + _toolRegistry: result.registries.tools, + _resourceRegistry: result.registries.resources, + _promptRegistry: result.registries.prompts + }; + + const server = new McpServer(result.serverInfo, internalOptions); + + // Wire up error handlers + if (result.errorHandlers.onError) { + server.onError(result.errorHandlers.onError); + } + if (result.errorHandlers.onProtocolError) { + server.onProtocolError(result.errorHandlers.onProtocolError); + } + + // Apply global middleware from builder + for (const middleware of result.middleware.universal) { + server.useMiddleware(middleware); + } + for (const middleware of result.middleware.tool) { + server.useToolMiddleware(middleware); + } + for (const middleware of result.middleware.resource) { + server.useResourceMiddleware(middleware); + } + for (const middleware of result.middleware.prompt) { + server.usePromptMiddleware(middleware); + } + + // Apply per-item middleware + for (const [name, middleware] of result.perItemMiddleware.tools) { + server._middleware.useToolMiddlewareFor(name, middleware); + } + for (const [uri, middleware] of result.perItemMiddleware.resources) { + server._middleware.useResourceMiddlewareFor(uri, middleware); + } + for (const [name, middleware] of result.perItemMiddleware.prompts) { + server._middleware.usePromptMiddlewareFor(name, middleware); + } + + return server; } /** @@ -140,50 +359,18 @@ export class McpServer { this.server.setRequestHandler( ListToolsRequestSchema, (): ListToolsResult => ({ - tools: Object.entries(this._registeredTools) - .filter(([, tool]) => tool.enabled) - .map(([name, tool]): Tool => { - const toolDefinition: Tool = { - name, - title: tool.title, - description: tool.description, - inputSchema: (() => { - const obj = normalizeObjectSchema(tool.inputSchema); - return obj - ? (toJsonSchemaCompat(obj, { - strictUnions: true, - pipeStrategy: 'input' - }) as Tool['inputSchema']) - : EMPTY_OBJECT_JSON_SCHEMA; - })(), - annotations: tool.annotations, - execution: tool.execution, - _meta: tool._meta - }; - - if (tool.outputSchema) { - const obj = normalizeObjectSchema(tool.outputSchema); - if (obj) { - toolDefinition.outputSchema = toJsonSchemaCompat(obj, { - strictUnions: true, - pipeStrategy: 'output' - }) as Tool['outputSchema']; - } - } - - return toolDefinition; - }) + tools: this._toolRegistry.getProtocolTools() }) ); - this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { + this.server.setRequestHandler(CallToolRequestSchema, async (request, ctx): Promise => { try { - const tool = this._registeredTools[request.params.name]; + const tool = this._toolRegistry.getTool(request.params.name); if (!tool) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); + throw ProtocolError.invalidParams(`Tool ${request.params.name} not found`); } if (!tool.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); + throw ProtocolError.invalidParams(`Tool ${request.params.name} disabled`); } const isTaskRequest = !!request.params.task; @@ -192,39 +379,55 @@ export class McpServer { // Validate task hint configuration if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { - throw new McpError( - ErrorCode.InternalError, + throw ProtocolError.internalError( `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` ); } // Handle taskSupport 'required' without task augmentation if (taskSupport === 'required' && !isTaskRequest) { - throw new McpError( - ErrorCode.MethodNotFound, - `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` - ); + throw ProtocolError.methodNotFound(`Tool ${request.params.name} requires task augmentation (taskSupport: 'required')`); } // Handle taskSupport 'optional' without task augmentation - automatic polling if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { - return await this.handleAutomaticTaskPolling(tool, request, extra); + return await this.handleAutomaticTaskPolling(tool, request, ctx); } - // Normal execution path - const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const result = await this.executeToolHandler(tool, args, extra); + // Build middleware context + const middlewareCtx: ToolContext = { + type: 'tool', + requestId: String(ctx.mcpCtx.requestId), + authInfo: ctx.requestCtx.authInfo, + signal: ctx.requestCtx.signal, + name: request.params.name, + args: request.params.arguments + }; + + // Execute with middleware (including per-tool middleware if registered) + const perToolMiddleware = this._middleware.getToolMiddlewareFor(request.params.name); + const result = await this._middleware.executeToolMiddleware( + middlewareCtx, + async (mwCtx, modifiedArgs) => { + const argsToUse = modifiedArgs ?? mwCtx.args; + const validatedArgs = await this.validateToolInput(tool, argsToUse, request.params.name); + const handlerResult = await this.executeToolHandler(tool, validatedArgs, ctx); + + // Return CreateTaskResult immediately for task requests + if (isTaskRequest) { + return handlerResult as CallToolResult; + } - // Return CreateTaskResult immediately for task requests - if (isTaskRequest) { - return result; - } + // Validate output schema for non-task requests + await this.validateToolOutput(tool, handlerResult, request.params.name); + return handlerResult as CallToolResult; + }, + perToolMiddleware + ); - // Validate output schema for non-task requests - await this.validateToolOutput(tool, result, request.params.name); return result; } catch (error) { - if (error instanceof McpError && error.code === ErrorCode.UrlElicitationRequired) { + if (isProtocolError(error) && error.code === ErrorCode.UrlElicitationRequired) { throw error; // Return the error to the caller without wrapping in CallToolResult } return this.createToolError(error instanceof Error ? error.message : String(error)); @@ -275,7 +478,7 @@ export class McpServer { if (!parseResult.success) { const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Input validation error: Invalid arguments for tool ${toolName}: ${errorMessage}`); + throw ProtocolError.invalidParams(`Input validation error: Invalid arguments for tool ${toolName}: ${errorMessage}`); } return parseResult.data as unknown as Args; @@ -299,8 +502,7 @@ export class McpServer { } if (!result.structuredContent) { - throw new McpError( - ErrorCode.InvalidParams, + throw ProtocolError.invalidParams( `Output validation error: Tool ${toolName} has an output schema but no structured content was provided` ); } @@ -311,10 +513,7 @@ export class McpServer { if (!parseResult.success) { const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; const errorMessage = getParseErrorMessage(error); - throw new McpError( - ErrorCode.InvalidParams, - `Output validation error: Invalid structured content for tool ${toolName}: ${errorMessage}` - ); + throw ProtocolError.invalidParams(`Output validation error: Invalid structured content for tool ${toolName}: ${errorMessage}`); } } @@ -324,36 +523,36 @@ export class McpServer { private async executeToolHandler( tool: RegisteredTool, args: unknown, - extra: RequestHandlerExtra + ctx: ServerContextInterface ): Promise { const handler = tool.handler as AnyToolHandler; const isTaskHandler = 'createTask' in handler; if (isTaskHandler) { - if (!extra.taskStore) { + if (!ctx.taskCtx?.store) { throw new Error('No task store provided.'); } - const taskExtra = { ...extra, taskStore: extra.taskStore }; + const taskCtx = ctx; if (tool.inputSchema) { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); + return await Promise.resolve(typedHandler.createTask(args as any, ctx)); } else { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler.createTask as any)(taskExtra)); + return await Promise.resolve((typedHandler.createTask as any)(taskCtx)); } } if (tool.inputSchema) { const typedHandler = handler as ToolCallback; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler(args as any, extra)); + return await Promise.resolve(typedHandler(args as any, ctx)); } else { const typedHandler = handler as ToolCallback; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler as any)(extra)); + return await Promise.resolve((typedHandler as any)(ctx)); } } @@ -363,21 +562,20 @@ export class McpServer { private async handleAutomaticTaskPolling( tool: RegisteredTool, request: RequestT, - extra: RequestHandlerExtra + ctx: ServerContextInterface ): Promise { - if (!extra.taskStore) { + if (!ctx.taskCtx?.store) { throw new Error('No task store provided for task-capable tool.'); } // Validate input and create task const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); const handler = tool.handler as ToolTaskHandler; - const taskExtra = { ...extra, taskStore: extra.taskStore }; const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined - ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) + ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, ctx)) : // eslint-disable-next-line @typescript-eslint/no-explicit-any - await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(taskExtra)); + await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(ctx)); // Poll until completion const taskId = createTaskResult.task.taskId; @@ -386,15 +584,12 @@ export class McpServer { while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { await new Promise(resolve => setTimeout(resolve, pollInterval)); - const updatedTask = await extra.taskStore.getTask(taskId); - if (!updatedTask) { - throw new McpError(ErrorCode.InternalError, `Task ${taskId} not found during polling`); - } + const updatedTask = await ctx.taskCtx!.store.getTask(taskId); task = updatedTask; } // Return the final result - return (await extra.taskStore.getTaskResult(taskId)) as CallToolResult; + return (await ctx.taskCtx!.store.getTaskResult(taskId)) as CallToolResult; } private _completionHandlerInitialized = false; @@ -423,7 +618,7 @@ export class McpServer { } default: { - throw new McpError(ErrorCode.InvalidParams, `Invalid completion reference: ${request.params.ref}`); + throw ProtocolError.invalidParams(`Invalid completion reference: ${request.params.ref}`); } } }); @@ -432,13 +627,13 @@ export class McpServer { } private async handlePromptCompletion(request: CompleteRequestPrompt, ref: PromptReference): Promise { - const prompt = this._registeredPrompts[ref.name]; + const prompt = this._promptRegistry.getPrompt(ref.name); if (!prompt) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} not found`); + throw ProtocolError.invalidParams(`Prompt ${ref.name} not found`); } if (!prompt.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} disabled`); + throw ProtocolError.invalidParams(`Prompt ${ref.name} disabled`); } if (!prompt.argsSchema) { @@ -463,18 +658,18 @@ export class McpServer { request: CompleteRequestResourceTemplate, ref: ResourceTemplateReference ): Promise { - const template = Object.values(this._registeredResourceTemplates).find(t => t.resourceTemplate.uriTemplate.toString() === ref.uri); + const template = this._resourceTemplateRegistry.values().find(t => t.template.uriTemplate.toString() === ref.uri); if (!template) { - if (this._registeredResources[ref.uri]) { + if (this._resourceRegistry.getResource(ref.uri)) { // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). return EMPTY_COMPLETION_RESULT; } - throw new McpError(ErrorCode.InvalidParams, `Resource template ${request.params.ref.uri} not found`); + throw ProtocolError.invalidParams(`Resource template ${request.params.ref.uri} not found`); } - const completer = template.resourceTemplate.completeCallback(request.params.argument.name); + const completer = template.template.completeCallback(request.params.argument.name); if (!completer) { return EMPTY_COMPLETION_RESULT; } @@ -500,22 +695,16 @@ export class McpServer { } }); - this.server.setRequestHandler(ListResourcesRequestSchema, async (request, extra) => { - const resources = Object.entries(this._registeredResources) - .filter(([_, resource]) => resource.enabled) - .map(([uri, resource]) => ({ - uri, - name: resource.name, - ...resource.metadata - })); + this.server.setRequestHandler(ListResourcesRequestSchema, async (request, ctx) => { + const resources = this._resourceRegistry.getProtocolResources(); const templateResources: Resource[] = []; - for (const template of Object.values(this._registeredResourceTemplates)) { - if (!template.resourceTemplate.listCallback) { + for (const template of this._resourceTemplateRegistry.getEnabled()) { + if (!template.template.listCallback) { continue; } - const result = await template.resourceTemplate.listCallback(extra); + const result = await template.template.listCallback(ctx); for (const resource of result.resources) { templateResources.push({ ...template.metadata, @@ -528,37 +717,60 @@ export class McpServer { return { resources: [...resources, ...templateResources] }; }); - this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { - const resourceTemplates = Object.entries(this._registeredResourceTemplates).map(([name, template]) => ({ - name, - uriTemplate: template.resourceTemplate.uriTemplate.toString(), - ...template.metadata - })); - - return { resourceTemplates }; - }); + this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => ({ + resourceTemplates: this._resourceTemplateRegistry.getProtocolResourceTemplates() + })); - this.server.setRequestHandler(ReadResourceRequestSchema, async (request, extra) => { + this.server.setRequestHandler(ReadResourceRequestSchema, async (request, ctx) => { const uri = new URL(request.params.uri); // First check for exact resource match - const resource = this._registeredResources[uri.toString()]; + const resource = this._resourceRegistry.getResource(uri.toString()); if (resource) { if (!resource.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); + throw ProtocolError.invalidParams(`Resource ${uri} disabled`); } - return resource.readCallback(uri, extra); + + // Build middleware context + const middlewareCtx: ResourceContext = { + type: 'resource', + requestId: String(ctx.mcpCtx.requestId), + authInfo: ctx.requestCtx.authInfo, + signal: ctx.requestCtx.signal, + uri: uri.toString() + }; + + // Execute with middleware (including per-resource middleware if registered) + const perResourceMiddleware = this._middleware.getResourceMiddlewareFor(uri.toString()); + return this._middleware.executeResourceMiddleware( + middlewareCtx, + async (mwCtx, modifiedUri) => { + const uriToUse = modifiedUri ? new URL(modifiedUri) : uri; + return resource.readCallback(uriToUse, ctx); + }, + perResourceMiddleware + ); } // Then check templates - for (const template of Object.values(this._registeredResourceTemplates)) { - const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); - if (variables) { - return template.readCallback(uri, variables, extra); - } + const match = this._resourceTemplateRegistry.findMatchingTemplate(uri.toString()); + if (match) { + // Build middleware context for template + const middlewareCtx: ResourceContext = { + type: 'resource', + requestId: String(ctx.mcpCtx.requestId), + authInfo: ctx.requestCtx.authInfo, + signal: ctx.requestCtx.signal, + uri: uri.toString() + }; + + // Execute with middleware (templates don't have per-item middleware from builder) + return this._middleware.executeResourceMiddleware(middlewareCtx, async () => { + return match.template.readCallback(uri, match.variables, ctx); + }); } - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); + throw ProtocolError.invalidParams(`Resource ${uri} not found`); }); this._resourceHandlersInitialized = true; @@ -583,274 +795,124 @@ export class McpServer { this.server.setRequestHandler( ListPromptsRequestSchema, (): ListPromptsResult => ({ - prompts: Object.entries(this._registeredPrompts) - .filter(([, prompt]) => prompt.enabled) - .map(([name, prompt]): Prompt => { - return { - name, - title: prompt.title, - description: prompt.description, - arguments: prompt.argsSchema ? promptArgumentsFromSchema(prompt.argsSchema) : undefined - }; - }) + prompts: this._promptRegistry.getProtocolPrompts() }) ); - this.server.setRequestHandler(GetPromptRequestSchema, async (request, extra): Promise => { - const prompt = this._registeredPrompts[request.params.name]; + this.server.setRequestHandler(GetPromptRequestSchema, async (request, ctx): Promise => { + const prompt = this._promptRegistry.getPrompt(request.params.name); if (!prompt) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); + throw ProtocolError.invalidParams(`Prompt ${request.params.name} not found`); } if (!prompt.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); + throw ProtocolError.invalidParams(`Prompt ${request.params.name} disabled`); } - if (prompt.argsSchema) { - const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; - const parseResult = await safeParseAsync(argsObj, request.params.arguments); - if (!parseResult.success) { - const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; - const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); - } + // Build middleware context + const middlewareCtx: PromptContext = { + type: 'prompt', + requestId: String(ctx.mcpCtx.requestId), + authInfo: ctx.requestCtx.authInfo, + signal: ctx.requestCtx.signal, + name: request.params.name, + args: request.params.arguments + }; - const args = parseResult.data; - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, extra)); - } else { - const cb = prompt.callback as PromptCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((cb as any)(extra)); - } + // Execute with middleware (including per-prompt middleware if registered) + const perPromptMiddleware = this._middleware.getPromptMiddlewareFor(request.params.name); + return this._middleware.executePromptMiddleware( + middlewareCtx, + async (mwCtx, modifiedArgs) => { + const argsToUse = modifiedArgs ?? mwCtx.args; + + if (prompt.argsSchema) { + const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; + const parseResult = await safeParseAsync(argsObj, argsToUse); + if (!parseResult.success) { + const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const errorMessage = getParseErrorMessage(error); + throw ProtocolError.invalidParams(`Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); + } + + const args = parseResult.data; + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(args, ctx)); + } else { + const cb = prompt.callback as PromptCallback; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve((cb as any)(ctx)); + } + }, + perPromptMiddleware + ); }); this._promptHandlersInitialized = true; } + usePlugin(plugin: ProtocolPlugin): this { + this.server.usePlugin(plugin); + return this; + } + /** * Registers a resource with a config object and callback. * For static resources, use a URI string. For dynamic resources, use a ResourceTemplate. */ - registerResource(name: string, uriOrTemplate: string, config: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + registerResource( + name: string, + uriOrTemplate: string, + config: ResourceMetadata, + readCallback: ReadResourceCallback + ): RegisteredResourceEntity; registerResource( name: string, uriOrTemplate: ResourceTemplate, config: ResourceMetadata, readCallback: ReadResourceTemplateCallback - ): RegisteredResourceTemplate; + ): RegisteredResourceTemplateEntity; registerResource( name: string, uriOrTemplate: string | ResourceTemplate, config: ResourceMetadata, readCallback: ReadResourceCallback | ReadResourceTemplateCallback - ): RegisteredResource | RegisteredResourceTemplate { + ): RegisteredResourceEntity | RegisteredResourceTemplateEntity { if (typeof uriOrTemplate === 'string') { - if (this._registeredResources[uriOrTemplate]) { - throw new Error(`Resource ${uriOrTemplate} is already registered`); - } - - const registeredResource = this._createRegisteredResource( + const registeredResource = this._resourceRegistry.register({ name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceCallback - ); + uri: uriOrTemplate, + title: (config as BaseMetadata).title, + description: config.description, + mimeType: config.mimeType, + metadata: config, + readCallback: readCallback as ReadResourceCallback + }); this.setResourceRequestHandlers(); - this.sendResourceListChanged(); return registeredResource; } else { - if (this._registeredResourceTemplates[name]) { - throw new Error(`Resource template ${name} is already registered`); - } - - const registeredResourceTemplate = this._createRegisteredResourceTemplate( + const registeredResourceTemplate = this._resourceTemplateRegistry.register({ name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceTemplateCallback - ); + template: uriOrTemplate, + title: (config as BaseMetadata).title, + description: config.description, + mimeType: config.mimeType, + metadata: config, + readCallback: readCallback as ReadResourceTemplateCallback + }); this.setResourceRequestHandlers(); - this.sendResourceListChanged(); - return registeredResourceTemplate; - } - } - - private _createRegisteredResource( - name: string, - title: string | undefined, - uri: string, - metadata: ResourceMetadata | undefined, - readCallback: ReadResourceCallback - ): RegisteredResource { - const registeredResource: RegisteredResource = { - name, - title, - metadata, - readCallback, - enabled: true, - disable: () => registeredResource.update({ enabled: false }), - enable: () => registeredResource.update({ enabled: true }), - remove: () => registeredResource.update({ uri: null }), - update: updates => { - if (updates.uri !== undefined && updates.uri !== uri) { - delete this._registeredResources[uri]; - if (updates.uri) this._registeredResources[updates.uri] = registeredResource; - } - if (updates.name !== undefined) registeredResource.name = updates.name; - if (updates.title !== undefined) registeredResource.title = updates.title; - if (updates.metadata !== undefined) registeredResource.metadata = updates.metadata; - if (updates.callback !== undefined) registeredResource.readCallback = updates.callback; - if (updates.enabled !== undefined) registeredResource.enabled = updates.enabled; - this.sendResourceListChanged(); - } - }; - this._registeredResources[uri] = registeredResource; - return registeredResource; - } - - private _createRegisteredResourceTemplate( - name: string, - title: string | undefined, - template: ResourceTemplate, - metadata: ResourceMetadata | undefined, - readCallback: ReadResourceTemplateCallback - ): RegisteredResourceTemplate { - const registeredResourceTemplate: RegisteredResourceTemplate = { - resourceTemplate: template, - title, - metadata, - readCallback, - enabled: true, - disable: () => registeredResourceTemplate.update({ enabled: false }), - enable: () => registeredResourceTemplate.update({ enabled: true }), - remove: () => registeredResourceTemplate.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - delete this._registeredResourceTemplates[name]; - if (updates.name) this._registeredResourceTemplates[updates.name] = registeredResourceTemplate; - } - if (updates.title !== undefined) registeredResourceTemplate.title = updates.title; - if (updates.template !== undefined) registeredResourceTemplate.resourceTemplate = updates.template; - if (updates.metadata !== undefined) registeredResourceTemplate.metadata = updates.metadata; - if (updates.callback !== undefined) registeredResourceTemplate.readCallback = updates.callback; - if (updates.enabled !== undefined) registeredResourceTemplate.enabled = updates.enabled; - this.sendResourceListChanged(); - } - }; - this._registeredResourceTemplates[name] = registeredResourceTemplate; - - // If the resource template has any completion callbacks, enable completions capability - const variableNames = template.uriTemplate.variableNames; - const hasCompleter = Array.isArray(variableNames) && variableNames.some(v => !!template.completeCallback(v)); - if (hasCompleter) { - this.setCompletionRequestHandler(); - } - return registeredResourceTemplate; - } - - private _createRegisteredPrompt( - name: string, - title: string | undefined, - description: string | undefined, - argsSchema: PromptArgsRawShape | undefined, - callback: PromptCallback - ): RegisteredPrompt { - const registeredPrompt: RegisteredPrompt = { - title, - description, - argsSchema: argsSchema === undefined ? undefined : objectFromShape(argsSchema), - callback, - enabled: true, - disable: () => registeredPrompt.update({ enabled: false }), - enable: () => registeredPrompt.update({ enabled: true }), - remove: () => registeredPrompt.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - delete this._registeredPrompts[name]; - if (updates.name) this._registeredPrompts[updates.name] = registeredPrompt; - } - if (updates.title !== undefined) registeredPrompt.title = updates.title; - if (updates.description !== undefined) registeredPrompt.description = updates.description; - if (updates.argsSchema !== undefined) registeredPrompt.argsSchema = objectFromShape(updates.argsSchema); - if (updates.callback !== undefined) registeredPrompt.callback = updates.callback; - if (updates.enabled !== undefined) registeredPrompt.enabled = updates.enabled; - this.sendPromptListChanged(); - } - }; - this._registeredPrompts[name] = registeredPrompt; - - // If any argument uses a Completable schema, enable completions capability - if (argsSchema) { - const hasCompletable = Object.values(argsSchema).some(field => { - const inner: unknown = field instanceof ZodOptional ? field._def?.innerType : field; - return isCompletable(inner); - }); - if (hasCompletable) { + // If the resource template has any completion callbacks, enable completions capability + const variableNames = uriOrTemplate.uriTemplate.variableNames; + const hasCompleter = Array.isArray(variableNames) && variableNames.some(v => !!uriOrTemplate.completeCallback(v)); + if (hasCompleter) { this.setCompletionRequestHandler(); } - } - - return registeredPrompt; - } - private _createRegisteredTool( - name: string, - title: string | undefined, - description: string | undefined, - inputSchema: ZodRawShapeCompat | AnySchema | undefined, - outputSchema: ZodRawShapeCompat | AnySchema | undefined, - annotations: ToolAnnotations | undefined, - execution: ToolExecution | undefined, - _meta: Record | undefined, - handler: AnyToolHandler - ): RegisteredTool { - // Validate tool name according to SEP specification - validateAndWarnToolName(name); - - const registeredTool: RegisteredTool = { - title, - description, - inputSchema: getZodSchemaObject(inputSchema), - outputSchema: getZodSchemaObject(outputSchema), - annotations, - execution, - _meta, - handler: handler, - enabled: true, - disable: () => registeredTool.update({ enabled: false }), - enable: () => registeredTool.update({ enabled: true }), - remove: () => registeredTool.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - if (typeof updates.name === 'string') { - validateAndWarnToolName(updates.name); - } - delete this._registeredTools[name]; - if (updates.name) this._registeredTools[updates.name] = registeredTool; - } - if (updates.title !== undefined) registeredTool.title = updates.title; - if (updates.description !== undefined) registeredTool.description = updates.description; - if (updates.paramsSchema !== undefined) registeredTool.inputSchema = objectFromShape(updates.paramsSchema); - if (updates.outputSchema !== undefined) registeredTool.outputSchema = objectFromShape(updates.outputSchema); - if (updates.callback !== undefined) registeredTool.handler = updates.callback; - if (updates.annotations !== undefined) registeredTool.annotations = updates.annotations; - if (updates._meta !== undefined) registeredTool._meta = updates._meta; - if (updates.enabled !== undefined) registeredTool.enabled = updates.enabled; - this.sendToolListChanged(); - } - }; - this._registeredTools[name] = registeredTool; - - this.setToolRequestHandlers(); - this.sendToolListChanged(); - - return registeredTool; + return registeredResourceTemplate; + } } /** @@ -864,33 +926,33 @@ export class McpServer { inputSchema?: InputArgs; outputSchema?: OutputArgs; annotations?: ToolAnnotations; + execution?: ToolExecution; _meta?: Record; }, cb: ToolCallback ): RegisteredTool { - if (this._registeredTools[name]) { - throw new Error(`Tool ${name} is already registered`); - } + const { title, description, inputSchema, outputSchema, annotations, execution, _meta } = config; - const { title, description, inputSchema, outputSchema, annotations, _meta } = config; - - return this._createRegisteredTool( + const registeredTool = this._toolRegistry.register({ name, title, description, - inputSchema, - outputSchema, + inputSchema: getZodSchemaObject(inputSchema), + outputSchema: getZodSchemaObject(outputSchema), annotations, - { taskSupport: 'forbidden' }, + execution: execution ?? { taskSupport: 'forbidden' }, _meta, - cb as ToolCallback - ); + handler: cb as ToolCallback + }); + + this.setToolRequestHandlers(); + return registeredTool; } /** * Registers a prompt with a config object and callback. */ - registerPrompt( + registerPrompt( name: string, config: { title?: string; @@ -899,22 +961,28 @@ export class McpServer { }, cb: PromptCallback ): RegisteredPrompt { - if (this._registeredPrompts[name]) { - throw new Error(`Prompt ${name} is already registered`); - } - const { title, description, argsSchema } = config; - const registeredPrompt = this._createRegisteredPrompt( + const registeredPrompt = this._promptRegistry.register({ name, title, description, argsSchema, - cb as PromptCallback - ); + callback: cb as PromptCallback + }); this.setPromptRequestHandlers(); - this.sendPromptListChanged(); + + // If any argument uses a Completable schema, enable completions capability + if (argsSchema) { + const hasCompletable = Object.values(argsSchema).some(field => { + const inner: unknown = field instanceof ZodOptional ? field._def?.innerType : field; + return isCompletable(inner); + }); + if (hasCompletable) { + this.setCompletionRequestHandler(); + } + } return registeredPrompt; } @@ -963,6 +1031,128 @@ export class McpServer { this.server.sendPromptListChanged(); } } + + /** + * Updates the error interceptor on the underlying Server based on current handlers. + * This combines both onError and onProtocolError handlers into a single interceptor. + */ + private _updateErrorInterceptor(): void { + if (!this._onErrorHandler && !this._onProtocolErrorHandler) { + // No handlers, clear the interceptor + this.server.setErrorInterceptor(undefined); + return; + } + + this.server.setErrorInterceptor(async (error: Error, ctx: ErrorInterceptionContext): Promise => { + const errorContext: ErrorContext = { + type: ctx.type === 'protocol' ? 'protocol' : 'tool', // Map to ErrorContext type + method: ctx.method, + requestId: typeof ctx.requestId === 'string' ? ctx.requestId : String(ctx.requestId) + }; + + let result: OnErrorReturn | OnProtocolErrorReturn | void = undefined; + + if (ctx.type === 'protocol' && this._onProtocolErrorHandler) { + // Protocol error - use onProtocolError handler + result = await this._onProtocolErrorHandler(error, errorContext); + } else if (this._onErrorHandler) { + // Application error (or protocol error without specific handler) - use onError handler + result = await this._onErrorHandler(error, errorContext); + } + + if (result === undefined || result === null) { + return undefined; + } + + // Convert the handler result to ErrorInterceptionResult + if (typeof result === 'string') { + return { message: result }; + } else if (result instanceof Error) { + const errorWithCode = result as Error & { code?: number; data?: unknown }; + return { + message: result.message, + code: ctx.type === 'application' ? errorWithCode.code : undefined, + data: errorWithCode.data + }; + } else { + // Object with code/message/data + return { + message: result.message, + code: ctx.type === 'application' ? (result as OnErrorReturn & { code?: number }).code : undefined, + data: result.data + }; + } + }); + } + + /** + * Registers an error handler for application errors in tool/resource/prompt handlers. + * + * The handler receives the error and a context object with information about where + * the error occurred. It can optionally return a custom error response that will + * modify the error sent to the client. + * + * Note: This is a single-handler pattern. Setting a new handler replaces any previous one. + * The handler is awaited, so async handlers are fully supported. + * + * @param handler - Error handler function + * @returns Unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = server.onError(async (error, ctx) => { + * console.error(`Error in ${ctx.type}/${ctx.method}: ${error.message}`); + * // Optionally return a custom error response + * return { + * code: -32000, + * message: `Application error: ${error.message}`, + * data: { type: ctx.type } + * }; + * }); + * ``` + */ + onError(handler: OnErrorHandler): () => void { + this._onErrorHandler = handler; + this._updateErrorInterceptor(); + return this._clearOnErrorHandler.bind(this); + } + + private _clearOnErrorHandler(): void { + this._onErrorHandler = undefined; + this._updateErrorInterceptor(); + } + + /** + * Registers an error handler for protocol errors (method not found, parse error, etc.). + * + * The handler receives the error and a context object. It can optionally return + * a custom error response. Note that the error code cannot be changed for protocol + * errors as they have fixed codes per the MCP specification. + * + * Note: This is a single-handler pattern. Setting a new handler replaces any previous one. + * The handler is awaited, so async handlers are fully supported. + * + * @param handler - Error handler function + * @returns Unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = server.onProtocolError(async (error, ctx) => { + * console.error(`Protocol error in ${ctx.method}: ${error.message}`); + * return { message: `Protocol error: ${error.message}` }; + * }); + * ``` + */ + onProtocolError(handler: OnProtocolErrorHandler): () => void { + this._onProtocolErrorHandler = handler; + this._updateErrorInterceptor(); + return this._clearOnProtocolErrorHandler.bind(this); + } + + private _clearOnProtocolErrorHandler(): void { + this._onProtocolErrorHandler = undefined; + this._updateErrorInterceptor(); + } } /** @@ -1025,13 +1215,13 @@ export class ResourceTemplate { export type BaseToolCallback< SendResultT extends Result, - Extra extends RequestHandlerExtra, + Extra extends ServerContextInterface, Args extends undefined | ZodRawShapeCompat | AnySchema > = Args extends ZodRawShapeCompat - ? (args: ShapeOutput, extra: Extra) => SendResultT | Promise + ? (args: ShapeOutput, ctx: Extra) => SendResultT | Promise : Args extends AnySchema - ? (args: SchemaOutput, extra: Extra) => SendResultT | Promise - : (extra: Extra) => SendResultT | Promise; + ? (args: SchemaOutput, ctx: Extra) => SendResultT | Promise + : (ctx: Extra) => SendResultT | Promise; /** * Callback for a tool handler registered with Server.tool(). @@ -1045,7 +1235,7 @@ export type BaseToolCallback< */ export type ToolCallback = BaseToolCallback< CallToolResult, - RequestHandlerExtra, + ServerContextInterface, Args >; @@ -1054,37 +1244,6 @@ export type ToolCallback = ToolCallback | ToolTaskHandler; -export type RegisteredTool = { - title?: string; - description?: string; - inputSchema?: AnySchema; - outputSchema?: AnySchema; - annotations?: ToolAnnotations; - execution?: ToolExecution; - _meta?: Record; - handler: AnyToolHandler; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - description?: string; - paramsSchema?: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - _meta?: Record; - callback?: ToolCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -const EMPTY_OBJECT_JSON_SCHEMA = { - type: 'object' as const, - properties: {} -}; - /** * Checks if a value looks like a Zod schema by checking for parse/safeParse methods. */ @@ -1164,105 +1323,9 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra + ctx: ServerContextInterface ) => ListResourcesResult | Promise; -/** - * Callback to read a resource at a given URI. - */ -export type ReadResourceCallback = ( - uri: URL, - extra: RequestHandlerExtra -) => ReadResourceResult | Promise; - -export type RegisteredResource = { - name: string; - title?: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string; - title?: string; - uri?: string | null; - metadata?: ResourceMetadata; - callback?: ReadResourceCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -/** - * Callback to read a resource at a given URI, following a filled-in URI template. - */ -export type ReadResourceTemplateCallback = ( - uri: URL, - variables: Variables, - extra: RequestHandlerExtra -) => ReadResourceResult | Promise; - -export type RegisteredResourceTemplate = { - resourceTemplate: ResourceTemplate; - title?: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceTemplateCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - template?: ResourceTemplate; - metadata?: ResourceMetadata; - callback?: ReadResourceTemplateCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -type PromptArgsRawShape = ZodRawShapeCompat; - -export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, extra: RequestHandlerExtra) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; - -export type RegisteredPrompt = { - title?: string; - description?: string; - argsSchema?: AnyObjectSchema; - callback: PromptCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - description?: string; - argsSchema?: Args; - callback?: PromptCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -function promptArgumentsFromSchema(schema: AnyObjectSchema): PromptArgument[] { - const shape = getObjectShape(schema); - if (!shape) return []; - return Object.entries(shape).map(([name, field]): PromptArgument => { - // Get description - works for both v3 and v4 - const description = getSchemaDescription(field); - // Check if optional - works for both v3 and v4 - const isOptional = isSchemaOptional(field); - return { - name, - description, - required: !isOptional - }; - }); -} - function getMethodValue(schema: AnyObjectSchema): string { const shape = getObjectShape(schema); const methodSchema = shape?.method as AnySchema | undefined; diff --git a/packages/server/src/server/middleware.ts b/packages/server/src/server/middleware.ts new file mode 100644 index 000000000..04324a31c --- /dev/null +++ b/packages/server/src/server/middleware.ts @@ -0,0 +1,453 @@ +/** + * McpServer Middleware System + * + * Provides a flexible middleware system for cross-cutting concerns like + * logging, authentication, rate limiting, metrics, and caching. + * + * Design follows Express/Koa/Hono patterns with the next() pattern for + * maximum flexibility. + */ + +import type { AuthInfo, CallToolResult, GetPromptResult, ReadResourceResult } from '@modelcontextprotocol/core'; + +// ═══════════════════════════════════════════════════════════════════════════ +// Context Interfaces +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Base context shared by all middleware + */ +interface BaseMiddlewareContext { + /** The request ID from JSON-RPC */ + requestId: string; + /** Authentication info if available */ + authInfo?: AuthInfo; + /** Abort signal for cancellation */ + signal: AbortSignal; +} + +/** + * Context for tool middleware + */ +export interface ToolContext extends BaseMiddlewareContext { + type: 'tool'; + /** The name of the tool being called */ + name: string; + /** The arguments passed to the tool */ + args: unknown; +} + +/** + * Context for resource middleware + */ +export interface ResourceContext extends BaseMiddlewareContext { + type: 'resource'; + /** The URI of the resource being read */ + uri: string; +} + +/** + * Context for prompt middleware + */ +export interface PromptContext extends BaseMiddlewareContext { + type: 'prompt'; + /** The name of the prompt being requested */ + name: string; + /** The arguments passed to the prompt */ + args: unknown; +} + +/** + * Union type for all middleware contexts + */ +export type MiddlewareContext = ToolContext | ResourceContext | PromptContext; + +// ═══════════════════════════════════════════════════════════════════════════ +// Middleware Types +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Next function for tool middleware. + * Can optionally pass modified args to the handler. + */ +export type ToolNextFn = (modifiedArgs?: unknown) => Promise; + +/** + * Next function for resource middleware. + * Can optionally pass a modified URI to the handler. + */ +export type ResourceNextFn = (modifiedUri?: string) => Promise; + +/** + * Next function for prompt middleware. + * Can optionally pass modified args to the handler. + */ +export type PromptNextFn = (modifiedArgs?: unknown) => Promise; + +/** + * Next function for universal middleware. + * Can optionally pass modified input to the handler. + */ +export type UniversalNextFn = (modified?: unknown) => Promise; + +/** + * Middleware for tool calls. + * Can abort, short-circuit, modify args, or pass through. + */ +export type ToolMiddleware = (ctx: ToolContext, next: ToolNextFn) => Promise; + +/** + * Middleware for resource reads. + * Can abort, short-circuit, modify URI, or pass through. + */ +export type ResourceMiddleware = (ctx: ResourceContext, next: ResourceNextFn) => Promise; + +/** + * Middleware for prompt requests. + * Can abort, short-circuit, modify args, or pass through. + */ +export type PromptMiddleware = (ctx: PromptContext, next: PromptNextFn) => Promise; + +/** + * Universal middleware that works for all types. + * Use the `type` property on the context to differentiate. + */ +export type UniversalMiddleware = (ctx: MiddlewareContext, next: UniversalNextFn) => Promise; + +// ═══════════════════════════════════════════════════════════════════════════ +// Middleware Chain Builder +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Composes multiple middleware functions into a single function. + * Each middleware can: + * - Abort with error: throw + * - Short-circuit: return result without calling next() + * - Modify input: call next(modified) + * - Pass through: call next() + * + * @param middlewares - Array of middleware functions + * @param handler - The final handler to call + * @returns A composed function that runs all middleware and the handler + */ +export function composeMiddleware( + middlewares: Array<(ctx: TCtx, next: (input?: TInput) => Promise) => Promise>, + handler: (ctx: TCtx, input?: TInput) => Promise +): (ctx: TCtx, initialInput?: TInput) => Promise { + return async (ctx: TCtx, initialInput?: TInput): Promise => { + let index = -1; + let currentInput: TInput | undefined = initialInput; + + const dispatch = async (i: number, input?: TInput): Promise => { + if (i <= index) { + throw new Error('next() called multiple times'); + } + index = i; + currentInput = input ?? currentInput; + + if (i >= middlewares.length) { + // All middleware processed, call the final handler + return handler(ctx, currentInput); + } + + const middleware = middlewares[i]; + if (!middleware) { + return handler(ctx, currentInput); + } + return middleware(ctx, (modifiedInput?: TInput) => dispatch(i + 1, modifiedInput)); + }; + + return dispatch(0, initialInput); + }; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Middleware Manager +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Manages middleware registration and execution for McpServer. + */ +export class MiddlewareManager { + private _universalMiddleware: UniversalMiddleware[] = []; + private _toolMiddleware: ToolMiddleware[] = []; + private _resourceMiddleware: ResourceMiddleware[] = []; + private _promptMiddleware: PromptMiddleware[] = []; + + // Per-item middleware (keyed by name/uri) + private _perToolMiddleware = new Map(); + private _perResourceMiddleware = new Map(); + private _perPromptMiddleware = new Map(); + + /** + * Registers universal middleware that runs for all request types. + */ + useMiddleware(middleware: UniversalMiddleware): this { + this._universalMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for tool calls. + */ + useToolMiddleware(middleware: ToolMiddleware): this { + this._toolMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for resource reads. + */ + useResourceMiddleware(middleware: ResourceMiddleware): this { + this._resourceMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for prompt requests. + */ + usePromptMiddleware(middleware: PromptMiddleware): this { + this._promptMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware for a specific tool by name. + */ + useToolMiddlewareFor(name: string, middleware: ToolMiddleware): this { + this._perToolMiddleware.set(name, middleware); + return this; + } + + /** + * Registers middleware for a specific resource by URI. + */ + useResourceMiddlewareFor(uri: string, middleware: ResourceMiddleware): this { + this._perResourceMiddleware.set(uri, middleware); + return this; + } + + /** + * Registers middleware for a specific prompt by name. + */ + usePromptMiddlewareFor(name: string, middleware: PromptMiddleware): this { + this._perPromptMiddleware.set(name, middleware); + return this; + } + + /** + * Gets per-tool middleware if registered. + */ + getToolMiddlewareFor(name: string): ToolMiddleware | undefined { + return this._perToolMiddleware.get(name); + } + + /** + * Gets per-resource middleware if registered. + */ + getResourceMiddlewareFor(uri: string): ResourceMiddleware | undefined { + return this._perResourceMiddleware.get(uri); + } + + /** + * Gets per-prompt middleware if registered. + */ + getPromptMiddlewareFor(name: string): PromptMiddleware | undefined { + return this._perPromptMiddleware.get(name); + } + + /** + * Executes tool middleware chain with the given context and handler. + */ + async executeToolMiddleware( + ctx: ToolContext, + handler: (ctx: ToolContext, args?: unknown) => Promise, + perRegistrationMiddleware?: ToolMiddleware + ): Promise { + // Build middleware chain: universal -> tool-specific -> per-registration + const chain: ToolMiddleware[] = []; + + // Add universal middleware (cast to tool middleware) + for (const mw of this._universalMiddleware) { + chain.push(async (c, next) => { + return (await mw(c, async modified => { + return next(modified as unknown); + })) as CallToolResult; + }); + } + + // Add tool-specific middleware + chain.push(...this._toolMiddleware); + + // Add per-registration middleware if provided + if (perRegistrationMiddleware) { + chain.push(perRegistrationMiddleware); + } + + // Compose and execute + const composed = composeMiddleware(chain, handler); + return composed(ctx, ctx.args); + } + + /** + * Executes resource middleware chain with the given context and handler. + */ + async executeResourceMiddleware( + ctx: ResourceContext, + handler: (ctx: ResourceContext, uri?: string) => Promise, + perRegistrationMiddleware?: ResourceMiddleware + ): Promise { + // Build middleware chain: universal -> resource-specific -> per-registration + const chain: ResourceMiddleware[] = []; + + // Add universal middleware (cast to resource middleware) + for (const mw of this._universalMiddleware) { + chain.push(async (c, next) => { + return (await mw(c, async modified => { + return next(modified as string); + })) as ReadResourceResult; + }); + } + + // Add resource-specific middleware + chain.push(...this._resourceMiddleware); + + // Add per-registration middleware if provided + if (perRegistrationMiddleware) { + chain.push(perRegistrationMiddleware); + } + + // Compose and execute + const composed = composeMiddleware(chain, handler); + return composed(ctx, ctx.uri); + } + + /** + * Executes prompt middleware chain with the given context and handler. + */ + async executePromptMiddleware( + ctx: PromptContext, + handler: (ctx: PromptContext, args?: unknown) => Promise, + perRegistrationMiddleware?: PromptMiddleware + ): Promise { + // Build middleware chain: universal -> prompt-specific -> per-registration + const chain: PromptMiddleware[] = []; + + // Add universal middleware (cast to prompt middleware) + for (const mw of this._universalMiddleware) { + chain.push(async (c, next) => { + return (await mw(c, async modified => { + return next(modified as unknown); + })) as GetPromptResult; + }); + } + + // Add prompt-specific middleware + chain.push(...this._promptMiddleware); + + // Add per-registration middleware if provided + if (perRegistrationMiddleware) { + chain.push(perRegistrationMiddleware); + } + + // Compose and execute + const composed = composeMiddleware(chain, handler); + return composed(ctx, ctx.args); + } + + /** + * Checks if any middleware is registered. + */ + hasMiddleware(): boolean { + return ( + this._universalMiddleware.length > 0 || + this._toolMiddleware.length > 0 || + this._resourceMiddleware.length > 0 || + this._promptMiddleware.length > 0 || + this._perToolMiddleware.size > 0 || + this._perResourceMiddleware.size > 0 || + this._perPromptMiddleware.size > 0 + ); + } + + /** + * Clears all registered middleware. + */ + clear(): void { + this._universalMiddleware = []; + this._toolMiddleware = []; + this._resourceMiddleware = []; + this._promptMiddleware = []; + this._perToolMiddleware.clear(); + this._perResourceMiddleware.clear(); + this._perPromptMiddleware.clear(); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Built-in Middleware Factories +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Options for the logging middleware. + */ +export interface LoggingMiddlewareOptions { + /** Log level: 'debug', 'info', 'warn', 'error' */ + level?: 'debug' | 'info' | 'warn' | 'error'; + /** Custom logger function */ + logger?: (level: string, message: string, data?: unknown) => void; +} + +/** + * Creates a logging middleware that logs all requests. + * + * @example + * ```typescript + * server.useMiddleware(createLoggingMiddleware({ level: 'debug' })); + * ``` + */ +export function createLoggingMiddleware(options: LoggingMiddlewareOptions = {}): UniversalMiddleware { + const { level = 'info', logger = console.log } = options; + + return async (ctx, next) => { + const identifier = ctx.type === 'resource' ? ctx.uri : ctx.name; + logger(level, `β†’ ${ctx.type}: ${identifier}`, { + type: ctx.type, + requestId: ctx.requestId + }); + + const start = Date.now(); + + try { + const result = await next(); + const duration = Date.now() - start; + logger(level, `← ${ctx.type}: ${identifier} (${duration}ms)`, { + type: ctx.type, + requestId: ctx.requestId, + duration + }); + return result; + } catch (error) { + const duration = Date.now() - start; + logger('error', `βœ— ${ctx.type}: ${identifier} (${duration}ms)`, { + type: ctx.type, + requestId: ctx.requestId, + duration, + error + }); + throw error; + } + }; +} + +/** + * Options for the rate limit middleware. + */ +export interface RateLimitMiddlewareOptions { + /** Maximum requests per time window */ + max: number; + /** Time window in milliseconds */ + windowMs?: number; + /** Error message when rate limited */ + message?: string; +} diff --git a/packages/server/src/server/registries/baseRegistry.ts b/packages/server/src/server/registries/baseRegistry.ts new file mode 100644 index 000000000..fa168e2b5 --- /dev/null +++ b/packages/server/src/server/registries/baseRegistry.ts @@ -0,0 +1,229 @@ +/** + * Base Registry + * + * Abstract base class for managing collections of registered entities + * (tools, resources, prompts). Provides common functionality for + * CRUD operations and notifications. + */ + +/** + * Base interface for all registered definitions + */ +export interface RegisteredDefinition { + /** + * Whether the definition is currently enabled + */ + enabled: boolean; + + /** + * Enable the definition + */ + enable(): this; + + /** + * Disable the definition + */ + disable(): this; + + /** + * Remove the definition from its registry + */ + remove(): void; +} + +/** + * Callback type for registry change notifications + */ +export type RegistryNotifyCallback = () => void; + +/** + * Abstract base class for registries. + * Provides common functionality for managing collections of registered entities. + * + * @template T - The type of registered entity this registry manages + */ +export abstract class BaseRegistry { + /** + * Internal storage for registered items + */ + protected _items = new Map(); + + /** + * Optional callback for change notifications. + * Can be set after construction via setNotifyCallback(). + */ + protected _notifyCallback?: RegistryNotifyCallback; + + /** + * Sets or updates the notification callback. + * This allows the callback to be bound after construction (e.g., by McpServer + * when using registries created by the builder). + * + * @param callback - The callback to invoke when the registry changes + */ + setNotifyCallback(callback: RegistryNotifyCallback): void { + this._notifyCallback = callback; + } + + /** + * Called when the registry contents change. + * Invokes the notification callback if one is set. + */ + protected notifyChanged(): void { + this._notifyCallback?.(); + } + + /** + * Checks if an item with the given ID exists in the registry. + * + * @param id - The identifier to check + * @returns true if the item exists + */ + has(id: string): boolean { + return this._items.has(id); + } + + /** + * Gets an item by its ID. + * + * @param id - The identifier of the item + * @returns The item or undefined if not found + */ + get(id: string): T | undefined { + return this._items.get(id); + } + + /** + * Gets all items in the registry as a read-only map. + * + * @returns A read-only map of all items + */ + getAll(): ReadonlyMap { + return this._items; + } + + /** + * Gets all items as an array. + * + * @returns Array of all registered items + */ + values(): T[] { + return [...this._items.values()]; + } + + /** + * Gets all enabled items as an array. + * + * @returns Array of enabled items + */ + getEnabled(): T[] { + return this.values().filter(item => item.enabled); + } + + /** + * Gets all disabled items as an array. + * + * @returns Array of disabled items + */ + getDisabled(): T[] { + return this.values().filter(item => !item.enabled); + } + + /** + * Gets the number of items in the registry. + */ + get size(): number { + return this._items.size; + } + + /** + * Removes an item from the registry. + * + * @param id - The identifier of the item to remove + * @returns true if the item was removed, false if it didn't exist + */ + remove(id: string): boolean { + const deleted = this._items.delete(id); + if (deleted) { + this.notifyChanged(); + } + return deleted; + } + + /** + * Disables all items in the registry. + */ + disableAll(): void { + let changed = false; + for (const item of this._items.values()) { + if (item.enabled) { + item.disable(); + changed = true; + } + } + if (changed) { + this.notifyChanged(); + } + } + + /** + * Enables all items in the registry. + */ + enableAll(): void { + let changed = false; + for (const item of this._items.values()) { + if (!item.enabled) { + item.enable(); + changed = true; + } + } + if (changed) { + this.notifyChanged(); + } + } + + /** + * Clears all items from the registry. + */ + clear(): void { + if (this._items.size > 0) { + this._items.clear(); + this.notifyChanged(); + } + } + + /** + * Internal method to add or update an item in the registry. + * Used by subclasses during registration. + * + * @param id - The identifier for the item + * @param item - The item to add + */ + protected _set(id: string, item: T): void { + this._items.set(id, item); + } + + /** + * Internal method to rename an item in the registry. + * + * @param oldId - The current identifier + * @param newId - The new identifier + * @returns true if renamed successfully + */ + protected _rename(oldId: string, newId: string): boolean { + const item = this._items.get(oldId); + if (!item) { + return false; + } + if (oldId === newId) { + return true; + } + if (this._items.has(newId)) { + throw new Error(`Cannot rename: '${newId}' already exists`); + } + this._items.delete(oldId); + this._items.set(newId, item); + this.notifyChanged(); + return true; + } +} diff --git a/packages/server/src/server/registries/index.ts b/packages/server/src/server/registries/index.ts new file mode 100644 index 000000000..f0ba60a77 --- /dev/null +++ b/packages/server/src/server/registries/index.ts @@ -0,0 +1,10 @@ +/** + * Registries Module + * + * Exports registry classes and entities for managing tools, resources, and prompts. + */ + +export * from './baseRegistry.js'; +export * from './promptRegistry.js'; +export * from './resourceRegistry.js'; +export * from './toolRegistry.js'; diff --git a/packages/server/src/server/registries/promptRegistry.ts b/packages/server/src/server/registries/promptRegistry.ts new file mode 100644 index 000000000..8a9aac658 --- /dev/null +++ b/packages/server/src/server/registries/promptRegistry.ts @@ -0,0 +1,241 @@ +/** + * Prompt Registry + * + * Manages registration and retrieval of prompts. + * Provides class-based RegisteredPromptEntity entities with proper encapsulation. + */ + +import type { AnyObjectSchema, AnySchema, Prompt, PromptArgument, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import { getObjectShape, getSchemaDescription, isSchemaOptional, objectFromShape } from '@modelcontextprotocol/core'; + +import type { PromptCallback, RegisteredPromptInterface } from '../../types/types.js'; +import { BaseRegistry } from './baseRegistry.js'; + +/** + * Configuration for registering a prompt + */ +export interface PromptConfig { + name: string; + title?: string; + description?: string; + argsSchema?: ZodRawShapeCompat; + callback: PromptCallback; +} + +/** + * Updates that can be applied to a registered prompt + */ +export interface PromptUpdates { + name?: string | null; + title?: string; + description?: string; + argsSchema?: ZodRawShapeCompat; + callback?: PromptCallback; + enabled?: boolean; +} + +/** + * Class-based representation of a registered prompt. + * Provides methods for managing the prompt's lifecycle. + */ +export class RegisteredPrompt implements RegisteredPromptInterface { + #name: string; + #enabled: boolean = true; + readonly #registry: PromptRegistry; + + #title?: string; + #description?: string; + #argsSchema?: AnyObjectSchema; + #callback: PromptCallback; + + constructor(config: PromptConfig, registry: PromptRegistry) { + this.#name = config.name; + this.#registry = registry; + this.#title = config.title; + this.#description = config.description; + this.#argsSchema = config.argsSchema ? objectFromShape(config.argsSchema) : undefined; + this.#callback = config.callback; + } + + /** The prompt's name (identifier) */ + get name(): string { + return this.#name; + } + + /** Whether the prompt is currently enabled */ + get enabled(): boolean { + return this.#enabled; + } + + /** The prompt's title */ + get title(): string | undefined { + return this.#title; + } + + /** The prompt's description */ + get description(): string | undefined { + return this.#description; + } + + /** The prompt's args schema */ + get argsSchema(): AnyObjectSchema | undefined { + return this.#argsSchema; + } + + /** The prompt's callback */ + get callback(): PromptCallback { + return this.#callback; + } + + /** + * Enables the prompt + */ + enable(): this { + if (!this.#enabled) { + this.#enabled = true; + this.#registry['notifyChanged'](); + } + return this; + } + + /** + * Disables the prompt + */ + disable(): this { + if (this.#enabled) { + this.#enabled = false; + this.#registry['notifyChanged'](); + } + return this; + } + + /** + * Removes the prompt from its registry + */ + remove(): void { + this.#registry.remove(this.#name); + } + + /** + * Renames the prompt + * + * @param newName - The new name for the prompt + */ + rename(newName: string): this { + this.#registry['_rename'](this.#name, newName); + this.#name = newName; + return this; + } + + /** + * Updates the prompt's properties + * + * @param updates - The updates to apply + */ + update(updates: PromptUpdates): void { + if (updates.name !== undefined) { + if (updates.name === null) { + this.remove(); + return; + } + this.rename(updates.name); + } + if (updates.title !== undefined) this.#title = updates.title; + if (updates.description !== undefined) this.#description = updates.description; + if (updates.argsSchema !== undefined) this.#argsSchema = objectFromShape(updates.argsSchema); + if (updates.callback !== undefined) this.#callback = updates.callback; + if (updates.enabled === undefined) { + this.#registry['notifyChanged'](); + } else { + if (updates.enabled) { + this.enable(); + } else { + this.disable(); + } + } + } + + /** + * Converts to the Prompt protocol type (for list responses) + */ + toProtocolPrompt(): Prompt { + return { + name: this.#name, + title: this.#title, + description: this.#description, + arguments: this.#argsSchema ? promptArgumentsFromSchema(this.#argsSchema) : undefined + }; + } +} + +/** + * Registry for managing prompts. + */ +export class PromptRegistry extends BaseRegistry { + /** + * Creates a new PromptRegistry. + * + * @param sendNotification - Optional callback to invoke when the prompt list changes. + * Can be set later via setNotifyCallback(). + */ + constructor(sendNotification?: () => void) { + super(); + if (sendNotification) { + this.setNotifyCallback(sendNotification); + } + } + + /** + * Registers a new prompt. + * + * @param config - The prompt configuration + * @returns The registered prompt + * @throws If a prompt with the same name already exists + */ + register(config: PromptConfig): RegisteredPrompt { + if (this._items.has(config.name)) { + throw new Error(`Prompt '${config.name}' is already registered`); + } + + const prompt = new RegisteredPrompt(config, this); + this._set(config.name, prompt); + this.notifyChanged(); + return prompt; + } + + /** + * Gets the list of enabled prompts in protocol format. + * + * @returns Array of Prompt objects for the protocol response + */ + getProtocolPrompts(): Prompt[] { + return this.getEnabled().map(prompt => prompt.toProtocolPrompt()); + } + + /** + * Gets a prompt by name. + * + * @param name - The prompt name + * @returns The registered prompt or undefined + */ + getPrompt(name: string): RegisteredPrompt | undefined { + return this.get(name); + } +} + +/** + * Converts a Zod object schema to an array of PromptArgument for the protocol. + */ +function promptArgumentsFromSchema(schema: AnyObjectSchema): PromptArgument[] { + const shape = getObjectShape(schema); + if (!shape) return []; + return Object.entries(shape).map(([name, field]): PromptArgument => { + const description = getSchemaDescription(field as AnySchema); + const isOptional = isSchemaOptional(field as AnySchema); + return { + name, + description, + required: !isOptional + }; + }); +} diff --git a/packages/server/src/server/registries/resourceRegistry.ts b/packages/server/src/server/registries/resourceRegistry.ts new file mode 100644 index 000000000..75b07b535 --- /dev/null +++ b/packages/server/src/server/registries/resourceRegistry.ts @@ -0,0 +1,497 @@ +/** + * Resource Registry + * + * Manages registration and retrieval of resources and resource templates. + * Provides class-based RegisteredResourceEntity entities with proper encapsulation. + */ + +import type { Resource, ResourceTemplateType as ResourceTemplateProtocol, Variables } from '@modelcontextprotocol/core'; + +import type { ReadResourceCallback, ReadResourceTemplateCallback } from '../../types/types.js'; +import type { ResourceMetadata, ResourceTemplate } from '../mcp.js'; +import type { RegisteredDefinition } from './baseRegistry.js'; +import { BaseRegistry } from './baseRegistry.js'; + +/** + * Configuration for registering a static resource + */ +export interface ResourceConfig { + name: string; + uri: string; + title?: string; + description?: string; + mimeType?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; +} + +/** + * Configuration for registering a resource template + */ +export interface ResourceTemplateConfig { + name: string; + template: ResourceTemplate; + title?: string; + description?: string; + mimeType?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; +} + +/** + * Updates that can be applied to a registered resource + */ +export interface ResourceUpdates { + name?: string; + uri?: string | null; + title?: string; + description?: string; + mimeType?: string; + metadata?: ResourceMetadata; + callback?: ReadResourceCallback; + enabled?: boolean; +} + +/** + * Updates that can be applied to a registered resource template + */ +export interface ResourceTemplateUpdates { + name?: string | null; + title?: string; + description?: string; + mimeType?: string; + template?: ResourceTemplate; + metadata?: ResourceMetadata; + callback?: ReadResourceTemplateCallback; + enabled?: boolean; +} + +/** + * Class-based representation of a registered static resource. + * Provides methods for managing the resource's lifecycle. + */ +export class RegisteredResourceEntity implements RegisteredDefinition { + private _name: string; + private _uri: string; + private _enabled: boolean = true; + private readonly _registry: ResourceRegistry; + + private _title?: string; + private _description?: string; + private _mimeType?: string; + private _metadata?: ResourceMetadata; + private _readCallback: ReadResourceCallback; + + constructor(config: ResourceConfig, registry: ResourceRegistry) { + this._name = config.name; + this._uri = config.uri; + this._registry = registry; + this._title = config.title; + this._description = config.description; + this._mimeType = config.mimeType; + this._metadata = config.metadata; + this._readCallback = config.readCallback; + } + + /** The resource's name */ + get name(): string { + return this._name; + } + + /** The resource's URI */ + get uri(): string { + return this._uri; + } + + /** Whether the resource is currently enabled */ + get enabled(): boolean { + return this._enabled; + } + + /** The resource's title */ + get title(): string | undefined { + return this._title; + } + + /** The resource's description */ + get description(): string | undefined { + return this._description; + } + + /** The resource's MIME type */ + get mimeType(): string | undefined { + return this._mimeType; + } + + /** The resource's metadata */ + get metadata(): ResourceMetadata | undefined { + return this._metadata; + } + + /** The resource's read callback */ + get readCallback(): ReadResourceCallback { + return this._readCallback; + } + + /** + * Enables the resource + */ + enable(): this { + if (!this._enabled) { + this._enabled = true; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Disables the resource + */ + disable(): this { + if (this._enabled) { + this._enabled = false; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Removes the resource from its registry + */ + remove(): void { + this._registry.remove(this._uri); + } + + /** + * Updates the resource's properties + * + * @param updates - The updates to apply + */ + update(updates: ResourceUpdates): void { + if (updates.uri !== undefined) { + if (updates.uri === null) { + this.remove(); + return; + } + // Handle URI change - need to re-register under new URI + const oldUri = this._uri; + this._uri = updates.uri; + this._registry['_items'].delete(oldUri); + this._registry['_items'].set(updates.uri, this); + } + if (updates.name !== undefined) this._name = updates.name; + if (updates.title !== undefined) this._title = updates.title; + if (updates.description !== undefined) this._description = updates.description; + if (updates.mimeType !== undefined) this._mimeType = updates.mimeType; + if (updates.metadata !== undefined) this._metadata = updates.metadata; + if (updates.callback !== undefined) this._readCallback = updates.callback; + if (updates.enabled === undefined) { + this._registry['notifyChanged'](); + } else { + if (updates.enabled) { + this.enable(); + } else { + this.disable(); + } + } + } + + /** + * Converts to the Resource protocol type (for list responses) + */ + toProtocolResource(): Resource { + return { + uri: this._uri, + name: this._name, + title: this._title, + description: this._description, + mimeType: this._mimeType, + ...this._metadata + }; + } +} + +/** + * Class-based representation of a registered resource template. + * Provides methods for managing the template's lifecycle. + */ +export class RegisteredResourceTemplateEntity implements RegisteredDefinition { + private _name: string; + private _enabled: boolean = true; + private readonly _registry: ResourceTemplateRegistry; + + private _title?: string; + private _description?: string; + private _mimeType?: string; + private _metadata?: ResourceMetadata; + private _template: ResourceTemplate; + private _readCallback: ReadResourceTemplateCallback; + + constructor(config: ResourceTemplateConfig, registry: ResourceTemplateRegistry) { + this._name = config.name; + this._registry = registry; + this._title = config.title; + this._description = config.description; + this._mimeType = config.mimeType; + this._metadata = config.metadata; + this._template = config.template; + this._readCallback = config.readCallback; + } + + /** The template's name (identifier) */ + get name(): string { + return this._name; + } + + /** Whether the template is currently enabled */ + get enabled(): boolean { + return this._enabled; + } + + /** The template's title */ + get title(): string | undefined { + return this._title; + } + + /** The template's description */ + get description(): string | undefined { + return this._description; + } + + /** The template's MIME type */ + get mimeType(): string | undefined { + return this._mimeType; + } + + /** The template's metadata */ + get metadata(): ResourceMetadata | undefined { + return this._metadata; + } + + /** The resource template */ + get template(): ResourceTemplate { + return this._template; + } + + /** Alias for template for backward compatibility */ + get resourceTemplate(): ResourceTemplate { + return this._template; + } + + /** The template's read callback */ + get readCallback(): ReadResourceTemplateCallback { + return this._readCallback; + } + + /** + * Enables the template + */ + enable(): this { + if (!this._enabled) { + this._enabled = true; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Disables the template + */ + disable(): this { + if (this._enabled) { + this._enabled = false; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Removes the template from its registry + */ + remove(): void { + this._registry.remove(this._name); + } + + /** + * Renames the template + * + * @param newName - The new name for the template + */ + rename(newName: string): this { + this._registry['_rename'](this._name, newName); + this._name = newName; + return this; + } + + /** + * Updates the template's properties + * + * @param updates - The updates to apply + */ + update(updates: ResourceTemplateUpdates): void { + if (updates.name !== undefined) { + if (updates.name === null) { + this.remove(); + return; + } + this.rename(updates.name); + } + if (updates.title !== undefined) this._title = updates.title; + if (updates.description !== undefined) this._description = updates.description; + if (updates.mimeType !== undefined) this._mimeType = updates.mimeType; + if (updates.metadata !== undefined) this._metadata = updates.metadata; + if (updates.template !== undefined) this._template = updates.template; + if (updates.callback !== undefined) this._readCallback = updates.callback; + if (updates.enabled === undefined) { + this._registry['notifyChanged'](); + } else { + if (updates.enabled) { + this.enable(); + } else { + this.disable(); + } + } + } + + /** + * Converts to the ResourceTemplate protocol type (for list responses) + */ + toProtocolResourceTemplate(): ResourceTemplateProtocol { + return { + name: this._name, + uriTemplate: this._template.uriTemplate.toString(), + title: this._title, + description: this._description, + mimeType: this._mimeType, + ...this._metadata + }; + } +} + +/** + * Registry for managing static resources. + * Resources are keyed by URI. + */ +export class ResourceRegistry extends BaseRegistry { + /** + * Creates a new ResourceRegistry. + * + * @param sendNotification - Optional callback to invoke when the resource list changes. + * Can be set later via setNotifyCallback(). + */ + constructor(sendNotification?: () => void) { + super(); + if (sendNotification) { + this.setNotifyCallback(sendNotification); + } + } + + /** + * Registers a new resource. + * + * @param config - The resource configuration + * @returns The registered resource + * @throws If a resource with the same URI already exists + */ + register(config: ResourceConfig): RegisteredResourceEntity { + if (this._items.has(config.uri)) { + throw new Error(`Resource '${config.uri}' is already registered`); + } + + const resource = new RegisteredResourceEntity(config, this); + this._set(config.uri, resource); + this.notifyChanged(); + return resource; + } + + /** + * Gets the list of enabled resources in protocol format. + * + * @returns Array of Resource objects for the protocol response + */ + getProtocolResources(): Resource[] { + return this.getEnabled().map(resource => resource.toProtocolResource()); + } + + /** + * Gets a resource by URI. + * + * @param uri - The resource URI + * @returns The registered resource or undefined + */ + getResource(uri: string): RegisteredResourceEntity | undefined { + return this.get(uri); + } +} + +/** + * Registry for managing resource templates. + * Templates are keyed by name. + */ +export class ResourceTemplateRegistry extends BaseRegistry { + /** + * Creates a new ResourceTemplateRegistry. + * + * @param sendNotification - Optional callback to invoke when the template list changes. + * Can be set later via setNotifyCallback(). + */ + constructor(sendNotification?: () => void) { + super(); + if (sendNotification) { + this.setNotifyCallback(sendNotification); + } + } + + /** + * Registers a new resource template. + * + * @param config - The template configuration + * @returns The registered template + * @throws If a template with the same name already exists + */ + register(config: ResourceTemplateConfig): RegisteredResourceTemplateEntity { + if (this._items.has(config.name)) { + throw new Error(`Resource template '${config.name}' is already registered`); + } + + const template = new RegisteredResourceTemplateEntity(config, this); + this._set(config.name, template); + this.notifyChanged(); + return template; + } + + /** + * Gets the list of enabled templates in protocol format. + * + * @returns Array of ResourceTemplate objects for the protocol response + */ + getProtocolResourceTemplates(): ResourceTemplateProtocol[] { + return this.getEnabled().map(template => template.toProtocolResourceTemplate()); + } + + /** + * Gets a template by name. + * + * @param name - The template name + * @returns The registered template or undefined + */ + getTemplate(name: string): RegisteredResourceTemplateEntity | undefined { + return this.get(name); + } + + /** + * Finds a template that matches the given URI. + * + * @param uri - The URI to match against templates + * @returns The matching template and extracted variables, or undefined + */ + findMatchingTemplate(uri: string): { template: RegisteredResourceTemplateEntity; variables: Variables } | undefined { + for (const template of this.getEnabled()) { + const variables = template.template.uriTemplate.match(uri); + if (variables) { + return { template, variables }; + } + } + return undefined; + } +} diff --git a/packages/server/src/server/registries/toolRegistry.ts b/packages/server/src/server/registries/toolRegistry.ts new file mode 100644 index 000000000..995893a34 --- /dev/null +++ b/packages/server/src/server/registries/toolRegistry.ts @@ -0,0 +1,303 @@ +/** + * Tool Registry + * + * Manages registration and retrieval of tools. + * Provides class-based RegisteredTool entities with proper encapsulation. + */ + +import type { AnySchema, Tool, ToolAnnotations, ToolExecution, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import { normalizeObjectSchema, toJsonSchemaCompat, validateAndWarnToolName } from '@modelcontextprotocol/core'; + +import type { RegisteredToolInterface } from '../../types/types.js'; +import type { AnyToolHandler } from '../mcp.js'; +import { BaseRegistry } from './baseRegistry.js'; + +/** + * Configuration for registering a tool + */ +export interface ToolConfig { + name: string; + title?: string; + description?: string; + inputSchema?: AnySchema; + outputSchema?: AnySchema; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + handler: AnyToolHandler; +} + +/** + * Updates that can be applied to a registered tool + */ +export interface ToolUpdates { + name?: string | null; + title?: string; + description?: string; + inputSchema?: AnySchema; + outputSchema?: AnySchema; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + handler?: AnyToolHandler; + enabled?: boolean; +} + +const EMPTY_OBJECT_JSON_SCHEMA = { + type: 'object' as const, + properties: {} +}; + +/** + * Class-based representation of a registered tool. + * Provides methods for managing the tool's lifecycle. + */ +export class RegisteredTool implements RegisteredToolInterface { + #name: string; + #enabled: boolean = true; + readonly #registry: ToolRegistry; + + #title?: string; + #description?: string; + #inputSchema?: AnySchema; + #outputSchema?: AnySchema; + #annotations?: ToolAnnotations; + #execution?: ToolExecution; + #__meta?: Record; + #handler: AnyToolHandler; + + constructor(config: ToolConfig, registry: ToolRegistry) { + this.#name = config.name; + this.#registry = registry; + this.#title = config.title; + this.#description = config.description; + this.#inputSchema = config.inputSchema; + this.#outputSchema = config.outputSchema; + this.#annotations = config.annotations; + this.#execution = config.execution; + this.#__meta = config._meta; + this.#handler = config.handler; + } + + /** The tool's name (identifier) */ + get name(): string { + return this.#name; + } + + /** Whether the tool is currently enabled */ + get enabled(): boolean { + return this.#enabled; + } + + /** The tool's title */ + get title(): string | undefined { + return this.#title; + } + + /** The tool's description */ + get description(): string | undefined { + return this.#description; + } + + /** The tool's input schema */ + get inputSchema(): AnySchema | undefined { + return this.#inputSchema; + } + + /** The tool's output schema */ + get outputSchema(): AnySchema | undefined { + return this.#outputSchema; + } + + /** The tool's annotations */ + get annotations(): ToolAnnotations | undefined { + return this.#annotations; + } + + /** The tool's execution settings */ + get execution(): ToolExecution | undefined { + return this.#execution; + } + + /** The tool's metadata */ + get _meta(): Record | undefined { + return this.#__meta; + } + + /** The tool's handler function */ + get handler(): AnyToolHandler { + return this.#handler; + } + + /** + * Enables the tool + */ + enable(): this { + if (!this.#enabled) { + this.#enabled = true; + this.#registry['notifyChanged'](); + } + return this; + } + + /** + * Disables the tool + */ + disable(): this { + if (this.#enabled) { + this.#enabled = false; + this.#registry['notifyChanged'](); + } + return this; + } + + /** + * Removes the tool from its registry + */ + remove(): void { + this.#registry.remove(this.#name); + } + + /** + * Renames the tool + * + * @param newName - The new name for the tool + */ + rename(newName: string): this { + validateAndWarnToolName(newName); + this.#registry['_rename'](this.#name, newName); + this.#name = newName; + return this; + } + + /** + * Updates the tool's properties + * + * @param updates - The updates to apply + */ + update(updates: { + name?: string | null; + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + _meta?: Record; + handler?: AnyToolHandler; + execution?: ToolExecution; + enabled?: boolean; + }): void { + if (updates.name !== undefined) { + if (updates.name === null) { + this.remove(); + return; + } + this.rename(updates.name); + } + if (updates.title !== undefined) this.#title = updates.title; + if (updates.description !== undefined) this.#description = updates.description; + if (updates.inputSchema !== undefined) this.#inputSchema = updates.inputSchema; + if (updates.outputSchema !== undefined) this.#outputSchema = updates.outputSchema; + if (updates.annotations !== undefined) this.#annotations = updates.annotations; + if (updates.execution !== undefined) this.#execution = updates.execution; + if (updates._meta !== undefined) this.#__meta = updates._meta; + if (updates.handler !== undefined) this.#handler = updates.handler; + if (updates.enabled === undefined) { + this.#registry['notifyChanged'](); + } else { + if (updates.enabled) { + this.enable(); + } else { + this.disable(); + } + } + } + + /** + * Converts to the Tool protocol type (for list responses) + */ + toProtocolTool(): Tool { + const tool: Tool = { + name: this.#name, + title: this.#title, + description: this.#description, + inputSchema: this.#inputSchema + ? (toJsonSchemaCompat(normalizeObjectSchema(this.#inputSchema) ?? this.#inputSchema, { + strictUnions: true, + pipeStrategy: 'input' + }) as Tool['inputSchema']) + : EMPTY_OBJECT_JSON_SCHEMA, + annotations: this.#annotations, + execution: this.#execution, + _meta: this.#__meta + }; + + if (this.#outputSchema) { + const obj = normalizeObjectSchema(this.#outputSchema); + if (obj) { + tool.outputSchema = toJsonSchemaCompat(obj, { + strictUnions: true, + pipeStrategy: 'output' + }) as Tool['outputSchema']; + } + } + + return tool; + } +} + +/** + * Registry for managing tools. + */ +export class ToolRegistry extends BaseRegistry { + /** + * Creates a new ToolRegistry. + * + * @param sendNotification - Optional callback to invoke when the tool list changes. + * Can be set later via setNotifyCallback(). + */ + constructor(sendNotification?: () => void) { + super(); + if (sendNotification) { + this.setNotifyCallback(sendNotification); + } + } + + /** + * Registers a new tool. + * + * @param config - The tool configuration + * @returns The registered tool + * @throws If a tool with the same name already exists + */ + register(config: ToolConfig): RegisteredTool { + if (this._items.has(config.name)) { + throw new Error(`Tool '${config.name}' is already registered`); + } + + validateAndWarnToolName(config.name); + const tool = new RegisteredTool(config, this); + this._set(config.name, tool); + this.notifyChanged(); + return tool; + } + + /** + * Gets the list of enabled tools in protocol format. + * + * @returns Array of Tool objects for the protocol response + */ + getProtocolTools(): Tool[] { + return this.getEnabled().map(tool => tool.toProtocolTool()); + } + + /** + * Gets a tool by name. + * + * @param name - The tool name + * @returns The registered tool or undefined + */ + getTool(name: string): RegisteredTool | undefined { + return this.get(name); + } +} diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index f705d6b01..f6fa9f346 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -1,6 +1,8 @@ import type { AnyObjectSchema, + BaseRequestContext, ClientCapabilities, + ContextInterface, CreateMessageRequest, CreateMessageRequestParamsBase, CreateMessageRequestParamsWithTools, @@ -9,19 +11,22 @@ import type { ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, + ErrorInterceptor, Implementation, InitializeRequest, InitializeResult, + JSONRPCRequest, JsonSchemaType, jsonSchemaValidator, ListRootsRequest, LoggingLevel, LoggingMessageNotification, + McpContext, + MessageExtraInfo, Notification, NotificationOptions, ProtocolOptions, Request, - RequestHandlerExtra, RequestOptions, ResourceUpdatedNotification, Result, @@ -32,6 +37,7 @@ import type { ServerResult, ToolResultContent, ToolUseContent, + Transport, ZodV3Internal, ZodV4Internal } from '@modelcontextprotocol/core'; @@ -41,12 +47,12 @@ import { assertToolsCallTaskCapability, CallToolRequestSchema, CallToolResultSchema, + CapabilityError, CreateMessageResultSchema, CreateMessageResultWithToolsSchema, CreateTaskResultSchema, ElicitResultSchema, EmptyResultSchema, - ErrorCode, getObjectShape, InitializedNotificationSchema, InitializeRequestSchema, @@ -54,15 +60,18 @@ import { LATEST_PROTOCOL_VERSION, ListRootsResultSchema, LoggingLevelSchema, - McpError, mergeCapabilities, Protocol, + ProtocolError, safeParse, SetLevelRequestSchema, + StateError, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; +import type { ServerRequestContext } from './context.js'; +import { ServerContext } from './context.js'; export type ServerOptions = ProtocolOptions & { /** @@ -166,9 +175,10 @@ export class Server< this.setNotificationHandler(InitializedNotificationSchema, () => this.oninitialized?.()); if (this._capabilities.logging) { - this.setRequestHandler(SetLevelRequestSchema, async (request, extra) => { + this.setRequestHandler(SetLevelRequestSchema, async (request, ctx) => { + const serverCtx = ctx as ServerContext; const transportSessionId: string | undefined = - extra.sessionId || (extra.requestInfo?.headers.get('mcp-session-id') as string) || undefined; + serverCtx.mcpCtx.sessionId || (serverCtx.requestCtx.headers.get('mcp-session-id') as string) || undefined; const { level } = request.params; const parseResult = LoggingLevelSchema.safeParse(level); if (parseResult.success) { @@ -214,11 +224,37 @@ export class Server< */ public registerCapabilities(capabilities: ServerCapabilities): void { if (this.transport) { - throw new Error('Cannot register capabilities after connecting to transport'); + throw StateError.registrationAfterConnect('capabilities'); } this._capabilities = mergeCapabilities(this._capabilities, capabilities); } + /** + * Sets an error interceptor that can customize error responses before they are sent. + * + * The interceptor is called for both protocol errors (method not found, etc.) and + * application errors (when a handler throws). It can modify the error message and data. + * For application errors, it can also change the error code. + * + * @param interceptor - The error interceptor function, or undefined to clear + * + * @example + * ```typescript + * server.setErrorInterceptor(async (error, ctx) => { + * console.error(`Error in ${ctx.method}: ${error.message}`); + * if (ctx.type === 'application') { + * return { + * message: 'An error occurred', + * data: { originalMessage: error.message } + * }; + * } + * }); + * ``` + */ + public override setErrorInterceptor(interceptor: ErrorInterceptor | undefined): void { + super.setErrorInterceptor(interceptor); + } + /** * Override request handler registration to enforce server-side validation for tools/call. */ @@ -226,9 +262,31 @@ export class Server< requestSchema: T, handler: ( request: SchemaOutput, - extra: RequestHandlerExtra + extra: ServerContext ) => ServerResult | ResultT | Promise ): void { + // Wrap the handler to ensure the context is a ServerContext and return a decorated handler that can be passed to the base implementation + + // Factory function to create a handler decorator that ensures the context is a ServerContext and returns a decorated handler that can be passed to the base implementation + const handlerDecoratorFactory = ( + innerHandler: ( + request: SchemaOutput, + ctx: ServerContext + ) => ServerResult | ResultT | Promise + ) => { + const decoratedHandler = ( + request: SchemaOutput, + ctx: ContextInterface + ) => { + if (!this.isContextExtra(ctx)) { + throw new Error('Internal error: Expected ServerContext for request handler context'); + } + return innerHandler(request, ctx); + }; + + return decoratedHandler; + }; + const shape = getObjectShape(requestSchema); const methodSchema = shape?.method; if (!methodSchema) { @@ -255,18 +313,18 @@ export class Server< if (method === 'tools/call') { const wrappedHandler = async ( request: SchemaOutput, - extra: RequestHandlerExtra + ctx: ContextInterface ): Promise => { const validatedRequest = safeParse(CallToolRequestSchema, request); if (!validatedRequest.success) { const errorMessage = validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid tools/call request: ${errorMessage}`); } const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handlerDecoratorFactory(handler)(request, ctx)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -276,7 +334,7 @@ export class Server< taskValidationResult.error instanceof Error ? taskValidationResult.error.message : String(taskValidationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid task creation result: ${errorMessage}`); } return taskValidationResult.data; } @@ -286,39 +344,46 @@ export class Server< if (!validationResult.success) { const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid tools/call result: ${errorMessage}`); } return validationResult.data; }; // Install the wrapped handler - return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + return super.setRequestHandler(requestSchema, handlerDecoratorFactory(wrappedHandler)); } // Other handlers use default behavior - return super.setRequestHandler(requestSchema, handler); + return super.setRequestHandler(requestSchema, handlerDecoratorFactory(handler)); + } + + // Runtime type guard: ensure extra is our ServerContext + private isContextExtra( + extra: ContextInterface + ): extra is ServerContext { + return extra instanceof ServerContext; } protected assertCapabilityForMethod(method: RequestT['method']): void { switch (method as ServerRequest['method']) { case 'sampling/createMessage': { if (!this._clientCapabilities?.sampling) { - throw new Error(`Client does not support sampling (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('sampling', method); } break; } case 'elicitation/create': { if (!this._clientCapabilities?.elicitation) { - throw new Error(`Client does not support elicitation (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('elicitation', method); } break; } case 'roots/list': { if (!this._clientCapabilities?.roots) { - throw new Error(`Client does not support listing roots (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('roots', method); } break; } @@ -334,7 +399,7 @@ export class Server< switch (method as ServerNotification['method']) { case 'notifications/message': { if (!this._capabilities.logging) { - throw new Error(`Server does not support logging (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('logging', method); } break; } @@ -342,28 +407,28 @@ export class Server< case 'notifications/resources/updated': case 'notifications/resources/list_changed': { if (!this._capabilities.resources) { - throw new Error(`Server does not support notifying about resources (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('resources', method); } break; } case 'notifications/tools/list_changed': { if (!this._capabilities.tools) { - throw new Error(`Server does not support notifying of tool list changes (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('tools', method); } break; } case 'notifications/prompts/list_changed': { if (!this._capabilities.prompts) { - throw new Error(`Server does not support notifying of prompt list changes (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('prompts', method); } break; } case 'notifications/elicitation/complete': { if (!this._clientCapabilities?.elicitation?.url) { - throw new Error(`Client does not support URL elicitation (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('elicitation.url', method); } break; } @@ -390,14 +455,14 @@ export class Server< switch (method) { case 'completion/complete': { if (!this._capabilities.completions) { - throw new Error(`Server does not support completions (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('completions', method); } break; } case 'logging/setLevel': { if (!this._capabilities.logging) { - throw new Error(`Server does not support logging (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('logging', method); } break; } @@ -405,7 +470,7 @@ export class Server< case 'prompts/get': case 'prompts/list': { if (!this._capabilities.prompts) { - throw new Error(`Server does not support prompts (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('prompts', method); } break; } @@ -414,7 +479,7 @@ export class Server< case 'resources/templates/list': case 'resources/read': { if (!this._capabilities.resources) { - throw new Error(`Server does not support resources (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('resources', method); } break; } @@ -422,7 +487,7 @@ export class Server< case 'tools/call': case 'tools/list': { if (!this._capabilities.tools) { - throw new Error(`Server does not support tools (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('tools', method); } break; } @@ -432,7 +497,7 @@ export class Server< case 'tasks/result': case 'tasks/cancel': { if (!this._capabilities.tasks) { - throw new Error(`Server does not support tasks capability (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('tasks', method); } break; } @@ -493,6 +558,40 @@ export class Server< return this._capabilities; } + protected createRequestContext(args: { + request: JSONRPCRequest; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): ContextInterface { + const { request, abortController, capturedTransport, extra } = args; + const sessionId = capturedTransport?.sessionId; + + // Build the MCP context using the helper from Protocol + const mcpContext: McpContext = this.buildMcpContext({ request, sessionId }); + + // Build the server request context with HTTP details (server-specific) + const requestCtx: ServerRequestContext = { + signal: abortController.signal, + authInfo: extra?.authInfo, + // URL is not available in MessageExtraInfo, use a placeholder + uri: new URL('mcp://request'), + headers: extra?.requestInfo?.headers ?? new Headers(), + stream: { + closeSSEStream: extra?.closeSSEStream, + closeStandaloneSSEStream: extra?.closeStandaloneSSEStream + } + }; + + // Return a ServerContext instance (task context is added by plugins if needed) + return new ServerContext({ + server: this, + request, + mcpContext, + requestCtx + }); + } + async ping() { return this.request({ method: 'ping' }, EmptyResultSchema); } @@ -525,7 +624,7 @@ export class Server< ): Promise { // Capability check - only required when tools/toolChoice are provided if ((params.tools || params.toolChoice) && !this._clientCapabilities?.sampling?.tools) { - throw new Error('Client does not support sampling tools capability.'); + throw CapabilityError.clientDoesNotSupport('sampling.tools', 'sampling/createMessage'); } // Message structure validation - always validate tool_use/tool_result pairs. @@ -583,7 +682,7 @@ export class Server< switch (mode) { case 'url': { if (!this._clientCapabilities?.elicitation?.url) { - throw new Error('Client does not support url elicitation.'); + throw CapabilityError.clientDoesNotSupport('elicitation.url', 'elicitation/create'); } const urlParams = params as ElicitRequestURLParams; @@ -591,7 +690,7 @@ export class Server< } case 'form': { if (!this._clientCapabilities?.elicitation?.form) { - throw new Error('Client does not support form elicitation.'); + throw CapabilityError.clientDoesNotSupport('elicitation.form', 'elicitation/create'); } const formParams: ElicitRequestFormParams = @@ -605,17 +704,15 @@ export class Server< const validationResult = validator(result.content); if (!validationResult.valid) { - throw new McpError( - ErrorCode.InvalidParams, + throw ProtocolError.invalidParams( `Elicitation response content does not match requested schema: ${validationResult.errorMessage}` ); } } catch (error) { - if (error instanceof McpError) { + if (error instanceof ProtocolError) { throw error; } - throw new McpError( - ErrorCode.InternalError, + throw ProtocolError.internalError( `Error validating elicitation response: ${error instanceof Error ? error.message : String(error)}` ); } @@ -635,7 +732,7 @@ export class Server< */ createElicitationCompletionNotifier(elicitationId: string, options?: NotificationOptions): () => Promise { if (!this._clientCapabilities?.elicitation?.url) { - throw new Error('Client does not support URL elicitation (required for notifications/elicitation/complete)'); + throw CapabilityError.clientDoesNotSupport('elicitation.url', 'notifications/elicitation/complete'); } return () => diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index ae8bad97e..a51c5da4b 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -17,6 +17,7 @@ import { isJSONRPCRequest, isJSONRPCResultResponse, JSONRPCMessageSchema, + StateError, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; @@ -244,7 +245,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { */ async start(): Promise { if (this._started) { - throw new Error('Transport already started'); + throw StateError.alreadyConnected(); } this._started = true; } diff --git a/packages/server/src/types/types.ts b/packages/server/src/types/types.ts new file mode 100644 index 000000000..10ff99e26 --- /dev/null +++ b/packages/server/src/types/types.ts @@ -0,0 +1,155 @@ +import type { + AnyObjectSchema, + AnySchema, + GetPromptResult, + ReadResourceResult, + ServerNotification, + ServerRequest, + ShapeOutput, + ToolAnnotations, + ToolExecution, + Variables, + ZodRawShapeCompat +} from '@modelcontextprotocol/core'; + +import type { ServerContextInterface } from '../server/context.js'; +import type { AnyToolHandler, ResourceMetadata, ResourceTemplate, ToolCallback } from '../server/mcp.js'; + +/** + * Base interface for all registered definitions + */ +export interface RegisteredDefinition { + /** + * Whether the definition is currently enabled + */ + enabled: boolean; + + /** + * Enable the definition + */ + enable(): void; + + /** + * Disable the definition + */ + disable(): void; + + /** + * Remove the definition from its registry + */ + remove(): void; + + /** + * Update the definition + */ + update(updates: unknown): void; +} + +export interface RegisteredToolInterface extends RegisteredDefinition { + title?: string; + description?: string; + inputSchema?: AnySchema; + outputSchema?: AnySchema; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + handler: AnyToolHandler; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + _meta?: Record; + callback?: ToolCallback; + enabled?: boolean; + }): void; + remove(): void; +} + +/** + * Callback to read a resource at a given URI. + */ +export type ReadResourceCallback = ( + uri: URL, + ctx: ServerContextInterface +) => ReadResourceResult | Promise; + +export interface RegisteredResourceInterface extends RegisteredDefinition { + name: string; + title?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string; + title?: string; + uri?: string | null; + metadata?: ResourceMetadata; + callback?: ReadResourceCallback; + enabled?: boolean; + }): void; + remove(): void; +} + +/** + * Callback to read a resource at a given URI, following a filled-in URI template. + */ +export type ReadResourceTemplateCallback = ( + uri: URL, + variables: Variables, + ctx: ServerContextInterface +) => ReadResourceResult | Promise; + +export interface RegisteredResourceTemplateInterface extends RegisteredDefinition { + resourceTemplate: ResourceTemplate; + title?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + template?: ResourceTemplate; + metadata?: ResourceMetadata; + callback?: ReadResourceTemplateCallback; + enabled?: boolean; + }): void; + remove(): void; +} + +export type PromptArgsRawShape = ZodRawShapeCompat; + +export type PromptCallback = Args extends PromptArgsRawShape + ? ( + args: ShapeOutput, + ctx: ServerContextInterface + ) => GetPromptResult | Promise + : (ctx: ServerContextInterface) => GetPromptResult | Promise; + +export interface RegisteredPromptInterface extends RegisteredDefinition { + title?: string; + description?: string; + argsSchema?: AnyObjectSchema; + callback: PromptCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + description?: string; + argsSchema?: Args; + callback?: PromptCallback; + enabled?: boolean; + }): void; + remove(): void; +} diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 4213eb9ff..04288d9f7 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -57,8 +57,8 @@ linkWorkspacePackages: deep minimumReleaseAge: 10080 # 7 days minimumReleaseAgeExclude: - '@modelcontextprotocol/conformance' - - hono@4.11.4 # fixes https://github.com/advisories/GHSA-3vhc-576x-3qv4 https://github.com/advisories/GHSA-f67f-6cw9-8mq4 - - '@hono/node-server@1.19.9' # https://github.com/honojs/node-server/pull/295 + - hono@4.11.4 # fixes https://github.com/advisories/GHSA-3vhc-576x-3qv4 https://github.com/advisories/GHSA-f67f-6cw9-8mq4 + - '@hono/node-server@1.19.9' # https://github.com/honojs/node-server/pull/295 onlyBuiltDependencies: - better-sqlite3 diff --git a/src/conformance/everything-server.ts b/src/conformance/everything-server.ts index 7f75ae3e2..a7d7a8f21 100644 --- a/src/conformance/everything-server.ts +++ b/src/conformance/everything-server.ts @@ -11,13 +11,17 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult, GetPromptResult, ReadResourceResult, EventId, EventStore, StreamId } from '@modelcontextprotocol/server'; import { + audio, CompleteRequestSchema, ElicitResultSchema, + embeddedResource, + image, isInitializeRequest, - SetLevelRequestSchema, McpServer, ResourceTemplate, + SetLevelRequestSchema, SubscribeRequestSchema, + text, UnsubscribeRequestSchema } from '@modelcontextprotocol/server'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; @@ -127,7 +131,7 @@ function createMcpServer(sessionId?: string) { }, async (): Promise => { return { - content: [{ type: 'text', text: 'This is a simple text response for testing.' }] + content: [text('This is a simple text response for testing.')] }; } ); @@ -140,7 +144,7 @@ function createMcpServer(sessionId?: string) { }, async (): Promise => { return { - content: [{ type: 'image', data: TEST_IMAGE_BASE64, mimeType: 'image/png' }] + content: [image(TEST_IMAGE_BASE64, 'image/png')] }; } ); @@ -153,7 +157,7 @@ function createMcpServer(sessionId?: string) { }, async (): Promise => { return { - content: [{ type: 'audio', data: TEST_AUDIO_BASE64, mimeType: 'audio/wav' }] + content: [audio(TEST_AUDIO_BASE64, 'audio/wav')] }; } ); @@ -167,14 +171,11 @@ function createMcpServer(sessionId?: string) { async (): Promise => { return { content: [ - { - type: 'resource', - resource: { - uri: 'test://embedded-resource', - mimeType: 'text/plain', - text: 'This is an embedded resource content.' - } - } + embeddedResource({ + uri: 'test://embedded-resource', + mimeType: 'text/plain', + text: 'This is an embedded resource content.' + }) ] }; } @@ -189,16 +190,13 @@ function createMcpServer(sessionId?: string) { async (): Promise => { return { content: [ - { type: 'text', text: 'Multiple content types test:' }, - { type: 'image', data: TEST_IMAGE_BASE64, mimeType: 'image/png' }, - { - type: 'resource', - resource: { - uri: 'test://mixed-content-resource', - mimeType: 'application/json', - text: JSON.stringify({ test: 'data', value: 123 }) - } - } + text('Multiple content types test:'), + image(TEST_IMAGE_BASE64, 'image/png'), + embeddedResource({ + uri: 'test://mixed-content-resource', + mimeType: 'application/json', + text: JSON.stringify({ test: 'data', value: 123 }) + }) ] }; } @@ -211,8 +209,8 @@ function createMcpServer(sessionId?: string) { description: 'Tests tool that emits log messages during execution', inputSchema: {} }, - async (_args, extra): Promise => { - await extra.sendNotification({ + async (_args, ctx): Promise => { + await ctx.sendNotification({ method: 'notifications/message', params: { level: 'info', @@ -221,7 +219,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/message', params: { level: 'info', @@ -230,7 +228,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/message', params: { level: 'info', @@ -238,7 +236,7 @@ function createMcpServer(sessionId?: string) { } }); return { - content: [{ type: 'text', text: 'Tool with logging executed successfully' }] + content: [text('Tool with logging executed successfully')] }; } ); @@ -250,10 +248,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests tool that reports progress notifications', inputSchema: {} }, - async (_args, extra): Promise => { - const progressToken = extra._meta?.progressToken ?? 0; + async (_args, ctx): Promise => { + const progressToken = ctx.mcpCtx._meta?.progressToken ?? 0; console.log('Progress token:', progressToken); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/progress', params: { progressToken, @@ -264,7 +262,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/progress', params: { progressToken, @@ -275,7 +273,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/progress', params: { progressToken, @@ -286,7 +284,7 @@ function createMcpServer(sessionId?: string) { }); return { - content: [{ type: 'text', text: String(progressToken) }] + content: [text(String(progressToken))] }; } ); @@ -310,23 +308,23 @@ function createMcpServer(sessionId?: string) { 'Tests SSE stream disconnection and client reconnection (SEP-1699). Server will close the stream mid-call and send the result after client reconnects.', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - console.log(`[${extra.sessionId}] Starting test_reconnection tool...`); + console.log(`[${ctx.mcpCtx.sessionId}] Starting test_reconnection tool...`); // Get the transport for this session - const transport = extra.sessionId ? transports[extra.sessionId] : undefined; - if (transport && extra.requestId) { + const transport = ctx.mcpCtx.sessionId ? transports[ctx.mcpCtx.sessionId] : undefined; + if (transport && ctx.mcpCtx.requestId) { // Close the SSE stream to trigger client reconnection - console.log(`[${extra.sessionId}] Closing SSE stream to trigger client polling...`); - transport.closeSSEStream(extra.requestId); + console.log(`[${ctx.mcpCtx.sessionId}] Closing SSE stream to trigger client polling...`); + transport.closeSSEStream(ctx.mcpCtx.requestId); } // Wait for client to reconnect (should respect retry field) await sleep(100); - console.log(`[${extra.sessionId}] test_reconnection tool complete`); + console.log(`[${ctx.mcpCtx.sessionId}] test_reconnection tool complete`); return { content: [ @@ -348,10 +346,10 @@ function createMcpServer(sessionId?: string) { prompt: z.string().describe('The prompt to send to the LLM') } }, - async (args: { prompt: string }, extra): Promise => { + async (args: { prompt: string }, ctx): Promise => { try { // Request sampling from client - const result = (await extra.sendRequest( + const result = (await ctx.sendRequest( { method: 'sampling/createMessage', params: { @@ -402,10 +400,10 @@ function createMcpServer(sessionId?: string) { message: z.string().describe('The message to show the user') } }, - async (args: { message: string }, extra): Promise => { + async (args: { message: string }, ctx): Promise => { try { // Request user input from client - const result = await extra.sendRequest( + const result = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -454,10 +452,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests elicitation with default values per SEP-1034', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { try { // Request user input with default values for all primitive types - const result = await extra.sendRequest( + const result = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -528,10 +526,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests elicitation with enum schema improvements per SEP-1330', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { try { // Request user input with all 5 enum schema variants - const result = await extra.sendRequest( + const result = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -819,21 +817,15 @@ function createMcpServer(sessionId?: string) { messages: [ { role: 'user', - content: { - type: 'resource', - resource: { - uri: args.resourceUri, - mimeType: 'text/plain', - text: 'Embedded resource content for testing.' - } - } + content: embeddedResource({ + uri: args.resourceUri, + mimeType: 'text/plain', + text: 'Embedded resource content for testing.' + }) }, { role: 'user', - content: { - type: 'text', - text: 'Please process the embedded resource above.' - } + content: text('Please process the embedded resource above.') } ] }; @@ -852,15 +844,11 @@ function createMcpServer(sessionId?: string) { messages: [ { role: 'user', - content: { - type: 'image', - data: TEST_IMAGE_BASE64, - mimeType: 'image/png' - } + content: image(TEST_IMAGE_BASE64, 'image/png') }, { role: 'user', - content: { type: 'text', text: 'Please analyze the image above.' } + content: text('Please analyze the image above.') } ] }; diff --git a/test/integration/test/client/client.test.ts b/test/integration/test/client/client.test.ts index ed1ea7d67..6aa5db8fa 100644 --- a/test/integration/test/client/client.test.ts +++ b/test/integration/test/client/client.test.ts @@ -19,8 +19,8 @@ import { ListRootsRequestSchema, ListToolsRequestSchema, ListToolsResultSchema, - McpError, NotificationSchema, + ProtocolError, RequestSchema, ResultSchema, SUPPORTED_PROTOCOL_VERSIONS @@ -1218,7 +1218,7 @@ test('should handle client cancelling a request', async () => { ); // Set up server to delay responding to listResources - server.setRequestHandler(ListResourcesRequestSchema, async (request, extra) => { + server.setRequestHandler(ListResourcesRequestSchema, async (request, ctx) => { await new Promise(resolve => setTimeout(resolve, 1000)); return { resources: [] @@ -1248,8 +1248,8 @@ test('should handle client cancelling a request', async () => { }); controller.abort('Cancelled by test'); - // Request should be rejected with an McpError - await expect(listResourcesPromise).rejects.toThrow(McpError); + // Request should be rejected with an ProtocolError + await expect(listResourcesPromise).rejects.toThrow(ProtocolError); }); /*** @@ -1269,10 +1269,10 @@ test('should handle request timeout', async () => { ); // Set up server with a delayed response - server.setRequestHandler(ListResourcesRequestSchema, async (_request, extra) => { + server.setRequestHandler(ListResourcesRequestSchema, async (_request, ctx) => { const timer = new Promise(resolve => { const timeout = setTimeout(resolve, 100); - extra.signal.addEventListener('abort', () => clearTimeout(timeout)); + ctx.requestCtx.signal.addEventListener('abort', () => clearTimeout(timeout)); }); await timer; @@ -2442,27 +2442,27 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2518,27 +2518,27 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Success!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2595,27 +2595,27 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Result data!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2676,27 +2676,27 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Success!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2773,18 +2773,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2866,18 +2866,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2958,18 +2958,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'result-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -3049,18 +3049,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -3161,27 +3161,27 @@ describe('Task-based execution', () => { } }, { - async createTask({ id }, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask({ id }, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); const result = { content: [{ type: 'text', text: `Result for ${id || 'unknown'}` }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -3429,27 +3429,27 @@ test('should respect server task capabilities', async () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); const result = { content: [{ type: 'text', text: 'Success!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -4231,7 +4231,7 @@ describe('Client sampling validation with tools', () => { expect(result.stopReason).toBe('toolUse'); expect(Array.isArray(result.content)).toBe(true); - expect((result.content as Array<{ type: string }>)[0].type).toBe('tool_use'); + expect((result.content as Array<{ type: string }>)[0]!.type).toBe('tool_use'); }); test('should validate single content when request includes tools', async () => { diff --git a/test/integration/test/experimental/tasks/taskListing.test.ts b/test/integration/test/experimental/tasks/taskListing.test.ts index 28b39bb3b..6d3f1c6ef 100644 --- a/test/integration/test/experimental/tasks/taskListing.test.ts +++ b/test/integration/test/experimental/tasks/taskListing.test.ts @@ -1,4 +1,4 @@ -import { ErrorCode, McpError } from '@modelcontextprotocol/core'; +import { ErrorCode, ProtocolError } from '@modelcontextprotocol/core'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { createInMemoryTaskEnvironment } from '../../helpers/mcp.js'; @@ -88,8 +88,8 @@ describe('Task Listing with Pagination', () => { }); // Try to use an invalid cursor - should return -32602 (Invalid params) per MCP spec - await expect(client.experimental.tasks.listTasks('invalid-cursor')).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + await expect(client.experimental.tasks.listTasks('invalid-cursor')).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Invalid cursor'); return true; diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index 0b89898ba..76f1989b4 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -19,12 +19,13 @@ import { ElicitResultSchema, ErrorCode, InMemoryTransport, + isTextContent, LATEST_PROTOCOL_VERSION, ListPromptsRequestSchema, ListResourcesRequestSchema, ListToolsRequestSchema, - McpError, NotificationSchema, + ProtocolError, RequestSchema, ResultSchema, SetLevelRequestSchema, @@ -1460,8 +1461,8 @@ test('should handle server cancelling a request', async () => { ); controller.abort('Cancelled by test'); - // Request should be rejected with an McpError - await expect(createMessagePromise).rejects.toThrow(McpError); + // Request should be rejected with an ProtocolError + await expect(createMessagePromise).rejects.toThrow(ProtocolError); }); test('should handle request timeout', async () => { @@ -1488,12 +1489,12 @@ test('should handle request timeout', async () => { } ); - client.setRequestHandler(CreateMessageRequestSchema, async (_request, extra) => { + client.setRequestHandler(CreateMessageRequestSchema, async (_request, ctx) => { await new Promise((resolve, reject) => { const timeout = setTimeout(resolve, 100); - extra.signal.addEventListener('abort', () => { + ctx.requestCtx.signal.addEventListener('abort', () => { clearTimeout(timeout); - reject(extra.signal.reason); + reject(ctx.requestCtx.signal.reason); }); }); @@ -1952,7 +1953,7 @@ describe('createMessage backwards compatibility', () => { expect(result.model).toBe('test-model'); expect(Array.isArray(result.content)).toBe(false); expect(result.content.type).toBe('text'); - if (result.content.type === 'text') { + if (isTextContent(result.content)) { expect(result.content.text).toBe('Hello from LLM'); } }); @@ -2248,9 +2249,9 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); // Simulate some async work @@ -2259,20 +2260,20 @@ describe('Task-based execution', () => { const result = { content: [{ type: 'text', text: 'Tool executed successfully!' }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); })(); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2440,13 +2441,13 @@ describe('Task-based execution', () => { let capturedElicitRequest: z4.infer | null = null; // Set up client elicitation handler - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { let taskId: string | undefined; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const createdTask = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const createdTask = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); taskId = createdTask.taskId; } @@ -2470,15 +2471,15 @@ describe('Task-based execution', () => { inputSchema: {} }, { - async createTask(_args, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask(_args, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); // Perform async work that makes a nested request (async () => { - // During tool execution, make a nested request to the client using extra.sendRequest - const elicitResult = await extra.sendRequest( + // During tool execution, make a nested request to the client using ctx.sendRequest + const elicitResult = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -2504,20 +2505,20 @@ describe('Task-based execution', () => { } ] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); })(); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -2594,18 +2595,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'server-test-user', confirmed: true } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2675,18 +2676,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2754,18 +2755,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'result-user', confirmed: true } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2835,18 +2836,18 @@ describe('Task-based execution', () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'list-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } @@ -2949,9 +2950,9 @@ describe('Task-based execution', () => { } }, { - async createTask({ delay, taskNum }, extra) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + async createTask({ delay, taskNum }, ctx) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); // Simulate async work @@ -2960,20 +2961,20 @@ describe('Task-based execution', () => { const result = { content: [{ type: 'text', text: `Completed task ${taskNum || 'unknown'}` }] }; - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); })(); return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -3178,18 +3179,18 @@ test('should respect client task capabilities', async () => { } ); - client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + client.setRequestHandler(ElicitRequestSchema, async (request, ctx) => { const result = { action: 'accept', content: { username: 'test-user' } }; // Check if task creation is requested - if (request.params.task && extra.taskStore) { - const task = await extra.taskStore.createTask({ - ttl: extra.taskRequestedTtl + if (request.params.task && ctx.taskCtx!.store) { + const task = await ctx.taskCtx!.store.createTask({ + ttl: ctx.taskCtx!.requestedTtl }); - await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', result); // Return CreateTaskResult when task creation is requested return { task }; } diff --git a/test/integration/test/server/context.test.ts b/test/integration/test/server/context.test.ts new file mode 100644 index 000000000..5b3997676 --- /dev/null +++ b/test/integration/test/server/context.test.ts @@ -0,0 +1,278 @@ +import { Client } from '@modelcontextprotocol/client'; +import type { BaseRequestContext, ContextInterface, ServerNotification, ServerRequest } from '@modelcontextprotocol/core'; +import { + CallToolResultSchema, + GetPromptResultSchema, + InMemoryTransport, + ListResourcesResultSchema, + LoggingMessageNotificationSchema, + ReadResourceResultSchema +} from '@modelcontextprotocol/core'; +import { McpServer, ResourceTemplate, ServerContext } from '@modelcontextprotocol/server'; +import { z } from 'zod/v4'; + +describe('ServerContext', () => { + /*** + * Test: `ctx` provided to callbacks is ServerContext (parameterized) + */ + type Seen = { isContext: boolean; hasRequestId: boolean }; + const contextCases: Array<[string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise]> = + [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + (_args: { name: string }, ctx) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + mcpServer.registerResource('ctx-resource', 'test://res/1', { title: 'ctx-resource' }, async (_uri, ctx) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + }); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + const template = new ResourceTemplate('test://items/{id}', { + list: async ctx => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _ctx) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + mcpServer.registerPrompt('ctx-prompt', {}, async ctx => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; + return { messages: [] }; + }); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass ServerContext as ctx to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + + const logLevelsThroughContext = ['debug', 'info', 'warning', 'error'] as const; + + //it.each for each log level, test that logging message is sent to client + it.each(logLevelsThroughContext)('should send logging message to client for %s level from ServerContext', async level => { + const mcpServer = new McpServer( + { name: 'ctx-test', version: '1.0' }, + { + capabilities: { + logging: {} + } + } + ); + const client = new Client( + { name: 'ctx-client', version: '1.0' }, + { + capabilities: {} + } + ); + + let seen = 0; + + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + seen++; + expect(notification.params.level).toBe(level); + expect(notification.params.data).toBe('Test message'); + expect(notification.params._meta?.test).toBe('test'); + expect(notification.params._meta?.sessionId).toBe('sample-session-id'); + return; + }); + + mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, ctx) => { + const serverCtx = ctx as ServerContext; + await serverCtx.loggingNotification[level]('Test message', { test: 'test' }, 'sample-session-id'); + await serverCtx.loggingNotification.log( + { + level, + data: 'Test message', + logger: 'test-logger-namespace', + _meta: { + test: 'test', + sessionId: 'sample-session-id' + } + }, + 'sample-session-id' + ); + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { name: 'ctx-log-test', arguments: { name: 'ctx-log-test-name' } } + }, + CallToolResultSchema + ); + + // two messages should have been sent - one from the .log method and one from the .debug/info/warning/error method + expect(seen).toBe(2); + + expect(result.content).toHaveLength(1); + expect(result.content[0]).toMatchObject({ + type: 'text', + text: 'ok' + }); + }); + describe('ContextInterface API', () => { + const contextCases: Array< + [string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise] + > = [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + // The test is to ensure that the ctx is compatible with the ContextInterface type + (_args: { name: string }, ctx: ContextInterface) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + // The test is to ensure that the ctx is compatible with the ContextInterface type + mcpServer.registerResource( + 'ctx-resource', + 'test://res/1', + { title: 'ctx-resource' }, + async (_uri, ctx: ContextInterface) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + } + ); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + // The test is to ensure that the ctx is compatible with the ContextInterface type + const template = new ResourceTemplate('test://items/{id}', { + list: async (ctx: ContextInterface) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _ctx) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + // The test is to ensure that the ctx is compatible with the ContextInterface type + mcpServer.registerPrompt( + 'ctx-prompt', + {}, + async (ctx: ContextInterface) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; + return { messages: [] }; + } + ); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass ServerContext as ctx to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + }); +}); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 5d811848b..db3643cfd 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1,5 +1,12 @@ import { Client } from '@modelcontextprotocol/client'; -import type { CallToolResult, Notification, TextContent } from '@modelcontextprotocol/core'; +import type { + CallToolResult, + ContextInterface, + Notification, + ServerNotification, + ServerRequest, + TextContent +} from '@modelcontextprotocol/core'; import { CallToolResultSchema, CompleteResultSchema, @@ -15,10 +22,11 @@ import { ListToolsResultSchema, LoggingMessageNotificationSchema, ReadResourceResultSchema, + text, UriTemplate, UrlElicitationRequiredError } from '@modelcontextprotocol/core'; -import { completable, McpServer, ResourceTemplate } from '@modelcontextprotocol/server'; +import { completable, McpServer, ResourceTemplate, ServerContext } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { zodTestMatrix } from '@modelcontextprotocol/test-helpers'; import { afterEach, beforeEach, describe, expect, test } from 'vitest'; @@ -117,13 +125,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { steps: z.number().min(1).describe('Number of steps to perform') } }, - async ({ steps }, { sendNotification, _meta }) => { - const progressToken = _meta?.progressToken; + async ({ steps }, ctx) => { + const progressToken = ctx.mcpCtx._meta?.progressToken; if (progressToken) { // Send progress notification for each step for (let i = 1; i <= steps; i++) { - await sendNotification({ + await ctx.sendNotification({ method: 'notifications/progress', params: { progressToken, @@ -244,7 +252,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { sendNotification: () => { throw new Error('Not implemented'); } - }); + } as unknown as ContextInterface); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); }); @@ -509,7 +517,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: '' }], + content: [text('')], structuredContent: { result: 42 } @@ -523,7 +531,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { sum: z.number() }, callback: async () => ({ - content: [{ type: 'text', text: '' }], + content: [text('')], structuredContent: { result: 42, sum: 100 @@ -653,7 +661,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { inputSchema: { name: z.string(), value: z.number() } }, async ({ name, value }) => ({ - content: [{ type: 'text', text: `${name}: ${value}` }] + content: [text(`${name}: ${value}`)] }) ); @@ -802,7 +810,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { annotations: { title: 'Test Tool', readOnlyHint: true } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [text(`Hello, ${name}!`)] }) ); @@ -849,7 +857,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [text(`Hello, ${name}!`)] }) ); @@ -897,7 +905,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Test response' }] + content: [text('Test response')] }) ); @@ -1325,7 +1333,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { /*** * Test: Pass Session ID to Tool Callback */ - test('should pass sessionId to tool callback via RequestHandlerExtra', async () => { + test('should pass sessionId to tool callback via context', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -1337,8 +1345,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedSessionId: string | undefined; - mcpServer.registerTool('test-tool', {}, async extra => { - receivedSessionId = extra.sessionId; + mcpServer.registerTool('test-tool', {}, async ctx => { + receivedSessionId = ctx.mcpCtx.sessionId; return { content: [ { @@ -1371,7 +1379,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { /*** * Test: Pass Request ID to Tool Callback */ - test('should pass requestId to tool callback via RequestHandlerExtra', async () => { + test('should pass requestId to tool callback via context', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -1383,13 +1391,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.registerTool('request-id-test', {}, async extra => { - receivedRequestId = extra.requestId; + mcpServer.registerTool('request-id-test', {}, async ctx => { + receivedRequestId = ctx.mcpCtx.requestId; return { content: [ { type: 'text', - text: `Received request ID: ${extra.requestId}` + text: `Received request ID: ${ctx.mcpCtx.requestId}` } ] }; @@ -1573,9 +1581,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); /*** - * Test: McpError for Invalid Tool Name + * Test: ProtocolError for Invalid Tool Name */ - test('should throw McpError for invalid tool name', async () => { + test('should throw ProtocolError for invalid tool name', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -1702,7 +1710,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { _meta: metaData }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [text(`Hello, ${name}!`)] }) ); @@ -1738,7 +1746,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { inputSchema: { name: z.string() } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [text(`Hello, ${name}!`)] }) ); @@ -1792,17 +1800,19 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async (_args, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000 }); + createTask: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000 }); return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) throw new Error('Task not found'); + getTask: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.getTask(ctx.taskCtx.id!); return task; }, - getTaskResult: async (_args, extra) => { - return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + getTaskResult: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + return (await ctx.taskCtx.store.getTaskResult(ctx.taskCtx.id!)) as CallToolResult; } } ); @@ -1861,17 +1871,18 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async (_args, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000 }); + createTask: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000 }); return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + getTask: async (_args, ctx) => { + const task = await ctx.taskCtx?.store?.getTask(ctx.taskCtx.id!); if (!task) throw new Error('Task not found'); return task; }, - getTaskResult: async (_args, extra) => { - return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + getTaskResult: async (_args, ctx) => { + return (await ctx.taskCtx?.store?.getTaskResult(ctx.taskCtx.id!)) as CallToolResult; } } ); @@ -1907,7 +1918,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { { description: 'A valid tool name' }, - async () => ({ content: [{ type: 'text', text: 'Success' }] }) + async () => ({ content: [text('Success')] }) ); // Test tool name with warnings (starts with dash) @@ -1916,7 +1927,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { { description: 'A tool name that generates warnings' }, - async () => ({ content: [{ type: 'text', text: 'Success' }] }) + async () => ({ content: [text('Success')] }) ); // Test invalid tool name (contains spaces) @@ -1925,7 +1936,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { { description: 'An invalid tool name' }, - async () => ({ content: [{ type: 'text', text: 'Success' }] }) + async () => ({ content: [text('Success')] }) ); // Verify that warnings were issued (both for warnings and validation failures) @@ -2614,9 +2625,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); /*** - * Test: McpError for Invalid Resource URI + * Test: ProtocolError for Invalid Resource URI */ - test('should throw McpError for invalid resource URI', async () => { + test('should throw ProtocolError for invalid resource URI', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -2846,7 +2857,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { /*** * Test: Pass Request ID to Resource Callback */ - test('should pass requestId to resource callback via RequestHandlerExtra', async () => { + test('should pass requestId to resource callback via context', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -2858,13 +2869,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.registerResource('request-id-test', 'test://resource', {}, async (_uri, extra) => { - receivedRequestId = extra.requestId; + mcpServer.registerResource('request-id-test', 'test://resource', {}, async (_uri, ctx) => { + receivedRequestId = ctx.mcpCtx.requestId; return { contents: [ { uri: 'test://resource', - text: `Received request ID: ${extra.requestId}` + text: `Received request ID: ${ctx.mcpCtx.requestId}` } ] }; @@ -3539,9 +3550,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); /*** - * Test: McpError for Invalid Prompt Name + * Test: ProtocolError for Invalid Prompt Name */ - test('should throw McpError for invalid prompt name', async () => { + test('should throw ProtocolError for invalid prompt name', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -3784,7 +3795,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { /*** * Test: Pass Request ID to Prompt Callback */ - test('should pass requestId to prompt callback via RequestHandlerExtra', async () => { + test('should pass requestId to prompt callback via context', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -3796,15 +3807,15 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.registerPrompt('request-id-test', {}, async extra => { - receivedRequestId = extra.requestId; + mcpServer.registerPrompt('request-id-test', {}, async ctx => { + receivedRequestId = ctx.mcpCtx.requestId; return { messages: [ { role: 'assistant', content: { type: 'text', - text: `Received request ID: ${extra.requestId}` + text: `Received request ID: ${ctx.mcpCtx.requestId}` } } ] @@ -4031,7 +4042,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Tool 1: Only name mcpServer.registerTool('tool_name_only', {}, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] })); // Tool 2: Name and annotations.title @@ -4044,7 +4055,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -4056,7 +4067,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { description: 'Tool with regular title' }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -4071,7 +4082,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -4311,17 +4322,20 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }) } }, - async ({ department, name }) => ({ - messages: [ - { - role: 'assistant', - content: { - type: 'text', - text: `Hello ${name}, welcome to the ${department} team!` + async (args, ctx: ContextInterface) => { + expect(ctx).toBeInstanceOf(ServerContext); + return { + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${args.name}, welcome to the ${args.department} team!` + } } - } - ] - }) + ] + }; + } ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -5321,7 +5335,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Tool 1: Only name mcpServer.registerTool('tool_name_only', {}, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] })); // Tool 2: Name and annotations.title @@ -5334,7 +5348,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -5346,7 +5360,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { description: 'Tool with regular title' }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -5361,7 +5375,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -5959,10 +5973,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { server.registerTool('contact', { inputSchema: unionSchema }, async args => { return args.type === 'email' ? { - content: [{ type: 'text', text: `Email contact: ${args.email}` }] + content: [text(`Email contact: ${args.email}`)] } : { - content: [{ type: 'text', text: `Phone contact: ${args.phone}` }] + content: [text(`Phone contact: ${args.phone}`)] }; }); @@ -6123,7 +6137,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { server.registerTool('union-test', { inputSchema: unionSchema }, async () => { return { - content: [{ type: 'text', text: 'Success' }] + content: [text('Success')] }; }); @@ -6224,11 +6238,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async ({ input }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ({ input }, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.taskCtx.store; // Simulate async work setTimeout(async () => { @@ -6239,15 +6254,14 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.getTask(ctx.taskCtx.id!); return task; }, - getTaskResult: async (_input, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async (_input, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const result = await ctx.taskCtx.store.getTaskResult(ctx.taskCtx.id!); return result as CallToolResult; } } @@ -6329,11 +6343,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async ({ value }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ({ value }, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.taskCtx.store; // Simulate async work setTimeout(async () => { @@ -6345,15 +6360,14 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.getTask(ctx.taskCtx.id!); return task; }, - getTaskResult: async (_value, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async (_value, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const result = await ctx.taskCtx.store.getTaskResult(ctx.taskCtx.id!); return result as CallToolResult; } } @@ -6437,14 +6451,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async ({ data }, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ({ data }, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.taskCtx.store; // Simulate async work setTimeout(async () => { + if (!store) throw new Error('Task store not found'); await store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text' as const, text: `Completed: ${data}` }] }); @@ -6453,15 +6469,14 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.getTask(ctx.taskCtx.id!); return task; }, - getTaskResult: async (_data, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async (_data, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const result = await ctx.taskCtx.store.getTaskResult(ctx.taskCtx.id!); return result as CallToolResult; } } @@ -6554,11 +6569,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async extra => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ctx => { + if (!ctx.taskCtx) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.taskCtx.store; // Simulate async failure setTimeout(async () => { @@ -6571,15 +6587,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async extra => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async ctx => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.getTask(ctx.taskCtx.id!); return task; }, - getTaskResult: async extra => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async ctx => { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id!); return result as CallToolResult; } } @@ -6660,11 +6674,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async extra => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async ctx => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000, pollInterval: 100 }); // Capture taskStore for use in setTimeout - const store = extra.taskStore; + const store = ctx.taskCtx.store; // Simulate async cancellation setTimeout(async () => { @@ -6674,15 +6689,14 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, - getTask: async extra => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async ctx => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.getTask(ctx.taskCtx.id!); return task; }, - getTaskResult: async extra => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async ctx => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const result = await ctx.taskCtx.store.getTaskResult(ctx.taskCtx.id!); return result as CallToolResult; } } @@ -6747,19 +6761,19 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, { - createTask: async (_args, extra) => { - const task = await extra.taskStore.createTask({ ttl: 60_000, pollInterval: 100 }); + createTask: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000, pollInterval: 100 }); return { task }; }, - getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); - if (!task) { - throw new Error('Task not found'); - } + getTask: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const task = await ctx.taskCtx.store.getTask(ctx.taskCtx.id!); return task; }, - getTaskResult: async (_args, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + getTaskResult: async (_args, ctx) => { + if (!ctx.taskCtx?.store) throw new Error('Task store not found'); + const result = await ctx.taskCtx.store.getTaskResult(ctx.taskCtx.id!); return result as CallToolResult; } } diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index 148082c93..68265b2f4 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -13,8 +13,8 @@ import { ErrorCode, InMemoryTaskMessageQueue, InMemoryTaskStore, - McpError, McpServer, + ProtocolError, RELATED_TASK_META_KEY, TaskSchema } from '@modelcontextprotocol/server'; @@ -64,8 +64,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ duration, shouldFail }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ duration, shouldFail }, ctx) { + const task = await ctx.taskCtx!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -76,11 +76,11 @@ describe('Task Lifecycle Integration Tests', () => { try { await (shouldFail - ? extra.taskStore.storeTaskResult(task.taskId, 'failed', { + ? ctx.taskCtx!.store.storeTaskResult(task.taskId, 'failed', { content: [{ type: 'text', text: 'Task failed as requested' }], isError: true }) - : extra.taskStore.storeTaskResult(task.taskId, 'completed', { + : ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Completed after ${duration}ms` }] })); } catch { @@ -90,15 +90,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -115,8 +115,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ userName }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ userName }, ctx) { + const task = await ctx.taskCtx!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -129,14 +129,14 @@ describe('Task Lifecycle Integration Tests', () => { if (userName) { // Complete immediately if userName was provided try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Hello, ${userName}!` }] }); } catch { // Task may have been cleaned up if test ended } } else { - const elicitationResult = await extra.sendRequest( + const elicitationResult = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -161,7 +161,7 @@ describe('Task Lifecycle Integration Tests', () => { ? elicitationResult.content.userName : 'Unknown'; try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Hello, ${name}!` }] }); } catch { @@ -172,15 +172,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -395,8 +395,8 @@ describe('Task Lifecycle Integration Tests', () => { expect(task.status).toBe('completed'); // Try to cancel via tasks/cancel request (should fail with -32602) - await expect(client.experimental.tasks.cancelTask(taskId)).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + await expect(client.experimental.tasks.cancelTask(taskId)).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Cannot cancel task in terminal status'); return true; @@ -419,8 +419,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ requestCount }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ requestCount }, ctx) { + const task = await ctx.taskCtx!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -433,7 +433,7 @@ describe('Task Lifecycle Integration Tests', () => { // Send multiple elicitation requests for (let i = 0; i < requestCount; i++) { - const elicitationResult = await extra.sendRequest( + const elicitationResult = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -459,7 +459,7 @@ describe('Task Lifecycle Integration Tests', () => { // Complete with all responses try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Received responses: ${responses.join(', ')}` }] }); } catch { @@ -469,15 +469,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -789,8 +789,8 @@ describe('Task Lifecycle Integration Tests', () => { await client.connect(transport); // Try to get non-existent task via tasks/get request - await expect(client.experimental.tasks.getTask('non-existent-task-id')).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + await expect(client.experimental.tasks.getTask('non-existent-task-id')).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Task not found'); return true; @@ -809,8 +809,8 @@ describe('Task Lifecycle Integration Tests', () => { await client.connect(transport); // Try to cancel non-existent task via tasks/cancel request - await expect(client.experimental.tasks.cancelTask('non-existent-task-id')).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + await expect(client.experimental.tasks.cancelTask('non-existent-task-id')).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Task not found'); return true; @@ -837,8 +837,8 @@ describe('Task Lifecycle Integration Tests', () => { }, CallToolResultSchema ) - ).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + ).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Task not found'); return true; @@ -908,8 +908,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ messageCount }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ messageCount }, ctx) { + const task = await ctx.taskCtx!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -922,28 +922,26 @@ describe('Task Lifecycle Integration Tests', () => { // Queue multiple elicitation requests for (let i = 0; i < messageCount; i++) { // Send request but don't await - let it queue - extra - .sendRequest( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: `Message ${i + 1} of ${messageCount}`, - requestedSchema: { - type: 'object', - properties: { - response: { type: 'string' } - }, - required: ['response'] - } + ctx.sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: `Message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] } - }, - ElicitResultSchema, - { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions - ) - .catch(() => { - // Ignore errors from cancelled requests - }); + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ).catch(() => { + // Ignore errors from cancelled requests + }); } // Don't complete - let the task be cancelled @@ -958,15 +956,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -1106,8 +1104,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ messageCount, delayBetweenMessages }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ messageCount, delayBetweenMessages }, ctx) { + const task = await ctx.taskCtx!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -1122,7 +1120,7 @@ describe('Task Lifecycle Integration Tests', () => { // Send messages with delays between them for (let i = 0; i < messageCount; i++) { - const elicitationResult = await extra.sendRequest( + const elicitationResult = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -1153,7 +1151,7 @@ describe('Task Lifecycle Integration Tests', () => { // Complete with all responses try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: `Received all responses: ${responses.join(', ')}` }] }); } catch { @@ -1162,7 +1160,7 @@ describe('Task Lifecycle Integration Tests', () => { } catch (error) { // Handle errors try { - await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'failed', { content: [{ type: 'text', text: `Error: ${error}` }], isError: true }); @@ -1174,15 +1172,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } @@ -1322,8 +1320,8 @@ describe('Task Lifecycle Integration Tests', () => { } }, { - async createTask({ messageCount }, extra) { - const task = await extra.taskStore.createTask({ + async createTask({ messageCount }, ctx) { + const task = await ctx.taskCtx!.store.createTask({ ttl: 60_000, pollInterval: 100 }); @@ -1336,33 +1334,31 @@ describe('Task Lifecycle Integration Tests', () => { for (let i = 0; i < messageCount; i++) { // Start the request but don't wait for response // The request gets queued when sendRequest is called - extra - .sendRequest( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: `Quick message ${i + 1} of ${messageCount}`, - requestedSchema: { - type: 'object', - properties: { - response: { type: 'string' } - }, - required: ['response'] - } + ctx.sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: `Quick message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] } - }, - ElicitResultSchema, - { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions - ) - .catch(() => {}); + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ).catch(() => {}); // Small delay to ensure message is queued before next iteration await new Promise(resolve => setTimeout(resolve, 10)); } // Complete the task after all messages are queued try { - await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'Task completed quickly' }] }); } catch { @@ -1371,7 +1367,7 @@ describe('Task Lifecycle Integration Tests', () => { } catch (error) { // Handle errors try { - await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'failed', { content: [{ type: 'text', text: `Error: ${error}` }], isError: true }); @@ -1383,15 +1379,15 @@ describe('Task Lifecycle Integration Tests', () => { return { task }; }, - async getTask(_args, extra) { - const task = await extra.taskStore.getTask(extra.taskId); + async getTask(_args, ctx) { + const task = await ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); if (!task) { - throw new Error(`Task ${extra.taskId} not found`); + throw new Error(`Task ${ctx.taskCtx!.id} not found`); } return task; }, - async getTaskResult(_args, extra) { - const result = await extra.taskStore.getTaskResult(extra.taskId); + async getTaskResult(_args, ctx) { + const result = await ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); return result as { content: Array<{ type: 'text'; text: string }> }; } } diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts index ce124eb93..f9847989c 100644 --- a/test/integration/test/taskResumability.test.ts +++ b/test/integration/test/taskResumability.test.ts @@ -5,7 +5,7 @@ import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; -import { CallToolResultSchema, LoggingMessageNotificationSchema, McpServer } from '@modelcontextprotocol/server'; +import { CallToolResultSchema, LoggingMessageNotificationSchema, McpServer, text } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -79,7 +79,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); return { - content: [{ type: 'text', text: 'Notification sent' }] + content: [text('Notification sent')] }; } ); @@ -112,7 +112,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } return { - content: [{ type: 'text', text: `Sent ${count} notifications` }] + content: [text(`Sent ${count} notifications`)] }; } );