Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions go/adk/pkg/tools/remote_a2a_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
a2atype "github.com/a2aproject/a2a-go/a2a"
"github.com/a2aproject/a2a-go/a2aclient"
"github.com/a2aproject/a2a-go/a2aclient/agentcard"
"github.com/a2aproject/a2a-go/a2asrv"
"github.com/kagent-dev/kagent/go/adk/pkg/a2a"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"google.golang.org/adk/tool"
Expand All @@ -20,18 +21,59 @@ import (
// userIDContextKey is the context key for passing the session user_id to the subagent.
type userIDContextKey struct{}

// userIDForwardingInterceptor forwards the session user_id as an x-user-id header.
type userIDForwardingInterceptor struct {
// authorizationHeaderContextKey is the context key for passing the parent Authorization header to the subagent.
type authorizationHeaderContextKey struct{}

// subagentForwardingInterceptor forwards x-user-id and the parent Authorization header.
// Only Authorization is promoted from parent request metadata; all other headers stay scoped to the parent call.
type subagentForwardingInterceptor struct {
a2aclient.PassthroughInterceptor
}

func (u *userIDForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) {
func (u *subagentForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) {
if uid, ok := ctx.Value(userIDContextKey{}).(string); ok && uid != "" {
req.Meta.Append("x-user-id", uid)
}
if authorization, ok := ctx.Value(authorizationHeaderContextKey{}).(string); ok && authorization != "" {
for key := range req.Meta {
if strings.EqualFold(key, "authorization") {
delete(req.Meta, key)
}
}
req.Meta.Append("Authorization", authorization)
}
return ctx, nil
}

func subagentCallContext(ctx tool.Context) context.Context {
sendCtx := context.WithValue(ctx, userIDContextKey{}, ctx.UserID())
if authorization := authorizationHeaderFromContext(ctx); authorization != "" {
sendCtx = context.WithValue(sendCtx, authorizationHeaderContextKey{}, authorization)
}
return sendCtx
}

func authorizationHeaderFromContext(ctx context.Context) string {
callCtx, ok := a2asrv.CallContextFrom(ctx)
if !ok {
return ""
}
meta := callCtx.RequestMeta()
if meta == nil {
return ""
}
values, ok := meta.Get("authorization")
if !ok {
return ""
}
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}

// remoteA2AInput is the typed argument for the remote A2A function tool.
type remoteA2AInput struct {
Request string `json:"request"`
Expand Down Expand Up @@ -121,7 +163,7 @@ func (s *remoteA2AState) ensureClient(ctx context.Context) (*a2aclient.Client, e
}
opts = append(opts, a2aclient.WithInterceptors(
a2aclient.NewStaticCallMetaInjector(meta),
&userIDForwardingInterceptor{},
&subagentForwardingInterceptor{},
))

client, err := a2aclient.NewFromCard(ctx, card, opts...)
Expand Down Expand Up @@ -159,7 +201,7 @@ func (s *remoteA2AState) handleFirstCall(ctx tool.Context, requestText string) (
)
message.ContextID = s.lastContextID

sendCtx := context.WithValue(ctx, userIDContextKey{}, ctx.UserID())
sendCtx := subagentCallContext(ctx)
result, err := client.SendMessage(sendCtx, &a2atype.MessageSendParams{Message: message})
if err != nil {
slog.Error("Remote agent request failed", "tool", s.name, "error", err)
Expand Down Expand Up @@ -209,7 +251,7 @@ func (s *remoteA2AState) handleResume(ctx tool.Context) (map[string]any, error)
return map[string]any{"error": err.Error()}, nil
}

sendCtx := context.WithValue(ctx, userIDContextKey{}, ctx.UserID())
sendCtx := subagentCallContext(ctx)
result, err := client.SendMessage(sendCtx, &a2atype.MessageSendParams{Message: message})
if err != nil {
slog.Error("Remote agent resume failed", "tool", subagentName, "error", err)
Expand Down
65 changes: 65 additions & 0 deletions go/adk/pkg/tools/remote_a2a_tool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package tools

import (
"context"
"reflect"
"testing"

"github.com/a2aproject/a2a-go/a2aclient"
"github.com/a2aproject/a2a-go/a2asrv"
)

func TestSubagentForwardingInterceptorForwardsUserIDAndAuthorization(t *testing.T) {
ctx := context.WithValue(context.Background(), userIDContextKey{}, "user-1")
ctx = context.WithValue(ctx, authorizationHeaderContextKey{}, "Bearer parent-token")

req := &a2aclient.Request{Meta: a2aclient.CallMeta{}}
_, err := (&subagentForwardingInterceptor{}).Before(ctx, req)
if err != nil {
t.Fatalf("Before() error = %v", err)
}

if got, want := req.Meta.Get("x-user-id"), []string{"user-1"}; !reflect.DeepEqual(got, want) {
t.Fatalf("x-user-id = %v, want %v", got, want)
}
if got, want := req.Meta.Get("authorization"), []string{"Bearer parent-token"}; !reflect.DeepEqual(got, want) {
t.Fatalf("authorization = %v, want %v", got, want)
}
}

func TestSubagentForwardingInterceptorReplacesExistingAuthorization(t *testing.T) {
ctx := context.WithValue(context.Background(), authorizationHeaderContextKey{}, "Bearer parent-token")

req := &a2aclient.Request{Meta: a2aclient.CallMeta{}}
req.Meta.Append("Authorization", "Bearer stale-token")

_, err := (&subagentForwardingInterceptor{}).Before(ctx, req)
if err != nil {
t.Fatalf("Before() error = %v", err)
}

if got, want := req.Meta.Get("authorization"), []string{"Bearer parent-token"}; !reflect.DeepEqual(got, want) {
t.Fatalf("authorization = %v, want %v", got, want)
}
}

func TestAuthorizationHeaderFromContext(t *testing.T) {
ctx, _ := a2asrv.WithCallContext(context.Background(), a2asrv.NewRequestMeta(map[string][]string{
"Authorization": []string{"Bearer parent-token"},
"X-Other": []string{"ignored"},
}))

if got, want := authorizationHeaderFromContext(ctx), "Bearer parent-token"; got != want {
t.Fatalf("authorizationHeaderFromContext() = %q, want %q", got, want)
}
}

func TestAuthorizationHeaderFromContextWithoutHeader(t *testing.T) {
ctx, _ := a2asrv.WithCallContext(context.Background(), a2asrv.NewRequestMeta(map[string][]string{
"X-Other": []string{"ignored"},
}))

if got := authorizationHeaderFromContext(ctx); got != "" {
t.Fatalf("authorizationHeaderFromContext() = %q, want empty", got)
}
}
38 changes: 31 additions & 7 deletions python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,51 @@
logger = logging.getLogger("kagent_adk." + __name__)

_USER_ID_CONTEXT_KEY = "x-user-id"
_HEADERS_CONTEXT_KEY = "headers"
_SOURCE_HEADER = "x-kagent-source"
_SOURCE_SUBAGENT = "agent"


class _SubagentInterceptor(ClientCallInterceptor):
"""
Injects the authenticated user's ID as an ``x-user-id`` HTTP header and
marks the request as originating from an agent call via
``x-kagent-source: agent`` on every outgoing A2A request.
Injects the authenticated user's ID as an ``x-user-id`` HTTP header,
forwards the parent ``Authorization`` header when available, and marks
the request as originating from an agent call via ``x-kagent-source:
agent`` on every outgoing A2A request.

Only ``Authorization`` is promoted from parent session headers; all
other session headers remain context-only.
"""

async def intercept(self, method_name, request_payload, http_kwargs, agent_card, context):
headers = dict(http_kwargs.get("headers", {}))
# Always mark requests from a parent agent tool as subagent-originated
headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT
if context and _USER_ID_CONTEXT_KEY in context.state:
headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY]
if context:
if _USER_ID_CONTEXT_KEY in context.state:
headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY]

request_headers = context.state.get(_HEADERS_CONTEXT_KEY, {})
if isinstance(request_headers, dict):
for key, value in request_headers.items():
if key.lower() == "authorization":
headers = {k: v for k, v in headers.items() if k.lower() != "authorization"}
headers[key] = value
break
Comment thread
towsif-rahman marked this conversation as resolved.
http_kwargs["headers"] = headers
return request_payload, http_kwargs


def _build_subagent_call_context(tool_context: ToolContext) -> ClientCallContext:
"""Build A2A call context for requests delegated to sub-agents."""
ctx_state = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id}
session_state = getattr(tool_context.session, "state", {}) or {}
session_headers = session_state.get(_HEADERS_CONTEXT_KEY, {}) if isinstance(session_state, dict) else {}
if session_headers:
ctx_state[_HEADERS_CONTEXT_KEY] = session_headers
return ClientCallContext(state=ctx_state)


def _extract_text_from_task(task: Task) -> str:
"""Extract text content from a completed task's artifacts or status message."""
# Prefer artifacts (the canonical result)
Expand Down Expand Up @@ -239,7 +263,7 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte

# Forward the authenticated user ID so the subagent session is scoped
# to the same user as the parent agent session.
call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id})
call_context = _build_subagent_call_context(tool_context)

task: Optional[Task] = None
try:
Expand Down Expand Up @@ -381,7 +405,7 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any:
)

client = await self._ensure_client()
call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id})
call_context = _build_subagent_call_context(tool_context)
task: Optional[Task] = None
try:
async for response in client.send_message(request=decision_message, context=call_context):
Expand Down
Loading
Loading