From 84a30ae102199ae12cab0c12af9bbc3005b87499 Mon Sep 17 00:00:00 2001 From: crazybolillo Date: Fri, 3 Apr 2026 22:20:52 -0600 Subject: [PATCH] feat: support structured output for agents To enable this, LanguageModel required updates to support tool calls as normal, non structured generation methods already do. These changes are fully backwards compatible. Note that for this feature to work, each provider must explicitly support tool calls alongside structured output. This commit only implements it for the OpenAI provider. If a provider lacks support, the agent will only generate structured output without executing tool calls. Closes #118. --- agent.go | 107 ++++++++++++++++--- object.go | 5 + object/object.go | 1 + providers/openai/responses_language_model.go | 36 +++++-- 4 files changed, 126 insertions(+), 23 deletions(-) diff --git a/agent.go b/agent.go index 34d4f8a1f..86c30affa 100644 --- a/agent.go +++ b/agent.go @@ -30,6 +30,8 @@ type stepExecutionResult struct { // StopCondition defines a function that determines when an agent should stop executing. type StopCondition = func(steps []StepResult) bool +type responseGenerator = func(ctx context.Context, model LanguageModel, call Call) (*Response, error) + // StepCountIs returns a stop condition that stops after the specified number of steps. func StepCountIs(stepCount int) StopCondition { return func(steps []StepResult) bool { @@ -304,6 +306,7 @@ type AgentResult struct { // Agent represents an AI agent that can generate responses and stream responses. type Agent interface { Generate(context.Context, AgentCall) (*AgentResult, error) + GenerateObject(context.Context, schema.Schema, AgentCall) (*AgentResult, error) Stream(context.Context, AgentStreamCall) (*AgentResult, error) } @@ -367,13 +370,12 @@ func (a *agent) prepareCall(call AgentCall) AgentCall { return call } -// Generate implements Agent. -func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) { - opts = a.prepareCall(opts) - initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...) - if err != nil { - return nil, err - } +func (a *agent) executeLoop( + ctx context.Context, + initialPrompt Prompt, + gen responseGenerator, + opts AgentCall, +) ([]StepResult, error) { var responseMessages []Message var steps []StepResult @@ -446,7 +448,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err retryOptions.OnRetry = opts.OnRetry retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions) result, err := retry(ctx, func() (*Response, error) { - return stepModel.Generate(ctx, Call{ + return gen(ctx, stepModel, Call{ Prompt: stepInputMessages, MaxOutputTokens: opts.MaxOutputTokens, Temperature: opts.Temperature, @@ -485,7 +487,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil) - // Build step content with validated tool calls and tool results. // Provider-executed tool calls are kept as-is. + // Build step content with validated tool calls and tool results. Provider-executed tool calls are kept as-is. stepContent := []Content{} toolCallIndex := 0 for _, content := range result.Content { @@ -528,8 +530,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err } } - totalUsage := Usage{} + //nolint:nilerr // tool execution failure breaks the loop but does not prevent an answer from being returned + return steps, nil +} +func toAgentResult(steps []StepResult) *AgentResult { + totalUsage := Usage{} for _, step := range steps { usage := step.Usage totalUsage.InputTokens += usage.InputTokens @@ -540,12 +546,89 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err totalUsage.TotalTokens += usage.TotalTokens } - agentResult := &AgentResult{ + return &AgentResult{ Steps: steps, Response: steps[len(steps)-1].Response, TotalUsage: totalUsage, } - return agentResult, nil +} + +// Generate implements Agent. +func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) { + opts = a.prepareCall(opts) + initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...) + if err != nil { + return nil, err + } + steps, err := a.executeLoop( + ctx, + initialPrompt, + func(ctx context.Context, stepModel LanguageModel, call Call) (*Response, error) { + return stepModel.Generate(ctx, call) + }, + opts, + ) + if err != nil { + return nil, err + } + + return toAgentResult(steps), nil +} + +func (a *agent) GenerateObject(ctx context.Context, s schema.Schema, opts AgentCall) (*AgentResult, error) { + opts = a.prepareCall(opts) + initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...) + if err != nil { + return nil, err + } + + steps, err := a.executeLoop( + ctx, + initialPrompt, + func(ctx context.Context, model LanguageModel, call Call) (*Response, error) { + res, err := model.GenerateObject(ctx, ObjectCall{ + Prompt: call.Prompt, + Schema: s, + MaxOutputTokens: call.MaxOutputTokens, + Temperature: call.Temperature, + TopP: call.TopP, + TopK: call.TopK, + PresencePenalty: call.PresencePenalty, + FrequencyPenalty: call.FrequencyPenalty, + UserAgent: call.UserAgent, + ProviderOptions: call.ProviderOptions, + RepairText: nil, + Tools: call.Tools, + ToolChoice: call.ToolChoice, + }) + if err != nil { + return nil, err + } + + var content ResponseContent + for _, toolCall := range res.ToolCalls { + content = append(content, toolCall) + } + + if res.RawText != "" { + content = append(content, TextContent{Text: res.RawText}) + } + + return &Response{ + Content: content, + FinishReason: res.FinishReason, + Usage: res.Usage, + Warnings: res.Warnings, + ProviderMetadata: res.ProviderMetadata, + }, nil + }, + opts, + ) + if err != nil { + return nil, err + } + + return toAgentResult(steps), nil } func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool { diff --git a/object.go b/object.go index 3cbcbd3eb..da027e2cf 100644 --- a/object.go +++ b/object.go @@ -47,6 +47,9 @@ type ObjectCall struct { ProviderOptions ProviderOptions RepairText schema.ObjectRepairFunc + + Tools []Tool `json:"tools"` + ToolChoice *ToolChoice `json:"tool_choice"` } // ObjectResponse represents the response from a structured object generation. @@ -57,6 +60,7 @@ type ObjectResponse struct { FinishReason FinishReason Warnings []CallWarning ProviderMetadata ProviderMetadata + ToolCalls []ToolCallContent } // ObjectStreamPartType indicates the type of stream part. @@ -99,6 +103,7 @@ type ObjectResult[T any] struct { FinishReason FinishReason Warnings []CallWarning ProviderMetadata ProviderMetadata + ToolCalls []ToolCallContent } // StreamObjectResult provides typed access to a streaming object generation result. diff --git a/object/object.go b/object/object.go index 16c77d1cf..d8e1566e3 100644 --- a/object/object.go +++ b/object/object.go @@ -51,6 +51,7 @@ func Generate[T any]( FinishReason: resp.FinishReason, Warnings: resp.Warnings, ProviderMetadata: resp.ProviderMetadata, + ToolCalls: resp.ToolCalls, }, nil } diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index bd61a68ba..56bc1cf90 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -1325,6 +1325,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, PresencePenalty: call.PresencePenalty, FrequencyPenalty: call.FrequencyPenalty, ProviderOptions: call.ProviderOptions, + Tools: call.Tools, + ToolChoice: call.ToolChoice, } params, warnings, err := o.prepareParams(fantasyCall) @@ -1350,8 +1352,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, } } - // Extract JSON text from response var jsonText string + var toolCalls []fantasy.ToolCallContent for _, outputItem := range response.Output { if outputItem.Type == "message" { for _, contentPart := range outputItem.Content { @@ -1361,15 +1363,21 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, } } } + if outputItem.Type == "function_call" { + toolCalls = append(toolCalls, fantasy.ToolCallContent{ + ProviderExecuted: false, + ToolCallID: outputItem.CallID, + ToolName: outputItem.Name, + Input: outputItem.Arguments.OfString, + }) + } } - if jsonText == "" { - usage := fantasy.Usage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, - } - finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false) + usage := responsesUsage(*response) + hasFunctionCall := len(toolCalls) > 0 + finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, hasFunctionCall) + + if jsonText == "" && len(toolCalls) == 0 { return nil, &fantasy.NoObjectGeneratedError{ RawText: "", ParseError: fmt.Errorf("no text content in response"), @@ -1378,6 +1386,14 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, } } + if jsonText == "" && len(toolCalls) > 0 { + return &fantasy.ObjectResponse{ + Usage: usage, + FinishReason: finishReason, + ToolCalls: toolCalls, + }, nil + } + // Parse and validate var obj any if call.RepairText != nil { @@ -1386,9 +1402,6 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, obj, err = schema.ParseAndValidate(jsonText, call.Schema) } - usage := responsesUsage(*response) - finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false) - if err != nil { // Add usage info to error if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok { @@ -1405,6 +1418,7 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, FinishReason: finishReason, Warnings: warnings, ProviderMetadata: responsesProviderMetadata(response.ID), + ToolCalls: toolCalls, }, nil }