Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion go/adk/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ func extractHeaders(headers map[string]string) map[string]string {
return headers
}

// makeBeforeToolCallback returns a BeforeToolCallback that logs tool invocations.
// makeBeforeToolCallback returns a BeforeToolCallback that logs tool invocations
// and short-circuits calls where required parameters are nil.
func makeBeforeToolCallback(logger logr.Logger) llmagent.BeforeToolCallback {
return func(ctx tool.Context, t tool.Tool, args map[string]any) (map[string]any, error) {
logger.Info("Tool execution started",
Expand All @@ -357,10 +358,84 @@ func makeBeforeToolCallback(logger logr.Logger) llmagent.BeforeToolCallback {
"invocationID", ctx.InvocationID(),
"args", truncateArgs(args),
)

// Short-circuit if required params are missing or nil; the LLM may drop large values
// due to output-token limits, producing an opaque downstream error otherwise.
if missingParams := findNilRequiredParams(t, args); len(missingParams) > 0 {
msg := fmt.Sprintf(
"tool %q: required parameter(s) %v are missing or nil; the LLM may have omitted large values due to output-token limits",
t.Name(), missingParams,
)
Comment thread
areebahmeddd marked this conversation as resolved.
logger.Info("required parameters missing or nil, skipping tool call",
"tool", t.Name(),
"missingParams", missingParams,
)
return map[string]any{"error": msg}, nil
}

return nil, nil
}
}

// toolDeclarationProvider is implemented by tools that expose a genai.FunctionDeclaration.
type toolDeclarationProvider interface {
Declaration() *genai.FunctionDeclaration
}

// requiredParamNames returns the required parameter names from the tool's declaration.
func requiredParamNames(t tool.Tool) []string {
dp, ok := t.(toolDeclarationProvider)
if !ok {
return nil
}
decl := dp.Declaration()
if decl == nil {
return nil
}

// Structured schema (function tools).
if decl.Parameters != nil && len(decl.Parameters.Required) > 0 {
return decl.Parameters.Required
}

// Opaque JSON schema (MCP tools): re-marshal to extract the "required" array.
if decl.ParametersJsonSchema != nil {
return extractRequiredFromJSONSchema(decl.ParametersJsonSchema)
}

return nil
}

// extractRequiredFromJSONSchema decodes the "required" array from an opaque JSON-schema value.
func extractRequiredFromJSONSchema(schema any) []string {
b, err := json.Marshal(schema)
if err != nil {
return nil
}
var s struct {
Required []string `json:"required"`
}
if err := json.Unmarshal(b, &s); err != nil {
return nil
}
return s.Required
}

// findNilRequiredParams returns required parameter names whose value in args is nil or absent.
func findNilRequiredParams(t tool.Tool, args map[string]any) []string {
required := requiredParamNames(t)
if len(required) == 0 {
return nil
}
var missing []string
for _, param := range required {
if v, exists := args[param]; !exists || v == nil {
missing = append(missing, param)
}
}
return missing
}

// makeAfterToolCallback returns an AfterToolCallback that logs tool completion.
func makeAfterToolCallback(logger logr.Logger) llmagent.AfterToolCallback {
return func(ctx tool.Context, t tool.Tool, args, result map[string]any, err error) (map[string]any, error) {
Expand Down
189 changes: 189 additions & 0 deletions go/adk/pkg/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/go-logr/logr"
"github.com/kagent-dev/kagent/go/adk/pkg/models"
"github.com/kagent-dev/kagent/go/api/adk"
"google.golang.org/adk/tool"
"google.golang.org/genai"
)

// TestConfigDeserialization_OpenAI verifies that a realistic OpenAI config.json
Expand Down Expand Up @@ -424,3 +426,190 @@ func TestAgentConfigFieldUsage(t *testing.T) {
})
}
}

// mockTool implements tool.Tool and toolDeclarationProvider.
type mockTool struct {
name string
declaration *genai.FunctionDeclaration
}

func (m *mockTool) Name() string { return m.name }
func (m *mockTool) Description() string { return "" }
func (m *mockTool) IsLongRunning() bool { return false }

