diff --git a/src/agent/executor.ts b/src/agent/executor.ts index 3d1ead9..f98e25c 100644 --- a/src/agent/executor.ts +++ b/src/agent/executor.ts @@ -20,6 +20,28 @@ export class Executor { sessionId: string, toolCall: ToolCall ): Promise { + const validationError = this.registry.validateInput( + toolCall.name, + toolCall.input + ); + if (validationError) { + await this.eventStore.append(sessionId, "tool_call", { + id: toolCall.id, + tool: toolCall.name, + input: toolCall.input, + }); + + await this.eventStore.append(sessionId, "tool_result", { + toolCallId: toolCall.id, + tool: toolCall.name, + content: validationError.content.slice(0, 2000), + isError: true, + metadata: validationError.metadata, + }); + + return validationError; + } + const decision = this.policyEngine.evaluate(toolCall.name, toolCall.input); await this.eventStore.append(sessionId, "policy_decision", { diff --git a/src/tools/registry.ts b/src/tools/registry.ts index 5746ab7..72724d4 100644 --- a/src/tools/registry.ts +++ b/src/tools/registry.ts @@ -1,5 +1,14 @@ import type { ToolDefinition, ToolResult, ToolSchema } from "../types/tools.js"; +const SUPPORTED_PROPERTY_TYPES = new Set([ + "string", + "number", + "integer", + "boolean", + "object", + "array", +]); + export class ToolRegistry { private tools = new Map(); @@ -29,6 +38,118 @@ export class ToolRegistry { if (!tool) { return { content: `Unknown tool: ${name}`, isError: true }; } + + const validationError = this.validateInput(name, input); + if (validationError) return validationError; + return tool.execute(input); } + + validateInput( + name: string, + input: Record + ): ToolResult | null { + const tool = this.tools.get(name); + if (!tool) return null; + + const errors = validateAgainstSchema(tool.inputSchema, input); + if (errors.length === 0) return null; + + return { + content: `Invalid input for tool "${name}": ${errors.join(" ")}`, + isError: true, + metadata: { + validationErrors: errors, + }, + }; + } +} + +function validateAgainstSchema( + schema: Record, + input: Record +): string[] { + const errors: string[] = []; + const required = readRequiredFields(schema.required); + + for (const field of required) { + if ( + !Object.prototype.hasOwnProperty.call(input, field) || + input[field] === undefined + ) { + errors.push(`missing required field "${field}".`); + } + } + + const properties = readProperties(schema.properties); + for (const [field, propertySchema] of Object.entries(properties)) { + if ( + !Object.prototype.hasOwnProperty.call(input, field) || + input[field] === undefined + ) { + continue; + } + + const expectedType = readExpectedType(propertySchema); + if (!expectedType) continue; + + const value = input[field]; + if (!matchesType(value, expectedType)) { + errors.push(`field "${field}" must be ${formatExpectedType(expectedType)}.`); + } + } + + return errors; +} + +function readRequiredFields(required: unknown): string[] { + if (!Array.isArray(required)) return []; + return required.filter((field): field is string => typeof field === "string"); +} + +function readProperties( + properties: unknown +): Record> { + if (!isPlainObject(properties)) return {}; + + const result: Record> = {}; + for (const [field, propertySchema] of Object.entries(properties)) { + if (isPlainObject(propertySchema)) { + result[field] = propertySchema; + } + } + return result; +} + +function readExpectedType(schema: Record): string | null { + const type = schema.type; + if (typeof type !== "string") return null; + return SUPPORTED_PROPERTY_TYPES.has(type) ? type : null; +} + +function matchesType(value: unknown, expectedType: string): boolean { + switch (expectedType) { + case "array": + return Array.isArray(value); + case "boolean": + return typeof value === "boolean"; + case "integer": + return Number.isInteger(value); + case "number": + return typeof value === "number" && Number.isFinite(value); + case "object": + return isPlainObject(value); + case "string": + return typeof value === "string"; + default: + return true; + } +} + +function formatExpectedType(expectedType: string): string { + return expectedType === "integer" ? "an integer" : `a ${expectedType}`; +} + +function isPlainObject(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); } diff --git a/tests/harness-behavior.test.ts b/tests/harness-behavior.test.ts index f1858c3..8ceab82 100644 --- a/tests/harness-behavior.test.ts +++ b/tests/harness-behavior.test.ts @@ -13,6 +13,7 @@ import type { ChatParams, ModelProvider } from "../src/models/provider.js"; import { EventStore } from "../src/storage/event-store.js"; import { SessionStore } from "../src/storage/session-store.js"; import { ToolRegistry } from "../src/tools/registry.js"; +import { runCommandTool } from "../src/tools/run-command.js"; import type { ModelResponse, SessionState } from "../src/types/agent.js"; import type { Message } from "../src/types/messages.js"; import type { ToolDefinition } from "../src/types/tools.js"; @@ -614,3 +615,78 @@ test("unknown tool calls are logged and returned as error tool results", async ( rmSync(cwd, { recursive: true, force: true }); } }); + +test("malformed run_command input returns a tool error without policy evaluation", async () => { + const cwd = createTempDir(); + const session = createSession(cwd); + const registry = new ToolRegistry(); + registry.register(runCommandTool); + const responses: ModelResponse[] = [ + { + stopReason: "tool_use", + content: [ + { + type: "tool_use", + id: "tool-1", + name: "run_command", + input: {}, + }, + ], + }, + { + stopReason: "end_turn", + content: [{ type: "text", text: "I need to include a command." }], + }, + ]; + const provider: ModelProvider = { + async chat() { + const response = responses.shift(); + assert.ok(response); + return response; + }, + }; + const { loop, eventStore, sessionStore } = createHarness(cwd, provider, { + registry, + policyEngine: PolicyEngine.withDefaults(), + }); + + try { + await silenceConsole(() => loop.run(session, "Run a command")); + + const saved = await sessionStore.load(session.id); + assert.ok(saved); + assert.deepEqual(saved.messages.at(-2), { + role: "user", + content: [ + { + type: "tool_result", + toolUseId: "tool-1", + content: + 'Invalid input for tool "run_command": missing required field "command".', + isError: true, + metadata: { + validationErrors: ['missing required field "command".'], + }, + }, + ], + }); + + const events = await eventStore.getEvents(session.id); + assert.deepEqual( + events.map((event) => event.type), + ["user_message", "assistant_response", "tool_call", "tool_result", "assistant_response"] + ); + assert.deepEqual(events[3].data, { + toolCallId: "tool-1", + tool: "run_command", + content: + 'Invalid input for tool "run_command": missing required field "command".', + isError: true, + metadata: { + validationErrors: ['missing required field "command".'], + }, + }); + } finally { + rmSync(cwd, { recursive: true, force: true }); + } +}); diff --git a/tests/tool-registry.test.ts b/tests/tool-registry.test.ts index 1d3ce86..ad300b2 100644 --- a/tests/tool-registry.test.ts +++ b/tests/tool-registry.test.ts @@ -47,3 +47,33 @@ test("ToolRegistry.toSchemas returns deterministic name-sorted schemas", () => { parameters: alpha.inputSchema, }); }); + +test("ToolRegistry.execute returns validation errors for malformed tool input", async () => { + const registry = new ToolRegistry(); + registry.register(createTool("alpha")); + + const result = await registry.execute("alpha", {}); + + assert.deepEqual(result, { + content: 'Invalid input for tool "alpha": missing required field "value".', + isError: true, + metadata: { + validationErrors: ['missing required field "value".'], + }, + }); +}); + +test("ToolRegistry.execute validates primitive property types", async () => { + const registry = new ToolRegistry(); + registry.register(createTool("alpha")); + + const result = await registry.execute("alpha", { value: 42 }); + + assert.deepEqual(result, { + content: 'Invalid input for tool "alpha": field "value" must be a string.', + isError: true, + metadata: { + validationErrors: ['field "value" must be a string.'], + }, + }); +});