diff --git a/.agents/__tests__/editor-best-of-n.integration.test.ts b/.agents/__tests__/editor-best-of-n.integration.test.ts new file mode 100644 index 000000000..b4f8fcce0 --- /dev/null +++ b/.agents/__tests__/editor-best-of-n.integration.test.ts @@ -0,0 +1,91 @@ +import { API_KEY_ENV_VAR } from '@codebuff/common/old-constants' +import { describe, expect, it } from 'bun:test' + +import { CodebuffClient } from '@codebuff/sdk' + +import type { PrintModeEvent } from '@codebuff/common/types/print-mode' + +/** + * Integration tests for the editor-best-of-n-max agent. + * These tests verify that the best-of-n editor workflow works correctly: + * 1. Spawns multiple implementor agents in parallel + * 2. Collects their implementation proposals + * 3. Uses a selector agent to choose the best implementation + * 4. Applies the chosen implementation + */ +describe('Editor Best-of-N Max Agent Integration', () => { + it( + 'should generate and select the best implementation for a simple edit', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create mock project files with a simple TypeScript file to edit + const projectFiles: Record = { + 'src/utils/math.ts': ` +export function add(a: number, b: number): number { + return a + b +} + +export function subtract(a: number, b: number): number { + return a - b +} +`, + 'src/index.ts': ` +import { add, subtract } from './utils/math' + +console.log(add(1, 2)) +console.log(subtract(5, 3)) +`, + 'package.json': JSON.stringify({ + name: 'test-project', + version: '1.0.0', + dependencies: {}, + }), + } + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-best-of-n-project', + projectFiles, + }) + + const events: PrintModeEvent[] = [] + + // Run the editor-best-of-n-max agent with a simple task + // Using n=2 to keep the test fast while still testing the best-of-n workflow + const run = await client.run({ + agent: 'editor-best-of-n-max', + prompt: + 'Add a multiply function to src/utils/math.ts that takes two numbers and returns their product', + params: { n: 2 }, + handleEvent: (event) => { + console.log(event) + events.push(event) + }, + }) + + // The output should not be an error + expect(run.output.type).not.toEqual('error') + + // Verify we got some output + expect(run.output).toBeDefined() + + // The output should contain the implementation response + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + console.log('Output:', outputStr) + + // Should contain evidence of the multiply function being added + const relevantTerms = ['multiply', 'product', 'str_replace', 'write_file'] + const foundRelevantTerm = relevantTerms.some((term) => + outputStr.toLowerCase().includes(term.toLowerCase()), + ) + + expect(foundRelevantTerm).toBe(true) + }, + { timeout: 120_000 }, // 2 minute timeout for best-of-n workflow + ) +}) diff --git a/.agents/__tests__/file-explorer.integration.test.ts b/.agents/__tests__/file-explorer.integration.test.ts new file mode 100644 index 000000000..0aa3cc3f6 --- /dev/null +++ b/.agents/__tests__/file-explorer.integration.test.ts @@ -0,0 +1,348 @@ +import { API_KEY_ENV_VAR } from '@codebuff/common/old-constants' +import { describe, expect, it } from 'bun:test' + +import { CodebuffClient } from '@codebuff/sdk' +import filePickerDefinition from '../file-explorer/file-picker' +import fileListerDefinition from '../file-explorer/file-lister' + +import type { PrintModeEvent } from '@codebuff/common/types/print-mode' + +/** + * Integration tests for agents that use the read_subtree tool. + * These tests verify that the SDK properly initializes the session state + * with project files and that agents can access the file tree through + * the read_subtree tool. + * + * The file-lister agent is used directly instead of file-picker because: + * - file-lister directly uses the read_subtree tool + * - file-picker spawns file-lister as a subagent, adding complexity + * - Testing file-lister directly verifies the core functionality + */ +describe('File Lister Agent Integration - read_subtree tool', () => { + it( + 'should find relevant files using read_subtree tool', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create mock project files that the file-lister should be able to find + const projectFiles: Record = { + 'src/index.ts': ` +import { UserService } from './services/user-service' +import { AuthService } from './services/auth-service' + +export function main() { + const userService = new UserService() + const authService = new AuthService() + console.log('Application started') +} +`, + 'src/services/user-service.ts': ` +export class UserService { + async getUser(id: string) { + return { id, name: 'John Doe' } + } + + async createUser(name: string) { + return { id: 'new-user-id', name } + } + + async deleteUser(id: string) { + console.log('User deleted:', id) + } +} +`, + 'src/services/auth-service.ts': ` +export class AuthService { + async login(email: string, password: string) { + return { token: 'mock-token' } + } + + async logout() { + console.log('Logged out') + } + + async validateToken(token: string) { + return token === 'mock-token' + } +} +`, + 'src/utils/logger.ts': ` +export function log(message: string) { + console.log('[LOG]', message) +} + +export function error(message: string) { + console.error('[ERROR]', message) +} +`, + 'src/types/user.ts': ` +export interface User { + id: string + name: string + email?: string +} +`, + 'package.json': JSON.stringify({ + name: 'test-project', + version: '1.0.0', + dependencies: {}, + }), + 'README.md': + '# Test Project\n\nA simple test project for integration testing.', + } + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-project', + projectFiles, + }) + + const events: PrintModeEvent[] = [] + + // Run the file-lister agent to find files related to user service + // The file-lister agent uses the read_subtree tool directly + const run = await client.run({ + agent: 'file-lister', + prompt: 'Find files related to user authentication and user management', + handleEvent: (event) => { + events.push(event) + }, + }) + + // The output should not be an error + expect(run.output.type).not.toEqual('error') + + // Verify we got some output + expect(run.output).toBeDefined() + + // The file-lister should have found relevant files + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + + // Verify that the file-lister found some relevant files + const relevantFiles = [ + 'user-service', + 'auth-service', + 'user', + 'auth', + 'services', + ] + const foundRelevantFile = relevantFiles.some((file) => + outputStr.toLowerCase().includes(file.toLowerCase()), + ) + + expect(foundRelevantFile).toBe(true) + }, + { timeout: 60_000 }, + ) + + it( + 'should use the file tree from session state', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create a different set of project files with a specific structure + const projectFiles: Record = { + 'packages/core/src/index.ts': 'export const VERSION = "1.0.0"', + 'packages/core/src/api/server.ts': + 'export function startServer() { console.log("started") }', + 'packages/core/src/api/routes.ts': + 'export const routes = { health: "/health" }', + 'packages/utils/src/helpers.ts': + 'export function formatDate(d: Date) { return d.toISOString() }', + 'docs/api.md': '# API Documentation\n\nAPI docs here.', + 'package.json': JSON.stringify({ name: 'mono-repo', version: '2.0.0' }), + } + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-project', + projectFiles, + }) + + const events: PrintModeEvent[] = [] + + // Run file-lister to find API-related files + const run = await client.run({ + agent: 'file-lister', + prompt: 'Find files related to the API server implementation', + handleEvent: (event) => { + events.push(event) + }, + }) + + expect(run.output.type).not.toEqual('error') + + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + + // Should find API-related files + const apiRelatedTerms = ['server', 'routes', 'api', 'core'] + const foundApiFile = apiRelatedTerms.some((term) => + outputStr.toLowerCase().includes(term.toLowerCase()), + ) + + expect(foundApiFile).toBe(true) + }, + { timeout: 60_000 }, + ) + + it( + 'should respect directories parameter', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create project with multiple top-level directories + const projectFiles: Record = { + 'frontend/src/App.tsx': + 'export function App() { return
App
}', + 'frontend/src/components/Button.tsx': + 'export function Button() { return }', + 'backend/src/server.ts': + 'export function start() { console.log("started") }', + 'backend/src/routes/users.ts': + 'export function getUsers() { return [] }', + 'shared/types/common.ts': 'export type ID = string', + 'package.json': JSON.stringify({ name: 'full-stack-app' }), + } + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-project', + projectFiles, + }) + + // Run file-lister with directories parameter to limit to frontend only + const run = await client.run({ + agent: 'file-lister', + prompt: 'Find React component files', + params: { + directories: ['frontend'], + }, + handleEvent: () => {}, + }) + + expect(run.output.type).not.toEqual('error') + + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + + // Should find frontend files + const frontendTerms = ['app', 'button', 'component', 'frontend'] + const foundFrontendFile = frontendTerms.some((term) => + outputStr.toLowerCase().includes(term.toLowerCase()), + ) + + expect(foundFrontendFile).toBe(true) + }, + { timeout: 60_000 }, + ) +}) + +/** + * Integration tests for the file-picker agent that spawns subagents. + * The file-picker spawns file-lister as a subagent to find files. + * This tests the spawn_agents tool functionality through the SDK. + */ +describe('File Picker Agent Integration - spawn_agents tool', () => { + // Note: This test requires the local agent definitions to be used for both + // file-picker AND its spawned file-lister subagent. Currently, the spawned + // agent may resolve to the server version which has the old parsing bug. + // Skip until we have a way to ensure spawned agents use local definitions. + it.skip( + 'should spawn file-lister subagent and find relevant files', + async () => { + const apiKey = process.env[API_KEY_ENV_VAR] + if (!apiKey) { + throw new Error('API key not found') + } + + // Create mock project files + const projectFiles: Record = { + 'src/index.ts': ` +import { UserService } from './services/user-service' +export function main() { + const userService = new UserService() + console.log('Application started') +} +`, + 'src/services/user-service.ts': ` +export class UserService { + async getUser(id: string) { + return { id, name: 'John Doe' } + } +} +`, + 'src/services/auth-service.ts': ` +export class AuthService { + async login(email: string, password: string) { + return { token: 'mock-token' } + } +} +`, + 'package.json': JSON.stringify({ + name: 'test-project', + version: '1.0.0', + }), + } + + // Use local agent definitions to test the updated handleSteps + const localFilePickerDef = filePickerDefinition as unknown as any + const localFileListerDef = fileListerDefinition as unknown as any + + const client = new CodebuffClient({ + apiKey, + cwd: '/tmp/test-project-picker', + projectFiles, + agentDefinitions: [localFilePickerDef, localFileListerDef], + }) + + const events: PrintModeEvent[] = [] + + // Run the file-picker agent which spawns file-lister as a subagent + const run = await client.run({ + agent: localFilePickerDef.id, + prompt: 'Find files related to user authentication', + handleEvent: (event) => { + events.push(event) + }, + }) + + // Check for errors in the output + if (run.output.type === 'error') { + console.error('File picker error:', run.output) + } + + console.log('File picker output type:', run.output.type) + console.log('File picker output:', JSON.stringify(run.output, null, 2)) + + // The output should not be an error + expect(run.output.type).not.toEqual('error') + + // Verify we got some output + expect(run.output).toBeDefined() + + // The file-picker should have found relevant files via its spawned file-lister + const outputStr = + typeof run.output === 'string' ? run.output : JSON.stringify(run.output) + + // Verify that the file-picker found some relevant files + const relevantFiles = ['user', 'auth', 'service'] + const foundRelevantFile = relevantFiles.some((file) => + outputStr.toLowerCase().includes(file.toLowerCase()), + ) + + expect(foundRelevantFile).toBe(true) + }, + { timeout: 90_000 }, + ) +}) diff --git a/.agents/base2/base2.ts b/.agents/base2/base2.ts index 38dca4054..e662f9d79 100644 --- a/.agents/base2/base2.ts +++ b/.agents/base2/base2.ts @@ -53,6 +53,7 @@ export function createBase2( 'str_replace', 'write_file', 'ask_user', + 'set_output', ), spawnableAgents: buildArray( 'file-picker', diff --git a/.agents/editor/best-of-n/editor-best-of-n.ts b/.agents/editor/best-of-n/editor-best-of-n.ts index a2dc9501c..ee2b6bddc 100644 --- a/.agents/editor/best-of-n/editor-best-of-n.ts +++ b/.agents/editor/best-of-n/editor-best-of-n.ts @@ -229,7 +229,9 @@ function* handleStepsDefault({ } } function* handleStepsMax({ + agentState, params, + logger, }: AgentStepContext): ReturnType< NonNullable > { @@ -254,6 +256,27 @@ function* handleStepsMax({ 'editor-implementor-opus', ] as const + // Only keep messages up to just before the last user role message (skips input prompt, instrucitons prompt). + const { messageHistory: initialMessageHistory } = agentState + let userMessageIndex = initialMessageHistory.length + + while (userMessageIndex > 0) { + const message = initialMessageHistory[userMessageIndex - 1] + if (message.role === 'user') { + userMessageIndex-- + } else { + break + } + } + const updatedMessageHistory = initialMessageHistory.slice(0, userMessageIndex) + yield { + toolName: 'set_messages', + input: { + messages: updatedMessageHistory, + }, + includeToolCall: false, + } satisfies ToolCall<'set_messages'> + // Spawn implementor agents using the model pattern const implementorAgents = MAX_MODEL_PATTERN.slice(0, n).map((agent_type) => ({ agent_type, @@ -269,8 +292,9 @@ function* handleStepsMax({ } satisfies ToolCall<'spawn_agents'> // Extract spawn results - const spawnedImplementations = - extractSpawnResults<{ text: string }[]>(implementorResults) + const spawnedImplementations = extractSpawnResults( + implementorResults, + ) as any[] // Extract all the plans from the structured outputs const letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' @@ -280,11 +304,16 @@ function* handleStepsMax({ content: 'errorMessage' in result ? `Error: ${result.errorMessage}` - : result[0].text, + : extractLastMessageText(result), })) + logger.info( + { spawnedImplementations, implementations }, + 'spawnedImplementations', + ) + // Spawn selector with implementations as params - const { toolResult: selectorResult } = yield { + const { toolResult: selectorResult, agentState: selectorAgentState } = yield { toolName: 'spawn_agents', input: { agents: [ @@ -298,8 +327,10 @@ function* handleStepsMax({ } satisfies ToolCall<'spawn_agents'> const selectorOutput = extractSpawnResults<{ - implementationId: string - reasoning: string + value: { + implementationId: string + reasoning: string + } }>(selectorResult)[0] if ('errorMessage' in selectorOutput) { @@ -309,7 +340,7 @@ function* handleStepsMax({ } satisfies ToolCall<'set_output'> return } - const { implementationId } = selectorOutput + const { implementationId } = selectorOutput.value const chosenImplementation = implementations.find( (implementation) => implementation.id === implementationId, ) @@ -321,68 +352,77 @@ function* handleStepsMax({ return } - // Apply the chosen implementation using STEP_TEXT (only tool calls, no commentary) - const toolCallsOnly = extractToolCallsOnly( - typeof chosenImplementation.content === 'string' - ? chosenImplementation.content - : '', - ) + const numMessagesBeforeStepText = selectorAgentState.messageHistory.length + const { agentState: postEditsAgentState } = yield { type: 'STEP_TEXT', - text: toolCallsOnly, + text: chosenImplementation.content, } as StepText const { messageHistory } = postEditsAgentState - const lastAssistantMessageIndex = messageHistory.findLastIndex( - (message) => message.role === 'assistant', - ) - const editToolResults = messageHistory - .slice(lastAssistantMessageIndex) - .filter((message) => message.role === 'tool') - .flatMap((message) => message.content) - .filter((output) => output.type === 'json') - .map((output) => output.value) - // Set output with the chosen implementation and reasoning + // Set output with the messages from running the step text of the chosen implementation yield { toolName: 'set_output', input: { - response: chosenImplementation.content, - toolResults: editToolResults, + messages: messageHistory.slice(numMessagesBeforeStepText), }, includeToolCall: false, } satisfies ToolCall<'set_output'> - function extractSpawnResults( - results: any[] | undefined, - ): (T | { errorMessage: string })[] { - if (!results) return [] - const spawnedResults = results - .filter((result) => result.type === 'json') - .map((result) => result.value) - .flat() as { - agentType: string - value: { value?: T; errorMessage?: string } - }[] - return spawnedResults.map( - (result) => - result.value.value ?? { - errorMessage: - result.value.errorMessage ?? 'Error extracting spawn results', - }, - ) + /** + * Extracts the array of subagent results from spawn_agents tool output. + * + * The spawn_agents tool result structure is: + * [{ type: 'json', value: [{ agentName, agentType, value: AgentOutput }] }] + * + * Returns an array of agent outputs, one per spawned agent. + */ + function extractSpawnResults(results: any[] | undefined): T[] { + if (!results || results.length === 0) return [] + + // Find the json result containing spawn results + const jsonResult = results.find((r) => r.type === 'json') + if (!jsonResult?.value) return [] + + // Get the spawned agent results array + const spawnedResults = Array.isArray(jsonResult.value) + ? jsonResult.value + : [jsonResult.value] + + // Extract the value (AgentOutput) from each result + return spawnedResults.map((result: any) => result?.value).filter(Boolean) } - // Extract only tool calls from text, removing any commentary - function extractToolCallsOnly(text: string): string { - const toolExtractionPattern = - /\n(.*?)\n<\/codebuff_tool_call>/gs - const matches: string[] = [] - - for (const match of text.matchAll(toolExtractionPattern)) { - matches.push(match[0]) // Include the full tool call with tags + /** + * Extracts the text content from a 'lastMessage' AgentOutput. + * + * For agents with outputMode: 'last_message', the output structure is: + * { type: 'lastMessage', value: [{ role: 'assistant', content: [{ type: 'text', text: '...' }] }] } + * + * Returns the text from the last assistant message, or null if not found. + */ + function extractLastMessageText(agentOutput: any): string | null { + if (!agentOutput) return null + + // Handle 'lastMessage' output mode - the value contains an array of messages + if ( + agentOutput.type === 'lastMessage' && + Array.isArray(agentOutput.value) + ) { + // Find the last assistant message with text content + for (let i = agentOutput.value.length - 1; i >= 0; i--) { + const message = agentOutput.value[i] + if (message.role === 'assistant' && Array.isArray(message.content)) { + // Find text content in the message + for (const part of message.content) { + if (part.type === 'text' && typeof part.text === 'string') { + return part.text + } + } + } + } } - - return matches.join('\n') + return null } } diff --git a/.agents/editor/best-of-n/editor-implementor.ts b/.agents/editor/best-of-n/editor-implementor.ts index f159df2ce..c27af72a2 100644 --- a/.agents/editor/best-of-n/editor-implementor.ts +++ b/.agents/editor/best-of-n/editor-implementor.ts @@ -37,7 +37,7 @@ export const createBestOfNImplementor = (options: { Your task is to write out ALL the code changes needed to complete the user's request in a single comprehensive response. -Important: You can not make any other tool calls besides editing files. You cannot read more files, write todos, or spawn agents. +Important: You can not make any other tool calls besides editing files. You cannot read more files, write todos, spawn agents, or set output. Do not call any of these tools! Write out what changes you would make using the tool call format below. Use this exact format for each file change: diff --git a/.agents/file-explorer/file-picker.ts b/.agents/file-explorer/file-picker.ts index 25f7b6008..048d904d3 100644 --- a/.agents/file-explorer/file-picker.ts +++ b/.agents/file-explorer/file-picker.ts @@ -64,17 +64,22 @@ Do not use any further tools or spawn any further agents. }, } satisfies ToolCall - const filesResult = - extractSpawnResults<{ text: string }[]>(fileListerResults)[0] - if (!Array.isArray(filesResult)) { + const spawnResults = extractSpawnResults(fileListerResults) + const firstResult = spawnResults[0] + const fileListText = extractLastMessageText(firstResult) + + if (!fileListText) { + const errorMessage = extractErrorMessage(firstResult) yield { type: 'STEP_TEXT', - text: filesResult.errorMessage, + text: errorMessage + ? `Error from file-lister: ${errorMessage}` + : 'Error: Could not extract file list from spawned agent', } satisfies StepText return } - const paths = filesResult[0].text.split('\n').filter(Boolean) + const paths = fileListText.split('\n').filter(Boolean) yield { toolName: 'read_files', @@ -85,24 +90,71 @@ Do not use any further tools or spawn any further agents. yield 'STEP' - function extractSpawnResults( - results: any[] | undefined, - ): (T | { errorMessage: string })[] { - if (!results) return [] - const spawnedResults = results - .filter((result) => result.type === 'json') - .map((result) => result.value) - .flat() as { - agentType: string - value: { value?: T; errorMessage?: string } - }[] - return spawnedResults.map( - (result) => - result.value.value ?? { - errorMessage: - result.value.errorMessage ?? 'Error extracting spawn results', - }, - ) + /** + * Extracts the array of subagent results from spawn_agents tool output. + * + * The spawn_agents tool result structure is: + * [{ type: 'json', value: [{ agentName, agentType, value: AgentOutput }] }] + * + * Returns an array of agent outputs, one per spawned agent. + */ + function extractSpawnResults(results: any[] | undefined): any[] { + if (!results || results.length === 0) return [] + + // Find the json result containing spawn results + const jsonResult = results.find((r) => r.type === 'json') + if (!jsonResult?.value) return [] + + // Get the spawned agent results array + const spawnedResults = Array.isArray(jsonResult.value) ? jsonResult.value : [jsonResult.value] + + // Extract the value (AgentOutput) from each result + return spawnedResults.map((result: any) => result?.value).filter(Boolean) + } + + /** + * Extracts the text content from a 'lastMessage' AgentOutput. + * + * For agents with outputMode: 'last_message', the output structure is: + * { type: 'lastMessage', value: [{ role: 'assistant', content: [{ type: 'text', text: '...' }] }] } + * + * Returns the text from the last assistant message, or null if not found. + */ + function extractLastMessageText(agentOutput: any): string | null { + if (!agentOutput) return null + + // Handle 'lastMessage' output mode - the value contains an array of messages + if (agentOutput.type === 'lastMessage' && Array.isArray(agentOutput.value)) { + // Find the last assistant message with text content + for (let i = agentOutput.value.length - 1; i >= 0; i--) { + const message = agentOutput.value[i] + if (message.role === 'assistant' && Array.isArray(message.content)) { + // Find text content in the message + for (const part of message.content) { + if (part.type === 'text' && typeof part.text === 'string') { + return part.text + } + } + } + } + } + + return null + } + + /** + * Extracts the error message from an AgentOutput if it's an error type. + * + * Returns the error message string, or null if not an error output. + */ + function extractErrorMessage(agentOutput: any): string | null { + if (!agentOutput) return null + + if (agentOutput.type === 'error') { + return agentOutput.message ?? agentOutput.value ?? null + } + + return null } }, } diff --git a/.agents/tsconfig.json b/.agents/tsconfig.json index 4387f3d66..dbb372c16 100644 --- a/.agents/tsconfig.json +++ b/.agents/tsconfig.json @@ -5,6 +5,7 @@ "skipLibCheck": true, "types": ["bun", "node"], "paths": { + "@codebuff/sdk": ["../sdk/src/index.ts"], "@codebuff/common/*": ["../common/src/*"] } }, diff --git a/backend/src/__tests__/cost-aggregation.integration.test.ts b/backend/src/__tests__/cost-aggregation.integration.test.ts index 3206f3e5d..5dd6a5cd8 100644 --- a/backend/src/__tests__/cost-aggregation.integration.test.ts +++ b/backend/src/__tests__/cost-aggregation.integration.test.ts @@ -4,6 +4,7 @@ import * as agentRegistry from '@codebuff/agent-runtime/templates/agent-registry import { TEST_USER_ID } from '@codebuff/common/old-constants' import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' import { getInitialSessionState } from '@codebuff/common/types/session-state' +import { generateCompactId } from '@codebuff/common/util/string' import { spyOn, beforeEach, @@ -22,6 +23,7 @@ import type { AgentRuntimeScopedDeps, } from '@codebuff/common/types/contracts/agent-runtime' import type { SendActionFn } from '@codebuff/common/types/contracts/client' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { ProjectFileContext } from '@codebuff/common/util/file' import type { Mock } from 'bun:test' @@ -149,15 +151,30 @@ describe('Cost Aggregation Integration Tests', () => { if (callCount === 1) { // Main agent spawns a subagent yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n', - } + type: 'tool-call', + toolName: 'spawn_agents', + toolCallId: generateCompactId('test-id-'), + input: { + agents: [ + { + agent_type: 'editor', + prompt: 'Write a simple hello world file', + }, + ], + }, + } satisfies StreamChunk } else { // Subagent writes a file yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n', - } + type: 'tool-call', + toolName: 'write_file', + toolCallId: generateCompactId('test-id-'), + input: { + path: 'hello.txt', + instructions: 'Create hello world file', + content: 'Hello, World!', + }, + } satisfies StreamChunk } return 'mock-message-id' }, @@ -252,8 +269,8 @@ describe('Cost Aggregation Integration Tests', () => { // Verify the total cost includes both main agent and subagent costs const finalCreditsUsed = result.sessionState.mainAgentState.creditsUsed - // The actual cost is higher than expected due to multiple steps in agent execution - expect(finalCreditsUsed).toEqual(73) + // 10 for the first call, 7 for the subagent, 7*9 for the next 9 calls + expect(finalCreditsUsed).toEqual(80) // Verify the cost breakdown makes sense expect(finalCreditsUsed).toBeGreaterThan(0) @@ -307,21 +324,35 @@ describe('Cost Aggregation Integration Tests', () => { if (callCount === 1) { // Main agent spawns first-level subagent yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n', - } + type: 'tool-call', + toolName: 'spawn_agents', + toolCallId: generateCompactId('test-id-'), + input: { + agents: [{ agent_type: 'editor', prompt: 'Create files' }], + }, + } satisfies StreamChunk } else if (callCount === 2) { // First-level subagent spawns second-level subagent yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n', - } + type: 'tool-call', + toolName: 'spawn_agents', + toolCallId: generateCompactId('test-id-'), + input: { + agents: [{ agent_type: 'editor', prompt: 'Write specific file' }], + }, + } satisfies StreamChunk } else { // Second-level subagent does actual work yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n', - } + type: 'tool-call', + toolName: 'write_file', + toolCallId: generateCompactId('test-id-'), + input: { + path: 'nested.txt', + instructions: 'Create nested file', + content: 'Nested content', + }, + } satisfies StreamChunk } return 'mock-message-id' @@ -348,8 +379,8 @@ describe('Cost Aggregation Integration Tests', () => { // Should aggregate costs from all levels: main + sub1 + sub2 const finalCreditsUsed = result.sessionState.mainAgentState.creditsUsed - // Multi-level agents should have higher costs than simple ones - expect(finalCreditsUsed).toEqual(50) + // 10 calls from base agent, 1 from first subagent, 1 from second subagent: 12 calls total + expect(finalCreditsUsed).toEqual(60) }) it('should maintain cost integrity when subagents fail', async () => { @@ -365,12 +396,19 @@ describe('Cost Aggregation Integration Tests', () => { if (callCount === 1) { // Main agent spawns subagent yield { - type: 'text' as const, - text: '\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n', - } + type: 'tool-call', + toolName: 'spawn_agents', + toolCallId: generateCompactId('test-id-'), + input: { + agents: [{ agent_type: 'editor', prompt: 'This will fail' }], + }, + } satisfies StreamChunk } else { // Subagent fails after incurring cost - yield { type: 'text' as const, text: 'Some response' } + yield { + type: 'text', + text: 'Some response', + } satisfies StreamChunk throw new Error('Subagent execution failed') } diff --git a/cli/src/hooks/use-send-message.ts b/cli/src/hooks/use-send-message.ts index 6bf3b46dc..3fe9e44fd 100644 --- a/cli/src/hooks/use-send-message.ts +++ b/cli/src/hooks/use-send-message.ts @@ -1470,6 +1470,7 @@ export const useSendMessage = ({ input, agentId, includeToolCall, + parentAgentId, } = event if (toolName === 'spawn_agents' && input?.agents) { @@ -1541,7 +1542,7 @@ export const useSendMessage = ({ } // If this tool call belongs to a subagent, add it to that agent's blocks - if (agentId) { + if (parentAgentId && agentId) { applyMessageUpdate((prev) => prev.map((msg) => { if (msg.id !== aiMessageId || !msg.blocks) { diff --git a/common/src/__tests__/agent-validation.test.ts b/common/src/__tests__/agent-validation.test.ts index d2995e962..dab2efa16 100644 --- a/common/src/__tests__/agent-validation.test.ts +++ b/common/src/__tests__/agent-validation.test.ts @@ -779,7 +779,10 @@ describe('Agent Validation', () => { } }) - test('should reject set_output tool without json output mode', () => { + // Note: The validation that rejected set_output without structured_output mode was + // intentionally disabled to allow parent agents to have set_output tool with 'last_message' + // outputMode while their subagents use 'structured_output' (preserves prompt caching). + test('should allow set_output tool without structured_output mode', () => { const { DynamicAgentTemplateSchema, } = require('../types/dynamic-agent-template') @@ -791,7 +794,7 @@ describe('Agent Validation', () => { spawnerPrompt: 'Testing', model: 'claude-3-5-sonnet-20241022', outputMode: 'last_message' as const, // Not structured_output - toolNames: ['end_turn', 'set_output'], // Has set_output + toolNames: ['end_turn', 'set_output'], // Has set_output - now allowed spawnableAgents: [], systemPrompt: 'Test', instructionsPrompt: 'Test', @@ -799,13 +802,7 @@ describe('Agent Validation', () => { } const result = DynamicAgentTemplateSchema.safeParse(agentConfig) - expect(result.success).toBe(false) - if (!result.success) { - const errorMessage = result.error.issues[0]?.message || '' - expect(errorMessage).toContain( - "'set_output' tool requires outputMode to be 'structured_output'", - ) - } + expect(result.success).toBe(true) }) test('should validate that handleSteps is a generator function', async () => { diff --git a/common/src/__tests__/dynamic-agent-template-schema.test.ts b/common/src/__tests__/dynamic-agent-template-schema.test.ts index 8eff0bf8c..7a71bfb52 100644 --- a/common/src/__tests__/dynamic-agent-template-schema.test.ts +++ b/common/src/__tests__/dynamic-agent-template-schema.test.ts @@ -282,45 +282,29 @@ describe('DynamicAgentDefinitionSchema', () => { expect(result.success).toBe(true) }) - it('should reject template with set_output tool but non-structured_output outputMode', () => { + // Note: The validation that rejected set_output without structured_output mode was + // intentionally disabled to allow parent agents to have set_output tool with 'last_message' + // outputMode while their subagents use 'structured_output' (preserves prompt caching). + it('should allow template with set_output tool and non-structured_output outputMode', () => { const template = { ...validBaseTemplate, outputMode: 'last_message' as const, - toolNames: ['end_turn', 'set_output'], // set_output without structured_output mode + toolNames: ['end_turn', 'set_output'], // set_output is now allowed with any outputMode } const result = DynamicAgentTemplateSchema.safeParse(template) - expect(result.success).toBe(false) - if (!result.success) { - const setOutputError = result.error.issues.find((issue) => - issue.message.includes( - "'set_output' tool requires outputMode to be 'structured_output'", - ), - ) - expect(setOutputError).toBeDefined() - expect(setOutputError?.message).toContain( - "'set_output' tool requires outputMode to be 'structured_output'", - ) - } + expect(result.success).toBe(true) }) - it('should reject template with set_output tool and all_messages outputMode', () => { + it('should allow template with set_output tool and all_messages outputMode', () => { const template = { ...validBaseTemplate, outputMode: 'all_messages' as const, - toolNames: ['end_turn', 'set_output'], // set_output without structured_output mode + toolNames: ['end_turn', 'set_output'], // set_output is now allowed with any outputMode } const result = DynamicAgentTemplateSchema.safeParse(template) - expect(result.success).toBe(false) - if (!result.success) { - const setOutputError = result.error.issues.find((issue) => - issue.message.includes( - "'set_output' tool requires outputMode to be 'structured_output'", - ), - ) - expect(setOutputError).toBeDefined() - } + expect(result.success).toBe(true) }) it('should reject template with non-empty spawnableAgents but missing spawn_agents tool', () => { diff --git a/common/src/tools/params/tool/add-message.ts b/common/src/tools/params/tool/add-message.ts index 2866cc2d3..ed680468d 100644 --- a/common/src/tools/params/tool/add-message.ts +++ b/common/src/tools/params/tool/add-message.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + textToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -16,7 +19,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -32,5 +35,5 @@ export const addMessageParams = { endsAgentStep, description, inputSchema, - outputSchema: emptyToolResultSchema(), + outputSchema: textToolResultSchema(), } satisfies $ToolParams diff --git a/common/src/tools/params/tool/add-subgoal.ts b/common/src/tools/params/tool/add-subgoal.ts index ed592797b..0630e76de 100644 --- a/common/src/tools/params/tool/add-subgoal.ts +++ b/common/src/tools/params/tool/add-subgoal.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -32,7 +32,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/ask-user.ts b/common/src/tools/params/tool/ask-user.ts index 8a228de46..a87e7d7fd 100644 --- a/common/src/tools/params/tool/ask-user.ts +++ b/common/src/tools/params/tool/ask-user.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -8,13 +8,19 @@ export const questionSchema = z.object({ question: z.string().describe('The question to ask the user'), header: z .string() - .max(12) + .max(18) .optional() - .describe('Short label (max 12 chars) displayed as a chip/tag. Example: "Auth method"'), + .describe( + // Tell the llm 12 chars so that if it goes over slightly, it will still be under the max. + 'Short label (max 12 chars) displayed as a chip/tag. Example: "Auth method"', + ), options: z .object({ label: z.string().describe('The display text for this option'), - description: z.string().optional().describe('Explanation shown when option is focused'), + description: z + .string() + .optional() + .describe('Explanation shown when option is focused'), }) .array() .refine((opts) => opts.length >= 2, { @@ -30,10 +36,22 @@ export const questionSchema = z.object({ ), validation: z .object({ - maxLength: z.number().optional().describe('Maximum length for "Other" text input'), - minLength: z.number().optional().describe('Minimum length for "Other" text input'), - pattern: z.string().optional().describe('Regex pattern for "Other" text input'), - patternError: z.string().optional().describe('Custom error message when pattern fails'), + maxLength: z + .number() + .optional() + .describe('Maximum length for "Other" text input'), + minLength: z + .number() + .optional() + .describe('Minimum length for "Other" text input'), + pattern: z + .string() + .optional() + .describe('Regex pattern for "Other" text input'), + patternError: z + .string() + .optional() + .describe('Custom error message when pattern fails'), }) .optional() .describe('Validation rules for "Other" text input'), @@ -67,14 +85,20 @@ const outputSchema = z.object({ .array(z.string()) .optional() .describe('Array of selected option texts (multi-select mode)'), - otherText: z.string().optional().describe('Custom text input (if user typed their own answer)'), + otherText: z + .string() + .optional() + .describe('Custom text input (if user typed their own answer)'), }), ) .optional() .describe( 'Array of user answers, one per question. Each answer has either selectedOption (single), selectedOptions (multi), or otherText.', ), - skipped: z.boolean().optional().describe('True if user skipped the questions'), + skipped: z + .boolean() + .optional() + .describe('True if user skipped the questions'), }) const description = ` @@ -87,7 +111,7 @@ The user can either: - Skip the questions to provide different instructions instead Single-select example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -96,9 +120,18 @@ ${$getToolCallString({ question: 'Which authentication method should we use?', header: 'Auth method', options: [ - { label: 'JWT tokens', description: 'Stateless tokens stored in localStorage' }, - { label: 'Session cookies', description: 'Server-side sessions with httpOnly cookies' }, - { label: 'OAuth2', description: 'Third-party authentication (Google, GitHub, etc.)' }, + { + label: 'JWT tokens', + description: 'Stateless tokens stored in localStorage', + }, + { + label: 'Session cookies', + description: 'Server-side sessions with httpOnly cookies', + }, + { + label: 'OAuth2', + description: 'Third-party authentication (Google, GitHub, etc.)', + }, ], }, ], @@ -107,7 +140,7 @@ ${$getToolCallString({ })} Multi-select example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/browser-logs.ts b/common/src/tools/params/tool/browser-logs.ts index acb4d51d9..742c2168c 100644 --- a/common/src/tools/params/tool/browser-logs.ts +++ b/common/src/tools/params/tool/browser-logs.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { BrowserResponseSchema } from '../../../browser-actions' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -64,7 +64,7 @@ Navigate: - \`waitUntil\`: (required) One of 'load', 'domcontentloaded', 'networkidle0' Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/code-search.ts b/common/src/tools/params/tool/code-search.ts index 876ea2934..2f5d82791 100644 --- a/common/src/tools/params/tool/code-search.ts +++ b/common/src/tools/params/tool/code-search.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -85,37 +85,37 @@ RESULT LIMITING: - If the global limit is reached, remaining files will be skipped Examples: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'foo' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'foo\\.bar = 1\\.0' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'import.*foo', cwd: 'src' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'function.*authenticate', flags: '-i -t ts -t js' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'TODO', flags: '-n --type-not py' }, endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { pattern: 'getUserData', maxResults: 10 }, diff --git a/common/src/tools/params/tool/create-plan.ts b/common/src/tools/params/tool/create-plan.ts index 1aca1d6ce..56c027da2 100644 --- a/common/src/tools/params/tool/create-plan.ts +++ b/common/src/tools/params/tool/create-plan.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { updateFileResultSchema } from './str-replace' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -52,7 +52,7 @@ After creating the plan, you should end turn to let the user review the plan. Important: Use this tool sparingly. Do not use this tool more than once in a conversation, unless in ask mode. Examples: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/end-turn.ts b/common/src/tools/params/tool/end-turn.ts index 16d21a672..be9fa94f3 100644 --- a/common/src/tools/params/tool/end-turn.ts +++ b/common/src/tools/params/tool/end-turn.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + textToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -20,14 +23,14 @@ Only use this tool to hand control back to the user. - Effect: Signals the UI to wait for the user's reply; any pending tool results will be ignored. *INCORRECT USAGE*: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName: 'some_tool_that_produces_results', inputSchema: null, input: { query: 'some example search term' }, endsAgentStep: false, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, @@ -37,7 +40,7 @@ ${$getToolCallString({ *CORRECT USAGE*: All done! Would you like some more help with xyz? -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, @@ -50,5 +53,5 @@ export const endTurnParams = { endsAgentStep, description, inputSchema, - outputSchema: emptyToolResultSchema(), + outputSchema: textToolResultSchema(), } satisfies $ToolParams diff --git a/common/src/tools/params/tool/find-files.ts b/common/src/tools/params/tool/find-files.ts index 4b46e15ec..3a931b342 100644 --- a/common/src/tools/params/tool/find-files.ts +++ b/common/src/tools/params/tool/find-files.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { fileContentsSchema } from './read-files' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -21,7 +21,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/glob.ts b/common/src/tools/params/tool/glob.ts index e98dc6798..b944dd73e 100644 --- a/common/src/tools/params/tool/glob.ts +++ b/common/src/tools/params/tool/glob.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -26,7 +26,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/list-directory.ts b/common/src/tools/params/tool/list-directory.ts index 403179981..d70590f37 100644 --- a/common/src/tools/params/tool/list-directory.ts +++ b/common/src/tools/params/tool/list-directory.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -19,7 +19,7 @@ const description = ` Lists all files and directories in the specified path. Useful for exploring directory structure and finding files. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -28,7 +28,7 @@ ${$getToolCallString({ endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/lookup-agent-info.ts b/common/src/tools/params/tool/lookup-agent-info.ts index 4f1ee5cc5..029668ec4 100644 --- a/common/src/tools/params/tool/lookup-agent-info.ts +++ b/common/src/tools/params/tool/lookup-agent-info.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { jsonValueSchema } from '../../../types/json' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -18,7 +18,7 @@ const description = ` Retrieve information about an agent by ID for proper spawning. Use this when you see a request with a full agent ID like "@publisher/agent-id@version" to validate the agent exists and get its metadata. Only agents that are published under a publisher and version are supported for this tool. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/read-docs.ts b/common/src/tools/params/tool/read-docs.ts index 235c3faee..25e5ee06b 100644 --- a/common/src/tools/params/tool/read-docs.ts +++ b/common/src/tools/params/tool/read-docs.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -50,7 +50,7 @@ Use cases: The tool will search for the library and return the most relevant documentation content. If a topic is specified, it will focus the results on that specific area. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -61,7 +61,7 @@ ${$getToolCallString({ endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/read-files.ts b/common/src/tools/params/tool/read-files.ts index 2c1877720..3f757aa9b 100644 --- a/common/src/tools/params/tool/read-files.ts +++ b/common/src/tools/params/tool/read-files.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -36,7 +36,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/read-subtree.ts b/common/src/tools/params/tool/read-subtree.ts index 3156d8ca7..09f0c1f58 100644 --- a/common/src/tools/params/tool/read-subtree.ts +++ b/common/src/tools/params/tool/read-subtree.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -28,7 +28,7 @@ const inputSchema = z ) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/run-file-change-hooks.ts b/common/src/tools/params/tool/run-file-change-hooks.ts index 1b1379982..e69c211d6 100644 --- a/common/src/tools/params/tool/run-file-change-hooks.ts +++ b/common/src/tools/params/tool/run-file-change-hooks.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { terminalCommandOutputSchema } from './run-terminal-command' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -25,7 +25,7 @@ Use cases: The client will run only the hooks whose filePattern matches the provided files. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/run-terminal-command.ts b/common/src/tools/params/tool/run-terminal-command.ts index c89e16e57..4bd53f0c2 100644 --- a/common/src/tools/params/tool/run-terminal-command.ts +++ b/common/src/tools/params/tool/run-terminal-command.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -156,7 +156,7 @@ Notes: ${gitCommitGuidePrompt} Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -165,7 +165,7 @@ ${$getToolCallString({ endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/set-messages.ts b/common/src/tools/params/tool/set-messages.ts index bb062cadf..0bab2f04f 100644 --- a/common/src/tools/params/tool/set-messages.ts +++ b/common/src/tools/params/tool/set-messages.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + textToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -13,7 +16,7 @@ const inputSchema = z .describe(`Set the conversation history to the provided messages.`) const description = ` Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -37,5 +40,5 @@ export const setMessagesParams = { endsAgentStep, description, inputSchema, - outputSchema: emptyToolResultSchema(), + outputSchema: textToolResultSchema(), } satisfies $ToolParams diff --git a/common/src/tools/params/tool/set-output.ts b/common/src/tools/params/tool/set-output.ts index f86c94f80..d9a69ea5d 100644 --- a/common/src/tools/params/tool/set-output.ts +++ b/common/src/tools/params/tool/set-output.ts @@ -1,22 +1,26 @@ import z from 'zod/v4' -import { $getToolCallString } from '../utils' +import { $getNativeToolCallExampleString } from '../utils' import type { $ToolParams } from '../../constants' const toolName = 'set_output' const endsAgentStep = false const inputSchema = z - .looseObject({}) + .looseObject({ + data: z.record(z.string(), z.any()).optional(), + }) .describe( - 'JSON object to set as the agent output. This completely replaces any previous output. If the agent was spawned, this value will be passed back to its parent. If the agent has an outputSchema defined, the output will be validated against it.', + 'JSON object to set as the agent output. The shape of the parameters are specified dynamically further down in the conversation. This completely replaces any previous output. If the agent was spawned, this value will be passed back to its parent. If the agent has an outputSchema defined, the output will be validated against it.', ) const description = ` -You must use this tool as it is the only way to report any findings to the user. Nothing else you write will be shown to the user. +Subagents must use this tool as it is the only way to report any findings. Nothing else you write will be visible to the user/parent agent. -Please set the output with all the information and analysis you want to pass on to the user. If you just want to send a simple message, use an object with the key "message" and value of the message you want to send. +Note that the output schema is provided dynamically in a user prompt further down in the conversation. Be sure to follow what the latest output schema is when using this tool. + +Please set the output with all the information and analysis you want to pass on. If you just want to send a simple message, use an object with the key "message" and value of the message you want to send. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/spawn-agent-inline.ts b/common/src/tools/params/tool/spawn-agent-inline.ts index 6ee9a9d44..60e234594 100644 --- a/common/src/tools/params/tool/spawn-agent-inline.ts +++ b/common/src/tools/params/tool/spawn-agent-inline.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + textToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -31,7 +34,7 @@ This is useful for: - Managing message history (e.g., summarization) The agent will run until it calls end_turn, then control returns to you. There is no tool result for this tool. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -48,5 +51,5 @@ export const spawnAgentInlineParams = { endsAgentStep, description, inputSchema, - outputSchema: emptyToolResultSchema(), + outputSchema: textToolResultSchema(), } satisfies $ToolParams diff --git a/common/src/tools/params/tool/spawn-agents.ts b/common/src/tools/params/tool/spawn-agents.ts index f7da5e5d7..e7cb0ff54 100644 --- a/common/src/tools/params/tool/spawn-agents.ts +++ b/common/src/tools/params/tool/spawn-agents.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { jsonObjectSchema } from '../../../types/json' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -31,18 +31,22 @@ const inputSchema = z `Spawn multiple agents and send a prompt and/or parameters to each of them. These agents will run in parallel. Note that that means they will run independently. If you need to run agents sequentially, use spawn_agents with one agent at a time instead.`, ) const description = ` -Use this tool to spawn agents to help you complete the user request. Each agent has specific requirements for prompt and params based on their inputSchema. +Use this tool to spawn agents to help you complete the user request. Each agent has specific requirements for prompt and params based on their tools schema. The prompt field is a simple string, while params is a JSON object that gets validated against the agent's schema. +Each agent available is already defined as another tool, or, dynamically defined later in the conversation. + +You can call agents either as direct tool calls (e.g., \`example-agent\`) or use \`spawn_agents\`. Both formats work, but **prefer using spawn_agents** because it allows you to spawn multiple agents in parallel for better performance. When using direct tool calls, the schema is flat (prompt is a field alongside other params), whereas spawn_agents uses nested \`prompt\` and \`params\` fields. + Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { agents: [ { - agent_type: 'planner', + agent_type: 'example-agent', prompt: 'Create a plan for implementing user authentication', params: { filePaths: ['src/auth.ts', 'src/user.ts'] }, }, diff --git a/common/src/tools/params/tool/str-replace.ts b/common/src/tools/params/tool/str-replace.ts index 5aee745fe..b02ce1e81 100644 --- a/common/src/tools/params/tool/str-replace.ts +++ b/common/src/tools/params/tool/str-replace.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -61,7 +61,7 @@ Important: If you are making multiple edits in a row to a file, use only one str_replace call with multiple replacements instead of multiple str_replace tool calls. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/task-completed.ts b/common/src/tools/params/tool/task-completed.ts index a8c35d1c6..3d9cdc5a0 100644 --- a/common/src/tools/params/tool/task-completed.ts +++ b/common/src/tools/params/tool/task-completed.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + textToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -34,19 +37,19 @@ Use this tool to signal that the task is complete. All changes have been implemented and tested successfully! -${$getToolCallString({ toolName, inputSchema, input: {}, endsAgentStep })} +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, endsAgentStep })} OR I need more information to proceed. Which database schema should I use for this migration? -${$getToolCallString({ toolName, inputSchema, input: {}, endsAgentStep })} +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, endsAgentStep })} OR I can't get the tests to pass after several different attempts. I need help from the user to proceed. -${$getToolCallString({ toolName, inputSchema, input: {}, endsAgentStep })} +${$getNativeToolCallExampleString({ toolName, inputSchema, input: {}, endsAgentStep })} `.trim() export const taskCompletedParams = { @@ -54,5 +57,5 @@ export const taskCompletedParams = { endsAgentStep, description, inputSchema, - outputSchema: emptyToolResultSchema(), + outputSchema: textToolResultSchema(), } satisfies $ToolParams diff --git a/common/src/tools/params/tool/think-deeply.ts b/common/src/tools/params/tool/think-deeply.ts index 4292332fa..ec387454b 100644 --- a/common/src/tools/params/tool/think-deeply.ts +++ b/common/src/tools/params/tool/think-deeply.ts @@ -1,6 +1,9 @@ import z from 'zod/v4' -import { $getToolCallString, emptyToolResultSchema } from '../utils' +import { + $getNativeToolCallExampleString, + textToolResultSchema, +} from '../utils' import type { $ToolParams } from '../../constants' @@ -29,7 +32,7 @@ Avoid for simple changes (e.g., single functions, minor edits). This tool does not generate a tool result. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -49,5 +52,5 @@ export const thinkDeeplyParams = { endsAgentStep, description, inputSchema, - outputSchema: emptyToolResultSchema(), + outputSchema: textToolResultSchema(), } satisfies $ToolParams diff --git a/common/src/tools/params/tool/update-subgoal.ts b/common/src/tools/params/tool/update-subgoal.ts index 299ca9eea..75e778c63 100644 --- a/common/src/tools/params/tool/update-subgoal.ts +++ b/common/src/tools/params/tool/update-subgoal.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -31,7 +31,7 @@ const description = ` Examples: Usage 1 (update status): -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -42,7 +42,7 @@ ${$getToolCallString({ })} Usage 2 (update plan): -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -53,7 +53,7 @@ ${$getToolCallString({ })} Usage 3 (add log): -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -64,7 +64,7 @@ ${$getToolCallString({ })} Usage 4 (update status and add log): -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/web-search.ts b/common/src/tools/params/tool/web-search.ts index 7a458cc01..e87c8f271 100644 --- a/common/src/tools/params/tool/web-search.ts +++ b/common/src/tools/params/tool/web-search.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -34,7 +34,7 @@ Use cases: The tool will return search results with titles, URLs, and content snippets. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -44,7 +44,7 @@ ${$getToolCallString({ endsAgentStep, })} -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/write-file.ts b/common/src/tools/params/tool/write-file.ts index 00ec71c6d..cf50fee05 100644 --- a/common/src/tools/params/tool/write-file.ts +++ b/common/src/tools/params/tool/write-file.ts @@ -1,7 +1,7 @@ import z from 'zod/v4' import { updateFileResultSchema } from './str-replace' -import { $getToolCallString, jsonToolResultSchema } from '../utils' +import { $getNativeToolCallExampleString, jsonToolResultSchema } from '../utils' import type { $ToolParams } from '../../constants' @@ -39,7 +39,7 @@ Do not use this tool to delete or rename a file. Instead run a terminal command Examples: Example 1 - Simple file creation: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { @@ -51,7 +51,7 @@ ${$getToolCallString({ })} Example 2 - Editing with placeholder comments: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/tool/write-todos.ts b/common/src/tools/params/tool/write-todos.ts index 7b7489a6f..ae73e72a1 100644 --- a/common/src/tools/params/tool/write-todos.ts +++ b/common/src/tools/params/tool/write-todos.ts @@ -1,6 +1,6 @@ import z from 'zod/v4' -import { $getToolCallString } from '../utils' +import { $getNativeToolCallExampleString } from '../utils' import type { $ToolParams } from '../../constants' @@ -30,7 +30,7 @@ After completing each todo step, call this tool again to update the list and mar Use this tool frequently as you work through tasks to update the list of todos with their current status. Doing this is extremely useful because it helps you stay on track and complete all the requirements of the user's request. It also helps inform the user of your plans and the current progress, which they want to know at all times. Example: -${$getToolCallString({ +${$getNativeToolCallExampleString({ toolName, inputSchema, input: { diff --git a/common/src/tools/params/utils.ts b/common/src/tools/params/utils.ts index cbf79d327..1c27d0097 100644 --- a/common/src/tools/params/utils.ts +++ b/common/src/tools/params/utils.ts @@ -34,6 +34,20 @@ export function $getToolCallString(params: { return [startToolTag, JSON.stringify(obj, null, 2), endToolTag].join('') } +export function $getNativeToolCallExampleString(params: { + toolName: string + inputSchema: z.ZodType | null + input: Input + endsAgentStep?: boolean // unused +}): string { + const { toolName, input } = params + return [ + `<${toolName}_params_example>\n`, + JSON.stringify(input, null, 2), + `\n`, + ].join('') +} + /** Generates the zod schema for a single JSON tool result. */ export function jsonToolResultSchema( valueSchema: z.ZodType, @@ -50,3 +64,15 @@ export function jsonToolResultSchema( export function emptyToolResultSchema() { return z.tuple([]) } + +/** Generates the zod schema for a simple text tool result. */ +export function textToolResultSchema() { + return z.tuple([ + z.object({ + type: z.literal('json'), + value: z.object({ + message: z.string(), + }), + }) satisfies z.ZodType, + ]) +} diff --git a/common/src/types/agent-template.ts b/common/src/types/agent-template.ts index 77989fc6d..9cd57c24d 100644 --- a/common/src/types/agent-template.ts +++ b/common/src/types/agent-template.ts @@ -5,6 +5,8 @@ * It imports base types from the user-facing template to eliminate duplication. */ +import { z } from 'zod/v4' + import type { MCPConfig } from './mcp' import type { Model } from '../old-constants' import type { ToolResultOutput } from './messages/content-part' @@ -15,7 +17,6 @@ import type { } from '../templates/initial-agents-dir/types/agent-definition' import type { Logger } from '../templates/initial-agents-dir/types/util-types' import type { ToolName } from '../tools/constants' -import type { z } from 'zod/v4' export type AgentId = `${string}/${string}@${number}.${number}.${number}` @@ -141,6 +142,33 @@ export type AgentTemplate< export type StepText = { type: 'STEP_TEXT'; text: string } export type GenerateN = { type: 'GENERATE_N'; n: number } +// Zod schemas for handleSteps yield values +export const StepTextSchema = z.object({ + type: z.literal('STEP_TEXT'), + text: z.string(), +}) + +export const GenerateNSchema = z.object({ + type: z.literal('GENERATE_N'), + n: z.number().int().positive(), +}) + +export const HandleStepsToolCallSchema = z.object({ + toolName: z.string().min(1), + input: z.record(z.string(), z.any()), + includeToolCall: z.boolean().optional(), +}) + +export const HandleStepsYieldValueSchema = z.union([ + z.literal('STEP'), + z.literal('STEP_ALL'), + StepTextSchema, + GenerateNSchema, + HandleStepsToolCallSchema, +]) + +export type HandleStepsYieldValue = z.infer + export type StepGenerator = Generator< Omit | 'STEP' | 'STEP_ALL' | StepText | GenerateN, // Generic tool call type void, diff --git a/common/src/types/contracts/llm.ts b/common/src/types/contracts/llm.ts index 23fc5ede7..58ff0ff15 100644 --- a/common/src/types/contracts/llm.ts +++ b/common/src/types/contracts/llm.ts @@ -1,12 +1,13 @@ import type { TrackEventFn } from './analytics' import type { SendActionFn } from './client' import type { CheckLiveUserInputFn } from './live-user-input' +import type { OpenRouterProviderRoutingOptions } from '../agent-template' import type { ParamsExcluding } from '../function-params' import type { Logger } from './logger' import type { Model } from '../../old-constants' import type { Message } from '../messages/codebuff-message' -import type { OpenRouterProviderRoutingOptions } from '../agent-template' -import type { generateText, streamText } from 'ai' +import type { AgentTemplate } from '../agent-template' +import type { generateText, streamText, ToolCallPart } from 'ai' import type z from 'zod/v4' export type StreamChunk = @@ -19,6 +20,10 @@ export type StreamChunk = type: 'reasoning' text: string } + | Pick< + ToolCallPart, + 'type' | 'toolCallId' | 'toolName' | 'input' | 'providerOptions' + > | { type: 'error'; message: string } export type PromptAiSdkStreamFn = ( @@ -38,6 +43,10 @@ export type PromptAiSdkStreamFn = ( onCostCalculated?: (credits: number) => Promise includeCacheControl?: boolean agentProviderOptions?: OpenRouterProviderRoutingOptions + /** List of agents that can be spawned - used to transform agent tool calls */ + spawnableAgents?: string[] + /** Map of locally available agent templates - used to transform agent tool calls */ + localAgentTemplates?: Record sendAction: SendActionFn logger: Logger trackEvent: TrackEventFn diff --git a/common/src/types/dynamic-agent-template.ts b/common/src/types/dynamic-agent-template.ts index d838b4be1..e8a22d96c 100644 --- a/common/src/types/dynamic-agent-template.ts +++ b/common/src/types/dynamic-agent-template.ts @@ -244,23 +244,24 @@ export const DynamicAgentTemplateSchema = DynamicAgentDefinitionSchema.extend({ path: ['toolNames'], }, ) - .refine( - (data) => { - // If 'set_output' tool is included, outputMode must be 'structured_output' - if ( - data.toolNames.includes('set_output') && - data.outputMode !== 'structured_output' - ) { - return false - } - return true - }, - { - message: - "'set_output' tool requires outputMode to be 'structured_output'. Change outputMode to 'structured_output' or remove 'set_output' from toolNames.", - path: ['outputMode'], - }, - ) + // Note(James): Disabled so that a parent agent can have set_output tool and 'last_message' outputMode while its subagents use 'structured_output'. (The set_output tool must be included in parent to preserver prompt caching.) + // .refine( + // (data) => { + // // If 'set_output' tool is included, outputMode must be 'structured_output' + // if ( + // data.toolNames.includes('set_output') && + // data.outputMode !== 'structured_output' + // ) { + // return false + // } + // return true + // }, + // { + // message: + // "'set_output' tool requires outputMode to be 'structured_output'. Change outputMode to 'structured_output' or remove 'set_output' from toolNames.", + // path: ['outputMode'], + // }, + // ) .refine( (data) => { // If spawnableAgents array is non-empty, 'spawn_agents' tool must be included diff --git a/common/src/types/session-state.ts b/common/src/types/session-state.ts index 8a3abaec2..7fc5907a4 100644 --- a/common/src/types/session-state.ts +++ b/common/src/types/session-state.ts @@ -48,7 +48,7 @@ export const AgentOutputSchema = z.discriminatedUnion('type', [ }), z.object({ type: z.literal('lastMessage'), - value: z.any(), + value: z.array(z.any()), // Array of assistant and tool messages from the last turn, including tool results }), z.object({ type: z.literal('allMessages'), diff --git a/common/src/util/__tests__/messages.test.ts b/common/src/util/__tests__/messages.test.ts index 53e1cb722..72658d1a0 100644 --- a/common/src/util/__tests__/messages.test.ts +++ b/common/src/util/__tests__/messages.test.ts @@ -13,6 +13,7 @@ import { } from '../messages' import type { Message } from '../../types/messages/codebuff-message' +import type { AssistantModelMessage, ToolResultPart } from 'ai' describe('withCacheControl', () => { it('should add cache control to object without providerOptions', () => { @@ -189,12 +190,6 @@ describe('convertCbToModelMessages', () => { describe('tool message conversion', () => { it('should convert tool messages with JSON output', () => { - const toolResult = [ - { - type: 'json', - value: { result: 'success' }, - }, - ] const messages: Message[] = [ { role: 'tool', @@ -211,15 +206,17 @@ describe('convertCbToModelMessages', () => { expect(result).toEqual([ expect.objectContaining({ - role: 'user', + role: 'tool', content: [ expect.objectContaining({ - type: 'text', - }), + type: 'tool-result', + toolCallId: 'call_123', + toolName: 'test_tool', + output: { type: 'json', value: { result: 'success' } }, + } satisfies ToolResultPart), ], }), ]) - expect((result as any)[0].content[0].text).toContain('') }) it('should convert tool messages with media output', () => { @@ -270,14 +267,15 @@ describe('convertCbToModelMessages', () => { includeCacheControl: false, }) - console.dir({ result }, { depth: null }) // Multiple tool outputs are aggregated into one user message expect(result).toEqual([ expect.objectContaining({ - role: 'user', + role: 'tool', + }), + expect.objectContaining({ + role: 'tool', }), ]) - expect(result[0].content).toHaveLength(2) }) }) @@ -806,14 +804,19 @@ describe('convertCbToModelMessages', () => { includeCacheControl: false, }) - expect(result).toHaveLength(1) - expect(result[0].role).toBe('assistant') - if (typeof result[0].content !== 'string') { - expect(result[0].content[0].type).toBe('text') - if (result[0].content[0].type === 'text') { - expect(result[0].content[0].text).toContain('test_tool') - } - } + expect(result).toEqual([ + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'call_123', + toolName: 'test_tool', + input: { param: 'value' }, + }, + ], + } satisfies AssistantModelMessage, + ]) }) it('should preserve message metadata during conversion', () => { diff --git a/common/src/util/messages.ts b/common/src/util/messages.ts index b8387a96f..478a2d4d2 100644 --- a/common/src/util/messages.ts +++ b/common/src/util/messages.ts @@ -1,7 +1,5 @@ import { cloneDeep, has, isEqual } from 'lodash' -import { getToolCallString } from '../tools/utils' - import type { JSONValue } from '../types/json' import type { AssistantMessage, @@ -100,21 +98,21 @@ function assistantToCodebuffMessage( content: Exclude[number] }, ): AssistantMessage { - if (message.content.type === 'tool-call') { - return cloneDeep({ - ...message, - content: [ - { - type: 'text', - text: getToolCallString( - message.content.toolName, - message.content.input, - false, - ), - }, - ], - }) - } + // if (message.content.type === 'tool-call') { + // return cloneDeep({ + // ...message, + // content: [ + // { + // type: 'text', + // text: getToolCallString( + // message.content.toolName, + // message.content.input, + // false, + // ), + // }, + // ], + // }) + // } return cloneDeep({ ...message, content: [message.content] }) } @@ -123,20 +121,10 @@ function convertToolResultMessage( ): ModelMessageWithAuxiliaryData[] { return message.content.map((c) => { if (c.type === 'json') { - const toolResult = { - toolName: message.toolName, - toolCallId: message.toolCallId, - output: c.value, - } - return cloneDeep({ + return cloneDeep({ ...message, - role: 'user', - content: [ - { - type: 'text', - text: `\n${JSON.stringify(toolResult, null, 2)}\n`, - }, - ], + role: 'tool', + content: [{ ...message, output: c, type: 'tool-result' }], }) } if (c.type === 'media') { @@ -147,8 +135,8 @@ function convertToolResultMessage( }) } c satisfies never - const oAny = c as any - throw new Error(`Invalid tool output type: ${oAny.type}`) + const cAny = c as any + throw new Error(`Invalid tool output type: ${cAny.type}`) }) } diff --git a/evals/buffbench/run-buffbench.ts b/evals/buffbench/run-buffbench.ts index 9d93d45ed..8acf9b70e 100644 --- a/evals/buffbench/run-buffbench.ts +++ b/evals/buffbench/run-buffbench.ts @@ -37,6 +37,7 @@ async function runTask(options: { extractLessons: boolean printEvents: boolean finalCheckCommands?: string[] + disableAnalysis?: boolean }) { const { client, @@ -53,6 +54,7 @@ async function runTask(options: { extractLessons, printEvents, finalCheckCommands, + disableAnalysis, } = options console.log( @@ -161,12 +163,14 @@ async function runTask(options: { const agentResults = await Promise.all(agentPromises) // After all agents complete for this commit, run trace analysis - const traceAnalysis = await analyzeAgentTraces({ - client, - traces: commitTraces, - codingAgentPrompt: commit.prompt, - analyzerContext, - }) + const traceAnalysis = disableAnalysis + ? undefined + : await analyzeAgentTraces({ + client, + traces: commitTraces, + codingAgentPrompt: commit.prompt, + analyzerContext, + }) const analysisData = { commitSha: commit.sha, @@ -268,6 +272,7 @@ export async function runBuffBench(options: { client?: CodebuffClient taskIds?: string[] extractLessons?: boolean + disableAnalysis?: boolean }) { const { evalDataPath, @@ -275,6 +280,7 @@ export async function runBuffBench(options: { taskConcurrency = 1, taskIds, extractLessons = false, + disableAnalysis = false, } = options const evalData: EvalDataV2 = JSON.parse( @@ -384,6 +390,7 @@ export async function runBuffBench(options: { extractLessons, printEvents: agents.length === 1 && taskConcurrency === 1, finalCheckCommands: evalData.finalCheckCommands, + disableAnalysis, }), ), ) @@ -448,36 +455,40 @@ export async function runBuffBench(options: { const logFiles = fs.readdirSync(logsDir) - const metaAnalysis = await analyzeAllTasks({ - client, - logsDir, - agents, - analyzerContext, - }) + const metaAnalysis = disableAnalysis + ? undefined + : await analyzeAllTasks({ + client, + logsDir, + agents, + analyzerContext, + }) - // Print meta-analysis results - console.log('\n=== Meta-Analysis Results ===') - console.log('\nOverall Comparison:') - console.log(metaAnalysis.overallComparison) - - if (metaAnalysis.agentInsights.length > 0) { - console.log('\nAgent-Specific Insights:') - for (const insight of metaAnalysis.agentInsights) { - console.log(`\n[${insight.agentId}]`) - if (insight.consistentStrengths.length > 0) { - console.log(' Strengths:', insight.consistentStrengths.join(', ')) - } - if (insight.consistentWeaknesses.length > 0) { - console.log(' Weaknesses:', insight.consistentWeaknesses.join(', ')) + if (metaAnalysis) { + // Print meta-analysis results + console.log('\n=== Meta-Analysis Results ===') + console.log('\nOverall Comparison:') + console.log(metaAnalysis.overallComparison) + + if (metaAnalysis.agentInsights.length > 0) { + console.log('\nAgent-Specific Insights:') + for (const insight of metaAnalysis.agentInsights) { + console.log(`\n[${insight.agentId}]`) + if (insight.consistentStrengths.length > 0) { + console.log(' Strengths:', insight.consistentStrengths.join(', ')) + } + if (insight.consistentWeaknesses.length > 0) { + console.log(' Weaknesses:', insight.consistentWeaknesses.join(', ')) + } } } - } - if (metaAnalysis.keyFindings.length > 0) { - console.log('\nKey Findings:') - metaAnalysis.keyFindings.forEach((finding, i) => { - console.log(` ${i + 1}. ${finding}`) - }) + if (metaAnalysis.keyFindings.length > 0) { + console.log('\nKey Findings:') + metaAnalysis.keyFindings.forEach((finding, i) => { + console.log(` ${i + 1}. ${finding}`) + }) + } } const finalResults = { diff --git a/evals/scaffolding.ts b/evals/scaffolding.ts index 6250a2f0b..a86b7b4e3 100644 --- a/evals/scaffolding.ts +++ b/evals/scaffolding.ts @@ -206,13 +206,15 @@ export async function runAgentStepScaffolding( const result = await runAgentStep({ ...EVALS_AGENT_RUNTIME_IMPL, ...agentRuntimeScopedImpl, + additionalToolDefinitions: () => Promise.resolve({}), - textOverride: null, - runId: 'test-run-id', - userId: TEST_USER_ID, - userInputId: generateCompactId(), + agentState, + agentType, + ancestorRunIds: [], clientSessionId: sessionId, + fileContext, fingerprintId: 'test-fingerprint-id', + localAgentTemplates, onResponseChunk: (chunk: string | PrintModeEvent) => { if (typeof chunk !== 'string') { return @@ -222,17 +224,16 @@ export async function runAgentStepScaffolding( } fullResponse += chunk }, - agentType, - fileContext, - localAgentTemplates, - agentState, prompt, - ancestorRunIds: [], - spawnParams: undefined, - repoUrl: undefined, repoId: undefined, - system: 'Test system prompt', + repoUrl: undefined, + runId: 'test-run-id', signal: new AbortController().signal, + spawnParams: undefined, + system: 'Test system prompt', + tools: {}, + userId: TEST_USER_ID, + userInputId: generateCompactId(), }) return { diff --git a/packages/agent-runtime/src/__tests__/cost-aggregation.test.ts b/packages/agent-runtime/src/__tests__/cost-aggregation.test.ts index 5c73a01b3..dc6fe6414 100644 --- a/packages/agent-runtime/src/__tests__/cost-aggregation.test.ts +++ b/packages/agent-runtime/src/__tests__/cost-aggregation.test.ts @@ -90,6 +90,7 @@ describe('Cost Aggregation System', () => { repoUrl: undefined, signal: new AbortController().signal, system: 'Test system prompt', + tools: {}, userId: 'test-user', userInputId: 'test-input', writeToClient: () => {}, @@ -159,7 +160,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: 75, // First subagent uses 75 credits }, - output: { type: 'lastMessage', value: 'Sub-agent 1 response' }, + output: { type: 'lastMessage', value: [assistantMessage('Sub-agent 1 response')] }, }) .mockResolvedValueOnce({ agentState: { @@ -169,7 +170,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: 100, // Second subagent uses 100 credits }, - output: { type: 'lastMessage', value: 'Sub-agent 2 response' }, + output: { type: 'lastMessage', value: [assistantMessage('Sub-agent 2 response')] }, }) const mockToolCall = { @@ -223,7 +224,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: 50, // Successful agent }, - output: { type: 'lastMessage', value: 'Successful response' }, + output: { type: 'lastMessage', value: [assistantMessage('Successful response')] }, }) .mockRejectedValueOnce( (() => { @@ -370,7 +371,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: subAgent1Cost, } as AgentState, - output: { type: 'lastMessage', value: 'Sub-agent 1 response' }, + output: { type: 'lastMessage', value: [assistantMessage('Sub-agent 1 response')] }, }) .mockResolvedValueOnce({ agentState: { @@ -381,7 +382,7 @@ describe('Cost Aggregation System', () => { stepsRemaining: 10, creditsUsed: subAgent2Cost, } as AgentState, - output: { type: 'lastMessage', value: 'Sub-agent 2 response' }, + output: { type: 'lastMessage', value: [assistantMessage('Sub-agent 2 response')] }, }) const mockToolCall = { diff --git a/packages/agent-runtime/src/__tests__/loop-agent-steps.test.ts b/packages/agent-runtime/src/__tests__/loop-agent-steps.test.ts index 31a5fc81a..baa922ff3 100644 --- a/packages/agent-runtime/src/__tests__/loop-agent-steps.test.ts +++ b/packages/agent-runtime/src/__tests__/loop-agent-steps.test.ts @@ -5,7 +5,6 @@ import { clearMockedModules, mockModule, } from '@codebuff/common/testing/mock-modules' -import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { assistantMessage, userMessage } from '@codebuff/common/util/messages' import db from '@codebuff/internal/db' @@ -25,7 +24,7 @@ import { z } from 'zod/v4' import { disableLiveUserInputCheck } from '../live-user-inputs' import { loopAgentSteps } from '../run-agent-step' import { clearAgentGeneratorCache } from '../run-programmatic-step' -import { mockFileContext } from './test-utils' +import { createToolCallChunk, mockFileContext } from './test-utils' import type { AgentTemplate } from '../templates/types' import type { StepGenerator } from '@codebuff/common/types/agent-template' @@ -81,10 +80,8 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => agentRuntimeImpl.promptAiSdkStream = async function* ({}) { llmCallCount++ - yield { - type: 'text' as const, - text: `LLM response\n\n${getToolCallString('end_turn', {})}`, - } + yield { type: 'text' as const, text: 'LLM response\n\n' } + yield createToolCallChunk('end_turn', {}) return 'mock-message-id' } @@ -508,10 +505,8 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => llmStepCount++ // LLM always tries to end turn - yield { - type: 'text' as const, - text: `LLM response\n\n${getToolCallString('end_turn', {})}`, - } + yield { type: 'text' as const, text: 'LLM response\n\n' } + yield createToolCallChunk('end_turn', {}) return `mock-message-id-${promptCallCount}` } @@ -558,10 +553,8 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => llmCallNumber++ if (llmCallNumber === 1) { // First call: agent tries to end turn without setting output - yield { - type: 'text' as const, - text: `First response without output\n\n${getToolCallString('end_turn', {})}`, - } + yield { type: 'text' as const, text: 'First response without output\n\n' } + yield createToolCallChunk('end_turn', {}) } else if (llmCallNumber === 2) { // Second call: agent sets output after being reminded // Manually set the output to simulate the set_output tool execution @@ -571,16 +564,14 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => status: 'success', } } - yield { - type: 'text' as const, - text: `Setting output now\n\n${getToolCallString('set_output', { result: 'test result', status: 'success' })}\n\n${getToolCallString('end_turn', {})}`, - } + yield { type: 'text' as const, text: 'Setting output now\n\n' } + yield createToolCallChunk('set_output', { result: 'test result', status: 'success' }) + yield { type: 'text' as const, text: '\n\n' } + yield createToolCallChunk('end_turn', {}) } else { // Safety: if called more than twice, just end - yield { - type: 'text' as const, - text: `Ending\n\n${getToolCallString('end_turn', {})}`, - } + yield { type: 'text' as const, text: 'Ending\n\n' } + yield createToolCallChunk('end_turn', {}) } return 'mock-message-id' } @@ -641,10 +632,10 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => if (capturedAgentState) { capturedAgentState.output = { result: 'success' } } - yield { - type: 'text' as const, - text: `Setting output\n\n${getToolCallString('set_output', { result: 'success' })}\n\n${getToolCallString('end_turn', {})}`, - } + yield { type: 'text' as const, text: 'Setting output\n\n' } + yield createToolCallChunk('set_output', { result: 'success' }) + yield { type: 'text' as const, text: '\n\n' } + yield createToolCallChunk('end_turn', {}) return 'mock-message-id' } @@ -757,10 +748,8 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => let llmCallNumber = 0 loopAgentStepsBaseParams.promptAiSdkStream = async function* ({}) { llmCallNumber++ - yield { - type: 'text' as const, - text: `Response without output\n\n${getToolCallString('end_turn', {})}`, - } + yield { type: 'text' as const, text: 'Response without output\n\n' } + yield createToolCallChunk('end_turn', {}) return 'mock-message-id' } @@ -802,19 +791,17 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => llmCallNumber++ if (llmCallNumber === 1) { // First call: agent does some work but doesn't end turn - yield { - type: 'text' as const, - text: `Doing work\n\n${getToolCallString('read_files', { paths: ['test.txt'] })}`, - } + yield { type: 'text' as const, text: 'Doing work\n\n' } + yield createToolCallChunk('read_files', { paths: ['test.txt'] }) } else { // Second call: agent sets output and ends if (capturedAgentState) { capturedAgentState.output = { result: 'done' } } - yield { - type: 'text' as const, - text: `Finishing\n\n${getToolCallString('set_output', { result: 'done' })}\n\n${getToolCallString('end_turn', {})}`, - } + yield { type: 'text' as const, text: 'Finishing\n\n' } + yield createToolCallChunk('set_output', { result: 'done' }) + yield { type: 'text' as const, text: '\n\n' } + yield createToolCallChunk('end_turn', {}) } return 'mock-message-id' } diff --git a/packages/agent-runtime/src/__tests__/main-prompt.test.ts b/packages/agent-runtime/src/__tests__/main-prompt.test.ts index cb552dfe1..2bb57809e 100644 --- a/packages/agent-runtime/src/__tests__/main-prompt.test.ts +++ b/packages/agent-runtime/src/__tests__/main-prompt.test.ts @@ -2,7 +2,6 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' -import { getToolCallString } from '@codebuff/common/tools/utils' import { AgentTemplateTypes, getInitialSessionState, @@ -33,9 +32,15 @@ import type { ProjectFileContext } from '@codebuff/common/util/file' let mainPromptBaseParams: ParamsExcluding -const mockAgentStream = (streamOutput: string) => { +import { createToolCallChunk } from './test-utils' + +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' + +const mockAgentStream = (chunks: StreamChunk[]) => { mainPromptBaseParams.promptAiSdkStream = async function* ({}) { - yield { type: 'text' as const, text: streamOutput } + for (const chunk of chunks) { + yield chunk + } return 'mock-message-id' } } @@ -117,7 +122,7 @@ describe('mainPrompt', () => { ) // Mock LLM APIs - mockAgentStream('Test response') + mockAgentStream([{ type: 'text', text: 'Test response' }]) // Mock websocket actions mainPromptBaseParams.requestFiles = async ({ filePaths }) => { @@ -196,15 +201,15 @@ describe('mainPrompt', () => { } it('should handle write_file tool call', async () => { - // Mock LLM to return a write_file tool call using getToolCallString - const mockResponse = - getToolCallString('write_file', { + // Mock LLM to return a write_file tool call using native tool call chunks + mockAgentStream([ + createToolCallChunk('write_file', { path: 'new-file.txt', instructions: 'Added Hello World', content: 'Hello, world!', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) // Get reference to the spy so we can check if it was called const requestToolCallSpy = mainPromptBaseParams.requestToolCall @@ -355,7 +360,7 @@ describe('mainPrompt', () => { it('should return no tool calls when LLM response is empty', async () => { // Mock the LLM stream to return nothing - mockAgentStream('') + mockAgentStream([]) const sessionState = getInitialSessionState(mockFileContext) const action = { @@ -380,16 +385,15 @@ describe('mainPrompt', () => { it('should unescape ampersands in run_terminal_command tool calls', async () => { const sessionState = getInitialSessionState(mockFileContext) const userPromptText = 'Run the backend tests' - const escapedCommand = 'cd backend && bun test' const expectedCommand = 'cd backend && bun test' - const mockResponse = - getToolCallString('run_terminal_command', { - command: escapedCommand, + mockAgentStream([ + createToolCallChunk('run_terminal_command', { + command: expectedCommand, process_type: 'SYNC', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) // Get reference to the spy so we can check if it was called const requestToolCallSpy = mainPromptBaseParams.requestToolCall diff --git a/packages/agent-runtime/src/__tests__/malformed-tool-call.test.ts b/packages/agent-runtime/src/__tests__/malformed-tool-call.test.ts index 8b32ea54a..39761fe91 100644 --- a/packages/agent-runtime/src/__tests__/malformed-tool-call.test.ts +++ b/packages/agent-runtime/src/__tests__/malformed-tool-call.test.ts @@ -2,7 +2,6 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' -import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import * as stringUtils from '@codebuff/common/util/string' import { @@ -15,8 +14,10 @@ import { test, } from 'bun:test' -import { mockFileContext } from './test-utils' -import { processStreamWithTools } from '../tools/stream-parser' +import { createToolCallChunk, mockFileContext } from './test-utils' +import { processStream } from '../tools/stream-parser' + +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' import type { AgentTemplate } from '../templates/types' import type { @@ -34,7 +35,7 @@ let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL } describe('malformed tool call error handling', () => { let testAgent: AgentTemplate let agentRuntimeImpl: AgentRuntimeDeps & AgentRuntimeScopedDeps - let defaultParams: ParamsOf + let defaultParams: ParamsOf beforeEach(() => { agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } @@ -119,27 +120,31 @@ describe('malformed tool call error handling', () => { agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL } }) - function createMockStream(chunks: string[]) { + function createMockStream(chunks: StreamChunk[]) { async function* generator() { for (const chunk of chunks) { - yield { type: 'text' as const, text: chunk } + yield chunk } return 'mock-message-id' } return generator() } + function textChunk(text: string): StreamChunk { + return { type: 'text' as const, text } + } + test('should add tool result errors to message history after stream completes', async () => { - const chunks = [ - // Malformed JSON tool call - '\n{\n "cb_tool_name": "read_files",\n "paths": ["test.ts"\n}\n', - // Valid end turn - getToolCallString('end_turn', {}), + // With native tools, malformed tool calls are handled at the API level. + // This test now verifies that an unknown tool is properly handled. + const chunks: StreamChunk[] = [ + createToolCallChunk('unknown_tool_xyz', { paths: ['test.ts'] }), + createToolCallChunk('end_turn', {}), ] const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, stream, }) @@ -152,7 +157,7 @@ describe('malformed tool call error handling', () => { expect(toolMessages.length).toBeGreaterThan(0) - // Find the error tool result + // Find the error tool result for the unknown tool const errorToolResult = toolMessages.find( (m) => m.content?.[0]?.type === 'json' && @@ -162,22 +167,20 @@ describe('malformed tool call error handling', () => { expect(errorToolResult).toBeDefined() expect( (errorToolResult?.content?.[0] as any)?.value?.errorMessage, - ).toContain('Invalid JSON') + ).toContain('not found') }) - test('should handle multiple malformed tool calls', async () => { - const chunks = [ - // First malformed call - '\n{\n "cb_tool_name": "read_files",\n invalid\n}\n', - 'Some text between calls', - // Second malformed call - '\n{\n missing_quotes: value\n}\n', - getToolCallString('end_turn', {}), + test('should handle multiple unknown tool calls', async () => { + const chunks: StreamChunk[] = [ + createToolCallChunk('unknown_tool_1', { param: 'value1' }), + textChunk('Some text between calls'), + createToolCallChunk('unknown_tool_2', { param: 'value2' }), + createToolCallChunk('end_turn', {}), ] const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, stream, }) @@ -197,14 +200,14 @@ describe('malformed tool call error handling', () => { }) test('should preserve original toolResults array alongside message history', async () => { - const chunks = [ - '\n{\n "cb_tool_name": "read_files",\n malformed\n}\n', - getToolCallString('end_turn', {}), + const chunks: StreamChunk[] = [ + createToolCallChunk('unknown_tool_xyz', { param: 'value' }), + createToolCallChunk('end_turn', {}), ] const stream = createMockStream(chunks) - const result = await processStreamWithTools({ + const result = await processStream({ ...defaultParams, stream, }) @@ -228,14 +231,14 @@ describe('malformed tool call error handling', () => { }) test('should handle unknown tool names and add error to message history', async () => { - const chunks = [ - '\n{\n "cb_tool_name": "unknown_tool",\n "param": "value"\n}\n', - getToolCallString('end_turn', {}), + const chunks: StreamChunk[] = [ + createToolCallChunk('unknown_tool', { param: 'value' }), + createToolCallChunk('end_turn', {}), ] const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, stream, }) @@ -258,17 +261,17 @@ describe('malformed tool call error handling', () => { }) test('should not affect valid tool calls in message history', async () => { - const chunks = [ + const chunks: StreamChunk[] = [ // Valid tool call - getToolCallString('read_files', { paths: ['test.ts'] }), - // Malformed tool call - '\n{\n "cb_tool_name": "read_files",\n invalid\n}\n', - getToolCallString('end_turn', {}), + createToolCallChunk('read_files', { paths: ['test.ts'] }), + // Unknown tool call + createToolCallChunk('unknown_tool_xyz', { param: 'value' }), + createToolCallChunk('end_turn', {}), ] const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, requestFiles: async ({ filePaths }) => { return Object.fromEntries( @@ -299,15 +302,15 @@ describe('malformed tool call error handling', () => { expect(errorResults.length).toBeGreaterThan(0) }) - test('should handle stream with only malformed calls', async () => { - const chunks = [ - '\n{\n invalid1\n}\n', - '\n{\n invalid2\n}\n', + test('should handle stream with only unknown tool calls', async () => { + const chunks: StreamChunk[] = [ + createToolCallChunk('unknown_tool_1', { param: 'value1' }), + createToolCallChunk('unknown_tool_2', { param: 'value2' }), ] const stream = createMockStream(chunks) - await processStreamWithTools({ + await processStream({ ...defaultParams, stream, }) @@ -320,7 +323,7 @@ describe('malformed tool call error handling', () => { toolMessages.forEach((msg) => { expect(msg.content?.[0]?.type).toBe('json') expect((msg.content?.[0] as any)?.value?.errorMessage).toContain( - 'Invalid JSON', + 'not found', ) }) }) diff --git a/packages/agent-runtime/src/__tests__/n-parameter.test.ts b/packages/agent-runtime/src/__tests__/n-parameter.test.ts index c30ef339f..6cecb22f5 100644 --- a/packages/agent-runtime/src/__tests__/n-parameter.test.ts +++ b/packages/agent-runtime/src/__tests__/n-parameter.test.ts @@ -104,7 +104,6 @@ describe('n parameter and GENERATE_N functionality', () => { runAgentStepBaseParams = { ...agentRuntimeImpl, additionalToolDefinitions: () => Promise.resolve({}), - textOverride: null, runId: 'test-run-id', ancestorRunIds: [], repoId: undefined, @@ -122,6 +121,7 @@ describe('n parameter and GENERATE_N functionality', () => { spawnParams: undefined, system: 'Test system', signal: new AbortController().signal, + tools: {} } }) diff --git a/packages/agent-runtime/src/__tests__/prompt-caching-subagents.test.ts b/packages/agent-runtime/src/__tests__/prompt-caching-subagents.test.ts index 2cbf43e85..48e10960f 100644 --- a/packages/agent-runtime/src/__tests__/prompt-caching-subagents.test.ts +++ b/packages/agent-runtime/src/__tests__/prompt-caching-subagents.test.ts @@ -407,6 +407,83 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => { // allowing the LLM provider to cache and reuse the system prompt }) + it('should pass parent tools and add subagent tools message when inheritParentSystemPrompt is true', async () => { + const sessionState = getInitialSessionState(mockFileContext) + + // Create a child that inherits system prompt and has specific tools + const childWithTools: AgentTemplate = { + id: 'child-with-tools', + displayName: 'Child With Tools', + outputMode: 'last_message', + inputSchema: {}, + spawnerPrompt: '', + model: 'anthropic/claude-sonnet-4', + includeMessageHistory: false, + inheritParentSystemPrompt: true, + mcpServers: {}, + toolNames: ['read_files', 'code_search'], + spawnableAgents: [], + systemPrompt: '', + instructionsPrompt: '', + stepPrompt: '', + } + + mockLocalAgentTemplates['child-with-tools'] = childWithTools + + // Run parent agent first + await loopAgentSteps({ + ...loopAgentStepsBaseParams, + userInputId: 'test-parent', + prompt: 'Parent task', + agentType: 'parent', + agentState: sessionState.mainAgentState, + }) + + const parentMessages = capturedMessages + const parentSystemPrompt = (parentMessages[0].content[0] as TextPart).text + + // Mock parent tools + const parentTools = { read_files: {}, write_file: {}, code_search: {} } + + // Run child agent with inheritParentSystemPrompt=true and parentTools + capturedMessages = [] + const childAgentState = { + ...sessionState.mainAgentState, + agentId: 'child-agent', + agentType: 'child-with-tools' as const, + messageHistory: [], + } + + await loopAgentSteps({ + ...loopAgentStepsBaseParams, + userInputId: 'test-child', + prompt: 'Child task', + agentType: 'child-with-tools', + agentState: childAgentState, + parentSystemPrompt: parentSystemPrompt, + parentTools: parentTools as any, + }) + + const childMessages = capturedMessages + + // Verify child uses parent's system prompt + expect(childMessages[0].role).toBe('system') + expect((childMessages[0].content[0] as TextPart).text).toBe( + parentSystemPrompt, + ) + + // Verify there's an instructions prompt message that includes subagent tools info + const instructionsMessage = childMessages.find( + (msg) => + msg.role === 'user' && + msg.content[0].type === 'text' && + msg.content[0].text.includes('subagent') && + msg.content[0].text.includes('read_files') && + msg.content[0].text.includes('code_search'), + ) + expect(instructionsMessage).toBeTruthy() + }) + it('should support both inheritParentSystemPrompt and includeMessageHistory together', async () => { const sessionState = getInitialSessionState(mockFileContext) diff --git a/packages/agent-runtime/src/__tests__/read-docs-tool.test.ts b/packages/agent-runtime/src/__tests__/read-docs-tool.test.ts index 4b62cb588..1ed0d8b28 100644 --- a/packages/agent-runtime/src/__tests__/read-docs-tool.test.ts +++ b/packages/agent-runtime/src/__tests__/read-docs-tool.test.ts @@ -2,7 +2,6 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' -import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { afterEach, @@ -16,7 +15,7 @@ import { } from 'bun:test' import { disableLiveUserInputCheck } from '../live-user-inputs' -import { mockFileContext } from './test-utils' +import { createToolCallChunk, mockFileContext } from './test-utils' import researcherAgent from '../../../../.agents/researcher/researcher' import * as webApi from '../llm-api/codebuff-web-api' import { runAgentStep } from '../run-agent-step' @@ -34,13 +33,12 @@ let runAgentStepBaseParams: ParamsExcluding< 'fileContext' | 'localAgentTemplates' | 'agentState' | 'prompt' > -function mockAgentStream(content: string | string[]) { +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' + +function mockAgentStream(chunks: StreamChunk[]) { const mockPromptAiSdkStream = async function* ({}) { - if (typeof content === 'string') { - content = [content] - } - for (const chunk of content) { - yield { type: 'text' as const, text: chunk } + for (const chunk of chunks) { + yield chunk } return 'mock-message-id' } @@ -75,7 +73,6 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { runAgentStepBaseParams = { ...agentRuntimeImpl, additionalToolDefinitions: () => Promise.resolve({}), - textOverride: null, runId: 'test-run-id', ancestorRunIds: [], repoId: undefined, @@ -89,6 +86,7 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { agentType: 'researcher', spawnParams: undefined, signal: new AbortController().signal, + tools: {}, } }) @@ -108,13 +106,13 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { documentation: mockDocumentation, }) - const mockResponse = - getToolCallString('read_docs', { + mockAgentStream([ + createToolCallChunk('read_docs', { libraryTitle: 'React', topic: 'hooks', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -154,14 +152,14 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { documentation: mockDocumentation, }) - const mockResponse = - getToolCallString('read_docs', { + mockAgentStream([ + createToolCallChunk('read_docs', { libraryTitle: 'React', topic: 'hooks', max_tokens: 5000, - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -194,13 +192,13 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { const msg = 'No documentation found for "NonExistentLibrary"' spyOn(webApi, 'callDocsSearchAPI').mockResolvedValue({ error: msg }) - const mockResponse = - getToolCallString('read_docs', { + mockAgentStream([ + createToolCallChunk('read_docs', { libraryTitle: 'NonExistentLibrary', topic: 'blah', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -214,7 +212,6 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { const { agentState: newAgentState } = await runAgentStep({ ...runAgentStepBaseParams, - textOverride: null, fileContext: mockFileContextWithAgents, localAgentTemplates: agentTemplates, agentState, @@ -234,13 +231,13 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { error: 'Network timeout', }) - const mockResponse = - getToolCallString('read_docs', { + mockAgentStream([ + createToolCallChunk('read_docs', { libraryTitle: 'React', topic: 'hooks', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -272,13 +269,13 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { test('should include topic in error message when specified', async () => { spyOn(webApi, 'callDocsSearchAPI').mockResolvedValue({ error: 'No docs' }) - const mockResponse = - getToolCallString('read_docs', { + mockAgentStream([ + createToolCallChunk('read_docs', { libraryTitle: 'React', topic: 'server-components', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -312,13 +309,13 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { throw 'String error' }) - const mockResponse = - getToolCallString('read_docs', { + mockAgentStream([ + createToolCallChunk('read_docs', { libraryTitle: 'React', topic: 'hooks', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -355,13 +352,13 @@ describe('read_docs tool with researcher agent (via web API facade)', () => { creditsUsed: mockCreditsUsed, }) - const mockResponse = - getToolCallString('read_docs', { + mockAgentStream([ + createToolCallChunk('read_docs', { libraryTitle: 'React', topic: 'hooks', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { diff --git a/packages/agent-runtime/src/__tests__/run-agent-step-tools.test.ts b/packages/agent-runtime/src/__tests__/run-agent-step-tools.test.ts index 62e026c0f..d7f51fdc3 100644 --- a/packages/agent-runtime/src/__tests__/run-agent-step-tools.test.ts +++ b/packages/agent-runtime/src/__tests__/run-agent-step-tools.test.ts @@ -2,7 +2,6 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' -import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { assistantMessage, userMessage } from '@codebuff/common/util/messages' import db from '@codebuff/internal/db' @@ -22,6 +21,7 @@ import { disableLiveUserInputCheck } from '../live-user-inputs' import { runAgentStep } from '../run-agent-step' import { clearAgentGeneratorCache } from '../run-programmatic-step' import { asUserMessage } from '../util/messages' +import { createToolCallChunk } from './test-utils' import type { AgentTemplate } from '../templates/types' import type { @@ -116,22 +116,21 @@ describe('runAgentStep - set_output tool', () => { runAgentStepBaseParams = { ...agentRuntimeImpl, + additionalToolDefinitions: () => Promise.resolve({}), ancestorRunIds: [], clientSessionId: 'test-session', fileContext: mockFileContext, fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, repoId: undefined, repoUrl: undefined, runId: 'test-run-id', signal: new AbortController().signal, spawnParams: undefined, system: 'Test system prompt', - textOverride: null, + tools: {}, userId: TEST_USER_ID, userInputId: 'test-input', - - additionalToolDefinitions: () => Promise.resolve({}), - onResponseChunk: () => {}, } }) @@ -170,15 +169,10 @@ describe('runAgentStep - set_output tool', () => { } it('should set output with simple key-value pair', async () => { - const mockResponse = - getToolCallString('set_output', { - message: 'Hi', - }) + - '\n\n' + - getToolCallString('end_turn', {}) - runAgentStepBaseParams.promptAiSdkStream = async function* ({}) { - yield { type: 'text' as const, text: mockResponse } + yield createToolCallChunk('set_output', { message: 'Hi' }) + yield { type: 'text' as const, text: '\n\n' } + yield createToolCallChunk('end_turn', {}) return 'mock-message-id' } @@ -203,15 +197,13 @@ describe('runAgentStep - set_output tool', () => { }) it('should set output with complex data', async () => { - const mockResponse = - getToolCallString('set_output', { + runAgentStepBaseParams.promptAiSdkStream = async function* ({}) { + yield createToolCallChunk('set_output', { message: 'Analysis complete', status: 'success', findings: ['Bug in auth.ts', 'Missing validation'], - }) + getToolCallString('end_turn', {}) - - runAgentStepBaseParams.promptAiSdkStream = async function* ({}) { - yield { type: 'text' as const, text: mockResponse } + }) + yield createToolCallChunk('end_turn', {}) return 'mock-message-id' } @@ -238,14 +230,12 @@ describe('runAgentStep - set_output tool', () => { }) it('should replace existing output data', async () => { - const mockResponse = - getToolCallString('set_output', { + runAgentStepBaseParams.promptAiSdkStream = async function* ({}) { + yield createToolCallChunk('set_output', { newField: 'new value', existingField: 'updated value', - }) + getToolCallString('end_turn', {}) - - runAgentStepBaseParams.promptAiSdkStream = async function* ({}) { - yield { type: 'text' as const, text: mockResponse } + }) + yield createToolCallChunk('end_turn', {}) return 'mock-message-id' } @@ -275,11 +265,9 @@ describe('runAgentStep - set_output tool', () => { }) it('should handle empty output parameter', async () => { - const mockResponse = - getToolCallString('set_output', {}) + getToolCallString('end_turn', {}) - runAgentStepBaseParams.promptAiSdkStream = async function* ({}) { - yield { type: 'text' as const, text: mockResponse } + yield createToolCallChunk('set_output', {}) + yield createToolCallChunk('end_turn', {}) return 'mock-message-id' } @@ -490,13 +478,10 @@ describe('runAgentStep - set_output tool', () => { // Mock the LLM stream to spawn the inline agent runAgentStepBaseParams.promptAiSdkStream = async function* ({}) { - yield { - type: 'text' as const, - text: getToolCallString('spawn_agent_inline', { - agent_type: 'message-deleter-agent', - prompt: 'Delete the last two assistant messages', - }), - } + yield createToolCallChunk('spawn_agent_inline', { + agent_type: 'message-deleter-agent', + prompt: 'Delete the last two assistant messages', + }) return 'mock-message-id' } diff --git a/packages/agent-runtime/src/__tests__/run-programmatic-step.test.ts b/packages/agent-runtime/src/__tests__/run-programmatic-step.test.ts index df7ded81d..3abfee451 100644 --- a/packages/agent-runtime/src/__tests__/run-programmatic-step.test.ts +++ b/packages/agent-runtime/src/__tests__/run-programmatic-step.test.ts @@ -259,16 +259,12 @@ describe('runProgrammaticStep', () => { // Check that no tool call chunk was sent for add_message const addMessageToolCallChunk = sentChunks.find( (chunk) => - chunk.includes('add_message') && chunk.includes('Hello world'), + typeof chunk === 'string' && + chunk.includes('add_message') && + chunk.includes('Hello world'), ) expect(addMessageToolCallChunk).toBeUndefined() - // Check that tool call chunk WAS sent for read_files (normal behavior) - const readFilesToolCallChunk = sentChunks.find( - (chunk) => chunk.includes('read_files') && chunk.includes('test.txt'), - ) - expect(readFilesToolCallChunk).toBeDefined() - // Verify final message history doesn't contain add_message tool call const addMessageToolCallInHistory = result.agentState.messageHistory.find( (msg) => @@ -778,12 +774,28 @@ describe('runProgrammaticStep', () => { const result = await runProgrammaticStep(mockParams) - expect(result.agentState.messageHistory).toEqual([ - ...previousMessageHistory, - assistantMessage( - '\n{\n "cb_tool_name": "end_turn",\n "cb_easp": true\n}\n', - ), - ]) + // Verify previous messages are preserved + expect(result.agentState.messageHistory.length).toBeGreaterThanOrEqual( + previousMessageHistory.length, + ) + // Check first messages match + expect(result.agentState.messageHistory[0]).toEqual( + previousMessageHistory[0], + ) + expect(result.agentState.messageHistory[1]).toEqual( + previousMessageHistory[1], + ) + // Verify an assistant message was added (with native tools, this is a tool-call structure) + const lastMessage = + result.agentState.messageHistory[ + result.agentState.messageHistory.length - 1 + ] + expect(lastMessage.role).toBe('assistant') + // With native tools, the tool call is structured differently than the old XML format + expect(lastMessage.content[0]).toMatchObject({ + type: 'tool-call', + toolName: 'end_turn', + }) }) }) @@ -1433,6 +1445,244 @@ describe('runProgrammaticStep', () => { }) }) + describe('yield value validation', () => { + it('should reject invalid yield values', async () => { + const mockGenerator = (function* () { + yield { invalid: 'value' } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const responseChunks: any[] = [] + mockParams.onResponseChunk = (chunk) => responseChunks.push(chunk) + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject yield values with wrong types', async () => { + const mockGenerator = (function* () { + yield { type: 'STEP_TEXT', text: 123 } as any // text should be string + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const responseChunks: any[] = [] + mockParams.onResponseChunk = (chunk) => responseChunks.push(chunk) + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject GENERATE_N with non-positive n', async () => { + const mockGenerator = (function* () { + yield { type: 'GENERATE_N', n: 0 } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const responseChunks: any[] = [] + mockParams.onResponseChunk = (chunk) => responseChunks.push(chunk) + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject GENERATE_N with negative n', async () => { + const mockGenerator = (function* () { + yield { type: 'GENERATE_N', n: -5 } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const responseChunks: any[] = [] + mockParams.onResponseChunk = (chunk) => responseChunks.push(chunk) + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should accept valid STEP literal', async () => { + const mockGenerator = (function* () { + yield 'STEP' + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(false) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept valid STEP_ALL literal', async () => { + const mockGenerator = (function* () { + yield 'STEP_ALL' + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(false) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept valid STEP_TEXT object and continue to next step', async () => { + // STEP_TEXT continues to the next step of handleSteps (unlike STEP which breaks) + // So we need additional yields after STEP_TEXT to test the continuation + const mockGenerator = (function* () { + yield { type: 'STEP_TEXT', text: 'Custom response text' } + yield { toolName: 'end_turn', input: {} } + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + // Should end turn because the generator continues after STEP_TEXT and reaches end_turn + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept valid GENERATE_N object', async () => { + const mockGenerator = (function* () { + yield { type: 'GENERATE_N', n: 3 } + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(false) + expect(result.generateN).toBe(3) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept valid tool call object', async () => { + const mockGenerator = (function* () { + yield { toolName: 'read_files', input: { paths: ['test.txt'] } } + yield { toolName: 'end_turn', input: {} } + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should accept tool call with includeToolCall option', async () => { + const mockGenerator = (function* () { + yield { + toolName: 'read_files', + input: { paths: ['test.txt'] }, + includeToolCall: false, + } + yield { toolName: 'end_turn', input: {} } + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toBeUndefined() + }) + + it('should reject random string values', async () => { + const mockGenerator = (function* () { + yield 'INVALID_STEP' as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject null yield values', async () => { + const mockGenerator = (function* () { + yield null as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject undefined yield values', async () => { + const mockGenerator = (function* () { + yield undefined as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject tool call without toolName', async () => { + const mockGenerator = (function* () { + yield { input: { paths: ['test.txt'] } } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + + it('should reject tool call without input', async () => { + const mockGenerator = (function* () { + yield { toolName: 'read_files' } as any + })() as StepGenerator + + mockTemplate.handleSteps = () => mockGenerator + + const result = await runProgrammaticStep(mockParams) + + expect(result.endTurn).toBe(true) + expect(result.agentState.output?.error).toContain( + 'Invalid yield value from handleSteps', + ) + }) + }) + describe('logging and context', () => { it('should log agent execution start', async () => { const mockGenerator = (function* () { diff --git a/packages/agent-runtime/src/__tests__/spawn-agents-message-history.test.ts b/packages/agent-runtime/src/__tests__/spawn-agents-message-history.test.ts index b47223129..cfb92f380 100644 --- a/packages/agent-runtime/src/__tests__/spawn-agents-message-history.test.ts +++ b/packages/agent-runtime/src/__tests__/spawn-agents-message-history.test.ts @@ -52,7 +52,7 @@ describe('Spawn Agents Message History', () => { assistantMessage('Mock agent response'), ], }, - output: { type: 'lastMessage', value: 'Mock agent response' }, + output: { type: 'lastMessage', value: [assistantMessage('Mock agent response')] }, } }) @@ -68,6 +68,7 @@ describe('Spawn Agents Message History', () => { sendSubagentChunk: mockSendSubagentChunk, signal: new AbortController().signal, system: 'Test system prompt', + tools: {}, userId: TEST_USER_ID, userInputId: 'test-input', writeToClient: () => {}, diff --git a/packages/agent-runtime/src/__tests__/spawn-agents-permissions.test.ts b/packages/agent-runtime/src/__tests__/spawn-agents-permissions.test.ts index 3f827e2a6..ef3ed0e7b 100644 --- a/packages/agent-runtime/src/__tests__/spawn-agents-permissions.test.ts +++ b/packages/agent-runtime/src/__tests__/spawn-agents-permissions.test.ts @@ -67,6 +67,7 @@ describe('Spawn Agents Permissions', () => { sendSubagentChunk: mockSendSubagentChunk, signal: new AbortController().signal, system: 'Test system prompt', + tools: {}, userId: TEST_USER_ID, userInputId: 'test-input', writeToClient: () => {}, @@ -85,7 +86,7 @@ describe('Spawn Agents Permissions', () => { ...options.agentState, messageHistory: [assistantMessage('Mock agent response')], }, - output: { type: 'lastMessage', value: 'Mock agent response' }, + output: { type: 'lastMessage', value: [assistantMessage('Mock agent response')] }, } }) }) diff --git a/packages/agent-runtime/src/__tests__/subagent-streaming.test.ts b/packages/agent-runtime/src/__tests__/subagent-streaming.test.ts index 1bd1a6970..d65c9f10a 100644 --- a/packages/agent-runtime/src/__tests__/subagent-streaming.test.ts +++ b/packages/agent-runtime/src/__tests__/subagent-streaming.test.ts @@ -70,6 +70,7 @@ describe('Subagent Streaming', () => { sendSubagentChunk: mockSendSubagentChunk, signal: new AbortController().signal, system: 'Test system prompt', + tools: {}, userId: TEST_USER_ID, userInputId: 'test-input', writeToClient: mockWriteToClient, @@ -96,7 +97,7 @@ describe('Subagent Streaming', () => { ...options.agentState, messageHistory: [assistantMessage('Test response from subagent')], }, - output: { type: 'lastMessage', value: 'Test response from subagent' }, + output: { type: 'lastMessage', value: [assistantMessage('Test response from subagent')] }, } }) diff --git a/packages/agent-runtime/src/__tests__/test-utils.ts b/packages/agent-runtime/src/__tests__/test-utils.ts index b2e1fc6be..66382f3cb 100644 --- a/packages/agent-runtime/src/__tests__/test-utils.ts +++ b/packages/agent-runtime/src/__tests__/test-utils.ts @@ -1,5 +1,45 @@ +import { generateCompactId } from '@codebuff/common/util/string' + +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' import type { ProjectFileContext } from '@codebuff/common/util/file' +/** + * Creates a native tool call stream chunk for testing. + * This replaces the old getToolCallString() approach which generated XML format. + */ +export function createToolCallChunk( + toolName: T, + input: Record, + toolCallId?: string, +): StreamChunk { + return { + type: 'tool-call', + toolName, + toolCallId: toolCallId ?? generateCompactId(), + input, + } +} + +/** + * Creates a mock stream that yields native tool call chunks. + * Use this instead of streams that yield text with XML tool calls. + */ +export function createMockStreamWithToolCalls( + chunks: (string | { toolName: string; input: Record })[], +): AsyncGenerator { + async function* generator(): AsyncGenerator { + for (const chunk of chunks) { + if (typeof chunk === 'string') { + yield { type: 'text' as const, text: chunk } + } else { + yield createToolCallChunk(chunk.toolName, chunk.input) + } + } + return 'mock-message-id' + } + return generator() +} + export const mockFileContext: ProjectFileContext = { projectRoot: '/test', cwd: '/test', diff --git a/packages/agent-runtime/src/__tests__/tool-stream-parser.test.ts b/packages/agent-runtime/src/__tests__/tool-stream-parser.test.ts index b5c5dfb23..ef73368e6 100644 --- a/packages/agent-runtime/src/__tests__/tool-stream-parser.test.ts +++ b/packages/agent-runtime/src/__tests__/tool-stream-parser.test.ts @@ -1,22 +1,25 @@ import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' -import { endsAgentStepParam } from '@codebuff/common/tools/constants' -import { getToolCallString } from '@codebuff/common/tools/utils' import { beforeEach, describe, expect, it } from 'bun:test' -import { globalStopSequence } from '../constants' -import { processStreamWithTags } from '../tool-stream-parser' +import { processStreamWithTools } from '../tool-stream-parser' +import { createToolCallChunk } from './test-utils' import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' describe('processStreamWithTags', () => { - async function* createMockStream(chunks: string[]) { + async function* createMockStream(chunks: StreamChunk[]) { for (const chunk of chunks) { - yield { type: 'text' as const, text: chunk } + yield chunk } return 'mock-message-id' } + function textChunk(text: string): StreamChunk { + return { type: 'text' as const, text } + } + let agentRuntimeImpl: AgentRuntimeDeps beforeEach(() => { @@ -24,8 +27,8 @@ describe('processStreamWithTags', () => { }) it('should handle basic tool call parsing', async () => { - const streamChunks = [ - '\n{\n "cb_tool_name": "test_tool",\n "param1": "value1"\n}\n', + const streamChunks: StreamChunk[] = [ + createToolCallChunk('test_tool', { param1: 'value1' }), ] const stream = createMockStream(streamChunks) @@ -61,7 +64,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -86,14 +89,12 @@ describe('processStreamWithTags', () => { params: { param1: 'value1' }, }, ]) - expect(result).toEqual(streamChunks) }) - it('should handle tool calls split across chunks', async () => { - const streamChunks = [ - '\n{\n "cb_tool_name": "test', - '_tool",\n "param1": "val', - 'ue1"\n}\n', + it('should handle tool calls with text before', async () => { + const streamChunks: StreamChunk[] = [ + textChunk('Some text before tool call'), + createToolCallChunk('test_tool', { param1: 'value1' }), ] const stream = createMockStream(streamChunks) @@ -129,7 +130,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -154,14 +155,13 @@ describe('processStreamWithTags', () => { params: { param1: 'value1' }, }, ]) - expect(result).toEqual(streamChunks) }) it('should handle multiple tool calls in sequence', async () => { - const streamChunks = [ - '\n{\n "cb_tool_name": "tool1",\n "param1": "value1"\n}\n', - 'text between tools', - '\n{\n "cb_tool_name": "tool2",\n "param2": "value2"\n}\n', + const streamChunks: StreamChunk[] = [ + createToolCallChunk('tool1', { param1: 'value1' }), + textChunk('text between tools'), + createToolCallChunk('tool2', { param2: 'value2' }), ] const stream = createMockStream(streamChunks) @@ -206,7 +206,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -241,12 +241,11 @@ describe('processStreamWithTags', () => { params: { param2: 'value2' }, }, ]) - expect(result).toEqual(streamChunks) }) - it('should handle malformed JSON and call onError', async () => { - const streamChunks = [ - '\n{\n "cb_tool_name": "test_tool",\n "param1": invalid_json\n}\n', + it('should handle unknown tool names via defaultProcessor', async () => { + const streamChunks: StreamChunk[] = [ + createToolCallChunk('unknown_tool', { param1: 'value1' }), ] const stream = createMockStream(streamChunks) @@ -268,68 +267,6 @@ describe('processStreamWithTags', () => { events.push({ name, error, type: 'error' }) } - const result: string[] = [] - const responseChunks: any[] = [] - - function onResponseChunk(chunk: any) { - responseChunks.push(chunk) - } - - function defaultProcessor(toolName: string) { - return { - onTagStart: () => {}, - onTagEnd: () => {}, - } - } - - for await (const chunk of processStreamWithTags({ - ...agentRuntimeImpl, - stream, - processors, - defaultProcessor, - onError, - onResponseChunk, - })) { - if (chunk.type === 'text') { - result.push(chunk.text) - } - } - - expect(events).toEqual([ - { - name: 'parse_error', - error: expect.stringContaining('Unexpected identifier'), - type: 'error', - }, - ]) - expect(result).toEqual(streamChunks) - }) - - it('should handle unknown tool names and call onError', async () => { - const streamChunks = [ - '\n{\n "cb_tool_name": "unknown_tool",\n "param1": "value1"\n}\n', - ] - const stream = createMockStream(streamChunks) - - const events: any[] = [] - - const processors = { - test_tool: { - params: ['param1'] as string[], - onTagStart: (tagName: string, attributes: Record) => { - events.push({ tagName, type: 'start', attributes }) - }, - onTagEnd: (tagName: string, params: Record) => { - events.push({ tagName, type: 'end', params }) - }, - }, - } - - function onError(name: string, error: string) { - events.push({ name, error, type: 'error' }) - } - - const result: string[] = [] const responseChunks: any[] = [] function onResponseChunk(chunk: any) { @@ -349,7 +286,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -357,9 +294,7 @@ describe('processStreamWithTags', () => { onError, onResponseChunk, })) { - if (chunk.type === 'text') { - result.push(chunk.text) - } + // consume stream } expect(events).toEqual([ @@ -369,12 +304,16 @@ describe('processStreamWithTags', () => { type: 'error', }, ]) - expect(result).toEqual(streamChunks) }) it('should handle tool calls with complex parameters', async () => { - const streamChunks = [ - '\n{\n "cb_tool_name": "complex_tool",\n "array_param": ["item1", "item2"],\n "object_param": {"nested": "value"},\n "boolean_param": true,\n "number_param": 42\n}\n', + const streamChunks: StreamChunk[] = [ + createToolCallChunk('complex_tool', { + array_param: ['item1', 'item2'], + object_param: { nested: 'value' }, + boolean_param: true, + number_param: 42, + }), ] const stream = createMockStream(streamChunks) @@ -415,7 +354,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -445,14 +384,13 @@ describe('processStreamWithTags', () => { }, }, ]) - expect(result).toEqual(streamChunks) }) it('should handle text content mixed with tool calls', async () => { - const streamChunks = [ - 'Some text before', - '\n{\n "cb_tool_name": "test_tool",\n "param1": "value1"\n}\n', - 'Some text after', + const streamChunks: StreamChunk[] = [ + textChunk('Some text before'), + createToolCallChunk('test_tool', { param1: 'value1' }), + textChunk('Some text after'), ] const stream = createMockStream(streamChunks) @@ -488,7 +426,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -513,81 +451,10 @@ describe('processStreamWithTags', () => { params: { param1: 'value1' }, }, ]) - expect(result).toEqual(streamChunks) - }) - - it('should handle incomplete tool calls at end of stream due to stop sequence', async () => { - const streamChunks = [ - '\n{\n "cb_tool_name": "test_tool",\n "param1": "value1",\n ', - // Missing closing tag - ] - const stream = createMockStream(streamChunks) - - const events: any[] = [] - - const processors = { - test_tool: { - params: ['param1'] as string[], - onTagStart: (tagName: string, attributes: Record) => { - events.push({ tagName, type: 'start', attributes }) - }, - onTagEnd: (tagName: string, params: Record) => { - events.push({ tagName, type: 'end', params }) - }, - }, - } - - function onError(name: string, error: string) { - events.push({ name, error, type: 'error' }) - } - - const result: string[] = [] - const responseChunks: any[] = [] - - function onResponseChunk(chunk: any) { - responseChunks.push(chunk) - } - - function defaultProcessor(toolName: string) { - return { - onTagStart: () => {}, - onTagEnd: () => {}, - } - } - - for await (const chunk of processStreamWithTags({ - ...agentRuntimeImpl, - stream, - processors, - defaultProcessor, - onError, - onResponseChunk, - })) { - if (chunk.type === 'text') { - result.push(chunk.text) - } - } - - // Should complete the tool call with the completion suffix - expect(events).toEqual([ - { - tagName: 'test_tool', - type: 'start', - attributes: {}, - }, - { - tagName: 'test_tool', - type: 'end', - params: { param1: 'value1', [endsAgentStepParam]: true }, - }, - ]) - - // Should include the completion suffix in the result - expect(result.join('')).toContain(globalStopSequence) }) it('should handle empty stream', async () => { - const streamChunks: string[] = [] + const streamChunks: StreamChunk[] = [] const stream = createMockStream(streamChunks) const events: any[] = [] @@ -612,7 +479,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -630,7 +497,10 @@ describe('processStreamWithTags', () => { }) it('should handle stream with only text content', async () => { - const streamChunks = ['Just some text', ' with no tool calls'] + const streamChunks: StreamChunk[] = [ + textChunk('Just some text'), + textChunk(' with no tool calls'), + ] const stream = createMockStream(streamChunks) const events: any[] = [] @@ -655,7 +525,7 @@ describe('processStreamWithTags', () => { } } - for await (const chunk of processStreamWithTags({ + for await (const chunk of processStreamWithTools({ ...agentRuntimeImpl, stream, processors, @@ -669,155 +539,5 @@ describe('processStreamWithTags', () => { } expect(events).toEqual([]) - expect(result).toEqual(streamChunks) - }) - - it('should handle tool call with missing cb_tool_name', async () => { - const streamChunks = [ - '\n{\n "param1": "value1"\n}\n', - ] - const stream = createMockStream(streamChunks) - - const events: any[] = [] - - const processors = { - test_tool: { - params: ['param1'] as string[], - onTagStart: (tagName: string, attributes: Record) => { - events.push({ tagName, type: 'start', attributes }) - }, - onTagEnd: (tagName: string, params: Record) => { - events.push({ tagName, type: 'end', params }) - }, - }, - } - - function onError(name: string, error: string) { - events.push({ name, error, type: 'error' }) - } - - const result: string[] = [] - const responseChunks: any[] = [] - - function onResponseChunk(chunk: any) { - responseChunks.push(chunk) - } - - function defaultProcessor(toolName: string) { - // Track when defaultProcessor is called (which means tool not found in processors) - events.push({ - name: toolName, - error: `Tool not found: ${toolName}`, - type: 'error', - }) - return { - onTagStart: () => {}, - onTagEnd: () => {}, - } - } - - for await (const chunk of processStreamWithTags({ - ...agentRuntimeImpl, - stream, - processors, - defaultProcessor, - onError, - onResponseChunk, - })) { - if (chunk.type === 'text') { - result.push(chunk.text) - } - } - - expect(events).toEqual([ - { - name: 'parse_error', - error: - 'Unknown tool undefined for tool call: {\n "param1": "value1"\n}', - type: 'error', - }, - ]) - expect(result).toEqual(streamChunks) - }) - - describe('real world examples', () => { - it('should handle within tool contents', async () => { - const toolName = 'write_file' - const streamChunks = [ - getToolCallString(toolName, { - path: 'backend/src/__tests__/xml-stream-parser.test.ts', - instructions: - 'Write comprehensive unit tests for the processStreamWithTags function', - content: - "import { describe, expect, it } from 'bun:test'\nimport { toolSchema } from '@codebuff/common/constants/tools'\nimport { processStreamWithTags } from '../xml-stream-parser'\n\ndescribe('processStreamWithTags', () => {\n async function* createMockStream(chunks: string[]) {\n for (const chunk of chunks) {\n yield chunk\n }\n }\n\n it('should handle basic tool call parsing', async () => {\n const streamChunks = [\n '\\n{\\n \"cb_tool_name\": \"test_tool\",\\n \"param1\": \"value1\"\\n}\\n',\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n test_tool: {\n params: ['param1'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([\n {\n tagName: 'test_tool',\n type: 'start',\n attributes: {},\n },\n {\n tagName: 'test_tool',\n type: 'end',\n params: { param1: 'value1' },\n },\n ])\n expect(result).toEqual(streamChunks)\n })\n\n it('should handle tool calls split across chunks', async () => {\n const streamChunks = [\n '\\n{\\n \"cb_tool_name\": \"test',\n '_tool\",\\n \"param1\": \"val',\n 'ue1\"\\n}\\n',\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n test_tool: {\n params: ['param1'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([\n {\n tagName: 'test_tool',\n type: 'start',\n attributes: {},\n },\n {\n tagName: 'test_tool',\n type: 'end',\n params: { param1: 'value1' },\n },\n ])\n expect(result).toEqual(streamChunks)\n })\n\n it('should handle multiple tool calls in sequence', async () => {\n const streamChunks = [\n '\\n{\\n \"cb_tool_name\": \"tool1\",\\n \"param1\": \"value1\"\\n}\\n',\n 'text between tools',\n '\\n{\\n \"cb_tool_name\": \"tool2\",\\n \"param2\": \"value2\"\\n}\\n',\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n tool1: {\n params: ['param1'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n tool2: {\n params: ['param2'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([\n {\n tagName: 'tool1',\n type: 'start',\n attributes: {},\n },\n {\n tagName: 'tool1',\n type: 'end',\n params: { param1: 'value1' },\n },\n {\n tagName: 'tool2',\n type: 'start',\n attributes: {},\n },\n {\n tagName: 'tool2',\n type: 'end',\n params: { param2: 'value2' },\n },\n ])\n expect(result).toEqual(streamChunks)\n })\n\n it('should handle malformed JSON and call onError', async () => {\n const streamChunks = [\n '\\n{\\n \"cb_tool_name\": \"test_tool\",\\n \"param1\": invalid_json\\n}\\n',\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n test_tool: {\n params: ['param1'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error, type: 'error' })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([\n {\n name: 'parse_error',\n error: expect.stringContaining('Unexpected token'),\n type: 'error',\n },\n ])\n expect(result).toEqual(streamChunks)\n })\n\n it('should handle unknown tool names and call onError', async () => {\n const streamChunks = [\n '\\n{\\n \"cb_tool_name\": \"unknown_tool\",\\n \"param1\": \"value1\"\\n}\\n',\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n test_tool: {\n params: ['param1'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error, type: 'error' })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([\n {\n name: 'unknown_tool',\n error: 'Tool not found: unknown_tool',\n type: 'error',\n },\n ])\n expect(result).toEqual(streamChunks)\n })\n\n it('should handle tool calls with complex parameters', async () => {\n const streamChunks = [\n '\\n{\\n \"cb_tool_name\": \"complex_tool\",\\n \"array_param\": [\"item1\", \"item2\"],\\n \"object_param\": {\"nested\": \"value\"},\\n \"boolean_param\": true,\\n \"number_param\": 42\\n}\\n',\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n complex_tool: {\n params: ['array_param', 'object_param', 'boolean_param', 'number_param'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error, type: 'error' })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([\n {\n tagName: 'complex_tool',\n type: 'start',\n attributes: {},\n },\n {\n tagName: 'complex_tool',\n type: 'end',\n params: {\n array_param: ['item1', 'item2'],\n object_param: { nested: 'value' },\n boolean_param: true,\n number_param: 42,\n },\n },\n ])\n expect(result).toEqual(streamChunks)\n })\n\n it('should handle text content mixed with tool calls', async () => {\n const streamChunks = [\n 'Some text before',\n '\\n{\\n \"cb_tool_name\": \"test_tool\",\\n \"param1\": \"value1\"\\n}\\n',\n 'Some text after',\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n test_tool: {\n params: ['param1'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error, type: 'error' })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([\n {\n tagName: 'test_tool',\n type: 'start',\n attributes: {},\n },\n {\n tagName: 'test_tool',\n type: 'end',\n params: { param1: 'value1' },\n },\n ])\n expect(result).toEqual(streamChunks)\n })\n\n it('should handle incomplete tool calls at end of stream', async () => {\n const streamChunks = [\n '\\n{\\n \"cb_tool_name\": \"test_tool\",\\n \"param1\": \"value1\"\\n}',\n // Missing closing tag\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n test_tool: {\n params: ['param1'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error, type: 'error' })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n // Should complete the tool call with the completion suffix\n expect(events).toEqual([\n {\n tagName: 'test_tool',\n type: 'start',\n attributes: {},\n },\n {\n tagName: 'test_tool',\n type: 'end',\n params: { param1: 'value1' },\n },\n ])\n \n // Should include the completion suffix in the result\n expect(result.join('')).toContain('\"codebuff_easp\": true')\n })\n\n it('should handle empty stream', async () => {\n const streamChunks: string[] = []\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {}\n\n function onError(name: string, error: string) {\n events.push({ name, error, type: 'error' })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([])\n expect(result).toEqual([])\n })\n\n it('should handle stream with only text content', async () => {\n const streamChunks = ['Just some text', ' with no tool calls']\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {}\n\n function onError(name: string, error: string) {\n events.push({ name, error, type: 'error' })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([])\n expect(result).toEqual(streamChunks)\n })\n\n it('should handle tool call with missing cb_tool_name', async () => {\n const streamChunks = [\n '\\n{\\n \"param1\": \"value1\"\\n}\\n',\n ]\n const stream = createMockStream(streamChunks)\n\n const events: any[] = []\n\n const processors = {\n test_tool: {\n params: ['param1'] as string[],\n onTagStart: (tagName: string, attributes: Record) => {\n events.push({ tagName, type: 'start', attributes })\n },\n onTagEnd: (tagName: string, params: Record) => {\n events.push({ tagName, type: 'end', params })\n },\n },\n }\n\n function onError(name: string, error: string) {\n events.push({ name, error, type: 'error' })\n }\n\n const result = []\n for await (const chunk of processStreamWithTags(\n stream,\n processors,\n onError\n )) {\n result.push(chunk)\n }\n\n expect(events).toEqual([\n {\n name: 'undefined',\n error: 'Tool not found: undefined',\n type: 'error',\n },\n ])\n expect(result).toEqual(streamChunks)\n })\n})", - }), - ] - - const stream = createMockStream(streamChunks) - - const events: any[] = [] - - const processors = { - write_file: { - params: ['path', 'instructions', 'content'] as string[], - onTagStart: (tagName: string, attributes: Record) => { - events.push({ tagName, type: 'start', attributes }) - }, - onTagEnd: (tagName: string, params: Record) => { - events.push({ tagName, type: 'end', params }) - }, - }, - } - - function onError(name: string, error: string) { - events.push({ name, error }) - } - - const result: string[] = [] - const responseChunks: any[] = [] - - function onResponseChunk(chunk: any) { - responseChunks.push(chunk) - } - - function defaultProcessor(toolName: string) { - return { - onTagStart: () => {}, - onTagEnd: () => {}, - } - } - - for await (const chunk of processStreamWithTags({ - ...agentRuntimeImpl, - stream, - processors, - defaultProcessor, - onError, - onResponseChunk, - })) { - if (chunk.type === 'text') { - result.push(chunk.text) - } - } - - expect(events).toEqual([ - { - attributes: {}, - tagName: 'write_file', - type: 'start', - }, - { - params: { - content: expect.stringContaining(''), - instructions: - 'Write comprehensive unit tests for the processStreamWithTags function', - path: 'backend/src/__tests__/xml-stream-parser.test.ts', - }, - tagName: 'write_file', - type: 'end', - }, - ]) - expect(result).toEqual(streamChunks) - }) }) }) diff --git a/packages/agent-runtime/src/__tests__/web-search-tool.test.ts b/packages/agent-runtime/src/__tests__/web-search-tool.test.ts index 417f57819..34becae01 100644 --- a/packages/agent-runtime/src/__tests__/web-search-tool.test.ts +++ b/packages/agent-runtime/src/__tests__/web-search-tool.test.ts @@ -2,7 +2,6 @@ import * as bigquery from '@codebuff/bigquery' import * as analytics from '@codebuff/common/analytics' import { TEST_USER_ID } from '@codebuff/common/old-constants' import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime' -import { getToolCallString } from '@codebuff/common/tools/utils' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { success } from '@codebuff/common/util/error' import { @@ -17,7 +16,7 @@ import { } from 'bun:test' import { disableLiveUserInputCheck } from '../live-user-inputs' -import { mockFileContext } from './test-utils' +import { createToolCallChunk, mockFileContext } from './test-utils' import researcherAgent from '../../../../.agents/researcher/researcher' import * as webApi from '../llm-api/codebuff-web-api' import { runAgentStep } from '../run-agent-step' @@ -34,13 +33,12 @@ let runAgentStepBaseParams: ParamsExcluding< typeof runAgentStep, 'localAgentTemplates' | 'agentState' | 'prompt' > -function mockAgentStream(content: string | string[]) { +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' + +function mockAgentStream(chunks: StreamChunk[]) { runAgentStepBaseParams.promptAiSdkStream = async function* ({}) { - if (typeof content === 'string') { - content = [content] - } - for (const chunk of content) { - yield { type: 'text' as const, text: chunk } + for (const chunk of chunks) { + yield chunk } return 'mock-message-id' } @@ -61,23 +59,22 @@ describe('web_search tool with researcher agent (via web API facade)', () => { runAgentStepBaseParams = { ...agentRuntimeImpl, + additionalToolDefinitions: () => Promise.resolve({}), agentType: 'researcher', ancestorRunIds: [], clientSessionId: 'test-session', fileContext: mockFileContext, fingerprintId: 'test-fingerprint', + onResponseChunk: () => {}, repoId: undefined, repoUrl: undefined, runId: 'test-run-id', signal: new AbortController().signal, spawnParams: undefined, system: 'Test system prompt', - textOverride: null, + tools: {}, userId: TEST_USER_ID, userInputId: 'test-input', - - additionalToolDefinitions: () => Promise.resolve({}), - onResponseChunk: () => {}, } // Mock analytics and tracing @@ -116,11 +113,10 @@ describe('web_search tool with researcher agent (via web API facade)', () => { result: mockSearchResult, }) - const mockResponse = - getToolCallString('web_search', { query: 'test query' }) + - getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + mockAgentStream([ + createToolCallChunk('web_search', { query: 'test query' }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -151,11 +147,10 @@ describe('web_search tool with researcher agent (via web API facade)', () => { result: mockSearchResult, }) - const mockResponse = - getToolCallString('web_search', { query: 'Next.js 15 new features' }) + - getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + mockAgentStream([ + createToolCallChunk('web_search', { query: 'Next.js 15 new features' }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -188,13 +183,13 @@ describe('web_search tool with researcher agent (via web API facade)', () => { result: 'Deep result', }) - const mockResponse = - getToolCallString('web_search', { + mockAgentStream([ + createToolCallChunk('web_search', { query: 'RSC tutorial', depth: 'deep', - }) + getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -222,11 +217,10 @@ describe('web_search tool with researcher agent (via web API facade)', () => { const msg = 'No search results found for "very obscure"' spyOn(webApi, 'callWebSearchAPI').mockResolvedValue({ error: msg }) - const mockResponse = - getToolCallString('web_search', { query: 'very obscure' }) + - getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + mockAgentStream([ + createToolCallChunk('web_search', { query: 'very obscure' }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -259,11 +253,10 @@ describe('web_search tool with researcher agent (via web API facade)', () => { error: 'Linkup API timeout', }) - const mockResponse = - getToolCallString('web_search', { query: 'test query' }) + - getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + mockAgentStream([ + createToolCallChunk('web_search', { query: 'test query' }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -296,11 +289,10 @@ describe('web_search tool with researcher agent (via web API facade)', () => { throw 'String error' }) - const mockResponse = - getToolCallString('web_search', { query: 'test query' }) + - getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + mockAgentStream([ + createToolCallChunk('web_search', { query: 'test query' }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -334,11 +326,10 @@ describe('web_search tool with researcher agent (via web API facade)', () => { result: mockSearchResult, }) - const mockResponse = - getToolCallString('web_search', { query: 'test formatting' }) + - getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + mockAgentStream([ + createToolCallChunk('web_search', { query: 'test formatting' }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { @@ -374,11 +365,10 @@ describe('web_search tool with researcher agent (via web API facade)', () => { creditsUsed: mockCreditsUsed, }) - const mockResponse = - getToolCallString('web_search', { query: 'test query' }) + - getToolCallString('end_turn', {}) - - mockAgentStream(mockResponse) + mockAgentStream([ + createToolCallChunk('web_search', { query: 'test query' }), + createToolCallChunk('end_turn', {}), + ]) const sessionState = getInitialSessionState(mockFileContextWithAgents) const agentState = { diff --git a/packages/agent-runtime/src/llm-api/relace-api.ts b/packages/agent-runtime/src/llm-api/relace-api.ts index 8ba89da43..e9a01f358 100644 --- a/packages/agent-runtime/src/llm-api/relace-api.ts +++ b/packages/agent-runtime/src/llm-api/relace-api.ts @@ -19,9 +19,10 @@ export async function promptRelaceAI( const { initialCode, editSnippet, instructions, promptAiSdk, logger } = params try { + const { tools, ...rest } = params // const model = 'relace-apply-2.5-lite' const content = await promptAiSdk({ - ...params, + ...rest, model: 'relace/relace-apply-3', messages: [ userMessage( diff --git a/packages/agent-runtime/src/prompt-agent-stream.ts b/packages/agent-runtime/src/prompt-agent-stream.ts index 3447b5948..4a5272f83 100644 --- a/packages/agent-runtime/src/prompt-agent-stream.ts +++ b/packages/agent-runtime/src/prompt-agent-stream.ts @@ -12,6 +12,7 @@ import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsOf } from '@codebuff/common/types/function-params' import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { OpenRouterProviderOptions } from '@codebuff/internal/openrouter-ai-sdk' +import type { ToolSet } from 'ai' export const getAgentStreamFromTemplate = (params: { agentId?: string @@ -20,12 +21,13 @@ export const getAgentStreamFromTemplate = (params: { fingerprintId: string includeCacheControl?: boolean liveUserInputRecord: UserInputRecord + localAgentTemplates: Record logger: Logger messages: Message[] runId: string sessionConnections: SessionRecord template: AgentTemplate - textOverride: string | null + tools: ToolSet userId: string | undefined userInputId: string @@ -41,12 +43,13 @@ export const getAgentStreamFromTemplate = (params: { fingerprintId, includeCacheControl, liveUserInputRecord, + localAgentTemplates, logger, messages, runId, sessionConnections, template, - textOverride, + tools, userId, userInputId, @@ -56,14 +59,6 @@ export const getAgentStreamFromTemplate = (params: { trackEvent, } = params - if (textOverride !== null) { - async function* stream(): ReturnType { - yield { type: 'text', text: textOverride!, agentId } - return crypto.randomUUID() - } - return stream() - } - if (!template) { throw new Error('Agent template is null/undefined') } @@ -71,24 +66,28 @@ export const getAgentStreamFromTemplate = (params: { const { model } = template const aiSdkStreamParams: ParamsOf = { + agentId, apiKey, - runId, + clientSessionId, + fingerprintId, + includeCacheControl, + logger, + liveUserInputRecord, + localAgentTemplates, + maxOutputTokens: 32_000, + maxRetries: 3, messages, model, + runId, + sessionConnections, + spawnableAgents: template.spawnableAgents, stopSequences: [globalStopSequence], - clientSessionId, - fingerprintId, - userInputId, + tools, userId, - maxOutputTokens: 32_000, + userInputId, + onCostCalculated, - includeCacheControl, - agentId, - maxRetries: 3, sendAction, - liveUserInputRecord, - sessionConnections, - logger, trackEvent, } diff --git a/packages/agent-runtime/src/run-agent-step.ts b/packages/agent-runtime/src/run-agent-step.ts index 9ad9b8c95..2ceb1ec59 100644 --- a/packages/agent-runtime/src/run-agent-step.ts +++ b/packages/agent-runtime/src/run-agent-step.ts @@ -13,8 +13,10 @@ import { getAgentStreamFromTemplate } from './prompt-agent-stream' import { runProgrammaticStep } from './run-programmatic-step' import { additionalSystemPrompts } from './system-prompt/prompts' import { getAgentTemplate } from './templates/agent-registry' +import { buildAgentToolSet } from './templates/prompts' import { getAgentPrompt } from './templates/strings' -import { processStreamWithTools } from './tools/stream-parser' +import { getToolSet } from './tools/prompts' +import { processStream } from './tools/stream-parser' import { getAgentOutput } from './util/agent-output' import { withSystemInstructionTags, @@ -57,6 +59,7 @@ import type { CustomToolDefinitions, ProjectFileContext, } from '@codebuff/common/util/file' +import type { ToolSet } from 'ai' async function additionalToolDefinitions( params: { @@ -106,7 +109,7 @@ export const runAgentStep = async ( trackEvent: TrackEventFn promptAiSdk: PromptAiSdkFn } & ParamsExcluding< - typeof processStreamWithTools, + typeof processStream, | 'agentContext' | 'agentState' | 'agentStepId' @@ -291,6 +294,7 @@ export const runAgentStep = async ( agentContext, systemTokens, agentTemplate, + tools: params.tools, }, `Start agent ${agentType} step ${iterationNum} (${userInputId}${prompt ? ` - Prompt: ${prompt.slice(0, 20)}` : ''})`, ) @@ -337,6 +341,7 @@ export const runAgentStep = async ( let fullResponse = '' const toolResults: ToolMessage[] = [] + // Raw stream from AI SDK const stream = getAgentStreamFromTemplate({ ...params, agentId: agentState.parentId ? agentState.agentId : undefined, @@ -349,10 +354,11 @@ export const runAgentStep = async ( const { fullResponse: fullResponseAfterStream, fullResponseChunks, + hadToolCallError, messageId, toolCalls, toolResults: newToolResults, - } = await processStreamWithTools({ + } = await processStream({ ...params, agentContext, agentState, @@ -410,7 +416,8 @@ export const runAgentStep = async ( ).length === 0 && toolResults.filter( (result) => !TOOLS_WHICH_WONT_FORCE_NEXT_STEP.includes(result.toolName), - ).length === 0 + ).length === 0 && + !hadToolCallError // Tool call errors should also force another step so the agent can retry const hasTaskCompleted = toolCalls.some( (call) => @@ -468,37 +475,36 @@ export const runAgentStep = async ( export async function loopAgentSteps( params: { - userInputId: string - agentType: AgentTemplateType + addAgentStep: AddAgentStepFn agentState: AgentState - prompt: string | undefined + agentType: AgentTemplateType + clearUserPromptMessagesAfterResponse?: boolean + clientSessionId: string content?: Array - spawnParams: Record | undefined fileContext: ProjectFileContext + finishAgentRun: FinishAgentRunFn localAgentTemplates: Record - clearUserPromptMessagesAfterResponse?: boolean + logger: Logger parentSystemPrompt?: string + parentTools?: ToolSet + prompt: string | undefined signal: AbortSignal - - userId: string | undefined - clientSessionId: string - + spawnParams: Record | undefined startAgentRun: StartAgentRunFn - finishAgentRun: FinishAgentRunFn - addAgentStep: AddAgentStepFn - logger: Logger + userId: string | undefined + userInputId: string } & ParamsExcluding & ParamsExcluding< typeof runProgrammaticStep, - | 'runId' | 'agentState' - | 'template' + | 'onCostCalculated' | 'prompt' - | 'toolCallParams' - | 'stepsComplete' + | 'runId' | 'stepNumber' + | 'stepsComplete' | 'system' - | 'onCostCalculated' + | 'template' + | 'toolCallParams' > & ParamsExcluding & ParamsExcluding< @@ -526,7 +532,7 @@ export async function loopAgentSteps( | 'runId' | 'spawnParams' | 'system' - | 'textOverride' + | 'tools' > & ParamsExcluding< AddAgentStepFn, @@ -543,23 +549,24 @@ export async function loopAgentSteps( output: AgentOutput }> { const { - userInputId, - agentType, + addAgentStep, agentState, - prompt, + agentType, + clearUserPromptMessagesAfterResponse = true, + clientSessionId, content, - spawnParams, fileContext, + finishAgentRun, localAgentTemplates, - userId, - clientSessionId, - clearUserPromptMessagesAfterResponse = true, + logger, parentSystemPrompt, + parentTools, + prompt, signal, + spawnParams, startAgentRun, - finishAgentRun, - addAgentStep, - logger, + userId, + userInputId, } = params const agentTemplate = await getAgentTemplate({ @@ -591,12 +598,17 @@ export async function loopAgentSteps( agentState.runId = runId let cachedAdditionalToolDefinitions: CustomToolDefinitions | undefined + // Use parent's tools for prompt caching when inheritParentSystemPrompt is true + const useParentTools = + agentTemplate.inheritParentSystemPrompt && parentTools !== undefined + // Initialize message history with user prompt and instructions on first iteration const instructionsPrompt = await getAgentPrompt({ ...params, agentTemplate, promptType: { type: 'instructionsPrompt' }, agentTemplates: localAgentTemplates, + useParentTools, additionalToolDefinitions: async () => { if (!cachedAdditionalToolDefinitions) { cachedAdditionalToolDefinitions = await additionalToolDefinitions({ @@ -631,6 +643,31 @@ export async function loopAgentSteps( }, })) ?? '' + // Build agent tools (agents as direct tool calls) for non-inherited tools + const agentTools = useParentTools + ? {} + : await buildAgentToolSet({ + ...params, + spawnableAgents: agentTemplate.spawnableAgents, + agentTemplates: localAgentTemplates, + }) + + const tools = useParentTools + ? parentTools + : await getToolSet({ + toolNames: agentTemplate.toolNames, + additionalToolDefinitions: async () => { + if (!cachedAdditionalToolDefinitions) { + cachedAdditionalToolDefinitions = await additionalToolDefinitions({ + ...params, + agentTemplate, + }) + } + return cachedAdditionalToolDefinitions + }, + agentTools, + }) + const hasUserMessage = Boolean( prompt || (spawnParams && Object.keys(spawnParams).length > 0) || @@ -704,26 +741,26 @@ export async function loopAgentSteps( const startTime = new Date() // 1. Run programmatic step first if it exists - let textOverride = null let n: number | undefined = undefined if (agentTemplate.handleSteps) { const programmaticResult = await runProgrammaticStep({ ...params, - runId, + agentState: currentAgentState, - template: agentTemplate, localAgentTemplates, - prompt: currentPrompt, - toolCallParams: currentParams, - system, - stepsComplete: shouldEndTurn, - stepNumber: totalSteps, nResponses, onCostCalculated: async (credits: number) => { agentState.creditsUsed += credits agentState.directCreditsUsed += credits }, + prompt: currentPrompt, + runId, + stepNumber: totalSteps, + stepsComplete: shouldEndTurn, + system, + template: agentTemplate, + toolCallParams: currentParams, }) const { agentState: programmaticAgentState, @@ -731,7 +768,6 @@ export async function loopAgentSteps( stepNumber, generateN, } = programmaticResult - textOverride = programmaticResult.textOverride n = generateN currentAgentState = programmaticAgentState @@ -788,6 +824,15 @@ export async function loopAgentSteps( nResponses: generatedResponses, } = await runAgentStep({ ...params, + + agentState: currentAgentState, + n, + prompt: currentPrompt, + runId, + spawnParams: currentParams, + system, + tools, + additionalToolDefinitions: async () => { if (!cachedAdditionalToolDefinitions) { cachedAdditionalToolDefinitions = await additionalToolDefinitions({ @@ -797,13 +842,6 @@ export async function loopAgentSteps( } return cachedAdditionalToolDefinitions }, - textOverride: textOverride, - runId, - agentState: currentAgentState, - prompt: currentPrompt, - spawnParams: currentParams, - system, - n, }) if (newAgentState.runId) { @@ -865,7 +903,10 @@ export async function loopAgentSteps( ) // Re-throw NetworkError and PaymentRequiredError to allow SDK retry wrapper to handle it - if (error instanceof Error && (error.name === 'NetworkError' || error.name === 'PaymentRequiredError')) { + if ( + error instanceof Error && + (error.name === 'NetworkError' || error.name === 'PaymentRequiredError') + ) { throw error } diff --git a/packages/agent-runtime/src/run-programmatic-step.ts b/packages/agent-runtime/src/run-programmatic-step.ts index f8cda7edf..ada1ed101 100644 --- a/packages/agent-runtime/src/run-programmatic-step.ts +++ b/packages/agent-runtime/src/run-programmatic-step.ts @@ -1,13 +1,17 @@ -import { getToolCallString } from '@codebuff/common/tools/utils' import { getErrorObject } from '@codebuff/common/util/error' import { assistantMessage } from '@codebuff/common/util/messages' import { cloneDeep } from 'lodash' import { executeToolCall } from './tools/tool-executor' +import { parseTextWithToolCalls } from './util/parse-tool-calls-from-text' + +import type { ParsedSegment } from './util/parse-tool-calls-from-text' import type { FileProcessingState } from './tools/handlers/tool/write-file' import type { ExecuteToolCallParams } from './tools/tool-executor' import type { CodebuffToolCall } from '@codebuff/common/tools/list' +import { HandleStepsYieldValueSchema } from '@codebuff/common/types/agent-template' + import type { AgentTemplate, StepGenerator, @@ -21,10 +25,12 @@ import type { AddAgentStepFn } from '@codebuff/common/types/contracts/database' import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { ToolMessage } from '@codebuff/common/types/messages/codebuff-message' -import type { ToolResultOutput } from '@codebuff/common/types/messages/content-part' +import type { + ToolCallPart, + ToolResultOutput, +} from '@codebuff/common/types/messages/content-part' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState } from '@codebuff/common/types/session-state' - // Maintains generator state for all agents. Generator state can't be serialized, so we store it in memory. const runIdToGenerator: Record = {} export const runIdToStepAll: Set = new Set() @@ -40,26 +46,26 @@ export function clearAgentGeneratorCache(params: { logger: Logger }) { // Function to handle programmatic agents export async function runProgrammaticStep( params: { + addAgentStep: AddAgentStepFn agentState: AgentState - template: AgentTemplate + clientSessionId: string + fingerprintId: string + handleStepsLogChunk: HandleStepsLogChunkFn + localAgentTemplates: Record + logger: Logger + nResponses?: string[] + onResponseChunk: (chunk: string | PrintModeEvent) => void prompt: string | undefined - toolCallParams: Record | undefined - system: string | undefined - userId: string | undefined repoId: string | undefined repoUrl: string | undefined - userInputId: string - fingerprintId: string - clientSessionId: string - onResponseChunk: (chunk: string | PrintModeEvent) => void - localAgentTemplates: Record - stepsComplete: boolean stepNumber: number - handleStepsLogChunk: HandleStepsLogChunkFn + stepsComplete: boolean + template: AgentTemplate + toolCallParams: Record | undefined sendAction: SendActionFn - addAgentStep: AddAgentStepFn - logger: Logger - nResponses?: string[] + system: string | undefined + userId: string | undefined + userInputId: string } & Omit< ExecuteToolCallParams, | 'toolName' @@ -89,7 +95,6 @@ export async function runProgrammaticStep( >, ): Promise<{ agentState: AgentState - textOverride: string | null endTurn: boolean stepNumber: number generateN?: number @@ -170,7 +175,7 @@ export async function runProgrammaticStep( // Clear the STEP_ALL mode. Stepping can continue if handleSteps doesn't return. runIdToStepAll.delete(agentState.runId) } else { - return { agentState, textOverride: null, endTurn: false, stepNumber } + return { agentState, endTurn: false, stepNumber } } } @@ -205,7 +210,6 @@ export async function runProgrammaticStep( let toolResult: ToolResultOutput[] | undefined = undefined let endTurn = false - let textOverride: string | null = null let generateN: number | undefined = undefined let startTime = new Date() @@ -232,6 +236,16 @@ export async function runProgrammaticStep( endTurn = true break } + + // Validate the yield value from handleSteps + const parseResult = HandleStepsYieldValueSchema.safeParse(result.value) + if (!parseResult.success) { + throw new Error( + `Invalid yield value from handleSteps in agent ${template.id}: ${parseResult.error.message}. ` + + `Received: ${JSON.stringify(result.value)}`, + ) + } + if (result.value === 'STEP') { break } @@ -241,8 +255,26 @@ export async function runProgrammaticStep( } if ('type' in result.value && result.value.type === 'STEP_TEXT') { - textOverride = result.value.text - break + // Parse text and tool calls, preserving interleaved order + const segments = parseTextWithToolCalls(result.value.text) + + if (segments.length > 0) { + // Execute segments (text and tool calls) in order + toolResult = await executeSegmentsArray(segments, { + ...params, + agentContext, + agentStepId, + agentTemplate: template, + agentState, + fileProcessingState, + fullResponse: '', + previousToolCallFinished: Promise.resolve(), + toolCalls, + toolResults, + onResponseChunk, + }) + } + continue } if ('type' in result.value && result.value.type === 'GENERATE_N') { @@ -254,121 +286,22 @@ export async function runProgrammaticStep( } // Process tool calls yielded by the generator - const toolCallWithoutId = result.value - const toolCallId = crypto.randomUUID() - const toolCall = { - ...toolCallWithoutId, - toolCallId, - } as CodebuffToolCall & { - includeToolCall?: boolean - } + const toolCall = result.value as ToolCallToExecute - // Note: We don't check if the tool is available for the agent template anymore. - // You can run any tool from handleSteps now! - // if (!template.toolNames.includes(toolCall.toolName)) { - // throw new Error( - // `Tool ${toolCall.toolName} is not available for agent ${template.id}. Available tools: ${template.toolNames.join(', ')}`, - // ) - // } - - const excludeToolFromMessageHistory = toolCall?.includeToolCall === false - // Add assistant message with the tool call before executing it - if (!excludeToolFromMessageHistory) { - const toolCallString = getToolCallString( - toolCall.toolName, - toolCall.input, - ) - onResponseChunk(toolCallString) - agentState.messageHistory.push(assistantMessage(toolCallString)) - // Optional call handles both top-level and nested agents - sendSubagentChunk({ - userInputId, - agentId: agentState.agentId, - agentType: agentState.agentType!, - chunk: toolCallString, - forwardToPrompt: !agentState.parentId, - }) - } - - // Execute the tool synchronously and get the result immediately - // Wrap onResponseChunk to add parentAgentId to nested agent events - await executeToolCall({ + toolResult = await executeSingleToolCall(toolCall, { ...params, - toolName: toolCall.toolName, - input: toolCall.input, - autoInsertEndStepParam: true, - excludeToolFromMessageHistory, - fromHandleSteps: true, - agentContext, agentStepId, agentTemplate: template, + agentState, fileProcessingState, fullResponse: '', previousToolCallFinished: Promise.resolve(), - toolCallId, toolCalls, toolResults, - toolResultsToAddAfterStream: [], - - onResponseChunk: (chunk: string | PrintModeEvent) => { - if (typeof chunk === 'string') { - onResponseChunk(chunk) - return - } - - // Only add parentAgentId if this programmatic agent has a parent (i.e., it's nested) - // This ensures we don't add parentAgentId to top-level spawns - if (agentState.parentId) { - const parentAgentId = agentState.agentId - - switch (chunk.type) { - case 'subagent_start': - case 'subagent_finish': - if (!chunk.parentAgentId) { - onResponseChunk({ - ...chunk, - parentAgentId, - }) - return - } - break - case 'tool_call': - case 'tool_result': { - if (!chunk.parentAgentId) { - const debugPayload = - chunk.type === 'tool_call' - ? { - eventType: chunk.type, - agentId: chunk.agentId, - parentId: parentAgentId, - } - : { - eventType: chunk.type, - parentId: parentAgentId, - } - onResponseChunk({ - ...chunk, - parentAgentId, - }) - return - } - break - } - default: - break - } - } - - // For other events or top-level spawns, send as-is - onResponseChunk(chunk) - }, + onResponseChunk, }) - // Get the latest tool result - const latestToolResult = toolResults[toolResults.length - 1] - toolResult = latestToolResult?.content - if (agentState.runId) { await addAgentStep({ ...params, @@ -393,7 +326,6 @@ export async function runProgrammaticStep( return { agentState, - textOverride, endTurn, stepNumber, generateN, @@ -437,7 +369,6 @@ export async function runProgrammaticStep( return { agentState, - textOverride: null, endTurn, stepNumber, generateN: undefined, @@ -462,3 +393,170 @@ export const getPublicAgentState = ( output, } } + +/** + * Represents a tool call to be executed. + * Can optionally include `includeToolCall: false` to exclude from message history. + */ +type ToolCallToExecute = { + toolName: string + input: Record + includeToolCall?: boolean +} + +/** + * Parameters for executing an array of tool calls. + */ +type ExecuteToolCallsArrayParams = Omit< + ExecuteToolCallParams, + | 'toolName' + | 'input' + | 'autoInsertEndStepParam' + | 'excludeToolFromMessageHistory' + | 'toolCallId' + | 'toolResultsToAddAfterStream' +> & { + agentState: AgentState + onResponseChunk: (chunk: string | PrintModeEvent) => void +} + +/** + * Executes a single tool call. + * Adds the tool call as an assistant message and then executes it. + * + * @returns The tool result from the executed tool call. + */ +async function executeSingleToolCall( + toolCallToExecute: ToolCallToExecute, + params: ExecuteToolCallsArrayParams, +): Promise { + const { agentState, onResponseChunk, toolResults } = params + + // Note: We don't check if the tool is available for the agent template anymore. + // You can run any tool from handleSteps now! + // if (!template.toolNames.includes(toolCall.toolName)) { + // throw new Error( + // `Tool ${toolCall.toolName} is not available for agent ${template.id}. Available tools: ${template.toolNames.join(', ')}`, + // ) + // } + + const toolCallId = crypto.randomUUID() + const excludeToolFromMessageHistory = + toolCallToExecute.includeToolCall === false + + // Add assistant message with the tool call before executing it + if (!excludeToolFromMessageHistory) { + const toolCallPart: ToolCallPart = { + type: 'tool-call', + toolCallId, + toolName: toolCallToExecute.toolName, + input: toolCallToExecute.input, + } + // onResponseChunk({ + // ...toolCallPart, + // type: 'tool_call', + // agentId: agentState.agentId, + // parentAgentId: agentState.parentId, + // }) + // NOTE(James): agentState.messageHistory is readonly for some reason (?!). Recreating the array is a workaround. + agentState.messageHistory = [...agentState.messageHistory] + agentState.messageHistory.push(assistantMessage(toolCallPart)) + // Optional call handles both top-level and nested agents + // sendSubagentChunk({ + // userInputId, + // agentId: agentState.agentId, + // agentType: agentState.agentType!, + // chunk: toolCallString, + // forwardToPrompt: !agentState.parentId, + // }) + } + + // Execute the tool call + await executeToolCall({ + ...params, + toolName: toolCallToExecute.toolName as any, + input: toolCallToExecute.input, + autoInsertEndStepParam: true, + excludeToolFromMessageHistory, + fromHandleSteps: true, + toolCallId, + toolResultsToAddAfterStream: [], + + onResponseChunk: (chunk: string | PrintModeEvent) => { + if (typeof chunk === 'string') { + onResponseChunk(chunk) + return + } + + // Only add parentAgentId if this programmatic agent has a parent (i.e., it's nested) + // This ensures we don't add parentAgentId to top-level spawns + if (agentState.parentId) { + const parentAgentId = agentState.agentId + + switch (chunk.type) { + case 'subagent_start': + case 'subagent_finish': + if (!chunk.parentAgentId) { + onResponseChunk({ + ...chunk, + parentAgentId, + }) + return + } + break + case 'tool_call': + case 'tool_result': { + if (!chunk.parentAgentId) { + onResponseChunk({ + ...chunk, + parentAgentId, + }) + return + } + break + } + default: + break + } + } + + // For other events or top-level spawns, send as-is + onResponseChunk(chunk) + }, + }) + + // Get the latest tool result + return toolResults[toolResults.length - 1]?.content +} + +/** + * Executes an array of segments (text and tool calls) sequentially. + * Text segments are added as assistant messages. + * Tool calls are added as assistant messages and then executed. + * + * @returns The tool result from the last executed tool call. + */ +async function executeSegmentsArray( + segments: ParsedSegment[], + params: ExecuteToolCallsArrayParams, +): Promise { + const { agentState } = params + + let toolResults: ToolResultOutput[] = [] + + for (const segment of segments) { + if (segment.type === 'text') { + // Add text as an assistant message + agentState.messageHistory = [...agentState.messageHistory] + agentState.messageHistory.push(assistantMessage(segment.text)) + } else { + // Handle tool call segment + const toolResult = await executeSingleToolCall(segment, params) + if (toolResult) { + toolResults.push(...toolResult) + } + } + } + + return toolResults +} diff --git a/packages/agent-runtime/src/templates/prompts.ts b/packages/agent-runtime/src/templates/prompts.ts index e1cb77d0a..0e83ea8e9 100644 --- a/packages/agent-runtime/src/templates/prompts.ts +++ b/packages/agent-runtime/src/templates/prompts.ts @@ -1,14 +1,169 @@ import { getAgentTemplate } from './agent-registry' import { buildArray } from '@codebuff/common/util/array' import { schemaToJsonStr } from '@codebuff/common/util/zod-schema' +import { z } from 'zod/v4' import type { AgentTemplate } from '@codebuff/common/types/agent-template' import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { AgentTemplateType } from '@codebuff/common/types/session-state' -import { getToolCallString } from '@codebuff/common/tools/utils' +import type { ToolSet } from 'ai' -export async function buildSpawnableAgentsDescription( +/** + * Gets the short agent name from a fully qualified agent ID. + * E.g., 'codebuff/file-picker@1.0.0' -> 'file-picker' + */ +export function getAgentShortName(agentType: AgentTemplateType): string { + const withoutVersion = agentType.split('@')[0] + const parts = withoutVersion.split('/') + return parts[parts.length - 1] +} + +/** + * Builds a flat input schema for an agent tool by combining prompt and params. + * E.g., { prompt?: string, ...paramsFields } + */ +export function buildAgentFlatInputSchema( + agentTemplate: AgentTemplate, +): z.ZodType { + const { inputSchema } = agentTemplate + + // Start with an empty object schema + let schemaFields: Record = {} + + // Add prompt field if defined + if (inputSchema?.prompt) { + schemaFields.prompt = inputSchema.prompt.optional() + } + + // Merge params fields directly into the schema (flat structure) + if (inputSchema?.params) { + // Try to get the shape from the params schema directly if it's a ZodObject + // This preserves the full nested structure instead of converting to z.any() + const paramsShape = getZodObjectShape(inputSchema.params) + + if (paramsShape) { + // We have the original Zod shape, use it directly + for (const [key, fieldSchema] of Object.entries(paramsShape)) { + // Skip if we already have a prompt field + if (key === 'prompt') continue + schemaFields[key] = fieldSchema as z.ZodType + } + } + } + + return z + .object(schemaFields) + .describe( + agentTemplate.spawnerPrompt || + `Spawn the ${agentTemplate.displayName} agent`, + ) +} + +/** + * Extracts the shape from a Zod schema if it's a ZodObject. + * Handles wrapped types like ZodOptional, ZodNullable, ZodDefault, etc. + */ +function getZodObjectShape( + schema: z.ZodType, +): Record | null { + // ZodObject has a public .shape property in Zod v4 + if ( + 'shape' in schema && + typeof schema.shape === 'object' && + schema.shape !== null + ) { + return schema.shape as Record + } + + // Handle wrapped types (optional, nullable, default, etc.) via internal def + const def = (schema as any)?._zod?.def + if (def?.inner) { + return getZodObjectShape(def.inner) + } + + return null +} + +/** + * Builds AI SDK tool definitions for spawnable agents. + * These tools allow the model to call agents directly as tool calls. + */ +export async function buildAgentToolSet( + params: { + spawnableAgents: AgentTemplateType[] + agentTemplates: Record + logger: Logger + } & ParamsExcluding< + typeof getAgentTemplate, + 'agentId' | 'localAgentTemplates' + >, +): Promise { + const { spawnableAgents, agentTemplates } = params + + const toolSet: ToolSet = {} + + for (const agentType of spawnableAgents) { + const agentTemplate = await getAgentTemplate({ + ...params, + agentId: agentType, + localAgentTemplates: agentTemplates, + }) + + if (!agentTemplate) continue + + const shortName = getAgentShortName(agentType) + const inputSchema = buildAgentFlatInputSchema(agentTemplate) + + // Use the same structure as other tools in toolParams + toolSet[shortName] = { + description: + agentTemplate.spawnerPrompt || + `Spawn the ${agentTemplate.displayName} agent`, + inputSchema, + } + } + + return toolSet +} + +/** + * Builds the description of a single agent for the system prompt. + */ +function buildSingleAgentDescription( + agentType: AgentTemplateType, + agentTemplate: AgentTemplate | null, +): string { + if (!agentTemplate) { + // Fallback for unknown agents + return `- ${agentType}: Dynamic agent (description not available) +prompt: {"description": "A coding task to complete", "type": "string"} +params: None` + } + + const { inputSchema } = agentTemplate + const inputSchemaStr = inputSchema + ? [ + `prompt: ${schemaToJsonStr(inputSchema.prompt)}`, + `params: ${schemaToJsonStr(inputSchema.params)}`, + ].join('\n') + : ['prompt: None', 'params: None'].join('\n') + + return buildArray( + `- ${agentType}: ${agentTemplate.spawnerPrompt}`, + agentTemplate.includeMessageHistory && + 'This agent can see the current message history.', + agentTemplate.inheritParentSystemPrompt && + "This agent inherits the parent's system prompt for prompt caching.", + inputSchemaStr, + ).join('\n') +} + +/** + * Builds the full spawnable agents specification for subagent instructions. + * This is used when inheritSystemPrompt is true to tell subagents which agents they can spawn. + */ +export async function buildFullSpawnableAgentsSpec( params: { spawnableAgents: AgentTemplateType[] agentTemplates: Record @@ -18,7 +173,7 @@ export async function buildSpawnableAgentsDescription( 'agentId' | 'localAgentTemplates' >, ): Promise { - const { spawnableAgents, agentTemplates, logger } = params + const { spawnableAgents, agentTemplates } = params if (spawnableAgents.length === 0) { return '' } @@ -37,54 +192,13 @@ export async function buildSpawnableAgentsDescription( ) const agentsDescription = subAgentTypesAndTemplates - .map(([agentType, agentTemplate]) => { - if (!agentTemplate) { - // Fallback for unknown agents - return `- ${agentType}: Dynamic agent (description not available) -prompt: {"description": "A coding task to complete", "type": "string"} -params: None` - } - const { inputSchema } = agentTemplate - const inputSchemaStr = inputSchema - ? [ - `prompt: ${schemaToJsonStr(inputSchema.prompt)}`, - `params: ${schemaToJsonStr(inputSchema.params)}`, - ].join('\n') - : ['prompt: None', 'params: None'].join('\n') - - return buildArray( - `- ${agentType}: ${agentTemplate.spawnerPrompt}`, - agentTemplate.includeMessageHistory && - 'This agent can see the current message history.', - agentTemplate.inheritParentSystemPrompt && - "This agent inherits the parent's system prompt for prompt caching.", - inputSchemaStr, - ).join('\n') - }) + .map(([agentType, agentTemplate]) => + buildSingleAgentDescription(agentType, agentTemplate), + ) .filter(Boolean) .join('\n\n') - return `\n\n## Spawnable Agents - -Use the spawn_agents tool to spawn agents to help you complete the user request. - -Notes: -- You can not call the agents as tool names directly: you must use the spawn_agents tool with the correct parameters to spawn them! -- There are two types of input arguments for agents: prompt and params. The prompt is a string, and the params is a json object. Some agents require only one or the other, some require both, and some require none. -- Below are the *only* available agents by their agent_type. Other agents may be referenced earlier in the conversation, but they are not available to you. - -Example: - -${getToolCallString('spawn_agents', { - agents: [ - { - agent_type: 'example-agent', - prompt: 'Do an example task for me', - }, - ], -})} - -Spawn only the below agents: + return `You are a subagent that can only spawn the following agents using the spawn_agents tool: ${agentsDescription}` } diff --git a/packages/agent-runtime/src/templates/strings.ts b/packages/agent-runtime/src/templates/strings.ts index 2f7c4e75f..766817226 100644 --- a/packages/agent-runtime/src/templates/strings.ts +++ b/packages/agent-runtime/src/templates/strings.ts @@ -4,18 +4,13 @@ import { schemaToJsonStr } from '@codebuff/common/util/zod-schema' import { z } from 'zod/v4' import { getAgentTemplate } from './agent-registry' -import { buildSpawnableAgentsDescription } from './prompts' +import { buildFullSpawnableAgentsSpec } from './prompts' import { PLACEHOLDER, placeholderValues } from './types' import { getGitChangesPrompt, getProjectFileTreePrompt, getSystemInfoPrompt, } from '../system-prompt/prompts' -import { - fullToolList, - getShortToolInstructions, - getToolsInstructions, -} from '../tools/prompts' import { parseUserMessage } from '../util/messages' import type { AgentTemplate, PlaceholderValue } from './types' @@ -113,9 +108,6 @@ export async function formatPrompt( [PLACEHOLDER.REMAINING_STEPS]: () => `${agentState.stepsRemaining!}`, [PLACEHOLDER.PROJECT_ROOT]: () => fileContext.projectRoot, [PLACEHOLDER.SYSTEM_INFO_PROMPT]: () => getSystemInfoPrompt(fileContext), - [PLACEHOLDER.TOOLS_PROMPT]: async () => - getToolsInstructions(tools, (await additionalToolDefinitions()) ?? {}), - [PLACEHOLDER.AGENTS_PROMPT]: () => buildSpawnableAgentsDescription(params), [PLACEHOLDER.USER_CWD]: () => fileContext.cwd, [PLACEHOLDER.USER_INPUT_PROMPT]: () => escapeString(lastUserInput ?? ''), [PLACEHOLDER.INITIAL_AGENT_PROMPT]: () => @@ -150,11 +142,6 @@ export async function formatPrompt( } type StringField = 'systemPrompt' | 'instructionsPrompt' | 'stepPrompt' -const additionalPlaceholders = { - systemPrompt: [PLACEHOLDER.TOOLS_PROMPT, PLACEHOLDER.AGENTS_PROMPT], - instructionsPrompt: [], - stepPrompt: [], -} satisfies Record export async function getAgentPrompt( params: { agentTemplate: AgentTemplate @@ -164,12 +151,13 @@ export async function getAgentPrompt( agentTemplates: Record additionalToolDefinitions: () => Promise logger: Logger + useParentTools?: boolean } & ParamsExcluding< typeof formatPrompt, 'prompt' | 'tools' | 'spawnableAgents' > & ParamsExcluding< - typeof buildSpawnableAgentsDescription, + typeof buildFullSpawnableAgentsSpec, 'spawnableAgents' | 'agentTemplates' >, ): Promise { @@ -179,14 +167,10 @@ export async function getAgentPrompt( agentState, agentTemplates, additionalToolDefinitions, + useParentTools, } = params let promptValue = agentTemplate[promptType.type] - for (const placeholder of additionalPlaceholders[promptType.type]) { - if (!promptValue.includes(placeholder)) { - promptValue += `\n\n${placeholder}` - } - } let prompt = await formatPrompt({ ...params, @@ -204,21 +188,26 @@ export async function getAgentPrompt( // Add tool instructions, spawnable agents, and output schema prompts to instructionsPrompt if (promptType.type === 'instructionsPrompt' && agentState.agentType) { - const toolsInstructions = agentTemplate.inheritParentSystemPrompt - ? fullToolList(agentTemplate.toolNames, await additionalToolDefinitions()) - : getShortToolInstructions( - agentTemplate.toolNames, - await additionalToolDefinitions(), - ) - addendum += - '\n\n' + - toolsInstructions + - '\n\n' + - (await buildSpawnableAgentsDescription({ - ...params, - spawnableAgents: agentTemplate.spawnableAgents, - agentTemplates, - })) + // Add subagent tools message when using parent's tools for prompt caching + if (useParentTools) { + addendum += `\n\nYou are a subagent that only has access to the following tools: ${agentTemplate.toolNames.join(', ')}. Do not attempt to use any other tools.` + + // For subagents with inheritSystemPrompt, include full spawnable agents spec + // since the parent's system prompt may not have these agents listed + if (agentTemplate.spawnableAgents.length > 0) { + addendum += + '\n\n' + + (await buildFullSpawnableAgentsSpec({ + ...params, + spawnableAgents: agentTemplate.spawnableAgents, + agentTemplates, + })) + } + } else if (agentTemplate.spawnableAgents.length > 0) { + // For non-inherited tools, agents are already defined as tools with full schemas, + // so we just list the available agent IDs here + addendum += `\n\nYou can spawn the following agents: ${agentTemplate.spawnableAgents.join(', ')}.` + } // Add output schema information if defined if (agentTemplate.outputSchema) { diff --git a/packages/agent-runtime/src/templates/types.ts b/packages/agent-runtime/src/templates/types.ts index b0e547ce1..ee46e095c 100644 --- a/packages/agent-runtime/src/templates/types.ts +++ b/packages/agent-runtime/src/templates/types.ts @@ -13,7 +13,6 @@ export type { AgentTemplate, StepGenerator, StepHandler } const placeholderNames = [ 'AGENT_NAME', - 'AGENTS_PROMPT', 'CONFIG_SCHEMA', 'FILE_TREE_PROMPT_SMALL', 'FILE_TREE_PROMPT', @@ -24,7 +23,6 @@ const placeholderNames = [ 'PROJECT_ROOT', 'REMAINING_STEPS', 'SYSTEM_INFO_PROMPT', - 'TOOLS_PROMPT', 'USER_CWD', 'USER_INPUT_PROMPT', ] as const diff --git a/packages/agent-runtime/src/tool-stream-parser.old.ts b/packages/agent-runtime/src/tool-stream-parser.old.ts new file mode 100644 index 000000000..e7e07ca43 --- /dev/null +++ b/packages/agent-runtime/src/tool-stream-parser.old.ts @@ -0,0 +1,217 @@ +import { AnalyticsEvent } from '@codebuff/common/constants/analytics-events' +import { + endsAgentStepParam, + endToolTag, + startToolTag, + toolNameParam, +} from '@codebuff/common/tools/constants' + +import type { Model } from '@codebuff/common/old-constants' +import type { TrackEventFn } from '@codebuff/common/types/contracts/analytics' +import type { StreamChunk } from '@codebuff/common/types/contracts/llm' +import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { + PrintModeError, + PrintModeText, +} from '@codebuff/common/types/print-mode' + +const toolExtractionPattern = new RegExp( + `${startToolTag}(.*?)${endToolTag}`, + 'gs', +) + +const completionSuffix = `${JSON.stringify(endsAgentStepParam)}: true\n}${endToolTag}` + +export async function* processStreamWithTags(params: { + stream: AsyncGenerator + processors: Record< + string, + { + onTagStart: (tagName: string, attributes: Record) => void + onTagEnd: (tagName: string, params: Record) => void + } + > + defaultProcessor: (toolName: string) => { + onTagStart: (tagName: string, attributes: Record) => void + onTagEnd: (tagName: string, params: Record) => void + } + onError: (tagName: string, errorMessage: string) => void + onResponseChunk: (chunk: PrintModeText | PrintModeError) => void + logger: Logger + loggerOptions?: { + userId?: string + model?: Model + agentName?: string + } + trackEvent: TrackEventFn +}): AsyncGenerator { + const { + stream, + processors, + defaultProcessor, + onError, + onResponseChunk, + logger, + loggerOptions, + trackEvent, + } = params + + let streamCompleted = false + let buffer = '' + let autocompleted = false + + function extractToolCalls(): string[] { + const matches: string[] = [] + let lastIndex = 0 + for (const match of buffer.matchAll(toolExtractionPattern)) { + if (match.index > lastIndex) { + onResponseChunk({ + type: 'text', + text: buffer.slice(lastIndex, match.index), + }) + } + lastIndex = match.index + match[0].length + matches.push(match[1]) + } + + buffer = buffer.slice(lastIndex) + return matches + } + + function processToolCallContents(contents: string): void { + let parsedParams: any + try { + parsedParams = JSON.parse(contents) + } catch (error: any) { + trackEvent({ + event: AnalyticsEvent.MALFORMED_TOOL_CALL_JSON, + userId: loggerOptions?.userId ?? '', + properties: { + contents: JSON.stringify(contents), + model: loggerOptions?.model, + agent: loggerOptions?.agentName, + error: { + name: error.name, + message: error.message, + stack: error.stack, + }, + autocompleted, + }, + logger, + }) + const shortenedContents = + contents.length < 200 + ? contents + : contents.slice(0, 100) + '...' + contents.slice(-100) + const errorMessage = `Invalid JSON: ${JSON.stringify(shortenedContents)}\nError: ${error.message}` + onResponseChunk({ + type: 'error', + message: errorMessage, + }) + onError('parse_error', errorMessage) + return + } + + const toolName = parsedParams[toolNameParam] as keyof typeof processors + const processor = + typeof toolName === 'string' + ? processors[toolName] ?? defaultProcessor(toolName) + : undefined + if (!processor) { + trackEvent({ + event: AnalyticsEvent.UNKNOWN_TOOL_CALL, + userId: loggerOptions?.userId ?? '', + properties: { + contents, + toolName, + model: loggerOptions?.model, + agent: loggerOptions?.agentName, + autocompleted, + }, + logger, + }) + onError( + 'parse_error', + `Unknown tool ${JSON.stringify(toolName)} for tool call: ${contents}`, + ) + return + } + + trackEvent({ + event: AnalyticsEvent.TOOL_USE, + userId: loggerOptions?.userId ?? '', + properties: { + toolName, + contents, + parsedParams, + autocompleted, + model: loggerOptions?.model, + agent: loggerOptions?.agentName, + }, + logger, + }) + delete parsedParams[toolNameParam] + + processor.onTagStart(toolName, {}) + processor.onTagEnd(toolName, parsedParams) + } + + function extractToolsFromBufferAndProcess(forceFlush = false) { + const matches = extractToolCalls() + matches.forEach(processToolCallContents) + if (forceFlush) { + onResponseChunk({ + type: 'text', + text: buffer, + }) + buffer = '' + } + } + + function* processChunk( + chunk: StreamChunk | undefined, + ): Generator { + if (chunk !== undefined && chunk.type === 'text') { + buffer += chunk.text + } + extractToolsFromBufferAndProcess() + + if (chunk === undefined) { + streamCompleted = true + if (buffer.includes(startToolTag)) { + buffer += completionSuffix + chunk = { + type: 'text', + text: completionSuffix, + } + autocompleted = true + } + extractToolsFromBufferAndProcess(true) + } + + if (chunk) { + yield chunk + } + } + + let messageId: string | null = null + while (true) { + const { value, done } = await stream.next() + if (done) { + messageId = value + break + } + if (streamCompleted) { + break + } + + yield* processChunk(value) + } + + if (!streamCompleted) { + // After the stream ends, try parsing one last time in case there's leftover text + yield* processChunk(undefined) + } + + return messageId +} diff --git a/packages/agent-runtime/src/tool-stream-parser.ts b/packages/agent-runtime/src/tool-stream-parser.ts index 0191596c4..2f096695d 100644 --- a/packages/agent-runtime/src/tool-stream-parser.ts +++ b/packages/agent-runtime/src/tool-stream-parser.ts @@ -1,10 +1,4 @@ import { AnalyticsEvent } from '@codebuff/common/constants/analytics-events' -import { - endsAgentStepParam, - endToolTag, - startToolTag, - toolNameParam, -} from '@codebuff/common/tools/constants' import type { Model } from '@codebuff/common/old-constants' import type { TrackEventFn } from '@codebuff/common/types/contracts/analytics' @@ -13,17 +7,9 @@ import type { Logger } from '@codebuff/common/types/contracts/logger' import type { PrintModeError, PrintModeText, - PrintModeToolCall, } from '@codebuff/common/types/print-mode' -const toolExtractionPattern = new RegExp( - `${startToolTag}(.*?)${endToolTag}`, - 'gs', -) - -const completionSuffix = `${JSON.stringify(endsAgentStepParam)}: true\n}${endToolTag}` - -export async function* processStreamWithTags(params: { +export async function* processStreamWithTools(params: { stream: AsyncGenerator processors: Record< string, @@ -37,9 +23,7 @@ export async function* processStreamWithTags(params: { onTagEnd: (tagName: string, params: Record) => void } onError: (tagName: string, errorMessage: string) => void - onResponseChunk: ( - chunk: PrintModeText | PrintModeToolCall | PrintModeError, - ) => void + onResponseChunk: (chunk: PrintModeText | PrintModeError) => void logger: Logger loggerOptions?: { userId?: string @@ -58,87 +42,18 @@ export async function* processStreamWithTags(params: { loggerOptions, trackEvent, } = params - let streamCompleted = false let buffer = '' let autocompleted = false - function extractToolCalls(): string[] { - const matches: string[] = [] - let lastIndex = 0 - for (const match of buffer.matchAll(toolExtractionPattern)) { - if (match.index > lastIndex) { - onResponseChunk({ - type: 'text', - text: buffer.slice(lastIndex, match.index), - }) - } - lastIndex = match.index + match[0].length - matches.push(match[1]) - } + function processToolCallObject(params: { + toolName: string + input: any + contents?: string + }): void { + const { toolName, input, contents } = params - buffer = buffer.slice(lastIndex) - return matches - } - - function processToolCallContents(contents: string): void { - let parsedParams: any - try { - parsedParams = JSON.parse(contents) - } catch (error: any) { - trackEvent({ - event: AnalyticsEvent.MALFORMED_TOOL_CALL_JSON, - userId: loggerOptions?.userId ?? '', - properties: { - contents: JSON.stringify(contents), - model: loggerOptions?.model, - agent: loggerOptions?.agentName, - error: { - name: error.name, - message: error.message, - stack: error.stack, - }, - autocompleted, - }, - logger, - }) - const shortenedContents = - contents.length < 200 - ? contents - : contents.slice(0, 100) + '...' + contents.slice(-100) - const errorMessage = `Invalid JSON: ${JSON.stringify(shortenedContents)}\nError: ${error.message}` - onResponseChunk({ - type: 'error', - message: errorMessage, - }) - onError('parse_error', errorMessage) - return - } - - const toolName = parsedParams[toolNameParam] as keyof typeof processors - const processor = - typeof toolName === 'string' - ? processors[toolName] ?? defaultProcessor(toolName) - : undefined - if (!processor) { - trackEvent({ - event: AnalyticsEvent.UNKNOWN_TOOL_CALL, - userId: loggerOptions?.userId ?? '', - properties: { - contents, - toolName, - model: loggerOptions?.model, - agent: loggerOptions?.agentName, - autocompleted, - }, - logger, - }) - onError( - 'parse_error', - `Unknown tool ${JSON.stringify(toolName)} for tool call: ${contents}`, - ) - return - } + const processor = processors[toolName] ?? defaultProcessor(toolName) trackEvent({ event: AnalyticsEvent.TOOL_USE, @@ -146,55 +61,48 @@ export async function* processStreamWithTags(params: { properties: { toolName, contents, - parsedParams, + parsedParams: input, autocompleted, model: loggerOptions?.model, agent: loggerOptions?.agentName, }, logger, }) - delete parsedParams[toolNameParam] processor.onTagStart(toolName, {}) - processor.onTagEnd(toolName, parsedParams) + processor.onTagEnd(toolName, input) } - function extractToolsFromBufferAndProcess(forceFlush = false) { - const matches = extractToolCalls() - matches.forEach(processToolCallContents) - if (forceFlush) { + function flush() { + if (buffer) { onResponseChunk({ type: 'text', text: buffer, }) - buffer = '' } + buffer = '' } function* processChunk( chunk: StreamChunk | undefined, ): Generator { - if (chunk !== undefined && chunk.type === 'text') { - buffer += chunk.text - } - extractToolsFromBufferAndProcess() - if (chunk === undefined) { + flush() streamCompleted = true - if (buffer.includes(startToolTag)) { - buffer += completionSuffix - chunk = { - type: 'text', - text: completionSuffix, - } - autocompleted = true - } - extractToolsFromBufferAndProcess(true) + return } - if (chunk) { - yield chunk + if (chunk.type === 'text') { + buffer += chunk.text + } else { + flush() } + + if (chunk.type === 'tool-call') { + processToolCallObject(chunk) + } + + yield chunk } let messageId: string | null = null @@ -207,14 +115,11 @@ export async function* processStreamWithTags(params: { if (streamCompleted) { break } - yield* processChunk(value) } - if (!streamCompleted) { // After the stream ends, try parsing one last time in case there's leftover text yield* processChunk(undefined) } - return messageId } diff --git a/packages/agent-runtime/src/tools/handlers/handler-function-type.ts b/packages/agent-runtime/src/tools/handlers/handler-function-type.ts index 25016f2c6..9244cecbd 100644 --- a/packages/agent-runtime/src/tools/handlers/handler-function-type.ts +++ b/packages/agent-runtime/src/tools/handlers/handler-function-type.ts @@ -18,6 +18,7 @@ import type { Logger } from '@codebuff/common/types/contracts/logger' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState, Subgoal } from '@codebuff/common/types/session-state' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { ToolSet } from 'ai' type PresentOrAbsent = | { [P in K]: V } @@ -49,6 +50,7 @@ export type CodebuffToolHandlerFunction = ( sendSubagentChunk: SendSubagentChunkFn signal: AbortSignal system: string + tools?: ToolSet trackEvent: TrackEventFn userId: string | undefined userInputId: string diff --git a/packages/agent-runtime/src/tools/handlers/tool/add-message.ts b/packages/agent-runtime/src/tools/handlers/tool/add-message.ts index 3bd82a13a..734b37c68 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/add-message.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/add-message.ts @@ -30,5 +30,5 @@ export const handleAddMessage = (async (params: { : assistantMessage(toolCall.input.content), ) - return { output: [] } + return { output: [{ type: 'json', value: { message: 'Message added.' } }] } }) satisfies CodebuffToolHandlerFunction<'add_message'> diff --git a/packages/agent-runtime/src/tools/handlers/tool/end-turn.ts b/packages/agent-runtime/src/tools/handlers/tool/end-turn.ts index 403c97f25..00b2ceaae 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/end-turn.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/end-turn.ts @@ -11,5 +11,5 @@ export const handleEndTurn = (async (params: { const { previousToolCallFinished } = params await previousToolCallFinished - return { output: [] } + return { output: [{ type: 'json', value: { message: 'Turn ended.' } }] } }) satisfies CodebuffToolHandlerFunction<'end_turn'> diff --git a/packages/agent-runtime/src/tools/handlers/tool/set-messages.ts b/packages/agent-runtime/src/tools/handlers/tool/set-messages.ts index 7a574ef8d..ae17fd50a 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/set-messages.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/set-messages.ts @@ -15,5 +15,5 @@ export const handleSetMessages = (async (params: { await previousToolCallFinished agentState.messageHistory = toolCall.input.messages - return { output: [] } + return { output: [{ type: 'json', value: { message: 'Messages set.' } }] } }) satisfies CodebuffToolHandlerFunction<'set_messages'> diff --git a/packages/agent-runtime/src/tools/handlers/tool/set-output.ts b/packages/agent-runtime/src/tools/handlers/tool/set-output.ts index 10944b6ba..2def7b1d5 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/set-output.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/set-output.ts @@ -28,10 +28,10 @@ export const handleSetOutput = (async (params: { }): Promise<{ output: CodebuffToolOutput }> => { const { previousToolCallFinished, toolCall, agentState, logger } = params const output = toolCall.input + const { data } = output ?? {} await previousToolCallFinished - // Validate output against outputSchema if defined let agentTemplate = null if (agentState.agentType) { agentTemplate = await getAgentTemplate({ @@ -39,26 +39,42 @@ export const handleSetOutput = (async (params: { agentId: agentState.agentType, }) } + + let finalOutput: unknown if (agentTemplate?.outputSchema) { + // When outputSchema is defined, validate against it try { agentTemplate.outputSchema.parse(output) + finalOutput = output } catch (error) { - const errorMessage = `Output validation error: Output failed to match the output schema and was ignored. You might want to try again! Issues: ${error}` - logger.error( - { - output, - agentType: agentState.agentType, - agentId: agentState.agentId, - error, - }, - 'set_output validation error', - ) - return { output: jsonToolResult({ message: errorMessage }) } + try { + // Fallback to the 'data' field if the whole output object is not valid + agentTemplate.outputSchema.parse(data) + finalOutput = data + } catch (error2) { + const errorMessage = `Output validation error: Output failed to match the output schema and was ignored. You might want to try again! Issues: ${error}` + logger.error( + { + output, + agentType: agentState.agentType, + agentId: agentState.agentId, + error, + }, + 'set_output validation error', + ) + return { output: jsonToolResult({ message: errorMessage }) } + } } + } else { + // When no outputSchema, use the data field if it is the only field + // otherwise use the entire output object + const keys = Object.keys(output) + const hasOnlyDataField = keys.length === 1 && keys[0] === 'data' + finalOutput = hasOnlyDataField ? data : output } // Set the output (completely replaces previous output) - agentState.output = output + agentState.output = finalOutput as Record return { output: jsonToolResult({ message: 'Output set' }) } }) satisfies CodebuffToolHandlerFunction diff --git a/packages/agent-runtime/src/tools/handlers/tool/spawn-agent-inline.ts b/packages/agent-runtime/src/tools/handlers/tool/spawn-agent-inline.ts index 95a1ad4b1..b168dc68b 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/spawn-agent-inline.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/spawn-agent-inline.ts @@ -17,6 +17,7 @@ import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState } from '@codebuff/common/types/session-state' import type { ProjectFileContext } from '@codebuff/common/util/file' +import type { ToolSet } from 'ai' type ToolName = 'spawn_agent_inline' export const handleSpawnAgentInline = (async ( @@ -32,6 +33,7 @@ export const handleSpawnAgentInline = (async ( localAgentTemplates: Record logger: Logger system: string + tools?: ToolSet userId: string | undefined userInputId: string writeToClient: (chunk: string | PrintModeEvent) => void @@ -44,6 +46,7 @@ export const handleSpawnAgentInline = (async ( | 'parentAgentState' | 'agentState' | 'parentSystemPrompt' + | 'parentTools' | 'onResponseChunk' | 'clearUserPromptMessagesAfterResponse' | 'fingerprintId' @@ -57,6 +60,7 @@ export const handleSpawnAgentInline = (async ( agentTemplate: parentAgentTemplate, fingerprintId, system, + tools: parentTools = {}, userInputId, writeToClient, } = params @@ -105,6 +109,9 @@ export const handleSpawnAgentInline = (async ( agentState: childAgentState, fingerprintId, parentSystemPrompt: system, + parentTools: agentTemplate.inheritParentSystemPrompt + ? parentTools + : undefined, onResponseChunk: (chunk) => { // Inherits parent's onResponseChunk, except for context-pruner (TODO: add an option for it to be silent?) if (agentType !== 'context-pruner') { @@ -117,5 +124,5 @@ export const handleSpawnAgentInline = (async ( // Update parent agent state to reflect shared message history parentAgentState.messageHistory = result.agentState.messageHistory - return { output: [] } + return { output: [{ type: 'json', value: { message: 'Agent spawned.' } }] } }) satisfies CodebuffToolHandlerFunction diff --git a/packages/agent-runtime/src/tools/handlers/tool/spawn-agent-utils.ts b/packages/agent-runtime/src/tools/handlers/tool/spawn-agent-utils.ts index e63cd3801..3f2fef9b9 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/spawn-agent-utils.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/spawn-agent-utils.ts @@ -4,6 +4,7 @@ import { generateCompactId } from '@codebuff/common/util/string' import { loopAgentSteps } from '../../../run-agent-step' import { getAgentTemplate } from '../../../templates/agent-registry' +import { filterUnfinishedToolCalls } from '../../../util/messages' import type { AgentTemplate } from '@codebuff/common/types/agent-template' import type { Logger } from '@codebuff/common/types/contracts/logger' @@ -11,6 +12,7 @@ import type { ParamsExcluding, OptionalFields, } from '@codebuff/common/types/function-params' +import type { ToolSet } from 'ai' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState, @@ -161,8 +163,11 @@ export function createAgentState( ): AgentState { const agentId = generateCompactId() + // When including message history, filter out any tool calls that don't have + // corresponding tool responses. This prevents the spawned agent from seeing + // unfinished tool calls which throw errors in the Anthropic API. const messageHistory = agentTemplate.includeMessageHistory - ? parentAgentState.messageHistory + ? filterUnfinishedToolCalls(parentAgentState.messageHistory) : [] return { @@ -227,6 +232,7 @@ export async function executeSubagent( { agentTemplate: AgentTemplate parentAgentState: AgentState + parentTools?: ToolSet onResponseChunk: (chunk: string | PrintModeEvent) => void isOnlyChild?: boolean ancestorRunIds: string[] diff --git a/packages/agent-runtime/src/tools/handlers/tool/spawn-agents.ts b/packages/agent-runtime/src/tools/handlers/tool/spawn-agents.ts index 38cb0b731..9ee345c7e 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/spawn-agents.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/spawn-agents.ts @@ -18,6 +18,7 @@ import type { Logger } from '@codebuff/common/types/contracts/logger' import type { ParamsExcluding } from '@codebuff/common/types/function-params' import type { PrintModeEvent } from '@codebuff/common/types/print-mode' import type { AgentState } from '@codebuff/common/types/session-state' +import type { ToolSet } from 'ai' export type SendSubagentChunk = (data: { userInputId: string @@ -40,6 +41,7 @@ export const handleSpawnAgents = (async ( localAgentTemplates: Record logger: Logger system: string + tools?: ToolSet userId: string | undefined userInputId: string sendSubagentChunk: SendSubagentChunk @@ -59,6 +61,7 @@ export const handleSpawnAgents = (async ( | 'fingerprintId' | 'isOnlyChild' | 'parentSystemPrompt' + | 'parentTools' | 'onResponseChunk' >, ): Promise<{ output: CodebuffToolOutput }> => { @@ -70,6 +73,7 @@ export const handleSpawnAgents = (async ( agentTemplate: parentAgentTemplate, fingerprintId, system: parentSystemPrompt, + tools: parentTools = {}, userInputId, sendSubagentChunk, writeToClient, @@ -118,6 +122,9 @@ export const handleSpawnAgents = (async ( fingerprintId, isOnlyChild: agents.length === 1, parentSystemPrompt, + parentTools: agentTemplate.inheritParentSystemPrompt + ? parentTools + : undefined, onResponseChunk: (chunk: string | PrintModeEvent) => { if (typeof chunk === 'string') { sendSubagentChunk({ diff --git a/packages/agent-runtime/src/tools/handlers/tool/task-completed.ts b/packages/agent-runtime/src/tools/handlers/tool/task-completed.ts index 4e8ea657f..8c8aea2d1 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/task-completed.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/task-completed.ts @@ -11,5 +11,5 @@ export const handleTaskCompleted = (async ({ toolCall: CodebuffToolCall<'task_completed'> }): Promise<{ output: CodebuffToolOutput<'task_completed'> }> => { await previousToolCallFinished - return { output: [] } + return { output: [{ type: 'json', value: { message: 'Task completed.' } }] } }) satisfies CodebuffToolHandlerFunction<'task_completed'> diff --git a/packages/agent-runtime/src/tools/handlers/tool/think-deeply.ts b/packages/agent-runtime/src/tools/handlers/tool/think-deeply.ts index fcac682c8..67dddbd6b 100644 --- a/packages/agent-runtime/src/tools/handlers/tool/think-deeply.ts +++ b/packages/agent-runtime/src/tools/handlers/tool/think-deeply.ts @@ -21,5 +21,5 @@ export const handleThinkDeeply = (async (params: { ) await previousToolCallFinished - return { output: [] } + return { output: [{ type: 'json', value: { message: 'Thought logged.' } }] } }) satisfies CodebuffToolHandlerFunction<'think_deeply'> diff --git a/packages/agent-runtime/src/tools/prompts.ts b/packages/agent-runtime/src/tools/prompts.ts index d893b8cf7..fc0342481 100644 --- a/packages/agent-runtime/src/tools/prompts.ts +++ b/packages/agent-runtime/src/tools/prompts.ts @@ -272,8 +272,9 @@ ${toolDescriptions.join('\n\n')} export async function getToolSet(params: { toolNames: string[] additionalToolDefinitions: () => Promise + agentTools: ToolSet }): Promise { - const { toolNames, additionalToolDefinitions } = params + const { toolNames, additionalToolDefinitions, agentTools } = params const toolSet: ToolSet = {} for (const toolName of toolNames) { @@ -289,5 +290,10 @@ export async function getToolSet(params: { } as Omit & { inputSchema: z.ZodType } } + // Add agent tools (agents as direct tool calls) + for (const [toolName, toolDefinition] of Object.entries(agentTools)) { + toolSet[toolName] = toolDefinition + } + return toolSet } diff --git a/packages/agent-runtime/src/tools/stream-parser.ts b/packages/agent-runtime/src/tools/stream-parser.ts index f47d41a67..c95ac60b9 100644 --- a/packages/agent-runtime/src/tools/stream-parser.ts +++ b/packages/agent-runtime/src/tools/stream-parser.ts @@ -3,13 +3,14 @@ import { buildArray } from '@codebuff/common/util/array' import { jsonToolResult, assistantMessage, + userMessage, } from '@codebuff/common/util/messages' import { generateCompactId } from '@codebuff/common/util/string' import { cloneDeep } from 'lodash' -import { processStreamWithTags } from '../tool-stream-parser' -import { executeCustomToolCall, executeToolCall } from './tool-executor' -import { expireMessages } from '../util/messages' +import { processStreamWithTools } from '../tool-stream-parser' +import { executeCustomToolCall, executeToolCall, tryTransformAgentToolCall } from './tool-executor' +import { expireMessages, withSystemTags } from '../util/messages' import type { CustomToolCall, ExecuteToolCallParams } from './tool-executor' import type { AgentTemplate } from '../templates/types' @@ -33,7 +34,7 @@ export type ToolCallError = { error: string } & Omit -export async function processStreamWithTools( +export async function processStream( params: { agentContext: Record agentTemplate: AgentTemplate @@ -65,7 +66,7 @@ export async function processStreamWithTools( | 'toolResultsToAddAfterStream' > & ParamsExcluding< - typeof processStreamWithTags, + typeof processStreamWithTools, 'processors' | 'defaultProcessor' | 'onError' | 'loggerOptions' >, ) { @@ -80,12 +81,14 @@ export async function processStreamWithTools( runId, signal, userId, + logger, } = params const fullResponseChunks: string[] = [fullResponse] const toolResults: ToolMessage[] = [] const toolResultsToAddAfterStream: ToolMessage[] = [] const toolCalls: (CodebuffToolCall | CustomToolCall)[] = [] + const assistantMessages: Message[] = [] const { promise: streamDonePromise, resolve: resolveStreamDonePromise } = Promise.withResolvers() let previousToolCallFinished = streamDonePromise @@ -122,6 +125,14 @@ export async function processStreamWithTools( toolResultsToAddAfterStream, onCostCalculated, + onResponseChunk: (chunk) => { + if (typeof chunk !== 'string' && chunk.type === 'tool_call') { + assistantMessages.push( + assistantMessage({ ...chunk, type: 'tool-call' }), + ) + } + return onResponseChunk(chunk) + }, }) }, } @@ -134,32 +145,76 @@ export async function processStreamWithTools( return } const toolCallId = generateCompactId() - // delegated to reusable helper - previousToolCallFinished = executeCustomToolCall({ - ...params, + + // Check if this is an agent tool call - if so, transform to spawn_agents + const transformed = tryTransformAgentToolCall({ toolName, input, - - fileProcessingState, - fullResponse: fullResponseChunks.join(''), - previousToolCallFinished, - toolCallId, - toolCalls, - toolResults, - toolResultsToAddAfterStream, + spawnableAgents: agentTemplate.spawnableAgents, }) + + if (transformed) { + // Use executeToolCall for spawn_agents (a native tool) + previousToolCallFinished = executeToolCall({ + ...params, + toolName: transformed.toolName, + input: transformed.input, + fromHandleSteps: false, + + fileProcessingState, + fullResponse: fullResponseChunks.join(''), + previousToolCallFinished, + toolCallId, + toolCalls, + toolResults, + toolResultsToAddAfterStream, + + onCostCalculated, + onResponseChunk: (chunk) => { + if (typeof chunk !== 'string' && chunk.type === 'tool_call') { + assistantMessages.push( + assistantMessage({ ...chunk, type: 'tool-call' }), + ) + } + return onResponseChunk(chunk) + }, + }) + } else { + // delegated to reusable helper for custom tools + previousToolCallFinished = executeCustomToolCall({ + ...params, + toolName, + input, + + fileProcessingState, + fullResponse: fullResponseChunks.join(''), + previousToolCallFinished, + toolCallId, + toolCalls, + toolResults, + toolResultsToAddAfterStream, + + onResponseChunk: (chunk) => { + if (typeof chunk !== 'string' && chunk.type === 'tool_call') { + assistantMessages.push( + assistantMessage({ ...chunk, type: 'tool-call' }), + ) + } + return onResponseChunk(chunk) + }, + }) + } }, } } - const streamWithTags = processStreamWithTags({ + const streamWithTags = processStreamWithTools({ ...params, processors: Object.fromEntries([ ...toolNames.map((toolName) => [toolName, toolCallback(toolName)]), - ...Object.keys(fileContext.customToolDefinitions ?? {}).map((toolName) => [ - toolName, - customToolCallback(toolName), - ]), + ...Object.keys(fileContext.customToolDefinitions ?? {}).map( + (toolName) => [toolName, customToolCallback(toolName)], + ), ]), defaultProcessor: customToolCallback, onError: (toolName, error) => { @@ -179,9 +234,25 @@ export async function processStreamWithTools( model: agentTemplate.model, agentName: agentTemplate.id, }, + onResponseChunk: (chunk) => { + if (chunk.type === 'text') { + if (chunk.text) { + assistantMessages.push(assistantMessage(chunk.text)) + } + } else if (chunk.type === 'error') { + // do nothing + } else { + chunk satisfies never + throw new Error( + `Internal error: unhandled chunk type: ${(chunk as any).type}`, + ) + } + return onResponseChunk(chunk) + }, }) let messageId: string | null = null + let hadToolCallError = false while (true) { if (signal.aborted) { break @@ -204,15 +275,27 @@ export async function processStreamWithTools( fullResponseChunks.push(chunk.text) } else if (chunk.type === 'error') { onResponseChunk(chunk) + + hadToolCallError = true + // Add error message to assistant messages so the agent can see what went wrong and retry + assistantMessages.push( + userMessage( + withSystemTags( + `Error during tool call: ${chunk.message}. Please check the tool name and arguments and try again.`, + ), + ), + ) + } else if (chunk.type === 'tool-call') { + // Do nothing, the onResponseChunk for tool is handled in the processor } else { chunk satisfies never + throw new Error(`Unhandled chunk type: ${(chunk as any).type}`) } } agentState.messageHistory = buildArray([ ...expireMessages(agentState.messageHistory, 'agentStep'), - fullResponseChunks.length > 0 && - assistantMessage(fullResponseChunks.join('')), + ...assistantMessages, ...toolResultsToAddAfterStream, ]) @@ -223,6 +306,7 @@ export async function processStreamWithTools( return { fullResponse: fullResponseChunks.join(''), fullResponseChunks, + hadToolCallError, messageId, toolCalls, toolResults, diff --git a/packages/agent-runtime/src/tools/tool-executor.ts b/packages/agent-runtime/src/tools/tool-executor.ts index 1baa2b774..6886f9b90 100644 --- a/packages/agent-runtime/src/tools/tool-executor.ts +++ b/packages/agent-runtime/src/tools/tool-executor.ts @@ -3,12 +3,14 @@ import { toolParams } from '@codebuff/common/tools/list' import { jsonToolResult } from '@codebuff/common/util/messages' import { generateCompactId } from '@codebuff/common/util/string' import { cloneDeep } from 'lodash' -import z from 'zod/v4' import { checkLiveUserInput } from '../live-user-inputs' import { getMCPToolData } from '../mcp' +import { getAgentShortName } from '../templates/prompts' import { codebuffToolHandlers } from './handlers/list' +import type { AgentTemplateType } from '@codebuff/common/types/session-state' + import type { AgentTemplate } from '../templates/types' import type { CodebuffToolHandlerFunction } from './handlers/handler-function-type' import type { FileProcessingState } from './handlers/tool/write-file' @@ -32,7 +34,7 @@ import type { CustomToolDefinitions, ProjectFileContext, } from '@codebuff/common/util/file' -import type { ToolCallPart } from 'ai' +import type { ToolCallPart, ToolSet } from 'ai' export type CustomToolCall = { toolName: string @@ -66,24 +68,28 @@ export function parseRawToolCall(params: { } const validName = toolName as T - const processedParameters: Record = {} - for (const [param, val] of Object.entries(rawToolCall.input ?? {})) { - processedParameters[param] = val - } + // const processedParameters: Record = {} + // for (const [param, val] of Object.entries(rawToolCall.input ?? {})) { + // processedParameters[param] = val + // } // Add the required codebuff_end_step parameter with the correct value for this tool if requested - if (autoInsertEndStepParam) { - processedParameters[endsAgentStepParam] = - toolParams[validName].endsAgentStep - } + // if (autoInsertEndStepParam) { + // processedParameters[endsAgentStepParam] = + // toolParams[validName].endsAgentStep + // } + + // const paramsSchema = toolParams[validName].endsAgentStep + // ? ( + // toolParams[validName].inputSchema satisfies z.ZodObject as z.ZodObject + // ).extend({ + // [endsAgentStepParam]: z.literal(toolParams[validName].endsAgentStep), + // }) + // : toolParams[validName].inputSchema + + const processedParameters = rawToolCall.input + const paramsSchema = toolParams[validName].inputSchema - const paramsSchema = toolParams[validName].endsAgentStep - ? ( - toolParams[validName].inputSchema satisfies z.ZodObject as z.ZodObject - ).extend({ - [endsAgentStepParam]: z.literal(toolParams[validName].endsAgentStep), - }) - : toolParams[validName].inputSchema const result = paramsSchema.safeParse(processedParameters) if (!result.success) { @@ -136,6 +142,7 @@ export type ExecuteToolCallParams = { runId: string signal: AbortSignal system: string + tools?: ToolSet toolCallId: string | undefined toolCalls: (CodebuffToolCall | CustomToolCall)[] toolResults: ToolMessage[] @@ -178,10 +185,9 @@ export function executeToolCall( toolCallId, toolName, input, - // Only include agentId for subagents (agents with a parent) - ...(agentState.parentId && { agentId: agentState.agentId }), - // Include includeToolCall flag if explicitly set to false - ...(excludeToolFromMessageHistory && { includeToolCall: false }), + agentId: agentState.agentId, + parentAgentId: agentState.parentId, + includeToolCall: !excludeToolFromMessageHistory, }) const toolCall: CodebuffToolCall | ToolCallError = parseRawToolCall({ @@ -495,3 +501,43 @@ export async function executeCustomToolCall( return }) } + +/** + * Checks if a tool name matches a spawnable agent and returns the transformed + * spawn_agents input if so. Returns null if not an agent tool call. + */ +export function tryTransformAgentToolCall(params: { + toolName: string + input: Record + spawnableAgents: AgentTemplateType[] +}): { toolName: 'spawn_agents'; input: Record } | null { + const { toolName, input, spawnableAgents } = params + + const agentShortNames = spawnableAgents.map(getAgentShortName) + if (!agentShortNames.includes(toolName)) { + return null + } + + // Find the full agent type for this short name + const fullAgentType = spawnableAgents.find( + (agentType) => getAgentShortName(agentType) === toolName, + ) + + // Convert to spawn_agents call + const spawnAgentsInput = { + agents: [ + { + agent_type: fullAgentType || toolName, + ...(typeof input.prompt === 'string' && { prompt: input.prompt }), + // Put all other fields into params + ...(Object.keys(input).filter((k) => k !== 'prompt').length > 0 && { + params: Object.fromEntries( + Object.entries(input).filter(([k]) => k !== 'prompt'), + ), + }), + }, + ], + } + + return { toolName: 'spawn_agents', input: spawnAgentsInput } +} diff --git a/packages/agent-runtime/src/util/__tests__/messages.test.ts b/packages/agent-runtime/src/util/__tests__/messages.test.ts index 3ee55e709..3a82fdd58 100644 --- a/packages/agent-runtime/src/util/__tests__/messages.test.ts +++ b/packages/agent-runtime/src/util/__tests__/messages.test.ts @@ -18,6 +18,7 @@ import { trimMessagesToFitTokenLimit, messagesWithSystem, getPreviouslyReadFiles, + filterUnfinishedToolCalls, } from '../../util/messages' import * as tokenCounter from '../token-counter' @@ -406,6 +407,235 @@ describe('trimMessagesToFitTokenLimit', () => { }) }) +describe('filterUnfinishedToolCalls', () => { + it('returns empty array when given empty messages', () => { + const result = filterUnfinishedToolCalls([]) + expect(result).toEqual([]) + }) + + it('keeps messages that are not assistant messages', () => { + const messages: Message[] = [ + userMessage('Hello'), + systemMessage('System prompt'), + { + role: 'tool', + toolName: 'read_files', + toolCallId: 'tool-1', + content: jsonToolResult({ files: [] }), + }, + ] + + const result = filterUnfinishedToolCalls(messages) + expect(result).toHaveLength(3) + expect(result).toEqual(messages) + }) + + it('keeps assistant messages with text content only', () => { + const messages: Message[] = [ + userMessage('Hello'), + assistantMessage('Hi there!'), + userMessage('How are you?'), + assistantMessage('I am doing well.'), + ] + + const result = filterUnfinishedToolCalls(messages) + expect(result).toHaveLength(4) + expect(result).toEqual(messages) + }) + + it('keeps tool calls that have corresponding tool responses', () => { + const messages: Message[] = [ + userMessage('Read a file'), + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'read_files', + input: { paths: ['test.ts'] }, + }, + ], + }, + { + role: 'tool', + toolName: 'read_files', + toolCallId: 'call-1', + content: jsonToolResult({ content: 'file content' }), + }, + ] + + const result = filterUnfinishedToolCalls(messages) + expect(result).toHaveLength(3) + expect(result[1].role).toBe('assistant') + expect(result[1].content).toHaveLength(1) + expect(result[1].content[0].type).toBe('tool-call') + }) + + it('removes tool calls that do not have corresponding tool responses', () => { + const messages: Message[] = [ + userMessage('Read a file'), + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'read_files', + input: { paths: ['test.ts'] }, + }, + ], + }, + // No tool response for call-1 + ] + + const result = filterUnfinishedToolCalls(messages) + expect(result).toHaveLength(1) // Only the user message + expect(result[0].role).toBe('user') + }) + + it('removes only unfinished tool calls from assistant messages with mixed content', () => { + const messages: Message[] = [ + userMessage('Read files'), + { + role: 'assistant', + content: [ + { type: 'text', text: 'I will read these files' }, + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'read_files', + input: { paths: ['file1.ts'] }, + }, + { + type: 'tool-call', + toolCallId: 'call-2', + toolName: 'read_files', + input: { paths: ['file2.ts'] }, + }, + ], + }, + { + role: 'tool', + toolName: 'read_files', + toolCallId: 'call-1', + content: jsonToolResult({ content: 'file1 content' }), + }, + // No tool response for call-2 + ] + + const result = filterUnfinishedToolCalls(messages) + expect(result).toHaveLength(3) // user, assistant (filtered), tool + + const assistantMsg = result[1] + expect(assistantMsg.role).toBe('assistant') + expect(assistantMsg.content).toHaveLength(2) // text + call-1 (call-2 removed) + expect(assistantMsg.content[0].type).toBe('text') + expect(assistantMsg.content[1].type).toBe('tool-call') + expect((assistantMsg.content[1] as any).toolCallId).toBe('call-1') + }) + + it('removes assistant message entirely if all content parts are unfinished tool calls', () => { + const messages: Message[] = [ + userMessage('Do something'), + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'write_file', + input: { path: 'test.ts', content: 'test' }, + }, + { + type: 'tool-call', + toolCallId: 'call-2', + toolName: 'read_files', + input: { paths: ['other.ts'] }, + }, + ], + }, + // No tool responses + ] + + const result = filterUnfinishedToolCalls(messages) + expect(result).toHaveLength(1) // Only the user message + expect(result[0].role).toBe('user') + }) + + it('handles multiple assistant messages with different tool call states', () => { + const messages: Message[] = [ + userMessage('First request'), + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'read_files', + input: { paths: ['file1.ts'] }, + }, + ], + }, + { + role: 'tool', + toolName: 'read_files', + toolCallId: 'call-1', + content: jsonToolResult({ content: 'content1' }), + }, + userMessage('Second request'), + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'call-2', + toolName: 'write_file', + input: { path: 'test.ts', content: 'test' }, + }, + ], + }, + // No tool response for call-2 (unfinished) + ] + + const result = filterUnfinishedToolCalls(messages) + expect(result).toHaveLength(4) // user1, assistant1 (kept), tool1, user2 + expect(result[0].role).toBe('user') + expect(result[1].role).toBe('assistant') + expect(result[2].role).toBe('tool') + expect(result[3].role).toBe('user') + }) + + it('preserves auxiliary message data on filtered assistant messages', () => { + const messages: Message[] = [ + userMessage('Test'), + { + role: 'assistant', + content: [ + { type: 'text', text: 'Response' }, + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'read_files', + input: { paths: ['test.ts'] }, + }, + ], + tags: ['important'], + keepDuringTruncation: true, + }, + // No tool response + ] + + const result = filterUnfinishedToolCalls(messages) + expect(result).toHaveLength(2) + + const assistantMsg = result[1] + expect(assistantMsg.tags).toEqual(['important']) + expect(assistantMsg.keepDuringTruncation).toBe(true) + expect(assistantMsg.content).toHaveLength(1) // Only text, tool-call removed + }) +}) + describe('getPreviouslyReadFiles', () => { it('returns empty array when no messages provided', () => { const result = getPreviouslyReadFiles({ messages: [], logger }) diff --git a/packages/agent-runtime/src/util/__tests__/parse-tool-calls-from-text.test.ts b/packages/agent-runtime/src/util/__tests__/parse-tool-calls-from-text.test.ts new file mode 100644 index 000000000..a61e82703 --- /dev/null +++ b/packages/agent-runtime/src/util/__tests__/parse-tool-calls-from-text.test.ts @@ -0,0 +1,363 @@ +import { describe, expect, it } from 'bun:test' + +import { + parseToolCallsFromText, + parseTextWithToolCalls, +} from '../parse-tool-calls-from-text' + +describe('parseToolCallsFromText', () => { + it('should parse a single tool call', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + toolName: 'read_files', + input: { paths: ['test.ts'] }, + }) + }) + + it('should parse multiple tool calls', () => { + const text = `Some commentary before + + +{ + "cb_tool_name": "read_files", + "paths": ["file1.ts"] +} + + +Some text between + + +{ + "cb_tool_name": "str_replace", + "path": "file1.ts", + "replacements": [{"old": "foo", "new": "bar"}] +} + + +Some commentary after` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ + toolName: 'read_files', + input: { paths: ['file1.ts'] }, + }) + expect(result[1]).toEqual({ + toolName: 'str_replace', + input: { + path: 'file1.ts', + replacements: [{ old: 'foo', new: 'bar' }], + }, + }) + }) + + it('should remove cb_tool_name from input', () => { + const text = ` +{ + "cb_tool_name": "write_file", + "path": "test.ts", + "content": "console.log('hello')" +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].input).not.toHaveProperty('cb_tool_name') + expect(result[0].input).toEqual({ + path: 'test.ts', + content: "console.log('hello')", + }) + }) + + it('should remove cb_easp from input', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"], + "cb_easp": true +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].input).not.toHaveProperty('cb_easp') + expect(result[0].input).toEqual({ paths: ['test.ts'] }) + }) + + it('should skip malformed JSON', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts" +} + + + +{ + "cb_tool_name": "write_file", + "path": "good.ts", + "content": "valid" +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].toolName).toBe('write_file') + }) + + it('should skip tool calls without cb_tool_name', () => { + const text = ` +{ + "paths": ["test.ts"] +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(0) + }) + + it('should return empty array for text without tool calls', () => { + const text = 'Just some regular text without any tool calls' + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(0) + }) + + it('should return empty array for empty string', () => { + const result = parseToolCallsFromText('') + + expect(result).toHaveLength(0) + }) + + it('should handle complex nested objects in input', () => { + const text = ` +{ + "cb_tool_name": "spawn_agents", + "agents": [ + { + "agent_type": "file-picker", + "prompt": "Find relevant files" + }, + { + "agent_type": "code-searcher", + "params": { + "searchQueries": [ + {"pattern": "function test"} + ] + } + } + ] +} +` + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].toolName).toBe('spawn_agents') + expect(result[0].input.agents).toHaveLength(2) + }) + + it('should handle tool calls with escaped characters in strings', () => { + const text = + '\n' + + '{\n' + + ' "cb_tool_name": "str_replace",\n' + + ' "path": "test.ts",\n' + + ' "replacements": [{"old": "console.log(\\"hello\\")", "new": "console.log(\'world\')"}]\n' + + '}\n' + + '' + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + const replacements = result[0].input.replacements as Array<{ + old: string + new: string + }> + expect(replacements[0].old).toBe('console.log("hello")') + }) + + it('should handle tool calls with newlines in content', () => { + const text = + '\n' + + '{\n' + + ' "cb_tool_name": "write_file",\n' + + ' "path": "test.ts",\n' + + ' "content": "line1\\nline2\\nline3"\n' + + '}\n' + + '' + + const result = parseToolCallsFromText(text) + + expect(result).toHaveLength(1) + expect(result[0].input.content).toBe('line1\nline2\nline3') + }) +}) + +describe('parseTextWithToolCalls', () => { + it('should parse interleaved text and tool calls', () => { + const text = `Some commentary before + + +{ + "cb_tool_name": "read_files", + "paths": ["file1.ts"] +} + + +Some text between + + +{ + "cb_tool_name": "write_file", + "path": "file2.ts", + "content": "test" +} + + +Some commentary after` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(5) + expect(result[0]).toEqual({ type: 'text', text: 'Some commentary before' }) + expect(result[1]).toEqual({ + type: 'tool_call', + toolName: 'read_files', + input: { paths: ['file1.ts'] }, + }) + expect(result[2]).toEqual({ type: 'text', text: 'Some text between' }) + expect(result[3]).toEqual({ + type: 'tool_call', + toolName: 'write_file', + input: { path: 'file2.ts', content: 'test' }, + }) + expect(result[4]).toEqual({ type: 'text', text: 'Some commentary after' }) + }) + + it('should return only tool call when no surrounding text', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} +` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + type: 'tool_call', + toolName: 'read_files', + input: { paths: ['test.ts'] }, + }) + }) + + it('should return only text when no tool calls', () => { + const text = 'Just some regular text without any tool calls' + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + type: 'text', + text: 'Just some regular text without any tool calls', + }) + }) + + it('should return empty array for empty string', () => { + const result = parseTextWithToolCalls('') + + expect(result).toHaveLength(0) + }) + + it('should handle text only before tool call', () => { + const text = `Introduction text + + +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} +` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ type: 'text', text: 'Introduction text' }) + expect(result[1].type).toBe('tool_call') + }) + + it('should handle text only after tool call', () => { + const text = ` +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} + + +Conclusion text` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(2) + expect(result[0].type).toBe('tool_call') + expect(result[1]).toEqual({ type: 'text', text: 'Conclusion text' }) + }) + + it('should skip malformed tool calls but keep surrounding text', () => { + const text = `Before text + + +{ + "cb_tool_name": "read_files", + "paths": ["test.ts" +} + + +After text` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ type: 'text', text: 'Before text' }) + expect(result[1]).toEqual({ type: 'text', text: 'After text' }) + }) + + it('should trim whitespace from text segments', () => { + const text = ` + Text with whitespace + + +{ + "cb_tool_name": "read_files", + "paths": ["test.ts"] +} + + + More text + ` + + const result = parseTextWithToolCalls(text) + + expect(result).toHaveLength(3) + expect(result[0]).toEqual({ type: 'text', text: 'Text with whitespace' }) + expect(result[1].type).toBe('tool_call') + expect(result[2]).toEqual({ type: 'text', text: 'More text' }) + }) +}) diff --git a/packages/agent-runtime/src/util/agent-output.ts b/packages/agent-runtime/src/util/agent-output.ts index 624e3ca63..fe3a8da0a 100644 --- a/packages/agent-runtime/src/util/agent-output.ts +++ b/packages/agent-runtime/src/util/agent-output.ts @@ -1,10 +1,49 @@ import type { AgentTemplate } from '@codebuff/common/types/agent-template' -import type { AssistantMessage } from '@codebuff/common/types/messages/codebuff-message' +import type { Message } from '@codebuff/common/types/messages/codebuff-message' import type { AgentState, AgentOutput, } from '@codebuff/common/types/session-state' +/** + * Get the last assistant turn messages, which includes the last assistant message + * and any subsequent tool messages that are responses to its tool calls. + */ +function getLastAssistantTurnMessages(messageHistory: Message[]): Message[] { + // Find the index of the last assistant message + let lastAssistantIndex = -1 + for (let i = messageHistory.length - 1; i >= 0; i--) { + if (messageHistory[i].role === 'assistant') { + lastAssistantIndex = i + break + } + } + + for (let i = lastAssistantIndex; i >= 0; i--) { + if (messageHistory[i].role === 'assistant') { + lastAssistantIndex = i + } else break + } + + if (lastAssistantIndex === -1) { + return [] + } + + // Collect the assistant message and all subsequent tool messages + const result: Message[] = [] + for (let i = lastAssistantIndex; i < messageHistory.length; i++) { + const message = messageHistory[i] + if (message.role === 'assistant' || message.role === 'tool') { + result.push(message) + } else { + // Stop if we hit a user or system message + break + } + } + + return result +} + export function getAgentOutput( agentState: AgentState, agentTemplate: AgentTemplate, @@ -16,11 +55,10 @@ export function getAgentOutput( } } if (agentTemplate.outputMode === 'last_message') { - const assistantMessages = agentState.messageHistory.filter( - (message): message is AssistantMessage => message.role === 'assistant', + const lastTurnMessages = getLastAssistantTurnMessages( + agentState.messageHistory, ) - const lastAssistantMessage = assistantMessages[assistantMessages.length - 1] - if (!lastAssistantMessage) { + if (lastTurnMessages.length === 0) { return { type: 'error', message: 'No response from agent', @@ -28,7 +66,7 @@ export function getAgentOutput( } return { type: 'lastMessage', - value: lastAssistantMessage.content, + value: lastTurnMessages, } } if (agentTemplate.outputMode === 'all_messages') { diff --git a/packages/agent-runtime/src/util/messages.ts b/packages/agent-runtime/src/util/messages.ts index 9b44eccc0..ec36916d2 100644 --- a/packages/agent-runtime/src/util/messages.ts +++ b/packages/agent-runtime/src/util/messages.ts @@ -313,6 +313,53 @@ export function expireMessages( }) } +/** + * Removes tool calls from the message history that don't have corresponding tool responses. + * This is important when passing message history to spawned agents, as unfinished tool calls + * will cause issues with the LLM expecting tool responses. + * + * The function: + * 1. Collects all toolCallIds from tool response messages + * 2. Filters assistant messages to remove tool-call content parts without responses + * 3. Removes assistant messages that become empty after filtering + */ +export function filterUnfinishedToolCalls(messages: Message[]): Message[] { + // Collect all toolCallIds that have corresponding tool responses + const respondedToolCallIds = new Set() + for (const message of messages) { + if (message.role === 'tool') { + respondedToolCallIds.add(message.toolCallId) + } + } + + // Filter messages, removing unfinished tool calls from assistant messages + const filteredMessages: Message[] = [] + for (const message of messages) { + if (message.role !== 'assistant') { + filteredMessages.push(message) + continue + } + + // Filter out tool-call content parts that don't have responses + const filteredContent = message.content.filter((part) => { + if (part.type !== 'tool-call') { + return true + } + return respondedToolCallIds.has(part.toolCallId) + }) + + // Only include the assistant message if it has content after filtering + if (filteredContent.length > 0) { + filteredMessages.push({ + ...message, + content: filteredContent, + }) + } + } + + return filteredMessages +} + export function getEditedFiles(params: { messages: Message[] logger: Logger diff --git a/packages/agent-runtime/src/util/parse-tool-calls-from-text.ts b/packages/agent-runtime/src/util/parse-tool-calls-from-text.ts new file mode 100644 index 000000000..4f9900a9e --- /dev/null +++ b/packages/agent-runtime/src/util/parse-tool-calls-from-text.ts @@ -0,0 +1,117 @@ +import { + startToolTag, + endToolTag, + toolNameParam, +} from '@codebuff/common/tools/constants' + +export type ParsedToolCallFromText = { + type: 'tool_call' + toolName: string + input: Record +} + +export type ParsedTextSegment = { + type: 'text' + text: string +} + +export type ParsedSegment = ParsedToolCallFromText | ParsedTextSegment + +/** + * Parses text containing tool calls in the XML format, + * returning interleaved text and tool call segments in order. + * + * Example input: + * ``` + * Some text before + * + * { + * "cb_tool_name": "read_files", + * "paths": ["file.ts"] + * } + * + * Some text after + * ``` + * + * @param text - The text containing tool calls in XML format + * @returns Array of segments (text and tool calls) in order of appearance + */ +export function parseTextWithToolCalls(text: string): ParsedSegment[] { + const segments: ParsedSegment[] = [] + + // Match ... blocks + const toolExtractionPattern = new RegExp( + `${escapeRegex(startToolTag)}([\\s\\S]*?)${escapeRegex(endToolTag)}`, + 'gs', + ) + + let lastIndex = 0 + + for (const match of text.matchAll(toolExtractionPattern)) { + // Add any text before this tool call + if (match.index !== undefined && match.index > lastIndex) { + const textBefore = text.slice(lastIndex, match.index).trim() + if (textBefore) { + segments.push({ type: 'text', text: textBefore }) + } + } + + const jsonContent = match[1].trim() + + try { + const parsed = JSON.parse(jsonContent) + const toolName = parsed[toolNameParam] + + if (typeof toolName === 'string') { + // Remove the tool name param from the input + const input = { ...parsed } + delete input[toolNameParam] + + // Also remove cb_easp if present + delete input['cb_easp'] + + segments.push({ + type: 'tool_call', + toolName, + input, + }) + } + } catch { + // Skip malformed JSON - don't add segment + } + + // Update lastIndex to after this match + if (match.index !== undefined) { + lastIndex = match.index + match[0].length + } + } + + // Add any remaining text after the last tool call + if (lastIndex < text.length) { + const textAfter = text.slice(lastIndex).trim() + if (textAfter) { + segments.push({ type: 'text', text: textAfter }) + } + } + + return segments +} + +/** + * Parses tool calls from text in the XML format. + * This is a convenience function that returns only tool calls (no text segments). + * + * @param text - The text containing tool calls in XML format + * @returns Array of parsed tool calls with toolName and input + */ +export function parseToolCallsFromText( + text: string, +): Omit[] { + return parseTextWithToolCalls(text) + .filter((segment): segment is ParsedToolCallFromText => segment.type === 'tool_call') + .map(({ toolName, input }) => ({ toolName, input })) +} + +function escapeRegex(string: string): string { + return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') +} diff --git a/packages/internal/src/openrouter-ai-sdk/chat/index.test.ts b/packages/internal/src/openrouter-ai-sdk/chat/index.test.ts index 29a4ddc2b..6fa153a10 100644 --- a/packages/internal/src/openrouter-ai-sdk/chat/index.test.ts +++ b/packages/internal/src/openrouter-ai-sdk/chat/index.test.ts @@ -660,13 +660,15 @@ describe('doGenerate', () => { const requestHeaders = server.calls[0]!.requestHeaders - expect(requestHeaders).toStrictEqual({ - authorization: 'Bearer test-api-key', - 'content-type': 'application/json', - 'custom-provider-header': 'provider-header-value', - 'custom-request-header': 'request-header-value', - 'user-agent': 'ai-sdk/provider-utils/3.0.17 runtime/bun/1.3.0', - }) + expect(requestHeaders.authorization).toBe('Bearer test-api-key') + expect(requestHeaders['content-type']).toBe('application/json') + expect(requestHeaders['custom-provider-header']).toBe( + 'provider-header-value', + ) + expect(requestHeaders['custom-request-header']).toBe('request-header-value') + expect(requestHeaders['user-agent']).toMatch( + /^ai-sdk\/provider-utils\/\d+\.\d+\.\d+ runtime\/bun\/\d+\.\d+\.\d+$/, + ) }) it('should pass responseFormat for JSON schema structured outputs', async () => { @@ -1496,13 +1498,15 @@ describe('doStream', () => { const requestHeaders = server.calls[0]!.requestHeaders - expect(requestHeaders).toStrictEqual({ - authorization: 'Bearer test-api-key', - 'content-type': 'application/json', - 'custom-provider-header': 'provider-header-value', - 'custom-request-header': 'request-header-value', - 'user-agent': 'ai-sdk/provider-utils/3.0.17 runtime/bun/1.3.0', - }) + expect(requestHeaders.authorization).toBe('Bearer test-api-key') + expect(requestHeaders['content-type']).toBe('application/json') + expect(requestHeaders['custom-provider-header']).toBe( + 'provider-header-value', + ) + expect(requestHeaders['custom-request-header']).toBe('request-header-value') + expect(requestHeaders['user-agent']).toMatch( + /^ai-sdk\/provider-utils\/\d+\.\d+\.\d+ runtime\/bun\/\d+\.\d+\.\d+$/, + ) }) it('should pass extra body', async () => { diff --git a/packages/internal/src/openrouter-ai-sdk/completion/index.test.ts b/packages/internal/src/openrouter-ai-sdk/completion/index.test.ts index 8d1fad908..94d2197cd 100644 --- a/packages/internal/src/openrouter-ai-sdk/completion/index.test.ts +++ b/packages/internal/src/openrouter-ai-sdk/completion/index.test.ts @@ -381,13 +381,15 @@ describe('doGenerate', () => { const requestHeaders = server.calls[0]!.requestHeaders - expect(requestHeaders).toStrictEqual({ - authorization: 'Bearer test-api-key', - 'content-type': 'application/json', - 'custom-provider-header': 'provider-header-value', - 'custom-request-header': 'request-header-value', - 'user-agent': 'ai-sdk/provider-utils/3.0.17 runtime/bun/1.3.0', - }) + expect(requestHeaders.authorization).toBe('Bearer test-api-key') + expect(requestHeaders['content-type']).toBe('application/json') + expect(requestHeaders['custom-provider-header']).toBe( + 'provider-header-value', + ) + expect(requestHeaders['custom-request-header']).toBe('request-header-value') + expect(requestHeaders['user-agent']).toMatch( + /^ai-sdk\/provider-utils\/\d+\.\d+\.\d+ runtime\/bun\/\d+\.\d+\.\d+$/, + ) }) }) @@ -614,13 +616,15 @@ describe('doStream', () => { const requestHeaders = server.calls[0]!.requestHeaders - expect(requestHeaders).toStrictEqual({ - authorization: 'Bearer test-api-key', - 'content-type': 'application/json', - 'custom-provider-header': 'provider-header-value', - 'custom-request-header': 'request-header-value', - 'user-agent': 'ai-sdk/provider-utils/3.0.17 runtime/bun/1.3.0', - }) + expect(requestHeaders.authorization).toBe('Bearer test-api-key') + expect(requestHeaders['content-type']).toBe('application/json') + expect(requestHeaders['custom-provider-header']).toBe( + 'provider-header-value', + ) + expect(requestHeaders['custom-request-header']).toBe('request-header-value') + expect(requestHeaders['user-agent']).toMatch( + /^ai-sdk\/provider-utils\/\d+\.\d+\.\d+ runtime\/bun\/\d+\.\d+\.\d+$/, + ) }) it('should pass extra body', async () => { diff --git a/sdk/src/__tests__/run-with-retry.test.ts b/sdk/src/__tests__/run-with-retry.test.ts index cf0351cf5..e240b8cff 100644 --- a/sdk/src/__tests__/run-with-retry.test.ts +++ b/sdk/src/__tests__/run-with-retry.test.ts @@ -1,10 +1,12 @@ +import { assistantMessage } from '@codebuff/common/util/messages' import { afterEach, describe, expect, it, mock, spyOn } from 'bun:test' -import { ErrorCodes, NetworkError } from '../errors' +import { ErrorCodes } from '../errors' import { run } from '../run' import * as runModule from '../run' import type { RunState } from '../run-state' +import type { SessionState } from '@codebuff/common/types/session-state' const baseOptions = { apiKey: 'test-key', @@ -19,8 +21,13 @@ describe('run retry wrapper', () => { }) it('returns immediately on success without retrying', async () => { - const expectedState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'hi' } } as RunState - const runSpy = spyOn(runModule, 'runOnce').mockResolvedValueOnce(expectedState) + const expectedState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('hi')] }, + } + const runSpy = spyOn(runModule, 'runOnce').mockResolvedValueOnce( + expectedState, + ) const result = await run(baseOptions) @@ -29,11 +36,14 @@ describe('run retry wrapper', () => { }) it('retries once on retryable error output and then succeeds', async () => { - const errorState = { - sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: Service unavailable' } - } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'hi' } } as RunState + const errorState: RunState = { + sessionState: {} as SessionState, + output: { type: 'error', message: 'NetworkError: Service unavailable' }, + } + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('hi')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -51,7 +61,7 @@ describe('run retry wrapper', () => { it('stops after max retries are exhausted and returns error output', async () => { const errorState = { sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: Connection timeout' } + output: { type: 'error', message: 'NetworkError: Connection timeout' }, } as RunState const runSpy = spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -73,7 +83,7 @@ describe('run retry wrapper', () => { it('does not retry non-retryable error outputs', async () => { const errorState = { sessionState: {} as any, - output: { type: 'error', message: 'Invalid input' } + output: { type: 'error', message: 'Invalid input' }, } as RunState const runSpy = spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -91,7 +101,7 @@ describe('run retry wrapper', () => { it('skips retry when retry is false even for retryable error outputs', async () => { const errorState = { sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: Connection failed' } + output: { type: 'error', message: 'NetworkError: Connection failed' }, } as RunState const runSpy = spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -106,11 +116,14 @@ describe('run retry wrapper', () => { }) it('retries when provided custom retryableErrorCodes set', async () => { - const errorState = { + const errorState: RunState = { sessionState: {} as any, - output: { type: 'error', message: 'Server error (500)' } - } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'hi' } } as RunState + output: { type: 'error', message: 'Server error (500)' }, + } + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('hi')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -149,11 +162,14 @@ describe('run retry wrapper', () => { }) it('calls onRetry callback with correct parameters on error output', async () => { - const errorState = { - sessionState: {} as any, - output: { type: 'error', message: 'Service unavailable (503)' } - } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'done' } } as RunState + const errorState: RunState = { + sessionState: {} as SessionState, + output: { type: 'error', message: 'Service unavailable (503)' }, + } + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('done')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -178,7 +194,7 @@ describe('run retry wrapper', () => { it('calls onRetryExhausted after all retries fail', async () => { const errorState = { sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: timeout' } + output: { type: 'error', message: 'NetworkError: timeout' }, } as RunState spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -200,7 +216,7 @@ describe('run retry wrapper', () => { it('returns error output without sessionState on first attempt failure', async () => { const errorState = { - output: { type: 'error', message: 'Not retryable' } + output: { type: 'error', message: 'Not retryable' }, } as RunState spyOn(runModule, 'runOnce').mockResolvedValue(errorState) @@ -216,14 +232,14 @@ describe('run retry wrapper', () => { it('preserves sessionState from previousRun on retry', async () => { const previousSession = { fileContext: { cwd: '/test' } } as any - const errorState = { - sessionState: { fileContext: { cwd: '/new' } } as any, - output: { type: 'error', message: 'Service unavailable' } - } as RunState - const successState = { - sessionState: { fileContext: { cwd: '/final' } } as any, - output: { type: 'lastMessage', value: 'ok' } - } as RunState + const errorState: RunState = { + sessionState: { fileContext: { cwd: '/new' } } as SessionState, + output: { type: 'error', message: 'Service unavailable' }, + } + const successState: RunState = { + sessionState: { fileContext: { cwd: '/final' } } as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('ok')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -231,7 +247,10 @@ describe('run retry wrapper', () => { const result = await run({ ...baseOptions, - previousRun: { sessionState: previousSession, output: { type: 'lastMessage', value: 'prev' } }, + previousRun: { + sessionState: previousSession, + output: { type: 'lastMessage', value: [assistantMessage('prev')] }, + }, retry: { backoffBaseMs: 1, backoffMaxMs: 2 }, }) @@ -240,11 +259,17 @@ describe('run retry wrapper', () => { }) it('handles 503 Service Unavailable errors as retryable', async () => { - const errorState = { - sessionState: {} as any, - output: { type: 'error', message: 'Error from AI SDK: 503 Service Unavailable' } - } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'ok' } } as RunState + const errorState: RunState = { + sessionState: {} as SessionState, + output: { + type: 'error', + message: 'Error from AI SDK: 503 Service Unavailable', + }, + } + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('ok')] }, + } const runSpy = spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) @@ -260,11 +285,14 @@ describe('run retry wrapper', () => { }) it('applies exponential backoff correctly', async () => { - const errorState = { - sessionState: {} as any, - output: { type: 'error', message: 'NetworkError: Connection refused' } + const errorState: RunState = { + sessionState: {} as SessionState, + output: { type: 'error', message: 'NetworkError: Connection refused' }, } as RunState - const successState = { sessionState: {} as any, output: { type: 'lastMessage', value: 'ok' } } as RunState + const successState: RunState = { + sessionState: {} as SessionState, + output: { type: 'lastMessage', value: [assistantMessage('ok')] }, + } spyOn(runModule, 'runOnce') .mockResolvedValueOnce(errorState) diff --git a/sdk/src/impl/llm.ts b/sdk/src/impl/llm.ts index 07f9563c0..f968d14d2 100644 --- a/sdk/src/impl/llm.ts +++ b/sdk/src/impl/llm.ts @@ -18,12 +18,18 @@ import { OpenAICompatibleChatLanguageModel, VERSION, } from '@codebuff/internal/openai-compatible/index' -import { streamText, APICallError, generateText, generateObject } from 'ai' +import { + streamText, + generateText, + generateObject, + NoSuchToolError, + APICallError, + ToolCallRepairError, + InvalidToolInputError, + TypeValidationError, +} from 'ai' import { WEBSITE_URL } from '../constants' -import { NetworkError, PaymentRequiredError, ErrorCodes } from '../errors' - -import type { ErrorCode } from '../errors' import type { LanguageModelV2 } from '@ai-sdk/provider' import type { OpenRouterProviderRoutingOptions } from '@codebuff/common/types/agent-template' import type { @@ -217,13 +223,124 @@ export async function* promptAiSdkStream( ...params, agentProviderOptions: params.agentProviderOptions, }), + // Handle tool call errors gracefully by passing them through to our validation layer + // instead of throwing (which would halt the agent). The only special case is when + // the tool name matches a spawnable agent - transform those to spawn_agents calls. + experimental_repairToolCall: async ({ toolCall, tools, error }) => { + const { spawnableAgents = [], localAgentTemplates = {} } = params + const toolName = toolCall.toolName + + // Check if this is a NoSuchToolError for a spawnable agent + // If so, transform to spawn_agents call + if (NoSuchToolError.isInstance(error) && 'spawn_agents' in tools) { + // Also check for underscore variant (e.g., "file_picker" -> "file-picker") + const toolNameWithHyphens = toolName.replace(/_/g, '-') + + const matchingAgentId = spawnableAgents.find((agentId) => { + const withoutVersion = agentId.split('@')[0] + const parts = withoutVersion.split('/') + const agentName = parts[parts.length - 1] + return ( + agentName === toolName || + agentName === toolNameWithHyphens || + agentId === toolName + ) + }) + const isSpawnableAgent = matchingAgentId !== undefined + const isLocalAgent = + toolName in localAgentTemplates || + toolNameWithHyphens in localAgentTemplates + + if (isSpawnableAgent || isLocalAgent) { + // Transform agent tool call to spawn_agents + const deepParseJson = (value: unknown): unknown => { + if (typeof value === 'string') { + try { + return deepParseJson(JSON.parse(value)) + } catch { + return value + } + } + if (Array.isArray(value)) return value.map(deepParseJson) + if (value !== null && typeof value === 'object') { + return Object.fromEntries( + Object.entries(value).map(([k, v]) => [k, deepParseJson(v)]), + ) + } + return value + } + + let input: Record = {} + try { + const rawInput = + typeof toolCall.input === 'string' + ? JSON.parse(toolCall.input) + : (toolCall.input as Record) + input = deepParseJson(rawInput) as Record + } catch { + // If parsing fails, use empty object + } + + const prompt = + typeof input.prompt === 'string' ? input.prompt : undefined + const agentParams = Object.fromEntries( + Object.entries(input).filter( + ([key, value]) => + !(key === 'prompt' && typeof value === 'string'), + ), + ) + + // Use the matching agent ID or corrected name with hyphens + const correctedAgentType = + matchingAgentId ?? + (toolNameWithHyphens in localAgentTemplates + ? toolNameWithHyphens + : toolName) + + const spawnAgentsInput = { + agents: [ + { + agent_type: correctedAgentType, + ...(prompt !== undefined && { prompt }), + ...(Object.keys(agentParams).length > 0 && { + params: agentParams, + }), + }, + ], + } + + logger.info( + { originalToolName: toolName, transformedInput: spawnAgentsInput }, + 'Transformed agent tool call to spawn_agents', + ) + + return { + ...toolCall, + toolName: 'spawn_agents', + input: JSON.stringify(spawnAgentsInput), + } + } + } + + // For all other cases (invalid args, unknown tools, etc.), pass through + // the original tool call. + logger.info( + { + toolName, + errorType: error.name, + error: error.message, + }, + 'Tool error - passing through for graceful error handling', + ) + return toolCall + }, }) let content = '' const stopSequenceHandler = new StopSequenceHandler(params.stopSequences) - for await (const chunk of response.fullStream) { - if (chunk.type !== 'text-delta') { + for await (const chunkValue of response.fullStream) { + if (chunkValue.type !== 'text-delta') { const flushed = stopSequenceHandler.flush() if (flushed) { content += flushed @@ -234,84 +351,57 @@ export async function* promptAiSdkStream( } } } - if (chunk.type === 'error') { - logger.error( - { - chunk: { ...chunk, error: undefined }, - error: getErrorObject(chunk.error), - model: params.model, - }, - 'Error from AI SDK', - ) + if (chunkValue.type === 'error') { + // Error chunks from fullStream are non-network errors (tool failures, model issues, etc.) + // Network errors are thrown, not yielded as chunks. - const errorBody = APICallError.isInstance(chunk.error) - ? chunk.error.responseBody + const errorBody = APICallError.isInstance(chunkValue.error) + ? chunkValue.error.responseBody : undefined const mainErrorMessage = - chunk.error instanceof Error - ? chunk.error.message - : typeof chunk.error === 'string' - ? chunk.error - : JSON.stringify(chunk.error) - const errorMessage = `Error from AI SDK (model ${params.model}): ${buildArray([mainErrorMessage, errorBody]).join('\n')}` - - // Determine error code from the error - let errorCode: ErrorCode = ErrorCodes.UNKNOWN_ERROR - let statusCode: number | undefined - - if (APICallError.isInstance(chunk.error)) { - statusCode = chunk.error.statusCode - if (statusCode) { - if (statusCode === 402) { - // Payment required - extract message from JSON response body - let paymentErrorMessage = mainErrorMessage - if (errorBody) { - try { - const parsed = JSON.parse(errorBody) - paymentErrorMessage = parsed.message || errorBody - } catch { - paymentErrorMessage = errorBody - } - } - throw new PaymentRequiredError(paymentErrorMessage) - } else if (statusCode === 503) { - errorCode = ErrorCodes.SERVICE_UNAVAILABLE - } else if (statusCode >= 500) { - errorCode = ErrorCodes.SERVER_ERROR - } else if (statusCode === 408 || statusCode === 429) { - errorCode = ErrorCodes.TIMEOUT - } - } - } else if (chunk.error instanceof Error) { - // Check error message for error type indicators (case-insensitive) - const msg = chunk.error.message.toLowerCase() - if (msg.includes('service unavailable') || msg.includes('503')) { - errorCode = ErrorCodes.SERVICE_UNAVAILABLE - } else if ( - msg.includes('econnrefused') || - msg.includes('connection refused') - ) { - errorCode = ErrorCodes.CONNECTION_REFUSED - } else if (msg.includes('enotfound') || msg.includes('dns')) { - errorCode = ErrorCodes.DNS_FAILURE - } else if (msg.includes('timeout')) { - errorCode = ErrorCodes.TIMEOUT - } else if ( - msg.includes('server error') || - msg.includes('500') || - msg.includes('502') || - msg.includes('504') - ) { - errorCode = ErrorCodes.SERVER_ERROR - } else if (msg.includes('network') || msg.includes('fetch failed')) { - errorCode = ErrorCodes.NETWORK_ERROR + chunkValue.error instanceof Error + ? chunkValue.error.message + : typeof chunkValue.error === 'string' + ? chunkValue.error + : JSON.stringify(chunkValue.error) + const errorMessage = buildArray([mainErrorMessage, errorBody]).join('\n') + + // Pass these errors back to the agent so it can see what went wrong and retry. + // Note: If you find any other error types that should be passed through to the agent, add them here! + if ( + NoSuchToolError.isInstance(chunkValue.error) || + InvalidToolInputError.isInstance(chunkValue.error) || + ToolCallRepairError.isInstance(chunkValue.error) || + TypeValidationError.isInstance(chunkValue.error) + ) { + logger.warn( + { + chunk: { ...chunkValue, error: undefined }, + error: getErrorObject(chunkValue.error), + model: params.model, + }, + 'Tool call error in AI SDK stream - passing through to agent to retry', + ) + yield { + type: 'error', + message: errorMessage, } + continue } - // Throw NetworkError so retry logic can handle it - throw new NetworkError(errorMessage, errorCode, statusCode, chunk.error) + logger.error( + { + chunk: { ...chunkValue, error: undefined }, + error: getErrorObject(chunkValue.error), + model: params.model, + }, + 'Error in AI SDK stream', + ) + + // For all other errors, throw them -- they are fatal. + throw chunkValue.error } - if (chunk.type === 'reasoning-delta') { + if (chunkValue.type === 'reasoning-delta') { for (const provider of ['openrouter', 'codebuff'] as const) { if ( ( @@ -325,23 +415,23 @@ export async function* promptAiSdkStream( } yield { type: 'reasoning', - text: chunk.text, + text: chunkValue.text, } } - if (chunk.type === 'text-delta') { + if (chunkValue.type === 'text-delta') { if (!params.stopSequences) { - content += chunk.text - if (chunk.text) { + content += chunkValue.text + if (chunkValue.text) { yield { type: 'text', - text: chunk.text, + text: chunkValue.text, ...(agentChunkMetadata ?? {}), } } continue } - const stopSequenceResult = stopSequenceHandler.process(chunk.text) + const stopSequenceResult = stopSequenceHandler.process(chunkValue.text) if (stopSequenceResult.text) { content += stopSequenceResult.text yield { @@ -351,6 +441,9 @@ export async function* promptAiSdkStream( } } } + if (chunkValue.type === 'tool-call') { + yield chunkValue + } } const flushed = stopSequenceHandler.flush() if (flushed) { diff --git a/sdk/src/run.ts b/sdk/src/run.ts index 0f77ca046..44bba6fbd 100644 --- a/sdk/src/run.ts +++ b/sdk/src/run.ts @@ -932,7 +932,7 @@ async function handleToolCall({ if (override) { result = await override(input as any) } else if (toolName === 'end_turn') { - result = [] + result = [{ type: 'json', value: { message: 'Turn ended.' } }] } else if (toolName === 'write_file' || toolName === 'str_replace') { result = await changeFile({ parameters: input, diff --git a/web/src/app/api/v1/chat/completions/_post.ts b/web/src/app/api/v1/chat/completions/_post.ts index 0762f3f1b..45f99d675 100644 --- a/web/src/app/api/v1/chat/completions/_post.ts +++ b/web/src/app/api/v1/chat/completions/_post.ts @@ -12,6 +12,7 @@ import { import { handleOpenRouterNonStream, handleOpenRouterStream, + OpenRouterError, } from '@/llm-api/openrouter' import { extractApiKeyFromHeader } from '@/util/auth' @@ -339,6 +340,12 @@ export async function postChatCompletions(params: { }, logger, }) + + // Pass through OpenRouter provider-specific errors + if (error instanceof OpenRouterError) { + return NextResponse.json(error.toJSON(), { status: error.statusCode }) + } + return NextResponse.json( { error: 'Failed to process request' }, { status: 500 }, diff --git a/web/src/llm-api/openrouter.ts b/web/src/llm-api/openrouter.ts index d9a85ed64..173eb9bfc 100644 --- a/web/src/llm-api/openrouter.ts +++ b/web/src/llm-api/openrouter.ts @@ -6,7 +6,10 @@ import { extractRequestMetadata, insertMessageToBigQuery, } from './helpers' -import { OpenRouterStreamChatCompletionChunkSchema } from './type/openrouter' +import { + OpenRouterErrorResponseSchema, + OpenRouterStreamChatCompletionChunkSchema, +} from './type/openrouter' import type { UsageData } from './helpers' import type { OpenRouterStreamChatCompletionChunk } from './type/openrouter' @@ -14,7 +17,6 @@ import type { InsertMessageBigqueryFn } from '@codebuff/common/types/contracts/b import type { Logger } from '@codebuff/common/types/contracts/logger' type StreamState = { responseText: string; reasoningText: string } - function createOpenRouterRequest(params: { body: any openrouterApiKey: string | null @@ -93,9 +95,9 @@ export async function handleOpenRouterNonStream({ const responses = await Promise.all(requests) if (responses.every((r) => !r.ok)) { - throw new Error( - `Failed to make all ${n} requests: ${responses.map((r) => r.statusText).join(', ')}`, - ) + // Return provider-specific error from the first failed response + const firstFailedResponse = responses[0] + throw await parseOpenRouterError(firstFailedResponse) } const allData = await Promise.all(responses.map((r) => r.json())) @@ -183,9 +185,7 @@ export async function handleOpenRouterNonStream({ }) if (!response.ok) { - throw new Error( - `OpenRouter API error (${response.statusText}): ${await response.text()}`, - ) + throw await parseOpenRouterError(response) } const data = await response.json() @@ -261,9 +261,7 @@ export async function handleOpenRouterStream({ }) if (!response.ok) { - throw new Error( - `OpenRouter API error (${response.statusText}): ${await response.text()}`, - ) + throw await parseOpenRouterError(response) } const reader = response.body?.getReader() @@ -532,3 +530,84 @@ async function handleStreamChunk({ state.reasoningText += choice.delta?.reasoning ?? '' return state } + +/** + * Custom error class for OpenRouter API errors that preserves provider-specific details. + */ +export class OpenRouterError extends Error { + constructor( + public readonly statusCode: number, + public readonly statusText: string, + public readonly errorBody: { + error: { + message: string + code: string | number | null + type?: string | null + param?: unknown + metadata?: { + raw?: string + provider_name?: string + } + } + }, + ) { + super(errorBody.error.message) + this.name = 'OpenRouterError' + } + + /** + * Returns the error in a format suitable for API responses. + */ + toJSON() { + return { + error: { + message: this.errorBody.error.message, + code: this.errorBody.error.code, + type: this.errorBody.error.type, + param: this.errorBody.error.param, + metadata: this.errorBody.error.metadata, + }, + } + } +} + +/** + * Parses an error response from OpenRouter and returns an OpenRouterError. + */ +async function parseOpenRouterError( + response: Response, +): Promise { + const errorText = await response.text() + let errorBody: OpenRouterError['errorBody'] + try { + const parsed = JSON.parse(errorText) + const validated = OpenRouterErrorResponseSchema.safeParse(parsed) + if (validated.success) { + errorBody = { + error: { + message: validated.data.error.message, + code: validated.data.error.code ?? null, + type: validated.data.error.type, + param: validated.data.error.param, + // metadata is not in the schema but OpenRouter includes it for provider errors + metadata: (parsed as any).error?.metadata, + }, + } + } else { + errorBody = { + error: { + message: errorText || response.statusText, + code: response.status, + }, + } + } + } catch { + errorBody = { + error: { + message: errorText || response.statusText, + code: response.status, + }, + } + } + return new OpenRouterError(response.status, response.statusText, errorBody) +}