// Declaration implements toolDeclarationProvider.
func (m *mockTool) Declaration() *genai.FunctionDeclaration {
return m.declaration
}

// mockToolNoDeclaration implements tool.Tool without a Declaration method.
type mockToolNoDeclaration struct{ name string }

func (m *mockToolNoDeclaration) Name() string { return m.name }
func (m *mockToolNoDeclaration) Description() string { return "" }
func (m *mockToolNoDeclaration) IsLongRunning() bool { return false }

// TestExtractRequiredFromJSONSchema tests required field extraction from opaque JSON-schema values.
func TestExtractRequiredFromJSONSchema(t *testing.T) {
tests := []struct {
name string
schema any
want []string
}{
{
name: "map with required",
schema: map[string]any{
"type": "object",
"required": []any{"file_content", "filename"},
"properties": map[string]any{
"file_content": map[string]any{"type": "string"},
"filename": map[string]any{"type": "string"},
},
},
want: []string{"file_content", "filename"},
},
{
name: "map without required",
schema: map[string]any{
"type": "object",
"properties": map[string]any{
"optional_param": map[string]any{"type": "string"},
},
},
want: nil,
},
{
name: "nil schema",
schema: nil,
want: nil,
},
{
name: "empty map",
schema: map[string]any{},
want: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractRequiredFromJSONSchema(tt.schema)
if len(got) != len(tt.want) {
t.Fatalf("extractRequiredFromJSONSchema() = %v, want %v", got, tt.want)
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("extractRequiredFromJSONSchema()[%d] = %q, want %q", i, got[i], tt.want[i])
}
}
})
}
}

// TestFindNilRequiredParams tests detection of nil required parameters.
func TestFindNilRequiredParams(t *testing.T) {
tests := []struct {
name string
tool tool.Tool
args map[string]any
wantNil bool
wantLen int
}{
{
name: "tool without Declaration - always OK",
tool: &mockToolNoDeclaration{name: "no_decl"},
args: map[string]any{"p": nil},
wantNil: true,
},
{
name: "all required params present",
tool: &mockTool{
name: "uploader",
declaration: &genai.FunctionDeclaration{
Parameters: &genai.Schema{
Required: []string{"file_content"},
},
},
},
args: map[string]any{"file_content": "hello"},
wantNil: true,
},
{
name: "required param is nil via genai schema",
tool: &mockTool{
name: "uploader",
declaration: &genai.FunctionDeclaration{
Parameters: &genai.Schema{
Required: []string{"file_content", "filename"},
},
},
},
args: map[string]any{"file_content": nil, "filename": "report.txt"},
wantNil: false,
wantLen: 1,
},
{
name: "required param absent from args via genai schema",
tool: &mockTool{
name: "uploader",
declaration: &genai.FunctionDeclaration{
Parameters: &genai.Schema{
Required: []string{"file_content"},
},
},
},
args: map[string]any{},
wantNil: false,
wantLen: 1,
},
{
name: "required param nil via JSON schema (MCP tool path)",
tool: &mockTool{
name: "sandbox_upload",
declaration: &genai.FunctionDeclaration{
ParametersJsonSchema: map[string]any{
"type": "object",
"required": []any{"file_content"},
},
},
},
args: map[string]any{"file_content": nil},
wantNil: false,
wantLen: 1,
},
{
name: "all required params present via JSON schema",
tool: &mockTool{
name: "sandbox_upload",
declaration: &genai.FunctionDeclaration{
ParametersJsonSchema: map[string]any{
"type": "object",
"required": []any{"file_content"},
},
},
},
args: map[string]any{"file_content": "data"},
wantNil: true,
},
{
name: "nil declaration",
tool: &mockTool{name: "tool_no_decl", declaration: nil},
args: map[string]any{"p": nil},
wantNil: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := findNilRequiredParams(tt.tool, tt.args)
if tt.wantNil {
if got != nil {
t.Errorf("findNilRequiredParams() = %v, want nil", got)
}
} else {
if len(got) != tt.wantLen {
t.Errorf("findNilRequiredParams() = %v (len %d), want len %d", got, len(got), tt.wantLen)
}
}
})
}
}