Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 95 additions & 12 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type stepExecutionResult struct {
// StopCondition defines a function that determines when an agent should stop executing.
type StopCondition = func(steps []StepResult) bool

type responseGenerator = func(ctx context.Context, model LanguageModel, call Call) (*Response, error)

// StepCountIs returns a stop condition that stops after the specified number of steps.
func StepCountIs(stepCount int) StopCondition {
return func(steps []StepResult) bool {
Expand Down Expand Up @@ -304,6 +306,7 @@ type AgentResult struct {
// Agent represents an AI agent that can generate responses and stream responses.
type Agent interface {
Generate(context.Context, AgentCall) (*AgentResult, error)
GenerateObject(context.Context, schema.Schema, AgentCall) (*AgentResult, error)
Stream(context.Context, AgentStreamCall) (*AgentResult, error)
}

Expand Down Expand Up @@ -367,13 +370,12 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
return call
}

// Generate implements Agent.
func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
opts = a.prepareCall(opts)
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
if err != nil {
return nil, err
}
func (a *agent) executeLoop(
ctx context.Context,
initialPrompt Prompt,
gen responseGenerator,
opts AgentCall,
) ([]StepResult, error) {
var responseMessages []Message
var steps []StepResult

Expand Down Expand Up @@ -446,7 +448,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
retryOptions.OnRetry = opts.OnRetry
retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
result, err := retry(ctx, func() (*Response, error) {
return stepModel.Generate(ctx, Call{
return gen(ctx, stepModel, Call{
Prompt: stepInputMessages,
MaxOutputTokens: opts.MaxOutputTokens,
Temperature: opts.Temperature,
Expand Down Expand Up @@ -485,7 +487,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err

toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)

// Build step content with validated tool calls and tool results. // Provider-executed tool calls are kept as-is.
// Build step content with validated tool calls and tool results. Provider-executed tool calls are kept as-is.
stepContent := []Content{}
toolCallIndex := 0
for _, content := range result.Content {
Expand Down Expand Up @@ -528,8 +530,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}
}

totalUsage := Usage{}
//nolint:nilerr // tool execution failure breaks the loop but does not prevent an answer from being returned
return steps, nil
}

func toAgentResult(steps []StepResult) *AgentResult {
totalUsage := Usage{}
for _, step := range steps {
usage := step.Usage
totalUsage.InputTokens += usage.InputTokens
Expand All @@ -540,12 +546,89 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
totalUsage.TotalTokens += usage.TotalTokens
}

agentResult := &AgentResult{
return &AgentResult{
Steps: steps,
Response: steps[len(steps)-1].Response,
TotalUsage: totalUsage,
}
return agentResult, nil
}

// Generate implements Agent.
func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
opts = a.prepareCall(opts)
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
if err != nil {
return nil, err
}
steps, err := a.executeLoop(
ctx,
initialPrompt,
func(ctx context.Context, stepModel LanguageModel, call Call) (*Response, error) {
return stepModel.Generate(ctx, call)
},
opts,
)
if err != nil {
return nil, err
}

return toAgentResult(steps), nil
}

func (a *agent) GenerateObject(ctx context.Context, s schema.Schema, opts AgentCall) (*AgentResult, error) {
opts = a.prepareCall(opts)
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
if err != nil {
return nil, err
}

steps, err := a.executeLoop(
ctx,
initialPrompt,
func(ctx context.Context, model LanguageModel, call Call) (*Response, error) {
res, err := model.GenerateObject(ctx, ObjectCall{
Prompt: call.Prompt,
Schema: s,
MaxOutputTokens: call.MaxOutputTokens,
Temperature: call.Temperature,
TopP: call.TopP,
TopK: call.TopK,
PresencePenalty: call.PresencePenalty,
FrequencyPenalty: call.FrequencyPenalty,
UserAgent: call.UserAgent,
ProviderOptions: call.ProviderOptions,
RepairText: nil,
Tools: call.Tools,
ToolChoice: call.ToolChoice,
})
if err != nil {
return nil, err
}

var content ResponseContent
for _, toolCall := range res.ToolCalls {
content = append(content, toolCall)
}

if res.RawText != "" {
content = append(content, TextContent{Text: res.RawText})
}

return &Response{
Content: content,
FinishReason: res.FinishReason,
Usage: res.Usage,
Warnings: res.Warnings,
ProviderMetadata: res.ProviderMetadata,
}, nil
},
opts,
)
if err != nil {
return nil, err
}

return toAgentResult(steps), nil
}

func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
Expand Down
5 changes: 5 additions & 0 deletions object.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ type ObjectCall struct {
ProviderOptions ProviderOptions

RepairText schema.ObjectRepairFunc

Tools []Tool `json:"tools"`
ToolChoice *ToolChoice `json:"tool_choice"`
}

// ObjectResponse represents the response from a structured object generation.
Expand All @@ -57,6 +60,7 @@ type ObjectResponse struct {
FinishReason FinishReason
Warnings []CallWarning
ProviderMetadata ProviderMetadata
ToolCalls []ToolCallContent
}

// ObjectStreamPartType indicates the type of stream part.
Expand Down Expand Up @@ -99,6 +103,7 @@ type ObjectResult[T any] struct {
FinishReason FinishReason
Warnings []CallWarning
ProviderMetadata ProviderMetadata
ToolCalls []ToolCallContent
}

// StreamObjectResult provides typed access to a streaming object generation result.
Expand Down
1 change: 1 addition & 0 deletions object/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func Generate[T any](
FinishReason: resp.FinishReason,
Warnings: resp.Warnings,
ProviderMetadata: resp.ProviderMetadata,
ToolCalls: resp.ToolCalls,
}, nil
}

Expand Down
36 changes: 25 additions & 11 deletions providers/openai/responses_language_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
PresencePenalty: call.PresencePenalty,
FrequencyPenalty: call.FrequencyPenalty,
ProviderOptions: call.ProviderOptions,
Tools: call.Tools,
ToolChoice: call.ToolChoice,
}

params, warnings, err := o.prepareParams(fantasyCall)
Expand All @@ -1350,8 +1352,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
}
}

// Extract JSON text from response
var jsonText string
var toolCalls []fantasy.ToolCallContent
for _, outputItem := range response.Output {
if outputItem.Type == "message" {
for _, contentPart := range outputItem.Content {
Expand All @@ -1361,15 +1363,21 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
}
}
}
if outputItem.Type == "function_call" {
toolCalls = append(toolCalls, fantasy.ToolCallContent{
ProviderExecuted: false,
ToolCallID: outputItem.CallID,
ToolName: outputItem.Name,
Input: outputItem.Arguments.OfString,
})
}
}

if jsonText == "" {
usage := fantasy.Usage{
InputTokens: response.Usage.InputTokens,
OutputTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
}
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false)
usage := responsesUsage(*response)
hasFunctionCall := len(toolCalls) > 0
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, hasFunctionCall)

if jsonText == "" && len(toolCalls) == 0 {
return nil, &fantasy.NoObjectGeneratedError{
RawText: "",
ParseError: fmt.Errorf("no text content in response"),
Expand All @@ -1378,6 +1386,14 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
}
}

if jsonText == "" && len(toolCalls) > 0 {
return &fantasy.ObjectResponse{
Usage: usage,
FinishReason: finishReason,
ToolCalls: toolCalls,
}, nil
}

// Parse and validate
var obj any
if call.RepairText != nil {
Expand All @@ -1386,9 +1402,6 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
obj, err = schema.ParseAndValidate(jsonText, call.Schema)
}

usage := responsesUsage(*response)
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false)

if err != nil {
// Add usage info to error
if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
Expand All @@ -1405,6 +1418,7 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
FinishReason: finishReason,
Warnings: warnings,
ProviderMetadata: responsesProviderMetadata(response.ID),
ToolCalls: toolCalls,
}, nil
}

Expand Down
Loading