Skip to content

Commit 6f9d48e

Browse files
committed
fix: Connect to MCP Servers during agent execution only
This way, the trace propagation should work better and concurrent access to MCP Servers is improved.
1 parent d7b1fd1 commit 6f9d48e

File tree

3 files changed

+38
-86
lines changed

3 files changed

+38
-86
lines changed

msaf/agenticlayer/msaf/agent.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from agent_framework._mcp import MCPStreamableHTTPTool
1515
from agent_framework._tools import FunctionTool
1616
from agenticlayer.shared.config import McpTool, SubAgent
17-
from agenticlayer.shared.otel import TraceContextHttpClient
1817
from httpx_retries import Retry, RetryTransport
1918

2019
logger = logging.getLogger(__name__)
@@ -136,15 +135,11 @@ def create_mcp_tools(self, mcp_tools: list[McpTool]) -> list[MCPStreamableHTTPTo
136135
tools: list[MCPStreamableHTTPTool] = []
137136
for mcp_tool in mcp_tools:
138137
logger.info("Creating MCP tool %s at %s", mcp_tool.name, mcp_tool.url)
139-
# Use TraceContextHttpClient so that stored trace context is injected
140-
# into MCP HTTP requests made by the background post_writer task.
141-
http_client = TraceContextHttpClient()
142138
tools.append(
143139
MCPStreamableHTTPTool(
144140
name=mcp_tool.name,
145141
url=str(mcp_tool.url),
146142
request_timeout=mcp_tool.timeout,
147-
http_client=http_client,
148143
)
149144
)
150145
return tools

msaf/agenticlayer/msaf/agent_to_a2a.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from agent_framework._mcp import MCPStreamableHTTPTool
3131
from agent_framework._tools import FunctionTool
3232
from agenticlayer.shared.config import McpTool, SubAgent
33-
from agenticlayer.shared.otel import TraceContextHttpClient
3433
from httpx_retries import Retry
3534
from starlette.applications import Starlette
3635

@@ -55,10 +54,14 @@ class MsafAgentExecutor(AgentExecutor):
5554
def __init__(
5655
self,
5756
agent: SupportsAgentRun,
58-
extra_tools: list[FunctionTool | MCPStreamableHTTPTool] | None = None,
57+
sub_agent_tools: list[FunctionTool] | None = None,
58+
mcp_tool_configs: list[McpTool] | None = None,
59+
agent_factory: MsafAgentFactory | None = None,
5960
) -> None:
6061
self._agent = agent
61-
self._extra_tools: list[FunctionTool | MCPStreamableHTTPTool] = extra_tools or []
62+
self._sub_agent_tools: list[FunctionTool] = sub_agent_tools or []
63+
self._mcp_tool_configs: list[McpTool] = mcp_tool_configs or []
64+
self._agent_factory = agent_factory
6265

6366
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
6467
"""Execute the agent and publish results to the event queue."""
@@ -93,16 +96,15 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
9396
)
9497

9598
try:
96-
# Capture current OTel trace context so that MCP HTTP requests
97-
# (which run in a background post_writer task without span context)
98-
# carry the correct traceparent/tracestate headers.
99-
for tool in self._extra_tools:
100-
if isinstance(tool, MCPStreamableHTTPTool):
101-
client = getattr(tool, "_httpx_client", None)
102-
if isinstance(client, TraceContextHttpClient):
103-
client.capture_trace_context()
104-
105-
response = await self._agent.run(user_input, tools=self._extra_tools if self._extra_tools else None)
99+
async with contextlib.AsyncExitStack() as stack:
100+
mcp_tools: list[MCPStreamableHTTPTool] = []
101+
if self._mcp_tool_configs and self._agent_factory:
102+
for mcp_tool in self._agent_factory.create_mcp_tools(self._mcp_tool_configs):
103+
await stack.enter_async_context(mcp_tool)
104+
mcp_tools.append(mcp_tool)
105+
106+
all_tools: list[FunctionTool | MCPStreamableHTTPTool] = [*self._sub_agent_tools, *mcp_tools]
107+
response = await self._agent.run(user_input, tools=all_tools if all_tools else None)
106108
response_text = response.text if hasattr(response, "text") else str(response)
107109

