Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions sdks/python/src/agent_control/control_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def chat(message: str) -> str:
from typing import Any, TypeVar

from agent_control_models import Step, normalize_action
from agent_control_telemetry import get_trace_context_from_provider

from agent_control import AgentControlClient
from agent_control.evaluation import check_evaluation_with_local
Expand All @@ -53,6 +54,25 @@ async def chat(message: str) -> str:
F = TypeVar("F", bound=Callable[..., Any])


def _resolve_control_trace_context() -> tuple[str, str]:
"""Resolve trace/span IDs for a decorated control site.

External providers, such as the Galileo bridge, are authoritative because
they may reserve the concrete span ID that the eventual LLM/tool call will
use. Without a provider, keep the existing behavior: share an active trace
but create a fresh function span for this decorated call.
"""
provider_context = get_trace_context_from_provider()
if provider_context is not None:
return provider_context["trace_id"], provider_context["span_id"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a provider registered, this returns the provider's span_id verbatim, so every @control() call inside one provider-owned span now reports under the same (trace_id, span_id) pair — both the pre and post _evaluate POSTs, and any stacked decorators in the same trace. The docstring above frames this as intentional. Flagging as FYI: this is a semantic change from the old path, where get_current_trace_id() already returned the provider's trace_id and _generate_span_id() was minted fresh per decorated call (see tracing.py:147-150). Consumers that group control events by (trace_id, span_id) will now see them collapse onto one span; if the bridge keys events by control_execution_id this is fine. Worth a brief inline code comment so future readers don't expect a fresh span per decorator like the fallback branches still produce.


existing_trace_id = get_current_trace_id()
if existing_trace_id:
return existing_trace_id, _generate_span_id()

return get_trace_and_span_ids()


@dataclass
class ControlContext:
"""
Expand Down Expand Up @@ -697,14 +717,7 @@ async def _execute_with_control(
# Get cached controls for local evaluation support
controls = _get_server_controls()

# Get trace context: inherit trace_id if set, always generate new span_id
# This allows multiple @control() calls to share the same trace but have unique spans
existing_trace_id = get_current_trace_id()
if existing_trace_id:
trace_id = existing_trace_id
span_id = _generate_span_id() # New span for this function
else:
trace_id, span_id = get_trace_and_span_ids() # New trace and span
trace_id, span_id = _resolve_control_trace_context()

ctx = ControlContext(
agent_name=agent.agent_name,
Expand Down
52 changes: 50 additions & 2 deletions sdks/python/tests/test_control_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from unittest.mock import MagicMock, patch

import pytest
from agent_control_telemetry import clear_trace_context_provider, set_trace_context_provider

from agent_control.control_decorators import ControlViolationError, ControlSteerError, control

from agent_control.control_decorators import ControlSteerError, ControlViolationError, control

# =============================================================================
# FIXTURES
Expand Down Expand Up @@ -255,6 +255,54 @@ async def chat(message: str) -> str:
class TestPrePostExecution:
"""Tests for pre and post execution checks."""

@pytest.mark.asyncio
async def test_uses_external_provider_trace_context(self, mock_agent, mock_safe_response):
"""Test that an external provider supplies both trace and span IDs."""
# Given: an external telemetry provider that owns the active trace/span IDs
provided_trace_id = "6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9d"
provided_span_id = "8d30272e-23f7-4a4c-80d8-2decb2f3f9f8"
captured_contexts = []

async def mock_evaluate(
agent_name,
step,
stage,
server_url,
trace_id=None,
span_id=None,
controls=None,
event_agent_name=None,
):
captured_contexts.append((trace_id, span_id))
return mock_safe_response

set_trace_context_provider(
lambda: {"trace_id": provided_trace_id, "span_id": provided_span_id}
)
try:
with (
patch(
"agent_control.control_decorators._get_current_agent",
return_value=mock_agent,
),
patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate),
):

@control()
async def chat(message: str) -> str:
return f"Response to: {message}"

# When: a protected function runs pre and post checks
await chat("Hello!")
finally:
clear_trace_context_provider()

# Then: Agent Control preserves the provider's concrete target span ID
assert captured_contexts == [
(provided_trace_id, provided_span_id),
(provided_trace_id, provided_span_id),
]

@pytest.mark.asyncio
async def test_calls_pre_and_post(self, mock_agent, mock_safe_response):
"""Test that both pre and post checks are called."""
Expand Down
Loading