From 4ce9d5ae255f19082ac59d23a1735dd58e33bca9 Mon Sep 17 00:00:00 2001 From: Ashok161 Date: Mon, 23 Mar 2026 19:46:13 +0530 Subject: [PATCH] feat(core): add tool execution timeout guard --- .changeset/tool-timeout-guard.md | 6 + packages/core/src/providers/tool-execution.ts | 21 ++- packages/core/src/types.ts | 1 + .../tests/providers/tool-execution.test.ts | 90 +++++++++++++ packages/server/src/config.ts | 1 + packages/server/src/handler.ts | 1 + packages/server/tests/handler.test.ts | 123 +++++++++++++++++- 7 files changed, 239 insertions(+), 4 deletions(-) create mode 100644 .changeset/tool-timeout-guard.md create mode 100644 packages/core/tests/providers/tool-execution.test.ts diff --git a/.changeset/tool-timeout-guard.md b/.changeset/tool-timeout-guard.md new file mode 100644 index 0000000..1999795 --- /dev/null +++ b/.changeset/tool-timeout-guard.md @@ -0,0 +1,6 @@ +--- +'@chatcops/core': patch +'@chatcops/server': patch +--- + +Add a configurable timeout guard for tool execution so hanging tools fail gracefully instead of blocking provider responses indefinitely. diff --git a/packages/core/src/providers/tool-execution.ts b/packages/core/src/providers/tool-execution.ts index a3da7dd..3ffbc4a 100644 --- a/packages/core/src/providers/tool-execution.ts +++ b/packages/core/src/providers/tool-execution.ts @@ -1,6 +1,7 @@ import type { ProviderChatParams, ProviderToolCall, ToolResult } from '../types.js'; export const MAX_TOOL_ROUNDS = 5; +export const DEFAULT_TOOL_TIMEOUT_MS = 30_000; export function parseToolInput(rawInput: string): Record { if (!rawInput.trim()) return {}; @@ -45,9 +46,27 @@ export async function executeToolCall( }; } + const timeoutMs = params.toolTimeoutMs ?? DEFAULT_TOOL_TIMEOUT_MS; + const timers = globalThis as typeof globalThis & { + setTimeout: (callback: () => void, delay?: number) => unknown; + clearTimeout: (timeoutId: unknown) => void; + }; + let timeoutId: unknown; + try { - return await params.toolExecutor(call); + return await Promise.race([ + params.toolExecutor(call), + new Promise((_, reject) => { + timeoutId = timers.setTimeout(() => { + reject(new Error(`Tool execution timed out after ${timeoutMs}ms.`)); + }, timeoutMs); + }), + ]); } catch (error) { return toToolFailure(error); + } finally { + if (timeoutId !== undefined) { + timers.clearTimeout(timeoutId); + } } } diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index b5b6e5d..6018fe2 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -88,6 +88,7 @@ export interface ProviderChatParams { systemPrompt: string; tools?: ToolDefinition[]; toolExecutor?: ProviderToolExecutor; + toolTimeoutMs?: number; maxTokens?: number; temperature?: number; } diff --git a/packages/core/tests/providers/tool-execution.test.ts b/packages/core/tests/providers/tool-execution.test.ts new file mode 100644 index 0000000..bfcb2c9 --- /dev/null +++ b/packages/core/tests/providers/tool-execution.test.ts @@ -0,0 +1,90 @@ +import { afterEach, describe, expect, it, vi } from 'vitest'; +import { executeToolCall } from '../../src/providers/tool-execution.js'; +import type { ProviderChatParams, ProviderToolCall, ToolResult } from '../../src/types.js'; + +function createParams(overrides?: Partial): ProviderChatParams { + return { + messages: [{ id: '1', role: 'user', content: 'Check order 123', timestamp: Date.now() }], + systemPrompt: 'You are helpful.', + ...overrides, + }; +} + +const toolCall: ProviderToolCall = { + id: 'call_1', + name: 'lookup_order', + input: { orderId: '123' }, +}; + +function createHangingToolExecutor(): NonNullable { + return vi.fn().mockImplementation( + () => new Promise(() => undefined) + ) as NonNullable; +} + +describe('executeToolCall', () => { + afterEach(() => { + vi.useRealTimers(); + }); + + it('returns a failure result when tool execution exceeds the timeout', async () => { + vi.useFakeTimers(); + + const resultPromise = executeToolCall( + createParams({ + toolTimeoutMs: 50, + toolExecutor: createHangingToolExecutor(), + }), + toolCall + ); + + await vi.advanceTimersByTimeAsync(50); + + await expect(resultPromise).resolves.toEqual({ + success: false, + message: 'Tool execution timed out after 50ms.', + }); + }); + + it('returns the tool result when execution completes before the timeout', async () => { + vi.useFakeTimers(); + + const expectedResult: ToolResult = { + success: true, + data: { status: 'shipped' }, + message: 'Order found', + }; + + const result = await executeToolCall( + createParams({ + toolTimeoutMs: 50, + toolExecutor: vi.fn().mockResolvedValue(expectedResult), + }), + toolCall + ); + + expect(result).toEqual(expectedResult); + expect(vi.getTimerCount()).toBe(0); + }); + + it('uses the configured timeout value', async () => { + vi.useFakeTimers(); + + const resultPromise = executeToolCall( + createParams({ + toolTimeoutMs: 10, + toolExecutor: createHangingToolExecutor(), + }), + toolCall + ); + + await vi.advanceTimersByTimeAsync(9); + await expect(Promise.race([resultPromise, Promise.resolve('pending')])).resolves.toBe('pending'); + + await vi.advanceTimersByTimeAsync(1); + await expect(resultPromise).resolves.toEqual({ + success: false, + message: 'Tool execution timed out after 10ms.', + }); + }); +}); diff --git a/packages/server/src/config.ts b/packages/server/src/config.ts index be91053..0e85b0d 100644 --- a/packages/server/src/config.ts +++ b/packages/server/src/config.ts @@ -5,6 +5,7 @@ export interface ChatCopsServerConfig { provider: ProviderConfig; systemPrompt: string; tools?: ChatTool[]; + toolTimeoutMs?: number; knowledge?: KnowledgeSource[]; rateLimit?: { maxRequests: number; windowMs: number }; webhooks?: WebhookConfig[]; diff --git a/packages/server/src/handler.ts b/packages/server/src/handler.ts index c78feb3..55adf31 100644 --- a/packages/server/src/handler.ts +++ b/packages/server/src/handler.ts @@ -109,6 +109,7 @@ export function createChatHandler(config: ChatCopsServerConfig) { messages, systemPrompt, tools: toolDefs.length > 0 ? toolDefs : undefined, + toolTimeoutMs: config.toolTimeoutMs, toolExecutor: async (toolCall) => { const tool = toolsByName.get(toolCall.name); if (!tool) { diff --git a/packages/server/tests/handler.test.ts b/packages/server/tests/handler.test.ts index 5abc337..66f6a99 100644 --- a/packages/server/tests/handler.test.ts +++ b/packages/server/tests/handler.test.ts @@ -1,10 +1,96 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -vi.mock('@chatcops/core', async (importOriginal) => { - const actual = await importOriginal(); +vi.mock('@chatcops/core', () => { + class ConversationManager { + private readonly conversations = new Map< + string, + { + id: string; + messages: Array<{ + id: string; + role: 'user' | 'assistant' | 'system'; + content: string; + timestamp: number; + metadata?: Record; + }>; + metadata?: Record; + createdAt: number; + updatedAt: number; + } + >(); + + async getOrCreate(id: string) { + const existing = this.conversations.get(id); + if (existing) { + return existing; + } + + const now = Date.now(); + const conversation = { + id, + messages: [], + createdAt: now, + updatedAt: now, + }; + this.conversations.set(id, conversation); + return conversation; + } + + async addMessage(conversationId: string, message: { + id: string; + role: 'user' | 'assistant' | 'system'; + content: string; + timestamp: number; + metadata?: Record; + }) { + const conversation = await this.getOrCreate(conversationId); + conversation.messages.push(message); + conversation.updatedAt = Date.now(); + } + + async getMessages(conversationId: string) { + const conversation = await this.getOrCreate(conversationId); + return [...conversation.messages]; + } + } + + class AnalyticsCollector { + private totalConversations = 0; + private leadsCapture = 0; + + track(event: string) { + if (event === 'conversation:started') { + this.totalConversations += 1; + } + + if (event === 'lead:captured') { + this.leadsCapture += 1; + } + } + + getStats() { + return { + totalConversations: this.totalConversations, + leadsCapture: this.leadsCapture, + }; + } + } + return { - ...actual, createProvider: vi.fn(), + ConversationManager, + AnalyticsCollector, + toolToDefinition: (tool: { + name: string; + description: string; + parameters: Record; + required: string[]; + }) => ({ + name: tool.name, + description: tool.description, + parameters: tool.parameters, + required: tool.required, + }), }; }); @@ -96,6 +182,37 @@ describe('createChatHandler', () => { ]); }); + it('passes the configured tool timeout into the provider chat params', async () => { + const receivedTimeouts: Array = []; + + mockedCreateProvider.mockResolvedValue({ + name: 'test-provider', + async *chat(params) { + receivedTimeouts.push(params.toolTimeoutMs); + yield 'Configured timeout received.'; + }, + async chatSync() { + return ''; + }, + }); + + const { handleChat } = createChatHandler({ + provider: { type: 'openai', apiKey: 'test-key' }, + systemPrompt: 'Test', + cors: '*', + toolTimeoutMs: 2_500, + }); + + for await (const _chunk of handleChat({ + conversationId: 'conv-timeout', + message: 'Hello', + })) { + // Drain the stream. + } + + expect(receivedTimeouts).toEqual([2_500]); + }); + it('tracks lead capture analytics when a tool succeeds', async () => { const leadTool = { name: 'capture_lead',