From 107767a39681c8a83709227dcfd7b59927e7e02d Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sun, 7 Dec 2025 09:27:06 +0200 Subject: [PATCH 01/17] save commit --- src/server/context.ts | 222 +++++++++++++++++++++++++++++++++++ src/server/index.ts | 79 ++++++++++++- src/server/requestContext.ts | 109 +++++++++++++++++ src/shared/protocol.ts | 94 +++++++++------ 4 files changed, 462 insertions(+), 42 deletions(-) create mode 100644 src/server/context.ts create mode 100644 src/server/requestContext.ts diff --git a/src/server/context.ts b/src/server/context.ts new file mode 100644 index 000000000..f529e1ea0 --- /dev/null +++ b/src/server/context.ts @@ -0,0 +1,222 @@ +import { + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + ElicitResultSchema, + LoggingMessageNotification, + Notification, + Request, + RequestId, + RequestInfo, + RequestMeta, + Result, + ServerNotification, + ServerRequest, + ServerResult +} from '../types.js'; +import { RequestHandlerExtra, RequestOptions, RequestTaskStore } from '../shared/protocol.js'; +import { Server } from './index.js'; +import { RequestContext } from './requestContext.js'; +import { AuthInfo } from './auth/types.js'; +import { AnySchema, SchemaOutput } from './zod-compat.js'; + +export interface ContextInterface extends RequestHandlerExtra { + elicit(params: ElicitRequest['params'], options?: RequestOptions): Promise; + requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; + log(params: LoggingMessageNotification['params'], sessionId?: string): Promise; + debug(message: string, extraLogData?: Record, sessionId?: string): Promise; + info(message: string, extraLogData?: Record, sessionId?: string): Promise; + warning(message: string, extraLogData?: Record, sessionId?: string): Promise; + error(message: string, extraLogData?: Record, sessionId?: string): Promise; +} +/** + * A context object that is passed to request handlers. + * + * Implements the RequestHandlerExtra interface for backwards compatibility. + */ +export class Context< + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + ResultT extends Result = Result +> implements ContextInterface +{ + private readonly server: Server; + + /** + * The request context. + * A type-safe context that is passed to request handlers. + */ + public readonly requestCtx: RequestContext< + RequestT | ServerRequest, + NotificationT | ServerNotification, + ResultT | ServerResult + >; + + constructor(args: { + server: Server; + requestCtx: RequestContext; + }) { + this.server = args.server; + this.requestCtx = args.requestCtx; + } + + public get requestId(): RequestId { + return this.requestCtx.requestId; + } + + public get signal(): AbortSignal { + return this.requestCtx.signal; + } + + public get authInfo(): AuthInfo | undefined { + return this.requestCtx.authInfo; + } + + public get requestInfo(): RequestInfo | undefined { + return this.requestCtx.requestInfo; + } + + public get _meta(): RequestMeta | undefined { + return this.requestCtx._meta; + } + + public get sessionId(): string | undefined { + return this.requestCtx.sessionId; + } + + public get taskId(): string | undefined { + return this.requestCtx.taskId; + } + + public get taskStore(): RequestTaskStore | undefined { + return this.requestCtx.taskStore; + } + + public get taskRequestedTtl(): number | undefined { + return this.requestCtx.taskRequestedTtl ?? undefined; + } + + public closeSSEStream = (): void => { + return this.requestCtx?.closeSSEStream(); + } + + public closeStandaloneSSEStream = (): void => { + return this.requestCtx?.closeStandaloneSSEStream(); + } + + /** + * Sends a notification that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + public sendNotification = (notification: NotificationT | ServerNotification): Promise => { + return this.requestCtx.sendNotification(notification); + }; + + /** + * Sends a request that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + public sendRequest = ( + request: RequestT | ServerRequest, + resultSchema: U, + options?: RequestOptions + ): Promise> => { + return this.requestCtx.sendRequest(request, resultSchema, { ...options, relatedRequestId: this.requestId }); + }; + + /** + * Sends a request to sample an LLM via the client. + */ + public requestSampling(params: CreateMessageRequest['params'], options?: RequestOptions) { + return this.server.createMessage(params, options); + } + + /** + * Sends an elicitation request to the client. + */ + public async elicit(params: ElicitRequest['params'], options?: RequestOptions): Promise { + const request: ElicitRequest = { + method: 'elicitation/create', + params + }; + return await this.server.request(request, ElicitResultSchema, { ...options, relatedRequestId: this.requestId }); + } + + /** + * Sends a logging message. + */ + public async log(params: LoggingMessageNotification['params'], sessionId?: string) { + await this.server.sendLoggingMessage(params, sessionId); + } + + /** + * Sends a debug log message. + */ + public async debug(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'debug', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an info log message. + */ + public async info(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'info', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends a warning log message. + */ + public async warning(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'warning', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an error log message. + */ + public async error(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'error', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } +} \ No newline at end of file diff --git a/src/server/index.ts b/src/server/index.ts index aa1a62d00..19cb39c39 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -40,7 +40,10 @@ import { type ToolUseContent, CallToolRequestSchema, CallToolResultSchema, - CreateTaskResultSchema + CreateTaskResultSchema, + JSONRPCRequest, + TaskCreationParams, + MessageExtraInfo } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js'; @@ -56,6 +59,10 @@ import { import { RequestHandlerExtra } from '../shared/protocol.js'; import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; import { assertToolsCallTaskCapability, assertClientRequestTaskCapability } from '../experimental/tasks/helpers.js'; +import { Context } from './context.js'; +import { TaskStore } from '../experimental/index.js'; +import { Transport } from '../shared/transport.js'; +import { RequestContext } from './requestContext.js'; export type ServerOptions = ProtocolOptions & { /** @@ -219,9 +226,23 @@ export class Server< requestSchema: T, handler: ( request: SchemaOutput, - extra: RequestHandlerExtra + extra: Context ) => ServerResult | ResultT | Promise ): void { + // Wrap the handler to ensure the extra is a Context and return a decorated handler that can be passed to the base implementation + + // Factory function to create a handler decorator that ensures the extra is a Context and returns a decorated handler that can be passed to the base implementation + const handlerDecoratorFactory = (innerHandler: (request: SchemaOutput, extra: Context) => ServerResult | ResultT | Promise) => { + const decoratedHandler = (request: SchemaOutput, extra: RequestHandlerExtra) => { + if (!this.isContextExtra(extra)) { + throw new Error('Internal error: Expected Context for request handler extra'); + } + return innerHandler(request, extra); + } + + return decoratedHandler; + } + const shape = getObjectShape(requestSchema); const methodSchema = shape?.method; if (!methodSchema) { @@ -259,7 +280,7 @@ export class Server< const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handlerDecoratorFactory(handler)(request, extra)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -286,11 +307,18 @@ export class Server< }; // Install the wrapped handler - return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + return super.setRequestHandler(requestSchema, handlerDecoratorFactory(wrappedHandler)); } // Other handlers use default behavior - return super.setRequestHandler(requestSchema, handler); + return super.setRequestHandler(requestSchema, handlerDecoratorFactory(handler)); + } + + // Runtime type guard: ensure extra is our Context + private isContextExtra( + extra: RequestHandlerExtra + ): extra is Context { + return extra instanceof Context; } protected assertCapabilityForMethod(method: RequestT['method']): void { @@ -468,6 +496,47 @@ export class Server< return this._capabilities; } + protected createRequestExtra( + args: { + request: JSONRPCRequest, + taskStore: TaskStore | undefined, + relatedTaskId: string | undefined, + taskCreationParams: TaskCreationParams | undefined, + abortController: AbortController, + capturedTransport: Transport | undefined, + extra?: MessageExtraInfo + } + ): RequestHandlerExtra { + const base = super.createRequestExtra(args) as RequestHandlerExtra< + ServerRequest | RequestT, + ServerNotification | NotificationT + >; + + // Wrap base in Context to add server utilities while preserving shape + const requestCtx = new RequestContext< + ServerRequest | RequestT, + ServerNotification | NotificationT, + ServerResult | ResultT + >({ + signal: base.signal, + authInfo: base.authInfo, + requestInfo: base.requestInfo, + requestId: base.requestId, + _meta: base._meta, + sessionId: base.sessionId, + protocol: this, + closeSSEStream: base.closeSSEStream ?? undefined, + closeStandaloneSSEStream: base.closeStandaloneSSEStream ?? undefined + }); + + const ctx = new Context({ + server: this, + requestCtx + }); + + return ctx; + } + async ping() { return this.request({ method: 'ping' }, EmptyResultSchema); } diff --git a/src/server/requestContext.ts b/src/server/requestContext.ts new file mode 100644 index 000000000..d7be71986 --- /dev/null +++ b/src/server/requestContext.ts @@ -0,0 +1,109 @@ +import { AuthInfo } from './auth/types.js'; +import { Notification, Request, RequestId, RequestInfo, RequestMeta, Result } from '../types.js'; +import { Protocol, RequestHandlerExtra, RequestTaskStore, TaskRequestOptions } from '../shared/protocol.js'; +import { AnySchema, SchemaOutput } from './zod-compat.js'; + +/** + * A context object that is passed to request handlers. + * + * Implements the RequestHandlerExtra interface for backwards compatibility. + */ +export class RequestContext< + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + ResultT extends Result = Result +> implements RequestHandlerExtra +{ + /** + * An abort signal used to communicate if the request was cancelled from the sender's side. + */ + public readonly signal: AbortSignal; + + /** + * Information about a validated access token, provided to request handlers. + */ + public readonly authInfo?: AuthInfo; + + /** + * The original HTTP request. + */ + public readonly requestInfo?: RequestInfo; + + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + */ + public readonly requestId: RequestId; + + /** + * Metadata from the original request. + */ + public readonly _meta?: RequestMeta; + + /** + * The session ID from the transport, if available. + */ + public readonly sessionId?: string; + + /** + * The task store, if available. + */ + public readonly taskStore?: RequestTaskStore; + + public readonly taskId?: string; + + public readonly taskRequestedTtl?: number | null; + + private readonly protocol: Protocol; + constructor(args: { + signal: AbortSignal; + authInfo?: AuthInfo; + requestInfo?: RequestInfo; + requestId: RequestId; + _meta?: RequestMeta; + sessionId?: string; + protocol: Protocol; + taskStore?: RequestTaskStore; + taskId?: string; + taskRequestedTtl?: number | null; + closeSSEStream: (() => void) | undefined; + closeStandaloneSSEStream: (() => void) | undefined; + }) { + this.signal = args.signal; + this.authInfo = args.authInfo; + this.requestInfo = args.requestInfo; + this.requestId = args.requestId; + this._meta = args._meta; + this.sessionId = args.sessionId; + this.protocol = args.protocol; + this.taskStore = args.taskStore; + this.taskId = args.taskId; + this.taskRequestedTtl = args.taskRequestedTtl; + } + + /** + * Sends a notification that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + public sendNotification = (notification: NotificationT): Promise => { + return this.protocol.notification(notification, { relatedRequestId: this.requestId }); + }; + + /** + * Sends a request that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + public sendRequest = (request: RequestT, resultSchema: U, options?: TaskRequestOptions): Promise> => { + return this.protocol.request(request, resultSchema, { ...options, relatedRequestId: this.requestId }); + }; + + public closeSSEStream = (): void => { + return this.closeSSEStream(); + } + + public closeStandaloneSSEStream = (): void => { + return this.closeStandaloneSSEStream(); + } +} \ No newline at end of file diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index e195478f2..5ec1d0151 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -709,43 +709,7 @@ export abstract class Protocol = { - signal: abortController.signal, - sessionId: capturedTransport?.sessionId, - _meta: request.params?._meta, - sendNotification: async notification => { - // Include related-task metadata if this request is part of a task - const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; - if (relatedTaskId) { - notificationOptions.relatedTask = { taskId: relatedTaskId }; - } - await this.notification(notification, notificationOptions); - }, - sendRequest: async (r, resultSchema, options?) => { - // Include related-task metadata if this request is part of a task - const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } - - // Set task status to input_required when sending a request within a task context - // Use the taskId from options (explicit) or fall back to relatedTaskId (inherited) - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } - - return await this.request(r, resultSchema, requestOptions); - }, - authInfo: extra?.authInfo, - requestId: request.id, - requestInfo: extra?.requestInfo, - taskId: relatedTaskId, - taskStore: taskStore, - taskRequestedTtl: taskCreationParams?.ttl, - closeSSEStream: extra?.closeSSEStream, - closeStandaloneSSEStream: extra?.closeStandaloneSSEStream - }; + const fullExtra: RequestHandlerExtra = this.createRequestExtra({ request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra }); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() @@ -823,6 +787,62 @@ export abstract class Protocol { + const { request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra } = args; + + return { + signal: abortController.signal, + sessionId: capturedTransport?.sessionId, + _meta: request.params?._meta, + sendNotification: async notification => { + // Include related-task metadata if this request is part of a task + const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; + if (relatedTaskId) { + notificationOptions.relatedTask = { taskId: relatedTaskId }; + } + await this.notification(notification, notificationOptions); + }, + sendRequest: async (r, resultSchema, options?) => { + // Include related-task metadata if this request is part of a task + const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; + if (relatedTaskId && !requestOptions.relatedTask) { + requestOptions.relatedTask = { taskId: relatedTaskId }; + } + + // Set task status to input_required when sending a request within a task context + // Use the taskId from options (explicit) or fall back to relatedTaskId (inherited) + const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; + if (effectiveTaskId && taskStore) { + await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); + } + + return await this.request(r, resultSchema, requestOptions); + }, + authInfo: extra?.authInfo, + requestId: request.id, + requestInfo: extra?.requestInfo, + taskId: relatedTaskId, + taskStore: taskStore, + taskRequestedTtl: taskCreationParams?.ttl, + closeSSEStream: extra?.closeSSEStream, + closeStandaloneSSEStream: extra?.closeStandaloneSSEStream + } as RequestHandlerExtra; + } + private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); From 68ff6650d8703efd8e59223c9ab8b67dc58892de Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sun, 7 Dec 2025 10:24:29 +0200 Subject: [PATCH 02/17] context API - backwards compatible introduction --- src/server/context.ts | 81 ++++++--- src/server/index.ts | 69 +++---- src/server/mcp.ts | 17 +- src/server/requestContext.ts | 109 ------------ src/shared/protocol.ts | 32 ++-- test/server/context.test.ts | 277 +++++++++++++++++++++++++++++ test/server/mcp.test.ts | 30 ++-- test/server/streamableHttp.test.ts | 4 +- 8 files changed, 411 insertions(+), 208 deletions(-) delete mode 100644 src/server/requestContext.ts create mode 100644 test/server/context.test.ts diff --git a/src/server/context.ts b/src/server/context.ts index f529e1ea0..a90d3db09 100644 --- a/src/server/context.ts +++ b/src/server/context.ts @@ -4,6 +4,7 @@ import { ElicitRequest, ElicitResult, ElicitResultSchema, + JSONRPCRequest, LoggingMessageNotification, Notification, Request, @@ -12,16 +13,15 @@ import { RequestMeta, Result, ServerNotification, - ServerRequest, - ServerResult + ServerRequest } from '../types.js'; import { RequestHandlerExtra, RequestOptions, RequestTaskStore } from '../shared/protocol.js'; import { Server } from './index.js'; -import { RequestContext } from './requestContext.js'; import { AuthInfo } from './auth/types.js'; import { AnySchema, SchemaOutput } from './zod-compat.js'; -export interface ContextInterface extends RequestHandlerExtra { +export interface ContextInterface + extends RequestHandlerExtra { elicit(params: ElicitRequest['params'], options?: RequestOptions): Promise; requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; log(params: LoggingMessageNotification['params'], sessionId?: string): Promise; @@ -35,11 +35,8 @@ export interface ContextInterface implements ContextInterface +export class Context + implements ContextInterface { private readonly server: Server; @@ -47,20 +44,52 @@ export class Context< * The request context. * A type-safe context that is passed to request handlers. */ - public readonly requestCtx: RequestContext< - RequestT | ServerRequest, - NotificationT | ServerNotification, - ResultT | ServerResult - >; + private readonly requestCtx: RequestHandlerExtra; + + /** + * The MCP context - Contains information about the current MCP request and session. + */ + public readonly mcpContext: { + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + */ + requestId: RequestId; + /** + * The method of the request. + */ + method: string; + /** + * The metadata of the request. + */ + _meta?: RequestMeta; + /** + * The session ID of the request. + */ + sessionId?: string; + }; constructor(args: { server: Server; - requestCtx: RequestContext; + request: JSONRPCRequest; + requestCtx: RequestHandlerExtra; }) { this.server = args.server; this.requestCtx = args.requestCtx; + this.mcpContext = { + requestId: args.requestCtx.requestId, + method: args.request.method, + _meta: args.requestCtx._meta, + sessionId: args.requestCtx.sessionId + }; } + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + * + * @deprecated Use {@link mcpContext.requestId} instead. + */ public get requestId(): RequestId { return this.requestCtx.requestId; } @@ -77,12 +106,18 @@ export class Context< return this.requestCtx.requestInfo; } + /** + * @deprecated Use {@link mcpContext._meta} instead. + */ public get _meta(): RequestMeta | undefined { return this.requestCtx._meta; } + /** + * @deprecated Use {@link mcpContext.sessionId} instead. + */ public get sessionId(): string | undefined { - return this.requestCtx.sessionId; + return this.mcpContext.sessionId; } public get taskId(): string | undefined { @@ -97,12 +132,12 @@ export class Context< return this.requestCtx.taskRequestedTtl ?? undefined; } - public closeSSEStream = (): void => { - return this.requestCtx?.closeSSEStream(); + public get closeSSEStream(): (() => void) | undefined { + return this.requestCtx.closeSSEStream; } - public closeStandaloneSSEStream = (): void => { - return this.requestCtx?.closeStandaloneSSEStream(); + public get closeStandaloneSSEStream(): (() => void) | undefined { + return this.requestCtx.closeStandaloneSSEStream; } /** @@ -111,7 +146,7 @@ export class Context< * This is used by certain transports to correctly associate related messages. */ public sendNotification = (notification: NotificationT | ServerNotification): Promise => { - return this.requestCtx.sendNotification(notification); + return this.server.notification(notification); }; /** @@ -124,7 +159,7 @@ export class Context< resultSchema: U, options?: RequestOptions ): Promise> => { - return this.requestCtx.sendRequest(request, resultSchema, { ...options, relatedRequestId: this.requestId }); + return this.server.request(request, resultSchema, { ...options, relatedRequestId: this.requestId }); }; /** @@ -219,4 +254,4 @@ export class Context< sessionId ); } -} \ No newline at end of file +} diff --git a/src/server/index.ts b/src/server/index.ts index 19cb39c39..fa0d038f1 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -62,7 +62,6 @@ import { assertToolsCallTaskCapability, assertClientRequestTaskCapability } from import { Context } from './context.js'; import { TaskStore } from '../experimental/index.js'; import { Transport } from '../shared/transport.js'; -import { RequestContext } from './requestContext.js'; export type ServerOptions = ProtocolOptions & { /** @@ -232,16 +231,24 @@ export class Server< // Wrap the handler to ensure the extra is a Context and return a decorated handler that can be passed to the base implementation // Factory function to create a handler decorator that ensures the extra is a Context and returns a decorated handler that can be passed to the base implementation - const handlerDecoratorFactory = (innerHandler: (request: SchemaOutput, extra: Context) => ServerResult | ResultT | Promise) => { - const decoratedHandler = (request: SchemaOutput, extra: RequestHandlerExtra) => { + const handlerDecoratorFactory = ( + innerHandler: ( + request: SchemaOutput, + extra: Context + ) => ServerResult | ResultT | Promise + ) => { + const decoratedHandler = ( + request: SchemaOutput, + extra: RequestHandlerExtra + ) => { if (!this.isContextExtra(extra)) { throw new Error('Internal error: Expected Context for request handler extra'); } return innerHandler(request, extra); - } + }; return decoratedHandler; - } + }; const shape = getObjectShape(requestSchema); const methodSchema = shape?.method; @@ -496,45 +503,23 @@ export class Server< return this._capabilities; } - protected createRequestExtra( - args: { - request: JSONRPCRequest, - taskStore: TaskStore | undefined, - relatedTaskId: string | undefined, - taskCreationParams: TaskCreationParams | undefined, - abortController: AbortController, - capturedTransport: Transport | undefined, - extra?: MessageExtraInfo - } - ): RequestHandlerExtra { - const base = super.createRequestExtra(args) as RequestHandlerExtra< - ServerRequest | RequestT, - ServerNotification | NotificationT - >; - - // Wrap base in Context to add server utilities while preserving shape - const requestCtx = new RequestContext< - ServerRequest | RequestT, - ServerNotification | NotificationT, - ServerResult | ResultT - >({ - signal: base.signal, - authInfo: base.authInfo, - requestInfo: base.requestInfo, - requestId: base.requestId, - _meta: base._meta, - sessionId: base.sessionId, - protocol: this, - closeSSEStream: base.closeSSEStream ?? undefined, - closeStandaloneSSEStream: base.closeStandaloneSSEStream ?? undefined - }); - - const ctx = new Context({ + protected createRequestExtra(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): RequestHandlerExtra { + const base = super.createRequestExtra(args) as RequestHandlerExtra; + + // Expose a Context instance to handlers, which implements RequestHandlerExtra + return new Context({ server: this, - requestCtx + request: args.request, + requestCtx: base }); - - return ctx; } async ping() { diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 7e61b4364..ba7024df9 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -63,6 +63,7 @@ import { validateAndWarnToolName } from '../shared/toolNameValidation.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcp-server.js'; import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ZodOptional } from 'zod'; +import { ContextInterface } from './context.js'; /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. @@ -324,7 +325,7 @@ export class McpServer { private async executeToolHandler( tool: RegisteredTool, args: unknown, - extra: RequestHandlerExtra + extra: ContextInterface ): Promise { const handler = tool.handler as AnyToolHandler; const isTaskHandler = 'createTask' in handler; @@ -1270,7 +1271,7 @@ export class ResourceTemplate { export type BaseToolCallback< SendResultT extends Result, - Extra extends RequestHandlerExtra, + Extra extends ContextInterface, Args extends undefined | ZodRawShapeCompat | AnySchema > = Args extends ZodRawShapeCompat ? (args: ShapeOutput, extra: Extra) => SendResultT | Promise @@ -1290,7 +1291,7 @@ export type BaseToolCallback< */ export type ToolCallback = BaseToolCallback< CallToolResult, - RequestHandlerExtra, + ContextInterface, Args >; @@ -1409,7 +1410,7 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra + extra: ContextInterface ) => ListResourcesResult | Promise; /** @@ -1417,7 +1418,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - extra: RequestHandlerExtra + extra: ContextInterface ) => ReadResourceResult | Promise; export type RegisteredResource = { @@ -1445,7 +1446,7 @@ export type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - extra: RequestHandlerExtra + extra: ContextInterface ) => ReadResourceResult | Promise; export type RegisteredResourceTemplate = { @@ -1470,8 +1471,8 @@ export type RegisteredResourceTemplate = { type PromptArgsRawShape = ZodRawShapeCompat; export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, extra: RequestHandlerExtra) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + ? (args: ShapeOutput, extra: ContextInterface) => GetPromptResult | Promise + : (extra: ContextInterface) => GetPromptResult | Promise; export type RegisteredPrompt = { title?: string; diff --git a/src/server/requestContext.ts b/src/server/requestContext.ts deleted file mode 100644 index d7be71986..000000000 --- a/src/server/requestContext.ts +++ /dev/null @@ -1,109 +0,0 @@ -import { AuthInfo } from './auth/types.js'; -import { Notification, Request, RequestId, RequestInfo, RequestMeta, Result } from '../types.js'; -import { Protocol, RequestHandlerExtra, RequestTaskStore, TaskRequestOptions } from '../shared/protocol.js'; -import { AnySchema, SchemaOutput } from './zod-compat.js'; - -/** - * A context object that is passed to request handlers. - * - * Implements the RequestHandlerExtra interface for backwards compatibility. - */ -export class RequestContext< - RequestT extends Request = Request, - NotificationT extends Notification = Notification, - ResultT extends Result = Result -> implements RequestHandlerExtra -{ - /** - * An abort signal used to communicate if the request was cancelled from the sender's side. - */ - public readonly signal: AbortSignal; - - /** - * Information about a validated access token, provided to request handlers. - */ - public readonly authInfo?: AuthInfo; - - /** - * The original HTTP request. - */ - public readonly requestInfo?: RequestInfo; - - /** - * The JSON-RPC ID of the request being handled. - * This can be useful for tracking or logging purposes. - */ - public readonly requestId: RequestId; - - /** - * Metadata from the original request. - */ - public readonly _meta?: RequestMeta; - - /** - * The session ID from the transport, if available. - */ - public readonly sessionId?: string; - - /** - * The task store, if available. - */ - public readonly taskStore?: RequestTaskStore; - - public readonly taskId?: string; - - public readonly taskRequestedTtl?: number | null; - - private readonly protocol: Protocol; - constructor(args: { - signal: AbortSignal; - authInfo?: AuthInfo; - requestInfo?: RequestInfo; - requestId: RequestId; - _meta?: RequestMeta; - sessionId?: string; - protocol: Protocol; - taskStore?: RequestTaskStore; - taskId?: string; - taskRequestedTtl?: number | null; - closeSSEStream: (() => void) | undefined; - closeStandaloneSSEStream: (() => void) | undefined; - }) { - this.signal = args.signal; - this.authInfo = args.authInfo; - this.requestInfo = args.requestInfo; - this.requestId = args.requestId; - this._meta = args._meta; - this.sessionId = args.sessionId; - this.protocol = args.protocol; - this.taskStore = args.taskStore; - this.taskId = args.taskId; - this.taskRequestedTtl = args.taskRequestedTtl; - } - - /** - * Sends a notification that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. - */ - public sendNotification = (notification: NotificationT): Promise => { - return this.protocol.notification(notification, { relatedRequestId: this.requestId }); - }; - - /** - * Sends a request that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. - */ - public sendRequest = (request: RequestT, resultSchema: U, options?: TaskRequestOptions): Promise> => { - return this.protocol.request(request, resultSchema, { ...options, relatedRequestId: this.requestId }); - }; - - public closeSSEStream = (): void => { - return this.closeSSEStream(); - } - - public closeStandaloneSSEStream = (): void => { - return this.closeStandaloneSSEStream(); - } -} \ No newline at end of file diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 5ec1d0151..87d70b10d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -231,6 +231,8 @@ export interface RequestTaskStore { /** * Extra data given to request handlers. + * + * @deprecated Use {@link ContextInterface} from {@link Context} instead. Future major versions will remove this type. */ export type RequestHandlerExtra = { /** @@ -709,7 +711,15 @@ export abstract class Protocol = this.createRequestExtra({ request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra }); + const fullExtra: RequestHandlerExtra = this.createRequestExtra({ + request, + taskStore, + relatedTaskId, + taskCreationParams, + abortController, + capturedTransport, + extra + }); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() @@ -791,17 +801,15 @@ export abstract class Protocol { + protected createRequestExtra(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): RequestHandlerExtra { const { request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra } = args; return { diff --git a/test/server/context.test.ts b/test/server/context.test.ts new file mode 100644 index 000000000..8c9518ae8 --- /dev/null +++ b/test/server/context.test.ts @@ -0,0 +1,277 @@ +import { z } from 'zod/v4'; +import { Client } from '../../src/client/index.js'; +import { McpServer, ResourceTemplate } from '../../src/server/mcp.js'; +import { Context } from '../../src/server/context.js'; +import { + CallToolResultSchema, + GetPromptResultSchema, + ListResourcesResultSchema, + LoggingMessageNotificationSchema, + ReadResourceResultSchema, + ServerNotification, + ServerRequest +} from '../../src/types.js'; +import { InMemoryTransport } from '../../src/inMemory.js'; +import { RequestHandlerExtra } from '../../src/shared/protocol.js'; +import { ShapeOutput, ZodRawShapeCompat } from '../../src/server/zod-compat.js'; + +describe('Context', () => { + /*** + * Test: `extra` provided to callbacks is Context (parameterized) + */ + type Seen = { isContext: boolean; hasRequestId: boolean }; + const contextCases: Array<[string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise]> = + [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + (_args: { name: string }, extra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + mcpServer.registerResource('ctx-resource', 'test://res/1', { title: 'ctx-resource' }, async (_uri, extra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + }); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + const template = new ResourceTemplate('test://items/{id}', { + list: async extra => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _extra) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + mcpServer.registerPrompt('ctx-prompt', {}, async extra => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { messages: [] }; + }); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass Context as extra to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + + const logLevelsThroughContext = ['debug', 'info', 'warning', 'error'] as const; + + //it.each for each log level, test that logging message is sent to client + it.each(logLevelsThroughContext)('should send logging message to client for %s level from Context', async level => { + const mcpServer = new McpServer( + { name: 'ctx-test', version: '1.0' }, + { + capabilities: { + logging: {} + } + } + ); + const client = new Client( + { name: 'ctx-client', version: '1.0' }, + { + capabilities: {} + } + ); + + let seen = 0; + + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + seen++; + expect(notification.params.level).toBe(level); + expect(notification.params.data).toBe('Test message'); + expect(notification.params.test).toBe('test'); + expect(notification.params.sessionId).toBe('sample-session-id'); + return; + }); + + mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, extra) => { + await extra[level]('Test message', { test: 'test' }, 'sample-session-id'); + await extra.log( + { + level, + data: 'Test message', + logger: 'test-logger-namespace' + }, + 'sample-session-id' + ); + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { name: 'ctx-log-test', arguments: { name: 'ctx-log-test-name' } } + }, + CallToolResultSchema + ); + + // two messages should have been sent - one from the .log method and one from the .debug/info/warning/error method + expect(seen).toBe(2); + + expect(result.content).toHaveLength(1); + expect(result.content[0]).toMatchObject({ + type: 'text', + text: 'ok' + }); + }); + describe('Legacy RequestHandlerExtra API', () => { + const contextCases: Array< + [string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise] + > = [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + (_args: { name: string }, extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + mcpServer.registerResource( + 'ctx-resource', + 'test://res/1', + { title: 'ctx-resource' }, + async (_uri, extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + } + ); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + const template = new ResourceTemplate('test://items/{id}', { + list: async (extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _extra) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + mcpServer.registerPrompt( + 'ctx-prompt', + {}, + async (args: ShapeOutput, extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { messages: [] }; + } + ); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass Context as extra to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + }); +}); diff --git a/test/server/mcp.test.ts b/test/server/mcp.test.ts index f6c2124e1..4be2d24f4 100644 --- a/test/server/mcp.test.ts +++ b/test/server/mcp.test.ts @@ -17,12 +17,15 @@ import { ReadResourceResultSchema, type TextContent, UrlElicitationRequiredError, - ErrorCode + ErrorCode, + ServerRequest, + ServerNotification } from '../../src/types.js'; import { completable } from '../../src/server/completable.js'; import { McpServer, ResourceTemplate } from '../../src/server/mcp.js'; import { InMemoryTaskStore } from '../../src/experimental/tasks/stores/in-memory.js'; import { zodTestMatrix, type ZodMatrixEntry } from '../../src/__fixtures__/zodTestMatrix.js'; +import { Context, ContextInterface } from '../../src/server/context.js'; function createLatch() { let latch = false; @@ -243,7 +246,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { sendNotification: () => { throw new Error('Not implemented'); } - }); + } as unknown as ContextInterface); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); }); @@ -4387,17 +4390,20 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }) } }, - async ({ department, name }) => ({ - messages: [ - { - role: 'assistant', - content: { - type: 'text', - text: `Hello ${name}, welcome to the ${department} team!` + async ({ department, name }, extra: ContextInterface) => { + expect(extra).toBeInstanceOf(Context); + return { + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}, welcome to the ${department} team!` + } } - } - ] - }) + ] + }; + } ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); diff --git a/test/server/streamableHttp.test.ts b/test/server/streamableHttp.test.ts index 8d94b272e..be5908f60 100644 --- a/test/server/streamableHttp.test.ts +++ b/test/server/streamableHttp.test.ts @@ -2280,8 +2280,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Verify we received the notification that was sent while disconnected expect(allText).toContain('Missed while disconnected'); - }); - }, 10000); + }, 10000); + }); // Test onsessionclosed callback describe('StreamableHTTPServerTransport onsessionclosed callback', () => { From f58b491fb6b15f5d170fb54ec0c296ca6d6cd7b4 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sun, 7 Dec 2025 10:39:36 +0200 Subject: [PATCH 03/17] fixes --- src/server/context.ts | 4 ++-- src/server/mcp.ts | 2 +- test/server/context.test.ts | 15 +++++---------- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/server/context.ts b/src/server/context.ts index a90d3db09..ee6a97f70 100644 --- a/src/server/context.ts +++ b/src/server/context.ts @@ -146,7 +146,7 @@ export class Context => { - return this.server.notification(notification); + return this.requestCtx.sendNotification(notification); }; /** @@ -159,7 +159,7 @@ export class Context> => { - return this.server.request(request, resultSchema, { ...options, relatedRequestId: this.requestId }); + return this.requestCtx.sendRequest(request, resultSchema, { ...options, relatedRequestId: this.requestId }); }; /** diff --git a/src/server/mcp.ts b/src/server/mcp.ts index ba7024df9..6d4e1a5ef 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -1136,7 +1136,7 @@ export class McpServer { /** * Registers a prompt with a config object and callback. */ - registerPrompt( + registerPrompt( name: string, config: { title?: string; diff --git a/test/server/context.test.ts b/test/server/context.test.ts index 8c9518ae8..eaa9374c7 100644 --- a/test/server/context.test.ts +++ b/test/server/context.test.ts @@ -13,7 +13,6 @@ import { } from '../../src/types.js'; import { InMemoryTransport } from '../../src/inMemory.js'; import { RequestHandlerExtra } from '../../src/shared/protocol.js'; -import { ShapeOutput, ZodRawShapeCompat } from '../../src/server/zod-compat.js'; describe('Context', () => { /*** @@ -243,15 +242,11 @@ describe('Context', () => { 'prompt', (mcpServer, seen) => { // The test is to ensure that the extra is compatible with the RequestHandlerExtra type - mcpServer.registerPrompt( - 'ctx-prompt', - {}, - async (args: ShapeOutput, extra: RequestHandlerExtra) => { - seen.isContext = extra instanceof Context; - seen.hasRequestId = !!extra.requestId; - return { messages: [] }; - } - ); + mcpServer.registerPrompt('ctx-prompt', {}, async (extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { messages: [] }; + }); }, client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) ] From e89d9d4c4b1c3dba226934ec0dcd320c8bde3619 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Mon, 8 Dec 2025 20:42:21 +0200 Subject: [PATCH 04/17] moved properties under objects --- src/server/context.ts | 240 ++++++++++++++++++++++++------------ test/server/context.test.ts | 4 +- 2 files changed, 162 insertions(+), 82 deletions(-) diff --git a/src/server/context.ts b/src/server/context.ts index ee6a97f70..0a541ac21 100644 --- a/src/server/context.ts +++ b/src/server/context.ts @@ -20,16 +20,118 @@ import { Server } from './index.js'; import { AuthInfo } from './auth/types.js'; import { AnySchema, SchemaOutput } from './zod-compat.js'; -export interface ContextInterface - extends RequestHandlerExtra { - elicit(params: ElicitRequest['params'], options?: RequestOptions): Promise; - requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; +/** + * Interface for sending logging messages to the client via {@link LoggingMessageNotification}. + */ +export interface LoggingMessageSenderInterface { + /** + * Sends a logging message to the client. + */ log(params: LoggingMessageNotification['params'], sessionId?: string): Promise; + /** + * Sends a debug log message to the client. + */ debug(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends an info log message to the client. + */ info(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends a warning log message to the client. + */ warning(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends an error log message to the client. + */ error(message: string, extraLogData?: Record, sessionId?: string): Promise; } + +export class ServerLogger implements LoggingMessageSenderInterface { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + constructor(private readonly server: Server) {} + + /** + * Sends a logging message. + */ + public async log(params: LoggingMessageNotification['params'], sessionId?: string) { + await this.server.sendLoggingMessage(params, sessionId); + } + + /** + * Sends a debug log message. + */ + public async debug(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'debug', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an info log message. + */ + public async info(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'info', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends a warning log message. + */ + public async warning(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'warning', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an error log message. + */ + public async error(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'error', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } +} + +export interface ContextInterface + extends RequestHandlerExtra { + elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise; + requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; + logger: LoggingMessageSenderInterface; +} /** * A context object that is passed to request handlers. * @@ -69,6 +171,31 @@ export class Context void) | undefined; + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior for server-initiated notifications. + */ + closeStandaloneSSEStream: (() => void) | undefined; + } | undefined; + + public readonly logger: LoggingMessageSenderInterface; + constructor(args: { server: Server; request: JSONRPCRequest; @@ -82,6 +209,19 @@ export class Context void) | undefined { return this.requestCtx.closeSSEStream; } + /** + * @deprecated Use {@link stream.closeStandaloneSSEStream} instead. + */ public get closeStandaloneSSEStream(): (() => void) | undefined { return this.requestCtx.closeStandaloneSSEStream; } @@ -172,86 +327,11 @@ export class Context { + public async elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise { const request: ElicitRequest = { method: 'elicitation/create', params }; return await this.server.request(request, ElicitResultSchema, { ...options, relatedRequestId: this.requestId }); } - - /** - * Sends a logging message. - */ - public async log(params: LoggingMessageNotification['params'], sessionId?: string) { - await this.server.sendLoggingMessage(params, sessionId); - } - - /** - * Sends a debug log message. - */ - public async debug(message: string, extraLogData?: Record, sessionId?: string) { - await this.log( - { - level: 'debug', - data: { - ...extraLogData, - message - }, - logger: 'server' - }, - sessionId - ); - } - - /** - * Sends an info log message. - */ - public async info(message: string, extraLogData?: Record, sessionId?: string) { - await this.log( - { - level: 'info', - data: { - ...extraLogData, - message - }, - logger: 'server' - }, - sessionId - ); - } - - /** - * Sends a warning log message. - */ - public async warning(message: string, extraLogData?: Record, sessionId?: string) { - await this.log( - { - level: 'warning', - data: { - ...extraLogData, - message - }, - logger: 'server' - }, - sessionId - ); - } - - /** - * Sends an error log message. - */ - public async error(message: string, extraLogData?: Record, sessionId?: string) { - await this.log( - { - level: 'error', - data: { - ...extraLogData, - message - }, - logger: 'server' - }, - sessionId - ); - } } diff --git a/test/server/context.test.ts b/test/server/context.test.ts index eaa9374c7..0f1190df0 100644 --- a/test/server/context.test.ts +++ b/test/server/context.test.ts @@ -138,8 +138,8 @@ describe('Context', () => { }); mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, extra) => { - await extra[level]('Test message', { test: 'test' }, 'sample-session-id'); - await extra.log( + await extra.logger[level]('Test message', { test: 'test' }, 'sample-session-id'); + await extra.logger.log( { level, data: 'Test message', From 187a3cd13fccd51e8edb4aafd85f9a8860af1a57 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Mon, 8 Dec 2025 20:44:05 +0200 Subject: [PATCH 05/17] prettier fix --- src/server/context.ts | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/server/context.ts b/src/server/context.ts index 0a541ac21..4c6e858b8 100644 --- a/src/server/context.ts +++ b/src/server/context.ts @@ -179,20 +179,22 @@ export class Context void) | undefined; - /** - * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. - * Use this to implement polling behavior for server-initiated notifications. - */ - closeStandaloneSSEStream: (() => void) | undefined; - } | undefined; + public readonly stream: + | { + /** + * Closes the SSE stream for this request, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior during long-running operations. + */ + closeSSEStream: (() => void) | undefined; + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior for server-initiated notifications. + */ + closeStandaloneSSEStream: (() => void) | undefined; + } + | undefined; public readonly logger: LoggingMessageSenderInterface; From 96169b33dcb8954f1997e0cd7e75d98af82ac3fb Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Mon, 8 Dec 2025 20:48:34 +0200 Subject: [PATCH 06/17] move logger methods under loggingNotification --- src/server/context.ts | 10 +++++----- test/server/context.test.ts | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/server/context.ts b/src/server/context.ts index 4c6e858b8..91b075615 100644 --- a/src/server/context.ts +++ b/src/server/context.ts @@ -23,7 +23,7 @@ import { AnySchema, SchemaOutput } from './zod-compat.js'; /** * Interface for sending logging messages to the client via {@link LoggingMessageNotification}. */ -export interface LoggingMessageSenderInterface { +export interface LoggingMessageNotificationSenderInterface { /** * Sends a logging message to the client. */ @@ -46,7 +46,7 @@ export interface LoggingMessageSenderInterface { error(message: string, extraLogData?: Record, sessionId?: string): Promise; } -export class ServerLogger implements LoggingMessageSenderInterface { +export class ServerLogger implements LoggingMessageNotificationSenderInterface { // eslint-disable-next-line @typescript-eslint/no-explicit-any constructor(private readonly server: Server) {} @@ -130,7 +130,7 @@ export interface ContextInterface { elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise; requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; - logger: LoggingMessageSenderInterface; + loggingNotification: LoggingMessageNotificationSenderInterface; } /** * A context object that is passed to request handlers. @@ -196,7 +196,7 @@ export class Context; @@ -218,7 +218,7 @@ export class Context { }); mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, extra) => { - await extra.logger[level]('Test message', { test: 'test' }, 'sample-session-id'); - await extra.logger.log( + await extra.loggingNotification[level]('Test message', { test: 'test' }, 'sample-session-id'); + await extra.loggingNotification.log( { level, data: 'Test message', From d5f5047ecee5d630d56929b52b4d091d8c4c8a44 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Tue, 23 Dec 2025 07:06:12 +0200 Subject: [PATCH 07/17] merge commit - v2 --- examples/server/src/simpleStreamableHttp.ts | 7 ++- .../core/src/experimental/tasks/interfaces.ts | 20 -------- .../src/experimental/tasks/interfaces.ts | 9 ++-- packages/server/src/server/mcp.ts | 10 ++-- packages/server/tsconfig.json | 2 +- test/integration/test/server/mcp.test.ts | 49 +++++++++++++------ 6 files changed, 50 insertions(+), 47 deletions(-) diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 7613e3786..0332ab81a 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -480,6 +480,7 @@ const getServer = () => { { async createTask({ duration }, { taskStore, taskRequestedTtl }) { // Create the task + if (!taskStore) throw new Error('Task store not found'); const task = await taskStore.createTask({ ttl: taskRequestedTtl }); @@ -503,10 +504,12 @@ const getServer = () => { }; }, async getTask(_args, { taskId, taskStore }) { - return await taskStore.getTask(taskId); + if (!taskStore) throw new Error('Task store not found'); + return await taskStore.getTask(taskId!); }, async getTaskResult(_args, { taskId, taskStore }) { - const result = await taskStore.getTaskResult(taskId); + if (!taskStore) throw new Error('Task store not found'); + const result = await taskStore.getTaskResult(taskId!); return result as CallToolResult; } } diff --git a/packages/core/src/experimental/tasks/interfaces.ts b/packages/core/src/experimental/tasks/interfaces.ts index c1901d70a..4bf11942c 100644 --- a/packages/core/src/experimental/tasks/interfaces.ts +++ b/packages/core/src/experimental/tasks/interfaces.ts @@ -3,7 +3,6 @@ * WARNING: These APIs are experimental and may change without notice. */ -import type { RequestHandlerExtra, RequestTaskStore } from '../../shared/protocol.js'; import type { JSONRPCErrorResponse, JSONRPCNotification, @@ -12,8 +11,6 @@ import type { Request, RequestId, Result, - ServerNotification, - ServerRequest, Task, ToolExecution } from '../../types/types.js'; @@ -22,23 +19,6 @@ import type { // Task Handler Types (for registerToolTask) // ============================================================================ -/** - * Extended handler extra with task store for task creation. - * @experimental - */ -export interface CreateTaskRequestHandlerExtra extends RequestHandlerExtra { - taskStore: RequestTaskStore; -} - -/** - * Extended handler extra with task ID and store for task operations. - * @experimental - */ -export interface TaskRequestHandlerExtra extends RequestHandlerExtra { - taskId: string; - taskStore: RequestTaskStore; -} - /** * Task-specific execution configuration. * taskSupport cannot be 'forbidden' for task-based tools. diff --git a/packages/server/src/experimental/tasks/interfaces.ts b/packages/server/src/experimental/tasks/interfaces.ts index 0b32be213..574d256db 100644 --- a/packages/server/src/experimental/tasks/interfaces.ts +++ b/packages/server/src/experimental/tasks/interfaces.ts @@ -6,14 +6,15 @@ import type { AnySchema, CallToolResult, - CreateTaskRequestHandlerExtra, CreateTaskResult, GetTaskResult, Result, - TaskRequestHandlerExtra, + ServerNotification, + ServerRequest, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import type { ContextInterface } from '../../server/context.js'; import type { BaseToolCallback } from '../../server/mcp.js'; // ============================================================================ @@ -27,7 +28,7 @@ import type { BaseToolCallback } from '../../server/mcp.js'; export type CreateTaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; +> = BaseToolCallback, Args>; /** * Handler for task operations (get, getResult). @@ -36,7 +37,7 @@ export type CreateTaskRequestHandler< export type TaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; +> = BaseToolCallback, Args>; /** * Interface for task-based tool handlers. diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 10c78b744..64cda5269 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -18,7 +18,6 @@ import type { PromptArgument, PromptReference, ReadResourceResult, - RequestHandlerExtra, Resource, ResourceTemplateReference, Result, @@ -341,7 +340,7 @@ export class McpServer { if (tool.inputSchema) { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); + return await Promise.resolve(typedHandler.createTask(args as any, extra)); } else { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -366,7 +365,7 @@ export class McpServer { private async handleAutomaticTaskPolling( tool: RegisteredTool, request: RequestT, - extra: RequestHandlerExtra + extra: ContextInterface ): Promise { if (!extra.taskStore) { throw new Error('No task store provided for task-capable tool.'); @@ -375,12 +374,11 @@ export class McpServer { // Validate input and create task const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); const handler = tool.handler as ToolTaskHandler; - const taskExtra = { ...extra, taskStore: extra.taskStore }; const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined - ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) + ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, extra)) : // eslint-disable-next-line @typescript-eslint/no-explicit-any - await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(taskExtra)); + await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(extra)); // Poll until completion const taskId = createTaskResult.task.taskId; diff --git a/packages/server/tsconfig.json b/packages/server/tsconfig.json index 79594b169..a16bfd7d9 100644 --- a/packages/server/tsconfig.json +++ b/packages/server/tsconfig.json @@ -1,6 +1,6 @@ { "extends": "@modelcontextprotocol/tsconfig", - "include": ["./", "../../test/integration/test/server/context.test.ts"], + "include": ["./"], "exclude": ["node_modules", "dist"], "compilerOptions": { "paths": { diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 27014b2fa..7e925b4ec 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1902,16 +1902,19 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async (_args, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000 }); return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) throw new Error('Task not found'); return task; }, getTaskResult: async (_args, extra) => { - return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + if (!extra.taskStore) throw new Error('Task store not found'); + return (await extra.taskStore.getTaskResult(extra.taskId!)) as CallToolResult; } } ); @@ -1971,16 +1974,17 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async (_args, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000 }); return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + const task = await extra.taskStore?.getTask(extra.taskId!); if (!task) throw new Error('Task not found'); return task; }, getTaskResult: async (_args, extra) => { - return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + return (await extra.taskStore?.getTaskResult(extra.taskId!)) as CallToolResult; } } ); @@ -6304,6 +6308,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ input }, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6319,14 +6324,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async (_input, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } @@ -6409,6 +6416,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ value }, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6425,14 +6433,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async (_value, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } @@ -6517,6 +6527,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ data }, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6524,6 +6535,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Simulate async work setTimeout(async () => { + if (!store) throw new Error('Task store not found'); await store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text' as const, text: `Completed: ${data}` }] }); @@ -6533,14 +6545,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async (_data, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } @@ -6634,6 +6648,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async extra => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6740,6 +6755,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async extra => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6754,14 +6770,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, getTask: async extra => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async extra => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } @@ -6827,18 +6845,21 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async (_args, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async (_args, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } From deca7fd882702cf764ab2f7c7cd2e7ae8ec2d333 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Wed, 21 Jan 2026 23:52:03 +0200 Subject: [PATCH 08/17] rename method to createRequestContext, update docs --- CLAUDE.md | 45 +++++++++++++------ packages/client/src/client/client.ts | 2 +- packages/core/src/shared/protocol.ts | 8 ++-- packages/core/test/shared/protocol.test.ts | 2 +- .../shared/protocolTransportHandling.test.ts | 2 +- packages/server/src/server/server.ts | 2 +- 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 0f6eaeece..32d98ac8f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -136,7 +136,7 @@ When a request arrives from the remote side: 2. **`Protocol.connect()`** routes to `_onrequest()`, `_onresponse()`, or `_onnotification()` 3. **`Protocol._onrequest()`**: - Looks up handler in `_requestHandlers` map (keyed by method name) - - Creates `RequestHandlerExtra` with `signal`, `sessionId`, `sendNotification`, `sendRequest` + - Creates a context object (`ServerContext` or `ClientContext`) via `createRequestContext()` - Invokes handler, sends JSON-RPC response back via transport 4. **Handler** was registered via `setRequestHandler(Schema, handler)` @@ -144,29 +144,46 @@ When a request arrives from the remote side: ```typescript // In Client (for server→client requests like sampling, elicitation) -client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { +client.setRequestHandler(CreateMessageRequestSchema, async (request, ctx) => { // Handle sampling request from server return { role: "assistant", content: {...}, model: "..." }; }); // In Server (for client→server requests like tools/call) -server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { +server.setRequestHandler(CallToolRequestSchema, async (request, ctx) => { // Handle tool call from client return { content: [...] }; }); ``` -### Request Handler Extra +### Request Handler Context -The `extra` parameter in handlers (`RequestHandlerExtra`) provides: +The `ctx` parameter in handlers provides a structured context with three layers: -- `signal`: AbortSignal for cancellation +**`ctx.mcpCtx`** - MCP-level context: +- `requestId`: JSON-RPC message ID +- `method`: The method being called +- `_meta`: Request metadata - `sessionId`: Transport session identifier + +**`ctx.requestCtx`** - Request-level context: +- `signal`: AbortSignal for cancellation - `authInfo`: Validated auth token info (if authenticated) -- `requestId`: JSON-RPC message ID -- `sendNotification(notification)`: Send related notification back -- `sendRequest(request, schema)`: Send related request (for bidirectional flows) -- `taskStore`: Task storage interface (if tasks enabled) +- For server: `uri`, `headers`, `stream` (HTTP details) + +**`ctx.taskCtx`** - Task context (when tasks are enabled): +- `id`: Current task ID (updates after `store.createTask()`) +- `store`: Request-scoped task store (`RequestTaskStore`) +- `requestedTtl`: Requested TTL for the task + +**Context methods**: +- `ctx.sendNotification(notification)`: Send notification back +- `ctx.sendRequest(request, schema)`: Send request (for bidirectional flows) + +For server contexts, additional helpers: +- `ctx.loggingNotification(level, data, logger)`: Send logging notification +- `ctx.requestSampling(params)`: Request sampling from client +- `ctx.elicitInput(params)`: Request user input from client ### Capability Checking @@ -197,7 +214,7 @@ const result = await server.createMessage({ }); // Client must have registered handler: -client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { +client.setRequestHandler(CreateMessageRequestSchema, async (request, ctx) => { // Client-side LLM call return { role: "assistant", content: {...} }; }); @@ -208,8 +225,8 @@ client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { ### Request Handler Registration (Low-Level Server) ```typescript -server.setRequestHandler(SomeRequestSchema, async (request, extra) => { - // extra contains sessionId, authInfo, sendNotification, etc. +server.setRequestHandler(SomeRequestSchema, async (request, ctx) => { + // ctx provides mcpCtx, requestCtx, taskCtx, sendNotification, sendRequest return { /* result */ }; @@ -219,7 +236,7 @@ server.setRequestHandler(SomeRequestSchema, async (request, extra) => { ### Tool Registration (High-Level McpServer) ```typescript -mcpServer.tool('tool-name', { param: z.string() }, async ({ param }, extra) => { +mcpServer.tool('tool-name', { param: z.string() }, async ({ param }, ctx) => { return { content: [{ type: 'text', text: 'result' }] }; }); ``` diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index a2195b3c3..c5888c945 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -493,7 +493,7 @@ export class Client< return super.setRequestHandler(requestSchema, handler); } - protected createRequestExtra(args: { + protected createRequestContext(args: { request: JSONRPCRequest; taskStore: TaskStore | undefined; relatedTaskId: string | undefined; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 585546488..4aeddeed9 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -588,7 +588,7 @@ export abstract class Protocol = this.createRequestExtra({ + const fullExtra: ContextInterface = this.createRequestContext({ request, taskStore: this._taskStore, relatedTaskId, @@ -672,7 +672,7 @@ export abstract class Protocol Date: Thu, 22 Jan 2026 00:02:11 +0200 Subject: [PATCH 09/17] fix server conformance --- packages/server/src/experimental/tasks/mcpServer.ts | 12 ++++++------ src/conformance/everything-server.ts | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/packages/server/src/experimental/tasks/mcpServer.ts b/packages/server/src/experimental/tasks/mcpServer.ts index 6fd5a6cc5..9c065f20c 100644 --- a/packages/server/src/experimental/tasks/mcpServer.ts +++ b/packages/server/src/experimental/tasks/mcpServer.ts @@ -55,16 +55,16 @@ export class ExperimentalMcpServerTasks { * inputSchema: { input: z.string() }, * execution: { taskSupport: 'required' } * }, { - * createTask: async (args, extra) => { - * const task = await extra.taskStore.createTask({ ttl: 300000 }); + * createTask: async (args, ctx) => { + * const task = await ctx.taskCtx!.store.createTask({ ttl: 300000 }); * startBackgroundWork(task.taskId, args); * return { task }; * }, - * getTask: async (args, extra) => { - * return extra.taskStore.getTask(extra.taskId); + * getTask: async (args, ctx) => { + * return ctx.taskCtx!.store.getTask(ctx.taskCtx!.id); * }, - * getTaskResult: async (args, extra) => { - * return extra.taskStore.getTaskResult(extra.taskId); + * getTaskResult: async (args, ctx) => { + * return ctx.taskCtx!.store.getTaskResult(ctx.taskCtx!.id); * } * }); * ``` diff --git a/src/conformance/everything-server.ts b/src/conformance/everything-server.ts index 7f75ae3e2..d55078cbf 100644 --- a/src/conformance/everything-server.ts +++ b/src/conformance/everything-server.ts @@ -1028,4 +1028,4 @@ const PORT = process.env.PORT || 3000; app.listen(PORT, () => { console.log(`MCP Conformance Test Server running on http://localhost:${PORT}`); console.log(` - MCP endpoint: http://localhost:${PORT}/mcp`); -}); +}); \ No newline at end of file From f60947827e2445d1039faf42df7b070d34ce69e9 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Thu, 22 Jan 2026 00:15:05 +0200 Subject: [PATCH 10/17] fix conformance --- package.json | 2 +- pnpm-lock.yaml | 18 +++++++++--------- src/conformance/everything-server.ts | 14 +++++++------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/package.json b/package.json index 58e6a160a..3fcffb96e 100644 --- a/package.json +++ b/package.json @@ -47,7 +47,7 @@ "@changesets/cli": "^2.29.8", "@eslint/js": "catalog:devTools", "@modelcontextprotocol/client": "workspace:^", - "@modelcontextprotocol/conformance": "0.1.9", + "@modelcontextprotocol/conformance": "0.1.10", "@modelcontextprotocol/node": "workspace:^", "@modelcontextprotocol/server": "workspace:^", "@pnpm/workspace.find-packages": "^1000.0.54", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c2d18d927..6b980d0be 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -147,8 +147,8 @@ importers: specifier: workspace:^ version: link:packages/client '@modelcontextprotocol/conformance': - specifier: 0.1.9 - version: 0.1.9(@cfworker/json-schema@4.1.1)(hono@4.11.3) + specifier: 0.1.10 + version: 0.1.10(@cfworker/json-schema@4.1.1)(hono@4.11.3) '@modelcontextprotocol/node': specifier: workspace:^ version: link:packages/middleware/node @@ -1292,12 +1292,12 @@ packages: '@manypkg/get-packages@1.1.3': resolution: {integrity: sha512-fo+QhuU3qE/2TQMQmbVMqaQ6EWbMhi4ABWP+O4AM1NqPBuy0OrApV5LO6BrrgnhtAHS2NH6RrVk9OL181tTi8A==} - '@modelcontextprotocol/conformance@0.1.9': - resolution: {integrity: sha512-hpR5PoW0feue3LHSi1kJNhQxbySEQNWR6McuB3QCoK0zsxIdoq+id4GxRwWVOnRnjOiTecDKMD1QMfXuurDZPQ==} + '@modelcontextprotocol/conformance@0.1.10': + resolution: {integrity: sha512-efzLxW3sNiC48ARADxNkSNSZREdjtQNJ+12MJHCUDSnHZMbiFa7v5SivB2Aja4LCE0ZiWgnG5I5PSNMBrp1xKg==} hasBin: true - '@modelcontextprotocol/sdk@1.25.1': - resolution: {integrity: sha512-yO28oVFFC7EBoiKdAn+VqRm+plcfv4v0xp6osG/VsCB0NlPZWi87ajbCZZ8f/RvOFLEu7//rSRmuZZ7lMoe3gQ==} + '@modelcontextprotocol/sdk@1.25.2': + resolution: {integrity: sha512-LZFeo4F9M5qOhC/Uc1aQSrBHxMrvxett+9KLHt7OhcExtoiRN9DKgbZffMP/nxjutWDQpfMDfP3nkHI4X9ijww==} engines: {node: '>=18'} peerDependencies: '@cfworker/json-schema': ^4.1.1 @@ -5972,9 +5972,9 @@ snapshots: globby: 11.1.0 read-yaml-file: 1.1.0 - '@modelcontextprotocol/conformance@0.1.9(@cfworker/json-schema@4.1.1)(hono@4.11.3)': + '@modelcontextprotocol/conformance@0.1.10(@cfworker/json-schema@4.1.1)(hono@4.11.3)': dependencies: - '@modelcontextprotocol/sdk': 1.25.1(@cfworker/json-schema@4.1.1)(hono@4.11.3)(zod@3.25.76) + '@modelcontextprotocol/sdk': 1.25.2(@cfworker/json-schema@4.1.1)(hono@4.11.3)(zod@3.25.76) commander: 14.0.2 eventsource-parser: 3.0.6 express: 5.2.1 @@ -5985,7 +5985,7 @@ snapshots: - hono - supports-color - '@modelcontextprotocol/sdk@1.25.1(@cfworker/json-schema@4.1.1)(hono@4.11.3)(zod@3.25.76)': + '@modelcontextprotocol/sdk@1.25.2(@cfworker/json-schema@4.1.1)(hono@4.11.3)(zod@3.25.76)': dependencies: '@hono/node-server': 1.19.8(hono@4.11.3) ajv: 8.17.1 diff --git a/src/conformance/everything-server.ts b/src/conformance/everything-server.ts index d55078cbf..637c1ed3f 100644 --- a/src/conformance/everything-server.ts +++ b/src/conformance/everything-server.ts @@ -251,7 +251,7 @@ function createMcpServer(sessionId?: string) { inputSchema: {} }, async (_args, extra): Promise => { - const progressToken = extra._meta?.progressToken ?? 0; + const progressToken = extra.mcpCtx._meta?.progressToken ?? 0; console.log('Progress token:', progressToken); await extra.sendNotification({ method: 'notifications/progress', @@ -313,20 +313,20 @@ function createMcpServer(sessionId?: string) { async (_args, extra): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - console.log(`[${extra.sessionId}] Starting test_reconnection tool...`); + console.log(`[${extra.mcpCtx.sessionId}] Starting test_reconnection tool...`); // Get the transport for this session - const transport = extra.sessionId ? transports[extra.sessionId] : undefined; - if (transport && extra.requestId) { + const transport = extra.mcpCtx.sessionId ? transports[extra.mcpCtx.sessionId] : undefined; + if (transport && extra.mcpCtx.requestId) { // Close the SSE stream to trigger client reconnection - console.log(`[${extra.sessionId}] Closing SSE stream to trigger client polling...`); - transport.closeSSEStream(extra.requestId); + console.log(`[${extra.mcpCtx.sessionId}] Closing SSE stream to trigger client polling...`); + transport.closeSSEStream(extra.mcpCtx.requestId); } // Wait for client to reconnect (should respect retry field) await sleep(100); - console.log(`[${extra.sessionId}] test_reconnection tool complete`); + console.log(`[${extra.mcpCtx.sessionId}] test_reconnection tool complete`); return { content: [ From 3c4f9d859bad75402e198590f39932a762371359 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Thu, 22 Jan 2026 00:25:48 +0200 Subject: [PATCH 11/17] move types --- CLAUDE.md | 5 + examples/server/src/ssePollingExample.ts | 3 +- packages/client/src/client/client.ts | 2 +- packages/client/src/client/context.ts | 18 ++- .../client/test/client/streamableHttp.test.ts | 138 ++++++++++-------- packages/core/src/shared/context.ts | 67 +-------- .../node/test/streamableHttp.test.ts | 69 ++++----- packages/server/src/server/context.ts | 40 ++++- packages/server/src/server/server.ts | 2 +- src/conformance/everything-server.ts | 2 +- 10 files changed, 173 insertions(+), 173 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 32d98ac8f..68302a22b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -161,26 +161,31 @@ server.setRequestHandler(CallToolRequestSchema, async (request, ctx) => { The `ctx` parameter in handlers provides a structured context with three layers: **`ctx.mcpCtx`** - MCP-level context: + - `requestId`: JSON-RPC message ID - `method`: The method being called - `_meta`: Request metadata - `sessionId`: Transport session identifier **`ctx.requestCtx`** - Request-level context: + - `signal`: AbortSignal for cancellation - `authInfo`: Validated auth token info (if authenticated) - For server: `uri`, `headers`, `stream` (HTTP details) **`ctx.taskCtx`** - Task context (when tasks are enabled): + - `id`: Current task ID (updates after `store.createTask()`) - `store`: Request-scoped task store (`RequestTaskStore`) - `requestedTtl`: Requested TTL for the task **Context methods**: + - `ctx.sendNotification(notification)`: Send notification back - `ctx.sendRequest(request, schema)`: Send request (for bidirectional flows) For server contexts, additional helpers: + - `ctx.loggingNotification(level, data, logger)`: Send logging notification - `ctx.requestSampling(params)`: Request sampling from client - `ctx.elicitInput(params)`: Request user input from client diff --git a/examples/server/src/ssePollingExample.ts b/examples/server/src/ssePollingExample.ts index b7de18eb1..10591de9b 100644 --- a/examples/server/src/ssePollingExample.ts +++ b/examples/server/src/ssePollingExample.ts @@ -14,10 +14,9 @@ */ import { randomUUID } from 'node:crypto'; -import type { ServerRequestContext } from '@modelcontextprotocol/core'; import { createMcpExpressApp } from '@modelcontextprotocol/express'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; -import type { CallToolResult } from '@modelcontextprotocol/server'; +import type { CallToolResult, ServerRequestContext } from '@modelcontextprotocol/server'; import { McpServer } from '@modelcontextprotocol/server'; import cors from 'cors'; import type { Request, Response } from 'express'; diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index c5888c945..4c54c7f7e 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -5,7 +5,6 @@ import type { ClientCapabilities, ClientNotification, ClientRequest, - ClientRequestContext, ClientResult, CompatibilityCallToolResultSchema, CompleteRequest, @@ -79,6 +78,7 @@ import { } from '@modelcontextprotocol/core'; import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; +import type { ClientRequestContext } from './context.js'; import { ClientContext } from './context.js'; /** diff --git a/packages/client/src/client/context.ts b/packages/client/src/client/context.ts index 74143f413..48965db5e 100644 --- a/packages/client/src/client/context.ts +++ b/packages/client/src/client/context.ts @@ -1,12 +1,12 @@ import type { + BaseRequestContext, ClientNotification, ClientRequest, - ClientRequestContext, + ClientResult, ContextInterface, JSONRPCRequest, McpContext, Notification, - ProtocolInterface, Request, Result, TaskContext @@ -15,6 +15,16 @@ import { BaseContext } from '@modelcontextprotocol/core'; import type { Client } from './client.js'; +/** + * Client-specific request context. + * Clients don't receive HTTP requests, so this is minimal. + * Extends BaseRequestContext with any client-specific fields. + */ +export type ClientRequestContext = BaseRequestContext & { + // Client doesn't receive HTTP requests, just JSON-RPC messages over transport. + // Additional client-specific fields can be added here if needed. +}; + /** * Type alias for client-side request handler context. * Extends the base ContextInterface with ClientRequestContext. @@ -34,7 +44,7 @@ export class ClientContext< NotificationT extends Notification = Notification, ResultT extends Result = Result > - extends BaseContext + extends BaseContext implements ClientContextInterface { private readonly client: Client; @@ -58,7 +68,7 @@ export class ClientContext< /** * Returns the client instance for sending notifications and requests. */ - protected getProtocol(): ProtocolInterface { + protected getProtocol(): Client { return this.client; } } diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 0c5d2dc01..13d746ca9 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -1,6 +1,6 @@ import type { JSONRPCMessage, JSONRPCRequest } from '@modelcontextprotocol/core'; import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from '@modelcontextprotocol/core'; -import { type Mock, type Mocked } from 'vitest'; +import type { Mock, Mocked } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; @@ -28,7 +28,7 @@ describe('StreamableHTTPClientTransport', () => { invalidateCredentials: vi.fn() }; transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider }); - vi.spyOn(global, 'fetch'); + vi.spyOn(globalThis, 'fetch'); }); afterEach(async () => { @@ -44,7 +44,7 @@ describe('StreamableHTTPClientTransport', () => { id: 'test-id' }; - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() @@ -52,7 +52,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.send(message); - expect(global.fetch).toHaveBeenCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( expect.anything(), expect.objectContaining({ method: 'POST', @@ -68,7 +68,7 @@ describe('StreamableHTTPClientTransport', () => { { jsonrpc: '2.0', method: 'test2', params: {}, id: 'id2' } ]; - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 200, headers: new Headers({ 'content-type': 'text/event-stream' }), @@ -77,7 +77,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.send(messages); - expect(global.fetch).toHaveBeenCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( expect.anything(), expect.objectContaining({ method: 'POST', @@ -98,7 +98,7 @@ describe('StreamableHTTPClientTransport', () => { id: 'init-id' }; - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 200, headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) @@ -107,7 +107,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.send(message); // Send a second message that should include the session ID - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() @@ -116,8 +116,8 @@ describe('StreamableHTTPClientTransport', () => { await transport.send({ jsonrpc: '2.0', method: 'test', params: {} } as JSONRPCMessage); // Check that second request included session ID header - const calls = (global.fetch as Mock).mock.calls; - const lastCall = calls[calls.length - 1]!; + const calls = (globalThis.fetch as Mock).mock.calls; + const lastCall = calls.at(-1)!; expect(lastCall[1].headers).toBeDefined(); expect(lastCall[1].headers.get('mcp-session-id')).toBe('test-session-id'); }); @@ -134,7 +134,7 @@ describe('StreamableHTTPClientTransport', () => { id: 'init-id' }; - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 200, headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) @@ -144,7 +144,7 @@ describe('StreamableHTTPClientTransport', () => { expect(transport.sessionId).toBe('test-session-id'); // Now terminate the session - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 200, headers: new Headers() @@ -153,8 +153,8 @@ describe('StreamableHTTPClientTransport', () => { await transport.terminateSession(); // Verify the DELETE request was sent with the session ID - const calls = (global.fetch as Mock).mock.calls; - const lastCall = calls[calls.length - 1]!; + const calls = (globalThis.fetch as Mock).mock.calls; + const lastCall = calls.at(-1)!; expect(lastCall[1].method).toBe('DELETE'); expect(lastCall[1].headers.get('mcp-session-id')).toBe('test-session-id'); @@ -174,7 +174,7 @@ describe('StreamableHTTPClientTransport', () => { id: 'init-id' }; - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 200, headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) @@ -183,7 +183,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.send(message); // Now terminate the session, but server responds with 405 - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: false, status: 405, statusText: 'Method Not Allowed', @@ -201,7 +201,7 @@ describe('StreamableHTTPClientTransport', () => { id: 'test-id' }; - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: false, status: 404, statusText: 'Not Found', @@ -230,7 +230,7 @@ describe('StreamableHTTPClientTransport', () => { id: 'test-id' }; - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 200, headers: new Headers({ 'content-type': 'application/json' }), @@ -247,7 +247,7 @@ describe('StreamableHTTPClientTransport', () => { it('should attempt initial GET connection and handle 405 gracefully', async () => { // Mock the server not supporting GET for SSE (returning 405) - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: false, status: 405, statusText: 'Method Not Allowed' @@ -258,7 +258,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.start(); await expect(transport['_startOrAuthSse']({})).resolves.not.toThrow('Failed to open SSE stream: Method Not Allowed'); // Check that GET was attempted - expect(global.fetch).toHaveBeenCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( expect.anything(), expect.objectContaining({ method: 'GET', @@ -267,14 +267,14 @@ describe('StreamableHTTPClientTransport', () => { ); // Verify transport still works after 405 - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); await transport.send({ jsonrpc: '2.0', method: 'test', params: {} } as JSONRPCMessage); - expect(global.fetch).toHaveBeenCalledTimes(2); + expect(globalThis.fetch).toHaveBeenCalledTimes(2); }); it('should handle successful initial GET connection for SSE', async () => { @@ -289,7 +289,7 @@ describe('StreamableHTTPClientTransport', () => { }); // Mock successful GET connection - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 200, headers: new Headers({ 'content-type': 'text/event-stream' }), @@ -326,7 +326,7 @@ describe('StreamableHTTPClientTransport', () => { }); }; - (global.fetch as Mock) + (globalThis.fetch as Mock) .mockResolvedValueOnce({ ok: true, status: 200, @@ -376,7 +376,7 @@ describe('StreamableHTTPClientTransport', () => { transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 500, - maxReconnectionDelay: 10000, + maxReconnectionDelay: 10_000, reconnectionDelayGrowFactor: 2, maxRetries: 5 } @@ -396,7 +396,7 @@ describe('StreamableHTTPClientTransport', () => { transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); // Mock fetch to verify headers sent - const fetchSpy = global.fetch as Mock; + const fetchSpy = globalThis.fetch as Mock; fetchSpy.mockReset(); fetchSpy.mockResolvedValue({ ok: true, @@ -444,7 +444,7 @@ describe('StreamableHTTPClientTransport', () => { const errorSpy = vi.fn(); transport.onerror = errorSpy; - (global.fetch as Mock).mockResolvedValueOnce({ + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 200, headers: new Headers({ 'content-type': 'text/plain' }), @@ -477,7 +477,7 @@ describe('StreamableHTTPClientTransport', () => { expect(customFetch).toHaveBeenCalled(); // Global fetch should never have been called - expect(global.fetch).not.toHaveBeenCalled(); + expect(globalThis.fetch).not.toHaveBeenCalled(); }); it('should always send specified custom headers', async () => { @@ -493,7 +493,7 @@ describe('StreamableHTTPClientTransport', () => { let actualReqInit: RequestInit = {}; - (global.fetch as Mock).mockImplementation(async (_url, reqInit) => { + (globalThis.fetch as Mock).mockImplementation(async (_url, reqInit) => { actualReqInit = reqInit; return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } }); }); @@ -509,7 +509,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.send({ jsonrpc: '2.0', method: 'test', params: {} } as JSONRPCMessage); expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('SecondCustomValue'); - expect(global.fetch).toHaveBeenCalledTimes(2); + expect(globalThis.fetch).toHaveBeenCalledTimes(2); }); it('should always send specified custom headers (Headers class)', async () => { @@ -525,7 +525,7 @@ describe('StreamableHTTPClientTransport', () => { let actualReqInit: RequestInit = {}; - (global.fetch as Mock).mockImplementation(async (_url, reqInit) => { + (globalThis.fetch as Mock).mockImplementation(async (_url, reqInit) => { actualReqInit = reqInit; return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } }); }); @@ -541,7 +541,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.send({ jsonrpc: '2.0', method: 'test', params: {} } as JSONRPCMessage); expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('SecondCustomValue'); - expect(global.fetch).toHaveBeenCalledTimes(2); + expect(globalThis.fetch).toHaveBeenCalledTimes(2); }); it('should always send specified custom headers (array of tuples)', async () => { @@ -556,7 +556,7 @@ describe('StreamableHTTPClientTransport', () => { let actualReqInit: RequestInit = {}; - (global.fetch as Mock).mockImplementation(async (_url, reqInit) => { + (globalThis.fetch as Mock).mockImplementation(async (_url, reqInit) => { actualReqInit = reqInit; return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } }); }); @@ -608,18 +608,22 @@ describe('StreamableHTTPClientTransport', () => { id: 'test-id' }; - (global.fetch as Mock) + (globalThis.fetch as Mock) .mockResolvedValueOnce({ ok: false, status: 401, statusText: 'Unauthorized', headers: new Headers(), - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }) .mockResolvedValue({ ok: false, status: 404, - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }); await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); @@ -634,7 +638,7 @@ describe('StreamableHTTPClientTransport', () => { id: 'test-id' }; - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; fetchMock // First call: returns 403 with insufficient_scope .mockResolvedValueOnce({ @@ -685,7 +689,7 @@ describe('StreamableHTTPClientTransport', () => { }; // Mock fetch calls to always return 403 with insufficient_scope - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; fetchMock.mockResolvedValue({ ok: false, status: 403, @@ -745,7 +749,7 @@ describe('StreamableHTTPClientTransport', () => { } }); - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; // Mock the initial GET request, which will fail. fetchMock.mockResolvedValueOnce({ ok: true, @@ -799,7 +803,7 @@ describe('StreamableHTTPClientTransport', () => { } }); - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; // Mock the POST request. It returns a streaming content-type but a failing body. fetchMock.mockResolvedValueOnce({ ok: true, @@ -854,7 +858,7 @@ describe('StreamableHTTPClientTransport', () => { } }); - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; // First call: POST returns streaming response with priming event fetchMock.mockResolvedValueOnce({ ok: true, @@ -917,7 +921,7 @@ describe('StreamableHTTPClientTransport', () => { } }); - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; fetchMock.mockResolvedValueOnce({ ok: true, status: 200, @@ -966,7 +970,7 @@ describe('StreamableHTTPClientTransport', () => { } }); - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; // POST request returns streaming response fetchMock.mockResolvedValueOnce({ @@ -1016,7 +1020,7 @@ describe('StreamableHTTPClientTransport', () => { } }); - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; fetchMock.mockResolvedValueOnce({ ok: true, status: 200, @@ -1065,9 +1069,11 @@ describe('StreamableHTTPClientTransport', () => { status: 401, statusText: 'Unauthorized', headers: new Headers(), - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }; - (global.fetch as Mock) + (globalThis.fetch as Mock) // Initial connection .mockResolvedValueOnce(unauthedResponse) // Resource discovery, path aware @@ -1120,9 +1126,11 @@ describe('StreamableHTTPClientTransport', () => { status: 401, statusText: 'Unauthorized', headers: new Headers(), - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }; - (global.fetch as Mock) + (globalThis.fetch as Mock) // Initial connection .mockResolvedValueOnce(unauthedResponse) // Resource discovery, path aware @@ -1147,7 +1155,9 @@ describe('StreamableHTTPClientTransport', () => { .mockResolvedValue({ ok: false, status: 404, - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }); // As above, just ensure the auth flow completes without unhandled @@ -1174,9 +1184,11 @@ describe('StreamableHTTPClientTransport', () => { status: 401, statusText: 'Unauthorized', headers: new Headers(), - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }; - (global.fetch as Mock) + (globalThis.fetch as Mock) // Initial connection .mockResolvedValueOnce(unauthedResponse) // Resource discovery, path aware @@ -1201,7 +1213,9 @@ describe('StreamableHTTPClientTransport', () => { .mockResolvedValue({ ok: false, status: 404, - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }); // Behavior for InvalidGrantError during auth is covered in dedicated OAuth @@ -1217,7 +1231,9 @@ describe('StreamableHTTPClientTransport', () => { status: 401, statusText: 'Unauthorized', headers: new Headers(), - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }; // Create custom fetch @@ -1278,7 +1294,7 @@ describe('StreamableHTTPClientTransport', () => { expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); // Global fetch should never have been called - expect(global.fetch).not.toHaveBeenCalled(); + expect(globalThis.fetch).not.toHaveBeenCalled(); }); it('uses custom fetch in finishAuth method - no global fetch fallback', async () => { @@ -1353,14 +1369,14 @@ describe('StreamableHTTPClientTransport', () => { }); // Global fetch should never have been called - expect(global.fetch).not.toHaveBeenCalled(); + expect(globalThis.fetch).not.toHaveBeenCalled(); }); }); describe('SSE retry field handling', () => { beforeEach(() => { vi.useFakeTimers(); - (global.fetch as Mock).mockReset(); + (globalThis.fetch as Mock).mockReset(); }); afterEach(() => vi.useRealTimers()); @@ -1387,7 +1403,7 @@ describe('StreamableHTTPClientTransport', () => { } }); - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; fetchMock.mockResolvedValueOnce({ ok: true, status: 200, @@ -1463,7 +1479,7 @@ describe('StreamableHTTPClientTransport', () => { } }); - const fetchMock = global.fetch as Mock; + const fetchMock = globalThis.fetch as Mock; fetchMock.mockResolvedValueOnce({ ok: true, status: 200, @@ -1582,10 +1598,12 @@ describe('StreamableHTTPClientTransport', () => { status: 401, statusText: 'Unauthorized', headers: new Headers(), - text: async () => Promise.reject('dont read my body') + text: async () => { + throw 'dont read my body'; + } }; - (global.fetch as Mock) + (globalThis.fetch as Mock) // First request - 401, triggers auth flow .mockResolvedValueOnce(unauthedResponse) // Resource discovery, path aware diff --git a/packages/core/src/shared/context.ts b/packages/core/src/shared/context.ts index 748467501..4e117bcee 100644 --- a/packages/core/src/shared/context.ts +++ b/packages/core/src/shared/context.ts @@ -1,7 +1,7 @@ import type { RequestTaskStoreInterface } from '../experimental/requestTaskStore.js'; -import type { AuthInfo, JSONRPCRequest, Notification, Request, RequestId, RequestMeta } from '../types/types.js'; +import type { AuthInfo, JSONRPCRequest, Notification, Request, RequestId, RequestMeta, Result } from '../types/types.js'; import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; -import type { NotificationOptions, RequestOptions } from './protocol.js'; +import type { NotificationOptions, Protocol, RequestOptions } from './protocol.js'; /** * MCP-level context for a request being handled. @@ -27,22 +27,6 @@ export type McpContext = { sessionId?: string; }; -/** - * Interface for protocol operations needed by context classes. - * This allows the base context to work with both Client and Server. - */ -export interface ProtocolInterface { - /** - * Sends a notification through the protocol. - */ - notification(notification: NotificationT, options?: NotificationOptions): Promise; - - /** - * Sends a request through the protocol. - */ - request(request: RequestT, resultSchema: U, options?: RequestOptions): Promise>; -} - /** * Base request context with fields common to both client and server. */ @@ -57,48 +41,6 @@ export type BaseRequestContext = { authInfo?: AuthInfo; }; -/** - * Server-specific request context with HTTP request details. - * Extends BaseRequestContext with fields only available on the server side. - */ -export type ServerRequestContext = BaseRequestContext & { - /** - * The URI of the incoming HTTP request. - */ - uri: URL; - /** - * The headers of the incoming HTTP request. - */ - headers: Headers; - /** - * Stream control methods for SSE connections. - */ - stream: { - /** - * Closes the SSE stream for this request, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. - * Use this to implement polling behavior during long-running operations. - */ - closeSSEStream: (() => void) | undefined; - /** - * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using StreamableHTTPServerTransport with eventStore configured. - * Use this to implement polling behavior for server-initiated notifications. - */ - closeStandaloneSSEStream: (() => void) | undefined; - }; -}; - -/** - * Client-specific request context. - * Clients don't receive HTTP requests, so this is minimal. - * Extends BaseRequestContext with any client-specific fields. - */ -export type ClientRequestContext = BaseRequestContext & { - // Client doesn't receive HTTP requests, just JSON-RPC messages over transport. - // Additional client-specific fields can be added here if needed. -}; - /** * Task-related context for task-augmented requests. */ @@ -187,7 +129,8 @@ export interface BaseContextArgs implements ContextInterface { /** @@ -209,7 +152,7 @@ export abstract class BaseContext< * Returns the protocol instance for sending notifications and requests. * Subclasses must implement this to provide the appropriate Client or Server instance. */ - protected abstract getProtocol(): ProtocolInterface; + protected abstract getProtocol(): Protocol; constructor(args: BaseContextArgs) { this.mcpCtx = { diff --git a/packages/middleware/node/test/streamableHttp.test.ts b/packages/middleware/node/test/streamableHttp.test.ts index 5c061dc22..1ddcde389 100644 --- a/packages/middleware/node/test/streamableHttp.test.ts +++ b/packages/middleware/node/test/streamableHttp.test.ts @@ -10,10 +10,9 @@ import type { JSONRPCErrorResponse, JSONRPCMessage, JSONRPCResultResponse, - RequestId, - ServerRequestContext + RequestId } from '@modelcontextprotocol/core'; -import type { EventId, EventStore, StreamId } from '@modelcontextprotocol/server'; +import type { EventId, EventStore, ServerRequestContext, StreamId } from '@modelcontextprotocol/server'; import { McpServer } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -27,7 +26,7 @@ async function getFreePort() { srv.listen(0, () => { const address = srv.address()!; if (typeof address === 'string') { - throw new Error('Unexpected address type: ' + typeof address); + throw new TypeError('Unexpected address type: ' + typeof address); } const port = (address as AddressInfo).port; srv.close(_err => res(port)); @@ -183,11 +182,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const server = createServer(async (req, res) => { try { - if (config.customRequestHandler) { - await config.customRequestHandler(req, res); - } else { - await transport.handleRequest(req, res); - } + await (config.customRequestHandler ? config.customRequestHandler(req, res) : transport.handleRequest(req, res)); } catch (error) { console.error('Error handling request:', error); if (!res.headersSent) res.writeHead(500).end(); @@ -304,7 +299,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = await response.json(); - expectErrorResponse(errorData, -32600, /Server already initialized/); + expectErrorResponse(errorData, -32_600, /Server already initialized/); }); it('should reject batch initialize request', async () => { @@ -325,7 +320,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = await response.json(); - expectErrorResponse(errorData, -32600, /Only one initialization request is allowed/); + expectErrorResponse(errorData, -32_600, /Only one initialization request is allowed/); }); it('should handle post requests via sse response correctly', async () => { @@ -343,7 +338,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const dataLine = eventLines.find(line => line.startsWith('data:')); expect(dataLine).toBeDefined(); - const eventData = JSON.parse(dataLine!.substring(5)); + const eventData = JSON.parse(dataLine!.slice(5)); expect(eventData).toMatchObject({ jsonrpc: '2.0', result: expect.objectContaining({ @@ -381,7 +376,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const dataLine = eventLines.find(line => line.startsWith('data:')); expect(dataLine).toBeDefined(); - const eventData = JSON.parse(dataLine!.substring(5)); + const eventData = JSON.parse(dataLine!.slice(5)); expect(eventData).toMatchObject({ jsonrpc: '2.0', result: { @@ -442,7 +437,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const dataLine = eventLines.find(line => line.startsWith('data:')); expect(dataLine).toBeDefined(); - const eventData = JSON.parse(dataLine!.substring(5)); + const eventData = JSON.parse(dataLine!.slice(5)); expect(eventData).toMatchObject({ jsonrpc: '2.0', @@ -475,7 +470,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = (await response.json()) as JSONRPCErrorResponse; - expectErrorResponse(errorData, -32000, /Bad Request/); + expectErrorResponse(errorData, -32_000, /Bad Request/); expect(errorData.id).toBeNull(); }); @@ -488,7 +483,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(404); const errorData = await response.json(); - expectErrorResponse(errorData, -32001, /Session not found/); + expectErrorResponse(errorData, -32_001, /Session not found/); }); it('should establish standalone SSE stream and receive server-initiated messages', async () => { @@ -525,7 +520,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const dataLine = eventLines.find(line => line.startsWith('data:')); expect(dataLine).toBeDefined(); - const eventData = JSON.parse(dataLine!.substring(5)); + const eventData = JSON.parse(dataLine!.slice(5)); expect(eventData).toMatchObject({ jsonrpc: '2.0', method: 'notifications/message', @@ -593,7 +588,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Should be rejected expect(secondStream.status).toBe(409); // Conflict const errorData = await secondStream.json(); - expectErrorResponse(errorData, -32000, /Only one SSE stream is allowed per session/); + expectErrorResponse(errorData, -32_000, /Only one SSE stream is allowed per session/); }); it('should reject GET requests without Accept: text/event-stream header', async () => { @@ -611,7 +606,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(406); const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Client must accept text\/event-stream/); + expectErrorResponse(errorData, -32_000, /Client must accept text\/event-stream/); }); it('should reject POST requests without proper Accept header', async () => { @@ -630,7 +625,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(406); const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Client must accept both application\/json and text\/event-stream/); + expectErrorResponse(errorData, -32_000, /Client must accept both application\/json and text\/event-stream/); }); it('should reject unsupported Content-Type', async () => { @@ -649,7 +644,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(415); const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Content-Type must be application\/json/); + expectErrorResponse(errorData, -32_000, /Content-Type must be application\/json/); }); it('should handle JSON-RPC batch notification messages with 202 response', async () => { @@ -707,7 +702,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = await response.json(); - expectErrorResponse(errorData, -32700, /Parse error/); + expectErrorResponse(errorData, -32_700, /Parse error/); }); it('should include error data in parse error response for unexpected errors', async () => { @@ -727,7 +722,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = (await response.json()) as JSONRPCErrorResponse; - expectErrorResponse(errorData, -32700, /Parse error/); + expectErrorResponse(errorData, -32_700, /Parse error/); // The error message should contain details about what went wrong expect(errorData.error.message).toContain('Invalid JSON'); }); @@ -765,7 +760,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Server not initialized/); + expectErrorResponse(errorData, -32_000, /Server not initialized/); // Cleanup await stopTestServer({ server: uninitializedServer, transport: uninitializedTransport }); @@ -884,7 +879,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(404); const errorData = await response.json(); - expectErrorResponse(errorData, -32001, /Session not found/); + expectErrorResponse(errorData, -32_001, /Session not found/); }); describe('protocol version header validation', () => { @@ -932,7 +927,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version: .+ \(supported versions: .+\)/); + expectErrorResponse(errorData, -32_000, /Bad Request: Unsupported protocol version: .+ \(supported versions: .+\)/); }); it('should accept when protocol version differs from negotiated version', async () => { @@ -969,7 +964,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version/); + expectErrorResponse(errorData, -32_000, /Bad Request: Unsupported protocol version/); }); it('should reject unsupported protocol version on DELETE requests', async () => { @@ -986,7 +981,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(response.status).toBe(400); const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version/); + expectErrorResponse(errorData, -32_000, /Bad Request: Unsupported protocol version/); }); }); }); @@ -1038,7 +1033,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const dataLine = eventLines.find(line => line.startsWith('data:')); expect(dataLine).toBeDefined(); - const eventData = JSON.parse(dataLine!.substring(5)); + const eventData = JSON.parse(dataLine!.slice(5)); expect(eventData).toMatchObject({ jsonrpc: '2.0', result: { @@ -1074,7 +1069,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const dataLine = eventLines.find(line => line.startsWith('data:')); expect(dataLine).toBeDefined(); - const eventData = JSON.parse(dataLine!.substring(5)); + const eventData = JSON.parse(dataLine!.slice(5)); expect(eventData).toMatchObject({ jsonrpc: '2.0', result: { @@ -1189,11 +1184,11 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const result = await createTestServer({ customRequestHandler: async (req, res) => { try { - if (parsedBody !== null) { + if (parsedBody === null) { + await transport.handleRequest(req, res); + } else { await transport.handleRequest(req, res, parsedBody); parsedBody = null; // Reset after use - } else { - await transport.handleRequest(req, res); } } catch (error) { console.error('Error handling request:', error); @@ -2347,7 +2342,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Verify we received the notification that was sent while disconnected expect(allText).toContain('Missed while disconnected'); - }, 15000); + }, 15_000); }); // Test onsessionclosed callback @@ -2521,8 +2516,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { initializationOrder.push('async-start'); // Simulate async operation await new Promise(resolve => setTimeout(resolve, 10)); - initializationOrder.push('async-end'); - initializationOrder.push(sessionId); + initializationOrder.push('async-end', sessionId); } }); @@ -2576,8 +2570,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { closureOrder.push('async-close-start'); // Simulate async operation await new Promise(resolve => setTimeout(resolve, 10)); - closureOrder.push('async-close-end'); - closureOrder.push(sessionId); + closureOrder.push('async-close-end', sessionId); } }); diff --git a/packages/server/src/server/context.ts b/packages/server/src/server/context.ts index 7f9eb6b60..81e693f68 100644 --- a/packages/server/src/server/context.ts +++ b/packages/server/src/server/context.ts @@ -1,4 +1,5 @@ import type { + BaseRequestContext, ContextInterface, CreateMessageRequest, CreateMessageResult, @@ -8,19 +9,50 @@ import type { LoggingMessageNotification, McpContext, Notification, - ProtocolInterface, Request, RequestOptions, Result, ServerNotification, ServerRequest, - ServerRequestContext, + ServerResult, TaskContext } from '@modelcontextprotocol/core'; import { BaseContext, ElicitResultSchema } from '@modelcontextprotocol/core'; import type { Server } from './server.js'; +/** + * Server-specific request context with HTTP request details. + * Extends BaseRequestContext with fields only available on the server side. + */ +export type ServerRequestContext = BaseRequestContext & { + /** + * The URI of the incoming HTTP request. + */ + uri: URL; + /** + * The headers of the incoming HTTP request. + */ + headers: Headers; + /** + * Stream control methods for SSE connections. + */ + stream: { + /** + * Closes the SSE stream for this request, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior during long-running operations. + */ + closeSSEStream: (() => void) | undefined; + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior for server-initiated notifications. + */ + closeStandaloneSSEStream: (() => void) | undefined; + }; +}; + /** * Interface for sending logging messages to the client via {@link LoggingMessageNotification}. */ @@ -156,7 +188,7 @@ export class ServerContext< NotificationT extends Notification = Notification, ResultT extends Result = Result > - extends BaseContext + extends BaseContext implements ServerContextInterface { private readonly server: Server; @@ -186,7 +218,7 @@ export class ServerContext< /** * Returns the server instance for sending notifications and requests. */ - protected getProtocol(): ProtocolInterface { + protected getProtocol(): Server { return this.server; } diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index a4d21e157..58f4a1fee 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -33,7 +33,6 @@ import type { ServerCapabilities, ServerNotification, ServerRequest, - ServerRequestContext, ServerResult, TaskContext, TaskCreationParams, @@ -72,6 +71,7 @@ import { } from '@modelcontextprotocol/core'; import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; +import type { ServerRequestContext } from './context.js'; import { ServerContext } from './context.js'; export type ServerOptions = ProtocolOptions & { diff --git a/src/conformance/everything-server.ts b/src/conformance/everything-server.ts index 637c1ed3f..685bfb382 100644 --- a/src/conformance/everything-server.ts +++ b/src/conformance/everything-server.ts @@ -1028,4 +1028,4 @@ const PORT = process.env.PORT || 3000; app.listen(PORT, () => { console.log(`MCP Conformance Test Server running on http://localhost:${PORT}`); console.log(` - MCP endpoint: http://localhost:${PORT}/mcp`); -}); \ No newline at end of file +}); From 86549fac656a8fc2409336506e6c01ac8d497bf1 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Thu, 22 Jan 2026 18:51:27 +0200 Subject: [PATCH 12/17] rename extra vars to ctx --- examples/server/src/elicitationUrlExample.ts | 10 +-- .../server/src/jsonResponseStreamableHttp.ts | 8 +- .../src/simpleStatelessStreamableHttp.ts | 4 +- examples/server/src/simpleStreamableHttp.ts | 18 ++--- examples/server/src/simpleTaskInteractive.ts | 10 +-- examples/server/src/ssePollingExample.ts | 2 +- packages/client/src/client/client.ts | 8 +- packages/core/src/shared/protocol.ts | 34 ++++----- packages/core/test/shared/protocol.test.ts | 18 ++--- .../shared/protocolTransportHandling.test.ts | 4 +- packages/server/src/server/mcp.ts | 64 ++++++++-------- packages/server/src/server/server.ts | 24 +++--- src/conformance/everything-server.ts | 48 ++++++------ test/integration/test/server/context.test.ts | 74 +++++++++---------- test/integration/test/server/mcp.test.ts | 28 +++---- 15 files changed, 177 insertions(+), 177 deletions(-) diff --git a/examples/server/src/elicitationUrlExample.ts b/examples/server/src/elicitationUrlExample.ts index 63da6bde2..3348bb9df 100644 --- a/examples/server/src/elicitationUrlExample.ts +++ b/examples/server/src/elicitationUrlExample.ts @@ -46,12 +46,12 @@ const getServer = () => { cartId: z.string().describe('The ID of the cart to confirm') } }, - async ({ cartId }, extra): Promise => { + async ({ cartId }, ctx): Promise => { /* In a real world scenario, there would be some logic here to check if the user has the provided cartId. For the purposes of this example, we'll throw an error (-> elicits the client to open a URL to confirm payment) */ - const sessionId = extra.mcpCtx.sessionId; + const sessionId = ctx.mcpCtx.sessionId; if (!sessionId) { throw new Error('Expected a Session ID'); } @@ -79,15 +79,15 @@ const getServer = () => { param1: z.string().describe('First parameter') } }, - async (_, extra): Promise => { + async (_, ctx): Promise => { /* In a real world scenario, there would be some logic here to check if we already have a valid access token for the user. - Auth info (with a subject or `sub` claim) can be typically be found in `extra.authInfo`. + Auth info (with a subject or `sub` claim) can be typically be found in `ctx.requestCtx.authInfo`. If we do, we can just return the result of the tool call. If we don't, we can throw an ElicitationRequiredError to request the user to authenticate. For the purposes of this example, we'll throw an error (-> elicits the client to open a URL to authenticate). */ - const sessionId = extra.mcpCtx.sessionId; + const sessionId = ctx.mcpCtx.sessionId; if (!sessionId) { throw new Error('Expected a Session ID'); } diff --git a/examples/server/src/jsonResponseStreamableHttp.ts b/examples/server/src/jsonResponseStreamableHttp.ts index 2dab4d0c0..481b47716 100644 --- a/examples/server/src/jsonResponseStreamableHttp.ts +++ b/examples/server/src/jsonResponseStreamableHttp.ts @@ -51,7 +51,7 @@ const getServer = () => { name: z.string().describe('Name to greet') } }, - async ({ name }, extra): Promise => { + async ({ name }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); await server.sendLoggingMessage( @@ -59,7 +59,7 @@ const getServer = () => { level: 'debug', data: `Starting multi-greet for ${name}` }, - extra.mcpCtx.sessionId + ctx.mcpCtx.sessionId ); await sleep(1000); // Wait 1 second before first greeting @@ -69,7 +69,7 @@ const getServer = () => { level: 'info', data: `Sending first greeting to ${name}` }, - extra.mcpCtx.sessionId + ctx.mcpCtx.sessionId ); await sleep(1000); // Wait another second before second greeting @@ -79,7 +79,7 @@ const getServer = () => { level: 'info', data: `Sending second greeting to ${name}` }, - extra.mcpCtx.sessionId + ctx.mcpCtx.sessionId ); return { diff --git a/examples/server/src/simpleStatelessStreamableHttp.ts b/examples/server/src/simpleStatelessStreamableHttp.ts index 6533b32e7..0c7a791b9 100644 --- a/examples/server/src/simpleStatelessStreamableHttp.ts +++ b/examples/server/src/simpleStatelessStreamableHttp.ts @@ -49,7 +49,7 @@ const getServer = () => { count: z.number().describe('Number of notifications to send (0 for 100)').default(10) } }, - async ({ interval, count }, extra): Promise => { + async ({ interval, count }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); let counter = 0; @@ -61,7 +61,7 @@ const getServer = () => { level: 'info', data: `Periodic notification #${counter} at ${new Date().toISOString()}` }, - extra.mcpCtx.sessionId + ctx.mcpCtx.sessionId ); } catch (error) { console.error('Error sending notification:', error); diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 8a881f8f9..808f252cc 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -86,7 +86,7 @@ const getServer = () => { openWorldHint: false } }, - async ({ name }, extra): Promise => { + async ({ name }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); await server.sendLoggingMessage( @@ -94,7 +94,7 @@ const getServer = () => { level: 'debug', data: `Starting multi-greet for ${name}` }, - extra.mcpCtx.sessionId + ctx.mcpCtx.sessionId ); await sleep(1000); // Wait 1 second before first greeting @@ -104,7 +104,7 @@ const getServer = () => { level: 'info', data: `Sending first greeting to ${name}` }, - extra.mcpCtx.sessionId + ctx.mcpCtx.sessionId ); await sleep(1000); // Wait another second before second greeting @@ -114,7 +114,7 @@ const getServer = () => { level: 'info', data: `Sending second greeting to ${name}` }, - extra.mcpCtx.sessionId + ctx.mcpCtx.sessionId ); return { @@ -137,7 +137,7 @@ const getServer = () => { infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect') } }, - async ({ infoType }, extra): Promise => { + async ({ infoType }, ctx): Promise => { let message: string; let requestedSchema: { type: 'object'; @@ -236,8 +236,8 @@ const getServer = () => { } try { - // Use sendRequest through the extra parameter to elicit input - const result = await extra.sendRequest( + // Use sendRequest through the ctx parameter to elicit input + const result = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -325,7 +325,7 @@ const getServer = () => { count: z.number().describe('Number of notifications to send (0 for 100)').default(50) } }, - async ({ interval, count }, extra): Promise => { + async ({ interval, count }, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); let counter = 0; @@ -337,7 +337,7 @@ const getServer = () => { level: 'info', data: `Periodic notification #${counter} at ${new Date().toISOString()}` }, - extra.mcpCtx.sessionId + ctx.mcpCtx.sessionId ); } catch (error) { console.error('Error sending notification:', error); diff --git a/examples/server/src/simpleTaskInteractive.ts b/examples/server/src/simpleTaskInteractive.ts index 458c2ee41..d02f9bf0b 100644 --- a/examples/server/src/simpleTaskInteractive.ts +++ b/examples/server/src/simpleTaskInteractive.ts @@ -516,7 +516,7 @@ const createServer = (): Server => { }); // Handle tool calls - server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { + server.setRequestHandler(CallToolRequestSchema, async (request, ctx): Promise => { const { name, arguments: args } = request.params; const taskParams = (request.params._meta?.task || request.params.task) as { ttl?: number; pollInterval?: number } | undefined; @@ -531,7 +531,7 @@ const createServer = (): Server => { pollInterval: taskParams.pollInterval ?? 1000 }; - const task = await taskStore.createTask(taskOptions, extra.mcpCtx.requestId, request, extra.mcpCtx.sessionId); + const task = await taskStore.createTask(taskOptions, ctx.mcpCtx.requestId, request, ctx.mcpCtx.sessionId); console.log(`\n[Server] ${name} called, task created: ${task.taskId}`); @@ -609,7 +609,7 @@ const createServer = (): Server => { activeTaskExecutions.set(task.taskId, { promise: taskExecution, server, - sessionId: extra.mcpCtx.sessionId ?? '' + sessionId: ctx.mcpCtx.sessionId ?? '' }); return { task }; @@ -626,10 +626,10 @@ const createServer = (): Server => { }); // Handle tasks/result - server.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra): Promise => { + server.setRequestHandler(GetTaskPayloadRequestSchema, async (request, ctx): Promise => { const { taskId } = request.params; console.log(`[Server] tasks/result called for task ${taskId}`); - return taskResultHandler.handle(taskId, server, extra.mcpCtx.sessionId ?? ''); + return taskResultHandler.handle(taskId, server, ctx.mcpCtx.sessionId ?? ''); }); return server; diff --git a/examples/server/src/ssePollingExample.ts b/examples/server/src/ssePollingExample.ts index 10591de9b..8a539de42 100644 --- a/examples/server/src/ssePollingExample.ts +++ b/examples/server/src/ssePollingExample.ts @@ -7,7 +7,7 @@ * Key features: * - Configures `retryInterval` to tell clients how long to wait before reconnecting * - Uses `eventStore` to persist events for replay after reconnection - * - Uses `extra.closeSSEStream()` callback to gracefully disconnect clients mid-operation + * - Uses `ctx.requestCtx.stream.closeSSEStream()` callback to gracefully disconnect clients mid-operation * * Run with: pnpm tsx src/ssePollingExample.ts * Test with: curl or the MCP Inspector diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 4c54c7f7e..5eaf2371a 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -373,7 +373,7 @@ export class Client< if (method === 'elicitation/create') { const wrappedHandler = async ( request: SchemaOutput, - extra: ContextInterface + ctx: ContextInterface ): Promise => { const validatedRequest = safeParse(ElicitRequestSchema, request); if (!validatedRequest.success) { @@ -395,7 +395,7 @@ export class Client< throw new McpError(ErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); } - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handler(request, ctx)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -446,7 +446,7 @@ export class Client< if (method === 'sampling/createMessage') { const wrappedHandler = async ( request: SchemaOutput, - extra: ContextInterface + ctx: ContextInterface ): Promise => { const validatedRequest = safeParse(CreateMessageRequestSchema, request); if (!validatedRequest.success) { @@ -457,7 +457,7 @@ export class Client< const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handler(request, ctx)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 4aeddeed9..74764077b 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -257,8 +257,8 @@ export abstract class Protocol { - const task = await this._taskStore!.getTask(request.params.taskId, extra.mcpCtx.sessionId); + this.setRequestHandler(GetTaskRequestSchema, async (request, ctx) => { + const task = await this._taskStore!.getTask(request.params.taskId, ctx.mcpCtx.sessionId); if (!task) { throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); } @@ -271,14 +271,14 @@ export abstract class Protocol { + this.setRequestHandler(GetTaskPayloadRequestSchema, async (request, ctx) => { const handleTaskResult = async (): Promise => { const taskId = request.params.taskId; // Deliver queued messages if (this._taskMessageQueue) { let queuedMessage: QueuedMessage | undefined; - while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, extra.mcpCtx.sessionId))) { + while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, ctx.mcpCtx.sessionId))) { // Handle response and error messages by routing them to the appropriate resolver if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { const message = queuedMessage.message; @@ -316,12 +316,12 @@ export abstract class Protocol { + this.setRequestHandler(ListTasksRequestSchema, async (request, ctx) => { try { - const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.mcpCtx.sessionId); + const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, ctx.mcpCtx.sessionId); // @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else return { tasks, @@ -375,10 +375,10 @@ export abstract class Protocol { + this.setRequestHandler(CancelTaskRequestSchema, async (request, ctx) => { try { // Get the current task to check if it's in a terminal state, in case the implementation is not atomic - const task = await this._taskStore!.getTask(request.params.taskId, extra.mcpCtx.sessionId); + const task = await this._taskStore!.getTask(request.params.taskId, ctx.mcpCtx.sessionId); if (!task) { throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); @@ -393,12 +393,12 @@ export abstract class Protocol, - extra: ContextInterface + ctx: ContextInterface ) => SendResultT | Promise ): void { const method = getMethodLiteral(requestSchema); this.assertRequestHandlerCapability(method); - this._requestHandlers.set(method, (request, extra) => { + this._requestHandlers.set(method, (request, ctx) => { const parsed = parseWithCompat(requestSchema, request) as SchemaOutput; - return Promise.resolve(handler(parsed, extra)); + return Promise.resolve(handler(parsed, ctx)); }); } diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index fe487ba38..460c27852 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -1963,9 +1963,9 @@ describe('Task-based execution', () => { await serverProtocol.connect(serverTransport); // Set up a handler that uses sendRequest and sendNotification - serverProtocol.setRequestHandler(CallToolRequestSchema, async (_request, extra) => { - // Send a notification using the extra.sendNotification - await extra.sendNotification({ + serverProtocol.setRequestHandler(CallToolRequestSchema, async (_request, ctx) => { + // Send a notification using the ctx.sendNotification + await ctx.sendNotification({ method: 'notifications/message', params: { level: 'info', data: 'test' } }); @@ -2045,10 +2045,10 @@ describe('Request Cancellation vs Task Cancellation', () => { method: z.literal('test/longRunning'), params: z.optional(z.record(z.string(), z.unknown())) }); - protocol.setRequestHandler(TestRequestSchema, async (_request, extra) => { + protocol.setRequestHandler(TestRequestSchema, async (_request, ctx) => { // Simulate a long-running operation await new Promise(resolve => setTimeout(resolve, 100)); - wasAborted = extra.requestCtx.signal.aborted; + wasAborted = ctx.requestCtx.signal.aborted; return { _meta: {} } as Result; }); @@ -2418,13 +2418,13 @@ describe('Progress notification support for tasks', () => { await protocol.connect(transport); // Set up a request handler that will complete the task - protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - if (extra.taskCtx?.store) { - const task = await extra.taskCtx.store.createTask({ ttl: 60_000 }); + protocol.setRequestHandler(CallToolRequestSchema, async (request, ctx) => { + if (ctx.taskCtx?.store) { + const task = await ctx.taskCtx.store.createTask({ ttl: 60_000 }); // Simulate async work then complete the task setTimeout(async () => { - await extra.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { + await ctx.taskCtx!.store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'Done' }] }); }, 50); diff --git a/packages/core/test/shared/protocolTransportHandling.test.ts b/packages/core/test/shared/protocolTransportHandling.test.ts index ba33e69cf..69b092a59 100644 --- a/packages/core/test/shared/protocolTransportHandling.test.ts +++ b/packages/core/test/shared/protocolTransportHandling.test.ts @@ -195,14 +195,14 @@ describe('Protocol transport handling bug', () => { }); // Set up handler with variable delay - protocol.setRequestHandler(DelayedRequestSchema, async (request, extra) => { + protocol.setRequestHandler(DelayedRequestSchema, async (request, ctx) => { const delay = request.params?.delay || 0; delays.push(delay); await new Promise(resolve => setTimeout(resolve, delay)); return { - processedBy: `handler-${extra.mcpCtx.requestId}`, + processedBy: `handler-${ctx.mcpCtx.requestId}`, delay: delay } as Result; }); diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index ef724dd10..fb3cce577 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -176,7 +176,7 @@ export class McpServer { }) ); - this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { + this.server.setRequestHandler(CallToolRequestSchema, async (request, ctx): Promise => { try { const tool = this._registeredTools[request.params.name]; if (!tool) { @@ -208,12 +208,12 @@ export class McpServer { // Handle taskSupport 'optional' without task augmentation - automatic polling if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { - return await this.handleAutomaticTaskPolling(tool, request, extra); + return await this.handleAutomaticTaskPolling(tool, request, ctx); } // Normal execution path const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const result = await this.executeToolHandler(tool, args, extra); + const result = await this.executeToolHandler(tool, args, ctx); // Return CreateTaskResult immediately for task requests if (isTaskRequest) { @@ -324,36 +324,36 @@ export class McpServer { private async executeToolHandler( tool: RegisteredTool, args: unknown, - extra: ContextInterface + ctx: ContextInterface ): Promise { const handler = tool.handler as AnyToolHandler; const isTaskHandler = 'createTask' in handler; if (isTaskHandler) { - if (!extra.taskCtx?.store) { + if (!ctx.taskCtx?.store) { throw new Error('No task store provided.'); } - const taskExtra = extra; + const taskCtx = ctx; if (tool.inputSchema) { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler.createTask(args as any, extra)); + return await Promise.resolve(typedHandler.createTask(args as any, ctx)); } else { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler.createTask as any)(taskExtra)); + return await Promise.resolve((typedHandler.createTask as any)(taskCtx)); } } if (tool.inputSchema) { const typedHandler = handler as ToolCallback; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler(args as any, extra)); + return await Promise.resolve(typedHandler(args as any, ctx)); } else { const typedHandler = handler as ToolCallback; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler as any)(extra)); + return await Promise.resolve((typedHandler as any)(ctx)); } } @@ -363,9 +363,9 @@ export class McpServer { private async handleAutomaticTaskPolling( tool: RegisteredTool, request: RequestT, - extra: ContextInterface + ctx: ContextInterface ): Promise { - if (!extra.taskCtx?.store) { + if (!ctx.taskCtx?.store) { throw new Error('No task store provided for task-capable tool.'); } @@ -374,9 +374,9 @@ export class McpServer { const handler = tool.handler as ToolTaskHandler; const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined - ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, extra)) + ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, ctx)) : // eslint-disable-next-line @typescript-eslint/no-explicit-any - await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(extra)); + await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(ctx)); // Poll until completion const taskId = createTaskResult.task.taskId; @@ -385,12 +385,12 @@ export class McpServer { while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { await new Promise(resolve => setTimeout(resolve, pollInterval)); - const updatedTask = await extra.taskCtx!.store.getTask(taskId); + const updatedTask = await ctx.taskCtx!.store.getTask(taskId); task = updatedTask; } // Return the final result - return (await extra.taskCtx!.store.getTaskResult(taskId)) as CallToolResult; + return (await ctx.taskCtx!.store.getTaskResult(taskId)) as CallToolResult; } private _completionHandlerInitialized = false; @@ -496,7 +496,7 @@ export class McpServer { } }); - this.server.setRequestHandler(ListResourcesRequestSchema, async (request, extra) => { + this.server.setRequestHandler(ListResourcesRequestSchema, async (request, ctx) => { const resources = Object.entries(this._registeredResources) .filter(([_, resource]) => resource.enabled) .map(([uri, resource]) => ({ @@ -511,7 +511,7 @@ export class McpServer { continue; } - const result = await template.resourceTemplate.listCallback(extra); + const result = await template.resourceTemplate.listCallback(ctx); for (const resource of result.resources) { templateResources.push({ ...template.metadata, @@ -534,7 +534,7 @@ export class McpServer { return { resourceTemplates }; }); - this.server.setRequestHandler(ReadResourceRequestSchema, async (request, extra) => { + this.server.setRequestHandler(ReadResourceRequestSchema, async (request, ctx) => { const uri = new URL(request.params.uri); // First check for exact resource match @@ -543,14 +543,14 @@ export class McpServer { if (!resource.enabled) { throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); } - return resource.readCallback(uri, extra); + return resource.readCallback(uri, ctx); } // Then check templates for (const template of Object.values(this._registeredResourceTemplates)) { const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); if (variables) { - return template.readCallback(uri, variables, extra); + return template.readCallback(uri, variables, ctx); } } @@ -592,7 +592,7 @@ export class McpServer { }) ); - this.server.setRequestHandler(GetPromptRequestSchema, async (request, extra): Promise => { + this.server.setRequestHandler(GetPromptRequestSchema, async (request, ctx): Promise => { const prompt = this._registeredPrompts[request.params.name]; if (!prompt) { throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); @@ -613,11 +613,11 @@ export class McpServer { const args = parseResult.data; const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, extra)); + return await Promise.resolve(cb(args, ctx)); } else { const cb = prompt.callback as PromptCallback; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((cb as any)(extra)); + return await Promise.resolve((cb as any)(ctx)); } }); @@ -1272,10 +1272,10 @@ export type BaseToolCallback< Extra extends ContextInterface, Args extends undefined | ZodRawShapeCompat | AnySchema > = Args extends ZodRawShapeCompat - ? (args: ShapeOutput, extra: Extra) => SendResultT | Promise + ? (args: ShapeOutput, ctx: Extra) => SendResultT | Promise : Args extends AnySchema - ? (args: SchemaOutput, extra: Extra) => SendResultT | Promise - : (extra: Extra) => SendResultT | Promise; + ? (args: SchemaOutput, ctx: Extra) => SendResultT | Promise + : (ctx: Extra) => SendResultT | Promise; /** * Callback for a tool handler registered with Server.tool(). @@ -1408,7 +1408,7 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: ContextInterface + ctx: ContextInterface ) => ListResourcesResult | Promise; /** @@ -1416,7 +1416,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - extra: ContextInterface + ctx: ContextInterface ) => ReadResourceResult | Promise; export type RegisteredResource = { @@ -1444,7 +1444,7 @@ export type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - extra: ContextInterface + ctx: ContextInterface ) => ReadResourceResult | Promise; export type RegisteredResourceTemplate = { @@ -1469,8 +1469,8 @@ export type RegisteredResourceTemplate = { type PromptArgsRawShape = ZodRawShapeCompat; export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, extra: ContextInterface) => GetPromptResult | Promise - : (extra: ContextInterface) => GetPromptResult | Promise; + ? (args: ShapeOutput, ctx: ContextInterface) => GetPromptResult | Promise + : (ctx: ContextInterface) => GetPromptResult | Promise; export type RegisteredPrompt = { title?: string; diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 58f4a1fee..36dc55835 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -176,10 +176,10 @@ export class Server< this.setNotificationHandler(InitializedNotificationSchema, () => this.oninitialized?.()); if (this._capabilities.logging) { - this.setRequestHandler(SetLevelRequestSchema, async (request, extra) => { - const ctx = extra as ServerContext; + this.setRequestHandler(SetLevelRequestSchema, async (request, ctx) => { + const serverCtx = ctx as ServerContext; const transportSessionId: string | undefined = - ctx.mcpCtx.sessionId || (ctx.requestCtx.headers.get('mcp-session-id') as string) || undefined; + serverCtx.mcpCtx.sessionId || (serverCtx.requestCtx.headers.get('mcp-session-id') as string) || undefined; const { level } = request.params; const parseResult = LoggingLevelSchema.safeParse(level); if (parseResult.success) { @@ -240,23 +240,23 @@ export class Server< extra: ServerContext ) => ServerResult | ResultT | Promise ): void { - // Wrap the handler to ensure the extra is a ServerContext and return a decorated handler that can be passed to the base implementation + // Wrap the handler to ensure the context is a ServerContext and return a decorated handler that can be passed to the base implementation - // Factory function to create a handler decorator that ensures the extra is a ServerContext and returns a decorated handler that can be passed to the base implementation + // Factory function to create a handler decorator that ensures the context is a ServerContext and returns a decorated handler that can be passed to the base implementation const handlerDecoratorFactory = ( innerHandler: ( request: SchemaOutput, - extra: ServerContext + ctx: ServerContext ) => ServerResult | ResultT | Promise ) => { const decoratedHandler = ( request: SchemaOutput, - extra: ContextInterface + ctx: ContextInterface ) => { - if (!this.isContextExtra(extra)) { - throw new Error('Internal error: Expected ServerContext for request handler extra'); + if (!this.isContextExtra(ctx)) { + throw new Error('Internal error: Expected ServerContext for request handler context'); } - return innerHandler(request, extra); + return innerHandler(request, ctx); }; return decoratedHandler; @@ -288,7 +288,7 @@ export class Server< if (method === 'tools/call') { const wrappedHandler = async ( request: SchemaOutput, - extra: ContextInterface + ctx: ContextInterface ): Promise => { const validatedRequest = safeParse(CallToolRequestSchema, request); if (!validatedRequest.success) { @@ -299,7 +299,7 @@ export class Server< const { params } = validatedRequest.data; - const result = await Promise.resolve(handlerDecoratorFactory(handler)(request, extra)); + const result = await Promise.resolve(handlerDecoratorFactory(handler)(request, ctx)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { diff --git a/src/conformance/everything-server.ts b/src/conformance/everything-server.ts index 685bfb382..411aebd5e 100644 --- a/src/conformance/everything-server.ts +++ b/src/conformance/everything-server.ts @@ -211,8 +211,8 @@ function createMcpServer(sessionId?: string) { description: 'Tests tool that emits log messages during execution', inputSchema: {} }, - async (_args, extra): Promise => { - await extra.sendNotification({ + async (_args, ctx): Promise => { + await ctx.sendNotification({ method: 'notifications/message', params: { level: 'info', @@ -221,7 +221,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/message', params: { level: 'info', @@ -230,7 +230,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/message', params: { level: 'info', @@ -250,10 +250,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests tool that reports progress notifications', inputSchema: {} }, - async (_args, extra): Promise => { - const progressToken = extra.mcpCtx._meta?.progressToken ?? 0; + async (_args, ctx): Promise => { + const progressToken = ctx.mcpCtx._meta?.progressToken ?? 0; console.log('Progress token:', progressToken); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/progress', params: { progressToken, @@ -264,7 +264,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/progress', params: { progressToken, @@ -275,7 +275,7 @@ function createMcpServer(sessionId?: string) { }); await new Promise(resolve => setTimeout(resolve, 50)); - await extra.sendNotification({ + await ctx.sendNotification({ method: 'notifications/progress', params: { progressToken, @@ -310,23 +310,23 @@ function createMcpServer(sessionId?: string) { 'Tests SSE stream disconnection and client reconnection (SEP-1699). Server will close the stream mid-call and send the result after client reconnects.', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - console.log(`[${extra.mcpCtx.sessionId}] Starting test_reconnection tool...`); + console.log(`[${ctx.mcpCtx.sessionId}] Starting test_reconnection tool...`); // Get the transport for this session - const transport = extra.mcpCtx.sessionId ? transports[extra.mcpCtx.sessionId] : undefined; - if (transport && extra.mcpCtx.requestId) { + const transport = ctx.mcpCtx.sessionId ? transports[ctx.mcpCtx.sessionId] : undefined; + if (transport && ctx.mcpCtx.requestId) { // Close the SSE stream to trigger client reconnection - console.log(`[${extra.mcpCtx.sessionId}] Closing SSE stream to trigger client polling...`); - transport.closeSSEStream(extra.mcpCtx.requestId); + console.log(`[${ctx.mcpCtx.sessionId}] Closing SSE stream to trigger client polling...`); + transport.closeSSEStream(ctx.mcpCtx.requestId); } // Wait for client to reconnect (should respect retry field) await sleep(100); - console.log(`[${extra.mcpCtx.sessionId}] test_reconnection tool complete`); + console.log(`[${ctx.mcpCtx.sessionId}] test_reconnection tool complete`); return { content: [ @@ -348,10 +348,10 @@ function createMcpServer(sessionId?: string) { prompt: z.string().describe('The prompt to send to the LLM') } }, - async (args: { prompt: string }, extra): Promise => { + async (args: { prompt: string }, ctx): Promise => { try { // Request sampling from client - const result = (await extra.sendRequest( + const result = (await ctx.sendRequest( { method: 'sampling/createMessage', params: { @@ -402,10 +402,10 @@ function createMcpServer(sessionId?: string) { message: z.string().describe('The message to show the user') } }, - async (args: { message: string }, extra): Promise => { + async (args: { message: string }, ctx): Promise => { try { // Request user input from client - const result = await extra.sendRequest( + const result = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -454,10 +454,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests elicitation with default values per SEP-1034', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { try { // Request user input with default values for all primitive types - const result = await extra.sendRequest( + const result = await ctx.sendRequest( { method: 'elicitation/create', params: { @@ -528,10 +528,10 @@ function createMcpServer(sessionId?: string) { description: 'Tests elicitation with enum schema improvements per SEP-1330', inputSchema: {} }, - async (_args, extra): Promise => { + async (_args, ctx): Promise => { try { // Request user input with all 5 enum schema variants - const result = await extra.sendRequest( + const result = await ctx.sendRequest( { method: 'elicitation/create', params: { diff --git a/test/integration/test/server/context.test.ts b/test/integration/test/server/context.test.ts index bd030b3f7..5b3997676 100644 --- a/test/integration/test/server/context.test.ts +++ b/test/integration/test/server/context.test.ts @@ -13,7 +13,7 @@ import { z } from 'zod/v4'; describe('ServerContext', () => { /*** - * Test: `extra` provided to callbacks is ServerContext (parameterized) + * Test: `ctx` provided to callbacks is ServerContext (parameterized) */ type Seen = { isContext: boolean; hasRequestId: boolean }; const contextCases: Array<[string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise]> = @@ -26,9 +26,9 @@ describe('ServerContext', () => { { inputSchema: z.object({ name: z.string() }) }, - (_args: { name: string }, extra) => { - seen.isContext = extra instanceof ServerContext; - seen.hasRequestId = !!extra.mcpCtx.requestId; + (_args: { name: string }, ctx) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; return { content: [{ type: 'text', text: 'ok' }] }; } ); @@ -50,9 +50,9 @@ describe('ServerContext', () => { [ 'resource', (mcpServer, seen) => { - mcpServer.registerResource('ctx-resource', 'test://res/1', { title: 'ctx-resource' }, async (_uri, extra) => { - seen.isContext = extra instanceof ServerContext; - seen.hasRequestId = !!extra.mcpCtx.requestId; + mcpServer.registerResource('ctx-resource', 'test://res/1', { title: 'ctx-resource' }, async (_uri, ctx) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; }); }, @@ -62,13 +62,13 @@ describe('ServerContext', () => { 'resource template list', (mcpServer, seen) => { const template = new ResourceTemplate('test://items/{id}', { - list: async extra => { - seen.isContext = extra instanceof ServerContext; - seen.hasRequestId = !!extra.mcpCtx.requestId; + list: async ctx => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; return { resources: [] }; } }); - mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _extra) => ({ + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _ctx) => ({ contents: [] })); }, @@ -77,9 +77,9 @@ describe('ServerContext', () => { [ 'prompt', (mcpServer, seen) => { - mcpServer.registerPrompt('ctx-prompt', {}, async extra => { - seen.isContext = extra instanceof ServerContext; - seen.hasRequestId = !!extra.mcpCtx.requestId; + mcpServer.registerPrompt('ctx-prompt', {}, async ctx => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; return { messages: [] }; }); }, @@ -87,7 +87,7 @@ describe('ServerContext', () => { ] ]; - test.each(contextCases)('should pass ServerContext as extra to %s callbacks', async (_kind, register, trigger) => { + test.each(contextCases)('should pass ServerContext as ctx to %s callbacks', async (_kind, register, trigger) => { const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); const client = new Client({ name: 'ctx-client', version: '1.0' }); @@ -134,10 +134,10 @@ describe('ServerContext', () => { return; }); - mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, extra) => { - const ctx = extra as ServerContext; - await ctx.loggingNotification[level]('Test message', { test: 'test' }, 'sample-session-id'); - await ctx.loggingNotification.log( + mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, ctx) => { + const serverCtx = ctx as ServerContext; + await serverCtx.loggingNotification[level]('Test message', { test: 'test' }, 'sample-session-id'); + await serverCtx.loggingNotification.log( { level, data: 'Test message', @@ -184,10 +184,10 @@ describe('ServerContext', () => { { inputSchema: z.object({ name: z.string() }) }, - // The test is to ensure that the extra is compatible with the ContextInterface type - (_args: { name: string }, extra: ContextInterface) => { - seen.isContext = extra instanceof ServerContext; - seen.hasRequestId = !!extra.mcpCtx.requestId; + // The test is to ensure that the ctx is compatible with the ContextInterface type + (_args: { name: string }, ctx: ContextInterface) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; return { content: [{ type: 'text', text: 'ok' }] }; } ); @@ -209,14 +209,14 @@ describe('ServerContext', () => { [ 'resource', (mcpServer, seen) => { - // The test is to ensure that the extra is compatible with the ContextInterface type + // The test is to ensure that the ctx is compatible with the ContextInterface type mcpServer.registerResource( 'ctx-resource', 'test://res/1', { title: 'ctx-resource' }, - async (_uri, extra: ContextInterface) => { - seen.isContext = extra instanceof ServerContext; - seen.hasRequestId = !!extra.mcpCtx.requestId; + async (_uri, ctx: ContextInterface) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; } ); @@ -226,15 +226,15 @@ describe('ServerContext', () => { [ 'resource template list', (mcpServer, seen) => { - // The test is to ensure that the extra is compatible with the ContextInterface type + // The test is to ensure that the ctx is compatible with the ContextInterface type const template = new ResourceTemplate('test://items/{id}', { - list: async (extra: ContextInterface) => { - seen.isContext = extra instanceof ServerContext; - seen.hasRequestId = !!extra.mcpCtx.requestId; + list: async (ctx: ContextInterface) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; return { resources: [] }; } }); - mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _extra) => ({ + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _ctx) => ({ contents: [] })); }, @@ -243,13 +243,13 @@ describe('ServerContext', () => { [ 'prompt', (mcpServer, seen) => { - // The test is to ensure that the extra is compatible with the ContextInterface type + // The test is to ensure that the ctx is compatible with the ContextInterface type mcpServer.registerPrompt( 'ctx-prompt', {}, - async (extra: ContextInterface) => { - seen.isContext = extra instanceof ServerContext; - seen.hasRequestId = !!extra.mcpCtx.requestId; + async (ctx: ContextInterface) => { + seen.isContext = ctx instanceof ServerContext; + seen.hasRequestId = !!ctx.mcpCtx.requestId; return { messages: [] }; } ); @@ -258,7 +258,7 @@ describe('ServerContext', () => { ] ]; - test.each(contextCases)('should pass ServerContext as extra to %s callbacks', async (_kind, register, trigger) => { + test.each(contextCases)('should pass ServerContext as ctx to %s callbacks', async (_kind, register, trigger) => { const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); const client = new Client({ name: 'ctx-client', version: '1.0' }); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 0cf95d0c7..9023e3ded 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1455,8 +1455,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedSessionId: string | undefined; - mcpServer.tool('test-tool', async extra => { - receivedSessionId = extra.mcpCtx.sessionId; + mcpServer.tool('test-tool', async ctx => { + receivedSessionId = ctx.mcpCtx.sessionId; return { content: [ { @@ -1501,13 +1501,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.tool('request-id-test', async extra => { - receivedRequestId = extra.mcpCtx.requestId; + mcpServer.tool('request-id-test', async ctx => { + receivedRequestId = ctx.mcpCtx.requestId; return { content: [ { type: 'text', - text: `Received request ID: ${extra.mcpCtx.requestId}` + text: `Received request ID: ${ctx.mcpCtx.requestId}` } ] }; @@ -2969,13 +2969,13 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.resource('request-id-test', 'test://resource', async (_uri, extra) => { - receivedRequestId = extra.mcpCtx.requestId; + mcpServer.resource('request-id-test', 'test://resource', async (_uri, ctx) => { + receivedRequestId = ctx.mcpCtx.requestId; return { contents: [ { uri: 'test://resource', - text: `Received request ID: ${extra.mcpCtx.requestId}` + text: `Received request ID: ${ctx.mcpCtx.requestId}` } ] }; @@ -3888,15 +3888,15 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); let receivedRequestId: string | number | undefined; - mcpServer.prompt('request-id-test', async extra => { - receivedRequestId = extra.mcpCtx.requestId; + mcpServer.prompt('request-id-test', async ctx => { + receivedRequestId = ctx.mcpCtx.requestId; return { messages: [ { role: 'assistant', content: { type: 'text', - text: `Received request ID: ${extra.mcpCtx.requestId}` + text: `Received request ID: ${ctx.mcpCtx.requestId}` } } ] @@ -4401,15 +4401,15 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }) } }, - async ({ department, name }, extra: ContextInterface) => { - expect(extra).toBeInstanceOf(ServerContext); + async (args, ctx: ContextInterface) => { + expect(ctx).toBeInstanceOf(ServerContext); return { messages: [ { role: 'assistant', content: { type: 'text', - text: `Hello ${name}, welcome to the ${department} team!` + text: `Hello ${args.name}, welcome to the ${args.department} team!` } } ] From 920da7ea5bc99a099ac7f52796ebf21f3c6827dd Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Thu, 22 Jan 2026 18:53:05 +0200 Subject: [PATCH 13/17] prettier fix --- pnpm-workspace.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 4213eb9ff..04288d9f7 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -57,8 +57,8 @@ linkWorkspacePackages: deep minimumReleaseAge: 10080 # 7 days minimumReleaseAgeExclude: - '@modelcontextprotocol/conformance' - - hono@4.11.4 # fixes https://github.com/advisories/GHSA-3vhc-576x-3qv4 https://github.com/advisories/GHSA-f67f-6cw9-8mq4 - - '@hono/node-server@1.19.9' # https://github.com/honojs/node-server/pull/295 + - hono@4.11.4 # fixes https://github.com/advisories/GHSA-3vhc-576x-3qv4 https://github.com/advisories/GHSA-f67f-6cw9-8mq4 + - '@hono/node-server@1.19.9' # https://github.com/honojs/node-server/pull/295 onlyBuiltDependencies: - better-sqlite3 From 639d7bdfed7acbbcd6da1ae1837b226aa4e3b8a6 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Thu, 22 Jan 2026 20:46:00 +0200 Subject: [PATCH 14/17] add changeset --- .changeset/hot-trees-sing.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .changeset/hot-trees-sing.md diff --git a/.changeset/hot-trees-sing.md b/.changeset/hot-trees-sing.md new file mode 100644 index 000000000..74df1dd00 --- /dev/null +++ b/.changeset/hot-trees-sing.md @@ -0,0 +1,12 @@ +--- +'@modelcontextprotocol/express': patch +'@modelcontextprotocol/hono': patch +'@modelcontextprotocol/node': patch +'@modelcontextprotocol/eslint-config': patch +'@modelcontextprotocol/test-integration': patch +'@modelcontextprotocol/client': patch +'@modelcontextprotocol/server': patch +'@modelcontextprotocol/core': patch +--- + +add context API to tool, prompt, resource callbacks, linting From f512df45b06e453ab0b969426a93f909609a10b5 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Mon, 26 Jan 2026 00:30:24 +0200 Subject: [PATCH 15/17] save commit --- .../client/src/multipleClientsParallel.ts | 3 +- .../client/src/parallelToolCallsClient.ts | 3 +- examples/client/src/simpleOAuthClient.ts | 5 +- examples/client/src/simpleStreamableHttp.ts | 7 +- .../client/src/simpleStreamableHttpBuilder.ts | 828 +++++++++ .../client/src/simpleTaskInteractiveClient.ts | 9 +- .../streamableHttpWithSseFallbackClient.ts | 3 +- examples/server/src/elicitationFormExample.ts | 64 +- examples/server/src/simpleStreamableHttp.ts | 15 +- .../server/src/simpleStreamableHttpBuilder.ts | 578 +++++++ examples/server/src/toolWithSampleServer.ts | 4 +- packages/client/src/client/builder.ts | 452 +++++ packages/client/src/client/client.ts | 303 +++- packages/client/src/client/context.ts | 2 +- packages/client/src/client/middleware.ts | 410 ++++- .../client/src/experimental/tasks/client.ts | 48 +- packages/client/src/index.ts | 1 + packages/core/src/errors.ts | 400 +++++ packages/core/src/index.ts | 10 + packages/core/src/shared/context.ts | 27 +- packages/core/src/shared/events.ts | 274 +++ packages/core/src/shared/handlerRegistry.ts | 184 ++ packages/core/src/shared/plugin.ts | 481 ++++++ packages/core/src/shared/pluginContext.ts | 192 +++ packages/core/src/shared/progressManager.ts | 126 ++ packages/core/src/shared/protocol.ts | 1501 ++++++++--------- packages/core/src/shared/taskClientPlugin.ts | 446 +++++ packages/core/src/shared/taskPlugin.ts | 489 ++++++ packages/core/src/shared/timeoutManager.ts | 172 ++ packages/core/src/shared/transport.ts | 48 + packages/core/src/util/content.ts | 211 +++ packages/middleware/express/package.json | 4 +- packages/middleware/hono/package.json | 4 +- packages/middleware/node/package.json | 4 +- .../src/experimental/tasks/interfaces.ts | 6 +- packages/server/src/index.ts | 4 + packages/server/src/server/builder.ts | 427 +++++ packages/server/src/server/context.ts | 2 +- packages/server/src/server/mcp.ts | 910 ++++++---- packages/server/src/server/middleware.ts | 453 +++++ .../src/server/registries/baseRegistry.ts | 229 +++ .../server/src/server/registries/index.ts | 10 + .../src/server/registries/promptRegistry.ts | 242 +++ .../src/server/registries/resourceRegistry.ts | 496 ++++++ .../src/server/registries/toolRegistry.ts | 297 ++++ packages/server/src/server/server.ts | 49 +- packages/server/src/server/sessions.ts | 344 ++++ src/conformance/everything-server.ts | 72 +- test/integration/test/server.test.ts | 3 +- test/integration/test/server/mcp.test.ts | 45 +- .../integration/test/taskResumability.test.ts | 6 +- 51 files changed, 9537 insertions(+), 1366 deletions(-) create mode 100644 examples/client/src/simpleStreamableHttpBuilder.ts create mode 100644 examples/server/src/simpleStreamableHttpBuilder.ts create mode 100644 packages/client/src/client/builder.ts create mode 100644 packages/core/src/errors.ts create mode 100644 packages/core/src/shared/events.ts create mode 100644 packages/core/src/shared/handlerRegistry.ts create mode 100644 packages/core/src/shared/plugin.ts create mode 100644 packages/core/src/shared/pluginContext.ts create mode 100644 packages/core/src/shared/progressManager.ts create mode 100644 packages/core/src/shared/taskClientPlugin.ts create mode 100644 packages/core/src/shared/taskPlugin.ts create mode 100644 packages/core/src/shared/timeoutManager.ts create mode 100644 packages/core/src/util/content.ts create mode 100644 packages/server/src/server/builder.ts create mode 100644 packages/server/src/server/middleware.ts create mode 100644 packages/server/src/server/registries/baseRegistry.ts create mode 100644 packages/server/src/server/registries/index.ts create mode 100644 packages/server/src/server/registries/promptRegistry.ts create mode 100644 packages/server/src/server/registries/resourceRegistry.ts create mode 100644 packages/server/src/server/registries/toolRegistry.ts create mode 100644 packages/server/src/server/sessions.ts diff --git a/examples/client/src/multipleClientsParallel.ts b/examples/client/src/multipleClientsParallel.ts index f537dff55..e24767408 100644 --- a/examples/client/src/multipleClientsParallel.ts +++ b/examples/client/src/multipleClientsParallel.ts @@ -2,6 +2,7 @@ import type { CallToolRequest, CallToolResult } from '@modelcontextprotocol/clie import { CallToolResultSchema, Client, + isTextContent, LoggingMessageNotificationSchema, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; @@ -133,7 +134,7 @@ async function main(): Promise { console.log(`\n[${id}] Tool result:`); if (Array.isArray(result.content)) { for (const item of result.content) { - if (item.type === 'text' && item.text) { + if (isTextContent(item)) { console.log(` ${item.text}`); } else { console.log(` ${item.type} content:`, item); diff --git a/examples/client/src/parallelToolCallsClient.ts b/examples/client/src/parallelToolCallsClient.ts index fbd3910de..5114c6a22 100644 --- a/examples/client/src/parallelToolCallsClient.ts +++ b/examples/client/src/parallelToolCallsClient.ts @@ -2,6 +2,7 @@ import type { CallToolResult, ListToolsRequest } from '@modelcontextprotocol/cli import { CallToolResultSchema, Client, + isTextContent, ListToolsResultSchema, LoggingMessageNotificationSchema, StreamableHTTPClientTransport @@ -60,7 +61,7 @@ async function main(): Promise { for (const [caller, result] of Object.entries(toolResults)) { console.log(`\n=== Tool result for ${caller} ===`); for (const item of result.content) { - if (item.type === 'text') { + if (isTextContent(item)) { console.log(` ${item.text}`); } else { console.log(` ${item.type} content:`, item); diff --git a/examples/client/src/simpleOAuthClient.ts b/examples/client/src/simpleOAuthClient.ts index 23ea05a99..84791acb9 100644 --- a/examples/client/src/simpleOAuthClient.ts +++ b/examples/client/src/simpleOAuthClient.ts @@ -9,6 +9,7 @@ import type { CallToolRequest, ListToolsRequest, OAuthClientMetadata } from '@mo import { CallToolResultSchema, Client, + isTextContent, ListToolsResultSchema, StreamableHTTPClientTransport, UnauthorizedError @@ -315,7 +316,7 @@ class InteractiveOAuthClient { console.log(`\n🔧 Tool '${toolName}' result:`); if (result.content) { for (const content of result.content) { - if (content.type === 'text') { + if (isTextContent(content)) { console.log(content.text); } else { console.log(content); @@ -396,7 +397,7 @@ class InteractiveOAuthClient { case 'result': { console.log('✓ Completed!'); for (const content of message.result.content) { - if (content.type === 'text') { + if (isTextContent(content)) { console.log(content.text); } else { console.log(content); diff --git a/examples/client/src/simpleStreamableHttp.ts b/examples/client/src/simpleStreamableHttp.ts index ced687027..b802fb64b 100644 --- a/examples/client/src/simpleStreamableHttp.ts +++ b/examples/client/src/simpleStreamableHttp.ts @@ -16,6 +16,7 @@ import { ErrorCode, getDisplayName, GetPromptResultSchema, + isTextContent, ListPromptsResultSchema, ListResourcesResultSchema, ListToolsResultSchema, @@ -737,7 +738,7 @@ async function runNotificationsToolWithResumability(interval: number, count: num console.log('Tool result:'); for (const item of result.content) { - if (item.type === 'text') { + if (isTextContent(item)) { console.log(` ${item.text}`); } else { console.log(` ${item.type} content:`, item); @@ -791,7 +792,7 @@ async function getPrompt(name: string, args: Record): Promise): Promis console.log('Task completed!'); console.log('Tool result:'); for (const item of message.result.content) { - if (item.type === 'text') { + if (isTextContent(item)) { console.log(` ${item.text}`); } } diff --git a/examples/client/src/simpleStreamableHttpBuilder.ts b/examples/client/src/simpleStreamableHttpBuilder.ts new file mode 100644 index 000000000..4246f7355 --- /dev/null +++ b/examples/client/src/simpleStreamableHttpBuilder.ts @@ -0,0 +1,828 @@ +/* eslint-disable unicorn/no-process-exit */ +/** + * Simple Streamable HTTP Client Example using Builder Pattern + * + * This example demonstrates using the Client.builder() fluent API + * to create and configure an MCP client with: + * - Builder pattern configuration + * - Universal middleware (logging) + * - Outgoing middleware (retry logic) + * - Tool call middleware (instrumentation) + * - Sampling request handler + * - Elicitation request handler + * - Roots list handler + * - Error handlers (onError, onProtocolError) + * + * Run with: npx tsx src/simpleStreamableHttpBuilder.ts + */ + +import { createInterface } from 'node:readline'; + +import type { + CallToolRequest, + ClientMiddleware, + GetPromptRequest, + ListPromptsRequest, + ListResourcesRequest, + ListToolsRequest, + OutgoingMiddleware, + ReadResourceRequest, + ToolCallMiddleware +} from '@modelcontextprotocol/client'; +import { + CallToolResultSchema, + Client, + getDisplayName, + GetPromptResultSchema, + isTextContent, + ListPromptsResultSchema, + ListResourcesResultSchema, + ListToolsResultSchema, + LoggingMessageNotificationSchema +, + ReadResourceResultSchema, + StreamableHTTPClientTransport +} from '@modelcontextprotocol/client'; + +// Create readline interface for user input +const readline = createInterface({ + input: process.stdin, + output: process.stdout +}); + +// Track received notifications +let notificationCount = 0; + +// Global client and transport +let client: Client | null = null; +let transport: StreamableHTTPClientTransport | null = null; +let serverUrl = 'http://localhost:3000/mcp'; +let sessionId: string | undefined; + +// ═══════════════════════════════════════════════════════════════════════════ +// Custom Middleware Examples +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Options for MCP client logging middleware. + */ +export interface ClientLoggingMiddlewareOptions { + /** Log level */ + level?: 'debug' | 'info' | 'warn' | 'error'; + /** Custom logger function */ + logger?: (level: string, message: string, data?: unknown) => void; +} + +/** + * Creates a logging middleware for MCP client operations. + * + * @example + * ```typescript + * client.useMiddleware(createClientLoggingMiddleware({ level: 'debug' })); + * ``` + */ +export function createClientLoggingMiddleware(options: ClientLoggingMiddlewareOptions = {}): ClientMiddleware { + const { level = 'info', logger = console.log } = options; + + return async (ctx, next) => { + logger(level, `${ctx.direction} ${ctx.type}: ${ctx.method}`, { + direction: ctx.direction, + type: ctx.type, + method: ctx.method, + requestId: ctx.requestId + }); + + const start = Date.now(); + + try { + const result = await next(); + const duration = Date.now() - start; + logger(level, `← ${ctx.type}: ${ctx.method} (${duration}ms)`, { + direction: ctx.direction, + type: ctx.type, + method: ctx.method, + requestId: ctx.requestId, + duration + }); + return result; + } catch (error) { + const duration = Date.now() - start; + logger('error', `✗ ${ctx.type}: ${ctx.method} (${duration}ms)`, { + direction: ctx.direction, + type: ctx.type, + method: ctx.method, + requestId: ctx.requestId, + duration, + error + }); + throw error; + } + }; +} + + +/** + * Options for retry middleware. + */ +interface RetryMiddlewareOptions { + /** Maximum number of retries */ + maxRetries?: number; + /** Base delay between retries in ms */ + baseDelay?: number; + /** Function to determine if an error is retryable */ + isRetryable?: (error: unknown) => boolean; +} +/** + * Creates a retry middleware for outgoing MCP requests. + * + * @example + * ```typescript + * client.useOutgoingMiddleware(createRetryMiddleware({ + * maxRetries: 3, + * baseDelay: 100, + * })); + * ``` + */ +export function createRetryMiddleware(options: RetryMiddlewareOptions = {}): OutgoingMiddleware { + const { maxRetries = 3, baseDelay = 100, isRetryable = () => true } = options; + + return async (ctx, next) => { + let lastError: unknown; + + for (let attempt = 1; attempt <= maxRetries + 1; attempt++) { + try { + return await next(); + } catch (error) { + lastError = error; + + if (attempt > maxRetries || !isRetryable(error)) { + throw error; + } + + // Exponential backoff + const delay = baseDelay * Math.pow(2, attempt - 1); + await new Promise(resolve => setTimeout(resolve, delay)); + } + } + + throw lastError; + }; +} + + +/** + * Custom tool call instrumentation middleware. + * Logs tool calls with timing information. + */ +const toolCallInstrumentationMiddleware: ToolCallMiddleware = async (ctx, next) => { + console.log(`\n[TOOL CALL] Starting: ${ctx.params.name}`); + console.log(`[TOOL CALL] Arguments: ${JSON.stringify(ctx.params.arguments || {})}`); + + const start = performance.now(); + try { + const result = await next(); + const duration = (performance.now() - start).toFixed(2); + console.log(`[TOOL CALL] Completed: ${ctx.params.name} (${duration}ms)`); + return result; + } catch (error) { + const duration = (performance.now() - start).toFixed(2); + console.log(`[TOOL CALL] Failed: ${ctx.params.name} (${duration}ms) - ${error}`); + throw error; + } +}; + +/** + * Custom request timing middleware. + * Tracks timing for all outgoing requests. + */ +const requestTimingMiddleware: ClientMiddleware = async (ctx, next) => { + const start = performance.now(); + try { + const result = await next(); + const duration = (performance.now() - start).toFixed(2); + console.log(`[TIMING] ${ctx.direction} ${ctx.method} completed in ${duration}ms`); + return result; + } catch (error) { + const duration = (performance.now() - start).toFixed(2); + console.log(`[TIMING] ${ctx.direction} ${ctx.method} failed in ${duration}ms`); + throw error; + } +}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Main Function +// ═══════════════════════════════════════════════════════════════════════════ + +async function main(): Promise { + console.log('═══════════════════════════════════════════════════════════════'); + console.log('MCP Interactive Client (Builder Pattern Example)'); + console.log('═══════════════════════════════════════════════════════════════'); + console.log(''); + console.log('Features demonstrated:'); + console.log(' - Builder pattern for client configuration'); + console.log(' - Universal middleware (logging, timing)'); + console.log(' - Outgoing middleware (retry logic)'); + console.log(' - Tool call middleware (instrumentation)'); + console.log(' - Sampling request handler'); + console.log(' - Elicitation request handler'); + console.log(' - Roots list handler'); + console.log(' - Error handlers (onError, onProtocolError)'); + console.log('═══════════════════════════════════════════════════════════════'); + + // Connect to server immediately + await connect(); + + // Print help and start the command loop + printHelp(); + commandLoop(); +} + +function printHelp(): void { + console.log('\nAvailable commands:'); + console.log(' connect [url] - Connect to MCP server (default: http://localhost:3000/mcp)'); + console.log(' disconnect - Disconnect from server'); + console.log(' reconnect - Reconnect to the server'); + console.log(' list-tools - List available tools'); + console.log(' call-tool [args] - Call a tool with optional JSON arguments'); + console.log(' greet [name] - Call the greet tool'); + console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); + console.log(' context-demo [msg] - Call the context-demo tool'); + console.log(' admin-action - Call admin-action (no auth)'); + console.log(' admin-action-auth - Call admin-action with auth token'); + console.log(' error-test - Test error handling (application/validation)'); + console.log(' list-prompts - List available prompts'); + console.log(' get-prompt [args] - Get a prompt with optional JSON arguments'); + console.log(' list-resources - List available resources'); + console.log(' read-resource - Read a specific resource by URI'); + console.log(' session-info - Read session info resource'); + console.log(' help - Show this help'); + console.log(' quit - Exit the program'); +} + +function commandLoop(): void { + readline.question('\n> ', async (input) => { + const args = input.trim().split(/\s+/); + const command = args[0]?.toLowerCase(); + + try { + switch (command) { + case 'connect': { + await connect(args[1]); + break; + } + + case 'disconnect': { + await disconnect(); + break; + } + + case 'reconnect': { + await reconnect(); + break; + } + + case 'list-tools': { + await listTools(); + break; + } + + case 'call-tool': { + if (args.length < 2) { + console.log('Usage: call-tool [args]'); + } else { + const toolName = args[1]!; + let toolArgs = {}; + if (args.length > 2) { + try { + toolArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await callTool(toolName, toolArgs); + } + break; + } + + case 'greet': { + await callTool('greet', { name: args[1] || 'World' }); + break; + } + + case 'multi-greet': { + console.log('Calling multi-greet tool (watch for notifications)...'); + await callTool('multi-greet', { name: args[1] || 'World' }); + break; + } + + case 'context-demo': { + await callTool('context-demo', { message: args.slice(1).join(' ') || 'Hello from client!' }); + break; + } + + case 'admin-action': { + if (args.length < 2) { + console.log('Usage: admin-action '); + } else { + // Call without requiresAdmin flag - should work + await callTool('admin-action', { action: args[1] }); + } + break; + } + + case 'admin-action-auth': { + if (args.length < 2) { + console.log('Usage: admin-action-auth '); + } else { + // Call with requiresAdmin but provide token + await callTool('admin-action', { + action: args[1], + requiresAdmin: true, + adminToken: 'demo-token-123' + }); + } + break; + } + + case 'error-test': { + if (args.length < 2) { + console.log('Usage: error-test '); + } else { + await callTool('error-test', { errorType: args[1] }); + } + break; + } + + case 'list-prompts': { + await listPrompts(); + break; + } + + case 'get-prompt': { + if (args.length < 2) { + console.log('Usage: get-prompt [args]'); + } else { + const promptName = args[1]!; + let promptArgs = {}; + if (args.length > 2) { + try { + promptArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await getPrompt(promptName, promptArgs); + } + break; + } + + case 'list-resources': { + await listResources(); + break; + } + + case 'read-resource': { + if (args.length < 2) { + console.log('Usage: read-resource '); + } else { + await readResource(args[1]!); + } + break; + } + + case 'session-info': { + await readResource('https://example.com/session/info'); + break; + } + + case 'help': { + printHelp(); + break; + } + + case 'quit': + case 'exit': { + await cleanup(); + return; + } + + default: { + if (command) { + console.log(`Unknown command: ${command}`); + } + break; + } + } + } catch (error) { + console.error(`Error executing command: ${error}`); + } + + // Continue the command loop + commandLoop(); + }); +} + +/** + * Connect to the MCP server using the builder pattern. + * + * The builder provides a fluent API for configuring the client: + * - .name() and .version() set client info + * - .capabilities() configures client capabilities + * - .useMiddleware() adds universal middleware + * - .useOutgoingMiddleware() adds outgoing-only middleware + * - .useToolCallMiddleware() adds tool call specific middleware + * - .onSamplingRequest() handles sampling requests from server + * - .onElicitation() handles elicitation requests from server + * - .onRootsList() handles roots list requests from server + * - .onError() handles application errors + * - .onProtocolError() handles protocol errors + * - .build() creates the configured Client instance + */ +async function connect(url?: string): Promise { + if (client) { + console.log('Already connected. Disconnect first.'); + return; + } + + if (url) { + serverUrl = url; + } + + console.log(`\nConnecting to ${serverUrl}...`); + + try { + // Create a new client using the builder pattern + client = Client.builder() + .name('builder-example-client') + .version('1.0.0') + + // ─── Capabilities ─── + // Enable sampling, elicitation, and roots capabilities + .capabilities({ + sampling: {}, + elicitation: { form: {} }, + roots: { listChanged: true } + }) + + // ─── Universal Middleware ─── + // Logging middleware for all requests + .useMiddleware( + createClientLoggingMiddleware({ + level: 'debug', + logger: (level, message, data) => { + const timestamp = new Date().toISOString(); + console.log(`[${timestamp}] [CLIENT ${level.toUpperCase()}] ${message}`); + if (data) { + console.log(`[${timestamp}] [CLIENT ${level.toUpperCase()}] Data:`, JSON.stringify(data, null, 2)); + } + } + }) + ) + + // Custom timing middleware + .useMiddleware(requestTimingMiddleware) + + // ─── Outgoing Middleware ─── + // Retry middleware for transient failures + .useOutgoingMiddleware( + createRetryMiddleware({ + maxRetries: 3, + baseDelay: 100, + isRetryable: (error) => { + // Retry on network errors + const message = error instanceof Error ? error.message : String(error); + return ( + message.includes('ECONNREFUSED') || + message.includes('ETIMEDOUT') || + message.includes('network') + ); + } + }) + ) + + // ─── Tool Call Middleware ─── + .useToolCallMiddleware(toolCallInstrumentationMiddleware) + + // ─── Request Handlers ─── + + // Sampling request handler (when server requests LLM completion) + .onSamplingRequest(async (params) => { + console.log('\n[SAMPLING] Received sampling request from server'); + console.log('[SAMPLING] Messages:', JSON.stringify(params, null, 2)); + + // In a real implementation, this would call an LLM + // For demo, return a simulated response + return { + role: 'assistant', + content: { + type: 'text', + text: 'This is a simulated sampling response from the client.' + }, + model: 'simulated-model-v1' + }; + }) + + // Elicitation handler (when server requests user input) + .onElicitation(async (params) => { + const elicitParams = params as { mode?: string; message?: string; requestedSchema?: unknown }; + console.log('\n[ELICITATION] Received elicitation request from server'); + console.log('[ELICITATION] Mode:', elicitParams.mode); + console.log('[ELICITATION] Message:', elicitParams.message); + + if (elicitParams.mode === 'form') { + // For demo, auto-accept with sample data + console.log('[ELICITATION] Auto-accepting form with sample data'); + return { + action: 'accept', + content: { + name: 'Demo User', + email: 'demo@example.com', + confirmed: true + } + }; + } + + // Decline other modes + console.log('[ELICITATION] Declining non-form elicitation'); + return { action: 'decline' }; + }) + + // Roots list handler (when server requests filesystem roots) + .onRootsList(async () => { + console.log('\n[ROOTS] Received roots list request from server'); + return { + roots: [ + { uri: 'file:///workspace', name: 'Workspace' }, + { uri: 'file:///home/user', name: 'Home Directory' }, + { uri: 'file:///tmp', name: 'Temporary Files' } + ] + }; + }) + + // ─── Error Handlers ─── + .onError((error, ctx) => { + console.error(`\n[CLIENT ERROR] ${ctx.type}: ${error.message}`); + console.error(`[CLIENT ERROR] Request ID: ${ctx.requestId}`); + // Return the original error (could also transform it) + return error; + }) + .onProtocolError((error, ctx) => { + console.error(`\n[PROTOCOL ERROR] ${ctx.method}: ${error.message}`); + console.error(`[PROTOCOL ERROR] Request ID: ${ctx.requestId}`); + }) + + .build(); + + // Set up client error handler + client.onerror = (error) => { + console.error('\n[CLIENT] Error event:', error); + }; + + // Create transport with optional session ID for reconnection + transport = new StreamableHTTPClientTransport(new URL(serverUrl), { + sessionId: sessionId + }); + + // Set up notification handler for logging messages + client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { + notificationCount++; + console.log(`\n[NOTIFICATION #${notificationCount}] ${notification.params.level}: ${notification.params.data}`); + process.stdout.write('> '); + }); + + // Connect the client + await client.connect(transport); + sessionId = transport.sessionId; + console.log('Connected to MCP server'); + console.log('Session ID:', sessionId); + } catch (error) { + console.error('Failed to connect:', error); + client = null; + transport = null; + } +} + +async function disconnect(): Promise { + if (!client || !transport) { + console.log('Not connected.'); + return; + } + + try { + await transport.close(); + console.log('Disconnected from MCP server'); + client = null; + transport = null; + } catch (error) { + console.error('Error disconnecting:', error); + } +} + +async function reconnect(): Promise { + if (client) { + await disconnect(); + } + await connect(); +} + +async function listTools(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ListToolsRequest = { + method: 'tools/list', + params: {} + }; + const result = await client.request(request, ListToolsResultSchema); + + console.log('\nAvailable tools:'); + if (result.tools.length === 0) { + console.log(' No tools available'); + } else { + for (const tool of result.tools) { + console.log(` - ${tool.name}: ${getDisplayName(tool)} - ${tool.description}`); + } + } + } catch (error) { + console.log(`Tools not supported by this server (${error})`); + } +} + +async function callTool(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: CallToolRequest = { + method: 'tools/call', + params: { + name, + arguments: args + } + }; + + const result = await client.request(request, CallToolResultSchema); + + console.log('\nTool result:'); + for (const item of result.content) { + if (isTextContent(item)) { + console.log(` ${item.text}`); + } else { + console.log(` [${item.type}]:`, item); + } + } + } catch (error) { + console.log(`Error calling tool ${name}: ${error}`); + } +} + +async function listPrompts(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ListPromptsRequest = { + method: 'prompts/list', + params: {} + }; + const result = await client.request(request, ListPromptsResultSchema); + + console.log('\nAvailable prompts:'); + if (result.prompts.length === 0) { + console.log(' No prompts available'); + } else { + for (const prompt of result.prompts) { + console.log(` - ${prompt.name}: ${getDisplayName(prompt)} - ${prompt.description}`); + } + } + } catch (error) { + console.log(`Prompts not supported by this server (${error})`); + } +} + +async function getPrompt(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: GetPromptRequest = { + method: 'prompts/get', + params: { + name, + arguments: args as Record + } + }; + + const result = await client.request(request, GetPromptResultSchema); + console.log('\nPrompt template:'); + for (const [index, msg] of result.messages.entries()) { + console.log(` [${index + 1}] ${msg.role}: ${isTextContent(msg.content) ? msg.content.text : JSON.stringify(msg.content)}`); + } + } catch (error) { + console.log(`Error getting prompt ${name}: ${error}`); + } +} + +async function listResources(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ListResourcesRequest = { + method: 'resources/list', + params: {} + }; + const result = await client.request(request, ListResourcesResultSchema); + + console.log('\nAvailable resources:'); + if (result.resources.length === 0) { + console.log(' No resources available'); + } else { + for (const resource of result.resources) { + console.log(` - ${resource.name}: ${getDisplayName(resource)} - ${resource.uri}`); + } + } + } catch (error) { + console.log(`Resources not supported by this server (${error})`); + } +} + +async function readResource(uri: string): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ReadResourceRequest = { + method: 'resources/read', + params: { uri } + }; + + console.log(`\nReading resource: ${uri}`); + const result = await client.request(request, ReadResourceResultSchema); + + console.log('Resource contents:'); + for (const content of result.contents) { + console.log(` URI: ${content.uri}`); + if (content.mimeType) { + console.log(` Type: ${content.mimeType}`); + } + + if ('text' in content && typeof content.text === 'string') { + console.log(' Content:'); + console.log(' ---'); + console.log( + content.text + .split('\n') + .map((line: string) => ' ' + line) + .join('\n') + ); + console.log(' ---'); + } else if ('blob' in content && typeof content.blob === 'string') { + console.log(` [Binary data: ${content.blob.length} bytes]`); + } + } + } catch (error) { + console.log(`Error reading resource ${uri}: ${error}`); + } +} + +async function cleanup(): Promise { + if (client && transport) { + try { + await transport.close(); + } catch (error) { + console.error('Error closing transport:', error); + } + } + + readline.close(); + console.log('\nGoodbye!'); + process.exit(0); +} + +// Handle Ctrl+C +process.on('SIGINT', async () => { + console.log('\nReceived SIGINT. Cleaning up...'); + await cleanup(); +}); + +// Start the interactive client +try { + await main(); +} catch (error) { + console.error('Error running MCP client:', error); + process.exit(1); +} diff --git a/examples/client/src/simpleTaskInteractiveClient.ts b/examples/client/src/simpleTaskInteractiveClient.ts index 2a4d47043..89310ef29 100644 --- a/examples/client/src/simpleTaskInteractiveClient.ts +++ b/examples/client/src/simpleTaskInteractiveClient.ts @@ -9,13 +9,14 @@ import { createInterface } from 'node:readline'; -import type { CreateMessageRequest, CreateMessageResult, TextContent } from '@modelcontextprotocol/client'; +import type { ContentBlock, CreateMessageRequest, CreateMessageResult } from '@modelcontextprotocol/client'; import { CallToolResultSchema, Client, CreateMessageRequestSchema, ElicitRequestSchema, ErrorCode, + isTextContent, McpError, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; @@ -34,9 +35,9 @@ function question(prompt: string): Promise { }); } -function getTextContent(result: { content: Array<{ type: string; text?: string }> }): string { - const textContent = result.content.find((c): c is TextContent => c.type === 'text'); - return textContent?.text ?? '(no text)'; +function getTextContent(result: { content: ContentBlock[] }): string | undefined { + const textContent = result.content.find(element => isTextContent(element)); + return textContent?.text; } async function elicitationCallback(params: { diff --git a/examples/client/src/streamableHttpWithSseFallbackClient.ts b/examples/client/src/streamableHttpWithSseFallbackClient.ts index 90fee9270..bda71d3f7 100644 --- a/examples/client/src/streamableHttpWithSseFallbackClient.ts +++ b/examples/client/src/streamableHttpWithSseFallbackClient.ts @@ -2,6 +2,7 @@ import type { CallToolRequest, ListToolsRequest } from '@modelcontextprotocol/cl import { CallToolResultSchema, Client, + isTextContent, ListToolsResultSchema, LoggingMessageNotificationSchema, SSEClientTransport, @@ -173,7 +174,7 @@ async function startNotificationTool(client: Client): Promise { console.log('Tool result:'); for (const item of result.content) { - if (item.type === 'text') { + if (isTextContent(item)) { console.log(` ${item.text}`); } else { console.log(` ${item.type} content:`, item); diff --git a/examples/server/src/elicitationFormExample.ts b/examples/server/src/elicitationFormExample.ts index 70ff8ecb5..2e8ad1b25 100644 --- a/examples/server/src/elicitationFormExample.ts +++ b/examples/server/src/elicitationFormExample.ts @@ -11,7 +11,7 @@ import { randomUUID } from 'node:crypto'; import { createMcpExpressApp } from '@modelcontextprotocol/express'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; -import { isInitializeRequest, McpServer } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer, text } from '@modelcontextprotocol/server'; import type { Request, Response } from 'express'; // Create MCP server - it will automatically use AjvJsonSchemaValidator with sensible defaults @@ -86,39 +86,21 @@ mcpServer.registerTool( return { content: [ - { - type: 'text', - text: `Registration successful!\n\nUsername: ${username}\nEmail: ${email}\nNewsletter: ${newsletter ? 'Yes' : 'No'}` - } + text(`Registration successful!\n\nUsername: ${username}\nEmail: ${email}\nNewsletter: ${newsletter ? 'Yes' : 'No'}`) ] }; } else if (result.action === 'decline') { return { - content: [ - { - type: 'text', - text: 'Registration cancelled by user.' - } - ] + content: [text('Registration cancelled by user.')] }; } else { return { - content: [ - { - type: 'text', - text: 'Registration was cancelled.' - } - ] + content: [text('Registration was cancelled.')] }; } } catch (error) { return { - content: [ - { - type: 'text', - text: `Registration failed: ${error instanceof Error ? error.message : String(error)}` - } - ], + content: [text(`Registration failed: ${error instanceof Error ? error.message : String(error)}`)], isError: true }; } @@ -162,7 +144,7 @@ mcpServer.registerTool( if (basicInfo.action !== 'accept' || !basicInfo.content) { return { - content: [{ type: 'text', text: 'Event creation cancelled.' }] + content: [text('Event creation cancelled.')] }; } @@ -198,7 +180,7 @@ mcpServer.registerTool( if (dateTime.action !== 'accept' || !dateTime.content) { return { - content: [{ type: 'text', text: 'Event creation cancelled.' }] + content: [text('Event creation cancelled.')] }; } @@ -209,21 +191,11 @@ mcpServer.registerTool( }; return { - content: [ - { - type: 'text', - text: `Event created successfully!\n\n${JSON.stringify(event, null, 2)}` - } - ] + content: [text(`Event created successfully!\n\n${JSON.stringify(event, null, 2)}`)] }; } catch (error) { return { - content: [ - { - type: 'text', - text: `Event creation failed: ${error instanceof Error ? error.message : String(error)}` - } - ], + content: [text(`Event creation failed: ${error instanceof Error ? error.message : String(error)}`)], isError: true }; } @@ -287,30 +259,20 @@ mcpServer.registerTool( if (result.action === 'accept' && result.content) { return { - content: [ - { - type: 'text', - text: `Address updated successfully!\n\n${JSON.stringify(result.content, null, 2)}` - } - ] + content: [text(`Address updated successfully!\n\n${JSON.stringify(result.content, null, 2)}`)] }; } else if (result.action === 'decline') { return { - content: [{ type: 'text', text: 'Address update cancelled by user.' }] + content: [text('Address update cancelled by user.')] }; } else { return { - content: [{ type: 'text', text: 'Address update was cancelled.' }] + content: [text('Address update was cancelled.')] }; } } catch (error) { return { - content: [ - { - type: 'text', - text: `Address update failed: ${error instanceof Error ? error.message : String(error)}` - } - ], + content: [text(`Address update failed: ${error instanceof Error ? error.message : String(error)}`)], isError: true }; } diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 808f252cc..85ae517d0 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -20,7 +20,8 @@ import { InMemoryTaskMessageQueue, InMemoryTaskStore, isInitializeRequest, - McpServer + McpServer, + TaskPlugin } from '@modelcontextprotocol/server'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -44,12 +45,18 @@ const getServer = () => { websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' }, { - capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } }, - taskStore, // Enable task support - taskMessageQueue: new InMemoryTaskMessageQueue() + capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } } } ); + // Enable task support via TaskPlugin + server.server.usePlugin( + new TaskPlugin({ + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + }) + ); + // Register a simple tool that returns a greeting server.registerTool( 'greet', diff --git a/examples/server/src/simpleStreamableHttpBuilder.ts b/examples/server/src/simpleStreamableHttpBuilder.ts new file mode 100644 index 000000000..de8574c9a --- /dev/null +++ b/examples/server/src/simpleStreamableHttpBuilder.ts @@ -0,0 +1,578 @@ +/** + * Simple Streamable HTTP Server Example using Builder Pattern + * + * This example demonstrates using the McpServer.builder() fluent API + * to create and configure an MCP server with: + * - Tools, resources, and prompts registration + * - Middleware (logging, rate limiting, custom metrics) + * - Per-tool middleware (authorization) + * - Error handlers (onError, onProtocolError) + * - Session management with SessionStore + * - Context helpers (logging, notifications) + * + * Run with: npx tsx src/simpleStreamableHttpBuilder.ts + */ + +import { randomUUID } from 'node:crypto'; + +import { createMcpExpressApp } from '@modelcontextprotocol/express'; +import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; +import type { CallToolResult, GetPromptResult, ReadResourceResult ,ToolMiddleware} from '@modelcontextprotocol/server'; +import { + createLoggingMiddleware, + createSessionStore, + isInitializeRequest, + McpServer, + text +} from '@modelcontextprotocol/server'; +import type { Request, Response } from 'express'; +import * as z from 'zod/v4'; + +import { InMemoryEventStore } from './inMemoryEventStore.js'; + +// ═══════════════════════════════════════════════════════════════════════════ +// Custom Middleware Examples +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Custom metrics middleware that tracks tool execution time. + * Demonstrates how to create custom middleware. + */ +const metricsMiddleware: ToolMiddleware = async (ctx, next) => { + const start = performance.now(); + try { + const result = await next(); + const duration = (performance.now() - start).toFixed(2); + console.log(`[METRICS] Tool '${ctx.name}' completed in ${duration}ms`); + return result; + } catch (error) { + const duration = (performance.now() - start).toFixed(2); + console.log(`[METRICS] Tool '${ctx.name}' failed in ${duration}ms`); + throw error; + } +}; + +/** + * Per-tool authorization middleware example. + * This is passed directly to a specific tool registration. + */ +const adminAuthMiddleware: ToolMiddleware = async (ctx, next) => { + // In a real app, check ctx.authInfo for admin scope + // For demo purposes, we'll check for a special argument + const args = ctx.args as Record; + if (args.requiresAdmin && !args.adminToken) { + throw new Error('Admin authorization required. Provide adminToken argument.'); + } + console.log(`[AUTH] Admin action authorized for tool '${ctx.name}'`); + return next(); +}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Session Store Setup +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Session data type - stores transport for each session. + */ +interface SessionData { + transport: NodeStreamableHTTPServerTransport; + createdAt: Date; +} + +/** + * Create session store with lifecycle events and timeout. + * This replaces the manual session map management. + */ +const sessionStore = createSessionStore({ + sessionTimeout: 30 * 60 * 1000, // 30 minutes + maxSessions: 100, + cleanupInterval: 60_000, // Check for expired sessions every minute + events: { + onSessionCreated: (id) => { + console.log(`[SESSION] Created: ${id}`); + }, + onSessionDestroyed: (id) => { + console.log(`[SESSION] Destroyed: ${id}`); + }, + onSessionUpdated: (id) => { + console.log(`[SESSION] Updated: ${id}`); + } + } +}); + +// ═══════════════════════════════════════════════════════════════════════════ +// Server Factory +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Creates an MCP server using the builder pattern. + * + * The builder provides a fluent API for configuring the server: + * - .name() and .version() set server info + * - .options() configures capabilities + * - .useMiddleware() adds universal middleware + * - .useToolMiddleware() adds tool-specific middleware + * - .tool() registers tools inline (with optional per-tool middleware) + * - .resource() registers resources inline + * - .prompt() registers prompts inline + * - .onError() handles application errors + * - .onProtocolError() handles protocol errors + * - .build() creates the configured McpServer instance + */ +const getServer = () => { + const server = McpServer.builder() + .name('builder-example-server') + .version('1.0.0') + .options({ + capabilities: { logging: {} } + }) + + // ─── Universal Middleware ─── + // Runs for all request types (tools, resources, prompts) + .useMiddleware( + createLoggingMiddleware({ + level: 'info', + logger: (level, message, data) => { + const timestamp = new Date().toISOString(); + console.log(`[${timestamp}] [${level.toUpperCase()}] ${message}`, data ? JSON.stringify(data) : ''); + } + }) + ) + + // ─── Tool-Specific Middleware ─── + .useToolMiddleware( + async (ctx, next) => { + console.log(`Tool '${ctx.name}' called`); + return next(); + } + ) + + // Custom metrics middleware + .useToolMiddleware(metricsMiddleware) + + // ─── Error Handlers ─── + .onError((error, ctx) => { + console.error(`[APP ERROR] ${ctx.type}/${ctx.name || ctx.method}: ${error.message}`); + // Return custom error response with additional context + return { + code: -32_000, + message: `Error in ${ctx.name || ctx.method}: ${error.message}`, + data: { type: ctx.type, requestId: ctx.requestId } + }; + }) + .onProtocolError((error, ctx) => { + console.error(`[PROTOCOL ERROR] ${ctx.method}: ${error.message}`); + // Protocol errors preserve error code, can customize message/data + return { + message: `Protocol error: ${error.message}`, + data: { requestId: ctx.requestId } + }; + }) + + // ─── Tool Registrations ─── + + // Simple greeting tool + .tool( + 'greet', + { + title: 'Greeting Tool', + description: 'A simple greeting tool that returns a personalized greeting', + inputSchema: { + name: z.string().describe('Name to greet') + } + }, + async ({ name }): Promise => { + return { + content: [text(`Hello, ${name}!`)] + }; + } + ) + + // Tool with notifications demonstrating context usage + .tool( + 'multi-greet', + { + title: 'Multiple Greeting Tool', + description: 'A tool that sends different greetings with delays and notifications', + inputSchema: { + name: z.string().describe('Name to greet') + } + }, + async function ({ name }, ctx): Promise { + const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); + + // Use context logging helper + await ctx.loggingNotification.debug(`Starting multi-greet for ${name}`); + + await sleep(1000); + + // Use sendNotification directly + await ctx.sendNotification({ + method: 'notifications/message', + params: { + level: 'info', + data: `Sending first greeting to ${name}` + } + }); + + await sleep(1000); + + await ctx.loggingNotification.info(`Sending second greeting to ${name}`); + + return { + content: [text(`Good morning, ${name}!`)] + }; + } + ) + + // Context demo tool - shows all context features + .tool( + 'context-demo', + { + title: 'Context Demo', + description: 'Demonstrates all context helper methods and properties', + inputSchema: { + message: z.string().describe('A message to echo back') + } + }, + async ({ message }, ctx): Promise => { + // Access MCP context + const mcpInfo = { + requestId: ctx.mcpCtx.requestId, + sessionId: ctx.mcpCtx.sessionId, + method: ctx.mcpCtx.method + }; + + // Access request context + const requestInfo = { + signalAborted: ctx.requestCtx.signal.aborted, + hasAuthInfo: !!ctx.requestCtx.authInfo + }; + + // Use logging helpers at different levels + await ctx.loggingNotification.debug('Debug: Processing context-demo'); + await ctx.loggingNotification.info('Info: Context inspection complete'); + + // Send custom notification + await ctx.sendNotification({ + method: 'notifications/message', + params: { + level: 'debug', + data: `Echo: ${message}` + } + }); + + return { + content: [ + text('Context Demo Results:'), + text(`MCP Context: ${JSON.stringify(mcpInfo, null, 2)}`), + text(`Request Context: ${JSON.stringify(requestInfo, null, 2)}`), + text(`Your message: ${message}`) + ] + }; + } + ) + + // Tool with per-tool middleware (authorization) + .tool( + 'admin-action', + { + title: 'Admin Action', + description: 'An admin-only tool demonstrating per-tool middleware', + inputSchema: { + action: z.string().describe('Admin action to perform'), + requiresAdmin: z.boolean().optional().describe('Whether this action requires admin auth'), + adminToken: z.string().optional().describe('Admin token for authorization') + }, + middleware: adminAuthMiddleware // Per-tool middleware + }, + async ({ action }): Promise => { + return { + content: [text(`Admin action '${action}' executed successfully`)] + }; + } + ) + + // Tool that intentionally throws an error (for testing error handlers) + .tool( + 'error-test', + { + title: 'Error Test', + description: 'A tool that throws errors to test error handlers', + inputSchema: { + errorType: z.enum(['application', 'validation']).describe('Type of error to throw') + } + }, + async ({ errorType }): Promise => { + const error = errorType === 'application' ? new Error('This is a test application error') : new Error('Validation failed: invalid input format'); + throw error; + } + ) + + // ─── Resource Registration ─── + .resource( + 'greeting-resource', + 'https://example.com/greetings/default', + { + title: 'Default Greeting', + description: 'A simple greeting resource' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'https://example.com/greetings/default', + mimeType: 'text/plain', + text: 'Hello, world!' + } + ] + }; + } + ) + + // Resource demonstrating session info + .resource( + 'session-info', + 'https://example.com/session/info', + { + title: 'Session Information', + description: 'Returns current session statistics' + }, + async (): Promise => { + const stats = { + activeSessions: sessionStore.size(), + sessionIds: sessionStore.keys() + }; + return { + contents: [ + { + uri: 'https://example.com/session/info', + mimeType: 'application/json', + text: JSON.stringify(stats, null, 2) + } + ] + }; + } + ) + + // ─── Prompt Registration ─── + .prompt( + 'greeting-template', + { + title: 'Greeting Template', + description: 'A simple greeting prompt template', + argsSchema: { + name: z.string().describe('Name to include in greeting') + } + }, + async ({ name }): Promise => { + return { + messages: [ + { + role: 'user', + content: text(`Please greet ${name} in a friendly manner.`) + } + ] + }; + } + ) + + .build(); + + return server; +}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Express App Setup +// ═══════════════════════════════════════════════════════════════════════════ + +const PORT = process.env.PORT ? Number.parseInt(process.env.PORT, 10) : 3000; + +const app = createMcpExpressApp(); + +// ═══════════════════════════════════════════════════════════════════════════ +// MCP Request Handlers +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * MCP POST endpoint handler. + * Uses SessionStore for session management. + */ +const mcpPostHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + + try { + // Check for existing session + const session = sessionId ? sessionStore.get(sessionId) : undefined; + + if (session) { + // Reuse existing transport + console.log(`[REQUEST] Using existing session: ${sessionId}`); + await session.transport.handleRequest(req, res, req.body); + return; + } + + if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request - create session + console.log('[REQUEST] New initialization request'); + + const eventStore = new InMemoryEventStore(); + const transport = new NodeStreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + eventStore, + onsessioninitialized: (sid) => { + // Store session with SessionStore + sessionStore.set(sid, { + transport, + createdAt: new Date() + }); + } + }); + + // Clean up session when transport closes + transport.onclose = () => { + const sid = transport.sessionId; + if (sid) { + sessionStore.delete(sid); + } + }; + + // Connect the transport to the MCP server + const server = getServer(); + await server.connect(transport); + + await transport.handleRequest(req, res, req.body); + return; + } + + // Invalid request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32_000, + message: 'Bad Request: No valid session ID provided' + }, + id: null + }); + } catch (error) { + console.error('[ERROR] Handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32_603, + message: 'Internal server error' + }, + id: null + }); + } + } +}; + +app.post('/mcp', mcpPostHandler); + +/** + * MCP GET endpoint handler for SSE streams. + */ +const mcpGetHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + const session = sessionId ? sessionStore.get(sessionId) : undefined; + + if (!session) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + const lastEventId = req.headers['last-event-id'] as string | undefined; + if (lastEventId) { + console.log(`[SSE] Client reconnecting with Last-Event-ID: ${lastEventId}`); + } else { + console.log(`[SSE] Establishing new stream for session ${sessionId}`); + } + + await session.transport.handleRequest(req, res); +}; + +app.get('/mcp', mcpGetHandler); + +/** + * MCP DELETE endpoint handler for session termination. + */ +const mcpDeleteHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + const session = sessionId ? sessionStore.get(sessionId) : undefined; + + if (!session) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + console.log(`[SESSION] Termination request for session ${sessionId}`); + + try { + await session.transport.handleRequest(req, res); + } catch (error) { + console.error('[ERROR] Session termination:', error); + if (!res.headersSent) { + res.status(500).send('Error processing session termination'); + } + } +}; + +app.delete('/mcp', mcpDeleteHandler); + +// ═══════════════════════════════════════════════════════════════════════════ +// Server Startup +// ═══════════════════════════════════════════════════════════════════════════ + +app.listen(PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + // eslint-disable-next-line unicorn/no-process-exit + process.exit(1); + } + console.log('═══════════════════════════════════════════════════════════════'); + console.log('MCP Builder Example Server'); + console.log('═══════════════════════════════════════════════════════════════'); + console.log(`Listening on port ${PORT}`); + console.log(`MCP endpoint: http://localhost:${PORT}/mcp`); + console.log(''); + console.log('Features demonstrated:'); + console.log(' - Builder pattern for server configuration'); + console.log(' - Universal middleware (logging)'); + console.log(' - Tool-specific middleware (rate limiting, metrics)'); + console.log(' - Per-tool middleware (authorization)'); + console.log(' - Error handlers (onError, onProtocolError)'); + console.log(' - Session management with SessionStore'); + console.log(' - Context helpers (logging, notifications)'); + console.log('═══════════════════════════════════════════════════════════════'); +}); + +// ═══════════════════════════════════════════════════════════════════════════ +// Graceful Shutdown +// ═══════════════════════════════════════════════════════════════════════════ + +process.on('SIGINT', async () => { + console.log('\n[SHUTDOWN] Received SIGINT, shutting down...'); + + // Close all sessions using SessionStore + const sessionIds = sessionStore.keys(); + for (const sid of sessionIds) { + try { + const session = sessionStore.get(sid); + if (session) { + console.log(`[SHUTDOWN] Closing session ${sid}`); + await session.transport.close(); + } + } catch (error) { + console.error(`[SHUTDOWN] Error closing session ${sid}:`, error); + } + } + + // Clear the session store (also stops cleanup timer) + sessionStore.clear(); + (sessionStore as { dispose?: () => void }).dispose?.(); + + console.log('[SHUTDOWN] Complete'); + process.exit(0); +}); diff --git a/examples/server/src/toolWithSampleServer.ts b/examples/server/src/toolWithSampleServer.ts index 9881830b5..e3a529b98 100644 --- a/examples/server/src/toolWithSampleServer.ts +++ b/examples/server/src/toolWithSampleServer.ts @@ -1,6 +1,6 @@ // Run with: pnpm tsx src/toolWithSampleServer.ts -import { McpServer, StdioServerTransport } from '@modelcontextprotocol/server'; +import { isTextContent, McpServer, StdioServerTransport } from '@modelcontextprotocol/server'; import * as z from 'zod/v4'; const mcpServer = new McpServer({ @@ -37,7 +37,7 @@ mcpServer.registerTool( content: [ { type: 'text', - text: response.content.type === 'text' ? response.content.text : 'Unable to generate summary' + text: isTextContent(response.content) ? response.content.text : 'Unable to generate summary' } ] }; diff --git a/packages/client/src/client/builder.ts b/packages/client/src/client/builder.ts new file mode 100644 index 000000000..eb5423887 --- /dev/null +++ b/packages/client/src/client/builder.ts @@ -0,0 +1,452 @@ +/** + * Client Builder + * + * Provides a fluent API for configuring and creating Client instances. + * The builder is an additive convenience layer - the existing constructor + * API remains available for users who prefer it. + * + * @example + * ```typescript + * const client = Client.builder() + * .name('my-client') + * .version('1.0.0') + * .capabilities({ sampling: {} }) + * .useMiddleware(loggingMiddleware) + * .onSamplingRequest(samplingHandler) + * .build(); + * ``` + */ + +import type { + ClientCapabilities, + CreateMessageRequest, + CreateMessageResult, + CreateMessageResultWithTools, + CreateTaskResult, + ElicitRequest, + ElicitResult, + jsonSchemaValidator, + ListChangedHandlers, + ListRootsRequest, + ListRootsResult +} from '@modelcontextprotocol/core'; + +import type { Client } from './client.js'; +import type { ClientContextInterface } from './context.js'; +import type { + ClientMiddleware, + ElicitationMiddleware, + IncomingMiddleware, + OutgoingMiddleware, + ResourceReadMiddleware, + SamplingMiddleware, + ToolCallMiddleware +} from './middleware.js'; + +/** + * Handler for sampling requests from the server. + * Receives the full CreateMessageRequest and returns the sampling result. + * When task creation is requested via params.task, returns CreateTaskResult instead. + */ +export type SamplingRequestHandler = ( + request: CreateMessageRequest, + ctx: ClientContextInterface +) => + | CreateMessageResult + | CreateMessageResultWithTools + | CreateTaskResult + | Promise; + +/** + * Handler for elicitation requests from the server. + * Receives the full ElicitRequest and returns the elicitation result. + * When task creation is requested via params.task, returns CreateTaskResult instead. + */ +export type ElicitationRequestHandler = ( + request: ElicitRequest, + ctx: ClientContextInterface +) => ElicitResult | CreateTaskResult | Promise; + +/** + * Handler for roots list requests from the server. + * Receives the full ListRootsRequest and returns the list of roots. + */ +export type RootsListHandler = ( + request: ListRootsRequest, + ctx: ClientContextInterface +) => ListRootsResult | Promise; + +/** + * Error handler type for application errors + */ +export type OnErrorHandler = (error: Error, ctx: ErrorContext) => OnErrorReturn | void | Promise; + +/** + * Error handler type for protocol errors + */ +export type OnProtocolErrorHandler = ( + error: Error, + ctx: ErrorContext +) => OnProtocolErrorReturn | void | Promise; + +/** + * Return type for onError handler + */ +export type OnErrorReturn = string | { code?: number; message?: string; data?: unknown } | Error; + +/** + * Return type for onProtocolError handler (code cannot be changed) + */ +export type OnProtocolErrorReturn = string | { message?: string; data?: unknown }; + +/** + * Context provided to error handlers + */ +export interface ErrorContext { + type: 'sampling' | 'elicitation' | 'rootsList' | 'protocol'; + method: string; + requestId: string; +} + +/** + * Options for client configuration + */ +export interface ClientBuilderOptions { + /** Enforce strict capability checking */ + enforceStrictCapabilities?: boolean; +} + +/** + * Fluent builder for Client instances. + * + * Provides a declarative, chainable API for configuring clients. + * All configuration is collected and applied when build() is called. + */ +export class ClientBuilder { + private _name?: string; + private _version?: string; + private _capabilities?: ClientCapabilities; + private _options: ClientBuilderOptions = {}; + private _jsonSchemaValidator?: jsonSchemaValidator; + private _listChanged?: ListChangedHandlers; + + // Middleware + private _universalMiddleware: ClientMiddleware[] = []; + private _outgoingMiddleware: OutgoingMiddleware[] = []; + private _incomingMiddleware: IncomingMiddleware[] = []; + private _toolCallMiddleware: ToolCallMiddleware[] = []; + private _resourceReadMiddleware: ResourceReadMiddleware[] = []; + private _samplingMiddleware: SamplingMiddleware[] = []; + private _elicitationMiddleware: ElicitationMiddleware[] = []; + + // Handlers + private _samplingHandler?: SamplingRequestHandler; + private _elicitationHandler?: ElicitationRequestHandler; + private _rootsListHandler?: RootsListHandler; + + // Error handlers + private _onError?: OnErrorHandler; + private _onProtocolError?: OnProtocolErrorHandler; + + /** + * Sets the client name. + */ + name(name: string): this { + this._name = name; + return this; + } + + /** + * Sets the client version. + */ + version(version: string): this { + this._version = version; + return this; + } + + /** + * Sets the client capabilities. + * + * @example + * ```typescript + * .capabilities({ + * sampling: {}, + * roots: { listChanged: true } + * }) + * ``` + */ + capabilities(capabilities: ClientCapabilities): this { + this._capabilities = { ...this._capabilities, ...capabilities }; + return this; + } + + /** + * Sets client options. + */ + options(options: ClientBuilderOptions): this { + this._options = { ...this._options, ...options }; + return this; + } + + /** + * Sets the JSON Schema validator for tool output validation. + * + * @example + * ```typescript + * .jsonSchemaValidator(new AjvJsonSchemaValidator()) + * ``` + */ + jsonSchemaValidator(validator: jsonSchemaValidator): this { + this._jsonSchemaValidator = validator; + return this; + } + + /** + * Configures handlers for list changed notifications (tools, prompts, resources). + * + * @example + * ```typescript + * .onListChanged({ + * tools: { + * onChanged: (error, tools) => console.log('Tools updated:', tools) + * }, + * prompts: { + * onChanged: (error, prompts) => console.log('Prompts updated:', prompts) + * } + * }) + * ``` + */ + onListChanged(handlers: ListChangedHandlers): this { + this._listChanged = { ...this._listChanged, ...handlers }; + return this; + } + + /** + * Adds universal middleware that runs for all requests. + */ + useMiddleware(middleware: ClientMiddleware): this { + this._universalMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware for outgoing requests only. + */ + useOutgoingMiddleware(middleware: OutgoingMiddleware): this { + this._outgoingMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware for incoming requests only. + */ + useIncomingMiddleware(middleware: IncomingMiddleware): this { + this._incomingMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for tool calls. + */ + useToolCallMiddleware(middleware: ToolCallMiddleware): this { + this._toolCallMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for resource reads. + */ + useResourceReadMiddleware(middleware: ResourceReadMiddleware): this { + this._resourceReadMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for sampling requests. + */ + useSamplingMiddleware(middleware: SamplingMiddleware): this { + this._samplingMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for elicitation requests. + */ + useElicitationMiddleware(middleware: ElicitationMiddleware): this { + this._elicitationMiddleware.push(middleware); + return this; + } + + /** + * Sets the handler for sampling requests from the server. + * + * @example + * ```typescript + * .onSamplingRequest(async (params, ctx) => { + * const result = await llm.complete(params.messages); + * return { role: 'assistant', content: result }; + * }) + * ``` + */ + onSamplingRequest(handler: SamplingRequestHandler): this { + this._samplingHandler = handler; + return this; + } + + /** + * Sets the handler for elicitation requests from the server. + */ + onElicitation(handler: ElicitationRequestHandler): this { + this._elicitationHandler = handler; + return this; + } + + /** + * Sets the handler for roots list requests from the server. + */ + onRootsList(handler: RootsListHandler): this { + this._rootsListHandler = handler; + return this; + } + + /** + * Sets the application error handler. + * Called when a handler throws an error. + */ + onError(handler: OnErrorHandler): this { + this._onError = handler; + return this; + } + + /** + * Sets the protocol error handler. + * Called for protocol-level errors. + */ + onProtocolError(handler: OnProtocolErrorHandler): this { + this._onProtocolError = handler; + return this; + } + + /** + * Gets the collected configuration (for debugging/testing). + */ + getConfig(): { + name?: string; + version?: string; + capabilities?: ClientCapabilities; + options: ClientBuilderOptions; + middlewareCount: number; + hasHandlers: boolean; + } { + return { + name: this._name, + version: this._version, + capabilities: this._capabilities, + options: this._options, + middlewareCount: + this._universalMiddleware.length + + this._outgoingMiddleware.length + + this._incomingMiddleware.length + + this._toolCallMiddleware.length + + this._resourceReadMiddleware.length + + this._samplingMiddleware.length + + this._elicitationMiddleware.length, + hasHandlers: !!this._samplingHandler || !!this._elicitationHandler || !!this._rootsListHandler + }; + } + + /** + * Builds and returns the configured Client instance. + */ + build(): Client { + if (!this._name) { + throw new Error('Client name is required. Use .name() to set it.'); + } + if (!this._version) { + throw new Error('Client version is required. Use .version() to set it.'); + } + + const result: ClientBuilderResult = { + clientInfo: { + name: this._name, + version: this._version + }, + capabilities: this._capabilities, + options: this._options, + jsonSchemaValidator: this._jsonSchemaValidator, + listChanged: this._listChanged, + middleware: { + universal: this._universalMiddleware, + outgoing: this._outgoingMiddleware, + incoming: this._incomingMiddleware, + toolCall: this._toolCallMiddleware, + resourceRead: this._resourceReadMiddleware, + sampling: this._samplingMiddleware, + elicitation: this._elicitationMiddleware + }, + handlers: { + sampling: this._samplingHandler, + elicitation: this._elicitationHandler, + rootsList: this._rootsListHandler + }, + errorHandlers: { + onError: this._onError, + onProtocolError: this._onProtocolError + } + }; + + // Dynamically import Client to create the instance + // eslint-disable-next-line @typescript-eslint/no-require-imports + const { Client: ClientClass } = require('./client.js'); + return ClientClass.fromBuilderResult(result); + } +} + +/** + * Result of building the client configuration. + * Used to create the actual Client instance. + */ +export interface ClientBuilderResult { + clientInfo: { + name: string; + version: string; + }; + capabilities?: ClientCapabilities; + options: ClientBuilderOptions; + jsonSchemaValidator?: jsonSchemaValidator; + listChanged?: ListChangedHandlers; + middleware: { + universal: ClientMiddleware[]; + outgoing: OutgoingMiddleware[]; + incoming: IncomingMiddleware[]; + toolCall: ToolCallMiddleware[]; + resourceRead: ResourceReadMiddleware[]; + sampling: SamplingMiddleware[]; + elicitation: ElicitationMiddleware[]; + }; + handlers: { + sampling?: SamplingRequestHandler; + elicitation?: ElicitationRequestHandler; + rootsList?: RootsListHandler; + }; + errorHandlers: { + onError?: OnErrorHandler; + onProtocolError?: OnProtocolErrorHandler; + }; +} + +/** + * Creates a new ClientBuilder instance. + * + * @example + * ```typescript + * const client = createClientBuilder() + * .name('my-client') + * .version('1.0.0') + * .capabilities({ sampling: {} }) + * .build(); + * ``` + */ +export function createClientBuilder(): ClientBuilder { + return new ClientBuilder(); +} diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 5eaf2371a..e5e4bbb91 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -20,6 +20,7 @@ import type { ListPromptsRequest, ListResourcesRequest, ListResourceTemplatesRequest, + ListRootsResult, ListToolsRequest, LoggingLevel, McpContext, @@ -33,9 +34,6 @@ import type { SchemaOutput, ServerCapabilities, SubscribeRequest, - TaskContext, - TaskCreationParams, - TaskStore, Tool, Transport, UnsubscribeRequest, @@ -65,6 +63,7 @@ import { ListPromptsResultSchema, ListResourcesResultSchema, ListResourceTemplatesResultSchema, + ListRootsRequestSchema, ListToolsResultSchema, McpError, mergeCapabilities, @@ -78,8 +77,20 @@ import { } from '@modelcontextprotocol/core'; import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; +import type { ClientBuilderResult, ErrorContext, OnErrorHandler, OnProtocolErrorHandler } from './builder.js'; +import { ClientBuilder } from './builder.js'; import type { ClientRequestContext } from './context.js'; import { ClientContext } from './context.js'; +import type { + ClientMiddleware, + ElicitationMiddleware, + IncomingMiddleware, + OutgoingMiddleware, + ResourceReadMiddleware, + SamplingMiddleware, + ToolCallMiddleware +} from './middleware.js'; +import { ClientMiddlewareManager } from './middleware.js'; /** * Elicitation default application helper. Applies defaults to the data based on the schema. @@ -262,6 +273,7 @@ export class Client< private _experimental?: { tasks: ExperimentalClientTasks }; private _listChangedDebounceTimers: Map> = new Map(); private _pendingListChangedConfig?: ListChangedHandlers; + private readonly _middleware: ClientMiddlewareManager; /** * Initializes this client with the given name and version information. @@ -273,6 +285,7 @@ export class Client< super(options); this._capabilities = options?.capabilities ?? {}; this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new AjvJsonSchemaValidator(); + this._middleware = new ClientMiddlewareManager(); // Store list changed config for setup after connection (when we know server capabilities) if (options?.listChanged) { @@ -280,6 +293,184 @@ export class Client< } } + /** + * Gets the middleware manager for advanced middleware configuration. + */ + get middleware(): ClientMiddlewareManager { + return this._middleware; + } + + /** + * Registers universal middleware that runs for all request types. + * + * @param middleware - The middleware function to register + * @returns This Client instance for chaining + */ + useMiddleware(middleware: ClientMiddleware): this { + this._middleware.useMiddleware(middleware); + return this; + } + + /** + * Registers middleware for outgoing requests only. + * + * @param middleware - The outgoing middleware function to register + * @returns This Client instance for chaining + */ + useOutgoingMiddleware(middleware: OutgoingMiddleware): this { + this._middleware.useOutgoingMiddleware(middleware); + return this; + } + + /** + * Registers middleware for incoming requests only. + * + * @param middleware - The incoming middleware function to register + * @returns This Client instance for chaining + */ + useIncomingMiddleware(middleware: IncomingMiddleware): this { + this._middleware.useIncomingMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for tool calls. + * + * @param middleware - The tool call middleware function to register + * @returns This Client instance for chaining + */ + useToolCallMiddleware(middleware: ToolCallMiddleware): this { + this._middleware.useToolCallMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for resource reads. + * + * @param middleware - The resource read middleware function to register + * @returns This Client instance for chaining + */ + useResourceReadMiddleware(middleware: ResourceReadMiddleware): this { + this._middleware.useResourceReadMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for sampling requests. + * + * @param middleware - The sampling middleware function to register + * @returns This Client instance for chaining + */ + useSamplingMiddleware(middleware: SamplingMiddleware): this { + this._middleware.useSamplingMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for elicitation requests. + * + * @param middleware - The elicitation middleware function to register + * @returns This Client instance for chaining + */ + useElicitationMiddleware(middleware: ElicitationMiddleware): this { + this._middleware.useElicitationMiddleware(middleware); + return this; + } + + /** + * Creates a new ClientBuilder for fluent configuration. + * + * @example + * ```typescript + * const client = Client.builder() + * .name('my-client') + * .version('1.0.0') + * .capabilities({ sampling: {} }) + * .onSamplingRequest(async (params) => { + * // Handle sampling request from server + * return { role: 'assistant', content: { type: 'text', text: '...' } }; + * }) + * .build(); + * ``` + */ + static builder(): ClientBuilder { + return new ClientBuilder(); + } + + /** + * Creates a Client from a ClientBuilderResult configuration. + * + * @param result - The result from ClientBuilder.build() + * @returns A configured Client instance + */ + static fromBuilderResult(result: ClientBuilderResult): Client { + const client = new Client(result.clientInfo, { + capabilities: result.capabilities, + enforceStrictCapabilities: result.options.enforceStrictCapabilities, + jsonSchemaValidator: result.jsonSchemaValidator, + listChanged: result.listChanged + }); + + // Register handlers + if (result.handlers.sampling) { + client.setRequestHandler( + CreateMessageRequestSchema, + result.handlers.sampling as Parameters[1] + ); + } + + if (result.handlers.elicitation) { + client.setRequestHandler(ElicitRequestSchema, result.handlers.elicitation as Parameters[1]); + } + + if (result.handlers.rootsList) { + client.setRequestHandler(ListRootsRequestSchema, result.handlers.rootsList as Parameters[1]); + } + + // Wire up error handlers to Protocol events + if (result.errorHandlers.onError || result.errorHandlers.onProtocolError) { + client.events.on('error', ({ error, context }) => { + const errorContext = { + type: (context as 'sampling' | 'elicitation' | 'rootsList' | 'protocol') || 'protocol', + method: context || 'unknown', + requestId: 'unknown' + }; + + // Call the appropriate error handler based on context + if (context === 'protocol' && result.errorHandlers.onProtocolError) { + (result.errorHandlers.onProtocolError as (error: Error, ctx: typeof errorContext) => void)(error, errorContext); + } else if (result.errorHandlers.onError) { + (result.errorHandlers.onError as (error: Error, ctx: typeof errorContext) => void)(error, errorContext); + } + }); + } + + // Apply middleware from builder + for (const middleware of result.middleware.universal) { + client.useMiddleware(middleware); + } + for (const middleware of result.middleware.outgoing) { + client.useOutgoingMiddleware(middleware); + } + for (const middleware of result.middleware.incoming) { + client.useIncomingMiddleware(middleware); + } + for (const middleware of result.middleware.toolCall) { + client.useToolCallMiddleware(middleware); + } + for (const middleware of result.middleware.resourceRead) { + client.useResourceReadMiddleware(middleware); + } + for (const middleware of result.middleware.sampling) { + client.useSamplingMiddleware(middleware); + } + for (const middleware of result.middleware.elicitation) { + client.useElicitationMiddleware(middleware); + } + + return client; + } + /** * Set up handlers for list changed notifications based on config and server capabilities. * This should only be called after initialization when server capabilities are known. @@ -495,14 +686,11 @@ export class Client< protected createRequestContext(args: { request: JSONRPCRequest; - taskStore: TaskStore | undefined; - relatedTaskId: string | undefined; - taskCreationParams: TaskCreationParams | undefined; abortController: AbortController; capturedTransport: Transport | undefined; extra?: MessageExtraInfo; }): ContextInterface { - const { request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra } = args; + const { request, abortController, capturedTransport, extra } = args; const sessionId = capturedTransport?.sessionId; // Build the MCP context using the helper from Protocol @@ -514,22 +702,12 @@ export class Client< authInfo: extra?.authInfo }; - // Build the task context using the helper from Protocol - const taskCtx: TaskContext | undefined = this.buildTaskContext({ - taskStore, - request, - sessionId, - relatedTaskId, - taskCreationParams - }); - - // Return a ClientContext instance + // Return a ClientContext instance (task context is added by plugins if needed) return new ClientContext({ client: this, request, mcpContext, - requestCtx, - task: taskCtx + requestCtx }); } @@ -980,4 +1158,91 @@ export class Client< async sendRootsListChanged() { return this.notification({ method: 'notifications/roots/list_changed' }); } + + /** + * Registers a handler for roots/list requests from the server. + * + * @param handler - Handler function that returns the list of roots + * @returns This Client instance for chaining + * + * @example + * ```typescript + * client.onRootsList(async () => ({ + * roots: [ + * { uri: 'file:///workspace', name: 'Workspace' } + * ] + * })); + * ``` + */ + onRootsList( + handler: ( + ctx: ContextInterface + ) => ListRootsResult | Promise + ): this { + this.setRequestHandler(ListRootsRequestSchema, (_request, ctx) => handler(ctx)); + return this; + } + + /** + * Registers an error handler for application errors. + * + * The handler receives the error and a context object with information about where + * the error occurred. It can optionally return a custom error response. + * + * @param handler - Error handler function + * @returns Unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = client.onError((error, ctx) => { + * console.error(`Error in ${ctx.type}: ${error.message}`); + * // Optionally return a custom error response + * return { + * code: -32000, + * message: `Application error: ${error.message}`, + * data: { type: ctx.type } + * }; + * }); + * ``` + */ + onError(handler: OnErrorHandler): () => void { + return this.events.on('error', ({ error, context }) => { + const errorContext: ErrorContext = { + type: (context as 'sampling' | 'elicitation' | 'rootsList' | 'protocol') || 'protocol', + method: context || 'unknown', + requestId: 'unknown' + }; + handler(error, errorContext); + }); + } + + /** + * Registers an error handler for protocol errors. + * + * The handler receives the error and a context object. It can optionally return + * a custom error response (but cannot change the error code). + * + * @param handler - Error handler function + * @returns Unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = client.onProtocolError((error, ctx) => { + * console.error(`Protocol error in ${ctx.method}: ${error.message}`); + * return { message: `Protocol error: ${error.message}` }; + * }); + * ``` + */ + onProtocolError(handler: OnProtocolErrorHandler): () => void { + return this.events.on('error', ({ error, context }) => { + if (context === 'protocol') { + const errorContext: ErrorContext = { + type: 'protocol', + method: context || 'unknown', + requestId: 'unknown' + }; + handler(error, errorContext); + } + }); + } } diff --git a/packages/client/src/client/context.ts b/packages/client/src/client/context.ts index 48965db5e..108a89e0f 100644 --- a/packages/client/src/client/context.ts +++ b/packages/client/src/client/context.ts @@ -54,7 +54,7 @@ export class ClientContext< request: JSONRPCRequest; mcpContext: McpContext; requestCtx: ClientRequestContext; - task: TaskContext | undefined; + task?: TaskContext; }) { super({ request: args.request, diff --git a/packages/client/src/client/middleware.ts b/packages/client/src/client/middleware.ts index 3fd52e41a..6525fb32a 100644 --- a/packages/client/src/client/middleware.ts +++ b/packages/client/src/client/middleware.ts @@ -1,8 +1,28 @@ -import type { FetchLike } from '@modelcontextprotocol/core'; +/** + * Client Middleware System + * + * This module provides two distinct middleware systems: + * + * 1. Fetch Middleware - For HTTP/fetch level operations (OAuth, logging, etc.) + * 2. MCP Client Middleware - For MCP protocol level operations (tool calls, sampling, etc.) + */ + +import type { + AuthInfo, + CallToolResult, + CreateMessageResult, + ElicitResult, + FetchLike, + ReadResourceResult +} from '@modelcontextprotocol/core'; import type { OAuthClientProvider } from './auth.js'; import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +// ═══════════════════════════════════════════════════════════════════════════ +// Fetch Middleware (HTTP Level) +// ═══════════════════════════════════════════════════════════════════════════ + /** * Middleware function that wraps and enhances fetch functionality. * Takes a fetch handler and returns an enhanced fetch handler. @@ -320,3 +340,391 @@ export const applyMiddlewares = (...middleware: Middleware[]): Middleware => { export const createMiddleware = (handler: (next: FetchLike, input: string | URL, init?: RequestInit) => Promise): Middleware => { return next => (input, init) => handler(next, input as string | URL, init); }; + +// ═══════════════════════════════════════════════════════════════════════════ +// MCP Client Middleware (Protocol Level) +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Base context shared by all MCP client middleware + */ +interface BaseClientContext { + /** The request ID */ + requestId: string; + /** Abort signal for cancellation */ + signal: AbortSignal; +} + +/** + * Context for outgoing requests (client → server) + */ +export interface OutgoingContext extends BaseClientContext { + direction: 'outgoing'; + /** The type of outgoing request */ + type: + | 'callTool' + | 'readResource' + | 'getPrompt' + | 'listTools' + | 'listResources' + | 'listPrompts' + | 'ping' + | 'complete' + | 'initialize' + | 'other'; + /** The JSON-RPC method name */ + method: string; + /** The request parameters */ + params: unknown; +} + +/** + * Context for incoming requests (server → client) + */ +export interface IncomingContext extends BaseClientContext { + direction: 'incoming'; + /** The type of incoming request */ + type: 'sampling' | 'elicitation' | 'rootsList' | 'other'; + /** The JSON-RPC method name */ + method: string; + /** The request parameters */ + params: unknown; + /** Authentication info if available */ + authInfo?: AuthInfo; +} + +/** + * Union type for all client contexts + */ +export type ClientContext = OutgoingContext | IncomingContext; + +// ═══════════════════════════════════════════════════════════════════════════ +// Type-Specific Contexts +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Context for tool call requests + */ +export interface ToolCallContext extends OutgoingContext { + type: 'callTool'; + params: { + name: string; + arguments?: unknown; + }; +} + +/** + * Context for resource read requests + */ +export interface ResourceReadContext extends OutgoingContext { + type: 'readResource'; + params: { + uri: string; + }; +} + +/** + * Context for sampling requests (server → client) + */ +export interface SamplingContext extends IncomingContext { + type: 'sampling'; + params: { + messages: unknown[]; + maxTokens?: number; + [key: string]: unknown; + }; +} + +/** + * Context for elicitation requests (server → client) + */ +export interface ElicitationContext extends IncomingContext { + type: 'elicitation'; + params: { + message?: string; + [key: string]: unknown; + }; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// MCP Middleware Types +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Next function for MCP client middleware + */ +export type ClientNextFn = (modifiedParams?: unknown) => Promise; + +/** + * Universal middleware for all MCP client requests + */ +export type ClientMiddleware = (ctx: ClientContext, next: ClientNextFn) => Promise; + +/** + * Middleware for outgoing requests only + */ +export type OutgoingMiddleware = (ctx: OutgoingContext, next: ClientNextFn) => Promise; + +/** + * Middleware for incoming requests only + */ +export type IncomingMiddleware = (ctx: IncomingContext, next: ClientNextFn) => Promise; + +/** + * Middleware specifically for tool calls + */ +export type ToolCallMiddleware = (ctx: ToolCallContext, next: ClientNextFn) => Promise; + +/** + * Middleware specifically for resource reads + */ +export type ResourceReadMiddleware = (ctx: ResourceReadContext, next: ClientNextFn) => Promise; + +/** + * Middleware specifically for sampling requests + */ +export type SamplingMiddleware = (ctx: SamplingContext, next: ClientNextFn) => Promise; + +/** + * Middleware specifically for elicitation requests + */ +export type ElicitationMiddleware = (ctx: ElicitationContext, next: ClientNextFn) => Promise; + +// ═══════════════════════════════════════════════════════════════════════════ +// MCP Middleware Manager +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Manages MCP middleware registration and execution for Client. + */ +export class ClientMiddlewareManager { + private _universalMiddleware: ClientMiddleware[] = []; + private _outgoingMiddleware: OutgoingMiddleware[] = []; + private _incomingMiddleware: IncomingMiddleware[] = []; + private _toolCallMiddleware: ToolCallMiddleware[] = []; + private _resourceReadMiddleware: ResourceReadMiddleware[] = []; + private _samplingMiddleware: SamplingMiddleware[] = []; + private _elicitationMiddleware: ElicitationMiddleware[] = []; + + /** + * Registers universal middleware that runs for all requests. + */ + useMiddleware(middleware: ClientMiddleware): this { + this._universalMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware for outgoing requests only. + */ + useOutgoingMiddleware(middleware: OutgoingMiddleware): this { + this._outgoingMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware for incoming requests only. + */ + useIncomingMiddleware(middleware: IncomingMiddleware): this { + this._incomingMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for tool calls. + */ + useToolCallMiddleware(middleware: ToolCallMiddleware): this { + this._toolCallMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for resource reads. + */ + useResourceReadMiddleware(middleware: ResourceReadMiddleware): this { + this._resourceReadMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for sampling requests. + */ + useSamplingMiddleware(middleware: SamplingMiddleware): this { + this._samplingMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for elicitation requests. + */ + useElicitationMiddleware(middleware: ElicitationMiddleware): this { + this._elicitationMiddleware.push(middleware); + return this; + } + + /** + * Executes the middleware chain for an outgoing tool call. + */ + async executeToolCall(ctx: ToolCallContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._outgoingMiddleware as unknown as ClientMiddleware[]), + ...this._toolCallMiddleware + ], + handler + ); + } + + /** + * Executes the middleware chain for an outgoing resource read. + */ + async executeResourceRead( + ctx: ResourceReadContext, + handler: (params?: unknown) => Promise + ): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._outgoingMiddleware as unknown as ClientMiddleware[]), + ...this._resourceReadMiddleware + ], + handler + ); + } + + /** + * Executes the middleware chain for an incoming sampling request. + */ + async executeSampling(ctx: SamplingContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._incomingMiddleware as unknown as ClientMiddleware[]), + ...this._samplingMiddleware + ], + handler + ); + } + + /** + * Executes the middleware chain for an incoming elicitation request. + */ + async executeElicitation(ctx: ElicitationContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._incomingMiddleware as unknown as ClientMiddleware[]), + ...this._elicitationMiddleware + ], + handler + ); + } + + /** + * Executes the middleware chain for a generic outgoing request. + */ + async executeOutgoing(ctx: OutgoingContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._outgoingMiddleware as unknown as ClientMiddleware[]) + ], + handler + ); + } + + /** + * Executes the middleware chain for a generic incoming request. + */ + async executeIncoming(ctx: IncomingContext, handler: (params?: unknown) => Promise): Promise { + return this._executeChain( + ctx, + [ + ...this._adaptToTyped(this._universalMiddleware), + ...this._adaptToTyped(this._incomingMiddleware as unknown as ClientMiddleware[]) + ], + handler + ); + } + + /** + * Checks if any middleware is registered. + */ + hasMiddleware(): boolean { + return ( + this._universalMiddleware.length > 0 || + this._outgoingMiddleware.length > 0 || + this._incomingMiddleware.length > 0 || + this._toolCallMiddleware.length > 0 || + this._resourceReadMiddleware.length > 0 || + this._samplingMiddleware.length > 0 || + this._elicitationMiddleware.length > 0 + ); + } + + /** + * Clears all registered middleware. + */ + clear(): void { + this._universalMiddleware = []; + this._outgoingMiddleware = []; + this._incomingMiddleware = []; + this._toolCallMiddleware = []; + this._resourceReadMiddleware = []; + this._samplingMiddleware = []; + this._elicitationMiddleware = []; + } + + /** + * Adapts generic middleware to a typed middleware. + */ + private _adaptToTyped( + middlewares: ClientMiddleware[] + ): Array<(ctx: TCtx, next: ClientNextFn) => Promise> { + return middlewares.map(mw => { + return async (ctx: TCtx, next: ClientNextFn): Promise => { + return (await mw(ctx, next as ClientNextFn)) as TResult; + }; + }); + } + + /** + * Executes a chain of middleware. + */ + private async _executeChain( + ctx: TCtx, + middlewares: Array<(ctx: TCtx, next: ClientNextFn) => Promise>, + handler: (params?: unknown) => Promise + ): Promise { + let index = -1; + let currentParams: unknown = ctx.params; + + const dispatch = async (i: number, params?: unknown): Promise => { + if (i <= index) { + throw new Error('next() called multiple times'); + } + index = i; + if (params !== undefined) { + currentParams = params; + } + + if (i >= middlewares.length) { + return handler(currentParams); + } + + const middleware = middlewares[i]; + if (!middleware) { + return handler(currentParams); + } + return middleware(ctx, (modifiedParams?: unknown) => dispatch(i + 1, modifiedParams)); + }; + + return dispatch(0); + } +} \ No newline at end of file diff --git a/packages/client/src/experimental/tasks/client.ts b/packages/client/src/experimental/tasks/client.ts index df57e91a4..b4f62b6c7 100644 --- a/packages/client/src/experimental/tasks/client.ts +++ b/packages/client/src/experimental/tasks/client.ts @@ -20,7 +20,7 @@ import type { Result, SchemaOutput } from '@modelcontextprotocol/core'; -import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/core'; +import { CallToolResultSchema, ErrorCode, McpError, TaskClientPlugin } from '@modelcontextprotocol/core'; import type { Client } from '../../client/client.js'; @@ -56,6 +56,20 @@ export class ExperimentalClientTasks< > { constructor(private readonly _client: Client) {} + /** + * Gets the TaskClientPlugin, throwing if not installed. + */ + private _getTaskClient(): TaskClientPlugin { + const plugin = this._client.getPlugin(TaskClientPlugin); + if (!plugin) { + throw new McpError( + ErrorCode.InternalError, + 'TaskClientPlugin not installed. Use client.usePlugin(new TaskClientPlugin()) first.' + ); + } + return plugin; + } + /** * Calls a tool and returns an AsyncGenerator that yields response messages. * The generator is guaranteed to end with either a 'result' or 'error' message. @@ -179,9 +193,7 @@ export class ExperimentalClientTasks< * @experimental */ async getTask(taskId: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - type ClientWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise }; - return (this._client as unknown as ClientWithGetTask).getTask({ taskId }, options); + return this._getTaskClient().getTask({ taskId }, options); } /** @@ -195,16 +207,10 @@ export class ExperimentalClientTasks< * @experimental */ async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - getTaskResult: ( - params: { taskId: string }, - resultSchema?: U, - options?: RequestOptions - ) => Promise>; - } - ).getTaskResult({ taskId }, resultSchema, options); + if (!resultSchema) { + throw new McpError(ErrorCode.InvalidParams, 'resultSchema is required'); + } + return this._getTaskClient().getTaskResult({ taskId }, resultSchema, options); } /** @@ -217,12 +223,7 @@ export class ExperimentalClientTasks< * @experimental */ async listTasks(cursor?: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise; - } - ).listTasks(cursor ? { cursor } : undefined, options); + return this._getTaskClient().listTasks(cursor ? { cursor } : undefined, options); } /** @@ -234,12 +235,7 @@ export class ExperimentalClientTasks< * @experimental */ async cancelTask(taskId: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise; - } - ).cancelTask({ taskId }, options); + return this._getTaskClient().cancelTask({ taskId }, options); } /** diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index 787cfd2f0..71674f898 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -1,5 +1,6 @@ export * from './client/auth.js'; export * from './client/authExtensions.js'; +export * from './client/builder.js'; export * from './client/client.js'; export * from './client/middleware.js'; export * from './client/sse.js'; diff --git a/packages/core/src/errors.ts b/packages/core/src/errors.ts new file mode 100644 index 000000000..581cae058 --- /dev/null +++ b/packages/core/src/errors.ts @@ -0,0 +1,400 @@ +/** + * MCP SDK Error Hierarchy + * + * This module defines a comprehensive error hierarchy for the MCP SDK: + * + * 1. Protocol Errors (McpError subclasses) - Errors that cross the wire as JSON-RPC errors + * - ProtocolError: SDK-generated errors with spec-mandated codes (code is locked) + * - ApplicationError: User handler errors wrapped by SDK (code can be customized) + * + * 2. SDK Errors (SdkError subclasses) - Local errors that don't cross the wire + * - StateError: Wrong SDK state (not connected, already connected, etc.) + * - CapabilityError: Missing required capability + * - TransportError: Network/connection issues + * - ValidationError: Local schema validation issues + * + * 3. OAuth Errors - Kept in auth/errors.ts (unchanged) + */ + +import { ErrorCode } from './types/types.js'; + +// ═══════════════════════════════════════════════════════════════════════════ +// SDK Error Codes (for local errors that don't cross the wire) +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Error codes for local SDK errors (not transmitted over JSON-RPC) + */ +export enum SdkErrorCode { + // State errors + NOT_CONNECTED = 'NOT_CONNECTED', + ALREADY_CONNECTED = 'ALREADY_CONNECTED', + INVALID_STATE = 'INVALID_STATE', + REGISTRATION_AFTER_CONNECT = 'REGISTRATION_AFTER_CONNECT', + + // Capability errors + CAPABILITY_NOT_SUPPORTED = 'CAPABILITY_NOT_SUPPORTED', + + // Transport errors + CONNECTION_FAILED = 'CONNECTION_FAILED', + CONNECTION_LOST = 'CONNECTION_LOST', + CONNECTION_TIMEOUT = 'CONNECTION_TIMEOUT', + SEND_FAILED = 'SEND_FAILED', + + // Validation errors + INVALID_SCHEMA = 'INVALID_SCHEMA', + INVALID_REQUEST = 'INVALID_REQUEST', + INVALID_RESPONSE = 'INVALID_RESPONSE' +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Protocol Error Subclasses (McpError hierarchy - crosses the wire) +// ═══════════════════════════════════════════════════════════════════════════ + +// Note: McpError is defined in types/types.ts and re-exported from there. +// These subclasses provide more specific error types. + +/** + * Protocol-level errors generated by the SDK for protocol violations. + * The error code is LOCKED and cannot be changed in onProtocolError handlers. + * + * These errors are for spec-mandated situations like: + * - Parse errors (-32700) + * - Invalid request (-32600) + * - Method not found (-32601) + * - Invalid params (-32602) + */ +export class ProtocolError extends Error { + /** + * Indicates this is a protocol-level error with a locked code + */ + readonly isProtocolLevel = true as const; + + constructor( + public readonly code: number, + message: string, + public readonly data?: unknown + ) { + super(`MCP protocol error ${code}: ${message}`); + this.name = 'ProtocolError'; + } + + /** + * Creates a parse error (-32700) + */ + static parseError(message: string = 'Parse error', data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.ParseError, message, data); + } + + /** + * Creates an invalid request error (-32600) + */ + static invalidRequest(message: string = 'Invalid request', data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.InvalidRequest, message, data); + } + + /** + * Creates a method not found error (-32601) + */ + static methodNotFound(method: string, data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.MethodNotFound, `Method not found: ${method}`, data); + } + + /** + * Creates an invalid params error (-32602) + */ + static invalidParams(message: string = 'Invalid params', data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.InvalidParams, message, data); + } +} + +/** + * Application-level errors from user handler code, wrapped by the SDK. + * The error code CAN be customized in onError handlers. + * + * Default code is InternalError (-32603), but can be changed. + */ +export class ApplicationError extends Error { + /** + * Indicates this is an application-level error with a customizable code + */ + readonly isProtocolLevel = false as const; + + constructor( + public code: number = ErrorCode.InternalError, + message: string, + public readonly data?: unknown, + public override readonly cause?: Error + ) { + super(`MCP application error ${code}: ${message}`); + this.name = 'ApplicationError'; + if (cause) { + this.cause = cause; + } + } + + /** + * Wraps any error as an ApplicationError + */ + static wrap(error: unknown, code: number = ErrorCode.InternalError): ApplicationError { + if (error instanceof ApplicationError) { + return error; + } + if (error instanceof Error) { + return new ApplicationError(code, error.message, undefined, error); + } + return new ApplicationError(code, String(error)); + } + + /** + * Creates an internal error (-32603) + */ + static internalError(message: string, data?: unknown, cause?: Error): ApplicationError { + return new ApplicationError(ErrorCode.InternalError, message, data, cause); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// SDK Error Hierarchy (local errors - don't cross the wire) +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Base class for local SDK errors that don't cross the wire. + * These are thrown locally and should be caught by the SDK user. + */ +export abstract class SdkError extends Error { + /** + * The SDK error code for programmatic handling + */ + abstract readonly code: SdkErrorCode; + + /** + * Whether this error is potentially recoverable + */ + readonly recoverable: boolean = false; + + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Errors related to incorrect SDK state. + * Examples: "Not connected", "Already connected", "Cannot register after connecting" + */ +export class StateError extends SdkError { + readonly code: SdkErrorCode; + + constructor( + code: + | SdkErrorCode.NOT_CONNECTED + | SdkErrorCode.ALREADY_CONNECTED + | SdkErrorCode.INVALID_STATE + | SdkErrorCode.REGISTRATION_AFTER_CONNECT, + message: string + ) { + super(message); + this.code = code; + } + + /** + * Creates a "not connected" error + */ + static notConnected(operation: string = 'perform this operation'): StateError { + return new StateError(SdkErrorCode.NOT_CONNECTED, `Cannot ${operation}: not connected`); + } + + /** + * Creates an "already connected" error + */ + static alreadyConnected(): StateError { + return new StateError(SdkErrorCode.ALREADY_CONNECTED, 'Already connected'); + } + + /** + * Creates an "invalid state" error + */ + static invalidState(message: string): StateError { + return new StateError(SdkErrorCode.INVALID_STATE, message); + } + + /** + * Creates a "registration after connect" error + */ + static registrationAfterConnect(type: string): StateError { + return new StateError(SdkErrorCode.REGISTRATION_AFTER_CONNECT, `Cannot register ${type} after connecting`); + } +} + +/** + * Errors related to missing or unsupported capabilities. + * Example: "Server does not support X (required for Y)" + */ +export class CapabilityError extends SdkError { + readonly code = SdkErrorCode.CAPABILITY_NOT_SUPPORTED as const; + + constructor( + public readonly capability: string, + public readonly requiredFor?: string + ) { + const message = requiredFor + ? `Capability '${capability}' is not supported (required for ${requiredFor})` + : `Capability '${capability}' is not supported`; + super(message); + } + + /** + * Creates a capability error for a missing server capability + */ + static serverDoesNotSupport(capability: string, requiredFor?: string): CapabilityError { + return new CapabilityError(capability, requiredFor); + } + + /** + * Creates a capability error for a missing client capability + */ + static clientDoesNotSupport(capability: string, requiredFor?: string): CapabilityError { + return new CapabilityError(capability, requiredFor); + } +} + +/** + * Errors related to transport/network issues. + * Examples: Connection failed, timeout, connection lost + */ +export class TransportError extends SdkError { + readonly code: SdkErrorCode; + override readonly recoverable: boolean; + + constructor( + code: SdkErrorCode.CONNECTION_FAILED | SdkErrorCode.CONNECTION_LOST | SdkErrorCode.CONNECTION_TIMEOUT | SdkErrorCode.SEND_FAILED, + message: string, + public override readonly cause?: Error + ) { + super(message); + this.code = code; + // Connection lost and timeout are potentially recoverable via retry + this.recoverable = code === SdkErrorCode.CONNECTION_LOST || code === SdkErrorCode.CONNECTION_TIMEOUT; + } + + /** + * Creates a connection failed error + */ + static connectionFailed(message: string = 'Connection failed', cause?: Error): TransportError { + return new TransportError(SdkErrorCode.CONNECTION_FAILED, message, cause); + } + + /** + * Creates a connection lost error + */ + static connectionLost(message: string = 'Connection lost', cause?: Error): TransportError { + const error = new TransportError(SdkErrorCode.CONNECTION_LOST, message, cause); + return error; + } + + /** + * Creates a connection timeout error + */ + static connectionTimeout(timeoutMs: number, cause?: Error): TransportError { + return new TransportError(SdkErrorCode.CONNECTION_TIMEOUT, `Connection timed out after ${timeoutMs}ms`, cause); + } + + /** + * Creates a send failed error + */ + static sendFailed(message: string = 'Failed to send message', cause?: Error): TransportError { + return new TransportError(SdkErrorCode.SEND_FAILED, message, cause); + } +} + +/** + * Errors related to local schema/validation issues (before sending). + * Examples: "Schema is missing a method literal", "Invalid request format" + */ +export class ValidationError extends SdkError { + readonly code: SdkErrorCode; + + constructor( + code: SdkErrorCode.INVALID_SCHEMA | SdkErrorCode.INVALID_REQUEST | SdkErrorCode.INVALID_RESPONSE, + message: string, + public readonly details?: unknown + ) { + super(message); + this.code = code; + } + + /** + * Creates an invalid schema error + */ + static invalidSchema(message: string, details?: unknown): ValidationError { + return new ValidationError(SdkErrorCode.INVALID_SCHEMA, message, details); + } + + /** + * Creates an invalid request error (local validation) + */ + static invalidRequest(message: string, details?: unknown): ValidationError { + return new ValidationError(SdkErrorCode.INVALID_REQUEST, message, details); + } + + /** + * Creates an invalid response error (local validation) + */ + static invalidResponse(message: string, details?: unknown): ValidationError { + return new ValidationError(SdkErrorCode.INVALID_RESPONSE, message, details); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Type Guards +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Type guard to check if an error is a ProtocolError + */ +export function isProtocolError(error: unknown): error is ProtocolError { + return error instanceof ProtocolError; +} + +/** + * Type guard to check if an error is an ApplicationError + */ +export function isApplicationError(error: unknown): error is ApplicationError { + return error instanceof ApplicationError; +} + +/** + * Type guard to check if an error is an SdkError + */ +export function isSdkError(error: unknown): error is SdkError { + return error instanceof SdkError; +} + +/** + * Type guard to check if an error is a StateError + */ +export function isStateError(error: unknown): error is StateError { + return error instanceof StateError; +} + +/** + * Type guard to check if an error is a CapabilityError + */ +export function isCapabilityError(error: unknown): error is CapabilityError { + return error instanceof CapabilityError; +} + +/** + * Type guard to check if an error is a TransportError + */ +export function isTransportError(error: unknown): error is TransportError { + return error instanceof TransportError; +} + +/** + * Type guard to check if an error is a ValidationError + */ +export function isValidationError(error: unknown): error is ValidationError { + return error instanceof ValidationError; +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index ca5875421..e111b9956 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,15 +1,25 @@ export * from './auth/errors.js'; +export * from './errors.js'; export * from './shared/auth.js'; export * from './shared/authUtils.js'; export * from './shared/context.js'; +export * from './shared/events.js'; +export * from './shared/handlerRegistry.js'; export * from './shared/metadataUtils.js'; +export * from './shared/plugin.js'; +export * from './shared/pluginContext.js'; +export * from './shared/progressManager.js'; export * from './shared/protocol.js'; export * from './shared/responseMessage.js'; export * from './shared/stdio.js'; +export * from './shared/taskClientPlugin.js'; +export * from './shared/taskPlugin.js'; +export * from './shared/timeoutManager.js'; export * from './shared/toolNameValidation.js'; export * from './shared/transport.js'; export * from './shared/uriTemplate.js'; export * from './types/types.js'; +export * from './util/content.js'; export * from './util/inMemory.js'; export * from './util/zodCompat.js'; export * from './util/zodJsonSchemaCompat.js'; diff --git a/packages/core/src/shared/context.ts b/packages/core/src/shared/context.ts index 4e117bcee..df05d906b 100644 --- a/packages/core/src/shared/context.ts +++ b/packages/core/src/shared/context.ts @@ -1,8 +1,23 @@ import type { RequestTaskStoreInterface } from '../experimental/requestTaskStore.js'; -import type { AuthInfo, JSONRPCRequest, Notification, Request, RequestId, RequestMeta, Result } from '../types/types.js'; +import type { + AuthInfo, + JSONRPCRequest, + Notification, + RelatedTaskMetadata, + Request, + RequestId, + RequestMeta, + Result +} from '../types/types.js'; import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; import type { NotificationOptions, Protocol, RequestOptions } from './protocol.js'; +/** + * Internal type for options that may include task-related fields. + * Used by context methods that need to set relatedTask. + */ +type OptionsWithTask = { relatedTask?: RelatedTaskMetadata }; + /** * MCP-level context for a request being handled. * Contains information about the JSON-RPC request and session. @@ -114,8 +129,9 @@ export interface BaseContextArgs => { - const notificationOptions: NotificationOptions = { relatedRequestId: this.mcpCtx.requestId }; + const notificationOptions: NotificationOptions & OptionsWithTask = { relatedRequestId: this.mcpCtx.requestId }; // Only set relatedTask if there's a valid (non-empty) task ID // Empty task ID means no task has been created yet or task queuing isn't applicable @@ -194,7 +211,7 @@ export abstract class BaseContext< resultSchema: U, options?: RequestOptions ): Promise> => { - const requestOptions: RequestOptions = { ...options, relatedRequestId: this.mcpCtx.requestId }; + const requestOptions: RequestOptions & OptionsWithTask = { ...options, relatedRequestId: this.mcpCtx.requestId }; // Only set relatedTask if there's a valid (non-empty) task ID // Empty task ID means no task has been created yet or task queuing isn't applicable diff --git a/packages/core/src/shared/events.ts b/packages/core/src/shared/events.ts new file mode 100644 index 000000000..60adc039d --- /dev/null +++ b/packages/core/src/shared/events.ts @@ -0,0 +1,274 @@ +/** + * Event Emitter System + * + * A lightweight, type-safe event emitter for SDK observability. + * + * Design decisions: + * - Custom implementation instead of Node's EventEmitter for cross-platform compatibility + * - Works in Node.js, browsers, and edge runtimes + * - Type-safe event names and payloads + * - Modern API with unsubscribe function returned from `on()` + */ + +/** + * Type-safe event emitter interface. + * Events is a record mapping event names to their payload types. + */ +export interface McpEventEmitter> { + /** + * Subscribe to an event. + * @param event - The event name + * @param listener - The callback to invoke when the event is emitted + * @returns An unsubscribe function + */ + on(event: K, listener: (data: Events[K]) => void): () => void; + + /** + * Subscribe to an event for a single occurrence. + * @param event - The event name + * @param listener - The callback to invoke when the event is emitted + * @returns An unsubscribe function + */ + once(event: K, listener: (data: Events[K]) => void): () => void; + + /** + * Unsubscribe from an event. + * @param event - The event name + * @param listener - The callback to remove + */ + off(event: K, listener: (data: Events[K]) => void): void; + + /** + * Emit an event with data. + * @param event - The event name + * @param data - The event payload + */ + emit(event: K, data: Events[K]): void; +} + +/** + * Type-safe event emitter implementation. + * Provides a minimal, cross-platform event system. + */ +export class TypedEventEmitter> implements McpEventEmitter { + private _listeners = new Map void>>(); + + /** + * Subscribe to an event. + * + * @param event - The event name + * @param listener - The callback to invoke when the event is emitted + * @returns An unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = emitter.on('connection:opened', ({ sessionId }) => { + * console.log(`Connected: ${sessionId}`); + * }); + * + * // Later, to unsubscribe: + * unsubscribe(); + * ``` + */ + on(event: K, listener: (data: Events[K]) => void): () => void { + if (!this._listeners.has(event)) { + this._listeners.set(event, new Set()); + } + const listeners = this._listeners.get(event)!; + listeners.add(listener as (data: unknown) => void); + + // Return unsubscribe function + return () => this.off(event, listener); + } + + /** + * Subscribe to an event for a single occurrence. + * The listener is automatically removed after the first invocation. + * + * @param event - The event name + * @param listener - The callback to invoke when the event is emitted + * @returns An unsubscribe function + */ + once(event: K, listener: (data: Events[K]) => void): () => void { + const wrapper = (data: Events[K]): void => { + this.off(event, wrapper); + listener(data); + }; + return this.on(event, wrapper); + } + + /** + * Unsubscribe from an event. + * + * @param event - The event name + * @param listener - The callback to remove + */ + off(event: K, listener: (data: Events[K]) => void): void { + const listeners = this._listeners.get(event); + if (listeners) { + listeners.delete(listener as (data: unknown) => void); + if (listeners.size === 0) { + this._listeners.delete(event); + } + } + } + + /** + * Emit an event with data. + * All registered listeners for the event will be invoked synchronously. + * + * @param event - The event name + * @param data - The event payload + */ + emit(event: K, data: Events[K]): void { + const listeners = this._listeners.get(event); + if (listeners) { + // Create a copy to allow listeners to unsubscribe during iteration + for (const listener of listeners) { + try { + listener(data); + } catch { + // Silently ignore listener errors to prevent one listener + // from breaking others. Errors should be handled by the listener. + } + } + } + } + + /** + * Check if any listeners are registered for an event. + * + * @param event - The event name + * @returns true if there are listeners for the event + */ + hasListeners(event: K): boolean { + const listeners = this._listeners.get(event); + return listeners !== undefined && listeners.size > 0; + } + + /** + * Get the number of listeners for an event. + * + * @param event - The event name + * @returns The number of listeners + */ + listenerCount(event: K): number { + const listeners = this._listeners.get(event); + return listeners?.size ?? 0; + } + + /** + * Remove all listeners for a specific event, or all events if no event is specified. + * + * @param event - Optional event name. If not provided, removes all listeners. + */ + removeAllListeners(event?: K): void { + if (event === undefined) { + this._listeners.clear(); + } else { + this._listeners.delete(event); + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Pre-defined Event Maps for SDK Components +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Events emitted by McpServer. + */ +export interface McpServerEvents { + [key: string]: unknown; + + /** + * Emitted when a tool is registered. + */ + 'tool:registered': { name: string; tool: unknown }; + + /** + * Emitted when a tool is removed. + */ + 'tool:removed': { name: string }; + + /** + * Emitted when a resource is registered. + */ + 'resource:registered': { uri: string; resource: unknown }; + + /** + * Emitted when a resource is removed. + */ + 'resource:removed': { uri: string }; + + /** + * Emitted when a prompt is registered. + */ + 'prompt:registered': { name: string; prompt: unknown }; + + /** + * Emitted when a prompt is removed. + */ + 'prompt:removed': { name: string }; + + /** + * Emitted when a connection is opened. + */ + 'connection:opened': { sessionId: string }; + + /** + * Emitted when a connection is closed. + */ + 'connection:closed': { sessionId: string; reason?: string }; + + /** + * Emitted when an error occurs. + */ + error: { error: Error; context?: string }; +} + +/** + * Events emitted by Client. + */ +export interface McpClientEvents { + [key: string]: unknown; + + /** + * Emitted when a connection is opened. + */ + 'connection:opened': { sessionId: string }; + + /** + * Emitted when a connection is closed. + */ + 'connection:closed': { sessionId: string; reason?: string }; + + /** + * Emitted when a tool call is made. + */ + 'tool:called': { name: string; args: unknown }; + + /** + * Emitted when a tool call returns a result. + */ + 'tool:result': { name: string; result: unknown }; + + /** + * Emitted when an error occurs. + */ + error: { error: Error; context?: string }; +} + +/** + * Creates a new typed event emitter for McpServer events. + */ +export function createServerEventEmitter(): TypedEventEmitter { + return new TypedEventEmitter(); +} + +/** + * Creates a new typed event emitter for Client events. + */ +export function createClientEventEmitter(): TypedEventEmitter { + return new TypedEventEmitter(); +} diff --git a/packages/core/src/shared/handlerRegistry.ts b/packages/core/src/shared/handlerRegistry.ts new file mode 100644 index 000000000..12cb0a0d8 --- /dev/null +++ b/packages/core/src/shared/handlerRegistry.ts @@ -0,0 +1,184 @@ +/** + * Handler Registry + * + * Manages request and notification handlers for the Protocol class. + * Extracted from Protocol to follow Single Responsibility Principle. + * + * This registry is focused on storage and management - it does NOT handle: + * - Schema parsing (handled by Protocol) + * - Capability assertions (handled by Protocol) + */ + +import type { JSONRPCNotification, JSONRPCRequest, Notification, Request, RequestId, Result } from '../types/types.js'; +import type { BaseRequestContext, ContextInterface } from './context.js'; + +/** + * Internal handler type for request handlers (after parsing by Protocol) + */ +export type InternalRequestHandler = ( + request: JSONRPCRequest, + extra: ContextInterface +) => Promise; + +/** + * Internal notification handler type (after parsing by Protocol) + */ +export type InternalNotificationHandler = (notification: JSONRPCNotification) => Promise; + +/** + * Manages request and notification handlers for the Protocol. + * Focused on storage, retrieval, and abort controller management. + */ +export class HandlerRegistry { + private _requestHandlers = new Map>(); + private _notificationHandlers = new Map(); + private _requestHandlerAbortControllers = new Map(); + + /** + * A handler to invoke for any request types that do not have their own handler installed. + */ + fallbackRequestHandler?: InternalRequestHandler; + + /** + * A handler to invoke for any notification types that do not have their own handler installed. + */ + fallbackNotificationHandler?: (notification: Notification) => Promise; + + // ═══════════════════════════════════════════════════════════════════════════ + // Request Handler Management + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Sets a request handler for a method. + * The handler should already be wrapped to handle JSONRPCRequest. + */ + setRequestHandler(method: string, handler: InternalRequestHandler): void { + this._requestHandlers.set(method, handler); + } + + /** + * Gets a request handler for a method, or the fallback handler if none exists. + */ + getRequestHandler(method: string): InternalRequestHandler | undefined { + return this._requestHandlers.get(method) ?? this.fallbackRequestHandler; + } + + /** + * Checks if a request handler exists for a method. + */ + hasRequestHandler(method: string): boolean { + return this._requestHandlers.has(method); + } + + /** + * Removes a request handler for a method. + */ + removeRequestHandler(method: string): void { + this._requestHandlers.delete(method); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Notification Handler Management + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Sets a notification handler for a method. + * The handler should already be wrapped to handle JSONRPCNotification. + */ + setNotificationHandler(method: string, handler: InternalNotificationHandler): void { + this._notificationHandlers.set(method, handler); + } + + /** + * Gets a notification handler for a method, or the fallback handler if none exists. + */ + getNotificationHandler(method: string): InternalNotificationHandler | undefined { + const handler = this._notificationHandlers.get(method); + if (handler) return handler; + // Wrap fallback to match InternalNotificationHandler signature + if (this.fallbackNotificationHandler) { + return async (notification: JSONRPCNotification) => { + await this.fallbackNotificationHandler!(notification as Notification); + }; + } + return undefined; + } + + /** + * Checks if a notification handler exists for a method. + */ + hasNotificationHandler(method: string): boolean { + return this._notificationHandlers.has(method); + } + + /** + * Removes a notification handler for a method. + */ + removeNotificationHandler(method: string): void { + this._notificationHandlers.delete(method); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Abort Controller Management + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Creates an AbortController for a request and stores it. + */ + createAbortController(requestId: RequestId): AbortController { + const controller = new AbortController(); + this._requestHandlerAbortControllers.set(requestId, controller); + return controller; + } + + /** + * Gets the AbortController for a request. + */ + getAbortController(requestId: RequestId): AbortController | undefined { + return this._requestHandlerAbortControllers.get(requestId); + } + + /** + * Removes the AbortController for a request. + */ + removeAbortController(requestId: RequestId): void { + this._requestHandlerAbortControllers.delete(requestId); + } + + /** + * Aborts all pending request handlers. + */ + abortAllPendingRequests(reason?: string): void { + for (const controller of this._requestHandlerAbortControllers.values()) { + controller.abort(reason); + } + this._requestHandlerAbortControllers.clear(); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Utility Methods + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Gets all registered request handler methods. + */ + getRequestMethods(): string[] { + return [...this._requestHandlers.keys()]; + } + + /** + * Gets all registered notification handler methods. + */ + getNotificationMethods(): string[] { + return [...this._notificationHandlers.keys()]; + } + + /** + * Clears all handlers and abort controllers. + */ + clear(): void { + this._requestHandlers.clear(); + this._notificationHandlers.clear(); + this.abortAllPendingRequests('Registry cleared'); + } +} diff --git a/packages/core/src/shared/plugin.ts b/packages/core/src/shared/plugin.ts new file mode 100644 index 000000000..0fa58f2d1 --- /dev/null +++ b/packages/core/src/shared/plugin.ts @@ -0,0 +1,481 @@ +/** + * Protocol Plugin System + * + * This module defines the plugin interface for extending Protocol functionality. + * Plugins are INTERNAL to the SDK - they are used for decomposing the Protocol class + * into focused components. They are not exposed as a public API for SDK users. + * + * For application-level extensibility (logging, auth, metrics), SDK users should + * use McpServer Middleware (see server/middleware.ts) or Client Middleware. + */ + +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + RequestId, + Result +} from '../types/types.js'; +import type { AnyObjectSchema, SchemaOutput } from '../util/zodCompat.js'; +import type { ProgressManagerInterface } from './progressManager.js'; +import type { Transport, TransportSendOptions } from './transport.js'; + +// ═══════════════════════════════════════════════════════════════════════════ +// Sub-Component Interfaces +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Interface for transport-related operations accessible to plugins. + */ +export interface PluginTransportInterface { + /** + * Get the current transport (may be undefined if not connected) + */ + getTransport(): Transport | undefined; + + /** + * Get the session ID (if available) + */ + getSessionId(): string | undefined; + + /** + * Send a message through the transport + */ + send( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise; +} + +/** + * Interface for making outbound requests from plugins. + */ +export interface PluginRequestsInterface { + /** + * Send a request through the protocol and wait for a response. + * + * @param request - The request to send + * @param resultSchema - Schema to validate the response + * @param options - Optional request options (timeout, signal, etc.) + * @returns The validated response + */ + sendRequest( + request: JSONRPCRequest, + resultSchema: T, + options?: PluginRequestOptions + ): Promise>; +} + +/** + * Interface for registering and managing handlers. + */ +export interface PluginHandlersInterface { + /** + * Register a request handler for a specific method. + * Handler returns SendResultT to ensure type safety with the Protocol. + */ + setRequestHandler( + schema: T, + handler: (request: SchemaOutput, extra: PluginHandlerExtra) => SendResultT | Promise + ): void; + + /** + * Register a notification handler for a specific method + */ + setNotificationHandler(schema: T, handler: (notification: SchemaOutput) => void | Promise): void; + + /** + * Remove a request handler + */ + removeRequestHandler(method: string): void; + + /** + * Remove a notification handler + */ + removeNotificationHandler(method: string): void; +} + +/** + * Interface for managing request resolvers. + * Used by TaskPlugin for routing queued responses back to their original callers. + */ +export interface PluginResolversInterface { + /** + * Register a resolver for a pending request. + */ + register(id: RequestId, resolver: (response: JSONRPCResultResponse | Error) => void): void; + + /** + * Get a resolver for a pending request. + */ + get(id: RequestId): ((response: JSONRPCResultResponse | Error) => void) | undefined; + + /** + * Remove a resolver for a pending request. + */ + remove(id: RequestId): void; +} + +/** + * Options for plugin requests. + */ +export interface PluginRequestOptions { + /** + * Timeout in milliseconds for the request + */ + timeout?: number; + + /** + * Abort signal for cancelling the request + */ + signal?: AbortSignal; + + /** Allow additional options */ + [key: string]: unknown; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Plugin Context +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Context provided to plugins during installation. + * Composed of focused sub-components for different concerns. + */ +export interface PluginContext { + /** + * Transport operations (get transport, send messages) + */ + readonly transport: PluginTransportInterface; + + /** + * Outbound request operations + */ + readonly requests: PluginRequestsInterface; + + /** + * Handler registration and management + */ + readonly handlers: PluginHandlersInterface; + + /** + * Request resolver management (for task response routing) + */ + readonly resolvers: PluginResolversInterface; + + /** + * Progress handler management + */ + readonly progress: ProgressManagerInterface; + + /** + * Report an error through the protocol's error handling + */ + reportError(error: Error): void; +} + +/** + * Extra context passed to plugin request handlers. + */ +export interface PluginHandlerExtra { + /** + * MCP context with request metadata + */ + readonly mcpCtx: { + readonly requestId: RequestId; + readonly sessionId?: string; + }; + + /** + * Request context with abort signal + */ + readonly requestCtx: { + readonly signal: AbortSignal; + }; +} + +/** + * Context provided to plugin hooks during request processing. + */ +export interface RequestContext { + /** + * The session ID for this request + */ + readonly sessionId?: string; + + /** + * The request ID from the JSON-RPC message + */ + readonly requestId: number | string; + + /** + * The method being called + */ + readonly method: string; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Protocol Plugin Interface +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Plugin interface for extending Protocol functionality. + * + * Plugins are internal SDK components for decomposing the Protocol class. + * They can: + * - Register handlers during installation + * - Hook into request/response lifecycle + * - Route messages (e.g., for task queueing) + * + * Note: Plugins are NOT a public API for SDK users. For application-level + * extensibility, use McpServer/Client middleware instead. + */ +export interface ProtocolPlugin { + /** + * Unique name for this plugin (for debugging and identification) + */ + readonly name: string; + + /** + * Priority determines execution order. Higher priority = runs first. + * Default: 0 + */ + readonly priority?: number; + + // ─── LIFECYCLE HOOKS ─── + + /** + * Called when the plugin is installed on a Protocol instance. + * Use this to register handlers, set up state, etc. + */ + install?(ctx: PluginContext): void | Promise; + + /** + * Called when a transport is connected. + */ + onConnect?(transport: Transport): void | Promise; + + /** + * Called when the connection is closed. + */ + onClose?(): void | Promise; + + // ─── MESSAGE ROUTING ─── + + /** + * Determines if this plugin should route the message instead of the default transport. + * Used by TaskPlugin to queue messages for task-related responses. + */ + shouldRouteMessage?( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): boolean; + + /** + * Routes the message. Only called if shouldRouteMessage returned true. + */ + routeMessage?( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise; + + // ─── REQUEST/RESPONSE HOOKS ─── + + /** + * Called before a request is processed. + * Can modify the request or return void to pass through unchanged. + */ + onRequest?(request: JSONRPCRequest, ctx: RequestContext): JSONRPCRequest | void | Promise; + + /** + * Called after a request is successfully processed. + * Can modify the result or return void to pass through unchanged. + */ + onRequestResult?(request: JSONRPCRequest, result: Result, ctx: RequestContext): Result | void | Promise; + + /** + * Called when a request handler throws an error. + * Can modify the error or return void to pass through unchanged. + */ + onRequestError?(request: JSONRPCRequest, error: Error, ctx: RequestContext): Error | void | Promise; + + /** + * Called when a response is received (for outgoing requests). + * Plugins can use this to manage progress handlers or other state. + * @param response - The response received + * @param messageId - The message ID (progress token) for this request + */ + onResponse?(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): void | Promise; + + // ─── NOTIFICATION HOOKS ─── + + /** + * Called before a notification is processed. + * Can modify the notification or return void to pass through unchanged. + */ + onNotification?(notification: JSONRPCNotification): JSONRPCNotification | void | Promise; + + // ─── OUTGOING MESSAGE HOOKS ─── + + /** + * Called before sending an outgoing request. + * Plugins can augment request params (e.g., add task metadata) or register response resolvers. + * @param request - The request being sent (can be mutated) + * @param options - The request options (can be mutated) + * @returns Modified request, or void to use original + */ + onBeforeSendRequest?(request: JSONRPCRequest, options: OutgoingRequestContext): JSONRPCRequest | void | Promise; + + /** + * Called before sending an outgoing notification. + * Plugins can augment notification params (e.g., add task metadata). + * @param notification - The notification being sent (can be mutated) + * @param options - The notification options (can be mutated) + * @returns Modified notification, or void to use original + */ + onBeforeSendNotification?( + notification: JSONRPCNotification, + options: OutgoingNotificationContext + ): JSONRPCNotification | void | Promise; + + // ─── HANDLER CONTEXT HOOKS ─── + + /** + * Called when building context for an incoming request handler. + * Plugins can contribute additional context (e.g., task context). + * @param request - The incoming request + * @param baseContext - Base context with session info + * @returns Additional context fields to merge, or void + */ + onBuildHandlerContext?( + request: JSONRPCRequest, + baseContext: HandlerContextBase + ): Record | void | Promise | void>; +} + +/** + * Context passed to onBeforeSendRequest hook. + */ +export interface OutgoingRequestContext { + /** Message ID for this request */ + readonly messageId: number; + /** Session ID if available */ + readonly sessionId?: string; + /** Original request options (plugins can read task, relatedTask, etc.) */ + readonly requestOptions?: Record; + /** Register a resolver to handle the response */ + registerResolver(resolver: (response: JSONRPCResultResponse | Error) => void): void; +} + +/** + * Context passed to onBeforeSendNotification hook. + */ +export interface OutgoingNotificationContext { + /** Session ID if available */ + readonly sessionId?: string; + /** Related request ID if this notification is in response to a request */ + readonly relatedRequestId?: RequestId; + /** Original notification options (plugins can read relatedTask, etc.) */ + readonly notificationOptions?: Record; +} + +/** + * Base context passed to onBuildHandlerContext hook. + */ +export interface HandlerContextBase { + /** Session ID if available */ + readonly sessionId?: string; + /** The incoming request */ + readonly request: JSONRPCRequest; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Base Plugin Class +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Abstract base class for plugins. + * Provides default no-op implementations for all hooks. + * Plugins only need to override the methods they care about. + */ +export abstract class BasePlugin implements ProtocolPlugin { + abstract readonly name: string; + readonly priority?: number; + + // Default no-op implementations + install?(_ctx: PluginContext): void | Promise { + // Override in subclass + } + + onConnect?(_transport: Transport): void | Promise { + // Override in subclass + } + + onClose?(): void | Promise { + // Override in subclass + } + + shouldRouteMessage?( + _message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + _options?: TransportSendOptions + ): boolean { + return false; + } + + routeMessage?( + _message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + _options?: TransportSendOptions + ): Promise { + return Promise.resolve(); + } + + onRequest?(_request: JSONRPCRequest, _ctx: RequestContext): JSONRPCRequest | void | Promise { + // Override in subclass + } + + onRequestResult?(_request: JSONRPCRequest, _result: Result, _ctx: RequestContext): Result | void | Promise { + // Override in subclass + } + + onRequestError?(_request: JSONRPCRequest, _error: Error, _ctx: RequestContext): Error | void | Promise { + // Override in subclass + } + + onResponse?(_response: JSONRPCResponse | JSONRPCErrorResponse, _messageId: number): void | Promise { + // Override in subclass + } + + onNotification?(_notification: JSONRPCNotification): JSONRPCNotification | void | Promise { + // Override in subclass + } + + onBeforeSendRequest?( + _request: JSONRPCRequest, + _options: OutgoingRequestContext + ): JSONRPCRequest | void | Promise { + // Override in subclass + } + + onBeforeSendNotification?( + _notification: JSONRPCNotification, + _options: OutgoingNotificationContext + ): JSONRPCNotification | void | Promise { + // Override in subclass + } + + onBuildHandlerContext?( + _request: JSONRPCRequest, + _baseContext: HandlerContextBase + ): Record | void | Promise | void> { + // Override in subclass + } +} + +/** + * Helper function to sort plugins by priority (higher priority first) + */ +export function sortPluginsByPriority

(plugins: P[]): P[] { + return plugins.toSorted((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); +} diff --git a/packages/core/src/shared/pluginContext.ts b/packages/core/src/shared/pluginContext.ts new file mode 100644 index 000000000..e6f04d3ca --- /dev/null +++ b/packages/core/src/shared/pluginContext.ts @@ -0,0 +1,192 @@ +/** + * Plugin Context Implementation + * + * This module provides the concrete implementations of the plugin context interfaces. + * These are internal to the SDK and are created by Protocol for plugin installation. + */ + +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + RequestId, + Result +} from '../types/types.js'; +import type { AnyObjectSchema, SchemaOutput } from '../util/zodCompat.js'; +import type { + PluginContext, + PluginHandlerExtra, + PluginHandlersInterface, + PluginRequestOptions, + PluginRequestsInterface, + PluginResolversInterface, + PluginTransportInterface +} from './plugin.js'; +import type { ProgressManagerInterface } from './progressManager.js'; +import type { Transport, TransportSendOptions } from './transport.js'; + +/** + * Protocol interface for plugin context creation. + * This avoids circular dependency with Protocol. + */ +export interface PluginHostProtocol { + readonly transport?: Transport; + request(request: JSONRPCRequest, resultSchema: T, options?: PluginRequestOptions): Promise>; + setRequestHandler( + schema: T, + handler: ( + request: SchemaOutput, + ctx: { mcpCtx: { requestId: RequestId; sessionId?: string }; requestCtx: { signal: AbortSignal } } + ) => SendResultT | Promise + ): void; + setNotificationHandler(schema: T, handler: (notification: SchemaOutput) => void | Promise): void; + removeRequestHandler(method: string): void; + removeNotificationHandler(method: string): void; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Transport Access Implementation +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Implementation of PluginTransportInterface. + * Provides transport-related operations to plugins. + */ +export class PluginTransport implements PluginTransportInterface { + constructor(private readonly getTransportFn: () => Transport | undefined) {} + + getTransport(): Transport | undefined { + return this.getTransportFn(); + } + + getSessionId(): string | undefined { + return this.getTransportFn()?.sessionId; + } + + async send( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise { + await this.getTransportFn()?.send(message, options); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Requests Access Implementation +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Implementation of PluginRequestsInterface. + * Allows plugins to make outbound requests. + */ +export class PluginRequests implements PluginRequestsInterface { + constructor(private readonly protocol: PluginHostProtocol) {} + + async sendRequest( + request: JSONRPCRequest, + resultSchema: T, + options?: PluginRequestOptions + ): Promise> { + return this.protocol.request(request, resultSchema, options); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Handler Registry Implementation +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Implementation of PluginHandlersInterface. + * Allows plugins to register request and notification handlers. + */ +export class PluginHandlers implements PluginHandlersInterface { + constructor(private readonly protocol: PluginHostProtocol) {} + + setRequestHandler( + schema: T, + handler: (request: SchemaOutput, extra: PluginHandlerExtra) => SendResultT | Promise + ): void { + this.protocol.setRequestHandler(schema, (parsedRequest, ctx) => { + const pluginExtra: PluginHandlerExtra = { + mcpCtx: { + requestId: ctx.mcpCtx.requestId, + sessionId: ctx.mcpCtx.sessionId + }, + requestCtx: { + signal: ctx.requestCtx.signal + } + }; + return handler(parsedRequest, pluginExtra); + }); + } + + setNotificationHandler(schema: T, handler: (notification: SchemaOutput) => void | Promise): void { + this.protocol.setNotificationHandler(schema, handler); + } + + removeRequestHandler(method: string): void { + this.protocol.removeRequestHandler(method); + } + + removeNotificationHandler(method: string): void { + this.protocol.removeNotificationHandler(method); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Request Resolver Implementation +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Implementation of PluginResolversInterface. + * Manages request resolvers for routing queued responses. + */ +export class PluginResolvers implements PluginResolversInterface { + constructor(private readonly resolvers: Map void>) {} + + register(id: RequestId, resolver: (response: JSONRPCResultResponse | Error) => void): void { + this.resolvers.set(id, resolver); + } + + get(id: RequestId): ((response: JSONRPCResultResponse | Error) => void) | undefined { + return this.resolvers.get(id); + } + + remove(id: RequestId): void { + this.resolvers.delete(id); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Factory Function +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Configuration for creating a PluginContext. + */ +export interface PluginContextConfig { + protocol: PluginHostProtocol; + getTransport: () => Transport | undefined; + resolvers: Map void>; + progressManager: ProgressManagerInterface; + reportError: (error: Error) => void; +} + +/** + * Creates a PluginContext from the given configuration. + * This is called once by Protocol and cached for reuse. + */ +export function createPluginContext( + config: PluginContextConfig +): PluginContext { + return { + transport: new PluginTransport(config.getTransport), + requests: new PluginRequests(config.protocol), + handlers: new PluginHandlers(config.protocol), + resolvers: new PluginResolvers(config.resolvers), + progress: config.progressManager, + reportError: config.reportError + }; +} diff --git a/packages/core/src/shared/progressManager.ts b/packages/core/src/shared/progressManager.ts new file mode 100644 index 000000000..a20070cc2 --- /dev/null +++ b/packages/core/src/shared/progressManager.ts @@ -0,0 +1,126 @@ +/** + * Progress Manager + * + * Manages progress tracking for the Protocol class. + * Extracted from Protocol to follow Single Responsibility Principle. + */ + +import type { Progress, ProgressNotification } from '../types/types.js'; + +/** + * Callback for progress notifications. + */ +export type ProgressCallback = (progress: Progress) => void; + +/** + * Interface for progress management. + * Plugins use this interface to register and manage progress handlers. + */ +export interface ProgressManagerInterface { + /** + * Registers a progress callback for a message. + * @param messageId - The message ID (used as progress token) + * @param callback - The callback to invoke when progress is received + */ + registerHandler(messageId: number, callback: ProgressCallback): void; + + /** + * Gets the progress callback for a message. + * @param messageId - The message ID + * @returns The progress callback or undefined + */ + getHandler(messageId: number): ProgressCallback | undefined; + + /** + * Removes the progress callback for a message. + * @param messageId - The message ID + */ + removeHandler(messageId: number): void; + + /** + * Handles an incoming progress notification. + * @param notification - The progress notification + * @returns true if handled, false if no handler was found + */ + handleProgress(notification: ProgressNotification): boolean; +} + +/** + * Manages progress tracking for requests. + */ +export class ProgressManager implements ProgressManagerInterface { + /** + * Maps message IDs to progress callbacks + */ + private _progressHandlers: Map = new Map(); + + /** + * Registers a progress callback for a message. + * + * @param messageId - The message ID (used as progress token) + * @param callback - The callback to invoke when progress is received + */ + registerHandler(messageId: number, callback: ProgressCallback): void { + this._progressHandlers.set(messageId, callback); + } + + /** + * Gets the progress callback for a message. + * + * @param messageId - The message ID + * @returns The progress callback or undefined + */ + getHandler(messageId: number): ProgressCallback | undefined { + return this._progressHandlers.get(messageId); + } + + /** + * Removes the progress callback for a message. + * + * @param messageId - The message ID + */ + removeHandler(messageId: number): void { + this._progressHandlers.delete(messageId); + } + + /** + * Handles an incoming progress notification. + * Returns true if the progress was handled, false if no handler was found. + * + * @param notification - The progress notification + * @returns true if handled, false otherwise + */ + handleProgress(notification: ProgressNotification): boolean { + const token = notification.params.progressToken; + if (typeof token !== 'number') { + // Token must be a number for our internal tracking + return false; + } + + const callback = this._progressHandlers.get(token); + if (callback) { + callback({ + progress: notification.params.progress, + total: notification.params.total, + message: notification.params.message + }); + return true; + } + + return false; + } + + /** + * Clears all progress handlers. + */ + clear(): void { + this._progressHandlers.clear(); + } + + /** + * Gets the number of active progress handlers. + */ + get handlerCount(): number { + return this._progressHandlers.size; + } +} diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 74764077b..eff48d57e 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1,12 +1,7 @@ -import { RequestTaskStore } from '../experimental/requestTaskStore.js'; -import type { QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; -import { isTerminal } from '../experimental/tasks/interfaces.js'; +import { StateError } from '../errors.js'; import type { CancelledNotification, ClientCapabilities, - GetTaskPayloadRequest, - GetTaskRequest, - GetTaskResult, JSONRPCErrorResponse, JSONRPCNotification, JSONRPCRequest, @@ -14,48 +9,47 @@ import type { JSONRPCResultResponse, MessageExtraInfo, Notification, - Progress, ProgressNotification, - RelatedTaskMetadata, Request, RequestId, Result, - ServerCapabilities, - TaskCreationParams + ServerCapabilities } from '../types/types.js'; import { CancelledNotificationSchema, - CancelTaskRequestSchema, - CancelTaskResultSchema, - CreateTaskResultSchema, ErrorCode, - GetTaskPayloadRequestSchema, - GetTaskRequestSchema, - GetTaskResultSchema, isJSONRPCErrorResponse, isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResultResponse, - isTaskAugmentedRequestParams, - ListTasksRequestSchema, - ListTasksResultSchema, McpError, PingRequestSchema, - ProgressNotificationSchema, - RELATED_TASK_META_KEY + ProgressNotificationSchema } from '../types/types.js'; import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/zodCompat.js'; import { safeParse } from '../util/zodCompat.js'; import { getMethodLiteral, parseWithCompat } from '../util/zodJsonSchemaCompat.js'; import type { BaseRequestContext, ContextInterface } from './context.js'; +import type { McpEventEmitter } from './events.js'; +import { TypedEventEmitter } from './events.js'; +import { HandlerRegistry } from './handlerRegistry.js'; +import type { + HandlerContextBase, + OutgoingNotificationContext, + OutgoingRequestContext, + PluginContext, + PluginRequestOptions, + ProtocolPlugin, + RequestContext +} from './plugin.js'; +import { sortPluginsByPriority } from './plugin.js'; +import { createPluginContext } from './pluginContext.js'; +import type { ProgressCallback } from './progressManager.js'; +import { ProgressManager } from './progressManager.js'; import type { ResponseMessage } from './responseMessage.js'; +import { TimeoutManager } from './timeoutManager.js'; import type { Transport, TransportSendOptions } from './transport.js'; -/** - * Callback for progress notifications. - */ -export type ProgressCallback = (progress: Progress) => void; - /** * Additional initialization options. */ @@ -75,29 +69,6 @@ export type ProtocolOptions = { * e.g., ['notifications/tools/list_changed'] */ debouncedNotificationMethods?: string[]; - /** - * Optional task storage implementation. If provided, enables task-related request handlers - * and provides task storage capabilities to request handlers. - */ - taskStore?: TaskStore; - /** - * Optional task message queue implementation for managing server-initiated messages - * that will be delivered through the tasks/result response stream. - */ - taskMessageQueue?: TaskMessageQueue; - /** - * Default polling interval (in milliseconds) for task status checks when no pollInterval - * is provided by the server. Defaults to 5000ms if not specified. - */ - defaultTaskPollInterval?: number; - /** - * Maximum number of messages that can be queued per task for side-channel delivery. - * If undefined, the queue size is unbounded. - * When the limit is exceeded, the TaskMessageQueue implementation's enqueue() method - * will throw an error. It's the implementation's responsibility to handle overflow - * appropriately (e.g., by failing the task, dropping messages, etc.). - */ - maxTaskQueueSize?: number; }; /** @@ -107,12 +78,30 @@ export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60_000; /** * Options that can be given per request. + * + * ## Plugin Extension Pattern + * + * Plugins can define their own typed options by creating intersection types. + * For type safety at call sites, use the plugin-specific type with `satisfies`: + * + * @example + * ```typescript + * import type { TaskRequestOptions } from '@modelcontextprotocol/core'; + * + * // Type-safe task options + * await ctx.sendRequest(req, schema, { + * task: { ttl: 60000 }, + * relatedTask: { taskId: 'parent-123' } + * } satisfies TaskRequestOptions); + * ``` + * + * The index signature allows plugins to read their options from the + * `requestOptions` field in their `onBeforeSendRequest` hooks. */ export type RequestOptions = { /** - * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. - * - * For task-augmented requests: progress notifications continue after CreateTaskResult is returned and stop automatically when the task reaches a terminal status. + * If set, requests progress notifications from the remote end (if supported). + * When progress notifications are received, this callback will be invoked. */ onprogress?: ProgressCallback; @@ -142,19 +131,30 @@ export type RequestOptions = { */ maxTotalTimeout?: number; - /** - * If provided, augments the request with task creation parameters to enable call-now, fetch-later execution patterns. - */ - task?: TaskCreationParams; - - /** - * If provided, associates this request with a related task. - */ - relatedTask?: RelatedTaskMetadata; + /** Allow plugin-specific options via index signature */ + [key: string]: unknown; } & TransportSendOptions; /** * Options that can be given per notification. + * + * ## Plugin Extension Pattern + * + * Plugins can define their own typed options by creating intersection types. + * For type safety at call sites, use the plugin-specific type with `satisfies`: + * + * @example + * ```typescript + * import type { TaskNotificationOptions } from '@modelcontextprotocol/core'; + * + * // Type-safe task options + * await ctx.sendNotification(notification, { + * relatedTask: { taskId: 'parent-123' } + * } satisfies TaskNotificationOptions); + * ``` + * + * The index signature allows plugins to read their options from the + * `notificationOptions` field in their `onBeforeSendNotification` hooks. */ export type NotificationOptions = { /** @@ -162,28 +162,112 @@ export type NotificationOptions = { */ relatedRequestId?: RequestId; + /** Allow plugin-specific options via index signature */ + [key: string]: unknown; +}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Error Interception +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Context provided to error interceptors. + */ +export interface ErrorInterceptionContext { /** - * If provided, associates this notification with a related task. + * The type of error: + * - 'protocol': Protocol-level errors (method not found, parse error, etc.) + * - 'application': Application errors (handler threw an exception) */ - relatedTask?: RelatedTaskMetadata; -}; + type: 'protocol' | 'application'; + + /** + * The method that was being called when the error occurred. + */ + method: string; + + /** + * The request ID from the JSON-RPC message. + */ + requestId: RequestId; + + /** + * For protocol errors, the fixed error code that cannot be changed. + * For application errors, the error code that will be used (can be modified via returned Error). + */ + errorCode: number; +} /** - * Options that can be given per request. + * Result from an error interceptor that can modify the error response. */ -// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. -export type TaskRequestOptions = Omit; +export interface ErrorInterceptionResult { + /** + * Override the error message. If not provided, the original error message is used. + */ + message?: string; + + /** + * Additional data to include in the error response. + */ + data?: unknown; + + /** + * For application errors only: override the error code. + * Ignored for protocol errors (they have fixed codes per MCP spec). + */ + code?: number; +} + /** - * Information about a request's timeout state + * Error interceptor function type. + * Called before sending error responses, allows customizing the error. + * + * @param error - The original error + * @param context - Context about where the error occurred + * @returns Optional modifications to the error response, or void to use defaults */ -type TimeoutInfo = { - timeoutId: ReturnType; - startTime: number; - timeout: number; - maxTotalTimeout?: number; - resetTimeoutOnProgress: boolean; - onTimeout: () => void; -}; +export type ErrorInterceptor = ( + error: Error, + context: ErrorInterceptionContext +) => ErrorInterceptionResult | void | Promise; + +// ═══════════════════════════════════════════════════════════════════════════ +// Protocol Events +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Events emitted by the Protocol class. + * + * @example + * ```typescript + * const unsubscribe = protocol.events.on('connection:opened', ({ sessionId }) => { + * console.log(`Connected with session: ${sessionId}`); + * }); + * + * protocol.events.on('error', ({ error, context }) => { + * console.error(`Protocol error in ${context}:`, error); + * }); + * ``` + */ +export interface ProtocolEvents { + [key: string]: unknown; + + /** + * Emitted when a connection is successfully established. + */ + 'connection:opened': { sessionId?: string }; + + /** + * Emitted when the connection is closed. + */ + 'connection:closed': { sessionId?: string; reason?: string }; + + /** + * Emitted when an error occurs during protocol operations. + */ + error: { error: Error; context?: string }; +} /** * Implements MCP protocol framing on top of a pluggable transport, including @@ -192,25 +276,25 @@ type TimeoutInfo = { export abstract class Protocol { private _transport?: Transport; private _requestMessageId = 0; - private _requestHandlers: Map< - string, - (request: JSONRPCRequest, extra: ContextInterface) => Promise - > = new Map(); - private _requestHandlerAbortControllers: Map = new Map(); - private _notificationHandlers: Map Promise> = new Map(); private _responseHandlers: Map void> = new Map(); - private _progressHandlers: Map = new Map(); - private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); - // Maps task IDs to progress tokens to keep handlers alive after CreateTaskResult - private _taskProgressTokens: Map = new Map(); + // Extracted managers + private _timeoutManager = new TimeoutManager(); + private _progressManager = new ProgressManager(); + private _handlerRegistry = new HandlerRegistry(); - private _taskStore?: TaskStore; - private _taskMessageQueue?: TaskMessageQueue; + // Plugin system + private _plugins: ProtocolPlugin[] = []; private _requestResolvers: Map void> = new Map(); + // Event emitter for observability + private _events = new TypedEventEmitter(); + + // Error interception callback + private _errorInterceptor?: ErrorInterceptor; + /** * Callback for when the connection is closed for any reason. * @@ -225,18 +309,80 @@ export abstract class Protocol void; + /** + * Event emitter for observability and monitoring. + * + * Subscribe to events like connection lifecycle, errors, etc. + * + * @example + * ```typescript + * protocol.events.on('connection:opened', ({ sessionId }) => { + * console.log(`Connected: ${sessionId}`); + * }); + * + * protocol.events.on('error', ({ error }) => { + * console.error('Protocol error:', error); + * }); + * ``` + */ + get events(): McpEventEmitter { + return this._events; + } + + /** + * Sets an error interceptor that can customize error responses before they are sent. + * + * The interceptor is called for both protocol errors (method not found, etc.) and + * application errors (when a handler throws). It can modify the error message and data, + * and for application errors, can also change the error code. + * + * @param interceptor - The error interceptor function, or undefined to clear + * + * @example + * ```typescript + * server.setErrorInterceptor(async (error, ctx) => { + * console.error(`Error in ${ctx.method}: ${error.message}`); + * return { + * message: 'An error occurred', + * data: { originalMessage: error.message } + * }; + * }); + * ``` + */ + protected setErrorInterceptor(interceptor: ErrorInterceptor | undefined): void { + this._errorInterceptor = interceptor; + } + /** * A handler to invoke for any request types that do not have their own handler installed. */ - fallbackRequestHandler?: ( - request: JSONRPCRequest, - extra: ContextInterface - ) => Promise; + get fallbackRequestHandler(): + | ((request: JSONRPCRequest, extra: ContextInterface) => Promise) + | undefined { + return this._handlerRegistry.fallbackRequestHandler; + } + + set fallbackRequestHandler( + handler: + | (( + request: JSONRPCRequest, + extra: ContextInterface + ) => Promise) + | undefined + ) { + this._handlerRegistry.fallbackRequestHandler = handler; + } /** * A handler to invoke for any notification types that do not have their own handler installed. */ - fallbackNotificationHandler?: (notification: Notification) => Promise; + get fallbackNotificationHandler(): ((notification: Notification) => Promise) | undefined { + return this._handlerRegistry.fallbackNotificationHandler; + } + + set fallbackNotificationHandler(handler: ((notification: Notification) => Promise) | undefined) { + this._handlerRegistry.fallbackNotificationHandler = handler; + } constructor(private _options?: ProtocolOptions) { this.setNotificationHandler(CancelledNotificationSchema, notification => { @@ -252,182 +398,258 @@ export abstract class Protocol ({}) as SendResultT ); + } - // Install task handlers if TaskStore is provided - this._taskStore = _options?.taskStore; - this._taskMessageQueue = _options?.taskMessageQueue; - if (this._taskStore) { - this.setRequestHandler(GetTaskRequestSchema, async (request, ctx) => { - const task = await this._taskStore!.getTask(request.params.taskId, ctx.mcpCtx.sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - - // Per spec: tasks/get responses SHALL NOT include related-task metadata - // as the taskId parameter is the source of truth - // @ts-expect-error SendResultT cannot contain GetTaskResult, but we include it in our derived types everywhere else - return { - ...task - } as SendResultT; - }); + // ═══════════════════════════════════════════════════════════════════════════ + // Plugin System + // ═══════════════════════════════════════════════════════════════════════════ - this.setRequestHandler(GetTaskPayloadRequestSchema, async (request, ctx) => { - const handleTaskResult = async (): Promise => { - const taskId = request.params.taskId; - - // Deliver queued messages - if (this._taskMessageQueue) { - let queuedMessage: QueuedMessage | undefined; - while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, ctx.mcpCtx.sessionId))) { - // Handle response and error messages by routing them to the appropriate resolver - if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { - const message = queuedMessage.message; - const requestId = message.id; - - // Lookup resolver in _requestResolvers map - const resolver = this._requestResolvers.get(requestId as RequestId); - - if (resolver) { - // Remove resolver from map after invocation - this._requestResolvers.delete(requestId as RequestId); - - // Invoke resolver with response or error - if (queuedMessage.type === 'response') { - resolver(message as JSONRPCResultResponse); - } else { - // Convert JSONRPCError to McpError - const errorMessage = message as JSONRPCErrorResponse; - const error = new McpError( - errorMessage.error.code, - errorMessage.error.message, - errorMessage.error.data - ); - resolver(error); - } - } else { - // Handle missing resolver gracefully with error logging - const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; - this._onerror(new Error(`${messageType} handler missing for request ${requestId}`)); - } - - // Continue to next message - continue; - } + /** + * Registers a plugin with the protocol. + * Plugins are installed immediately and sorted by priority. + * + * @param plugin - The plugin to register + * @returns this for chaining + */ + usePlugin(plugin: ProtocolPlugin): this { + this._plugins.push(plugin); + this._plugins = sortPluginsByPriority(this._plugins); - // Send the message on the response stream by passing the relatedRequestId - // This tells the transport to write the message to the tasks/result response stream - await this._transport?.send(queuedMessage.message, { relatedRequestId: ctx.mcpCtx.requestId }); - } - } + // Install the plugin immediately + const ctx = this._getPluginContext(); + plugin.install?.(ctx); - // Now check task status - const task = await this._taskStore!.getTask(taskId, ctx.mcpCtx.sessionId); - if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); - } + return this; + } - // Block if task is not terminal (we've already delivered all queued messages above) - if (!isTerminal(task.status)) { - // Wait for status change or new messages - await this._waitForTaskUpdate(taskId, ctx.requestCtx.signal); + /** + * Retrieves a registered plugin by its class. + * Returns undefined if the plugin is not registered. + * + * @param PluginClass - The plugin class to find + * @returns The plugin instance or undefined + * + * @example + * ```typescript + * const taskPlugin = server.getPlugin(TaskPlugin); + * if (taskPlugin) { + * // Access plugin-specific methods + * } + * ``` + */ + getPlugin>(PluginClass: abstract new (...args: unknown[]) => T): T | undefined { + return this._plugins.find((p): p is T => p instanceof PluginClass); + } - // After waking up, recursively call to deliver any new messages or result - return await handleTaskResult(); - } + /** + * Cached plugin context, created once and reused for all plugins. + */ + private _pluginContext?: PluginContext; - // If task is terminal, return the result - if (isTerminal(task.status)) { - const result = await this._taskStore!.getTaskResult(taskId, ctx.mcpCtx.sessionId); + /** + * Gets or creates the plugin context for plugin installation. + * The context is created once and cached for reuse. + */ + private _getPluginContext(): PluginContext { + if (!this._pluginContext) { + this._pluginContext = createPluginContext({ + protocol: this._createPluginHostProtocol(), + getTransport: () => this._transport, + resolvers: this._requestResolvers, + progressManager: this._progressManager, + reportError: error => this._onerror(error, 'plugin') + }); + } + return this._pluginContext; + } - this._clearTaskQueue(taskId); + /** + * Creates the protocol interface for plugin context. + * This provides a typed view of Protocol for the plugin system. + */ + private _createPluginHostProtocol() { + return { + transport: this._transport, + request: (request: JSONRPCRequest, resultSchema: T, options?: PluginRequestOptions) => + this.request(request as SendRequestT, resultSchema, options), + setRequestHandler: ( + schema: T, + handler: ( + request: SchemaOutput, + ctx: { mcpCtx: { requestId: RequestId; sessionId?: string }; requestCtx: { signal: AbortSignal } } + ) => SendResultT | Promise + ) => this.setRequestHandler(schema, handler), + setNotificationHandler: ( + schema: T, + handler: (notification: SchemaOutput) => void | Promise + ) => this.setNotificationHandler(schema, handler), + removeRequestHandler: (method: string) => this.removeRequestHandler(method), + removeNotificationHandler: (method: string) => this.removeNotificationHandler(method) + }; + } - return { - ...result, - _meta: { - ...result._meta, - [RELATED_TASK_META_KEY]: { - taskId: taskId - } - } - } as SendResultT; - } + /** + * Calls onConnect on all plugins. + */ + private async _notifyPluginsConnect(transport: Transport): Promise { + for (const plugin of this._plugins) { + await plugin.onConnect?.(transport); + } + } - return await handleTaskResult(); - }; + /** + * Calls onClose on all plugins. + */ + private async _notifyPluginsClose(): Promise { + for (const plugin of this._plugins) { + await plugin.onClose?.(); + } + } - return await handleTaskResult(); - }); + /** + * Checks if any plugin wants to route a message instead of the default transport. + */ + private _findMessageRouter( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): ProtocolPlugin | undefined { + return this._plugins.find(p => p.shouldRouteMessage?.(message, options)); + } - this.setRequestHandler(ListTasksRequestSchema, async (request, ctx) => { - try { - const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, ctx.mcpCtx.sessionId); - // @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else - return { - tasks, - nextCursor, - _meta: {} - } as SendResultT; - } catch (error) { - throw new McpError( - ErrorCode.InvalidParams, - `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` - ); - } - }); + /** + * Calls onRequest on all plugins, allowing them to modify the request. + */ + private async _runPluginOnRequest(request: JSONRPCRequest, ctx: RequestContext): Promise { + let current = request; + for (const plugin of this._plugins) { + const modified = await plugin.onRequest?.(current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - this.setRequestHandler(CancelTaskRequestSchema, async (request, ctx) => { - try { - // Get the current task to check if it's in a terminal state, in case the implementation is not atomic - const task = await this._taskStore!.getTask(request.params.taskId, ctx.mcpCtx.sessionId); + /** + * Calls onRequestResult on all plugins, allowing them to modify the result. + */ + private async _runPluginOnRequestResult(request: JSONRPCRequest, result: Result, ctx: RequestContext): Promise { + let current = result; + for (const plugin of this._plugins) { + const modified = await plugin.onRequestResult?.(request, current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); - } + /** + * Calls onRequestError on all plugins, allowing them to modify the error. + */ + private async _runPluginOnRequestError(request: JSONRPCRequest, error: Error, ctx: RequestContext): Promise { + let current = error; + for (const plugin of this._plugins) { + const modified = await plugin.onRequestError?.(request, current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - // Reject cancellation of terminal tasks - if (isTerminal(task.status)) { - throw new McpError(ErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); - } + /** + * Calls onNotification on all plugins, allowing them to modify the notification. + */ + private async _runPluginOnNotification(notification: JSONRPCNotification): Promise { + let current = notification; + for (const plugin of this._plugins) { + const modified = await plugin.onNotification?.(current); + if (modified) { + current = modified; + } + } + return current; + } - await this._taskStore!.updateTaskStatus( - request.params.taskId, - 'cancelled', - 'Client cancelled task execution.', - ctx.mcpCtx.sessionId - ); + /** + * Calls onBeforeSendRequest on all plugins, allowing them to augment the request. + */ + private async _runPluginOnBeforeSendRequest(request: JSONRPCRequest, ctx: OutgoingRequestContext): Promise { + let current = request; + for (const plugin of this._plugins) { + const modified = await plugin.onBeforeSendRequest?.(current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - this._clearTaskQueue(request.params.taskId); + /** + * Calls onBeforeSendNotification on all plugins, allowing them to augment the notification. + */ + private async _runPluginOnBeforeSendNotification( + notification: JSONRPCNotification, + ctx: OutgoingNotificationContext + ): Promise { + let current = notification; + for (const plugin of this._plugins) { + const modified = await plugin.onBeforeSendNotification?.(current, ctx); + if (modified) { + current = modified; + } + } + return current; + } - const cancelledTask = await this._taskStore!.getTask(request.params.taskId, ctx.mcpCtx.sessionId); - if (!cancelledTask) { - // Task was deleted during cancellation (e.g., cleanup happened) - throw new McpError(ErrorCode.InvalidParams, `Task not found after cancellation: ${request.params.taskId}`); - } + /** + * Calls onBuildHandlerContext on all plugins, merging additional context. + */ + private async _runPluginOnBuildHandlerContext( + request: JSONRPCRequest, + baseContext: HandlerContextBase + ): Promise> { + const additions: Record = {}; + for (const plugin of this._plugins) { + const pluginContext = await plugin.onBuildHandlerContext?.(request, baseContext); + if (pluginContext) { + Object.assign(additions, pluginContext); + } + } + return additions; + } - return { - _meta: {}, - ...cancelledTask - } as unknown as SendResultT; - } catch (error) { - // Re-throw McpError as-is - if (error instanceof McpError) { - throw error; - } - throw new McpError( - ErrorCode.InvalidRequest, - `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` - ); - } - }); + /** + * Routes a message through plugins or transport. + * Plugins can intercept messages (e.g., for task queueing) via shouldRouteMessage/routeMessage. + */ + private async _routeMessage( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise { + // Check if any plugin wants to route this message + for (const plugin of this._plugins) { + if (plugin.shouldRouteMessage?.(message, options)) { + await plugin.routeMessage?.(message, options); + return; + } } + + // No plugin routing - send via transport + await this._transport?.send(message, options); } + // ═══════════════════════════════════════════════════════════════════════════ + // Internal Handlers + // ═══════════════════════════════════════════════════════════════════════════ + private async _oncancel(notification: CancelledNotification): Promise { if (!notification.params.requestId) { return; } // Handle request cancellation - const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); + const controller = this._handlerRegistry.getAbortController(notification.params.requestId); controller?.abort(notification.params.reason); } @@ -438,9 +660,7 @@ export abstract class Protocol void, resetTimeoutOnProgress: boolean = false ) { - this._timeoutInfo.set(messageId, { - timeoutId: setTimeout(onTimeout, timeout), - startTime: Date.now(), + this._timeoutManager.setup(messageId, { timeout, maxTotalTimeout, resetTimeoutOnProgress, @@ -449,29 +669,26 @@ export abstract class Protocol= info.maxTotalTimeout) { - this._timeoutInfo.delete(messageId); - throw McpError.fromError(ErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { - maxTotalTimeout: info.maxTotalTimeout, - totalElapsed - }); + // Check max total timeout before delegating to manager + if (info.maxTotalTimeout) { + const totalElapsed = Date.now() - info.startTime; + if (totalElapsed >= info.maxTotalTimeout) { + this._timeoutManager.cleanup(messageId); + throw McpError.fromError(ErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { + maxTotalTimeout: info.maxTotalTimeout, + totalElapsed + }); + } } - clearTimeout(info.timeoutId); - info.timeoutId = setTimeout(info.onTimeout, info.timeout); - return true; + return this._timeoutManager.reset(messageId); } private _cleanupTimeout(messageId: number) { - const info = this._timeoutInfo.get(messageId); - if (info) { - clearTimeout(info.timeoutId); - this._timeoutInfo.delete(messageId); - } + this._timeoutManager.cleanup(messageId); } /** @@ -503,36 +720,96 @@ export abstract class Protocol this._onerror(error_, 'plugin-close')); + for (const handler of responseHandlers.values()) { handler(error); } } - private _onerror(error: Error): void { + private _onerror(error: Error, context?: string): void { this.onerror?.(error); + this._events.emit('error', { error, context }); + } + + /** + * Sends a protocol-level error response (e.g., method not found, parse error). + * Protocol errors have fixed error codes per MCP spec - the interceptor can only + * modify the message and data, not the code. + */ + private _sendProtocolError(request: JSONRPCRequest, errorCode: number, defaultMessage: string, sessionId: string | undefined): void { + const error = new McpError(errorCode, defaultMessage); + + // Call error interceptor if set (async, fire-and-forget for the interception result usage) + Promise.resolve() + .then(async () => { + let message = defaultMessage; + let data: unknown; + + if (this._errorInterceptor) { + const ctx: ErrorInterceptionContext = { + type: 'protocol', + method: request.method, + requestId: request.id, + errorCode + }; + const result = await this._errorInterceptor(error, ctx); + if (result) { + message = result.message ?? message; + data = result.data; + // Note: result.code is ignored for protocol errors (fixed codes per MCP spec) + } + } + + const errorResponse: JSONRPCErrorResponse = { + jsonrpc: '2.0', + id: request.id, + error: { + code: errorCode, + message, + ...(data !== undefined && { data }) + } + }; + + // Route error response through plugins + await this._routeMessage(errorResponse, { sessionId }); + }) + .catch(error_ => this._onerror(new Error(`Failed to send error response: ${error_}`), 'send-error-response')); } private _onnotification(notification: JSONRPCNotification): void { - const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; + const handler = this._handlerRegistry.getNotificationHandler(notification.method); // Ignore notifications not being subscribed to. if (handler === undefined) { @@ -541,73 +818,59 @@ export abstract class Protocol handler(notification)) - .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); + .then(async () => { + // Let plugins modify the notification + const modifiedNotification = await this._runPluginOnNotification(notification); + return handler(modifiedNotification); + }) + .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`), 'notification-handler')); } private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { - const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; + const handler = this._handlerRegistry.getRequestHandler(request.method); // Capture the current transport at request time to ensure responses go to the correct client const capturedTransport = this._transport; - // Extract taskId from request metadata if present (needed early for method not found case) - const relatedTaskId = request.params?._meta?.[RELATED_TASK_META_KEY]?.taskId; - if (handler === undefined) { - const errorResponse: JSONRPCErrorResponse = { - jsonrpc: '2.0', - id: request.id, - error: { - code: ErrorCode.MethodNotFound, - message: 'Method not found' - } - }; - - // Queue or send the error response based on whether this is a task-related request - if (relatedTaskId && this._taskMessageQueue) { - this._enqueueTaskMessage( - relatedTaskId, - { - type: 'error', - message: errorResponse, - timestamp: Date.now() - }, - capturedTransport?.sessionId - ).catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); - } else { - capturedTransport - ?.send(errorResponse) - .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); - } + // Handle method not found - this is a protocol error + this._sendProtocolError(request, ErrorCode.MethodNotFound, 'Method not found', capturedTransport?.sessionId); return; } - const abortController = new AbortController(); - this._requestHandlerAbortControllers.set(request.id, abortController); - - const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; + const abortController = this._handlerRegistry.createAbortController(request.id); + const sessionId = capturedTransport?.sessionId; - const fullExtra: ContextInterface = this.createRequestContext({ + const baseExtra: ContextInterface = this.createRequestContext({ request, - taskStore: this._taskStore, - relatedTaskId, - taskCreationParams, abortController, capturedTransport, extra }); + // Build plugin request context + const pluginReqCtx: RequestContext = { + requestId: request.id, + method: request.method, + sessionId + }; + // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() - .then(() => { - // If this request asked for task creation, check capability first - if (taskCreationParams) { - // Check if the request method supports task creation - this.assertTaskHandlerCapability(request.method); + // Let plugins modify the request + .then(() => this._runPluginOnRequest(request, pluginReqCtx)) + .then(async modifiedRequest => { + // Let plugins contribute additional context (e.g., task context) + const additionalContext = await this._runPluginOnBuildHandlerContext(request, { sessionId, request }); + + // Assign additional context properties to the existing context object + // This preserves the prototype chain (instanceof checks work) + if (additionalContext) { + Object.assign(baseExtra, additionalContext); } + + return handler(modifiedRequest, baseExtra); }) - .then(() => handler(request, fullExtra)) .then( async result => { if (abortController.signal.aborted) { @@ -615,24 +878,19 @@ export abstract class Protocol { if (abortController.signal.aborted) { @@ -640,33 +898,55 @@ export abstract class Protocol this._onerror(new Error(`Failed to send response: ${error}`))) + .catch(error => this._onerror(new Error(`Failed to send response: ${error}`), 'send-response')) .finally(() => { - this._requestHandlerAbortControllers.delete(request.id); + this._handlerRegistry.removeAbortController(request.id); }); } @@ -688,50 +968,6 @@ export abstract class Protocol; - if (result.task && typeof result.task === 'object') { - const task = result.task as Record; - if (typeof task.taskId === 'string') { - isTaskResponse = true; - this._taskProgressTokens.set(task.taskId, messageId); - } - } - } + // Let plugins process the response (e.g., for task progress management) + // Plugins can inspect the response and manage progress handlers via getProgressManager() + this._runPluginOnOutboundResponse(response, messageId); - if (!isTaskResponse) { - this._progressHandlers.delete(messageId); - } + // Default: remove progress handler + // Plugins that need to keep progress handlers active should re-register them in their onResponse hook + this._progressManager.removeHandler(messageId); if (isJSONRPCResultResponse(response)) { handler(response); @@ -826,6 +1053,15 @@ export abstract class Protocol( request: SendRequestT, resultSchema: T, options?: RequestOptions ): AsyncGenerator>, void, void> { - const { task } = options ?? {}; - - // For non-task requests, just yield the result - if (!task) { - try { - const result = await this.request(request, resultSchema, options); - yield { type: 'result', result }; - } catch (error) { - yield { - type: 'error', - error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) - }; - } - return; - } - - // For task-augmented requests, we need to poll for status - // First, make the request to create the task - let taskId: string | undefined; try { - // Send the request and get the CreateTaskResult - const createResult = await this.request(request, CreateTaskResultSchema, options); - - // Extract taskId from the result - if (createResult.task) { - taskId = createResult.task.taskId; - yield { type: 'taskCreated', task: createResult.task }; - } else { - throw new McpError(ErrorCode.InternalError, 'Task creation did not return a task'); - } - - // Poll for task completion - while (true) { - // Get current task status - const task = await this.getTask({ taskId }, options); - yield { type: 'taskStatus', task }; - - // Check if task is terminal - if (isTerminal(task.status)) { - switch (task.status) { - case 'completed': { - // Get the final result - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - - break; - } - case 'failed': { - yield { - type: 'error', - error: new McpError(ErrorCode.InternalError, `Task ${taskId} failed`) - }; - - break; - } - case 'cancelled': { - yield { - type: 'error', - error: new McpError(ErrorCode.InternalError, `Task ${taskId} was cancelled`) - }; - - break; - } - // No default - } - return; - } - - // When input_required, call tasks/result to deliver queued messages - // (elicitation, sampling) via SSE and block until terminal - if (task.status === 'input_required') { - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - return; - } - - // Wait before polling again - const pollInterval = task.pollInterval ?? this._options?.defaultTaskPollInterval ?? 1000; - await new Promise(resolve => setTimeout(resolve, pollInterval)); - - // Check if cancelled - options?.signal?.throwIfAborted(); - } + const result = await this.request(request, resultSchema, options); + yield { type: 'result', result }; } catch (error) { yield { type: 'error', @@ -1001,7 +1135,7 @@ export abstract class Protocol(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; + const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; // Send the request return new Promise>((resolve, reject) => { @@ -1017,11 +1151,6 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); - this._progressHandlers.delete(messageId); + this._progressManager.removeHandler(messageId); this._cleanupTimeout(messageId); this._transport @@ -1084,7 +1194,7 @@ export abstract class Protocol this._onerror(new Error(`Failed to send cancellation: ${error}`))); + .catch(error => this._onerror(new Error(`Failed to send cancellation: ${error}`), 'send-cancellation')); // Wrap the reason in an McpError if it isn't already const error = reason instanceof McpError ? reason : new McpError(ErrorCode.RequestTimeout, String(reason)); @@ -1122,129 +1232,69 @@ export abstract class Protocol { - const handler = this._responseHandlers.get(messageId); - if (handler) { - handler(response); - } else { - // Log error when resolver is missing, but don't fail - this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); - } - }; - this._requestResolvers.set(messageId, responseResolver); - - this._enqueueTaskMessage(relatedTaskId, { - type: 'request', - message: jsonrpcRequest, - timestamp: Date.now() - }).catch(error => { - this._cleanupTimeout(messageId); - reject(error); - }); + // Create plugin context for outgoing request + const outgoingCtx: OutgoingRequestContext = { + messageId, + sessionId: this._transport?.sessionId, + requestOptions: options as Record, + registerResolver: () => { + // Register resolver so responses can be routed back (used by task plugin) + const responseResolver = (response: JSONRPCResultResponse | Error) => { + const handler = this._responseHandlers.get(messageId); + if (handler) { + handler(response); + } else { + this._onerror( + new Error(`Response handler missing for side-channeled request ${messageId}`), + 'side-channel-routing' + ); + } + }; + this._requestResolvers.set(messageId, responseResolver); + } + }; - // Don't send through transport - queued messages are delivered via tasks/result only - // This prevents duplicate delivery for bidirectional transports - } else { - // No related task OR no task message queue configured - send through transport normally - // Note: relatedTask metadata is still included in jsonrpcRequest.params._meta for receivers to see - this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + // Let plugins augment the request (e.g., add task metadata) + this._runPluginOnBeforeSendRequest(jsonrpcRequest, outgoingCtx) + .then(modifiedRequest => { + jsonrpcRequest = modifiedRequest; + + // Route message through plugins or transport + return this._routeMessage(jsonrpcRequest, { + relatedRequestId, + sessionId: this._transport?.sessionId, + resumptionToken, + onresumptiontoken + }); + }) + .catch(error => { this._cleanupTimeout(messageId); reject(error); }); - } }); } - /** - * Gets the current status of a task. - * - * @experimental Use `client.experimental.tasks.getTask()` to access this method. - */ - protected async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { - // @ts-expect-error SendRequestT cannot directly contain GetTaskRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); - } - - /** - * Retrieves the result of a completed task. - * - * @experimental Use `client.experimental.tasks.getTaskResult()` to access this method. - */ - protected async getTaskResult( - params: GetTaskPayloadRequest['params'], - resultSchema: T, - options?: RequestOptions - ): Promise> { - // @ts-expect-error SendRequestT cannot directly contain GetTaskPayloadRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/result', params }, resultSchema, options); - } - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @experimental Use `client.experimental.tasks.listTasks()` to access this method. - */ - protected async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { - // @ts-expect-error SendRequestT cannot directly contain ListTasksRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); - } - - /** - * Cancels a specific task. - * - * @experimental Use `client.experimental.tasks.cancelTask()` to access this method. - */ - protected async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { - // @ts-expect-error SendRequestT cannot directly contain CancelTaskRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); - } - /** * Emits a notification, which is a one-way message that does not expect a response. */ async notification(notification: SendNotificationT, options?: NotificationOptions): Promise { if (!this._transport) { - throw new Error('Not connected'); + throw StateError.notConnected('send notification'); } this.assertNotificationCapability(notification.method); - // Queue notification if related to a task AND task message queue is configured - const relatedTaskId = options?.relatedTask?.taskId; - if (relatedTaskId && this._taskStore && this._taskMessageQueue) { - // Build the JSONRPC notification with metadata - const jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0', - params: { - ...notification.params, - _meta: { - ...notification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; - - await this._enqueueTaskMessage(relatedTaskId, { - type: 'notification', - message: jsonrpcNotification, - timestamp: Date.now() - }); - - // Don't send through transport - queued messages are delivered via tasks/result only - // This prevents duplicate delivery for bidirectional transports - return; - } - const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; // A notification can only be debounced if it's in the list AND it's "simple" - // (i.e., has no parameters and no related request ID or related task that could be lost). - const canDebounce = - debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId && !options?.relatedTask; + // (i.e., has no parameters and no related request ID). + const canDebounce = debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId; + + // Create plugin context for outgoing notification + const outgoingCtx: OutgoingNotificationContext = { + sessionId: this._transport?.sessionId, + relatedRequestId: options?.relatedRequestId, + notificationOptions: options as Record + }; if (canDebounce) { // If a notification of this type is already scheduled, do nothing. @@ -1257,7 +1307,7 @@ export abstract class Protocol { + Promise.resolve().then(async () => { // Un-mark the notification so the next one can be scheduled. this._pendingDebouncedNotifications.delete(notification.method); @@ -1271,23 +1321,14 @@ export abstract class Protocol this._onerror(error)); + // Route notification through plugins + this._routeMessage(jsonrpcNotification, { + ...options, + sessionId: this._transport?.sessionId + }).catch(error => this._onerror(error, 'send-notification')); }); // Return immediately. @@ -1299,21 +1340,14 @@ export abstract class Protocol { + // Wrap handler to parse the request and delegate to registry + this._handlerRegistry.setRequestHandler(method, (request, ctx) => { const parsed = parseWithCompat(requestSchema, request) as SchemaOutput; return Promise.resolve(handler(parsed, ctx)); }); @@ -1341,15 +1376,15 @@ export abstract class Protocol) => void | Promise ): void { const method = getMethodLiteral(notificationSchema); - this._notificationHandlers.set(method, notification => { + // Wrap handler to parse the notification and delegate to registry + this._handlerRegistry.setNotificationHandler(method, notification => { const parsed = parseWithCompat(notificationSchema, notification) as SchemaOutput; return Promise.resolve(handler(parsed)); }); @@ -1373,106 +1409,7 @@ export abstract class Protocol { - // Task message queues are only used when taskStore is configured - if (!this._taskStore || !this._taskMessageQueue) { - throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); - } - - const maxQueueSize = this._options?.maxTaskQueueSize; - await this._taskMessageQueue.enqueue(taskId, message, sessionId, maxQueueSize); - } - - /** - * Clears the message queue for a task and rejects any pending request resolvers. - * @param taskId The task ID whose queue should be cleared - * @param sessionId Optional session ID for binding the operation to a specific session - */ - private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { - if (this._taskMessageQueue) { - // Reject any pending request resolvers - const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); - for (const message of messages) { - if (message.type === 'request' && isJSONRPCRequest(message.message)) { - // Extract request ID from the message - const requestId = message.message.id as RequestId; - const resolver = this._requestResolvers.get(requestId); - if (resolver) { - resolver(new McpError(ErrorCode.InternalError, 'Task cancelled or completed')); - this._requestResolvers.delete(requestId); - } else { - // Log error when resolver is missing during cleanup for better observability - this._onerror(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); - } - } - } - } - } - - /** - * Waits for a task update (new messages or status change) with abort signal support. - * Uses polling to check for updates at the task's configured poll interval. - * @param taskId The task ID to wait for - * @param signal Abort signal to cancel the wait - * @returns Promise that resolves when an update occurs or rejects if aborted - */ - private async _waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { - // Get the task's poll interval, falling back to default - let interval = this._options?.defaultTaskPollInterval ?? 1000; - try { - const task = await this._taskStore?.getTask(taskId); - if (task?.pollInterval) { - interval = task.pollInterval; - } - } catch { - // Use default interval if task lookup fails - } - - return new Promise((resolve, reject) => { - if (signal.aborted) { - reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); - return; - } - - // Wait for the poll interval, then resolve so caller can check for updates - const timeoutId = setTimeout(resolve, interval); - - // Clean up timeout and reject if aborted - signal.addEventListener( - 'abort', - () => { - clearTimeout(timeoutId); - reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); - }, - { once: true } - ); - }); + this._handlerRegistry.removeNotificationHandler(method); } } diff --git a/packages/core/src/shared/taskClientPlugin.ts b/packages/core/src/shared/taskClientPlugin.ts new file mode 100644 index 000000000..66f24ad9b --- /dev/null +++ b/packages/core/src/shared/taskClientPlugin.ts @@ -0,0 +1,446 @@ +/** + * Task Client Plugin + * + * This plugin provides client-side methods for calling task APIs on a remote server. + * It also manages task-related progress handlers. + * + * Usage: + * ```typescript + * const taskClient = client.getPlugin(TaskClientPlugin); + * const task = await taskClient?.getTask({ taskId: 'task-123' }); + * ``` + */ + +import { isTerminal } from '../experimental/tasks/interfaces.js'; +import type { + CancelTaskResult, + GetTaskPayloadRequest, + GetTaskRequest, + GetTaskResult, + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListTasksResult, + RelatedTaskMetadata, + Request, + Result, + TaskCreationParams +} from '../types/types.js'; +import { + CancelTaskResultSchema, + CreateTaskResultSchema, + ErrorCode, + GetTaskResultSchema, + isJSONRPCResultResponse, + ListTasksResultSchema, + McpError, + RELATED_TASK_META_KEY +} from '../types/types.js'; +import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; +import type { OutgoingNotificationContext, OutgoingRequestContext, PluginContext, PluginRequestOptions, ProtocolPlugin } from './plugin.js'; +import type { ProgressCallback, ProgressManagerInterface } from './progressManager.js'; +import type { RequestOptions } from './protocol.js'; +import type { ResponseMessage } from './responseMessage.js'; + +// ═══════════════════════════════════════════════════════════════════════════ +// Task-Specific Option Types +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Extended request options for task-augmented requests. + * + * Use these options when sending requests that should create or relate to tasks. + * For type safety at call sites, use `satisfies TaskRequestOptions`: + * + * @example + * ```typescript + * import type { TaskRequestOptions } from '@modelcontextprotocol/core'; + * + * // Create a task with the request + * await client.request(callToolRequest, CallToolResultSchema, { + * task: { ttl: 60000 } + * } satisfies TaskRequestOptions); + * + * // Inside a handler, associate with a parent task + * await ctx.sendRequest(req, schema, { + * relatedTask: { taskId: ctx.taskCtx?.id } + * } satisfies TaskRequestOptions); + * ``` + */ +export type TaskRequestOptions = RequestOptions & { + /** + * If provided, augments the request with task creation parameters + * to enable call-now, fetch-later execution patterns. + */ + task?: TaskCreationParams; + + /** + * If provided, associates this request with a related task. + * This is typically set internally by the SDK when handling task-augmented requests. + */ + relatedTask?: RelatedTaskMetadata; +}; + +/** + * Extended notification options for task-related notifications. + * + * Use these options when sending notifications that should be associated with a task. + * For type safety at call sites, use `satisfies TaskNotificationOptions`: + * + * @example + * ```typescript + * import type { TaskNotificationOptions } from '@modelcontextprotocol/core'; + * + * // Inside a handler, associate notification with a parent task + * await ctx.sendNotification(progressNotification, { + * relatedTask: { taskId: ctx.taskCtx?.id } + * } satisfies TaskNotificationOptions); + * ``` + */ +export type TaskNotificationOptions = { + /** + * If provided, associates this notification with a related task. + * This is typically set internally by the SDK when handling task-augmented requests. + */ + relatedTask?: RelatedTaskMetadata; +}; + +/** + * Plugin that provides client-side task API methods. + * Clients access this via getPlugin(TaskClientPlugin) to call task APIs on remote servers. + */ +export class TaskClientPlugin implements ProtocolPlugin { + readonly name = 'TaskClientPlugin'; + readonly priority = 50; // Standard priority + + private ctx?: PluginContext; + private progressManager?: ProgressManagerInterface; + + /** + * Maps task IDs to their associated progress token (message ID) and handler. + * This allows progress to continue after CreateTaskResult is returned. + */ + private readonly taskProgressHandlers = new Map(); + + /** + * Install the plugin. + */ + install(ctx: PluginContext): void { + this.ctx = ctx; + this.progressManager = ctx.progress; + } + + /** + * Called when a response is received for an outgoing request. + * Detects task creation responses and preserves progress handlers. + */ + onResponse(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): void { + if (!this.progressManager) return; + + // Check if this is a CreateTaskResult response + if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') { + const result = response.result as Record; + if (result.task && typeof result.task === 'object') { + const task = result.task as Record; + if (typeof task.taskId === 'string') { + const taskId = task.taskId; + + // Get the current progress handler before Protocol removes it + const handler = this.progressManager.getHandler(messageId); + if (handler) { + // Store the handler for this task + this.taskProgressHandlers.set(taskId, { messageId, handler }); + + // Re-register the handler so it stays active + // This is called before Protocol.removeHandler, so we need to + // re-register after Protocol removes it. We do this by + // scheduling it on next tick. + queueMicrotask(() => { + this.progressManager?.registerHandler(messageId, handler); + }); + } + } + } + } + } + + /** + * Clears the progress handler for a completed task. + * Call this when a task reaches terminal state. + * + * @param taskId - The task ID whose progress handler should be removed + */ + clearTaskProgress(taskId: string): void { + const entry = this.taskProgressHandlers.get(taskId); + if (entry) { + this.progressManager?.removeHandler(entry.messageId); + this.taskProgressHandlers.delete(taskId); + } + } + + /** + * Checks if a task has an active progress handler. + */ + hasTaskProgress(taskId: string): boolean { + return this.taskProgressHandlers.has(taskId); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Outgoing Message Hooks + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Augments outgoing requests with task metadata. + * - Adds task creation params if `task` option is provided + * - Adds related task metadata if `relatedTask` option is provided + * - Registers response resolver for task-related requests + */ + onBeforeSendRequest(request: JSONRPCRequest, ctx: OutgoingRequestContext): JSONRPCRequest | void { + // Read task-specific options from the raw options object + const options = ctx.requestOptions as TaskRequestOptions | undefined; + if (!options) return; + + let modified = request; + const { task, relatedTask } = options; + + // Augment with task creation parameters if provided + if (task) { + modified = { + ...modified, + params: { + ...modified.params, + task + } + }; + } + + // Augment with related task metadata if provided + if (relatedTask) { + const existingParams = (modified.params ?? {}) as Record; + const existingMeta = (existingParams._meta ?? {}) as Record; + modified = { + ...modified, + params: { + ...existingParams, + _meta: { + ...existingMeta, + [RELATED_TASK_META_KEY]: relatedTask + } + } + }; + + // Register resolver for task-related requests so responses route back + ctx.registerResolver(() => { + // The resolver is registered automatically by Protocol + }); + } + + // Return modified request if changes were made + if (modified === request) { + return undefined; + } + return modified; + } + + /** + * Augments outgoing notifications with task metadata. + * Adds related task metadata if `relatedTask` option is provided. + */ + onBeforeSendNotification(notification: JSONRPCNotification, ctx: OutgoingNotificationContext): JSONRPCNotification | void { + // Read task-specific options from the raw options object + const options = ctx.notificationOptions as TaskNotificationOptions | undefined; + if (!options?.relatedTask) return; + + const existingParams = (notification.params ?? {}) as Record; + const existingMeta = (existingParams._meta ?? {}) as Record; + const modified = { + ...notification, + params: { + ...existingParams, + _meta: { + ...existingMeta, + [RELATED_TASK_META_KEY]: options.relatedTask + } + } + }; + + return modified; + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Task API Methods + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Gets the current status of a task. + */ + async getTask(params: GetTaskRequest['params'], options?: PluginRequestOptions): Promise { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + return this.ctx.requests.sendRequest({ jsonrpc: '2.0', id: 0, method: 'tasks/get', params }, GetTaskResultSchema, options); + } + + /** + * Retrieves the result of a completed task. + * Uses long-polling to wait for task completion. + */ + async getTaskResult( + params: GetTaskPayloadRequest['params'], + resultSchema: T, + options?: PluginRequestOptions + ): Promise> { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + const result = await this.ctx.requests.sendRequest( + { jsonrpc: '2.0', id: 0, method: 'tasks/result', params }, + resultSchema, + options + ); + + // Clear progress handler when task result is retrieved + this.clearTaskProgress(params.taskId); + + return result; + } + + /** + * Lists all tasks, optionally with pagination. + */ + async listTasks(params?: { cursor?: string }, options?: PluginRequestOptions): Promise { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + return this.ctx.requests.sendRequest({ jsonrpc: '2.0', id: 0, method: 'tasks/list', params }, ListTasksResultSchema, options); + } + + /** + * Cancels a running task. + */ + async cancelTask(params: { taskId: string }, options?: PluginRequestOptions): Promise { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + const result = await this.ctx.requests.sendRequest( + { jsonrpc: '2.0', id: 0, method: 'tasks/cancel', params }, + CancelTaskResultSchema, + options + ); + + // Clear progress handler when task is cancelled + this.clearTaskProgress(params.taskId); + + return result; + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Task Streaming + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Sends a task-augmented request and streams status updates until completion. + * This handles the full task lifecycle: creation, polling, and result retrieval. + * + * @param request - The request to send (method and params) + * @param resultSchema - Schema to validate the final result + * @param options - Options including task creation params + * @yields ResponseMessage events for task creation, status updates, and final result/error + */ + async *requestStream( + request: Request, + resultSchema: T, + options: TaskClientRequestStreamOptions + ): AsyncGenerator>, void, void> { + if (!this.ctx) { + throw new Error('TaskClientPlugin not installed'); + } + + let taskId: string | undefined; + try { + // Send the request and get the CreateTaskResult + // Convert Request to JSONRPCRequest format for sendRequest + const jsonRpcRequest = { jsonrpc: '2.0' as const, id: 0, ...request }; + const createResult = await this.ctx.requests.sendRequest(jsonRpcRequest, CreateTaskResultSchema, options); + + // Extract taskId from the result + if (createResult.task) { + taskId = createResult.task.taskId; + yield { type: 'taskCreated', task: createResult.task }; + } else { + throw new McpError(ErrorCode.InternalError, 'Task creation did not return a task'); + } + + // Poll for task completion + while (true) { + // Get current task status + const task = await this.getTask({ taskId }, options); + yield { type: 'taskStatus', task }; + + // Check if task is terminal + if (isTerminal(task.status)) { + switch (task.status) { + case 'completed': { + // Get the final result + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + break; + } + case 'failed': { + yield { + type: 'error', + error: new McpError(ErrorCode.InternalError, `Task ${taskId} failed`) + }; + break; + } + case 'cancelled': { + yield { + type: 'error', + error: new McpError(ErrorCode.InternalError, `Task ${taskId} was cancelled`) + }; + break; + } + // No default + } + return; + } + + // When input_required, call tasks/result to deliver queued messages + // (elicitation, sampling) via SSE and block until terminal + if (task.status === 'input_required') { + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + return; + } + + // Wait before polling again + const pollInterval = task.pollInterval ?? options.defaultPollInterval ?? 1000; + await new Promise(resolve => setTimeout(resolve, pollInterval)); + + // Check if cancelled + options.signal?.throwIfAborted(); + } + } catch (error) { + yield { + type: 'error', + error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) + }; + } + } +} + +/** + * Options for TaskClientPlugin.requestStream. + */ +export interface TaskClientRequestStreamOptions extends PluginRequestOptions { + task: TaskCreationParams; + defaultPollInterval?: number; +} + +/** + * Factory function to create a TaskClientPlugin. + */ +export function createTaskClientPlugin(): TaskClientPlugin { + return new TaskClientPlugin(); +} diff --git a/packages/core/src/shared/taskPlugin.ts b/packages/core/src/shared/taskPlugin.ts new file mode 100644 index 000000000..34662bd20 --- /dev/null +++ b/packages/core/src/shared/taskPlugin.ts @@ -0,0 +1,489 @@ +/** + * Task Plugin + * + * This plugin completely abstracts all task-related functionality from the Protocol class: + * - Message routing for task-related messages (queue instead of send) + * - Task API handlers (tasks/get, tasks/result, tasks/list, tasks/cancel) + * - Task message queue management + * + * The plugin is internal to the SDK and not exposed as a public API. + */ + +import { RequestTaskStore } from '../experimental/requestTaskStore.js'; +import type { QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; +import { isTerminal } from '../experimental/tasks/interfaces.js'; +import type { + CancelTaskResult, + GetTaskPayloadRequest, + GetTaskRequest, + GetTaskResult, + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + ListTasksResult, + RequestId, + Result +} from '../types/types.js'; +import { + CancelTaskRequestSchema, + ErrorCode, + GetTaskPayloadRequestSchema, + GetTaskRequestSchema, + isJSONRPCRequest, + isTaskAugmentedRequestParams, + ListTasksRequestSchema, + McpError, + RELATED_TASK_META_KEY +} from '../types/types.js'; +import type { HandlerContextBase, PluginContext, PluginHandlerExtra, ProtocolPlugin } from './plugin.js'; +import type { Transport, TransportSendOptions } from './transport.js'; + +/** + * Configuration for the TaskPlugin. + */ +export interface TaskPluginConfig { + /** + * The task store implementation for persisting task state. + */ + readonly taskStore: TaskStore; + + /** + * Optional message queue for async message delivery during task execution. + */ + readonly taskMessageQueue?: TaskMessageQueue; + + /** + * Default polling interval (in milliseconds) for task status checks. + * Defaults to 1000ms if not specified. + */ + readonly defaultTaskPollInterval?: number; + + /** + * Maximum number of messages that can be queued per task. + * If undefined, the queue size is unbounded. + */ + readonly maxTaskQueueSize?: number; +} + +/** + * Plugin that handles all task-related MCP operations. + * This completely abstracts task functionality from the Protocol class. + */ +export class TaskPlugin implements ProtocolPlugin { + readonly name = 'TaskPlugin'; + readonly priority = 100; // High priority to run before other plugins + + private ctx?: PluginContext; + private transport?: Transport; + + constructor(private readonly config: TaskPluginConfig) {} + + // ═══════════════════════════════════════════════════════════════════════════ + // Plugin Lifecycle + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Install the plugin by registering task request handlers. + */ + install(ctx: PluginContext): void { + this.ctx = ctx; + + // Register tasks/get handler + ctx.handlers.setRequestHandler(GetTaskRequestSchema, async (request, extra) => { + return this.handleGetTask(request, extra); + }); + + // Register tasks/result handler + ctx.handlers.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra) => { + return this.handleGetTaskPayload(request, extra); + }); + + // Register tasks/list handler + ctx.handlers.setRequestHandler(ListTasksRequestSchema, async (request, extra) => { + return this.handleListTasks(request.params, extra); + }); + + // Register tasks/cancel handler + ctx.handlers.setRequestHandler(CancelTaskRequestSchema, async (request, extra) => { + return this.handleCancelTask(request.params, extra); + }); + } + + /** + * Called when transport connects. + */ + onConnect(transport: Transport): void { + this.transport = transport; + } + + /** + * Called when connection closes. + */ + onClose(): void { + this.transport = undefined; + } + + /** + * Called before a request is processed. + * Checks if task creation is supported for the request method. + */ + onRequest(request: JSONRPCRequest): JSONRPCRequest | void { + // If this request asks for task creation, check capability + const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; + if (taskCreationParams) { + // Check if this method supports task creation + // For now, we support tasks for tools/call and sampling/createMessage + const taskCapableMethods = ['tools/call', 'sampling/createMessage']; + if (!taskCapableMethods.includes(request.method)) { + throw new McpError(ErrorCode.InvalidRequest, `Task creation is not supported for method: ${request.method}`); + } + } + // Return void to pass through unchanged + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Message Routing + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Extracts the relatedTaskId from a message's _meta field. + */ + private extractRelatedTaskId( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse + ): string | undefined { + // For requests/notifications, check params._meta + if ('method' in message && 'params' in message && message.params) { + const params = message.params as Record; + const meta = params._meta as Record | undefined; + const taskMeta = meta?.[RELATED_TASK_META_KEY] as { taskId?: string } | undefined; + return taskMeta?.taskId; + } + return undefined; + } + + /** + * Determines if this plugin should route the message (queue for task delivery). + * Returns true if the message has a relatedTaskId in its metadata and task queue is configured. + */ + shouldRouteMessage( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + _options?: TransportSendOptions + ): boolean { + // Route if there's a related task ID in the message and we have a message queue + const relatedTaskId = this.extractRelatedTaskId(message); + return Boolean(relatedTaskId && this.config.taskMessageQueue); + } + + /** + * Routes the message by queueing it for task delivery. + */ + async routeMessage( + message: JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCErrorResponse, + options?: TransportSendOptions + ): Promise { + const relatedTaskId = this.extractRelatedTaskId(message); + const sessionId = options?.sessionId; + if (!relatedTaskId || !this.config.taskMessageQueue) { + throw new Error('Cannot route message: relatedTaskId or taskMessageQueue not available'); + } + + const timestamp = Date.now(); + + // Create properly typed QueuedMessage based on message structure + let queuedMessage: QueuedMessage; + if ('method' in message && 'id' in message) { + queuedMessage = { type: 'request', message: message as JSONRPCRequest, timestamp }; + } else if ('method' in message && !('id' in message)) { + queuedMessage = { type: 'notification', message: message as JSONRPCNotification, timestamp }; + } else if ('result' in message) { + queuedMessage = { type: 'response', message: message as JSONRPCResultResponse, timestamp }; + } else { + queuedMessage = { type: 'error', message: message as JSONRPCErrorResponse, timestamp }; + } + + await this.enqueueTaskMessage(relatedTaskId, queuedMessage, sessionId); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Handler Context Hook + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Builds task context for incoming request handlers. + * Extracts task creation params and related task metadata from the request, + * creates a RequestTaskStore, and returns the task context. + */ + onBuildHandlerContext(request: JSONRPCRequest, baseContext: HandlerContextBase): Record | undefined { + // Only build task context if we have a task store configured + if (!this.config.taskStore) { + return undefined; + } + + // Extract task metadata from request + const relatedTaskId = this.extractRelatedTaskId(request); + const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; + + // Create the RequestTaskStore + const requestTaskStore = new RequestTaskStore({ + taskStore: this.config.taskStore, + requestId: request.id, + request, + sessionId: baseContext.sessionId, + initialTaskId: relatedTaskId ?? '' + }); + + // Return task context that will be merged into the handler context + return { + taskCtx: { + get id() { + return requestTaskStore.currentTaskId; + }, + store: requestTaskStore, + requestedTtl: taskCreationParams?.ttl ?? null + } + }; + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Task Message Queue Management + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Enqueues a message for task delivery. + */ + private async enqueueTaskMessage(taskId: string, message: QueuedMessage, sessionId?: string): Promise { + if (!this.config.taskMessageQueue) { + throw new Error('Cannot enqueue task message: taskMessageQueue is not configured'); + } + + await this.config.taskMessageQueue.enqueue(taskId, message, sessionId, this.config.maxTaskQueueSize); + } + + /** + * Clears the message queue for a task and rejects any pending request resolvers. + */ + private async clearTaskQueue(taskId: string, sessionId?: string): Promise { + if (!this.config.taskMessageQueue || !this.ctx) { + return; + } + + // Dequeue all messages and reject pending request resolvers + const messages = await this.config.taskMessageQueue.dequeueAll(taskId, sessionId); + for (const message of messages) { + if (message.type === 'request' && isJSONRPCRequest(message.message)) { + const requestId = message.message.id as RequestId; + const resolver = this.ctx.resolvers.get(requestId); + if (resolver) { + resolver(new McpError(ErrorCode.InternalError, 'Task cancelled or completed')); + this.ctx.resolvers.remove(requestId); + } else { + this.ctx.reportError(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); + } + } + } + } + + /** + * Waits for a task update (new messages or status change) with abort signal support. + */ + private async waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { + // Get the task's poll interval, falling back to default + let interval = this.config.defaultTaskPollInterval ?? 1000; + try { + const task = await this.config.taskStore.getTask(taskId); + if (task?.pollInterval) { + interval = task.pollInterval; + } + } catch { + // Use default interval if task lookup fails + } + + return new Promise((resolve, reject) => { + if (signal.aborted) { + reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); + return; + } + + // Wait for the poll interval, then resolve so caller can check for updates + const timeoutId = setTimeout(resolve, interval); + + // Clean up timeout and reject if aborted + signal.addEventListener( + 'abort', + () => { + clearTimeout(timeoutId); + reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); + }, + { once: true } + ); + }); + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Task API Handlers + // ═══════════════════════════════════════════════════════════════════════════ + + /** + * Handler for tasks/get - retrieves task metadata. + */ + private async handleGetTask(request: GetTaskRequest, extra: PluginHandlerExtra): Promise { + const task = await this.config.taskStore.getTask(request.params.taskId, extra.mcpCtx.sessionId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + // Per spec: tasks/get responses SHALL NOT include related-task metadata + return { ...task }; + } + + /** + * Handler for tasks/result - delivers task results and queued messages. + * Implements long-polling pattern for task updates. + */ + private async handleGetTaskPayload(request: GetTaskPayloadRequest, extra: PluginHandlerExtra): Promise { + const taskId = request.params.taskId; + + const poll = async (): Promise => { + // Deliver any queued messages first + await this.deliverQueuedMessages(taskId, extra); + + // Check task status + const task = await this.config.taskStore.getTask(taskId, extra.mcpCtx.sessionId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); + } + + // If task is not terminal, wait for updates and poll again + if (!isTerminal(task.status)) { + await this.waitForTaskUpdate(taskId, extra.requestCtx.signal); + return poll(); + } + + // Task is terminal - return the result + const result = await this.config.taskStore.getTaskResult(taskId, extra.mcpCtx.sessionId); + await this.clearTaskQueue(taskId, extra.mcpCtx.sessionId); + + return { + ...result, + _meta: { + ...result._meta, + [RELATED_TASK_META_KEY]: { taskId } + } + }; + }; + + return poll(); + } + + /** + * Delivers queued messages for a task. + */ + private async deliverQueuedMessages(taskId: string, extra: PluginHandlerExtra): Promise { + const { taskMessageQueue } = this.config; + if (!taskMessageQueue || !this.ctx) { + return; + } + + let queuedMessage: QueuedMessage | undefined; + while ((queuedMessage = await taskMessageQueue.dequeue(taskId, extra.mcpCtx.sessionId))) { + // Handle response and error messages by routing to original resolver + if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { + await this.routeQueuedResponse(queuedMessage); + continue; + } + + // Send other messages (notifications, requests) on the response stream + const transport = this.ctx.transport.getTransport(); + await transport?.send(queuedMessage.message, { relatedRequestId: extra.mcpCtx.requestId }); + } + } + + /** + * Routes a queued response/error back to its original request resolver. + */ + private async routeQueuedResponse(queuedMessage: QueuedMessage): Promise { + if (!this.ctx) return; + + const message = queuedMessage.message as JSONRPCResultResponse | JSONRPCErrorResponse; + const requestId = message.id as RequestId; + + const resolver = this.ctx.resolvers.get(requestId); + if (!resolver) { + const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; + this.ctx.reportError(new Error(`${messageType} handler missing for request ${requestId}`)); + return; + } + + this.ctx.resolvers.remove(requestId); + + if (queuedMessage.type === 'response') { + resolver(message as JSONRPCResultResponse); + } else { + const errorMessage = message as JSONRPCErrorResponse; + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + + /** + * Handler for tasks/list - lists all tasks. + */ + private async handleListTasks(params: { cursor?: string } | undefined, extra: PluginHandlerExtra): Promise { + try { + const { tasks, nextCursor } = await this.config.taskStore.listTasks(params?.cursor, extra.mcpCtx.sessionId); + return { tasks, nextCursor, _meta: {} }; + } catch (error) { + throw new McpError(ErrorCode.InvalidParams, `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}`); + } + } + + /** + * Handler for tasks/cancel - cancels a running task. + */ + private async handleCancelTask(params: { taskId: string }, extra: PluginHandlerExtra): Promise { + try { + const task = await this.config.taskStore.getTask(params.taskId, extra.mcpCtx.sessionId); + + if (!task) { + throw new McpError(ErrorCode.InvalidParams, `Task not found: ${params.taskId}`); + } + + if (isTerminal(task.status)) { + throw new McpError(ErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); + } + + await this.config.taskStore.updateTaskStatus( + params.taskId, + 'cancelled', + 'Client cancelled task execution.', + extra.mcpCtx.sessionId + ); + + await this.clearTaskQueue(params.taskId, extra.mcpCtx.sessionId); + + const cancelledTask = await this.config.taskStore.getTask(params.taskId, extra.mcpCtx.sessionId); + if (!cancelledTask) { + throw new McpError(ErrorCode.InvalidParams, `Task not found after cancellation: ${params.taskId}`); + } + + return { _meta: {}, ...cancelledTask }; + } catch (error) { + if (error instanceof McpError) { + throw error; + } + throw new McpError( + ErrorCode.InvalidRequest, + `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` + ); + } + } +} + +/** + * Factory function to create a TaskPlugin. + */ +export function createTaskPlugin(config: TaskPluginConfig): TaskPlugin { + return new TaskPlugin(config); +} diff --git a/packages/core/src/shared/timeoutManager.ts b/packages/core/src/shared/timeoutManager.ts new file mode 100644 index 000000000..e87da7367 --- /dev/null +++ b/packages/core/src/shared/timeoutManager.ts @@ -0,0 +1,172 @@ +/** + * Timeout Manager + * + * Manages request timeouts for the Protocol class. + * Extracted from Protocol to follow Single Responsibility Principle. + */ + +/** + * Information about a request's timeout state + */ +export interface TimeoutInfo { + timeoutId: ReturnType; + startTime: number; + timeout: number; + maxTotalTimeout?: number; + resetTimeoutOnProgress: boolean; + onTimeout: () => void; +} + +/** + * Options for setting up a timeout + */ +export interface TimeoutOptions { + /** + * The timeout duration in milliseconds + */ + timeout: number; + + /** + * Maximum total time allowed (optional) + */ + maxTotalTimeout?: number; + + /** + * Whether to reset the timeout when progress is received + */ + resetTimeoutOnProgress?: boolean; + + /** + * Callback to invoke when the timeout expires + */ + onTimeout: () => void; +} + +/** + * Manages request timeouts for outgoing requests. + */ +export class TimeoutManager { + private _timeoutInfo: Map = new Map(); + + /** + * Sets up a timeout for a message. + * + * @param messageId - The unique identifier for the message + * @param options - Timeout configuration options + */ + setup(messageId: number, options: TimeoutOptions): void { + const { timeout, maxTotalTimeout, resetTimeoutOnProgress, onTimeout } = options; + + this._timeoutInfo.set(messageId, { + timeoutId: setTimeout(onTimeout, timeout), + startTime: Date.now(), + timeout, + maxTotalTimeout, + resetTimeoutOnProgress: resetTimeoutOnProgress ?? false, + onTimeout + }); + } + + /** + * Resets the timeout for a message (e.g., when progress is received). + * Returns true if the timeout was reset, false if it wasn't found or + * if the max total timeout would be exceeded. + * + * @param messageId - The message ID whose timeout should be reset + * @returns true if reset succeeded, false otherwise + */ + reset(messageId: number): boolean { + const info = this._timeoutInfo.get(messageId); + if (!info || !info.resetTimeoutOnProgress) { + return false; + } + + const elapsed = Date.now() - info.startTime; + + // Check if max total timeout would be exceeded + if (info.maxTotalTimeout === undefined) { + // No max total timeout, just reset with original timeout + clearTimeout(info.timeoutId); + info.timeoutId = setTimeout(info.onTimeout, info.timeout); + } else { + const remainingTotal = info.maxTotalTimeout - elapsed; + if (remainingTotal <= 0) { + // Don't reset, let the timeout fire + return false; + } + + // Clear old timeout and set new one with the smaller of: + // - original timeout + // - remaining total time + clearTimeout(info.timeoutId); + const newTimeout = Math.min(info.timeout, remainingTotal); + info.timeoutId = setTimeout(info.onTimeout, newTimeout); + } + + return true; + } + + /** + * Cleans up the timeout for a message (e.g., when a response is received). + * + * @param messageId - The message ID whose timeout should be cleaned up + */ + cleanup(messageId: number): void { + const info = this._timeoutInfo.get(messageId); + if (info) { + clearTimeout(info.timeoutId); + this._timeoutInfo.delete(messageId); + } + } + + /** + * Gets the timeout info for a message. + * + * @param messageId - The message ID + * @returns The timeout info or undefined if not found + */ + get(messageId: number): TimeoutInfo | undefined { + return this._timeoutInfo.get(messageId); + } + + /** + * Checks if a timeout exists for a message. + * + * @param messageId - The message ID + * @returns true if a timeout exists + */ + has(messageId: number): boolean { + return this._timeoutInfo.has(messageId); + } + + /** + * Gets the elapsed time for a message's timeout. + * + * @param messageId - The message ID + * @returns The elapsed time in milliseconds, or undefined if not found + */ + getElapsed(messageId: number): number | undefined { + const info = this._timeoutInfo.get(messageId); + if (!info) { + return undefined; + } + return Date.now() - info.startTime; + } + + /** + * Clears all timeouts. + */ + clearAll(): void { + for (const info of this._timeoutInfo.values()) { + clearTimeout(info.timeoutId); + } + this._timeoutInfo.clear(); + } + + /** + * Gets the number of active timeouts. + */ + get size(): number { + return this._timeoutInfo.size; + } +} diff --git a/packages/core/src/shared/transport.ts b/packages/core/src/shared/transport.ts index 87608f124..844445c7b 100644 --- a/packages/core/src/shared/transport.ts +++ b/packages/core/src/shared/transport.ts @@ -1,5 +1,29 @@ import type { JSONRPCMessage, MessageExtraInfo, RequestId } from '../types/types.js'; +// ═══════════════════════════════════════════════════════════════════════════ +// Connection State +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Represents the current state of a transport connection. + * + * State transitions: + * - disconnected → connecting → connected + * - connected → reconnecting → connected + * - * → error (from any state on unrecoverable error) + * - * → disconnected (on close) + */ +export type ConnectionState = 'disconnected' | 'connecting' | 'authenticating' | 'connected' | 'reconnecting' | 'error'; + +/** + * Callback for connection state changes + */ +export type ConnectionStateChangeCallback = (state: ConnectionState, previousState: ConnectionState) => void; + +// ═══════════════════════════════════════════════════════════════════════════ +// Fetch Utilities +// ═══════════════════════════════════════════════════════════════════════════ + export type FetchLike = (url: string | URL, init?: RequestInit) => Promise; /** @@ -54,6 +78,11 @@ export type TransportSendOptions = { */ relatedRequestId?: RequestId; + /** + * Optional session ID for the message routing context. + */ + sessionId?: string; + /** * The resumption token used to continue long-running requests that were interrupted. * @@ -125,4 +154,23 @@ export interface Transport { * Sets the protocol version used for the connection (called when the initialize response is received). */ setProtocolVersion?: (version: string) => void; + + // ─── Connection State (optional, for transports that support it) ─── + + /** + * The current connection state. + * Optional - not all transports track state. + */ + readonly state?: ConnectionState; + + /** + * Whether the transport is currently connected. + * This is a convenience property equivalent to `state === 'connected'`. + */ + readonly isConnected?: boolean; + + /** + * Callback for when the connection state changes. + */ + onStateChange?: ConnectionStateChangeCallback; } diff --git a/packages/core/src/util/content.ts b/packages/core/src/util/content.ts new file mode 100644 index 000000000..e1ac5fa30 --- /dev/null +++ b/packages/core/src/util/content.ts @@ -0,0 +1,211 @@ +/** + * Content Formatting Helpers + * + * Utilities for working with tool call results and content types. + * Reduces boilerplate when processing mixed content types in results. + */ + +import type { + AudioContent, + BlobResourceContents, + ContentBlock, + EmbeddedResource, + ImageContent, + ResourceLink, + TextContent, + TextResourceContents +} from '../types/types.js'; + +/** + * Type guard to check if content is TextContent + */ +export function isTextContent(item: ContentBlock): item is TextContent { + return item.type === 'text'; +} + +/** + * Type guard to check if content is ImageContent + */ +export function isImageContent(item: ContentBlock): item is ImageContent { + return item.type === 'image'; +} + +/** + * Type guard to check if content is AudioContent + */ +export function isAudioContent(item: ContentBlock): item is AudioContent { + return item.type === 'audio'; +} + +/** + * Type guard to check if content is EmbeddedResource + */ +export function isEmbeddedResource(item: ContentBlock): item is EmbeddedResource { + return item.type === 'resource'; +} + +/** + * Type guard to check if content is ResourceLink + */ +export function isResourceLink(item: ContentBlock): item is ResourceLink { + return item.type === 'resource_link'; +} + +/** + * Extracts all text content from a tool result content array. + * + * @example + * ```typescript + * const result = await client.callTool('search', { query: 'hello' }); + * const texts = extractTextContent(result.content); + * console.log(texts.join('\n')); + * ``` + */ +export function extractTextContent(content: ContentBlock[]): string[] { + return content.filter(item => isTextContent(item)).map(item => item.text); +} + +/** + * Formats all text content from a tool result as a single string. + * + * @param content - The content array from a tool result + * @param separator - Separator between text items (default: newline) + * @returns Concatenated text content + * + * @example + * ```typescript + * const result = await client.callTool('search', { query: 'hello' }); + * const text = formatTextContent(result.content); + * ``` + */ +export function formatTextContent(content: ContentBlock[], separator: string = '\n'): string { + return extractTextContent(content).join(separator); +} + +/** + * Extracts all image content from a tool result content array. + */ +export function extractImageContent(content: ContentBlock[]): ImageContent[] { + return content.filter(item => isImageContent(item)); +} + +/** + * Extracts all audio content from a tool result content array. + */ +export function extractAudioContent(content: ContentBlock[]): AudioContent[] { + return content.filter(item => isAudioContent(item)); +} + +/** + * Extracts all embedded resources from a tool result content array. + */ +export function extractEmbeddedResources(content: ContentBlock[]): EmbeddedResource[] { + return content.filter(item => isEmbeddedResource(item)); +} + +/** + * Extracts all resource links from a tool result content array. + */ +export function extractResourceLinks(content: ContentBlock[]): ResourceLink[] { + return content.filter(item => isResourceLink(item)); +} + +/** + * Creates a text content item. + * + * @example + * ```typescript + * return { content: [text('Hello, world!')] }; + * ``` + */ +export function text(content: string, annotations?: TextContent['annotations']): TextContent { + return { + type: 'text', + text: content, + annotations + }; +} + +/** + * Creates an image content item from base64 data. + * + * @example + * ```typescript + * return { content: [image(base64Data, 'image/png')] }; + * ``` + */ +export function image(data: string, mimeType: string, annotations?: ImageContent['annotations']): ImageContent { + return { + type: 'image', + data, + mimeType, + annotations + }; +} + +/** + * Creates an audio content item from base64 data. + * + * @example + * ```typescript + * return { content: [audio(base64Data, 'audio/wav')] }; + * ``` + */ +export function audio(data: string, mimeType: string, annotations?: AudioContent['annotations']): AudioContent { + return { + type: 'audio', + data, + mimeType, + annotations + }; +} + +/** + * Creates an embedded resource content item. + * + * @example + * ```typescript + * return { + * content: [ + * embeddedResource({ + * uri: 'file:///path/to/file.txt', + * mimeType: 'text/plain', + * text: 'File contents' + * }) + * ] + * }; + * ``` + */ +export function embeddedResource( + resource: TextResourceContents | BlobResourceContents, + annotations?: EmbeddedResource['annotations'] +): EmbeddedResource { + return { + type: 'resource', + resource, + annotations + }; +} + +/** + * Creates a resource link content item. + * + * @example + * ```typescript + * return { + * content: [ + * resourceLink({ + * uri: 'file:///path/to/file.txt', + * mimeType: 'text/plain', + * name: 'file.txt' + * }) + * ] + * }; + * ``` + */ +export function resourceLink(link: Omit): ResourceLink { + return { + type: 'resource_link', + ...link + }; +} diff --git a/packages/middleware/express/package.json b/packages/middleware/express/package.json index 408cf446a..844b3e3ec 100644 --- a/packages/middleware/express/package.json +++ b/packages/middleware/express/package.json @@ -37,8 +37,8 @@ "build": "tsdown", "build:watch": "tsdown --watch", "prepack": "npm run build", - "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", - "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "lint": "eslint src/ && prettier --ignore-path ../../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../../.prettierignore --write .", "check": "npm run typecheck && npm run lint", "test": "vitest run", "test:watch": "vitest" diff --git a/packages/middleware/hono/package.json b/packages/middleware/hono/package.json index 3377c5fb4..afef9b02e 100644 --- a/packages/middleware/hono/package.json +++ b/packages/middleware/hono/package.json @@ -37,8 +37,8 @@ "build": "tsdown", "build:watch": "tsdown --watch", "prepack": "npm run build", - "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", - "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "lint": "eslint src/ && prettier --ignore-path ../../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../../.prettierignore --write .", "check": "npm run typecheck && npm run lint", "test": "vitest run", "test:watch": "vitest" diff --git a/packages/middleware/node/package.json b/packages/middleware/node/package.json index 766346613..5024fa55f 100644 --- a/packages/middleware/node/package.json +++ b/packages/middleware/node/package.json @@ -36,8 +36,8 @@ "build": "tsdown", "build:watch": "tsdown --watch", "prepack": "npm run build", - "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", - "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "lint": "eslint src/ && prettier --ignore-path ../../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../../.prettierignore --write .", "check": "npm run typecheck && npm run lint", "test": "vitest run", "test:watch": "vitest", diff --git a/packages/server/src/experimental/tasks/interfaces.ts b/packages/server/src/experimental/tasks/interfaces.ts index 05b679215..391b79e49 100644 --- a/packages/server/src/experimental/tasks/interfaces.ts +++ b/packages/server/src/experimental/tasks/interfaces.ts @@ -6,7 +6,6 @@ import type { AnySchema, CallToolResult, - ContextInterface, CreateTaskResult, GetTaskResult, Result, @@ -15,6 +14,7 @@ import type { ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import type { ServerContextInterface } from '../../server/context.js'; import type { BaseToolCallback } from '../../server/mcp.js'; // ============================================================================ @@ -28,7 +28,7 @@ import type { BaseToolCallback } from '../../server/mcp.js'; export type CreateTaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback, Args>; +> = BaseToolCallback, Args>; /** * Handler for task operations (get, getResult). @@ -37,7 +37,7 @@ export type CreateTaskRequestHandler< export type TaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback, Args>; +> = BaseToolCallback, Args>; /** * Interface for task-based tool handlers. diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index b99333e2f..51f3e0ace 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,8 +1,12 @@ +export * from './server/builder.js'; export * from './server/completable.js'; export * from './server/context.js'; export * from './server/mcp.js'; +export * from './server/middleware.js'; export * from './server/middleware/hostHeaderValidation.js'; +export * from './server/registries/index.js'; export * from './server/server.js'; +export * from './server/sessions.js'; export * from './server/stdio.js'; export * from './server/streamableHttp.js'; diff --git a/packages/server/src/server/builder.ts b/packages/server/src/server/builder.ts new file mode 100644 index 000000000..380dee527 --- /dev/null +++ b/packages/server/src/server/builder.ts @@ -0,0 +1,427 @@ +/** + * McpServer Builder + * + * Provides a fluent API for configuring and creating McpServer instances. + * The builder is an additive convenience layer - the existing constructor + * API remains available for users who prefer it. + * + * @example + * ```typescript + * const server = McpServer.builder() + * .name('my-server') + * .version('1.0.0') + * .useMiddleware(loggingMiddleware) + * .tool('greet', { inputSchema: { name: z.string() } }, handler) + * .build(); + * ``` + */ + +import type { ToolAnnotations, ToolExecution, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import { objectFromShape } from '@modelcontextprotocol/core'; + +import type { McpServer, PromptCallback, ReadResourceCallback, ResourceMetadata, ToolCallback } from './mcp.js'; +import type { PromptMiddleware, ResourceMiddleware, ToolMiddleware, UniversalMiddleware } from './middleware.js'; +import { PromptRegistry } from './registries/promptRegistry.js'; +import { ResourceRegistry } from './registries/resourceRegistry.js'; +import { ToolRegistry } from './registries/toolRegistry.js'; +import type { ServerOptions as BaseServerOptions } from './server.js'; + +// ZodRawShape for backward compatibility +type ZodRawShape = ZodRawShapeCompat; + +/** + * Extended server options including builder-specific options + */ +export interface McpServerBuilderOptions extends BaseServerOptions { + /** Server name */ + name?: string; + /** Server version */ + version?: string; +} + +/** + * Error handler type for application errors + */ +export type OnErrorHandler = (error: Error, ctx: ErrorContext) => OnErrorReturn | void | Promise; + +/** + * Error handler type for protocol errors + */ +export type OnProtocolErrorHandler = ( + error: Error, + ctx: ErrorContext +) => OnProtocolErrorReturn | void | Promise; + +/** + * Return type for onError handler + */ +export type OnErrorReturn = string | { code?: number; message?: string; data?: unknown } | Error; + +/** + * Return type for onProtocolError handler (code cannot be changed) + */ +export type OnProtocolErrorReturn = string | { message?: string; data?: unknown }; + +/** + * Context provided to error handlers + */ +export interface ErrorContext { + type: 'tool' | 'resource' | 'prompt' | 'protocol'; + name?: string; + method: string; + requestId: string; +} + +/** + * Fluent builder for McpServer instances. + * + * Provides a declarative, chainable API for configuring servers. + * All configuration is collected and applied when build() is called. + */ +export class McpServerBuilder { + private _name?: string; + private _version?: string; + private _options: McpServerBuilderOptions = {}; + + // Global middleware + private _universalMiddleware: UniversalMiddleware[] = []; + private _toolMiddleware: ToolMiddleware[] = []; + private _resourceMiddleware: ResourceMiddleware[] = []; + private _promptMiddleware: PromptMiddleware[] = []; + + // Registries (created without callbacks - McpServer will bind them later) + private _toolRegistry = new ToolRegistry(); + private _resourceRegistry = new ResourceRegistry(); + private _promptRegistry = new PromptRegistry(); + + // Per-item middleware (keyed by name/uri) + private _perToolMiddleware = new Map(); + private _perResourceMiddleware = new Map(); + private _perPromptMiddleware = new Map(); + + // Error handlers + private _onError?: OnErrorHandler; + private _onProtocolError?: OnProtocolErrorHandler; + + /** + * Sets the server name. + */ + name(name: string): this { + this._name = name; + return this; + } + + /** + * Sets the server version. + */ + version(version: string): this { + this._version = version; + return this; + } + + /** + * Sets server options. + */ + options(options: McpServerBuilderOptions): this { + this._options = { ...this._options, ...options }; + return this; + } + + /** + * Adds universal middleware that runs for all request types. + */ + useMiddleware(middleware: UniversalMiddleware): this { + this._universalMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for tool calls. + */ + useToolMiddleware(middleware: ToolMiddleware): this { + this._toolMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for resource reads. + */ + useResourceMiddleware(middleware: ResourceMiddleware): this { + this._resourceMiddleware.push(middleware); + return this; + } + + /** + * Adds middleware specifically for prompt requests. + */ + usePromptMiddleware(middleware: PromptMiddleware): this { + this._promptMiddleware.push(middleware); + return this; + } + + /** + * Registers a tool with the server. + * + * @example + * ```typescript + * .tool('greet', { + * description: 'Greet a user', + * inputSchema: { name: z.string() } + * }, async ({ name }) => { + * return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + * }) + * ``` + */ + tool( + name: string, + config: { + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: ZodRawShape; + middleware?: ToolMiddleware; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + }, + handler: ToolCallback + ): this { + this._toolRegistry.register({ + name, + title: config.title, + description: config.description, + inputSchema: config.inputSchema ? objectFromShape(config.inputSchema) : undefined, + outputSchema: config.outputSchema ? objectFromShape(config.outputSchema) : undefined, + annotations: config.annotations, + execution: config.execution, + _meta: config._meta, + handler: handler as ToolCallback + }); + + // Store per-tool middleware if provided + if (config.middleware) { + this._perToolMiddleware.set(name, config.middleware); + } + + return this; + } + + /** + * Registers a resource with the server. + * + * @example + * ```typescript + * .resource('config', 'file:///config', { + * description: 'Configuration file' + * }, async (uri) => { + * return { contents: [{ uri, mimeType: 'application/json', text: '{}' }] }; + * }) + * ``` + */ + resource( + name: string, + uri: string, + config: { + title?: string; + description?: string; + mimeType?: string; + metadata?: ResourceMetadata; + middleware?: ResourceMiddleware; + }, + readCallback: ReadResourceCallback + ): this { + this._resourceRegistry.register({ + name, + uri, + title: config.title, + description: config.description, + mimeType: config.mimeType, + metadata: config.metadata, + readCallback + }); + + // Store per-resource middleware if provided + if (config.middleware) { + this._perResourceMiddleware.set(uri, config.middleware); + } + + return this; + } + + /** + * Registers a prompt with the server. + * + * @example + * ```typescript + * .prompt('summarize', { + * description: 'Summarize text', + * argsSchema: { text: z.string() } + * }, async ({ text }) => { + * return { messages: [{ role: 'user', content: { type: 'text', text } }] }; + * }) + * ``` + */ + prompt( + name: string, + config: { + title?: string; + description?: string; + argsSchema?: Args; + middleware?: PromptMiddleware; + }, + callback: PromptCallback + ): this { + this._promptRegistry.register({ + name, + title: config.title, + description: config.description, + argsSchema: config.argsSchema, + callback: callback as PromptCallback + }); + + // Store per-prompt middleware if provided + if (config.middleware) { + this._perPromptMiddleware.set(name, config.middleware); + } + + return this; + } + + /** + * Sets the application error handler. + * Called when a handler throws an error. + */ + onError(handler: OnErrorHandler): this { + this._onError = handler; + return this; + } + + /** + * Sets the protocol error handler. + * Called for protocol-level errors (parse, method not found, etc.) + */ + onProtocolError(handler: OnProtocolErrorHandler): this { + this._onProtocolError = handler; + return this; + } + + /** + * Gets the collected configuration (for debugging/testing). + */ + getConfig(): { + name?: string; + version?: string; + options: McpServerBuilderOptions; + toolCount: number; + resourceCount: number; + promptCount: number; + middlewareCount: number; + } { + return { + name: this._name, + version: this._version, + options: this._options, + toolCount: this._toolRegistry.size, + resourceCount: this._resourceRegistry.size, + promptCount: this._promptRegistry.size, + middlewareCount: + this._universalMiddleware.length + + this._toolMiddleware.length + + this._resourceMiddleware.length + + this._promptMiddleware.length + }; + } + + /** + * Builds and returns the configured McpServer instance. + */ + build(): McpServer { + if (!this._name) { + throw new Error('Server name is required. Use .name() to set it.'); + } + if (!this._version) { + throw new Error('Server version is required. Use .version() to set it.'); + } + + const result: BuilderResult = { + serverInfo: { + name: this._name, + version: this._version + }, + options: this._options, + middleware: { + universal: this._universalMiddleware, + tool: this._toolMiddleware, + resource: this._resourceMiddleware, + prompt: this._promptMiddleware + }, + registries: { + tools: this._toolRegistry, + resources: this._resourceRegistry, + prompts: this._promptRegistry + }, + perItemMiddleware: { + tools: this._perToolMiddleware, + resources: this._perResourceMiddleware, + prompts: this._perPromptMiddleware + }, + errorHandlers: { + onError: this._onError, + onProtocolError: this._onProtocolError + } + }; + + // Dynamically import McpServer to create the instance + // eslint-disable-next-line @typescript-eslint/no-require-imports + const { McpServer: McpServerClass } = require('./mcp.js'); + return McpServerClass.fromBuilderResult(result); + } +} + +/** + * Result of building the server configuration. + * Used to create the actual McpServer instance. + */ +export interface BuilderResult { + serverInfo: { + name: string; + version: string; + }; + options: McpServerBuilderOptions; + middleware: { + universal: UniversalMiddleware[]; + tool: ToolMiddleware[]; + resource: ResourceMiddleware[]; + prompt: PromptMiddleware[]; + }; + registries: { + tools: ToolRegistry; + resources: ResourceRegistry; + prompts: PromptRegistry; + }; + perItemMiddleware: { + tools: Map; + resources: Map; + prompts: Map; + }; + errorHandlers: { + onError?: OnErrorHandler; + onProtocolError?: OnProtocolErrorHandler; + }; +} + +/** + * Creates a new McpServerBuilder instance. + * + * @example + * ```typescript + * const server = createServerBuilder() + * .name('my-server') + * .version('1.0.0') + * .tool('greet', { inputSchema: { name: z.string() } }, handler) + * .build(); + * ``` + */ +export function createServerBuilder(): McpServerBuilder { + return new McpServerBuilder(); +} diff --git a/packages/server/src/server/context.ts b/packages/server/src/server/context.ts index 81e693f68..5913cff30 100644 --- a/packages/server/src/server/context.ts +++ b/packages/server/src/server/context.ts @@ -203,7 +203,7 @@ export class ServerContext< request: JSONRPCRequest; mcpContext: McpContext; requestCtx: ServerRequestContext; - task: TaskContext | undefined; + task?: TaskContext; }) { super({ request: args.request, diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 35298c791..be5b75831 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -7,16 +7,15 @@ import type { CompleteRequestPrompt, CompleteRequestResourceTemplate, CompleteResult, - ContextInterface, CreateTaskResult, + ErrorInterceptionContext, + ErrorInterceptionResult, GetPromptResult, Implementation, ListPromptsResult, ListResourcesResult, ListToolsResult, LoggingMessageNotification, - Prompt, - PromptArgument, PromptReference, ReadResourceResult, Resource, @@ -26,7 +25,6 @@ import type { ServerNotification, ServerRequest, ShapeOutput, - Tool, ToolAnnotations, ToolExecution, Transport, @@ -43,8 +41,6 @@ import { getObjectShape, getParseErrorMessage, GetPromptRequestSchema, - getSchemaDescription, - isSchemaOptional, ListPromptsRequestSchema, ListResourcesRequestSchema, ListResourceTemplatesRequestSchema, @@ -54,18 +50,55 @@ import { objectFromShape, ReadResourceRequestSchema, safeParseAsync, - toJsonSchemaCompat, - UriTemplate, - validateAndWarnToolName + UriTemplate } from '@modelcontextprotocol/core'; import { ZodOptional } from 'zod'; import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; +import type { + BuilderResult, + ErrorContext, + OnErrorHandler, + OnErrorReturn, + OnProtocolErrorHandler, + OnProtocolErrorReturn +} from './builder.js'; +import { McpServerBuilder } from './builder.js'; import { getCompleter, isCompletable } from './completable.js'; +import type { ServerContextInterface } from './context.js'; +import type { + PromptContext, + PromptMiddleware, + ResourceContext, + ResourceMiddleware, + ToolContext, + ToolMiddleware, + UniversalMiddleware +} from './middleware.js'; +import { MiddlewareManager } from './middleware.js'; +import type { RegisteredPromptEntity } from './registries/promptRegistry.js'; +import { PromptRegistry } from './registries/promptRegistry.js'; +import type { RegisteredResourceEntity, RegisteredResourceTemplateEntity } from './registries/resourceRegistry.js'; +import { ResourceRegistry, ResourceTemplateRegistry } from './registries/resourceRegistry.js'; +import type { RegisteredToolEntity } from './registries/toolRegistry.js'; +import { ToolRegistry } from './registries/toolRegistry.js'; import type { ServerOptions } from './server.js'; import { Server } from './server.js'; +/** + * Internal options for McpServer that can include pre-created registries. + * Used by fromBuilderResult to pass registries from the builder. + */ +interface InternalMcpServerOptions extends ServerOptions { + /** Pre-created tool registry (callbacks will be bound by McpServer) */ + _toolRegistry?: ToolRegistry; + /** Pre-created resource registry (callbacks will be bound by McpServer) */ + _resourceRegistry?: ResourceRegistry; + /** Pre-created prompt registry (callbacks will be bound by McpServer) */ + _promptRegistry?: PromptRegistry; +} + /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. * For advanced usage (like sending notifications or setting custom request handlers), use the underlying @@ -77,16 +110,200 @@ export class McpServer { */ public readonly server: Server; - private _registeredResources: { [uri: string]: RegisteredResource } = {}; - private _registeredResourceTemplates: { - [name: string]: RegisteredResourceTemplate; - } = {}; - private _registeredTools: { [name: string]: RegisteredTool } = {}; - private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + private readonly _toolRegistry: ToolRegistry; + private readonly _resourceRegistry: ResourceRegistry; + private readonly _resourceTemplateRegistry: ResourceTemplateRegistry; + private readonly _promptRegistry: PromptRegistry; + private readonly _middleware: MiddlewareManager; private _experimental?: { tasks: ExperimentalMcpServerTasks }; + // Error handlers (single callback pattern, not event-based) + private _onErrorHandler?: OnErrorHandler; + private _onProtocolErrorHandler?: OnProtocolErrorHandler; + constructor(serverInfo: Implementation, options?: ServerOptions) { + const internalOptions = options as InternalMcpServerOptions | undefined; this.server = new Server(serverInfo, options); + + // Use pre-created registries if provided, otherwise create new ones + // Either way, bind the notification callbacks to this server instance + this._toolRegistry = internalOptions?._toolRegistry ?? new ToolRegistry(); + this._toolRegistry.setNotifyCallback(() => this.sendToolListChanged()); + + this._resourceRegistry = internalOptions?._resourceRegistry ?? new ResourceRegistry(); + this._resourceRegistry.setNotifyCallback(() => this.sendResourceListChanged()); + + // Resource template registry is always created fresh (not passed from builder) + this._resourceTemplateRegistry = new ResourceTemplateRegistry(); + this._resourceTemplateRegistry.setNotifyCallback(() => this.sendResourceListChanged()); + + this._promptRegistry = internalOptions?._promptRegistry ?? new PromptRegistry(); + this._promptRegistry.setNotifyCallback(() => this.sendPromptListChanged()); + + // Initialize middleware manager + this._middleware = new MiddlewareManager(); + + // If registries were pre-populated, set up request handlers + if (this._toolRegistry.size > 0) { + this.setToolRequestHandlers(); + } + if (this._resourceRegistry.size > 0) { + this.setResourceRequestHandlers(); + } + if (this._promptRegistry.size > 0) { + this.setPromptRequestHandlers(); + } + } + + /** + * Gets the middleware manager for advanced middleware configuration. + */ + get middleware(): MiddlewareManager { + return this._middleware; + } + + /** + * Registers universal middleware that runs for all request types (tools, resources, prompts). + * + * @param middleware - The middleware function to register + * @returns This McpServer instance for chaining + */ + useMiddleware(middleware: UniversalMiddleware): this { + this._middleware.useMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for tool calls. + * + * @param middleware - The tool middleware function to register + * @returns This McpServer instance for chaining + */ + useToolMiddleware(middleware: ToolMiddleware): this { + this._middleware.useToolMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for resource reads. + * + * @param middleware - The resource middleware function to register + * @returns This McpServer instance for chaining + */ + useResourceMiddleware(middleware: ResourceMiddleware): this { + this._middleware.useResourceMiddleware(middleware); + return this; + } + + /** + * Registers middleware specifically for prompt requests. + * + * @param middleware - The prompt middleware function to register + * @returns This McpServer instance for chaining + */ + usePromptMiddleware(middleware: PromptMiddleware): this { + this._middleware.usePromptMiddleware(middleware); + return this; + } + + /** + * Gets the tool registry for advanced tool management. + */ + get tools(): ToolRegistry { + return this._toolRegistry; + } + + /** + * Gets the resource registry for advanced resource management. + */ + get resources(): ResourceRegistry { + return this._resourceRegistry; + } + + /** + * Gets the resource template registry for advanced template management. + */ + get resourceTemplates(): ResourceTemplateRegistry { + return this._resourceTemplateRegistry; + } + + /** + * Gets the prompt registry for advanced prompt management. + */ + get prompts(): PromptRegistry { + return this._promptRegistry; + } + + /** + * Creates a new McpServerBuilder for fluent configuration. + * + * @example + * ```typescript + * const server = McpServer.builder() + * .name('my-server') + * .version('1.0.0') + * .tool('greet', { name: z.string() }, async ({ name }) => ({ + * content: [{ type: 'text', text: `Hello, ${name}!` }] + * })) + * .build(); + * ``` + */ + static builder(): McpServerBuilder { + return new McpServerBuilder(); + } + + /** + * Creates an McpServer from a BuilderResult configuration. + * + * @param result - The result from McpServerBuilder.build() + * @returns A configured McpServer instance + */ + static fromBuilderResult(result: BuilderResult): McpServer { + // Create server with pre-populated registries from the builder + // The constructor will bind notification callbacks to the registries + const internalOptions: InternalMcpServerOptions = { + ...result.options, + _toolRegistry: result.registries.tools, + _resourceRegistry: result.registries.resources, + _promptRegistry: result.registries.prompts + }; + + const server = new McpServer(result.serverInfo, internalOptions); + + // Wire up error handlers + if (result.errorHandlers.onError) { + server.onError(result.errorHandlers.onError); + } + if (result.errorHandlers.onProtocolError) { + server.onProtocolError(result.errorHandlers.onProtocolError); + } + + // Apply global middleware from builder + for (const middleware of result.middleware.universal) { + server.useMiddleware(middleware); + } + for (const middleware of result.middleware.tool) { + server.useToolMiddleware(middleware); + } + for (const middleware of result.middleware.resource) { + server.useResourceMiddleware(middleware); + } + for (const middleware of result.middleware.prompt) { + server.usePromptMiddleware(middleware); + } + + // Apply per-item middleware + for (const [name, middleware] of result.perItemMiddleware.tools) { + server._middleware.useToolMiddlewareFor(name, middleware); + } + for (const [uri, middleware] of result.perItemMiddleware.resources) { + server._middleware.useResourceMiddlewareFor(uri, middleware); + } + for (const [name, middleware] of result.perItemMiddleware.prompts) { + server._middleware.usePromptMiddlewareFor(name, middleware); + } + + return server; } /** @@ -140,45 +357,13 @@ export class McpServer { this.server.setRequestHandler( ListToolsRequestSchema, (): ListToolsResult => ({ - tools: Object.entries(this._registeredTools) - .filter(([, tool]) => tool.enabled) - .map(([name, tool]): Tool => { - const toolDefinition: Tool = { - name, - title: tool.title, - description: tool.description, - inputSchema: (() => { - const obj = normalizeObjectSchema(tool.inputSchema); - return obj - ? (toJsonSchemaCompat(obj, { - strictUnions: true, - pipeStrategy: 'input' - }) as Tool['inputSchema']) - : EMPTY_OBJECT_JSON_SCHEMA; - })(), - annotations: tool.annotations, - execution: tool.execution, - _meta: tool._meta - }; - - if (tool.outputSchema) { - const obj = normalizeObjectSchema(tool.outputSchema); - if (obj) { - toolDefinition.outputSchema = toJsonSchemaCompat(obj, { - strictUnions: true, - pipeStrategy: 'output' - }) as Tool['outputSchema']; - } - } - - return toolDefinition; - }) + tools: this._toolRegistry.getProtocolTools() }) ); this.server.setRequestHandler(CallToolRequestSchema, async (request, ctx): Promise => { try { - const tool = this._registeredTools[request.params.name]; + const tool = this._toolRegistry.getTool(request.params.name); if (!tool) { throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); } @@ -211,17 +396,37 @@ export class McpServer { return await this.handleAutomaticTaskPolling(tool, request, ctx); } - // Normal execution path - const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const result = await this.executeToolHandler(tool, args, ctx); + // Build middleware context + const middlewareCtx: ToolContext = { + type: 'tool', + requestId: String(ctx.mcpCtx.requestId), + authInfo: ctx.requestCtx.authInfo, + signal: ctx.requestCtx.signal, + name: request.params.name, + args: request.params.arguments + }; + + // Execute with middleware (including per-tool middleware if registered) + const perToolMiddleware = this._middleware.getToolMiddlewareFor(request.params.name); + const result = await this._middleware.executeToolMiddleware( + middlewareCtx, + async (mwCtx, modifiedArgs) => { + const argsToUse = modifiedArgs ?? mwCtx.args; + const validatedArgs = await this.validateToolInput(tool, argsToUse, request.params.name); + const handlerResult = await this.executeToolHandler(tool, validatedArgs, ctx); + + // Return CreateTaskResult immediately for task requests + if (isTaskRequest) { + return handlerResult as CallToolResult; + } - // Return CreateTaskResult immediately for task requests - if (isTaskRequest) { - return result; - } + // Validate output schema for non-task requests + await this.validateToolOutput(tool, handlerResult, request.params.name); + return handlerResult as CallToolResult; + }, + perToolMiddleware + ); - // Validate output schema for non-task requests - await this.validateToolOutput(tool, result, request.params.name); return result; } catch (error) { if (error instanceof McpError && error.code === ErrorCode.UrlElicitationRequired) { @@ -256,7 +461,7 @@ export class McpServer { * Validates tool input arguments against the tool's input schema. */ private async validateToolInput< - Tool extends RegisteredTool, + Tool extends RegisteredToolEntity, Args extends Tool['inputSchema'] extends infer InputSchema ? InputSchema extends AnySchema ? SchemaOutput @@ -284,7 +489,11 @@ export class McpServer { /** * Validates tool output against the tool's output schema. */ - private async validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string): Promise { + private async validateToolOutput( + tool: RegisteredToolEntity, + result: CallToolResult | CreateTaskResult, + toolName: string + ): Promise { if (!tool.outputSchema) { return; } @@ -322,9 +531,9 @@ export class McpServer { * Executes a tool handler (either regular or task-based). */ private async executeToolHandler( - tool: RegisteredTool, + tool: RegisteredToolEntity, args: unknown, - ctx: ContextInterface + ctx: ServerContextInterface ): Promise { const handler = tool.handler as AnyToolHandler; const isTaskHandler = 'createTask' in handler; @@ -361,9 +570,9 @@ export class McpServer { * Handles automatic task polling for tools with taskSupport 'optional'. */ private async handleAutomaticTaskPolling( - tool: RegisteredTool, + tool: RegisteredToolEntity, request: RequestT, - ctx: ContextInterface + ctx: ServerContextInterface ): Promise { if (!ctx.taskCtx?.store) { throw new Error('No task store provided for task-capable tool.'); @@ -428,7 +637,7 @@ export class McpServer { } private async handlePromptCompletion(request: CompleteRequestPrompt, ref: PromptReference): Promise { - const prompt = this._registeredPrompts[ref.name]; + const prompt = this._promptRegistry.getPrompt(ref.name); if (!prompt) { throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} not found`); } @@ -459,10 +668,10 @@ export class McpServer { request: CompleteRequestResourceTemplate, ref: ResourceTemplateReference ): Promise { - const template = Object.values(this._registeredResourceTemplates).find(t => t.resourceTemplate.uriTemplate.toString() === ref.uri); + const template = this._resourceTemplateRegistry.values().find(t => t.template.uriTemplate.toString() === ref.uri); if (!template) { - if (this._registeredResources[ref.uri]) { + if (this._resourceRegistry.getResource(ref.uri)) { // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). return EMPTY_COMPLETION_RESULT; } @@ -470,7 +679,7 @@ export class McpServer { throw new McpError(ErrorCode.InvalidParams, `Resource template ${request.params.ref.uri} not found`); } - const completer = template.resourceTemplate.completeCallback(request.params.argument.name); + const completer = template.template.completeCallback(request.params.argument.name); if (!completer) { return EMPTY_COMPLETION_RESULT; } @@ -497,21 +706,15 @@ export class McpServer { }); this.server.setRequestHandler(ListResourcesRequestSchema, async (request, ctx) => { - const resources = Object.entries(this._registeredResources) - .filter(([_, resource]) => resource.enabled) - .map(([uri, resource]) => ({ - uri, - name: resource.name, - ...resource.metadata - })); + const resources = this._resourceRegistry.getProtocolResources(); const templateResources: Resource[] = []; - for (const template of Object.values(this._registeredResourceTemplates)) { - if (!template.resourceTemplate.listCallback) { + for (const template of this._resourceTemplateRegistry.getEnabled()) { + if (!template.template.listCallback) { continue; } - const result = await template.resourceTemplate.listCallback(ctx); + const result = await template.template.listCallback(ctx); for (const resource of result.resources) { templateResources.push({ ...template.metadata, @@ -524,34 +727,57 @@ export class McpServer { return { resources: [...resources, ...templateResources] }; }); - this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { - const resourceTemplates = Object.entries(this._registeredResourceTemplates).map(([name, template]) => ({ - name, - uriTemplate: template.resourceTemplate.uriTemplate.toString(), - ...template.metadata - })); - - return { resourceTemplates }; - }); + this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => ({ + resourceTemplates: this._resourceTemplateRegistry.getProtocolResourceTemplates() + })); this.server.setRequestHandler(ReadResourceRequestSchema, async (request, ctx) => { const uri = new URL(request.params.uri); // First check for exact resource match - const resource = this._registeredResources[uri.toString()]; + const resource = this._resourceRegistry.getResource(uri.toString()); if (resource) { if (!resource.enabled) { throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); } - return resource.readCallback(uri, ctx); + + // Build middleware context + const middlewareCtx: ResourceContext = { + type: 'resource', + requestId: String(ctx.mcpCtx.requestId), + authInfo: ctx.requestCtx.authInfo, + signal: ctx.requestCtx.signal, + uri: uri.toString() + }; + + // Execute with middleware (including per-resource middleware if registered) + const perResourceMiddleware = this._middleware.getResourceMiddlewareFor(uri.toString()); + return this._middleware.executeResourceMiddleware( + middlewareCtx, + async (mwCtx, modifiedUri) => { + const uriToUse = modifiedUri ? new URL(modifiedUri) : uri; + return resource.readCallback(uriToUse, ctx); + }, + perResourceMiddleware + ); } // Then check templates - for (const template of Object.values(this._registeredResourceTemplates)) { - const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); - if (variables) { - return template.readCallback(uri, variables, ctx); - } + const match = this._resourceTemplateRegistry.findMatchingTemplate(uri.toString()); + if (match) { + // Build middleware context for template + const middlewareCtx: ResourceContext = { + type: 'resource', + requestId: String(ctx.mcpCtx.requestId), + authInfo: ctx.requestCtx.authInfo, + signal: ctx.requestCtx.signal, + uri: uri.toString() + }; + + // Execute with middleware (templates don't have per-item middleware from builder) + return this._middleware.executeResourceMiddleware(middlewareCtx, async () => { + return match.template.readCallback(uri, match.variables, ctx); + }); } throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); @@ -579,21 +805,12 @@ export class McpServer { this.server.setRequestHandler( ListPromptsRequestSchema, (): ListPromptsResult => ({ - prompts: Object.entries(this._registeredPrompts) - .filter(([, prompt]) => prompt.enabled) - .map(([name, prompt]): Prompt => { - return { - name, - title: prompt.title, - description: prompt.description, - arguments: prompt.argsSchema ? promptArgumentsFromSchema(prompt.argsSchema) : undefined - }; - }) + prompts: this._promptRegistry.getProtocolPrompts() }) ); this.server.setRequestHandler(GetPromptRequestSchema, async (request, ctx): Promise => { - const prompt = this._registeredPrompts[request.params.name]; + const prompt = this._promptRegistry.getPrompt(request.params.name); if (!prompt) { throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); } @@ -602,23 +819,43 @@ export class McpServer { throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); } - if (prompt.argsSchema) { - const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; - const parseResult = await safeParseAsync(argsObj, request.params.arguments); - if (!parseResult.success) { - const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; - const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); - } + // Build middleware context + const middlewareCtx: PromptContext = { + type: 'prompt', + requestId: String(ctx.mcpCtx.requestId), + authInfo: ctx.requestCtx.authInfo, + signal: ctx.requestCtx.signal, + name: request.params.name, + args: request.params.arguments + }; - const args = parseResult.data; - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, ctx)); - } else { - const cb = prompt.callback as PromptCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((cb as any)(ctx)); - } + // Execute with middleware (including per-prompt middleware if registered) + const perPromptMiddleware = this._middleware.getPromptMiddlewareFor(request.params.name); + return this._middleware.executePromptMiddleware( + middlewareCtx, + async (mwCtx, modifiedArgs) => { + const argsToUse = modifiedArgs ?? mwCtx.args; + + if (prompt.argsSchema) { + const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; + const parseResult = await safeParseAsync(argsObj, argsToUse); + if (!parseResult.success) { + const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const errorMessage = getParseErrorMessage(error); + throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); + } + + const args = parseResult.data; + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(args, ctx)); + } else { + const cb = prompt.callback as PromptCallback; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve((cb as any)(ctx)); + } + }, + perPromptMiddleware + ); }); this._promptHandlersInitialized = true; @@ -628,225 +865,59 @@ export class McpServer { * Registers a resource with a config object and callback. * For static resources, use a URI string. For dynamic resources, use a ResourceTemplate. */ - registerResource(name: string, uriOrTemplate: string, config: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + registerResource( + name: string, + uriOrTemplate: string, + config: ResourceMetadata, + readCallback: ReadResourceCallback + ): RegisteredResourceEntity; registerResource( name: string, uriOrTemplate: ResourceTemplate, config: ResourceMetadata, readCallback: ReadResourceTemplateCallback - ): RegisteredResourceTemplate; + ): RegisteredResourceTemplateEntity; registerResource( name: string, uriOrTemplate: string | ResourceTemplate, config: ResourceMetadata, readCallback: ReadResourceCallback | ReadResourceTemplateCallback - ): RegisteredResource | RegisteredResourceTemplate { + ): RegisteredResourceEntity | RegisteredResourceTemplateEntity { if (typeof uriOrTemplate === 'string') { - if (this._registeredResources[uriOrTemplate]) { - throw new Error(`Resource ${uriOrTemplate} is already registered`); - } - - const registeredResource = this._createRegisteredResource( + const registeredResource = this._resourceRegistry.register({ name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceCallback - ); + uri: uriOrTemplate, + title: (config as BaseMetadata).title, + description: config.description, + mimeType: config.mimeType, + metadata: config, + readCallback: readCallback as ReadResourceCallback + }); this.setResourceRequestHandlers(); - this.sendResourceListChanged(); return registeredResource; } else { - if (this._registeredResourceTemplates[name]) { - throw new Error(`Resource template ${name} is already registered`); - } - - const registeredResourceTemplate = this._createRegisteredResourceTemplate( + const registeredResourceTemplate = this._resourceTemplateRegistry.register({ name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceTemplateCallback - ); + template: uriOrTemplate, + title: (config as BaseMetadata).title, + description: config.description, + mimeType: config.mimeType, + metadata: config, + readCallback: readCallback as ReadResourceTemplateCallback + }); this.setResourceRequestHandlers(); - this.sendResourceListChanged(); - return registeredResourceTemplate; - } - } - - private _createRegisteredResource( - name: string, - title: string | undefined, - uri: string, - metadata: ResourceMetadata | undefined, - readCallback: ReadResourceCallback - ): RegisteredResource { - const registeredResource: RegisteredResource = { - name, - title, - metadata, - readCallback, - enabled: true, - disable: () => registeredResource.update({ enabled: false }), - enable: () => registeredResource.update({ enabled: true }), - remove: () => registeredResource.update({ uri: null }), - update: updates => { - if (updates.uri !== undefined && updates.uri !== uri) { - delete this._registeredResources[uri]; - if (updates.uri) this._registeredResources[updates.uri] = registeredResource; - } - if (updates.name !== undefined) registeredResource.name = updates.name; - if (updates.title !== undefined) registeredResource.title = updates.title; - if (updates.metadata !== undefined) registeredResource.metadata = updates.metadata; - if (updates.callback !== undefined) registeredResource.readCallback = updates.callback; - if (updates.enabled !== undefined) registeredResource.enabled = updates.enabled; - this.sendResourceListChanged(); - } - }; - this._registeredResources[uri] = registeredResource; - return registeredResource; - } - - private _createRegisteredResourceTemplate( - name: string, - title: string | undefined, - template: ResourceTemplate, - metadata: ResourceMetadata | undefined, - readCallback: ReadResourceTemplateCallback - ): RegisteredResourceTemplate { - const registeredResourceTemplate: RegisteredResourceTemplate = { - resourceTemplate: template, - title, - metadata, - readCallback, - enabled: true, - disable: () => registeredResourceTemplate.update({ enabled: false }), - enable: () => registeredResourceTemplate.update({ enabled: true }), - remove: () => registeredResourceTemplate.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - delete this._registeredResourceTemplates[name]; - if (updates.name) this._registeredResourceTemplates[updates.name] = registeredResourceTemplate; - } - if (updates.title !== undefined) registeredResourceTemplate.title = updates.title; - if (updates.template !== undefined) registeredResourceTemplate.resourceTemplate = updates.template; - if (updates.metadata !== undefined) registeredResourceTemplate.metadata = updates.metadata; - if (updates.callback !== undefined) registeredResourceTemplate.readCallback = updates.callback; - if (updates.enabled !== undefined) registeredResourceTemplate.enabled = updates.enabled; - this.sendResourceListChanged(); - } - }; - this._registeredResourceTemplates[name] = registeredResourceTemplate; - - // If the resource template has any completion callbacks, enable completions capability - const variableNames = template.uriTemplate.variableNames; - const hasCompleter = Array.isArray(variableNames) && variableNames.some(v => !!template.completeCallback(v)); - if (hasCompleter) { - this.setCompletionRequestHandler(); - } - - return registeredResourceTemplate; - } - private _createRegisteredPrompt( - name: string, - title: string | undefined, - description: string | undefined, - argsSchema: PromptArgsRawShape | undefined, - callback: PromptCallback - ): RegisteredPrompt { - const registeredPrompt: RegisteredPrompt = { - title, - description, - argsSchema: argsSchema === undefined ? undefined : objectFromShape(argsSchema), - callback, - enabled: true, - disable: () => registeredPrompt.update({ enabled: false }), - enable: () => registeredPrompt.update({ enabled: true }), - remove: () => registeredPrompt.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - delete this._registeredPrompts[name]; - if (updates.name) this._registeredPrompts[updates.name] = registeredPrompt; - } - if (updates.title !== undefined) registeredPrompt.title = updates.title; - if (updates.description !== undefined) registeredPrompt.description = updates.description; - if (updates.argsSchema !== undefined) registeredPrompt.argsSchema = objectFromShape(updates.argsSchema); - if (updates.callback !== undefined) registeredPrompt.callback = updates.callback; - if (updates.enabled !== undefined) registeredPrompt.enabled = updates.enabled; - this.sendPromptListChanged(); - } - }; - this._registeredPrompts[name] = registeredPrompt; - - // If any argument uses a Completable schema, enable completions capability - if (argsSchema) { - const hasCompletable = Object.values(argsSchema).some(field => { - const inner: unknown = field instanceof ZodOptional ? field._def?.innerType : field; - return isCompletable(inner); - }); - if (hasCompletable) { + // If the resource template has any completion callbacks, enable completions capability + const variableNames = uriOrTemplate.uriTemplate.variableNames; + const hasCompleter = Array.isArray(variableNames) && variableNames.some(v => !!uriOrTemplate.completeCallback(v)); + if (hasCompleter) { this.setCompletionRequestHandler(); } - } - - return registeredPrompt; - } - - private _createRegisteredTool( - name: string, - title: string | undefined, - description: string | undefined, - inputSchema: ZodRawShapeCompat | AnySchema | undefined, - outputSchema: ZodRawShapeCompat | AnySchema | undefined, - annotations: ToolAnnotations | undefined, - execution: ToolExecution | undefined, - _meta: Record | undefined, - handler: AnyToolHandler - ): RegisteredTool { - // Validate tool name according to SEP specification - validateAndWarnToolName(name); - - const registeredTool: RegisteredTool = { - title, - description, - inputSchema: getZodSchemaObject(inputSchema), - outputSchema: getZodSchemaObject(outputSchema), - annotations, - execution, - _meta, - handler: handler, - enabled: true, - disable: () => registeredTool.update({ enabled: false }), - enable: () => registeredTool.update({ enabled: true }), - remove: () => registeredTool.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - if (typeof updates.name === 'string') { - validateAndWarnToolName(updates.name); - } - delete this._registeredTools[name]; - if (updates.name) this._registeredTools[updates.name] = registeredTool; - } - if (updates.title !== undefined) registeredTool.title = updates.title; - if (updates.description !== undefined) registeredTool.description = updates.description; - if (updates.paramsSchema !== undefined) registeredTool.inputSchema = objectFromShape(updates.paramsSchema); - if (updates.outputSchema !== undefined) registeredTool.outputSchema = objectFromShape(updates.outputSchema); - if (updates.callback !== undefined) registeredTool.handler = updates.callback; - if (updates.annotations !== undefined) registeredTool.annotations = updates.annotations; - if (updates._meta !== undefined) registeredTool._meta = updates._meta; - if (updates.enabled !== undefined) registeredTool.enabled = updates.enabled; - this.sendToolListChanged(); - } - }; - this._registeredTools[name] = registeredTool; - this.setToolRequestHandlers(); - this.sendToolListChanged(); - - return registeredTool; + return registeredResourceTemplate; + } } /** @@ -860,27 +931,27 @@ export class McpServer { inputSchema?: InputArgs; outputSchema?: OutputArgs; annotations?: ToolAnnotations; + execution?: ToolExecution; _meta?: Record; }, cb: ToolCallback - ): RegisteredTool { - if (this._registeredTools[name]) { - throw new Error(`Tool ${name} is already registered`); - } + ): RegisteredToolEntity { + const { title, description, inputSchema, outputSchema, annotations, execution, _meta } = config; - const { title, description, inputSchema, outputSchema, annotations, _meta } = config; - - return this._createRegisteredTool( + const registeredTool = this._toolRegistry.register({ name, title, description, - inputSchema, - outputSchema, + inputSchema: getZodSchemaObject(inputSchema), + outputSchema: getZodSchemaObject(outputSchema), annotations, - { taskSupport: 'forbidden' }, + execution: execution ?? { taskSupport: 'forbidden' }, _meta, - cb as ToolCallback - ); + handler: cb as ToolCallback + }); + + this.setToolRequestHandlers(); + return registeredTool; } /** @@ -894,23 +965,29 @@ export class McpServer { argsSchema?: Args; }, cb: PromptCallback - ): RegisteredPrompt { - if (this._registeredPrompts[name]) { - throw new Error(`Prompt ${name} is already registered`); - } - + ): RegisteredPromptEntity { const { title, description, argsSchema } = config; - const registeredPrompt = this._createRegisteredPrompt( + const registeredPrompt = this._promptRegistry.register({ name, title, description, argsSchema, - cb as PromptCallback - ); + callback: cb as PromptCallback + }); this.setPromptRequestHandlers(); - this.sendPromptListChanged(); + + // If any argument uses a Completable schema, enable completions capability + if (argsSchema) { + const hasCompletable = Object.values(argsSchema).some(field => { + const inner: unknown = field instanceof ZodOptional ? field._def?.innerType : field; + return isCompletable(inner); + }); + if (hasCompletable) { + this.setCompletionRequestHandler(); + } + } return registeredPrompt; } @@ -959,6 +1036,128 @@ export class McpServer { this.server.sendPromptListChanged(); } } + + /** + * Updates the error interceptor on the underlying Server based on current handlers. + * This combines both onError and onProtocolError handlers into a single interceptor. + */ + private _updateErrorInterceptor(): void { + if (!this._onErrorHandler && !this._onProtocolErrorHandler) { + // No handlers, clear the interceptor + this.server.setErrorInterceptor(undefined); + return; + } + + this.server.setErrorInterceptor(async (error: Error, ctx: ErrorInterceptionContext): Promise => { + const errorContext: ErrorContext = { + type: ctx.type === 'protocol' ? 'protocol' : 'tool', // Map to ErrorContext type + method: ctx.method, + requestId: typeof ctx.requestId === 'string' ? ctx.requestId : String(ctx.requestId) + }; + + let result: OnErrorReturn | OnProtocolErrorReturn | void = undefined; + + if (ctx.type === 'protocol' && this._onProtocolErrorHandler) { + // Protocol error - use onProtocolError handler + result = await this._onProtocolErrorHandler(error, errorContext); + } else if (this._onErrorHandler) { + // Application error (or protocol error without specific handler) - use onError handler + result = await this._onErrorHandler(error, errorContext); + } + + if (result === undefined || result === null) { + return undefined; + } + + // Convert the handler result to ErrorInterceptionResult + if (typeof result === 'string') { + return { message: result }; + } else if (result instanceof Error) { + const errorWithCode = result as Error & { code?: number; data?: unknown }; + return { + message: result.message, + code: ctx.type === 'application' ? errorWithCode.code : undefined, + data: errorWithCode.data + }; + } else { + // Object with code/message/data + return { + message: result.message, + code: ctx.type === 'application' ? (result as OnErrorReturn & { code?: number }).code : undefined, + data: result.data + }; + } + }); + } + + /** + * Registers an error handler for application errors in tool/resource/prompt handlers. + * + * The handler receives the error and a context object with information about where + * the error occurred. It can optionally return a custom error response that will + * modify the error sent to the client. + * + * Note: This is a single-handler pattern. Setting a new handler replaces any previous one. + * The handler is awaited, so async handlers are fully supported. + * + * @param handler - Error handler function + * @returns Unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = server.onError(async (error, ctx) => { + * console.error(`Error in ${ctx.type}/${ctx.method}: ${error.message}`); + * // Optionally return a custom error response + * return { + * code: -32000, + * message: `Application error: ${error.message}`, + * data: { type: ctx.type } + * }; + * }); + * ``` + */ + onError(handler: OnErrorHandler): () => void { + this._onErrorHandler = handler; + this._updateErrorInterceptor(); + return this._clearOnErrorHandler.bind(this); + } + + private _clearOnErrorHandler(): void { + this._onErrorHandler = undefined; + this._updateErrorInterceptor(); + } + + /** + * Registers an error handler for protocol errors (method not found, parse error, etc.). + * + * The handler receives the error and a context object. It can optionally return + * a custom error response. Note that the error code cannot be changed for protocol + * errors as they have fixed codes per the MCP specification. + * + * Note: This is a single-handler pattern. Setting a new handler replaces any previous one. + * The handler is awaited, so async handlers are fully supported. + * + * @param handler - Error handler function + * @returns Unsubscribe function + * + * @example + * ```typescript + * const unsubscribe = server.onProtocolError(async (error, ctx) => { + * console.error(`Protocol error in ${ctx.method}: ${error.message}`); + * return { message: `Protocol error: ${error.message}` }; + * }); + * ``` + */ + onProtocolError(handler: OnProtocolErrorHandler): () => void { + this._onProtocolErrorHandler = handler; + this._updateErrorInterceptor(); + return this._clearOnProtocolErrorHandler.bind(this); + } + + private _clearOnProtocolErrorHandler(): void { + this._onProtocolErrorHandler = undefined; + this._updateErrorInterceptor(); + } } /** @@ -1021,7 +1220,7 @@ export class ResourceTemplate { export type BaseToolCallback< SendResultT extends Result, - Extra extends ContextInterface, + Extra extends ServerContextInterface, Args extends undefined | ZodRawShapeCompat | AnySchema > = Args extends ZodRawShapeCompat ? (args: ShapeOutput, ctx: Extra) => SendResultT | Promise @@ -1041,7 +1240,7 @@ export type BaseToolCallback< */ export type ToolCallback = BaseToolCallback< CallToolResult, - ContextInterface, + ServerContextInterface, Args >; @@ -1076,11 +1275,6 @@ export type RegisteredTool = { remove(): void; }; -const EMPTY_OBJECT_JSON_SCHEMA = { - type: 'object' as const, - properties: {} -}; - /** * Checks if a value looks like a Zod schema by checking for parse/safeParse methods. */ @@ -1160,7 +1354,7 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - ctx: ContextInterface + ctx: ServerContextInterface ) => ListResourcesResult | Promise; /** @@ -1168,7 +1362,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - ctx: ContextInterface + ctx: ServerContextInterface ) => ReadResourceResult | Promise; export type RegisteredResource = { @@ -1196,7 +1390,7 @@ export type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - ctx: ContextInterface + ctx: ServerContextInterface ) => ReadResourceResult | Promise; export type RegisteredResourceTemplate = { @@ -1221,8 +1415,8 @@ export type RegisteredResourceTemplate = { type PromptArgsRawShape = ZodRawShapeCompat; export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, ctx: ContextInterface) => GetPromptResult | Promise - : (ctx: ContextInterface) => GetPromptResult | Promise; + ? (args: ShapeOutput, ctx: ServerContextInterface) => GetPromptResult | Promise + : (ctx: ServerContextInterface) => GetPromptResult | Promise; export type RegisteredPrompt = { title?: string; @@ -1243,22 +1437,6 @@ export type RegisteredPrompt = { remove(): void; }; -function promptArgumentsFromSchema(schema: AnyObjectSchema): PromptArgument[] { - const shape = getObjectShape(schema); - if (!shape) return []; - return Object.entries(shape).map(([name, field]): PromptArgument => { - // Get description - works for both v3 and v4 - const description = getSchemaDescription(field); - // Check if optional - works for both v3 and v4 - const isOptional = isSchemaOptional(field); - return { - name, - description, - required: !isOptional - }; - }); -} - function getMethodValue(schema: AnyObjectSchema): string { const shape = getObjectShape(schema); const methodSchema = shape?.method as AnySchema | undefined; diff --git a/packages/server/src/server/middleware.ts b/packages/server/src/server/middleware.ts new file mode 100644 index 000000000..529ff39bb --- /dev/null +++ b/packages/server/src/server/middleware.ts @@ -0,0 +1,453 @@ +/** + * McpServer Middleware System + * + * Provides a flexible middleware system for cross-cutting concerns like + * logging, authentication, rate limiting, metrics, and caching. + * + * Design follows Express/Koa/Hono patterns with the next() pattern for + * maximum flexibility. + */ + +import type { AuthInfo, CallToolResult, GetPromptResult, ReadResourceResult } from '@modelcontextprotocol/core'; + +// ═══════════════════════════════════════════════════════════════════════════ +// Context Interfaces +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Base context shared by all middleware + */ +interface BaseMiddlewareContext { + /** The request ID from JSON-RPC */ + requestId: string; + /** Authentication info if available */ + authInfo?: AuthInfo; + /** Abort signal for cancellation */ + signal: AbortSignal; +} + +/** + * Context for tool middleware + */ +export interface ToolContext extends BaseMiddlewareContext { + type: 'tool'; + /** The name of the tool being called */ + name: string; + /** The arguments passed to the tool */ + args: unknown; +} + +/** + * Context for resource middleware + */ +export interface ResourceContext extends BaseMiddlewareContext { + type: 'resource'; + /** The URI of the resource being read */ + uri: string; +} + +/** + * Context for prompt middleware + */ +export interface PromptContext extends BaseMiddlewareContext { + type: 'prompt'; + /** The name of the prompt being requested */ + name: string; + /** The arguments passed to the prompt */ + args: unknown; +} + +/** + * Union type for all middleware contexts + */ +export type MiddlewareContext = ToolContext | ResourceContext | PromptContext; + +// ═══════════════════════════════════════════════════════════════════════════ +// Middleware Types +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Next function for tool middleware. + * Can optionally pass modified args to the handler. + */ +export type ToolNextFn = (modifiedArgs?: unknown) => Promise; + +/** + * Next function for resource middleware. + * Can optionally pass a modified URI to the handler. + */ +export type ResourceNextFn = (modifiedUri?: string) => Promise; + +/** + * Next function for prompt middleware. + * Can optionally pass modified args to the handler. + */ +export type PromptNextFn = (modifiedArgs?: unknown) => Promise; + +/** + * Next function for universal middleware. + * Can optionally pass modified input to the handler. + */ +export type UniversalNextFn = (modified?: unknown) => Promise; + +/** + * Middleware for tool calls. + * Can abort, short-circuit, modify args, or pass through. + */ +export type ToolMiddleware = (ctx: ToolContext, next: ToolNextFn) => Promise; + +/** + * Middleware for resource reads. + * Can abort, short-circuit, modify URI, or pass through. + */ +export type ResourceMiddleware = (ctx: ResourceContext, next: ResourceNextFn) => Promise; + +/** + * Middleware for prompt requests. + * Can abort, short-circuit, modify args, or pass through. + */ +export type PromptMiddleware = (ctx: PromptContext, next: PromptNextFn) => Promise; + +/** + * Universal middleware that works for all types. + * Use the `type` property on the context to differentiate. + */ +export type UniversalMiddleware = (ctx: MiddlewareContext, next: UniversalNextFn) => Promise; + +// ═══════════════════════════════════════════════════════════════════════════ +// Middleware Chain Builder +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Composes multiple middleware functions into a single function. + * Each middleware can: + * - Abort with error: throw + * - Short-circuit: return result without calling next() + * - Modify input: call next(modified) + * - Pass through: call next() + * + * @param middlewares - Array of middleware functions + * @param handler - The final handler to call + * @returns A composed function that runs all middleware and the handler + */ +export function composeMiddleware( + middlewares: Array<(ctx: TCtx, next: (input?: TInput) => Promise) => Promise>, + handler: (ctx: TCtx, input?: TInput) => Promise +): (ctx: TCtx, initialInput?: TInput) => Promise { + return async (ctx: TCtx, initialInput?: TInput): Promise => { + let index = -1; + let currentInput: TInput | undefined = initialInput; + + const dispatch = async (i: number, input?: TInput): Promise => { + if (i <= index) { + throw new Error('next() called multiple times'); + } + index = i; + currentInput = input ?? currentInput; + + if (i >= middlewares.length) { + // All middleware processed, call the final handler + return handler(ctx, currentInput); + } + + const middleware = middlewares[i]; + if (!middleware) { + return handler(ctx, currentInput); + } + return middleware(ctx, (modifiedInput?: TInput) => dispatch(i + 1, modifiedInput)); + }; + + return dispatch(0, initialInput); + }; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Middleware Manager +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Manages middleware registration and execution for McpServer. + */ +export class MiddlewareManager { + private _universalMiddleware: UniversalMiddleware[] = []; + private _toolMiddleware: ToolMiddleware[] = []; + private _resourceMiddleware: ResourceMiddleware[] = []; + private _promptMiddleware: PromptMiddleware[] = []; + + // Per-item middleware (keyed by name/uri) + private _perToolMiddleware = new Map(); + private _perResourceMiddleware = new Map(); + private _perPromptMiddleware = new Map(); + + /** + * Registers universal middleware that runs for all request types. + */ + useMiddleware(middleware: UniversalMiddleware): this { + this._universalMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for tool calls. + */ + useToolMiddleware(middleware: ToolMiddleware): this { + this._toolMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for resource reads. + */ + useResourceMiddleware(middleware: ResourceMiddleware): this { + this._resourceMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware specifically for prompt requests. + */ + usePromptMiddleware(middleware: PromptMiddleware): this { + this._promptMiddleware.push(middleware); + return this; + } + + /** + * Registers middleware for a specific tool by name. + */ + useToolMiddlewareFor(name: string, middleware: ToolMiddleware): this { + this._perToolMiddleware.set(name, middleware); + return this; + } + + /** + * Registers middleware for a specific resource by URI. + */ + useResourceMiddlewareFor(uri: string, middleware: ResourceMiddleware): this { + this._perResourceMiddleware.set(uri, middleware); + return this; + } + + /** + * Registers middleware for a specific prompt by name. + */ + usePromptMiddlewareFor(name: string, middleware: PromptMiddleware): this { + this._perPromptMiddleware.set(name, middleware); + return this; + } + + /** + * Gets per-tool middleware if registered. + */ + getToolMiddlewareFor(name: string): ToolMiddleware | undefined { + return this._perToolMiddleware.get(name); + } + + /** + * Gets per-resource middleware if registered. + */ + getResourceMiddlewareFor(uri: string): ResourceMiddleware | undefined { + return this._perResourceMiddleware.get(uri); + } + + /** + * Gets per-prompt middleware if registered. + */ + getPromptMiddlewareFor(name: string): PromptMiddleware | undefined { + return this._perPromptMiddleware.get(name); + } + + /** + * Executes tool middleware chain with the given context and handler. + */ + async executeToolMiddleware( + ctx: ToolContext, + handler: (ctx: ToolContext, args?: unknown) => Promise, + perRegistrationMiddleware?: ToolMiddleware + ): Promise { + // Build middleware chain: universal -> tool-specific -> per-registration + const chain: ToolMiddleware[] = []; + + // Add universal middleware (cast to tool middleware) + for (const mw of this._universalMiddleware) { + chain.push(async (c, next) => { + return (await mw(c, async modified => { + return next(modified as unknown); + })) as CallToolResult; + }); + } + + // Add tool-specific middleware + chain.push(...this._toolMiddleware); + + // Add per-registration middleware if provided + if (perRegistrationMiddleware) { + chain.push(perRegistrationMiddleware); + } + + // Compose and execute + const composed = composeMiddleware(chain, handler); + return composed(ctx, ctx.args); + } + + /** + * Executes resource middleware chain with the given context and handler. + */ + async executeResourceMiddleware( + ctx: ResourceContext, + handler: (ctx: ResourceContext, uri?: string) => Promise, + perRegistrationMiddleware?: ResourceMiddleware + ): Promise { + // Build middleware chain: universal -> resource-specific -> per-registration + const chain: ResourceMiddleware[] = []; + + // Add universal middleware (cast to resource middleware) + for (const mw of this._universalMiddleware) { + chain.push(async (c, next) => { + return (await mw(c, async modified => { + return next(modified as string); + })) as ReadResourceResult; + }); + } + + // Add resource-specific middleware + chain.push(...this._resourceMiddleware); + + // Add per-registration middleware if provided + if (perRegistrationMiddleware) { + chain.push(perRegistrationMiddleware); + } + + // Compose and execute + const composed = composeMiddleware(chain, handler); + return composed(ctx, ctx.uri); + } + + /** + * Executes prompt middleware chain with the given context and handler. + */ + async executePromptMiddleware( + ctx: PromptContext, + handler: (ctx: PromptContext, args?: unknown) => Promise, + perRegistrationMiddleware?: PromptMiddleware + ): Promise { + // Build middleware chain: universal -> prompt-specific -> per-registration + const chain: PromptMiddleware[] = []; + + // Add universal middleware (cast to prompt middleware) + for (const mw of this._universalMiddleware) { + chain.push(async (c, next) => { + return (await mw(c, async modified => { + return next(modified as unknown); + })) as GetPromptResult; + }); + } + + // Add prompt-specific middleware + chain.push(...this._promptMiddleware); + + // Add per-registration middleware if provided + if (perRegistrationMiddleware) { + chain.push(perRegistrationMiddleware); + } + + // Compose and execute + const composed = composeMiddleware(chain, handler); + return composed(ctx, ctx.args); + } + + /** + * Checks if any middleware is registered. + */ + hasMiddleware(): boolean { + return ( + this._universalMiddleware.length > 0 || + this._toolMiddleware.length > 0 || + this._resourceMiddleware.length > 0 || + this._promptMiddleware.length > 0 || + this._perToolMiddleware.size > 0 || + this._perResourceMiddleware.size > 0 || + this._perPromptMiddleware.size > 0 + ); + } + + /** + * Clears all registered middleware. + */ + clear(): void { + this._universalMiddleware = []; + this._toolMiddleware = []; + this._resourceMiddleware = []; + this._promptMiddleware = []; + this._perToolMiddleware.clear(); + this._perResourceMiddleware.clear(); + this._perPromptMiddleware.clear(); + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Built-in Middleware Factories +// ═══════════════════════════════════════════════════════════════════════════ + +/** + * Options for the logging middleware. + */ +export interface LoggingMiddlewareOptions { + /** Log level: 'debug', 'info', 'warn', 'error' */ + level?: 'debug' | 'info' | 'warn' | 'error'; + /** Custom logger function */ + logger?: (level: string, message: string, data?: unknown) => void; +} + +/** + * Creates a logging middleware that logs all requests. + * + * @example + * ```typescript + * server.useMiddleware(createLoggingMiddleware({ level: 'debug' })); + * ``` + */ +export function createLoggingMiddleware(options: LoggingMiddlewareOptions = {}): UniversalMiddleware { + const { level = 'info', logger = console.log } = options; + + return async (ctx, next) => { + const identifier = ctx.type === 'resource' ? ctx.uri : ctx.name; + logger(level, `→ ${ctx.type}: ${identifier}`, { + type: ctx.type, + requestId: ctx.requestId + }); + + const start = Date.now(); + + try { + const result = await next(); + const duration = Date.now() - start; + logger(level, `← ${ctx.type}: ${identifier} (${duration}ms)`, { + type: ctx.type, + requestId: ctx.requestId, + duration + }); + return result; + } catch (error) { + const duration = Date.now() - start; + logger('error', `✗ ${ctx.type}: ${identifier} (${duration}ms)`, { + type: ctx.type, + requestId: ctx.requestId, + duration, + error + }); + throw error; + } + }; +} + +/** + * Options for the rate limit middleware. + */ +export interface RateLimitMiddlewareOptions { + /** Maximum requests per time window */ + max: number; + /** Time window in milliseconds */ + windowMs?: number; + /** Error message when rate limited */ + message?: string; +} \ No newline at end of file diff --git a/packages/server/src/server/registries/baseRegistry.ts b/packages/server/src/server/registries/baseRegistry.ts new file mode 100644 index 000000000..fa168e2b5 --- /dev/null +++ b/packages/server/src/server/registries/baseRegistry.ts @@ -0,0 +1,229 @@ +/** + * Base Registry + * + * Abstract base class for managing collections of registered entities + * (tools, resources, prompts). Provides common functionality for + * CRUD operations and notifications. + */ + +/** + * Base interface for all registered definitions + */ +export interface RegisteredDefinition { + /** + * Whether the definition is currently enabled + */ + enabled: boolean; + + /** + * Enable the definition + */ + enable(): this; + + /** + * Disable the definition + */ + disable(): this; + + /** + * Remove the definition from its registry + */ + remove(): void; +} + +/** + * Callback type for registry change notifications + */ +export type RegistryNotifyCallback = () => void; + +/** + * Abstract base class for registries. + * Provides common functionality for managing collections of registered entities. + * + * @template T - The type of registered entity this registry manages + */ +export abstract class BaseRegistry { + /** + * Internal storage for registered items + */ + protected _items = new Map(); + + /** + * Optional callback for change notifications. + * Can be set after construction via setNotifyCallback(). + */ + protected _notifyCallback?: RegistryNotifyCallback; + + /** + * Sets or updates the notification callback. + * This allows the callback to be bound after construction (e.g., by McpServer + * when using registries created by the builder). + * + * @param callback - The callback to invoke when the registry changes + */ + setNotifyCallback(callback: RegistryNotifyCallback): void { + this._notifyCallback = callback; + } + + /** + * Called when the registry contents change. + * Invokes the notification callback if one is set. + */ + protected notifyChanged(): void { + this._notifyCallback?.(); + } + + /** + * Checks if an item with the given ID exists in the registry. + * + * @param id - The identifier to check + * @returns true if the item exists + */ + has(id: string): boolean { + return this._items.has(id); + } + + /** + * Gets an item by its ID. + * + * @param id - The identifier of the item + * @returns The item or undefined if not found + */ + get(id: string): T | undefined { + return this._items.get(id); + } + + /** + * Gets all items in the registry as a read-only map. + * + * @returns A read-only map of all items + */ + getAll(): ReadonlyMap { + return this._items; + } + + /** + * Gets all items as an array. + * + * @returns Array of all registered items + */ + values(): T[] { + return [...this._items.values()]; + } + + /** + * Gets all enabled items as an array. + * + * @returns Array of enabled items + */ + getEnabled(): T[] { + return this.values().filter(item => item.enabled); + } + + /** + * Gets all disabled items as an array. + * + * @returns Array of disabled items + */ + getDisabled(): T[] { + return this.values().filter(item => !item.enabled); + } + + /** + * Gets the number of items in the registry. + */ + get size(): number { + return this._items.size; + } + + /** + * Removes an item from the registry. + * + * @param id - The identifier of the item to remove + * @returns true if the item was removed, false if it didn't exist + */ + remove(id: string): boolean { + const deleted = this._items.delete(id); + if (deleted) { + this.notifyChanged(); + } + return deleted; + } + + /** + * Disables all items in the registry. + */ + disableAll(): void { + let changed = false; + for (const item of this._items.values()) { + if (item.enabled) { + item.disable(); + changed = true; + } + } + if (changed) { + this.notifyChanged(); + } + } + + /** + * Enables all items in the registry. + */ + enableAll(): void { + let changed = false; + for (const item of this._items.values()) { + if (!item.enabled) { + item.enable(); + changed = true; + } + } + if (changed) { + this.notifyChanged(); + } + } + + /** + * Clears all items from the registry. + */ + clear(): void { + if (this._items.size > 0) { + this._items.clear(); + this.notifyChanged(); + } + } + + /** + * Internal method to add or update an item in the registry. + * Used by subclasses during registration. + * + * @param id - The identifier for the item + * @param item - The item to add + */ + protected _set(id: string, item: T): void { + this._items.set(id, item); + } + + /** + * Internal method to rename an item in the registry. + * + * @param oldId - The current identifier + * @param newId - The new identifier + * @returns true if renamed successfully + */ + protected _rename(oldId: string, newId: string): boolean { + const item = this._items.get(oldId); + if (!item) { + return false; + } + if (oldId === newId) { + return true; + } + if (this._items.has(newId)) { + throw new Error(`Cannot rename: '${newId}' already exists`); + } + this._items.delete(oldId); + this._items.set(newId, item); + this.notifyChanged(); + return true; + } +} diff --git a/packages/server/src/server/registries/index.ts b/packages/server/src/server/registries/index.ts new file mode 100644 index 000000000..f0ba60a77 --- /dev/null +++ b/packages/server/src/server/registries/index.ts @@ -0,0 +1,10 @@ +/** + * Registries Module + * + * Exports registry classes and entities for managing tools, resources, and prompts. + */ + +export * from './baseRegistry.js'; +export * from './promptRegistry.js'; +export * from './resourceRegistry.js'; +export * from './toolRegistry.js'; diff --git a/packages/server/src/server/registries/promptRegistry.ts b/packages/server/src/server/registries/promptRegistry.ts new file mode 100644 index 000000000..06f7d6b11 --- /dev/null +++ b/packages/server/src/server/registries/promptRegistry.ts @@ -0,0 +1,242 @@ +/** + * Prompt Registry + * + * Manages registration and retrieval of prompts. + * Provides class-based RegisteredPromptEntity entities with proper encapsulation. + */ + +import type { AnyObjectSchema, AnySchema, Prompt, PromptArgument, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import { getObjectShape, getSchemaDescription, isSchemaOptional, objectFromShape } from '@modelcontextprotocol/core'; + +import type { PromptCallback } from '../mcp.js'; +import type { RegisteredDefinition } from './baseRegistry.js'; +import { BaseRegistry } from './baseRegistry.js'; + +/** + * Configuration for registering a prompt + */ +export interface PromptConfig { + name: string; + title?: string; + description?: string; + argsSchema?: ZodRawShapeCompat; + callback: PromptCallback; +} + +/** + * Updates that can be applied to a registered prompt + */ +export interface PromptUpdates { + name?: string | null; + title?: string; + description?: string; + argsSchema?: ZodRawShapeCompat; + callback?: PromptCallback; + enabled?: boolean; +} + +/** + * Class-based representation of a registered prompt. + * Provides methods for managing the prompt's lifecycle. + */ +export class RegisteredPromptEntity implements RegisteredDefinition { + private _name: string; + private _enabled: boolean = true; + private readonly _registry: PromptRegistry; + + private _title?: string; + private _description?: string; + private _argsSchema?: AnyObjectSchema; + private _callback: PromptCallback; + + constructor(config: PromptConfig, registry: PromptRegistry) { + this._name = config.name; + this._registry = registry; + this._title = config.title; + this._description = config.description; + this._argsSchema = config.argsSchema ? objectFromShape(config.argsSchema) : undefined; + this._callback = config.callback; + } + + /** The prompt's name (identifier) */ + get name(): string { + return this._name; + } + + /** Whether the prompt is currently enabled */ + get enabled(): boolean { + return this._enabled; + } + + /** The prompt's title */ + get title(): string | undefined { + return this._title; + } + + /** The prompt's description */ + get description(): string | undefined { + return this._description; + } + + /** The prompt's args schema */ + get argsSchema(): AnyObjectSchema | undefined { + return this._argsSchema; + } + + /** The prompt's callback */ + get callback(): PromptCallback { + return this._callback; + } + + /** + * Enables the prompt + */ + enable(): this { + if (!this._enabled) { + this._enabled = true; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Disables the prompt + */ + disable(): this { + if (this._enabled) { + this._enabled = false; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Removes the prompt from its registry + */ + remove(): void { + this._registry.remove(this._name); + } + + /** + * Renames the prompt + * + * @param newName - The new name for the prompt + */ + rename(newName: string): this { + this._registry['_rename'](this._name, newName); + this._name = newName; + return this; + } + + /** + * Updates the prompt's properties + * + * @param updates - The updates to apply + */ + update(updates: PromptUpdates): void { + if (updates.name !== undefined) { + if (updates.name === null) { + this.remove(); + return; + } + this.rename(updates.name); + } + if (updates.title !== undefined) this._title = updates.title; + if (updates.description !== undefined) this._description = updates.description; + if (updates.argsSchema !== undefined) this._argsSchema = objectFromShape(updates.argsSchema); + if (updates.callback !== undefined) this._callback = updates.callback; + if (updates.enabled === undefined) { + this._registry['notifyChanged'](); + } else { + if (updates.enabled) { + this.enable(); + } else { + this.disable(); + } + } + } + + /** + * Converts to the Prompt protocol type (for list responses) + */ + toProtocolPrompt(): Prompt { + return { + name: this._name, + title: this._title, + description: this._description, + arguments: this._argsSchema ? promptArgumentsFromSchema(this._argsSchema) : undefined + }; + } +} + +/** + * Registry for managing prompts. + */ +export class PromptRegistry extends BaseRegistry { + /** + * Creates a new PromptRegistry. + * + * @param sendNotification - Optional callback to invoke when the prompt list changes. + * Can be set later via setNotifyCallback(). + */ + constructor(sendNotification?: () => void) { + super(); + if (sendNotification) { + this.setNotifyCallback(sendNotification); + } + } + + /** + * Registers a new prompt. + * + * @param config - The prompt configuration + * @returns The registered prompt + * @throws If a prompt with the same name already exists + */ + register(config: PromptConfig): RegisteredPromptEntity { + if (this._items.has(config.name)) { + throw new Error(`Prompt '${config.name}' is already registered`); + } + + const prompt = new RegisteredPromptEntity(config, this); + this._set(config.name, prompt); + this.notifyChanged(); + return prompt; + } + + /** + * Gets the list of enabled prompts in protocol format. + * + * @returns Array of Prompt objects for the protocol response + */ + getProtocolPrompts(): Prompt[] { + return this.getEnabled().map(prompt => prompt.toProtocolPrompt()); + } + + /** + * Gets a prompt by name. + * + * @param name - The prompt name + * @returns The registered prompt or undefined + */ + getPrompt(name: string): RegisteredPromptEntity | undefined { + return this.get(name); + } +} + +/** + * Converts a Zod object schema to an array of PromptArgument for the protocol. + */ +function promptArgumentsFromSchema(schema: AnyObjectSchema): PromptArgument[] { + const shape = getObjectShape(schema); + if (!shape) return []; + return Object.entries(shape).map(([name, field]): PromptArgument => { + const description = getSchemaDescription(field as AnySchema); + const isOptional = isSchemaOptional(field as AnySchema); + return { + name, + description, + required: !isOptional + }; + }); +} diff --git a/packages/server/src/server/registries/resourceRegistry.ts b/packages/server/src/server/registries/resourceRegistry.ts new file mode 100644 index 000000000..0468fda92 --- /dev/null +++ b/packages/server/src/server/registries/resourceRegistry.ts @@ -0,0 +1,496 @@ +/** + * Resource Registry + * + * Manages registration and retrieval of resources and resource templates. + * Provides class-based RegisteredResourceEntity entities with proper encapsulation. + */ + +import type { Resource, ResourceTemplateType as ResourceTemplateProtocol, Variables } from '@modelcontextprotocol/core'; + +import type { ReadResourceCallback, ReadResourceTemplateCallback, ResourceMetadata, ResourceTemplate } from '../mcp.js'; +import type { RegisteredDefinition } from './baseRegistry.js'; +import { BaseRegistry } from './baseRegistry.js'; + +/** + * Configuration for registering a static resource + */ +export interface ResourceConfig { + name: string; + uri: string; + title?: string; + description?: string; + mimeType?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; +} + +/** + * Configuration for registering a resource template + */ +export interface ResourceTemplateConfig { + name: string; + template: ResourceTemplate; + title?: string; + description?: string; + mimeType?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; +} + +/** + * Updates that can be applied to a registered resource + */ +export interface ResourceUpdates { + name?: string; + uri?: string | null; + title?: string; + description?: string; + mimeType?: string; + metadata?: ResourceMetadata; + callback?: ReadResourceCallback; + enabled?: boolean; +} + +/** + * Updates that can be applied to a registered resource template + */ +export interface ResourceTemplateUpdates { + name?: string | null; + title?: string; + description?: string; + mimeType?: string; + template?: ResourceTemplate; + metadata?: ResourceMetadata; + callback?: ReadResourceTemplateCallback; + enabled?: boolean; +} + +/** + * Class-based representation of a registered static resource. + * Provides methods for managing the resource's lifecycle. + */ +export class RegisteredResourceEntity implements RegisteredDefinition { + private _name: string; + private _uri: string; + private _enabled: boolean = true; + private readonly _registry: ResourceRegistry; + + private _title?: string; + private _description?: string; + private _mimeType?: string; + private _metadata?: ResourceMetadata; + private _readCallback: ReadResourceCallback; + + constructor(config: ResourceConfig, registry: ResourceRegistry) { + this._name = config.name; + this._uri = config.uri; + this._registry = registry; + this._title = config.title; + this._description = config.description; + this._mimeType = config.mimeType; + this._metadata = config.metadata; + this._readCallback = config.readCallback; + } + + /** The resource's name */ + get name(): string { + return this._name; + } + + /** The resource's URI */ + get uri(): string { + return this._uri; + } + + /** Whether the resource is currently enabled */ + get enabled(): boolean { + return this._enabled; + } + + /** The resource's title */ + get title(): string | undefined { + return this._title; + } + + /** The resource's description */ + get description(): string | undefined { + return this._description; + } + + /** The resource's MIME type */ + get mimeType(): string | undefined { + return this._mimeType; + } + + /** The resource's metadata */ + get metadata(): ResourceMetadata | undefined { + return this._metadata; + } + + /** The resource's read callback */ + get readCallback(): ReadResourceCallback { + return this._readCallback; + } + + /** + * Enables the resource + */ + enable(): this { + if (!this._enabled) { + this._enabled = true; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Disables the resource + */ + disable(): this { + if (this._enabled) { + this._enabled = false; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Removes the resource from its registry + */ + remove(): void { + this._registry.remove(this._uri); + } + + /** + * Updates the resource's properties + * + * @param updates - The updates to apply + */ + update(updates: ResourceUpdates): void { + if (updates.uri !== undefined) { + if (updates.uri === null) { + this.remove(); + return; + } + // Handle URI change - need to re-register under new URI + const oldUri = this._uri; + this._uri = updates.uri; + this._registry['_items'].delete(oldUri); + this._registry['_items'].set(updates.uri, this); + } + if (updates.name !== undefined) this._name = updates.name; + if (updates.title !== undefined) this._title = updates.title; + if (updates.description !== undefined) this._description = updates.description; + if (updates.mimeType !== undefined) this._mimeType = updates.mimeType; + if (updates.metadata !== undefined) this._metadata = updates.metadata; + if (updates.callback !== undefined) this._readCallback = updates.callback; + if (updates.enabled === undefined) { + this._registry['notifyChanged'](); + } else { + if (updates.enabled) { + this.enable(); + } else { + this.disable(); + } + } + } + + /** + * Converts to the Resource protocol type (for list responses) + */ + toProtocolResource(): Resource { + return { + uri: this._uri, + name: this._name, + title: this._title, + description: this._description, + mimeType: this._mimeType, + ...this._metadata + }; + } +} + +/** + * Class-based representation of a registered resource template. + * Provides methods for managing the template's lifecycle. + */ +export class RegisteredResourceTemplateEntity implements RegisteredDefinition { + private _name: string; + private _enabled: boolean = true; + private readonly _registry: ResourceTemplateRegistry; + + private _title?: string; + private _description?: string; + private _mimeType?: string; + private _metadata?: ResourceMetadata; + private _template: ResourceTemplate; + private _readCallback: ReadResourceTemplateCallback; + + constructor(config: ResourceTemplateConfig, registry: ResourceTemplateRegistry) { + this._name = config.name; + this._registry = registry; + this._title = config.title; + this._description = config.description; + this._mimeType = config.mimeType; + this._metadata = config.metadata; + this._template = config.template; + this._readCallback = config.readCallback; + } + + /** The template's name (identifier) */ + get name(): string { + return this._name; + } + + /** Whether the template is currently enabled */ + get enabled(): boolean { + return this._enabled; + } + + /** The template's title */ + get title(): string | undefined { + return this._title; + } + + /** The template's description */ + get description(): string | undefined { + return this._description; + } + + /** The template's MIME type */ + get mimeType(): string | undefined { + return this._mimeType; + } + + /** The template's metadata */ + get metadata(): ResourceMetadata | undefined { + return this._metadata; + } + + /** The resource template */ + get template(): ResourceTemplate { + return this._template; + } + + /** Alias for template for backward compatibility */ + get resourceTemplate(): ResourceTemplate { + return this._template; + } + + /** The template's read callback */ + get readCallback(): ReadResourceTemplateCallback { + return this._readCallback; + } + + /** + * Enables the template + */ + enable(): this { + if (!this._enabled) { + this._enabled = true; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Disables the template + */ + disable(): this { + if (this._enabled) { + this._enabled = false; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Removes the template from its registry + */ + remove(): void { + this._registry.remove(this._name); + } + + /** + * Renames the template + * + * @param newName - The new name for the template + */ + rename(newName: string): this { + this._registry['_rename'](this._name, newName); + this._name = newName; + return this; + } + + /** + * Updates the template's properties + * + * @param updates - The updates to apply + */ + update(updates: ResourceTemplateUpdates): void { + if (updates.name !== undefined) { + if (updates.name === null) { + this.remove(); + return; + } + this.rename(updates.name); + } + if (updates.title !== undefined) this._title = updates.title; + if (updates.description !== undefined) this._description = updates.description; + if (updates.mimeType !== undefined) this._mimeType = updates.mimeType; + if (updates.metadata !== undefined) this._metadata = updates.metadata; + if (updates.template !== undefined) this._template = updates.template; + if (updates.callback !== undefined) this._readCallback = updates.callback; + if (updates.enabled === undefined) { + this._registry['notifyChanged'](); + } else { + if (updates.enabled) { + this.enable(); + } else { + this.disable(); + } + } + } + + /** + * Converts to the ResourceTemplate protocol type (for list responses) + */ + toProtocolResourceTemplate(): ResourceTemplateProtocol { + return { + name: this._name, + uriTemplate: this._template.uriTemplate.toString(), + title: this._title, + description: this._description, + mimeType: this._mimeType, + ...this._metadata + }; + } +} + +/** + * Registry for managing static resources. + * Resources are keyed by URI. + */ +export class ResourceRegistry extends BaseRegistry { + /** + * Creates a new ResourceRegistry. + * + * @param sendNotification - Optional callback to invoke when the resource list changes. + * Can be set later via setNotifyCallback(). + */ + constructor(sendNotification?: () => void) { + super(); + if (sendNotification) { + this.setNotifyCallback(sendNotification); + } + } + + /** + * Registers a new resource. + * + * @param config - The resource configuration + * @returns The registered resource + * @throws If a resource with the same URI already exists + */ + register(config: ResourceConfig): RegisteredResourceEntity { + if (this._items.has(config.uri)) { + throw new Error(`Resource '${config.uri}' is already registered`); + } + + const resource = new RegisteredResourceEntity(config, this); + this._set(config.uri, resource); + this.notifyChanged(); + return resource; + } + + /** + * Gets the list of enabled resources in protocol format. + * + * @returns Array of Resource objects for the protocol response + */ + getProtocolResources(): Resource[] { + return this.getEnabled().map(resource => resource.toProtocolResource()); + } + + /** + * Gets a resource by URI. + * + * @param uri - The resource URI + * @returns The registered resource or undefined + */ + getResource(uri: string): RegisteredResourceEntity | undefined { + return this.get(uri); + } +} + +/** + * Registry for managing resource templates. + * Templates are keyed by name. + */ +export class ResourceTemplateRegistry extends BaseRegistry { + /** + * Creates a new ResourceTemplateRegistry. + * + * @param sendNotification - Optional callback to invoke when the template list changes. + * Can be set later via setNotifyCallback(). + */ + constructor(sendNotification?: () => void) { + super(); + if (sendNotification) { + this.setNotifyCallback(sendNotification); + } + } + + /** + * Registers a new resource template. + * + * @param config - The template configuration + * @returns The registered template + * @throws If a template with the same name already exists + */ + register(config: ResourceTemplateConfig): RegisteredResourceTemplateEntity { + if (this._items.has(config.name)) { + throw new Error(`Resource template '${config.name}' is already registered`); + } + + const template = new RegisteredResourceTemplateEntity(config, this); + this._set(config.name, template); + this.notifyChanged(); + return template; + } + + /** + * Gets the list of enabled templates in protocol format. + * + * @returns Array of ResourceTemplate objects for the protocol response + */ + getProtocolResourceTemplates(): ResourceTemplateProtocol[] { + return this.getEnabled().map(template => template.toProtocolResourceTemplate()); + } + + /** + * Gets a template by name. + * + * @param name - The template name + * @returns The registered template or undefined + */ + getTemplate(name: string): RegisteredResourceTemplateEntity | undefined { + return this.get(name); + } + + /** + * Finds a template that matches the given URI. + * + * @param uri - The URI to match against templates + * @returns The matching template and extracted variables, or undefined + */ + findMatchingTemplate(uri: string): { template: RegisteredResourceTemplateEntity; variables: Variables } | undefined { + for (const template of this.getEnabled()) { + const variables = template.template.uriTemplate.match(uri); + if (variables) { + return { template, variables }; + } + } + return undefined; + } +} diff --git a/packages/server/src/server/registries/toolRegistry.ts b/packages/server/src/server/registries/toolRegistry.ts new file mode 100644 index 000000000..360e26638 --- /dev/null +++ b/packages/server/src/server/registries/toolRegistry.ts @@ -0,0 +1,297 @@ +/** + * Tool Registry + * + * Manages registration and retrieval of tools. + * Provides class-based RegisteredTool entities with proper encapsulation. + */ + +import type { AnySchema, Tool, ToolAnnotations, ToolExecution, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import { normalizeObjectSchema, toJsonSchemaCompat, validateAndWarnToolName } from '@modelcontextprotocol/core'; + +import type { AnyToolHandler } from '../mcp.js'; +import type { RegisteredDefinition } from './baseRegistry.js'; +import { BaseRegistry } from './baseRegistry.js'; + +/** + * Tool handler type - compatible with both ToolCallback and ToolTaskHandler + */ +export type ToolHandler = AnyToolHandler; + +/** + * Configuration for registering a tool + */ +export interface ToolConfig { + name: string; + title?: string; + description?: string; + inputSchema?: AnySchema; + outputSchema?: AnySchema; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + handler: ToolHandler; +} + +/** + * Updates that can be applied to a registered tool + */ +export interface ToolUpdates { + name?: string | null; + title?: string; + description?: string; + inputSchema?: AnySchema; + outputSchema?: AnySchema; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + handler?: ToolHandler; + enabled?: boolean; +} + +const EMPTY_OBJECT_JSON_SCHEMA = { + type: 'object' as const, + properties: {} +}; + +/** + * Class-based representation of a registered tool. + * Provides methods for managing the tool's lifecycle. + */ +export class RegisteredToolEntity implements RegisteredDefinition { + private _name: string; + private _enabled: boolean = true; + private readonly _registry: ToolRegistry; + + private _title?: string; + private _description?: string; + private _inputSchema?: AnySchema; + private _outputSchema?: AnySchema; + private _annotations?: ToolAnnotations; + private _execution?: ToolExecution; + private __meta?: Record; + private _handler: ToolHandler; + + constructor(config: ToolConfig, registry: ToolRegistry) { + this._name = config.name; + this._registry = registry; + this._title = config.title; + this._description = config.description; + this._inputSchema = config.inputSchema; + this._outputSchema = config.outputSchema; + this._annotations = config.annotations; + this._execution = config.execution; + this.__meta = config._meta; + this._handler = config.handler; + } + + /** The tool's name (identifier) */ + get name(): string { + return this._name; + } + + /** Whether the tool is currently enabled */ + get enabled(): boolean { + return this._enabled; + } + + /** The tool's title */ + get title(): string | undefined { + return this._title; + } + + /** The tool's description */ + get description(): string | undefined { + return this._description; + } + + /** The tool's input schema */ + get inputSchema(): AnySchema | undefined { + return this._inputSchema; + } + + /** The tool's output schema */ + get outputSchema(): AnySchema | undefined { + return this._outputSchema; + } + + /** The tool's annotations */ + get annotations(): ToolAnnotations | undefined { + return this._annotations; + } + + /** The tool's execution settings */ + get execution(): ToolExecution | undefined { + return this._execution; + } + + /** The tool's metadata */ + get _meta(): Record | undefined { + return this.__meta; + } + + /** The tool's handler function */ + get handler(): ToolHandler { + return this._handler; + } + + /** + * Enables the tool + */ + enable(): this { + if (!this._enabled) { + this._enabled = true; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Disables the tool + */ + disable(): this { + if (this._enabled) { + this._enabled = false; + this._registry['notifyChanged'](); + } + return this; + } + + /** + * Removes the tool from its registry + */ + remove(): void { + this._registry.remove(this._name); + } + + /** + * Renames the tool + * + * @param newName - The new name for the tool + */ + rename(newName: string): this { + validateAndWarnToolName(newName); + this._registry['_rename'](this._name, newName); + this._name = newName; + return this; + } + + /** + * Updates the tool's properties + * + * @param updates - The updates to apply + */ + update(updates: ToolUpdates): void { + if (updates.name !== undefined) { + if (updates.name === null) { + this.remove(); + return; + } + this.rename(updates.name); + } + if (updates.title !== undefined) this._title = updates.title; + if (updates.description !== undefined) this._description = updates.description; + if (updates.inputSchema !== undefined) this._inputSchema = updates.inputSchema; + if (updates.outputSchema !== undefined) this._outputSchema = updates.outputSchema; + if (updates.annotations !== undefined) this._annotations = updates.annotations; + if (updates.execution !== undefined) this._execution = updates.execution; + if (updates._meta !== undefined) this.__meta = updates._meta; + if (updates.handler !== undefined) this._handler = updates.handler; + if (updates.enabled === undefined) { + this._registry['notifyChanged'](); + } else { + if (updates.enabled) { + this.enable(); + } else { + this.disable(); + } + } + } + + /** + * Converts to the Tool protocol type (for list responses) + */ + toProtocolTool(): Tool { + const tool: Tool = { + name: this._name, + title: this._title, + description: this._description, + inputSchema: this._inputSchema + ? (toJsonSchemaCompat(normalizeObjectSchema(this._inputSchema) ?? this._inputSchema, { + strictUnions: true, + pipeStrategy: 'input' + }) as Tool['inputSchema']) + : EMPTY_OBJECT_JSON_SCHEMA, + annotations: this._annotations, + execution: this._execution, + _meta: this.__meta + }; + + if (this._outputSchema) { + const obj = normalizeObjectSchema(this._outputSchema); + if (obj) { + tool.outputSchema = toJsonSchemaCompat(obj, { + strictUnions: true, + pipeStrategy: 'output' + }) as Tool['outputSchema']; + } + } + + return tool; + } +} + +/** + * Registry for managing tools. + */ +export class ToolRegistry extends BaseRegistry { + /** + * Creates a new ToolRegistry. + * + * @param sendNotification - Optional callback to invoke when the tool list changes. + * Can be set later via setNotifyCallback(). + */ + constructor(sendNotification?: () => void) { + super(); + if (sendNotification) { + this.setNotifyCallback(sendNotification); + } + } + + /** + * Registers a new tool. + * + * @param config - The tool configuration + * @returns The registered tool + * @throws If a tool with the same name already exists + */ + register(config: ToolConfig): RegisteredToolEntity { + if (this._items.has(config.name)) { + throw new Error(`Tool '${config.name}' is already registered`); + } + + validateAndWarnToolName(config.name); + const tool = new RegisteredToolEntity(config, this); + this._set(config.name, tool); + this.notifyChanged(); + return tool; + } + + /** + * Gets the list of enabled tools in protocol format. + * + * @returns Array of Tool objects for the protocol response + */ + getProtocolTools(): Tool[] { + return this.getEnabled().map(tool => tool.toProtocolTool()); + } + + /** + * Gets a tool by name. + * + * @param name - The tool name + * @returns The registered tool or undefined + */ + getTool(name: string): RegisteredToolEntity | undefined { + return this.get(name); + } +} diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 36dc55835..cd2ff6a0e 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -11,6 +11,7 @@ import type { ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, + ErrorInterceptor, Implementation, InitializeRequest, InitializeResult, @@ -34,9 +35,6 @@ import type { ServerNotification, ServerRequest, ServerResult, - TaskContext, - TaskCreationParams, - TaskStore, ToolResultContent, ToolUseContent, Transport, @@ -230,6 +228,32 @@ export class Server< this._capabilities = mergeCapabilities(this._capabilities, capabilities); } + /** + * Sets an error interceptor that can customize error responses before they are sent. + * + * The interceptor is called for both protocol errors (method not found, etc.) and + * application errors (when a handler throws). It can modify the error message and data. + * For application errors, it can also change the error code. + * + * @param interceptor - The error interceptor function, or undefined to clear + * + * @example + * ```typescript + * server.setErrorInterceptor(async (error, ctx) => { + * console.error(`Error in ${ctx.method}: ${error.message}`); + * if (ctx.type === 'application') { + * return { + * message: 'An error occurred', + * data: { originalMessage: error.message } + * }; + * } + * }); + * ``` + */ + public override setErrorInterceptor(interceptor: ErrorInterceptor | undefined): void { + super.setErrorInterceptor(interceptor); + } + /** * Override request handler registration to enforce server-side validation for tools/call. */ @@ -535,14 +559,11 @@ export class Server< protected createRequestContext(args: { request: JSONRPCRequest; - taskStore: TaskStore | undefined; - relatedTaskId: string | undefined; - taskCreationParams: TaskCreationParams | undefined; abortController: AbortController; capturedTransport: Transport | undefined; extra?: MessageExtraInfo; }): ContextInterface { - const { request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra } = args; + const { request, abortController, capturedTransport, extra } = args; const sessionId = capturedTransport?.sessionId; // Build the MCP context using the helper from Protocol @@ -561,22 +582,12 @@ export class Server< } }; - // Build the task context using the helper from Protocol - const taskCtx: TaskContext | undefined = this.buildTaskContext({ - taskStore, - request, - sessionId, - relatedTaskId, - taskCreationParams - }); - - // Return a ServerContext instance + // Return a ServerContext instance (task context is added by plugins if needed) return new ServerContext({ server: this, request, mcpContext, - requestCtx, - task: taskCtx + requestCtx }); } diff --git a/packages/server/src/server/sessions.ts b/packages/server/src/server/sessions.ts new file mode 100644 index 000000000..09e835eaa --- /dev/null +++ b/packages/server/src/server/sessions.ts @@ -0,0 +1,344 @@ +/** + * Session Management Abstraction + * + * Provides a SessionStore interface and implementations for managing + * server session state. This replaces the manual session map management + * patterns seen across examples. + */ + +/** + * Session lifecycle event callbacks + */ +export interface SessionStoreEvents { + /** + * Called when a new session is created + */ + onSessionCreated?: (sessionId: string, data: T) => void; + + /** + * Called when a session is destroyed + */ + onSessionDestroyed?: (sessionId: string) => void; + + /** + * Called when session data is updated + */ + onSessionUpdated?: (sessionId: string, data: T) => void; +} + +/** + * Interface for session storage implementations. + * + * @template T - The type of session data + */ +export interface SessionStore { + /** + * Gets the session data for a given session ID. + * + * @param sessionId - The session identifier + * @returns The session data or undefined if not found + */ + get(sessionId: string): T | undefined; + + /** + * Sets the session data for a given session ID. + * Creates a new session if it doesn't exist. + * + * @param sessionId - The session identifier + * @param data - The session data to store + */ + set(sessionId: string, data: T): void; + + /** + * Deletes a session. + * + * @param sessionId - The session identifier + * @returns true if the session was deleted, false if it didn't exist + */ + delete(sessionId: string): boolean; + + /** + * Checks if a session exists. + * + * @param sessionId - The session identifier + * @returns true if the session exists + */ + has(sessionId: string): boolean; + + /** + * Gets the number of active sessions. + */ + size(): number; + + /** + * Gets all session IDs. + */ + keys(): string[]; + + /** + * Clears all sessions. + */ + clear(): void; +} + +/** + * Options for InMemorySessionStore + */ +export interface InMemorySessionStoreOptions { + /** + * Maximum number of sessions to store. + * When exceeded, the oldest session will be evicted. + * Default: unlimited + */ + maxSessions?: number; + + /** + * Session timeout in milliseconds. + * Sessions older than this will be automatically cleaned up. + * Default: no timeout + */ + sessionTimeout?: number; + + /** + * Interval for checking expired sessions in milliseconds. + * Default: 60000 (1 minute) + */ + cleanupInterval?: number; + + /** + * Event callbacks + */ + events?: SessionStoreEvents; +} + +/** + * Internal session entry with metadata + */ +interface SessionEntry { + data: T; + createdAt: number; + lastAccessedAt: number; +} + +/** + * In-memory implementation of SessionStore. + * + * Features: + * - Optional maximum session limit with LRU eviction + * - Optional session timeout with automatic cleanup + * - Lifecycle event callbacks + * + * @template T - The type of session data + */ +export class InMemorySessionStore implements SessionStore { + private _sessions = new Map>(); + private _options: InMemorySessionStoreOptions; + private _cleanupTimer?: ReturnType; + + constructor(options: InMemorySessionStoreOptions = {}) { + this._options = options; + + // Set up automatic cleanup if timeout is configured + if (options.sessionTimeout && options.sessionTimeout > 0) { + const interval = options.cleanupInterval ?? 60_000; + this._cleanupTimer = setInterval(() => { + this._cleanupExpiredSessions(); + }, interval); + + // Prevent timer from keeping process alive + if (typeof this._cleanupTimer.unref === 'function') { + this._cleanupTimer.unref(); + } + } + } + + /** + * Gets the session data for a given session ID. + * Updates the last accessed time on access. + */ + get(sessionId: string): T | undefined { + const entry = this._sessions.get(sessionId); + if (!entry) { + return undefined; + } + + // Check if expired + if (this._isExpired(entry)) { + this.delete(sessionId); + return undefined; + } + + // Update last accessed time + entry.lastAccessedAt = Date.now(); + return entry.data; + } + + /** + * Sets the session data for a given session ID. + * Creates a new session if it doesn't exist. + */ + set(sessionId: string, data: T): void { + const existing = this._sessions.get(sessionId); + const now = Date.now(); + + if (existing) { + // Update existing session + existing.data = data; + existing.lastAccessedAt = now; + this._options.events?.onSessionUpdated?.(sessionId, data); + } else { + // Create new session + // Check max sessions limit + if (this._options.maxSessions && this._sessions.size >= this._options.maxSessions) { + this._evictOldestSession(); + } + + this._sessions.set(sessionId, { + data, + createdAt: now, + lastAccessedAt: now + }); + this._options.events?.onSessionCreated?.(sessionId, data); + } + } + + /** + * Deletes a session. + */ + delete(sessionId: string): boolean { + const deleted = this._sessions.delete(sessionId); + if (deleted) { + this._options.events?.onSessionDestroyed?.(sessionId); + } + return deleted; + } + + /** + * Checks if a session exists. + */ + has(sessionId: string): boolean { + const entry = this._sessions.get(sessionId); + if (!entry) { + return false; + } + + // Check if expired + if (this._isExpired(entry)) { + this.delete(sessionId); + return false; + } + + return true; + } + + /** + * Gets the number of active sessions. + */ + size(): number { + return this._sessions.size; + } + + /** + * Gets all session IDs. + */ + keys(): string[] { + return [...this._sessions.keys()]; + } + + /** + * Clears all sessions. + */ + clear(): void { + const sessionIds = this.keys(); + for (const sessionId of sessionIds) { + this.delete(sessionId); + } + } + + /** + * Stops the cleanup timer. + * Call this when the store is no longer needed. + */ + dispose(): void { + if (this._cleanupTimer) { + clearInterval(this._cleanupTimer); + this._cleanupTimer = undefined; + } + } + + /** + * Checks if a session entry is expired. + */ + private _isExpired(entry: SessionEntry): boolean { + if (!this._options.sessionTimeout) { + return false; + } + const age = Date.now() - entry.lastAccessedAt; + return age > this._options.sessionTimeout; + } + + /** + * Evicts the oldest session (by last access time). + */ + private _evictOldestSession(): void { + let oldestId: string | undefined; + let oldestTime = Infinity; + + for (const [id, entry] of this._sessions) { + if (entry.lastAccessedAt < oldestTime) { + oldestTime = entry.lastAccessedAt; + oldestId = id; + } + } + + if (oldestId) { + this.delete(oldestId); + } + } + + /** + * Cleans up expired sessions. + */ + private _cleanupExpiredSessions(): void { + for (const [sessionId, entry] of this._sessions) { + if (this._isExpired(entry)) { + this.delete(sessionId); + } + } + } +} + +/** + * Creates a new in-memory session store. + * + * @example + * ```typescript + * const sessionStore = createSessionStore<{ userId: string }>({ + * sessionTimeout: 30 * 60 * 1000, // 30 minutes + * maxSessions: 1000, + * events: { + * onSessionCreated: (id) => console.log(`Session created: ${id}`), + * onSessionDestroyed: (id) => console.log(`Session destroyed: ${id}`), + * } + * }); + * ``` + */ +export function createSessionStore(options?: InMemorySessionStoreOptions): InMemorySessionStore { + return new InMemorySessionStore(options); +} + +/** + * Session ID generator using crypto.randomUUID. + * Falls back to Math.random if crypto is not available. + */ +export function generateSessionId(): string { + if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') { + return crypto.randomUUID(); + } + // Fallback for environments without crypto.randomUUID + return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replaceAll(/[xy]/g, c => { + const r = Math.trunc(Math.random() * 16); + const v = c === 'x' ? r : (r & 0x3) | 0x8; + return v.toString(16); + }); +} diff --git a/src/conformance/everything-server.ts b/src/conformance/everything-server.ts index 411aebd5e..a7d7a8f21 100644 --- a/src/conformance/everything-server.ts +++ b/src/conformance/everything-server.ts @@ -11,13 +11,17 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult, GetPromptResult, ReadResourceResult, EventId, EventStore, StreamId } from '@modelcontextprotocol/server'; import { + audio, CompleteRequestSchema, ElicitResultSchema, + embeddedResource, + image, isInitializeRequest, - SetLevelRequestSchema, McpServer, ResourceTemplate, + SetLevelRequestSchema, SubscribeRequestSchema, + text, UnsubscribeRequestSchema } from '@modelcontextprotocol/server'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; @@ -127,7 +131,7 @@ function createMcpServer(sessionId?: string) { }, async (): Promise => { return { - content: [{ type: 'text', text: 'This is a simple text response for testing.' }] + content: [text('This is a simple text response for testing.')] }; } ); @@ -140,7 +144,7 @@ function createMcpServer(sessionId?: string) { }, async (): Promise => { return { - content: [{ type: 'image', data: TEST_IMAGE_BASE64, mimeType: 'image/png' }] + content: [image(TEST_IMAGE_BASE64, 'image/png')] }; } ); @@ -153,7 +157,7 @@ function createMcpServer(sessionId?: string) { }, async (): Promise => { return { - content: [{ type: 'audio', data: TEST_AUDIO_BASE64, mimeType: 'audio/wav' }] + content: [audio(TEST_AUDIO_BASE64, 'audio/wav')] }; } ); @@ -167,14 +171,11 @@ function createMcpServer(sessionId?: string) { async (): Promise => { return { content: [ - { - type: 'resource', - resource: { - uri: 'test://embedded-resource', - mimeType: 'text/plain', - text: 'This is an embedded resource content.' - } - } + embeddedResource({ + uri: 'test://embedded-resource', + mimeType: 'text/plain', + text: 'This is an embedded resource content.' + }) ] }; } @@ -189,16 +190,13 @@ function createMcpServer(sessionId?: string) { async (): Promise => { return { content: [ - { type: 'text', text: 'Multiple content types test:' }, - { type: 'image', data: TEST_IMAGE_BASE64, mimeType: 'image/png' }, - { - type: 'resource', - resource: { - uri: 'test://mixed-content-resource', - mimeType: 'application/json', - text: JSON.stringify({ test: 'data', value: 123 }) - } - } + text('Multiple content types test:'), + image(TEST_IMAGE_BASE64, 'image/png'), + embeddedResource({ + uri: 'test://mixed-content-resource', + mimeType: 'application/json', + text: JSON.stringify({ test: 'data', value: 123 }) + }) ] }; } @@ -238,7 +236,7 @@ function createMcpServer(sessionId?: string) { } }); return { - content: [{ type: 'text', text: 'Tool with logging executed successfully' }] + content: [text('Tool with logging executed successfully')] }; } ); @@ -286,7 +284,7 @@ function createMcpServer(sessionId?: string) { }); return { - content: [{ type: 'text', text: String(progressToken) }] + content: [text(String(progressToken))] }; } ); @@ -819,21 +817,15 @@ function createMcpServer(sessionId?: string) { messages: [ { role: 'user', - content: { - type: 'resource', - resource: { - uri: args.resourceUri, - mimeType: 'text/plain', - text: 'Embedded resource content for testing.' - } - } + content: embeddedResource({ + uri: args.resourceUri, + mimeType: 'text/plain', + text: 'Embedded resource content for testing.' + }) }, { role: 'user', - content: { - type: 'text', - text: 'Please process the embedded resource above.' - } + content: text('Please process the embedded resource above.') } ] }; @@ -852,15 +844,11 @@ function createMcpServer(sessionId?: string) { messages: [ { role: 'user', - content: { - type: 'image', - data: TEST_IMAGE_BASE64, - mimeType: 'image/png' - } + content: image(TEST_IMAGE_BASE64, 'image/png') }, { role: 'user', - content: { type: 'text', text: 'Please analyze the image above.' } + content: text('Please analyze the image above.') } ] }; diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index 9b75dfc04..082892fa3 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -19,6 +19,7 @@ import { ElicitResultSchema, ErrorCode, InMemoryTransport, + isTextContent, LATEST_PROTOCOL_VERSION, ListPromptsRequestSchema, ListResourcesRequestSchema, @@ -1952,7 +1953,7 @@ describe('createMessage backwards compatibility', () => { expect(result.model).toBe('test-model'); expect(Array.isArray(result.content)).toBe(false); expect(result.content.type).toBe('text'); - if (result.content.type === 'text') { + if (isTextContent(result.content)) { expect(result.content.text).toBe('Hello from LLM'); } }); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 85692b8b1..1f0a6eb16 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -22,6 +22,7 @@ import { ListToolsResultSchema, LoggingMessageNotificationSchema, ReadResourceResultSchema, + text, UriTemplate, UrlElicitationRequiredError } from '@modelcontextprotocol/core'; @@ -516,7 +517,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: '' }], + content: [text('')], structuredContent: { result: 42 } @@ -530,7 +531,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { sum: z.number() }, callback: async () => ({ - content: [{ type: 'text', text: '' }], + content: [text('')], structuredContent: { result: 42, sum: 100 @@ -660,7 +661,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { inputSchema: { name: z.string(), value: z.number() } }, async ({ name, value }) => ({ - content: [{ type: 'text', text: `${name}: ${value}` }] + content: [text(`${name}: ${value}`)] }) ); @@ -809,7 +810,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { annotations: { title: 'Test Tool', readOnlyHint: true } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [text(`Hello, ${name}!`)] }) ); @@ -856,7 +857,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [text(`Hello, ${name}!`)] }) ); @@ -904,7 +905,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Test response' }] + content: [text('Test response')] }) ); @@ -1709,7 +1710,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { _meta: metaData }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [text(`Hello, ${name}!`)] }) ); @@ -1745,7 +1746,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { inputSchema: { name: z.string() } }, async ({ name }) => ({ - content: [{ type: 'text', text: `Hello, ${name}!` }] + content: [text(`Hello, ${name}!`)] }) ); @@ -1917,7 +1918,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { { description: 'A valid tool name' }, - async () => ({ content: [{ type: 'text', text: 'Success' }] }) + async () => ({ content: [text('Success')] }) ); // Test tool name with warnings (starts with dash) @@ -1926,7 +1927,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { { description: 'A tool name that generates warnings' }, - async () => ({ content: [{ type: 'text', text: 'Success' }] }) + async () => ({ content: [text('Success')] }) ); // Test invalid tool name (contains spaces) @@ -1935,7 +1936,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { { description: 'An invalid tool name' }, - async () => ({ content: [{ type: 'text', text: 'Success' }] }) + async () => ({ content: [text('Success')] }) ); // Verify that warnings were issued (both for warnings and validation failures) @@ -4041,7 +4042,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Tool 1: Only name mcpServer.registerTool('tool_name_only', {}, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] })); // Tool 2: Name and annotations.title @@ -4054,7 +4055,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -4066,7 +4067,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { description: 'Tool with regular title' }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -4081,7 +4082,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -5334,7 +5335,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Tool 1: Only name mcpServer.registerTool('tool_name_only', {}, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] })); // Tool 2: Name and annotations.title @@ -5347,7 +5348,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -5359,7 +5360,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { description: 'Tool with regular title' }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -5374,7 +5375,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [text('Response')] }) ); @@ -5972,10 +5973,10 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { server.registerTool('contact', { inputSchema: unionSchema }, async args => { return args.type === 'email' ? { - content: [{ type: 'text', text: `Email contact: ${args.email}` }] + content: [text(`Email contact: ${args.email}`)] } : { - content: [{ type: 'text', text: `Phone contact: ${args.phone}` }] + content: [text(`Phone contact: ${args.phone}`)] }; }); @@ -6136,7 +6137,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { server.registerTool('union-test', { inputSchema: unionSchema }, async () => { return { - content: [{ type: 'text', text: 'Success' }] + content: [text('Success')] }; }); diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts index ce124eb93..f9847989c 100644 --- a/test/integration/test/taskResumability.test.ts +++ b/test/integration/test/taskResumability.test.ts @@ -5,7 +5,7 @@ import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; -import { CallToolResultSchema, LoggingMessageNotificationSchema, McpServer } from '@modelcontextprotocol/server'; +import { CallToolResultSchema, LoggingMessageNotificationSchema, McpServer, text } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; @@ -79,7 +79,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); return { - content: [{ type: 'text', text: 'Notification sent' }] + content: [text('Notification sent')] }; } ); @@ -112,7 +112,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { } return { - content: [{ type: 'text', text: `Sent ${count} notifications` }] + content: [text(`Sent ${count} notifications`)] }; } ); From 27989ca7e69ec6afd5cceb754f3347d3a3420a53 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Mon, 26 Jan 2026 09:01:01 +0200 Subject: [PATCH 16/17] save commit --- .../client/src/simpleStreamableHttpBuilder.ts | 23 +- .../server/src/simpleStreamableHttpBuilder.ts | 106 ++---- packages/client/src/client/builder.ts | 5 +- packages/client/src/client/middleware.ts | 2 +- .../src/experimental/tasks/mcpServer.ts | 3 +- packages/server/src/index.ts | 1 - packages/server/src/server/builder.ts | 3 +- packages/server/src/server/mcp.ts | 134 +------ packages/server/src/server/middleware.ts | 2 +- .../src/server/registries/promptRegistry.ts | 87 +++-- .../src/server/registries/resourceRegistry.ts | 3 +- .../src/server/registries/toolRegistry.ts | 156 ++++---- packages/server/src/server/sessions.ts | 344 ------------------ packages/server/src/types/types.ts | 155 ++++++++ 14 files changed, 348 insertions(+), 676 deletions(-) delete mode 100644 packages/server/src/server/sessions.ts create mode 100644 packages/server/src/types/types.ts diff --git a/examples/client/src/simpleStreamableHttpBuilder.ts b/examples/client/src/simpleStreamableHttpBuilder.ts index 4246f7355..32d8ba8b1 100644 --- a/examples/client/src/simpleStreamableHttpBuilder.ts +++ b/examples/client/src/simpleStreamableHttpBuilder.ts @@ -38,8 +38,7 @@ import { ListPromptsResultSchema, ListResourcesResultSchema, ListToolsResultSchema, - LoggingMessageNotificationSchema -, + LoggingMessageNotificationSchema, ReadResourceResultSchema, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; @@ -120,7 +119,6 @@ export function createClientLoggingMiddleware(options: ClientLoggingMiddlewareOp }; } - /** * Options for retry middleware. */ @@ -169,7 +167,6 @@ export function createRetryMiddleware(options: RetryMiddlewareOptions = {}): Out }; } - /** * Custom tool call instrumentation middleware. * Logs tool calls with timing information. @@ -260,7 +257,7 @@ function printHelp(): void { } function commandLoop(): void { - readline.question('\n> ', async (input) => { + readline.question('\n> ', async input => { const args = input.trim().split(/\s+/); const command = args[0]?.toLowerCase(); @@ -488,14 +485,10 @@ async function connect(url?: string): Promise { createRetryMiddleware({ maxRetries: 3, baseDelay: 100, - isRetryable: (error) => { + isRetryable: error => { // Retry on network errors const message = error instanceof Error ? error.message : String(error); - return ( - message.includes('ECONNREFUSED') || - message.includes('ETIMEDOUT') || - message.includes('network') - ); + return message.includes('ECONNREFUSED') || message.includes('ETIMEDOUT') || message.includes('network'); } }) ) @@ -506,7 +499,7 @@ async function connect(url?: string): Promise { // ─── Request Handlers ─── // Sampling request handler (when server requests LLM completion) - .onSamplingRequest(async (params) => { + .onSamplingRequest(async params => { console.log('\n[SAMPLING] Received sampling request from server'); console.log('[SAMPLING] Messages:', JSON.stringify(params, null, 2)); @@ -523,7 +516,7 @@ async function connect(url?: string): Promise { }) // Elicitation handler (when server requests user input) - .onElicitation(async (params) => { + .onElicitation(async params => { const elicitParams = params as { mode?: string; message?: string; requestedSchema?: unknown }; console.log('\n[ELICITATION] Received elicitation request from server'); console.log('[ELICITATION] Mode:', elicitParams.mode); @@ -574,7 +567,7 @@ async function connect(url?: string): Promise { .build(); // Set up client error handler - client.onerror = (error) => { + client.onerror = error => { console.error('\n[CLIENT] Error event:', error); }; @@ -584,7 +577,7 @@ async function connect(url?: string): Promise { }); // Set up notification handler for logging messages - client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { notificationCount++; console.log(`\n[NOTIFICATION #${notificationCount}] ${notification.params.level}: ${notification.params.data}`); process.stdout.write('> '); diff --git a/examples/server/src/simpleStreamableHttpBuilder.ts b/examples/server/src/simpleStreamableHttpBuilder.ts index de8574c9a..d1c64f8a4 100644 --- a/examples/server/src/simpleStreamableHttpBuilder.ts +++ b/examples/server/src/simpleStreamableHttpBuilder.ts @@ -4,10 +4,9 @@ * This example demonstrates using the McpServer.builder() fluent API * to create and configure an MCP server with: * - Tools, resources, and prompts registration - * - Middleware (logging, rate limiting, custom metrics) + * - Middleware (logging, custom metrics) * - Per-tool middleware (authorization) * - Error handlers (onError, onProtocolError) - * - Session management with SessionStore * - Context helpers (logging, notifications) * * Run with: npx tsx src/simpleStreamableHttpBuilder.ts @@ -17,14 +16,8 @@ import { randomUUID } from 'node:crypto'; import { createMcpExpressApp } from '@modelcontextprotocol/express'; import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; -import type { CallToolResult, GetPromptResult, ReadResourceResult ,ToolMiddleware} from '@modelcontextprotocol/server'; -import { - createLoggingMiddleware, - createSessionStore, - isInitializeRequest, - McpServer, - text -} from '@modelcontextprotocol/server'; +import type { CallToolResult, GetPromptResult, ReadResourceResult, ToolMiddleware } from '@modelcontextprotocol/server'; +import { createLoggingMiddleware, isInitializeRequest, McpServer, text } from '@modelcontextprotocol/server'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -68,7 +61,7 @@ const adminAuthMiddleware: ToolMiddleware = async (ctx, next) => { }; // ═══════════════════════════════════════════════════════════════════════════ -// Session Store Setup +// Session Management // ═══════════════════════════════════════════════════════════════════════════ /** @@ -80,25 +73,9 @@ interface SessionData { } /** - * Create session store with lifecycle events and timeout. - * This replaces the manual session map management. + * Simple Map-based session storage. */ -const sessionStore = createSessionStore({ - sessionTimeout: 30 * 60 * 1000, // 30 minutes - maxSessions: 100, - cleanupInterval: 60_000, // Check for expired sessions every minute - events: { - onSessionCreated: (id) => { - console.log(`[SESSION] Created: ${id}`); - }, - onSessionDestroyed: (id) => { - console.log(`[SESSION] Destroyed: ${id}`); - }, - onSessionUpdated: (id) => { - console.log(`[SESSION] Updated: ${id}`); - } - } -}); +const sessions = new Map(); // ═══════════════════════════════════════════════════════════════════════════ // Server Factory @@ -140,12 +117,10 @@ const getServer = () => { ) // ─── Tool-Specific Middleware ─── - .useToolMiddleware( - async (ctx, next) => { - console.log(`Tool '${ctx.name}' called`); - return next(); - } - ) + .useToolMiddleware(async (ctx, next) => { + console.log(`Tool '${ctx.name}' called`); + return next(); + }) // Custom metrics middleware .useToolMiddleware(metricsMiddleware) @@ -199,7 +174,7 @@ const getServer = () => { } }, async function ({ name }, ctx): Promise { - const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); // Use context logging helper await ctx.loggingNotification.debug(`Starting multi-greet for ${name}`); @@ -304,7 +279,10 @@ const getServer = () => { } }, async ({ errorType }): Promise => { - const error = errorType === 'application' ? new Error('This is a test application error') : new Error('Validation failed: invalid input format'); + const error = + errorType === 'application' + ? new Error('This is a test application error') + : new Error('Validation failed: invalid input format'); throw error; } ) @@ -330,23 +308,23 @@ const getServer = () => { } ) - // Resource demonstrating session info + // Resource demonstrating server info .resource( - 'session-info', - 'https://example.com/session/info', + 'server-info', + 'https://example.com/server/info', { - title: 'Session Information', - description: 'Returns current session statistics' + title: 'Server Information', + description: 'Returns current server statistics' }, async (): Promise => { const stats = { - activeSessions: sessionStore.size(), - sessionIds: sessionStore.keys() + activeSessions: sessions.size, + uptime: process.uptime() }; return { contents: [ { - uri: 'https://example.com/session/info', + uri: 'https://example.com/server/info', mimeType: 'application/json', text: JSON.stringify(stats, null, 2) } @@ -396,14 +374,14 @@ const app = createMcpExpressApp(); /** * MCP POST endpoint handler. - * Uses SessionStore for session management. + * Uses a simple Map for session management. */ const mcpPostHandler = async (req: Request, res: Response) => { const sessionId = req.headers['mcp-session-id'] as string | undefined; try { // Check for existing session - const session = sessionId ? sessionStore.get(sessionId) : undefined; + const session = sessionId ? sessions.get(sessionId) : undefined; if (session) { // Reuse existing transport @@ -420,9 +398,9 @@ const mcpPostHandler = async (req: Request, res: Response) => { const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore, - onsessioninitialized: (sid) => { - // Store session with SessionStore - sessionStore.set(sid, { + onsessioninitialized: sid => { + // Store session + sessions.set(sid, { transport, createdAt: new Date() }); @@ -433,7 +411,7 @@ const mcpPostHandler = async (req: Request, res: Response) => { transport.onclose = () => { const sid = transport.sessionId; if (sid) { - sessionStore.delete(sid); + sessions.delete(sid); } }; @@ -476,7 +454,7 @@ app.post('/mcp', mcpPostHandler); */ const mcpGetHandler = async (req: Request, res: Response) => { const sessionId = req.headers['mcp-session-id'] as string | undefined; - const session = sessionId ? sessionStore.get(sessionId) : undefined; + const session = sessionId ? sessions.get(sessionId) : undefined; if (!session) { res.status(400).send('Invalid or missing session ID'); @@ -500,7 +478,7 @@ app.get('/mcp', mcpGetHandler); */ const mcpDeleteHandler = async (req: Request, res: Response) => { const sessionId = req.headers['mcp-session-id'] as string | undefined; - const session = sessionId ? sessionStore.get(sessionId) : undefined; + const session = sessionId ? sessions.get(sessionId) : undefined; if (!session) { res.status(400).send('Invalid or missing session ID'); @@ -525,7 +503,7 @@ app.delete('/mcp', mcpDeleteHandler); // Server Startup // ═══════════════════════════════════════════════════════════════════════════ -app.listen(PORT, (error) => { +app.listen(PORT, error => { if (error) { console.error('Failed to start server:', error); // eslint-disable-next-line unicorn/no-process-exit @@ -540,10 +518,9 @@ app.listen(PORT, (error) => { console.log('Features demonstrated:'); console.log(' - Builder pattern for server configuration'); console.log(' - Universal middleware (logging)'); - console.log(' - Tool-specific middleware (rate limiting, metrics)'); + console.log(' - Tool-specific middleware (metrics)'); console.log(' - Per-tool middleware (authorization)'); console.log(' - Error handlers (onError, onProtocolError)'); - console.log(' - Session management with SessionStore'); console.log(' - Context helpers (logging, notifications)'); console.log('═══════════════════════════════════════════════════════════════'); }); @@ -555,23 +532,18 @@ app.listen(PORT, (error) => { process.on('SIGINT', async () => { console.log('\n[SHUTDOWN] Received SIGINT, shutting down...'); - // Close all sessions using SessionStore - const sessionIds = sessionStore.keys(); - for (const sid of sessionIds) { + // Close all sessions + for (const [sid, session] of sessions) { try { - const session = sessionStore.get(sid); - if (session) { - console.log(`[SHUTDOWN] Closing session ${sid}`); - await session.transport.close(); - } + console.log(`[SHUTDOWN] Closing session ${sid}`); + await session.transport.close(); } catch (error) { console.error(`[SHUTDOWN] Error closing session ${sid}:`, error); } } - // Clear the session store (also stops cleanup timer) - sessionStore.clear(); - (sessionStore as { dispose?: () => void }).dispose?.(); + // Clear the sessions map + sessions.clear(); console.log('[SHUTDOWN] Complete'); process.exit(0); diff --git a/packages/client/src/client/builder.ts b/packages/client/src/client/builder.ts index eb5423887..2724097f9 100644 --- a/packages/client/src/client/builder.ts +++ b/packages/client/src/client/builder.ts @@ -71,10 +71,7 @@ export type ElicitationRequestHandler = ( * Handler for roots list requests from the server. * Receives the full ListRootsRequest and returns the list of roots. */ -export type RootsListHandler = ( - request: ListRootsRequest, - ctx: ClientContextInterface -) => ListRootsResult | Promise; +export type RootsListHandler = (request: ListRootsRequest, ctx: ClientContextInterface) => ListRootsResult | Promise; /** * Error handler type for application errors diff --git a/packages/client/src/client/middleware.ts b/packages/client/src/client/middleware.ts index 6525fb32a..1fc1dcd91 100644 --- a/packages/client/src/client/middleware.ts +++ b/packages/client/src/client/middleware.ts @@ -727,4 +727,4 @@ export class ClientMiddlewareManager { return dispatch(0); } -} \ No newline at end of file +} diff --git a/packages/server/src/experimental/tasks/mcpServer.ts b/packages/server/src/experimental/tasks/mcpServer.ts index 9c065f20c..a8a526d1e 100644 --- a/packages/server/src/experimental/tasks/mcpServer.ts +++ b/packages/server/src/experimental/tasks/mcpServer.ts @@ -7,7 +7,8 @@ import type { AnySchema, TaskToolExecution, ToolAnnotations, ToolExecution, ZodRawShapeCompat } from '@modelcontextprotocol/core'; -import type { AnyToolHandler, McpServer, RegisteredTool } from '../../server/mcp.js'; +import type { AnyToolHandler, McpServer } from '../../server/mcp.js'; +import type { RegisteredTool } from '../../server/registries/toolRegistry.js'; import type { ToolTaskHandler } from './interfaces.js'; /** diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 51f3e0ace..df415bf24 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -6,7 +6,6 @@ export * from './server/middleware.js'; export * from './server/middleware/hostHeaderValidation.js'; export * from './server/registries/index.js'; export * from './server/server.js'; -export * from './server/sessions.js'; export * from './server/stdio.js'; export * from './server/streamableHttp.js'; diff --git a/packages/server/src/server/builder.ts b/packages/server/src/server/builder.ts index 380dee527..978abe4ec 100644 --- a/packages/server/src/server/builder.ts +++ b/packages/server/src/server/builder.ts @@ -19,7 +19,8 @@ import type { ToolAnnotations, ToolExecution, ZodRawShapeCompat } from '@modelcontextprotocol/core'; import { objectFromShape } from '@modelcontextprotocol/core'; -import type { McpServer, PromptCallback, ReadResourceCallback, ResourceMetadata, ToolCallback } from './mcp.js'; +import type { PromptCallback, ReadResourceCallback } from '../types/types.js'; +import type { McpServer, ResourceMetadata, ToolCallback } from './mcp.js'; import type { PromptMiddleware, ResourceMiddleware, ToolMiddleware, UniversalMiddleware } from './middleware.js'; import { PromptRegistry } from './registries/promptRegistry.js'; import { ResourceRegistry } from './registries/resourceRegistry.js'; diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index be5b75831..1ac7283a7 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -17,7 +17,6 @@ import type { ListToolsResult, LoggingMessageNotification, PromptReference, - ReadResourceResult, Resource, ResourceTemplateReference, Result, @@ -28,7 +27,6 @@ import type { ToolAnnotations, ToolExecution, Transport, - Variables, ZodRawShapeCompat } from '@modelcontextprotocol/core'; import { @@ -56,6 +54,7 @@ import { ZodOptional } from 'zod'; import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; +import type { PromptArgsRawShape, PromptCallback, ReadResourceCallback, ReadResourceTemplateCallback } from '../types/types.js'; import type { BuilderResult, ErrorContext, @@ -77,11 +76,11 @@ import type { UniversalMiddleware } from './middleware.js'; import { MiddlewareManager } from './middleware.js'; -import type { RegisteredPromptEntity } from './registries/promptRegistry.js'; +import type { RegisteredPrompt } from './registries/promptRegistry.js'; import { PromptRegistry } from './registries/promptRegistry.js'; import type { RegisteredResourceEntity, RegisteredResourceTemplateEntity } from './registries/resourceRegistry.js'; import { ResourceRegistry, ResourceTemplateRegistry } from './registries/resourceRegistry.js'; -import type { RegisteredToolEntity } from './registries/toolRegistry.js'; +import type { RegisteredTool } from './registries/toolRegistry.js'; import { ToolRegistry } from './registries/toolRegistry.js'; import type { ServerOptions } from './server.js'; import { Server } from './server.js'; @@ -461,7 +460,7 @@ export class McpServer { * Validates tool input arguments against the tool's input schema. */ private async validateToolInput< - Tool extends RegisteredToolEntity, + Tool extends RegisteredTool, Args extends Tool['inputSchema'] extends infer InputSchema ? InputSchema extends AnySchema ? SchemaOutput @@ -489,11 +488,7 @@ export class McpServer { /** * Validates tool output against the tool's output schema. */ - private async validateToolOutput( - tool: RegisteredToolEntity, - result: CallToolResult | CreateTaskResult, - toolName: string - ): Promise { + private async validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string): Promise { if (!tool.outputSchema) { return; } @@ -531,7 +526,7 @@ export class McpServer { * Executes a tool handler (either regular or task-based). */ private async executeToolHandler( - tool: RegisteredToolEntity, + tool: RegisteredTool, args: unknown, ctx: ServerContextInterface ): Promise { @@ -570,7 +565,7 @@ export class McpServer { * Handles automatic task polling for tools with taskSupport 'optional'. */ private async handleAutomaticTaskPolling( - tool: RegisteredToolEntity, + tool: RegisteredTool, request: RequestT, ctx: ServerContextInterface ): Promise { @@ -842,7 +837,10 @@ export class McpServer { if (!parseResult.success) { const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for prompt ${request.params.name}: ${errorMessage}` + ); } const args = parseResult.data; @@ -935,7 +933,7 @@ export class McpServer { _meta?: Record; }, cb: ToolCallback - ): RegisteredToolEntity { + ): RegisteredTool { const { title, description, inputSchema, outputSchema, annotations, execution, _meta } = config; const registeredTool = this._toolRegistry.register({ @@ -965,7 +963,7 @@ export class McpServer { argsSchema?: Args; }, cb: PromptCallback - ): RegisteredPromptEntity { + ): RegisteredPrompt { const { title, description, argsSchema } = config; const registeredPrompt = this._promptRegistry.register({ @@ -1249,32 +1247,6 @@ export type ToolCallback = ToolCallback | ToolTaskHandler; -export type RegisteredTool = { - title?: string; - description?: string; - inputSchema?: AnySchema; - outputSchema?: AnySchema; - annotations?: ToolAnnotations; - execution?: ToolExecution; - _meta?: Record; - handler: AnyToolHandler; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - description?: string; - paramsSchema?: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - _meta?: Record; - callback?: ToolCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - /** * Checks if a value looks like a Zod schema by checking for parse/safeParse methods. */ @@ -1357,86 +1329,6 @@ export type ListResourcesCallback = ( ctx: ServerContextInterface ) => ListResourcesResult | Promise; -/** - * Callback to read a resource at a given URI. - */ -export type ReadResourceCallback = ( - uri: URL, - ctx: ServerContextInterface -) => ReadResourceResult | Promise; - -export type RegisteredResource = { - name: string; - title?: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string; - title?: string; - uri?: string | null; - metadata?: ResourceMetadata; - callback?: ReadResourceCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -/** - * Callback to read a resource at a given URI, following a filled-in URI template. - */ -export type ReadResourceTemplateCallback = ( - uri: URL, - variables: Variables, - ctx: ServerContextInterface -) => ReadResourceResult | Promise; - -export type RegisteredResourceTemplate = { - resourceTemplate: ResourceTemplate; - title?: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceTemplateCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - template?: ResourceTemplate; - metadata?: ResourceMetadata; - callback?: ReadResourceTemplateCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -type PromptArgsRawShape = ZodRawShapeCompat; - -export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, ctx: ServerContextInterface) => GetPromptResult | Promise - : (ctx: ServerContextInterface) => GetPromptResult | Promise; - -export type RegisteredPrompt = { - title?: string; - description?: string; - argsSchema?: AnyObjectSchema; - callback: PromptCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - description?: string; - argsSchema?: Args; - callback?: PromptCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - function getMethodValue(schema: AnyObjectSchema): string { const shape = getObjectShape(schema); const methodSchema = shape?.method as AnySchema | undefined; diff --git a/packages/server/src/server/middleware.ts b/packages/server/src/server/middleware.ts index 529ff39bb..04324a31c 100644 --- a/packages/server/src/server/middleware.ts +++ b/packages/server/src/server/middleware.ts @@ -450,4 +450,4 @@ export interface RateLimitMiddlewareOptions { windowMs?: number; /** Error message when rate limited */ message?: string; -} \ No newline at end of file +} diff --git a/packages/server/src/server/registries/promptRegistry.ts b/packages/server/src/server/registries/promptRegistry.ts index 06f7d6b11..8a9aac658 100644 --- a/packages/server/src/server/registries/promptRegistry.ts +++ b/packages/server/src/server/registries/promptRegistry.ts @@ -8,8 +8,7 @@ import type { AnyObjectSchema, AnySchema, Prompt, PromptArgument, ZodRawShapeCompat } from '@modelcontextprotocol/core'; import { getObjectShape, getSchemaDescription, isSchemaOptional, objectFromShape } from '@modelcontextprotocol/core'; -import type { PromptCallback } from '../mcp.js'; -import type { RegisteredDefinition } from './baseRegistry.js'; +import type { PromptCallback, RegisteredPromptInterface } from '../../types/types.js'; import { BaseRegistry } from './baseRegistry.js'; /** @@ -39,62 +38,62 @@ export interface PromptUpdates { * Class-based representation of a registered prompt. * Provides methods for managing the prompt's lifecycle. */ -export class RegisteredPromptEntity implements RegisteredDefinition { - private _name: string; - private _enabled: boolean = true; - private readonly _registry: PromptRegistry; +export class RegisteredPrompt implements RegisteredPromptInterface { + #name: string; + #enabled: boolean = true; + readonly #registry: PromptRegistry; - private _title?: string; - private _description?: string; - private _argsSchema?: AnyObjectSchema; - private _callback: PromptCallback; + #title?: string; + #description?: string; + #argsSchema?: AnyObjectSchema; + #callback: PromptCallback; constructor(config: PromptConfig, registry: PromptRegistry) { - this._name = config.name; - this._registry = registry; - this._title = config.title; - this._description = config.description; - this._argsSchema = config.argsSchema ? objectFromShape(config.argsSchema) : undefined; - this._callback = config.callback; + this.#name = config.name; + this.#registry = registry; + this.#title = config.title; + this.#description = config.description; + this.#argsSchema = config.argsSchema ? objectFromShape(config.argsSchema) : undefined; + this.#callback = config.callback; } /** The prompt's name (identifier) */ get name(): string { - return this._name; + return this.#name; } /** Whether the prompt is currently enabled */ get enabled(): boolean { - return this._enabled; + return this.#enabled; } /** The prompt's title */ get title(): string | undefined { - return this._title; + return this.#title; } /** The prompt's description */ get description(): string | undefined { - return this._description; + return this.#description; } /** The prompt's args schema */ get argsSchema(): AnyObjectSchema | undefined { - return this._argsSchema; + return this.#argsSchema; } /** The prompt's callback */ get callback(): PromptCallback { - return this._callback; + return this.#callback; } /** * Enables the prompt */ enable(): this { - if (!this._enabled) { - this._enabled = true; - this._registry['notifyChanged'](); + if (!this.#enabled) { + this.#enabled = true; + this.#registry['notifyChanged'](); } return this; } @@ -103,9 +102,9 @@ export class RegisteredPromptEntity implements RegisteredDefinition { * Disables the prompt */ disable(): this { - if (this._enabled) { - this._enabled = false; - this._registry['notifyChanged'](); + if (this.#enabled) { + this.#enabled = false; + this.#registry['notifyChanged'](); } return this; } @@ -114,7 +113,7 @@ export class RegisteredPromptEntity implements RegisteredDefinition { * Removes the prompt from its registry */ remove(): void { - this._registry.remove(this._name); + this.#registry.remove(this.#name); } /** @@ -123,8 +122,8 @@ export class RegisteredPromptEntity implements RegisteredDefinition { * @param newName - The new name for the prompt */ rename(newName: string): this { - this._registry['_rename'](this._name, newName); - this._name = newName; + this.#registry['_rename'](this.#name, newName); + this.#name = newName; return this; } @@ -141,12 +140,12 @@ export class RegisteredPromptEntity implements RegisteredDefinition { } this.rename(updates.name); } - if (updates.title !== undefined) this._title = updates.title; - if (updates.description !== undefined) this._description = updates.description; - if (updates.argsSchema !== undefined) this._argsSchema = objectFromShape(updates.argsSchema); - if (updates.callback !== undefined) this._callback = updates.callback; + if (updates.title !== undefined) this.#title = updates.title; + if (updates.description !== undefined) this.#description = updates.description; + if (updates.argsSchema !== undefined) this.#argsSchema = objectFromShape(updates.argsSchema); + if (updates.callback !== undefined) this.#callback = updates.callback; if (updates.enabled === undefined) { - this._registry['notifyChanged'](); + this.#registry['notifyChanged'](); } else { if (updates.enabled) { this.enable(); @@ -161,10 +160,10 @@ export class RegisteredPromptEntity implements RegisteredDefinition { */ toProtocolPrompt(): Prompt { return { - name: this._name, - title: this._title, - description: this._description, - arguments: this._argsSchema ? promptArgumentsFromSchema(this._argsSchema) : undefined + name: this.#name, + title: this.#title, + description: this.#description, + arguments: this.#argsSchema ? promptArgumentsFromSchema(this.#argsSchema) : undefined }; } } @@ -172,7 +171,7 @@ export class RegisteredPromptEntity implements RegisteredDefinition { /** * Registry for managing prompts. */ -export class PromptRegistry extends BaseRegistry { +export class PromptRegistry extends BaseRegistry { /** * Creates a new PromptRegistry. * @@ -193,12 +192,12 @@ export class PromptRegistry extends BaseRegistry { * @returns The registered prompt * @throws If a prompt with the same name already exists */ - register(config: PromptConfig): RegisteredPromptEntity { + register(config: PromptConfig): RegisteredPrompt { if (this._items.has(config.name)) { throw new Error(`Prompt '${config.name}' is already registered`); } - const prompt = new RegisteredPromptEntity(config, this); + const prompt = new RegisteredPrompt(config, this); this._set(config.name, prompt); this.notifyChanged(); return prompt; @@ -219,7 +218,7 @@ export class PromptRegistry extends BaseRegistry { * @param name - The prompt name * @returns The registered prompt or undefined */ - getPrompt(name: string): RegisteredPromptEntity | undefined { + getPrompt(name: string): RegisteredPrompt | undefined { return this.get(name); } } diff --git a/packages/server/src/server/registries/resourceRegistry.ts b/packages/server/src/server/registries/resourceRegistry.ts index 0468fda92..75b07b535 100644 --- a/packages/server/src/server/registries/resourceRegistry.ts +++ b/packages/server/src/server/registries/resourceRegistry.ts @@ -7,7 +7,8 @@ import type { Resource, ResourceTemplateType as ResourceTemplateProtocol, Variables } from '@modelcontextprotocol/core'; -import type { ReadResourceCallback, ReadResourceTemplateCallback, ResourceMetadata, ResourceTemplate } from '../mcp.js'; +import type { ReadResourceCallback, ReadResourceTemplateCallback } from '../../types/types.js'; +import type { ResourceMetadata, ResourceTemplate } from '../mcp.js'; import type { RegisteredDefinition } from './baseRegistry.js'; import { BaseRegistry } from './baseRegistry.js'; diff --git a/packages/server/src/server/registries/toolRegistry.ts b/packages/server/src/server/registries/toolRegistry.ts index 360e26638..995893a34 100644 --- a/packages/server/src/server/registries/toolRegistry.ts +++ b/packages/server/src/server/registries/toolRegistry.ts @@ -8,15 +8,10 @@ import type { AnySchema, Tool, ToolAnnotations, ToolExecution, ZodRawShapeCompat } from '@modelcontextprotocol/core'; import { normalizeObjectSchema, toJsonSchemaCompat, validateAndWarnToolName } from '@modelcontextprotocol/core'; +import type { RegisteredToolInterface } from '../../types/types.js'; import type { AnyToolHandler } from '../mcp.js'; -import type { RegisteredDefinition } from './baseRegistry.js'; import { BaseRegistry } from './baseRegistry.js'; -/** - * Tool handler type - compatible with both ToolCallback and ToolTaskHandler - */ -export type ToolHandler = AnyToolHandler; - /** * Configuration for registering a tool */ @@ -29,7 +24,7 @@ export interface ToolConfig { annotations?: ToolAnnotations; execution?: ToolExecution; _meta?: Record; - handler: ToolHandler; + handler: AnyToolHandler; } /** @@ -44,7 +39,7 @@ export interface ToolUpdates { annotations?: ToolAnnotations; execution?: ToolExecution; _meta?: Record; - handler?: ToolHandler; + handler?: AnyToolHandler; enabled?: boolean; } @@ -57,90 +52,90 @@ const EMPTY_OBJECT_JSON_SCHEMA = { * Class-based representation of a registered tool. * Provides methods for managing the tool's lifecycle. */ -export class RegisteredToolEntity implements RegisteredDefinition { - private _name: string; - private _enabled: boolean = true; - private readonly _registry: ToolRegistry; - - private _title?: string; - private _description?: string; - private _inputSchema?: AnySchema; - private _outputSchema?: AnySchema; - private _annotations?: ToolAnnotations; - private _execution?: ToolExecution; - private __meta?: Record; - private _handler: ToolHandler; +export class RegisteredTool implements RegisteredToolInterface { + #name: string; + #enabled: boolean = true; + readonly #registry: ToolRegistry; + + #title?: string; + #description?: string; + #inputSchema?: AnySchema; + #outputSchema?: AnySchema; + #annotations?: ToolAnnotations; + #execution?: ToolExecution; + #__meta?: Record; + #handler: AnyToolHandler; constructor(config: ToolConfig, registry: ToolRegistry) { - this._name = config.name; - this._registry = registry; - this._title = config.title; - this._description = config.description; - this._inputSchema = config.inputSchema; - this._outputSchema = config.outputSchema; - this._annotations = config.annotations; - this._execution = config.execution; - this.__meta = config._meta; - this._handler = config.handler; + this.#name = config.name; + this.#registry = registry; + this.#title = config.title; + this.#description = config.description; + this.#inputSchema = config.inputSchema; + this.#outputSchema = config.outputSchema; + this.#annotations = config.annotations; + this.#execution = config.execution; + this.#__meta = config._meta; + this.#handler = config.handler; } /** The tool's name (identifier) */ get name(): string { - return this._name; + return this.#name; } /** Whether the tool is currently enabled */ get enabled(): boolean { - return this._enabled; + return this.#enabled; } /** The tool's title */ get title(): string | undefined { - return this._title; + return this.#title; } /** The tool's description */ get description(): string | undefined { - return this._description; + return this.#description; } /** The tool's input schema */ get inputSchema(): AnySchema | undefined { - return this._inputSchema; + return this.#inputSchema; } /** The tool's output schema */ get outputSchema(): AnySchema | undefined { - return this._outputSchema; + return this.#outputSchema; } /** The tool's annotations */ get annotations(): ToolAnnotations | undefined { - return this._annotations; + return this.#annotations; } /** The tool's execution settings */ get execution(): ToolExecution | undefined { - return this._execution; + return this.#execution; } /** The tool's metadata */ get _meta(): Record | undefined { - return this.__meta; + return this.#__meta; } /** The tool's handler function */ - get handler(): ToolHandler { - return this._handler; + get handler(): AnyToolHandler { + return this.#handler; } /** * Enables the tool */ enable(): this { - if (!this._enabled) { - this._enabled = true; - this._registry['notifyChanged'](); + if (!this.#enabled) { + this.#enabled = true; + this.#registry['notifyChanged'](); } return this; } @@ -149,9 +144,9 @@ export class RegisteredToolEntity implements RegisteredDefinition { * Disables the tool */ disable(): this { - if (this._enabled) { - this._enabled = false; - this._registry['notifyChanged'](); + if (this.#enabled) { + this.#enabled = false; + this.#registry['notifyChanged'](); } return this; } @@ -160,7 +155,7 @@ export class RegisteredToolEntity implements RegisteredDefinition { * Removes the tool from its registry */ remove(): void { - this._registry.remove(this._name); + this.#registry.remove(this.#name); } /** @@ -170,8 +165,8 @@ export class RegisteredToolEntity implements RegisteredDefinition { */ rename(newName: string): this { validateAndWarnToolName(newName); - this._registry['_rename'](this._name, newName); - this._name = newName; + this.#registry['_rename'](this.#name, newName); + this.#name = newName; return this; } @@ -180,7 +175,18 @@ export class RegisteredToolEntity implements RegisteredDefinition { * * @param updates - The updates to apply */ - update(updates: ToolUpdates): void { + update(updates: { + name?: string | null; + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + _meta?: Record; + handler?: AnyToolHandler; + execution?: ToolExecution; + enabled?: boolean; + }): void { if (updates.name !== undefined) { if (updates.name === null) { this.remove(); @@ -188,16 +194,16 @@ export class RegisteredToolEntity implements RegisteredDefinition { } this.rename(updates.name); } - if (updates.title !== undefined) this._title = updates.title; - if (updates.description !== undefined) this._description = updates.description; - if (updates.inputSchema !== undefined) this._inputSchema = updates.inputSchema; - if (updates.outputSchema !== undefined) this._outputSchema = updates.outputSchema; - if (updates.annotations !== undefined) this._annotations = updates.annotations; - if (updates.execution !== undefined) this._execution = updates.execution; - if (updates._meta !== undefined) this.__meta = updates._meta; - if (updates.handler !== undefined) this._handler = updates.handler; + if (updates.title !== undefined) this.#title = updates.title; + if (updates.description !== undefined) this.#description = updates.description; + if (updates.inputSchema !== undefined) this.#inputSchema = updates.inputSchema; + if (updates.outputSchema !== undefined) this.#outputSchema = updates.outputSchema; + if (updates.annotations !== undefined) this.#annotations = updates.annotations; + if (updates.execution !== undefined) this.#execution = updates.execution; + if (updates._meta !== undefined) this.#__meta = updates._meta; + if (updates.handler !== undefined) this.#handler = updates.handler; if (updates.enabled === undefined) { - this._registry['notifyChanged'](); + this.#registry['notifyChanged'](); } else { if (updates.enabled) { this.enable(); @@ -212,22 +218,22 @@ export class RegisteredToolEntity implements RegisteredDefinition { */ toProtocolTool(): Tool { const tool: Tool = { - name: this._name, - title: this._title, - description: this._description, - inputSchema: this._inputSchema - ? (toJsonSchemaCompat(normalizeObjectSchema(this._inputSchema) ?? this._inputSchema, { + name: this.#name, + title: this.#title, + description: this.#description, + inputSchema: this.#inputSchema + ? (toJsonSchemaCompat(normalizeObjectSchema(this.#inputSchema) ?? this.#inputSchema, { strictUnions: true, pipeStrategy: 'input' }) as Tool['inputSchema']) : EMPTY_OBJECT_JSON_SCHEMA, - annotations: this._annotations, - execution: this._execution, - _meta: this.__meta + annotations: this.#annotations, + execution: this.#execution, + _meta: this.#__meta }; - if (this._outputSchema) { - const obj = normalizeObjectSchema(this._outputSchema); + if (this.#outputSchema) { + const obj = normalizeObjectSchema(this.#outputSchema); if (obj) { tool.outputSchema = toJsonSchemaCompat(obj, { strictUnions: true, @@ -243,7 +249,7 @@ export class RegisteredToolEntity implements RegisteredDefinition { /** * Registry for managing tools. */ -export class ToolRegistry extends BaseRegistry { +export class ToolRegistry extends BaseRegistry { /** * Creates a new ToolRegistry. * @@ -264,13 +270,13 @@ export class ToolRegistry extends BaseRegistry { * @returns The registered tool * @throws If a tool with the same name already exists */ - register(config: ToolConfig): RegisteredToolEntity { + register(config: ToolConfig): RegisteredTool { if (this._items.has(config.name)) { throw new Error(`Tool '${config.name}' is already registered`); } validateAndWarnToolName(config.name); - const tool = new RegisteredToolEntity(config, this); + const tool = new RegisteredTool(config, this); this._set(config.name, tool); this.notifyChanged(); return tool; @@ -291,7 +297,7 @@ export class ToolRegistry extends BaseRegistry { * @param name - The tool name * @returns The registered tool or undefined */ - getTool(name: string): RegisteredToolEntity | undefined { + getTool(name: string): RegisteredTool | undefined { return this.get(name); } } diff --git a/packages/server/src/server/sessions.ts b/packages/server/src/server/sessions.ts deleted file mode 100644 index 09e835eaa..000000000 --- a/packages/server/src/server/sessions.ts +++ /dev/null @@ -1,344 +0,0 @@ -/** - * Session Management Abstraction - * - * Provides a SessionStore interface and implementations for managing - * server session state. This replaces the manual session map management - * patterns seen across examples. - */ - -/** - * Session lifecycle event callbacks - */ -export interface SessionStoreEvents { - /** - * Called when a new session is created - */ - onSessionCreated?: (sessionId: string, data: T) => void; - - /** - * Called when a session is destroyed - */ - onSessionDestroyed?: (sessionId: string) => void; - - /** - * Called when session data is updated - */ - onSessionUpdated?: (sessionId: string, data: T) => void; -} - -/** - * Interface for session storage implementations. - * - * @template T - The type of session data - */ -export interface SessionStore { - /** - * Gets the session data for a given session ID. - * - * @param sessionId - The session identifier - * @returns The session data or undefined if not found - */ - get(sessionId: string): T | undefined; - - /** - * Sets the session data for a given session ID. - * Creates a new session if it doesn't exist. - * - * @param sessionId - The session identifier - * @param data - The session data to store - */ - set(sessionId: string, data: T): void; - - /** - * Deletes a session. - * - * @param sessionId - The session identifier - * @returns true if the session was deleted, false if it didn't exist - */ - delete(sessionId: string): boolean; - - /** - * Checks if a session exists. - * - * @param sessionId - The session identifier - * @returns true if the session exists - */ - has(sessionId: string): boolean; - - /** - * Gets the number of active sessions. - */ - size(): number; - - /** - * Gets all session IDs. - */ - keys(): string[]; - - /** - * Clears all sessions. - */ - clear(): void; -} - -/** - * Options for InMemorySessionStore - */ -export interface InMemorySessionStoreOptions { - /** - * Maximum number of sessions to store. - * When exceeded, the oldest session will be evicted. - * Default: unlimited - */ - maxSessions?: number; - - /** - * Session timeout in milliseconds. - * Sessions older than this will be automatically cleaned up. - * Default: no timeout - */ - sessionTimeout?: number; - - /** - * Interval for checking expired sessions in milliseconds. - * Default: 60000 (1 minute) - */ - cleanupInterval?: number; - - /** - * Event callbacks - */ - events?: SessionStoreEvents; -} - -/** - * Internal session entry with metadata - */ -interface SessionEntry { - data: T; - createdAt: number; - lastAccessedAt: number; -} - -/** - * In-memory implementation of SessionStore. - * - * Features: - * - Optional maximum session limit with LRU eviction - * - Optional session timeout with automatic cleanup - * - Lifecycle event callbacks - * - * @template T - The type of session data - */ -export class InMemorySessionStore implements SessionStore { - private _sessions = new Map>(); - private _options: InMemorySessionStoreOptions; - private _cleanupTimer?: ReturnType; - - constructor(options: InMemorySessionStoreOptions = {}) { - this._options = options; - - // Set up automatic cleanup if timeout is configured - if (options.sessionTimeout && options.sessionTimeout > 0) { - const interval = options.cleanupInterval ?? 60_000; - this._cleanupTimer = setInterval(() => { - this._cleanupExpiredSessions(); - }, interval); - - // Prevent timer from keeping process alive - if (typeof this._cleanupTimer.unref === 'function') { - this._cleanupTimer.unref(); - } - } - } - - /** - * Gets the session data for a given session ID. - * Updates the last accessed time on access. - */ - get(sessionId: string): T | undefined { - const entry = this._sessions.get(sessionId); - if (!entry) { - return undefined; - } - - // Check if expired - if (this._isExpired(entry)) { - this.delete(sessionId); - return undefined; - } - - // Update last accessed time - entry.lastAccessedAt = Date.now(); - return entry.data; - } - - /** - * Sets the session data for a given session ID. - * Creates a new session if it doesn't exist. - */ - set(sessionId: string, data: T): void { - const existing = this._sessions.get(sessionId); - const now = Date.now(); - - if (existing) { - // Update existing session - existing.data = data; - existing.lastAccessedAt = now; - this._options.events?.onSessionUpdated?.(sessionId, data); - } else { - // Create new session - // Check max sessions limit - if (this._options.maxSessions && this._sessions.size >= this._options.maxSessions) { - this._evictOldestSession(); - } - - this._sessions.set(sessionId, { - data, - createdAt: now, - lastAccessedAt: now - }); - this._options.events?.onSessionCreated?.(sessionId, data); - } - } - - /** - * Deletes a session. - */ - delete(sessionId: string): boolean { - const deleted = this._sessions.delete(sessionId); - if (deleted) { - this._options.events?.onSessionDestroyed?.(sessionId); - } - return deleted; - } - - /** - * Checks if a session exists. - */ - has(sessionId: string): boolean { - const entry = this._sessions.get(sessionId); - if (!entry) { - return false; - } - - // Check if expired - if (this._isExpired(entry)) { - this.delete(sessionId); - return false; - } - - return true; - } - - /** - * Gets the number of active sessions. - */ - size(): number { - return this._sessions.size; - } - - /** - * Gets all session IDs. - */ - keys(): string[] { - return [...this._sessions.keys()]; - } - - /** - * Clears all sessions. - */ - clear(): void { - const sessionIds = this.keys(); - for (const sessionId of sessionIds) { - this.delete(sessionId); - } - } - - /** - * Stops the cleanup timer. - * Call this when the store is no longer needed. - */ - dispose(): void { - if (this._cleanupTimer) { - clearInterval(this._cleanupTimer); - this._cleanupTimer = undefined; - } - } - - /** - * Checks if a session entry is expired. - */ - private _isExpired(entry: SessionEntry): boolean { - if (!this._options.sessionTimeout) { - return false; - } - const age = Date.now() - entry.lastAccessedAt; - return age > this._options.sessionTimeout; - } - - /** - * Evicts the oldest session (by last access time). - */ - private _evictOldestSession(): void { - let oldestId: string | undefined; - let oldestTime = Infinity; - - for (const [id, entry] of this._sessions) { - if (entry.lastAccessedAt < oldestTime) { - oldestTime = entry.lastAccessedAt; - oldestId = id; - } - } - - if (oldestId) { - this.delete(oldestId); - } - } - - /** - * Cleans up expired sessions. - */ - private _cleanupExpiredSessions(): void { - for (const [sessionId, entry] of this._sessions) { - if (this._isExpired(entry)) { - this.delete(sessionId); - } - } - } -} - -/** - * Creates a new in-memory session store. - * - * @example - * ```typescript - * const sessionStore = createSessionStore<{ userId: string }>({ - * sessionTimeout: 30 * 60 * 1000, // 30 minutes - * maxSessions: 1000, - * events: { - * onSessionCreated: (id) => console.log(`Session created: ${id}`), - * onSessionDestroyed: (id) => console.log(`Session destroyed: ${id}`), - * } - * }); - * ``` - */ -export function createSessionStore(options?: InMemorySessionStoreOptions): InMemorySessionStore { - return new InMemorySessionStore(options); -} - -/** - * Session ID generator using crypto.randomUUID. - * Falls back to Math.random if crypto is not available. - */ -export function generateSessionId(): string { - if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') { - return crypto.randomUUID(); - } - // Fallback for environments without crypto.randomUUID - return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replaceAll(/[xy]/g, c => { - const r = Math.trunc(Math.random() * 16); - const v = c === 'x' ? r : (r & 0x3) | 0x8; - return v.toString(16); - }); -} diff --git a/packages/server/src/types/types.ts b/packages/server/src/types/types.ts new file mode 100644 index 000000000..10ff99e26 --- /dev/null +++ b/packages/server/src/types/types.ts @@ -0,0 +1,155 @@ +import type { + AnyObjectSchema, + AnySchema, + GetPromptResult, + ReadResourceResult, + ServerNotification, + ServerRequest, + ShapeOutput, + ToolAnnotations, + ToolExecution, + Variables, + ZodRawShapeCompat +} from '@modelcontextprotocol/core'; + +import type { ServerContextInterface } from '../server/context.js'; +import type { AnyToolHandler, ResourceMetadata, ResourceTemplate, ToolCallback } from '../server/mcp.js'; + +/** + * Base interface for all registered definitions + */ +export interface RegisteredDefinition { + /** + * Whether the definition is currently enabled + */ + enabled: boolean; + + /** + * Enable the definition + */ + enable(): void; + + /** + * Disable the definition + */ + disable(): void; + + /** + * Remove the definition from its registry + */ + remove(): void; + + /** + * Update the definition + */ + update(updates: unknown): void; +} + +export interface RegisteredToolInterface extends RegisteredDefinition { + title?: string; + description?: string; + inputSchema?: AnySchema; + outputSchema?: AnySchema; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + handler: AnyToolHandler; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + _meta?: Record; + callback?: ToolCallback; + enabled?: boolean; + }): void; + remove(): void; +} + +/** + * Callback to read a resource at a given URI. + */ +export type ReadResourceCallback = ( + uri: URL, + ctx: ServerContextInterface +) => ReadResourceResult | Promise; + +export interface RegisteredResourceInterface extends RegisteredDefinition { + name: string; + title?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string; + title?: string; + uri?: string | null; + metadata?: ResourceMetadata; + callback?: ReadResourceCallback; + enabled?: boolean; + }): void; + remove(): void; +} + +/** + * Callback to read a resource at a given URI, following a filled-in URI template. + */ +export type ReadResourceTemplateCallback = ( + uri: URL, + variables: Variables, + ctx: ServerContextInterface +) => ReadResourceResult | Promise; + +export interface RegisteredResourceTemplateInterface extends RegisteredDefinition { + resourceTemplate: ResourceTemplate; + title?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + template?: ResourceTemplate; + metadata?: ResourceMetadata; + callback?: ReadResourceTemplateCallback; + enabled?: boolean; + }): void; + remove(): void; +} + +export type PromptArgsRawShape = ZodRawShapeCompat; + +export type PromptCallback = Args extends PromptArgsRawShape + ? ( + args: ShapeOutput, + ctx: ServerContextInterface + ) => GetPromptResult | Promise + : (ctx: ServerContextInterface) => GetPromptResult | Promise; + +export interface RegisteredPromptInterface extends RegisteredDefinition { + title?: string; + description?: string; + argsSchema?: AnyObjectSchema; + callback: PromptCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + description?: string; + argsSchema?: Args; + callback?: PromptCallback; + enabled?: boolean; + }): void; + remove(): void; +} From a0a89bb37e05a0362f0d4d353f7049c1caf3adff Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Mon, 26 Jan 2026 23:06:57 +0200 Subject: [PATCH 17/17] save commit --- examples/client/src/elicitationUrlExample.ts | 4 +- examples/client/src/simpleStreamableHttp.ts | 4 +- .../client/src/simpleTaskInteractiveClient.ts | 4 +- examples/server/src/simpleStreamableHttp.ts | 2 +- packages/client/src/client/client.ts | 188 ++++++++++++------ packages/client/src/client/sse.ts | 6 +- packages/client/src/client/stdio.ts | 8 +- packages/client/src/client/streamableHttp.ts | 7 +- packages/client/src/client/websocket.ts | 8 +- .../client/src/experimental/tasks/client.ts | 14 +- packages/core/src/errors.ts | 117 +++++------ packages/core/src/shared/protocol.ts | 61 ++++-- packages/core/src/shared/responseMessage.ts | 5 +- packages/core/src/shared/taskClientPlugin.ts | 10 +- packages/core/src/shared/taskPlugin.ts | 31 +-- packages/core/src/types/types.ts | 45 +---- packages/core/src/util/inMemory.ts | 3 +- packages/core/src/util/zodJsonSchemaCompat.ts | 3 +- packages/core/test/shared/protocol.test.ts | 43 ++-- packages/server/src/server/mcp.ts | 55 +++-- packages/server/src/server/server.ts | 57 +++--- packages/server/src/server/streamableHttp.ts | 3 +- test/integration/test/client/client.test.ts | 6 +- .../experimental/tasks/taskListing.test.ts | 6 +- test/integration/test/server.test.ts | 6 +- test/integration/test/server/mcp.test.ts | 12 +- test/integration/test/taskLifecycle.test.ts | 18 +- 27 files changed, 391 insertions(+), 335 deletions(-) diff --git a/examples/client/src/elicitationUrlExample.ts b/examples/client/src/elicitationUrlExample.ts index 4ac59aa6a..a61ae8535 100644 --- a/examples/client/src/elicitationUrlExample.ts +++ b/examples/client/src/elicitationUrlExample.ts @@ -26,7 +26,7 @@ import { ErrorCode, getDisplayName, ListToolsResultSchema, - McpError, + ProtocolError, StreamableHTTPClientTransport, UnauthorizedError, UrlElicitationRequiredError @@ -339,7 +339,7 @@ async function handleElicitationRequest(request: ElicitRequest): Promise { // Set up elicitation request handler with proper validation client.setRequestHandler(ElicitRequestSchema, async request => { if (request.params.mode !== 'form') { - throw new McpError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); + throw new ProtocolError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); } console.log('\n🔔 Elicitation (form) Request Received:'); console.log(`Message: ${request.params.message}`); diff --git a/examples/client/src/simpleTaskInteractiveClient.ts b/examples/client/src/simpleTaskInteractiveClient.ts index 89310ef29..31b645127 100644 --- a/examples/client/src/simpleTaskInteractiveClient.ts +++ b/examples/client/src/simpleTaskInteractiveClient.ts @@ -17,7 +17,7 @@ import { ElicitRequestSchema, ErrorCode, isTextContent, - McpError, + ProtocolError, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; @@ -105,7 +105,7 @@ async function run(url: string): Promise { // Set up elicitation request handler client.setRequestHandler(ElicitRequestSchema, async request => { if (request.params.mode && request.params.mode !== 'form') { - throw new McpError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); + throw new ProtocolError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); } return elicitationCallback(request.params); }); diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 85ae517d0..4157b878f 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -50,7 +50,7 @@ const getServer = () => { ); // Enable task support via TaskPlugin - server.server.usePlugin( + server.usePlugin( new TaskPlugin({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index e5e4bbb91..6ed1a4cfd 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -9,6 +9,8 @@ import type { CompatibilityCallToolResultSchema, CompleteRequest, ContextInterface, + ErrorInterceptionContext, + ErrorInterceptionResult, GetPromptRequest, Implementation, JSONRPCRequest, @@ -45,6 +47,7 @@ import { assertClientRequestTaskCapability, assertToolsCallTaskCapability, CallToolResultSchema, + CapabilityError, CompleteResultSchema, CreateMessageRequestSchema, CreateMessageResultSchema, @@ -53,10 +56,10 @@ import { ElicitRequestSchema, ElicitResultSchema, EmptyResultSchema, - ErrorCode, getObjectShape, GetPromptResultSchema, InitializeResultSchema, + isProtocolError, isZ4Schema, LATEST_PROTOCOL_VERSION, ListChangedOptionsBaseSchema, @@ -65,19 +68,27 @@ import { ListResourceTemplatesResultSchema, ListRootsRequestSchema, ListToolsResultSchema, - McpError, mergeCapabilities, PromptListChangedNotificationSchema, Protocol, + ProtocolError, ReadResourceResultSchema, ResourceListChangedNotificationSchema, safeParse, + StateError, SUPPORTED_PROTOCOL_VERSIONS, ToolListChangedNotificationSchema } from '@modelcontextprotocol/core'; import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; -import type { ClientBuilderResult, ErrorContext, OnErrorHandler, OnProtocolErrorHandler } from './builder.js'; +import type { + ClientBuilderResult, + ErrorContext, + OnErrorHandler, + OnErrorReturn, + OnProtocolErrorHandler, + OnProtocolErrorReturn +} from './builder.js'; import { ClientBuilder } from './builder.js'; import type { ClientRequestContext } from './context.js'; import { ClientContext } from './context.js'; @@ -275,6 +286,10 @@ export class Client< private _pendingListChangedConfig?: ListChangedHandlers; private readonly _middleware: ClientMiddlewareManager; + // Error handlers (single callback pattern, matching McpServer) + private _onErrorHandler?: OnErrorHandler; + private _onProtocolErrorHandler?: OnProtocolErrorHandler; + /** * Initializes this client with the given name and version information. */ @@ -523,7 +538,7 @@ export class Client< */ public registerCapabilities(capabilities: ClientCapabilities): void { if (this.transport) { - throw new Error('Cannot register capabilities after connecting to transport'); + throw StateError.registrationAfterConnect('capabilities'); } this._capabilities = mergeCapabilities(this._capabilities, capabilities); @@ -571,7 +586,7 @@ export class Client< // Type guard: if success is false, error is guaranteed to exist const errorMessage = validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid elicitation request: ${errorMessage}`); } const { params } = validatedRequest.data; @@ -579,11 +594,11 @@ export class Client< const { supportsFormMode, supportsUrlMode } = getSupportedElicitationModes(this._capabilities.elicitation); if (params.mode === 'form' && !supportsFormMode) { - throw new McpError(ErrorCode.InvalidParams, 'Client does not support form-mode elicitation requests'); + throw ProtocolError.invalidParams('Client does not support form-mode elicitation requests'); } if (params.mode === 'url' && !supportsUrlMode) { - throw new McpError(ErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); + throw ProtocolError.invalidParams('Client does not support URL-mode elicitation requests'); } const result = await Promise.resolve(handler(request, ctx)); @@ -596,7 +611,7 @@ export class Client< taskValidationResult.error instanceof Error ? taskValidationResult.error.message : String(taskValidationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid task creation result: ${errorMessage}`); } return taskValidationResult.data; } @@ -607,7 +622,7 @@ export class Client< // Type guard: if success is false, error is guaranteed to exist const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid elicitation result: ${errorMessage}`); } const validatedResult = validationResult.data; @@ -643,7 +658,7 @@ export class Client< if (!validatedRequest.success) { const errorMessage = validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid sampling request: ${errorMessage}`); } const { params } = validatedRequest.data; @@ -658,7 +673,7 @@ export class Client< taskValidationResult.error instanceof Error ? taskValidationResult.error.message : String(taskValidationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid task creation result: ${errorMessage}`); } return taskValidationResult.data; } @@ -670,7 +685,7 @@ export class Client< if (!validationResult.success) { const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid sampling result: ${errorMessage}`); } return validationResult.data; @@ -713,7 +728,7 @@ export class Client< protected assertCapability(capability: keyof ServerCapabilities, method: string): void { if (!this._serverCapabilities?.[capability]) { - throw new Error(`Server does not support ${capability} (required for ${method})`); + throw CapabilityError.serverDoesNotSupport(capability, method); } } @@ -796,7 +811,7 @@ export class Client< switch (method as ClientRequest['method']) { case 'logging/setLevel': { if (!this._serverCapabilities?.logging) { - throw new Error(`Server does not support logging (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('logging', method); } break; } @@ -804,7 +819,7 @@ export class Client< case 'prompts/get': case 'prompts/list': { if (!this._serverCapabilities?.prompts) { - throw new Error(`Server does not support prompts (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('prompts', method); } break; } @@ -815,11 +830,11 @@ export class Client< case 'resources/subscribe': case 'resources/unsubscribe': { if (!this._serverCapabilities?.resources) { - throw new Error(`Server does not support resources (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('resources', method); } if (method === 'resources/subscribe' && !this._serverCapabilities.resources.subscribe) { - throw new Error(`Server does not support resource subscriptions (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('resources.subscribe', method); } break; @@ -828,14 +843,14 @@ export class Client< case 'tools/call': case 'tools/list': { if (!this._serverCapabilities?.tools) { - throw new Error(`Server does not support tools (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('tools', method); } break; } case 'completion/complete': { if (!this._serverCapabilities?.completions) { - throw new Error(`Server does not support completions (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('completions', method); } break; } @@ -856,7 +871,7 @@ export class Client< switch (method as ClientNotification['method']) { case 'notifications/roots/list_changed': { if (!this._capabilities.roots?.listChanged) { - throw new Error(`Client does not support roots list changed notifications (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('roots.listChanged', method); } break; } @@ -888,21 +903,21 @@ export class Client< switch (method) { case 'sampling/createMessage': { if (!this._capabilities.sampling) { - throw new Error(`Client does not support sampling capability (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('sampling', method); } break; } case 'elicitation/create': { if (!this._capabilities.elicitation) { - throw new Error(`Client does not support elicitation capability (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('elicitation', method); } break; } case 'roots/list': { if (!this._capabilities.roots) { - throw new Error(`Client does not support roots capability (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('roots', method); } break; } @@ -912,7 +927,7 @@ export class Client< case 'tasks/result': case 'tasks/cancel': { if (!this._capabilities.tasks) { - throw new Error(`Client does not support tasks capability (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('tasks', method); } break; } @@ -990,8 +1005,7 @@ export class Client< ) { // Guard: required-task tools need experimental API if (this.isToolTaskRequired(params.name)) { - throw new McpError( - ErrorCode.InvalidRequest, + throw ProtocolError.invalidRequest( `Tool "${params.name}" requires task-based execution. Use client.experimental.tasks.callToolStream() instead.` ); } @@ -1003,10 +1017,7 @@ export class Client< if (validator) { // If tool has outputSchema, it MUST return structuredContent (unless it's an error) if (!result.structuredContent && !result.isError) { - throw new McpError( - ErrorCode.InvalidRequest, - `Tool ${params.name} has an output schema but did not return structured content` - ); + throw ProtocolError.invalidRequest(`Tool ${params.name} has an output schema but did not return structured content`); } // Only validate structured content if present (not when there's an error) @@ -1016,17 +1027,15 @@ export class Client< const validationResult = validator(result.structuredContent); if (!validationResult.valid) { - throw new McpError( - ErrorCode.InvalidParams, + throw ProtocolError.invalidParams( `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` ); } } catch (error) { - if (error instanceof McpError) { + if (isProtocolError(error)) { throw error; } - throw new McpError( - ErrorCode.InvalidParams, + throw ProtocolError.invalidParams( `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` ); } @@ -1184,18 +1193,85 @@ export class Client< } /** - * Registers an error handler for application errors. + * Updates the error interceptor based on current handlers. + * This combines both onError and onProtocolError handlers into a single interceptor. + */ + private _updateErrorInterceptor(): void { + if (!this._onErrorHandler && !this._onProtocolErrorHandler) { + // No handlers, clear the interceptor + this.setErrorInterceptor(undefined); + return; + } + + this.setErrorInterceptor(async (error: Error, ctx: ErrorInterceptionContext): Promise => { + const errorContext: ErrorContext = { + type: ctx.type === 'protocol' ? 'protocol' : (ctx.method as ErrorContext['type']) || 'sampling', + method: ctx.method, + requestId: typeof ctx.requestId === 'string' ? ctx.requestId : String(ctx.requestId) + }; + + let result: OnErrorReturn | OnProtocolErrorReturn | void = undefined; + + if (ctx.type === 'protocol' && this._onProtocolErrorHandler) { + // Protocol error - use onProtocolError handler + result = await this._onProtocolErrorHandler(error, errorContext); + } else if (this._onErrorHandler) { + // Application error (or protocol error without specific handler) - use onError handler + result = await this._onErrorHandler(error, errorContext); + } + + if (result === undefined || result === null) { + return undefined; + } + + // Convert the handler result to ErrorInterceptionResult + if (typeof result === 'string') { + return { message: result }; + } else if (result instanceof Error) { + const errorWithCode = result as Error & { code?: number; data?: unknown }; + return { + message: result.message, + code: ctx.type === 'application' ? errorWithCode.code : undefined, + data: errorWithCode.data + }; + } else { + // Object with code/message/data + return { + message: result.message, + code: ctx.type === 'application' ? (result as OnErrorReturn & { code?: number }).code : undefined, + data: result.data + }; + } + }); + } + + private _clearOnErrorHandler(): void { + this._onErrorHandler = undefined; + this._updateErrorInterceptor(); + } + + private _clearOnProtocolErrorHandler(): void { + this._onProtocolErrorHandler = undefined; + this._updateErrorInterceptor(); + } + + /** + * Registers an error handler for application errors in sampling/elicitation/rootsList handlers. * * The handler receives the error and a context object with information about where - * the error occurred. It can optionally return a custom error response. + * the error occurred. It can optionally return a custom error response that will + * modify the error sent to the server. + * + * Note: This is a single-handler pattern. Setting a new handler replaces any previous one. + * The handler is awaited, so async handlers are fully supported. * * @param handler - Error handler function * @returns Unsubscribe function * * @example * ```typescript - * const unsubscribe = client.onError((error, ctx) => { - * console.error(`Error in ${ctx.type}: ${error.message}`); + * const unsubscribe = client.onError(async (error, ctx) => { + * console.error(`Error in ${ctx.type}/${ctx.method}: ${error.message}`); * // Optionally return a custom error response * return { * code: -32000, @@ -1206,43 +1282,35 @@ export class Client< * ``` */ onError(handler: OnErrorHandler): () => void { - return this.events.on('error', ({ error, context }) => { - const errorContext: ErrorContext = { - type: (context as 'sampling' | 'elicitation' | 'rootsList' | 'protocol') || 'protocol', - method: context || 'unknown', - requestId: 'unknown' - }; - handler(error, errorContext); - }); + this._onErrorHandler = handler; + this._updateErrorInterceptor(); + return this._clearOnErrorHandler.bind(this); } /** - * Registers an error handler for protocol errors. + * Registers an error handler for protocol errors (method not found, parse error, etc.). * * The handler receives the error and a context object. It can optionally return - * a custom error response (but cannot change the error code). + * a custom error response. Note that the error code cannot be changed for protocol + * errors as they have fixed codes per the MCP specification. + * + * Note: This is a single-handler pattern. Setting a new handler replaces any previous one. + * The handler is awaited, so async handlers are fully supported. * * @param handler - Error handler function * @returns Unsubscribe function * * @example * ```typescript - * const unsubscribe = client.onProtocolError((error, ctx) => { + * const unsubscribe = client.onProtocolError(async (error, ctx) => { * console.error(`Protocol error in ${ctx.method}: ${error.message}`); * return { message: `Protocol error: ${error.message}` }; * }); * ``` */ onProtocolError(handler: OnProtocolErrorHandler): () => void { - return this.events.on('error', ({ error, context }) => { - if (context === 'protocol') { - const errorContext: ErrorContext = { - type: 'protocol', - method: context || 'unknown', - requestId: 'unknown' - }; - handler(error, errorContext); - } - }); + this._onProtocolErrorHandler = handler; + this._updateErrorInterceptor(); + return this._clearOnProtocolErrorHandler.bind(this); } } diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 2141cb12d..253f84a0e 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -1,5 +1,5 @@ import type { FetchLike, JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; -import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders } from '@modelcontextprotocol/core'; +import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders, StateError } from '@modelcontextprotocol/core'; import type { ErrorEvent, EventSourceInit } from 'eventsource'; import { EventSource } from 'eventsource'; @@ -211,7 +211,7 @@ export class SSEClientTransport implements Transport { async start() { if (this._eventSource) { - throw new Error('SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.'); + throw StateError.alreadyConnected(); } return await this._startOrAuth(); @@ -245,7 +245,7 @@ export class SSEClientTransport implements Transport { async send(message: JSONRPCMessage): Promise { if (!this._endpoint) { - throw new Error('Not connected'); + throw StateError.notConnected('send message'); } try { diff --git a/packages/client/src/client/stdio.ts b/packages/client/src/client/stdio.ts index 47df59e3b..178d979c4 100644 --- a/packages/client/src/client/stdio.ts +++ b/packages/client/src/client/stdio.ts @@ -4,7 +4,7 @@ import type { Stream } from 'node:stream'; import { PassThrough } from 'node:stream'; import type { JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; -import { ReadBuffer, serializeMessage } from '@modelcontextprotocol/core'; +import { ReadBuffer, serializeMessage, StateError } from '@modelcontextprotocol/core'; import spawn from 'cross-spawn'; export type StdioServerParameters = { @@ -112,9 +112,7 @@ export class StdioClientTransport implements Transport { */ async start(): Promise { if (this._process) { - throw new Error( - 'StdioClientTransport already started! If using Client class, note that connect() calls start() automatically.' - ); + throw StateError.alreadyConnected(); } return new Promise((resolve, reject) => { @@ -246,7 +244,7 @@ export class StdioClientTransport implements Transport { send(message: JSONRPCMessage): Promise { return new Promise(resolve => { if (!this._process?.stdin) { - throw new Error('Not connected'); + throw StateError.notConnected('send message'); } const json = serializeMessage(message); diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index dbee90f31..82645d30e 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -7,7 +7,8 @@ import { isJSONRPCRequest, isJSONRPCResultResponse, JSONRPCMessageSchema, - normalizeHeaders + normalizeHeaders, + StateError } from '@modelcontextprotocol/core'; import { EventSourceParserStream } from 'eventsource-parser/stream'; @@ -422,9 +423,7 @@ export class StreamableHTTPClientTransport implements Transport { async start() { if (this._abortController) { - throw new Error( - 'StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.' - ); + throw StateError.alreadyConnected(); } this._abortController = new AbortController(); diff --git a/packages/client/src/client/websocket.ts b/packages/client/src/client/websocket.ts index cb0c34687..6b6eda667 100644 --- a/packages/client/src/client/websocket.ts +++ b/packages/client/src/client/websocket.ts @@ -1,5 +1,5 @@ import type { JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; -import { JSONRPCMessageSchema } from '@modelcontextprotocol/core'; +import { JSONRPCMessageSchema, StateError } from '@modelcontextprotocol/core'; const SUBPROTOCOL = 'mcp'; @@ -20,9 +20,7 @@ export class WebSocketClientTransport implements Transport { start(): Promise { if (this._socket) { - throw new Error( - 'WebSocketClientTransport already started! If using Client class, note that connect() calls start() automatically.' - ); + throw StateError.alreadyConnected(); } return new Promise((resolve, reject) => { @@ -63,7 +61,7 @@ export class WebSocketClientTransport implements Transport { send(message: JSONRPCMessage): Promise { return new Promise((resolve, reject) => { if (!this._socket) { - reject(new Error('Not connected')); + reject(StateError.notConnected('send message')); return; } diff --git a/packages/client/src/experimental/tasks/client.ts b/packages/client/src/experimental/tasks/client.ts index b4f62b6c7..f0a1ec01c 100644 --- a/packages/client/src/experimental/tasks/client.ts +++ b/packages/client/src/experimental/tasks/client.ts @@ -20,7 +20,7 @@ import type { Result, SchemaOutput } from '@modelcontextprotocol/core'; -import { CallToolResultSchema, ErrorCode, McpError, TaskClientPlugin } from '@modelcontextprotocol/core'; +import { CallToolResultSchema, ErrorCode, ProtocolError, TaskClientPlugin } from '@modelcontextprotocol/core'; import type { Client } from '../../client/client.js'; @@ -62,7 +62,7 @@ export class ExperimentalClientTasks< private _getTaskClient(): TaskClientPlugin { const plugin = this._client.getPlugin(TaskClientPlugin); if (!plugin) { - throw new McpError( + throw new ProtocolError( ErrorCode.InternalError, 'TaskClientPlugin not installed. Use client.usePlugin(new TaskClientPlugin()) first.' ); @@ -137,7 +137,7 @@ export class ExperimentalClientTasks< if (!result.structuredContent && !result.isError) { yield { type: 'error', - error: new McpError( + error: new ProtocolError( ErrorCode.InvalidRequest, `Tool ${params.name} has an output schema but did not return structured content` ) @@ -154,7 +154,7 @@ export class ExperimentalClientTasks< if (!validationResult.valid) { yield { type: 'error', - error: new McpError( + error: new ProtocolError( ErrorCode.InvalidParams, `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` ) @@ -162,13 +162,13 @@ export class ExperimentalClientTasks< return; } } catch (error) { - if (error instanceof McpError) { + if (error instanceof ProtocolError) { yield { type: 'error', error }; return; } yield { type: 'error', - error: new McpError( + error: new ProtocolError( ErrorCode.InvalidParams, `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` ) @@ -208,7 +208,7 @@ export class ExperimentalClientTasks< */ async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { if (!resultSchema) { - throw new McpError(ErrorCode.InvalidParams, 'resultSchema is required'); + throw new ProtocolError(ErrorCode.InvalidParams, 'resultSchema is required'); } return this._getTaskClient().getTaskResult({ taskId }, resultSchema, options); } diff --git a/packages/core/src/errors.ts b/packages/core/src/errors.ts index 581cae058..7b5dd81c5 100644 --- a/packages/core/src/errors.ts +++ b/packages/core/src/errors.ts @@ -3,9 +3,10 @@ * * This module defines a comprehensive error hierarchy for the MCP SDK: * - * 1. Protocol Errors (McpError subclasses) - Errors that cross the wire as JSON-RPC errors - * - ProtocolError: SDK-generated errors with spec-mandated codes (code is locked) - * - ApplicationError: User handler errors wrapped by SDK (code can be customized) + * 1. Protocol Errors - Errors that cross the wire as JSON-RPC errors + * - ProtocolError: Protocol-level errors with locked codes + * - Users can throw ProtocolError for intentional locked-code errors + * - Other errors thrown by users are customizable via onError handler * * 2. SDK Errors (SdkError subclasses) - Local errors that don't cross the wire * - StateError: Wrong SDK state (not connected, already connected, etc.) @@ -16,6 +17,7 @@ * 3. OAuth Errors - Kept in auth/errors.ts (unchanged) */ +import type { ElicitRequestURLParams } from './types/types.js'; import { ErrorCode } from './types/types.js'; // ═══════════════════════════════════════════════════════════════════════════ @@ -48,21 +50,19 @@ export enum SdkErrorCode { } // ═══════════════════════════════════════════════════════════════════════════ -// Protocol Error Subclasses (McpError hierarchy - crosses the wire) +// Protocol Errors (cross the wire as JSON-RPC errors) // ═══════════════════════════════════════════════════════════════════════════ -// Note: McpError is defined in types/types.ts and re-exported from there. -// These subclasses provide more specific error types. - /** - * Protocol-level errors generated by the SDK for protocol violations. + * Protocol-level errors that cross the wire as JSON-RPC errors. * The error code is LOCKED and cannot be changed in onProtocolError handlers. * - * These errors are for spec-mandated situations like: - * - Parse errors (-32700) - * - Invalid request (-32600) - * - Method not found (-32601) - * - Invalid params (-32602) + * Use this when you want a specific error code that should not be customized: + * - SDK uses this for spec-mandated errors (parse error, method not found, etc.) + * - Users can throw this for intentional locked-code errors + * + * For errors where you want the onError handler to customize the response, + * throw a plain Error instead. */ export class ProtocolError extends Error { /** @@ -75,7 +75,7 @@ export class ProtocolError extends Error { message: string, public readonly data?: unknown ) { - super(`MCP protocol error ${code}: ${message}`); + super(`MCP error ${code}: ${message}`); this.name = 'ProtocolError'; } @@ -106,51 +106,45 @@ export class ProtocolError extends Error { static invalidParams(message: string = 'Invalid params', data?: unknown): ProtocolError { return new ProtocolError(ErrorCode.InvalidParams, message, data); } -} -/** - * Application-level errors from user handler code, wrapped by the SDK. - * The error code CAN be customized in onError handlers. - * - * Default code is InternalError (-32603), but can be changed. - */ -export class ApplicationError extends Error { /** - * Indicates this is an application-level error with a customizable code + * Creates an internal error (-32603) */ - readonly isProtocolLevel = false as const; - - constructor( - public code: number = ErrorCode.InternalError, - message: string, - public readonly data?: unknown, - public override readonly cause?: Error - ) { - super(`MCP application error ${code}: ${message}`); - this.name = 'ApplicationError'; - if (cause) { - this.cause = cause; - } + static internalError(message: string = 'Internal error', data?: unknown): ProtocolError { + return new ProtocolError(ErrorCode.InternalError, message, data); } /** - * Wraps any error as an ApplicationError + * Factory method to create the appropriate error type based on the error code and data */ - static wrap(error: unknown, code: number = ErrorCode.InternalError): ApplicationError { - if (error instanceof ApplicationError) { - return error; - } - if (error instanceof Error) { - return new ApplicationError(code, error.message, undefined, error); + static fromError(code: number, message: string, data?: unknown): ProtocolError { + // Check for specific error types + if (code === ErrorCode.UrlElicitationRequired && data) { + const errorData = data as { elicitations?: unknown[] }; + if (errorData.elicitations) { + return new UrlElicitationRequiredError(errorData.elicitations as ElicitRequestURLParams[], message); + } } - return new ApplicationError(code, String(error)); + + // Default to generic ProtocolError + return new ProtocolError(code, message, data); } +} - /** - * Creates an internal error (-32603) - */ - static internalError(message: string, data?: unknown, cause?: Error): ApplicationError { - return new ApplicationError(ErrorCode.InternalError, message, data, cause); +/** + * Specialized error type when a tool requires a URL mode elicitation. + * This makes it nicer for the client to handle since there is specific data to work with. + */ +export class UrlElicitationRequiredError extends ProtocolError { + constructor(elicitations: ElicitRequestURLParams[], message: string = `URL elicitation${elicitations.length > 1 ? 's' : ''} required`) { + super(ErrorCode.UrlElicitationRequired, message, { + elicitations: elicitations + }); + this.name = 'UrlElicitationRequiredError'; + } + + get elicitations(): ElicitRequestURLParams[] { + return (this.data as { elicitations: ElicitRequestURLParams[] })?.elicitations ?? []; } } @@ -300,12 +294,30 @@ export class TransportError extends SdkError { return new TransportError(SdkErrorCode.CONNECTION_TIMEOUT, `Connection timed out after ${timeoutMs}ms`, cause); } + /** + * Creates a request timeout error (request sent but no response received in time) + */ + static requestTimeout( + message: string = 'Request timed out', + details?: { timeout?: number; maxTotalTimeout?: number; totalElapsed?: number } + ): TransportError { + const detailsStr = details ? ` (${JSON.stringify(details)})` : ''; + return new TransportError(SdkErrorCode.CONNECTION_TIMEOUT, `${message}${detailsStr}`); + } + /** * Creates a send failed error */ static sendFailed(message: string = 'Failed to send message', cause?: Error): TransportError { return new TransportError(SdkErrorCode.SEND_FAILED, message, cause); } + + /** + * Creates a connection closed error + */ + static connectionClosed(message: string = 'Connection closed'): TransportError { + return new TransportError(SdkErrorCode.CONNECTION_LOST, message); + } } /** @@ -357,13 +369,6 @@ export function isProtocolError(error: unknown): error is ProtocolError { return error instanceof ProtocolError; } -/** - * Type guard to check if an error is an ApplicationError - */ -export function isApplicationError(error: unknown): error is ApplicationError { - return error instanceof ApplicationError; -} - /** * Type guard to check if an error is an SdkError */ diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index eff48d57e..5401c9d8f 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1,4 +1,4 @@ -import { StateError } from '../errors.js'; +import { isProtocolError, ProtocolError, StateError, TransportError } from '../errors.js'; import type { CancelledNotification, ClientCapabilities, @@ -22,7 +22,6 @@ import { isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResultResponse, - McpError, PingRequestSchema, ProgressNotificationSchema } from '../types/types.js'; @@ -111,7 +110,7 @@ export type RequestOptions = { signal?: AbortSignal; /** - * A timeout (in milliseconds) for this request. If exceeded, an McpError with code `RequestTimeout` will be raised from request(). + * A timeout (in milliseconds) for this request. If exceeded, a TransportError will be raised from request(). * * If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout. */ @@ -126,7 +125,7 @@ export type RequestOptions = { /** * Maximum total time (in milliseconds) to wait for a response. - * If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications. + * If exceeded, a TransportError will be raised, regardless of progress notifications. * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; @@ -677,7 +676,7 @@ export abstract class Protocol= info.maxTotalTimeout) { this._timeoutManager.cleanup(messageId); - throw McpError.fromError(ErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { + throw TransportError.requestTimeout('Maximum total timeout exceeded', { maxTotalTimeout: info.maxTotalTimeout, totalElapsed }); @@ -740,7 +739,7 @@ export abstract class Protocol this._onerror(new Error(`Failed to send response: ${error}`), 'send-response')) + .catch(async error => { + // Last resort: try to send an error response even if something went wrong above + // This prevents the client from hanging indefinitely + try { + const errorCode = isProtocolError(error) ? error.code : ErrorCode.InternalError; + const errorResponse: JSONRPCErrorResponse = { + jsonrpc: '2.0', + id: request.id, + error: { + code: errorCode, + message: error?.message ?? 'Internal error' + } + }; + await capturedTransport?.send(errorResponse); + } catch { + // Truly give up - can't even send error response + } + this._onerror(new Error(`Failed to send response: ${error}`), 'send-response'); + }) .finally(() => { this._handlerRegistry.removeAbortController(request.id); }); @@ -1022,7 +1045,7 @@ export abstract class Protocol this._onerror(new Error(`Failed to send cancellation: ${error}`), 'send-cancellation')); - // Wrap the reason in an McpError if it isn't already - const error = reason instanceof McpError ? reason : new McpError(ErrorCode.RequestTimeout, String(reason)); + // Wrap the reason in a TransportError if it isn't already an error we recognize + const error = reason instanceof Error ? reason : TransportError.requestTimeout(String(reason)); reject(error); }; @@ -1228,7 +1251,7 @@ export abstract class Protocol cancel(McpError.fromError(ErrorCode.RequestTimeout, 'Request timed out', { timeout })); + const timeoutHandler = () => cancel(TransportError.requestTimeout('Request timed out', { timeout })); this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); diff --git a/packages/core/src/shared/responseMessage.ts b/packages/core/src/shared/responseMessage.ts index 8a0dcc2c2..b67f52362 100644 --- a/packages/core/src/shared/responseMessage.ts +++ b/packages/core/src/shared/responseMessage.ts @@ -1,4 +1,5 @@ -import type { McpError, Result, Task } from '../types/types.js'; +import type { ProtocolError } from '../errors.js'; +import type { Result, Task } from '../types/types.js'; /** * Base message type @@ -36,7 +37,7 @@ export interface ResultMessage extends BaseResponseMessage { */ export interface ErrorMessage extends BaseResponseMessage { type: 'error'; - error: McpError; + error: ProtocolError; } /** diff --git a/packages/core/src/shared/taskClientPlugin.ts b/packages/core/src/shared/taskClientPlugin.ts index 66f24ad9b..e29211133 100644 --- a/packages/core/src/shared/taskClientPlugin.ts +++ b/packages/core/src/shared/taskClientPlugin.ts @@ -11,6 +11,7 @@ * ``` */ +import { ProtocolError } from '../errors.js'; import { isTerminal } from '../experimental/tasks/interfaces.js'; import type { CancelTaskResult, @@ -34,7 +35,6 @@ import { GetTaskResultSchema, isJSONRPCResultResponse, ListTasksResultSchema, - McpError, RELATED_TASK_META_KEY } from '../types/types.js'; import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; @@ -369,7 +369,7 @@ export class TaskClientPlugin implements ProtocolPlugin { taskId = createResult.task.taskId; yield { type: 'taskCreated', task: createResult.task }; } else { - throw new McpError(ErrorCode.InternalError, 'Task creation did not return a task'); + throw new ProtocolError(ErrorCode.InternalError, 'Task creation did not return a task'); } // Poll for task completion @@ -390,14 +390,14 @@ export class TaskClientPlugin implements ProtocolPlugin { case 'failed': { yield { type: 'error', - error: new McpError(ErrorCode.InternalError, `Task ${taskId} failed`) + error: new ProtocolError(ErrorCode.InternalError, `Task ${taskId} failed`) }; break; } case 'cancelled': { yield { type: 'error', - error: new McpError(ErrorCode.InternalError, `Task ${taskId} was cancelled`) + error: new ProtocolError(ErrorCode.InternalError, `Task ${taskId} was cancelled`) }; break; } @@ -424,7 +424,7 @@ export class TaskClientPlugin implements ProtocolPlugin { } catch (error) { yield { type: 'error', - error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) + error: error instanceof ProtocolError ? error : new ProtocolError(ErrorCode.InternalError, String(error)) }; } } diff --git a/packages/core/src/shared/taskPlugin.ts b/packages/core/src/shared/taskPlugin.ts index 34662bd20..6d6e86589 100644 --- a/packages/core/src/shared/taskPlugin.ts +++ b/packages/core/src/shared/taskPlugin.ts @@ -9,6 +9,7 @@ * The plugin is internal to the SDK and not exposed as a public API. */ +import { ProtocolError } from '../errors.js'; import { RequestTaskStore } from '../experimental/requestTaskStore.js'; import type { QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; import { isTerminal } from '../experimental/tasks/interfaces.js'; @@ -34,7 +35,6 @@ import { isJSONRPCRequest, isTaskAugmentedRequestParams, ListTasksRequestSchema, - McpError, RELATED_TASK_META_KEY } from '../types/types.js'; import type { HandlerContextBase, PluginContext, PluginHandlerExtra, ProtocolPlugin } from './plugin.js'; @@ -137,7 +137,7 @@ export class TaskPlugin implements ProtocolPlugin { // For now, we support tasks for tools/call and sampling/createMessage const taskCapableMethods = ['tools/call', 'sampling/createMessage']; if (!taskCapableMethods.includes(request.method)) { - throw new McpError(ErrorCode.InvalidRequest, `Task creation is not supported for method: ${request.method}`); + throw new ProtocolError(ErrorCode.InvalidRequest, `Task creation is not supported for method: ${request.method}`); } } // Return void to pass through unchanged @@ -276,7 +276,7 @@ export class TaskPlugin implements ProtocolPlugin { const requestId = message.message.id as RequestId; const resolver = this.ctx.resolvers.get(requestId); if (resolver) { - resolver(new McpError(ErrorCode.InternalError, 'Task cancelled or completed')); + resolver(new ProtocolError(ErrorCode.InternalError, 'Task cancelled or completed')); this.ctx.resolvers.remove(requestId); } else { this.ctx.reportError(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); @@ -302,7 +302,7 @@ export class TaskPlugin implements ProtocolPlugin { return new Promise((resolve, reject) => { if (signal.aborted) { - reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); + reject(new ProtocolError(ErrorCode.InvalidRequest, 'Request cancelled')); return; } @@ -314,7 +314,7 @@ export class TaskPlugin implements ProtocolPlugin { 'abort', () => { clearTimeout(timeoutId); - reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); + reject(new ProtocolError(ErrorCode.InvalidRequest, 'Request cancelled')); }, { once: true } ); @@ -331,7 +331,7 @@ export class TaskPlugin implements ProtocolPlugin { private async handleGetTask(request: GetTaskRequest, extra: PluginHandlerExtra): Promise { const task = await this.config.taskStore.getTask(request.params.taskId, extra.mcpCtx.sessionId); if (!task) { - throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + throw new ProtocolError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); } // Per spec: tasks/get responses SHALL NOT include related-task metadata @@ -352,7 +352,7 @@ export class TaskPlugin implements ProtocolPlugin { // Check task status const task = await this.config.taskStore.getTask(taskId, extra.mcpCtx.sessionId); if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); + throw new ProtocolError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); } // If task is not terminal, wait for updates and poll again @@ -422,7 +422,7 @@ export class TaskPlugin implements ProtocolPlugin { resolver(message as JSONRPCResultResponse); } else { const errorMessage = message as JSONRPCErrorResponse; - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } @@ -435,7 +435,10 @@ export class TaskPlugin implements ProtocolPlugin { const { tasks, nextCursor } = await this.config.taskStore.listTasks(params?.cursor, extra.mcpCtx.sessionId); return { tasks, nextCursor, _meta: {} }; } catch (error) { - throw new McpError(ErrorCode.InvalidParams, `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}`); + throw new ProtocolError( + ErrorCode.InvalidParams, + `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` + ); } } @@ -447,11 +450,11 @@ export class TaskPlugin implements ProtocolPlugin { const task = await this.config.taskStore.getTask(params.taskId, extra.mcpCtx.sessionId); if (!task) { - throw new McpError(ErrorCode.InvalidParams, `Task not found: ${params.taskId}`); + throw new ProtocolError(ErrorCode.InvalidParams, `Task not found: ${params.taskId}`); } if (isTerminal(task.status)) { - throw new McpError(ErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); + throw new ProtocolError(ErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); } await this.config.taskStore.updateTaskStatus( @@ -465,15 +468,15 @@ export class TaskPlugin implements ProtocolPlugin { const cancelledTask = await this.config.taskStore.getTask(params.taskId, extra.mcpCtx.sessionId); if (!cancelledTask) { - throw new McpError(ErrorCode.InvalidParams, `Task not found after cancellation: ${params.taskId}`); + throw new ProtocolError(ErrorCode.InvalidParams, `Task not found after cancellation: ${params.taskId}`); } return { _meta: {}, ...cancelledTask }; } catch (error) { - if (error instanceof McpError) { + if (error instanceof ProtocolError) { throw error; } - throw new McpError( + throw new ProtocolError( ErrorCode.InvalidRequest, `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` ); diff --git a/packages/core/src/types/types.ts b/packages/core/src/types/types.ts index d3e404c58..a8c08f4f8 100644 --- a/packages/core/src/types/types.ts +++ b/packages/core/src/types/types.ts @@ -2302,48 +2302,9 @@ export const ServerResultSchema = z.union([ CreateTaskResultSchema ]); -export class McpError extends Error { - constructor( - public readonly code: number, - message: string, - public readonly data?: unknown - ) { - super(`MCP error ${code}: ${message}`); - this.name = 'McpError'; - } - - /** - * Factory method to create the appropriate error type based on the error code and data - */ - static fromError(code: number, message: string, data?: unknown): McpError { - // Check for specific error types - if (code === ErrorCode.UrlElicitationRequired && data) { - const errorData = data as { elicitations?: unknown[] }; - if (errorData.elicitations) { - return new UrlElicitationRequiredError(errorData.elicitations as ElicitRequestURLParams[], message); - } - } - - // Default to generic McpError - return new McpError(code, message, data); - } -} - -/** - * Specialized error type when a tool requires a URL mode elicitation. - * This makes it nicer for the client to handle since there is specific data to work with instead of just a code to check against. - */ -export class UrlElicitationRequiredError extends McpError { - constructor(elicitations: ElicitRequestURLParams[], message: string = `URL elicitation${elicitations.length > 1 ? 's' : ''} required`) { - super(ErrorCode.UrlElicitationRequired, message, { - elicitations: elicitations - }); - } - - get elicitations(): ElicitRequestURLParams[] { - return (this.data as { elicitations: ElicitRequestURLParams[] })?.elicitations ?? []; - } -} +// Note: McpError has been removed. Use ProtocolError from '../errors.js' instead. +// ProtocolError is for errors with locked codes (SDK-generated or user-intentional). +// For customizable errors, throw a plain Error and use the onError handler. type Primitive = string | number | boolean | bigint | null | undefined; type Flatten = T extends Primitive diff --git a/packages/core/src/util/inMemory.ts b/packages/core/src/util/inMemory.ts index 3f832b06b..9e541b053 100644 --- a/packages/core/src/util/inMemory.ts +++ b/packages/core/src/util/inMemory.ts @@ -1,3 +1,4 @@ +import { StateError } from '../errors.js'; import type { Transport } from '../shared/transport.js'; import type { AuthInfo, JSONRPCMessage, RequestId } from '../types/types.js'; @@ -50,7 +51,7 @@ export class InMemoryTransport implements Transport { */ async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId; authInfo?: AuthInfo }): Promise { if (!this._otherTransport) { - throw new Error('Not connected'); + throw StateError.notConnected('send message'); } if (this._otherTransport.onmessage) { diff --git a/packages/core/src/util/zodJsonSchemaCompat.ts b/packages/core/src/util/zodJsonSchemaCompat.ts index 12e5e88c4..d144a6941 100644 --- a/packages/core/src/util/zodJsonSchemaCompat.ts +++ b/packages/core/src/util/zodJsonSchemaCompat.ts @@ -9,6 +9,7 @@ import type * as z4c from 'zod/v4/core'; import * as z4mini from 'zod/v4-mini'; import { zodToJsonSchema } from 'zod-to-json-schema'; +import { ValidationError } from '../errors.js'; import type { AnyObjectSchema, AnySchema } from './zodCompat.js'; import { getLiteralValue, getObjectShape, isZ4Schema, safeParse } from './zodCompat.js'; @@ -48,7 +49,7 @@ export function getMethodLiteral(schema: AnyObjectSchema): string { const shape = getObjectShape(schema); const methodSchema = shape?.method as AnySchema | undefined; if (!methodSchema) { - throw new Error('Schema is missing a method literal'); + throw ValidationError.invalidSchema('Schema is missing a method literal'); } const value = getLiteralValue(methodSchema); diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 460c27852..4c76d8fee 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -31,7 +31,8 @@ import type { Task, TaskCreationParams } from '../../src/types/types.js'; -import { CallToolRequestSchema, ErrorCode, McpError, RELATED_TASK_META_KEY } from '../../src/types/types.js'; +import { CallToolRequestSchema, ErrorCode, RELATED_TASK_META_KEY } from '../../src/types/types.js'; +import { ProtocolError } from '../../src/errors.js'; // Type helper for accessing private/protected Protocol properties in tests interface TestProtocol { @@ -258,8 +259,8 @@ describe('protocol tests', () => { timeout: 0 }); } catch (error) { - expect(error).toBeInstanceOf(McpError); - if (error instanceof McpError) { + expect(error).toBeInstanceOf(ProtocolError); + if (error instanceof ProtocolError) { expect(error.code).toBe(ErrorCode.RequestTimeout); } } @@ -3592,7 +3593,7 @@ describe('Message Interception', () => { }); protocol.setRequestHandler(TestRequestSchema, async () => { - throw new McpError(ErrorCode.InternalError, 'Test error message'); + throw new ProtocolError(ErrorCode.InternalError, 'Test error message'); }); // Simulate an incoming request with relatedTask metadata @@ -4134,7 +4135,7 @@ describe('Queue lifecycle management', () => { // Verify the request promise is rejected const result = await requestPromise; - expect(result).toBeInstanceOf(McpError); + expect(result).toBeInstanceOf(ProtocolError); expect(result.message).toContain('Task cancelled or completed'); // Verify queue is cleared (no messages available) @@ -4196,7 +4197,7 @@ describe('Queue lifecycle management', () => { // Verify the request promise is rejected const result = await requestPromise; - expect(result).toBeInstanceOf(McpError); + expect(result).toBeInstanceOf(ProtocolError); expect(result.message).toContain('Task cancelled or completed'); // Verify queue is cleared (no messages available) @@ -4244,11 +4245,11 @@ describe('Queue lifecycle management', () => { const result2 = await request2Promise; const result3 = await request3Promise; - expect(result1).toBeInstanceOf(McpError); + expect(result1).toBeInstanceOf(ProtocolError); expect(result1.message).toContain('Task cancelled or completed'); - expect(result2).toBeInstanceOf(McpError); + expect(result2).toBeInstanceOf(ProtocolError); expect(result2.message).toContain('Task cancelled or completed'); - expect(result3).toBeInstanceOf(McpError); + expect(result3).toBeInstanceOf(ProtocolError); expect(result3.message).toContain('Task cancelled or completed'); // Verify queue is cleared (no messages available) @@ -4284,7 +4285,7 @@ describe('Queue lifecycle management', () => { // Verify request promise is rejected const result = await requestPromise; - expect(result).toBeInstanceOf(McpError); + expect(result).toBeInstanceOf(ProtocolError); expect(result.message).toContain('Task cancelled or completed'); // Verify resolver mapping is cleaned up @@ -4835,7 +4836,7 @@ describe('Error handling for missing resolvers', () => { await testProtocol._clearTaskQueue(task.taskId); // Verify resolver was called with cancellation error - expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); // Verify the error has the correct properties const calledError = resolverMock.mock.calls[0]![0]; @@ -4895,7 +4896,7 @@ describe('Error handling for missing resolvers', () => { await testProtocol._clearTaskQueue(task.taskId); // Verify resolver was called for first request - expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); // Verify the error has the correct properties const calledError = resolverMock.mock.calls[0]![0]; @@ -5045,13 +5046,13 @@ describe('Error handling for missing resolvers', () => { if (resolver) { testProtocol._requestResolvers.delete(reqId); - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } - // Verify resolver was called with McpError - expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + // Verify resolver was called with ProtocolError + expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); const calledError = resolverMock.mock.calls[0]![0]; expect(calledError.code).toBe(ErrorCode.InvalidRequest); expect(calledError.message).toContain('Invalid request parameters'); @@ -5140,13 +5141,13 @@ describe('Error handling for missing resolvers', () => { if (resolver) { testProtocol._requestResolvers.delete(reqId); - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } - // Verify resolver was called with McpError including data - expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + // Verify resolver was called with ProtocolError including data + expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); const calledError = resolverMock.mock.calls[0]![0]; expect(calledError.code).toBe(ErrorCode.InvalidParams); expect(calledError.message).toContain('Validation failed'); @@ -5256,7 +5257,7 @@ describe('Error handling for missing resolvers', () => { const resolver = testProtocol._requestResolvers.get(requestId); if (resolver) { testProtocol._requestResolvers.delete(requestId); - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } @@ -5264,7 +5265,7 @@ describe('Error handling for missing resolvers', () => { // Verify all resolvers were called correctly expect(resolver1).toHaveBeenCalledWith(expect.objectContaining({ id: 1 })); - expect(resolver2).toHaveBeenCalledWith(expect.any(McpError)); + expect(resolver2).toHaveBeenCalledWith(expect.any(ProtocolError)); expect(resolver3).toHaveBeenCalledWith(expect.objectContaining({ id: 3 })); // Verify error has correct properties @@ -5331,7 +5332,7 @@ describe('Error handling for missing resolvers', () => { const resolver = testProtocol._requestResolvers.get(requestId); if (resolver) { testProtocol._requestResolvers.delete(requestId); - const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } } diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 1ac7283a7..9d4504833 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -17,12 +17,14 @@ import type { ListToolsResult, LoggingMessageNotification, PromptReference, + ProtocolPlugin, Resource, ResourceTemplateReference, Result, SchemaOutput, ServerNotification, ServerRequest, + ServerResult, ShapeOutput, ToolAnnotations, ToolExecution, @@ -39,13 +41,14 @@ import { getObjectShape, getParseErrorMessage, GetPromptRequestSchema, + isProtocolError, ListPromptsRequestSchema, ListResourcesRequestSchema, ListResourceTemplatesRequestSchema, ListToolsRequestSchema, - McpError, normalizeObjectSchema, objectFromShape, + ProtocolError, ReadResourceRequestSchema, safeParseAsync, UriTemplate @@ -364,10 +367,10 @@ export class McpServer { try { const tool = this._toolRegistry.getTool(request.params.name); if (!tool) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); + throw ProtocolError.invalidParams(`Tool ${request.params.name} not found`); } if (!tool.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); + throw ProtocolError.invalidParams(`Tool ${request.params.name} disabled`); } const isTaskRequest = !!request.params.task; @@ -376,18 +379,14 @@ export class McpServer { // Validate task hint configuration if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { - throw new McpError( - ErrorCode.InternalError, + throw ProtocolError.internalError( `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` ); } // Handle taskSupport 'required' without task augmentation if (taskSupport === 'required' && !isTaskRequest) { - throw new McpError( - ErrorCode.MethodNotFound, - `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` - ); + throw ProtocolError.methodNotFound(`Tool ${request.params.name} requires task augmentation (taskSupport: 'required')`); } // Handle taskSupport 'optional' without task augmentation - automatic polling @@ -428,7 +427,7 @@ export class McpServer { return result; } catch (error) { - if (error instanceof McpError && error.code === ErrorCode.UrlElicitationRequired) { + if (isProtocolError(error) && error.code === ErrorCode.UrlElicitationRequired) { throw error; // Return the error to the caller without wrapping in CallToolResult } return this.createToolError(error instanceof Error ? error.message : String(error)); @@ -479,7 +478,7 @@ export class McpServer { if (!parseResult.success) { const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Input validation error: Invalid arguments for tool ${toolName}: ${errorMessage}`); + throw ProtocolError.invalidParams(`Input validation error: Invalid arguments for tool ${toolName}: ${errorMessage}`); } return parseResult.data as unknown as Args; @@ -503,8 +502,7 @@ export class McpServer { } if (!result.structuredContent) { - throw new McpError( - ErrorCode.InvalidParams, + throw ProtocolError.invalidParams( `Output validation error: Tool ${toolName} has an output schema but no structured content was provided` ); } @@ -515,10 +513,7 @@ export class McpServer { if (!parseResult.success) { const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; const errorMessage = getParseErrorMessage(error); - throw new McpError( - ErrorCode.InvalidParams, - `Output validation error: Invalid structured content for tool ${toolName}: ${errorMessage}` - ); + throw ProtocolError.invalidParams(`Output validation error: Invalid structured content for tool ${toolName}: ${errorMessage}`); } } @@ -623,7 +618,7 @@ export class McpServer { } default: { - throw new McpError(ErrorCode.InvalidParams, `Invalid completion reference: ${request.params.ref}`); + throw ProtocolError.invalidParams(`Invalid completion reference: ${request.params.ref}`); } } }); @@ -634,11 +629,11 @@ export class McpServer { private async handlePromptCompletion(request: CompleteRequestPrompt, ref: PromptReference): Promise { const prompt = this._promptRegistry.getPrompt(ref.name); if (!prompt) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} not found`); + throw ProtocolError.invalidParams(`Prompt ${ref.name} not found`); } if (!prompt.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} disabled`); + throw ProtocolError.invalidParams(`Prompt ${ref.name} disabled`); } if (!prompt.argsSchema) { @@ -671,7 +666,7 @@ export class McpServer { return EMPTY_COMPLETION_RESULT; } - throw new McpError(ErrorCode.InvalidParams, `Resource template ${request.params.ref.uri} not found`); + throw ProtocolError.invalidParams(`Resource template ${request.params.ref.uri} not found`); } const completer = template.template.completeCallback(request.params.argument.name); @@ -733,7 +728,7 @@ export class McpServer { const resource = this._resourceRegistry.getResource(uri.toString()); if (resource) { if (!resource.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); + throw ProtocolError.invalidParams(`Resource ${uri} disabled`); } // Build middleware context @@ -775,7 +770,7 @@ export class McpServer { }); } - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); + throw ProtocolError.invalidParams(`Resource ${uri} not found`); }); this._resourceHandlersInitialized = true; @@ -807,11 +802,11 @@ export class McpServer { this.server.setRequestHandler(GetPromptRequestSchema, async (request, ctx): Promise => { const prompt = this._promptRegistry.getPrompt(request.params.name); if (!prompt) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); + throw ProtocolError.invalidParams(`Prompt ${request.params.name} not found`); } if (!prompt.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); + throw ProtocolError.invalidParams(`Prompt ${request.params.name} disabled`); } // Build middleware context @@ -837,10 +832,7 @@ export class McpServer { if (!parseResult.success) { const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; const errorMessage = getParseErrorMessage(error); - throw new McpError( - ErrorCode.InvalidParams, - `Invalid arguments for prompt ${request.params.name}: ${errorMessage}` - ); + throw ProtocolError.invalidParams(`Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); } const args = parseResult.data; @@ -859,6 +851,11 @@ export class McpServer { this._promptHandlersInitialized = true; } + usePlugin(plugin: ProtocolPlugin): this { + this.server.usePlugin(plugin); + return this; + } + /** * Registers a resource with a config object and callback. * For static resources, use a URI string. For dynamic resources, use a ResourceTemplate. diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index cd2ff6a0e..f6fa9f346 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -47,12 +47,12 @@ import { assertToolsCallTaskCapability, CallToolRequestSchema, CallToolResultSchema, + CapabilityError, CreateMessageResultSchema, CreateMessageResultWithToolsSchema, CreateTaskResultSchema, ElicitResultSchema, EmptyResultSchema, - ErrorCode, getObjectShape, InitializedNotificationSchema, InitializeRequestSchema, @@ -60,11 +60,12 @@ import { LATEST_PROTOCOL_VERSION, ListRootsResultSchema, LoggingLevelSchema, - McpError, mergeCapabilities, Protocol, + ProtocolError, safeParse, SetLevelRequestSchema, + StateError, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; @@ -223,7 +224,7 @@ export class Server< */ public registerCapabilities(capabilities: ServerCapabilities): void { if (this.transport) { - throw new Error('Cannot register capabilities after connecting to transport'); + throw StateError.registrationAfterConnect('capabilities'); } this._capabilities = mergeCapabilities(this._capabilities, capabilities); } @@ -318,7 +319,7 @@ export class Server< if (!validatedRequest.success) { const errorMessage = validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid tools/call request: ${errorMessage}`); } const { params } = validatedRequest.data; @@ -333,7 +334,7 @@ export class Server< taskValidationResult.error instanceof Error ? taskValidationResult.error.message : String(taskValidationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid task creation result: ${errorMessage}`); } return taskValidationResult.data; } @@ -343,7 +344,7 @@ export class Server< if (!validationResult.success) { const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); + throw ProtocolError.invalidParams(`Invalid tools/call result: ${errorMessage}`); } return validationResult.data; @@ -368,21 +369,21 @@ export class Server< switch (method as ServerRequest['method']) { case 'sampling/createMessage': { if (!this._clientCapabilities?.sampling) { - throw new Error(`Client does not support sampling (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('sampling', method); } break; } case 'elicitation/create': { if (!this._clientCapabilities?.elicitation) { - throw new Error(`Client does not support elicitation (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('elicitation', method); } break; } case 'roots/list': { if (!this._clientCapabilities?.roots) { - throw new Error(`Client does not support listing roots (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('roots', method); } break; } @@ -398,7 +399,7 @@ export class Server< switch (method as ServerNotification['method']) { case 'notifications/message': { if (!this._capabilities.logging) { - throw new Error(`Server does not support logging (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('logging', method); } break; } @@ -406,28 +407,28 @@ export class Server< case 'notifications/resources/updated': case 'notifications/resources/list_changed': { if (!this._capabilities.resources) { - throw new Error(`Server does not support notifying about resources (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('resources', method); } break; } case 'notifications/tools/list_changed': { if (!this._capabilities.tools) { - throw new Error(`Server does not support notifying of tool list changes (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('tools', method); } break; } case 'notifications/prompts/list_changed': { if (!this._capabilities.prompts) { - throw new Error(`Server does not support notifying of prompt list changes (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('prompts', method); } break; } case 'notifications/elicitation/complete': { if (!this._clientCapabilities?.elicitation?.url) { - throw new Error(`Client does not support URL elicitation (required for ${method})`); + throw CapabilityError.clientDoesNotSupport('elicitation.url', method); } break; } @@ -454,14 +455,14 @@ export class Server< switch (method) { case 'completion/complete': { if (!this._capabilities.completions) { - throw new Error(`Server does not support completions (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('completions', method); } break; } case 'logging/setLevel': { if (!this._capabilities.logging) { - throw new Error(`Server does not support logging (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('logging', method); } break; } @@ -469,7 +470,7 @@ export class Server< case 'prompts/get': case 'prompts/list': { if (!this._capabilities.prompts) { - throw new Error(`Server does not support prompts (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('prompts', method); } break; } @@ -478,7 +479,7 @@ export class Server< case 'resources/templates/list': case 'resources/read': { if (!this._capabilities.resources) { - throw new Error(`Server does not support resources (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('resources', method); } break; } @@ -486,7 +487,7 @@ export class Server< case 'tools/call': case 'tools/list': { if (!this._capabilities.tools) { - throw new Error(`Server does not support tools (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('tools', method); } break; } @@ -496,7 +497,7 @@ export class Server< case 'tasks/result': case 'tasks/cancel': { if (!this._capabilities.tasks) { - throw new Error(`Server does not support tasks capability (required for ${method})`); + throw CapabilityError.serverDoesNotSupport('tasks', method); } break; } @@ -623,7 +624,7 @@ export class Server< ): Promise { // Capability check - only required when tools/toolChoice are provided if ((params.tools || params.toolChoice) && !this._clientCapabilities?.sampling?.tools) { - throw new Error('Client does not support sampling tools capability.'); + throw CapabilityError.clientDoesNotSupport('sampling.tools', 'sampling/createMessage'); } // Message structure validation - always validate tool_use/tool_result pairs. @@ -681,7 +682,7 @@ export class Server< switch (mode) { case 'url': { if (!this._clientCapabilities?.elicitation?.url) { - throw new Error('Client does not support url elicitation.'); + throw CapabilityError.clientDoesNotSupport('elicitation.url', 'elicitation/create'); } const urlParams = params as ElicitRequestURLParams; @@ -689,7 +690,7 @@ export class Server< } case 'form': { if (!this._clientCapabilities?.elicitation?.form) { - throw new Error('Client does not support form elicitation.'); + throw CapabilityError.clientDoesNotSupport('elicitation.form', 'elicitation/create'); } const formParams: ElicitRequestFormParams = @@ -703,17 +704,15 @@ export class Server< const validationResult = validator(result.content); if (!validationResult.valid) { - throw new McpError( - ErrorCode.InvalidParams, + throw ProtocolError.invalidParams( `Elicitation response content does not match requested schema: ${validationResult.errorMessage}` ); } } catch (error) { - if (error instanceof McpError) { + if (error instanceof ProtocolError) { throw error; } - throw new McpError( - ErrorCode.InternalError, + throw ProtocolError.internalError( `Error validating elicitation response: ${error instanceof Error ? error.message : String(error)}` ); } @@ -733,7 +732,7 @@ export class Server< */ createElicitationCompletionNotifier(elicitationId: string, options?: NotificationOptions): () => Promise { if (!this._clientCapabilities?.elicitation?.url) { - throw new Error('Client does not support URL elicitation (required for notifications/elicitation/complete)'); + throw CapabilityError.clientDoesNotSupport('elicitation.url', 'notifications/elicitation/complete'); } return () => diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index ae8bad97e..a51c5da4b 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -17,6 +17,7 @@ import { isJSONRPCRequest, isJSONRPCResultResponse, JSONRPCMessageSchema, + StateError, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; @@ -244,7 +245,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { */ async start(): Promise { if (this._started) { - throw new Error('Transport already started'); + throw StateError.alreadyConnected(); } this._started = true; } diff --git a/test/integration/test/client/client.test.ts b/test/integration/test/client/client.test.ts index ff8ec4135..41ac860f9 100644 --- a/test/integration/test/client/client.test.ts +++ b/test/integration/test/client/client.test.ts @@ -19,8 +19,8 @@ import { ListRootsRequestSchema, ListToolsRequestSchema, ListToolsResultSchema, - McpError, NotificationSchema, + ProtocolError, RequestSchema, ResultSchema, SUPPORTED_PROTOCOL_VERSIONS @@ -1175,8 +1175,8 @@ test('should handle client cancelling a request', async () => { }); controller.abort('Cancelled by test'); - // Request should be rejected with an McpError - await expect(listResourcesPromise).rejects.toThrow(McpError); + // Request should be rejected with an ProtocolError + await expect(listResourcesPromise).rejects.toThrow(ProtocolError); }); /*** diff --git a/test/integration/test/experimental/tasks/taskListing.test.ts b/test/integration/test/experimental/tasks/taskListing.test.ts index 28b39bb3b..6d3f1c6ef 100644 --- a/test/integration/test/experimental/tasks/taskListing.test.ts +++ b/test/integration/test/experimental/tasks/taskListing.test.ts @@ -1,4 +1,4 @@ -import { ErrorCode, McpError } from '@modelcontextprotocol/core'; +import { ErrorCode, ProtocolError } from '@modelcontextprotocol/core'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { createInMemoryTaskEnvironment } from '../../helpers/mcp.js'; @@ -88,8 +88,8 @@ describe('Task Listing with Pagination', () => { }); // Try to use an invalid cursor - should return -32602 (Invalid params) per MCP spec - await expect(client.experimental.tasks.listTasks('invalid-cursor')).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + await expect(client.experimental.tasks.listTasks('invalid-cursor')).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Invalid cursor'); return true; diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index 082892fa3..76f1989b4 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -24,8 +24,8 @@ import { ListPromptsRequestSchema, ListResourcesRequestSchema, ListToolsRequestSchema, - McpError, NotificationSchema, + ProtocolError, RequestSchema, ResultSchema, SetLevelRequestSchema, @@ -1461,8 +1461,8 @@ test('should handle server cancelling a request', async () => { ); controller.abort('Cancelled by test'); - // Request should be rejected with an McpError - await expect(createMessagePromise).rejects.toThrow(McpError); + // Request should be rejected with an ProtocolError + await expect(createMessagePromise).rejects.toThrow(ProtocolError); }); test('should handle request timeout', async () => { diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 1f0a6eb16..db3643cfd 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1581,9 +1581,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); /*** - * Test: McpError for Invalid Tool Name + * Test: ProtocolError for Invalid Tool Name */ - test('should throw McpError for invalid tool name', async () => { + test('should throw ProtocolError for invalid tool name', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -2625,9 +2625,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); /*** - * Test: McpError for Invalid Resource URI + * Test: ProtocolError for Invalid Resource URI */ - test('should throw McpError for invalid resource URI', async () => { + test('should throw ProtocolError for invalid resource URI', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' @@ -3550,9 +3550,9 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }); /*** - * Test: McpError for Invalid Prompt Name + * Test: ProtocolError for Invalid Prompt Name */ - test('should throw McpError for invalid prompt name', async () => { + test('should throw ProtocolError for invalid prompt name', async () => { const mcpServer = new McpServer({ name: 'test server', version: '1.0' diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index 2737cb5cf..68265b2f4 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -13,8 +13,8 @@ import { ErrorCode, InMemoryTaskMessageQueue, InMemoryTaskStore, - McpError, McpServer, + ProtocolError, RELATED_TASK_META_KEY, TaskSchema } from '@modelcontextprotocol/server'; @@ -395,8 +395,8 @@ describe('Task Lifecycle Integration Tests', () => { expect(task.status).toBe('completed'); // Try to cancel via tasks/cancel request (should fail with -32602) - await expect(client.experimental.tasks.cancelTask(taskId)).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + await expect(client.experimental.tasks.cancelTask(taskId)).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Cannot cancel task in terminal status'); return true; @@ -789,8 +789,8 @@ describe('Task Lifecycle Integration Tests', () => { await client.connect(transport); // Try to get non-existent task via tasks/get request - await expect(client.experimental.tasks.getTask('non-existent-task-id')).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + await expect(client.experimental.tasks.getTask('non-existent-task-id')).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Task not found'); return true; @@ -809,8 +809,8 @@ describe('Task Lifecycle Integration Tests', () => { await client.connect(transport); // Try to cancel non-existent task via tasks/cancel request - await expect(client.experimental.tasks.cancelTask('non-existent-task-id')).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + await expect(client.experimental.tasks.cancelTask('non-existent-task-id')).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Task not found'); return true; @@ -837,8 +837,8 @@ describe('Task Lifecycle Integration Tests', () => { }, CallToolResultSchema ) - ).rejects.toSatisfy((error: McpError) => { - expect(error).toBeInstanceOf(McpError); + ).rejects.toSatisfy((error: ProtocolError) => { + expect(error).toBeInstanceOf(ProtocolError); expect(error.code).toBe(ErrorCode.InvalidParams); expect(error.message).toContain('Task not found'); return true;