diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index 3d6159768..dc8fd3bc3 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -347,7 +347,8 @@ func extractHeaders(headers map[string]string) map[string]string { return headers } -// makeBeforeToolCallback returns a BeforeToolCallback that logs tool invocations. +// makeBeforeToolCallback returns a BeforeToolCallback that logs tool invocations +// and short-circuits calls where required parameters are nil. func makeBeforeToolCallback(logger logr.Logger) llmagent.BeforeToolCallback { return func(ctx tool.Context, t tool.Tool, args map[string]any) (map[string]any, error) { logger.Info("Tool execution started", @@ -357,10 +358,84 @@ func makeBeforeToolCallback(logger logr.Logger) llmagent.BeforeToolCallback { "invocationID", ctx.InvocationID(), "args", truncateArgs(args), ) + + // Short-circuit if required params are missing or nil; the LLM may drop large values + // due to output-token limits, producing an opaque downstream error otherwise. + if missingParams := findNilRequiredParams(t, args); len(missingParams) > 0 { + msg := fmt.Sprintf( + "tool %q: required parameter(s) %v are missing or nil; the LLM may have omitted large values due to output-token limits", + t.Name(), missingParams, + ) + logger.Info("required parameters missing or nil, skipping tool call", + "tool", t.Name(), + "missingParams", missingParams, + ) + return map[string]any{"error": msg}, nil + } + return nil, nil } } +// toolDeclarationProvider is implemented by tools that expose a genai.FunctionDeclaration. +type toolDeclarationProvider interface { + Declaration() *genai.FunctionDeclaration +} + +// requiredParamNames returns the required parameter names from the tool's declaration. +func requiredParamNames(t tool.Tool) []string { + dp, ok := t.(toolDeclarationProvider) + if !ok { + return nil + } + decl := dp.Declaration() + if decl == nil { + return nil + } + + // Structured schema (function tools). + if decl.Parameters != nil && len(decl.Parameters.Required) > 0 { + return decl.Parameters.Required + } + + // Opaque JSON schema (MCP tools): re-marshal to extract the "required" array. + if decl.ParametersJsonSchema != nil { + return extractRequiredFromJSONSchema(decl.ParametersJsonSchema) + } + + return nil +} + +// extractRequiredFromJSONSchema decodes the "required" array from an opaque JSON-schema value. +func extractRequiredFromJSONSchema(schema any) []string { + b, err := json.Marshal(schema) + if err != nil { + return nil + } + var s struct { + Required []string `json:"required"` + } + if err := json.Unmarshal(b, &s); err != nil { + return nil + } + return s.Required +} + +// findNilRequiredParams returns required parameter names whose value in args is nil or absent. +func findNilRequiredParams(t tool.Tool, args map[string]any) []string { + required := requiredParamNames(t) + if len(required) == 0 { + return nil + } + var missing []string + for _, param := range required { + if v, exists := args[param]; !exists || v == nil { + missing = append(missing, param) + } + } + return missing +} + // makeAfterToolCallback returns an AfterToolCallback that logs tool completion. func makeAfterToolCallback(logger logr.Logger) llmagent.AfterToolCallback { return func(ctx tool.Context, t tool.Tool, args, result map[string]any, err error) (map[string]any, error) { diff --git a/go/adk/pkg/agent/agent_test.go b/go/adk/pkg/agent/agent_test.go index 4f4c5fb45..664dd1661 100644 --- a/go/adk/pkg/agent/agent_test.go +++ b/go/adk/pkg/agent/agent_test.go @@ -9,6 +9,8 @@ import ( "github.com/go-logr/logr" "github.com/kagent-dev/kagent/go/adk/pkg/models" "github.com/kagent-dev/kagent/go/api/adk" + "google.golang.org/adk/tool" + "google.golang.org/genai" ) // TestConfigDeserialization_OpenAI verifies that a realistic OpenAI config.json @@ -424,3 +426,190 @@ func TestAgentConfigFieldUsage(t *testing.T) { }) } } + +// mockTool implements tool.Tool and toolDeclarationProvider. +type mockTool struct { + name string + declaration *genai.FunctionDeclaration +} + +func (m *mockTool) Name() string { return m.name } +func (m *mockTool) Description() string { return "" } +func (m *mockTool) IsLongRunning() bool { return false } + +// Declaration implements toolDeclarationProvider. +func (m *mockTool) Declaration() *genai.FunctionDeclaration { + return m.declaration +} + +// mockToolNoDeclaration implements tool.Tool without a Declaration method. +type mockToolNoDeclaration struct{ name string } + +func (m *mockToolNoDeclaration) Name() string { return m.name } +func (m *mockToolNoDeclaration) Description() string { return "" } +func (m *mockToolNoDeclaration) IsLongRunning() bool { return false } + +// TestExtractRequiredFromJSONSchema tests required field extraction from opaque JSON-schema values. +func TestExtractRequiredFromJSONSchema(t *testing.T) { + tests := []struct { + name string + schema any + want []string + }{ + { + name: "map with required", + schema: map[string]any{ + "type": "object", + "required": []any{"file_content", "filename"}, + "properties": map[string]any{ + "file_content": map[string]any{"type": "string"}, + "filename": map[string]any{"type": "string"}, + }, + }, + want: []string{"file_content", "filename"}, + }, + { + name: "map without required", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "optional_param": map[string]any{"type": "string"}, + }, + }, + want: nil, + }, + { + name: "nil schema", + schema: nil, + want: nil, + }, + { + name: "empty map", + schema: map[string]any{}, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractRequiredFromJSONSchema(tt.schema) + if len(got) != len(tt.want) { + t.Fatalf("extractRequiredFromJSONSchema() = %v, want %v", got, tt.want) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("extractRequiredFromJSONSchema()[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +} + +// TestFindNilRequiredParams tests detection of nil required parameters. +func TestFindNilRequiredParams(t *testing.T) { + tests := []struct { + name string + tool tool.Tool + args map[string]any + wantNil bool + wantLen int + }{ + { + name: "tool without Declaration - always OK", + tool: &mockToolNoDeclaration{name: "no_decl"}, + args: map[string]any{"p": nil}, + wantNil: true, + }, + { + name: "all required params present", + tool: &mockTool{ + name: "uploader", + declaration: &genai.FunctionDeclaration{ + Parameters: &genai.Schema{ + Required: []string{"file_content"}, + }, + }, + }, + args: map[string]any{"file_content": "hello"}, + wantNil: true, + }, + { + name: "required param is nil via genai schema", + tool: &mockTool{ + name: "uploader", + declaration: &genai.FunctionDeclaration{ + Parameters: &genai.Schema{ + Required: []string{"file_content", "filename"}, + }, + }, + }, + args: map[string]any{"file_content": nil, "filename": "report.txt"}, + wantNil: false, + wantLen: 1, + }, + { + name: "required param absent from args via genai schema", + tool: &mockTool{ + name: "uploader", + declaration: &genai.FunctionDeclaration{ + Parameters: &genai.Schema{ + Required: []string{"file_content"}, + }, + }, + }, + args: map[string]any{}, + wantNil: false, + wantLen: 1, + }, + { + name: "required param nil via JSON schema (MCP tool path)", + tool: &mockTool{ + name: "sandbox_upload", + declaration: &genai.FunctionDeclaration{ + ParametersJsonSchema: map[string]any{ + "type": "object", + "required": []any{"file_content"}, + }, + }, + }, + args: map[string]any{"file_content": nil}, + wantNil: false, + wantLen: 1, + }, + { + name: "all required params present via JSON schema", + tool: &mockTool{ + name: "sandbox_upload", + declaration: &genai.FunctionDeclaration{ + ParametersJsonSchema: map[string]any{ + "type": "object", + "required": []any{"file_content"}, + }, + }, + }, + args: map[string]any{"file_content": "data"}, + wantNil: true, + }, + { + name: "nil declaration", + tool: &mockTool{name: "tool_no_decl", declaration: nil}, + args: map[string]any{"p": nil}, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := findNilRequiredParams(tt.tool, tt.args) + if tt.wantNil { + if got != nil { + t.Errorf("findNilRequiredParams() = %v, want nil", got) + } + } else { + if len(got) != tt.wantLen { + t.Errorf("findNilRequiredParams() = %v (len %d), want len %d", got, len(got), tt.wantLen) + } + } + }) + } +}