108110
await event_queue.enqueue_event(
@@ -162,7 +164,8 @@ async def create_a2a_app(
162164
description: str | None,
163165
rpc_url: str,
164166
sub_agent_tools: list[FunctionTool],
165-
mcp_tools: list[MCPStreamableHTTPTool],
167+
mcp_tool_configs: list[McpTool] | None = None,
168+
agent_factory: MsafAgentFactory | None = None,
166169
) -> A2AStarletteApplication:
167170
"""Create an A2A Starlette application from a Microsoft Agent Framework agent.
168171
@@ -172,14 +175,19 @@ async def create_a2a_app(
172175
description: Optional description of the agent
173176
rpc_url: The URL where the agent will be available for A2A communication
174177
sub_agent_tools: Pre-loaded FunctionTools wrapping remote A2A sub-agents
175-
mcp_tools: Connected MCPStreamableHTTPTool instances
178+
mcp_tool_configs: MCP tool configurations; per-request connections are created at execution time
179+
agent_factory: Factory used to create MCP tools per request
176180
177181
Returns:
178182
An A2AStarletteApplication instance
179183
"""
180184
task_store = InMemoryTaskStore()
181-
extra_tools: list[FunctionTool | MCPStreamableHTTPTool] = [*sub_agent_tools, *mcp_tools]
182-
agent_executor = MsafAgentExecutor(agent=agent, extra_tools=extra_tools if extra_tools else None)
185+
agent_executor = MsafAgentExecutor(
186+
agent=agent,
187+
sub_agent_tools=sub_agent_tools if sub_agent_tools else None,
188+
mcp_tool_configs=mcp_tool_configs,
189+
agent_factory=agent_factory,
190+
)
183191
request_handler = DefaultRequestHandler(agent_executor=agent_executor, task_store=task_store)
184192

185193
agent_card = AgentCard(
@@ -247,44 +255,33 @@ async def _build_app(
247255
sub_agents: list[SubAgent],
248256
tools: list[McpTool],
249257
factory: MsafAgentFactory,
250-
) -> tuple[A2AStarletteApplication, list[MCPStreamableHTTPTool]]:
251-
"""Load sub-agents, connect MCP tools, and return the A2A app plus tools to manage.
258+
) -> A2AStarletteApplication:
259+
"""Load sub-agents and return the A2A app.
252260
253-
The returned MCP tools have already been entered as async context managers;
254-
the caller (lifespan) is responsible for exiting them on shutdown.
261+
MCP tools are created per-request inside the executor; no connections are
262+
established here.
255263
"""
256264
sub_agent_tools = await factory.load_sub_agents(sub_agents)
257-
mcp_tools = factory.create_mcp_tools(tools)
258-
259-
connected_mcp: list[MCPStreamableHTTPTool] = []
260-
try:
261-
for mcp_tool in mcp_tools:
262-
await mcp_tool.__aenter__()
263-
connected_mcp.append(mcp_tool)
264-
except Exception:
265-
for already_connected in reversed(connected_mcp):
266-
await already_connected.__aexit__(None, None, None)
267-
raise
268-
269-
app = await create_a2a_app(
265+
266+
return await create_a2a_app(
270267
agent=agent,
271268
name=name,
272269
description=description,
273270
rpc_url=rpc_url,
274271
sub_agent_tools=sub_agent_tools,
275-
mcp_tools=connected_mcp,
272+
mcp_tool_configs=tools if tools else None,
273+
agent_factory=factory,
276274
)
277-
return app, connected_mcp
278275

279276

280277
def to_starlette(
281-
a2a_app_creator: Callable[[], Awaitable[tuple[A2AStarletteApplication, list[MCPStreamableHTTPTool]]]],
278+
a2a_app_creator: Callable[[], Awaitable[A2AStarletteApplication]],
282279
) -> Starlette:
283280
"""Convert an A2A application creator to a Starlette application.
284281
285282
Args:
286-
a2a_app_creator: A callable that creates an A2AStarletteApplication and
287-
connected MCP tools asynchronously during startup.
283+
a2a_app_creator: A callable that creates an A2AStarletteApplication
284+
asynchronously during startup.
288285
289286
Returns:
290287
A Starlette application that can be run with uvicorn
@@ -295,14 +292,10 @@ def to_starlette(
295292

296293
@contextlib.asynccontextmanager
297294
async def lifespan(app: Starlette) -> AsyncIterator[None]:
298-
a2a_app, connected_mcp = await a2a_app_creator()
295+
a2a_app = await a2a_app_creator()
299296
# Add A2A routes to the main app
300297
a2a_app.add_routes_to_app(app)
301-
try:
302-
yield
303-
finally:
304-
for mcp_tool in connected_mcp:
305-
await mcp_tool.__aexit__(None, None, None)
298+
yield
306299

307300
# Create a Starlette app that will be configured during startup
308301
starlette_app = Starlette(lifespan=lifespan)

shared/agenticlayer/shared/otel.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22

33
import logging
44
import os
5-
from typing import Any
65

76
import httpx
87
from opentelemetry import _logs, metrics, trace
98
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
109
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
1110
from opentelemetry.instrumentation.logging import LoggingInstrumentor
12-
from opentelemetry.propagate import inject
1311
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
1412
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
1513
from opentelemetry.sdk.metrics import MeterProvider
@@ -59,40 +57,6 @@ def response_hook(span: trace.Span, request: httpx.Request, response: httpx.Resp
5957
_logger.exception("Failed to log response body")
6058

6159

62-
class TraceContextHttpClient(httpx.AsyncClient):
63-
"""httpx client that propagates stored trace context headers into every request.
64-
65-
Used for MCP clients where HTTP requests happen in background tasks
66-
that don't inherit the request handler's OTel span context.
67-
68-
Call :meth:`capture_trace_context` from the request handler context
69-
(before agent execution) to snapshot the current trace context for
70-
later injection by the background ``post_writer`` task.
71-
"""
72-
73-
def __init__(self, **kwargs: Any) -> None:
74-
self._trace_headers: dict[str, str] = {}
75-
# Apply MCP-compatible defaults (follow redirects, no env proxy lookup)
76-
kwargs.setdefault("follow_redirects", True)
77-
super().__init__(**kwargs)
78-
# Prepend our hook so stored headers are set first; the monkey-patched
79-
# _async_inject_trace_context hook (from setup_otel) runs after and will
80-
# overwrite with live context when an active span exists.
81-
self.event_hooks.setdefault("request", []).insert(0, self._inject_trace_headers)
82-
83-
async def _inject_trace_headers(self, request: httpx.Request) -> None:
84-
"""Inject stored trace context headers into the request."""
85-
for k, v in self._trace_headers.items():
86-
request.headers[k] = v
87-
88-
def capture_trace_context(self) -> None:
89-
"""Capture current OTel trace context for injection into future requests."""
90-
carrier: dict[str, str] = {}
91-
inject(carrier)
92-
if carrier:
93-
self._trace_headers = carrier
94-
95-
9660
def setup_otel() -> None:
9761
"""Set up OpenTelemetry tracing, logging and metrics (framework-independent)."""
9862
# Set log level for urllib to WARNING to reduce noise (like sending logs to OTLP)

0 commit comments

Comments
 (0)