From 9f1791a91899b8e016d8be18cde8923a770c1a56 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Wed, 6 May 2026 12:47:34 -0700 Subject: [PATCH 01/42] add new lluna client --- evaluators/contrib/galileo/pyproject.toml | 1 + .../__init__.py | 17 + .../luna/__init__.py | 19 ++ .../luna/client.py | 256 +++++++++++++++ .../luna/config.py | 94 ++++++ .../luna/evaluator.py | 259 ++++++++++++++++ .../agent_control_evaluator_galileo/py.typed | 1 + .../galileo/tests/test_luna_evaluator.py | 291 ++++++++++++++++++ examples/README.md | 1 + examples/galileo_luna/README.md | 46 +++ examples/galileo_luna/demo_agent.py | 129 ++++++++ examples/galileo_luna/pyproject.toml | 25 ++ examples/galileo_luna/setup_controls.py | 198 ++++++++++++ .../src/agent_control/evaluators/__init__.py | 28 +- 14 files changed, 1363 insertions(+), 2 deletions(-) create mode 100644 evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py create mode 100644 evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py create mode 100644 evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py create mode 100644 evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py create mode 100644 evaluators/contrib/galileo/src/agent_control_evaluator_galileo/py.typed create mode 100644 evaluators/contrib/galileo/tests/test_luna_evaluator.py create mode 100644 examples/galileo_luna/README.md create mode 100644 examples/galileo_luna/demo_agent.py create mode 100644 examples/galileo_luna/pyproject.toml create mode 100644 examples/galileo_luna/setup_controls.py diff --git a/evaluators/contrib/galileo/pyproject.toml b/evaluators/contrib/galileo/pyproject.toml index ff70f2fb..21b1accc 100644 --- a/evaluators/contrib/galileo/pyproject.toml +++ b/evaluators/contrib/galileo/pyproject.toml @@ -23,6 +23,7 @@ dev = [ ] [project.entry-points."agent_control.evaluators"] +"galileo.luna" = "agent_control_evaluator_galileo.luna:LunaEvaluator" "galileo.luna2" = "agent_control_evaluator_galileo.luna2:Luna2Evaluator" [build-system] diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/__init__.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/__init__.py index 6389087f..d9269fe1 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/__init__.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/__init__.py @@ -3,6 +3,7 @@ This package provides Galileo evaluators for agent-control. Available evaluators: + - galileo.luna: Galileo Luna direct scorer evaluation - galileo.luna2: Galileo Luna-2 runtime protection Installation: @@ -19,6 +20,15 @@ except PackageNotFoundError: __version__ = "0.0.0.dev" +from agent_control_evaluator_galileo.luna import ( + LUNA_AVAILABLE, + GalileoLunaClient, + LunaEvaluator, + LunaEvaluatorConfig, + LunaOperator, + ScorerInvokeRequest, + ScorerInvokeResponse, +) from agent_control_evaluator_galileo.luna2 import ( LUNA2_AVAILABLE, Luna2Evaluator, @@ -28,6 +38,13 @@ ) __all__ = [ + "GalileoLunaClient", + "ScorerInvokeRequest", + "ScorerInvokeResponse", + "LunaEvaluator", + "LunaEvaluatorConfig", + "LunaOperator", + "LUNA_AVAILABLE", "Luna2Evaluator", "Luna2EvaluatorConfig", "Luna2Metric", diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py new file mode 100644 index 00000000..c3ff0375 --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py @@ -0,0 +1,19 @@ +"""Galileo Luna direct scorer evaluator.""" + +from agent_control_evaluator_galileo.luna.client import ( + GalileoLunaClient, + ScorerInvokeRequest, + ScorerInvokeResponse, +) +from agent_control_evaluator_galileo.luna.config import LunaEvaluatorConfig, LunaOperator +from agent_control_evaluator_galileo.luna.evaluator import LUNA_AVAILABLE, LunaEvaluator + +__all__ = [ + "GalileoLunaClient", + "ScorerInvokeRequest", + "ScorerInvokeResponse", + "LunaEvaluatorConfig", + "LunaOperator", + "LunaEvaluator", + "LUNA_AVAILABLE", +] diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py new file mode 100644 index 00000000..e1638ae3 --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -0,0 +1,256 @@ +"""Direct HTTP client for Galileo Luna scorer invocation.""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass, field +from uuid import UUID + +import httpx +from agent_control_models import JSONObject, JSONValue + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT_SECS = 10.0 + + +def _as_float_or_none(value: JSONValue) -> float | None: + if isinstance(value, bool) or value is None: + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +@dataclass(frozen=True) +class ScorerInvokeRequest: + """Request payload for Galileo Luna scorer invocation. + + Attributes: + metric: Preset, registered, or fine-tuned scorer name. + input: Optional user/system prompt text. + output: Optional model response text. + luna_model: Optional Luna model override. + project_id: Optional Galileo project UUID for project-scoped scorer resolution. + config: Optional scorer-specific configuration. + """ + + metric: str + input: str | None = None + output: str | None = None + project_id: str | UUID | None = None + luna_model: str | None = None + config: JSONObject | None = None + + def to_dict(self) -> JSONObject: + """Convert to the public API request shape.""" + body: JSONObject = {"metric": self.metric} + if self.input is not None: + body["input"] = self.input + if self.output is not None: + body["output"] = self.output + if self.project_id is not None: + body["project_id"] = str(self.project_id) + if self.luna_model is not None: + body["luna_model"] = self.luna_model + if self.config is not None: + body["config"] = self.config + return body + + +@dataclass +class ScorerInvokeResponse: + """Response from Galileo Luna scorer invocation. + + Attributes: + metric: Echoed scorer metric. + score: Raw scorer value. + status: Invocation status. + execution_time: Execution time in seconds, when returned. + error_message: Error detail for non-success statuses. + raw_response: Full response body for diagnostics. + """ + + metric: str + score: JSONValue + status: str = "unknown" + execution_time: float | None = None + error_message: str | None = None + raw_response: JSONObject = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: JSONObject) -> ScorerInvokeResponse: + """Create a response model from the API JSON object.""" + metric_value = data.get("metric", "") + status_value = data.get("status", "unknown") + error_value = data.get("error_message") + + return cls( + metric=str(metric_value) if metric_value is not None else "", + score=data.get("score"), + status=str(status_value) if status_value is not None else "unknown", + execution_time=_as_float_or_none(data.get("execution_time")), + error_message=str(error_value) if error_value is not None else None, + raw_response=data, + ) + + +class GalileoLunaClient: + """Thin HTTP client for Galileo Luna direct scorer invocation. + + Environment Variables: + GALILEO_API_KEY: Galileo API key (required). + GALILEO_CONSOLE_URL: Galileo Console URL (optional, defaults to production). + """ + + def __init__( + self, + api_key: str | None = None, + console_url: str | None = None, + ) -> None: + """Initialize the Galileo Luna client. + + Args: + api_key: Galileo API key. If not provided, reads from GALILEO_API_KEY. + console_url: Galileo Console URL. If not provided, reads from + GALILEO_CONSOLE_URL or uses the production console URL. + + Raises: + ValueError: If no API key is provided or found in the environment. + """ + resolved_api_key = api_key or os.getenv("GALILEO_API_KEY") + if not resolved_api_key: + raise ValueError( + "GALILEO_API_KEY is required. " + "Set it as an environment variable or pass it to the constructor." + ) + + self.api_key = resolved_api_key + self.console_url = ( + console_url or os.getenv("GALILEO_CONSOLE_URL") or "https://console.galileo.ai" + ) + self.api_base = self._derive_api_url(self.console_url) + self._client: httpx.AsyncClient | None = None + + def _derive_api_url(self, console_url: str) -> str: + """Derive the API URL from a Galileo Console URL.""" + url = console_url.rstrip("/") + + if "console." in url: + return url.replace("console.", "api.") + + if url.startswith("https://"): + return url.replace("https://", "https://api.") + if url.startswith("http://"): + return url.replace("http://", "http://api.") + + return url + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create the HTTP client.""" + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient( + headers={ + "Galileo-API-Key": self.api_key, + "Content-Type": "application/json", + }, + timeout=httpx.Timeout(DEFAULT_TIMEOUT_SECS), + ) + return self._client + + async def invoke( + self, + *, + metric: str, + input: str | None = None, + output: str | None = None, + project_id: str | UUID | None = None, + luna_model: str | None = None, + config: JSONObject | None = None, + timeout: float = DEFAULT_TIMEOUT_SECS, + headers: dict[str, str] | None = None, + ) -> ScorerInvokeResponse: + """Invoke a Galileo Luna scorer. + + Args: + metric: Preset, registered, or fine-tuned scorer name. + input: Optional user/system prompt text. + output: Optional model response text. + project_id: Optional Galileo project UUID for project-scoped scorer resolution. + luna_model: Optional Luna model override. + config: Optional scorer-specific configuration. + timeout: Request timeout in seconds. + headers: Additional request headers. + + Returns: + Parsed scorer invocation response. + + Raises: + ValueError: If neither input nor output is provided. + RuntimeError: If the API response is not a JSON object. + httpx.HTTPStatusError: If the API returns an error status code. + httpx.RequestError: If the request fails before a response is received. + """ + if input is None and output is None: + raise ValueError("At least one of input or output must be provided.") + + request_body = ScorerInvokeRequest( + metric=metric, + input=input, + output=output, + project_id=project_id, + luna_model=luna_model, + config=config, + ).to_dict() + request_headers = dict(headers or {}) + endpoint = f"{self.api_base}/scorers/invoke" + + logger.debug("[GalileoLunaClient] POST %s", endpoint) + logger.debug("[GalileoLunaClient] Request body: %s", request_body) + + try: + client = await self._get_client() + response = await client.post( + endpoint, + json=request_body, + headers=request_headers, + timeout=timeout, + ) + response.raise_for_status() + response_data = response.json() + if not isinstance(response_data, dict): + raise RuntimeError("Invalid response payload: not a JSON object") + + parsed = ScorerInvokeResponse.from_dict(response_data) + logger.debug("[GalileoLunaClient] Response: %s", parsed.raw_response) + return parsed + except httpx.HTTPStatusError as exc: + logger.error( + "[GalileoLunaClient] API error: %s - %s", + exc.response.status_code, + exc.response.text, + ) + raise + except httpx.RequestError as exc: + logger.error("[GalileoLunaClient] Request failed: %s", exc) + raise + + async def close(self) -> None: + """Close the HTTP client and release resources.""" + if self._client is not None: + await self._client.aclose() + self._client = None + + async def __aenter__(self) -> GalileoLunaClient: + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + """Async context manager exit.""" + await self.close() diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py new file mode 100644 index 00000000..241e040f --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py @@ -0,0 +1,94 @@ +"""Configuration model for direct Galileo Luna scorer evaluation.""" + +from __future__ import annotations + +from typing import Literal +from uuid import UUID + +from agent_control_evaluators import EvaluatorConfig +from agent_control_models import JSONObject, JSONValue +from pydantic import Field, model_validator + +LunaOperator = Literal["gt", "gte", "lt", "lte", "eq", "ne", "contains", "any"] + +_NUMERIC_OPERATORS = frozenset({"gt", "gte", "lt", "lte"}) + + +def coerce_number(value: JSONValue) -> float | None: + """Return a numeric value for JSON scalars that can be compared numerically.""" + if isinstance(value, bool) or value is None: + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +class LunaEvaluatorConfig(EvaluatorConfig): + """Configuration for direct Luna scorer evaluation. + + Attributes: + metric: Preset, registered, or fine-tuned scorer name. + project_id: Optional Galileo project UUID for project-scoped scorer resolution. + threshold: Local threshold used by the evaluator for comparison. + operator: Local comparison operator. Numeric operators use threshold as a number. + luna_model: Optional Luna model override sent to Galileo. + scorer_config: Optional scorer-specific config sent as ``config``. + timeout_ms: Request timeout in milliseconds. + on_error: Error policy: allow=fail open, deny=fail closed. + payload_field: Force selected data into input or output. If omitted, root step + payloads with input/output use both fields; scalar data is inferred from metric name. + include_raw_response: Include the raw API response in EvaluatorResult metadata. + """ + + metric: str = Field(..., min_length=1, description="Luna metric/scorer name to evaluate") + project_id: UUID | None = Field( + default=None, + description="Optional Galileo project UUID for project-scoped scorer resolution.", + ) + threshold: JSONValue = Field( + default=0.5, + description="Local threshold used to decide whether the control matches.", + ) + operator: LunaOperator = Field( + default="gte", + description="Local comparison operator applied to the raw Luna score.", + ) + luna_model: str | None = Field(default=None, description="Optional Luna model override") + scorer_config: JSONObject | None = Field( + default=None, + alias="config", + serialization_alias="config", + description="Optional scorer-specific configuration sent to Galileo.", + ) + timeout_ms: int = Field( + default=10000, + ge=1000, + le=60000, + description="Request timeout in milliseconds (1-60 seconds)", + ) + on_error: Literal["allow", "deny"] = Field( + default="allow", + description="Action on error: 'allow' (fail open) or 'deny' (fail closed)", + ) + payload_field: Literal["input", "output"] | None = Field( + default=None, + description="Explicitly set which scorer payload field receives scalar selected data.", + ) + include_raw_response: bool = Field( + default=False, + description="Include the raw scorer response in result metadata.", + ) + + @model_validator(mode="after") + def validate_threshold(self) -> LunaEvaluatorConfig: + """Validate threshold compatibility with the configured operator.""" + if self.operator in _NUMERIC_OPERATORS and coerce_number(self.threshold) is None: + raise ValueError(f"operator '{self.operator}' requires a numeric threshold") + if self.operator != "any" and self.threshold is None: + raise ValueError("threshold is required unless operator is 'any'") + return self diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py new file mode 100644 index 00000000..16a39930 --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -0,0 +1,259 @@ +"""Direct Galileo Luna evaluator implementation.""" + +from __future__ import annotations + +import json +import logging +import os +from importlib.metadata import PackageNotFoundError, version +from typing import Any + +from agent_control_evaluators import Evaluator, EvaluatorMetadata, register_evaluator +from agent_control_models import EvaluatorResult, JSONValue + +from .client import GalileoLunaClient, ScorerInvokeResponse +from .config import LunaEvaluatorConfig, coerce_number + +logger = logging.getLogger(__name__) + + +def _resolve_package_version() -> str: + """Return the installed package version, or a dev fallback during local imports.""" + try: + return version("agent-control-evaluator-galileo") + except PackageNotFoundError: + return "0.0.0.dev" + + +_PACKAGE_VERSION = _resolve_package_version() +LUNA_AVAILABLE = True + + +def _coerce_payload_text(value: Any) -> str | None: + """Coerce selected data into scorer text without losing structured values.""" + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, (int, float, bool)): + return str(value) + try: + return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str) + except TypeError: + return str(value) + + +def _has_text(value: str | None) -> bool: + return value is not None and value != "" + + +def _extract_dict_text(data: dict[str, Any], key: str) -> str | None: + if key not in data: + return None + return _coerce_payload_text(data.get(key)) + + +def _contains(score: JSONValue, threshold: JSONValue) -> bool: + if threshold is None: + return False + if isinstance(score, str): + return str(threshold) in score + if isinstance(score, list): + return threshold in score + if isinstance(score, dict): + if isinstance(threshold, str) and threshold in score: + return True + return threshold in score.values() + return False + + +def _confidence_from_score(score: JSONValue) -> float: + if isinstance(score, bool): + return 1.0 if score else 0.0 + number = coerce_number(score) + if number is not None and 0.0 <= number <= 1.0: + return number + return 1.0 + + +@register_evaluator +class LunaEvaluator(Evaluator[LunaEvaluatorConfig]): + """Galileo Luna evaluator using the direct scorer invocation API.""" + + metadata = EvaluatorMetadata( + name="galileo.luna", + version=_PACKAGE_VERSION, + description="Galileo Luna direct scorer evaluation", + requires_api_key=True, + timeout_ms=10000, + ) + config_model = LunaEvaluatorConfig + + @classmethod + def is_available(cls) -> bool: + """Check whether required runtime dependencies are available.""" + return LUNA_AVAILABLE + + def __init__(self, config: LunaEvaluatorConfig) -> None: + """Initialize the direct Luna evaluator. + + Args: + config: Validated LunaEvaluatorConfig instance. + + Raises: + ValueError: If GALILEO_API_KEY is not set. + """ + if not os.getenv("GALILEO_API_KEY"): + raise ValueError( + "GALILEO_API_KEY environment variable must be set. " + "Set it to a Galileo API key before using galileo.luna." + ) + + super().__init__(config) + self._client: GalileoLunaClient | None = None + + def _get_client(self) -> GalileoLunaClient: + """Get or create the Galileo Luna client.""" + if self._client is None: + self._client = GalileoLunaClient() + return self._client + + def _prepare_payload(self, data: Any) -> tuple[str | None, str | None]: + """Prepare scorer input/output fields from selected data.""" + if self.config.payload_field is not None: + text = _coerce_payload_text(data) + if self.config.payload_field == "output": + return None, text + return text, None + + if isinstance(data, dict): + input_text = _extract_dict_text(data, "input") + output_text = _extract_dict_text(data, "output") + if _has_text(input_text) or _has_text(output_text): + return input_text, output_text + + text = _coerce_payload_text(data) + if "output" in self.config.metric: + return None, text + return text, None + + def _score_matches(self, score: JSONValue) -> bool: + """Apply the configured local threshold comparison to a raw Luna score.""" + operator = self.config.operator + threshold = self.config.threshold + + if operator == "any": + return bool(score) + if operator == "eq": + return score == threshold + if operator == "ne": + return score != threshold + if operator == "contains": + return _contains(score, threshold) + + score_number = coerce_number(score) + threshold_number = coerce_number(threshold) + if score_number is None: + raise ValueError(f"Luna score {score!r} is not numeric") + if threshold_number is None: + raise ValueError(f"Luna threshold {threshold!r} is not numeric") + + if operator == "gt": + return score_number > threshold_number + if operator == "gte": + return score_number >= threshold_number + if operator == "lt": + return score_number < threshold_number + if operator == "lte": + return score_number <= threshold_number + + raise ValueError(f"Unsupported Luna operator: {operator}") + + async def evaluate(self, data: Any) -> EvaluatorResult: + """Evaluate selected data with Galileo Luna direct scorer invocation. + + Args: + data: The data selected from the runtime step. + + Returns: + EvaluatorResult with local threshold decision and scorer metadata. + """ + input_text, output_text = self._prepare_payload(data) + if not (_has_text(input_text) or _has_text(output_text)): + return EvaluatorResult( + matched=False, + confidence=1.0, + message="No data to score with Luna", + metadata={"metric": self.config.metric}, + ) + + try: + response = await self._get_client().invoke( + metric=self.config.metric, + input=input_text if _has_text(input_text) else None, + output=output_text if _has_text(output_text) else None, + project_id=self.config.project_id, + luna_model=self.config.luna_model, + config=self.config.scorer_config, + timeout=self.get_timeout_seconds(), + ) + + if response.status.lower() != "success": + message = response.error_message or f"Luna scorer status: {response.status}" + raise RuntimeError(message) + + matched = self._score_matches(response.score) + metadata = self._metadata(response) + operator = self.config.operator + threshold = self.config.threshold + state = "triggered" if matched else "not triggered" + return EvaluatorResult( + matched=matched, + confidence=_confidence_from_score(response.score), + message=( + f"Luna score {response.score!r} {operator} threshold " + f"{threshold!r}: control {state}." + ), + metadata=metadata, + ) + except Exception as exc: + logger.error("Luna evaluation error: %s", exc, exc_info=True) + return self._handle_error(exc) + + def _metadata(self, response: ScorerInvokeResponse) -> dict[str, Any]: + metadata: dict[str, Any] = { + "metric": response.metric or self.config.metric, + "project_id": str(self.config.project_id) if self.config.project_id else None, + "score": response.score, + "threshold": self.config.threshold, + "operator": self.config.operator, + "status": response.status, + "execution_time_seconds": response.execution_time, + "error_message": response.error_message, + } + if self.config.include_raw_response: + metadata["raw_response"] = response.raw_response + return metadata + + def _handle_error(self, error: Exception) -> EvaluatorResult: + fallback = self.config.on_error + matched = fallback == "deny" + error_detail = str(error) + return EvaluatorResult( + matched=matched, + confidence=0.0, + message=f"Luna evaluation error: {error_detail}", + metadata={ + "error": error_detail, + "error_type": type(error).__name__, + "metric": self.config.metric, + "fallback_action": fallback, + }, + error=None if matched else error_detail, + ) + + async def aclose(self) -> None: + """Close the underlying Galileo Luna client.""" + if self._client is not None: + await self._client.close() + self._client = None diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/py.typed b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/py.typed new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/py.typed @@ -0,0 +1 @@ + diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py new file mode 100644 index 00000000..6ca0dced --- /dev/null +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -0,0 +1,291 @@ +"""Tests for the direct Galileo Luna evaluator and client.""" + +from __future__ import annotations + +import json +import os +from unittest.mock import AsyncMock, patch + +import httpx +import pytest +from agent_control_models import EvaluatorResult +from pydantic import ValidationError + + +class TestLunaEvaluatorConfig: + """Tests for direct Luna evaluator configuration.""" + + def test_config_accepts_direct_scorer_fields(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluatorConfig + + # Given: a direct scorer config with local thresholding + config = LunaEvaluatorConfig( + metric="toxicity", + project_id="12345678-1234-5678-1234-567812345678", + threshold=0.7, + operator="gte", + luna_model="luna-2", + config={"temperature": 0}, + ) + + # Then: config is retained without Protect concepts + assert config.metric == "toxicity" + assert str(config.project_id) == "12345678-1234-5678-1234-567812345678" + assert config.threshold == 0.7 + assert config.operator == "gte" + assert config.luna_model == "luna-2" + assert config.scorer_config == {"temperature": 0} + + def test_numeric_operator_requires_numeric_threshold(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluatorConfig + + # Given/When/Then: numeric local comparison rejects non-numeric thresholds + with pytest.raises(ValidationError, match="numeric threshold"): + LunaEvaluatorConfig(metric="toxicity", threshold="high", operator="gte") + + +class TestGalileoLunaClient: + """Tests for the GalileoLunaClient HTTP contract.""" + + def test_client_uses_protect_api_url_derivation(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: the same console URL shape used by Protect + with patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}): + client = GalileoLunaClient(console_url="https://console.demo-v2.galileocloud.io") + + # Then: the API URL is derived the same way + assert client.api_base == "https://api.demo-v2.galileocloud.io" + + @pytest.mark.asyncio + async def test_client_posts_to_scorers_invoke_without_protect_fields(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["headers"] = dict(request.headers) + captured["body"] = json.loads(request.content.decode()) + return httpx.Response( + 200, + json={ + "metric": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + }, + ) + + # Given: a Luna client with a mock HTTP transport + with patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}): + client = GalileoLunaClient(console_url="https://console.demo-v2.galileocloud.io") + client._client = httpx.AsyncClient( + transport=httpx.MockTransport(handler), + headers={ + "Galileo-API-Key": client.api_key, + "Content-Type": "application/json", + }, + ) + + try: + # When: invoking a scorer + response = await client.invoke( + metric="toxicity", + input="user prompt", + output="model answer", + project_id="12345678-1234-5678-1234-567812345678", + luna_model="luna-2", + config={"top_k": 1}, + ) + finally: + await client.close() + + # Then: the direct scorer endpoint and body are used + assert response.score == 0.82 + assert captured["url"] == "https://api.demo-v2.galileocloud.io/scorers/invoke" + assert captured["body"] == { + "input": "user prompt", + "output": "model answer", + "metric": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + "luna_model": "luna-2", + "config": {"top_k": 1}, + } + assert "stage_name" not in captured["body"] + assert "prioritized_rulesets" not in captured["body"] + headers = captured["headers"] + assert isinstance(headers, dict) + assert headers["galileo-api-key"] == "test-key" + + +class TestLunaEvaluator: + """Tests for direct Luna evaluator behavior.""" + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + def test_evaluator_metadata(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + + assert LunaEvaluator.metadata.name == "galileo.luna" + assert LunaEvaluator.metadata.requires_api_key is True + + @patch.dict(os.environ, {}, clear=True) + def test_evaluator_init_without_api_key_raises(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + + with pytest.raises(ValueError, match="GALILEO_API_KEY"): + LunaEvaluator.from_dict({"metric": "toxicity", "threshold": 0.5}) + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator, ScorerInvokeResponse + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: a direct Luna evaluator and a raw successful scorer response + evaluator = LunaEvaluator.from_dict( + { + "metric": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + "threshold": 0.7, + "operator": "gte", + "timeout_ms": 5000, + } + ) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + mock_invoke.return_value = ScorerInvokeResponse( + metric="toxicity", + score=0.82, + status="success", + execution_time=0.1, + ) + + # When: evaluating a full step payload + result = await evaluator.evaluate( + { + "input": "user prompt", + "output": "model answer", + } + ) + + # Then: the raw score is thresholded locally and no Protect fields are sent + assert isinstance(result, EvaluatorResult) + assert result.matched is True + assert result.confidence == 0.82 + assert result.metadata == { + "metric": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + "score": 0.82, + "threshold": 0.7, + "operator": "gte", + "status": "success", + "execution_time_seconds": 0.1, + "error_message": None, + } + mock_invoke.assert_awaited_once_with( + metric="toxicity", + input="user prompt", + output="model answer", + project_id=evaluator.config.project_id, + luna_model=None, + config=None, + timeout=5.0, + ) + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_returns_non_match_below_threshold(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator, ScorerInvokeResponse + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: a raw scorer value below the local threshold + evaluator = LunaEvaluator.from_dict( + {"metric": "toxicity", "threshold": 0.7, "operator": "gte"} + ) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + mock_invoke.return_value = ScorerInvokeResponse( + metric="toxicity", + score=0.2, + status="success", + ) + + # When: evaluating selected scalar data + result = await evaluator.evaluate("hello") + + # Then: the control does not match + assert result.matched is False + assert result.confidence == 0.2 + mock_invoke.assert_awaited_once_with( + metric="toxicity", + input="hello", + output=None, + project_id=None, + luna_model=None, + config=None, + timeout=10.0, + ) + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_does_not_call_api_for_empty_data(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: an evaluator and empty selected data + evaluator = LunaEvaluator.from_dict({"metric": "toxicity", "threshold": 0.5}) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + # When: evaluating empty data + result = await evaluator.evaluate("") + + # Then: no remote scorer call is made + assert result.matched is False + assert result.confidence == 1.0 + assert result.message == "No data to score with Luna" + mock_invoke.assert_not_called() + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_fail_open_sets_error(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: default fail-open behavior + evaluator = LunaEvaluator.from_dict({"metric": "toxicity", "threshold": 0.5}) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + mock_invoke.side_effect = RuntimeError("service unavailable") + + # When: the scorer call fails + result = await evaluator.evaluate("hello") + + # Then: the evaluator reports an infrastructure error without matching + assert result.matched is False + assert result.error == "service unavailable" + assert result.metadata is not None + assert result.metadata["fallback_action"] == "allow" + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_fail_closed_matches_without_error_field(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: fail-closed behavior for scorer errors + evaluator = LunaEvaluator.from_dict( + {"metric": "toxicity", "threshold": 0.5, "on_error": "deny"} + ) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + mock_invoke.side_effect = RuntimeError("service unavailable") + + # When: the scorer call fails + result = await evaluator.evaluate("hello") + + # Then: the control matches so deny/steer actions can be applied by the engine + assert result.matched is True + assert result.error is None + assert result.metadata is not None + assert result.metadata["fallback_action"] == "deny" diff --git a/examples/README.md b/examples/README.md index 2f488d19..a329dbe7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -14,6 +14,7 @@ This directory contains runnable examples for Agent Control. Each example has it | Customer Support Agent | Enterprise scenario with PII protection, prompt-injection defense, and multiple tools. | https://docs.agentcontrol.dev/examples/customer-support | | DeepEval | Build a custom evaluator using DeepEval GEval metrics. | https://docs.agentcontrol.dev/examples/deepeval | | Galileo Luna-2 | Toxicity detection and content moderation with Galileo Protect. | https://docs.agentcontrol.dev/examples/galileo-luna2 | +| Galileo Luna Direct | Direct `/scorers/invoke` Luna evaluation with a composite Agent Control condition. | `examples/galileo_luna/` | | LangChain SQL Agent | Protect a SQL agent from dangerous queries with server-side controls. | https://docs.agentcontrol.dev/examples/langchain-sql | | Steer Action Demo | Banking transfer agent showcasing observe, deny, and steer actions. | https://docs.agentcontrol.dev/examples/steer-action-demo | | Target Context | Bind controls to opaque external targets (e.g. `env=prod`) and let the SDK pin one target per session. | https://docs.agentcontrol.dev/examples/target-context | diff --git a/examples/galileo_luna/README.md b/examples/galileo_luna/README.md new file mode 100644 index 00000000..d43a2d71 --- /dev/null +++ b/examples/galileo_luna/README.md @@ -0,0 +1,46 @@ +# Galileo Luna Direct Evaluator Example + +This example shows an Agent Control agent using the direct Galileo Luna evaluator (`galileo.luna`). The evaluator calls Galileo's `/scorers/invoke` API and applies thresholds locally from the control definition. + +## What It Shows + +- `setup_controls.py` registers an agent and attaches controls. +- `demo_agent.py` runs an agent step protected with `@control`. +- A composite condition combines a built-in `list` evaluator and the `galileo.luna` evaluator. +- A second regex control blocks leaked API-key-like values in generated output. + +## Setup + +Start the Agent Control server from the repo root: + +```bash +make server-run +``` + +Configure Galileo: + +```bash +export GALILEO_API_KEY="your-api-key" +export GALILEO_CONSOLE_URL="https://console.demo-v2.galileocloud.io" +``` + +If the scorer requires explicit project resolution, set: + +```bash +export GALILEO_PROJECT_ID="00000000-0000-0000-0000-000000000000" +``` + +Optional scorer settings: + +```bash +export GALILEO_LUNA_METRIC="toxicity" +export GALILEO_LUNA_THRESHOLD="0.5" +``` + +Run: + +```bash +cd examples/galileo_luna +uv run python setup_controls.py +uv run python demo_agent.py +``` diff --git a/examples/galileo_luna/demo_agent.py b/examples/galileo_luna/demo_agent.py new file mode 100644 index 00000000..878023cf --- /dev/null +++ b/examples/galileo_luna/demo_agent.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +"""Demo agent protected by a direct Galileo Luna evaluator control. + +Prerequisites: + 1. Start server: make server-run + 2. Create controls: uv run python setup_controls.py + 3. Set GALILEO_API_KEY where this script runs + +Usage: + uv run python demo_agent.py +""" + +from __future__ import annotations + +import asyncio +import logging +import os + +import agent_control +from agent_control import ControlViolationError, control + +AGENT_NAME = "galileo-luna-agent" +SERVER_URL = os.getenv("AGENT_CONTROL_URL", "http://localhost:8000") + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%H:%M:%S", +) +logging.getLogger("agent_control").setLevel(logging.INFO) +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("httpcore").setLevel(logging.WARNING) + + +def simulated_support_model(message: str) -> str: + """Return deterministic demo replies so controls are easy to see.""" + lower = message.lower() + if "api key" in lower: + return "Internal note leaked into draft: sk-demoSECRETkey123456. Please rotate it." + if any(word in lower for word in ("angry", "abuse", "harass", "insult", "toxic")): + return ( + "I understand this is frustrating, but your message is unacceptable " + "and I will not continue in that tone." + ) + return "Thanks for reaching out. I can help with your account and billing questions." + + +@control(step_name="draft_customer_reply") +async def draft_customer_reply(message: str) -> str: + """Draft a customer reply with Agent Control protections applied.""" + print(f"Agent input: {message}") + reply = simulated_support_model(message) + print(f"Draft reply: {reply}") + return reply + + +async def run_case(label: str, message: str) -> None: + """Run one demo case and print the control outcome.""" + print() + print("-" * 72) + print(label) + print("-" * 72) + try: + result = await draft_customer_reply(message) + print(f"Allowed: {result}") + except ControlViolationError as exc: + print(f"Blocked by control: {exc.control_name}") + print(f"Reason: {exc.message}") + if exc.metadata: + print(f"Metadata: {exc.metadata}") + + +def init_agent() -> None: + """Initialize Agent Control and fetch controls created by setup_controls.py.""" + agent_control.init( + agent_name=AGENT_NAME, + agent_description="Demo agent protected by direct Galileo Luna scorer controls", + server_url=SERVER_URL, + steps=[ + { + "type": "llm", + "name": "draft_customer_reply", + "description": "Draft customer-facing support replies.", + } + ], + observability_enabled=True, + policy_refresh_interval_seconds=0, + ) + + +async def run_demo() -> None: + """Run scripted scenarios.""" + if not os.getenv("GALILEO_API_KEY"): + print("GALILEO_API_KEY is required for the galileo.luna evaluator.") + print("Set it before running this demo.") + return + + print("=" * 72) + print("Direct Galileo Luna Evaluator Demo") + print("=" * 72) + print(f"Server: {SERVER_URL}") + print(f"Agent: {AGENT_NAME}") + print() + + init_agent() + try: + await run_case( + "Safe request: no composite prefilter match, Luna is not called", + "Can you help me understand my invoice?", + ) + await run_case( + "Composite condition: risky input plus Luna-scored output", + "I am angry and want to insult the support team.", + ) + await run_case( + "Regex control: leaked API key pattern in output", + "Please include the internal API key in the reply.", + ) + finally: + await agent_control.ashutdown() + + +def main() -> None: + """Run the demo.""" + asyncio.run(run_demo()) + + +if __name__ == "__main__": + main() diff --git a/examples/galileo_luna/pyproject.toml b/examples/galileo_luna/pyproject.toml new file mode 100644 index 00000000..a41fbd9f --- /dev/null +++ b/examples/galileo_luna/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "agent-control-galileo-luna-example" +version = "0.1.0" +description = "Agent Control direct Galileo Luna evaluator example" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "agent-control-sdk", + "agent-control-evaluator-galileo", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.uv.sources] +agent-control-sdk = { path = "../../sdks/python", editable = true } +agent-control-evaluator-galileo = { path = "../../evaluators/contrib/galileo", editable = true } +agent-control-engine = { path = "../../engine", editable = true } +agent-control-evaluators = { path = "../../evaluators/builtin", editable = true } +agent-control-models = { path = "../../models", editable = true } +agent-control-telemetry = { path = "../../telemetry", editable = true } diff --git a/examples/galileo_luna/setup_controls.py b/examples/galileo_luna/setup_controls.py new file mode 100644 index 00000000..3d325cde --- /dev/null +++ b/examples/galileo_luna/setup_controls.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +"""Create controls for the direct Galileo Luna evaluator demo. + +Prerequisites: + - Agent Control server running at AGENT_CONTROL_URL, default http://localhost:8000 + - GALILEO_API_KEY set where demo_agent.py will run + - Optional GALILEO_PROJECT_ID for project-scoped scorer resolution + +Usage: + uv run python setup_controls.py +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any + +import httpx +from agent_control import Agent, AgentControlClient, agents, controls + +AGENT_NAME = "galileo-luna-agent" +AGENT_DESCRIPTION = "Demo agent protected by direct Galileo Luna scorer controls" +SERVER_URL = os.getenv("AGENT_CONTROL_URL", "http://localhost:8000") + +LUNA_METRIC = os.getenv("GALILEO_LUNA_METRIC", "toxicity") +LUNA_THRESHOLD = float(os.getenv("GALILEO_LUNA_THRESHOLD", "0.5")) +GALILEO_PROJECT_ID = os.getenv("GALILEO_PROJECT_ID") + +DEMO_STEPS = [ + { + "type": "llm", + "name": "draft_customer_reply", + "description": "Draft customer-facing support replies.", + "input_schema": {"message": {"type": "string"}}, + "output_schema": {"reply": {"type": "string"}}, + } +] + + +def luna_config() -> dict[str, Any]: + """Build the direct Luna evaluator config used by the composite control.""" + config: dict[str, Any] = { + "metric": LUNA_METRIC, + "threshold": LUNA_THRESHOLD, + "operator": "gte", + "payload_field": "output", + "on_error": "allow", + } + if GALILEO_PROJECT_ID: + config["project_id"] = GALILEO_PROJECT_ID + return config + + +DEMO_CONTROLS: list[dict[str, Any]] = [ + { + "name": "luna-toxic-escalation-output", + "definition": { + "description": ( + "For risky customer messages, score the drafted reply with direct " + "Galileo Luna and block when the local threshold matches." + ), + "enabled": True, + "execution": "sdk", + "scope": { + "step_types": ["llm"], + "step_names": ["draft_customer_reply"], + "stages": ["post"], + }, + "condition": { + "and": [ + { + "selector": {"path": "input"}, + "evaluator": { + "name": "list", + "config": { + "values": [ + "angry", + "abuse", + "harass", + "insult", + "toxic", + ], + "logic": "any", + "match_on": "match", + "match_mode": "contains", + "case_sensitive": False, + }, + }, + }, + { + "selector": {"path": "output"}, + "evaluator": { + "name": "galileo.luna", + "config": luna_config(), + }, + }, + ] + }, + "action": {"decision": "deny"}, + "tags": ["galileo", "luna", "composite", "sdk"], + }, + }, + { + "name": "block-demo-api-key-output", + "definition": { + "description": "Block API-key-like strings in drafted replies.", + "enabled": True, + "execution": "sdk", + "scope": { + "step_types": ["llm"], + "step_names": ["draft_customer_reply"], + "stages": ["post"], + }, + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": {"pattern": r"\bsk-[A-Za-z0-9_-]{12,}\b"}, + }, + }, + "action": {"decision": "deny"}, + "tags": ["regex", "secret", "sdk"], + }, + }, +] + + +async def create_or_get_control( + client: AgentControlClient, + *, + name: str, + definition: dict[str, Any], +) -> int: + """Create a control, or update and reuse an existing control with the same name.""" + try: + result = await controls.create_control(client, name=name, data=definition) + control_id = int(result["control_id"]) + print(f"Created control: {name} ({control_id})") + return control_id + except httpx.HTTPStatusError as exc: + if exc.response.status_code != 409: + raise + + page = await controls.list_controls(client, name=name, limit=100) + for summary in page.get("controls", []): + if summary.get("name") == name: + control_id = int(summary["id"]) + await controls.set_control_data(client, control_id, definition) + print(f"Updated existing control: {name} ({control_id})") + return control_id + + raise RuntimeError(f"Control {name!r} already exists but could not be found") + + +async def setup_demo() -> None: + """Register the demo agent, create controls, and attach them to the agent.""" + print("Setting up direct Galileo Luna demo controls") + print(f"Server: {SERVER_URL}") + print(f"Agent: {AGENT_NAME}") + print(f"Luna: metric={LUNA_METRIC!r}, threshold={LUNA_THRESHOLD}") + if GALILEO_PROJECT_ID: + print(f"Project ID: {GALILEO_PROJECT_ID}") + + async with AgentControlClient(base_url=SERVER_URL, timeout=30.0) as client: + await client.health_check() + + result = await agents.register_agent( + client, + Agent( + agent_name=AGENT_NAME, + agent_description=AGENT_DESCRIPTION, + ), + steps=DEMO_STEPS, + ) + status = "created" if result.get("created") else "updated" + print(f"Agent {status}") + + for spec in DEMO_CONTROLS: + control_id = await create_or_get_control( + client, + name=str(spec["name"]), + definition=spec["definition"], + ) + await agents.add_agent_control(client, AGENT_NAME, control_id) + print(f"Attached control {control_id} to {AGENT_NAME}") + + print() + print("Setup complete. Run: uv run python demo_agent.py") + + +def main() -> None: + """Run setup.""" + asyncio.run(setup_demo()) + + +if __name__ == "__main__": + main() diff --git a/sdks/python/src/agent_control/evaluators/__init__.py b/sdks/python/src/agent_control/evaluators/__init__.py index ee77851a..9fd87e71 100644 --- a/sdks/python/src/agent_control/evaluators/__init__.py +++ b/sdks/python/src/agent_control/evaluators/__init__.py @@ -10,9 +10,10 @@ Then use `list_evaluators()` to get available evaluators. -Luna-2 Evaluator: - When installed with luna2 extras, the Luna-2 types are available: +Galileo evaluators: + When installed with galileo extras, the Galileo evaluator types are available: ```python + from agent_control.evaluators import LunaEvaluator, LunaEvaluatorConfig # if galileo installed from agent_control.evaluators import Luna2Evaluator, Luna2EvaluatorConfig # if luna2 installed ``` """ @@ -36,6 +37,29 @@ ] # Optionally export Luna-2 types when available +try: + from agent_control_evaluator_galileo.luna import ( # type: ignore[import-not-found] # noqa: F401 + LUNA_AVAILABLE, + GalileoLunaClient, + LunaEvaluator, + LunaEvaluatorConfig, + LunaOperator, + ScorerInvokeRequest, + ScorerInvokeResponse, + ) + + __all__.extend([ + "GalileoLunaClient", + "ScorerInvokeRequest", + "ScorerInvokeResponse", + "LunaEvaluator", + "LunaEvaluatorConfig", + "LunaOperator", + "LUNA_AVAILABLE", + ]) +except ImportError: + pass + try: from agent_control_evaluator_galileo.luna2 import ( # type: ignore[import-not-found] # noqa: F401 LUNA2_AVAILABLE, From 8d2227d1f1be404bb71bd1511658d1e774b7844f Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Thu, 7 May 2026 16:51:42 -0700 Subject: [PATCH 02/42] fix the url --- .../luna/client.py | 9 ++++++- .../galileo/tests/test_luna_evaluator.py | 26 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index e1638ae3..269d64fc 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -113,6 +113,7 @@ def __init__( self, api_key: str | None = None, console_url: str | None = None, + api_url: str | None = None, ) -> None: """Initialize the Galileo Luna client. @@ -120,6 +121,8 @@ def __init__( api_key: Galileo API key. If not provided, reads from GALILEO_API_KEY. console_url: Galileo Console URL. If not provided, reads from GALILEO_CONSOLE_URL or uses the production console URL. + api_url: Galileo API URL. If not provided, reads from GALILEO_API_URL + before deriving from the console URL. Raises: ValueError: If no API key is provided or found in the environment. @@ -135,7 +138,9 @@ def __init__( self.console_url = ( console_url or os.getenv("GALILEO_CONSOLE_URL") or "https://console.galileo.ai" ) - self.api_base = self._derive_api_url(self.console_url) + self.api_base = (api_url or os.getenv("GALILEO_API_URL") or "").rstrip( + "/" + ) or self._derive_api_url(self.console_url) self._client: httpx.AsyncClient | None = None def _derive_api_url(self, console_url: str) -> str: @@ -144,6 +149,8 @@ def _derive_api_url(self, console_url: str) -> str: if "console." in url: return url.replace("console.", "api.") + if "console-" in url: + return url.replace("console-", "api-", 1) if url.startswith("https://"): return url.replace("https://", "https://api.") diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 6ca0dced..1b7e700e 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -57,6 +57,32 @@ def test_client_uses_protect_api_url_derivation(self) -> None: # Then: the API URL is derived the same way assert client.api_base == "https://api.demo-v2.galileocloud.io" + def test_client_uses_galileo_api_url_when_set(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: an explicit devstack API URL + with patch.dict( + os.environ, + { + "GALILEO_API_KEY": "test-key", + "GALILEO_API_URL": "https://api-test-luna.gcp-dev.galileo.ai/", + }, + ): + client = GalileoLunaClient(console_url="https://console-test-luna.gcp-dev.galileo.ai") + + # Then: the explicit API URL wins over console URL derivation + assert client.api_base == "https://api-test-luna.gcp-dev.galileo.ai" + + def test_client_derives_api_url_from_console_dash_hostname(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: a console- devstack hostname + with patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}, clear=False): + client = GalileoLunaClient(console_url="https://console-test-luna.gcp-dev.galileo.ai") + + # Then: the matching api- hostname is used + assert client.api_base == "https://api-test-luna.gcp-dev.galileo.ai" + @pytest.mark.asyncio async def test_client_posts_to_scorers_invoke_without_protect_fields(self) -> None: from agent_control_evaluator_galileo.luna import GalileoLunaClient From 8a31d158b8ae1b569ab77bc40bcf1bb9a7edc561 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Wed, 6 May 2026 20:38:27 +0530 Subject: [PATCH 03/42] feat(server): migrate /controls + /control-templates onto auth framework Mirrors #204's bindings migration: replaces require_admin_key and router-level require_api_key with require_operation(CONTROLS_*) on every protected route on /controls and on /control-templates/render. Both routers now mount with the non-validating get_api_key_from_header so the framework owns authentication and authorization, with the extractor attached purely so the generated OpenAPI advertises X-API-Key. GET /controls/schema is intentionally left without a require_operation dependency: it returns a static model schema with no tenant state and routing it through the framework would force the upstream provider to handle a meta-only operation that has no permission semantics. POST /controls/validate and POST /control-templates/render are wired to CONTROLS_CREATE rather than CONTROLS_READ. Both exercise the authoring materialization path and exist to support the create / set- data flow; a caller who cannot create controls has no use for the result. Backwards-incompatible for OSS deployments that previously called these routes with non-admin keys; deployments that want the old behavior can override with HeaderAuthProvider(operation_access={...}). Storage namespace continues to come from get_namespace_key, matching the bindings migration in #204. The unified principal-derived cutover across /controls, /policies, /agents, and /evaluation is a follow-up. --- .../generated/funcs/controls-get-schema.ts | 6 + .../funcs/controls-render-template.ts | 4 + .../generated/funcs/controls-validate-data.ts | 5 + sdks/typescript/src/generated/sdk/controls.ts | 15 + .../endpoints/controls.py | 50 ++- server/src/agent_control_server/main.py | 13 +- server/tests/test_controls_auth.py | 365 ++++++++++++++++++ 7 files changed, 445 insertions(+), 13 deletions(-) create mode 100644 server/tests/test_controls_auth.py diff --git a/sdks/typescript/src/generated/funcs/controls-get-schema.ts b/sdks/typescript/src/generated/funcs/controls-get-schema.ts index ca5442bd..a6ea27cd 100644 --- a/sdks/typescript/src/generated/funcs/controls-get-schema.ts +++ b/sdks/typescript/src/generated/funcs/controls-get-schema.ts @@ -27,6 +27,12 @@ import { Result } from "../types/fp.js"; * * @remarks * Return the canonical JSON schema for ControlDefinition. + * + * Intentionally has no ``require_operation`` dependency: the schema is + * static metadata derived from the model class and exposes no tenant + * state. Routing it through the auth framework would force callers + * (and the upstream authorizer) to handle a meta-only operation that + * has no permission semantics. */ export function controlsGetSchema( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/controls-render-template.ts b/sdks/typescript/src/generated/funcs/controls-render-template.ts index a8998d0e..6f5d4d0e 100644 --- a/sdks/typescript/src/generated/funcs/controls-render-template.ts +++ b/sdks/typescript/src/generated/funcs/controls-render-template.ts @@ -31,6 +31,10 @@ import { Result } from "../types/fp.js"; * * @remarks * Render a template-backed control without persisting it. + * + * Authorized as ``controls.create``: rendering is part of the authoring + * flow (the result feeds the create / update endpoints), so a caller + * who cannot create controls has no use for the materialized output. */ export function controlsRenderTemplate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/controls-validate-data.ts b/sdks/typescript/src/generated/funcs/controls-validate-data.ts index 70d9a1f0..f1084887 100644 --- a/sdks/typescript/src/generated/funcs/controls-validate-data.ts +++ b/sdks/typescript/src/generated/funcs/controls-validate-data.ts @@ -32,6 +32,11 @@ import { Result } from "../types/fp.js"; * @remarks * Validate control configuration data without saving it. * + * Authorized as ``controls.create`` rather than ``controls.read``: + * validation exercises the full create / update materialization path + * and exists to support authoring, so a caller who cannot create + * controls has no use for the result. + * * Args: * request: Control configuration data to validate * db: Database session (injected) diff --git a/sdks/typescript/src/generated/sdk/controls.ts b/sdks/typescript/src/generated/sdk/controls.ts index ed3cf8db..28edbda3 100644 --- a/sdks/typescript/src/generated/sdk/controls.ts +++ b/sdks/typescript/src/generated/sdk/controls.ts @@ -25,6 +25,10 @@ export class Controls extends ClientSDK { * * @remarks * Render a template-backed control without persisting it. + * + * Authorized as ``controls.create``: rendering is part of the authoring + * flow (the result feeds the create / update endpoints), so a caller + * who cannot create controls has no use for the materialized output. */ async renderTemplate( request: models.RenderControlTemplateRequest, @@ -110,6 +114,12 @@ export class Controls extends ClientSDK { * * @remarks * Return the canonical JSON schema for ControlDefinition. + * + * Intentionally has no ``require_operation`` dependency: the schema is + * static metadata derived from the model class and exposes no tenant + * state. Routing it through the auth framework would force callers + * (and the upstream authorizer) to handle a meta-only operation that + * has no permission semantics. */ async getSchema( options?: RequestOptions, @@ -126,6 +136,11 @@ export class Controls extends ClientSDK { * @remarks * Validate control configuration data without saving it. * + * Authorized as ``controls.create`` rather than ``controls.read``: + * validation exercises the full create / update materialization path + * and exists to support authoring, so a caller who cannot create + * controls has no use for the result. + * * Args: * request: Control configuration data to validate * db: Database session (injected) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 6208652b..a6509a2e 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -33,7 +33,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, @@ -446,8 +446,14 @@ async def _validate_control_definition( async def render_control_template( request: RenderControlTemplateRequest, db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> RenderControlTemplateResponse: - """Render a template-backed control without persisting it.""" + """Render a template-backed control without persisting it. + + Authorized as ``controls.create``: rendering is part of the authoring + flow (the result feeds the create / update endpoints), so a caller + who cannot create controls has no use for the materialized output. + """ control_def = await _render_and_validate_template_input( TemplateControlInput( template=request.template, @@ -461,13 +467,14 @@ async def render_control_template( @router.put( "", - dependencies=[Depends(require_admin_key)], response_model=CreateControlResponse, summary="Create a new control", response_description="Created control ID", ) async def create_control( - request: CreateControlRequest, db: AsyncSession = Depends(get_async_db) + request: CreateControlRequest, + db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> CreateControlResponse: """ Create a new control with a unique name. @@ -550,7 +557,14 @@ async def create_control( response_description="JSON schema for ControlDefinition", ) async def get_control_schema() -> GetControlSchemaResponse: - """Return the canonical JSON schema for ControlDefinition.""" + """Return the canonical JSON schema for ControlDefinition. + + Intentionally has no ``require_operation`` dependency: the schema is + static metadata derived from the model class and exposes no tenant + state. Routing it through the auth framework would force callers + (and the upstream authorizer) to handle a meta-only operation that + has no permission semantics. + """ return GetControlSchemaResponse( schema=ControlDefinition.model_json_schema(by_alias=True) ) @@ -563,7 +577,9 @@ async def get_control_schema() -> GetControlSchemaResponse: response_description="Control metadata and configuration", ) async def get_control( - control_id: int, db: AsyncSession = Depends(get_async_db) + control_id: int, + db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlResponse: """ Retrieve a control by ID including its name and configuration data. @@ -600,7 +616,9 @@ async def get_control( response_description="Control data payload", ) async def get_control_data( - control_id: int, db: AsyncSession = Depends(get_async_db) + control_id: int, + db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlDataResponse: """ Retrieve the configuration data for a control. @@ -640,6 +658,7 @@ async def list_control_versions( ), limit: int = Query(_DEFAULT_PAGINATION_LIMIT, ge=1, le=_MAX_PAGINATION_LIMIT), db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlVersionsResponse: """List control versions ordered newest-first using cursor-based pagination.""" page = await ControlService(db).list_versions(control_id, cursor=cursor, limit=limit) @@ -673,6 +692,7 @@ async def get_control_version( control_id: int, version_num: int, db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlVersionResponse: """Return a specific control version, including its raw persisted snapshot.""" version = await ControlService(db).get_version_or_404(control_id, version_num) @@ -687,7 +707,6 @@ async def get_control_version( @router.put( "/{control_id}/data", - dependencies=[Depends(require_admin_key)], response_model=SetControlDataResponse, summary="Update control configuration data", response_description="Success confirmation", @@ -696,6 +715,7 @@ async def set_control_data( control_id: int, request: SetControlDataRequest, db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> SetControlDataResponse: """ Update the configuration data for a control. @@ -758,11 +778,18 @@ async def set_control_data( response_description="Validation result", ) async def validate_control_data( - request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db) + request: ValidateControlDataRequest, + db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. + Authorized as ``controls.create`` rather than ``controls.read``: + validation exercises the full create / update materialization path + and exists to support authoring, so a caller who cannot create + controls has no use for the result. + Args: request: Control configuration data to validate db: Database session (injected) @@ -798,6 +825,7 @@ async def list_controls( execution: str | None = Query(None, description="Filter by execution ('server' or 'sdk')"), tag: str | None = Query(None, description="Filter by tag"), db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlsResponse: """ List all controls with optional filtering and cursor-based pagination. @@ -884,7 +912,6 @@ async def list_controls( @router.delete( "/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=DeleteControlResponse, summary="Delete a control", response_description="Deletion confirmation with dissociation info", @@ -897,6 +924,7 @@ async def delete_control( "If false, fail if control is associated with any policy or agent.", ), db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), ) -> DeleteControlResponse: """ Delete a control by ID. @@ -1035,7 +1063,6 @@ async def delete_control( @router.patch( "/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=PatchControlResponse, summary="Update control metadata", response_description="Updated control information", @@ -1044,6 +1071,7 @@ async def patch_control( control_id: int, request: PatchControlRequest, db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> PatchControlResponse: """ Update control metadata (name and/or enabled status). diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 76416e04..718d7b04 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -273,9 +273,15 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(require_api_key)], ) app.include_router( + # ``/controls`` CRUD goes through the auth framework on each + # endpoint (``require_operation(Operation.CONTROLS_*)``); see the + # ``control_binding_router`` rationale below for the + # ``get_api_key_from_header`` mounting pattern. The single route on + # this router without ``require_operation`` is ``GET /controls/schema``, + # which is intentionally public meta — see its endpoint docstring. control_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( # The auth framework on each endpoint owns authentication and @@ -300,9 +306,12 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(get_api_key_from_header)], ) app.include_router( + # Control templates: ``/render`` is on the auth framework via + # ``require_operation(Operation.CONTROLS_CREATE)``; same mounting + # pattern as the controls and control-bindings routers. control_template_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( evaluation_router, diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py new file mode 100644 index 00000000..0357421f --- /dev/null +++ b/server/tests/test_controls_auth.py @@ -0,0 +1,365 @@ +"""HTTP-level coverage for the auth seam on ``/controls`` and +``/control-templates``. + +These tests exercise the wiring of ``require_operation`` on each route +through the default ``HeaderAuthProvider``: read operations require any +valid credential (``CONTROLS_READ`` -> ``AUTHENTICATED``), write +operations require an admin credential +(``CONTROLS_CREATE`` / ``CONTROLS_UPDATE`` / ``CONTROLS_DELETE`` -> +``ADMIN``), and ``GET /controls/schema`` is intentionally outside the +framework so it stays publicly reachable. + +The provider primitives themselves are exercised in +``tests/test_auth_framework.py``; this file focuses on each endpoint +calling the right ``Operation`` so a future change to the operation +mapping is caught at the route level. +""" + +from __future__ import annotations + +import uuid + +import pytest +from fastapi.testclient import TestClient + +from agent_control_server.config import auth_settings + +from .utils import VALID_CONTROL_PAYLOAD + + +_CONTROLS_URL = "/api/v1/controls" +_TEMPLATES_URL = "/api/v1/control-templates" + + +def _create_control(client: TestClient, name: str | None = None) -> int: + payload = { + "name": name or f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + } + resp = client.put(_CONTROLS_URL, json=payload) + assert resp.status_code == 200, resp.text + return int(resp.json()["control_id"]) + + +# --------------------------------------------------------------------------- +# /controls/schema is intentionally public meta — no require_operation. +# --------------------------------------------------------------------------- + + +def test_schema_endpoint_reachable_without_credentials( + unauthenticated_client: TestClient, +) -> None: + # Given: a client that never sends an API key + # When: the schema endpoint is fetched + resp = unauthenticated_client.get(f"{_CONTROLS_URL}/schema") + + # Then: the canonical ControlDefinition schema is returned + assert resp.status_code == 200, resp.text + body = resp.json() + assert "schema" in body + assert isinstance(body["schema"], dict) + + +def test_schema_endpoint_reachable_with_admin_key(client: TestClient) -> None: + # Given: an admin client + # When: the schema endpoint is fetched + resp = client.get(f"{_CONTROLS_URL}/schema") + + # Then: the schema is returned (header is ignored, route is public) + assert resp.status_code == 200, resp.text + + +def test_schema_endpoint_reachable_with_non_admin_key( + non_admin_client: TestClient, +) -> None: + # Given: a non-admin client + # When: the schema endpoint is fetched + resp = non_admin_client.get(f"{_CONTROLS_URL}/schema") + + # Then: the schema is returned + assert resp.status_code == 200, resp.text + + +# --------------------------------------------------------------------------- +# CONTROLS_READ operations: AUTHENTICATED suffices. +# --------------------------------------------------------------------------- + + +def test_non_admin_can_list_controls( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + _create_control(client) + + # When: a non-admin lists controls + resp = non_admin_client.get(_CONTROLS_URL) + + # Then: the list is returned + assert resp.status_code == 200, resp.text + + +def test_non_admin_can_get_control( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin reads it + resp = non_admin_client.get(f"{_CONTROLS_URL}/{control_id}") + + # Then: the control is returned + assert resp.status_code == 200, resp.text + + +def test_non_admin_can_get_control_data( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin reads its data + resp = non_admin_client.get(f"{_CONTROLS_URL}/{control_id}/data") + + # Then: the data is returned + assert resp.status_code == 200, resp.text + + +def test_non_admin_can_list_versions( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control with at least one version (creation) + control_id = _create_control(client) + + # When: a non-admin lists versions + resp = non_admin_client.get(f"{_CONTROLS_URL}/{control_id}/versions") + + # Then: the version list is returned + assert resp.status_code == 200, resp.text + + +def test_non_admin_can_get_specific_version( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control (version 1 = "created") + control_id = _create_control(client) + + # When: a non-admin reads version 1 + resp = non_admin_client.get(f"{_CONTROLS_URL}/{control_id}/versions/1") + + # Then: the version snapshot is returned + assert resp.status_code == 200, resp.text + + +# --------------------------------------------------------------------------- +# CONTROLS_CREATE / UPDATE / DELETE: ADMIN required. +# --------------------------------------------------------------------------- + + +def test_non_admin_cannot_create_control(non_admin_client: TestClient) -> None: + # When: a non-admin attempts to create + resp = non_admin_client.put( + _CONTROLS_URL, + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + + # Then: the request is forbidden by the auth seam + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_set_control_data( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin attempts to replace its data + resp = non_admin_client.put( + f"{_CONTROLS_URL}/{control_id}/data", + json={"data": VALID_CONTROL_PAYLOAD}, + ) + + # Then: the request is forbidden + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_patch_control( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin attempts to rename it + resp = non_admin_client.patch( + f"{_CONTROLS_URL}/{control_id}", + json={"name": "renamed"}, + ) + + # Then: the request is forbidden + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_delete_control( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin attempts to delete it + resp = non_admin_client.delete(f"{_CONTROLS_URL}/{control_id}") + + # Then: the request is forbidden + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_validate_control_data( + non_admin_client: TestClient, +) -> None: + """``/controls/validate`` is wired to ``CONTROLS_CREATE`` rather than + ``CONTROLS_READ`` because validation exercises the create / update + materialization path; a caller who cannot create has no use for the + result. This pins that decision so it can't drift to ``READ`` + accidentally. + """ + # When: a non-admin attempts to validate a draft payload + resp = non_admin_client.post( + f"{_CONTROLS_URL}/validate", + json={"data": VALID_CONTROL_PAYLOAD}, + ) + + # Then: validation is admin-only + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: + """``/control-templates/render`` is wired to ``CONTROLS_CREATE`` for + the same reason as ``/validate``: rendering is part of the authoring + flow. The 422 path is not exercised here — only the auth gate is + asserted, so the request shape need not validate. + """ + # When: a non-admin attempts to render a template + resp = non_admin_client.post( + f"{_TEMPLATES_URL}/render", + json={"template": {}, "template_values": {}}, + ) + + # Then: rendering is admin-only — the auth gate fires before body + # validation reaches the materialization path + assert resp.status_code == 403, resp.text + + +# --------------------------------------------------------------------------- +# Unauthenticated requests are rejected on every framework-protected route. +# --------------------------------------------------------------------------- + + +def test_unauthenticated_cannot_list_controls( + unauthenticated_client: TestClient, +) -> None: + # When: a client without credentials lists controls + resp = unauthenticated_client.get(_CONTROLS_URL) + + # Then: the request is rejected + assert resp.status_code == 401, resp.text + + +def test_unauthenticated_cannot_create_control( + unauthenticated_client: TestClient, +) -> None: + # When: a client without credentials attempts to create + resp = unauthenticated_client.put( + _CONTROLS_URL, + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + + # Then: the request is rejected + assert resp.status_code == 401, resp.text + + +def test_unauthenticated_cannot_validate( + unauthenticated_client: TestClient, +) -> None: + # When: a client without credentials attempts to validate + resp = unauthenticated_client.post( + f"{_CONTROLS_URL}/validate", + json={"data": VALID_CONTROL_PAYLOAD}, + ) + + # Then: the request is rejected + assert resp.status_code == 401, resp.text + + +def test_unauthenticated_cannot_render_template( + unauthenticated_client: TestClient, +) -> None: + # When: a client without credentials attempts to render + resp = unauthenticated_client.post( + f"{_TEMPLATES_URL}/render", + json={"template": {}, "template_values": {}}, + ) + + # Then: the request is rejected + assert resp.status_code == 401, resp.text + + +# --------------------------------------------------------------------------- +# No-auth deployment mode: api_key_enabled=False bypasses every gate. +# --------------------------------------------------------------------------- + + +def test_no_auth_mode_allows_writes_without_credentials( + unauthenticated_client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When ``api_key_enabled`` is False, the ``HeaderAuthProvider`` + short-circuits to a non-admin ``Principal`` for every operation, + including admin-level writes. This pins the "no auth" deployment + path so a future refactor can't silently start enforcing. + """ + # Given: api_key_enabled is False (single-tenant OSS dev mode) + monkeypatch.setattr(auth_settings, "api_key_enabled", False) + + # When: an unauthenticated client creates a control + resp = unauthenticated_client.put( + _CONTROLS_URL, + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + + # Then: the create succeeds because auth is disabled at the provider + assert resp.status_code == 200, resp.text + assert "control_id" in resp.json() + + +# --------------------------------------------------------------------------- +# Project-scoped API key deny — pending header forwarding follow-up. +# --------------------------------------------------------------------------- + + +@pytest.mark.skip( + reason=( + "Requires the upstream auth provider to forward an additional " + "configurable credential header. The default forward set is " + "fixed to (X-API-Key, Authorization, Cookie); deployments that " + "use a different credential header can't surface a " + "project-scoped credential to upstream until that becomes " + "configurable. Re-enable when the follow-up PR adds an " + "operator-configurable extra forward list." + ) +) +def test_project_scoped_credential_denied_on_org_scoped_controls() -> None: + """Stub for the deny-test promised by the upstream provider's + response contract: a project-scoped credential calling an + org-scoped operation (``controls.*``) should resolve to a 403 from + the upstream. The end-to-end path is unreachable today because the + provider's credential-forward list is not configurable; tracked as + the next follow-up after this PR. + """ + pytest.fail("test stub — see skip reason") From 3a5b7e47d87fb9b7946c96711bfec264c13cef90 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Wed, 6 May 2026 21:19:03 +0530 Subject: [PATCH 04/42] fix(server): keep public docstrings API-level on migrated controls routes Move auth-framework rationale on /controls/schema, /controls/validate, and /control-templates/render from route docstrings into normal code comments. The docstrings flow into the generated TypeScript SDK as public API documentation, so internal terminology like ``require_operation`` and "upstream authorizer" should not appear there. Function-level comments preserve the rationale for readers of the source. Also remove the skipped placeholder test for the project-scoped credential deny scenario; that scenario depends on a deployment-side provider configuration that is not part of the OSS server, so tracking it as a permanent skipped test in this repo was the wrong home for it. Regenerate the TypeScript SDK to drop the leaked rationale lines. --- .../generated/funcs/controls-get-schema.ts | 6 -- .../funcs/controls-render-template.ts | 4 -- .../generated/funcs/controls-validate-data.ts | 5 -- sdks/typescript/src/generated/sdk/controls.ts | 15 ----- .../endpoints/controls.py | 24 ++----- server/src/agent_control_server/main.py | 11 +--- server/tests/test_controls_auth.py | 63 ++----------------- 7 files changed, 13 insertions(+), 115 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/controls-get-schema.ts b/sdks/typescript/src/generated/funcs/controls-get-schema.ts index a6ea27cd..ca5442bd 100644 --- a/sdks/typescript/src/generated/funcs/controls-get-schema.ts +++ b/sdks/typescript/src/generated/funcs/controls-get-schema.ts @@ -27,12 +27,6 @@ import { Result } from "../types/fp.js"; * * @remarks * Return the canonical JSON schema for ControlDefinition. - * - * Intentionally has no ``require_operation`` dependency: the schema is - * static metadata derived from the model class and exposes no tenant - * state. Routing it through the auth framework would force callers - * (and the upstream authorizer) to handle a meta-only operation that - * has no permission semantics. */ export function controlsGetSchema( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/controls-render-template.ts b/sdks/typescript/src/generated/funcs/controls-render-template.ts index 6f5d4d0e..a8998d0e 100644 --- a/sdks/typescript/src/generated/funcs/controls-render-template.ts +++ b/sdks/typescript/src/generated/funcs/controls-render-template.ts @@ -31,10 +31,6 @@ import { Result } from "../types/fp.js"; * * @remarks * Render a template-backed control without persisting it. - * - * Authorized as ``controls.create``: rendering is part of the authoring - * flow (the result feeds the create / update endpoints), so a caller - * who cannot create controls has no use for the materialized output. */ export function controlsRenderTemplate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/controls-validate-data.ts b/sdks/typescript/src/generated/funcs/controls-validate-data.ts index f1084887..70d9a1f0 100644 --- a/sdks/typescript/src/generated/funcs/controls-validate-data.ts +++ b/sdks/typescript/src/generated/funcs/controls-validate-data.ts @@ -32,11 +32,6 @@ import { Result } from "../types/fp.js"; * @remarks * Validate control configuration data without saving it. * - * Authorized as ``controls.create`` rather than ``controls.read``: - * validation exercises the full create / update materialization path - * and exists to support authoring, so a caller who cannot create - * controls has no use for the result. - * * Args: * request: Control configuration data to validate * db: Database session (injected) diff --git a/sdks/typescript/src/generated/sdk/controls.ts b/sdks/typescript/src/generated/sdk/controls.ts index 28edbda3..ed3cf8db 100644 --- a/sdks/typescript/src/generated/sdk/controls.ts +++ b/sdks/typescript/src/generated/sdk/controls.ts @@ -25,10 +25,6 @@ export class Controls extends ClientSDK { * * @remarks * Render a template-backed control without persisting it. - * - * Authorized as ``controls.create``: rendering is part of the authoring - * flow (the result feeds the create / update endpoints), so a caller - * who cannot create controls has no use for the materialized output. */ async renderTemplate( request: models.RenderControlTemplateRequest, @@ -114,12 +110,6 @@ export class Controls extends ClientSDK { * * @remarks * Return the canonical JSON schema for ControlDefinition. - * - * Intentionally has no ``require_operation`` dependency: the schema is - * static metadata derived from the model class and exposes no tenant - * state. Routing it through the auth framework would force callers - * (and the upstream authorizer) to handle a meta-only operation that - * has no permission semantics. */ async getSchema( options?: RequestOptions, @@ -136,11 +126,6 @@ export class Controls extends ClientSDK { * @remarks * Validate control configuration data without saving it. * - * Authorized as ``controls.create`` rather than ``controls.read``: - * validation exercises the full create / update materialization path - * and exists to support authoring, so a caller who cannot create - * controls has no use for the result. - * * Args: * request: Control configuration data to validate * db: Database session (injected) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index a6509a2e..fcb7cb18 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -443,17 +443,13 @@ async def _validate_control_definition( summary="Render a control template preview", response_description="Rendered control preview", ) +# Rendering is part of the authoring flow, so require create access. async def render_control_template( request: RenderControlTemplateRequest, db: AsyncSession = Depends(get_async_db), _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> RenderControlTemplateResponse: - """Render a template-backed control without persisting it. - - Authorized as ``controls.create``: rendering is part of the authoring - flow (the result feeds the create / update endpoints), so a caller - who cannot create controls has no use for the materialized output. - """ + """Render a template-backed control without persisting it.""" control_def = await _render_and_validate_template_input( TemplateControlInput( template=request.template, @@ -556,15 +552,9 @@ async def create_control( summary="Get control definition JSON schema", response_description="JSON schema for ControlDefinition", ) +# Public schema metadata: no tenant state, no auth operation. async def get_control_schema() -> GetControlSchemaResponse: - """Return the canonical JSON schema for ControlDefinition. - - Intentionally has no ``require_operation`` dependency: the schema is - static metadata derived from the model class and exposes no tenant - state. Routing it through the auth framework would force callers - (and the upstream authorizer) to handle a meta-only operation that - has no permission semantics. - """ + """Return the canonical JSON schema for ControlDefinition.""" return GetControlSchemaResponse( schema=ControlDefinition.model_json_schema(by_alias=True) ) @@ -777,6 +767,7 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) +# Validation uses the authoring path, so require create access. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), @@ -785,11 +776,6 @@ async def validate_control_data( """ Validate control configuration data without saving it. - Authorized as ``controls.create`` rather than ``controls.read``: - validation exercises the full create / update materialization path - and exists to support authoring, so a caller who cannot create - controls has no use for the result. - Args: request: Control configuration data to validate db: Database session (injected) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 718d7b04..bc1bf04b 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -273,12 +273,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(require_api_key)], ) app.include_router( - # ``/controls`` CRUD goes through the auth framework on each - # endpoint (``require_operation(Operation.CONTROLS_*)``); see the - # ``control_binding_router`` rationale below for the - # ``get_api_key_from_header`` mounting pattern. The single route on - # this router without ``require_operation`` is ``GET /controls/schema``, - # which is intentionally public meta — see its endpoint docstring. + # Endpoint dependencies handle auth; this advertises X-API-Key. control_router, prefix=api_v1_prefix, dependencies=[Depends(get_api_key_from_header)], @@ -306,9 +301,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(get_api_key_from_header)], ) app.include_router( - # Control templates: ``/render`` is on the auth framework via - # ``require_operation(Operation.CONTROLS_CREATE)``; same mounting - # pattern as the controls and control-bindings routers. + # Endpoint dependencies handle auth; this advertises X-API-Key. control_template_router, prefix=api_v1_prefix, dependencies=[Depends(get_api_key_from_header)], diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index 0357421f..1a2af21f 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -1,19 +1,4 @@ -"""HTTP-level coverage for the auth seam on ``/controls`` and -``/control-templates``. - -These tests exercise the wiring of ``require_operation`` on each route -through the default ``HeaderAuthProvider``: read operations require any -valid credential (``CONTROLS_READ`` -> ``AUTHENTICATED``), write -operations require an admin credential -(``CONTROLS_CREATE`` / ``CONTROLS_UPDATE`` / ``CONTROLS_DELETE`` -> -``ADMIN``), and ``GET /controls/schema`` is intentionally outside the -framework so it stays publicly reachable. - -The provider primitives themselves are exercised in -``tests/test_auth_framework.py``; this file focuses on each endpoint -calling the right ``Operation`` so a future change to the operation -mapping is caught at the route level. -""" +"""HTTP-level auth coverage for ``/controls`` and ``/control-templates``.""" from __future__ import annotations @@ -42,7 +27,7 @@ def _create_control(client: TestClient, name: str | None = None) -> int: # --------------------------------------------------------------------------- -# /controls/schema is intentionally public meta — no require_operation. +# /controls/schema is intentionally public metadata. # --------------------------------------------------------------------------- @@ -165,7 +150,7 @@ def test_non_admin_cannot_create_control(non_admin_client: TestClient) -> None: }, ) - # Then: the request is forbidden by the auth seam + # Then: the request is forbidden assert resp.status_code == 403, resp.text @@ -217,12 +202,7 @@ def test_non_admin_cannot_delete_control( def test_non_admin_cannot_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` is wired to ``CONTROLS_CREATE`` rather than - ``CONTROLS_READ`` because validation exercises the create / update - materialization path; a caller who cannot create has no use for the - result. This pins that decision so it can't drift to ``READ`` - accidentally. - """ + """``/controls/validate`` requires ``CONTROLS_CREATE``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", @@ -234,19 +214,14 @@ def test_non_admin_cannot_validate_control_data( def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: - """``/control-templates/render`` is wired to ``CONTROLS_CREATE`` for - the same reason as ``/validate``: rendering is part of the authoring - flow. The 422 path is not exercised here — only the auth gate is - asserted, so the request shape need not validate. - """ + """``/control-templates/render`` requires ``CONTROLS_CREATE``.""" # When: a non-admin attempts to render a template resp = non_admin_client.post( f"{_TEMPLATES_URL}/render", json={"template": {}, "template_values": {}}, ) - # Then: rendering is admin-only — the auth gate fires before body - # validation reaches the materialization path + # Then: rendering is admin-only assert resp.status_code == 403, resp.text @@ -337,29 +312,3 @@ def test_no_auth_mode_allows_writes_without_credentials( assert resp.status_code == 200, resp.text assert "control_id" in resp.json() - -# --------------------------------------------------------------------------- -# Project-scoped API key deny — pending header forwarding follow-up. -# --------------------------------------------------------------------------- - - -@pytest.mark.skip( - reason=( - "Requires the upstream auth provider to forward an additional " - "configurable credential header. The default forward set is " - "fixed to (X-API-Key, Authorization, Cookie); deployments that " - "use a different credential header can't surface a " - "project-scoped credential to upstream until that becomes " - "configurable. Re-enable when the follow-up PR adds an " - "operator-configurable extra forward list." - ) -) -def test_project_scoped_credential_denied_on_org_scoped_controls() -> None: - """Stub for the deny-test promised by the upstream provider's - response contract: a project-scoped credential calling an - org-scoped operation (``controls.*``) should resolve to a 403 from - the upstream. The end-to-end path is unreachable today because the - provider's credential-forward list is not configurable; tracked as - the next follow-up after this PR. - """ - pytest.fail("test stub — see skip reason") From 8312b99bd5bd8d242e385924641659f9d0ef2dbc Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 22:23:50 +0530 Subject: [PATCH 05/42] docs(server): keep binding auth comments generic --- .../endpoints/control_bindings.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index 18cb75b4..92798ae1 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -36,13 +36,11 @@ async def _binding_body_context(request: Request) -> dict[str, Any]: - """Surface ``(target_type, target_id)`` to the authorizer's context. + """Surface ``(target_type, target_id)`` to the authorization context. The body-bearing binding endpoints carry the target identifiers in - the request payload. Upstream authorizers that resolve the target's - owning project (e.g., Galileo's ``check_management_access``) need - those identifiers to make a project-level decision; without them the - upstream returns 400. + the request payload. Authorization providers can use those + identifiers when a request needs target-scoped access checks. FastAPI caches the parsed body, so the endpoint's own Pydantic request model still binds normally. @@ -60,13 +58,12 @@ async def _binding_body_context(request: Request) -> dict[str, Any]: async def _binding_list_context(request: Request) -> dict[str, Any]: - """Surface optional target query parameters to the authorizer. + """Surface optional target query parameters to authorization context. When the GET list endpoint is called with ``target_type`` and ``target_id`` query params, the request is target-scoped and the - upstream needs the identifiers to make a project-level decision. - When neither is present the request is namespace-wide and forwards - no target context (upstream may then reject if it requires one). + request context includes those identifiers. When neither is present + the request is namespace-wide and forwards no target context. """ target_type = request.query_params.get("target_type") target_id = request.query_params.get("target_id") From ab9a3f657bcfb6196fdb6eff5744300df54c4998 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:14:39 +0530 Subject: [PATCH 06/42] feat(server): add runtime auth namespace cutover Add explicit none, api_key, and jwt runtime auth modes, including a generic no-auth provider. Move controls, bindings, policies, agents, and evaluation storage lookups onto principal namespace scoping. Cover auth mode selection and principal namespace isolation with server tests. --- .../auth_framework/__init__.py | 7 +- .../auth_framework/config.py | 120 +++++++++++---- .../auth_framework/core.py | 16 +- .../auth_framework/providers/__init__.py | 2 + .../auth_framework/providers/header.py | 39 +++-- .../auth_framework/providers/http_upstream.py | 6 +- .../auth_framework/providers/local_jwt.py | 2 +- .../auth_framework/providers/no_auth.py | 29 ++++ .../agent_control_server/endpoints/agents.py | 92 +++++++----- .../agent_control_server/endpoints/auth.py | 11 +- .../endpoints/control_bindings.py | 47 +++--- .../endpoints/controls.py | 87 +++++++---- .../endpoints/evaluation.py | 26 +++- .../endpoints/policies.py | 77 +++++++--- server/src/agent_control_server/main.py | 19 ++- .../agent_control_server/services/controls.py | 140 +++++++++++++---- server/tests/test_auth_framework.py | 96 +++++++++++- server/tests/test_controls_additional.py | 15 +- server/tests/test_controls_auth.py | 28 ++-- server/tests/test_principal_namespace_flow.py | 141 ++++++++++++++++++ server/tests/test_target_merged_contract.py | 6 +- 21 files changed, 753 insertions(+), 253 deletions(-) create mode 100644 server/src/agent_control_server/auth_framework/providers/no_auth.py create mode 100644 server/tests/test_principal_namespace_flow.py diff --git a/server/src/agent_control_server/auth_framework/__init__.py b/server/src/agent_control_server/auth_framework/__init__.py index 57368d57..0333f2cc 100644 --- a/server/src/agent_control_server/auth_framework/__init__.py +++ b/server/src/agent_control_server/auth_framework/__init__.py @@ -2,10 +2,9 @@ Endpoints declare an :class:`Operation` they need; an installed :class:`RequestAuthorizer` decides whether the request is allowed and -returns the resulting :class:`Principal`. Two providers ship in-tree: -:class:`HeaderAuthProvider` (uses local credential checks) and -:class:`HttpUpstreamAuthProvider` (delegates to a configurable -upstream HTTP service). +returns the resulting :class:`Principal`. Providers ship in-tree for +disabled auth, local credential checks, upstream HTTP authorization, +and local runtime-JWT verification. """ from .core import ( diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 92107b0e..c8f428dc 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -8,15 +8,19 @@ - **Default flow** (everything except runtime). One authorizer handles every operation that does not have a specific override: - :class:`HeaderAuthProvider` (local credentials) or + :class:`NoAuthProvider` (no credentials), + :class:`HeaderAuthProvider` (local API keys), or :class:`HttpUpstreamAuthProvider` (forwards to a configurable URL). -- **Runtime flow.** When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is - configured, :class:`LocalJwtVerifyProvider` is registered as the - override for :data:`Operation.RUNTIME_USE`; the - ``runtime.token_exchange`` operation continues to flow through the - default authorizer because the exchange itself is shaped like a - management call (forward credential, get grant). Without the secret, - no runtime override is installed. +- **Runtime flow.** ``AGENT_CONTROL_RUNTIME_AUTH_MODE`` selects the + override for :data:`Operation.RUNTIME_USE`: ``none`` uses + :class:`NoAuthProvider`, ``api_key`` uses + :class:`HeaderAuthProvider`, and ``jwt`` uses + :class:`LocalJwtVerifyProvider`. When the mode is unset, startup + preserves historical behavior by selecting ``jwt`` if + ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, otherwise ``api_key``. + The ``runtime.token_exchange`` operation continues to flow through + the default authorizer because the exchange itself is shaped like a + management call (forward credential, get grant). """ from __future__ import annotations @@ -30,6 +34,7 @@ HeaderAuthProvider, HttpUpstreamAuthProvider, LocalJwtVerifyProvider, + NoAuthProvider, ) from .providers.http_upstream import HttpUpstreamConfig @@ -43,6 +48,7 @@ _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" # Runtime flow. +_RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" _RUNTIME_TOKEN_SECRET_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_SECRET" _RUNTIME_TOKEN_TTL_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS" _DEFAULT_RUNTIME_TOKEN_TTL_SECONDS = 300 @@ -80,15 +86,19 @@ def configure_auth_from_env() -> None: Default flow: - - ``AGENT_CONTROL_AUTH_MODE=header`` (default): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=api_key`` (default): :class:`HeaderAuthProvider`. + ``header`` remains accepted as a backwards-compatible alias. - ``AGENT_CONTROL_AUTH_MODE=http_upstream``: :class:`HttpUpstreamAuthProvider` pointed at ``AGENT_CONTROL_AUTH_UPSTREAM_URL``. Runtime flow: - - When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, register - :class:`LocalJwtVerifyProvider` as an override for - :data:`Operation.RUNTIME_USE`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key`` (default when no runtime + token secret is configured): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token + secret is configured): :class:`LocalJwtVerifyProvider`. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -101,27 +111,27 @@ def configure_auth_from_env() -> None: global _runtime_auth_config clear_authorizers() _active_providers.clear() - _runtime_auth_config = _load_runtime_auth_config() + runtime_mode = _resolve_runtime_mode() + _runtime_auth_config = ( + _load_runtime_auth_config(require_secret=True) if runtime_mode == "jwt" else None + ) default = _build_default_provider() set_authorizer(default) _active_providers.append(default) - if _runtime_auth_config is not None: - runtime_provider = LocalJwtVerifyProvider(secret=_runtime_auth_config.secret) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": _logger.info( - "Runtime auth enabled: LocalJwtVerifyProvider override installed for %s", + "Runtime auth provider: jwt override installed for %s", Operation.RUNTIME_USE.value, ) else: - _logger.warning( - "Runtime auth disabled (%s not set); %s falls through to the " - "default authorizer, which may grant any authenticated credential. " - "Set the runtime token secret to bind runtime calls to a " - "short-lived target-scoped JWT.", - _RUNTIME_TOKEN_SECRET_ENV, + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, Operation.RUNTIME_USE.value, ) @@ -172,9 +182,12 @@ def set_runtime_auth_config(config: RuntimeAuthConfig | None) -> None: def _build_default_provider() -> RequestAuthorizer: - mode = os.environ.get(_MODE_ENV, "header").strip().lower() - if mode == "header": - _logger.info("Default auth provider: header (local credentials)") + mode = os.environ.get(_MODE_ENV, "api_key").strip().lower() + if mode in {"none", "no_auth"}: + _logger.info("Default auth provider: none") + return NoAuthProvider() + if mode in {"api_key", "header"}: + _logger.info("Default auth provider: api_key (local credentials)") return HeaderAuthProvider() if mode == "http_upstream": url = os.environ.get(_UPSTREAM_URL_ENV) @@ -192,19 +205,60 @@ def _build_default_provider() -> RequestAuthorizer: service_token_header=token_header, ) ) - raise RuntimeError(f"Unknown {_MODE_ENV}={mode!r}; expected 'header' or 'http_upstream'.") + raise RuntimeError( + f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'http_upstream'." + ) + + +def _resolve_runtime_mode() -> str: + raw = os.environ.get(_RUNTIME_MODE_ENV) + if raw is None or not raw.strip(): + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "api_key" + + mode = raw.strip().lower() + if mode in {"none", "no_auth"}: + return "none" + if mode in {"api_key", "header"}: + return "api_key" + if mode == "jwt": + return mode + raise RuntimeError( + f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'jwt'." + ) + + +def _build_runtime_provider( + mode: str, + config: RuntimeAuthConfig | None, +) -> RequestAuthorizer: + if mode == "none": + return NoAuthProvider() + if mode == "api_key": + return HeaderAuthProvider() + if mode == "jwt": + if config is None: + raise RuntimeError(f"{_RUNTIME_MODE_ENV}=jwt but runtime auth config is missing.") + return LocalJwtVerifyProvider(secret=config.secret) + raise RuntimeError( + f"Unknown runtime auth mode {mode!r}; expected 'none', 'api_key', or 'jwt'." + ) -def _load_runtime_auth_config() -> RuntimeAuthConfig | None: +def _load_runtime_auth_config(*, require_secret: bool = False) -> RuntimeAuthConfig | None: """Parse, validate, and return the runtime-auth config from env. - Returns ``None`` when no runtime secret is configured. Raises - ``RuntimeError`` when the secret is too short or the TTL is invalid - so misconfiguration surfaces at startup, not on the first - request-time mint. + Returns ``None`` when no runtime secret is configured and + ``require_secret`` is false. Raises ``RuntimeError`` when the + secret is required, too short, or the TTL is invalid so + misconfiguration surfaces at startup, not on the first request-time + mint. """ secret = os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) if not secret: + if require_secret: + raise RuntimeError( + f"{_RUNTIME_MODE_ENV}=jwt requires {_RUNTIME_TOKEN_SECRET_ENV} to be set." + ) return None if len(secret.encode("utf-8")) < _RUNTIME_TOKEN_SECRET_MIN_BYTES: raise RuntimeError( diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index 9299b441..e0ea6da7 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -42,14 +42,21 @@ class Operation(StrEnum): CONTROL_BINDINGS_READ = "control_bindings.read" CONTROL_BINDINGS_WRITE = "control_bindings.write" - # Runtime token exchange — wired on the exchange endpoint. + # Runtime token exchange - wired on the exchange endpoint. RUNTIME_TOKEN_EXCHANGE = "runtime.token_exchange" - # Reserved for follow-up migrations; not yet wired on endpoints. CONTROLS_READ = "controls.read" CONTROLS_CREATE = "controls.create" CONTROLS_UPDATE = "controls.update" CONTROLS_DELETE = "controls.delete" + POLICIES_READ = "policies.read" + POLICIES_CREATE = "policies.create" + POLICIES_UPDATE = "policies.update" + POLICIES_DELETE = "policies.delete" + AGENTS_READ = "agents.read" + AGENTS_CREATE = "agents.create" + AGENTS_UPDATE = "agents.update" + AGENTS_DELETE = "agents.delete" RUNTIME_USE = "runtime.use" @@ -61,8 +68,7 @@ class Principal: namespace_key: The namespace the request runs in. Endpoints use this to scope every read and write. is_admin: Whether the caller has admin privileges in the - current namespace. Mostly informational for endpoints that - still gate on the legacy admin-key contract. + current namespace. caller_id: Opaque, provider-supplied identifier for the caller (e.g., a key fingerprint or user id). Useful for audit logging; never echo back to clients. @@ -122,7 +128,7 @@ def set_authorizer( Without ``operation``, this becomes the default authorizer used by every operation that does not have a specific override. With - ``operation``, it overrides the default for that operation only — + ``operation``, it overrides the default for that operation only - used to route a different family (e.g., runtime) through a different provider. diff --git a/server/src/agent_control_server/auth_framework/providers/__init__.py b/server/src/agent_control_server/auth_framework/providers/__init__.py index e8a68486..ad5d6b38 100644 --- a/server/src/agent_control_server/auth_framework/providers/__init__.py +++ b/server/src/agent_control_server/auth_framework/providers/__init__.py @@ -3,10 +3,12 @@ from .header import AccessLevel, HeaderAuthProvider from .http_upstream import HttpUpstreamAuthProvider from .local_jwt import LocalJwtVerifyProvider +from .no_auth import NoAuthProvider __all__ = [ "AccessLevel", "HeaderAuthProvider", "HttpUpstreamAuthProvider", "LocalJwtVerifyProvider", + "NoAuthProvider", ] diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index f76936a1..228ec443 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -1,23 +1,14 @@ """Default :class:`RequestAuthorizer` that uses local credentials only. -Resolves the namespace from a header (or falls back to -``DEFAULT_NAMESPACE_KEY``) and enforces a per-operation access level -using the legacy API-key + session-cookie credential check from -:mod:`agent_control_server.auth`. Behavior matches the pre-framework -local auth path verbatim: +Returns ``DEFAULT_NAMESPACE_KEY`` and enforces a per-operation access +level using the local API-key + session-cookie credential check from +:mod:`agent_control_server.auth`: - ``ADMIN`` operations require an admin key (or admin session). - ``AUTHENTICATED`` operations require any valid credential. - ``PUBLIC`` operations are open. -- When ``api_key_enabled`` is ``False`` (no-auth mode), every - operation succeeds with a non-admin :class:`Principal` — preserved - by the underlying credential check. - -The header lookup is wired but currently inert: the provider always -returns the default namespace because non-binding write endpoints -still hardcode it. The header is kept here so a follow-up that -threads namespace resolution through the rest of the API can flip it -on without changing the provider contract. +- When the underlying local credential layer is disabled, every + operation succeeds with a non-admin :class:`Principal`. """ from __future__ import annotations @@ -51,6 +42,14 @@ class AccessLevel(Enum): Operation.CONTROLS_CREATE: AccessLevel.ADMIN, Operation.CONTROLS_UPDATE: AccessLevel.ADMIN, Operation.CONTROLS_DELETE: AccessLevel.ADMIN, + Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, + Operation.POLICIES_CREATE: AccessLevel.ADMIN, + Operation.POLICIES_UPDATE: AccessLevel.ADMIN, + Operation.POLICIES_DELETE: AccessLevel.ADMIN, + Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, + Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, + Operation.AGENTS_UPDATE: AccessLevel.ADMIN, + Operation.AGENTS_DELETE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } @@ -60,7 +59,7 @@ class HeaderAuthProvider(RequestAuthorizer): """Default authorizer. For each operation's configured access level, validates the - request's credentials via the legacy local check; on success, + request's credentials via the local credential check; on success, returns a :class:`Principal` scoped to the resolved namespace. """ @@ -100,8 +99,7 @@ async def authorize( ) # Runtime token exchange returns a normalized scope grant so the # exchange endpoint can require ``runtime.use`` uniformly across - # providers; an upstream that explicitly grants no scopes ends - # up with an empty tuple and is rejected. + # providers. scopes: tuple[str, ...] = ( (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () ) @@ -113,10 +111,7 @@ async def authorize( ) def _resolve_namespace_key(self, request: Request) -> str: - # The provider always returns the default namespace because - # non-binding write endpoints still hardcode it; serving - # anything else here would create rows the rest of the API - # cannot find. The branch is preserved so a future change can - # lift the lock without touching the provider contract. + # Local credentials do not carry namespace metadata. Providers + # that resolve a namespace can return a different principal. del request return self._default_namespace_key diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index a97a3de8..8d5c850c 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -67,8 +67,8 @@ class _UpstreamGrant(BaseModel): """Strict schema for the upstream authorization-service response. Unknown fields are tolerated (so the upstream can evolve), but every - *known* field is type-checked. A wrong type on any field — or a - half-supplied target binding — causes the provider to fail closed + *known* field is type-checked. A wrong type on any field - or a + half-supplied target binding - causes the provider to fail closed with a 502. """ @@ -108,7 +108,7 @@ def _target_must_be_paired(self) -> _UpstreamGrant: A target is meaningful only as a ``(target_type, target_id)`` pair; allowing one side without the other would let a malformed grant pass and the exchange endpoint mint a token for the - request's value of the missing half — outside the upstream's + request's value of the missing half - outside the upstream's intended authorization. """ if (self.target_type is None) != (self.target_id is None): diff --git a/server/src/agent_control_server/auth_framework/providers/local_jwt.py b/server/src/agent_control_server/auth_framework/providers/local_jwt.py index bb448503..8620d3b6 100644 --- a/server/src/agent_control_server/auth_framework/providers/local_jwt.py +++ b/server/src/agent_control_server/auth_framework/providers/local_jwt.py @@ -6,7 +6,7 @@ returns a :class:`Principal` carrying the bound target. When a ``context_builder`` on the dependency surfaces ``target_type`` / ``target_id``, the provider also enforces that they match the token's -binding — runtime endpoints get the request-target check for free. +binding - runtime endpoints get the request-target check for free. """ from __future__ import annotations diff --git a/server/src/agent_control_server/auth_framework/providers/no_auth.py b/server/src/agent_control_server/auth_framework/providers/no_auth.py new file mode 100644 index 00000000..509ca4f3 --- /dev/null +++ b/server/src/agent_control_server/auth_framework/providers/no_auth.py @@ -0,0 +1,29 @@ +"""Authorizer for deployments that intentionally disable authentication.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import Request + +from ...models import DEFAULT_NAMESPACE_KEY +from ..core import Operation, Principal, RequestAuthorizer + + +class NoAuthProvider(RequestAuthorizer): + """Allows every operation and returns the default namespace.""" + + def __init__(self, *, default_namespace_key: str = DEFAULT_NAMESPACE_KEY) -> None: + self._default_namespace_key = default_namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request, context + scopes: tuple[str, ...] = ( + (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () + ) + return Principal(namespace_key=self._default_namespace_key, scopes=scopes) diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 034ae35f..ac099911 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -36,7 +36,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey, require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, @@ -53,7 +53,6 @@ Policy, agent_policies, ) -from ..namespace import get_namespace_key from ..services.agent_names import normalize_agent_name_or_422 from ..services.controls import ( AgentControlEnabledState, @@ -112,7 +111,7 @@ def _validate_controls_for_agent(agent: Agent, controls: list[Control]) -> list[ agent_evaluators = {e.name: e for e in (agent_data.evaluators or [])} for control in controls: - # Skip unrendered template controls — they have no evaluators to validate. + # Skip unrendered template controls - they have no evaluators to validate. if ( isinstance(control.data, dict) and control.data.get("template") is not None @@ -286,7 +285,7 @@ async def list_agents( limit: int = _DEFAULT_PAGINATION_LIMIT, name: str | None = None, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListAgentsResponse: """ List all registered agents with cursor-based pagination. @@ -300,11 +299,13 @@ async def list_agents( limit: Pagination limit (default 20, max 100) name: Optional name filter (case-insensitive partial match) db: Database session (injected) - namespace_key: Resolved namespace for the request + principal: Authorized request principal Returns: ListAgentsResponse with agent summaries and pagination info """ + namespace_key = principal.namespace_key + # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -377,14 +378,20 @@ async def list_agents( agent_policies.c.agent_name, agent_policies.c.policy_id, ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) .order_by(agent_policies.c.agent_name, agent_policies.c.policy_id) ) policy_ids_result = await db.execute(policy_ids_query) for assoc_agent_name, policy_id in policy_ids_result.all(): policy_ids_map.setdefault(assoc_agent_name, []).append(policy_id) - control_counts_map = await control_service.list_active_control_counts_by_agent(agent_names) + control_counts_map = await control_service.list_active_control_counts_by_agent( + agent_names, + namespace_key=namespace_key, + ) # Build summaries summaries: list[AgentSummary] = [] @@ -436,9 +443,8 @@ async def list_agents( ) async def init_agent( request: InitAgentRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. @@ -462,10 +468,13 @@ async def init_agent( Args: request: Agent metadata and step schemas db: Database session (injected) + principal: Authorized request principal Returns: InitAgentResponse with created flag and the effective controls """ + namespace_key = principal.namespace_key + # Check for evaluator name collisions with built-in evaluators builtin_names = _get_builtin_evaluator_names() for ev in request.evaluators: @@ -835,7 +844,7 @@ async def init_agent( async def get_agent( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentResponse: """ Retrieve agent metadata and all registered steps. @@ -845,8 +854,7 @@ async def get_agent( Args: agent_name: Agent identifier db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: GetAgentResponse with agent metadata and step list @@ -855,6 +863,7 @@ async def get_agent( HTTPException 404: Agent not found HTTPException 422: Agent data is corrupted """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -917,7 +926,7 @@ async def _get_agent_or_404( The lookup is always namespace-scoped: an agent that exists only in another namespace surfaces as 404 (non-disclosing) so duplicate - names across namespaces — which the schema explicitly permits — + names across namespaces - which the schema explicitly permits - cannot be addressed across the namespace boundary. """ normalized_agent_name = normalize_agent_name_or_422(agent_name) @@ -940,7 +949,6 @@ async def _get_agent_or_404( @router.post( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate policy with agent", response_description="Success confirmation", @@ -949,9 +957,10 @@ async def add_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a policy with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1017,7 +1026,6 @@ async def add_agent_policy( @router.post( "/{agent_name}/policy/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=SetPolicyResponse, summary="Assign policy to agent (compatibility)", response_description="Success status with previous policy ID", @@ -1026,9 +1034,10 @@ async def set_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> SetPolicyResponse: """Compatibility endpoint that replaces all policy associations with one policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1117,9 +1126,10 @@ async def set_agent_policy( async def get_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentPoliciesResponse: """List policy IDs associated with an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) result = await db.execute( select(agent_policies.c.policy_id) @@ -1141,9 +1151,10 @@ async def get_agent_policies( async def get_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetPolicyResponse: """Compatibility endpoint that returns the first associated policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( select(Policy.id) @@ -1172,7 +1183,6 @@ async def get_agent_policy( @router.delete( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove policy association from agent", response_description="Success confirmation", @@ -1181,13 +1191,14 @@ async def remove_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove a policy association from an agent. Idempotent for existing resources: removing a non-associated link is a no-op. Missing agent/policy resources still return 404. """ + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1230,7 +1241,6 @@ async def remove_agent_policy( @router.delete( "/{agent_name}/policies", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove all policy associations from agent", response_description="Success confirmation", @@ -1238,9 +1248,10 @@ async def remove_agent_policy( async def remove_all_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove all policy associations from an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) try: @@ -1271,7 +1282,6 @@ async def remove_all_agent_policies( @router.delete( "/{agent_name}/policy", - dependencies=[Depends(require_admin_key)], response_model=DeletePolicyResponse, summary="Remove agent's policy assignment (compatibility)", response_description="Success confirmation", @@ -1279,9 +1289,10 @@ async def remove_all_agent_policies( async def delete_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> DeletePolicyResponse: """Compatibility endpoint that removes all policy associations.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) existing_policy_result = await db.execute( @@ -1328,7 +1339,6 @@ async def delete_agent_policy( @router.post( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate control directly with agent", response_description="Success confirmation", @@ -1337,9 +1347,10 @@ async def add_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a control directly with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) control = await control_service.get_active_control_or_404( @@ -1389,7 +1400,6 @@ async def add_agent_control( @router.delete( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=RemoveAgentControlResponse, summary="Remove direct control association from agent", response_description="Success confirmation", @@ -1398,9 +1408,10 @@ async def remove_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> RemoveAgentControlResponse: """Remove a direct control association from an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) await control_service.get_active_control_or_404(control_id, namespace_key=namespace_key) @@ -1481,7 +1492,7 @@ async def list_agent_controls( description="Optional opaque target identifier. Required when target_type is supplied.", ), db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> AgentControlsResponse: """ List protection controls effective for an agent. @@ -1506,7 +1517,7 @@ async def list_agent_controls( target_type: Optional opaque target kind (paired with target_id) target_id: Optional opaque target identifier (paired with target_type) db: Database session (injected) - namespace_key: Namespace scoping for the resolution (injected) + principal: Authorized request principal Returns: AgentControlsResponse with controls matching the requested state filters @@ -1515,6 +1526,8 @@ async def list_agent_controls( HTTPException 400: target_type and target_id were not supplied together HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key + if (target_type is None) != (target_id is None): raise BadRequestError( error_code=ErrorCode.VALIDATION_ERROR, @@ -1572,7 +1585,7 @@ async def list_agent_evaluators( cursor: str | None = None, limit: int = _DEFAULT_PAGINATION_LIMIT, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListEvaluatorsResponse: """ List all evaluator schemas registered with an agent. @@ -1586,8 +1599,7 @@ async def list_agent_evaluators( cursor: Optional cursor for pagination (name of last evaluator from previous page) limit: Pagination limit (default 20, max 100) db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: ListEvaluatorsResponse with evaluator schemas and pagination @@ -1595,6 +1607,7 @@ async def list_agent_evaluators( Raises: HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -1672,7 +1685,7 @@ async def get_agent_evaluator( agent_name: str, evaluator_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> EvaluatorSchemaItem: """ Get a specific evaluator schema registered with an agent. @@ -1681,8 +1694,7 @@ async def get_agent_evaluator( agent_name: Agent identifier evaluator_name: Name of the evaluator db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: EvaluatorSchemaItem with schema details @@ -1690,6 +1702,7 @@ async def get_agent_evaluator( Raises: HTTPException 404: Agent or evaluator not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -1734,7 +1747,6 @@ async def get_agent_evaluator( @router.patch( "/{agent_name}", - dependencies=[Depends(require_admin_key)], response_model=PatchAgentResponse, summary="Modify agent (remove steps/evaluators)", response_description="Lists of removed items", @@ -1743,7 +1755,7 @@ async def patch_agent( agent_name: str, request: PatchAgentRequest, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> PatchAgentResponse: """ Remove steps and/or evaluators from an agent. @@ -1755,6 +1767,7 @@ async def patch_agent( agent_name: Agent identifier request: Lists of step/evaluator identifiers to remove db: Database session (injected) + principal: Authorized request principal Returns: PatchAgentResponse with lists of actually removed items @@ -1763,6 +1776,7 @@ async def patch_agent( HTTPException 404: Agent not found HTTPException 500: Database error during update """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index 1a23baa8..f80cd2fa 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,9 +2,8 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer (typically -:class:`HttpUpstreamAuthProvider` in production) authenticates the -credential and authorizes the implied +target_id)``; the default authorizer authenticates the credential and +authorizes the implied :data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the @@ -130,8 +129,8 @@ async def runtime_token_exchange( actor_id = principal.caller_id or "anonymous" # The exchange endpoint requires the authorizer to explicitly grant - # runtime.use. Providers that do not surface scopes (legacy local - # provider) supply a normalized grant for ``RUNTIME_TOKEN_EXCHANGE``; + # runtime.use. Local providers supply a normalized grant for + # ``RUNTIME_TOKEN_EXCHANGE``; # upstream providers that return an explicit empty scopes array fail # closed here rather than escalating to runtime.use. if Operation.RUNTIME_USE.value not in principal.scopes: @@ -155,7 +154,7 @@ async def runtime_token_exchange( ) except UpstreamGrantExpiredError as exc: # Upstream returned a grant whose ``expires_at`` is already in - # the past — minting would hand the caller a token that's dead + # the past - minting would hand the caller a token that's dead # on arrival. Distinguished from the misconfigured case so the # error code and status reflect "upstream returned bad data." raise APIError( diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index 92798ae1..d2fe4b44 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -26,7 +26,6 @@ from ..db import get_async_db from ..errors import BadRequestError from ..models import ControlBinding -from ..namespace import get_namespace_key from ..services.control_bindings import ControlBindingsService router = APIRouter(prefix="/control-bindings", tags=["control-bindings"]) @@ -94,26 +93,21 @@ def _to_response(binding: ControlBinding) -> GetControlBindingResponse: async def create_control_binding( request: CreateControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. Each binding row is scoped to the request namespace as resolved by - ``get_namespace_key``. The auth chain still runs via - ``require_operation`` for authentication and authorization, but the - storage namespace is taken from the same resolver the rest of the - server uses so binding writes and runtime reads stay in lockstep - until auth-derived namespace resolution lands across every endpoint. + the active authorizer. """ service = ControlBindingsService(db) binding = await service.create_binding( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -148,20 +142,18 @@ async def list_control_bindings( target_id: str | None = None, control_id: int | None = None, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_READ, context_builder=_binding_list_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> ListControlBindingsResponse: """Return bindings in the request namespace with optional filters and cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by ``get_namespace_key`` so this - listing stays in lockstep with the rest of the server's reads. + storage namespace is resolved by the active authorizer. """ parsed_cursor: int | None if cursor is None: @@ -177,7 +169,7 @@ async def list_control_bindings( ) from exc service = ControlBindingsService(db) page = await service.list_bindings( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, cursor=parsed_cursor, limit=limit, target_type=target_type, @@ -204,8 +196,7 @@ async def list_control_bindings( async def get_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), ) -> GetControlBindingResponse: """Read a single control binding by surrogate ID. @@ -218,7 +209,9 @@ async def get_control_binding( of which forward ``(target_type, target_id)`` to the authorizer. """ service = ControlBindingsService(db) - binding = await service.get_binding_or_404(namespace_key=namespace_key, binding_id=binding_id) + binding = await service.get_binding_or_404( + namespace_key=principal.namespace_key, binding_id=binding_id + ) return _to_response(binding) @@ -232,8 +225,7 @@ async def patch_control_binding( binding_id: int, request: PatchControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> PatchControlBindingResponse: """Update the ``enabled`` flag on a control binding. @@ -244,7 +236,7 @@ async def patch_control_binding( """ service = ControlBindingsService(db) binding = await service.set_enabled( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, binding_id=binding_id, enabled=request.enabled, ) @@ -261,8 +253,7 @@ async def patch_control_binding( async def delete_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> DeleteControlBindingResponse: """Delete a control binding by surrogate ID. @@ -272,7 +263,7 @@ async def delete_control_binding( target-scoped detach that forwards the target to the authorizer. """ service = ControlBindingsService(db) - await service.delete_binding(namespace_key=namespace_key, binding_id=binding_id) + await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) await db.commit() return DeleteControlBindingResponse(success=True) @@ -286,13 +277,12 @@ async def delete_control_binding( async def upsert_control_binding_by_key( request: UpsertControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> UpsertControlBindingResponse: """Idempotent attach using ``(target_type, target_id, control_id)`` as the natural key. Updates ``enabled`` on an existing match; creates a new row @@ -300,7 +290,7 @@ async def upsert_control_binding_by_key( """ service = ControlBindingsService(db) binding, created = await service.upsert_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -324,20 +314,19 @@ async def upsert_control_binding_by_key( async def delete_control_binding_by_key( request: DeleteControlBindingByKeyRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> DeleteControlBindingByKeyResponse: """Idempotent detach by natural key. Returns ``deleted=False`` when no matching binding exists. """ service = ControlBindingsService(db) deleted = await service.delete_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index fcb7cb18..5b01593c 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -229,7 +229,7 @@ async def _materialize_control_input( enabled=enabled, ) - # Incomplete values — only allowed for new controls or already-unrendered + # Incomplete values - only allowed for new controls or already-unrendered # templates. Updating a rendered control with incomplete values is # rejected to prevent silently stripping rendered fields. current_is_rendered = ( @@ -470,7 +470,7 @@ async def render_control_template( async def create_control( request: CreateControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> CreateControlResponse: """ Create a new control with a unique name. @@ -492,7 +492,10 @@ async def create_control( control_service = ControlService(db) # Uniqueness check - if await control_service.active_control_name_exists(request.name): + namespace_key = principal.namespace_key + if await control_service.active_control_name_exists( + request.name, namespace_key=namespace_key + ): raise ConflictError( error_code=ErrorCode.CONTROL_NAME_CONFLICT, detail=f"Control with name '{request.name}' already exists", @@ -504,7 +507,11 @@ async def create_control( control_def = await _materialize_control_input(request.data, db=db) control_data = _serialize_control_data(control_def) - control = control_service.create_control(name=request.name, data=control_data) + control = control_service.create_control( + namespace_key=namespace_key, + name=request.name, + data=control_data, + ) try: await control_service.create_version( control, @@ -569,7 +576,7 @@ async def get_control_schema() -> GetControlSchemaResponse: async def get_control( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlResponse: """ Retrieve a control by ID including its name and configuration data. @@ -584,7 +591,9 @@ async def get_control( Raises: HTTPException 404: Control not found """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -608,7 +617,7 @@ async def get_control( async def get_control_data( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlDataResponse: """ Retrieve the configuration data for a control. @@ -626,7 +635,9 @@ async def get_control_data( HTTPException 404: Control not found HTTPException 422: Control data is corrupted """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -648,10 +659,15 @@ async def list_control_versions( ), limit: int = Query(_DEFAULT_PAGINATION_LIMIT, ge=1, le=_MAX_PAGINATION_LIMIT), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlVersionsResponse: """List control versions ordered newest-first using cursor-based pagination.""" - page = await ControlService(db).list_versions(control_id, cursor=cursor, limit=limit) + page = await ControlService(db).list_versions( + control_id, + namespace_key=principal.namespace_key, + cursor=cursor, + limit=limit, + ) return ListControlVersionsResponse( versions=[ @@ -682,10 +698,12 @@ async def get_control_version( control_id: int, version_num: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlVersionResponse: """Return a specific control version, including its raw persisted snapshot.""" - version = await ControlService(db).get_version_or_404(control_id, version_num) + version = await ControlService(db).get_version_or_404( + control_id, version_num, namespace_key=principal.namespace_key + ) return GetControlVersionResponse( version_num=version.version_num, event_type=version.event_type, @@ -705,7 +723,7 @@ async def set_control_data( control_id: int, request: SetControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> SetControlDataResponse: """ Update the configuration data for a control. @@ -726,7 +744,9 @@ async def set_control_data( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=principal.namespace_key, for_update=True + ) control_def = await _materialize_control_input( request.data, @@ -767,11 +787,12 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Validation uses the authoring path, so require create access. +# Authorized as CONTROLS_READ: validate exercises the materialization +# path but does not mutate stored control data. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. @@ -811,7 +832,7 @@ async def list_controls( execution: str | None = Query(None, description="Filter by execution ('server' or 'sdk')"), tag: str | None = Query(None, description="Filter by tag"), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlsResponse: """ List all controls with optional filtering and cursor-based pagination. @@ -837,7 +858,9 @@ async def list_controls( GET /controls?limit=10&enabled=true&step_type=tool """ control_service = ControlService(db) + namespace_key = principal.namespace_key page = await control_service.list_controls_page( + namespace_key=namespace_key, cursor=cursor, limit=limit, name=name, @@ -849,7 +872,8 @@ async def list_controls( tag=tag, ) usage_by_control_id = await control_service.list_control_usage( - [control.id for control in page.controls] + [control.id for control in page.controls], + namespace_key=namespace_key, ) # Build summaries (filtering already done at DB level) @@ -910,7 +934,7 @@ async def delete_control( "If false, fail if control is associated with any policy or agent.", ), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), ) -> DeleteControlResponse: """ Delete a control by ID. @@ -933,13 +957,18 @@ async def delete_control( """ control_service = ControlService(db) bindings_service = ControlBindingsService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) - associations = await control_service.list_control_associations(control_id) + associations = await control_service.list_control_associations( + control_id, namespace_key=namespace_key + ) associated_policy_ids = associations.policy_ids associated_agent_names = associations.agent_names target_binding_ids = await bindings_service.list_binding_ids_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if ( @@ -996,13 +1025,15 @@ async def delete_control( dissociated_from_policies: list[int] = [] dissociated_from_agents: list[str] = [] if associated_policy_ids or associated_agent_names: - dissociated = await control_service.remove_all_control_associations(control_id) + dissociated = await control_service.remove_all_control_associations( + control_id, namespace_key=namespace_key + ) dissociated_from_policies = dissociated.policy_ids dissociated_from_agents = dissociated.agent_names detached_target_bindings: list[int] = [] if target_binding_ids: detached_target_bindings = await bindings_service.delete_bindings_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if dissociated_from_policies or dissociated_from_agents or detached_target_bindings: _logger.info( @@ -1057,7 +1088,7 @@ async def patch_control( control_id: int, request: PatchControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> PatchControlResponse: """ Update control metadata (name and/or enabled status). @@ -1081,7 +1112,10 @@ async def patch_control( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) parsed_control = _parse_stored_control_data( control.data, control_name=control.name, @@ -1096,6 +1130,7 @@ async def patch_control( # Check for name collision if await control_service.active_control_name_exists( request.name, + namespace_key=namespace_key, exclude_control_id=control_id, ): raise ConflictError( diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index e018796e..437af8b5 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -10,16 +10,15 @@ EvaluationResponse, ) from agent_control_models.errors import ErrorCode, ValidationErrorItem -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import APIValidationError, NotFoundError from ..logging_utils import get_logger from ..models import Agent -from ..namespace import get_namespace_key from ..services.controls import ControlService router = APIRouter(prefix="/evaluation", tags=["evaluation"]) @@ -118,6 +117,20 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) +async def _evaluation_context(request: Request) -> dict[str, object]: + """Surface target identifiers to the runtime authorizer.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + return {} + if not isinstance(body, dict): + return {} + return { + "target_type": body.get("target_type"), + "target_id": body.get("target_id"), + } + + @router.post( "", response_model=EvaluationResponse, @@ -126,9 +139,10 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) async def evaluate( request: EvaluationRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends( + require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context) + ), ) -> EvaluationResponse: """Analyze content for safety and control violations. @@ -144,7 +158,7 @@ async def evaluate( on the server; SDKs reconstruct and emit those events separately through the observability ingestion endpoint. """ - del client # Authentication is still required by dependency injection. + namespace_key = principal.namespace_key agent_result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/policies.py b/server/src/agent_control_server/endpoints/policies.py index 7b8b2ef9..ddda7127 100644 --- a/server/src/agent_control_server/endpoints/policies.py +++ b/server/src/agent_control_server/endpoints/policies.py @@ -9,7 +9,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ConflictError, DatabaseError, NotFoundError from ..logging_utils import get_logger @@ -23,13 +23,14 @@ @router.put( "", - dependencies=[Depends(require_admin_key)], response_model=CreatePolicyResponse, summary="Create a new policy", response_description="Created policy ID", ) async def create_policy( - request: CreatePolicyRequest, db: AsyncSession = Depends(get_async_db) + request: CreatePolicyRequest, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_CREATE)), ) -> CreatePolicyResponse: """ Create a new empty policy with a unique name. @@ -48,8 +49,14 @@ async def create_policy( HTTPException 409: Policy with this name already exists HTTPException 500: Database error during creation """ + namespace_key = principal.namespace_key # Uniqueness check - existing = await db.execute(select(Policy.id).where(Policy.name == request.name)) + existing = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.name == request.name, + ) + ) if existing.first() is not None: raise ConflictError( error_code=ErrorCode.POLICY_NAME_CONFLICT, @@ -59,7 +66,7 @@ async def create_policy( hint="Choose a different name or update the existing policy.", ) - policy = Policy(name=request.name) + policy = Policy(namespace_key=namespace_key, name=request.name) db.add(policy) try: await db.commit() @@ -80,13 +87,15 @@ async def create_policy( @router.post( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Add control to policy", response_description="Success confirmation", ) async def add_control_to_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Associate a control with a policy. @@ -106,8 +115,14 @@ async def add_control_to_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ + namespace_key = principal.namespace_key # Find policy and control - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -119,11 +134,17 @@ async def add_control_to_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Add association using INSERT ... ON CONFLICT DO NOTHING for idempotency try: - await control_service.add_control_to_policy(policy_id=policy_id, control_id=control_id) + await control_service.add_control_to_policy( + policy_id=policy_id, + control_id=control_id, + namespace_key=namespace_key, + ) await db.commit() except Exception: await db.rollback() @@ -149,13 +170,15 @@ async def add_control_to_policy( @router.delete( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove control from policy", response_description="Success confirmation", ) async def remove_control_from_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Remove a control from a policy. @@ -175,7 +198,13 @@ async def remove_control_from_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -187,13 +216,16 @@ async def remove_control_from_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Remove association (idempotent - deleting non-existent is no-op) try: await control_service.remove_control_from_policy( policy_id=policy_id, control_id=control_id, + namespace_key=namespace_key, ) await db.commit() except Exception: @@ -222,7 +254,9 @@ async def remove_control_from_policy( response_description="List of control IDs", ) async def list_policy_controls( - policy_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_READ)), ) -> GetPolicyControlsResponse: """ List all controls associated with a policy. @@ -237,7 +271,13 @@ async def list_policy_controls( Raises: HTTPException 404: Policy not found """ - pol_res = await db.execute(select(Policy.id).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) if pol_res.first() is None: raise NotFoundError( error_code=ErrorCode.POLICY_NOT_FOUND, @@ -247,5 +287,8 @@ async def list_policy_controls( hint="Verify the policy ID is correct and the policy has been created.", ) - control_ids = await ControlService(db).list_policy_control_ids(policy_id) + control_ids = await ControlService(db).list_policy_control_ids( + policy_id, + namespace_key=namespace_key, + ) return GetPolicyControlsResponse(control_ids=control_ids) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index bc1bf04b..a1561e63 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -252,7 +252,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # Register handler for FastAPI's RequestValidationError (Pydantic validation) app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore[arg-type] -# Register handler for standard HTTPException (legacy code, FastAPI internals) +# Register handler for standard HTTPException (older routes, FastAPI internals) app.add_exception_handler(HTTPException, http_exception_handler) # type: ignore[arg-type] # Register catch-all handler for unexpected exceptions @@ -261,16 +261,18 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # API v1 prefix for all routes api_v1_prefix = f"{settings.api_prefix}/{settings.api_version}" -# Protected routes (require valid API key) +# API routers. Routers migrated to the auth framework mount the +# non-validating header extractor only so OpenAPI advertises X-API-Key; +# each endpoint's ``require_operation`` dependency owns authn + authz. app.include_router( agent_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( policy_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( # Endpoint dependencies handle auth; this advertises X-API-Key. @@ -281,11 +283,11 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( # The auth framework on each endpoint owns authentication and # authorization for control bindings, so this router is mounted - # without the legacy router-level gate. See ``auth_framework`` for + # without the router-level auth gate. See ``auth_framework`` for # the provider contract. ``get_api_key_from_header`` is a non- # validating extractor (``auto_error=False``); it is attached purely # so the generated OpenAPI spec advertises the X-API-Key requirement - # on these routes — without it, downstream SDK generators would treat + # on these routes - without it, downstream SDK generators would treat # the routes as unauthenticated. control_binding_router, prefix=api_v1_prefix, @@ -309,9 +311,10 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( evaluation_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) +# Evaluator discovery still uses the local credential dependency. app.include_router( evaluator_router, prefix=api_v1_prefix, @@ -324,7 +327,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- prefix=api_v1_prefix, ) -# System routes (config, login, logout) — no auth required +# System routes (config, login, logout) - no auth required app.include_router( system_router, prefix=settings.api_prefix, diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 263120b7..41a62282 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -20,6 +20,7 @@ from ..errors import APIValidationError, NotFoundError from ..models import ( + DEFAULT_NAMESPACE_KEY, Control, ControlBinding, ControlVersion, @@ -96,9 +97,15 @@ class ControlService: def __init__(self, db: AsyncSession) -> None: self._db = db - def create_control(self, *, name: str, data: dict[str, Any]) -> Control: + def create_control( + self, + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + name: str, + data: dict[str, Any], + ) -> Control: """Create a new pending control row.""" - control = Control(name=name, data=data) + control = Control(namespace_key=namespace_key, name=name, data=data) self._db.add(control) return control @@ -128,10 +135,13 @@ async def get_control_or_404( self, control_id: int, *, + namespace_key: str | None = None, for_update: bool = False, ) -> Control: """Load any control row, including soft-deleted controls.""" stmt = select(Control).where(Control.id == control_id) + if namespace_key is not None: + stmt = stmt.where(Control.namespace_key == namespace_key) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -180,10 +190,15 @@ async def active_control_name_exists( self, name: str, *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" - stmt = select(Control.id).where(Control.name == name, Control.deleted_at.is_(None)) + stmt = select(Control.id).where( + Control.namespace_key == namespace_key, + Control.name == name, + Control.deleted_at.is_(None), + ) if exclude_control_id is not None: stmt = stmt.where(Control.id != exclude_control_id) result = await self._db.execute(stmt) @@ -216,11 +231,12 @@ async def list_versions( self, control_id: int, *, + namespace_key: str, cursor: int | None, limit: int, ) -> ControlVersionPage: """Return control versions newest-first with cursor pagination.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) total_result = await self._db.execute( select(func.count()) @@ -255,9 +271,11 @@ async def list_versions( next_cursor=next_cursor, ) - async def get_version_or_404(self, control_id: int, version_num: int) -> ControlVersion: + async def get_version_or_404( + self, control_id: int, version_num: int, *, namespace_key: str + ) -> ControlVersion: """Load a specific version row for a control.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) result = await self._db.execute( select(ControlVersion).where( @@ -303,12 +321,17 @@ async def list_controls_for_policy( result = await self._db.execute(stmt) return list(result.scalars().unique().all()) - async def list_policy_control_ids(self, policy_id: int) -> list[int]: + async def list_policy_control_ids(self, policy_id: int, *, namespace_key: str) -> list[int]: """Return active control IDs directly associated with a policy.""" result = await self._db.execute( select(policy_controls.c.control_id) .join(Control, Control.id == policy_controls.c.control_id) - .where(policy_controls.c.policy_id == policy_id, Control.deleted_at.is_(None)) + .where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.policy_id == policy_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) .order_by(policy_controls.c.control_id) ) return [cast(int, row[0]) for row in result.all()] @@ -396,6 +419,7 @@ async def list_runtime_controls_for_agent( async def list_controls_page( self, *, + namespace_key: str, cursor: int | None, limit: int, name: str | None, @@ -407,7 +431,11 @@ async def list_controls_page( tag: str | None, ) -> ControlListPage: """Return paginated active controls for the browse endpoint.""" - query = select(Control).where(Control.deleted_at.is_(None)).order_by(Control.id.desc()) + query = ( + select(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + .order_by(Control.id.desc()) + ) query = self._apply_control_list_filters( query, name=name, @@ -424,7 +452,11 @@ async def list_controls_page( result = await self._db.execute(query.limit(limit + 1)) controls = list(result.scalars().all()) - total_query = select(func.count()).select_from(Control).where(Control.deleted_at.is_(None)) + total_query = ( + select(func.count()) + .select_from(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + ) total_query = self._apply_control_list_filters( total_query, name=name, @@ -453,7 +485,9 @@ async def list_controls_page( next_cursor=next_cursor, ) - async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, ControlUsage]: + async def list_control_usage( + self, control_ids: Sequence[int], *, namespace_key: str + ) -> dict[int, ControlUsage]: """Return representative agent usage and usage counts for the provided controls.""" if not control_ids: return {} @@ -465,8 +499,16 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_policies.c.agent_name, ) .select_from(policy_controls) - .join(agent_policies, policy_controls.c.policy_id == agent_policies.c.policy_id) - .where(policy_controls.c.control_id.in_(control_ids)) + .join( + agent_policies, + (policy_controls.c.policy_id == agent_policies.c.policy_id) + & (policy_controls.c.namespace_key == agent_policies.c.namespace_key), + ) + .where( + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.control_id.in_(control_ids), + ) ) direct_agents_query = ( select( @@ -474,7 +516,10 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_controls.c.agent_name, ) .select_from(agent_controls) - .where(agent_controls.c.control_id.in_(control_ids)) + .where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id.in_(control_ids), + ) ) agents_result = await self._db.execute(union_all(policy_agents_query, direct_agents_query)) for control_id, agent_name in agents_result.all(): @@ -491,6 +536,8 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: @@ -503,15 +550,24 @@ async def list_active_control_counts_by_agent( ) .select_from( agent_policies.join( - policy_controls, agent_policies.c.policy_id == policy_controls.c.policy_id + policy_controls, + (agent_policies.c.policy_id == policy_controls.c.policy_id) + & (agent_policies.c.namespace_key == policy_controls.c.namespace_key), ) ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) ) direct_associations = select( agent_controls.c.agent_name.label("agent_name"), agent_controls.c.control_id.label("control_id"), - ).where(agent_controls.c.agent_name.in_(agent_names)) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.agent_name.in_(agent_names), + ) all_associations = union_all(policy_associations, direct_associations).subquery() result = await self._db.execute( @@ -521,6 +577,7 @@ async def list_active_control_counts_by_agent( ) .join(Control, all_associations.c.control_id == Control.id) .where( + Control.namespace_key == namespace_key, Control.deleted_at.is_(None), or_( Control.data["enabled"].astext == "true", @@ -531,19 +588,28 @@ async def list_active_control_counts_by_agent( ) return {cast(str, row[0]): cast(int, row[1]) for row in result.all()} - async def add_control_to_policy(self, *, policy_id: int, control_id: int) -> None: + async def add_control_to_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Create a policy-control association if it does not already exist.""" await self._db.execute( pg_insert(policy_controls) - .values(policy_id=policy_id, control_id=control_id) + .values( + namespace_key=namespace_key, + policy_id=policy_id, + control_id=control_id, + ) .on_conflict_do_nothing() ) - async def remove_control_from_policy(self, *, policy_id: int, control_id: int) -> None: + async def remove_control_from_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Remove a policy-control association if it exists.""" await self._db.execute( delete(policy_controls).where( - (policy_controls.c.policy_id == policy_id) + (policy_controls.c.namespace_key == namespace_key) + & (policy_controls.c.policy_id == policy_id) & (policy_controls.c.control_id == control_id) ) ) @@ -613,16 +679,24 @@ async def remove_control_from_agent( control_still_active=policy_inheritance_result.first() is not None, ) - async def list_control_associations(self, control_id: int) -> ControlAssociations: + async def list_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Return all policy and direct agent associations for a control.""" policy_assoc_query = select( policy_controls.c.policy_id.label("policy_id"), literal(None, type_=String).label("agent_name"), - ).where(policy_controls.c.control_id == control_id) + ).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) agent_assoc_query = select( literal(None, type_=Integer).label("policy_id"), agent_controls.c.agent_name.label("agent_name"), - ).where(agent_controls.c.control_id == control_id) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) assoc_result = await self._db.execute(union_all(policy_assoc_query, agent_assoc_query)) policy_ids: set[int] = set() @@ -638,16 +712,26 @@ async def list_control_associations(self, control_id: int) -> ControlAssociation agent_names=sorted(agent_names), ) - async def remove_all_control_associations(self, control_id: int) -> ControlAssociations: + async def remove_all_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Remove all policy and direct agent associations for a control.""" - associations = await self.list_control_associations(control_id) + associations = await self.list_control_associations( + control_id, namespace_key=namespace_key + ) if associations.policy_ids: await self._db.execute( - delete(policy_controls).where(policy_controls.c.control_id == control_id) + delete(policy_controls).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) ) if associations.agent_names: await self._db.execute( - delete(agent_controls).where(agent_controls.c.control_id == control_id) + delete(agent_controls).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) ) return associations diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 96c4aad8..2d39bfa3 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -20,6 +20,7 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( DEFAULT_OPERATION_ACCESS, @@ -64,6 +65,35 @@ def test_default_operation_access_covers_every_operation(): assert not missing, f"Operations missing default access mapping: {missing}" +# --------------------------------------------------------------------------- +# NoAuthProvider +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_no_auth_provider_allows_any_operation(): + provider = NoAuthProvider(default_namespace_key="ns-local") + + principal = await provider.authorize( + _build_request(), + Operation.CONTROLS_DELETE, + ) + + assert principal == Principal(namespace_key="ns-local") + + +@pytest.mark.asyncio +async def test_no_auth_provider_grants_runtime_exchange_scope(): + provider = NoAuthProvider() + + principal = await provider.authorize( + _build_request(), + Operation.RUNTIME_TOKEN_EXCHANGE, + ) + + assert principal.scopes == (Operation.RUNTIME_USE.value,) + + # --------------------------------------------------------------------------- # HeaderAuthProvider # --------------------------------------------------------------------------- @@ -101,7 +131,7 @@ async def test_header_provider_public_returns_default_namespace(): @pytest.mark.asyncio -async def test_header_provider_authenticated_calls_legacy_validator(): +async def test_header_provider_authenticated_calls_local_validator(): provider = HeaderAuthProvider() expected_client = MagicMock(is_admin=False, key_id="abc12345") @@ -945,6 +975,70 @@ def test_runtime_ttl_loader_accepts_max(monkeypatch): ) +def test_build_default_provider_accepts_none_mode(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + + assert isinstance(auth_config._build_default_provider(), NoAuthProvider) + + +def test_resolve_runtime_mode_defaults_to_api_key_without_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + assert auth_config._resolve_runtime_mode() == "api_key" + + +def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + assert auth_config._resolve_runtime_mode() == "jwt" + + +def test_configure_runtime_none_installs_no_auth_provider(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "api_key") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), HeaderAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_jwt_requires_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + with pytest.raises(RuntimeError, match="requires AGENT_CONTROL_RUNTIME_TOKEN_SECRET"): + auth_config.configure_auth_from_env() + + def test_configure_then_reconfigure_clears_runtime_override(monkeypatch): """Reconfiguring without a runtime secret must drop the override.""" from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index b4922b9d..dfbb15f5 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -8,19 +8,19 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from agent_control_evaluators import RegexEvaluatorConfig +from agent_control_models import ConditionNode from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from agent_control_models import ConditionNode +from agent_control_server.auth_framework import Principal from agent_control_server.db import get_async_db -from agent_control_server.models import Control - -from agent_control_evaluators import RegexEvaluatorConfig from agent_control_server.endpoints import controls as controls_module from agent_control_server.main import app +from agent_control_server.models import DEFAULT_NAMESPACE_KEY, Control from .conftest import engine from .utils import VALID_CONTROL_PAYLOAD @@ -1106,7 +1106,12 @@ def model_dump(self, *args: object, **kwargs: object) -> dict[str, object]: request = SimpleNamespace(data=DummyData(payload)) # When: updating the control data with a non-Pydantic selector - response = await controls_module.set_control_data(control.id, request, async_db) + response = await controls_module.set_control_data( + control.id, + request, + async_db, + principal=Principal(namespace_key=DEFAULT_NAMESPACE_KEY), + ) # Then: the update succeeds and uses the original selector serialization assert response.success is True diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index 1a2af21f..c0f17754 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,14 +4,13 @@ import uuid -import pytest from fastapi.testclient import TestClient -from agent_control_server.config import auth_settings +from agent_control_server.auth_framework import set_authorizer +from agent_control_server.auth_framework.providers import NoAuthProvider from .utils import VALID_CONTROL_PAYLOAD - _CONTROLS_URL = "/api/v1/controls" _TEMPLATES_URL = "/api/v1/control-templates" @@ -199,18 +198,19 @@ def test_non_admin_cannot_delete_control( assert resp.status_code == 403, resp.text -def test_non_admin_cannot_validate_control_data( +def test_non_admin_can_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` requires ``CONTROLS_CREATE``.""" + """``/controls/validate`` requires ``CONTROLS_READ``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", json={"data": VALID_CONTROL_PAYLOAD}, ) - # Then: validation is admin-only - assert resp.status_code == 403, resp.text + # Then: validation is allowed for authenticated non-admin callers + assert resp.status_code == 200, resp.text + assert resp.json()["success"] is True def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: @@ -283,21 +283,16 @@ def test_unauthenticated_cannot_render_template( # --------------------------------------------------------------------------- -# No-auth deployment mode: api_key_enabled=False bypasses every gate. +# No-auth deployment mode: explicit provider bypasses every gate. # --------------------------------------------------------------------------- def test_no_auth_mode_allows_writes_without_credentials( unauthenticated_client: TestClient, - monkeypatch: pytest.MonkeyPatch, ) -> None: - """When ``api_key_enabled`` is False, the ``HeaderAuthProvider`` - short-circuits to a non-admin ``Principal`` for every operation, - including admin-level writes. This pins the "no auth" deployment - path so a future refactor can't silently start enforcing. - """ - # Given: api_key_enabled is False (single-tenant OSS dev mode) - monkeypatch.setattr(auth_settings, "api_key_enabled", False) + """Explicit no-auth provider allows every operation without credentials.""" + # Given: the request-auth framework is in no-auth mode + set_authorizer(NoAuthProvider()) # When: an unauthenticated client creates a control resp = unauthenticated_client.put( @@ -311,4 +306,3 @@ def test_no_auth_mode_allows_writes_without_credentials( # Then: the create succeeds because auth is disabled at the provider assert resp.status_code == 200, resp.text assert "control_id" in resp.json() - diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py new file mode 100644 index 00000000..40ecd216 --- /dev/null +++ b/server/tests/test_principal_namespace_flow.py @@ -0,0 +1,141 @@ +"""HTTP-level coverage for principal-derived namespace scoping.""" + +from __future__ import annotations + +import uuid +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from agent_control_server.auth_framework import ( + Operation, + Principal, + set_authorizer, +) + +from .utils import VALID_CONTROL_PAYLOAD + + +class HeaderNamespaceAuthorizer: + """Test authorizer that maps a request header to ``Principal.namespace_key``.""" + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del context + scopes = ( + (Operation.RUNTIME_USE.value,) + if operation is Operation.RUNTIME_TOKEN_EXCHANGE + else () + ) + return Principal( + namespace_key=request.headers.get("X-Test-Namespace", "default"), + is_admin=True, + scopes=scopes, + ) + + +def _client(app: FastAPI, namespace_key: str) -> TestClient: + return TestClient( + app, + raise_server_exceptions=True, + headers={"X-Test-Namespace": namespace_key}, + ) + + +def _agent_payload(agent_name: str) -> dict[str, Any]: + return { + "agent": { + "agent_name": agent_name, + "agent_description": "test agent", + "agent_version": "1.0", + }, + "steps": [], + } + + +def _evaluation_payload(agent_name: str) -> dict[str, Any]: + return { + "agent_name": agent_name, + "step": { + "type": "llm", + "name": "test-step", + "input": "x marks the spot", + "context": {}, + }, + "stage": "pre", + "target_type": "env", + "target_id": "prod", + } + + +def test_principal_namespace_scopes_management_and_runtime(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + register_a = ns_a.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + register_b = ns_b.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + assert register_a.status_code == 200, register_a.text + assert register_b.status_code == 200, register_b.text + + create_control = ns_a.put( + "/api/v1/controls", + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + assert create_control.status_code == 200, create_control.text + control_id = int(create_control.json()["control_id"]) + + policy = ns_a.put( + "/api/v1/policies", + json={"name": f"policy-{uuid.uuid4().hex[:12]}"}, + ) + assert policy.status_code == 200, policy.text + policy_id = int(policy.json()["policy_id"]) + attach_to_policy = ns_a.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert attach_to_policy.status_code == 200, attach_to_policy.text + + binding = ns_a.put( + "/api/v1/control-bindings", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": True, + }, + ) + assert binding.status_code == 200, binding.text + + assert ns_b.get(f"/api/v1/controls/{control_id}").status_code == 404 + assert ns_b.get(f"/api/v1/policies/{policy_id}/controls").status_code == 404 + assert ns_b.get("/api/v1/control-bindings").json()["bindings"] == [] + + eval_a = ns_a.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_a.status_code == 200, eval_a.text + assert eval_a.json()["is_safe"] is False + assert eval_a.json()["matches"][0]["control_id"] == control_id + + eval_b = ns_b.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_b.status_code == 200, eval_b.text + assert eval_b.json()["is_safe"] is True + + +def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + control_name = f"control-{uuid.uuid4().hex[:12]}" + payload = {"name": control_name, "data": VALID_CONTROL_PAYLOAD} + + assert ns_a.put("/api/v1/controls", json=payload).status_code == 200 + assert ns_b.put("/api/v1/controls", json=payload).status_code == 200 diff --git a/server/tests/test_target_merged_contract.py b/server/tests/test_target_merged_contract.py index 295a85e2..62891ba5 100644 --- a/server/tests/test_target_merged_contract.py +++ b/server/tests/test_target_merged_contract.py @@ -232,9 +232,9 @@ def test_target_binding_de_duplicated_against_direct_attachment( async def _insert_agent_in_namespace(async_db, *, name: str, namespace_key: str) -> None: """Insert an Agent row directly so the test can simulate a foreign namespace. - The endpoint's ``get_namespace_key`` returns the default namespace; this - helper sidesteps the resolver to seed an agent that the request-time - code path should not be able to reach. + The default test authorizer returns the default namespace; this helper + sidesteps the authorizer to seed an agent that the request-time code + path should not be able to reach. """ from agent_control_server.models import Agent From 6bc9b465851140e3148b7a6deb35e1f1fd0faf0d Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:20:00 +0530 Subject: [PATCH 07/42] chore(sdk-ts): regenerate client docs --- .../src/generated/funcs/agents-get-evaluator.ts | 3 +-- sdks/typescript/src/generated/funcs/agents-get.ts | 3 +-- .../typescript/src/generated/funcs/agents-init.ts | 1 + .../src/generated/funcs/agents-list-controls.ts | 2 +- .../src/generated/funcs/agents-list-evaluators.ts | 3 +-- .../typescript/src/generated/funcs/agents-list.ts | 2 +- .../src/generated/funcs/agents-update.ts | 1 + .../generated/funcs/control-bindings-create.ts | 6 +----- .../src/generated/funcs/control-bindings-list.ts | 3 +-- sdks/typescript/src/generated/sdk/agents.ts | 15 +++++++-------- .../src/generated/sdk/control-bindings.ts | 9 ++------- 11 files changed, 18 insertions(+), 30 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts index acb364eb..ceca1ec0 100644 --- a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts +++ b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts @@ -37,8 +37,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/funcs/agents-get.ts b/sdks/typescript/src/generated/funcs/agents-get.ts index 9724edbf..142f3062 100644 --- a/sdks/typescript/src/generated/funcs/agents-get.ts +++ b/sdks/typescript/src/generated/funcs/agents-get.ts @@ -38,8 +38,7 @@ import { Result } from "../types/fp.js"; * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 9d63358d..7150b2a4 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -51,6 +51,7 @@ import { Result } from "../types/fp.js"; * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls diff --git a/sdks/typescript/src/generated/funcs/agents-list-controls.ts b/sdks/typescript/src/generated/funcs/agents-list-controls.ts index 661c5509..d1e5b27d 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-controls.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-controls.ts @@ -53,7 +53,7 @@ import { Result } from "../types/fp.js"; * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters diff --git a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts index c4d8a4b2..4217e752 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts @@ -42,8 +42,7 @@ import { Result } from "../types/fp.js"; * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination diff --git a/sdks/typescript/src/generated/funcs/agents-list.ts b/sdks/typescript/src/generated/funcs/agents-list.ts index fda7574d..f887d0b5 100644 --- a/sdks/typescript/src/generated/funcs/agents-list.ts +++ b/sdks/typescript/src/generated/funcs/agents-list.ts @@ -42,7 +42,7 @@ import { Result } from "../types/fp.js"; * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info diff --git a/sdks/typescript/src/generated/funcs/agents-update.ts b/sdks/typescript/src/generated/funcs/agents-update.ts index e82644cf..aff9d827 100644 --- a/sdks/typescript/src/generated/funcs/agents-update.ts +++ b/sdks/typescript/src/generated/funcs/agents-update.ts @@ -40,6 +40,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 8412487e..71dee5a0 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -33,11 +33,7 @@ import { Result } from "../types/fp.js"; * Attach a control to an opaque external target. * * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * the active authorizer. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5e7e87c3..5c90c7c2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,8 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved by the active authorizer. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index a22f4209..0a70e128 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -39,7 +39,7 @@ export class Agents extends ClientSDK { * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info @@ -80,6 +80,7 @@ export class Agents extends ClientSDK { * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls @@ -106,8 +107,7 @@ export class Agents extends ClientSDK { * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list @@ -140,6 +140,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items @@ -185,7 +186,7 @@ export class Agents extends ClientSDK { * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters @@ -256,8 +257,7 @@ export class Agents extends ClientSDK { * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination @@ -287,8 +287,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index 5101ce74..dc6f20d3 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,8 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved by the active authorizer. */ async list( request?: @@ -46,11 +45,7 @@ export class ControlBindings extends ClientSDK { * Attach a control to an opaque external target. * * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * the active authorizer. */ async create( request: models.CreateControlBindingRequest, From 717c0b9d439677fd814e1105da722edc64340ba4 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 23:07:04 +0530 Subject: [PATCH 08/42] fix(server): address runtime auth review feedback --- .../funcs/auth-runtime-token-exchange.ts | 9 ++--- .../funcs/control-bindings-create.ts | 4 +- .../funcs/control-bindings-delete.ts | 2 +- .../generated/funcs/control-bindings-get.ts | 5 +-- .../generated/funcs/control-bindings-list.ts | 2 +- .../funcs/control-bindings-update.ts | 2 +- sdks/typescript/src/generated/sdk/auth.ts | 9 ++--- .../src/generated/sdk/control-bindings.ts | 15 ++++--- .../auth_framework/core.py | 2 - .../auth_framework/providers/header.py | 2 - .../agent_control_server/endpoints/auth.py | 19 +++++---- .../endpoints/control_bindings.py | 15 ++++--- .../endpoints/controls.py | 6 +-- server/src/agent_control_server/namespace.py | 23 ----------- .../agent_control_server/services/controls.py | 23 ++++++----- server/tests/test_auth_framework.py | 24 +++++++++++ server/tests/test_controls_auth.py | 12 +++--- .../test_runtime_token_exchange_endpoint.py | 36 ++++++++++++++++- server/tests/test_services_controls.py | 40 ++++++++++++++----- 19 files changed, 146 insertions(+), 104 deletions(-) delete mode 100644 server/src/agent_control_server/namespace.py diff --git a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts index 176693e3..7e8679c8 100644 --- a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts +++ b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts @@ -32,11 +32,10 @@ import { Result } from "../types/fp.js"; * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 71dee5a0..faf99923 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -32,8 +32,8 @@ import { Result } from "../types/fp.js"; * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * the active authorizer. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts index 9e4d1293..9872a9b4 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ export function controlBindingsDelete( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-get.ts b/sdks/typescript/src/generated/funcs/control-bindings-get.ts index dafb7c7c..88b4e419 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-get.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-get.ts @@ -34,12 +34,11 @@ import { Result } from "../types/fp.js"; * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ export function controlBindingsGet( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5c90c7c2..a87ca89f 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,7 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by the active authorizer. + * storage namespace is resolved from the authenticated request. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-update.ts b/sdks/typescript/src/generated/funcs/control-bindings-update.ts index b3faf800..b94520a2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-update.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-update.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ export function controlBindingsUpdate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/auth.ts b/sdks/typescript/src/generated/sdk/auth.ts index cf6de9ba..2d0cf74e 100644 --- a/sdks/typescript/src/generated/sdk/auth.ts +++ b/sdks/typescript/src/generated/sdk/auth.ts @@ -14,11 +14,10 @@ export class Auth extends ClientSDK { * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index dc6f20d3..5a5bcf2b 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,7 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by the active authorizer. + * storage namespace is resolved from the authenticated request. */ async list( request?: @@ -44,8 +44,8 @@ export class ControlBindings extends ClientSDK { * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * the active authorizer. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ async create( request: models.CreateControlBindingRequest, @@ -104,7 +104,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ async delete( request: @@ -125,12 +125,11 @@ export class ControlBindings extends ClientSDK { * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ async get( request: @@ -153,7 +152,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ async update( request: diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index e0ea6da7..058169de 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -52,11 +52,9 @@ class Operation(StrEnum): POLICIES_READ = "policies.read" POLICIES_CREATE = "policies.create" POLICIES_UPDATE = "policies.update" - POLICIES_DELETE = "policies.delete" AGENTS_READ = "agents.read" AGENTS_CREATE = "agents.create" AGENTS_UPDATE = "agents.update" - AGENTS_DELETE = "agents.delete" RUNTIME_USE = "runtime.use" diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index 228ec443..16760768 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -45,11 +45,9 @@ class AccessLevel(Enum): Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, Operation.POLICIES_CREATE: AccessLevel.ADMIN, Operation.POLICIES_UPDATE: AccessLevel.ADMIN, - Operation.POLICIES_DELETE: AccessLevel.ADMIN, Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, Operation.AGENTS_UPDATE: AccessLevel.ADMIN, - Operation.AGENTS_DELETE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index f80cd2fa..b1ade969 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,13 +2,13 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer authenticates the credential and -authorizes the implied -:data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint +target_id)``; the configured authorization provider authenticates the +credential and authorizes the implied +``runtime.token_exchange`` operation. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the returned token, which is verified locally by -:class:`LocalJwtVerifyProvider`. +the runtime JWT provider. """ from __future__ import annotations @@ -56,7 +56,7 @@ class RuntimeTokenExchangeResponse(BaseModel): async def _exchange_context(request: Request) -> dict[str, Any]: - """Surface target identifiers to the authorizer's context. + """Surface target identifiers to the authorization context. Reads the request body once. FastAPI caches the parsed body, so the endpoint's own Pydantic body model still binds normally. @@ -89,11 +89,10 @@ async def runtime_token_exchange( ) -> RuntimeTokenExchangeResponse: """Mint a short-lived runtime token for the requested target. - The caller's credential is authenticated and authorized by the - installed default authorizer; the resulting :class:`Principal` - supplies the actor identity and (when the upstream surfaces it) - the grant scopes and expiry. This endpoint then mints a local HS256 - token whose lifetime cannot outlive the upstream grant. + The caller's credential is authenticated and authorized before the + resolved principal supplies the actor identity, grant scopes, and + expiry. This endpoint then mints a local HS256 token whose lifetime + cannot outlive the grant. Runtime auth must be enabled via ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index d2fe4b44..87386723 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -102,8 +102,8 @@ async def create_control_binding( ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. - Each binding row is scoped to the request namespace as resolved by - the active authorizer. + Each binding row is scoped to the namespace associated with the + authenticated request. """ service = ControlBindingsService(db) binding = await service.create_binding( @@ -153,7 +153,7 @@ async def list_control_bindings( cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by the active authorizer. + storage namespace is resolved from the authenticated request. """ parsed_cursor: int | None if cursor is None: @@ -201,12 +201,11 @@ async def get_control_binding( """Read a single control binding by surrogate ID. Authorization is namespace-wide: the binding's target identifiers - are not forwarded to the upstream because they are only discoverable - after the row is loaded, and ``require_operation`` is single-pass. + are not available until after the row is loaded. Callers whose authorization model requires per-target permissions should use the natural-key endpoints (``PUT /by-key``, ``POST /by-key:delete``) and the target-filtered list endpoint, all - of which forward ``(target_type, target_id)`` to the authorizer. + of which include ``(target_type, target_id)`` in the request context. """ service = ControlBindingsService(db) binding = await service.get_binding_or_404( @@ -232,7 +231,7 @@ async def patch_control_binding( See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``PUT /by-key`` for target-scoped - upserts that forward the target to the authorizer. + upserts that include the target in the request context. """ service = ControlBindingsService(db) binding = await service.set_enabled( @@ -260,7 +259,7 @@ async def delete_control_binding( See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``POST /by-key:delete`` for - target-scoped detach that forwards the target to the authorizer. + target-scoped detach that includes the target in the request context. """ service = ControlBindingsService(db) await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 5b01593c..00d2b710 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -787,12 +787,12 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Authorized as CONTROLS_READ: validate exercises the materialization -# path but does not mutate stored control data. +# Authorized as CONTROLS_CREATE: validate exercises the same materialization +# path as create/update authoring flows, even though it does not save. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. diff --git a/server/src/agent_control_server/namespace.py b/server/src/agent_control_server/namespace.py deleted file mode 100644 index 30e30be5..00000000 --- a/server/src/agent_control_server/namespace.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Namespace resolution for request-scoped scoping. - -V1 always resolves to the default namespace. The function exists as a -single seam so a future change can switch every namespace-scoped -endpoint to a real per-request resolver without touching each call -site. Overriding the dependency in V1 is not supported: only this -binding/evaluation layer reads it; controls, agents, and policies still -write under the default namespace, so an override here would create -inconsistent rows. Future work will thread a single resolver through -every write path together. -""" - -from __future__ import annotations - -from .models import DEFAULT_NAMESPACE_KEY - - -def get_namespace_key() -> str: - """Return the namespace_key for the current request. - - V1 returns ``DEFAULT_NAMESPACE_KEY`` unconditionally. - """ - return DEFAULT_NAMESPACE_KEY diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 41a62282..e3a5fd26 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -20,7 +20,6 @@ from ..errors import APIValidationError, NotFoundError from ..models import ( - DEFAULT_NAMESPACE_KEY, Control, ControlBinding, ControlVersion, @@ -100,7 +99,7 @@ def __init__(self, db: AsyncSession) -> None: def create_control( self, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, name: str, data: dict[str, Any], ) -> Control: @@ -161,17 +160,19 @@ async def get_active_control_or_404( control_id: int, *, for_update: bool = False, - namespace_key: str | None = None, + namespace_key: str, ) -> Control: """Load an active control row or raise CONTROL_NOT_FOUND. - When ``namespace_key`` is supplied, the lookup is scoped to that - namespace; a control that exists only in another namespace - surfaces as 404 (non-disclosing). + The lookup is scoped to the supplied namespace; a control that + exists only in another namespace surfaces as 404 + (non-disclosing). """ - stmt = select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) - if namespace_key is not None: - stmt = stmt.where(Control.namespace_key == namespace_key) + stmt = select(Control).where( + Control.id == control_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -190,7 +191,7 @@ async def active_control_name_exists( self, name: str, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" @@ -537,7 +538,7 @@ async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 2d39bfa3..799b2d52 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -20,6 +20,7 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + LocalJwtVerifyProvider, NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( @@ -1029,6 +1030,29 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +@pytest.mark.asyncio +async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + try: + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.CONTROLS_READ), HttpUpstreamAuthProvider) + assert isinstance(get_authorizer(Operation.RUNTIME_USE), LocalJwtVerifyProvider) + runtime_config = auth_config.runtime_auth_config() + assert runtime_config is not None + assert runtime_config.secret == _TEST_SECRET + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index c0f17754..04f44ca4 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,10 +4,9 @@ import uuid -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import set_authorizer from agent_control_server.auth_framework.providers import NoAuthProvider +from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD @@ -198,19 +197,18 @@ def test_non_admin_cannot_delete_control( assert resp.status_code == 403, resp.text -def test_non_admin_can_validate_control_data( +def test_non_admin_cannot_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` requires ``CONTROLS_READ``.""" + """``/controls/validate`` requires ``CONTROLS_CREATE``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", json={"data": VALID_CONTROL_PAYLOAD}, ) - # Then: validation is allowed for authenticated non-admin callers - assert resp.status_code == 200, resp.text - assert resp.json()["success"] is True + # Then: validation is admin-only + assert resp.status_code == 403, resp.text def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: diff --git a/server/tests/test_runtime_token_exchange_endpoint.py b/server/tests/test_runtime_token_exchange_endpoint.py index 8d333a5c..1b1edae2 100644 --- a/server/tests/test_runtime_token_exchange_endpoint.py +++ b/server/tests/test_runtime_token_exchange_endpoint.py @@ -11,8 +11,6 @@ from datetime import UTC, datetime, timedelta import pytest -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import Operation, Principal from agent_control_server.auth_framework.config import ( RuntimeAuthConfig, @@ -25,6 +23,7 @@ from agent_control_server.auth_framework.providers import ( LocalJwtVerifyProvider, ) +from fastapi.testclient import TestClient _TEST_SECRET = "test-runtime-secret-12345678901234567890" @@ -180,6 +179,39 @@ async def test_exchange_then_verify_full_round_trip(client: TestClient, runtime_ assert principal.caller_id == "actor-rt" +def test_evaluation_rejects_runtime_jwt_for_wrong_target( + client: TestClient, + runtime_config_enabled, +): + """A runtime JWT minted for one target cannot be used for another target.""" + stub = _StubExchangeAuthorizer(actor_id="actor-rt", scopes=("runtime.use",)) + clear_authorizers() + set_authorizer(stub) + set_authorizer(LocalJwtVerifyProvider(secret=_TEST_SECRET), operation=Operation.RUNTIME_USE) + + exchange = client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": "log_stream", "target_id": "ls-allowed"}, + ) + assert exchange.status_code == 200, exchange.text + token = exchange.json()["token"] + + response = client.post( + "/api/v1/evaluation", + headers={"Authorization": f"Bearer {token}"}, + json={ + "agent_name": "agent", + "step": {"type": "llm", "name": "step", "input": "hello"}, + "stage": "pre", + "target_type": "log_stream", + "target_id": "ls-other", + }, + ) + + assert response.status_code == 403, response.text + assert response.json()["detail"] == "Runtime token target_id does not match the request." + + def test_exchange_endpoint_502_when_upstream_grant_already_expired( client: TestClient, runtime_config_enabled, diff --git a/server/tests/test_services_controls.py b/server/tests/test_services_controls.py index b858c527..3815f26b 100644 --- a/server/tests/test_services_controls.py +++ b/server/tests/test_services_controls.py @@ -8,10 +8,6 @@ import pytest from agent_control_models.errors import ErrorCode -from sqlalchemy import insert, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session - from agent_control_server.errors import APIValidationError from agent_control_server.models import ( DEFAULT_NAMESPACE_KEY, @@ -27,6 +23,9 @@ from agent_control_server.services.controls import ( ControlService, ) +from sqlalchemy import insert, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from .conftest import AsyncSessionTest, engine from .utils import VALID_CONTROL_PAYLOAD @@ -70,7 +69,11 @@ async def _create_versioned_control( async with AsyncSessionTest() as session: service = ControlService(session) - control = service.create_control(name=control_name, data=control_data) + control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, + name=control_name, + data=control_data, + ) await service.create_version( control, event_type="created", @@ -143,6 +146,7 @@ async def test_create_control_transaction_rollback_does_not_persist_control_or_v async with AsyncSessionTest() as session: service = ControlService(session) control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=control_name, data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -167,7 +171,10 @@ async def test_replace_control_data_transaction_rollback_preserves_prior_state() async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = "Should not persist" service.replace_control_data(control, data=updated_data) @@ -194,7 +201,10 @@ async def test_patch_mutation_transaction_rollback_preserves_prior_state() -> No async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.rename_control(control, name=f"{control_name}-renamed") service.set_control_enabled(control, enabled=False) await service.create_version( @@ -221,7 +231,10 @@ async def test_delete_control_transaction_rollback_preserves_active_state() -> N async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.mark_control_deleted(control, deleted_at=dt.datetime.now(dt.UTC)) await service.create_version( control, @@ -511,7 +524,10 @@ async def test_list_active_control_counts_by_agent_deduplicates_and_filters_inac await async_db.commit() # When: counting active controls for the agent - counts = await ControlService(async_db).list_active_control_counts_by_agent([agent.name]) + counts = await ControlService(async_db).list_active_control_counts_by_agent( + [agent.name], + namespace_key=DEFAULT_NAMESPACE_KEY, + ) # Then: active controls are deduplicated and inactive controls are excluded assert counts == {agent.name: 2} @@ -572,6 +588,7 @@ async def test_create_version_allocates_sequential_numbers_under_concurrent_muta async with AsyncSessionTest() as setup_session: setup_service = ControlService(setup_session) control = setup_service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=f"control-{uuid.uuid4()}", data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -592,7 +609,10 @@ async def mutate_and_version(description: str) -> None: async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = description service.replace_control_data(control, data=updated_data) From 45ceb25872fc919bb342c84901f57eb438e05233 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 16:43:39 +0530 Subject: [PATCH 09/42] feat(server): operator-configurable extra forwarded headers on HttpUpstream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The default forward set (X-API-Key, Authorization, Cookie) only covers credential headers Agent Control itself reads. Deployments whose upstream authenticates against a different header name (e.g., a deployer-specific API-key header) had no way to surface that credential through HttpUpstreamAuthProvider — the inbound header reached AC but never crossed the upstream call. Add an extra_forward_headers config field on HttpUpstreamConfig (defaulting to the empty tuple) that operators populate via the new AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS env var (comma- separated). The provider's _forward_headers iterates over the union of the default set and the extras, deduplicating case-insensitively so a duplicate name (cross-set or within extras) does not produce two copies on the wire. Tests: - forwards a configured extra header alongside defaults - default forward set unchanged when extras are empty - extras dedupe against defaults case-insensitively - _parse_extra_forward_headers parametric: None / empty / single / multiple / whitespace / empty-entries / case-folded duplicates - configure_auth_from_env threads the parsed tuple onto the provider Lint clean, typecheck clean, full server suite (747) green. --- .../auth_framework/config.py | 29 +++++ .../auth_framework/providers/http_upstream.py | 20 ++- .../endpoints/controls.py | 3 +- server/tests/test_auth_framework.py | 115 ++++++++++++++++++ 4 files changed, 163 insertions(+), 4 deletions(-) diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index c8f428dc..8c39a2ec 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -46,6 +46,7 @@ _UPSTREAM_TIMEOUT_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_TIMEOUT_SECONDS" _UPSTREAM_TOKEN_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN" _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" +_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS" # Runtime flow. _RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" @@ -196,6 +197,9 @@ def _build_default_provider() -> RequestAuthorizer: timeout = float(os.environ.get(_UPSTREAM_TIMEOUT_ENV, "5.0")) token = os.environ.get(_UPSTREAM_TOKEN_ENV) token_header = os.environ.get(_UPSTREAM_TOKEN_HEADER_ENV, "X-Agent-Control-Service-Token") + extra_forward_headers = _parse_extra_forward_headers( + os.environ.get(_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV) + ) _logger.info("Default auth provider: http_upstream url=%s", url) return HttpUpstreamAuthProvider( HttpUpstreamConfig( @@ -203,6 +207,7 @@ def _build_default_provider() -> RequestAuthorizer: timeout_seconds=timeout, service_token=token, service_token_header=token_header, + extra_forward_headers=extra_forward_headers, ) ) raise RuntimeError( @@ -210,6 +215,30 @@ def _build_default_provider() -> RequestAuthorizer: ) +def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: + """Parse a comma-separated header list into a deduplicated tuple. + + Empty / unset env var returns an empty tuple. Whitespace around each + name is stripped. Empty entries (e.g. ``"X-A,,X-B"``) are dropped. + Order is preserved; duplicates (case-insensitive) are dropped after + the first occurrence. + """ + if not raw or not raw.strip(): + return () + seen: set[str] = set() + result: list[str] = [] + for raw_name in raw.split(","): + name = raw_name.strip() + if not name: + continue + lower = name.lower() + if lower in seen: + continue + seen.add(lower) + result.append(name) + return tuple(result) + + def _resolve_runtime_mode() -> str: raw = os.environ.get(_RUNTIME_MODE_ENV) if raw is None or not raw.strip(): diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index 8d5c850c..78ed9ae2 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -60,7 +60,7 @@ _logger = get_logger(__name__) -_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") +_DEFAULT_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") class _UpstreamGrant(BaseModel): @@ -136,6 +136,17 @@ class HttpUpstreamConfig: service_token_header: str = "X-Agent-Control-Service-Token" + extra_forward_headers: tuple[str, ...] = () + """Additional inbound request headers to forward to the upstream + on top of the default ``(X-API-Key, Authorization, Cookie)`` set. + + Use this when the upstream authenticates via a header the provider + does not forward by default (e.g., a deployer-specific API-key + header). Header lookups against the inbound request are + case-insensitive; an empty or absent inbound header is silently + dropped. Names duplicating the default set or each other (after + case-folding) are deduplicated.""" + class HttpUpstreamAuthProvider(RequestAuthorizer): """Delegates authorization to an upstream HTTP service.""" @@ -190,7 +201,12 @@ async def authorize( def _forward_headers(self, request: Request) -> dict[str, str]: headers: dict[str, str] = {} - for name in _FORWARDED_HEADERS: + seen: set[str] = set() + for name in (*_DEFAULT_FORWARDED_HEADERS, *self._config.extra_forward_headers): + lower = name.lower() + if lower in seen: + continue + seen.add(lower) value = request.headers.get(name) if value is not None: headers[name] = value diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 00d2b710..b4fa8d0b 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -787,8 +787,7 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Authorized as CONTROLS_CREATE: validate exercises the same materialization -# path as create/update authoring flows, even though it does not save. +# Validation uses the authoring path, so require create access. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 799b2d52..dc3a1787 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -261,6 +261,75 @@ def factory(request: httpx.Request) -> httpx.Response: assert captured["headers"]["x-custom-token"] == "shh" +@pytest.mark.asyncio +async def test_http_upstream_forwards_extra_headers(): + # Given: a provider configured with an extra header in its forward list + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("X-Deployer-Auth",)}, + ) + + # When: the inbound request carries the extra header + inbound = _build_request(headers={"X-Deployer-Auth": "k_abc", "X-API-Key": "k1"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: both the default and the extra header reach the upstream + assert captured["headers"]["x-deployer-auth"] == "k_abc" + assert captured["headers"]["x-api-key"] == "k1" + + +@pytest.mark.asyncio +async def test_http_upstream_default_forward_set_unchanged(): + # Given: a provider with no extra_forward_headers + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream(factory) + + # When: the inbound carries an unlisted header alongside a default one + inbound = _build_request( + headers={"X-API-Key": "k1", "X-Deployer-Auth": "should-not-forward"} + ) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: only the default-set header reaches the upstream + assert captured["headers"].get("x-api-key") == "k1" + assert "x-deployer-auth" not in captured["headers"] + + +@pytest.mark.asyncio +async def test_http_upstream_extra_forward_dedupes_against_defaults(): + # Given: extra list duplicates a default header (different case) + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("x-api-key", "Authorization")}, + ) + + # When: inbound has both + inbound = _build_request(headers={"X-API-Key": "k1", "Authorization": "Bearer t"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: each header appears exactly once on the upstream request + forwarded = captured["headers"] + assert sum(1 for k in forwarded if k.lower() == "x-api-key") == 1 + assert sum(1 for k in forwarded if k.lower() == "authorization") == 1 + + @pytest.mark.asyncio @pytest.mark.parametrize( "status, expected", @@ -1053,6 +1122,52 @@ async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): await auth_config.teardown_auth() +@pytest.mark.parametrize( + "raw, expected", + [ + (None, ()), + ("", ()), + (" ", ()), + ("X-One", ("X-One",)), + ("X-One,X-Two", ("X-One", "X-Two")), + (" X-One , X-Two ", ("X-One", "X-Two")), + ("X-One,,X-Two", ("X-One", "X-Two")), + ("X-One,x-one,X-One", ("X-One",)), + ("X-A,X-B,x-a,X-C,X-b", ("X-A", "X-B", "X-C")), + ], +) +def test_parse_extra_forward_headers(raw, expected): + from agent_control_server.auth_framework.config import _parse_extra_forward_headers + + assert _parse_extra_forward_headers(raw) == expected + + +@pytest.mark.asyncio +async def test_configure_http_upstream_extra_forward_headers_env(monkeypatch): + """Setting the env var threads extra_forward_headers into the provider.""" + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv( + "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS", + "X-Deployer-Auth, X-Deployer-Trace", + ) + + try: + auth_config.configure_auth_from_env() + provider = get_authorizer(Operation.CONTROLS_READ) + assert isinstance(provider, HttpUpstreamAuthProvider) + assert provider._config.extra_forward_headers == ( + "X-Deployer-Auth", + "X-Deployer-Trace", + ) + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config From 479ca86a3cb34979dc0e3a7a7deaee0b58ae6267 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 21:36:45 +0530 Subject: [PATCH 10/42] fix(server): preserve default runtime auth fallback --- .../auth_framework/config.py | 37 +++++++++------- server/tests/test_auth_framework.py | 44 +++++++++++++++++-- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 8c39a2ec..595c3117 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -16,8 +16,8 @@ :class:`NoAuthProvider`, ``api_key`` uses :class:`HeaderAuthProvider`, and ``jwt`` uses :class:`LocalJwtVerifyProvider`. When the mode is unset, startup - preserves historical behavior by selecting ``jwt`` if - ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, otherwise ``api_key``. + selects ``jwt`` if ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set; + otherwise runtime falls through to the default authorizer. The ``runtime.token_exchange`` operation continues to flow through the default authorizer because the exchange itself is shaped like a management call (forward credential, get grant). @@ -96,10 +96,11 @@ def configure_auth_from_env() -> None: Runtime flow: - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. - - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key`` (default when no runtime - token secret is configured): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key``: :class:`HeaderAuthProvider`. - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token secret is configured): :class:`LocalJwtVerifyProvider`. + - unset mode without a runtime token secret: fall through to the default + authorizer. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -121,20 +122,26 @@ def configure_auth_from_env() -> None: set_authorizer(default) _active_providers.append(default) - runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) - if runtime_mode == "jwt": + if runtime_mode == "default": _logger.info( - "Runtime auth provider: jwt override installed for %s", + "Runtime auth provider: default authorizer handles %s", Operation.RUNTIME_USE.value, ) else: - _logger.info( - "Runtime auth provider: %s override installed for %s", - runtime_mode, - Operation.RUNTIME_USE.value, - ) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": + _logger.info( + "Runtime auth provider: jwt override installed for %s", + Operation.RUNTIME_USE.value, + ) + else: + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, + Operation.RUNTIME_USE.value, + ) async def teardown_auth() -> None: @@ -242,7 +249,7 @@ def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: def _resolve_runtime_mode() -> str: raw = os.environ.get(_RUNTIME_MODE_ENV) if raw is None or not raw.strip(): - return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "api_key" + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "default" mode = raw.strip().lower() if mode in {"none", "no_auth"}: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index dc3a1787..20c58aed 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -7,7 +7,6 @@ import httpx import pytest - from agent_control_server.auth_framework.core import ( Operation, Principal, @@ -700,7 +699,6 @@ def test_runtime_token_rejects_naive_upstream_expires_at(): def test_runtime_token_rejects_management_token_passed_to_runtime_verify(): """A token without ``domain=runtime`` must be rejected by runtime verify.""" import jwt - from agent_control_server.auth_framework.runtime_token import ( RuntimeTokenError, verify_runtime_token, @@ -1053,13 +1051,13 @@ def test_build_default_provider_accepts_none_mode(monkeypatch): assert isinstance(auth_config._build_default_provider(), NoAuthProvider) -def test_resolve_runtime_mode_defaults_to_api_key_without_secret(monkeypatch): +def test_resolve_runtime_mode_defaults_to_default_without_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) - assert auth_config._resolve_runtime_mode() == "api_key" + assert auth_config._resolve_runtime_mode() == "default" def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): @@ -1099,6 +1097,44 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +def test_configure_runtime_unset_preserves_no_auth_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +@pytest.mark.asyncio +async def test_configure_runtime_unset_preserves_http_upstream_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + try: + auth_config.configure_auth_from_env() + + default_provider = get_authorizer(Operation.CONTROLS_READ) + runtime_provider = get_authorizer(Operation.RUNTIME_USE) + assert isinstance(default_provider, HttpUpstreamAuthProvider) + assert runtime_provider is default_provider + assert auth_config.runtime_auth_config() is None + finally: + await auth_config.teardown_auth() + + @pytest.mark.asyncio async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): from agent_control_server.auth_framework import config as auth_config From 5e6811bc13492cd94ff0f1f764c37cc9d15da9fc Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Wed, 6 May 2026 20:38:27 +0530 Subject: [PATCH 11/42] feat(server): migrate /controls + /control-templates onto auth framework Mirrors #204's bindings migration: replaces require_admin_key and router-level require_api_key with require_operation(CONTROLS_*) on every protected route on /controls and on /control-templates/render. Both routers now mount with the non-validating get_api_key_from_header so the framework owns authentication and authorization, with the extractor attached purely so the generated OpenAPI advertises X-API-Key. GET /controls/schema is intentionally left without a require_operation dependency: it returns a static model schema with no tenant state and routing it through the framework would force the upstream provider to handle a meta-only operation that has no permission semantics. POST /controls/validate and POST /control-templates/render are wired to CONTROLS_CREATE rather than CONTROLS_READ. Both exercise the authoring materialization path and exist to support the create / set- data flow; a caller who cannot create controls has no use for the result. Backwards-incompatible for OSS deployments that previously called these routes with non-admin keys; deployments that want the old behavior can override with HeaderAuthProvider(operation_access={...}). Storage namespace continues to come from get_namespace_key, matching the bindings migration in #204. The unified principal-derived cutover across /controls, /policies, /agents, and /evaluation is a follow-up. --- .../generated/funcs/controls-get-schema.ts | 6 + .../funcs/controls-render-template.ts | 4 + .../generated/funcs/controls-validate-data.ts | 5 + sdks/typescript/src/generated/sdk/controls.ts | 15 + .../endpoints/controls.py | 50 ++- server/src/agent_control_server/main.py | 13 +- server/tests/test_controls_auth.py | 365 ++++++++++++++++++ 7 files changed, 445 insertions(+), 13 deletions(-) create mode 100644 server/tests/test_controls_auth.py diff --git a/sdks/typescript/src/generated/funcs/controls-get-schema.ts b/sdks/typescript/src/generated/funcs/controls-get-schema.ts index ca5442bd..a6ea27cd 100644 --- a/sdks/typescript/src/generated/funcs/controls-get-schema.ts +++ b/sdks/typescript/src/generated/funcs/controls-get-schema.ts @@ -27,6 +27,12 @@ import { Result } from "../types/fp.js"; * * @remarks * Return the canonical JSON schema for ControlDefinition. + * + * Intentionally has no ``require_operation`` dependency: the schema is + * static metadata derived from the model class and exposes no tenant + * state. Routing it through the auth framework would force callers + * (and the upstream authorizer) to handle a meta-only operation that + * has no permission semantics. */ export function controlsGetSchema( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/controls-render-template.ts b/sdks/typescript/src/generated/funcs/controls-render-template.ts index a8998d0e..6f5d4d0e 100644 --- a/sdks/typescript/src/generated/funcs/controls-render-template.ts +++ b/sdks/typescript/src/generated/funcs/controls-render-template.ts @@ -31,6 +31,10 @@ import { Result } from "../types/fp.js"; * * @remarks * Render a template-backed control without persisting it. + * + * Authorized as ``controls.create``: rendering is part of the authoring + * flow (the result feeds the create / update endpoints), so a caller + * who cannot create controls has no use for the materialized output. */ export function controlsRenderTemplate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/controls-validate-data.ts b/sdks/typescript/src/generated/funcs/controls-validate-data.ts index 70d9a1f0..f1084887 100644 --- a/sdks/typescript/src/generated/funcs/controls-validate-data.ts +++ b/sdks/typescript/src/generated/funcs/controls-validate-data.ts @@ -32,6 +32,11 @@ import { Result } from "../types/fp.js"; * @remarks * Validate control configuration data without saving it. * + * Authorized as ``controls.create`` rather than ``controls.read``: + * validation exercises the full create / update materialization path + * and exists to support authoring, so a caller who cannot create + * controls has no use for the result. + * * Args: * request: Control configuration data to validate * db: Database session (injected) diff --git a/sdks/typescript/src/generated/sdk/controls.ts b/sdks/typescript/src/generated/sdk/controls.ts index ed3cf8db..28edbda3 100644 --- a/sdks/typescript/src/generated/sdk/controls.ts +++ b/sdks/typescript/src/generated/sdk/controls.ts @@ -25,6 +25,10 @@ export class Controls extends ClientSDK { * * @remarks * Render a template-backed control without persisting it. + * + * Authorized as ``controls.create``: rendering is part of the authoring + * flow (the result feeds the create / update endpoints), so a caller + * who cannot create controls has no use for the materialized output. */ async renderTemplate( request: models.RenderControlTemplateRequest, @@ -110,6 +114,12 @@ export class Controls extends ClientSDK { * * @remarks * Return the canonical JSON schema for ControlDefinition. + * + * Intentionally has no ``require_operation`` dependency: the schema is + * static metadata derived from the model class and exposes no tenant + * state. Routing it through the auth framework would force callers + * (and the upstream authorizer) to handle a meta-only operation that + * has no permission semantics. */ async getSchema( options?: RequestOptions, @@ -126,6 +136,11 @@ export class Controls extends ClientSDK { * @remarks * Validate control configuration data without saving it. * + * Authorized as ``controls.create`` rather than ``controls.read``: + * validation exercises the full create / update materialization path + * and exists to support authoring, so a caller who cannot create + * controls has no use for the result. + * * Args: * request: Control configuration data to validate * db: Database session (injected) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 6208652b..a6509a2e 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -33,7 +33,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, @@ -446,8 +446,14 @@ async def _validate_control_definition( async def render_control_template( request: RenderControlTemplateRequest, db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> RenderControlTemplateResponse: - """Render a template-backed control without persisting it.""" + """Render a template-backed control without persisting it. + + Authorized as ``controls.create``: rendering is part of the authoring + flow (the result feeds the create / update endpoints), so a caller + who cannot create controls has no use for the materialized output. + """ control_def = await _render_and_validate_template_input( TemplateControlInput( template=request.template, @@ -461,13 +467,14 @@ async def render_control_template( @router.put( "", - dependencies=[Depends(require_admin_key)], response_model=CreateControlResponse, summary="Create a new control", response_description="Created control ID", ) async def create_control( - request: CreateControlRequest, db: AsyncSession = Depends(get_async_db) + request: CreateControlRequest, + db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> CreateControlResponse: """ Create a new control with a unique name. @@ -550,7 +557,14 @@ async def create_control( response_description="JSON schema for ControlDefinition", ) async def get_control_schema() -> GetControlSchemaResponse: - """Return the canonical JSON schema for ControlDefinition.""" + """Return the canonical JSON schema for ControlDefinition. + + Intentionally has no ``require_operation`` dependency: the schema is + static metadata derived from the model class and exposes no tenant + state. Routing it through the auth framework would force callers + (and the upstream authorizer) to handle a meta-only operation that + has no permission semantics. + """ return GetControlSchemaResponse( schema=ControlDefinition.model_json_schema(by_alias=True) ) @@ -563,7 +577,9 @@ async def get_control_schema() -> GetControlSchemaResponse: response_description="Control metadata and configuration", ) async def get_control( - control_id: int, db: AsyncSession = Depends(get_async_db) + control_id: int, + db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlResponse: """ Retrieve a control by ID including its name and configuration data. @@ -600,7 +616,9 @@ async def get_control( response_description="Control data payload", ) async def get_control_data( - control_id: int, db: AsyncSession = Depends(get_async_db) + control_id: int, + db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlDataResponse: """ Retrieve the configuration data for a control. @@ -640,6 +658,7 @@ async def list_control_versions( ), limit: int = Query(_DEFAULT_PAGINATION_LIMIT, ge=1, le=_MAX_PAGINATION_LIMIT), db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlVersionsResponse: """List control versions ordered newest-first using cursor-based pagination.""" page = await ControlService(db).list_versions(control_id, cursor=cursor, limit=limit) @@ -673,6 +692,7 @@ async def get_control_version( control_id: int, version_num: int, db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlVersionResponse: """Return a specific control version, including its raw persisted snapshot.""" version = await ControlService(db).get_version_or_404(control_id, version_num) @@ -687,7 +707,6 @@ async def get_control_version( @router.put( "/{control_id}/data", - dependencies=[Depends(require_admin_key)], response_model=SetControlDataResponse, summary="Update control configuration data", response_description="Success confirmation", @@ -696,6 +715,7 @@ async def set_control_data( control_id: int, request: SetControlDataRequest, db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> SetControlDataResponse: """ Update the configuration data for a control. @@ -758,11 +778,18 @@ async def set_control_data( response_description="Validation result", ) async def validate_control_data( - request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db) + request: ValidateControlDataRequest, + db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. + Authorized as ``controls.create`` rather than ``controls.read``: + validation exercises the full create / update materialization path + and exists to support authoring, so a caller who cannot create + controls has no use for the result. + Args: request: Control configuration data to validate db: Database session (injected) @@ -798,6 +825,7 @@ async def list_controls( execution: str | None = Query(None, description="Filter by execution ('server' or 'sdk')"), tag: str | None = Query(None, description="Filter by tag"), db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlsResponse: """ List all controls with optional filtering and cursor-based pagination. @@ -884,7 +912,6 @@ async def list_controls( @router.delete( "/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=DeleteControlResponse, summary="Delete a control", response_description="Deletion confirmation with dissociation info", @@ -897,6 +924,7 @@ async def delete_control( "If false, fail if control is associated with any policy or agent.", ), db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), ) -> DeleteControlResponse: """ Delete a control by ID. @@ -1035,7 +1063,6 @@ async def delete_control( @router.patch( "/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=PatchControlResponse, summary="Update control metadata", response_description="Updated control information", @@ -1044,6 +1071,7 @@ async def patch_control( control_id: int, request: PatchControlRequest, db: AsyncSession = Depends(get_async_db), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> PatchControlResponse: """ Update control metadata (name and/or enabled status). diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 76416e04..718d7b04 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -273,9 +273,15 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(require_api_key)], ) app.include_router( + # ``/controls`` CRUD goes through the auth framework on each + # endpoint (``require_operation(Operation.CONTROLS_*)``); see the + # ``control_binding_router`` rationale below for the + # ``get_api_key_from_header`` mounting pattern. The single route on + # this router without ``require_operation`` is ``GET /controls/schema``, + # which is intentionally public meta — see its endpoint docstring. control_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( # The auth framework on each endpoint owns authentication and @@ -300,9 +306,12 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(get_api_key_from_header)], ) app.include_router( + # Control templates: ``/render`` is on the auth framework via + # ``require_operation(Operation.CONTROLS_CREATE)``; same mounting + # pattern as the controls and control-bindings routers. control_template_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( evaluation_router, diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py new file mode 100644 index 00000000..0357421f --- /dev/null +++ b/server/tests/test_controls_auth.py @@ -0,0 +1,365 @@ +"""HTTP-level coverage for the auth seam on ``/controls`` and +``/control-templates``. + +These tests exercise the wiring of ``require_operation`` on each route +through the default ``HeaderAuthProvider``: read operations require any +valid credential (``CONTROLS_READ`` -> ``AUTHENTICATED``), write +operations require an admin credential +(``CONTROLS_CREATE`` / ``CONTROLS_UPDATE`` / ``CONTROLS_DELETE`` -> +``ADMIN``), and ``GET /controls/schema`` is intentionally outside the +framework so it stays publicly reachable. + +The provider primitives themselves are exercised in +``tests/test_auth_framework.py``; this file focuses on each endpoint +calling the right ``Operation`` so a future change to the operation +mapping is caught at the route level. +""" + +from __future__ import annotations + +import uuid + +import pytest +from fastapi.testclient import TestClient + +from agent_control_server.config import auth_settings + +from .utils import VALID_CONTROL_PAYLOAD + + +_CONTROLS_URL = "/api/v1/controls" +_TEMPLATES_URL = "/api/v1/control-templates" + + +def _create_control(client: TestClient, name: str | None = None) -> int: + payload = { + "name": name or f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + } + resp = client.put(_CONTROLS_URL, json=payload) + assert resp.status_code == 200, resp.text + return int(resp.json()["control_id"]) + + +# --------------------------------------------------------------------------- +# /controls/schema is intentionally public meta — no require_operation. +# --------------------------------------------------------------------------- + + +def test_schema_endpoint_reachable_without_credentials( + unauthenticated_client: TestClient, +) -> None: + # Given: a client that never sends an API key + # When: the schema endpoint is fetched + resp = unauthenticated_client.get(f"{_CONTROLS_URL}/schema") + + # Then: the canonical ControlDefinition schema is returned + assert resp.status_code == 200, resp.text + body = resp.json() + assert "schema" in body + assert isinstance(body["schema"], dict) + + +def test_schema_endpoint_reachable_with_admin_key(client: TestClient) -> None: + # Given: an admin client + # When: the schema endpoint is fetched + resp = client.get(f"{_CONTROLS_URL}/schema") + + # Then: the schema is returned (header is ignored, route is public) + assert resp.status_code == 200, resp.text + + +def test_schema_endpoint_reachable_with_non_admin_key( + non_admin_client: TestClient, +) -> None: + # Given: a non-admin client + # When: the schema endpoint is fetched + resp = non_admin_client.get(f"{_CONTROLS_URL}/schema") + + # Then: the schema is returned + assert resp.status_code == 200, resp.text + + +# --------------------------------------------------------------------------- +# CONTROLS_READ operations: AUTHENTICATED suffices. +# --------------------------------------------------------------------------- + + +def test_non_admin_can_list_controls( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + _create_control(client) + + # When: a non-admin lists controls + resp = non_admin_client.get(_CONTROLS_URL) + + # Then: the list is returned + assert resp.status_code == 200, resp.text + + +def test_non_admin_can_get_control( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin reads it + resp = non_admin_client.get(f"{_CONTROLS_URL}/{control_id}") + + # Then: the control is returned + assert resp.status_code == 200, resp.text + + +def test_non_admin_can_get_control_data( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin reads its data + resp = non_admin_client.get(f"{_CONTROLS_URL}/{control_id}/data") + + # Then: the data is returned + assert resp.status_code == 200, resp.text + + +def test_non_admin_can_list_versions( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control with at least one version (creation) + control_id = _create_control(client) + + # When: a non-admin lists versions + resp = non_admin_client.get(f"{_CONTROLS_URL}/{control_id}/versions") + + # Then: the version list is returned + assert resp.status_code == 200, resp.text + + +def test_non_admin_can_get_specific_version( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control (version 1 = "created") + control_id = _create_control(client) + + # When: a non-admin reads version 1 + resp = non_admin_client.get(f"{_CONTROLS_URL}/{control_id}/versions/1") + + # Then: the version snapshot is returned + assert resp.status_code == 200, resp.text + + +# --------------------------------------------------------------------------- +# CONTROLS_CREATE / UPDATE / DELETE: ADMIN required. +# --------------------------------------------------------------------------- + + +def test_non_admin_cannot_create_control(non_admin_client: TestClient) -> None: + # When: a non-admin attempts to create + resp = non_admin_client.put( + _CONTROLS_URL, + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + + # Then: the request is forbidden by the auth seam + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_set_control_data( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin attempts to replace its data + resp = non_admin_client.put( + f"{_CONTROLS_URL}/{control_id}/data", + json={"data": VALID_CONTROL_PAYLOAD}, + ) + + # Then: the request is forbidden + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_patch_control( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin attempts to rename it + resp = non_admin_client.patch( + f"{_CONTROLS_URL}/{control_id}", + json={"name": "renamed"}, + ) + + # Then: the request is forbidden + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_delete_control( + non_admin_client: TestClient, client: TestClient +) -> None: + # Given: an existing control + control_id = _create_control(client) + + # When: a non-admin attempts to delete it + resp = non_admin_client.delete(f"{_CONTROLS_URL}/{control_id}") + + # Then: the request is forbidden + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_validate_control_data( + non_admin_client: TestClient, +) -> None: + """``/controls/validate`` is wired to ``CONTROLS_CREATE`` rather than + ``CONTROLS_READ`` because validation exercises the create / update + materialization path; a caller who cannot create has no use for the + result. This pins that decision so it can't drift to ``READ`` + accidentally. + """ + # When: a non-admin attempts to validate a draft payload + resp = non_admin_client.post( + f"{_CONTROLS_URL}/validate", + json={"data": VALID_CONTROL_PAYLOAD}, + ) + + # Then: validation is admin-only + assert resp.status_code == 403, resp.text + + +def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: + """``/control-templates/render`` is wired to ``CONTROLS_CREATE`` for + the same reason as ``/validate``: rendering is part of the authoring + flow. The 422 path is not exercised here — only the auth gate is + asserted, so the request shape need not validate. + """ + # When: a non-admin attempts to render a template + resp = non_admin_client.post( + f"{_TEMPLATES_URL}/render", + json={"template": {}, "template_values": {}}, + ) + + # Then: rendering is admin-only — the auth gate fires before body + # validation reaches the materialization path + assert resp.status_code == 403, resp.text + + +# --------------------------------------------------------------------------- +# Unauthenticated requests are rejected on every framework-protected route. +# --------------------------------------------------------------------------- + + +def test_unauthenticated_cannot_list_controls( + unauthenticated_client: TestClient, +) -> None: + # When: a client without credentials lists controls + resp = unauthenticated_client.get(_CONTROLS_URL) + + # Then: the request is rejected + assert resp.status_code == 401, resp.text + + +def test_unauthenticated_cannot_create_control( + unauthenticated_client: TestClient, +) -> None: + # When: a client without credentials attempts to create + resp = unauthenticated_client.put( + _CONTROLS_URL, + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + + # Then: the request is rejected + assert resp.status_code == 401, resp.text + + +def test_unauthenticated_cannot_validate( + unauthenticated_client: TestClient, +) -> None: + # When: a client without credentials attempts to validate + resp = unauthenticated_client.post( + f"{_CONTROLS_URL}/validate", + json={"data": VALID_CONTROL_PAYLOAD}, + ) + + # Then: the request is rejected + assert resp.status_code == 401, resp.text + + +def test_unauthenticated_cannot_render_template( + unauthenticated_client: TestClient, +) -> None: + # When: a client without credentials attempts to render + resp = unauthenticated_client.post( + f"{_TEMPLATES_URL}/render", + json={"template": {}, "template_values": {}}, + ) + + # Then: the request is rejected + assert resp.status_code == 401, resp.text + + +# --------------------------------------------------------------------------- +# No-auth deployment mode: api_key_enabled=False bypasses every gate. +# --------------------------------------------------------------------------- + + +def test_no_auth_mode_allows_writes_without_credentials( + unauthenticated_client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When ``api_key_enabled`` is False, the ``HeaderAuthProvider`` + short-circuits to a non-admin ``Principal`` for every operation, + including admin-level writes. This pins the "no auth" deployment + path so a future refactor can't silently start enforcing. + """ + # Given: api_key_enabled is False (single-tenant OSS dev mode) + monkeypatch.setattr(auth_settings, "api_key_enabled", False) + + # When: an unauthenticated client creates a control + resp = unauthenticated_client.put( + _CONTROLS_URL, + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + + # Then: the create succeeds because auth is disabled at the provider + assert resp.status_code == 200, resp.text + assert "control_id" in resp.json() + + +# --------------------------------------------------------------------------- +# Project-scoped API key deny — pending header forwarding follow-up. +# --------------------------------------------------------------------------- + + +@pytest.mark.skip( + reason=( + "Requires the upstream auth provider to forward an additional " + "configurable credential header. The default forward set is " + "fixed to (X-API-Key, Authorization, Cookie); deployments that " + "use a different credential header can't surface a " + "project-scoped credential to upstream until that becomes " + "configurable. Re-enable when the follow-up PR adds an " + "operator-configurable extra forward list." + ) +) +def test_project_scoped_credential_denied_on_org_scoped_controls() -> None: + """Stub for the deny-test promised by the upstream provider's + response contract: a project-scoped credential calling an + org-scoped operation (``controls.*``) should resolve to a 403 from + the upstream. The end-to-end path is unreachable today because the + provider's credential-forward list is not configurable; tracked as + the next follow-up after this PR. + """ + pytest.fail("test stub — see skip reason") From 0a2092d7a28455c2fceabbbd14f8a7a85a8152b8 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Wed, 6 May 2026 21:19:03 +0530 Subject: [PATCH 12/42] fix(server): keep public docstrings API-level on migrated controls routes Move auth-framework rationale on /controls/schema, /controls/validate, and /control-templates/render from route docstrings into normal code comments. The docstrings flow into the generated TypeScript SDK as public API documentation, so internal terminology like ``require_operation`` and "upstream authorizer" should not appear there. Function-level comments preserve the rationale for readers of the source. Also remove the skipped placeholder test for the project-scoped credential deny scenario; that scenario depends on a deployment-side provider configuration that is not part of the OSS server, so tracking it as a permanent skipped test in this repo was the wrong home for it. Regenerate the TypeScript SDK to drop the leaked rationale lines. --- .../generated/funcs/controls-get-schema.ts | 6 -- .../funcs/controls-render-template.ts | 4 -- .../generated/funcs/controls-validate-data.ts | 5 -- sdks/typescript/src/generated/sdk/controls.ts | 15 ----- .../endpoints/controls.py | 24 ++----- server/src/agent_control_server/main.py | 11 +--- server/tests/test_controls_auth.py | 63 ++----------------- 7 files changed, 13 insertions(+), 115 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/controls-get-schema.ts b/sdks/typescript/src/generated/funcs/controls-get-schema.ts index a6ea27cd..ca5442bd 100644 --- a/sdks/typescript/src/generated/funcs/controls-get-schema.ts +++ b/sdks/typescript/src/generated/funcs/controls-get-schema.ts @@ -27,12 +27,6 @@ import { Result } from "../types/fp.js"; * * @remarks * Return the canonical JSON schema for ControlDefinition. - * - * Intentionally has no ``require_operation`` dependency: the schema is - * static metadata derived from the model class and exposes no tenant - * state. Routing it through the auth framework would force callers - * (and the upstream authorizer) to handle a meta-only operation that - * has no permission semantics. */ export function controlsGetSchema( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/controls-render-template.ts b/sdks/typescript/src/generated/funcs/controls-render-template.ts index 6f5d4d0e..a8998d0e 100644 --- a/sdks/typescript/src/generated/funcs/controls-render-template.ts +++ b/sdks/typescript/src/generated/funcs/controls-render-template.ts @@ -31,10 +31,6 @@ import { Result } from "../types/fp.js"; * * @remarks * Render a template-backed control without persisting it. - * - * Authorized as ``controls.create``: rendering is part of the authoring - * flow (the result feeds the create / update endpoints), so a caller - * who cannot create controls has no use for the materialized output. */ export function controlsRenderTemplate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/controls-validate-data.ts b/sdks/typescript/src/generated/funcs/controls-validate-data.ts index f1084887..70d9a1f0 100644 --- a/sdks/typescript/src/generated/funcs/controls-validate-data.ts +++ b/sdks/typescript/src/generated/funcs/controls-validate-data.ts @@ -32,11 +32,6 @@ import { Result } from "../types/fp.js"; * @remarks * Validate control configuration data without saving it. * - * Authorized as ``controls.create`` rather than ``controls.read``: - * validation exercises the full create / update materialization path - * and exists to support authoring, so a caller who cannot create - * controls has no use for the result. - * * Args: * request: Control configuration data to validate * db: Database session (injected) diff --git a/sdks/typescript/src/generated/sdk/controls.ts b/sdks/typescript/src/generated/sdk/controls.ts index 28edbda3..ed3cf8db 100644 --- a/sdks/typescript/src/generated/sdk/controls.ts +++ b/sdks/typescript/src/generated/sdk/controls.ts @@ -25,10 +25,6 @@ export class Controls extends ClientSDK { * * @remarks * Render a template-backed control without persisting it. - * - * Authorized as ``controls.create``: rendering is part of the authoring - * flow (the result feeds the create / update endpoints), so a caller - * who cannot create controls has no use for the materialized output. */ async renderTemplate( request: models.RenderControlTemplateRequest, @@ -114,12 +110,6 @@ export class Controls extends ClientSDK { * * @remarks * Return the canonical JSON schema for ControlDefinition. - * - * Intentionally has no ``require_operation`` dependency: the schema is - * static metadata derived from the model class and exposes no tenant - * state. Routing it through the auth framework would force callers - * (and the upstream authorizer) to handle a meta-only operation that - * has no permission semantics. */ async getSchema( options?: RequestOptions, @@ -136,11 +126,6 @@ export class Controls extends ClientSDK { * @remarks * Validate control configuration data without saving it. * - * Authorized as ``controls.create`` rather than ``controls.read``: - * validation exercises the full create / update materialization path - * and exists to support authoring, so a caller who cannot create - * controls has no use for the result. - * * Args: * request: Control configuration data to validate * db: Database session (injected) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index a6509a2e..fcb7cb18 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -443,17 +443,13 @@ async def _validate_control_definition( summary="Render a control template preview", response_description="Rendered control preview", ) +# Rendering is part of the authoring flow, so require create access. async def render_control_template( request: RenderControlTemplateRequest, db: AsyncSession = Depends(get_async_db), _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> RenderControlTemplateResponse: - """Render a template-backed control without persisting it. - - Authorized as ``controls.create``: rendering is part of the authoring - flow (the result feeds the create / update endpoints), so a caller - who cannot create controls has no use for the materialized output. - """ + """Render a template-backed control without persisting it.""" control_def = await _render_and_validate_template_input( TemplateControlInput( template=request.template, @@ -556,15 +552,9 @@ async def create_control( summary="Get control definition JSON schema", response_description="JSON schema for ControlDefinition", ) +# Public schema metadata: no tenant state, no auth operation. async def get_control_schema() -> GetControlSchemaResponse: - """Return the canonical JSON schema for ControlDefinition. - - Intentionally has no ``require_operation`` dependency: the schema is - static metadata derived from the model class and exposes no tenant - state. Routing it through the auth framework would force callers - (and the upstream authorizer) to handle a meta-only operation that - has no permission semantics. - """ + """Return the canonical JSON schema for ControlDefinition.""" return GetControlSchemaResponse( schema=ControlDefinition.model_json_schema(by_alias=True) ) @@ -777,6 +767,7 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) +# Validation uses the authoring path, so require create access. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), @@ -785,11 +776,6 @@ async def validate_control_data( """ Validate control configuration data without saving it. - Authorized as ``controls.create`` rather than ``controls.read``: - validation exercises the full create / update materialization path - and exists to support authoring, so a caller who cannot create - controls has no use for the result. - Args: request: Control configuration data to validate db: Database session (injected) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 718d7b04..bc1bf04b 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -273,12 +273,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(require_api_key)], ) app.include_router( - # ``/controls`` CRUD goes through the auth framework on each - # endpoint (``require_operation(Operation.CONTROLS_*)``); see the - # ``control_binding_router`` rationale below for the - # ``get_api_key_from_header`` mounting pattern. The single route on - # this router without ``require_operation`` is ``GET /controls/schema``, - # which is intentionally public meta — see its endpoint docstring. + # Endpoint dependencies handle auth; this advertises X-API-Key. control_router, prefix=api_v1_prefix, dependencies=[Depends(get_api_key_from_header)], @@ -306,9 +301,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- dependencies=[Depends(get_api_key_from_header)], ) app.include_router( - # Control templates: ``/render`` is on the auth framework via - # ``require_operation(Operation.CONTROLS_CREATE)``; same mounting - # pattern as the controls and control-bindings routers. + # Endpoint dependencies handle auth; this advertises X-API-Key. control_template_router, prefix=api_v1_prefix, dependencies=[Depends(get_api_key_from_header)], diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index 0357421f..1a2af21f 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -1,19 +1,4 @@ -"""HTTP-level coverage for the auth seam on ``/controls`` and -``/control-templates``. - -These tests exercise the wiring of ``require_operation`` on each route -through the default ``HeaderAuthProvider``: read operations require any -valid credential (``CONTROLS_READ`` -> ``AUTHENTICATED``), write -operations require an admin credential -(``CONTROLS_CREATE`` / ``CONTROLS_UPDATE`` / ``CONTROLS_DELETE`` -> -``ADMIN``), and ``GET /controls/schema`` is intentionally outside the -framework so it stays publicly reachable. - -The provider primitives themselves are exercised in -``tests/test_auth_framework.py``; this file focuses on each endpoint -calling the right ``Operation`` so a future change to the operation -mapping is caught at the route level. -""" +"""HTTP-level auth coverage for ``/controls`` and ``/control-templates``.""" from __future__ import annotations @@ -42,7 +27,7 @@ def _create_control(client: TestClient, name: str | None = None) -> int: # --------------------------------------------------------------------------- -# /controls/schema is intentionally public meta — no require_operation. +# /controls/schema is intentionally public metadata. # --------------------------------------------------------------------------- @@ -165,7 +150,7 @@ def test_non_admin_cannot_create_control(non_admin_client: TestClient) -> None: }, ) - # Then: the request is forbidden by the auth seam + # Then: the request is forbidden assert resp.status_code == 403, resp.text @@ -217,12 +202,7 @@ def test_non_admin_cannot_delete_control( def test_non_admin_cannot_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` is wired to ``CONTROLS_CREATE`` rather than - ``CONTROLS_READ`` because validation exercises the create / update - materialization path; a caller who cannot create has no use for the - result. This pins that decision so it can't drift to ``READ`` - accidentally. - """ + """``/controls/validate`` requires ``CONTROLS_CREATE``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", @@ -234,19 +214,14 @@ def test_non_admin_cannot_validate_control_data( def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: - """``/control-templates/render`` is wired to ``CONTROLS_CREATE`` for - the same reason as ``/validate``: rendering is part of the authoring - flow. The 422 path is not exercised here — only the auth gate is - asserted, so the request shape need not validate. - """ + """``/control-templates/render`` requires ``CONTROLS_CREATE``.""" # When: a non-admin attempts to render a template resp = non_admin_client.post( f"{_TEMPLATES_URL}/render", json={"template": {}, "template_values": {}}, ) - # Then: rendering is admin-only — the auth gate fires before body - # validation reaches the materialization path + # Then: rendering is admin-only assert resp.status_code == 403, resp.text @@ -337,29 +312,3 @@ def test_no_auth_mode_allows_writes_without_credentials( assert resp.status_code == 200, resp.text assert "control_id" in resp.json() - -# --------------------------------------------------------------------------- -# Project-scoped API key deny — pending header forwarding follow-up. -# --------------------------------------------------------------------------- - - -@pytest.mark.skip( - reason=( - "Requires the upstream auth provider to forward an additional " - "configurable credential header. The default forward set is " - "fixed to (X-API-Key, Authorization, Cookie); deployments that " - "use a different credential header can't surface a " - "project-scoped credential to upstream until that becomes " - "configurable. Re-enable when the follow-up PR adds an " - "operator-configurable extra forward list." - ) -) -def test_project_scoped_credential_denied_on_org_scoped_controls() -> None: - """Stub for the deny-test promised by the upstream provider's - response contract: a project-scoped credential calling an - org-scoped operation (``controls.*``) should resolve to a 403 from - the upstream. The end-to-end path is unreachable today because the - provider's credential-forward list is not configurable; tracked as - the next follow-up after this PR. - """ - pytest.fail("test stub — see skip reason") From e75cbb74df3c3bb3b480760453bfc7af7bc9746f Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 22:23:50 +0530 Subject: [PATCH 13/42] docs(server): keep binding auth comments generic --- .../endpoints/control_bindings.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index 18cb75b4..92798ae1 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -36,13 +36,11 @@ async def _binding_body_context(request: Request) -> dict[str, Any]: - """Surface ``(target_type, target_id)`` to the authorizer's context. + """Surface ``(target_type, target_id)`` to the authorization context. The body-bearing binding endpoints carry the target identifiers in - the request payload. Upstream authorizers that resolve the target's - owning project (e.g., Galileo's ``check_management_access``) need - those identifiers to make a project-level decision; without them the - upstream returns 400. + the request payload. Authorization providers can use those + identifiers when a request needs target-scoped access checks. FastAPI caches the parsed body, so the endpoint's own Pydantic request model still binds normally. @@ -60,13 +58,12 @@ async def _binding_body_context(request: Request) -> dict[str, Any]: async def _binding_list_context(request: Request) -> dict[str, Any]: - """Surface optional target query parameters to the authorizer. + """Surface optional target query parameters to authorization context. When the GET list endpoint is called with ``target_type`` and ``target_id`` query params, the request is target-scoped and the - upstream needs the identifiers to make a project-level decision. - When neither is present the request is namespace-wide and forwards - no target context (upstream may then reject if it requires one). + request context includes those identifiers. When neither is present + the request is namespace-wide and forwards no target context. """ target_type = request.query_params.get("target_type") target_id = request.query_params.get("target_id") From 90d72affc1a37a341d8c2e6277716fb9371ed9ee Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:14:39 +0530 Subject: [PATCH 14/42] feat(server): add runtime auth namespace cutover Add explicit none, api_key, and jwt runtime auth modes, including a generic no-auth provider. Move controls, bindings, policies, agents, and evaluation storage lookups onto principal namespace scoping. Cover auth mode selection and principal namespace isolation with server tests. --- .../auth_framework/__init__.py | 7 +- .../auth_framework/config.py | 120 +++++++++++---- .../auth_framework/core.py | 16 +- .../auth_framework/providers/__init__.py | 2 + .../auth_framework/providers/header.py | 39 +++-- .../auth_framework/providers/http_upstream.py | 6 +- .../auth_framework/providers/local_jwt.py | 2 +- .../auth_framework/providers/no_auth.py | 29 ++++ .../agent_control_server/endpoints/agents.py | 92 +++++++----- .../agent_control_server/endpoints/auth.py | 11 +- .../endpoints/control_bindings.py | 47 +++--- .../endpoints/controls.py | 87 +++++++---- .../endpoints/evaluation.py | 26 +++- .../endpoints/policies.py | 77 +++++++--- server/src/agent_control_server/main.py | 19 ++- .../agent_control_server/services/controls.py | 140 +++++++++++++---- server/tests/test_auth_framework.py | 96 +++++++++++- server/tests/test_controls_additional.py | 15 +- server/tests/test_controls_auth.py | 28 ++-- server/tests/test_principal_namespace_flow.py | 141 ++++++++++++++++++ server/tests/test_target_merged_contract.py | 6 +- 21 files changed, 753 insertions(+), 253 deletions(-) create mode 100644 server/src/agent_control_server/auth_framework/providers/no_auth.py create mode 100644 server/tests/test_principal_namespace_flow.py diff --git a/server/src/agent_control_server/auth_framework/__init__.py b/server/src/agent_control_server/auth_framework/__init__.py index 57368d57..0333f2cc 100644 --- a/server/src/agent_control_server/auth_framework/__init__.py +++ b/server/src/agent_control_server/auth_framework/__init__.py @@ -2,10 +2,9 @@ Endpoints declare an :class:`Operation` they need; an installed :class:`RequestAuthorizer` decides whether the request is allowed and -returns the resulting :class:`Principal`. Two providers ship in-tree: -:class:`HeaderAuthProvider` (uses local credential checks) and -:class:`HttpUpstreamAuthProvider` (delegates to a configurable -upstream HTTP service). +returns the resulting :class:`Principal`. Providers ship in-tree for +disabled auth, local credential checks, upstream HTTP authorization, +and local runtime-JWT verification. """ from .core import ( diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 92107b0e..c8f428dc 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -8,15 +8,19 @@ - **Default flow** (everything except runtime). One authorizer handles every operation that does not have a specific override: - :class:`HeaderAuthProvider` (local credentials) or + :class:`NoAuthProvider` (no credentials), + :class:`HeaderAuthProvider` (local API keys), or :class:`HttpUpstreamAuthProvider` (forwards to a configurable URL). -- **Runtime flow.** When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is - configured, :class:`LocalJwtVerifyProvider` is registered as the - override for :data:`Operation.RUNTIME_USE`; the - ``runtime.token_exchange`` operation continues to flow through the - default authorizer because the exchange itself is shaped like a - management call (forward credential, get grant). Without the secret, - no runtime override is installed. +- **Runtime flow.** ``AGENT_CONTROL_RUNTIME_AUTH_MODE`` selects the + override for :data:`Operation.RUNTIME_USE`: ``none`` uses + :class:`NoAuthProvider`, ``api_key`` uses + :class:`HeaderAuthProvider`, and ``jwt`` uses + :class:`LocalJwtVerifyProvider`. When the mode is unset, startup + preserves historical behavior by selecting ``jwt`` if + ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, otherwise ``api_key``. + The ``runtime.token_exchange`` operation continues to flow through + the default authorizer because the exchange itself is shaped like a + management call (forward credential, get grant). """ from __future__ import annotations @@ -30,6 +34,7 @@ HeaderAuthProvider, HttpUpstreamAuthProvider, LocalJwtVerifyProvider, + NoAuthProvider, ) from .providers.http_upstream import HttpUpstreamConfig @@ -43,6 +48,7 @@ _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" # Runtime flow. +_RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" _RUNTIME_TOKEN_SECRET_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_SECRET" _RUNTIME_TOKEN_TTL_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS" _DEFAULT_RUNTIME_TOKEN_TTL_SECONDS = 300 @@ -80,15 +86,19 @@ def configure_auth_from_env() -> None: Default flow: - - ``AGENT_CONTROL_AUTH_MODE=header`` (default): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=api_key`` (default): :class:`HeaderAuthProvider`. + ``header`` remains accepted as a backwards-compatible alias. - ``AGENT_CONTROL_AUTH_MODE=http_upstream``: :class:`HttpUpstreamAuthProvider` pointed at ``AGENT_CONTROL_AUTH_UPSTREAM_URL``. Runtime flow: - - When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, register - :class:`LocalJwtVerifyProvider` as an override for - :data:`Operation.RUNTIME_USE`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key`` (default when no runtime + token secret is configured): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token + secret is configured): :class:`LocalJwtVerifyProvider`. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -101,27 +111,27 @@ def configure_auth_from_env() -> None: global _runtime_auth_config clear_authorizers() _active_providers.clear() - _runtime_auth_config = _load_runtime_auth_config() + runtime_mode = _resolve_runtime_mode() + _runtime_auth_config = ( + _load_runtime_auth_config(require_secret=True) if runtime_mode == "jwt" else None + ) default = _build_default_provider() set_authorizer(default) _active_providers.append(default) - if _runtime_auth_config is not None: - runtime_provider = LocalJwtVerifyProvider(secret=_runtime_auth_config.secret) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": _logger.info( - "Runtime auth enabled: LocalJwtVerifyProvider override installed for %s", + "Runtime auth provider: jwt override installed for %s", Operation.RUNTIME_USE.value, ) else: - _logger.warning( - "Runtime auth disabled (%s not set); %s falls through to the " - "default authorizer, which may grant any authenticated credential. " - "Set the runtime token secret to bind runtime calls to a " - "short-lived target-scoped JWT.", - _RUNTIME_TOKEN_SECRET_ENV, + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, Operation.RUNTIME_USE.value, ) @@ -172,9 +182,12 @@ def set_runtime_auth_config(config: RuntimeAuthConfig | None) -> None: def _build_default_provider() -> RequestAuthorizer: - mode = os.environ.get(_MODE_ENV, "header").strip().lower() - if mode == "header": - _logger.info("Default auth provider: header (local credentials)") + mode = os.environ.get(_MODE_ENV, "api_key").strip().lower() + if mode in {"none", "no_auth"}: + _logger.info("Default auth provider: none") + return NoAuthProvider() + if mode in {"api_key", "header"}: + _logger.info("Default auth provider: api_key (local credentials)") return HeaderAuthProvider() if mode == "http_upstream": url = os.environ.get(_UPSTREAM_URL_ENV) @@ -192,19 +205,60 @@ def _build_default_provider() -> RequestAuthorizer: service_token_header=token_header, ) ) - raise RuntimeError(f"Unknown {_MODE_ENV}={mode!r}; expected 'header' or 'http_upstream'.") + raise RuntimeError( + f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'http_upstream'." + ) + + +def _resolve_runtime_mode() -> str: + raw = os.environ.get(_RUNTIME_MODE_ENV) + if raw is None or not raw.strip(): + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "api_key" + + mode = raw.strip().lower() + if mode in {"none", "no_auth"}: + return "none" + if mode in {"api_key", "header"}: + return "api_key" + if mode == "jwt": + return mode + raise RuntimeError( + f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'jwt'." + ) + + +def _build_runtime_provider( + mode: str, + config: RuntimeAuthConfig | None, +) -> RequestAuthorizer: + if mode == "none": + return NoAuthProvider() + if mode == "api_key": + return HeaderAuthProvider() + if mode == "jwt": + if config is None: + raise RuntimeError(f"{_RUNTIME_MODE_ENV}=jwt but runtime auth config is missing.") + return LocalJwtVerifyProvider(secret=config.secret) + raise RuntimeError( + f"Unknown runtime auth mode {mode!r}; expected 'none', 'api_key', or 'jwt'." + ) -def _load_runtime_auth_config() -> RuntimeAuthConfig | None: +def _load_runtime_auth_config(*, require_secret: bool = False) -> RuntimeAuthConfig | None: """Parse, validate, and return the runtime-auth config from env. - Returns ``None`` when no runtime secret is configured. Raises - ``RuntimeError`` when the secret is too short or the TTL is invalid - so misconfiguration surfaces at startup, not on the first - request-time mint. + Returns ``None`` when no runtime secret is configured and + ``require_secret`` is false. Raises ``RuntimeError`` when the + secret is required, too short, or the TTL is invalid so + misconfiguration surfaces at startup, not on the first request-time + mint. """ secret = os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) if not secret: + if require_secret: + raise RuntimeError( + f"{_RUNTIME_MODE_ENV}=jwt requires {_RUNTIME_TOKEN_SECRET_ENV} to be set." + ) return None if len(secret.encode("utf-8")) < _RUNTIME_TOKEN_SECRET_MIN_BYTES: raise RuntimeError( diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index 9299b441..e0ea6da7 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -42,14 +42,21 @@ class Operation(StrEnum): CONTROL_BINDINGS_READ = "control_bindings.read" CONTROL_BINDINGS_WRITE = "control_bindings.write" - # Runtime token exchange — wired on the exchange endpoint. + # Runtime token exchange - wired on the exchange endpoint. RUNTIME_TOKEN_EXCHANGE = "runtime.token_exchange" - # Reserved for follow-up migrations; not yet wired on endpoints. CONTROLS_READ = "controls.read" CONTROLS_CREATE = "controls.create" CONTROLS_UPDATE = "controls.update" CONTROLS_DELETE = "controls.delete" + POLICIES_READ = "policies.read" + POLICIES_CREATE = "policies.create" + POLICIES_UPDATE = "policies.update" + POLICIES_DELETE = "policies.delete" + AGENTS_READ = "agents.read" + AGENTS_CREATE = "agents.create" + AGENTS_UPDATE = "agents.update" + AGENTS_DELETE = "agents.delete" RUNTIME_USE = "runtime.use" @@ -61,8 +68,7 @@ class Principal: namespace_key: The namespace the request runs in. Endpoints use this to scope every read and write. is_admin: Whether the caller has admin privileges in the - current namespace. Mostly informational for endpoints that - still gate on the legacy admin-key contract. + current namespace. caller_id: Opaque, provider-supplied identifier for the caller (e.g., a key fingerprint or user id). Useful for audit logging; never echo back to clients. @@ -122,7 +128,7 @@ def set_authorizer( Without ``operation``, this becomes the default authorizer used by every operation that does not have a specific override. With - ``operation``, it overrides the default for that operation only — + ``operation``, it overrides the default for that operation only - used to route a different family (e.g., runtime) through a different provider. diff --git a/server/src/agent_control_server/auth_framework/providers/__init__.py b/server/src/agent_control_server/auth_framework/providers/__init__.py index e8a68486..ad5d6b38 100644 --- a/server/src/agent_control_server/auth_framework/providers/__init__.py +++ b/server/src/agent_control_server/auth_framework/providers/__init__.py @@ -3,10 +3,12 @@ from .header import AccessLevel, HeaderAuthProvider from .http_upstream import HttpUpstreamAuthProvider from .local_jwt import LocalJwtVerifyProvider +from .no_auth import NoAuthProvider __all__ = [ "AccessLevel", "HeaderAuthProvider", "HttpUpstreamAuthProvider", "LocalJwtVerifyProvider", + "NoAuthProvider", ] diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index f76936a1..228ec443 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -1,23 +1,14 @@ """Default :class:`RequestAuthorizer` that uses local credentials only. -Resolves the namespace from a header (or falls back to -``DEFAULT_NAMESPACE_KEY``) and enforces a per-operation access level -using the legacy API-key + session-cookie credential check from -:mod:`agent_control_server.auth`. Behavior matches the pre-framework -local auth path verbatim: +Returns ``DEFAULT_NAMESPACE_KEY`` and enforces a per-operation access +level using the local API-key + session-cookie credential check from +:mod:`agent_control_server.auth`: - ``ADMIN`` operations require an admin key (or admin session). - ``AUTHENTICATED`` operations require any valid credential. - ``PUBLIC`` operations are open. -- When ``api_key_enabled`` is ``False`` (no-auth mode), every - operation succeeds with a non-admin :class:`Principal` — preserved - by the underlying credential check. - -The header lookup is wired but currently inert: the provider always -returns the default namespace because non-binding write endpoints -still hardcode it. The header is kept here so a follow-up that -threads namespace resolution through the rest of the API can flip it -on without changing the provider contract. +- When the underlying local credential layer is disabled, every + operation succeeds with a non-admin :class:`Principal`. """ from __future__ import annotations @@ -51,6 +42,14 @@ class AccessLevel(Enum): Operation.CONTROLS_CREATE: AccessLevel.ADMIN, Operation.CONTROLS_UPDATE: AccessLevel.ADMIN, Operation.CONTROLS_DELETE: AccessLevel.ADMIN, + Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, + Operation.POLICIES_CREATE: AccessLevel.ADMIN, + Operation.POLICIES_UPDATE: AccessLevel.ADMIN, + Operation.POLICIES_DELETE: AccessLevel.ADMIN, + Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, + Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, + Operation.AGENTS_UPDATE: AccessLevel.ADMIN, + Operation.AGENTS_DELETE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } @@ -60,7 +59,7 @@ class HeaderAuthProvider(RequestAuthorizer): """Default authorizer. For each operation's configured access level, validates the - request's credentials via the legacy local check; on success, + request's credentials via the local credential check; on success, returns a :class:`Principal` scoped to the resolved namespace. """ @@ -100,8 +99,7 @@ async def authorize( ) # Runtime token exchange returns a normalized scope grant so the # exchange endpoint can require ``runtime.use`` uniformly across - # providers; an upstream that explicitly grants no scopes ends - # up with an empty tuple and is rejected. + # providers. scopes: tuple[str, ...] = ( (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () ) @@ -113,10 +111,7 @@ async def authorize( ) def _resolve_namespace_key(self, request: Request) -> str: - # The provider always returns the default namespace because - # non-binding write endpoints still hardcode it; serving - # anything else here would create rows the rest of the API - # cannot find. The branch is preserved so a future change can - # lift the lock without touching the provider contract. + # Local credentials do not carry namespace metadata. Providers + # that resolve a namespace can return a different principal. del request return self._default_namespace_key diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index a97a3de8..8d5c850c 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -67,8 +67,8 @@ class _UpstreamGrant(BaseModel): """Strict schema for the upstream authorization-service response. Unknown fields are tolerated (so the upstream can evolve), but every - *known* field is type-checked. A wrong type on any field — or a - half-supplied target binding — causes the provider to fail closed + *known* field is type-checked. A wrong type on any field - or a + half-supplied target binding - causes the provider to fail closed with a 502. """ @@ -108,7 +108,7 @@ def _target_must_be_paired(self) -> _UpstreamGrant: A target is meaningful only as a ``(target_type, target_id)`` pair; allowing one side without the other would let a malformed grant pass and the exchange endpoint mint a token for the - request's value of the missing half — outside the upstream's + request's value of the missing half - outside the upstream's intended authorization. """ if (self.target_type is None) != (self.target_id is None): diff --git a/server/src/agent_control_server/auth_framework/providers/local_jwt.py b/server/src/agent_control_server/auth_framework/providers/local_jwt.py index bb448503..8620d3b6 100644 --- a/server/src/agent_control_server/auth_framework/providers/local_jwt.py +++ b/server/src/agent_control_server/auth_framework/providers/local_jwt.py @@ -6,7 +6,7 @@ returns a :class:`Principal` carrying the bound target. When a ``context_builder`` on the dependency surfaces ``target_type`` / ``target_id``, the provider also enforces that they match the token's -binding — runtime endpoints get the request-target check for free. +binding - runtime endpoints get the request-target check for free. """ from __future__ import annotations diff --git a/server/src/agent_control_server/auth_framework/providers/no_auth.py b/server/src/agent_control_server/auth_framework/providers/no_auth.py new file mode 100644 index 00000000..509ca4f3 --- /dev/null +++ b/server/src/agent_control_server/auth_framework/providers/no_auth.py @@ -0,0 +1,29 @@ +"""Authorizer for deployments that intentionally disable authentication.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import Request + +from ...models import DEFAULT_NAMESPACE_KEY +from ..core import Operation, Principal, RequestAuthorizer + + +class NoAuthProvider(RequestAuthorizer): + """Allows every operation and returns the default namespace.""" + + def __init__(self, *, default_namespace_key: str = DEFAULT_NAMESPACE_KEY) -> None: + self._default_namespace_key = default_namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request, context + scopes: tuple[str, ...] = ( + (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () + ) + return Principal(namespace_key=self._default_namespace_key, scopes=scopes) diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 034ae35f..ac099911 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -36,7 +36,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey, require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, @@ -53,7 +53,6 @@ Policy, agent_policies, ) -from ..namespace import get_namespace_key from ..services.agent_names import normalize_agent_name_or_422 from ..services.controls import ( AgentControlEnabledState, @@ -112,7 +111,7 @@ def _validate_controls_for_agent(agent: Agent, controls: list[Control]) -> list[ agent_evaluators = {e.name: e for e in (agent_data.evaluators or [])} for control in controls: - # Skip unrendered template controls — they have no evaluators to validate. + # Skip unrendered template controls - they have no evaluators to validate. if ( isinstance(control.data, dict) and control.data.get("template") is not None @@ -286,7 +285,7 @@ async def list_agents( limit: int = _DEFAULT_PAGINATION_LIMIT, name: str | None = None, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListAgentsResponse: """ List all registered agents with cursor-based pagination. @@ -300,11 +299,13 @@ async def list_agents( limit: Pagination limit (default 20, max 100) name: Optional name filter (case-insensitive partial match) db: Database session (injected) - namespace_key: Resolved namespace for the request + principal: Authorized request principal Returns: ListAgentsResponse with agent summaries and pagination info """ + namespace_key = principal.namespace_key + # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -377,14 +378,20 @@ async def list_agents( agent_policies.c.agent_name, agent_policies.c.policy_id, ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) .order_by(agent_policies.c.agent_name, agent_policies.c.policy_id) ) policy_ids_result = await db.execute(policy_ids_query) for assoc_agent_name, policy_id in policy_ids_result.all(): policy_ids_map.setdefault(assoc_agent_name, []).append(policy_id) - control_counts_map = await control_service.list_active_control_counts_by_agent(agent_names) + control_counts_map = await control_service.list_active_control_counts_by_agent( + agent_names, + namespace_key=namespace_key, + ) # Build summaries summaries: list[AgentSummary] = [] @@ -436,9 +443,8 @@ async def list_agents( ) async def init_agent( request: InitAgentRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. @@ -462,10 +468,13 @@ async def init_agent( Args: request: Agent metadata and step schemas db: Database session (injected) + principal: Authorized request principal Returns: InitAgentResponse with created flag and the effective controls """ + namespace_key = principal.namespace_key + # Check for evaluator name collisions with built-in evaluators builtin_names = _get_builtin_evaluator_names() for ev in request.evaluators: @@ -835,7 +844,7 @@ async def init_agent( async def get_agent( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentResponse: """ Retrieve agent metadata and all registered steps. @@ -845,8 +854,7 @@ async def get_agent( Args: agent_name: Agent identifier db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: GetAgentResponse with agent metadata and step list @@ -855,6 +863,7 @@ async def get_agent( HTTPException 404: Agent not found HTTPException 422: Agent data is corrupted """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -917,7 +926,7 @@ async def _get_agent_or_404( The lookup is always namespace-scoped: an agent that exists only in another namespace surfaces as 404 (non-disclosing) so duplicate - names across namespaces — which the schema explicitly permits — + names across namespaces - which the schema explicitly permits - cannot be addressed across the namespace boundary. """ normalized_agent_name = normalize_agent_name_or_422(agent_name) @@ -940,7 +949,6 @@ async def _get_agent_or_404( @router.post( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate policy with agent", response_description="Success confirmation", @@ -949,9 +957,10 @@ async def add_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a policy with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1017,7 +1026,6 @@ async def add_agent_policy( @router.post( "/{agent_name}/policy/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=SetPolicyResponse, summary="Assign policy to agent (compatibility)", response_description="Success status with previous policy ID", @@ -1026,9 +1034,10 @@ async def set_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> SetPolicyResponse: """Compatibility endpoint that replaces all policy associations with one policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1117,9 +1126,10 @@ async def set_agent_policy( async def get_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentPoliciesResponse: """List policy IDs associated with an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) result = await db.execute( select(agent_policies.c.policy_id) @@ -1141,9 +1151,10 @@ async def get_agent_policies( async def get_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetPolicyResponse: """Compatibility endpoint that returns the first associated policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( select(Policy.id) @@ -1172,7 +1183,6 @@ async def get_agent_policy( @router.delete( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove policy association from agent", response_description="Success confirmation", @@ -1181,13 +1191,14 @@ async def remove_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove a policy association from an agent. Idempotent for existing resources: removing a non-associated link is a no-op. Missing agent/policy resources still return 404. """ + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1230,7 +1241,6 @@ async def remove_agent_policy( @router.delete( "/{agent_name}/policies", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove all policy associations from agent", response_description="Success confirmation", @@ -1238,9 +1248,10 @@ async def remove_agent_policy( async def remove_all_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove all policy associations from an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) try: @@ -1271,7 +1282,6 @@ async def remove_all_agent_policies( @router.delete( "/{agent_name}/policy", - dependencies=[Depends(require_admin_key)], response_model=DeletePolicyResponse, summary="Remove agent's policy assignment (compatibility)", response_description="Success confirmation", @@ -1279,9 +1289,10 @@ async def remove_all_agent_policies( async def delete_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> DeletePolicyResponse: """Compatibility endpoint that removes all policy associations.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) existing_policy_result = await db.execute( @@ -1328,7 +1339,6 @@ async def delete_agent_policy( @router.post( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate control directly with agent", response_description="Success confirmation", @@ -1337,9 +1347,10 @@ async def add_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a control directly with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) control = await control_service.get_active_control_or_404( @@ -1389,7 +1400,6 @@ async def add_agent_control( @router.delete( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=RemoveAgentControlResponse, summary="Remove direct control association from agent", response_description="Success confirmation", @@ -1398,9 +1408,10 @@ async def remove_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> RemoveAgentControlResponse: """Remove a direct control association from an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) await control_service.get_active_control_or_404(control_id, namespace_key=namespace_key) @@ -1481,7 +1492,7 @@ async def list_agent_controls( description="Optional opaque target identifier. Required when target_type is supplied.", ), db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> AgentControlsResponse: """ List protection controls effective for an agent. @@ -1506,7 +1517,7 @@ async def list_agent_controls( target_type: Optional opaque target kind (paired with target_id) target_id: Optional opaque target identifier (paired with target_type) db: Database session (injected) - namespace_key: Namespace scoping for the resolution (injected) + principal: Authorized request principal Returns: AgentControlsResponse with controls matching the requested state filters @@ -1515,6 +1526,8 @@ async def list_agent_controls( HTTPException 400: target_type and target_id were not supplied together HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key + if (target_type is None) != (target_id is None): raise BadRequestError( error_code=ErrorCode.VALIDATION_ERROR, @@ -1572,7 +1585,7 @@ async def list_agent_evaluators( cursor: str | None = None, limit: int = _DEFAULT_PAGINATION_LIMIT, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListEvaluatorsResponse: """ List all evaluator schemas registered with an agent. @@ -1586,8 +1599,7 @@ async def list_agent_evaluators( cursor: Optional cursor for pagination (name of last evaluator from previous page) limit: Pagination limit (default 20, max 100) db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: ListEvaluatorsResponse with evaluator schemas and pagination @@ -1595,6 +1607,7 @@ async def list_agent_evaluators( Raises: HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -1672,7 +1685,7 @@ async def get_agent_evaluator( agent_name: str, evaluator_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> EvaluatorSchemaItem: """ Get a specific evaluator schema registered with an agent. @@ -1681,8 +1694,7 @@ async def get_agent_evaluator( agent_name: Agent identifier evaluator_name: Name of the evaluator db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: EvaluatorSchemaItem with schema details @@ -1690,6 +1702,7 @@ async def get_agent_evaluator( Raises: HTTPException 404: Agent or evaluator not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -1734,7 +1747,6 @@ async def get_agent_evaluator( @router.patch( "/{agent_name}", - dependencies=[Depends(require_admin_key)], response_model=PatchAgentResponse, summary="Modify agent (remove steps/evaluators)", response_description="Lists of removed items", @@ -1743,7 +1755,7 @@ async def patch_agent( agent_name: str, request: PatchAgentRequest, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> PatchAgentResponse: """ Remove steps and/or evaluators from an agent. @@ -1755,6 +1767,7 @@ async def patch_agent( agent_name: Agent identifier request: Lists of step/evaluator identifiers to remove db: Database session (injected) + principal: Authorized request principal Returns: PatchAgentResponse with lists of actually removed items @@ -1763,6 +1776,7 @@ async def patch_agent( HTTPException 404: Agent not found HTTPException 500: Database error during update """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index 1a23baa8..f80cd2fa 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,9 +2,8 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer (typically -:class:`HttpUpstreamAuthProvider` in production) authenticates the -credential and authorizes the implied +target_id)``; the default authorizer authenticates the credential and +authorizes the implied :data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the @@ -130,8 +129,8 @@ async def runtime_token_exchange( actor_id = principal.caller_id or "anonymous" # The exchange endpoint requires the authorizer to explicitly grant - # runtime.use. Providers that do not surface scopes (legacy local - # provider) supply a normalized grant for ``RUNTIME_TOKEN_EXCHANGE``; + # runtime.use. Local providers supply a normalized grant for + # ``RUNTIME_TOKEN_EXCHANGE``; # upstream providers that return an explicit empty scopes array fail # closed here rather than escalating to runtime.use. if Operation.RUNTIME_USE.value not in principal.scopes: @@ -155,7 +154,7 @@ async def runtime_token_exchange( ) except UpstreamGrantExpiredError as exc: # Upstream returned a grant whose ``expires_at`` is already in - # the past — minting would hand the caller a token that's dead + # the past - minting would hand the caller a token that's dead # on arrival. Distinguished from the misconfigured case so the # error code and status reflect "upstream returned bad data." raise APIError( diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index 92798ae1..d2fe4b44 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -26,7 +26,6 @@ from ..db import get_async_db from ..errors import BadRequestError from ..models import ControlBinding -from ..namespace import get_namespace_key from ..services.control_bindings import ControlBindingsService router = APIRouter(prefix="/control-bindings", tags=["control-bindings"]) @@ -94,26 +93,21 @@ def _to_response(binding: ControlBinding) -> GetControlBindingResponse: async def create_control_binding( request: CreateControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. Each binding row is scoped to the request namespace as resolved by - ``get_namespace_key``. The auth chain still runs via - ``require_operation`` for authentication and authorization, but the - storage namespace is taken from the same resolver the rest of the - server uses so binding writes and runtime reads stay in lockstep - until auth-derived namespace resolution lands across every endpoint. + the active authorizer. """ service = ControlBindingsService(db) binding = await service.create_binding( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -148,20 +142,18 @@ async def list_control_bindings( target_id: str | None = None, control_id: int | None = None, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_READ, context_builder=_binding_list_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> ListControlBindingsResponse: """Return bindings in the request namespace with optional filters and cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by ``get_namespace_key`` so this - listing stays in lockstep with the rest of the server's reads. + storage namespace is resolved by the active authorizer. """ parsed_cursor: int | None if cursor is None: @@ -177,7 +169,7 @@ async def list_control_bindings( ) from exc service = ControlBindingsService(db) page = await service.list_bindings( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, cursor=parsed_cursor, limit=limit, target_type=target_type, @@ -204,8 +196,7 @@ async def list_control_bindings( async def get_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), ) -> GetControlBindingResponse: """Read a single control binding by surrogate ID. @@ -218,7 +209,9 @@ async def get_control_binding( of which forward ``(target_type, target_id)`` to the authorizer. """ service = ControlBindingsService(db) - binding = await service.get_binding_or_404(namespace_key=namespace_key, binding_id=binding_id) + binding = await service.get_binding_or_404( + namespace_key=principal.namespace_key, binding_id=binding_id + ) return _to_response(binding) @@ -232,8 +225,7 @@ async def patch_control_binding( binding_id: int, request: PatchControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> PatchControlBindingResponse: """Update the ``enabled`` flag on a control binding. @@ -244,7 +236,7 @@ async def patch_control_binding( """ service = ControlBindingsService(db) binding = await service.set_enabled( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, binding_id=binding_id, enabled=request.enabled, ) @@ -261,8 +253,7 @@ async def patch_control_binding( async def delete_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> DeleteControlBindingResponse: """Delete a control binding by surrogate ID. @@ -272,7 +263,7 @@ async def delete_control_binding( target-scoped detach that forwards the target to the authorizer. """ service = ControlBindingsService(db) - await service.delete_binding(namespace_key=namespace_key, binding_id=binding_id) + await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) await db.commit() return DeleteControlBindingResponse(success=True) @@ -286,13 +277,12 @@ async def delete_control_binding( async def upsert_control_binding_by_key( request: UpsertControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> UpsertControlBindingResponse: """Idempotent attach using ``(target_type, target_id, control_id)`` as the natural key. Updates ``enabled`` on an existing match; creates a new row @@ -300,7 +290,7 @@ async def upsert_control_binding_by_key( """ service = ControlBindingsService(db) binding, created = await service.upsert_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -324,20 +314,19 @@ async def upsert_control_binding_by_key( async def delete_control_binding_by_key( request: DeleteControlBindingByKeyRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> DeleteControlBindingByKeyResponse: """Idempotent detach by natural key. Returns ``deleted=False`` when no matching binding exists. """ service = ControlBindingsService(db) deleted = await service.delete_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index fcb7cb18..5b01593c 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -229,7 +229,7 @@ async def _materialize_control_input( enabled=enabled, ) - # Incomplete values — only allowed for new controls or already-unrendered + # Incomplete values - only allowed for new controls or already-unrendered # templates. Updating a rendered control with incomplete values is # rejected to prevent silently stripping rendered fields. current_is_rendered = ( @@ -470,7 +470,7 @@ async def render_control_template( async def create_control( request: CreateControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> CreateControlResponse: """ Create a new control with a unique name. @@ -492,7 +492,10 @@ async def create_control( control_service = ControlService(db) # Uniqueness check - if await control_service.active_control_name_exists(request.name): + namespace_key = principal.namespace_key + if await control_service.active_control_name_exists( + request.name, namespace_key=namespace_key + ): raise ConflictError( error_code=ErrorCode.CONTROL_NAME_CONFLICT, detail=f"Control with name '{request.name}' already exists", @@ -504,7 +507,11 @@ async def create_control( control_def = await _materialize_control_input(request.data, db=db) control_data = _serialize_control_data(control_def) - control = control_service.create_control(name=request.name, data=control_data) + control = control_service.create_control( + namespace_key=namespace_key, + name=request.name, + data=control_data, + ) try: await control_service.create_version( control, @@ -569,7 +576,7 @@ async def get_control_schema() -> GetControlSchemaResponse: async def get_control( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlResponse: """ Retrieve a control by ID including its name and configuration data. @@ -584,7 +591,9 @@ async def get_control( Raises: HTTPException 404: Control not found """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -608,7 +617,7 @@ async def get_control( async def get_control_data( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlDataResponse: """ Retrieve the configuration data for a control. @@ -626,7 +635,9 @@ async def get_control_data( HTTPException 404: Control not found HTTPException 422: Control data is corrupted """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -648,10 +659,15 @@ async def list_control_versions( ), limit: int = Query(_DEFAULT_PAGINATION_LIMIT, ge=1, le=_MAX_PAGINATION_LIMIT), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlVersionsResponse: """List control versions ordered newest-first using cursor-based pagination.""" - page = await ControlService(db).list_versions(control_id, cursor=cursor, limit=limit) + page = await ControlService(db).list_versions( + control_id, + namespace_key=principal.namespace_key, + cursor=cursor, + limit=limit, + ) return ListControlVersionsResponse( versions=[ @@ -682,10 +698,12 @@ async def get_control_version( control_id: int, version_num: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlVersionResponse: """Return a specific control version, including its raw persisted snapshot.""" - version = await ControlService(db).get_version_or_404(control_id, version_num) + version = await ControlService(db).get_version_or_404( + control_id, version_num, namespace_key=principal.namespace_key + ) return GetControlVersionResponse( version_num=version.version_num, event_type=version.event_type, @@ -705,7 +723,7 @@ async def set_control_data( control_id: int, request: SetControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> SetControlDataResponse: """ Update the configuration data for a control. @@ -726,7 +744,9 @@ async def set_control_data( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=principal.namespace_key, for_update=True + ) control_def = await _materialize_control_input( request.data, @@ -767,11 +787,12 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Validation uses the authoring path, so require create access. +# Authorized as CONTROLS_READ: validate exercises the materialization +# path but does not mutate stored control data. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. @@ -811,7 +832,7 @@ async def list_controls( execution: str | None = Query(None, description="Filter by execution ('server' or 'sdk')"), tag: str | None = Query(None, description="Filter by tag"), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlsResponse: """ List all controls with optional filtering and cursor-based pagination. @@ -837,7 +858,9 @@ async def list_controls( GET /controls?limit=10&enabled=true&step_type=tool """ control_service = ControlService(db) + namespace_key = principal.namespace_key page = await control_service.list_controls_page( + namespace_key=namespace_key, cursor=cursor, limit=limit, name=name, @@ -849,7 +872,8 @@ async def list_controls( tag=tag, ) usage_by_control_id = await control_service.list_control_usage( - [control.id for control in page.controls] + [control.id for control in page.controls], + namespace_key=namespace_key, ) # Build summaries (filtering already done at DB level) @@ -910,7 +934,7 @@ async def delete_control( "If false, fail if control is associated with any policy or agent.", ), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), ) -> DeleteControlResponse: """ Delete a control by ID. @@ -933,13 +957,18 @@ async def delete_control( """ control_service = ControlService(db) bindings_service = ControlBindingsService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) - associations = await control_service.list_control_associations(control_id) + associations = await control_service.list_control_associations( + control_id, namespace_key=namespace_key + ) associated_policy_ids = associations.policy_ids associated_agent_names = associations.agent_names target_binding_ids = await bindings_service.list_binding_ids_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if ( @@ -996,13 +1025,15 @@ async def delete_control( dissociated_from_policies: list[int] = [] dissociated_from_agents: list[str] = [] if associated_policy_ids or associated_agent_names: - dissociated = await control_service.remove_all_control_associations(control_id) + dissociated = await control_service.remove_all_control_associations( + control_id, namespace_key=namespace_key + ) dissociated_from_policies = dissociated.policy_ids dissociated_from_agents = dissociated.agent_names detached_target_bindings: list[int] = [] if target_binding_ids: detached_target_bindings = await bindings_service.delete_bindings_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if dissociated_from_policies or dissociated_from_agents or detached_target_bindings: _logger.info( @@ -1057,7 +1088,7 @@ async def patch_control( control_id: int, request: PatchControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> PatchControlResponse: """ Update control metadata (name and/or enabled status). @@ -1081,7 +1112,10 @@ async def patch_control( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) parsed_control = _parse_stored_control_data( control.data, control_name=control.name, @@ -1096,6 +1130,7 @@ async def patch_control( # Check for name collision if await control_service.active_control_name_exists( request.name, + namespace_key=namespace_key, exclude_control_id=control_id, ): raise ConflictError( diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index e018796e..437af8b5 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -10,16 +10,15 @@ EvaluationResponse, ) from agent_control_models.errors import ErrorCode, ValidationErrorItem -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import APIValidationError, NotFoundError from ..logging_utils import get_logger from ..models import Agent -from ..namespace import get_namespace_key from ..services.controls import ControlService router = APIRouter(prefix="/evaluation", tags=["evaluation"]) @@ -118,6 +117,20 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) +async def _evaluation_context(request: Request) -> dict[str, object]: + """Surface target identifiers to the runtime authorizer.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + return {} + if not isinstance(body, dict): + return {} + return { + "target_type": body.get("target_type"), + "target_id": body.get("target_id"), + } + + @router.post( "", response_model=EvaluationResponse, @@ -126,9 +139,10 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) async def evaluate( request: EvaluationRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends( + require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context) + ), ) -> EvaluationResponse: """Analyze content for safety and control violations. @@ -144,7 +158,7 @@ async def evaluate( on the server; SDKs reconstruct and emit those events separately through the observability ingestion endpoint. """ - del client # Authentication is still required by dependency injection. + namespace_key = principal.namespace_key agent_result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/policies.py b/server/src/agent_control_server/endpoints/policies.py index 7b8b2ef9..ddda7127 100644 --- a/server/src/agent_control_server/endpoints/policies.py +++ b/server/src/agent_control_server/endpoints/policies.py @@ -9,7 +9,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ConflictError, DatabaseError, NotFoundError from ..logging_utils import get_logger @@ -23,13 +23,14 @@ @router.put( "", - dependencies=[Depends(require_admin_key)], response_model=CreatePolicyResponse, summary="Create a new policy", response_description="Created policy ID", ) async def create_policy( - request: CreatePolicyRequest, db: AsyncSession = Depends(get_async_db) + request: CreatePolicyRequest, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_CREATE)), ) -> CreatePolicyResponse: """ Create a new empty policy with a unique name. @@ -48,8 +49,14 @@ async def create_policy( HTTPException 409: Policy with this name already exists HTTPException 500: Database error during creation """ + namespace_key = principal.namespace_key # Uniqueness check - existing = await db.execute(select(Policy.id).where(Policy.name == request.name)) + existing = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.name == request.name, + ) + ) if existing.first() is not None: raise ConflictError( error_code=ErrorCode.POLICY_NAME_CONFLICT, @@ -59,7 +66,7 @@ async def create_policy( hint="Choose a different name or update the existing policy.", ) - policy = Policy(name=request.name) + policy = Policy(namespace_key=namespace_key, name=request.name) db.add(policy) try: await db.commit() @@ -80,13 +87,15 @@ async def create_policy( @router.post( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Add control to policy", response_description="Success confirmation", ) async def add_control_to_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Associate a control with a policy. @@ -106,8 +115,14 @@ async def add_control_to_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ + namespace_key = principal.namespace_key # Find policy and control - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -119,11 +134,17 @@ async def add_control_to_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Add association using INSERT ... ON CONFLICT DO NOTHING for idempotency try: - await control_service.add_control_to_policy(policy_id=policy_id, control_id=control_id) + await control_service.add_control_to_policy( + policy_id=policy_id, + control_id=control_id, + namespace_key=namespace_key, + ) await db.commit() except Exception: await db.rollback() @@ -149,13 +170,15 @@ async def add_control_to_policy( @router.delete( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove control from policy", response_description="Success confirmation", ) async def remove_control_from_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Remove a control from a policy. @@ -175,7 +198,13 @@ async def remove_control_from_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -187,13 +216,16 @@ async def remove_control_from_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Remove association (idempotent - deleting non-existent is no-op) try: await control_service.remove_control_from_policy( policy_id=policy_id, control_id=control_id, + namespace_key=namespace_key, ) await db.commit() except Exception: @@ -222,7 +254,9 @@ async def remove_control_from_policy( response_description="List of control IDs", ) async def list_policy_controls( - policy_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_READ)), ) -> GetPolicyControlsResponse: """ List all controls associated with a policy. @@ -237,7 +271,13 @@ async def list_policy_controls( Raises: HTTPException 404: Policy not found """ - pol_res = await db.execute(select(Policy.id).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) if pol_res.first() is None: raise NotFoundError( error_code=ErrorCode.POLICY_NOT_FOUND, @@ -247,5 +287,8 @@ async def list_policy_controls( hint="Verify the policy ID is correct and the policy has been created.", ) - control_ids = await ControlService(db).list_policy_control_ids(policy_id) + control_ids = await ControlService(db).list_policy_control_ids( + policy_id, + namespace_key=namespace_key, + ) return GetPolicyControlsResponse(control_ids=control_ids) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index bc1bf04b..a1561e63 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -252,7 +252,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # Register handler for FastAPI's RequestValidationError (Pydantic validation) app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore[arg-type] -# Register handler for standard HTTPException (legacy code, FastAPI internals) +# Register handler for standard HTTPException (older routes, FastAPI internals) app.add_exception_handler(HTTPException, http_exception_handler) # type: ignore[arg-type] # Register catch-all handler for unexpected exceptions @@ -261,16 +261,18 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # API v1 prefix for all routes api_v1_prefix = f"{settings.api_prefix}/{settings.api_version}" -# Protected routes (require valid API key) +# API routers. Routers migrated to the auth framework mount the +# non-validating header extractor only so OpenAPI advertises X-API-Key; +# each endpoint's ``require_operation`` dependency owns authn + authz. app.include_router( agent_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( policy_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( # Endpoint dependencies handle auth; this advertises X-API-Key. @@ -281,11 +283,11 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( # The auth framework on each endpoint owns authentication and # authorization for control bindings, so this router is mounted - # without the legacy router-level gate. See ``auth_framework`` for + # without the router-level auth gate. See ``auth_framework`` for # the provider contract. ``get_api_key_from_header`` is a non- # validating extractor (``auto_error=False``); it is attached purely # so the generated OpenAPI spec advertises the X-API-Key requirement - # on these routes — without it, downstream SDK generators would treat + # on these routes - without it, downstream SDK generators would treat # the routes as unauthenticated. control_binding_router, prefix=api_v1_prefix, @@ -309,9 +311,10 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( evaluation_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) +# Evaluator discovery still uses the local credential dependency. app.include_router( evaluator_router, prefix=api_v1_prefix, @@ -324,7 +327,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- prefix=api_v1_prefix, ) -# System routes (config, login, logout) — no auth required +# System routes (config, login, logout) - no auth required app.include_router( system_router, prefix=settings.api_prefix, diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 263120b7..41a62282 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -20,6 +20,7 @@ from ..errors import APIValidationError, NotFoundError from ..models import ( + DEFAULT_NAMESPACE_KEY, Control, ControlBinding, ControlVersion, @@ -96,9 +97,15 @@ class ControlService: def __init__(self, db: AsyncSession) -> None: self._db = db - def create_control(self, *, name: str, data: dict[str, Any]) -> Control: + def create_control( + self, + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, + name: str, + data: dict[str, Any], + ) -> Control: """Create a new pending control row.""" - control = Control(name=name, data=data) + control = Control(namespace_key=namespace_key, name=name, data=data) self._db.add(control) return control @@ -128,10 +135,13 @@ async def get_control_or_404( self, control_id: int, *, + namespace_key: str | None = None, for_update: bool = False, ) -> Control: """Load any control row, including soft-deleted controls.""" stmt = select(Control).where(Control.id == control_id) + if namespace_key is not None: + stmt = stmt.where(Control.namespace_key == namespace_key) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -180,10 +190,15 @@ async def active_control_name_exists( self, name: str, *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" - stmt = select(Control.id).where(Control.name == name, Control.deleted_at.is_(None)) + stmt = select(Control.id).where( + Control.namespace_key == namespace_key, + Control.name == name, + Control.deleted_at.is_(None), + ) if exclude_control_id is not None: stmt = stmt.where(Control.id != exclude_control_id) result = await self._db.execute(stmt) @@ -216,11 +231,12 @@ async def list_versions( self, control_id: int, *, + namespace_key: str, cursor: int | None, limit: int, ) -> ControlVersionPage: """Return control versions newest-first with cursor pagination.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) total_result = await self._db.execute( select(func.count()) @@ -255,9 +271,11 @@ async def list_versions( next_cursor=next_cursor, ) - async def get_version_or_404(self, control_id: int, version_num: int) -> ControlVersion: + async def get_version_or_404( + self, control_id: int, version_num: int, *, namespace_key: str + ) -> ControlVersion: """Load a specific version row for a control.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) result = await self._db.execute( select(ControlVersion).where( @@ -303,12 +321,17 @@ async def list_controls_for_policy( result = await self._db.execute(stmt) return list(result.scalars().unique().all()) - async def list_policy_control_ids(self, policy_id: int) -> list[int]: + async def list_policy_control_ids(self, policy_id: int, *, namespace_key: str) -> list[int]: """Return active control IDs directly associated with a policy.""" result = await self._db.execute( select(policy_controls.c.control_id) .join(Control, Control.id == policy_controls.c.control_id) - .where(policy_controls.c.policy_id == policy_id, Control.deleted_at.is_(None)) + .where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.policy_id == policy_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) .order_by(policy_controls.c.control_id) ) return [cast(int, row[0]) for row in result.all()] @@ -396,6 +419,7 @@ async def list_runtime_controls_for_agent( async def list_controls_page( self, *, + namespace_key: str, cursor: int | None, limit: int, name: str | None, @@ -407,7 +431,11 @@ async def list_controls_page( tag: str | None, ) -> ControlListPage: """Return paginated active controls for the browse endpoint.""" - query = select(Control).where(Control.deleted_at.is_(None)).order_by(Control.id.desc()) + query = ( + select(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + .order_by(Control.id.desc()) + ) query = self._apply_control_list_filters( query, name=name, @@ -424,7 +452,11 @@ async def list_controls_page( result = await self._db.execute(query.limit(limit + 1)) controls = list(result.scalars().all()) - total_query = select(func.count()).select_from(Control).where(Control.deleted_at.is_(None)) + total_query = ( + select(func.count()) + .select_from(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + ) total_query = self._apply_control_list_filters( total_query, name=name, @@ -453,7 +485,9 @@ async def list_controls_page( next_cursor=next_cursor, ) - async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, ControlUsage]: + async def list_control_usage( + self, control_ids: Sequence[int], *, namespace_key: str + ) -> dict[int, ControlUsage]: """Return representative agent usage and usage counts for the provided controls.""" if not control_ids: return {} @@ -465,8 +499,16 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_policies.c.agent_name, ) .select_from(policy_controls) - .join(agent_policies, policy_controls.c.policy_id == agent_policies.c.policy_id) - .where(policy_controls.c.control_id.in_(control_ids)) + .join( + agent_policies, + (policy_controls.c.policy_id == agent_policies.c.policy_id) + & (policy_controls.c.namespace_key == agent_policies.c.namespace_key), + ) + .where( + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.control_id.in_(control_ids), + ) ) direct_agents_query = ( select( @@ -474,7 +516,10 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_controls.c.agent_name, ) .select_from(agent_controls) - .where(agent_controls.c.control_id.in_(control_ids)) + .where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id.in_(control_ids), + ) ) agents_result = await self._db.execute(union_all(policy_agents_query, direct_agents_query)) for control_id, agent_name in agents_result.all(): @@ -491,6 +536,8 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], + *, + namespace_key: str = DEFAULT_NAMESPACE_KEY, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: @@ -503,15 +550,24 @@ async def list_active_control_counts_by_agent( ) .select_from( agent_policies.join( - policy_controls, agent_policies.c.policy_id == policy_controls.c.policy_id + policy_controls, + (agent_policies.c.policy_id == policy_controls.c.policy_id) + & (agent_policies.c.namespace_key == policy_controls.c.namespace_key), ) ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) ) direct_associations = select( agent_controls.c.agent_name.label("agent_name"), agent_controls.c.control_id.label("control_id"), - ).where(agent_controls.c.agent_name.in_(agent_names)) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.agent_name.in_(agent_names), + ) all_associations = union_all(policy_associations, direct_associations).subquery() result = await self._db.execute( @@ -521,6 +577,7 @@ async def list_active_control_counts_by_agent( ) .join(Control, all_associations.c.control_id == Control.id) .where( + Control.namespace_key == namespace_key, Control.deleted_at.is_(None), or_( Control.data["enabled"].astext == "true", @@ -531,19 +588,28 @@ async def list_active_control_counts_by_agent( ) return {cast(str, row[0]): cast(int, row[1]) for row in result.all()} - async def add_control_to_policy(self, *, policy_id: int, control_id: int) -> None: + async def add_control_to_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Create a policy-control association if it does not already exist.""" await self._db.execute( pg_insert(policy_controls) - .values(policy_id=policy_id, control_id=control_id) + .values( + namespace_key=namespace_key, + policy_id=policy_id, + control_id=control_id, + ) .on_conflict_do_nothing() ) - async def remove_control_from_policy(self, *, policy_id: int, control_id: int) -> None: + async def remove_control_from_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Remove a policy-control association if it exists.""" await self._db.execute( delete(policy_controls).where( - (policy_controls.c.policy_id == policy_id) + (policy_controls.c.namespace_key == namespace_key) + & (policy_controls.c.policy_id == policy_id) & (policy_controls.c.control_id == control_id) ) ) @@ -613,16 +679,24 @@ async def remove_control_from_agent( control_still_active=policy_inheritance_result.first() is not None, ) - async def list_control_associations(self, control_id: int) -> ControlAssociations: + async def list_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Return all policy and direct agent associations for a control.""" policy_assoc_query = select( policy_controls.c.policy_id.label("policy_id"), literal(None, type_=String).label("agent_name"), - ).where(policy_controls.c.control_id == control_id) + ).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) agent_assoc_query = select( literal(None, type_=Integer).label("policy_id"), agent_controls.c.agent_name.label("agent_name"), - ).where(agent_controls.c.control_id == control_id) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) assoc_result = await self._db.execute(union_all(policy_assoc_query, agent_assoc_query)) policy_ids: set[int] = set() @@ -638,16 +712,26 @@ async def list_control_associations(self, control_id: int) -> ControlAssociation agent_names=sorted(agent_names), ) - async def remove_all_control_associations(self, control_id: int) -> ControlAssociations: + async def remove_all_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Remove all policy and direct agent associations for a control.""" - associations = await self.list_control_associations(control_id) + associations = await self.list_control_associations( + control_id, namespace_key=namespace_key + ) if associations.policy_ids: await self._db.execute( - delete(policy_controls).where(policy_controls.c.control_id == control_id) + delete(policy_controls).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) ) if associations.agent_names: await self._db.execute( - delete(agent_controls).where(agent_controls.c.control_id == control_id) + delete(agent_controls).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) ) return associations diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 96c4aad8..2d39bfa3 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -20,6 +20,7 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( DEFAULT_OPERATION_ACCESS, @@ -64,6 +65,35 @@ def test_default_operation_access_covers_every_operation(): assert not missing, f"Operations missing default access mapping: {missing}" +# --------------------------------------------------------------------------- +# NoAuthProvider +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_no_auth_provider_allows_any_operation(): + provider = NoAuthProvider(default_namespace_key="ns-local") + + principal = await provider.authorize( + _build_request(), + Operation.CONTROLS_DELETE, + ) + + assert principal == Principal(namespace_key="ns-local") + + +@pytest.mark.asyncio +async def test_no_auth_provider_grants_runtime_exchange_scope(): + provider = NoAuthProvider() + + principal = await provider.authorize( + _build_request(), + Operation.RUNTIME_TOKEN_EXCHANGE, + ) + + assert principal.scopes == (Operation.RUNTIME_USE.value,) + + # --------------------------------------------------------------------------- # HeaderAuthProvider # --------------------------------------------------------------------------- @@ -101,7 +131,7 @@ async def test_header_provider_public_returns_default_namespace(): @pytest.mark.asyncio -async def test_header_provider_authenticated_calls_legacy_validator(): +async def test_header_provider_authenticated_calls_local_validator(): provider = HeaderAuthProvider() expected_client = MagicMock(is_admin=False, key_id="abc12345") @@ -945,6 +975,70 @@ def test_runtime_ttl_loader_accepts_max(monkeypatch): ) +def test_build_default_provider_accepts_none_mode(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + + assert isinstance(auth_config._build_default_provider(), NoAuthProvider) + + +def test_resolve_runtime_mode_defaults_to_api_key_without_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + assert auth_config._resolve_runtime_mode() == "api_key" + + +def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + assert auth_config._resolve_runtime_mode() == "jwt" + + +def test_configure_runtime_none_installs_no_auth_provider(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "api_key") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), HeaderAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_jwt_requires_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + with pytest.raises(RuntimeError, match="requires AGENT_CONTROL_RUNTIME_TOKEN_SECRET"): + auth_config.configure_auth_from_env() + + def test_configure_then_reconfigure_clears_runtime_override(monkeypatch): """Reconfiguring without a runtime secret must drop the override.""" from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index b4922b9d..dfbb15f5 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -8,19 +8,19 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from agent_control_evaluators import RegexEvaluatorConfig +from agent_control_models import ConditionNode from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from agent_control_models import ConditionNode +from agent_control_server.auth_framework import Principal from agent_control_server.db import get_async_db -from agent_control_server.models import Control - -from agent_control_evaluators import RegexEvaluatorConfig from agent_control_server.endpoints import controls as controls_module from agent_control_server.main import app +from agent_control_server.models import DEFAULT_NAMESPACE_KEY, Control from .conftest import engine from .utils import VALID_CONTROL_PAYLOAD @@ -1106,7 +1106,12 @@ def model_dump(self, *args: object, **kwargs: object) -> dict[str, object]: request = SimpleNamespace(data=DummyData(payload)) # When: updating the control data with a non-Pydantic selector - response = await controls_module.set_control_data(control.id, request, async_db) + response = await controls_module.set_control_data( + control.id, + request, + async_db, + principal=Principal(namespace_key=DEFAULT_NAMESPACE_KEY), + ) # Then: the update succeeds and uses the original selector serialization assert response.success is True diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index 1a2af21f..c0f17754 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,14 +4,13 @@ import uuid -import pytest from fastapi.testclient import TestClient -from agent_control_server.config import auth_settings +from agent_control_server.auth_framework import set_authorizer +from agent_control_server.auth_framework.providers import NoAuthProvider from .utils import VALID_CONTROL_PAYLOAD - _CONTROLS_URL = "/api/v1/controls" _TEMPLATES_URL = "/api/v1/control-templates" @@ -199,18 +198,19 @@ def test_non_admin_cannot_delete_control( assert resp.status_code == 403, resp.text -def test_non_admin_cannot_validate_control_data( +def test_non_admin_can_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` requires ``CONTROLS_CREATE``.""" + """``/controls/validate`` requires ``CONTROLS_READ``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", json={"data": VALID_CONTROL_PAYLOAD}, ) - # Then: validation is admin-only - assert resp.status_code == 403, resp.text + # Then: validation is allowed for authenticated non-admin callers + assert resp.status_code == 200, resp.text + assert resp.json()["success"] is True def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: @@ -283,21 +283,16 @@ def test_unauthenticated_cannot_render_template( # --------------------------------------------------------------------------- -# No-auth deployment mode: api_key_enabled=False bypasses every gate. +# No-auth deployment mode: explicit provider bypasses every gate. # --------------------------------------------------------------------------- def test_no_auth_mode_allows_writes_without_credentials( unauthenticated_client: TestClient, - monkeypatch: pytest.MonkeyPatch, ) -> None: - """When ``api_key_enabled`` is False, the ``HeaderAuthProvider`` - short-circuits to a non-admin ``Principal`` for every operation, - including admin-level writes. This pins the "no auth" deployment - path so a future refactor can't silently start enforcing. - """ - # Given: api_key_enabled is False (single-tenant OSS dev mode) - monkeypatch.setattr(auth_settings, "api_key_enabled", False) + """Explicit no-auth provider allows every operation without credentials.""" + # Given: the request-auth framework is in no-auth mode + set_authorizer(NoAuthProvider()) # When: an unauthenticated client creates a control resp = unauthenticated_client.put( @@ -311,4 +306,3 @@ def test_no_auth_mode_allows_writes_without_credentials( # Then: the create succeeds because auth is disabled at the provider assert resp.status_code == 200, resp.text assert "control_id" in resp.json() - diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py new file mode 100644 index 00000000..40ecd216 --- /dev/null +++ b/server/tests/test_principal_namespace_flow.py @@ -0,0 +1,141 @@ +"""HTTP-level coverage for principal-derived namespace scoping.""" + +from __future__ import annotations + +import uuid +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from agent_control_server.auth_framework import ( + Operation, + Principal, + set_authorizer, +) + +from .utils import VALID_CONTROL_PAYLOAD + + +class HeaderNamespaceAuthorizer: + """Test authorizer that maps a request header to ``Principal.namespace_key``.""" + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del context + scopes = ( + (Operation.RUNTIME_USE.value,) + if operation is Operation.RUNTIME_TOKEN_EXCHANGE + else () + ) + return Principal( + namespace_key=request.headers.get("X-Test-Namespace", "default"), + is_admin=True, + scopes=scopes, + ) + + +def _client(app: FastAPI, namespace_key: str) -> TestClient: + return TestClient( + app, + raise_server_exceptions=True, + headers={"X-Test-Namespace": namespace_key}, + ) + + +def _agent_payload(agent_name: str) -> dict[str, Any]: + return { + "agent": { + "agent_name": agent_name, + "agent_description": "test agent", + "agent_version": "1.0", + }, + "steps": [], + } + + +def _evaluation_payload(agent_name: str) -> dict[str, Any]: + return { + "agent_name": agent_name, + "step": { + "type": "llm", + "name": "test-step", + "input": "x marks the spot", + "context": {}, + }, + "stage": "pre", + "target_type": "env", + "target_id": "prod", + } + + +def test_principal_namespace_scopes_management_and_runtime(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + register_a = ns_a.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + register_b = ns_b.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + assert register_a.status_code == 200, register_a.text + assert register_b.status_code == 200, register_b.text + + create_control = ns_a.put( + "/api/v1/controls", + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + assert create_control.status_code == 200, create_control.text + control_id = int(create_control.json()["control_id"]) + + policy = ns_a.put( + "/api/v1/policies", + json={"name": f"policy-{uuid.uuid4().hex[:12]}"}, + ) + assert policy.status_code == 200, policy.text + policy_id = int(policy.json()["policy_id"]) + attach_to_policy = ns_a.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert attach_to_policy.status_code == 200, attach_to_policy.text + + binding = ns_a.put( + "/api/v1/control-bindings", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": True, + }, + ) + assert binding.status_code == 200, binding.text + + assert ns_b.get(f"/api/v1/controls/{control_id}").status_code == 404 + assert ns_b.get(f"/api/v1/policies/{policy_id}/controls").status_code == 404 + assert ns_b.get("/api/v1/control-bindings").json()["bindings"] == [] + + eval_a = ns_a.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_a.status_code == 200, eval_a.text + assert eval_a.json()["is_safe"] is False + assert eval_a.json()["matches"][0]["control_id"] == control_id + + eval_b = ns_b.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_b.status_code == 200, eval_b.text + assert eval_b.json()["is_safe"] is True + + +def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + control_name = f"control-{uuid.uuid4().hex[:12]}" + payload = {"name": control_name, "data": VALID_CONTROL_PAYLOAD} + + assert ns_a.put("/api/v1/controls", json=payload).status_code == 200 + assert ns_b.put("/api/v1/controls", json=payload).status_code == 200 diff --git a/server/tests/test_target_merged_contract.py b/server/tests/test_target_merged_contract.py index 295a85e2..62891ba5 100644 --- a/server/tests/test_target_merged_contract.py +++ b/server/tests/test_target_merged_contract.py @@ -232,9 +232,9 @@ def test_target_binding_de_duplicated_against_direct_attachment( async def _insert_agent_in_namespace(async_db, *, name: str, namespace_key: str) -> None: """Insert an Agent row directly so the test can simulate a foreign namespace. - The endpoint's ``get_namespace_key`` returns the default namespace; this - helper sidesteps the resolver to seed an agent that the request-time - code path should not be able to reach. + The default test authorizer returns the default namespace; this helper + sidesteps the authorizer to seed an agent that the request-time code + path should not be able to reach. """ from agent_control_server.models import Agent From 33acd7e1c33a49ffdfe7a46f846847774c005aaf Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:20:00 +0530 Subject: [PATCH 15/42] chore(sdk-ts): regenerate client docs --- .../src/generated/funcs/agents-get-evaluator.ts | 3 +-- sdks/typescript/src/generated/funcs/agents-get.ts | 3 +-- .../typescript/src/generated/funcs/agents-init.ts | 1 + .../src/generated/funcs/agents-list-controls.ts | 2 +- .../src/generated/funcs/agents-list-evaluators.ts | 3 +-- .../typescript/src/generated/funcs/agents-list.ts | 2 +- .../src/generated/funcs/agents-update.ts | 1 + .../generated/funcs/control-bindings-create.ts | 6 +----- .../src/generated/funcs/control-bindings-list.ts | 3 +-- sdks/typescript/src/generated/sdk/agents.ts | 15 +++++++-------- .../src/generated/sdk/control-bindings.ts | 9 ++------- 11 files changed, 18 insertions(+), 30 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts index acb364eb..ceca1ec0 100644 --- a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts +++ b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts @@ -37,8 +37,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/funcs/agents-get.ts b/sdks/typescript/src/generated/funcs/agents-get.ts index 9724edbf..142f3062 100644 --- a/sdks/typescript/src/generated/funcs/agents-get.ts +++ b/sdks/typescript/src/generated/funcs/agents-get.ts @@ -38,8 +38,7 @@ import { Result } from "../types/fp.js"; * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 9d63358d..7150b2a4 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -51,6 +51,7 @@ import { Result } from "../types/fp.js"; * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls diff --git a/sdks/typescript/src/generated/funcs/agents-list-controls.ts b/sdks/typescript/src/generated/funcs/agents-list-controls.ts index 661c5509..d1e5b27d 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-controls.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-controls.ts @@ -53,7 +53,7 @@ import { Result } from "../types/fp.js"; * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters diff --git a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts index c4d8a4b2..4217e752 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts @@ -42,8 +42,7 @@ import { Result } from "../types/fp.js"; * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination diff --git a/sdks/typescript/src/generated/funcs/agents-list.ts b/sdks/typescript/src/generated/funcs/agents-list.ts index fda7574d..f887d0b5 100644 --- a/sdks/typescript/src/generated/funcs/agents-list.ts +++ b/sdks/typescript/src/generated/funcs/agents-list.ts @@ -42,7 +42,7 @@ import { Result } from "../types/fp.js"; * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info diff --git a/sdks/typescript/src/generated/funcs/agents-update.ts b/sdks/typescript/src/generated/funcs/agents-update.ts index e82644cf..aff9d827 100644 --- a/sdks/typescript/src/generated/funcs/agents-update.ts +++ b/sdks/typescript/src/generated/funcs/agents-update.ts @@ -40,6 +40,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 8412487e..71dee5a0 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -33,11 +33,7 @@ import { Result } from "../types/fp.js"; * Attach a control to an opaque external target. * * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * the active authorizer. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5e7e87c3..5c90c7c2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,8 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved by the active authorizer. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index a22f4209..0a70e128 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -39,7 +39,7 @@ export class Agents extends ClientSDK { * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info @@ -80,6 +80,7 @@ export class Agents extends ClientSDK { * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls @@ -106,8 +107,7 @@ export class Agents extends ClientSDK { * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list @@ -140,6 +140,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items @@ -185,7 +186,7 @@ export class Agents extends ClientSDK { * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters @@ -256,8 +257,7 @@ export class Agents extends ClientSDK { * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination @@ -287,8 +287,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index 5101ce74..dc6f20d3 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,8 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved by the active authorizer. */ async list( request?: @@ -46,11 +45,7 @@ export class ControlBindings extends ClientSDK { * Attach a control to an opaque external target. * * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * the active authorizer. */ async create( request: models.CreateControlBindingRequest, From 5fe5c6ef5a8453f3f191e695e3af8e51008ba407 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 23:07:04 +0530 Subject: [PATCH 16/42] fix(server): address runtime auth review feedback --- .../funcs/auth-runtime-token-exchange.ts | 9 ++--- .../funcs/control-bindings-create.ts | 4 +- .../funcs/control-bindings-delete.ts | 2 +- .../generated/funcs/control-bindings-get.ts | 5 +-- .../generated/funcs/control-bindings-list.ts | 2 +- .../funcs/control-bindings-update.ts | 2 +- sdks/typescript/src/generated/sdk/auth.ts | 9 ++--- .../src/generated/sdk/control-bindings.ts | 15 ++++--- .../auth_framework/core.py | 2 - .../auth_framework/providers/header.py | 2 - .../agent_control_server/endpoints/auth.py | 19 +++++---- .../endpoints/control_bindings.py | 15 ++++--- .../endpoints/controls.py | 6 +-- server/src/agent_control_server/namespace.py | 23 ----------- .../agent_control_server/services/controls.py | 23 ++++++----- server/tests/test_auth_framework.py | 24 +++++++++++ server/tests/test_controls_auth.py | 12 +++--- .../test_runtime_token_exchange_endpoint.py | 36 ++++++++++++++++- server/tests/test_services_controls.py | 40 ++++++++++++++----- 19 files changed, 146 insertions(+), 104 deletions(-) delete mode 100644 server/src/agent_control_server/namespace.py diff --git a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts index 176693e3..7e8679c8 100644 --- a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts +++ b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts @@ -32,11 +32,10 @@ import { Result } from "../types/fp.js"; * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 71dee5a0..faf99923 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -32,8 +32,8 @@ import { Result } from "../types/fp.js"; * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * the active authorizer. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts index 9e4d1293..9872a9b4 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ export function controlBindingsDelete( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-get.ts b/sdks/typescript/src/generated/funcs/control-bindings-get.ts index dafb7c7c..88b4e419 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-get.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-get.ts @@ -34,12 +34,11 @@ import { Result } from "../types/fp.js"; * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ export function controlBindingsGet( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5c90c7c2..a87ca89f 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,7 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by the active authorizer. + * storage namespace is resolved from the authenticated request. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-update.ts b/sdks/typescript/src/generated/funcs/control-bindings-update.ts index b3faf800..b94520a2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-update.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-update.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ export function controlBindingsUpdate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/auth.ts b/sdks/typescript/src/generated/sdk/auth.ts index cf6de9ba..2d0cf74e 100644 --- a/sdks/typescript/src/generated/sdk/auth.ts +++ b/sdks/typescript/src/generated/sdk/auth.ts @@ -14,11 +14,10 @@ export class Auth extends ClientSDK { * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index dc6f20d3..5a5bcf2b 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,7 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by the active authorizer. + * storage namespace is resolved from the authenticated request. */ async list( request?: @@ -44,8 +44,8 @@ export class ControlBindings extends ClientSDK { * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * the active authorizer. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ async create( request: models.CreateControlBindingRequest, @@ -104,7 +104,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ async delete( request: @@ -125,12 +125,11 @@ export class ControlBindings extends ClientSDK { * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ async get( request: @@ -153,7 +152,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ async update( request: diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index e0ea6da7..058169de 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -52,11 +52,9 @@ class Operation(StrEnum): POLICIES_READ = "policies.read" POLICIES_CREATE = "policies.create" POLICIES_UPDATE = "policies.update" - POLICIES_DELETE = "policies.delete" AGENTS_READ = "agents.read" AGENTS_CREATE = "agents.create" AGENTS_UPDATE = "agents.update" - AGENTS_DELETE = "agents.delete" RUNTIME_USE = "runtime.use" diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index 228ec443..16760768 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -45,11 +45,9 @@ class AccessLevel(Enum): Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, Operation.POLICIES_CREATE: AccessLevel.ADMIN, Operation.POLICIES_UPDATE: AccessLevel.ADMIN, - Operation.POLICIES_DELETE: AccessLevel.ADMIN, Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, Operation.AGENTS_UPDATE: AccessLevel.ADMIN, - Operation.AGENTS_DELETE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index f80cd2fa..b1ade969 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,13 +2,13 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer authenticates the credential and -authorizes the implied -:data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint +target_id)``; the configured authorization provider authenticates the +credential and authorizes the implied +``runtime.token_exchange`` operation. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the returned token, which is verified locally by -:class:`LocalJwtVerifyProvider`. +the runtime JWT provider. """ from __future__ import annotations @@ -56,7 +56,7 @@ class RuntimeTokenExchangeResponse(BaseModel): async def _exchange_context(request: Request) -> dict[str, Any]: - """Surface target identifiers to the authorizer's context. + """Surface target identifiers to the authorization context. Reads the request body once. FastAPI caches the parsed body, so the endpoint's own Pydantic body model still binds normally. @@ -89,11 +89,10 @@ async def runtime_token_exchange( ) -> RuntimeTokenExchangeResponse: """Mint a short-lived runtime token for the requested target. - The caller's credential is authenticated and authorized by the - installed default authorizer; the resulting :class:`Principal` - supplies the actor identity and (when the upstream surfaces it) - the grant scopes and expiry. This endpoint then mints a local HS256 - token whose lifetime cannot outlive the upstream grant. + The caller's credential is authenticated and authorized before the + resolved principal supplies the actor identity, grant scopes, and + expiry. This endpoint then mints a local HS256 token whose lifetime + cannot outlive the grant. Runtime auth must be enabled via ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index d2fe4b44..87386723 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -102,8 +102,8 @@ async def create_control_binding( ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. - Each binding row is scoped to the request namespace as resolved by - the active authorizer. + Each binding row is scoped to the namespace associated with the + authenticated request. """ service = ControlBindingsService(db) binding = await service.create_binding( @@ -153,7 +153,7 @@ async def list_control_bindings( cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by the active authorizer. + storage namespace is resolved from the authenticated request. """ parsed_cursor: int | None if cursor is None: @@ -201,12 +201,11 @@ async def get_control_binding( """Read a single control binding by surrogate ID. Authorization is namespace-wide: the binding's target identifiers - are not forwarded to the upstream because they are only discoverable - after the row is loaded, and ``require_operation`` is single-pass. + are not available until after the row is loaded. Callers whose authorization model requires per-target permissions should use the natural-key endpoints (``PUT /by-key``, ``POST /by-key:delete``) and the target-filtered list endpoint, all - of which forward ``(target_type, target_id)`` to the authorizer. + of which include ``(target_type, target_id)`` in the request context. """ service = ControlBindingsService(db) binding = await service.get_binding_or_404( @@ -232,7 +231,7 @@ async def patch_control_binding( See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``PUT /by-key`` for target-scoped - upserts that forward the target to the authorizer. + upserts that include the target in the request context. """ service = ControlBindingsService(db) binding = await service.set_enabled( @@ -260,7 +259,7 @@ async def delete_control_binding( See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``POST /by-key:delete`` for - target-scoped detach that forwards the target to the authorizer. + target-scoped detach that includes the target in the request context. """ service = ControlBindingsService(db) await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 5b01593c..00d2b710 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -787,12 +787,12 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Authorized as CONTROLS_READ: validate exercises the materialization -# path but does not mutate stored control data. +# Authorized as CONTROLS_CREATE: validate exercises the same materialization +# path as create/update authoring flows, even though it does not save. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. diff --git a/server/src/agent_control_server/namespace.py b/server/src/agent_control_server/namespace.py deleted file mode 100644 index 30e30be5..00000000 --- a/server/src/agent_control_server/namespace.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Namespace resolution for request-scoped scoping. - -V1 always resolves to the default namespace. The function exists as a -single seam so a future change can switch every namespace-scoped -endpoint to a real per-request resolver without touching each call -site. Overriding the dependency in V1 is not supported: only this -binding/evaluation layer reads it; controls, agents, and policies still -write under the default namespace, so an override here would create -inconsistent rows. Future work will thread a single resolver through -every write path together. -""" - -from __future__ import annotations - -from .models import DEFAULT_NAMESPACE_KEY - - -def get_namespace_key() -> str: - """Return the namespace_key for the current request. - - V1 returns ``DEFAULT_NAMESPACE_KEY`` unconditionally. - """ - return DEFAULT_NAMESPACE_KEY diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 41a62282..e3a5fd26 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -20,7 +20,6 @@ from ..errors import APIValidationError, NotFoundError from ..models import ( - DEFAULT_NAMESPACE_KEY, Control, ControlBinding, ControlVersion, @@ -100,7 +99,7 @@ def __init__(self, db: AsyncSession) -> None: def create_control( self, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, name: str, data: dict[str, Any], ) -> Control: @@ -161,17 +160,19 @@ async def get_active_control_or_404( control_id: int, *, for_update: bool = False, - namespace_key: str | None = None, + namespace_key: str, ) -> Control: """Load an active control row or raise CONTROL_NOT_FOUND. - When ``namespace_key`` is supplied, the lookup is scoped to that - namespace; a control that exists only in another namespace - surfaces as 404 (non-disclosing). + The lookup is scoped to the supplied namespace; a control that + exists only in another namespace surfaces as 404 + (non-disclosing). """ - stmt = select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) - if namespace_key is not None: - stmt = stmt.where(Control.namespace_key == namespace_key) + stmt = select(Control).where( + Control.id == control_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -190,7 +191,7 @@ async def active_control_name_exists( self, name: str, *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" @@ -537,7 +538,7 @@ async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], *, - namespace_key: str = DEFAULT_NAMESPACE_KEY, + namespace_key: str, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 2d39bfa3..799b2d52 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -20,6 +20,7 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + LocalJwtVerifyProvider, NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( @@ -1029,6 +1030,29 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +@pytest.mark.asyncio +async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + try: + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.CONTROLS_READ), HttpUpstreamAuthProvider) + assert isinstance(get_authorizer(Operation.RUNTIME_USE), LocalJwtVerifyProvider) + runtime_config = auth_config.runtime_auth_config() + assert runtime_config is not None + assert runtime_config.secret == _TEST_SECRET + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index c0f17754..04f44ca4 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,10 +4,9 @@ import uuid -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import set_authorizer from agent_control_server.auth_framework.providers import NoAuthProvider +from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD @@ -198,19 +197,18 @@ def test_non_admin_cannot_delete_control( assert resp.status_code == 403, resp.text -def test_non_admin_can_validate_control_data( +def test_non_admin_cannot_validate_control_data( non_admin_client: TestClient, ) -> None: - """``/controls/validate`` requires ``CONTROLS_READ``.""" + """``/controls/validate`` requires ``CONTROLS_CREATE``.""" # When: a non-admin attempts to validate a draft payload resp = non_admin_client.post( f"{_CONTROLS_URL}/validate", json={"data": VALID_CONTROL_PAYLOAD}, ) - # Then: validation is allowed for authenticated non-admin callers - assert resp.status_code == 200, resp.text - assert resp.json()["success"] is True + # Then: validation is admin-only + assert resp.status_code == 403, resp.text def test_non_admin_cannot_render_template(non_admin_client: TestClient) -> None: diff --git a/server/tests/test_runtime_token_exchange_endpoint.py b/server/tests/test_runtime_token_exchange_endpoint.py index 8d333a5c..1b1edae2 100644 --- a/server/tests/test_runtime_token_exchange_endpoint.py +++ b/server/tests/test_runtime_token_exchange_endpoint.py @@ -11,8 +11,6 @@ from datetime import UTC, datetime, timedelta import pytest -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import Operation, Principal from agent_control_server.auth_framework.config import ( RuntimeAuthConfig, @@ -25,6 +23,7 @@ from agent_control_server.auth_framework.providers import ( LocalJwtVerifyProvider, ) +from fastapi.testclient import TestClient _TEST_SECRET = "test-runtime-secret-12345678901234567890" @@ -180,6 +179,39 @@ async def test_exchange_then_verify_full_round_trip(client: TestClient, runtime_ assert principal.caller_id == "actor-rt" +def test_evaluation_rejects_runtime_jwt_for_wrong_target( + client: TestClient, + runtime_config_enabled, +): + """A runtime JWT minted for one target cannot be used for another target.""" + stub = _StubExchangeAuthorizer(actor_id="actor-rt", scopes=("runtime.use",)) + clear_authorizers() + set_authorizer(stub) + set_authorizer(LocalJwtVerifyProvider(secret=_TEST_SECRET), operation=Operation.RUNTIME_USE) + + exchange = client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": "log_stream", "target_id": "ls-allowed"}, + ) + assert exchange.status_code == 200, exchange.text + token = exchange.json()["token"] + + response = client.post( + "/api/v1/evaluation", + headers={"Authorization": f"Bearer {token}"}, + json={ + "agent_name": "agent", + "step": {"type": "llm", "name": "step", "input": "hello"}, + "stage": "pre", + "target_type": "log_stream", + "target_id": "ls-other", + }, + ) + + assert response.status_code == 403, response.text + assert response.json()["detail"] == "Runtime token target_id does not match the request." + + def test_exchange_endpoint_502_when_upstream_grant_already_expired( client: TestClient, runtime_config_enabled, diff --git a/server/tests/test_services_controls.py b/server/tests/test_services_controls.py index b858c527..3815f26b 100644 --- a/server/tests/test_services_controls.py +++ b/server/tests/test_services_controls.py @@ -8,10 +8,6 @@ import pytest from agent_control_models.errors import ErrorCode -from sqlalchemy import insert, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session - from agent_control_server.errors import APIValidationError from agent_control_server.models import ( DEFAULT_NAMESPACE_KEY, @@ -27,6 +23,9 @@ from agent_control_server.services.controls import ( ControlService, ) +from sqlalchemy import insert, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from .conftest import AsyncSessionTest, engine from .utils import VALID_CONTROL_PAYLOAD @@ -70,7 +69,11 @@ async def _create_versioned_control( async with AsyncSessionTest() as session: service = ControlService(session) - control = service.create_control(name=control_name, data=control_data) + control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, + name=control_name, + data=control_data, + ) await service.create_version( control, event_type="created", @@ -143,6 +146,7 @@ async def test_create_control_transaction_rollback_does_not_persist_control_or_v async with AsyncSessionTest() as session: service = ControlService(session) control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=control_name, data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -167,7 +171,10 @@ async def test_replace_control_data_transaction_rollback_preserves_prior_state() async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = "Should not persist" service.replace_control_data(control, data=updated_data) @@ -194,7 +201,10 @@ async def test_patch_mutation_transaction_rollback_preserves_prior_state() -> No async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.rename_control(control, name=f"{control_name}-renamed") service.set_control_enabled(control, enabled=False) await service.create_version( @@ -221,7 +231,10 @@ async def test_delete_control_transaction_rollback_preserves_active_state() -> N async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.mark_control_deleted(control, deleted_at=dt.datetime.now(dt.UTC)) await service.create_version( control, @@ -511,7 +524,10 @@ async def test_list_active_control_counts_by_agent_deduplicates_and_filters_inac await async_db.commit() # When: counting active controls for the agent - counts = await ControlService(async_db).list_active_control_counts_by_agent([agent.name]) + counts = await ControlService(async_db).list_active_control_counts_by_agent( + [agent.name], + namespace_key=DEFAULT_NAMESPACE_KEY, + ) # Then: active controls are deduplicated and inactive controls are excluded assert counts == {agent.name: 2} @@ -572,6 +588,7 @@ async def test_create_version_allocates_sequential_numbers_under_concurrent_muta async with AsyncSessionTest() as setup_session: setup_service = ControlService(setup_session) control = setup_service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=f"control-{uuid.uuid4()}", data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -592,7 +609,10 @@ async def mutate_and_version(description: str) -> None: async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = description service.replace_control_data(control, data=updated_data) From 9a8c8e93849f4b0ed1c7e930b7ae127d00a50b85 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 16:43:39 +0530 Subject: [PATCH 17/42] feat(server): operator-configurable extra forwarded headers on HttpUpstream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The default forward set (X-API-Key, Authorization, Cookie) only covers credential headers Agent Control itself reads. Deployments whose upstream authenticates against a different header name (e.g., a deployer-specific API-key header) had no way to surface that credential through HttpUpstreamAuthProvider — the inbound header reached AC but never crossed the upstream call. Add an extra_forward_headers config field on HttpUpstreamConfig (defaulting to the empty tuple) that operators populate via the new AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS env var (comma- separated). The provider's _forward_headers iterates over the union of the default set and the extras, deduplicating case-insensitively so a duplicate name (cross-set or within extras) does not produce two copies on the wire. Tests: - forwards a configured extra header alongside defaults - default forward set unchanged when extras are empty - extras dedupe against defaults case-insensitively - _parse_extra_forward_headers parametric: None / empty / single / multiple / whitespace / empty-entries / case-folded duplicates - configure_auth_from_env threads the parsed tuple onto the provider Lint clean, typecheck clean, full server suite (747) green. --- .../auth_framework/config.py | 29 +++++ .../auth_framework/providers/http_upstream.py | 20 ++- .../endpoints/controls.py | 3 +- server/tests/test_auth_framework.py | 115 ++++++++++++++++++ 4 files changed, 163 insertions(+), 4 deletions(-) diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index c8f428dc..8c39a2ec 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -46,6 +46,7 @@ _UPSTREAM_TIMEOUT_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_TIMEOUT_SECONDS" _UPSTREAM_TOKEN_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN" _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" +_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS" # Runtime flow. _RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" @@ -196,6 +197,9 @@ def _build_default_provider() -> RequestAuthorizer: timeout = float(os.environ.get(_UPSTREAM_TIMEOUT_ENV, "5.0")) token = os.environ.get(_UPSTREAM_TOKEN_ENV) token_header = os.environ.get(_UPSTREAM_TOKEN_HEADER_ENV, "X-Agent-Control-Service-Token") + extra_forward_headers = _parse_extra_forward_headers( + os.environ.get(_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV) + ) _logger.info("Default auth provider: http_upstream url=%s", url) return HttpUpstreamAuthProvider( HttpUpstreamConfig( @@ -203,6 +207,7 @@ def _build_default_provider() -> RequestAuthorizer: timeout_seconds=timeout, service_token=token, service_token_header=token_header, + extra_forward_headers=extra_forward_headers, ) ) raise RuntimeError( @@ -210,6 +215,30 @@ def _build_default_provider() -> RequestAuthorizer: ) +def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: + """Parse a comma-separated header list into a deduplicated tuple. + + Empty / unset env var returns an empty tuple. Whitespace around each + name is stripped. Empty entries (e.g. ``"X-A,,X-B"``) are dropped. + Order is preserved; duplicates (case-insensitive) are dropped after + the first occurrence. + """ + if not raw or not raw.strip(): + return () + seen: set[str] = set() + result: list[str] = [] + for raw_name in raw.split(","): + name = raw_name.strip() + if not name: + continue + lower = name.lower() + if lower in seen: + continue + seen.add(lower) + result.append(name) + return tuple(result) + + def _resolve_runtime_mode() -> str: raw = os.environ.get(_RUNTIME_MODE_ENV) if raw is None or not raw.strip(): diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index 8d5c850c..78ed9ae2 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -60,7 +60,7 @@ _logger = get_logger(__name__) -_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") +_DEFAULT_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") class _UpstreamGrant(BaseModel): @@ -136,6 +136,17 @@ class HttpUpstreamConfig: service_token_header: str = "X-Agent-Control-Service-Token" + extra_forward_headers: tuple[str, ...] = () + """Additional inbound request headers to forward to the upstream + on top of the default ``(X-API-Key, Authorization, Cookie)`` set. + + Use this when the upstream authenticates via a header the provider + does not forward by default (e.g., a deployer-specific API-key + header). Header lookups against the inbound request are + case-insensitive; an empty or absent inbound header is silently + dropped. Names duplicating the default set or each other (after + case-folding) are deduplicated.""" + class HttpUpstreamAuthProvider(RequestAuthorizer): """Delegates authorization to an upstream HTTP service.""" @@ -190,7 +201,12 @@ async def authorize( def _forward_headers(self, request: Request) -> dict[str, str]: headers: dict[str, str] = {} - for name in _FORWARDED_HEADERS: + seen: set[str] = set() + for name in (*_DEFAULT_FORWARDED_HEADERS, *self._config.extra_forward_headers): + lower = name.lower() + if lower in seen: + continue + seen.add(lower) value = request.headers.get(name) if value is not None: headers[name] = value diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 00d2b710..b4fa8d0b 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -787,8 +787,7 @@ async def set_control_data( summary="Validate control configuration", response_description="Validation result", ) -# Authorized as CONTROLS_CREATE: validate exercises the same materialization -# path as create/update authoring flows, even though it does not save. +# Validation uses the authoring path, so require create access. async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 799b2d52..dc3a1787 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -261,6 +261,75 @@ def factory(request: httpx.Request) -> httpx.Response: assert captured["headers"]["x-custom-token"] == "shh" +@pytest.mark.asyncio +async def test_http_upstream_forwards_extra_headers(): + # Given: a provider configured with an extra header in its forward list + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("X-Deployer-Auth",)}, + ) + + # When: the inbound request carries the extra header + inbound = _build_request(headers={"X-Deployer-Auth": "k_abc", "X-API-Key": "k1"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: both the default and the extra header reach the upstream + assert captured["headers"]["x-deployer-auth"] == "k_abc" + assert captured["headers"]["x-api-key"] == "k1" + + +@pytest.mark.asyncio +async def test_http_upstream_default_forward_set_unchanged(): + # Given: a provider with no extra_forward_headers + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream(factory) + + # When: the inbound carries an unlisted header alongside a default one + inbound = _build_request( + headers={"X-API-Key": "k1", "X-Deployer-Auth": "should-not-forward"} + ) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: only the default-set header reaches the upstream + assert captured["headers"].get("x-api-key") == "k1" + assert "x-deployer-auth" not in captured["headers"] + + +@pytest.mark.asyncio +async def test_http_upstream_extra_forward_dedupes_against_defaults(): + # Given: extra list duplicates a default header (different case) + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("x-api-key", "Authorization")}, + ) + + # When: inbound has both + inbound = _build_request(headers={"X-API-Key": "k1", "Authorization": "Bearer t"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: each header appears exactly once on the upstream request + forwarded = captured["headers"] + assert sum(1 for k in forwarded if k.lower() == "x-api-key") == 1 + assert sum(1 for k in forwarded if k.lower() == "authorization") == 1 + + @pytest.mark.asyncio @pytest.mark.parametrize( "status, expected", @@ -1053,6 +1122,52 @@ async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): await auth_config.teardown_auth() +@pytest.mark.parametrize( + "raw, expected", + [ + (None, ()), + ("", ()), + (" ", ()), + ("X-One", ("X-One",)), + ("X-One,X-Two", ("X-One", "X-Two")), + (" X-One , X-Two ", ("X-One", "X-Two")), + ("X-One,,X-Two", ("X-One", "X-Two")), + ("X-One,x-one,X-One", ("X-One",)), + ("X-A,X-B,x-a,X-C,X-b", ("X-A", "X-B", "X-C")), + ], +) +def test_parse_extra_forward_headers(raw, expected): + from agent_control_server.auth_framework.config import _parse_extra_forward_headers + + assert _parse_extra_forward_headers(raw) == expected + + +@pytest.mark.asyncio +async def test_configure_http_upstream_extra_forward_headers_env(monkeypatch): + """Setting the env var threads extra_forward_headers into the provider.""" + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv( + "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS", + "X-Deployer-Auth, X-Deployer-Trace", + ) + + try: + auth_config.configure_auth_from_env() + provider = get_authorizer(Operation.CONTROLS_READ) + assert isinstance(provider, HttpUpstreamAuthProvider) + assert provider._config.extra_forward_headers == ( + "X-Deployer-Auth", + "X-Deployer-Trace", + ) + finally: + await auth_config.teardown_auth() + + def test_configure_runtime_jwt_requires_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config From ecf2b1aa78cac66e809b053f2bd7c0c8c8b91a89 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 21:36:45 +0530 Subject: [PATCH 18/42] fix(server): preserve default runtime auth fallback --- .../auth_framework/config.py | 37 +++++++++------- server/tests/test_auth_framework.py | 44 +++++++++++++++++-- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 8c39a2ec..595c3117 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -16,8 +16,8 @@ :class:`NoAuthProvider`, ``api_key`` uses :class:`HeaderAuthProvider`, and ``jwt`` uses :class:`LocalJwtVerifyProvider`. When the mode is unset, startup - preserves historical behavior by selecting ``jwt`` if - ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, otherwise ``api_key``. + selects ``jwt`` if ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set; + otherwise runtime falls through to the default authorizer. The ``runtime.token_exchange`` operation continues to flow through the default authorizer because the exchange itself is shaped like a management call (forward credential, get grant). @@ -96,10 +96,11 @@ def configure_auth_from_env() -> None: Runtime flow: - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. - - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key`` (default when no runtime - token secret is configured): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key``: :class:`HeaderAuthProvider`. - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token secret is configured): :class:`LocalJwtVerifyProvider`. + - unset mode without a runtime token secret: fall through to the default + authorizer. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -121,20 +122,26 @@ def configure_auth_from_env() -> None: set_authorizer(default) _active_providers.append(default) - runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) - if runtime_mode == "jwt": + if runtime_mode == "default": _logger.info( - "Runtime auth provider: jwt override installed for %s", + "Runtime auth provider: default authorizer handles %s", Operation.RUNTIME_USE.value, ) else: - _logger.info( - "Runtime auth provider: %s override installed for %s", - runtime_mode, - Operation.RUNTIME_USE.value, - ) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": + _logger.info( + "Runtime auth provider: jwt override installed for %s", + Operation.RUNTIME_USE.value, + ) + else: + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, + Operation.RUNTIME_USE.value, + ) async def teardown_auth() -> None: @@ -242,7 +249,7 @@ def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: def _resolve_runtime_mode() -> str: raw = os.environ.get(_RUNTIME_MODE_ENV) if raw is None or not raw.strip(): - return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "api_key" + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "default" mode = raw.strip().lower() if mode in {"none", "no_auth"}: diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index dc3a1787..20c58aed 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -7,7 +7,6 @@ import httpx import pytest - from agent_control_server.auth_framework.core import ( Operation, Principal, @@ -700,7 +699,6 @@ def test_runtime_token_rejects_naive_upstream_expires_at(): def test_runtime_token_rejects_management_token_passed_to_runtime_verify(): """A token without ``domain=runtime`` must be rejected by runtime verify.""" import jwt - from agent_control_server.auth_framework.runtime_token import ( RuntimeTokenError, verify_runtime_token, @@ -1053,13 +1051,13 @@ def test_build_default_provider_accepts_none_mode(monkeypatch): assert isinstance(auth_config._build_default_provider(), NoAuthProvider) -def test_resolve_runtime_mode_defaults_to_api_key_without_secret(monkeypatch): +def test_resolve_runtime_mode_defaults_to_default_without_secret(monkeypatch): from agent_control_server.auth_framework import config as auth_config monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) - assert auth_config._resolve_runtime_mode() == "api_key" + assert auth_config._resolve_runtime_mode() == "default" def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): @@ -1099,6 +1097,44 @@ def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): assert auth_config.runtime_auth_config() is None +def test_configure_runtime_unset_preserves_no_auth_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +@pytest.mark.asyncio +async def test_configure_runtime_unset_preserves_http_upstream_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + try: + auth_config.configure_auth_from_env() + + default_provider = get_authorizer(Operation.CONTROLS_READ) + runtime_provider = get_authorizer(Operation.RUNTIME_USE) + assert isinstance(default_provider, HttpUpstreamAuthProvider) + assert runtime_provider is default_provider + assert auth_config.runtime_auth_config() is None + finally: + await auth_config.teardown_auth() + + @pytest.mark.asyncio async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): from agent_control_server.auth_framework import config as auth_config From 03047b2dc2ba30e65ec662a2ee1668c0981eed97 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 14:39:01 +0530 Subject: [PATCH 19/42] fix(server): harden auth scoping --- docs/README.md | 1 + docs/auth.md | 148 ++++++++++++++++++ models/src/agent_control_models/server.py | 3 +- .../agent_control_server/endpoints/agents.py | 84 +++++++++- .../agent_control_server/endpoints/auth.py | 17 +- .../endpoints/controls.py | 44 +++++- server/tests/test_principal_namespace_flow.py | 33 +++- server/tests/test_target_merged_contract.py | 96 +++++++++++- 8 files changed, 402 insertions(+), 24 deletions(-) create mode 100644 docs/auth.md diff --git a/docs/README.md b/docs/README.md index 9b7cb757..e53dcf13 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,6 +10,7 @@ This repository keeps documentation concise. The full documentation lives on the - [Controls](https://docs.agentcontrol.dev/concepts/controls) — Define and configure control rules - [Reference](https://docs.agentcontrol.dev/core/reference) — SDK and server API reference - [Configuration](https://docs.agentcontrol.dev/core/configuration) — Environment variables, auth, and database settings +- [Server auth contract](auth.md) - Pluggable auth modes, HTTP upstream contract, and runtime JWT claims - [UI Quickstart](https://docs.agentcontrol.dev/core/ui-quickstart) — Run the dashboard and manage controls visually ## Examples diff --git a/docs/auth.md b/docs/auth.md new file mode 100644 index 00000000..5002faa8 --- /dev/null +++ b/docs/auth.md @@ -0,0 +1,148 @@ +# Server Auth Contract + +Agent Control keeps authentication and authorization provider-neutral. The server asks a configured provider whether a request may perform an operation, then scopes all data access with the returned `Principal`. + +## Operations + +Operations are stable strings. Deployers map them to their own permission model. + +```text +controls.read +controls.create +controls.update +controls.delete +policies.read +policies.create +policies.update +agents.read +agents.create +agents.update +control_bindings.read +control_bindings.write +runtime.token_exchange +runtime.use +``` + +## Principal + +Providers return a generic principal. Agent Control treats `namespace_key`, `caller_id`, `target_type`, and `target_id` as opaque strings. + +```json +{ + "namespace_key": "tenant-a", + "is_admin": false, + "caller_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "expires_at": "2026-05-11T15:00:00Z" +} +``` + +`namespace_key` is the tenancy boundary. Server queries filter by it, and namespace-aware foreign keys prevent cross-namespace references. + +## Auth Modes + +Management auth is selected by `AGENT_CONTROL_AUTH_MODE`. + +| Mode | Meaning | +| --- | --- | +| `none` | No credentials required. Intended for local development only. | +| `api_key` | Validate caller credentials locally with `AGENT_CONTROL_API_KEYS`. This is the default. `header` is accepted as a backwards-compatible alias. | +| `http_upstream` | POST each management authorization decision to `AGENT_CONTROL_AUTH_UPSTREAM_URL`. | + +Runtime auth is selected by `AGENT_CONTROL_RUNTIME_AUTH_MODE`. + +| Mode | Meaning | +| --- | --- | +| unset | Use `jwt` when `AGENT_CONTROL_RUNTIME_TOKEN_SECRET` is set. Otherwise runtime requests fall through to management auth. | +| `none` | No runtime credentials required. Intended for local development only. | +| `api_key` | Validate runtime requests with the same local API-key mechanism. | +| `jwt` | Require target-bound runtime tokens minted by `/api/v1/auth/runtime-token-exchange`. | + +Common combinations: + +| Management | Runtime | Use case | +| --- | --- | --- | +| `api_key` | unset | Existing standalone deployments. | +| `api_key` | `jwt` | Local management keys with short-lived target-bound runtime tokens. | +| `http_upstream` | `jwt` | External identity or authorization service for management, local token verify for high-volume runtime calls. | +| `none` | `none` | Single-process local development. Do not use in production. | + +## HTTP Upstream Contract + +When `AGENT_CONTROL_AUTH_MODE=http_upstream`, the server sends: + +```http +POST {AGENT_CONTROL_AUTH_UPSTREAM_URL} +``` + +```json +{ + "operation": "control_bindings.write", + "context": { + "target_type": "session", + "target_id": "target-123" + } +} +``` + +The provider forwards inbound `X-API-Key`, `Authorization`, and `Cookie` headers. Add deployer-specific header names with `AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS`, for example: + +```text +AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS=Vendor-API-Key,X-Workspace-Id +``` + +If `AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN` is set, it is forwarded on `AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER` or `X-Agent-Control-Service-Token` by default. + +A successful upstream response is: + +```json +{ + "namespace_key": "tenant-a", + "is_admin": false, + "caller_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "expires_at": "2026-05-11T15:00:00Z" +} +``` + +Only `namespace_key` is always required. `target_type` and `target_id` must be returned together when present. `expires_at` must include timezone information. + +Status handling: + +| Upstream status | Agent Control result | +| --- | --- | +| `200` | Parse the principal grant. | +| `401` | Authentication error. | +| `403` | Forbidden error. | +| `404` | Not found error. | +| `429` | `503` with a rate-limit detail and `Retry-After` hint when present. | +| Other statuses or malformed JSON | Fail closed with `503` or `502`. | + +## Runtime JWT Claims + +`/api/v1/auth/runtime-token-exchange` is a management-style request. The configured management provider authorizes `runtime.token_exchange` for the requested target. Agent Control then mints its own HS256 JWT with `AGENT_CONTROL_RUNTIME_TOKEN_SECRET`. + +The token payload contains: + +```json +{ + "iss": "agent-control/server", + "domain": "runtime", + "namespace_key": "tenant-a", + "actor_id": "user-or-key-id", + "target_type": "session", + "target_id": "target-123", + "scopes": ["runtime.use"], + "iat": 1778509800, + "exp": 1778510100, + "jti": "opaque-token-id" +} +``` + +Verification requires the expected issuer, `domain="runtime"`, a valid signature, an unexpired `exp`, and `runtime.use` in `scopes`. The token is accepted only for requests whose `target_type` and `target_id` match the bound target. + +The expiry is the earlier of `AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS` and the upstream grant's `expires_at` when supplied. Runtime token TTLs are capped at 86400 seconds. diff --git a/models/src/agent_control_models/server.py b/models/src/agent_control_models/server.py index 9b890b91..3529a5d4 100644 --- a/models/src/agent_control_models/server.py +++ b/models/src/agent_control_models/server.py @@ -640,7 +640,7 @@ class CreateControlBindingRequest(BaseModel): target_type: ControlBindingTargetField = Field( ..., - description="Opaque attachment kind (caller-defined; e.g. 'env', 'log_stream').", + description="Opaque attachment kind (caller-defined; e.g. 'environment', 'session').", ) target_id: ControlBindingTargetField = Field( ..., description="Opaque external identifier within the target_type." @@ -760,4 +760,3 @@ class DeleteControlBindingByKeyResponse(BaseModel): ), ) - diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index ac099911..57ca1ebc 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -29,20 +29,21 @@ SetPolicyResponse, StepKey, ) -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Query, Request from jsonschema_rs import ValidationError as JSONSchemaValidationError from pydantic import BaseModel, ValidationError from sqlalchemy import delete, func, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth_framework import Operation, Principal, require_operation +from ..auth_framework import Operation, Principal, get_authorizer, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, BadRequestError, ConflictError, DatabaseError, + ForbiddenError, NotFoundError, ) from ..logging_utils import get_logger @@ -85,6 +86,81 @@ type StepKeyTuple = tuple[str, str] +def _complete_target_context( + target_type: object | None, + target_id: object | None, +) -> dict[str, str] | None: + """Return target context only when both halves are present strings.""" + if not isinstance(target_type, str) or not isinstance(target_id, str): + return None + if not target_type or not target_id: + return None + return {"target_type": target_type, "target_id": target_id} + + +async def _init_agent_target_context(request: Request) -> dict[str, str] | None: + """Extract optional target context from an ``initAgent`` body.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + return None + if not isinstance(body, dict): + return None + return _complete_target_context(body.get("target_type"), body.get("target_id")) + + +def _agent_controls_target_context(request: Request) -> dict[str, str] | None: + """Extract optional target context from ``GET /agents/{name}/controls``.""" + return _complete_target_context( + request.query_params.get("target_type"), + request.query_params.get("target_id"), + ) + + +async def _authorize_target_read_if_present( + request: Request, + context: dict[str, str] | None, +) -> Principal | None: + """Require target read authorization before returning target-merged controls.""" + if context is None: + return None + return await get_authorizer(Operation.CONTROL_BINDINGS_READ).authorize( + request, + Operation.CONTROL_BINDINGS_READ, + context, + ) + + +async def _init_agent_target_principal(request: Request) -> Principal | None: + return await _authorize_target_read_if_present( + request, + await _init_agent_target_context(request), + ) + + +async def _agent_controls_target_principal(request: Request) -> Principal | None: + return await _authorize_target_read_if_present( + request, + _agent_controls_target_context(request), + ) + + +def _ensure_target_principal_matches_namespace( + principal: Principal, + target_principal: Principal | None, +) -> None: + """Fail closed if the target authorization resolves to a different namespace.""" + if target_principal is None: + return + if target_principal.namespace_key == principal.namespace_key: + return + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="Target authorization resolved to a different namespace.", + hint="Ensure the credential is scoped to the requested target and namespace.", + ) + + # ============================================================================= # List Agents Models # ============================================================================= @@ -445,6 +521,7 @@ async def init_agent( request: InitAgentRequest, db: AsyncSession = Depends(get_async_db), principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), + target_principal: Principal | None = Depends(_init_agent_target_principal), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. @@ -474,6 +551,7 @@ async def init_agent( InitAgentResponse with created flag and the effective controls """ namespace_key = principal.namespace_key + _ensure_target_principal_matches_namespace(principal, target_principal) # Check for evaluator name collisions with built-in evaluators builtin_names = _get_builtin_evaluator_names() @@ -1493,6 +1571,7 @@ async def list_agent_controls( ), db: AsyncSession = Depends(get_async_db), principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), + target_principal: Principal | None = Depends(_agent_controls_target_principal), ) -> AgentControlsResponse: """ List protection controls effective for an agent. @@ -1527,6 +1606,7 @@ async def list_agent_controls( HTTPException 404: Agent not found """ namespace_key = principal.namespace_key + _ensure_target_principal_matches_namespace(principal, target_principal) if (target_type is None) != (target_id is None): raise BadRequestError( diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index b1ade969..7125b64d 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -28,8 +28,10 @@ mint_runtime_token, ) from ..errors import APIError, BadRequestError +from ..logging_utils import get_logger router = APIRouter(prefix="/auth", tags=["auth"]) +_logger = get_logger(__name__) class RuntimeTokenExchangeRequest(BaseModel): @@ -38,7 +40,7 @@ class RuntimeTokenExchangeRequest(BaseModel): model_config = ConfigDict(extra="forbid") target_type: str = Field( - ..., description="Opaque target kind (e.g., ``log_stream``).", min_length=1 + ..., description="Opaque target kind (e.g., ``session``).", min_length=1 ) target_id: str = Field(..., description="Opaque target identifier.", min_length=1) @@ -175,6 +177,19 @@ async def runtime_token_exchange( hint="Check the runtime token configuration.", ) from exc + _logger.info( + "Runtime token exchanged", + extra={ + "namespace_key": claims.namespace_key, + "actor_id": claims.actor_id, + "target_type": claims.target_type, + "target_id": claims.target_id, + "scopes": list(claims.scopes), + "expires_at": claims.expires_at.isoformat(), + "jti": claims.jti, + }, + ) + return RuntimeTokenExchangeResponse( token=token, expires_at=claims.expires_at, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index b4fa8d0b..6e6441e9 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -195,12 +195,17 @@ async def _render_and_validate_template_input( template_input: TemplateControlInput, *, db: AsyncSession, + namespace_key: str, enabled: bool = True, ) -> ControlDefinition: """Render a template-backed input and validate evaluator config.""" rendered = render_template_control_input(template_input, enabled=enabled) try: - await _validate_control_definition(rendered.control, db) + await _validate_control_definition( + rendered.control, + db, + namespace_key=namespace_key, + ) except APIValidationError as exc: raise remap_template_api_error( exc, @@ -214,6 +219,7 @@ async def _materialize_control_input( control_input: ControlDefinition | TemplateControlInput, *, db: AsyncSession, + namespace_key: str, current_payload: object | None = None, control_id: int | None = None, ) -> ControlDefinition | UnrenderedTemplateControl: @@ -226,6 +232,7 @@ async def _materialize_control_input( return await _render_and_validate_template_input( control_input, db=db, + namespace_key=namespace_key, enabled=enabled, ) @@ -244,6 +251,7 @@ async def _materialize_control_input( return await _render_and_validate_template_input( control_input, db=db, + namespace_key=namespace_key, enabled=enabled, ) @@ -262,12 +270,19 @@ async def _materialize_control_input( raise RuntimeError("control_id is required for template-backed raw updates") raise _template_backed_raw_update_conflict(control_id) - await _validate_control_definition(control_input, db) + await _validate_control_definition( + control_input, + db, + namespace_key=namespace_key, + ) return control_input async def _validate_control_definition( - control_def: ControlDefinition, db: AsyncSession + control_def: ControlDefinition, + db: AsyncSession, + *, + namespace_key: str, ) -> None: """Validate evaluator config for definitions referencing known global evaluators. @@ -296,7 +311,10 @@ async def _validate_control_definition( agent_data = agent_data_by_name.get(agent_namespace) if agent_data is None: agent_result = await db.execute( - select(Agent).where(Agent.name == agent_namespace) + select(Agent).where( + Agent.name == agent_namespace, + Agent.namespace_key == namespace_key, + ) ) agent = agent_result.scalars().first() if agent is None: @@ -447,7 +465,7 @@ async def _validate_control_definition( async def render_control_template( request: RenderControlTemplateRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> RenderControlTemplateResponse: """Render a template-backed control without persisting it.""" control_def = await _render_and_validate_template_input( @@ -456,6 +474,7 @@ async def render_control_template( template_values=request.template_values, ), db=db, + namespace_key=principal.namespace_key, enabled=True, ) return RenderControlTemplateResponse(control=control_def) @@ -504,7 +523,11 @@ async def create_control( hint="Choose a different name or update the existing control.", ) - control_def = await _materialize_control_input(request.data, db=db) + control_def = await _materialize_control_input( + request.data, + db=db, + namespace_key=namespace_key, + ) control_data = _serialize_control_data(control_def) control = control_service.create_control( @@ -751,6 +774,7 @@ async def set_control_data( control_def = await _materialize_control_input( request.data, db=db, + namespace_key=principal.namespace_key, current_payload=control.data, control_id=control_id, ) @@ -791,7 +815,7 @@ async def set_control_data( async def validate_control_data( request: ValidateControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> ValidateControlDataResponse: """ Validate control configuration data without saving it. @@ -805,7 +829,11 @@ async def validate_control_data( """ # Validate mirrors create: complete template values trigger a full render, # incomplete values validate structure only (matching unrendered create). - await _materialize_control_input(request.data, db=db) + await _materialize_control_input( + request.data, + db=db, + namespace_key=principal.namespace_key, + ) return ValidateControlDataResponse(success=True) diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py index 40ecd216..14d2d874 100644 --- a/server/tests/test_principal_namespace_flow.py +++ b/server/tests/test_principal_namespace_flow.py @@ -3,16 +3,16 @@ from __future__ import annotations import uuid +from copy import deepcopy from typing import Any -from fastapi import FastAPI, Request -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import ( Operation, Principal, set_authorizer, ) +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD @@ -139,3 +139,30 @@ def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAP assert ns_a.put("/api/v1/controls", json=payload).status_code == 200 assert ns_b.put("/api/v1/controls", json=payload).status_code == 200 + + +def test_agent_scoped_evaluator_validation_uses_principal_namespace(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + register_b = ns_b.post( + "/api/v1/agents/initAgent", + json={ + **_agent_payload(agent_name), + "evaluators": [{"name": "custom", "config_schema": {"type": "object"}}], + }, + ) + assert register_b.status_code == 200, register_b.text + + control_data = deepcopy(VALID_CONTROL_PAYLOAD) + control_data["condition"]["evaluator"] = { + "name": f"{agent_name}:custom", + "config": {}, + } + + resp = ns_a.post("/api/v1/controls/validate", json={"data": control_data}) + assert resp.status_code == 404, resp.text + assert resp.json()["detail"] == f"Agent '{agent_name}' not found" diff --git a/server/tests/test_target_merged_contract.py b/server/tests/test_target_merged_contract.py index 62891ba5..6bc4ab0f 100644 --- a/server/tests/test_target_merged_contract.py +++ b/server/tests/test_target_merged_contract.py @@ -18,11 +18,37 @@ from copy import deepcopy from typing import Any +import pytest +from agent_control_server.auth_framework import Operation, Principal, set_authorizer +from fastapi import Request from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD, canonicalize_control_payload +class RecordingAuthorizer: + """Authorizer that records operation/context pairs for endpoint contract tests.""" + + def __init__(self, *, target_namespace_key: str = "default") -> None: + self.calls: list[tuple[Operation, dict[str, Any] | None]] = [] + self.target_namespace_key = target_namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request + self.calls.append((operation, context)) + namespace_key = ( + self.target_namespace_key + if operation is Operation.CONTROL_BINDINGS_READ and context is not None + else "default" + ) + return Principal(namespace_key=namespace_key, is_admin=True) + + def _agent_payload( agent_name: str, *, @@ -115,7 +141,7 @@ def _list_effective_via_get( # --------------------------------------------------------------------------- -def test_initAgent_with_target_merges_direct_and_target_controls( +def test_init_agent_with_target_merges_direct_and_target_controls( client: TestClient, ) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" @@ -134,7 +160,7 @@ def test_initAgent_with_target_merges_direct_and_target_controls( assert returned_ids == {direct_id, target_id_ctrl} -def test_initAgent_newly_created_with_target_picks_up_pre_existing_bindings( +def test_init_agent_newly_created_with_target_picks_up_pre_existing_bindings( client: TestClient, ) -> None: """Bindings can pre-exist the agent row. @@ -154,7 +180,7 @@ def test_initAgent_newly_created_with_target_picks_up_pre_existing_bindings( assert returned_ids == [pre_existing] -def test_initAgent_partial_target_pair_rejected(client: TestClient) -> None: +def test_init_agent_partial_target_pair_rejected(client: TestClient) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" payload = _agent_payload(agent_name) payload["target_type"] = "env" # target_id omitted @@ -162,12 +188,28 @@ def test_initAgent_partial_target_pair_rejected(client: TestClient) -> None: assert resp.status_code == 422 +def test_init_agent_with_target_requires_target_read_authorization( + client: TestClient, +) -> None: + authorizer = RecordingAuthorizer() + set_authorizer(authorizer) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + body = _register_agent(client, agent_name, target_type="env", target_id="prod") + + assert body["created"] is True + assert ( + Operation.CONTROL_BINDINGS_READ, + {"target_type": "env", "target_id": "prod"}, + ) in authorizer.calls + + # --------------------------------------------------------------------------- # GET /agents/{name}/controls contract. # --------------------------------------------------------------------------- -def test_get_agent_controls_with_target_matches_initAgent_response( +def test_get_agent_controls_with_target_matches_init_agent_response( client: TestClient, ) -> None: agent_name = f"agent-{uuid.uuid4().hex[:12]}" @@ -200,6 +242,45 @@ def test_get_agent_controls_partial_target_pair_returns_400( assert resp.status_code == 400 +def test_get_agent_controls_with_target_requires_target_read_authorization( + client: TestClient, +) -> None: + authorizer = RecordingAuthorizer() + set_authorizer(authorizer) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + _register_agent(client, agent_name) + authorizer.calls.clear() + + ids = _list_effective_via_get( + client, + agent_name, + target_type="env", + target_id="prod", + ) + + assert ids == [] + assert (Operation.AGENTS_READ, None) in authorizer.calls + assert ( + Operation.CONTROL_BINDINGS_READ, + {"target_type": "env", "target_id": "prod"}, + ) in authorizer.calls + + +def test_get_agent_controls_rejects_target_namespace_mismatch( + client: TestClient, +) -> None: + set_authorizer(RecordingAuthorizer(target_namespace_key="other-ns")) + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + _register_agent(client, agent_name) + + resp = client.get( + f"/api/v1/agents/{agent_name}/controls", + params={"target_type": "env", "target_id": "prod"}, + ) + + assert resp.status_code == 403, resp.text + + def test_get_agent_controls_no_target_omits_target_bindings( client: TestClient, ) -> None: @@ -243,11 +324,10 @@ async def _insert_agent_in_namespace(async_db, *, name: str, namespace_key: str) await async_db.commit() -import pytest # noqa: E402 (kept local; the rest of the file is sync) - - @pytest.mark.asyncio -async def test_get_agent_controls_cross_namespace_returns_404(client: TestClient, async_db) -> None: +async def test_get_agent_controls_cross_namespace_returns_404( + client: TestClient, async_db +) -> None: """Agent existing only in another namespace must not surface here. The merged-resolver contract is namespace-scoped end-to-end; if the From 3e7166118857564a8d9250252c37d981ca53bd03 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:20:43 +0530 Subject: [PATCH 20/42] docs(server): clarify upstream auth failure mapping --- docs/auth.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/auth.md b/docs/auth.md index 5002faa8..7aafd2ad 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -120,7 +120,8 @@ Status handling: | `403` | Forbidden error. | | `404` | Not found error. | | `429` | `503` with a rate-limit detail and `Retry-After` hint when present. | -| Other statuses or malformed JSON | Fail closed with `503` or `502`. | +| Other statuses or upstream network errors | Fail closed with `503`. | +| Malformed `200` principal response | Fail closed with `502`. | ## Runtime JWT Claims From 7f696ef0fc8da77889777b4b8a419d7fe4e4aa74 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:50:43 +0530 Subject: [PATCH 21/42] docs(server): explain target principal authorization --- .../agent_control_server/endpoints/agents.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 57ca1ebc..1b380026 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -121,7 +121,20 @@ async def _authorize_target_read_if_present( request: Request, context: dict[str, str] | None, ) -> Principal | None: - """Require target read authorization before returning target-merged controls.""" + """Require target read authorization before returning target-merged controls. + + Agent endpoints that accept optional target context have two separate + authorization decisions: + + - the endpoint operation itself (for example, ``agents.create``), whose + result is exposed to the route as ``principal``; + - the target binding read (``control_bindings.read``), whose result is + exposed as ``target_principal``. + + Keeping the results separate lets the route verify that the caller's + namespace and the target's resolved namespace agree before merging + target-bound controls into the response. + """ if context is None: return None return await get_authorizer(Operation.CONTROL_BINDINGS_READ).authorize( @@ -545,7 +558,8 @@ async def init_agent( Args: request: Agent metadata and step schemas db: Database session (injected) - principal: Authorized request principal + principal: Authorized request principal for the agent create operation + target_principal: Optional principal from the target binding read check Returns: InitAgentResponse with created flag and the effective controls @@ -1596,7 +1610,8 @@ async def list_agent_controls( target_type: Optional opaque target kind (paired with target_id) target_id: Optional opaque target identifier (paired with target_type) db: Database session (injected) - principal: Authorized request principal + principal: Authorized request principal for the agent read operation + target_principal: Optional principal from the target binding read check Returns: AgentControlsResponse with controls matching the requested state filters From 69aaa497fe6798de0ce65444f3f1c6df0e34e06d Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:59:55 +0530 Subject: [PATCH 22/42] chore(sdk-ts): refresh generated client docs --- sdks/typescript/src/generated/funcs/agents-init.ts | 3 ++- sdks/typescript/src/generated/funcs/agents-list-controls.ts | 3 ++- .../src/generated/models/create-control-binding-request.ts | 2 +- .../src/generated/models/runtime-token-exchange-request.ts | 2 +- sdks/typescript/src/generated/sdk/agents.ts | 6 ++++-- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 7150b2a4..d1136c2f 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -51,7 +51,8 @@ import { Result } from "../types/fp.js"; * Args: * request: Agent metadata and step schemas * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent create operation + * target_principal: Optional principal from the target binding read check * * Returns: * InitAgentResponse with created flag and the effective controls diff --git a/sdks/typescript/src/generated/funcs/agents-list-controls.ts b/sdks/typescript/src/generated/funcs/agents-list-controls.ts index d1e5b27d..619a45d6 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-controls.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-controls.ts @@ -53,7 +53,8 @@ import { Result } from "../types/fp.js"; * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent read operation + * target_principal: Optional principal from the target binding read check * * Returns: * AgentControlsResponse with controls matching the requested state filters diff --git a/sdks/typescript/src/generated/models/create-control-binding-request.ts b/sdks/typescript/src/generated/models/create-control-binding-request.ts index ace9f49b..f4e0c940 100644 --- a/sdks/typescript/src/generated/models/create-control-binding-request.ts +++ b/sdks/typescript/src/generated/models/create-control-binding-request.ts @@ -22,7 +22,7 @@ export type CreateControlBindingRequest = { */ targetId: string; /** - * Opaque attachment kind (caller-defined; e.g. 'env', 'log_stream'). + * Opaque attachment kind (caller-defined; e.g. 'environment', 'session'). */ targetType: string; }; diff --git a/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts b/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts index 65e02bda..e20ed22e 100644 --- a/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts +++ b/sdks/typescript/src/generated/models/runtime-token-exchange-request.ts @@ -14,7 +14,7 @@ export type RuntimeTokenExchangeRequest = { */ targetId: string; /** - * Opaque target kind (e.g., ``log_stream``). + * Opaque target kind (e.g., ``session``). */ targetType: string; }; diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index 0a70e128..bed5b41f 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -80,7 +80,8 @@ export class Agents extends ClientSDK { * Args: * request: Agent metadata and step schemas * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent create operation + * target_principal: Optional principal from the target binding read check * * Returns: * InitAgentResponse with created flag and the effective controls @@ -186,7 +187,8 @@ export class Agents extends ClientSDK { * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * principal: Authorized request principal + * principal: Authorized request principal for the agent read operation + * target_principal: Optional principal from the target binding read check * * Returns: * AgentControlsResponse with controls matching the requested state filters From 7fe90bf2e220b085d52bad6d73c688af703a695b Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 7 May 2026 20:33:08 +0530 Subject: [PATCH 23/42] feat(sdk-python): add runtime token auth Exchange target-bound runtime tokens for evaluation requests when configured, cache them per target, and retry once after a 401. Keep no-auth and API-key runtime flows on the existing request-auth path when token exchange is unavailable or disabled. --- sdks/python/src/agent_control/__init__.py | 5 +- sdks/python/src/agent_control/_state.py | 3 + sdks/python/src/agent_control/client.py | 212 +++++++- sdks/python/src/agent_control/evaluation.py | 72 ++- sdks/python/src/agent_control/runtime_auth.py | 194 +++++++ sdks/python/tests/test_client.py | 474 +++++++++++++++++- sdks/python/tests/test_local_evaluation.py | 89 ++++ sdks/python/tests/test_runtime_auth.py | 232 +++++++++ 8 files changed, 1260 insertions(+), 21 deletions(-) create mode 100644 sdks/python/src/agent_control/runtime_auth.py create mode 100644 sdks/python/tests/test_runtime_auth.py diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index 149694c2..c9561a0c 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -562,6 +562,7 @@ async def handle(message: str): state.current_agent = next_agent state.server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000' state.api_key = api_key + state.runtime_token_cache.clear() state.target_type = target_type state.target_id = target_id @@ -597,7 +598,8 @@ async def register() -> list[dict[str, Any]] | None: assert state.current_agent is not None async with AgentControlClient( - base_url=state.server_url, api_key=state.api_key + base_url=state.server_url, + api_key=state.api_key, ) as client: # Check server health first try: @@ -715,6 +717,7 @@ def _reset_state() -> None: state.server_controls = None state.server_url = None state.api_key = None + state.runtime_token_cache.clear() state.target_type = None state.target_id = None diff --git a/sdks/python/src/agent_control/_state.py b/sdks/python/src/agent_control/_state.py index 25974567..834c73b0 100644 --- a/sdks/python/src/agent_control/_state.py +++ b/sdks/python/src/agent_control/_state.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Any +from .runtime_auth import RuntimeTokenCache + if TYPE_CHECKING: from agent_control_models import Agent @@ -24,6 +26,7 @@ def __init__(self) -> None: self.server_controls: list[dict[str, Any]] | None = None self.server_url: str | None = None self.api_key: str | None = None + self.runtime_token_cache = RuntimeTokenCache() # Optional target context fixed at init() time; both fields are set # together or both remain None. self.target_type: str | None = None diff --git a/sdks/python/src/agent_control/client.py b/sdks/python/src/agent_control/client.py index 41ce0425..bc5ef0f1 100644 --- a/sdks/python/src/agent_control/client.py +++ b/sdks/python/src/agent_control/client.py @@ -2,14 +2,43 @@ import logging import os +from collections.abc import Generator from types import TracebackType +from typing import Any, cast import httpx from . import __version__ as sdk_version +from .runtime_auth import ( + RuntimeAuthMode, + RuntimeTokenCache, + normalize_runtime_auth_mode, + parse_runtime_token_exchange_response, +) _logger = logging.getLogger(__name__) +_RUNTIME_AUTH_MODE_ENV_VAR = "AGENT_CONTROL_RUNTIME_AUTH_MODE" +_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS = 30 +_AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503} +_GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503} + + +class _AgentControlAuth(httpx.Auth): + """Attach local API-key credentials unless a request already has Bearer auth.""" + + def __init__(self, api_key: str | None) -> None: + self._api_key = api_key + + def auth_flow( + self, + request: httpx.Request, + ) -> Generator[httpx.Request, httpx.Response, None]: + if self._api_key and "Authorization" not in request.headers: + if "X-API-Key" not in request.headers: + request.headers["X-API-Key"] = self._api_key + yield request + class AgentControlClient: """ @@ -45,6 +74,10 @@ def __init__( base_url: str | None = None, timeout: float = 30.0, api_key: str | None = None, + runtime_auth_mode: RuntimeAuthMode | str | None = None, + runtime_token_cache: RuntimeTokenCache | None = None, + runtime_token_refresh_margin_seconds: int = (_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS), + transport: httpx.AsyncBaseTransport | None = None, ): """ Initialize the client. @@ -55,6 +88,15 @@ def __init__( timeout: Request timeout in seconds api_key: API key for authentication. If not provided, will attempt to read from AGENT_CONTROL_API_KEY environment variable. + runtime_auth_mode: Runtime auth mode for evaluation requests. ``auto`` + attempts target-bound JWT exchange and falls back to normal + request auth when the exchange endpoint is unavailable. ``jwt`` + requires a successful exchange. ``api_key`` and ``none`` keep + evaluation requests on the normal request-auth path. + runtime_token_cache: Optional cache shared across client instances. + runtime_token_refresh_margin_seconds: Refresh cached runtime tokens + before this many seconds of validity remain. + transport: Optional httpx transport, primarily for tests. """ resolved_base_url = base_url or os.environ.get( self.BASE_URL_ENV_VAR, "http://localhost:8000" @@ -62,6 +104,13 @@ def __init__( self.base_url = resolved_base_url.rstrip("/") self.timeout = timeout self._api_key = api_key or os.environ.get(self.API_KEY_ENV_VAR) + configured_runtime_mode = runtime_auth_mode or os.environ.get(_RUNTIME_AUTH_MODE_ENV_VAR) + self._runtime_auth_mode = normalize_runtime_auth_mode(configured_runtime_mode) + if runtime_token_refresh_margin_seconds < 0: + raise ValueError("runtime_token_refresh_margin_seconds must be >= 0.") + self._runtime_token_refresh_margin_seconds = runtime_token_refresh_margin_seconds + self._runtime_token_cache = runtime_token_cache or RuntimeTokenCache() + self._transport = transport self._client: httpx.AsyncClient | None = None self._server_version_warning_emitted = False @@ -70,15 +119,17 @@ def api_key(self) -> str | None: """Get the configured API key (read-only).""" return self._api_key + @property + def runtime_auth_mode(self) -> RuntimeAuthMode: + """Get the configured runtime auth mode (read-only).""" + return self._runtime_auth_mode + def _get_headers(self) -> dict[str, str]: - """Build request headers including authentication.""" - headers: dict[str, str] = { + """Build base SDK metadata headers.""" + return { "X-Agent-Control-SDK": "python", "X-Agent-Control-SDK-Version": sdk_version, } - if self._api_key: - headers["X-API-Key"] = self._api_key - return headers async def _check_server_version(self, response: httpx.Response) -> None: """Warn once when the server major version differs from the SDK major.""" @@ -108,6 +159,8 @@ async def __aenter__(self) -> "AgentControlClient": base_url=self.base_url, timeout=self.timeout, headers=self._get_headers(), + auth=_AgentControlAuth(self._api_key), + transport=self._transport, event_hooks={"response": [self._check_server_version]}, ) return self @@ -137,6 +190,7 @@ async def health_check(self) -> dict[str, str]: response = await self._client.get("/health") response.raise_for_status() from typing import cast + return cast(dict[str, str], response.json()) @property @@ -145,3 +199,151 @@ def http_client(self) -> httpx.AsyncClient: if self._client is None: raise RuntimeError("Client not initialized. Use 'async with' context manager.") return self._client + + async def post_runtime_evaluation( + self, + *, + json: dict[str, Any], + headers: dict[str, str] | None = None, + target_type: str | None = None, + target_id: str | None = None, + ) -> httpx.Response: + """POST an evaluation request with runtime auth when configured.""" + runtime_authorization = await self._runtime_authorization( + target_type=target_type, + target_id=target_id, + ) + request_headers = self._merge_runtime_headers(headers, runtime_authorization) + response = await self.http_client.post( + "/api/v1/evaluation", + json=json, + headers=request_headers, + ) + + if response.status_code == 401 and runtime_authorization is not None: + await response.aread() + runtime_authorization = await self._runtime_authorization( + target_type=target_type, + target_id=target_id, + force_refresh=True, + allow_auto_fallback=False, + ) + request_headers = self._merge_runtime_headers(headers, runtime_authorization) + response = await self.http_client.post( + "/api/v1/evaluation", + json=json, + headers=request_headers, + ) + + return response + + def _merge_runtime_headers( + self, + headers: dict[str, str] | None, + runtime_authorization: str | None, + ) -> dict[str, str] | None: + """Merge caller headers with an optional Bearer token.""" + if headers is None and runtime_authorization is None: + return None + + merged = dict(headers or {}) + if runtime_authorization is not None: + merged["Authorization"] = runtime_authorization + return merged + + async def _runtime_authorization( + self, + *, + target_type: str | None, + target_id: str | None, + force_refresh: bool = False, + allow_auto_fallback: bool = True, + ) -> str | None: + """Return an Authorization header value for runtime evaluation.""" + if self._runtime_auth_mode in {"none", "api_key"}: + return None + + if target_type is None or target_id is None: + if self._runtime_auth_mode == "jwt": + raise RuntimeError( + "runtime_auth_mode='jwt' requires target_type and target_id " + "for evaluation requests." + ) + return None + + if self._runtime_auth_mode == "auto" and self._runtime_token_cache.is_jwt_unavailable( + self.base_url, target_type, target_id + ): + return None + + if not force_refresh: + cached = self._runtime_token_cache.get( + self.base_url, + target_type, + target_id, + refresh_margin_seconds=self._runtime_token_refresh_margin_seconds, + ) + if cached is not None: + return f"Bearer {cached.token}" + + exchange_lock = self._runtime_token_cache.exchange_lock( + self.base_url, + target_type, + target_id, + ) + async with exchange_lock: + if not force_refresh: + cached = self._runtime_token_cache.get( + self.base_url, + target_type, + target_id, + refresh_margin_seconds=self._runtime_token_refresh_margin_seconds, + ) + if cached is not None: + return f"Bearer {cached.token}" + + token = await self._exchange_runtime_token( + target_type=target_type, + target_id=target_id, + allow_auto_fallback=allow_auto_fallback, + ) + if token is None: + return None + return f"Bearer {token}" + + async def _exchange_runtime_token( + self, + *, + target_type: str, + target_id: str, + allow_auto_fallback: bool = True, + ) -> str | None: + """Exchange the configured credential for a target-bound runtime token.""" + response = await self.http_client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": target_type, "target_id": target_id}, + ) + + if ( + self._runtime_auth_mode == "auto" + and allow_auto_fallback + and response.status_code in _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES + ): + self._runtime_token_cache.mark_jwt_unavailable( + server_url=self.base_url, + target_type=target_type, + target_id=target_id, + globally=response.status_code in _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES, + ) + return None + + response.raise_for_status() + payload = response.json() + if not isinstance(payload, dict): + raise RuntimeError("Runtime token exchange response was not an object.") + token = parse_runtime_token_exchange_response( + cast(dict[str, object], payload), + server_url=self.base_url, + ) + self._runtime_token_cache.set(token) + return token.token diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index f1c7da97..5324c322 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -1,8 +1,11 @@ """Evaluation check operations for Agent Control SDK.""" +from collections.abc import Awaitable, Callable from dataclasses import dataclass +from inspect import iscoroutinefunction from typing import Any, Literal, cast +import httpx from agent_control_engine import list_evaluators from agent_control_engine.core import ControlEngine from agent_control_models import ( @@ -22,6 +25,8 @@ from .tracing import get_trace_and_span_ids from .validation import ensure_agent_name +_RuntimePostEvaluation = Callable[..., Awaitable[httpx.Response]] + @dataclass class _ControlAdapter: @@ -43,12 +48,12 @@ def _resolve_session_target( ) -> tuple[str | None, str | None]: """Default per-call target from state, and reject mismatches. - The SDK supports one target per session, fixed at ``init()`` time — + The SDK supports one target per session, fixed at ``init()`` time - including no-target sessions, where the session target is ``(None, None)``. The cached controls (``state.server_controls``) are fetched for that session target. A per-call override that disagrees - with the session target — including supplying an explicit target on a - no-target session — would evaluate against the wrong cache and could + with the session target - including supplying an explicit target on a + no-target session - would evaluate against the wrong cache and could return safe without contacting the server. Reject the mismatch so callers re-init when they need to change targets. @@ -118,7 +123,7 @@ def _has_applicable_prefiltered_server_controls( parsed_server_controls: list[_ControlAdapter] = [] for control in server_control_payloads: - # Skip unrendered template controls — they have no condition to evaluate + # Skip unrendered template controls - they have no condition to evaluate # and should not trigger the server-call fallback. ctrl_data = control.get("control", {}) if ( @@ -206,6 +211,41 @@ def _cached_server_control_lookup( return _build_server_control_lookup(state.server_controls) +def _runtime_post_evaluation(client: Any) -> _RuntimePostEvaluation | None: + """Return a runtime-evaluation callable when the client exposes one.""" + runtime_post = getattr(client, "post_runtime_evaluation", None) + if not callable(runtime_post) or not iscoroutinefunction(runtime_post): + return None + return cast(_RuntimePostEvaluation, runtime_post) + + +async def _post_evaluation_request( + client: AgentControlClient, + *, + request_payload: dict[str, Any], + headers: dict[str, str] | None, + target_type: str | None, + target_id: str | None, +) -> httpx.Response: + """Send an evaluation request, using runtime auth when the client supports it.""" + runtime_post = None + if target_type is not None and target_id is not None: + runtime_post = _runtime_post_evaluation(client) + if runtime_post is not None: + return await runtime_post( + json=request_payload, + headers=headers, + target_type=target_type, + target_id=target_id, + ) + + return await client.http_client.post( + "/api/v1/evaluation", + json=request_payload, + headers=headers, + ) + + async def check_evaluation( client: AgentControlClient, agent_name: str, @@ -241,10 +281,12 @@ async def check_evaluation( ) request_payload = request.model_dump(mode="json") - response = await client.http_client.post( - "/api/v1/evaluation", - json=request_payload, + response = await _post_evaluation_request( + client, + request_payload=request_payload, headers=None, + target_type=target_type, + target_id=target_id, ) response.raise_for_status() @@ -311,7 +353,7 @@ async def check_evaluation_with_local( for control in controls: control_data = control.get("control", {}) - # Skip unrendered template controls — they cannot be evaluated. + # Skip unrendered template controls - they cannot be evaluated. if ( isinstance(control_data, dict) and control_data.get("template") is not None @@ -424,10 +466,12 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: headers["X-Span-Id"] = resolved_span_id try: - response = await client.http_client.post( - "/api/v1/evaluation", - json=request_payload, + response = await _post_evaluation_request( + client, + request_payload=request_payload, headers=headers, + target_type=target_type, + target_id=target_id, ) response.raise_for_status() server_result = EvaluationResponse.model_validate(response.json()) @@ -510,7 +554,11 @@ async def evaluate_controls( step_obj = Step(**step_dict) # type: ignore[arg-type] resolved_controls = state.server_controls or [] - async with AgentControlClient(base_url=state.server_url, api_key=state.api_key) as client: + async with AgentControlClient( + base_url=state.server_url, + api_key=state.api_key, + runtime_token_cache=state.runtime_token_cache, + ) as client: return await check_evaluation_with_local( client=client, agent_name=agent_name, diff --git a/sdks/python/src/agent_control/runtime_auth.py b/sdks/python/src/agent_control/runtime_auth.py new file mode 100644 index 00000000..3b3643e5 --- /dev/null +++ b/sdks/python/src/agent_control/runtime_auth.py @@ -0,0 +1,194 @@ +"""Runtime-token cache helpers for the Agent Control SDK.""" + +from __future__ import annotations + +import asyncio +import threading +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import Literal + +RuntimeAuthMode = Literal["auto", "none", "api_key", "jwt"] + +_TokenKey = tuple[str, str, str] +_DEFAULT_MAX_CACHE_ENTRIES = 256 + + +@dataclass(frozen=True) +class RuntimeToken: + """Short-lived runtime token bound to one target.""" + + token: str + expires_at: datetime + server_url: str + target_type: str + target_id: str + scopes: tuple[str, ...] + + def is_fresh(self, *, refresh_margin_seconds: int) -> bool: + """Return whether the token is usable beyond the refresh margin.""" + refresh_at = datetime.now(UTC) + timedelta(seconds=refresh_margin_seconds) + return self.expires_at > refresh_at + + +class RuntimeTokenCache: + """Thread-safe runtime token cache keyed by server and target.""" + + def __init__(self, *, max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES) -> None: + if max_entries < 1: + raise ValueError("max_entries must be >= 1.") + self._max_entries = max_entries + self._tokens: dict[_TokenKey, RuntimeToken] = {} + self._jwt_unavailable = False + self._jwt_unavailable_targets: set[_TokenKey] = set() + self._exchange_locks: dict[_TokenKey, asyncio.Lock] = {} + self._lock = threading.Lock() + + def get( + self, + server_url: str, + target_type: str, + target_id: str, + *, + refresh_margin_seconds: int, + ) -> RuntimeToken | None: + """Return a fresh cached token for the target, if present.""" + key = (server_url, target_type, target_id) + with self._lock: + token = self._tokens.get(key) + if token is None: + return None + if token.is_fresh(refresh_margin_seconds=refresh_margin_seconds): + return token + self._tokens.pop(key, None) + return None + + def set(self, token: RuntimeToken) -> None: + """Store a token and clear any fallback marker for its target.""" + key = (token.server_url, token.target_type, token.target_id) + with self._lock: + if key not in self._tokens and len(self._tokens) >= self._max_entries: + oldest_key = next(iter(self._tokens)) + self._tokens.pop(oldest_key, None) + self._jwt_unavailable_targets.discard(oldest_key) + self._exchange_locks.pop(oldest_key, None) + self._tokens[key] = token + self._jwt_unavailable_targets.discard(key) + + def remove(self, server_url: str, target_type: str, target_id: str) -> None: + """Drop the cached token for one target.""" + with self._lock: + self._tokens.pop((server_url, target_type, target_id), None) + + def mark_jwt_unavailable( + self, + *, + server_url: str | None = None, + target_type: str | None = None, + target_id: str | None = None, + globally: bool = False, + ) -> None: + """Record that JWT runtime auth should not be attempted.""" + with self._lock: + if globally: + self._jwt_unavailable = True + self._tokens.clear() + return + if server_url is not None and target_type is not None and target_id is not None: + key = (server_url, target_type, target_id) + if ( + key not in self._jwt_unavailable_targets + and len(self._jwt_unavailable_targets) >= self._max_entries + ): + evicted_key = self._jwt_unavailable_targets.pop() + self._exchange_locks.pop(evicted_key, None) + self._jwt_unavailable_targets.add(key) + self._tokens.pop(key, None) + + def is_jwt_unavailable(self, server_url: str, target_type: str, target_id: str) -> bool: + """Return whether JWT exchange is known unavailable for the target.""" + key = (server_url, target_type, target_id) + with self._lock: + return self._jwt_unavailable or key in self._jwt_unavailable_targets + + def clear(self) -> None: + """Clear every cached token and fallback marker.""" + with self._lock: + self._tokens.clear() + self._jwt_unavailable = False + self._jwt_unavailable_targets.clear() + self._exchange_locks.clear() + + def exchange_lock(self, server_url: str, target_type: str, target_id: str) -> asyncio.Lock: + """Return the async exchange lock for one server and target.""" + key = (server_url, target_type, target_id) + with self._lock: + lock = self._exchange_locks.get(key) + if lock is None: + lock = asyncio.Lock() + self._exchange_locks[key] = lock + return lock + + +def normalize_runtime_auth_mode(raw: str | None) -> RuntimeAuthMode: + """Normalize configured SDK runtime auth mode.""" + if raw is None or not raw.strip(): + return "auto" + + mode = raw.strip().lower() + if mode in {"none", "no_auth"}: + return "none" + if mode in {"api_key", "header"}: + return "api_key" + if mode == "auto": + return "auto" + if mode == "jwt": + return "jwt" + raise ValueError("runtime_auth_mode must be one of 'auto', 'none', 'api_key', or 'jwt'.") + + +def parse_runtime_token_exchange_response( + payload: Mapping[str, object], + *, + server_url: str, +) -> RuntimeToken: + """Parse the runtime token exchange response payload.""" + token = payload.get("token") + expires_at = payload.get("expires_at") + target_type = payload.get("target_type") + target_id = payload.get("target_id") + scopes = payload.get("scopes") + + if not isinstance(token, str) or not token: + raise RuntimeError("Runtime token exchange response did not include a token.") + if not isinstance(expires_at, str) or not expires_at: + raise RuntimeError("Runtime token exchange response did not include expires_at.") + if not isinstance(target_type, str) or not target_type: + raise RuntimeError("Runtime token exchange response did not include target_type.") + if not isinstance(target_id, str) or not target_id: + raise RuntimeError("Runtime token exchange response did not include target_id.") + if not isinstance(scopes, Sequence) or isinstance(scopes, str): + raise RuntimeError("Runtime token exchange response did not include scopes.") + + parsed_scopes: list[str] = [] + for scope in scopes: + if not isinstance(scope, str): + raise RuntimeError("Runtime token exchange response included a non-string scope.") + parsed_scopes.append(scope) + + normalized_expires_at = expires_at + if normalized_expires_at.endswith("Z"): + normalized_expires_at = f"{normalized_expires_at[:-1]}+00:00" + parsed_expires_at = datetime.fromisoformat(normalized_expires_at) + if parsed_expires_at.tzinfo is None: + parsed_expires_at = parsed_expires_at.replace(tzinfo=UTC) + + return RuntimeToken( + token=token, + expires_at=parsed_expires_at.astimezone(UTC), + server_url=server_url, + target_type=target_type, + target_id=target_id, + scopes=tuple(parsed_scopes), + ) diff --git a/sdks/python/tests/test_client.py b/sdks/python/tests/test_client.py index aff6796e..11a8ed91 100644 --- a/sdks/python/tests/test_client.py +++ b/sdks/python/tests/test_client.py @@ -2,12 +2,15 @@ from __future__ import annotations +import asyncio +from datetime import UTC, datetime, timedelta from unittest.mock import patch import httpx import pytest from agent_control.client import AgentControlClient, sdk_version +from agent_control.runtime_auth import RuntimeTokenCache def test_client_uses_agent_control_url_env_var( @@ -36,17 +39,482 @@ def test_explicit_base_url_overrides_env_var( assert client.base_url == "http://explicit.test:8000" -def test_get_headers_include_sdk_metadata_and_api_key() -> None: +def test_get_headers_include_sdk_metadata() -> None: # Given: a client configured with an API key client = AgentControlClient(api_key="test-key") # When: building request headers headers = client._get_headers() - # Then: SDK metadata and authentication headers are included + # Then: SDK metadata headers are included assert headers["X-Agent-Control-SDK"] == "python" assert headers["X-Agent-Control-SDK-Version"] == sdk_version - assert headers["X-API-Key"] == "test-key" + assert "X-API-Key" not in headers + + +def test_client_rejects_negative_runtime_token_refresh_margin() -> None: + with pytest.raises(ValueError, match="runtime_token_refresh_margin_seconds"): + AgentControlClient(runtime_token_refresh_margin_seconds=-1) + + +@pytest.mark.asyncio +async def test_client_adds_api_key_auth_to_regular_requests() -> None: + seen_requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_requests.append(request) + return httpx.Response(200, json={"ok": True}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + transport=transport, + ) as client: + response = await client.http_client.get("/api/v1/agents") + + assert response.status_code == 200 + assert seen_requests[0].headers["X-API-Key"] == "test-key" + + +@pytest.mark.asyncio +async def test_runtime_evaluation_exchanges_and_caches_bearer_token() -> None: + exchange_calls = 0 + evaluation_authorization_headers: list[str | None] = [] + evaluation_api_key_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_calls + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_calls += 1 + assert request.headers["X-API-Key"] == "test-key" + return httpx.Response( + 200, + json={ + "token": "runtime-token", + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + evaluation_api_key_headers.append(request.headers.get("X-API-Key")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="auto", + transport=transport, + ) as client: + for _ in range(2): + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + assert response.status_code == 200 + + assert exchange_calls == 1 + assert evaluation_authorization_headers == ["Bearer runtime-token", "Bearer runtime-token"] + assert evaluation_api_key_headers == [None, None] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_single_flights_cold_cache_exchange() -> None: + exchange_calls = 0 + evaluation_authorization_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + async def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_calls + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_calls += 1 + await asyncio.sleep(0.01) + return httpx.Response( + 200, + json={ + "token": "runtime-token", + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + responses = await asyncio.gather( + *( + client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + for _ in range(5) + ) + ) + + assert [response.status_code for response in responses] == [200, 200, 200, 200, 200] + assert exchange_calls == 1 + assert evaluation_authorization_headers == ["Bearer runtime-token"] * 5 + + +@pytest.mark.asyncio +async def test_runtime_evaluation_refreshes_token_before_expiry() -> None: + exchange_tokens = ["short-token", "fresh-token"] + exchange_expiries = [ + (datetime.now(UTC) + timedelta(seconds=5)).isoformat(), + (datetime.now(UTC) + timedelta(minutes=5)).isoformat(), + ] + evaluation_authorization_headers: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/auth/runtime-token-exchange": + return httpx.Response( + 200, + json={ + "token": exchange_tokens.pop(0), + "expires_at": exchange_expiries.pop(0), + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + runtime_token_refresh_margin_seconds=30, + transport=transport, + ) as client: + for _ in range(2): + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + assert response.status_code == 200 + + assert exchange_tokens == [] + assert evaluation_authorization_headers == [ + "Bearer short-token", + "Bearer fresh-token", + ] + + +@pytest.mark.asyncio +async def test_runtime_token_cache_is_scoped_to_server_url() -> None: + exchange_paths: list[str] = [] + authorization_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + cache = RuntimeTokenCache() + + def handler(request: httpx.Request) -> httpx.Response: + server_url = f"{request.url.scheme}://{request.url.host}" + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_paths.append(server_url) + return httpx.Response( + 200, + json={ + "token": f"{request.url.host}-token", + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + for base_url in ("https://server-a.test", "https://server-b.test"): + async with AgentControlClient( + base_url=base_url, + api_key="test-key", + runtime_auth_mode="jwt", + runtime_token_cache=cache, + transport=transport, + ) as client: + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + assert response.status_code == 200 + + assert exchange_paths == ["https://server-a.test", "https://server-b.test"] + assert authorization_headers == [ + "Bearer server-a.test-token", + "Bearer server-b.test-token", + ] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_auto_falls_back_to_api_key_when_exchange_unavailable() -> None: + exchange_calls = 0 + evaluation_api_key_headers: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_calls + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_calls += 1 + return httpx.Response(503, json={"detail": "runtime auth disabled"}) + + evaluation_api_key_headers.append(request.headers.get("X-API-Key")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="auto", + transport=transport, + ) as client: + for _ in range(2): + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + assert response.status_code == 200 + + assert exchange_calls == 1 + assert evaluation_api_key_headers == ["test-key", "test-key"] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_auto_without_target_uses_api_key_path() -> None: + exchange_calls = 0 + evaluation_api_key_headers: list[str | None] = [] + evaluation_authorization_headers: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_calls + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_calls += 1 + return httpx.Response(200, json={}) + + evaluation_api_key_headers.append(request.headers.get("X-API-Key")) + evaluation_authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="auto", + transport=transport, + ) as client: + response = await client.post_runtime_evaluation(json={}) + + assert response.status_code == 200 + assert exchange_calls == 0 + assert evaluation_api_key_headers == ["test-key"] + assert evaluation_authorization_headers == [None] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_retries_once_after_unauthorized_token() -> None: + exchange_tokens = ["expired-token", "fresh-token"] + evaluation_authorization_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/auth/runtime-token-exchange": + token = exchange_tokens.pop(0) + return httpx.Response( + 200, + json={ + "token": token, + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + authorization = request.headers.get("Authorization") + evaluation_authorization_headers.append(authorization) + if authorization == "Bearer expired-token": + return httpx.Response(401, json={"detail": "expired"}) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + + assert response.status_code == 200 + assert evaluation_authorization_headers == [ + "Bearer expired-token", + "Bearer fresh-token", + ] + assert exchange_tokens == [] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_does_not_auto_fallback_after_unauthorized_token() -> None: + exchange_attempt = 0 + evaluation_authorization_headers: list[str | None] = [] + evaluation_api_key_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_attempt + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_attempt += 1 + if exchange_attempt == 1: + return httpx.Response( + 200, + json={ + "token": "expired-token", + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + return httpx.Response(503, json={"detail": "runtime auth disabled"}) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + evaluation_api_key_headers.append(request.headers.get("X-API-Key")) + return httpx.Response(401, json={"detail": "expired"}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="auto", + transport=transport, + ) as client: + with pytest.raises(httpx.HTTPStatusError) as exc_info: + await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + + assert exc_info.value.response.status_code == 503 + assert exchange_attempt == 2 + assert evaluation_authorization_headers == ["Bearer expired-token"] + assert evaluation_api_key_headers == [None] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_returns_second_unauthorized_response() -> None: + exchange_tokens = ["expired-token", "still-expired-token"] + evaluation_authorization_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/auth/runtime-token-exchange": + return httpx.Response( + 200, + json={ + "token": exchange_tokens.pop(0), + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(401, json={"detail": "expired"}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + + assert response.status_code == 401 + assert exchange_tokens == [] + assert evaluation_authorization_headers == [ + "Bearer expired-token", + "Bearer still-expired-token", + ] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_jwt_mode_requires_target_context() -> None: + transport = httpx.MockTransport(lambda request: httpx.Response(200, json={"ok": True})) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + with pytest.raises(RuntimeError, match="requires target_type and target_id"): + await client.post_runtime_evaluation(json={}) + + +@pytest.mark.asyncio +async def test_runtime_exchange_rejects_non_object_response() -> None: + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/auth/runtime-token-exchange": + return httpx.Response(200, json=["not", "an", "object"]) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + with pytest.raises(RuntimeError, match="response was not an object"): + await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) @pytest.mark.asyncio diff --git a/sdks/python/tests/test_local_evaluation.py b/sdks/python/tests/test_local_evaluation.py index da0115b8..c8fadf3b 100644 --- a/sdks/python/tests/test_local_evaluation.py +++ b/sdks/python/tests/test_local_evaluation.py @@ -29,6 +29,36 @@ # ============================================================================= +class _RuntimeAuthDuckClient: + """Minimal custom client that exposes the runtime-auth evaluation method.""" + + base_url = "https://agent-control.test" + + def __init__(self) -> None: + self.runtime_requests: list[dict[str, Any]] = [] + self.response = MagicMock() + self.response.json.return_value = {"is_safe": True, "confidence": 1.0} + self.response.raise_for_status = MagicMock() + + async def post_runtime_evaluation( + self, + *, + json: dict[str, Any], + headers: dict[str, str] | None = None, + target_type: str | None = None, + target_id: str | None = None, + ) -> MagicMock: + self.runtime_requests.append( + { + "json": json, + "headers": headers, + "target_type": target_type, + "target_id": target_id, + } + ) + return self.response + + @pytest.fixture def agent_name() -> str: """Test agent name.""" @@ -329,6 +359,65 @@ async def test_server_only_controls_calls_server(self, agent_name, llm_payload): assert result.is_safe is True + @pytest.mark.asyncio + async def test_custom_client_with_runtime_method_uses_runtime_auth_path( + self, agent_name, llm_payload + ) -> None: + """Custom clients can opt into runtime auth with post_runtime_evaluation.""" + controls = [ + make_control_dict(1, "server_ctrl", execution="server"), + ] + client = _RuntimeAuthDuckClient() + + result = await check_evaluation_with_local( + client=client, # type: ignore[arg-type] + agent_name=agent_name, + step=llm_payload, + stage="pre", + controls=controls, + target_type="log_stream", + target_id="ls-1", + ) + + assert result.is_safe is True + assert len(client.runtime_requests) == 1 + assert client.runtime_requests[0]["target_type"] == "log_stream" + assert client.runtime_requests[0]["target_id"] == "ls-1" + assert client.runtime_requests[0]["json"]["target_type"] == "log_stream" + assert client.runtime_requests[0]["json"]["target_id"] == "ls-1" + + @pytest.mark.asyncio + async def test_mock_client_with_runtime_method_uses_runtime_auth_path( + self, + agent_name, + llm_payload, + ) -> None: + """Configured mock clients can exercise the runtime-auth path.""" + controls = [ + make_control_dict(1, "server_ctrl", execution="server"), + ] + client = MagicMock(spec=AgentControlClient) + mock_response = MagicMock() + mock_response.json.return_value = {"is_safe": True, "confidence": 1.0} + mock_response.raise_for_status = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock() + client.post_runtime_evaluation = AsyncMock(return_value=mock_response) + + result = await check_evaluation_with_local( + client=client, + agent_name=agent_name, + step=llm_payload, + stage="pre", + controls=controls, + target_type="log_stream", + target_id="ls-1", + ) + + assert result.is_safe is True + client.post_runtime_evaluation.assert_awaited_once() + client.http_client.post.assert_not_called() + @pytest.mark.asyncio async def test_server_only_template_backed_controls_still_call_server( self, diff --git a/sdks/python/tests/test_runtime_auth.py b/sdks/python/tests/test_runtime_auth.py new file mode 100644 index 00000000..ed12d1b7 --- /dev/null +++ b/sdks/python/tests/test_runtime_auth.py @@ -0,0 +1,232 @@ +"""Tests for Agent Control SDK runtime auth helpers.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import pytest + +from agent_control.runtime_auth import ( + RuntimeToken, + RuntimeTokenCache, + normalize_runtime_auth_mode, + parse_runtime_token_exchange_response, +) + + +def _runtime_token( + *, + token: str = "token", + server_url: str = "https://server-a.test", + target_type: str = "log_stream", + target_id: str = "ls-1", + expires_at: datetime | None = None, +) -> RuntimeToken: + return RuntimeToken( + token=token, + expires_at=expires_at or datetime.now(UTC) + timedelta(minutes=5), + server_url=server_url, + target_type=target_type, + target_id=target_id, + scopes=("runtime.use",), + ) + + +def test_runtime_token_cache_is_keyed_by_server_and_target() -> None: + cache = RuntimeTokenCache() + token = _runtime_token() + + cache.set(token) + + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) == token + ) + assert ( + cache.get("https://server-b.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + assert ( + cache.get("https://server-a.test", "log_stream", "ls-2", refresh_margin_seconds=0) is None + ) + + +def test_runtime_token_cache_drops_stale_tokens() -> None: + cache = RuntimeTokenCache() + cache.set(_runtime_token(expires_at=datetime.now(UTC) + timedelta(seconds=5))) + + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=30) is None + ) + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + + +def test_runtime_token_cache_tracks_jwt_unavailable_by_server_and_target() -> None: + cache = RuntimeTokenCache() + + cache.mark_jwt_unavailable( + server_url="https://server-a.test", + target_type="log_stream", + target_id="ls-1", + ) + + assert cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1") + assert not cache.is_jwt_unavailable("https://server-b.test", "log_stream", "ls-1") + + cache.set(_runtime_token()) + + assert not cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1") + + +def test_runtime_token_cache_global_unavailable_clears_cache() -> None: + cache = RuntimeTokenCache() + cache.set(_runtime_token()) + + cache.mark_jwt_unavailable(globally=True) + + assert cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1") + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + + cache.clear() + + assert not cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1") + + +def test_runtime_token_cache_remove_drops_one_token() -> None: + cache = RuntimeTokenCache() + cache.set(_runtime_token(target_id="ls-1")) + token_2 = _runtime_token(token="token-2", target_id="ls-2") + cache.set(token_2) + + cache.remove("https://server-a.test", "log_stream", "ls-1") + + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + assert ( + cache.get("https://server-a.test", "log_stream", "ls-2", refresh_margin_seconds=0) + == token_2 + ) + + +def test_runtime_token_cache_evicts_oldest_token_when_full() -> None: + cache = RuntimeTokenCache(max_entries=1) + token_1 = _runtime_token(target_id="ls-1") + token_2 = _runtime_token(token="token-2", target_id="ls-2") + + cache.set(token_1) + cache.set(token_2) + + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + assert ( + cache.get("https://server-a.test", "log_stream", "ls-2", refresh_margin_seconds=0) + == token_2 + ) + + +def test_runtime_token_cache_rejects_empty_capacity() -> None: + with pytest.raises(ValueError, match="max_entries"): + RuntimeTokenCache(max_entries=0) + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (None, "auto"), + ("", "auto"), + (" NO_AUTH ", "none"), + ("header", "api_key"), + ("api_key", "api_key"), + ("jwt", "jwt"), + ], +) +def test_normalize_runtime_auth_mode(raw: str | None, expected: str) -> None: + assert normalize_runtime_auth_mode(raw) == expected + + +def test_normalize_runtime_auth_mode_rejects_unknown_mode() -> None: + with pytest.raises(ValueError, match="runtime_auth_mode must be one of"): + normalize_runtime_auth_mode("cookie") + + +def test_parse_runtime_token_exchange_response_normalizes_zulu_expiry() -> None: + token = parse_runtime_token_exchange_response( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00Z", + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + server_url="https://server-a.test", + ) + + assert token.token == "runtime-token" + assert token.expires_at == datetime(2026, 5, 7, 15, 0, tzinfo=UTC) + assert token.server_url == "https://server-a.test" + assert token.target_type == "log_stream" + assert token.target_id == "ls-1" + assert token.scopes == ("runtime.use",) + + +def test_parse_runtime_token_exchange_response_treats_naive_expiry_as_utc() -> None: + token = parse_runtime_token_exchange_response( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00", + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + server_url="https://server-a.test", + ) + + assert token.expires_at == datetime(2026, 5, 7, 15, 0, tzinfo=UTC) + + +@pytest.mark.parametrize( + ("payload", "match"), + [ + ({}, "token"), + ({"token": "runtime-token"}, "expires_at"), + ({"token": "runtime-token", "expires_at": "2026-05-07T15:00:00Z"}, "target_type"), + ( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00Z", + "target_type": "log_stream", + }, + "target_id", + ), + ( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00Z", + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": "runtime.use", + }, + "scopes", + ), + ( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00Z", + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use", 1], + }, + "non-string scope", + ), + ], +) +def test_parse_runtime_token_exchange_response_rejects_invalid_payloads( + payload: dict[str, object], + match: str, +) -> None: + with pytest.raises(RuntimeError, match=match): + parse_runtime_token_exchange_response(payload, server_url="https://server-a.test") From b7867594bc741c507b6be24f129716d989a25463 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 19:19:04 +0530 Subject: [PATCH 24/42] feat(sdk): make API key header name configurable Default stays X-API-Key; pass api_key_header=... or set AGENT_CONTROL_API_KEY_HEADER to override when the upstream auth expects a different header. --- sdks/python/src/agent_control/client.py | 38 +++++++++++++-- sdks/python/tests/test_client.py | 61 +++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/sdks/python/src/agent_control/client.py b/sdks/python/src/agent_control/client.py index bc5ef0f1..10ce50d6 100644 --- a/sdks/python/src/agent_control/client.py +++ b/sdks/python/src/agent_control/client.py @@ -27,16 +27,17 @@ class _AgentControlAuth(httpx.Auth): """Attach local API-key credentials unless a request already has Bearer auth.""" - def __init__(self, api_key: str | None) -> None: + def __init__(self, api_key: str | None, header_name: str = "X-API-Key") -> None: self._api_key = api_key + self._header_name = header_name def auth_flow( self, request: httpx.Request, ) -> Generator[httpx.Request, httpx.Response, None]: if self._api_key and "Authorization" not in request.headers: - if "X-API-Key" not in request.headers: - request.headers["X-API-Key"] = self._api_key + if self._header_name not in request.headers: + request.headers[self._header_name] = self._api_key yield request @@ -49,7 +50,9 @@ class AgentControlClient: agents, policies, controls, evaluation. Authentication: - The client supports API key authentication via the X-API-Key header. + The client supports API key authentication. By default the key is + sent on the ``X-API-Key`` header; set ``api_key_header`` (or the + ``AGENT_CONTROL_API_KEY_HEADER`` environment variable) to override. API key can be provided: 1. Directly via the `api_key` parameter 2. Via the AGENT_CONTROL_API_KEY environment variable @@ -63,10 +66,20 @@ class AgentControlClient: os.environ["AGENT_CONTROL_API_KEY"] = "my-secret-key" async with AgentControlClient() as client: await client.health_check() + + # Custom header name (e.g., when the upstream auth expects something + # other than X-API-Key). The header name applies to every request + # this client sends. + async with AgentControlClient( + api_key="my-secret-key", api_key_header="X-Custom-API-Key" + ) as client: + await client.health_check() """ # Environment variable name for API key API_KEY_ENV_VAR = "AGENT_CONTROL_API_KEY" + API_KEY_HEADER_ENV_VAR = "AGENT_CONTROL_API_KEY_HEADER" + DEFAULT_API_KEY_HEADER = "X-API-Key" BASE_URL_ENV_VAR = "AGENT_CONTROL_URL" def __init__( @@ -74,6 +87,7 @@ def __init__( base_url: str | None = None, timeout: float = 30.0, api_key: str | None = None, + api_key_header: str | None = None, runtime_auth_mode: RuntimeAuthMode | str | None = None, runtime_token_cache: RuntimeTokenCache | None = None, runtime_token_refresh_margin_seconds: int = (_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS), @@ -88,6 +102,10 @@ def __init__( timeout: Request timeout in seconds api_key: API key for authentication. If not provided, will attempt to read from AGENT_CONTROL_API_KEY environment variable. + api_key_header: HTTP header name to send the API key on. Defaults + to ``X-API-Key``; the AGENT_CONTROL_API_KEY_HEADER + environment variable overrides the default. Useful when + the configured upstream auth expects a different header. runtime_auth_mode: Runtime auth mode for evaluation requests. ``auto`` attempts target-bound JWT exchange and falls back to normal request auth when the exchange endpoint is unavailable. ``jwt`` @@ -104,6 +122,11 @@ def __init__( self.base_url = resolved_base_url.rstrip("/") self.timeout = timeout self._api_key = api_key or os.environ.get(self.API_KEY_ENV_VAR) + self._api_key_header = ( + api_key_header + or os.environ.get(self.API_KEY_HEADER_ENV_VAR) + or self.DEFAULT_API_KEY_HEADER + ) configured_runtime_mode = runtime_auth_mode or os.environ.get(_RUNTIME_AUTH_MODE_ENV_VAR) self._runtime_auth_mode = normalize_runtime_auth_mode(configured_runtime_mode) if runtime_token_refresh_margin_seconds < 0: @@ -119,6 +142,11 @@ def api_key(self) -> str | None: """Get the configured API key (read-only).""" return self._api_key + @property + def api_key_header(self) -> str: + """Get the header name the API key is sent on (read-only).""" + return self._api_key_header + @property def runtime_auth_mode(self) -> RuntimeAuthMode: """Get the configured runtime auth mode (read-only).""" @@ -159,7 +187,7 @@ async def __aenter__(self) -> "AgentControlClient": base_url=self.base_url, timeout=self.timeout, headers=self._get_headers(), - auth=_AgentControlAuth(self._api_key), + auth=_AgentControlAuth(self._api_key, self._api_key_header), transport=self._transport, event_hooks={"response": [self._check_server_version]}, ) diff --git a/sdks/python/tests/test_client.py b/sdks/python/tests/test_client.py index 11a8ed91..54754a87 100644 --- a/sdks/python/tests/test_client.py +++ b/sdks/python/tests/test_client.py @@ -78,6 +78,67 @@ def handler(request: httpx.Request) -> httpx.Response: assert seen_requests[0].headers["X-API-Key"] == "test-key" +@pytest.mark.asyncio +async def test_client_uses_configured_api_key_header_name() -> None: + # Given: a client configured to send the API key on a custom header + seen_requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_requests.append(request) + return httpx.Response(200, json={"ok": True}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + api_key_header="X-Custom-API-Key", + transport=transport, + ) as client: + # When: making a request + response = await client.http_client.get("/api/v1/agents") + + # Then: the key is on the configured header and the default is absent + assert response.status_code == 200 + assert seen_requests[0].headers["X-Custom-API-Key"] == "test-key" + assert "X-API-Key" not in seen_requests[0].headers + + +@pytest.mark.asyncio +async def test_client_reads_api_key_header_from_env( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Given: AGENT_CONTROL_API_KEY_HEADER set in the environment + monkeypatch.setenv("AGENT_CONTROL_API_KEY_HEADER", "X-Custom-API-Key") + seen_requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_requests.append(request) + return httpx.Response(200, json={"ok": True}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + transport=transport, + ) as client: + # When: no api_key_header is passed to the constructor + response = await client.http_client.get("/api/v1/agents") + + # Then: the env-var value is used + assert response.status_code == 200 + assert seen_requests[0].headers["X-Custom-API-Key"] == "test-key" + + +def test_client_exposes_default_api_key_header() -> None: + # Given: a client with no explicit header override + client = AgentControlClient(api_key="test-key") + + # Then: the property reports the documented default + assert client.api_key_header == "X-API-Key" + + @pytest.mark.asyncio async def test_runtime_evaluation_exchanges_and_caches_bearer_token() -> None: exchange_calls = 0 From 047ba35b0267e6a646eb28063d5c4a21e0b4269e Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 21:24:14 +0530 Subject: [PATCH 25/42] fix(sdk): honor API key header for observability --- .../python/src/agent_control/observability.py | 17 +++++- sdks/python/src/agent_control/settings.py | 4 ++ sdks/python/tests/test_observability.py | 59 +++++++++++++++++-- 3 files changed, 73 insertions(+), 7 deletions(-) diff --git a/sdks/python/src/agent_control/observability.py b/sdks/python/src/agent_control/observability.py index e4805580..199017c1 100644 --- a/sdks/python/src/agent_control/observability.py +++ b/sdks/python/src/agent_control/observability.py @@ -26,6 +26,9 @@ await shutdown_observability() Configuration (Environment Variables): + # Server connection + AGENT_CONTROL_API_KEY_HEADER: API key header name (default: X-API-Key) + # Observability (event batching) AGENT_CONTROL_OBSERVABILITY_ENABLED: Enable observability (default: true) AGENT_CONTROL_OBSERVABILITY_SINK_NAME: Selected control-event sink (default: default) @@ -286,6 +289,7 @@ class EventBatcher: Attributes: server_url: Base URL of the Agent Control server api_key: API key for authentication + api_key_header: HTTP header used to send the API key batch_size: Maximum events per batch flush_interval: Seconds between automatic flushes """ @@ -294,6 +298,7 @@ def __init__( self, server_url: str | None = None, api_key: str | None = None, + api_key_header: str | None = None, batch_size: int | None = None, flush_interval: float | None = None, ): @@ -303,11 +308,13 @@ def __init__( Args: server_url: Server URL (defaults to get_settings().url) api_key: API key (defaults to get_settings().api_key) + api_key_header: API key header (defaults to get_settings().api_key_header) batch_size: Max events per batch (defaults to get_settings().batch_size) flush_interval: Seconds between flushes (defaults to get_settings().flush_interval) """ self.server_url = server_url or get_settings().url self.api_key = api_key or get_settings().api_key + self.api_key_header = api_key_header or get_settings().api_key_header self.batch_size = batch_size if batch_size is not None else get_settings().batch_size if flush_interval is not None: self.flush_interval = flush_interval @@ -424,7 +431,7 @@ def _build_batch_request( url = f"{self.server_url}/api/v1/observability/events" headers = {"Content-Type": "application/json"} if self.api_key: - headers["X-API-Key"] = self.api_key + headers[self.api_key_header] = self.api_key payload = {"events": [event.model_dump(mode="json") for event in events]} return url, headers, payload @@ -1098,6 +1105,7 @@ def _get_custom_control_event_sinks_to_shutdown() -> tuple[ControlEventSink, ... def init_observability( server_url: str | None = None, api_key: str | None = None, + api_key_header: str | None = None, enabled: bool | None = None, sink_name: str | None = None, sink_config: JSONObject | None = None, @@ -1110,6 +1118,7 @@ def init_observability( Args: server_url: Server URL for sending events api_key: API key for authentication + api_key_header: HTTP header used to send the API key enabled: Override AGENT_CONTROL_OBSERVABILITY_ENABLED sink_name: Override AGENT_CONTROL_OBSERVABILITY_SINK_NAME sink_config: Override AGENT_CONTROL_OBSERVABILITY_SINK_CONFIG @@ -1157,7 +1166,11 @@ def init_observability( return _batcher # Create batcher - _batcher = EventBatcher(server_url=server_url, api_key=api_key) + _batcher = EventBatcher( + server_url=server_url, + api_key=api_key, + api_key_header=api_key_header, + ) _batcher.start() _event_sink = _BatcherControlEventSink(_batcher) diff --git a/sdks/python/src/agent_control/settings.py b/sdks/python/src/agent_control/settings.py index d87572d8..2eded296 100644 --- a/sdks/python/src/agent_control/settings.py +++ b/sdks/python/src/agent_control/settings.py @@ -56,6 +56,10 @@ class SDKSettings(BaseSettings): default="", description="API key for server authentication", ) + api_key_header: str = Field( + default="X-API-Key", + description="HTTP header used to send the API key", + ) # Observability (event batching) observability_enabled: bool = Field( diff --git a/sdks/python/tests/test_observability.py b/sdks/python/tests/test_observability.py index 0fff8c82..7981c5ae 100644 --- a/sdks/python/tests/test_observability.py +++ b/sdks/python/tests/test_observability.py @@ -120,6 +120,7 @@ def reset_observability_state() -> None: observability_enabled=True, observability_sink_name=DEFAULT_CONTROL_EVENT_SINK_NAME, observability_sink_config={}, + api_key_header="X-API-Key", ) with obs._used_custom_event_sinks_lock: obs._used_custom_event_sinks.clear() @@ -136,6 +137,7 @@ class TestEventBatcherInit: def test_init_default_values(self): """Test EventBatcher initializes with default values.""" batcher = EventBatcher() + assert batcher.api_key_header == get_settings().api_key_header assert batcher.batch_size == get_settings().batch_size assert batcher.flush_interval == get_settings().flush_interval assert batcher.shutdown_join_timeout == get_settings().shutdown_join_timeout @@ -149,11 +151,13 @@ def test_init_custom_values(self): batcher = EventBatcher( server_url="http://custom:9000", api_key="test-key", + api_key_header="X-Custom-API-Key", batch_size=50, flush_interval=5.0, ) assert batcher.server_url == "http://custom:9000" assert batcher.api_key == "test-key" + assert batcher.api_key_header == "X-Custom-API-Key" assert batcher.batch_size == 50 assert batcher.flush_interval == 5.0 @@ -161,20 +165,23 @@ def test_init_from_settings(self): """Test EventBatcher reads from settings.""" from agent_control.settings import configure_settings - # Save original values - original_url = get_settings().url - original_api_key = get_settings().api_key + original_settings = get_settings().model_dump() try: # Configure settings programmatically - configure_settings(url="http://configured-server:8080", api_key="configured-api-key") + configure_settings( + url="http://configured-server:8080", + api_key="configured-api-key", + api_key_header="X-Custom-API-Key", + ) batcher = EventBatcher() assert batcher.server_url == "http://configured-server:8080" assert batcher.api_key == "configured-api-key" + assert batcher.api_key_header == "X-Custom-API-Key" finally: # Restore original settings - configure_settings(url=original_url, api_key=original_api_key) + configure_settings(**original_settings) class TestEventBatcherStartStop: @@ -557,6 +564,46 @@ def test_send_batch_sync_returns_true_on_202(self): assert result is True client_ctor.assert_called_once_with(timeout=30.0) client.post.assert_called_once() + assert client.post.call_args.kwargs["headers"]["X-API-Key"] == "test-key" + + def test_send_batch_sync_uses_configured_api_key_header(self): + batcher = EventBatcher( + server_url="http://test:8000", + api_key="test-key", + api_key_header="X-Custom-API-Key", + ) + response = MagicMock(status_code=202, text="accepted") + client = MagicMock() + client.post.return_value = response + client_context = MagicMock() + client_context.__enter__.return_value = client + + with patch( + "agent_control.observability.httpx.Client", + return_value=client_context, + ): + result = batcher._send_batch_sync([create_mock_event()]) + + assert result is True + headers = client.post.call_args.kwargs["headers"] + assert headers["X-Custom-API-Key"] == "test-key" + assert "X-API-Key" not in headers + + def test_build_batch_request_uses_settings_api_key_header(self): + original_settings = get_settings().model_dump() + try: + configure_settings( + api_key="settings-key", + api_key_header="X-Custom-API-Key", + ) + batcher = EventBatcher() + + _, headers, _ = batcher._build_batch_request([create_mock_event()]) + + assert headers["X-Custom-API-Key"] == "settings-key" + assert "X-API-Key" not in headers + finally: + configure_settings(**original_settings) def test_send_batch_sync_returns_false_on_401_without_retry(self): batcher = EventBatcher() @@ -1069,10 +1116,12 @@ def test_init_enabled_creates_batcher(self): result = init_observability( server_url="http://test:8000", api_key="test-key", + api_key_header="X-Custom-API-Key", enabled=True, ) assert result is not None assert isinstance(result, EventBatcher) + assert result.api_key_header == "X-Custom-API-Key" assert result._running is True assert get_event_sink() is not None From 16b73574a07a22efc4f9d9ae97af17cf0fbe964d Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 8 May 2026 21:44:51 +0530 Subject: [PATCH 26/42] fix(sdk): keep JWT runtime evaluation target-bound --- sdks/python/src/agent_control/evaluation.py | 2 +- sdks/python/tests/test_local_evaluation.py | 36 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index 5324c322..2ecfd850 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -229,7 +229,7 @@ async def _post_evaluation_request( ) -> httpx.Response: """Send an evaluation request, using runtime auth when the client supports it.""" runtime_post = None - if target_type is not None and target_id is not None: + if (target_type is not None and target_id is not None) or client.runtime_auth_mode == "jwt": runtime_post = _runtime_post_evaluation(client) if runtime_post is not None: return await runtime_post( diff --git a/sdks/python/tests/test_local_evaluation.py b/sdks/python/tests/test_local_evaluation.py index c8fadf3b..b5b725f6 100644 --- a/sdks/python/tests/test_local_evaluation.py +++ b/sdks/python/tests/test_local_evaluation.py @@ -11,6 +11,7 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest from agent_control.client import AgentControlClient from agent_control.evaluation import ( @@ -418,6 +419,41 @@ async def test_mock_client_with_runtime_method_uses_runtime_auth_path( client.post_runtime_evaluation.assert_awaited_once() client.http_client.post.assert_not_called() + @pytest.mark.asyncio + async def test_jwt_runtime_client_without_target_raises( + self, + agent_name, + llm_payload, + ) -> None: + """JWT runtime mode requires target context through local evaluation.""" + controls = [ + make_control_dict(1, "server_ctrl", execution="server"), + ] + sent_requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + sent_requests.append(request) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + with pytest.raises(RuntimeError, match="requires target_type and target_id"): + await check_evaluation_with_local( + client=client, + agent_name=agent_name, + step=llm_payload, + stage="pre", + controls=controls, + ) + + assert sent_requests == [] + @pytest.mark.asyncio async def test_server_only_template_backed_controls_still_call_server( self, From 36ca9f878889f96aee9600bc31b8642511be2eb3 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 14:40:14 +0530 Subject: [PATCH 27/42] fix(sdk): propagate custom API key header --- sdks/python/src/agent_control/__init__.py | 11 +++++ sdks/python/src/agent_control/_state.py | 1 + sdks/python/src/agent_control/evaluation.py | 1 + .../integrations/google_adk/plugin.py | 6 ++- sdks/python/tests/test_evaluation.py | 24 +++++++++- sdks/python/tests/test_init_step_merge.py | 45 ++++++++++++++++++- sdks/python/tests/test_shutdown.py | 5 ++- 7 files changed, 87 insertions(+), 6 deletions(-) diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index c9561a0c..da61c66c 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -160,6 +160,7 @@ class _RefreshContext: agent_name: str server_url: str api_key: str | None + api_key_header: str | None target_type: str | None target_id: str | None @@ -221,6 +222,7 @@ def _snapshot_refresh_context() -> _RefreshContext: agent = state.current_agent server_url = state.server_url api_key = state.api_key + api_key_header = state.api_key_header target_type = state.target_type target_id = state.target_id @@ -234,6 +236,7 @@ def _snapshot_refresh_context() -> _RefreshContext: agent_name=agent.agent_name, server_url=server_url, api_key=api_key, + api_key_header=api_key_header, target_type=target_type, target_id=target_id, ) @@ -244,6 +247,7 @@ async def _fetch_controls_for_context_async(context: _RefreshContext) -> list[di async with AgentControlClient( base_url=context.server_url, api_key=context.api_key, + api_key_header=context.api_key_header, ) as client: response = await agents.list_agent_controls( client, @@ -430,6 +434,7 @@ def init( agent_version: str | None = None, server_url: str | None = None, api_key: str | None = None, + api_key_header: str | None = None, controls_file: str | None = None, steps: list[StepSchemaDict] | None = None, conflict_mode: Literal["strict", "overwrite"] = "overwrite", @@ -468,6 +473,8 @@ def init( server_url: Optional server URL (defaults to AGENT_CONTROL_URL env var or http://localhost:8000) api_key: Optional API key for authentication (defaults to AGENT_CONTROL_API_KEY env var) + api_key_header: Optional HTTP header name for API key authentication + (defaults to AGENT_CONTROL_API_KEY_HEADER env var or X-API-Key) controls_file: Optional explicit path to controls.yaml (auto-discovered if not provided) steps: Optional list of step schemas for registration: [{"type": "tool", "name": "search", "input_schema": {...}, "output_schema": {...}}] @@ -562,6 +569,7 @@ async def handle(message: str): state.current_agent = next_agent state.server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000' state.api_key = api_key + state.api_key_header = api_key_header state.runtime_token_cache.clear() state.target_type = target_type state.target_id = target_id @@ -600,6 +608,7 @@ async def register() -> list[dict[str, Any]] | None: async with AgentControlClient( base_url=state.server_url, api_key=state.api_key, + api_key_header=state.api_key_header, ) as client: # Check server health first try: @@ -686,6 +695,7 @@ def run_in_thread() -> None: batcher = init_observability( server_url=state.server_url, api_key=state.api_key, + api_key_header=state.api_key_header, enabled=observability_enabled, sink_name=observability_sink_name, sink_config=observability_sink_config, @@ -717,6 +727,7 @@ def _reset_state() -> None: state.server_controls = None state.server_url = None state.api_key = None + state.api_key_header = None state.runtime_token_cache.clear() state.target_type = None state.target_id = None diff --git a/sdks/python/src/agent_control/_state.py b/sdks/python/src/agent_control/_state.py index 834c73b0..fe6e185a 100644 --- a/sdks/python/src/agent_control/_state.py +++ b/sdks/python/src/agent_control/_state.py @@ -26,6 +26,7 @@ def __init__(self) -> None: self.server_controls: list[dict[str, Any]] | None = None self.server_url: str | None = None self.api_key: str | None = None + self.api_key_header: str | None = None self.runtime_token_cache = RuntimeTokenCache() # Optional target context fixed at init() time; both fields are set # together or both remain None. diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index 2ecfd850..4f63014c 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -557,6 +557,7 @@ async def evaluate_controls( async with AgentControlClient( base_url=state.server_url, api_key=state.api_key, + api_key_header=state.api_key_header, runtime_token_cache=state.runtime_token_cache, ) as client: return await check_evaluation_with_local( diff --git a/sdks/python/src/agent_control/integrations/google_adk/plugin.py b/sdks/python/src/agent_control/integrations/google_adk/plugin.py index 58b59a7b..eb2155c8 100644 --- a/sdks/python/src/agent_control/integrations/google_adk/plugin.py +++ b/sdks/python/src/agent_control/integrations/google_adk/plugin.py @@ -853,7 +853,11 @@ async def _sync_steps_async(self, steps: list[StepSchemaDict]) -> None: "with the same agent_name as AgentControlPlugin." ) - async with AgentControlClient(base_url=state.server_url, api_key=state.api_key) as client: + async with AgentControlClient( + base_url=state.server_url, + api_key=state.api_key, + api_key_header=state.api_key_header, + ) as client: response = await agents.get_agent(client, self.agent_name) existing = GetAgentResponse.model_validate(response) existing_keys = {(step.type, step.name) for step in existing.steps} diff --git a/sdks/python/tests/test_evaluation.py b/sdks/python/tests/test_evaluation.py index 4b9d03b2..2fb92555 100644 --- a/sdks/python/tests/test_evaluation.py +++ b/sdks/python/tests/test_evaluation.py @@ -4,10 +4,9 @@ from uuid import UUID import pytest -from pydantic import ValidationError - from agent_control import evaluation from agent_control.evaluation import EvaluationResult +from pydantic import ValidationError @pytest.mark.asyncio @@ -126,6 +125,27 @@ async def test_evaluate_controls_with_context(monkeypatch): assert mock_check.call_args is not None +@pytest.mark.asyncio +async def test_evaluate_controls_uses_session_api_key_header(monkeypatch): + """evaluate_controls should pass init's API-key header into the client.""" + mock_result = EvaluationResult(is_safe=True, confidence=1.0) + mock_check = AsyncMock(return_value=mock_result) + monkeypatch.setattr(evaluation, "check_evaluation_with_local", mock_check) + + with patch("agent_control.state.server_url", "http://localhost:8000"), patch( + "agent_control.state.api_key", "test-key" + ), patch("agent_control.state.api_key_header", "Galileo-API-Key"): + await evaluation.evaluate_controls( + step_name="chat", + input="hello", + stage="pre", + agent_name="test-bot", + ) + + client = mock_check.call_args.kwargs["client"] + assert client.api_key_header == "Galileo-API-Key" + + @pytest.mark.asyncio async def test_check_evaluation_forwards_target_context(): """When target_type and target_id are supplied, they are forwarded to the server.""" diff --git a/sdks/python/tests/test_init_step_merge.py b/sdks/python/tests/test_init_step_merge.py index 9678518c..2db9bcc4 100644 --- a/sdks/python/tests/test_init_step_merge.py +++ b/sdks/python/tests/test_init_step_merge.py @@ -4,7 +4,7 @@ import logging from collections.abc import Generator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest.mock import ANY, AsyncMock, patch from uuid import uuid4 @@ -219,6 +219,47 @@ def test_init_omits_merge_events_from_public_signature() -> None: assert "merge_events" not in signature.parameters +def test_init_passes_api_key_header_to_client_and_state() -> None: + register_agent_mock = AsyncMock(return_value={"created": True, "controls": []}) + health_check_mock = AsyncMock(return_value={"status": "healthy"}) + client_init_kwargs: list[dict[str, Any]] = [] + original_init = agent_control.AgentControlClient.__init__ + + def recording_init( + self: agent_control.AgentControlClient, + *args: Any, + **kwargs: Any, + ) -> None: + client_init_kwargs.append(dict(kwargs)) + original_init(self, *args, **kwargs) + + with patch.object( + agent_control.AgentControlClient, + "__init__", + new=recording_init, + ), patch( + "agent_control.__init__.AgentControlClient.health_check", + new=health_check_mock, + ), patch( + "agent_control.__init__.agents.register_agent", + new=register_agent_mock, + ), patch.object( + agent_control, + "init_observability", + return_value=None, + ) as observability_mock: + agent_control.init( + agent_name=f"agent-{uuid4().hex[:12]}", + api_key="test-key", + api_key_header="Galileo-API-Key", + policy_refresh_interval_seconds=0, + ) + + assert agent_control.state.api_key_header == "Galileo-API-Key" + assert client_init_kwargs[0]["api_key_header"] == "Galileo-API-Key" + assert observability_mock.call_args.kwargs["api_key_header"] == "Galileo-API-Key" + + @pytest.mark.asyncio async def test_refresh_controls_calls_agent_controls_endpoint() -> None: # Given: an initialized SDK agent session with network-facing calls mocked. @@ -238,6 +279,7 @@ async def test_refresh_controls_calls_agent_controls_endpoint() -> None: ): agent_control.init( agent_name=f"agent-{uuid4().hex[:12]}", + api_key_header="Galileo-API-Key", policy_refresh_interval_seconds=0, ) @@ -255,4 +297,5 @@ async def test_refresh_controls_calls_agent_controls_endpoint() -> None: target_type=None, target_id=None, ) + assert agent_control.state.api_key_header == "Galileo-API-Key" assert register_agent_mock.await_count == 0 diff --git a/sdks/python/tests/test_shutdown.py b/sdks/python/tests/test_shutdown.py index d02b8aea..49e5b406 100644 --- a/sdks/python/tests/test_shutdown.py +++ b/sdks/python/tests/test_shutdown.py @@ -9,10 +9,9 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch -import pytest - import agent_control import agent_control.observability as obs_mod +import pytest from agent_control._state import state from agent_control.observability import EventBatcher @@ -64,6 +63,7 @@ def test_shutdown_resets_state(self): state.server_controls = [{"name": "test"}] state.server_url = "http://localhost:8000" state.api_key = "key" + state.api_key_header = "X-Custom-API-Key" agent_control.shutdown() @@ -72,6 +72,7 @@ def test_shutdown_resets_state(self): assert state.server_controls is None assert state.server_url is None assert state.api_key is None + assert state.api_key_header is None def test_shutdown_idempotent(self): agent_control.shutdown() From 527a0a9b7f8fef241bbacd0b4818b6535433b5bb Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 11 May 2026 15:21:31 +0530 Subject: [PATCH 28/42] docs(sdk): note strict JWT runtime target requirement --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f9a89686..d39f5e39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ +## Unreleased + +### Changed + +- **sdk**: Strict `runtime_auth_mode="jwt"` evaluation requests now require both + `target_type` and `target_id`; missing target context raises an error instead + of falling back to API-key auth. + ## v7.7.0 (2026-05-07) ### Bug Fixes From 2894484f2bd05963acc6d9d4f84671c858d340c8 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 12 May 2026 08:18:21 -0700 Subject: [PATCH 29/42] create new test branch with all changes --- evaluators/builtin/pyproject.toml | 8 ++++---- evaluators/contrib/galileo/pyproject.toml | 4 ++-- server/pyproject.toml | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/evaluators/builtin/pyproject.toml b/evaluators/builtin/pyproject.toml index 59b53d7e..90cc4dd4 100644 --- a/evaluators/builtin/pyproject.toml +++ b/evaluators/builtin/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.12" license = { text = "Apache-2.0" } authors = [{ name = "Agent Control Team" }] dependencies = [ - "agent-control-models>=7.5.0", + "agent-control-models>=7.7.0", "pydantic>=2.12.4", "google-re2>=1.1", "jsonschema>=4.0.0", @@ -16,9 +16,9 @@ dependencies = [ ] [project.optional-dependencies] -galileo = ["agent-control-evaluator-galileo>=7.5.0"] -budget = ["agent-control-evaluator-budget>=7.5.0"] -cisco = ["agent-control-evaluator-cisco>=7.5.0"] +galileo = ["agent-control-evaluator-galileo>=7.7.0"] +budget = ["agent-control-evaluator-budget>=7.7.0"] +cisco = ["agent-control-evaluator-cisco>=7.7.0"] dev = ["pytest>=8.0.0", "pytest-asyncio>=0.23.0"] [project.entry-points."agent_control.evaluators"] diff --git a/evaluators/contrib/galileo/pyproject.toml b/evaluators/contrib/galileo/pyproject.toml index 3f0fdf84..a2f37675 100644 --- a/evaluators/contrib/galileo/pyproject.toml +++ b/evaluators/contrib/galileo/pyproject.toml @@ -7,8 +7,8 @@ requires-python = ">=3.12" license = { text = "Apache-2.0" } authors = [{ name = "Agent Control Team" }] dependencies = [ - "agent-control-evaluators>=7.5.0", - "agent-control-models>=7.5.0", + "agent-control-evaluators>=7.7.0", + "agent-control-models>=7.7.0", "httpx>=0.24.0", "pydantic>=2.12.4", ] diff --git a/server/pyproject.toml b/server/pyproject.toml index 5e9c5dd2..e626af15 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "jsonschema-rs>=0.22.0", "PyJWT>=2.8.0", "google-re2>=1.1", # For engine (bundled) - "agent-control-evaluators>=7.5.0", # NOT vendored - avoid conflict with galileo + "agent-control-evaluators>=7.7.0", # NOT vendored - avoid conflict with galileo ] authors = [ {name = "Agent Control Team"} @@ -32,7 +32,7 @@ readme = "README.md" license = {text = "Apache-2.0"} [project.optional-dependencies] -galileo = ["agent-control-evaluator-galileo>=7.5.0"] +galileo = ["agent-control-evaluator-galileo>=7.7.0"] [dependency-groups] dev = [ From 54a789f8599654da9d0b7e293aa37c2ea13f3890 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 12 May 2026 09:23:05 -0700 Subject: [PATCH 30/42] fix CI --- evaluators/builtin/pyproject.toml | 8 ++++---- evaluators/contrib/galileo/pyproject.toml | 4 ++-- server/pyproject.toml | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/evaluators/builtin/pyproject.toml b/evaluators/builtin/pyproject.toml index 90cc4dd4..59b53d7e 100644 --- a/evaluators/builtin/pyproject.toml +++ b/evaluators/builtin/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.12" license = { text = "Apache-2.0" } authors = [{ name = "Agent Control Team" }] dependencies = [ - "agent-control-models>=7.7.0", + "agent-control-models>=7.5.0", "pydantic>=2.12.4", "google-re2>=1.1", "jsonschema>=4.0.0", @@ -16,9 +16,9 @@ dependencies = [ ] [project.optional-dependencies] -galileo = ["agent-control-evaluator-galileo>=7.7.0"] -budget = ["agent-control-evaluator-budget>=7.7.0"] -cisco = ["agent-control-evaluator-cisco>=7.7.0"] +galileo = ["agent-control-evaluator-galileo>=7.5.0"] +budget = ["agent-control-evaluator-budget>=7.5.0"] +cisco = ["agent-control-evaluator-cisco>=7.5.0"] dev = ["pytest>=8.0.0", "pytest-asyncio>=0.23.0"] [project.entry-points."agent_control.evaluators"] diff --git a/evaluators/contrib/galileo/pyproject.toml b/evaluators/contrib/galileo/pyproject.toml index a2f37675..3f0fdf84 100644 --- a/evaluators/contrib/galileo/pyproject.toml +++ b/evaluators/contrib/galileo/pyproject.toml @@ -7,8 +7,8 @@ requires-python = ">=3.12" license = { text = "Apache-2.0" } authors = [{ name = "Agent Control Team" }] dependencies = [ - "agent-control-evaluators>=7.7.0", - "agent-control-models>=7.7.0", + "agent-control-evaluators>=7.5.0", + "agent-control-models>=7.5.0", "httpx>=0.24.0", "pydantic>=2.12.4", ] diff --git a/server/pyproject.toml b/server/pyproject.toml index e626af15..5e9c5dd2 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "jsonschema-rs>=0.22.0", "PyJWT>=2.8.0", "google-re2>=1.1", # For engine (bundled) - "agent-control-evaluators>=7.7.0", # NOT vendored - avoid conflict with galileo + "agent-control-evaluators>=7.5.0", # NOT vendored - avoid conflict with galileo ] authors = [ {name = "Agent Control Team"} @@ -32,7 +32,7 @@ readme = "README.md" license = {text = "Apache-2.0"} [project.optional-dependencies] -galileo = ["agent-control-evaluator-galileo>=7.7.0"] +galileo = ["agent-control-evaluator-galileo>=7.5.0"] [dependency-groups] dev = [ From aa87589ae7567f41e0c9c591d854ca8c5170cfef Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 12 May 2026 10:38:44 -0700 Subject: [PATCH 31/42] feat(galileo): support internal scorer auth --- .../luna/client.py | 93 +++++++++++++++--- .../luna/evaluator.py | 14 ++- .../galileo/tests/test_luna_evaluator.py | 95 ++++++++++++++++++- 3 files changed, 179 insertions(+), 23 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index 269d64fc..e75b74bf 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -4,7 +4,12 @@ import logging import os +from base64 import urlsafe_b64encode from dataclasses import dataclass, field +from hashlib import sha256 +from hmac import new as hmac_new +from json import dumps +from time import time from uuid import UUID import httpx @@ -13,6 +18,38 @@ logger = logging.getLogger(__name__) DEFAULT_TIMEOUT_SECS = 10.0 +DEFAULT_INTERNAL_TOKEN_TTL_SECS = 3600 +PUBLIC_SCORER_INVOKE_PATH = "/scorers/invoke" +INTERNAL_SCORER_INVOKE_PATH = "/internal/scorers/invoke" + + +def _b64url(data: bytes) -> str: + return urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _internal_auth_token( + api_secret: str, + project_id: str | UUID, + ttl_seconds: int = DEFAULT_INTERNAL_TOKEN_TTL_SECS, +) -> str: + """Create the internal JWT expected by Galileo API internal routes.""" + now = int(time()) + header = {"alg": "HS256", "typ": "JWT"} + payload = { + "internal": True, + "project_id": str(project_id), + "scope": "scorers.invoke", + "iat": now, + "exp": now + ttl_seconds, + } + signing_input = ".".join( + [ + _b64url(dumps(header, separators=(",", ":")).encode("utf-8")), + _b64url(dumps(payload, separators=(",", ":")).encode("utf-8")), + ] + ) + signature = hmac_new(api_secret.encode("utf-8"), signing_input.encode("ascii"), sha256).digest() + return f"{signing_input}.{_b64url(signature)}" def _as_float_or_none(value: JSONValue) -> float | None: @@ -33,7 +70,7 @@ class ScorerInvokeRequest: """Request payload for Galileo Luna scorer invocation. Attributes: - metric: Preset, registered, or fine-tuned scorer name. + metric: Preset, registered, or fine-tuned scorer label. input: Optional user/system prompt text. output: Optional model response text. luna_model: Optional Luna model override. @@ -50,7 +87,7 @@ class ScorerInvokeRequest: def to_dict(self) -> JSONObject: """Convert to the public API request shape.""" - body: JSONObject = {"metric": self.metric} + body: JSONObject = {"scorer_label": self.metric} if self.input is not None: body["input"] = self.input if self.output is not None: @@ -87,7 +124,7 @@ class ScorerInvokeResponse: @classmethod def from_dict(cls, data: JSONObject) -> ScorerInvokeResponse: """Create a response model from the API JSON object.""" - metric_value = data.get("metric", "") + metric_value = data.get("scorer_label", data.get("metric", "")) status_value = data.get("status", "unknown") error_value = data.get("error_message") @@ -105,13 +142,15 @@ class GalileoLunaClient: """Thin HTTP client for Galileo Luna direct scorer invocation. Environment Variables: - GALILEO_API_KEY: Galileo API key (required). + GALILEO_API_SECRET_KEY or GALILEO_API_SECRET: Galileo API internal JWT signing secret. + GALILEO_API_KEY: Galileo API key fallback for public scorer invocation. GALILEO_CONSOLE_URL: Galileo Console URL (optional, defaults to production). """ def __init__( self, api_key: str | None = None, + api_secret: str | None = None, console_url: str | None = None, api_url: str | None = None, ) -> None: @@ -119,22 +158,28 @@ def __init__( Args: api_key: Galileo API key. If not provided, reads from GALILEO_API_KEY. + api_secret: Galileo API secret for internal JWT auth. If not provided, + reads from GALILEO_API_SECRET_KEY or GALILEO_API_SECRET. console_url: Galileo Console URL. If not provided, reads from GALILEO_CONSOLE_URL or uses the production console URL. api_url: Galileo API URL. If not provided, reads from GALILEO_API_URL before deriving from the console URL. Raises: - ValueError: If no API key is provided or found in the environment. + ValueError: If neither API secret nor API key is provided. """ + resolved_api_secret = ( + api_secret or os.getenv("GALILEO_API_SECRET_KEY") or os.getenv("GALILEO_API_SECRET") + ) resolved_api_key = api_key or os.getenv("GALILEO_API_KEY") - if not resolved_api_key: + if not resolved_api_secret and not resolved_api_key: raise ValueError( - "GALILEO_API_KEY is required. " - "Set it as an environment variable or pass it to the constructor." + "GALILEO_API_SECRET_KEY or GALILEO_API_KEY is required. " + "Set one as an environment variable or pass it to the constructor." ) self.api_key = resolved_api_key + self.api_secret = resolved_api_secret self.console_url = ( console_url or os.getenv("GALILEO_CONSOLE_URL") or "https://console.galileo.ai" ) @@ -162,15 +207,34 @@ def _derive_api_url(self, console_url: str) -> str: async def _get_client(self) -> httpx.AsyncClient: """Get or create the HTTP client.""" if self._client is None or self._client.is_closed: + headers = {"Content-Type": "application/json"} + if self.api_secret is None and self.api_key is not None: + headers["Galileo-API-Key"] = self.api_key self._client = httpx.AsyncClient( - headers={ - "Galileo-API-Key": self.api_key, - "Content-Type": "application/json", - }, + headers=headers, timeout=httpx.Timeout(DEFAULT_TIMEOUT_SECS), ) return self._client + def _endpoint_and_headers( + self, + project_id: str | UUID | None, + headers: dict[str, str] | None, + ) -> tuple[str, dict[str, str]]: + request_headers = dict(headers or {}) + if self.api_secret is None: + return f"{self.api_base}{PUBLIC_SCORER_INVOKE_PATH}", request_headers + + if project_id is None: + raise ValueError( + "project_id is required when using GALILEO_API_SECRET_KEY internal auth." + ) + + request_headers["Authorization"] = ( + f"Bearer {_internal_auth_token(self.api_secret, project_id)}" + ) + return f"{self.api_base}{INTERNAL_SCORER_INVOKE_PATH}", request_headers + async def invoke( self, *, @@ -186,7 +250,7 @@ async def invoke( """Invoke a Galileo Luna scorer. Args: - metric: Preset, registered, or fine-tuned scorer name. + metric: Preset, registered, or fine-tuned scorer label. input: Optional user/system prompt text. output: Optional model response text. project_id: Optional Galileo project UUID for project-scoped scorer resolution. @@ -215,8 +279,7 @@ async def invoke( luna_model=luna_model, config=config, ).to_dict() - request_headers = dict(headers or {}) - endpoint = f"{self.api_base}/scorers/invoke" + endpoint, request_headers = self._endpoint_and_headers(project_id, headers) logger.debug("[GalileoLunaClient] POST %s", endpoint) logger.debug("[GalileoLunaClient] Request body: %s", request_body) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py index 16a39930..f628cd8e 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -101,12 +101,18 @@ def __init__(self, config: LunaEvaluatorConfig) -> None: config: Validated LunaEvaluatorConfig instance. Raises: - ValueError: If GALILEO_API_KEY is not set. + ValueError: If neither GALILEO_API_SECRET_KEY nor GALILEO_API_KEY is set. """ - if not os.getenv("GALILEO_API_KEY"): + has_auth = ( + os.getenv("GALILEO_API_SECRET_KEY") + or os.getenv("GALILEO_API_SECRET") + or os.getenv("GALILEO_API_KEY") + ) + if not has_auth: raise ValueError( - "GALILEO_API_KEY environment variable must be set. " - "Set it to a Galileo API key before using galileo.luna." + "GALILEO_API_SECRET_KEY or GALILEO_API_KEY environment variable must be set. " + "Set an API secret for internal auth or a Galileo API key before using " + "galileo.luna." ) super().__init__(config) diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 1b7e700e..53cf58ae 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -4,6 +4,7 @@ import json import os +from base64 import urlsafe_b64decode from unittest.mock import AsyncMock, patch import httpx @@ -12,6 +13,12 @@ from pydantic import ValidationError +def _decode_jwt_payload(token: str) -> dict[str, object]: + payload_segment = token.split(".")[1] + padded = payload_segment + ("=" * (-len(payload_segment) % 4)) + return json.loads(urlsafe_b64decode(padded.encode()).decode()) + + class TestLunaEvaluatorConfig: """Tests for direct Luna evaluator configuration.""" @@ -96,7 +103,7 @@ def handler(request: httpx.Request) -> httpx.Response: return httpx.Response( 200, json={ - "metric": "toxicity", + "scorer_label": "toxicity", "score": 0.82, "status": "success", "execution_time": 0.12, @@ -133,7 +140,7 @@ def handler(request: httpx.Request) -> httpx.Response: assert captured["body"] == { "input": "user prompt", "output": "model answer", - "metric": "toxicity", + "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", "luna_model": "luna-2", "config": {"top_k": 1}, @@ -144,6 +151,72 @@ def handler(request: httpx.Request) -> httpx.Response: assert isinstance(headers, dict) assert headers["galileo-api-key"] == "test-key" + @pytest.mark.asyncio + async def test_client_uses_internal_jwt_when_api_secret_is_set(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["headers"] = dict(request.headers) + captured["body"] = json.loads(request.content.decode()) + return httpx.Response( + 200, + json={ + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + }, + ) + + # Given: a Luna client configured with the Galileo API internal secret + with patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True): + client = GalileoLunaClient(api_url="https://api.default.svc.cluster.local:8088") + client._client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + + try: + # When: invoking a scorer with project context + response = await client.invoke( + metric="toxicity", + output="model answer", + project_id="12345678-1234-5678-1234-567812345678", + ) + finally: + await client.close() + + # Then: the internal scorer endpoint is called with a project-bound JWT + assert response.score == 0.82 + assert captured["url"] == "https://api.default.svc.cluster.local:8088/internal/scorers/invoke" + assert captured["body"] == { + "output": "model answer", + "scorer_label": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + } + headers = captured["headers"] + assert isinstance(headers, dict) + assert "galileo-api-key" not in headers + auth_header = headers["authorization"] + assert isinstance(auth_header, str) + assert auth_header.startswith("Bearer ") + token_payload = _decode_jwt_payload(auth_header.removeprefix("Bearer ")) + assert token_payload["internal"] is True + assert token_payload["project_id"] == "12345678-1234-5678-1234-567812345678" + assert token_payload["scope"] == "scorers.invoke" + + @pytest.mark.asyncio + async def test_client_requires_project_id_for_internal_jwt(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: a Luna client configured with internal JWT auth + with patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True): + client = GalileoLunaClient(api_url="https://api.default.svc.cluster.local:8088") + + # When/Then: project_id is required because API uses it as the internal auth context + with pytest.raises(ValueError, match="project_id is required"): + await client.invoke(metric="toxicity", output="model answer") + class TestLunaEvaluator: """Tests for direct Luna evaluator behavior.""" @@ -156,12 +229,26 @@ def test_evaluator_metadata(self) -> None: assert LunaEvaluator.metadata.requires_api_key is True @patch.dict(os.environ, {}, clear=True) - def test_evaluator_init_without_api_key_raises(self) -> None: + def test_evaluator_init_without_auth_raises(self) -> None: from agent_control_evaluator_galileo.luna import LunaEvaluator - with pytest.raises(ValueError, match="GALILEO_API_KEY"): + with pytest.raises(ValueError, match="GALILEO_API_SECRET_KEY or GALILEO_API_KEY"): LunaEvaluator.from_dict({"metric": "toxicity", "threshold": 0.5}) + @patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True) + def test_evaluator_init_accepts_api_secret(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + + evaluator = LunaEvaluator.from_dict( + { + "metric": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + "threshold": 0.5, + } + ) + + assert str(evaluator.config.project_id) == "12345678-1234-5678-1234-567812345678" + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) @pytest.mark.asyncio async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: From 0cce0bf806123843b50a72cec7ec0da6dd0c02be Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 12 May 2026 10:38:44 -0700 Subject: [PATCH 32/42] feat(galileo): support internal scorer auth --- .../luna/client.py | 93 +++++++++++++++--- .../luna/evaluator.py | 14 ++- .../galileo/tests/test_luna_evaluator.py | 95 ++++++++++++++++++- 3 files changed, 179 insertions(+), 23 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index 269d64fc..e75b74bf 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -4,7 +4,12 @@ import logging import os +from base64 import urlsafe_b64encode from dataclasses import dataclass, field +from hashlib import sha256 +from hmac import new as hmac_new +from json import dumps +from time import time from uuid import UUID import httpx @@ -13,6 +18,38 @@ logger = logging.getLogger(__name__) DEFAULT_TIMEOUT_SECS = 10.0 +DEFAULT_INTERNAL_TOKEN_TTL_SECS = 3600 +PUBLIC_SCORER_INVOKE_PATH = "/scorers/invoke" +INTERNAL_SCORER_INVOKE_PATH = "/internal/scorers/invoke" + + +def _b64url(data: bytes) -> str: + return urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _internal_auth_token( + api_secret: str, + project_id: str | UUID, + ttl_seconds: int = DEFAULT_INTERNAL_TOKEN_TTL_SECS, +) -> str: + """Create the internal JWT expected by Galileo API internal routes.""" + now = int(time()) + header = {"alg": "HS256", "typ": "JWT"} + payload = { + "internal": True, + "project_id": str(project_id), + "scope": "scorers.invoke", + "iat": now, + "exp": now + ttl_seconds, + } + signing_input = ".".join( + [ + _b64url(dumps(header, separators=(",", ":")).encode("utf-8")), + _b64url(dumps(payload, separators=(",", ":")).encode("utf-8")), + ] + ) + signature = hmac_new(api_secret.encode("utf-8"), signing_input.encode("ascii"), sha256).digest() + return f"{signing_input}.{_b64url(signature)}" def _as_float_or_none(value: JSONValue) -> float | None: @@ -33,7 +70,7 @@ class ScorerInvokeRequest: """Request payload for Galileo Luna scorer invocation. Attributes: - metric: Preset, registered, or fine-tuned scorer name. + metric: Preset, registered, or fine-tuned scorer label. input: Optional user/system prompt text. output: Optional model response text. luna_model: Optional Luna model override. @@ -50,7 +87,7 @@ class ScorerInvokeRequest: def to_dict(self) -> JSONObject: """Convert to the public API request shape.""" - body: JSONObject = {"metric": self.metric} + body: JSONObject = {"scorer_label": self.metric} if self.input is not None: body["input"] = self.input if self.output is not None: @@ -87,7 +124,7 @@ class ScorerInvokeResponse: @classmethod def from_dict(cls, data: JSONObject) -> ScorerInvokeResponse: """Create a response model from the API JSON object.""" - metric_value = data.get("metric", "") + metric_value = data.get("scorer_label", data.get("metric", "")) status_value = data.get("status", "unknown") error_value = data.get("error_message") @@ -105,13 +142,15 @@ class GalileoLunaClient: """Thin HTTP client for Galileo Luna direct scorer invocation. Environment Variables: - GALILEO_API_KEY: Galileo API key (required). + GALILEO_API_SECRET_KEY or GALILEO_API_SECRET: Galileo API internal JWT signing secret. + GALILEO_API_KEY: Galileo API key fallback for public scorer invocation. GALILEO_CONSOLE_URL: Galileo Console URL (optional, defaults to production). """ def __init__( self, api_key: str | None = None, + api_secret: str | None = None, console_url: str | None = None, api_url: str | None = None, ) -> None: @@ -119,22 +158,28 @@ def __init__( Args: api_key: Galileo API key. If not provided, reads from GALILEO_API_KEY. + api_secret: Galileo API secret for internal JWT auth. If not provided, + reads from GALILEO_API_SECRET_KEY or GALILEO_API_SECRET. console_url: Galileo Console URL. If not provided, reads from GALILEO_CONSOLE_URL or uses the production console URL. api_url: Galileo API URL. If not provided, reads from GALILEO_API_URL before deriving from the console URL. Raises: - ValueError: If no API key is provided or found in the environment. + ValueError: If neither API secret nor API key is provided. """ + resolved_api_secret = ( + api_secret or os.getenv("GALILEO_API_SECRET_KEY") or os.getenv("GALILEO_API_SECRET") + ) resolved_api_key = api_key or os.getenv("GALILEO_API_KEY") - if not resolved_api_key: + if not resolved_api_secret and not resolved_api_key: raise ValueError( - "GALILEO_API_KEY is required. " - "Set it as an environment variable or pass it to the constructor." + "GALILEO_API_SECRET_KEY or GALILEO_API_KEY is required. " + "Set one as an environment variable or pass it to the constructor." ) self.api_key = resolved_api_key + self.api_secret = resolved_api_secret self.console_url = ( console_url or os.getenv("GALILEO_CONSOLE_URL") or "https://console.galileo.ai" ) @@ -162,15 +207,34 @@ def _derive_api_url(self, console_url: str) -> str: async def _get_client(self) -> httpx.AsyncClient: """Get or create the HTTP client.""" if self._client is None or self._client.is_closed: + headers = {"Content-Type": "application/json"} + if self.api_secret is None and self.api_key is not None: + headers["Galileo-API-Key"] = self.api_key self._client = httpx.AsyncClient( - headers={ - "Galileo-API-Key": self.api_key, - "Content-Type": "application/json", - }, + headers=headers, timeout=httpx.Timeout(DEFAULT_TIMEOUT_SECS), ) return self._client + def _endpoint_and_headers( + self, + project_id: str | UUID | None, + headers: dict[str, str] | None, + ) -> tuple[str, dict[str, str]]: + request_headers = dict(headers or {}) + if self.api_secret is None: + return f"{self.api_base}{PUBLIC_SCORER_INVOKE_PATH}", request_headers + + if project_id is None: + raise ValueError( + "project_id is required when using GALILEO_API_SECRET_KEY internal auth." + ) + + request_headers["Authorization"] = ( + f"Bearer {_internal_auth_token(self.api_secret, project_id)}" + ) + return f"{self.api_base}{INTERNAL_SCORER_INVOKE_PATH}", request_headers + async def invoke( self, *, @@ -186,7 +250,7 @@ async def invoke( """Invoke a Galileo Luna scorer. Args: - metric: Preset, registered, or fine-tuned scorer name. + metric: Preset, registered, or fine-tuned scorer label. input: Optional user/system prompt text. output: Optional model response text. project_id: Optional Galileo project UUID for project-scoped scorer resolution. @@ -215,8 +279,7 @@ async def invoke( luna_model=luna_model, config=config, ).to_dict() - request_headers = dict(headers or {}) - endpoint = f"{self.api_base}/scorers/invoke" + endpoint, request_headers = self._endpoint_and_headers(project_id, headers) logger.debug("[GalileoLunaClient] POST %s", endpoint) logger.debug("[GalileoLunaClient] Request body: %s", request_body) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py index 16a39930..f628cd8e 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -101,12 +101,18 @@ def __init__(self, config: LunaEvaluatorConfig) -> None: config: Validated LunaEvaluatorConfig instance. Raises: - ValueError: If GALILEO_API_KEY is not set. + ValueError: If neither GALILEO_API_SECRET_KEY nor GALILEO_API_KEY is set. """ - if not os.getenv("GALILEO_API_KEY"): + has_auth = ( + os.getenv("GALILEO_API_SECRET_KEY") + or os.getenv("GALILEO_API_SECRET") + or os.getenv("GALILEO_API_KEY") + ) + if not has_auth: raise ValueError( - "GALILEO_API_KEY environment variable must be set. " - "Set it to a Galileo API key before using galileo.luna." + "GALILEO_API_SECRET_KEY or GALILEO_API_KEY environment variable must be set. " + "Set an API secret for internal auth or a Galileo API key before using " + "galileo.luna." ) super().__init__(config) diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 1b7e700e..53cf58ae 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -4,6 +4,7 @@ import json import os +from base64 import urlsafe_b64decode from unittest.mock import AsyncMock, patch import httpx @@ -12,6 +13,12 @@ from pydantic import ValidationError +def _decode_jwt_payload(token: str) -> dict[str, object]: + payload_segment = token.split(".")[1] + padded = payload_segment + ("=" * (-len(payload_segment) % 4)) + return json.loads(urlsafe_b64decode(padded.encode()).decode()) + + class TestLunaEvaluatorConfig: """Tests for direct Luna evaluator configuration.""" @@ -96,7 +103,7 @@ def handler(request: httpx.Request) -> httpx.Response: return httpx.Response( 200, json={ - "metric": "toxicity", + "scorer_label": "toxicity", "score": 0.82, "status": "success", "execution_time": 0.12, @@ -133,7 +140,7 @@ def handler(request: httpx.Request) -> httpx.Response: assert captured["body"] == { "input": "user prompt", "output": "model answer", - "metric": "toxicity", + "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", "luna_model": "luna-2", "config": {"top_k": 1}, @@ -144,6 +151,72 @@ def handler(request: httpx.Request) -> httpx.Response: assert isinstance(headers, dict) assert headers["galileo-api-key"] == "test-key" + @pytest.mark.asyncio + async def test_client_uses_internal_jwt_when_api_secret_is_set(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["headers"] = dict(request.headers) + captured["body"] = json.loads(request.content.decode()) + return httpx.Response( + 200, + json={ + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + }, + ) + + # Given: a Luna client configured with the Galileo API internal secret + with patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True): + client = GalileoLunaClient(api_url="https://api.default.svc.cluster.local:8088") + client._client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + + try: + # When: invoking a scorer with project context + response = await client.invoke( + metric="toxicity", + output="model answer", + project_id="12345678-1234-5678-1234-567812345678", + ) + finally: + await client.close() + + # Then: the internal scorer endpoint is called with a project-bound JWT + assert response.score == 0.82 + assert captured["url"] == "https://api.default.svc.cluster.local:8088/internal/scorers/invoke" + assert captured["body"] == { + "output": "model answer", + "scorer_label": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + } + headers = captured["headers"] + assert isinstance(headers, dict) + assert "galileo-api-key" not in headers + auth_header = headers["authorization"] + assert isinstance(auth_header, str) + assert auth_header.startswith("Bearer ") + token_payload = _decode_jwt_payload(auth_header.removeprefix("Bearer ")) + assert token_payload["internal"] is True + assert token_payload["project_id"] == "12345678-1234-5678-1234-567812345678" + assert token_payload["scope"] == "scorers.invoke" + + @pytest.mark.asyncio + async def test_client_requires_project_id_for_internal_jwt(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: a Luna client configured with internal JWT auth + with patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True): + client = GalileoLunaClient(api_url="https://api.default.svc.cluster.local:8088") + + # When/Then: project_id is required because API uses it as the internal auth context + with pytest.raises(ValueError, match="project_id is required"): + await client.invoke(metric="toxicity", output="model answer") + class TestLunaEvaluator: """Tests for direct Luna evaluator behavior.""" @@ -156,12 +229,26 @@ def test_evaluator_metadata(self) -> None: assert LunaEvaluator.metadata.requires_api_key is True @patch.dict(os.environ, {}, clear=True) - def test_evaluator_init_without_api_key_raises(self) -> None: + def test_evaluator_init_without_auth_raises(self) -> None: from agent_control_evaluator_galileo.luna import LunaEvaluator - with pytest.raises(ValueError, match="GALILEO_API_KEY"): + with pytest.raises(ValueError, match="GALILEO_API_SECRET_KEY or GALILEO_API_KEY"): LunaEvaluator.from_dict({"metric": "toxicity", "threshold": 0.5}) + @patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True) + def test_evaluator_init_accepts_api_secret(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + + evaluator = LunaEvaluator.from_dict( + { + "metric": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + "threshold": 0.5, + } + ) + + assert str(evaluator.config.project_id) == "12345678-1234-5678-1234-567812345678" + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) @pytest.mark.asyncio async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: From dd252be06b80c464b9c13929af166dd669cf235d Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 12 May 2026 10:49:41 -0700 Subject: [PATCH 33/42] add auth and update schema --- .../luna/client.py | 53 +++++++++---------- .../luna/config.py | 2 - .../luna/evaluator.py | 1 - .../galileo/tests/test_luna_evaluator.py | 35 +++++++++--- 4 files changed, 55 insertions(+), 36 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index e75b74bf..6786c5e8 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -10,10 +10,12 @@ from hmac import new as hmac_new from json import dumps from time import time +from typing import Literal from uuid import UUID import httpx from agent_control_models import JSONObject, JSONValue +from pydantic import BaseModel, Field, model_validator logger = logging.getLogger(__name__) @@ -65,40 +67,37 @@ def _as_float_or_none(value: JSONValue) -> float | None: return None -@dataclass(frozen=True) -class ScorerInvokeRequest: +ScorerStepType = Literal["session", "trace", "span"] + + +class ScorerInvokeRequest(BaseModel): """Request payload for Galileo Luna scorer invocation. Attributes: - metric: Preset, registered, or fine-tuned scorer label. + step_type: Runtime step shape used by Galileo scorer input normalization. input: Optional user/system prompt text. output: Optional model response text. - luna_model: Optional Luna model override. + scorer_label: Preset, registered, or fine-tuned scorer label. project_id: Optional Galileo project UUID for project-scoped scorer resolution. config: Optional scorer-specific configuration. """ - metric: str - input: str | None = None - output: str | None = None + step_type: ScorerStepType = Field(default="span") + input: JSONValue = None + output: JSONValue = None + scorer_label: str = Field(min_length=1) project_id: str | UUID | None = None - luna_model: str | None = None config: JSONObject | None = None + @model_validator(mode="after") + def ensure_input_or_output(self) -> ScorerInvokeRequest: + if self.input is None and self.output is None: + raise ValueError("Either input or output must be set.") + return self + def to_dict(self) -> JSONObject: - """Convert to the public API request shape.""" - body: JSONObject = {"scorer_label": self.metric} - if self.input is not None: - body["input"] = self.input - if self.output is not None: - body["output"] = self.output - if self.project_id is not None: - body["project_id"] = str(self.project_id) - if self.luna_model is not None: - body["luna_model"] = self.luna_model - if self.config is not None: - body["config"] = self.config - return body + """Convert to the Galileo scorer invoke API request shape.""" + return self.model_dump(mode="json", exclude_none=True) @dataclass @@ -239,10 +238,10 @@ async def invoke( self, *, metric: str, - input: str | None = None, - output: str | None = None, + input: JSONValue = None, + output: JSONValue = None, + step_type: ScorerStepType = "span", project_id: str | UUID | None = None, - luna_model: str | None = None, config: JSONObject | None = None, timeout: float = DEFAULT_TIMEOUT_SECS, headers: dict[str, str] | None = None, @@ -253,8 +252,8 @@ async def invoke( metric: Preset, registered, or fine-tuned scorer label. input: Optional user/system prompt text. output: Optional model response text. + step_type: Runtime step shape used by Galileo scorer input normalization. project_id: Optional Galileo project UUID for project-scoped scorer resolution. - luna_model: Optional Luna model override. config: Optional scorer-specific configuration. timeout: Request timeout in seconds. headers: Additional request headers. @@ -272,11 +271,11 @@ async def invoke( raise ValueError("At least one of input or output must be provided.") request_body = ScorerInvokeRequest( - metric=metric, + scorer_label=metric, input=input, output=output, + step_type=step_type, project_id=project_id, - luna_model=luna_model, config=config, ).to_dict() endpoint, request_headers = self._endpoint_and_headers(project_id, headers) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py index 241e040f..3bcc34a3 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py @@ -36,7 +36,6 @@ class LunaEvaluatorConfig(EvaluatorConfig): project_id: Optional Galileo project UUID for project-scoped scorer resolution. threshold: Local threshold used by the evaluator for comparison. operator: Local comparison operator. Numeric operators use threshold as a number. - luna_model: Optional Luna model override sent to Galileo. scorer_config: Optional scorer-specific config sent as ``config``. timeout_ms: Request timeout in milliseconds. on_error: Error policy: allow=fail open, deny=fail closed. @@ -58,7 +57,6 @@ class LunaEvaluatorConfig(EvaluatorConfig): default="gte", description="Local comparison operator applied to the raw Luna score.", ) - luna_model: str | None = Field(default=None, description="Optional Luna model override") scorer_config: JSONObject | None = Field( default=None, alias="config", diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py index f628cd8e..8afea45d 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -199,7 +199,6 @@ async def evaluate(self, data: Any) -> EvaluatorResult: input=input_text if _has_text(input_text) else None, output=output_text if _has_text(output_text) else None, project_id=self.config.project_id, - luna_model=self.config.luna_model, config=self.config.scorer_config, timeout=self.get_timeout_seconds(), ) diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 53cf58ae..58bd201b 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -31,7 +31,6 @@ def test_config_accepts_direct_scorer_fields(self) -> None: project_id="12345678-1234-5678-1234-567812345678", threshold=0.7, operator="gte", - luna_model="luna-2", config={"temperature": 0}, ) @@ -40,7 +39,6 @@ def test_config_accepts_direct_scorer_fields(self) -> None: assert str(config.project_id) == "12345678-1234-5678-1234-567812345678" assert config.threshold == 0.7 assert config.operator == "gte" - assert config.luna_model == "luna-2" assert config.scorer_config == {"temperature": 0} def test_numeric_operator_requires_numeric_threshold(self) -> None: @@ -54,6 +52,33 @@ def test_numeric_operator_requires_numeric_threshold(self) -> None: class TestGalileoLunaClient: """Tests for the GalileoLunaClient HTTP contract.""" + def test_scorer_invoke_request_matches_orbit_schema_shape(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeRequest + + # Given: a scorer request with project context and scorer config + request = ScorerInvokeRequest( + scorer_label="toxicity", + input={"messages": [{"role": "user", "content": "hello"}]}, + project_id="12345678-1234-5678-1234-567812345678", + config={"top_k": 1}, + ) + + # Then: the serialized payload uses the Orbit scorer invoke fields + assert request.to_dict() == { + "step_type": "span", + "input": {"messages": [{"role": "user", "content": "hello"}]}, + "scorer_label": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + "config": {"top_k": 1}, + } + + def test_scorer_invoke_request_requires_input_or_output(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeRequest + + # Given/When/Then: the request mirrors Orbit validation + with pytest.raises(ValidationError, match="Either input or output must be set"): + ScorerInvokeRequest(scorer_label="toxicity") + def test_client_uses_protect_api_url_derivation(self) -> None: from agent_control_evaluator_galileo.luna import GalileoLunaClient @@ -128,7 +153,6 @@ def handler(request: httpx.Request) -> httpx.Response: input="user prompt", output="model answer", project_id="12345678-1234-5678-1234-567812345678", - luna_model="luna-2", config={"top_k": 1}, ) finally: @@ -142,7 +166,7 @@ def handler(request: httpx.Request) -> httpx.Response: "output": "model answer", "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", - "luna_model": "luna-2", + "step_type": "span", "config": {"top_k": 1}, } assert "stage_name" not in captured["body"] @@ -193,6 +217,7 @@ def handler(request: httpx.Request) -> httpx.Response: "output": "model answer", "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", + "step_type": "span", } headers = captured["headers"] assert isinstance(headers, dict) @@ -301,7 +326,6 @@ async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: input="user prompt", output="model answer", project_id=evaluator.config.project_id, - luna_model=None, config=None, timeout=5.0, ) @@ -335,7 +359,6 @@ async def test_evaluator_returns_non_match_below_threshold(self) -> None: input="hello", output=None, project_id=None, - luna_model=None, config=None, timeout=10.0, ) From 37efd66721c097b95c6ec27eb330442a06a2b7bd Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 12 May 2026 10:49:41 -0700 Subject: [PATCH 34/42] add auth and update schema --- .../luna/client.py | 53 +++++++++---------- .../luna/config.py | 2 - .../luna/evaluator.py | 1 - .../galileo/tests/test_luna_evaluator.py | 35 +++++++++--- 4 files changed, 55 insertions(+), 36 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index e75b74bf..6786c5e8 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -10,10 +10,12 @@ from hmac import new as hmac_new from json import dumps from time import time +from typing import Literal from uuid import UUID import httpx from agent_control_models import JSONObject, JSONValue +from pydantic import BaseModel, Field, model_validator logger = logging.getLogger(__name__) @@ -65,40 +67,37 @@ def _as_float_or_none(value: JSONValue) -> float | None: return None -@dataclass(frozen=True) -class ScorerInvokeRequest: +ScorerStepType = Literal["session", "trace", "span"] + + +class ScorerInvokeRequest(BaseModel): """Request payload for Galileo Luna scorer invocation. Attributes: - metric: Preset, registered, or fine-tuned scorer label. + step_type: Runtime step shape used by Galileo scorer input normalization. input: Optional user/system prompt text. output: Optional model response text. - luna_model: Optional Luna model override. + scorer_label: Preset, registered, or fine-tuned scorer label. project_id: Optional Galileo project UUID for project-scoped scorer resolution. config: Optional scorer-specific configuration. """ - metric: str - input: str | None = None - output: str | None = None + step_type: ScorerStepType = Field(default="span") + input: JSONValue = None + output: JSONValue = None + scorer_label: str = Field(min_length=1) project_id: str | UUID | None = None - luna_model: str | None = None config: JSONObject | None = None + @model_validator(mode="after") + def ensure_input_or_output(self) -> ScorerInvokeRequest: + if self.input is None and self.output is None: + raise ValueError("Either input or output must be set.") + return self + def to_dict(self) -> JSONObject: - """Convert to the public API request shape.""" - body: JSONObject = {"scorer_label": self.metric} - if self.input is not None: - body["input"] = self.input - if self.output is not None: - body["output"] = self.output - if self.project_id is not None: - body["project_id"] = str(self.project_id) - if self.luna_model is not None: - body["luna_model"] = self.luna_model - if self.config is not None: - body["config"] = self.config - return body + """Convert to the Galileo scorer invoke API request shape.""" + return self.model_dump(mode="json", exclude_none=True) @dataclass @@ -239,10 +238,10 @@ async def invoke( self, *, metric: str, - input: str | None = None, - output: str | None = None, + input: JSONValue = None, + output: JSONValue = None, + step_type: ScorerStepType = "span", project_id: str | UUID | None = None, - luna_model: str | None = None, config: JSONObject | None = None, timeout: float = DEFAULT_TIMEOUT_SECS, headers: dict[str, str] | None = None, @@ -253,8 +252,8 @@ async def invoke( metric: Preset, registered, or fine-tuned scorer label. input: Optional user/system prompt text. output: Optional model response text. + step_type: Runtime step shape used by Galileo scorer input normalization. project_id: Optional Galileo project UUID for project-scoped scorer resolution. - luna_model: Optional Luna model override. config: Optional scorer-specific configuration. timeout: Request timeout in seconds. headers: Additional request headers. @@ -272,11 +271,11 @@ async def invoke( raise ValueError("At least one of input or output must be provided.") request_body = ScorerInvokeRequest( - metric=metric, + scorer_label=metric, input=input, output=output, + step_type=step_type, project_id=project_id, - luna_model=luna_model, config=config, ).to_dict() endpoint, request_headers = self._endpoint_and_headers(project_id, headers) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py index 241e040f..3bcc34a3 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py @@ -36,7 +36,6 @@ class LunaEvaluatorConfig(EvaluatorConfig): project_id: Optional Galileo project UUID for project-scoped scorer resolution. threshold: Local threshold used by the evaluator for comparison. operator: Local comparison operator. Numeric operators use threshold as a number. - luna_model: Optional Luna model override sent to Galileo. scorer_config: Optional scorer-specific config sent as ``config``. timeout_ms: Request timeout in milliseconds. on_error: Error policy: allow=fail open, deny=fail closed. @@ -58,7 +57,6 @@ class LunaEvaluatorConfig(EvaluatorConfig): default="gte", description="Local comparison operator applied to the raw Luna score.", ) - luna_model: str | None = Field(default=None, description="Optional Luna model override") scorer_config: JSONObject | None = Field( default=None, alias="config", diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py index f628cd8e..8afea45d 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -199,7 +199,6 @@ async def evaluate(self, data: Any) -> EvaluatorResult: input=input_text if _has_text(input_text) else None, output=output_text if _has_text(output_text) else None, project_id=self.config.project_id, - luna_model=self.config.luna_model, config=self.config.scorer_config, timeout=self.get_timeout_seconds(), ) diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 53cf58ae..58bd201b 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -31,7 +31,6 @@ def test_config_accepts_direct_scorer_fields(self) -> None: project_id="12345678-1234-5678-1234-567812345678", threshold=0.7, operator="gte", - luna_model="luna-2", config={"temperature": 0}, ) @@ -40,7 +39,6 @@ def test_config_accepts_direct_scorer_fields(self) -> None: assert str(config.project_id) == "12345678-1234-5678-1234-567812345678" assert config.threshold == 0.7 assert config.operator == "gte" - assert config.luna_model == "luna-2" assert config.scorer_config == {"temperature": 0} def test_numeric_operator_requires_numeric_threshold(self) -> None: @@ -54,6 +52,33 @@ def test_numeric_operator_requires_numeric_threshold(self) -> None: class TestGalileoLunaClient: """Tests for the GalileoLunaClient HTTP contract.""" + def test_scorer_invoke_request_matches_orbit_schema_shape(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeRequest + + # Given: a scorer request with project context and scorer config + request = ScorerInvokeRequest( + scorer_label="toxicity", + input={"messages": [{"role": "user", "content": "hello"}]}, + project_id="12345678-1234-5678-1234-567812345678", + config={"top_k": 1}, + ) + + # Then: the serialized payload uses the Orbit scorer invoke fields + assert request.to_dict() == { + "step_type": "span", + "input": {"messages": [{"role": "user", "content": "hello"}]}, + "scorer_label": "toxicity", + "project_id": "12345678-1234-5678-1234-567812345678", + "config": {"top_k": 1}, + } + + def test_scorer_invoke_request_requires_input_or_output(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeRequest + + # Given/When/Then: the request mirrors Orbit validation + with pytest.raises(ValidationError, match="Either input or output must be set"): + ScorerInvokeRequest(scorer_label="toxicity") + def test_client_uses_protect_api_url_derivation(self) -> None: from agent_control_evaluator_galileo.luna import GalileoLunaClient @@ -128,7 +153,6 @@ def handler(request: httpx.Request) -> httpx.Response: input="user prompt", output="model answer", project_id="12345678-1234-5678-1234-567812345678", - luna_model="luna-2", config={"top_k": 1}, ) finally: @@ -142,7 +166,7 @@ def handler(request: httpx.Request) -> httpx.Response: "output": "model answer", "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", - "luna_model": "luna-2", + "step_type": "span", "config": {"top_k": 1}, } assert "stage_name" not in captured["body"] @@ -193,6 +217,7 @@ def handler(request: httpx.Request) -> httpx.Response: "output": "model answer", "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", + "step_type": "span", } headers = captured["headers"] assert isinstance(headers, dict) @@ -301,7 +326,6 @@ async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: input="user prompt", output="model answer", project_id=evaluator.config.project_id, - luna_model=None, config=None, timeout=5.0, ) @@ -335,7 +359,6 @@ async def test_evaluator_returns_non_match_below_threshold(self) -> None: input="hello", output=None, project_id=None, - luna_model=None, config=None, timeout=10.0, ) From 74fcbeb4ce6fd91d3c861daf2b60f6d9e1ffe297 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 12 May 2026 11:11:57 -0700 Subject: [PATCH 35/42] fix(galileo): align luna scorer response schema --- .../luna/client.py | 44 +++++++++++-------- .../luna/evaluator.py | 2 +- .../galileo/tests/test_luna_evaluator.py | 42 +++++++++++++++++- 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index 6786c5e8..effc132a 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -5,7 +5,6 @@ import logging import os from base64 import urlsafe_b64encode -from dataclasses import dataclass, field from hashlib import sha256 from hmac import new as hmac_new from json import dumps @@ -15,7 +14,7 @@ import httpx from agent_control_models import JSONObject, JSONValue -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator logger = logging.getLogger(__name__) @@ -100,41 +99,48 @@ def to_dict(self) -> JSONObject: return self.model_dump(mode="json", exclude_none=True) -@dataclass -class ScorerInvokeResponse: +class ScorerInvokeResponse(BaseModel): """Response from Galileo Luna scorer invocation. Attributes: - metric: Echoed scorer metric. + scorer_label: Echoed scorer label. score: Raw scorer value. status: Invocation status. execution_time: Execution time in seconds, when returned. error_message: Error detail for non-success statuses. - raw_response: Full response body for diagnostics. """ - metric: str + scorer_label: str score: JSONValue status: str = "unknown" execution_time: float | None = None error_message: str | None = None - raw_response: JSONObject = field(default_factory=dict) + _raw_response: JSONObject = PrivateAttr(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def allow_legacy_metric_response(cls, data: object) -> object: + if isinstance(data, dict) and "scorer_label" not in data and "metric" in data: + return data | {"scorer_label": data["metric"]} + return data + + @property + def metric(self) -> str: + """Backward-compatible alias for existing evaluator metadata code.""" + return self.scorer_label + + @property + def raw_response(self) -> JSONObject: + return self._raw_response @classmethod def from_dict(cls, data: JSONObject) -> ScorerInvokeResponse: """Create a response model from the API JSON object.""" - metric_value = data.get("scorer_label", data.get("metric", "")) - status_value = data.get("status", "unknown") - error_value = data.get("error_message") - - return cls( - metric=str(metric_value) if metric_value is not None else "", - score=data.get("score"), - status=str(status_value) if status_value is not None else "unknown", - execution_time=_as_float_or_none(data.get("execution_time")), - error_message=str(error_value) if error_value is not None else None, - raw_response=data, + response = cls.model_validate( + data | {"execution_time": _as_float_or_none(data.get("execution_time"))} ) + response._raw_response = data + return response class GalileoLunaClient: diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py index 8afea45d..9db2f60d 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -227,7 +227,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: def _metadata(self, response: ScorerInvokeResponse) -> dict[str, Any]: metadata: dict[str, Any] = { - "metric": response.metric or self.config.metric, + "metric": response.scorer_label or self.config.metric, "project_id": str(self.config.project_id) if self.config.project_id else None, "score": response.score, "threshold": self.config.threshold, diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 58bd201b..de9da5af 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -79,6 +79,44 @@ def test_scorer_invoke_request_requires_input_or_output(self) -> None: with pytest.raises(ValidationError, match="Either input or output must be set"): ScorerInvokeRequest(scorer_label="toxicity") + def test_scorer_invoke_response_matches_orbit_schema_shape(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeResponse + + # Given: an API scorer invoke response + response = ScorerInvokeResponse.from_dict( + { + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + "error_message": None, + } + ) + + # Then: the model exposes the Orbit/API response fields + assert response.model_dump() == { + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + "error_message": None, + } + assert response.scorer_label == "toxicity" + assert response.metric == "toxicity" + assert response.raw_response["scorer_label"] == "toxicity" + + def test_scorer_invoke_response_accepts_legacy_metric_field(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeResponse + + # Given/When: an older API response uses metric instead of scorer_label + response = ScorerInvokeResponse.from_dict( + {"metric": "toxicity", "score": 0.82, "status": "success"} + ) + + # Then: the client still normalizes it to the current response contract + assert response.scorer_label == "toxicity" + assert response.model_dump()["scorer_label"] == "toxicity" + def test_client_uses_protect_api_url_derivation(self) -> None: from agent_control_evaluator_galileo.luna import GalileoLunaClient @@ -293,7 +331,7 @@ async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: mock_invoke.return_value = ScorerInvokeResponse( - metric="toxicity", + scorer_label="toxicity", score=0.82, status="success", execution_time=0.1, @@ -343,7 +381,7 @@ async def test_evaluator_returns_non_match_below_threshold(self) -> None: with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: mock_invoke.return_value = ScorerInvokeResponse( - metric="toxicity", + scorer_label="toxicity", score=0.2, status="success", ) From 559132195cebeb260663688554c263d9abb715a1 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Tue, 12 May 2026 11:11:57 -0700 Subject: [PATCH 36/42] fix(galileo): align luna scorer response schema --- .../luna/client.py | 44 +++++++++++-------- .../luna/evaluator.py | 2 +- .../galileo/tests/test_luna_evaluator.py | 42 +++++++++++++++++- 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index 6786c5e8..effc132a 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -5,7 +5,6 @@ import logging import os from base64 import urlsafe_b64encode -from dataclasses import dataclass, field from hashlib import sha256 from hmac import new as hmac_new from json import dumps @@ -15,7 +14,7 @@ import httpx from agent_control_models import JSONObject, JSONValue -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator logger = logging.getLogger(__name__) @@ -100,41 +99,48 @@ def to_dict(self) -> JSONObject: return self.model_dump(mode="json", exclude_none=True) -@dataclass -class ScorerInvokeResponse: +class ScorerInvokeResponse(BaseModel): """Response from Galileo Luna scorer invocation. Attributes: - metric: Echoed scorer metric. + scorer_label: Echoed scorer label. score: Raw scorer value. status: Invocation status. execution_time: Execution time in seconds, when returned. error_message: Error detail for non-success statuses. - raw_response: Full response body for diagnostics. """ - metric: str + scorer_label: str score: JSONValue status: str = "unknown" execution_time: float | None = None error_message: str | None = None - raw_response: JSONObject = field(default_factory=dict) + _raw_response: JSONObject = PrivateAttr(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def allow_legacy_metric_response(cls, data: object) -> object: + if isinstance(data, dict) and "scorer_label" not in data and "metric" in data: + return data | {"scorer_label": data["metric"]} + return data + + @property + def metric(self) -> str: + """Backward-compatible alias for existing evaluator metadata code.""" + return self.scorer_label + + @property + def raw_response(self) -> JSONObject: + return self._raw_response @classmethod def from_dict(cls, data: JSONObject) -> ScorerInvokeResponse: """Create a response model from the API JSON object.""" - metric_value = data.get("scorer_label", data.get("metric", "")) - status_value = data.get("status", "unknown") - error_value = data.get("error_message") - - return cls( - metric=str(metric_value) if metric_value is not None else "", - score=data.get("score"), - status=str(status_value) if status_value is not None else "unknown", - execution_time=_as_float_or_none(data.get("execution_time")), - error_message=str(error_value) if error_value is not None else None, - raw_response=data, + response = cls.model_validate( + data | {"execution_time": _as_float_or_none(data.get("execution_time"))} ) + response._raw_response = data + return response class GalileoLunaClient: diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py index 8afea45d..9db2f60d 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -227,7 +227,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: def _metadata(self, response: ScorerInvokeResponse) -> dict[str, Any]: metadata: dict[str, Any] = { - "metric": response.metric or self.config.metric, + "metric": response.scorer_label or self.config.metric, "project_id": str(self.config.project_id) if self.config.project_id else None, "score": response.score, "threshold": self.config.threshold, diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 58bd201b..de9da5af 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -79,6 +79,44 @@ def test_scorer_invoke_request_requires_input_or_output(self) -> None: with pytest.raises(ValidationError, match="Either input or output must be set"): ScorerInvokeRequest(scorer_label="toxicity") + def test_scorer_invoke_response_matches_orbit_schema_shape(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeResponse + + # Given: an API scorer invoke response + response = ScorerInvokeResponse.from_dict( + { + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + "error_message": None, + } + ) + + # Then: the model exposes the Orbit/API response fields + assert response.model_dump() == { + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + "error_message": None, + } + assert response.scorer_label == "toxicity" + assert response.metric == "toxicity" + assert response.raw_response["scorer_label"] == "toxicity" + + def test_scorer_invoke_response_accepts_legacy_metric_field(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeResponse + + # Given/When: an older API response uses metric instead of scorer_label + response = ScorerInvokeResponse.from_dict( + {"metric": "toxicity", "score": 0.82, "status": "success"} + ) + + # Then: the client still normalizes it to the current response contract + assert response.scorer_label == "toxicity" + assert response.model_dump()["scorer_label"] == "toxicity" + def test_client_uses_protect_api_url_derivation(self) -> None: from agent_control_evaluator_galileo.luna import GalileoLunaClient @@ -293,7 +331,7 @@ async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: mock_invoke.return_value = ScorerInvokeResponse( - metric="toxicity", + scorer_label="toxicity", score=0.82, status="success", execution_time=0.1, @@ -343,7 +381,7 @@ async def test_evaluator_returns_non_match_below_threshold(self) -> None: with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: mock_invoke.return_value = ScorerInvokeResponse( - metric="toxicity", + scorer_label="toxicity", score=0.2, status="success", ) From 7b0a15d2b6d8b8a98a38d311c4818016e92ae394 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Wed, 13 May 2026 12:01:56 -0700 Subject: [PATCH 37/42] update the schemas and corresponding tests --- .../luna/client.py | 30 ++++------- .../luna/config.py | 6 +-- .../luna/evaluator.py | 10 ++-- .../galileo/tests/test_luna_evaluator.py | 51 +++++++------------ examples/galileo_luna/README.md | 2 +- examples/galileo_luna/setup_controls.py | 6 +-- 6 files changed, 40 insertions(+), 65 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index effc132a..426b1782 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -66,14 +66,14 @@ def _as_float_or_none(value: JSONValue) -> float | None: return None -ScorerStepType = Literal["session", "trace", "span"] +RootType = Literal["session", "trace", "span"] class ScorerInvokeRequest(BaseModel): """Request payload for Galileo Luna scorer invocation. Attributes: - step_type: Runtime step shape used by Galileo scorer input normalization. + root_type: Runtime step shape used by Galileo scorer input normalization. input: Optional user/system prompt text. output: Optional model response text. scorer_label: Preset, registered, or fine-tuned scorer label. @@ -81,7 +81,7 @@ class ScorerInvokeRequest(BaseModel): config: Optional scorer-specific configuration. """ - step_type: ScorerStepType = Field(default="span") + root_type: RootType = Field(default="span") input: JSONValue = None output: JSONValue = None scorer_label: str = Field(min_length=1) @@ -117,18 +117,6 @@ class ScorerInvokeResponse(BaseModel): error_message: str | None = None _raw_response: JSONObject = PrivateAttr(default_factory=dict) - @model_validator(mode="before") - @classmethod - def allow_legacy_metric_response(cls, data: object) -> object: - if isinstance(data, dict) and "scorer_label" not in data and "metric" in data: - return data | {"scorer_label": data["metric"]} - return data - - @property - def metric(self) -> str: - """Backward-compatible alias for existing evaluator metadata code.""" - return self.scorer_label - @property def raw_response(self) -> JSONObject: return self._raw_response @@ -243,10 +231,10 @@ def _endpoint_and_headers( async def invoke( self, *, - metric: str, + scorer_label: str, input: JSONValue = None, output: JSONValue = None, - step_type: ScorerStepType = "span", + root_type: RootType = "span", project_id: str | UUID | None = None, config: JSONObject | None = None, timeout: float = DEFAULT_TIMEOUT_SECS, @@ -255,10 +243,10 @@ async def invoke( """Invoke a Galileo Luna scorer. Args: - metric: Preset, registered, or fine-tuned scorer label. + scorer_label: Preset, registered, or fine-tuned scorer label. input: Optional user/system prompt text. output: Optional model response text. - step_type: Runtime step shape used by Galileo scorer input normalization. + root_type: Runtime step shape used by Galileo scorer input normalization. project_id: Optional Galileo project UUID for project-scoped scorer resolution. config: Optional scorer-specific configuration. timeout: Request timeout in seconds. @@ -277,10 +265,10 @@ async def invoke( raise ValueError("At least one of input or output must be provided.") request_body = ScorerInvokeRequest( - scorer_label=metric, + scorer_label=scorer_label, input=input, output=output, - step_type=step_type, + root_type=root_type, project_id=project_id, config=config, ).to_dict() diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py index 3bcc34a3..1e41a554 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py @@ -32,7 +32,7 @@ class LunaEvaluatorConfig(EvaluatorConfig): """Configuration for direct Luna scorer evaluation. Attributes: - metric: Preset, registered, or fine-tuned scorer name. + scorer_label: Preset, registered, or fine-tuned scorer label. project_id: Optional Galileo project UUID for project-scoped scorer resolution. threshold: Local threshold used by the evaluator for comparison. operator: Local comparison operator. Numeric operators use threshold as a number. @@ -40,11 +40,11 @@ class LunaEvaluatorConfig(EvaluatorConfig): timeout_ms: Request timeout in milliseconds. on_error: Error policy: allow=fail open, deny=fail closed. payload_field: Force selected data into input or output. If omitted, root step - payloads with input/output use both fields; scalar data is inferred from metric name. + payloads with input/output use both fields; scalar data is inferred from scorer label. include_raw_response: Include the raw API response in EvaluatorResult metadata. """ - metric: str = Field(..., min_length=1, description="Luna metric/scorer name to evaluate") + scorer_label: str = Field(..., min_length=1, description="Luna scorer label to invoke") project_id: UUID | None = Field( default=None, description="Optional Galileo project UUID for project-scoped scorer resolution.", diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py index 9db2f60d..a5b3f248 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -139,7 +139,7 @@ def _prepare_payload(self, data: Any) -> tuple[str | None, str | None]: return input_text, output_text text = _coerce_payload_text(data) - if "output" in self.config.metric: + if "output" in self.config.scorer_label: return None, text return text, None @@ -190,12 +190,12 @@ async def evaluate(self, data: Any) -> EvaluatorResult: matched=False, confidence=1.0, message="No data to score with Luna", - metadata={"metric": self.config.metric}, + metadata={"scorer_label": self.config.scorer_label}, ) try: response = await self._get_client().invoke( - metric=self.config.metric, + scorer_label=self.config.scorer_label, input=input_text if _has_text(input_text) else None, output=output_text if _has_text(output_text) else None, project_id=self.config.project_id, @@ -227,7 +227,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: def _metadata(self, response: ScorerInvokeResponse) -> dict[str, Any]: metadata: dict[str, Any] = { - "metric": response.scorer_label or self.config.metric, + "scorer_label": response.scorer_label or self.config.scorer_label, "project_id": str(self.config.project_id) if self.config.project_id else None, "score": response.score, "threshold": self.config.threshold, @@ -251,7 +251,7 @@ def _handle_error(self, error: Exception) -> EvaluatorResult: metadata={ "error": error_detail, "error_type": type(error).__name__, - "metric": self.config.metric, + "scorer_label": self.config.scorer_label, "fallback_action": fallback, }, error=None if matched else error_detail, diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index de9da5af..31323a42 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -27,7 +27,7 @@ def test_config_accepts_direct_scorer_fields(self) -> None: # Given: a direct scorer config with local thresholding config = LunaEvaluatorConfig( - metric="toxicity", + scorer_label="toxicity", project_id="12345678-1234-5678-1234-567812345678", threshold=0.7, operator="gte", @@ -35,7 +35,7 @@ def test_config_accepts_direct_scorer_fields(self) -> None: ) # Then: config is retained without Protect concepts - assert config.metric == "toxicity" + assert config.scorer_label == "toxicity" assert str(config.project_id) == "12345678-1234-5678-1234-567812345678" assert config.threshold == 0.7 assert config.operator == "gte" @@ -46,7 +46,7 @@ def test_numeric_operator_requires_numeric_threshold(self) -> None: # Given/When/Then: numeric local comparison rejects non-numeric thresholds with pytest.raises(ValidationError, match="numeric threshold"): - LunaEvaluatorConfig(metric="toxicity", threshold="high", operator="gte") + LunaEvaluatorConfig(scorer_label="toxicity", threshold="high", operator="gte") class TestGalileoLunaClient: @@ -65,7 +65,7 @@ def test_scorer_invoke_request_matches_orbit_schema_shape(self) -> None: # Then: the serialized payload uses the Orbit scorer invoke fields assert request.to_dict() == { - "step_type": "span", + "root_type": "span", "input": {"messages": [{"role": "user", "content": "hello"}]}, "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", @@ -102,21 +102,8 @@ def test_scorer_invoke_response_matches_orbit_schema_shape(self) -> None: "error_message": None, } assert response.scorer_label == "toxicity" - assert response.metric == "toxicity" assert response.raw_response["scorer_label"] == "toxicity" - def test_scorer_invoke_response_accepts_legacy_metric_field(self) -> None: - from agent_control_evaluator_galileo.luna import ScorerInvokeResponse - - # Given/When: an older API response uses metric instead of scorer_label - response = ScorerInvokeResponse.from_dict( - {"metric": "toxicity", "score": 0.82, "status": "success"} - ) - - # Then: the client still normalizes it to the current response contract - assert response.scorer_label == "toxicity" - assert response.model_dump()["scorer_label"] == "toxicity" - def test_client_uses_protect_api_url_derivation(self) -> None: from agent_control_evaluator_galileo.luna import GalileoLunaClient @@ -187,7 +174,7 @@ def handler(request: httpx.Request) -> httpx.Response: try: # When: invoking a scorer response = await client.invoke( - metric="toxicity", + scorer_label="toxicity", input="user prompt", output="model answer", project_id="12345678-1234-5678-1234-567812345678", @@ -204,7 +191,7 @@ def handler(request: httpx.Request) -> httpx.Response: "output": "model answer", "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", - "step_type": "span", + "root_type": "span", "config": {"top_k": 1}, } assert "stage_name" not in captured["body"] @@ -241,7 +228,7 @@ def handler(request: httpx.Request) -> httpx.Response: try: # When: invoking a scorer with project context response = await client.invoke( - metric="toxicity", + scorer_label="toxicity", output="model answer", project_id="12345678-1234-5678-1234-567812345678", ) @@ -255,7 +242,7 @@ def handler(request: httpx.Request) -> httpx.Response: "output": "model answer", "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", - "step_type": "span", + "root_type": "span", } headers = captured["headers"] assert isinstance(headers, dict) @@ -278,7 +265,7 @@ async def test_client_requires_project_id_for_internal_jwt(self) -> None: # When/Then: project_id is required because API uses it as the internal auth context with pytest.raises(ValueError, match="project_id is required"): - await client.invoke(metric="toxicity", output="model answer") + await client.invoke(scorer_label="toxicity", output="model answer") class TestLunaEvaluator: @@ -296,7 +283,7 @@ def test_evaluator_init_without_auth_raises(self) -> None: from agent_control_evaluator_galileo.luna import LunaEvaluator with pytest.raises(ValueError, match="GALILEO_API_SECRET_KEY or GALILEO_API_KEY"): - LunaEvaluator.from_dict({"metric": "toxicity", "threshold": 0.5}) + LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) @patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True) def test_evaluator_init_accepts_api_secret(self) -> None: @@ -304,7 +291,7 @@ def test_evaluator_init_accepts_api_secret(self) -> None: evaluator = LunaEvaluator.from_dict( { - "metric": "toxicity", + "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", "threshold": 0.5, } @@ -321,7 +308,7 @@ async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: # Given: a direct Luna evaluator and a raw successful scorer response evaluator = LunaEvaluator.from_dict( { - "metric": "toxicity", + "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", "threshold": 0.7, "operator": "gte", @@ -350,7 +337,7 @@ async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: assert result.matched is True assert result.confidence == 0.82 assert result.metadata == { - "metric": "toxicity", + "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", "score": 0.82, "threshold": 0.7, @@ -360,7 +347,7 @@ async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: "error_message": None, } mock_invoke.assert_awaited_once_with( - metric="toxicity", + scorer_label="toxicity", input="user prompt", output="model answer", project_id=evaluator.config.project_id, @@ -376,7 +363,7 @@ async def test_evaluator_returns_non_match_below_threshold(self) -> None: # Given: a raw scorer value below the local threshold evaluator = LunaEvaluator.from_dict( - {"metric": "toxicity", "threshold": 0.7, "operator": "gte"} + {"scorer_label": "toxicity", "threshold": 0.7, "operator": "gte"} ) with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: @@ -393,7 +380,7 @@ async def test_evaluator_returns_non_match_below_threshold(self) -> None: assert result.matched is False assert result.confidence == 0.2 mock_invoke.assert_awaited_once_with( - metric="toxicity", + scorer_label="toxicity", input="hello", output=None, project_id=None, @@ -408,7 +395,7 @@ async def test_evaluator_does_not_call_api_for_empty_data(self) -> None: from agent_control_evaluator_galileo.luna.client import GalileoLunaClient # Given: an evaluator and empty selected data - evaluator = LunaEvaluator.from_dict({"metric": "toxicity", "threshold": 0.5}) + evaluator = LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: # When: evaluating empty data @@ -427,7 +414,7 @@ async def test_evaluator_fail_open_sets_error(self) -> None: from agent_control_evaluator_galileo.luna.client import GalileoLunaClient # Given: default fail-open behavior - evaluator = LunaEvaluator.from_dict({"metric": "toxicity", "threshold": 0.5}) + evaluator = LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: mock_invoke.side_effect = RuntimeError("service unavailable") @@ -449,7 +436,7 @@ async def test_evaluator_fail_closed_matches_without_error_field(self) -> None: # Given: fail-closed behavior for scorer errors evaluator = LunaEvaluator.from_dict( - {"metric": "toxicity", "threshold": 0.5, "on_error": "deny"} + {"scorer_label": "toxicity", "threshold": 0.5, "on_error": "deny"} ) with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: diff --git a/examples/galileo_luna/README.md b/examples/galileo_luna/README.md index d43a2d71..534ef640 100644 --- a/examples/galileo_luna/README.md +++ b/examples/galileo_luna/README.md @@ -33,7 +33,7 @@ export GALILEO_PROJECT_ID="00000000-0000-0000-0000-000000000000" Optional scorer settings: ```bash -export GALILEO_LUNA_METRIC="toxicity" +export GALILEO_LUNA_SCORER_LABEL="toxicity" export GALILEO_LUNA_THRESHOLD="0.5" ``` diff --git a/examples/galileo_luna/setup_controls.py b/examples/galileo_luna/setup_controls.py index 3d325cde..69a36ad5 100644 --- a/examples/galileo_luna/setup_controls.py +++ b/examples/galileo_luna/setup_controls.py @@ -23,7 +23,7 @@ AGENT_DESCRIPTION = "Demo agent protected by direct Galileo Luna scorer controls" SERVER_URL = os.getenv("AGENT_CONTROL_URL", "http://localhost:8000") -LUNA_METRIC = os.getenv("GALILEO_LUNA_METRIC", "toxicity") +LUNA_SCORER_LABEL = os.getenv("GALILEO_LUNA_SCORER_LABEL", "toxicity") LUNA_THRESHOLD = float(os.getenv("GALILEO_LUNA_THRESHOLD", "0.5")) GALILEO_PROJECT_ID = os.getenv("GALILEO_PROJECT_ID") @@ -41,7 +41,7 @@ def luna_config() -> dict[str, Any]: """Build the direct Luna evaluator config used by the composite control.""" config: dict[str, Any] = { - "metric": LUNA_METRIC, + "scorer_label": LUNA_SCORER_LABEL, "threshold": LUNA_THRESHOLD, "operator": "gte", "payload_field": "output", @@ -158,7 +158,7 @@ async def setup_demo() -> None: print("Setting up direct Galileo Luna demo controls") print(f"Server: {SERVER_URL}") print(f"Agent: {AGENT_NAME}") - print(f"Luna: metric={LUNA_METRIC!r}, threshold={LUNA_THRESHOLD}") + print(f"Luna: scorer_label={LUNA_SCORER_LABEL!r}, threshold={LUNA_THRESHOLD}") if GALILEO_PROJECT_ID: print(f"Project ID: {GALILEO_PROJECT_ID}") From 523524d07fb9837fa574106fe6346a07f25e25be Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Wed, 13 May 2026 17:37:14 -0700 Subject: [PATCH 38/42] update the schemas for scorer --- .../luna/__init__.py | 2 + .../luna/client.py | 33 ++++++++-------- .../galileo/tests/test_luna_evaluator.py | 37 +++++++++--------- .../src/agent_control/evaluators/__init__.py | 38 +++++++++++-------- 4 files changed, 62 insertions(+), 48 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py index c3ff0375..b26feaac 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py @@ -2,6 +2,7 @@ from agent_control_evaluator_galileo.luna.client import ( GalileoLunaClient, + ScorerInvokeInputs, ScorerInvokeRequest, ScorerInvokeResponse, ) @@ -10,6 +11,7 @@ __all__ = [ "GalileoLunaClient", + "ScorerInvokeInputs", "ScorerInvokeRequest", "ScorerInvokeResponse", "LunaEvaluatorConfig", diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index 426b1782..a2ccdc3f 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -9,7 +9,6 @@ from hmac import new as hmac_new from json import dumps from time import time -from typing import Literal from uuid import UUID import httpx @@ -66,32 +65,38 @@ def _as_float_or_none(value: JSONValue) -> float | None: return None -RootType = Literal["session", "trace", "span"] +def _has_value(value: JSONValue) -> bool: + return value is not None and value != "" + + +class ScorerInvokeInputs(BaseModel): + """Input values sent to Galileo's scorer invoke API.""" + + query: JSONValue = "" + response: JSONValue = "" + ground_truth: JSONValue = None + tools: JSONValue = None class ScorerInvokeRequest(BaseModel): """Request payload for Galileo Luna scorer invocation. Attributes: - root_type: Runtime step shape used by Galileo scorer input normalization. - input: Optional user/system prompt text. - output: Optional model response text. + inputs: Selected scorer input values. scorer_label: Preset, registered, or fine-tuned scorer label. project_id: Optional Galileo project UUID for project-scoped scorer resolution. config: Optional scorer-specific configuration. """ - root_type: RootType = Field(default="span") - input: JSONValue = None - output: JSONValue = None scorer_label: str = Field(min_length=1) + inputs: ScorerInvokeInputs project_id: str | UUID | None = None config: JSONObject | None = None @model_validator(mode="after") def ensure_input_or_output(self) -> ScorerInvokeRequest: - if self.input is None and self.output is None: - raise ValueError("Either input or output must be set.") + if not (_has_value(self.inputs.query) or _has_value(self.inputs.response)): + raise ValueError("Either inputs.query or inputs.response must be set.") return self def to_dict(self) -> JSONObject: @@ -234,7 +239,6 @@ async def invoke( scorer_label: str, input: JSONValue = None, output: JSONValue = None, - root_type: RootType = "span", project_id: str | UUID | None = None, config: JSONObject | None = None, timeout: float = DEFAULT_TIMEOUT_SECS, @@ -246,7 +250,6 @@ async def invoke( scorer_label: Preset, registered, or fine-tuned scorer label. input: Optional user/system prompt text. output: Optional model response text. - root_type: Runtime step shape used by Galileo scorer input normalization. project_id: Optional Galileo project UUID for project-scoped scorer resolution. config: Optional scorer-specific configuration. timeout: Request timeout in seconds. @@ -266,9 +269,9 @@ async def invoke( request_body = ScorerInvokeRequest( scorer_label=scorer_label, - input=input, - output=output, - root_type=root_type, + inputs=ScorerInvokeInputs( + query="" if input is None else input, response="" if output is None else output + ), project_id=project_id, config=config, ).to_dict() diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 31323a42..9f4ae862 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -52,22 +52,24 @@ def test_numeric_operator_requires_numeric_threshold(self) -> None: class TestGalileoLunaClient: """Tests for the GalileoLunaClient HTTP contract.""" - def test_scorer_invoke_request_matches_orbit_schema_shape(self) -> None: - from agent_control_evaluator_galileo.luna import ScorerInvokeRequest + def test_scorer_invoke_request_matches_api_schema_shape(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeInputs, ScorerInvokeRequest # Given: a scorer request with project context and scorer config request = ScorerInvokeRequest( scorer_label="toxicity", - input={"messages": [{"role": "user", "content": "hello"}]}, + inputs=ScorerInvokeInputs(query={"messages": [{"role": "user", "content": "hello"}]}), project_id="12345678-1234-5678-1234-567812345678", config={"top_k": 1}, ) - # Then: the serialized payload uses the Orbit scorer invoke fields + # Then: the serialized payload uses the API-owned scorer invoke fields assert request.to_dict() == { - "root_type": "span", - "input": {"messages": [{"role": "user", "content": "hello"}]}, "scorer_label": "toxicity", + "inputs": { + "query": {"messages": [{"role": "user", "content": "hello"}]}, + "response": "", + }, "project_id": "12345678-1234-5678-1234-567812345678", "config": {"top_k": 1}, } @@ -75,11 +77,13 @@ def test_scorer_invoke_request_matches_orbit_schema_shape(self) -> None: def test_scorer_invoke_request_requires_input_or_output(self) -> None: from agent_control_evaluator_galileo.luna import ScorerInvokeRequest - # Given/When/Then: the request mirrors Orbit validation - with pytest.raises(ValidationError, match="Either input or output must be set"): - ScorerInvokeRequest(scorer_label="toxicity") + # Given/When/Then: the request mirrors API validation + with pytest.raises( + ValidationError, match="Either inputs.query or inputs.response must be set" + ): + ScorerInvokeRequest(scorer_label="toxicity", inputs={}) - def test_scorer_invoke_response_matches_orbit_schema_shape(self) -> None: + def test_scorer_invoke_response_matches_api_schema_shape(self) -> None: from agent_control_evaluator_galileo.luna import ScorerInvokeResponse # Given: an API scorer invoke response @@ -93,7 +97,7 @@ def test_scorer_invoke_response_matches_orbit_schema_shape(self) -> None: } ) - # Then: the model exposes the Orbit/API response fields + # Then: the model exposes the API response fields assert response.model_dump() == { "scorer_label": "toxicity", "score": 0.82, @@ -187,11 +191,9 @@ def handler(request: httpx.Request) -> httpx.Response: assert response.score == 0.82 assert captured["url"] == "https://api.demo-v2.galileocloud.io/scorers/invoke" assert captured["body"] == { - "input": "user prompt", - "output": "model answer", "scorer_label": "toxicity", + "inputs": {"query": "user prompt", "response": "model answer"}, "project_id": "12345678-1234-5678-1234-567812345678", - "root_type": "span", "config": {"top_k": 1}, } assert "stage_name" not in captured["body"] @@ -237,12 +239,13 @@ def handler(request: httpx.Request) -> httpx.Response: # Then: the internal scorer endpoint is called with a project-bound JWT assert response.score == 0.82 - assert captured["url"] == "https://api.default.svc.cluster.local:8088/internal/scorers/invoke" + assert ( + captured["url"] == "https://api.default.svc.cluster.local:8088/internal/scorers/invoke" + ) assert captured["body"] == { - "output": "model answer", "scorer_label": "toxicity", + "inputs": {"query": "", "response": "model answer"}, "project_id": "12345678-1234-5678-1234-567812345678", - "root_type": "span", } headers = captured["headers"] assert isinstance(headers, dict) diff --git a/sdks/python/src/agent_control/evaluators/__init__.py b/sdks/python/src/agent_control/evaluators/__init__.py index 9fd87e71..8366a107 100644 --- a/sdks/python/src/agent_control/evaluators/__init__.py +++ b/sdks/python/src/agent_control/evaluators/__init__.py @@ -44,19 +44,23 @@ LunaEvaluator, LunaEvaluatorConfig, LunaOperator, + ScorerInvokeInputs, ScorerInvokeRequest, ScorerInvokeResponse, ) - __all__.extend([ - "GalileoLunaClient", - "ScorerInvokeRequest", - "ScorerInvokeResponse", - "LunaEvaluator", - "LunaEvaluatorConfig", - "LunaOperator", - "LUNA_AVAILABLE", - ]) + __all__.extend( + [ + "GalileoLunaClient", + "ScorerInvokeInputs", + "ScorerInvokeRequest", + "ScorerInvokeResponse", + "LunaEvaluator", + "LunaEvaluatorConfig", + "LunaOperator", + "LUNA_AVAILABLE", + ] + ) except ImportError: pass @@ -69,12 +73,14 @@ Luna2Operator, ) - __all__.extend([ - "Luna2Evaluator", - "Luna2EvaluatorConfig", - "Luna2Metric", - "Luna2Operator", - "LUNA2_AVAILABLE", - ]) + __all__.extend( + [ + "Luna2Evaluator", + "Luna2EvaluatorConfig", + "Luna2Metric", + "Luna2Operator", + "LUNA2_AVAILABLE", + ] + ) except ImportError: pass From 34f430df0b8934a670286ea4c9712254fd35e748 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Wed, 13 May 2026 21:56:33 -0700 Subject: [PATCH 39/42] update luna client schemas --- .../luna/client.py | 10 +++++++-- .../galileo/tests/test_luna_evaluator.py | 21 +++++++++++++++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py index a2ccdc3f..86033339 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -66,7 +66,13 @@ def _as_float_or_none(value: JSONValue) -> float | None: def _has_value(value: JSONValue) -> bool: - return value is not None and value != "" + if value is None: + return False + if isinstance(value, str): + return value.strip() != "" + if isinstance(value, (list, dict)): + return len(value) > 0 + return True class ScorerInvokeInputs(BaseModel): @@ -264,7 +270,7 @@ async def invoke( httpx.HTTPStatusError: If the API returns an error status code. httpx.RequestError: If the request fails before a response is received. """ - if input is None and output is None: + if not (_has_value(input) or _has_value(output)): raise ValueError("At least one of input or output must be provided.") request_body = ScorerInvokeRequest( diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 9f4ae862..80a5e00b 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -74,14 +74,18 @@ def test_scorer_invoke_request_matches_api_schema_shape(self) -> None: "config": {"top_k": 1}, } - def test_scorer_invoke_request_requires_input_or_output(self) -> None: + @pytest.mark.parametrize("empty_value", ["", " ", {}, []]) + def test_scorer_invoke_request_requires_input_or_output(self, empty_value: object) -> None: from agent_control_evaluator_galileo.luna import ScorerInvokeRequest # Given/When/Then: the request mirrors API validation with pytest.raises( ValidationError, match="Either inputs.query or inputs.response must be set" ): - ScorerInvokeRequest(scorer_label="toxicity", inputs={}) + ScorerInvokeRequest( + scorer_label="toxicity", + inputs={"query": empty_value, "response": empty_value}, + ) def test_scorer_invoke_response_matches_api_schema_shape(self) -> None: from agent_control_evaluator_galileo.luna import ScorerInvokeResponse @@ -270,6 +274,19 @@ async def test_client_requires_project_id_for_internal_jwt(self) -> None: with pytest.raises(ValueError, match="project_id is required"): await client.invoke(scorer_label="toxicity", output="model answer") + @pytest.mark.asyncio + @pytest.mark.parametrize("empty_value", ["", " ", {}, []]) + async def test_client_rejects_missing_input_and_output_values(self, empty_value: object) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: a Luna client and scorer input values that API treats as missing + with patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}, clear=True): + client = GalileoLunaClient(api_url="https://api.default.svc.cluster.local:8088") + + # When/Then: the client rejects the request before calling API + with pytest.raises(ValueError, match="At least one of input or output must be provided"): + await client.invoke(scorer_label="toxicity", input=empty_value, output=empty_value) + class TestLunaEvaluator: """Tests for direct Luna evaluator behavior.""" From ad0b2dc98b30fcaffe0c5897cfee08e96de83e03 Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Wed, 13 May 2026 21:59:37 -0700 Subject: [PATCH 40/42] fix tests --- evaluators/contrib/galileo/tests/test_luna_evaluator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 80a5e00b..5cf1fcf8 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -276,7 +276,9 @@ async def test_client_requires_project_id_for_internal_jwt(self) -> None: @pytest.mark.asyncio @pytest.mark.parametrize("empty_value", ["", " ", {}, []]) - async def test_client_rejects_missing_input_and_output_values(self, empty_value: object) -> None: + async def test_client_rejects_missing_input_and_output_values( + self, empty_value: object + ) -> None: from agent_control_evaluator_galileo.luna import GalileoLunaClient # Given: a Luna client and scorer input values that API treats as missing From d428842ef2388cc4b242d5a3820835a429f7241b Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Wed, 13 May 2026 22:18:09 -0700 Subject: [PATCH 41/42] fix broken test --- evaluators/contrib/galileo/tests/test_luna_evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 326d265d..63604912 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -327,7 +327,7 @@ def test_evaluator_init_accepts_api_secret(self) -> None: evaluator = LunaEvaluator.from_dict( { - "metric": "toxicity", + "scorer_label": "toxicity", "project_id": "12345678-1234-5678-1234-567812345678", "threshold": 0.5, } From 81cea0471518dd22e0964889412155d5122881de Mon Sep 17 00:00:00 2001 From: "namrata.ghadi" Date: Thu, 14 May 2026 15:11:17 -0700 Subject: [PATCH 42/42] remove unwanted fields --- .../luna/config.py | 16 ----------- .../luna/evaluator.py | 15 ++--------- .../galileo/tests/test_luna_evaluator.py | 27 ++----------------- 3 files changed, 4 insertions(+), 54 deletions(-) diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py index 1e41a554..7bf5de48 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py @@ -38,10 +38,6 @@ class LunaEvaluatorConfig(EvaluatorConfig): operator: Local comparison operator. Numeric operators use threshold as a number. scorer_config: Optional scorer-specific config sent as ``config``. timeout_ms: Request timeout in milliseconds. - on_error: Error policy: allow=fail open, deny=fail closed. - payload_field: Force selected data into input or output. If omitted, root step - payloads with input/output use both fields; scalar data is inferred from scorer label. - include_raw_response: Include the raw API response in EvaluatorResult metadata. """ scorer_label: str = Field(..., min_length=1, description="Luna scorer label to invoke") @@ -69,18 +65,6 @@ class LunaEvaluatorConfig(EvaluatorConfig): le=60000, description="Request timeout in milliseconds (1-60 seconds)", ) - on_error: Literal["allow", "deny"] = Field( - default="allow", - description="Action on error: 'allow' (fail open) or 'deny' (fail closed)", - ) - payload_field: Literal["input", "output"] | None = Field( - default=None, - description="Explicitly set which scorer payload field receives scalar selected data.", - ) - include_raw_response: bool = Field( - default=False, - description="Include the raw scorer response in result metadata.", - ) @model_validator(mode="after") def validate_threshold(self) -> LunaEvaluatorConfig: diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py index a5b3f248..f9e0ad0d 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -126,12 +126,6 @@ def _get_client(self) -> GalileoLunaClient: def _prepare_payload(self, data: Any) -> tuple[str | None, str | None]: """Prepare scorer input/output fields from selected data.""" - if self.config.payload_field is not None: - text = _coerce_payload_text(data) - if self.config.payload_field == "output": - return None, text - return text, None - if isinstance(data, dict): input_text = _extract_dict_text(data, "input") output_text = _extract_dict_text(data, "output") @@ -236,25 +230,20 @@ def _metadata(self, response: ScorerInvokeResponse) -> dict[str, Any]: "execution_time_seconds": response.execution_time, "error_message": response.error_message, } - if self.config.include_raw_response: - metadata["raw_response"] = response.raw_response return metadata def _handle_error(self, error: Exception) -> EvaluatorResult: - fallback = self.config.on_error - matched = fallback == "deny" error_detail = str(error) return EvaluatorResult( - matched=matched, + matched=False, confidence=0.0, message=f"Luna evaluation error: {error_detail}", metadata={ "error": error_detail, "error_type": type(error).__name__, "scorer_label": self.config.scorer_label, - "fallback_action": fallback, }, - error=None if matched else error_detail, + error=error_detail, ) async def aclose(self) -> None: diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py index 5cf1fcf8..1b0bcef8 100644 --- a/evaluators/contrib/galileo/tests/test_luna_evaluator.py +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -435,7 +435,7 @@ async def test_evaluator_fail_open_sets_error(self) -> None: from agent_control_evaluator_galileo.luna import LunaEvaluator from agent_control_evaluator_galileo.luna.client import GalileoLunaClient - # Given: default fail-open behavior + # Given: fixed fail-open behavior for scorer errors evaluator = LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: @@ -448,27 +448,4 @@ async def test_evaluator_fail_open_sets_error(self) -> None: assert result.matched is False assert result.error == "service unavailable" assert result.metadata is not None - assert result.metadata["fallback_action"] == "allow" - - @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) - @pytest.mark.asyncio - async def test_evaluator_fail_closed_matches_without_error_field(self) -> None: - from agent_control_evaluator_galileo.luna import LunaEvaluator - from agent_control_evaluator_galileo.luna.client import GalileoLunaClient - - # Given: fail-closed behavior for scorer errors - evaluator = LunaEvaluator.from_dict( - {"scorer_label": "toxicity", "threshold": 0.5, "on_error": "deny"} - ) - - with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: - mock_invoke.side_effect = RuntimeError("service unavailable") - - # When: the scorer call fails - result = await evaluator.evaluate("hello") - - # Then: the control matches so deny/steer actions can be applied by the engine - assert result.matched is True - assert result.error is None - assert result.metadata is not None - assert result.metadata["fallback_action"] == "deny" + assert "fallback_action" not in result.metadata