Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/agent/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ export class Executor {
sessionId: string,
toolCall: ToolCall
): Promise<ToolResult> {
const validationError = this.registry.validateInput(
toolCall.name,
toolCall.input
);
if (validationError) {
await this.eventStore.append(sessionId, "tool_call", {
id: toolCall.id,
tool: toolCall.name,
input: toolCall.input,
});

await this.eventStore.append(sessionId, "tool_result", {
toolCallId: toolCall.id,
tool: toolCall.name,
content: validationError.content.slice(0, 2000),
isError: true,
metadata: validationError.metadata,
});

return validationError;
}

const decision = this.policyEngine.evaluate(toolCall.name, toolCall.input);

await this.eventStore.append(sessionId, "policy_decision", {
Expand Down
121 changes: 121 additions & 0 deletions src/tools/registry.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import type { ToolDefinition, ToolResult, ToolSchema } from "../types/tools.js";

const SUPPORTED_PROPERTY_TYPES = new Set([
"string",
"number",
"integer",
"boolean",
"object",
"array",
]);

export class ToolRegistry {
private tools = new Map<string, ToolDefinition>();

Expand Down Expand Up @@ -29,6 +38,118 @@ export class ToolRegistry {
if (!tool) {
return { content: `Unknown tool: ${name}`, isError: true };
}

const validationError = this.validateInput(name, input);
if (validationError) return validationError;

return tool.execute(input);
}

validateInput(
name: string,
input: Record<string, unknown>
): ToolResult | null {
const tool = this.tools.get(name);
if (!tool) return null;

const errors = validateAgainstSchema(tool.inputSchema, input);
if (errors.length === 0) return null;

return {
content: `Invalid input for tool "${name}": ${errors.join(" ")}`,
isError: true,
metadata: {
validationErrors: errors,
},
};
}
}

function validateAgainstSchema(
schema: Record<string, unknown>,
input: Record<string, unknown>
): string[] {
const errors: string[] = [];
const required = readRequiredFields(schema.required);

for (const field of required) {
if (
!Object.prototype.hasOwnProperty.call(input, field) ||
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Guard non-object tool inputs before required checks

When a provider returns a malformed function call whose arguments parse to null, the OpenAI adapters currently cast JSON.parse(...) directly to Record<string, unknown> (for example in src/models/openai-responses.ts:327 and src/models/openai.ts:275). This new validation path then calls hasOwnProperty on that value and throws TypeError before the executor can return the structured validation tool result, so that malformed tool call still aborts the run instead of letting the model recover.

Useful? React with 👍 / 👎.

input[field] === undefined
) {
errors.push(`missing required field "${field}".`);
}
}

const properties = readProperties(schema.properties);
for (const [field, propertySchema] of Object.entries(properties)) {
if (
!Object.prototype.hasOwnProperty.call(input, field) ||
input[field] === undefined
) {
continue;
}

const expectedType = readExpectedType(propertySchema);
if (!expectedType) continue;

const value = input[field];
if (!matchesType(value, expectedType)) {
errors.push(`field "${field}" must be ${formatExpectedType(expectedType)}.`);
}
}

return errors;
}

function readRequiredFields(required: unknown): string[] {
if (!Array.isArray(required)) return [];
return required.filter((field): field is string => typeof field === "string");
}

function readProperties(
properties: unknown
): Record<string, Record<string, unknown>> {
if (!isPlainObject(properties)) return {};

const result: Record<string, Record<string, unknown>> = {};
for (const [field, propertySchema] of Object.entries(properties)) {
if (isPlainObject(propertySchema)) {
result[field] = propertySchema;
}
}
return result;
}

function readExpectedType(schema: Record<string, unknown>): string | null {
const type = schema.type;
if (typeof type !== "string") return null;
return SUPPORTED_PROPERTY_TYPES.has(type) ? type : null;
}

function matchesType(value: unknown, expectedType: string): boolean {
switch (expectedType) {
case "array":
return Array.isArray(value);
case "boolean":
return typeof value === "boolean";
case "integer":
return Number.isInteger(value);
case "number":
return typeof value === "number" && Number.isFinite(value);
case "object":
return isPlainObject(value);
case "string":
return typeof value === "string";
default:
return true;
}
}

function formatExpectedType(expectedType: string): string {
return expectedType === "integer" ? "an integer" : `a ${expectedType}`;
}

function isPlainObject(value: unknown): value is Record<string, unknown> {
return typeof value === "object" && value !== null && !Array.isArray(value);
}
76 changes: 76 additions & 0 deletions tests/harness-behavior.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type { ChatParams, ModelProvider } from "../src/models/provider.js";
import { EventStore } from "../src/storage/event-store.js";
import { SessionStore } from "../src/storage/session-store.js";
import { ToolRegistry } from "../src/tools/registry.js";
import { runCommandTool } from "../src/tools/run-command.js";
import type { ModelResponse, SessionState } from "../src/types/agent.js";
import type { Message } from "../src/types/messages.js";
import type { ToolDefinition } from "../src/types/tools.js";
Expand Down Expand Up @@ -614,3 +615,78 @@ test("unknown tool calls are logged and returned as error tool results", async (
rmSync(cwd, { recursive: true, force: true });
}
});

test("malformed run_command input returns a tool error without policy evaluation", async () => {
const cwd = createTempDir();
const session = createSession(cwd);
const registry = new ToolRegistry();
registry.register(runCommandTool);
const responses: ModelResponse[] = [
{
stopReason: "tool_use",
content: [
{
type: "tool_use",
id: "tool-1",
name: "run_command",
input: {},
},
],
},
{
stopReason: "end_turn",
content: [{ type: "text", text: "I need to include a command." }],
},
];
const provider: ModelProvider = {
async chat() {
const response = responses.shift();
assert.ok(response);
return response;
},
};
const { loop, eventStore, sessionStore } = createHarness(cwd, provider, {
registry,
policyEngine: PolicyEngine.withDefaults(),
});

try {
await silenceConsole(() => loop.run(session, "Run a command"));

const saved = await sessionStore.load(session.id);
assert.ok(saved);
assert.deepEqual(saved.messages.at(-2), {
role: "user",
content: [
{
type: "tool_result",
toolUseId: "tool-1",
content:
'Invalid input for tool "run_command": missing required field "command".',
isError: true,
metadata: {
validationErrors: ['missing required field "command".'],
},
},
],
});

const events = await eventStore.getEvents(session.id);
assert.deepEqual(
events.map((event) => event.type),
["user_message", "assistant_response", "tool_call", "tool_result", "assistant_response"]
);
assert.deepEqual(events[3].data, {
toolCallId: "tool-1",
tool: "run_command",
content:
'Invalid input for tool "run_command": missing required field "command".',
isError: true,
metadata: {
validationErrors: ['missing required field "command".'],
},
});
} finally {
rmSync(cwd, { recursive: true, force: true });
}
});
30 changes: 30 additions & 0 deletions tests/tool-registry.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,33 @@ test("ToolRegistry.toSchemas returns deterministic name-sorted schemas", () => {
parameters: alpha.inputSchema,
});
});

test("ToolRegistry.execute returns validation errors for malformed tool input", async () => {
const registry = new ToolRegistry();
registry.register(createTool("alpha"));

const result = await registry.execute("alpha", {});

assert.deepEqual(result, {
content: 'Invalid input for tool "alpha": missing required field "value".',
isError: true,
metadata: {
validationErrors: ['missing required field "value".'],
},
});
});

test("ToolRegistry.execute validates primitive property types", async () => {
const registry = new ToolRegistry();
registry.register(createTool("alpha"));

const result = await registry.execute("alpha", { value: 42 });

assert.deepEqual(result, {
content: 'Invalid input for tool "alpha": field "value" must be a string.',
isError: true,
metadata: {
validationErrors: ['field "value" must be a string.'],
},
});
});