diff --git a/src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py b/src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py index c69421c2..ee0a0390 100644 --- a/src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py +++ b/src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py @@ -44,6 +44,7 @@ def query_log_group( start_time: datetime, end_time: datetime, query_string: Optional[str] = None, + agent_id: Optional[str] = None, ) -> List[dict]: """Query a single CloudWatch log group for session data. @@ -54,16 +55,23 @@ def query_log_group( end_time: Query end time query_string: Optional custom query string. When provided, used instead of the default substring match query. + agent_id: Optional agent ID to filter by (prevents cross-agent session collisions) Returns: List of parsed JSON log messages """ if query_string is None: + agent_filter = "" + if agent_id is not None: + agent_filter = ( + f'\n | parse resource.attributes.cloud.resource_id "runtime/*/" as parsedAgentId' + f"\n | filter parsedAgentId = '{agent_id}'" + ) query_string = f"""fields @timestamp, @message | filter @message like "{session_id}" | filter ispresent(scope.name) | filter ispresent(traceId) - | filter ispresent(spanId) + | filter ispresent(spanId){agent_filter} | sort @timestamp asc""" max_attempts = 30 @@ -149,6 +157,7 @@ def fetch_spans( event_log_group: str, start_time: datetime, end_time: Optional[datetime] = None, + agent_id: Optional[str] = None, ) -> List[dict]: """Fetch ADOT spans from CloudWatch with configurable event log group. @@ -162,6 +171,7 @@ def fetch_spans( - For custom agents: Any log group you configured (e.g., "/my-app/agent-events") start_time: Start time for log query end_time: End time for log query + agent_id: Optional agent ID to filter by (prevents cross-agent session collisions) Returns: List of ADOT span and log record dictionaries @@ -190,14 +200,167 @@ def fetch_spans( end_time = datetime.now() # Query both log groups - aws_spans = self.query_log_group("aws/spans", session_id, start_time, end_time) - event_logs = self.query_log_group(event_log_group, session_id, start_time, end_time) + aws_spans = self.query_log_group("aws/spans", session_id, start_time, end_time, agent_id=agent_id) + event_logs = self.query_log_group(event_log_group, session_id, start_time, end_time, agent_id=agent_id) all_data = aws_spans + event_logs logger.info("Fetched %d span items from CloudWatch", len(all_data)) return all_data + def query_spans_by_trace( + self, + trace_id: str, + start_time_ms: int, + end_time_ms: int, + ) -> List[dict]: + """Query all spans for a trace from aws/spans log group. + + Note: Trace IDs are globally unique, so no agent_id filter needed to prevent cross-agent access + + Args: + trace_id: The trace ID to query + start_time_ms: Start time in milliseconds since epoch + end_time_ms: End time in milliseconds since epoch + + Returns: + List of result dictionaries from CloudWatch Logs Insights + """ + query_string = f"""fields @timestamp, @message, traceId, spanId, name as spanName, + kind, status.code as statusCode, status.message as statusMessage, + durationNano/1000000 as durationMs, attributes.session.id as sessionId, + startTimeUnixNano, endTimeUnixNano, parentSpanId, events, + resource.attributes.service.name as serviceName + | filter traceId = '{trace_id}' + | sort startTimeUnixNano asc""" + return self._execute_query(query_string, "aws/spans", start_time_ms, end_time_ms) + + def query_runtime_logs_by_traces( + self, + trace_ids: List[str], + start_time_ms: int, + end_time_ms: int, + agent_id: str, + endpoint_name: str = "DEFAULT", + ) -> List[dict]: + """Query runtime logs for multiple traces from agent-specific log group. + + Args: + trace_ids: List of trace IDs to query + start_time_ms: Start time in milliseconds since epoch + end_time_ms: End time in milliseconds since epoch + agent_id: Agent ID for constructing the log group name + endpoint_name: Runtime endpoint name (default: DEFAULT) + + Returns: + List of result dictionaries from CloudWatch Logs Insights + """ + if not trace_ids: + return [] + + log_group = f"/aws/bedrock-agentcore/runtimes/{agent_id}-{endpoint_name}" + trace_ids_quoted = ", ".join([f"'{tid}'" for tid in trace_ids]) + query_string = f"""fields @timestamp, @message, spanId, traceId, @logStream + | filter traceId in [{trace_ids_quoted}] + | sort @timestamp asc""" + + try: + return self._execute_query(query_string, log_group, start_time_ms, end_time_ms) + except Exception as e: + logger.warning("Batch query failed, falling back to individual queries: %s", e) + return self._query_runtime_logs_individually(trace_ids, log_group, start_time_ms, end_time_ms) + + def get_latest_session_id( + self, + start_time_ms: int, + end_time_ms: int, + agent_id: str, + ) -> Optional[str]: + """Get the most recent session ID for an agent. + + Args: + start_time_ms: Start time in milliseconds since epoch + end_time_ms: End time in milliseconds since epoch + agent_id: Agent ID to query for + + Returns: + Latest session ID or None if no sessions found + """ + query_string = f"""filter resource.attributes.aws.service.type = "gen_ai_agent" + | parse resource.attributes.cloud.resource_id "runtime/*/" as parsedAgentId + | filter parsedAgentId = '{agent_id}' + | stats max(endTimeUnixNano) as maxEnd by attributes.session.id + | sort maxEnd desc + | limit 1""" + + results = self._execute_query(query_string, "aws/spans", start_time_ms, end_time_ms) + if not results or not results[0]: + return None + + for field in results[0]: + if field.get("field") == "attributes.session.id": + return field.get("value") + return None + + def _execute_query( + self, + query_string: str, + log_group_name: str, + start_time_ms: int, + end_time_ms: int, + ) -> List[dict]: + """Execute a CloudWatch Logs Insights query and wait for results. + + Args: + query_string: The query string + log_group_name: Log group to query + start_time_ms: Start time in milliseconds since epoch + end_time_ms: End time in milliseconds since epoch + + Returns: + List of result row dictionaries + """ + response = self.logs_client.start_query( + logGroupName=log_group_name, + startTime=start_time_ms // 1000, + endTime=end_time_ms // 1000, + queryString=query_string, + ) + query_id = response["queryId"] + + timeout = 60 + poll_interval = 2 + start = time.time() + while True: + if time.time() - start > timeout: + raise TimeoutError(f"Query {query_id} timed out after {timeout} seconds") + result = self.logs_client.get_query_results(queryId=query_id) + status = result["status"] + if status == "Complete": + return result.get("results", []) + elif status in ("Failed", "Cancelled"): + raise RuntimeError(f"Query {query_id} failed with status: {status}") + time.sleep(poll_interval) + + def _query_runtime_logs_individually( + self, + trace_ids: List[str], + log_group: str, + start_time_ms: int, + end_time_ms: int, + ) -> List[dict]: + """Fallback to query runtime logs one trace at a time.""" + results = [] + for trace_id in trace_ids: + query = f"""fields @timestamp, @message, spanId, traceId, @logStream + | filter traceId = '{trace_id}' + | sort @timestamp asc""" + try: + results.extend(self._execute_query(query, log_group, start_time_ms, end_time_ms)) + except Exception as e: + logger.warning("Failed to query runtime logs for trace %s: %s", trace_id, e) + return results + def fetch_spans_from_cloudwatch( session_id: str, diff --git a/src/bedrock_agentcore/observability/__init__.py b/src/bedrock_agentcore/observability/__init__.py new file mode 100644 index 00000000..8f6cfa25 --- /dev/null +++ b/src/bedrock_agentcore/observability/__init__.py @@ -0,0 +1,5 @@ +"""Observability delivery and configuration for AgentCore resources.""" + +from .client import ObservabilityClient + +__all__ = ["ObservabilityClient"] diff --git a/src/bedrock_agentcore/observability/client.py b/src/bedrock_agentcore/observability/client.py new file mode 100644 index 00000000..f2803409 --- /dev/null +++ b/src/bedrock_agentcore/observability/client.py @@ -0,0 +1,345 @@ +"""Client for managing CloudWatch observability delivery and X-Ray Transaction Search for AgentCore resources.""" + +import json +import logging +from typing import Any, Dict, Optional + +import boto3 +from botocore.exceptions import ClientError + +logger = logging.getLogger(__name__) + +SUPPORTED_RESOURCE_TYPES = {"memory", "gateway", "runtime"} +AUTO_LOG_RESOURCE_TYPES = {"runtime"} + + +class ObservabilityClient: + """Manages CloudWatch delivery configuration and X-Ray Transaction Search for AgentCore resources.""" + + def __init__( + self, + region_name: Optional[str] = None, + session: Optional[boto3.Session] = None, + ): + """Initialize the ObservabilityClient.""" + self._session = session or boto3.Session() + self.region = region_name or self._session.region_name + if not self.region: + raise ValueError( + "AWS region must be specified either via region_name parameter " + "or configured in boto3 session/environment" + ) + self._logs_client = self._session.client("logs", region_name=self.region) + self._xray_client = self._session.client("xray", region_name=self.region) + sts_client = self._session.client("sts", region_name=self.region) + self._account_id = sts_client.get_caller_identity()["Account"] + + def enable_observability_for_resource( + self, + resource_arn: str, + resource_id: Optional[str] = None, + resource_type: Optional[str] = None, + enable_logs: bool = True, + enable_traces: bool = True, + custom_log_group: Optional[str] = None, + ) -> Dict[str, Any]: + """Enable CloudWatch logs and/or traces delivery for an AgentCore resource.""" + if resource_type is None or resource_id is None: + try: + resource_part = resource_arn.split(":")[-1] + parsed_type, parsed_id = resource_part.split("/", 1) + resource_type = resource_type or parsed_type + resource_id = resource_id or parsed_id + except (IndexError, ValueError) as e: + raise ValueError( + f"Could not parse resource_type/resource_id from ARN: {resource_arn}. " + f"Please provide them explicitly. Error: {e}" + ) from e + + if resource_type not in SUPPORTED_RESOURCE_TYPES: + raise ValueError( + f"Unsupported resource_type: '{resource_type}'. Must be one of: {SUPPORTED_RESOURCE_TYPES}" + ) + + results: Dict[str, Any] = { + "resource_id": resource_id, + "resource_type": resource_type, + "resource_arn": resource_arn, + "logs_enabled": False, + "traces_enabled": False, + "log_group": None, + "deliveries": {}, + } + + if custom_log_group: + log_group_name = custom_log_group + elif resource_type == "runtime": + log_group_name = f"/aws/bedrock-agentcore/runtimes/{resource_id}" + else: + log_group_name = f"/aws/vendedlogs/bedrock-agentcore/{resource_type}/APPLICATION_LOGS/{resource_id}" + + log_group_arn = f"arn:aws:logs:{self.region}:{self._account_id}:log-group:{log_group_name}" + results["log_group"] = log_group_name + + try: + if resource_type not in AUTO_LOG_RESOURCE_TYPES: + self._create_log_group_if_not_exists(log_group_name) + + if enable_logs and resource_type not in AUTO_LOG_RESOURCE_TYPES: + results["deliveries"]["logs"] = self._setup_logs_delivery(resource_arn, resource_id, log_group_arn) + results["logs_enabled"] = True + elif resource_type in AUTO_LOG_RESOURCE_TYPES: + results["logs_enabled"] = True + results["deliveries"]["logs"] = {"status": "auto-created by AWS"} + + if enable_traces: + results["deliveries"]["traces"] = self._setup_traces_delivery(resource_arn, resource_id) + results["traces_enabled"] = True + + results["status"] = "success" + except Exception as e: + logger.error("Failed to enable observability for %s/%s: %s", resource_type, resource_id, e) + results["status"] = "error" + results["error"] = str(e) + + return results + + def disable_observability_for_resource( + self, + resource_id: str, + delete_log_group: bool = False, + ) -> Dict[str, Any]: + """Disable CloudWatch observability delivery for a resource.""" + results: Dict[str, Any] = {"resource_id": resource_id, "deleted": [], "errors": []} + + for suffix in ["logs", "traces"]: + source_name = f"{resource_id}-{suffix}-source" + dest_name = f"{resource_id}-{suffix}-destination" + + # Delete deliveries referencing this source first + try: + deliveries = self._logs_client.describe_deliveries() + for delivery in deliveries.get("deliveries", []): + if delivery.get("deliverySourceName") == source_name: + try: + self._logs_client.delete_delivery(id=delivery["id"]) + results["deleted"].append(f"delivery:{delivery['id']}") + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + results["errors"].append(f"Failed to delete delivery {delivery['id']}: {e}") + except ClientError: + pass + + try: + self._logs_client.delete_delivery_source(name=source_name) + results["deleted"].append(f"source:{source_name}") + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + results["errors"].append(f"Failed to delete {source_name}: {e}") + + try: + self._logs_client.delete_delivery_destination(name=dest_name) + results["deleted"].append(f"destination:{dest_name}") + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + results["errors"].append(f"Failed to delete {dest_name}: {e}") + + if delete_log_group: + for resource_type in SUPPORTED_RESOURCE_TYPES: + if resource_type == "runtime": + lg = f"/aws/bedrock-agentcore/runtimes/{resource_id}" + else: + lg = f"/aws/vendedlogs/bedrock-agentcore/{resource_type}/APPLICATION_LOGS/{resource_id}" + try: + self._logs_client.delete_log_group(logGroupName=lg) + results["deleted"].append(f"log_group:{lg}") + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + results["errors"].append(f"Failed to delete log group {lg}: {e}") + + results["status"] = "success" if not results["errors"] else "partial" + return results + + def enable_transaction_search(self) -> bool: + """Enable X-Ray Transaction Search (resource policy, trace destination, indexing rule).""" + try: + if self._need_resource_policy(): + self._create_resource_policy() + + if self._need_trace_destination(): + try: + self._xray_client.update_trace_segment_destination(Destination="CloudWatchLogs") + except ClientError as e: + if e.response["Error"]["Code"] != "InvalidRequestException": + raise + + if self._need_indexing_rule(): + try: + self._xray_client.update_indexing_rule( + Name="Default", Rule={"Probabilistic": {"DesiredSamplingPercentage": 1}} + ) + except ClientError as e: + if e.response["Error"]["Code"] != "InvalidRequestException": + raise + + return True + except Exception as e: + logger.warning("Transaction Search configuration failed: %s", e) + return False + + def get_observability_status( + self, + resource_id: str, + ) -> Dict[str, Any]: + """Check the observability configuration status for a resource.""" + status: Dict[str, Any] = { + "resource_id": resource_id, + "logs": {"configured": False}, + "traces": {"configured": False}, + } + + for suffix in ["logs", "traces"]: + source_name = f"{resource_id}-{suffix}-source" + try: + self._logs_client.get_delivery_source(name=source_name) + status[suffix]["configured"] = True + status[suffix]["source_name"] = source_name + except ClientError: + pass + + return status + + # Private helpers + # ------------------------------------------------------------------------- + + def _create_log_group_if_not_exists(self, log_group_name: str) -> None: + try: + self._logs_client.create_log_group(logGroupName=log_group_name) + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceAlreadyExistsException": + raise + + def _setup_logs_delivery(self, resource_arn: str, resource_id: str, log_group_arn: str) -> Dict[str, str]: + source_name = f"{resource_id}-logs-source" + dest_name = f"{resource_id}-logs-destination" + + try: + logs_source = self._logs_client.put_delivery_source( + name=source_name, logType="APPLICATION_LOGS", resourceArn=resource_arn + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceAlreadyExistsException": + logs_source = {"deliverySource": {"name": source_name}} + else: + raise + + try: + logs_dest = self._logs_client.put_delivery_destination( + name=dest_name, + deliveryDestinationType="CWL", + deliveryDestinationConfiguration={"destinationResourceArn": log_group_arn}, + ) + dest_arn = logs_dest["deliveryDestination"]["arn"] + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceAlreadyExistsException": + dest_arn = f"arn:aws:logs:{self.region}:{self._account_id}:delivery-destination:{dest_name}" + else: + raise + + try: + delivery = self._logs_client.create_delivery( + deliverySourceName=logs_source["deliverySource"]["name"], deliveryDestinationArn=dest_arn + ) + delivery_id = delivery.get("id", "created") + except ClientError as e: + if e.response["Error"]["Code"] == "ConflictException": + delivery_id = "existing" + else: + raise + + return {"delivery_id": delivery_id, "source_name": source_name, "destination_name": dest_name} + + def _setup_traces_delivery(self, resource_arn: str, resource_id: str) -> Dict[str, str]: + source_name = f"{resource_id}-traces-source" + dest_name = f"{resource_id}-traces-destination" + + try: + traces_source = self._logs_client.put_delivery_source( + name=source_name, logType="TRACES", resourceArn=resource_arn + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceAlreadyExistsException": + traces_source = {"deliverySource": {"name": source_name}} + else: + raise + + try: + traces_dest = self._logs_client.put_delivery_destination(name=dest_name, deliveryDestinationType="XRAY") + dest_arn = traces_dest["deliveryDestination"]["arn"] + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceAlreadyExistsException": + dest_arn = f"arn:aws:logs:{self.region}:{self._account_id}:delivery-destination:{dest_name}" + else: + raise + + try: + delivery = self._logs_client.create_delivery( + deliverySourceName=traces_source["deliverySource"]["name"], deliveryDestinationArn=dest_arn + ) + delivery_id = delivery.get("id", "created") + except ClientError as e: + if e.response["Error"]["Code"] == "ConflictException": + delivery_id = "existing" + else: + raise + + return {"delivery_id": delivery_id, "source_name": source_name, "destination_name": dest_name} + + def _need_resource_policy(self, policy_name: str = "TransactionSearchXRayAccess") -> bool: + try: + response = self._logs_client.describe_resource_policies() + return not any(p.get("policyName") == policy_name for p in response.get("resourcePolicies", [])) + except Exception: + return True + + def _create_resource_policy(self) -> None: + policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "TransactionSearchXRayAccess", + "Effect": "Allow", + "Principal": {"Service": "xray.amazonaws.com"}, + "Action": "logs:PutLogEvents", + "Resource": [ + f"arn:aws:logs:{self.region}:{self._account_id}:log-group:aws/spans:*", + f"arn:aws:logs:{self.region}:{self._account_id}:log-group:/aws/application-signals/data:*", + ], + "Condition": { + "ArnLike": {"aws:SourceArn": f"arn:aws:xray:{self.region}:{self._account_id}:*"}, + "StringEquals": {"aws:SourceAccount": self._account_id}, + }, + } + ], + } + try: + self._logs_client.put_resource_policy( + policyName="TransactionSearchXRayAccess", policyDocument=json.dumps(policy_document) + ) + except ClientError as e: + if e.response["Error"]["Code"] != "InvalidParameterException": + raise + + def _need_trace_destination(self) -> bool: + try: + response = self._xray_client.get_trace_segment_destination() + return response.get("Destination") != "CloudWatchLogs" + except Exception: + return True # If check fails, assume we need it + + def _need_indexing_rule(self) -> bool: + try: + response = self._xray_client.get_indexing_rules() + return not any(r.get("Name") == "Default" for r in response.get("IndexingRules", [])) + except Exception: + return True diff --git a/tests/bedrock_agentcore/evaluation/utils/test_cloudwatch_span_helper.py b/tests/bedrock_agentcore/evaluation/utils/test_cloudwatch_span_helper.py index d3096f3d..029d9001 100644 --- a/tests/bedrock_agentcore/evaluation/utils/test_cloudwatch_span_helper.py +++ b/tests/bedrock_agentcore/evaluation/utils/test_cloudwatch_span_helper.py @@ -3,6 +3,8 @@ from datetime import datetime, timezone from unittest.mock import Mock, patch +import pytest + from bedrock_agentcore.evaluation.utils.cloudwatch_span_helper import ( CloudWatchSpanHelper, _is_valid_adot_document, @@ -237,3 +239,250 @@ def test_fetch_spans_from_cloudwatch_combines_both_log_groups(self): assert len(spans) == 2 assert spans[0]["scope"]["name"] == "test" assert spans[1]["scope"]["name"] == "test2" + + +class TestAgentIdFiltering: + """Test agent_id filtering on existing methods.""" + + def test_query_log_group_with_agent_id(self): + """Test that agent_id adds parse/filter clauses to default query.""" + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = {"status": "Complete", "results": []} + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) + + helper.query_log_group("aws/spans", "sess-1", start_time, end_time, agent_id="my-agent") + + query = mock_client.start_query.call_args[1]["queryString"] + assert "parsedAgentId" in query + assert "my-agent" in query + + def test_query_log_group_without_agent_id(self): + """Test that omitting agent_id does not add agent filter.""" + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = {"status": "Complete", "results": []} + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) + + helper.query_log_group("aws/spans", "sess-1", start_time, end_time) + + query = mock_client.start_query.call_args[1]["queryString"] + assert "parsedAgentId" not in query + + def test_fetch_spans_passes_agent_id(self): + """Test that fetch_spans forwards agent_id to query_log_group.""" + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = {"status": "Complete", "results": []} + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) + + helper.fetch_spans("sess-1", "/aws/logs/agent", start_time, end_time, agent_id="my-agent") + + # Both calls (aws/spans + event log group) should include agent filter + for call in mock_client.start_query.call_args_list: + assert "my-agent" in call[1]["queryString"] + + +class TestQuerySpansByTrace: + """Test query_spans_by_trace method.""" + + def test_returns_results(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = { + "status": "Complete", + "results": [[{"field": "traceId", "value": "trace-abc"}]], + } + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + results = helper.query_spans_by_trace("trace-abc", 1000000, 2000000) + assert len(results) == 1 + mock_client.start_query.assert_called_once() + query = mock_client.start_query.call_args[1]["queryString"] + assert "trace-abc" in query + assert mock_client.start_query.call_args[1]["logGroupName"] == "aws/spans" + + def test_empty_results(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = {"status": "Complete", "results": []} + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + results = helper.query_spans_by_trace("trace-abc", 1000000, 2000000) + assert results == [] + + +class TestQueryRuntimeLogsByTraces: + """Test query_runtime_logs_by_traces method.""" + + def test_batch_query(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = { + "status": "Complete", + "results": [ + [{"field": "traceId", "value": "t1"}, {"field": "@message", "value": "log1"}], + [{"field": "traceId", "value": "t2"}, {"field": "@message", "value": "log2"}], + ], + } + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + results = helper.query_runtime_logs_by_traces(["t1", "t2"], 1000000, 2000000, "agent-1") + assert len(results) == 2 + query = mock_client.start_query.call_args[1]["queryString"] + assert "'t1'" in query + assert "'t2'" in query + assert mock_client.start_query.call_args[1]["logGroupName"] == "/aws/bedrock-agentcore/runtimes/agent-1-DEFAULT" + + def test_custom_endpoint_name(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = {"status": "Complete", "results": []} + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + helper.query_runtime_logs_by_traces(["t1"], 1000000, 2000000, "agent-1", endpoint_name="prod") + assert mock_client.start_query.call_args[1]["logGroupName"] == "/aws/bedrock-agentcore/runtimes/agent-1-prod" + + def test_empty_trace_ids(self): + helper = CloudWatchSpanHelper() + helper.logs_client = Mock() + + results = helper.query_runtime_logs_by_traces([], 1000000, 2000000, "agent-1") + assert results == [] + helper.logs_client.start_query.assert_not_called() + + def test_fallback_to_individual_queries(self): + mock_client = Mock() + # First call (batch) fails, individual calls succeed + mock_client.start_query.side_effect = [ + Exception("batch failed"), + {"queryId": "q-1"}, + {"queryId": "q-2"}, + ] + mock_client.get_query_results.return_value = { + "status": "Complete", + "results": [[{"field": "traceId", "value": "t1"}]], + } + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + results = helper.query_runtime_logs_by_traces(["t1", "t2"], 1000000, 2000000, "agent-1") + assert len(results) == 2 + assert mock_client.start_query.call_count == 3 + + +class TestGetLatestSessionId: + """Test get_latest_session_id method.""" + + def test_returns_session_id(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = { + "status": "Complete", + "results": [[{"field": "attributes.session.id", "value": "sess-latest"}]], + } + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + result = helper.get_latest_session_id(1000000, 2000000, "agent-1") + assert result == "sess-latest" + query = mock_client.start_query.call_args[1]["queryString"] + assert "agent-1" in query + assert "limit 1" in query + + def test_returns_none_when_no_results(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = {"status": "Complete", "results": []} + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + result = helper.get_latest_session_id(1000000, 2000000, "agent-1") + assert result is None + + def test_returns_none_when_field_missing(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = { + "status": "Complete", + "results": [[{"field": "other_field", "value": "something"}]], + } + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + result = helper.get_latest_session_id(1000000, 2000000, "agent-1") + assert result is None + + +class TestExecuteQuery: + """Test _execute_query helper.""" + + def test_successful_query(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = { + "status": "Complete", + "results": [[{"field": "f", "value": "v"}]], + } + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + results = helper._execute_query("fields @message", "lg", 1000000, 2000000) + assert len(results) == 1 + # Verify ms -> seconds conversion + assert mock_client.start_query.call_args[1]["startTime"] == 1000 + assert mock_client.start_query.call_args[1]["endTime"] == 2000 + + def test_failed_query_raises(self): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = {"status": "Failed"} + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + with pytest.raises(RuntimeError, match="failed with status"): + helper._execute_query("fields @message", "lg", 1000000, 2000000) + + @patch("bedrock_agentcore.evaluation.utils.cloudwatch_span_helper.time") + def test_timeout_raises(self, mock_time): + mock_client = Mock() + mock_client.start_query.return_value = {"queryId": "q-1"} + mock_client.get_query_results.return_value = {"status": "Running"} + # Simulate time passing beyond timeout + mock_time.time.side_effect = [0, 0, 61] + mock_time.sleep = Mock() + + helper = CloudWatchSpanHelper() + helper.logs_client = mock_client + + with pytest.raises(TimeoutError): + helper._execute_query("fields @message", "lg", 1000000, 2000000) diff --git a/tests/unit/observability/__init__.py b/tests/unit/observability/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/observability/test_observability_client.py b/tests/unit/observability/test_observability_client.py new file mode 100644 index 00000000..07d9f1ed --- /dev/null +++ b/tests/unit/observability/test_observability_client.py @@ -0,0 +1,314 @@ +"""Unit tests for ObservabilityClient.""" + +from unittest.mock import MagicMock, patch + +import pytest +from botocore.exceptions import ClientError + +from bedrock_agentcore.observability.client import ObservabilityClient + + +def _client_error(code, message="error"): + return ClientError({"Error": {"Code": code, "Message": message}}, "op") + + +def _make_client(): + with patch("boto3.Session") as mock_session_class: + mock_session = MagicMock() + mock_session.region_name = "us-east-1" + sts = MagicMock() + sts.get_caller_identity.return_value = {"Account": "123456789012"} + logs = MagicMock() + xray = MagicMock() + + def client_factory(service, **kwargs): + return {"sts": sts, "logs": logs, "xray": xray}[service] + + mock_session.client.side_effect = client_factory + mock_session_class.return_value = mock_session + client = ObservabilityClient(region_name="us-east-1", session=mock_session) + return client, logs, xray + + +class TestInit: + def test_init_with_region(self): + client, _, _ = _make_client() + assert client.region == "us-east-1" + assert client._account_id == "123456789012" + + def test_init_without_region_raises(self): + with patch("boto3.Session") as mock_cls: + mock_session = MagicMock() + mock_session.region_name = None + mock_cls.return_value = mock_session + with pytest.raises(ValueError, match="AWS region must be specified"): + ObservabilityClient(session=mock_session) + + +class TestEnableObservability: + def test_memory_success(self): + client, logs, _ = _make_client() + logs.put_delivery_source.return_value = {"deliverySource": {"name": "src"}} + logs.put_delivery_destination.return_value = { + "deliveryDestination": { + "name": "dst", + "arn": "arn:aws:logs:us-east-1:123456789012:delivery-destination:dst", + } + } + logs.create_delivery.return_value = {"id": "d-123"} + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:memory/mem-1", + resource_id="mem-1", + resource_type="memory", + ) + assert result["status"] == "success" + assert result["logs_enabled"] is True + assert result["traces_enabled"] is True + assert result["log_group"] == "/aws/vendedlogs/bedrock-agentcore/memory/APPLICATION_LOGS/mem-1" + + def test_runtime_skips_log_creation(self): + client, logs, _ = _make_client() + logs.put_delivery_source.return_value = {"deliverySource": {"name": "src"}} + logs.put_delivery_destination.return_value = { + "deliveryDestination": { + "name": "dst", + "arn": "arn:aws:logs:us-east-1:123456789012:delivery-destination:dst", + } + } + logs.create_delivery.return_value = {"id": "d-123"} + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/rt-1", + resource_id="rt-1", + resource_type="runtime", + ) + assert result["status"] == "success" + assert result["logs_enabled"] is True + assert result["deliveries"]["logs"] == {"status": "auto-created by AWS"} + logs.create_log_group.assert_not_called() + + def test_logs_only(self): + client, logs, _ = _make_client() + logs.put_delivery_source.return_value = {"deliverySource": {"name": "src"}} + logs.put_delivery_destination.return_value = { + "deliveryDestination": { + "name": "dst", + "arn": "arn:aws:logs:us-east-1:123456789012:delivery-destination:dst", + } + } + logs.create_delivery.return_value = {"id": "d-123"} + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:memory/mem-1", + resource_id="mem-1", + resource_type="memory", + enable_traces=False, + ) + assert result["status"] == "success" + assert result["logs_enabled"] is True + assert result["traces_enabled"] is False + + def test_traces_only(self): + client, logs, _ = _make_client() + logs.put_delivery_source.return_value = {"deliverySource": {"name": "src"}} + logs.put_delivery_destination.return_value = { + "deliveryDestination": { + "name": "dst", + "arn": "arn:aws:logs:us-east-1:123456789012:delivery-destination:dst", + } + } + logs.create_delivery.return_value = {"id": "d-123"} + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:memory/mem-1", + resource_id="mem-1", + resource_type="memory", + enable_logs=False, + ) + assert result["status"] == "success" + assert result["logs_enabled"] is False + assert result["traces_enabled"] is True + + def test_custom_log_group(self): + client, logs, _ = _make_client() + logs.put_delivery_source.return_value = {"deliverySource": {"name": "src"}} + logs.put_delivery_destination.return_value = { + "deliveryDestination": { + "name": "dst", + "arn": "arn:aws:logs:us-east-1:123456789012:delivery-destination:dst", + } + } + logs.create_delivery.return_value = {"id": "d-123"} + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:memory/mem-1", + resource_id="mem-1", + resource_type="memory", + custom_log_group="/my/custom/group", + ) + assert result["log_group"] == "/my/custom/group" + + def test_invalid_resource_type(self): + client, _, _ = _make_client() + with pytest.raises(ValueError, match="Unsupported resource_type"): + client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:invalid/x", + resource_id="x", + resource_type="invalid", + ) + + def test_parses_arn_when_ids_not_provided(self): + client, logs, _ = _make_client() + logs.put_delivery_source.return_value = {"deliverySource": {"name": "src"}} + logs.put_delivery_destination.return_value = { + "deliveryDestination": { + "name": "dst", + "arn": "arn:aws:logs:us-east-1:123456789012:delivery-destination:dst", + } + } + logs.create_delivery.return_value = {"id": "d-123"} + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:gateway/gw-99", + ) + assert result["resource_type"] == "gateway" + assert result["resource_id"] == "gw-99" + + def test_invalid_arn_raises(self): + client, _, _ = _make_client() + with pytest.raises(ValueError, match="Could not parse"): + client.enable_observability_for_resource(resource_arn="bad-arn") + + def test_log_group_already_exists(self): + client, logs, _ = _make_client() + logs.create_log_group.side_effect = _client_error("ResourceAlreadyExistsException") + logs.put_delivery_source.return_value = {"deliverySource": {"name": "src"}} + logs.put_delivery_destination.return_value = { + "deliveryDestination": { + "name": "dst", + "arn": "arn:aws:logs:us-east-1:123456789012:delivery-destination:dst", + } + } + logs.create_delivery.return_value = {"id": "d-123"} + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:memory/mem-1", + resource_id="mem-1", + resource_type="memory", + ) + assert result["status"] == "success" + + def test_delivery_already_exists(self): + client, logs, _ = _make_client() + logs.put_delivery_source.return_value = {"deliverySource": {"name": "src"}} + logs.put_delivery_destination.return_value = { + "deliveryDestination": { + "name": "dst", + "arn": "arn:aws:logs:us-east-1:123456789012:delivery-destination:dst", + } + } + logs.create_delivery.side_effect = _client_error("ConflictException") + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:memory/mem-1", + resource_id="mem-1", + resource_type="memory", + ) + assert result["status"] == "success" + assert result["deliveries"]["logs"]["delivery_id"] == "existing" + + def test_api_error_returns_error_status(self): + client, logs, _ = _make_client() + logs.create_log_group.side_effect = _client_error("AccessDeniedException") + + result = client.enable_observability_for_resource( + resource_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:memory/mem-1", + resource_id="mem-1", + resource_type="memory", + ) + assert result["status"] == "error" + assert "AccessDeniedException" in result["error"] + + +class TestDisableObservability: + def test_success(self): + client, logs, _ = _make_client() + result = client.disable_observability_for_resource(resource_id="mem-1") + assert result["status"] == "success" + assert len(result["deleted"]) == 4 + + def test_resource_not_found_is_ok(self): + client, logs, _ = _make_client() + logs.delete_delivery_source.side_effect = _client_error("ResourceNotFoundException") + logs.delete_delivery_destination.side_effect = _client_error("ResourceNotFoundException") + + result = client.disable_observability_for_resource(resource_id="nonexistent") + assert result["status"] == "success" + assert len(result["deleted"]) == 0 + + def test_with_log_group_deletion(self): + client, logs, _ = _make_client() + result = client.disable_observability_for_resource(resource_id="mem-1", delete_log_group=True) + assert result["status"] == "success" + assert logs.delete_log_group.called + + +class TestEnableTransactionSearch: + def test_all_steps_needed(self): + client, logs, xray = _make_client() + logs.describe_resource_policies.return_value = {"resourcePolicies": []} + xray.get_trace_segment_destination.return_value = {"Destination": "XRay"} + xray.get_indexing_rules.return_value = {"IndexingRules": []} + + assert client.enable_transaction_search() is True + logs.put_resource_policy.assert_called_once() + xray.update_trace_segment_destination.assert_called_once_with(Destination="CloudWatchLogs") + xray.update_indexing_rule.assert_called_once() + + def test_all_already_configured(self): + client, logs, xray = _make_client() + logs.describe_resource_policies.return_value = { + "resourcePolicies": [{"policyName": "TransactionSearchXRayAccess"}] + } + xray.get_trace_segment_destination.return_value = {"Destination": "CloudWatchLogs"} + xray.get_indexing_rules.return_value = {"IndexingRules": [{"Name": "Default"}]} + + assert client.enable_transaction_search() is True + logs.put_resource_policy.assert_not_called() + xray.update_trace_segment_destination.assert_not_called() + xray.update_indexing_rule.assert_not_called() + + def test_idempotent_on_invalid_request(self): + client, logs, xray = _make_client() + logs.describe_resource_policies.return_value = {"resourcePolicies": []} + xray.get_trace_segment_destination.return_value = {"Destination": "XRay"} + xray.update_trace_segment_destination.side_effect = _client_error("InvalidRequestException") + xray.get_indexing_rules.return_value = {"IndexingRules": [{"Name": "Default"}]} + + assert client.enable_transaction_search() is True + + def test_failure_returns_false(self): + client, logs, xray = _make_client() + logs.describe_resource_policies.return_value = {"resourcePolicies": []} + logs.put_resource_policy.side_effect = _client_error("AccessDeniedException") + + assert client.enable_transaction_search() is False + + +class TestGetObservabilityStatus: + def test_both_configured(self): + client, logs, _ = _make_client() + logs.get_delivery_source.return_value = {"deliverySource": {"name": "src"}} + + result = client.get_observability_status(resource_id="mem-1") + assert result["logs"]["configured"] is True + assert result["traces"]["configured"] is True + + def test_none_configured(self): + client, logs, _ = _make_client() + logs.get_delivery_source.side_effect = _client_error("ResourceNotFoundException") + + result = client.get_observability_status(resource_id="mem-1") + assert result["logs"]["configured"] is False + assert result["traces"]["configured"] is False diff --git a/tests_integ/observability/__init__.py b/tests_integ/observability/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests_integ/observability/test_cloudwatch_span_helper.py b/tests_integ/observability/test_cloudwatch_span_helper.py new file mode 100644 index 00000000..52c666c2 --- /dev/null +++ b/tests_integ/observability/test_cloudwatch_span_helper.py @@ -0,0 +1,89 @@ +"""Integration tests for CloudWatchSpanHelper new methods. + +Requires: + SPAN_TEST_AGENT_ID: Agent runtime ID with pre-existing spans. + SPAN_TEST_TRACE_ID: A known trace ID from that agent. + SPAN_TEST_SESSION_ID: A known session ID from that agent. +""" + +import os +import time + +import pytest + +from bedrock_agentcore.evaluation.utils.cloudwatch_span_helper import CloudWatchSpanHelper + + +@pytest.mark.integration +class TestCloudWatchSpanHelperInteg: + """Integration tests for CloudWatchSpanHelper query methods.""" + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.agent_id = os.environ.get("SPAN_TEST_AGENT_ID") + cls.trace_id = os.environ.get("SPAN_TEST_TRACE_ID") + cls.session_id = os.environ.get("SPAN_TEST_SESSION_ID") + if not all([cls.agent_id, cls.trace_id, cls.session_id]): + pytest.fail("SPAN_TEST_AGENT_ID, SPAN_TEST_TRACE_ID, and SPAN_TEST_SESSION_ID must be set") + cls.helper = CloudWatchSpanHelper(region=cls.region) + # Wide time window to always capture pre-populated data + cls.start_time_ms = 0 + cls.end_time_ms = int(time.time() * 1000) + + @pytest.mark.order(1) + def test_query_spans_by_trace(self): + results = self.helper.query_spans_by_trace( + trace_id=self.trace_id, + start_time_ms=self.start_time_ms, + end_time_ms=self.end_time_ms, + ) + assert len(results) > 0 + # Verify all results match the queried trace ID + for row in results: + for field in row: + if field.get("field") == "traceId": + assert field.get("value") == self.trace_id + + @pytest.mark.order(2) + def test_query_runtime_logs_by_traces(self): + results = self.helper.query_runtime_logs_by_traces( + trace_ids=[self.trace_id], + start_time_ms=self.start_time_ms, + end_time_ms=self.end_time_ms, + agent_id=self.agent_id, + ) + assert len(results) > 0 + # Verify all results match the queried trace ID + for row in results: + for field in row: + if field.get("field") == "traceId": + assert field.get("value") == self.trace_id + + @pytest.mark.order(3) + def test_get_latest_session_id(self): + session_id = self.helper.get_latest_session_id( + start_time_ms=self.start_time_ms, + end_time_ms=self.end_time_ms, + agent_id=self.agent_id, + ) + assert session_id is not None + + @pytest.mark.order(4) + def test_query_spans_by_trace_nonexistent(self): + results = self.helper.query_spans_by_trace( + trace_id="00000000000000000000000000000000", + start_time_ms=self.start_time_ms, + end_time_ms=self.end_time_ms, + ) + assert results == [] + + @pytest.mark.order(5) + def test_query_runtime_logs_empty_traces(self): + results = self.helper.query_runtime_logs_by_traces( + trace_ids=[], + start_time_ms=self.start_time_ms, + end_time_ms=self.end_time_ms, + agent_id=self.agent_id, + ) + assert results == [] diff --git a/tests_integ/observability/test_observability_client.py b/tests_integ/observability/test_observability_client.py new file mode 100644 index 00000000..59286060 --- /dev/null +++ b/tests_integ/observability/test_observability_client.py @@ -0,0 +1,144 @@ +"""Integration tests for ObservabilityClient. + +Requires: + OBSERVABILITY_TEST_MEMORY_ID: ID of a persistent memory resource in the test account. + OBSERVABILITY_TEST_MEMORY_ARN: Full ARN of that memory resource. +""" + +import os + +import pytest + +from bedrock_agentcore.observability.client import ObservabilityClient + + +@pytest.mark.integration +class TestObservabilityDelivery: + """Tests enable/disable/status lifecycle using a real memory resource.""" + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.memory_id = os.environ.get("OBSERVABILITY_TEST_MEMORY_ID") + cls.memory_arn = os.environ.get("OBSERVABILITY_TEST_MEMORY_ARN") + if not cls.memory_id or not cls.memory_arn: + pytest.fail("OBSERVABILITY_TEST_MEMORY_ID and OBSERVABILITY_TEST_MEMORY_ARN must be set") + cls.client = ObservabilityClient(region_name=cls.region) + + @classmethod + def teardown_class(cls): + try: + cls.client.disable_observability_for_resource( + resource_id=cls.memory_id, + delete_log_group=True, + ) + except Exception as e: + print(f"Teardown: {e}") + + @pytest.mark.order(1) + def test_enable_observability_logs_and_traces(self): + result = self.client.enable_observability_for_resource( + resource_arn=self.memory_arn, + resource_id=self.memory_id, + resource_type="memory", + enable_logs=True, + enable_traces=True, + ) + assert result["status"] == "success" + assert result["logs_enabled"] is True + assert result["traces_enabled"] is True + assert "APPLICATION_LOGS" in result["log_group"] + + @pytest.mark.order(2) + def test_get_status_shows_configured(self): + status = self.client.get_observability_status(resource_id=self.memory_id) + assert status["logs"]["configured"] is True + assert status["traces"]["configured"] is True + + @pytest.mark.order(3) + def test_enable_is_idempotent(self): + result = self.client.enable_observability_for_resource( + resource_arn=self.memory_arn, + resource_id=self.memory_id, + resource_type="memory", + ) + assert result["status"] == "success" + + @pytest.mark.order(4) + def test_disable_observability(self): + result = self.client.disable_observability_for_resource( + resource_id=self.memory_id, + delete_log_group=True, + ) + assert result["status"] == "success" + assert len(result["deleted"]) > 0 + + @pytest.mark.order(5) + def test_get_status_shows_not_configured_after_disable(self): + status = self.client.get_observability_status(resource_id=self.memory_id) + assert status["logs"]["configured"] is False + assert status["traces"]["configured"] is False + + @pytest.mark.order(6) + def test_disable_is_idempotent(self): + result = self.client.disable_observability_for_resource( + resource_id=self.memory_id, + ) + assert result["status"] == "success" + + +@pytest.mark.integration +class TestObservabilityArnParsing: + """Tests that ARN parsing works with real API calls.""" + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.memory_id = os.environ.get("OBSERVABILITY_TEST_MEMORY_ID") + cls.memory_arn = os.environ.get("OBSERVABILITY_TEST_MEMORY_ARN") + if not cls.memory_id or not cls.memory_arn: + pytest.fail("OBSERVABILITY_TEST_MEMORY_ID and OBSERVABILITY_TEST_MEMORY_ARN must be set") + cls.client = ObservabilityClient(region_name=cls.region) + + @classmethod + def teardown_class(cls): + try: + cls.client.disable_observability_for_resource( + resource_id=cls.memory_id, + delete_log_group=True, + ) + except Exception as e: + print(f"Teardown: {e}") + + @pytest.mark.order(7) + def test_enable_with_arn_only(self): + """Test that resource_type and resource_id are inferred from ARN.""" + result = self.client.enable_observability_for_resource( + resource_arn=self.memory_arn, + ) + assert result["status"] == "success" + assert result["resource_type"] == "memory" + assert result["resource_id"] == self.memory_id + + +@pytest.mark.integration +class TestTransactionSearch: + """Tests enable_transaction_search with real X-Ray and CloudWatch APIs.""" + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.memory_id = os.environ.get("OBSERVABILITY_TEST_MEMORY_ID") + if not cls.memory_id: + pytest.fail("OBSERVABILITY_TEST_MEMORY_ID must be set") + cls.client = ObservabilityClient(region_name=cls.region) + + @pytest.mark.order(8) + def test_enable_transaction_search(self): + result = self.client.enable_transaction_search() + assert result is True + + @pytest.mark.order(9) + def test_enable_transaction_search_is_idempotent(self): + result = self.client.enable_transaction_search() + assert result is True