Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 166 additions & 3 deletions src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def query_log_group(
start_time: datetime,
end_time: datetime,
query_string: Optional[str] = None,
agent_id: Optional[str] = None,
) -> List[dict]:
"""Query a single CloudWatch log group for session data.

Expand All @@ -54,16 +55,23 @@ def query_log_group(
end_time: Query end time
query_string: Optional custom query string. When provided, used instead
of the default substring match query.
agent_id: Optional agent ID to filter by (prevents cross-agent session collisions)

Returns:
List of parsed JSON log messages
"""
if query_string is None:
agent_filter = ""
if agent_id is not None:
agent_filter = (
f'\n | parse resource.attributes.cloud.resource_id "runtime/*/" as parsedAgentId'
f"\n | filter parsedAgentId = '{agent_id}'"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CRITICAL] CWL Insights injection — agent_id, session_id, trace_id, endpoint_name

Every new method interpolates caller-controlled values directly into CWL Insights query strings with only quote-wrapping and no validation:

f"| filter parsedAgentId = '{agent_id}'"          # line 69
f'| filter @message like "{session_id}"'           # line 71
f"| filter traceId = '{trace_id}'"                 # lines 234, 356
f"/aws/bedrock-agentcore/runtimes/{agent_id}-{endpoint_name}"  # line 261 (log group name)
f"| filter parsedAgentId = '{agent_id}'"           # line 291

A value like x' | stats count(*) as n # rewrites the query semantics. A malicious endpoint_name can pivot the log group path to any group the caller's IAM role can read. In any multi-tenant SDK wrapper this is a cross-tenant log disclosure vector. No test exercises a hostile input.

Fix: validate all interpolated values against a strict pattern (e.g. ^[A-Za-z0-9_\-:.]{1,64}$) and reject anything else before interpolation.

query_string = f"""fields @timestamp, @message
| filter @message like "{session_id}"
| filter ispresent(scope.name)
| filter ispresent(traceId)
| filter ispresent(spanId)
| filter ispresent(spanId){agent_filter}
| sort @timestamp asc"""

max_attempts = 30
Expand Down Expand Up @@ -149,6 +157,7 @@ def fetch_spans(
event_log_group: str,
start_time: datetime,
end_time: Optional[datetime] = None,
agent_id: Optional[str] = None,
) -> List[dict]:
"""Fetch ADOT spans from CloudWatch with configurable event log group.

Expand All @@ -162,6 +171,7 @@ def fetch_spans(
- For custom agents: Any log group you configured (e.g., "/my-app/agent-events")
start_time: Start time for log query
end_time: End time for log query
agent_id: Optional agent ID to filter by (prevents cross-agent session collisions)

Returns:
List of ADOT span and log record dictionaries
Expand Down Expand Up @@ -190,14 +200,167 @@ def fetch_spans(
end_time = datetime.now()

# Query both log groups
aws_spans = self.query_log_group("aws/spans", session_id, start_time, end_time)
event_logs = self.query_log_group(event_log_group, session_id, start_time, end_time)
aws_spans = self.query_log_group("aws/spans", session_id, start_time, end_time, agent_id=agent_id)
event_logs = self.query_log_group(event_log_group, session_id, start_time, end_time, agent_id=agent_id)

all_data = aws_spans + event_logs

logger.info("Fetched %d span items from CloudWatch", len(all_data))
return all_data

def query_spans_by_trace(
self,
trace_id: str,
start_time_ms: int,
end_time_ms: int,
) -> List[dict]:
"""Query all spans for a trace from aws/spans log group.

Note: Trace IDs are globally unique, so no agent_id filter needed to prevent cross-agent access

Args:
trace_id: The trace ID to query
start_time_ms: Start time in milliseconds since epoch
end_time_ms: End time in milliseconds since epoch

Returns:
List of result dictionaries from CloudWatch Logs Insights
"""
query_string = f"""fields @timestamp, @message, traceId, spanId, name as spanName,
kind, status.code as statusCode, status.message as statusMessage,
durationNano/1000000 as durationMs, attributes.session.id as sessionId,
startTimeUnixNano, endTimeUnixNano, parentSpanId, events,
resource.attributes.service.name as serviceName
| filter traceId = '{trace_id}'
| sort startTimeUnixNano asc"""
return self._execute_query(query_string, "aws/spans", start_time_ms, end_time_ms)

def query_runtime_logs_by_traces(
self,
trace_ids: List[str],
start_time_ms: int,
end_time_ms: int,
agent_id: str,
endpoint_name: str = "DEFAULT",
) -> List[dict]:
"""Query runtime logs for multiple traces from agent-specific log group.

Args:
trace_ids: List of trace IDs to query
start_time_ms: Start time in milliseconds since epoch
end_time_ms: End time in milliseconds since epoch
agent_id: Agent ID for constructing the log group name
endpoint_name: Runtime endpoint name (default: DEFAULT)

