diff --git a/go/adk/pkg/tools/remote_a2a_tool.go b/go/adk/pkg/tools/remote_a2a_tool.go index 9fc79b641..75472aa79 100644 --- a/go/adk/pkg/tools/remote_a2a_tool.go +++ b/go/adk/pkg/tools/remote_a2a_tool.go @@ -11,6 +11,7 @@ import ( a2atype "github.com/a2aproject/a2a-go/a2a" "github.com/a2aproject/a2a-go/a2aclient" "github.com/a2aproject/a2a-go/a2aclient/agentcard" + "github.com/a2aproject/a2a-go/a2asrv" "github.com/kagent-dev/kagent/go/adk/pkg/a2a" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "google.golang.org/adk/tool" @@ -20,18 +21,59 @@ import ( // userIDContextKey is the context key for passing the session user_id to the subagent. type userIDContextKey struct{} -// userIDForwardingInterceptor forwards the session user_id as an x-user-id header. -type userIDForwardingInterceptor struct { +// authorizationHeaderContextKey is the context key for passing the parent Authorization header to the subagent. +type authorizationHeaderContextKey struct{} + +// subagentForwardingInterceptor forwards x-user-id and the parent Authorization header. +// Only Authorization is promoted from parent request metadata; all other headers stay scoped to the parent call. +type subagentForwardingInterceptor struct { a2aclient.PassthroughInterceptor } -func (u *userIDForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) { +func (u *subagentForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) { if uid, ok := ctx.Value(userIDContextKey{}).(string); ok && uid != "" { req.Meta.Append("x-user-id", uid) } + if authorization, ok := ctx.Value(authorizationHeaderContextKey{}).(string); ok && authorization != "" { + for key := range req.Meta { + if strings.EqualFold(key, "authorization") { + delete(req.Meta, key) + } + } + req.Meta.Append("Authorization", authorization) + } return ctx, nil } +func subagentCallContext(ctx tool.Context) context.Context { + sendCtx := context.WithValue(ctx, userIDContextKey{}, ctx.UserID()) + if authorization := authorizationHeaderFromContext(ctx); authorization != "" { + sendCtx = context.WithValue(sendCtx, authorizationHeaderContextKey{}, authorization) + } + return sendCtx +} + +func authorizationHeaderFromContext(ctx context.Context) string { + callCtx, ok := a2asrv.CallContextFrom(ctx) + if !ok { + return "" + } + meta := callCtx.RequestMeta() + if meta == nil { + return "" + } + values, ok := meta.Get("authorization") + if !ok { + return "" + } + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + // remoteA2AInput is the typed argument for the remote A2A function tool. type remoteA2AInput struct { Request string `json:"request"` @@ -121,7 +163,7 @@ func (s *remoteA2AState) ensureClient(ctx context.Context) (*a2aclient.Client, e } opts = append(opts, a2aclient.WithInterceptors( a2aclient.NewStaticCallMetaInjector(meta), - &userIDForwardingInterceptor{}, + &subagentForwardingInterceptor{}, )) client, err := a2aclient.NewFromCard(ctx, card, opts...) @@ -159,7 +201,7 @@ func (s *remoteA2AState) handleFirstCall(ctx tool.Context, requestText string) ( ) message.ContextID = s.lastContextID - sendCtx := context.WithValue(ctx, userIDContextKey{}, ctx.UserID()) + sendCtx := subagentCallContext(ctx) result, err := client.SendMessage(sendCtx, &a2atype.MessageSendParams{Message: message}) if err != nil { slog.Error("Remote agent request failed", "tool", s.name, "error", err) @@ -209,7 +251,7 @@ func (s *remoteA2AState) handleResume(ctx tool.Context) (map[string]any, error) return map[string]any{"error": err.Error()}, nil } - sendCtx := context.WithValue(ctx, userIDContextKey{}, ctx.UserID()) + sendCtx := subagentCallContext(ctx) result, err := client.SendMessage(sendCtx, &a2atype.MessageSendParams{Message: message}) if err != nil { slog.Error("Remote agent resume failed", "tool", subagentName, "error", err) diff --git a/go/adk/pkg/tools/remote_a2a_tool_test.go b/go/adk/pkg/tools/remote_a2a_tool_test.go new file mode 100644 index 000000000..d4626fbab --- /dev/null +++ b/go/adk/pkg/tools/remote_a2a_tool_test.go @@ -0,0 +1,65 @@ +package tools + +import ( + "context" + "reflect" + "testing" + + "github.com/a2aproject/a2a-go/a2aclient" + "github.com/a2aproject/a2a-go/a2asrv" +) + +func TestSubagentForwardingInterceptorForwardsUserIDAndAuthorization(t *testing.T) { + ctx := context.WithValue(context.Background(), userIDContextKey{}, "user-1") + ctx = context.WithValue(ctx, authorizationHeaderContextKey{}, "Bearer parent-token") + + req := &a2aclient.Request{Meta: a2aclient.CallMeta{}} + _, err := (&subagentForwardingInterceptor{}).Before(ctx, req) + if err != nil { + t.Fatalf("Before() error = %v", err) + } + + if got, want := req.Meta.Get("x-user-id"), []string{"user-1"}; !reflect.DeepEqual(got, want) { + t.Fatalf("x-user-id = %v, want %v", got, want) + } + if got, want := req.Meta.Get("authorization"), []string{"Bearer parent-token"}; !reflect.DeepEqual(got, want) { + t.Fatalf("authorization = %v, want %v", got, want) + } +} + +func TestSubagentForwardingInterceptorReplacesExistingAuthorization(t *testing.T) { + ctx := context.WithValue(context.Background(), authorizationHeaderContextKey{}, "Bearer parent-token") + + req := &a2aclient.Request{Meta: a2aclient.CallMeta{}} + req.Meta.Append("Authorization", "Bearer stale-token") + + _, err := (&subagentForwardingInterceptor{}).Before(ctx, req) + if err != nil { + t.Fatalf("Before() error = %v", err) + } + + if got, want := req.Meta.Get("authorization"), []string{"Bearer parent-token"}; !reflect.DeepEqual(got, want) { + t.Fatalf("authorization = %v, want %v", got, want) + } +} + +func TestAuthorizationHeaderFromContext(t *testing.T) { + ctx, _ := a2asrv.WithCallContext(context.Background(), a2asrv.NewRequestMeta(map[string][]string{ + "Authorization": {"Bearer parent-token"}, + "X-Other": {"ignored"}, + })) + + if got, want := authorizationHeaderFromContext(ctx), "Bearer parent-token"; got != want { + t.Fatalf("authorizationHeaderFromContext() = %q, want %q", got, want) + } +} + +func TestAuthorizationHeaderFromContextWithoutHeader(t *testing.T) { + ctx, _ := a2asrv.WithCallContext(context.Background(), a2asrv.NewRequestMeta(map[string][]string{ + "X-Other": {"ignored"}, + })) + + if got := authorizationHeaderFromContext(ctx); got != "" { + t.Fatalf("authorizationHeaderFromContext() = %q, want empty", got) + } +} diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index ff73d0c6d..9a2fbe4f1 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -56,27 +56,51 @@ logger = logging.getLogger("kagent_adk." + __name__) _USER_ID_CONTEXT_KEY = "x-user-id" +_HEADERS_CONTEXT_KEY = "headers" _SOURCE_HEADER = "x-kagent-source" _SOURCE_SUBAGENT = "agent" class _SubagentInterceptor(ClientCallInterceptor): """ - Injects the authenticated user's ID as an ``x-user-id`` HTTP header and - marks the request as originating from an agent call via - ``x-kagent-source: agent`` on every outgoing A2A request. + Injects the authenticated user's ID as an ``x-user-id`` HTTP header, + forwards the parent ``Authorization`` header when available, and marks + the request as originating from an agent call via ``x-kagent-source: + agent`` on every outgoing A2A request. + + Only ``Authorization`` is promoted from parent session headers; all + other session headers remain context-only. """ async def intercept(self, method_name, request_payload, http_kwargs, agent_card, context): headers = dict(http_kwargs.get("headers", {})) # Always mark requests from a parent agent tool as subagent-originated headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT - if context and _USER_ID_CONTEXT_KEY in context.state: - headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] + if context: + if _USER_ID_CONTEXT_KEY in context.state: + headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] + + request_headers = context.state.get(_HEADERS_CONTEXT_KEY, {}) + if isinstance(request_headers, dict): + for key, value in request_headers.items(): + if key.lower() == "authorization": + headers = {k: v for k, v in headers.items() if k.lower() != "authorization"} + headers[key] = value + break http_kwargs["headers"] = headers return request_payload, http_kwargs +def _build_subagent_call_context(tool_context: ToolContext) -> ClientCallContext: + """Build A2A call context for requests delegated to sub-agents.""" + ctx_state = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} + session_state = getattr(tool_context.session, "state", {}) or {} + session_headers = session_state.get(_HEADERS_CONTEXT_KEY, {}) if isinstance(session_state, dict) else {} + if session_headers: + ctx_state[_HEADERS_CONTEXT_KEY] = session_headers + return ClientCallContext(state=ctx_state) + + def _extract_text_from_task(task: Task) -> str: """Extract text content from a completed task's artifacts or status message.""" # Prefer artifacts (the canonical result) @@ -239,7 +263,7 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte # Forward the authenticated user ID so the subagent session is scoped # to the same user as the parent agent session. - call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id}) + call_context = _build_subagent_call_context(tool_context) task: Optional[Task] = None try: @@ -381,7 +405,7 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any: ) client = await self._ensure_client() - call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id}) + call_context = _build_subagent_call_context(tool_context) task: Optional[Task] = None try: async for response in client.send_message(request=decision_message, context=call_context): diff --git a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py index 3185ad8e2..72c40601a 100644 --- a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py +++ b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx +from a2a.client.middleware import ClientCallContext from a2a.types import ( DataPart, Role, @@ -26,6 +27,7 @@ KAgentRemoteA2ATool, KAgentRemoteA2AToolset, SubagentSessionProvider, + _SubagentInterceptor, ) # --------------------------------------------------------------------------- @@ -38,8 +40,9 @@ class _MockSession: """Minimal session mock providing user_id.""" - def __init__(self, user_id: str = _DEFAULT_USER_ID): + def __init__(self, user_id: str = _DEFAULT_USER_ID, state: dict[str, Any] | None = None): self.user_id = user_id + self.state = state or {} class MockToolContext: @@ -49,11 +52,12 @@ def __init__( self, tool_confirmation: ToolConfirmation | None = None, user_id: str = _DEFAULT_USER_ID, + session_state: dict[str, Any] | None = None, ): self.state: dict[str, Any] = {} self.function_call_id = "outer_fc_1" self.tool_confirmation = tool_confirmation - self.session = _MockSession(user_id) + self.session = _MockSession(user_id, session_state) self._confirmations: dict[str, ToolConfirmation] = {} def request_confirmation(self, *, hint: str = "", payload: dict | None = None) -> None: @@ -237,6 +241,48 @@ async def capture(*, request, context=None, **kw): assert captured_contexts[0].state["x-user-id"] == "alice@example.com" + async def test_session_headers_forwarded_in_call_context(self): + """The parent session's request headers are forwarded via ClientCallContext.""" + tool = _make_tool() + task = _make_task(TaskState.completed, text="ok") + captured_contexts: list = [] + session_headers = { + "Authorization": "Bearer user-token", + "x-not-forwarded-by-interceptor": "kept-in-context-only", + } + + async def capture(*, request, context=None, **kw): + captured_contexts.append(context) + yield (task, None) + + p, _ = _patch_client(tool, capture) + try: + ctx = MockToolContext(session_state={"headers": session_headers}) + await tool.run_async(args={"request": "go"}, tool_context=ctx) + finally: + p.stop() + + assert captured_contexts[0].state["headers"] == session_headers + + async def test_call_context_omits_headers_when_session_has_none(self): + """Missing session headers do not prevent building a valid ClientCallContext.""" + tool = _make_tool() + task = _make_task(TaskState.completed, text="ok") + captured_contexts: list = [] + + async def capture(*, request, context=None, **kw): + captured_contexts.append(context) + yield (task, None) + + p, _ = _patch_client(tool, capture) + try: + ctx = MockToolContext(user_id="alice@example.com") + await tool.run_async(args={"request": "go"}, tool_context=ctx) + finally: + p.stop() + + assert captured_contexts[0].state == {"x-user-id": "alice@example.com"} + # --------------------------------------------------------------------------- # HITL input_required tests @@ -404,6 +450,84 @@ async def test_resume_input_required_chains(self): assert ctx.function_call_id in ctx._confirmations assert "restart_pod" in ctx._confirmations[ctx.function_call_id].hint + async def test_session_headers_forwarded_in_resume_call_context(self): + """Resume calls also forward parent session request headers via ClientCallContext.""" + tool = _make_tool() + task = _make_task(TaskState.completed, text="ok") + captured_contexts: list = [] + session_headers = {"Authorization": "Bearer resumed-user-token"} + + async def capture(*, request, context=None, **kw): + captured_contexts.append(context) + yield (task, None) + + p, _ = _patch_client(tool, capture) + try: + ctx = _approval_ctx( + confirmed=True, + payload=_RESUME_PAYLOAD, + session_state={"headers": session_headers}, + ) + await tool.run_async(args={}, tool_context=ctx) + finally: + p.stop() + + assert captured_contexts[0].state["headers"] == session_headers + + +# --------------------------------------------------------------------------- +# Subagent interceptor tests +# --------------------------------------------------------------------------- + + +class TestSubagentInterceptor: + async def test_forwards_authorization_from_context_headers(self): + interceptor = _SubagentInterceptor() + context = ClientCallContext( + state={ + "x-user-id": "alice@example.com", + "headers": { + "Authorization": "Bearer user-token", + "x-secret-header": "should-not-forward", + }, + } + ) + + _, http_kwargs = await interceptor.intercept( + "send_message", + {}, + {"headers": {"authorization": "Bearer stale-token", "accept": "application/json"}}, + None, + context, + ) + + assert http_kwargs["headers"]["Authorization"] == "Bearer user-token" + assert http_kwargs["headers"]["x-user-id"] == "alice@example.com" + assert http_kwargs["headers"]["x-kagent-source"] == "agent" + assert http_kwargs["headers"]["accept"] == "application/json" + assert "authorization" not in http_kwargs["headers"] + assert "x-secret-header" not in http_kwargs["headers"] + + async def test_ignores_context_headers_without_authorization(self): + interceptor = _SubagentInterceptor() + context = ClientCallContext( + state={ + "headers": { + "x-secret-header": "should-not-forward", + }, + } + ) + + _, http_kwargs = await interceptor.intercept( + "send_message", + {}, + {"headers": {}}, + None, + context, + ) + + assert http_kwargs["headers"] == {"x-kagent-source": "agent"} + # --------------------------------------------------------------------------- # Toolset lifecycle tests