diff --git a/agent.go b/agent.go index 34d4f8a1f..86c30affa 100644 --- a/agent.go +++ b/agent.go @@ -30,6 +30,8 @@ type stepExecutionResult struct { // StopCondition defines a function that determines when an agent should stop executing. type StopCondition = func(steps []StepResult) bool +type responseGenerator = func(ctx context.Context, model LanguageModel, call Call) (*Response, error) + // StepCountIs returns a stop condition that stops after the specified number of steps. func StepCountIs(stepCount int) StopCondition { return func(steps []StepResult) bool { @@ -304,6 +306,7 @@ type AgentResult struct { // Agent represents an AI agent that can generate responses and stream responses. type Agent interface { Generate(context.Context, AgentCall) (*AgentResult, error) + GenerateObject(context.Context, schema.Schema, AgentCall) (*AgentResult, error) Stream(context.Context, AgentStreamCall) (*AgentResult, error) } @@ -367,13 +370,12 @@ func (a *agent) prepareCall(call AgentCall) AgentCall { return call } -// Generate implements Agent. -func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) { - opts = a.prepareCall(opts) - initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...) - if err != nil { - return nil, err - } +func (a *agent) executeLoop( + ctx context.Context, + initialPrompt Prompt, + gen responseGenerator, + opts AgentCall, +) ([]StepResult, error) { var responseMessages []Message var steps []StepResult @@ -446,7 +448,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err retryOptions.OnRetry = opts.OnRetry retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions) result, err := retry(ctx, func() (*Response, error) { - return stepModel.Generate(ctx, Call{ + return gen(ctx, stepModel, Call{ Prompt: stepInputMessages, MaxOutputTokens: opts.MaxOutputTokens, Temperature: opts.Temperature, @@ -485,7 +487,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil) - // Build step content with validated tool calls and tool results. // Provider-executed tool calls are kept as-is. + // Build step content with validated tool calls and tool results. Provider-executed tool calls are kept as-is. stepContent := []Content{} toolCallIndex := 0 for _, content := range result.Content { @@ -528,8 +530,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err } } - totalUsage := Usage{} + //nolint:nilerr // tool execution failure breaks the loop but does not prevent an answer from being returned + return steps, nil +} +func toAgentResult(steps []StepResult) *AgentResult { + totalUsage := Usage{} for _, step := range steps { usage := step.Usage totalUsage.InputTokens += usage.InputTokens @@ -540,12 +546,89 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err totalUsage.TotalTokens += usage.TotalTokens } - agentResult := &AgentResult{ + return &AgentResult{ Steps: steps, Response: steps[len(steps)-1].Response, TotalUsage: totalUsage, } - return agentResult, nil +} + +// Generate implements Agent. +func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) { + opts = a.prepareCall(opts) + initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...) + if err != nil { + return nil, err + } + steps, err := a.executeLoop( + ctx, + initialPrompt, + func(ctx context.Context, stepModel LanguageModel, call Call) (*Response, error) { + return stepModel.Generate(ctx, call) + }, + opts, + ) + if err != nil { + return nil, err + } + + return toAgentResult(steps), nil +} + +func (a *agent) GenerateObject(ctx context.Context, s schema.Schema, opts AgentCall) (*AgentResult, error) { + opts = a.prepareCall(opts) + initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...) + if err != nil { + return nil, err + } + + steps, err := a.executeLoop( + ctx, + initialPrompt, + func(ctx context.Context, model LanguageModel, call Call) (*Response, error) { + res, err := model.GenerateObject(ctx, ObjectCall{ + Prompt: call.Prompt, + Schema: s, + MaxOutputTokens: call.MaxOutputTokens, + Temperature: call.Temperature, + TopP: call.TopP, + TopK: call.TopK, + PresencePenalty: call.PresencePenalty, + FrequencyPenalty: call.FrequencyPenalty, + UserAgent: call.UserAgent, + ProviderOptions: call.ProviderOptions, + RepairText: nil, + Tools: call.Tools, + ToolChoice: call.ToolChoice, + }) + if err != nil { + return nil, err + } + + var content ResponseContent + for _, toolCall := range res.ToolCalls { + content = append(content, toolCall) + } + + if res.RawText != "" { + content = append(content, TextContent{Text: res.RawText}) + } + + return &Response{ + Content: content, + FinishReason: res.FinishReason, + Usage: res.Usage, + Warnings: res.Warnings, + ProviderMetadata: res.ProviderMetadata, + }, nil + }, + opts, + ) + if err != nil { + return nil, err + } + + return toAgentResult(steps), nil } func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool { diff --git a/object.go b/object.go index 3cbcbd3eb..da027e2cf 100644 --- a/object.go +++ b/object.go @@ -47,6 +47,9 @@ type ObjectCall struct { ProviderOptions ProviderOptions RepairText schema.ObjectRepairFunc + + Tools []Tool `json:"tools"` + ToolChoice *ToolChoice `json:"tool_choice"` } // ObjectResponse represents the response from a structured object generation. @@ -57,6 +60,7 @@ type ObjectResponse struct { FinishReason FinishReason Warnings []CallWarning ProviderMetadata ProviderMetadata + ToolCalls []ToolCallContent } // ObjectStreamPartType indicates the type of stream part. @@ -99,6 +103,7 @@ type ObjectResult[T any] struct { FinishReason FinishReason Warnings []CallWarning ProviderMetadata ProviderMetadata + ToolCalls []ToolCallContent } // StreamObjectResult provides typed access to a streaming object generation result. diff --git a/object/object.go b/object/object.go index 16c77d1cf..d8e1566e3 100644 --- a/object/object.go +++ b/object/object.go @@ -51,6 +51,7 @@ func Generate[T any]( FinishReason: resp.FinishReason, Warnings: resp.Warnings, ProviderMetadata: resp.ProviderMetadata, + ToolCalls: resp.ToolCalls, }, nil } diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index bd61a68ba..56bc1cf90 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -1325,6 +1325,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, PresencePenalty: call.PresencePenalty, FrequencyPenalty: call.FrequencyPenalty, ProviderOptions: call.ProviderOptions, + Tools: call.Tools, + ToolChoice: call.ToolChoice, } params, warnings, err := o.prepareParams(fantasyCall) @@ -1350,8 +1352,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, } } - // Extract JSON text from response var jsonText string + var toolCalls []fantasy.ToolCallContent for _, outputItem := range response.Output { if outputItem.Type == "message" { for _, contentPart := range outputItem.Content { @@ -1361,15 +1363,21 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, } } } + if outputItem.Type == "function_call" { + toolCalls = append(toolCalls, fantasy.ToolCallContent{ + ProviderExecuted: false, + ToolCallID: outputItem.CallID, + ToolName: outputItem.Name, + Input: outputItem.Arguments.OfString, + }) + } } - if jsonText == "" { - usage := fantasy.Usage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, - } - finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false) + usage := responsesUsage(*response) + hasFunctionCall := len(toolCalls) > 0 + finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, hasFunctionCall) + + if jsonText == "" && len(toolCalls) == 0 { return nil, &fantasy.NoObjectGeneratedError{ RawText: "", ParseError: fmt.Errorf("no text content in response"), @@ -1378,6 +1386,14 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, } } + if jsonText == "" && len(toolCalls) > 0 { + return &fantasy.ObjectResponse{ + Usage: usage, + FinishReason: finishReason, + ToolCalls: toolCalls, + }, nil + } + // Parse and validate var obj any if call.RepairText != nil { @@ -1386,9 +1402,6 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, obj, err = schema.ParseAndValidate(jsonText, call.Schema) } - usage := responsesUsage(*response) - finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false) - if err != nil { // Add usage info to error if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok { @@ -1405,6 +1418,7 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, FinishReason: finishReason, Warnings: warnings, ProviderMetadata: responsesProviderMetadata(response.ID), + ToolCalls: toolCalls, }, nil }