Returns:
List of result dictionaries from CloudWatch Logs Insights
"""
if not trace_ids:
return []

log_group = f"/aws/bedrock-agentcore/runtimes/{agent_id}-{endpoint_name}"
trace_ids_quoted = ", ".join([f"'{tid}'" for tid in trace_ids])
query_string = f"""fields @timestamp, @message, spanId, traceId, @logStream
| filter traceId in [{trace_ids_quoted}]
| sort @timestamp asc"""

try:
return self._execute_query(query_string, log_group, start_time_ms, end_time_ms)
except Exception as e:
logger.warning("Batch query failed, falling back to individual queries: %s", e)
return self._query_runtime_logs_individually(trace_ids, log_group, start_time_ms, end_time_ms)

def get_latest_session_id(
self,
start_time_ms: int,
end_time_ms: int,
agent_id: str,
) -> Optional[str]:
"""Get the most recent session ID for an agent.

Args:
start_time_ms: Start time in milliseconds since epoch
end_time_ms: End time in milliseconds since epoch
agent_id: Agent ID to query for

Returns:
Latest session ID or None if no sessions found
"""
query_string = f"""filter resource.attributes.aws.service.type = "gen_ai_agent"
| parse resource.attributes.cloud.resource_id "runtime/*/" as parsedAgentId
| filter parsedAgentId = '{agent_id}'
| stats max(endTimeUnixNano) as maxEnd by attributes.session.id
| sort maxEnd desc
| limit 1"""

results = self._execute_query(query_string, "aws/spans", start_time_ms, end_time_ms)
if not results or not results[0]:
return None

for field in results[0]:
if field.get("field") == "attributes.session.id":
return field.get("value")
return None

def _execute_query(
self,
query_string: str,
log_group_name: str,
start_time_ms: int,
end_time_ms: int,
) -> List[dict]:
"""Execute a CloudWatch Logs Insights query and wait for results.

Args:
query_string: The query string
log_group_name: Log group to query
start_time_ms: Start time in milliseconds since epoch
end_time_ms: End time in milliseconds since epoch

Returns:
List of result row dictionaries
"""
response = self.logs_client.start_query(
logGroupName=log_group_name,
startTime=start_time_ms // 1000,
endTime=end_time_ms // 1000,
queryString=query_string,
)
query_id = response["queryId"]

timeout = 60
poll_interval = 2
start = time.time()
while True:
if time.time() - start > timeout:
raise TimeoutError(f"Query {query_id} timed out after {timeout} seconds")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CRITICAL] Running query is never cancelled on timeout, error, or exception

When TimeoutError is raised (or any exception propagates out of the polling loop), the query continues running server-side. There is no try/finally and no call to stop_query(queryId=query_id).

Per AWS docs: "queries continue to run until completion" unless StopQuery is explicitly called. Abandoned queries consume one of your 100 regional concurrent query slots (shared with scheduled queries) for up to the 60-minute server-side timeout. The _query_runtime_logs_individually fallback multiplies this: a single failed batch query spawns N individual queries that can each also leak a slot.

Fix:

status = None
try:
    while True:
        ...
        if status == "Complete":
            return result.get("results", [])
finally:
    if status != "Complete":
        self.logs_client.stop_query(queryId=query_id)

result = self.logs_client.get_query_results(queryId=query_id)
status = result["status"]
if status == "Complete":
return result.get("results", [])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[HIGH] No 10,000-row truncation warning — silent data loss

_execute_query returns result.get("results", []) with no check on statistics.recordsMatched. Per AWS docs, GetQueryResults has a hard ceiling of 10,000 rows; statistics.recordsMatched reflects the total matching events regardless of this cap, so truncation is detectable.

Contrast with query_log_group (lines 106–117) which explicitly checks if records_matched > 10000 and warns. All three new methods (query_spans_by_trace, query_runtime_logs_by_traces, get_latest_session_id) go through _execute_query and silently truncate. Evaluation pipelines computing metrics over truncated span data will produce incorrect results with no signal.

Fix: port the recordsMatched check from query_log_group into this method.

elif status in ("Failed", "Cancelled"):
raise RuntimeError(f"Query {query_id} failed with status: {status}")
time.sleep(poll_interval)

def _query_runtime_logs_individually(
self,
trace_ids: List[str],
log_group: str,
start_time_ms: int,
end_time_ms: int,
) -> List[dict]:
"""Fallback to query runtime logs one trace at a time."""
results = []
for trace_id in trace_ids:
query = f"""fields @timestamp, @message, spanId, traceId, @logStream
| filter traceId = '{trace_id}'
| sort @timestamp asc"""
try:
results.extend(self._execute_query(query, log_group, start_time_ms, end_time_ms))
except Exception as e:
logger.warning("Failed to query runtime logs for trace %s: %s", trace_id, e)
return results
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[HIGH] Fallback silently returns partial results — no indication of which traces failed

Each per-trace except Exception is swallowed and the loop continues. The caller receives a plain List[dict] with no metadata about completeness — indistinguishable from a full successful result. Evaluation pipelines will compute error rates, latency, and span counts over an unknown partial dataset and report them as authoritative.

Fix: return a structured result such as {"results": [...], "failed_trace_ids": [...], "complete": bool}, or raise after the batch fallback already fired to ensure the caller is aware.



def fetch_spans_from_cloudwatch(
session_id: str,
Expand Down
5 changes: 5 additions & 0 deletions src/bedrock_agentcore/observability/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Observability delivery and configuration for AgentCore resources."""

from .client import ObservabilityClient

__all__ = ["ObservabilityClient"]
Loading
Loading