diff --git a/pyproject.toml b/pyproject.toml index 38aaacd3..4604285f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ "langchain-text-splitters>=1.0,<2.0", "beautifulsoup4>=4.12", "psycopg[binary]>=3.2", + "toolguard>=0.2.17", "cuga-oak-health; python_version>='3.12'", "aiosmtpd", ] diff --git a/src/cuga/backend/cuga_graph/policy/__init__.py b/src/cuga/backend/cuga_graph/policy/__init__.py index 3dace391..50a224ab 100644 --- a/src/cuga/backend/cuga_graph/policy/__init__.py +++ b/src/cuga/backend/cuga_graph/policy/__init__.py @@ -8,6 +8,10 @@ PolicyType, Playbook, IntentGuard, + ToolGuide, + ToolGuard, + ToolApproval, + OutputFormatter, CustomPolicy, PlaybookStep, IntentGuardResponse, @@ -31,6 +35,10 @@ "PolicyType", "Playbook", "IntentGuard", + "ToolGuide", + "ToolGuard", + "ToolApproval", + "OutputFormatter", "CustomPolicy", "PlaybookStep", "IntentGuardResponse", diff --git a/src/cuga/backend/cuga_graph/policy/filesystem_sync.py b/src/cuga/backend/cuga_graph/policy/filesystem_sync.py index a290b29f..fb2801a3 100644 --- a/src/cuga/backend/cuga_graph/policy/filesystem_sync.py +++ b/src/cuga/backend/cuga_graph/policy/filesystem_sync.py @@ -91,7 +91,7 @@ def _policy_to_markdown(self, policy: Policy) -> str: # Build frontmatter # Get policy type - try 'type' first, then 'policy_type' policy_type = getattr(policy, 'type', None) or getattr(policy, 'policy_type', None) - policy_type_value = policy_type.value if hasattr(policy_type, 'value') else str(policy_type) + policy_type_value = policy_type.value if policy_type is not None and hasattr(policy_type, 'value') else str(policy_type) frontmatter = { 'id': policy.id, @@ -103,10 +103,11 @@ def _policy_to_markdown(self, policy: Policy) -> str: } # Add triggers if present - if hasattr(policy, 'triggers') and policy.triggers: + policy_triggers = getattr(policy, 'triggers', None) + if policy_triggers: triggers_config = {} - for trigger in policy.triggers: + for trigger in policy_triggers: if isinstance(trigger, KeywordTrigger): triggers_config['keywords'] = trigger.value triggers_config['target'] = trigger.target @@ -133,6 +134,10 @@ def _policy_to_markdown(self, policy: Policy) -> str: frontmatter['target_tools'] = policy.target_tools if policy.target_apps: frontmatter['target_apps'] = policy.target_apps + if policy.tool_guards: + frontmatter['tool_guards'] = { + tool_name: guard.model_dump() for tool_name, guard in policy.tool_guards.items() + } frontmatter['prepend'] = policy.prepend content = policy.guide_content or "" elif isinstance(policy, IntentGuard): diff --git a/src/cuga/backend/cuga_graph/policy/folder_loader.py b/src/cuga/backend/cuga_graph/policy/folder_loader.py index a88b1287..636d57bc 100644 --- a/src/cuga/backend/cuga_graph/policy/folder_loader.py +++ b/src/cuga/backend/cuga_graph/policy/folder_loader.py @@ -7,7 +7,7 @@ import os from pathlib import Path -from typing import Dict, Any, List +from typing import Dict, Any, List, cast import yaml from loguru import logger @@ -15,12 +15,14 @@ Playbook, OutputFormatter, ToolGuide, + ToolGuard, IntentGuard, ToolApproval, KeywordTrigger, NaturalLanguageTrigger, AlwaysTrigger, IntentGuardResponse, + Trigger, ) @@ -123,7 +125,7 @@ def create_playbook_from_markdown( raise ValueError(f"Playbook in {file_path} missing 'name' in frontmatter") triggers_config = frontmatter.get('triggers', {}) - triggers = create_triggers_from_metadata(triggers_config) + triggers = cast(List[Trigger], create_triggers_from_metadata(triggers_config)) if not triggers: raise ValueError(f"Playbook {name} must have at least one trigger") @@ -160,10 +162,10 @@ def create_output_formatter_from_markdown( raise ValueError(f"OutputFormatter in {file_path} missing 'name' in frontmatter") triggers_config = frontmatter.get('triggers', {}) - triggers = create_triggers_from_metadata(triggers_config) + triggers = cast(List[Trigger], create_triggers_from_metadata(triggers_config)) if not triggers: - triggers = [AlwaysTrigger()] + triggers = cast(List[Trigger], [AlwaysTrigger()]) format_type = frontmatter.get('format_type', 'markdown') if format_type not in ['markdown', 'json_schema', 'direct']: @@ -205,10 +207,17 @@ def create_tool_guide_from_markdown( raise ValueError(f"ToolGuide {name} must specify 'target_tools'") triggers_config = frontmatter.get('triggers', {}) - triggers = create_triggers_from_metadata(triggers_config) + triggers = cast(List[Trigger], create_triggers_from_metadata(triggers_config)) if not triggers: - triggers = [AlwaysTrigger()] + triggers = cast(List[Trigger], [AlwaysTrigger()]) + + raw_tool_guards = frontmatter.get('tool_guards') + tool_guards = ( + {tool_name: ToolGuard(**guard_config) for tool_name, guard_config in raw_tool_guards.items()} + if isinstance(raw_tool_guards, dict) + else None + ) return ToolGuide( id=frontmatter.get('id', f"tool_guide_{Path(file_path).stem}"), @@ -218,6 +227,7 @@ def create_tool_guide_from_markdown( target_tools=target_tools, target_apps=frontmatter.get('target_apps'), guide_content=content, + tool_guards=tool_guards, prepend=frontmatter.get('prepend', False), priority=frontmatter.get('priority', 50), enabled=frontmatter.get('enabled', True), @@ -244,7 +254,7 @@ def create_intent_guard_from_markdown( raise ValueError(f"IntentGuard in {file_path} missing 'name' in frontmatter") triggers_config = frontmatter.get('triggers', {}) - triggers = create_triggers_from_metadata(triggers_config) + triggers = cast(List[Trigger], create_triggers_from_metadata(triggers_config)) if not triggers: raise ValueError(f"IntentGuard {name} must have at least one trigger") @@ -261,6 +271,7 @@ def create_intent_guard_from_markdown( response=IntentGuardResponse( response_type=response_type, content=content, + status_code=frontmatter.get('status_code'), ), allow_override=frontmatter.get('allow_override', False), priority=frontmatter.get('priority', 50), diff --git a/src/cuga/backend/cuga_graph/policy/models.py b/src/cuga/backend/cuga_graph/policy/models.py index 3c8dc655..33066dcf 100644 --- a/src/cuga/backend/cuga_graph/policy/models.py +++ b/src/cuga/backend/cuga_graph/policy/models.py @@ -165,6 +165,26 @@ class IntentGuardResponse(BaseModel): status_code: Optional[int] = Field(None, description="HTTP status code if applicable") +class ToolGuard(BaseModel): + """Guard configuration for a specific tool with compliance rules.""" + + violating_examples: List[str] = Field( + default_factory=list, description="Examples of violating usage patterns" + ) + compliance_examples: List[str] = Field( + default_factory=list, description="Examples of compliant usage patterns" + ) + policy_code: str = Field( + default="", + description=( + "Python code that validates tool usage compliance. " + "This code is executed in a sandboxed environment using the toolguard library. " + "Only trusted administrators with manage access should be allowed to modify policy code. " + "While sandboxed, policy code should still be reviewed for correctness and performance." + ) + ) + + class IntentGuard(BaseModel): """Guard that intercepts intents and provides custom responses.""" @@ -214,6 +234,9 @@ class ToolGuide(BaseModel): None, description="List of app names to enrich tools for (optional)" ) guide_content: str = Field(..., description="Markdown content to append to tool descriptions") + tool_guards: Optional[Dict[str, ToolGuard]] = Field( + default=None, description="Optional guard configurations per tool (key: tool_name, value: ToolGuard)" + ) prepend: bool = Field(False, description="Whether to prepend content instead of appending") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") priority: int = Field(0, description="Priority when multiple guides match (higher = more important)") diff --git a/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py b/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py index 828fc76e..f0fc27ec 100644 --- a/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py +++ b/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py @@ -24,6 +24,7 @@ IntentGuard, Playbook, ToolGuide, + ToolGuard, ToolApproval, OutputFormatter, KeywordTrigger, @@ -164,6 +165,39 @@ async def test_save_policy_to_filesystem( for expected in expected_content: assert expected in content + @pytest.mark.asyncio + async def test_save_tool_guide_with_tool_guards_to_filesystem(self, temp_cuga_folder): + """Test saving tool_guards in ToolGuide frontmatter.""" + fs_sync = PolicyFilesystemSync(cuga_folder=temp_cuga_folder) + policy = ToolGuide( + id="test_guide_with_guards", + name="Guide With Guards", + description="Test tool guide with per-tool guard configuration", + triggers=[AlwaysTrigger()], + target_tools=["test_tool"], + target_apps=None, + guide_content="## Guidelines\n- Be careful", + tool_guards={ + "test_tool": ToolGuard( + violating_examples=["Deleting all records without confirmation"], + compliance_examples=["Delete one record after explicit confirmation"], + policy_code="def validate(call):\n return True", + ) + }, + prepend=False, + priority=0, + enabled=True, + ) + + file_path = fs_sync.save_policy_to_file(policy) + + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + assert "tool_guards:" in content + assert "test_tool:" in content + assert "Only allow safe usage" in content + @pytest.mark.asyncio async def test_delete_policy_file(self, temp_cuga_folder): """Test deleting a policy file""" @@ -376,7 +410,18 @@ async def test_auto_load_multiple_policy_types(self, temp_cuga_folder): description="Test", triggers=[AlwaysTrigger()], target_tools=["*"], + target_apps=None, guide_content="## Test", + tool_guards={ + "test_tool": ToolGuard( + violating_examples=["bad_example()"], + compliance_examples=["good_example()"], + policy_code="def validate(call):\n return True", + ) + }, + prepend=False, + priority=0, + enabled=True, ) fs_sync.save_policy_to_file(guard) @@ -399,6 +444,12 @@ async def test_auto_load_multiple_policy_types(self, temp_cuga_folder): assert "playbook" in policy_types assert "tool_guide" in policy_types + loaded_guide = await agent.policies.get("guide_1") + assert loaded_guide is not None + assert loaded_guide["policy"].tool_guards is not None + assert "test_tool" in loaded_guide["policy"].tool_guards + # ToolGuard no longer has description field - it's derived from ToolGuide + @pytest.mark.asyncio async def test_auto_load_disabled(self, temp_cuga_folder): """Test that auto-load can be disabled""" diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py new file mode 100644 index 00000000..841989f8 --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py @@ -0,0 +1,13 @@ +""" +Tool Guard integration for CUGA. + +This module provides integration between CUGA's tool system and Toolguard's +policy enforcement framework. +""" + +from .manager import ToolGuardManager +from .tool_guard_runtime import ToolGuardRuntime + +__all__ = ["ToolGuardManager", "ToolGuardRuntime"] + + diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py new file mode 100644 index 00000000..1d022ea7 --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py @@ -0,0 +1,418 @@ +""" +Manager for generating tool guard examples using the toolguard library. + +This module provides integration between CUGA's policy system and toolguard's +example generation capabilities to create violating and compliance examples +for tool usage policies. +""" + +import asyncio +import re +import shutil +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from langchain_core.tools import StructuredTool +from loguru import logger +from toolguard.buildtime.buildtime import ( + ToolGuardSpec, + generate_guard_examples, + generate_guards_code, +) +from toolguard.buildtime.llm import LangchainModelWrapper +from toolguard.extra.langchain_to_oas import langchain_tools_to_openapi +from toolguard.runtime.data_types import ( + ToolGuardsCodeGenerationResult, + ToolGuardSpecItem, +) + +from cuga.backend.cuga_graph.policy.models import PolicyType, ToolGuide +from cuga.sdk import CugaAgent + + +class ToolGuardManager: + + def __init__( + self, + agent: CugaAgent, + ): + self.langchain_tools: List[StructuredTool] = [] # Store LangChain tools + self.tools_dict: Dict[str, Any] = {} # Store OpenAPI dict for ToolGuard + self._initialized = False + self._init_lock = asyncio.Lock() + + # Validate tool_provider upfront + if agent.tool_provider is None: + raise ValueError( + "Agent tool_provider is not initialized. Ensure the CugaAgent has a valid " + "tool_provider before creating ToolGuardManager." + ) + + # Validate cuga_folder upfront + if not agent.cuga_folder: + raise ValueError( + "Agent cuga_folder is not set. Ensure the CugaAgent has a valid " + "cuga_folder path before creating ToolGuardManager." + ) + + self.tool_provider = agent.tool_provider + + # Wrap CUGA's LLM with adapter for ToolGuard compatibility + if agent._model is None: + raise ValueError( + "Agent model is not initialized. Ensure the CugaAgent has a valid model " + "before creating ToolGuardManager." + ) + + self.llm = LangchainModelWrapper(agent._model) + logger.info(f"Initialized ToolGuardManager with {type(agent._model).__name__} via LangchainModelWrapper") + # Use Path to properly handle path concatenation and ensure directory exists + work_dir_path = Path(agent.cuga_folder) / "toolguard" + work_dir_path.mkdir(parents=True, exist_ok=True) + self.work_dir = str(work_dir_path) + logger.debug(f"ToolGuard working directory: {self.work_dir}") + + async def initialize( + self, + ) -> None: + async with self._init_lock: + if self._initialized: + logger.debug("ToolGuardManager already initialized, skipping") + return + + logger.info("Initializing ToolGuardManager...") + + # Get all LangChain tools from the tool provider + self.langchain_tools = await self.tool_provider.get_all_tools() + + # Convert LangChain tools to OpenAPI dict using ToolGuard's utility + self.tools_dict = langchain_tools_to_openapi(self.langchain_tools) + + self._initialized = True + logger.info(f"✅ ToolGuardManager initialized with {len(self.langchain_tools)} tools") + + def _ensure_initialized(self) -> None: + """Ensure manager is initialized, raise RuntimeError if not.""" + if not self._initialized or not self.langchain_tools: + raise RuntimeError( + "ToolGuardManager not initialized. Call initialize() first with a tool provider." + ) + + def _validate_policy_and_tool(self, policy: ToolGuide, target_tool: str) -> None: + """Validate policy type and target tool.""" + if policy.type != PolicyType.TOOL_GUIDE: + raise ValueError( + f"Policy must be of type 'tool_guide', got '{policy.type}'. " + f"Only tool_guide policies can generate examples." + ) + + if "*" not in policy.target_tools and target_tool not in policy.target_tools: + raise ValueError( + f"Tool '{target_tool}' is not in policy.target_tools. " + f"Policy targets: {policy.target_tools}" + ) + + tool_names = [tool.name for tool in self.langchain_tools] + if target_tool not in tool_names: + raise ValueError( + f"Tool '{target_tool}' not found in available tools. " + f"Available tools: {tool_names}" + ) + + def _validate_app_name(self, app_name: str) -> str: + """ + Validate and sanitize app_name to prevent path traversal attacks. + + Args: + app_name: Application name to validate + + Returns: + The validated app_name + + Raises: + ValueError: If app_name contains unsafe characters or patterns + """ + # Check for path separators and traversal segments + if '/' in app_name or '\\' in app_name: + raise ValueError( + f"Invalid app_name '{app_name}': path separators ('/', '\\') are not allowed" + ) + + if '..' in app_name: + raise ValueError( + f"Invalid app_name '{app_name}': path traversal segments ('..') are not allowed" + ) + + # Validate against safe whitelist: alphanumeric, underscore, and hyphen only + if not re.match(r'^[A-Za-z0-9_-]+$', app_name): + raise ValueError( + f"Invalid app_name '{app_name}': only alphanumeric characters, " + f"underscores, and hyphens are allowed (pattern: /^[A-Za-z0-9_-]+$/)" + ) + + return app_name + + def _build_description(self, policy: ToolGuide) -> str: + """Build description from policy, concatenating guide_content if present.""" + description = policy.description + if policy.guide_content: + description = f"{description}\n\n{policy.guide_content}" + return description + + def _create_spec_item( + self, + policy: ToolGuide, + violating_examples: Optional[List[str]] = None, + compliance_examples: Optional[List[str]] = None + ) -> ToolGuardSpecItem: + """Create ToolGuardSpecItem from policy.""" + description = self._build_description(policy) + + kwargs: Dict[str, Any] = { + "name": policy.name, + "description": description, + "references": [policy.guide_content] if policy.guide_content else [] + } + + if violating_examples is not None: + kwargs["violation_examples"] = violating_examples + if compliance_examples is not None: + kwargs["compliance_examples"] = compliance_examples + + return ToolGuardSpecItem(**kwargs) + + @contextmanager + def _temp_directory(self) -> Iterator[Path]: + """Context manager for temporary toolguard directory. + + Creates a unique temporary subdirectory per invocation to avoid + race conditions when generate_examples() or generate_guard_code() + are called concurrently. + """ + import uuid + tmp_dir = Path(self.work_dir) / f"tmp_{uuid.uuid4().hex}" + tmp_dir.mkdir(parents=True, exist_ok=True) + try: + yield tmp_dir + finally: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + logger.debug(f"Cleaned up temporary directory: {tmp_dir}") + + def _save_domain_files(self, result: ToolGuardsCodeGenerationResult) -> None: + """Save RuntimeDomain files to domain directory.""" + work_dir_path = Path(self.work_dir) + domain_dir = work_dir_path / "domain" + domain_dir.mkdir(parents=True, exist_ok=True) + + for attr_name in ["app_types", "app_api", "app_api_impl"]: + domain_file = getattr(result.domain, attr_name) + domain_file.save(domain_dir) + logger.info(f"Saved {attr_name} to {domain_dir / domain_file.file_name}") + + async def generate_examples( + self, + policy: ToolGuide, + target_tool: str + ) -> Tuple[List[str], List[str]]: + """ + Generate violating and compliance examples for a specific tool in a ToolGuide policy. + + Args: + policy: ToolGuide policy to generate examples for + target_tool: Specific tool name to generate examples for + + Returns: + Tuple of (violating_examples, compliance_examples) + + Raises: + RuntimeError: If manager not initialized + ValueError: If policy is not a ToolGuide or target_tool not in policy.target_tools + """ + self._ensure_initialized() + self._validate_policy_and_tool(policy, target_tool) + + logger.info(f"Generating examples for tool '{target_tool}'...") + + # Create ToolGuardSpecItem with policy information + spec_item = self._create_spec_item(policy) + + # Create ToolGuardSpec with the spec item + spec = ToolGuardSpec( + tool_name=target_tool, + policy_items=[spec_item] + ) + + # Generate examples using toolguard + with self._temp_directory() as tmp_dir: + try: + updated_specs = await generate_guard_examples( + tools=self.tools_dict, # Pass the OpenAPI dict + tool_specs=[spec], + llm=self.llm, # type: ignore + work_dir=str(tmp_dir) + ) + + # Extract examples from the updated spec + if updated_specs: + updated_spec = updated_specs[0] + if updated_spec.policy_items: + policy_item = updated_spec.policy_items[0] + + violating_examples = policy_item.violation_examples + compliance_examples = policy_item.compliance_examples + + logger.info( + f"✅ Generated {len(violating_examples)} violating and {len(compliance_examples)} " + f"compliance examples for tool '{target_tool}'" + ) + + return violating_examples, compliance_examples + else: + logger.warning(f"No policy items in updated spec for tool '{target_tool}'") + return [], [] + else: + logger.warning(f"No results returned for tool '{target_tool}'") + return [], [] + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error( + f"❌ Failed to generate examples for tool '{target_tool}': {e}" + ) + raise RuntimeError(f"Failed to generate examples for tool '{target_tool}'") from e + + async def generate_guard_code( + self, + policy: ToolGuide, + target_tool: str, + app_name: str = "cuga_app" + ) -> str: + """ + Generate guard code for a specific tool in a ToolGuide policy. + + This method creates a ToolGuardSpec from the policy, validates it has examples, + calls toolguard's generate_guards_code, saves the RuntimeDomain to a file, + and returns the generated guard code content. + + Args: + policy: ToolGuide policy to generate guard code for + target_tool: Specific tool name to generate guard code for + app_name: Application name for the generated code (default: "cuga_app") + + Returns: + String containing the generated guard code + + Raises: + RuntimeError: If manager not initialized + ValueError: If policy is not a ToolGuide, target_tool not in policy.target_tools, + if the policy doesn't have examples for the target tool, + or if app_name contains unsafe characters + """ + self._ensure_initialized() + self._validate_policy_and_tool(policy, target_tool) + + # Validate app_name to prevent path traversal attacks + app_name = self._validate_app_name(app_name) + + logger.info(f"Generating guard code for tool '{target_tool}'...") + + # Check if policy has tool_guards for this specific tool + tool_guard = None + if policy.tool_guards and target_tool in policy.tool_guards: + tool_guard = policy.tool_guards[target_tool] + + # Validate that we have examples (either from tool_guards or need to generate them first) + if tool_guard: + violating_examples = tool_guard.violating_examples + compliance_examples = tool_guard.compliance_examples + else: + violating_examples = [] + compliance_examples = [] + + # Ensure we have examples + if not violating_examples and not compliance_examples: + raise ValueError( + f"Policy for tool '{target_tool}' must have examples before generating guard code. " + f"Call generate_examples() first to create examples, or provide them in the policy's tool_guards." + ) + + # Create ToolGuardSpecItem with policy information and examples + spec_item = self._create_spec_item( + policy, + violating_examples=violating_examples, + compliance_examples=compliance_examples + ) + + # Create ToolGuardSpec with the spec item + spec = ToolGuardSpec( + tool_name=target_tool, + policy_items=[spec_item] + ) + + # Generate guard code using toolguard + with self._temp_directory() as tmp_dir: + try: + result: ToolGuardsCodeGenerationResult = await generate_guards_code( + tools=self.tools_dict, # Pass the OpenAPI dict + tool_specs=[spec], + work_dir=str(tmp_dir), + llm=self.llm, # type: ignore + app_name=app_name + ) + + # Save RuntimeDomain files directly under toolguard directory (not in tmp) + self._save_domain_files(result) + + # Extract guard code from the result + if target_tool in result.tools: + tool_result = result.tools[target_tool] + + # Get the item guard file content (should be only one) + if not tool_result.item_guard_files: + raise ValueError( + f"No item guard files generated for tool '{target_tool}'" + ) + + if len(tool_result.item_guard_files) > 1: + logger.warning( + f"Multiple item guard files found for tool '{target_tool}', using the first one" + ) + + item_guard_file = tool_result.item_guard_files[0] + if item_guard_file is None: + raise ValueError( + f"Item guard file is None for tool '{target_tool}'" + ) + + guard_code = item_guard_file.content + + logger.info( + f"✅ Generated guard code for tool '{target_tool}' " + f"(guard function: {tool_result.guard_fn_name})" + ) + + return guard_code + else: + raise ValueError( + f"Tool '{target_tool}' not found in generation results. " + f"Available tools: {list(result.tools.keys())}" + ) + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error( + f"❌ Failed to generate guard code for tool '{target_tool}': {e}" + ) + raise RuntimeError(f"Failed to generate guard code for tool '{target_tool}'") from e + + @property + def is_initialized(self) -> bool: + """Check if the manager has been initialized.""" + return self._initialized + + + diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py new file mode 100644 index 00000000..e96699e3 --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -0,0 +1,839 @@ +""" +Runtime execution of tool guards for policy enforcement. + +This module provides runtime validation of tool calls against registered +ToolGuide policies with policy_code. +""" + +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from loguru import logger +from toolguard.runtime.data_types import ( + FileTwin, + PolicyViolationException, + RuntimeDomain, + ToolGuardCodeResult, + ToolGuardsCodeGenerationResult, + ToolGuardSpec, + ToolGuardSpecItem, +) +from toolguard.runtime.runtime import load_toolguards_from_memory + +from cuga.backend.cuga_graph.policy.models import Policy, PolicyType, ToolGuide +from cuga.backend.cuga_graph.policy.storage import PolicyStorage +from cuga.backend.cuga_graph.policy.tool_guard.tool_invoker import ToolGuardInvoker + + +class ToolGuardRuntime: + """ + Runtime system for executing tool guards during tool invocation. + + This class: + 1. Manages policy storage lifecycle (connect/disconnect) + 2. Initializes a ToolGuardInvoker for tool execution + 3. Loads all ToolGuide policies with policy_code + 4. Creates a mapping: tool_name -> List[ToolGuide with code] + 5. Prebuilds umbrella guard modules per tool + 6. Executes guard validation through toolguard runtime + """ + + def __init__( + self, + tool_provider, + enable_policies: bool = False, + policy_storage: Optional[PolicyStorage] = None + ) -> None: + """ + Initialize the ToolGuardRuntime. + + Args: + tool_provider: CUGA's tool provider instance + enable_policies: Whether to enable policy enforcement + policy_storage: Optional PolicyStorage instance (will be created if None and enable_policies=True) + """ + self.tool_provider = tool_provider + self.enable_policies = enable_policies + self.policy_storage = policy_storage + self.invoker = ToolGuardInvoker(tool_provider) + self.tool_to_guards: Dict[str, List[ToolGuide]] = {} + # Per-app runtime mapping to avoid cross-app collisions + self._runtimes_by_app: Dict[str, Any] = {} + self._runtime_domains_by_app: Dict[str, RuntimeDomain] = {} + self._initialized = False + self._policy_storage_owned = False # Track if we created the storage + logger.debug(f"Created ToolGuardRuntime instance (enable_policies={enable_policies})") + + async def initialize(self) -> None: + """ + Initialize the runtime by connecting to policy storage and loading policies. + + This method: + 1. Connects to policy storage if policies are enabled + 2. Fetches all ToolGuide policies from storage + 3. Filters for policies that have tool_guards with policy_code + 4. Builds the tool_to_guards mapping + 5. Per-app runtimes will be lazily loaded on first use + + Raises: + RuntimeError: If policy system is enabled but storage connection fails (fail-closed) + """ + logger.info("Initializing ToolGuardRuntime...") + self._reset_state() + + # Connect to policy storage if policies are enabled + if self.enable_policies: + if self.policy_storage is None: + # Create policy storage if not provided + from cuga.backend.cuga_graph.policy.storage import PolicyStorage + self.policy_storage = PolicyStorage() + self._policy_storage_owned = True + logger.debug("Created PolicyStorage instance") + + # Validate policy_storage has required interface + self._validate_policy_storage() + + try: + await self.policy_storage.connect() + logger.info("✅ Connected policy storage for ToolGuardRuntime") + except Exception as e: + logger.error(f"Failed to connect policy storage: {e}") + # Fail closed: if policy enforcement is enabled but storage fails, + # don't allow the service to start without policy validation + raise RuntimeError( + f"Policy system is enabled but PolicyStorage.connect() failed: {e}" + ) from e + + # Load policies if storage is available + if self.policy_storage is not None: + policies = await self.policy_storage.list_policies( + policy_type=PolicyType.TOOL_GUIDE, enabled_only=True + ) + logger.debug(f"Found {len(policies)} ToolGuide policies") + + # Filter to ensure we only have ToolGuide instances + tool_guide_policies = [p for p in policies if isinstance(p, ToolGuide)] + self._build_tool_to_guards_mapping(tool_guide_policies) + else: + logger.debug("No policy storage available, skipping policy loading") + + self._initialized = True + self._log_initialization_summary() + + def _validate_policy_storage(self) -> None: + """ + Validate that policy_storage has the required interface. + + Raises: + ValueError: If policy_storage doesn't implement required methods + """ + if self.policy_storage is None: + return + + required_methods = ['connect', 'disconnect', 'list_policies', 'get_policy'] + missing_methods = [] + + for method in required_methods: + if not hasattr(self.policy_storage, method): + missing_methods.append(method) + + if missing_methods: + raise ValueError( + f"policy_storage must implement the following methods: {', '.join(missing_methods)}. " + f"Provided object type: {type(self.policy_storage).__name__}" + ) + + logger.debug("✅ Policy storage interface validation passed") + + def _reset_state(self) -> None: + """Reset internal state for reinitialization.""" + # Clean up all per-app runtimes + for app_name, runtime in self._runtimes_by_app.items(): + if runtime is not None: + try: + runtime.__exit__(None, None, None) + except Exception: + logger.exception(f"Error while exiting ToolGuard runtime for app '{app_name}'") + self.tool_to_guards = {} + self._runtimes_by_app = {} + self._runtime_domains_by_app = {} + + def _build_tool_to_guards_mapping(self, policies: Sequence[ToolGuide]) -> None: + """ + Build mapping from tool names to their guard policies. + + Args: + policies: Sequence of ToolGuide policies to process + """ + for policy in policies: + if not policy.tool_guards: + logger.debug(f"Policy '{policy.name}' has no tool_guards, skipping") + continue + + self._register_policy_guards(policy) + + def _register_policy_guards(self, policy: ToolGuide) -> None: + """ + Register guards from a policy for all its tools. + + Args: + policy: ToolGuide policy to register + """ + if not policy.tool_guards: + return + + for tool_name, tool_guard in policy.tool_guards.items(): + if not tool_guard.policy_code: + logger.debug( + f"Tool guard for '{tool_name}' in policy '{policy.name}' " + f"has no policy_code, skipping" + ) + continue + + # Validate that policy_code contains at least one async def guard_* function + guard_func_name = self._extract_guard_function_name(tool_guard.policy_code) + if not guard_func_name: + logger.error( + f"Tool guard for '{tool_name}' in policy '{policy.name}' " + f"has policy_code but no valid 'async def guard_*' function found. " + f"Skipping registration to prevent marking tool as guarded without enforcement." + ) + continue + + if tool_name not in self.tool_to_guards: + self.tool_to_guards[tool_name] = [] + + self.tool_to_guards[tool_name].append(policy) + logger.debug( + f"Registered guard for tool '{tool_name}' from policy '{policy.name}' " + f"with guard function '{guard_func_name}'" + ) + + def _log_initialization_summary(self) -> None: + """Log summary of initialization results.""" + logger.info( + f"✅ ToolGuardRuntime initialized with guards for " + f"{len(self.tool_to_guards)} tools" + ) + for tool_name, guards in self.tool_to_guards.items(): + logger.debug( + f" - Tool '{tool_name}': {len(guards)} guard(s) " + f"({', '.join(g.name for g in guards)})" + ) + + def _build_runtime(self, app_name: str): + """ + Build an in-memory ToolGuard runtime from registered guard policies for a specific app. + + Args: + app_name: Name of the application to build runtime for + + Returns: + Runtime instance for the specified app + """ + runtime_domain = self._runtime_domains_by_app.get(app_name) + if runtime_domain is None: + raise RuntimeError(f"ToolGuard runtime domain not loaded for app '{app_name}'") + + file_twins: List[FileTwin] = [ + runtime_domain.app_types, + runtime_domain.app_api, + runtime_domain.app_api_impl, + ] + tools: Dict[str, ToolGuardCodeResult] = {} + + for tool_name, all_guards in self.tool_to_guards.items(): + # Filter guards to only those applicable to this app + guards = [ + guard for guard in all_guards + if guard.target_apps is None or not guard.target_apps or app_name in guard.target_apps + ] + + # Skip this tool if no guards apply to this app + if not guards: + logger.debug( + f"Skipping tool '{tool_name}' for app '{app_name}' - " + f"no applicable guards (out of {len(all_guards)} total)" + ) + continue + + module_name = self._module_name_for_tool(tool_name) + guard_fn_name = self._guard_function_name_for_tool(tool_name) + guard_module_path = Path(*module_name.split(".")).with_suffix(".py") + + module_content = self._build_tool_guard_module( + tool_name=tool_name, + guards=guards, + guard_fn_name=guard_fn_name, + ) + + guard_file = FileTwin( + file_name=guard_module_path, + content=module_content, + ) + file_twins.append(guard_file) + + tools[tool_name] = ToolGuardCodeResult( + tool=ToolGuardSpec( + tool_name=tool_name, + policy_items=[ + ToolGuardSpecItem( + name=policy.name, + description=f"Runtime guard from policy '{policy.name}'", + ) + for policy in guards + ], + ), + guard_fn_name=guard_fn_name, + guard_file=guard_file, + item_guard_files=[], + test_files=[], + ) + + result = ToolGuardsCodeGenerationResult( + out_dir=Path("."), + domain=runtime_domain, + tools=tools, + ) + + runtime = load_toolguards_from_memory(result) + runtime.__enter__() + return runtime + + def _load_runtime_domain(self, app_name: str) -> RuntimeDomain: + """ + Load RuntimeDomain files saved by ToolGuardManager for a specific app. + + Args: + app_name: Name of the application to load domain for + + Returns: + RuntimeDomain with loaded domain files for the specified app + + Raises: + RuntimeError: If domain directory or files are not found + """ + domain_dir = Path.cwd() / ".cuga" / "toolguard" / "domain" + self._validate_domain_directory(domain_dir) + + # Look for the specific app's domain + app_dir = domain_dir / app_name + if not app_dir.exists(): + raise RuntimeError( + f"ToolGuard domain directory not found for app '{app_name}': {app_dir}" + ) + + selected_domain = self._find_complete_domain_for_app(domain_dir, app_name) + + if selected_domain is None: + raise RuntimeError( + f"No complete ToolGuard domain found for app '{app_name}' under {domain_dir}" + ) + + return self._create_runtime_domain(domain_dir, selected_domain) + + def _validate_domain_directory(self, domain_dir: Path) -> None: + """ + Validate that the domain directory exists. + + Args: + domain_dir: Path to domain directory + + Raises: + RuntimeError: If domain directory doesn't exist + """ + if not domain_dir.exists(): + raise RuntimeError( + f"ToolGuard domain directory not found: {domain_dir}. " + "Generate tool guard code first so ToolGuardManager saves the domain files." + ) + + def _get_sorted_app_directories(self, domain_dir: Path) -> List[Path]: + """ + Get app directories sorted by modification time (newest first). + + Args: + domain_dir: Path to domain directory + + Returns: + List of app directory paths + + Raises: + RuntimeError: If no app directories found + """ + app_dirs = sorted( + [path for path in domain_dir.iterdir() if path.is_dir()], + key=lambda path: path.stat().st_mtime, + reverse=True, + ) + if not app_dirs: + raise RuntimeError( + f"No ToolGuard app directories found under {domain_dir}" + ) + return app_dirs + + def _find_complete_domain( + self, domain_dir: Path, app_dirs: List[Path] + ) -> Optional[Tuple[str, Path, Path, Path]]: + """ + Find the first complete domain with all required files. + + Args: + domain_dir: Path to domain directory + app_dirs: List of app directories to search + + Returns: + Tuple of (app_name, types_path, api_path, impl_path) or None + """ + for app_dir in app_dirs: + app_name = app_dir.name + app_types_rel = Path(app_name) / f"{app_name}_types.py" + app_api_rel = Path(app_name) / f"i_{app_name}.py" + app_api_impl_rel = Path(app_name) / f"{app_name}_impl.py" + + candidate_paths = [ + domain_dir / app_types_rel, + domain_dir / app_api_rel, + domain_dir / app_api_impl_rel, + ] + if all(path.exists() for path in candidate_paths): + return (app_name, app_types_rel, app_api_rel, app_api_impl_rel) + + return None + + def _find_complete_domain_for_app( + self, domain_dir: Path, app_name: str + ) -> Optional[Tuple[str, Path, Path, Path]]: + """ + Find complete domain files for a specific app. + + Args: + domain_dir: Path to domain directory + app_name: Name of the app to find domain for + + Returns: + Tuple of (app_name, types_path, api_path, impl_path) or None + """ + app_types_rel = Path(app_name) / f"{app_name}_types.py" + app_api_rel = Path(app_name) / f"i_{app_name}.py" + app_api_impl_rel = Path(app_name) / f"{app_name}_impl.py" + + candidate_paths = [ + domain_dir / app_types_rel, + domain_dir / app_api_rel, + domain_dir / app_api_impl_rel, + ] + if all(path.exists() for path in candidate_paths): + return (app_name, app_types_rel, app_api_rel, app_api_impl_rel) + + return None + + def _create_runtime_domain( + self, domain_dir: Path, selected_domain: Tuple[str, Path, Path, Path] + ) -> RuntimeDomain: + """ + Create RuntimeDomain from selected domain files. + + Args: + domain_dir: Path to domain directory + selected_domain: Tuple of (app_name, types_path, api_path, impl_path) + + Returns: + RuntimeDomain instance + """ + app_name, app_types_rel, app_api_rel, app_api_impl_rel = selected_domain + + api_content = FileTwin.load_from(domain_dir, app_api_rel).content + api_impl_content = FileTwin.load_from(domain_dir, app_api_impl_rel).content + + app_api_class_name = self._extract_class_name( + api_content, f"I{''.join(part.capitalize() for part in app_name.split('_'))}" + ) + app_api_impl_class_name = self._extract_class_name( + api_impl_content, ''.join(part.capitalize() for part in app_name.split('_')) + ) + + return RuntimeDomain( + app_name=app_name, + app_types=FileTwin.load_from(domain_dir, app_types_rel), + app_api_class_name=app_api_class_name, + app_api=FileTwin.load_from(domain_dir, app_api_rel), + app_api_size=0, + app_api_impl_class_name=app_api_impl_class_name, + app_api_impl=FileTwin.load_from(domain_dir, app_api_impl_rel), + ) + + def _extract_class_name(self, content: str, default: str) -> str: + """ + Extract class name from Python source code. + + Args: + content: Python source code + default: Default class name if not found + + Returns: + Extracted or default class name + """ + for line in content.splitlines(): + stripped = line.strip() + if stripped.startswith("class "): + return stripped.split()[1].split("(")[0].rstrip(":") + return default + + def _build_tool_guard_module( + self, + tool_name: str, + guards: List[ToolGuide], + guard_fn_name: str, + ) -> str: + """ + Create a module containing a single umbrella guard function for one tool. + + Args: + tool_name: Name of the tool + guards: List of ToolGuide policies for this tool + guard_fn_name: Name for the umbrella guard function + + Returns: + Generated Python module content as string + """ + guard_blocks: List[str] = [] + guard_calls: List[str] = [] + + for index, policy in enumerate(guards): + self._process_policy_guard( + policy, tool_name, index, guard_blocks, guard_calls + ) + + return self._generate_module_content(guard_fn_name, guard_blocks, guard_calls) + + def _process_policy_guard( + self, + policy: ToolGuide, + tool_name: str, + index: int, + guard_blocks: List[str], + guard_calls: List[str], + ) -> None: + """ + Process a single policy guard and add to blocks and calls. + + Args: + policy: ToolGuide policy to process + tool_name: Name of the tool + index: Index of this guard + guard_blocks: List to append guard code blocks to + guard_calls: List to append guard call statements to + """ + tool_guard = policy.tool_guards.get(tool_name) if policy.tool_guards else None + if not tool_guard or not tool_guard.policy_code: + logger.warning( + f"Policy '{policy.name}' missing tool_guard for '{tool_name}', skipping" + ) + return + + guard_func_name = self._extract_guard_function_name(tool_guard.policy_code) + if not guard_func_name: + logger.warning( + f"Could not find guard function in policy code for '{policy.name}', skipping" + ) + return + + validate_alias = f"_guard_validate_{index}" + + guard_blocks.append( + f"# Policy: {policy.name}\n" + f"{tool_guard.policy_code}\n" + f"# Assign the specific guard function for this policy\n" + f"{validate_alias} = {guard_func_name}\n" + ) + + # Sanitize policy name for safe embedding in generated Python code + policy_name_literal = repr(policy.name) + + guard_calls.extend([ + " try:", + f" await {validate_alias}(api=api, args=args)", + " except PolicyViolationException as e:", + " error_msg = str(e)", + " # Check if error already contains policy name to avoid duplication", + f" _policy_name = {policy_name_literal}", + " _prefix = f\"[{_policy_name}]\"", + " if not error_msg.startswith(_prefix):", + " error_msg = f\"{_prefix} {error_msg}\"", + " violations.append(error_msg)", + ]) + + def _extract_guard_function_name(self, policy_code: str) -> Optional[str]: + """ + Extract guard function name from policy code. + + Args: + policy_code: Generated policy code + + Returns: + Guard function name or None if not found + """ + for line in policy_code.split('\n'): + line = line.strip() + if line.startswith('async def guard_'): + # Extract function name: "async def guard_xxx(..." -> "guard_xxx" + return line.split('(')[0].replace('async def ', '').strip() + return None + + def _generate_module_content( + self, guard_fn_name: str, guard_blocks: List[str], guard_calls: List[str] + ) -> str: + """ + Generate the complete module content. + + Args: + guard_fn_name: Name for the umbrella guard function + guard_blocks: List of guard code blocks + guard_calls: List of guard call statements + + Returns: + Complete module content as string + """ + if not guard_calls: + guard_calls = [" return None"] + else: + guard_calls = [ + " violations = []", + *guard_calls, + " if violations:", + " raise PolicyViolationException(\"\\n\".join(violations))", + ] + + return ( + "from toolguard.runtime.data_types import (\n" + " PolicyViolationException,\n" + " assert_any_condition_met,\n" + ")\n" + "from toolguard.runtime.rules import rule\n\n" + f"{''.join(guard_blocks)}\n" + f"async def {guard_fn_name}(api, args):\n" + f"{chr(10).join(guard_calls)}\n" + ) + + def _module_name_for_tool(self, tool_name: str) -> str: + """ + Convert a tool name to a valid python module name. + + Args: + tool_name: Name of the tool + + Returns: + Valid Python module name + """ + normalized = self._normalize_name(tool_name) + return f"cuga_toolguard_runtime.generated.guard_{normalized}" + + def _guard_function_name_for_tool(self, tool_name: str) -> str: + """ + Convert a tool name to a valid umbrella guard function name. + + Args: + tool_name: Name of the tool + + Returns: + Valid Python function name + """ + normalized = self._normalize_name(tool_name) + return f"guard_{normalized}" + + def _normalize_name(self, name: str) -> str: + """ + Normalize a name to be a valid Python identifier with disambiguation. + + Args: + name: Name to normalize + + Returns: + Normalized name safe for use as Python identifier with hash suffix + """ + import hashlib + + # Create readable normalized portion + normalized = "".join( + ch if ch.isalnum() else "_" for ch in name.lower() + ).strip("_") + + # Use "tool" as base if normalization results in empty string + base = normalized if normalized else "tool" + + # Add short hash suffix for disambiguation + name_hash = hashlib.sha256(name.encode()).hexdigest()[:8] + + return f"{base}_{name_hash}" + + async def _get_or_create_runtime_for_app(self, app_name: str): + """ + Get or lazily create a runtime for the specified app. + + Args: + app_name: Name of the application + + Returns: + Runtime instance for the app, or None if it cannot be created + """ + # Return cached runtime if available + if app_name in self._runtimes_by_app: + return self._runtimes_by_app[app_name] + + # Try to load and build runtime for this app + try: + logger.info(f"Loading runtime domain for app '{app_name}'...") + runtime_domain = self._load_runtime_domain(app_name) + self._runtime_domains_by_app[app_name] = runtime_domain + + logger.info(f"Building runtime for app '{app_name}'...") + runtime = self._build_runtime(app_name) + self._runtimes_by_app[app_name] = runtime + + logger.info(f"✅ Runtime initialized for app '{app_name}'") + return runtime + except Exception as e: + logger.error( + f"Failed to initialize runtime for app '{app_name}': {e}", + exc_info=True + ) + # Cache None to avoid repeated failed attempts + self._runtimes_by_app[app_name] = None + return None + + async def guard_tool_call( + self, + app_name: str, + function_name: str, + arguments: Dict[str, Any] + ) -> Optional[str]: + """ + Validate a tool call against registered guards. + + This method delegates validation to the ToolGuard runtime using a + prebuilt umbrella guard function for the requested tool. + + Args: + app_name: Name of the application calling the tool + function_name: Name of the tool/function being called + arguments: Arguments being passed to the tool + + Returns: + Error message string if validation fails, None if validation passes + """ + if not self._initialized: + logger.warning("ToolGuardRuntime not initialized, skipping validation") + return None + + # Check if this tool has any guards + if function_name not in self.tool_to_guards: + logger.debug(f"No guards registered for tool '{function_name}'") + return None + + # Filter guards to only those applicable to this app + all_guards = self.tool_to_guards[function_name] + guards = [ + guard for guard in all_guards + if guard.target_apps is None or not guard.target_apps or app_name in guard.target_apps + ] + + if not guards: + logger.debug( + f"No guards applicable for tool '{function_name}' on app '{app_name}' " + f"(found {len(all_guards)} guard(s) but none match this app)" + ) + return None + + # Get or create app-specific runtime + runtime = await self._get_or_create_runtime_for_app(app_name) + if runtime is None: + logger.warning( + f"ToolGuard runtime unavailable for app '{app_name}' and tool '{function_name}', " + "skipping validation" + ) + return None + + logger.debug( + f"Validating tool call '{function_name}' for app '{app_name}' against " + f"{len(guards)} applicable guard(s) (out of {len(all_guards)} total) using umbrella runtime" + ) + + try: + args_obj = SimpleNamespace(**arguments) + await runtime.guard_toolcall( + tool_name=function_name, + args=arguments | {"args": args_obj}, + delegate=self.invoker, + ) + except PolicyViolationException as e: + error = str(e) + logger.warning( + f"Tool guard blocked call to '{function_name}' for app '{app_name}': {error}" + ) + return error + except Exception as e: + logger.error( + f"Error executing umbrella guard for tool '{function_name}' in app '{app_name}': {e}", + exc_info=True + ) + # Fail closed: treat internal guard errors as a violation so a buggy + # or malformed guard cannot silently bypass policy enforcement. + return ( + f"Internal guard error for '{function_name}': {e}. " + "Tool call blocked as a safety precaution." + ) + + logger.debug(f"Tool call '{function_name}' for app '{app_name}' passed all guards") + return None + + @property + def is_initialized(self) -> bool: + """Check if the runtime has been initialized.""" + return self._initialized + + def get_guarded_tools(self) -> List[str]: + """ + Get list of tool names that have guards registered. + + Returns: + List of tool names with active guards + """ + return list(self.tool_to_guards.keys()) + + def get_guards_for_tool(self, tool_name: str) -> List[ToolGuide]: + """ + Get all guards registered for a specific tool. + + Args: + tool_name: Name of the tool + + Returns: + List of ToolGuide policies with guards for this tool + """ + return self.tool_to_guards.get(tool_name, []) + + async def shutdown(self) -> None: + """Release in-memory ToolGuard runtime resources and disconnect policy storage.""" + # Clean up all per-app runtimes + for app_name, runtime in self._runtimes_by_app.items(): + if runtime is not None: + try: + runtime.__exit__(None, None, None) + except Exception: + logger.exception(f"Error while shutting down ToolGuard runtime for app '{app_name}'") + self._runtimes_by_app = {} + self._runtime_domains_by_app = {} + + # Disconnect policy storage if we own it + if self.policy_storage is not None and self._policy_storage_owned: + try: + await self.policy_storage.disconnect() + logger.debug("Disconnected policy storage") + except Exception as e: + logger.warning(f"Error disconnecting policy storage during shutdown: {e}") + self.policy_storage = None + + self._initialized = False + logger.debug("ToolGuardRuntime shutdown complete") \ No newline at end of file diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py new file mode 100644 index 00000000..6118a839 --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py @@ -0,0 +1,141 @@ +""" +ToolGuard invoker for CUGA's tool provider. + +This module provides integration between toolguard's runtime validation +and CUGA's tool provider system. +""" + +import asyncio +from typing import Any, Dict, Optional, Type, TypeVar +from loguru import logger + +from toolguard.runtime import IToolInvoker + +T = TypeVar('T') + + +class ToolGuardInvoker(IToolInvoker): + """ + Tool invoker that uses CUGA's tool provider for executing tools + during toolguard validation. + + This class bridges toolguard's runtime validation with CUGA's + tool execution system, allowing guards to invoke tools for + validation purposes. + + Similar to LangchainToolInvoker and MCPToolInvoker from the toolguard + library, but adapted to work with CUGA's tool provider. + """ + + def __init__(self, tool_provider): + """ + Initialize the ToolGuardInvoker. + + Args: + tool_provider: CUGA's tool provider instance that manages + and executes tools + """ + self.tool_provider = tool_provider + self._tools_cache: Optional[Dict[str, Any]] = None + logger.debug("Initialized ToolGuardInvoker with CUGA tool provider") + + async def _get_tools(self) -> Dict[str, Any]: + """ + Get all available tools from the tool provider. + + Returns: + Dictionary mapping tool names to tool instances + + Raises: + ValueError: If duplicate tool names are detected + """ + if self._tools_cache is None: + tools_list = await self.tool_provider.get_all_tools() + + # Check for duplicate tool names before building cache + tools_map: Dict[str, Any] = {} + for tool in tools_list: + if tool.name in tools_map: + raise ValueError( + f"Duplicate tool name detected: '{tool.name}'. " + f"Tool names must be unique across all providers to ensure " + f"correct routing of guards to tools." + ) + tools_map[tool.name] = tool + + self._tools_cache = tools_map + logger.debug(f"Cached {len(self._tools_cache)} tools") + return self._tools_cache + + async def invoke( + self, + toolname: str, + arguments: Dict[str, Any], + return_type: Type[T] + ) -> T: + """ + Invoke a tool by name with the given arguments. + + This method is called by toolguard during guard validation + to execute tools and verify their behavior. + + Args: + toolname: Name of the tool to invoke + arguments: Dictionary of arguments to pass to the tool + return_type: Expected return type for the tool invocation + + Returns: + The result of the tool invocation, cast to the expected type + + Raises: + ValueError: If the tool is not found + RuntimeError: If tool invocation fails + """ + try: + # Redact sensitive arguments before logging + arg_summary = { + k: f"<{type(v).__name__}>" if v is not None else None + for k, v in (arguments.items() if isinstance(arguments, dict) else {}) + } + logger.debug(f"Invoking tool '{toolname}' with arg keys: {list(arg_summary.keys())}") + + # Get the tool from the provider + tools = await self._get_tools() + + if toolname not in tools: + available_tools = list(tools.keys()) + raise ValueError( + f"Tool '{toolname}' not found. " + f"Available tools: {available_tools}" + ) + + tool = tools[toolname] + + # Invoke the tool using LangChain's invoke method + # LangChain tools typically accept a single input or dict of inputs + result = await tool.ainvoke(arguments) + + logger.debug(f"Tool '{toolname}' invocation successful") + return result + + except ValueError: + # Re-raise ValueError as-is (tool not found) + raise + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Failed to invoke tool '{toolname}': {e}") + raise RuntimeError( + f"Tool invocation failed for '{toolname}': {str(e)}" + ) from e + + def clear_cache(self) -> None: + """ + Clear the cached tools. + + Call this method if tools are added or removed from the + tool provider after initialization. + """ + self._tools_cache = None + logger.debug("Cleared tools cache") + diff --git a/src/cuga/backend/cuga_graph/policy/utils.py b/src/cuga/backend/cuga_graph/policy/utils.py index 2abf20db..e336180e 100644 --- a/src/cuga/backend/cuga_graph/policy/utils.py +++ b/src/cuga/backend/cuga_graph/policy/utils.py @@ -14,6 +14,7 @@ Policy, PolicyType, ToolGuide, + ToolGuard, ToolApproval, ) from cuga.backend.cuga_graph.policy.storage import PolicyStorage @@ -215,6 +216,19 @@ async def apply_policies_data_to_storage( enabled=policy_data.get("enabled", True), ) elif policy_type == "tool_guide": + raw_tool_guards = policy_data.get("tool_guards") + tool_guards = ( + { + tool_name: ( + guard_config + if isinstance(guard_config, ToolGuard) + else ToolGuard(**guard_config) + ) + for tool_name, guard_config in raw_tool_guards.items() + } + if isinstance(raw_tool_guards, dict) + else None + ) policy = ToolGuide( id=policy_data["id"], name=policy_data["name"], @@ -223,6 +237,7 @@ async def apply_policies_data_to_storage( target_tools=policy_data.get("target_tools", []), target_apps=policy_data.get("target_apps"), guide_content=policy_data.get("guide_content", ""), + tool_guards=tool_guards, prepend=policy_data.get("prepend", False), priority=policy_data.get("priority", 50), enabled=policy_data.get("enabled", True), @@ -423,8 +438,8 @@ async def restore_policies(storage: PolicyStorage, backup_dir: str) -> int: for policy_type in PolicyType: backup_file = backup_path / f"policies_{policy_type.value}.json" if backup_file.exists(): - count = await load_policies_from_json(str(backup_file), storage) - total_count += count + result = await load_policies_from_json(str(backup_file), storage) + total_count += result.get("count", 0) logger.info(f"Restored {total_count} policies from {backup_dir}") return total_count @@ -453,7 +468,8 @@ def validate_policy(policy: Policy) -> tuple[bool, List[str]]: errors.append("Policy name is required") if not policy.description: errors.append("Policy description is required") - if not policy.triggers: + policy_triggers = getattr(policy, "triggers", None) + if policy_triggers is not None and not policy_triggers: errors.append("At least one trigger is required") # Type-specific validation @@ -522,16 +538,18 @@ def format_policy_summary(policy: Policy) -> str: Returns: Formatted summary string """ + policy_triggers = getattr(policy, "triggers", None) or [] + lines = [ f"Policy: {policy.name} ({policy.id})", f"Type: {policy.type}", f"Description: {policy.description}", f"Priority: {policy.priority}", f"Enabled: {'Yes' if policy.enabled else 'No'}", - f"Triggers: {len(policy.triggers)}", + f"Triggers: {len(policy_triggers)}", ] - for i, trigger in enumerate(policy.triggers): + for i, trigger in enumerate(policy_triggers): value = getattr(trigger, 'value', 'N/A') if isinstance(value, list): value_str = ', '.join(value) if value else '[]' diff --git a/src/cuga/backend/server/main.py b/src/cuga/backend/server/main.py index b9c9a8bb..e6d21edd 100644 --- a/src/cuga/backend/server/main.py +++ b/src/cuga/backend/server/main.py @@ -2429,6 +2429,7 @@ async def get_policies_config( frontend_policy["target_tools"] = policy_dict.get("target_tools", []) frontend_policy["target_apps"] = policy_dict.get("target_apps") frontend_policy["guide_content"] = policy_dict.get("guide_content", "") + frontend_policy["tool_guards"] = policy_dict.get("tool_guards") frontend_policy["prepend"] = policy_dict.get("prepend", False) elif policy_dict["type"] == "tool_approval": frontend_policy["required_tools"] = policy_dict.get("required_tools", []) @@ -2458,9 +2459,13 @@ async def get_policies_config( @app.post("/api/config/policies") async def save_policies_config( request: Request, - current_user: Optional[UserInfo] = Depends(require_auth), + current_user: Optional[UserInfo] = Depends(require_manage_access), ): - """Endpoint to save policies configuration. Use draft collection when X-Use-Draft header is set.""" + """Endpoint to save policies configuration. Use draft collection when X-Use-Draft header is set. + + Security: Requires manage access role. Policy code is executed unsandboxed at runtime, + so only trusted administrators should be allowed to modify policies. + """ if not settings.policy.enabled: return JSONResponse( {"status": "error", "message": "Policy system is disabled in settings"}, diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry.py b/src/cuga/backend/tools_env/registry/registry/api_registry.py index 76bb874d..90c19d3f 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry.py @@ -29,12 +29,25 @@ class ApiRegistry: interacting with the mcp manager """ - def __init__(self, client: MCPManager): + def __init__(self, client: MCPManager, enable_policies: bool = False): + """ + Initialize ApiRegistry. + + Args: + client: MCPManager instance for tool management + enable_policies: Whether to enable policy-based tool validation + """ logger.info("ApiRegistry: Initializing.") self.mcp_client = client self.auth_manager = None self.tavily_client = None self._init_tavily_if_enabled() + + # ToolGuardRuntime support for policy-based tool validation + # ToolGuardRuntime now manages its own policy storage lifecycle + self.tool_guard_runtime = None + self._tool_guard_initialized = False + self._enable_policies = enable_policies def _init_tavily_if_enabled(self): """Initialize Tavily client if web search is enabled.""" @@ -60,6 +73,62 @@ async def start_servers(self): """Start servers and load tools""" await self.mcp_client.load_tools() logger.info("ApiRegistry: Servers started successfully.") + + # Initialize ToolGuardRuntime after tools are loaded + await self._initialize_tool_guard_runtime() + + async def _initialize_tool_guard_runtime(self): + """ + Initialize ToolGuardRuntime after tools are loaded. + + This is called after load_tools() to ensure both tools and policies are available. + ToolGuardRuntime now manages its own policy storage lifecycle. + """ + if self._tool_guard_initialized: + return + + self._tool_guard_initialized = True + + if not self._enable_policies: + logger.debug("ToolGuardRuntime: Policy enforcement disabled, skipping initialization") + return + + try: + from cuga.backend.cuga_graph.policy.tool_guard.tool_guard_runtime import ToolGuardRuntime + + # ToolGuardRuntime now manages policy storage internally + # We just pass enable_policies flag and it handles the rest + self.tool_guard_runtime = ToolGuardRuntime( + tool_provider=self.mcp_client, + enable_policies=self._enable_policies + ) + + await self.tool_guard_runtime.initialize() + + guarded_tools = self.tool_guard_runtime.get_guarded_tools() + logger.info( + f"✅ ToolGuardRuntime initialized with guards for {len(guarded_tools)} tools" + ) + if guarded_tools: + logger.debug(f" Guarded tools: {', '.join(guarded_tools)}") + + except Exception as e: + logger.error(f"Failed to initialize ToolGuardRuntime: {e}", exc_info=True) + # Fail closed: if policy enforcement is enabled but ToolGuardRuntime fails, + # don't allow the service to start without policy validation + raise RuntimeError( + f"ToolGuardRuntime failed to initialize. Policy enforcement cannot be bypassed. Error: {e}" + ) from e + + async def cleanup(self): + """Cleanup resources including ToolGuardRuntime.""" + if self.tool_guard_runtime is not None: + try: + await self.tool_guard_runtime.shutdown() + logger.debug("ToolGuardRuntime shutdown complete") + except Exception as e: + logger.warning(f"Error shutting down ToolGuardRuntime: {e}") + self.tool_guard_runtime = None async def show_applications(self) -> List[AppDefinition]: """Lists application names and their descriptions.""" @@ -179,9 +248,50 @@ async def call_function( self, app_name: str, function_name: str, arguments: Dict[str, Any], auth_config=None ) -> Dict[str, Any]: """Calls a function via the mcp_client.""" + + # Use arguments as-is - do not unwrap 'params' unconditionally + # Transport-layer wrappers should be normalized at the request boundary, + # not here where it affects both guards and tools + unwrapped_args = arguments + + # Validate tool call against ToolGuard policies + if self.tool_guard_runtime and self.tool_guard_runtime.is_initialized: + try: + error_message = await self.tool_guard_runtime.guard_tool_call( + app_name=app_name, + function_name=function_name, + arguments=unwrapped_args if isinstance(unwrapped_args, dict) else {} + ) + + if error_message: + # Guard validation failed - return error without executing tool + logger.warning( + f"🛡️ Tool guard blocked call to '{function_name}': {error_message}" + ) + return { + "status": "exception", + "status_code": 403, + "message": f"Tool guard policy violation: {error_message}", + "error_type": "ToolGuardViolation", + "function_name": function_name, + } + else: + logger.debug(f"✅ Tool guard validation passed for '{function_name}'") + except Exception as e: + # Fail-closed: treat exceptions from guard_tool_call as policy violations + # to honor ToolGuardRuntime's fail-closed contract + error_msg = f"Exception during tool guard execution for '{function_name}': {e}" + logger.error(error_msg, exc_info=True) + return { + "status": "exception", + "status_code": 403, + "message": f"Tool guard policy violation: {error_msg}", + "error_type": "ToolGuardViolation", + "function_name": function_name, + } + if app_name == "web" and function_name == "search_web" and self._is_web_search_enabled(): - args = arguments.get('params', arguments) if isinstance(arguments, dict) else arguments - query = args.get('query') if isinstance(args, dict) else str(args) + query = unwrapped_args.get('query') if isinstance(unwrapped_args, dict) else str(unwrapped_args) if not query: return { "status": "exception", @@ -284,13 +394,12 @@ async def call_function( f"ApiRegistry: call_function(function_name='{function_name}', arguments={arguments}, headers={headers}) called." ) try: - # Delegate the call to the client - args = arguments['params'] if 'params' in arguments else arguments + # Delegate the call to the client (use already-unwrapped args) if self.auth_manager: headers["_tokens"] = json.dumps(self.auth_manager.get_stored_tokens()) result = await self.mcp_client.call_tool( tool_name=function_name, - args=args, + args=unwrapped_args, headers=headers, ) logger.debug("Response:", result) diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py index 87538221..86887ccc 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py @@ -150,8 +150,24 @@ async def _get_or_create_registry( logger.debug(f"Knowledge MCP transport override skipped: {e}") manager = MCPManager(config=services) - reg = ApiRegistry(client=manager) - await reg.start_servers() + + # ApiRegistry now manages policy storage lifecycle through ToolGuardRuntime + reg = ApiRegistry(client=manager, enable_policies=settings.policy.enabled) + try: + await reg.start_servers() + except Exception: + # Clean up manager if start_servers fails + try: + await manager.shutdown() + except Exception as cleanup_error: + logger.warning(f"Error shutting down manager during cleanup: {cleanup_error}") + + # Clean up registry resources (including ToolGuardRuntime/policy storage) + try: + await reg.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error cleaning up registry during cleanup: {cleanup_error}") + raise agent_registries[agent_id] = (manager, reg) return manager, reg @@ -176,8 +192,24 @@ async def lifespan(app: FastAPI): print(f"Using configuration file: {config_file}") services = load_service_configs(str(config_file)) mcp_manager = MCPManager(config=services) - registry = ApiRegistry(client=mcp_manager) - await registry.start_servers() + + # ApiRegistry now manages policy storage lifecycle through ToolGuardRuntime + registry = ApiRegistry(client=mcp_manager, enable_policies=settings.policy.enabled) + try: + await registry.start_servers() + except Exception: + # Clean up mcp_manager if start_servers fails + try: + await mcp_manager.shutdown() + except Exception as cleanup_error: + logger.warning(f"Error shutting down mcp_manager during cleanup: {cleanup_error}") + + # Clean up registry resources (including ToolGuardRuntime/policy storage) + try: + await registry.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error cleaning up registry during cleanup: {cleanup_error}") + raise yield @@ -185,8 +217,20 @@ async def lifespan(app: FastAPI): for agent_id, (mgr, reg) in agent_registries.items(): logger.info(f"Cleaning up registry for agent: {agent_id}") await mgr.shutdown() - if not database_mode and 'mcp_manager' in globals(): - await mcp_manager.shutdown() + # Cleanup registry resources (including ToolGuardRuntime/policy storage) + try: + await reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up registry for agent {agent_id}: {e}") + if not database_mode: + # In YAML mode, also clean up the global registry + if 'registry' in globals() and registry is not None: + try: + await registry.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up global registry: {e}") + if 'mcp_manager' in globals(): + await mcp_manager.shutdown() # --- FastAPI Server Setup --- @@ -483,8 +527,12 @@ async def reload_config( logger.info(f"Reloading from database for agent: {agent_id}") # Clear cache for this agent if agent_id in agent_registries: - old_mgr, _ = agent_registries.pop(agent_id) + old_mgr, old_reg = agent_registries.pop(agent_id) await old_mgr.shutdown() + try: + await old_reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up old registry: {e}") # Recreate registry for this agent with retry on empty await _get_or_create_registry(agent_id, retry_on_empty=True) # If this is the default agent, update global registry @@ -509,8 +557,12 @@ async def reload_config( # Reload all agents (clear cache) logger.info("Reloading all agents from database") agent_ids = list(agent_registries.keys()) - for old_mgr, _ in agent_registries.values(): + for old_mgr, old_reg in agent_registries.values(): await old_mgr.shutdown() + try: + await old_reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up old registry: {e}") agent_registries.clear() # Recreate default agent mcp_manager, registry = await _get_or_create_registry(default_agent_id) @@ -526,8 +578,32 @@ async def reload_config( logger.info(f"Reloading from file: {config_path}") services = load_service_configs(config_path) new_manager = MCPManager(config=services) - new_registry = ApiRegistry(client=new_manager) - await new_registry.start_servers() + + # ApiRegistry now manages policy storage lifecycle through ToolGuardRuntime + new_registry = ApiRegistry(client=new_manager, enable_policies=settings.policy.enabled) + try: + await new_registry.start_servers() + except Exception: + # Clean up new_manager if start_servers fails + try: + await new_manager.shutdown() + except Exception as cleanup_error: + logger.warning(f"Error shutting down new_manager during cleanup: {cleanup_error}") + + # Clean up new registry resources + try: + await new_registry.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error cleaning up new registry during cleanup: {cleanup_error}") + raise + + # Clean up old registry before replacing + if 'registry' in globals() and registry is not None: + try: + await registry.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up old registry: {e}") + if 'mcp_manager' in globals(): await mcp_manager.shutdown() mcp_manager = new_manager @@ -554,16 +630,24 @@ async def clear_agent_cache( try: if agent_id: if agent_id in agent_registries: - mgr, _ = agent_registries.pop(agent_id) + mgr, reg = agent_registries.pop(agent_id) await mgr.shutdown() + try: + await reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up registry: {e}") logger.info(f"Cleared cache for agent: {agent_id}") return {"status": "ok", "message": f"Cache cleared for agent: {agent_id}"} else: return {"status": "ok", "message": f"No cache found for agent: {agent_id}"} else: count = len(agent_registries) - for mgr, _ in agent_registries.values(): + for mgr, reg in agent_registries.values(): await mgr.shutdown() + try: + await reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up registry: {e}") agent_registries.clear() logger.info(f"Cleared cache for {count} agents") return {"status": "ok", "message": f"Cache cleared for {count} agents"} diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index 94df8879..1887712d 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -68,7 +68,7 @@ def delete_database(table: str) -> str: ``` """ -from typing import List, Optional, Dict, Any, Union, TYPE_CHECKING +from typing import List, Optional, Dict, Any, Union, TYPE_CHECKING, Tuple import uuid from loguru import logger from pydantic import BaseModel, Field @@ -107,6 +107,7 @@ def delete_database(table: str) -> str: IntentGuard, Playbook, ToolGuide, + ToolGuard, ToolApproval, OutputFormatter, KeywordTrigger, @@ -535,6 +536,107 @@ async def add_tool_guide( logger.info(f"Added Tool Guide policy: {policy.id}") return policy.id + async def update_tool_guard( + self, + policy_id: str, + tool_guards: Dict[str, Dict[str, Any]], + ) -> str: + """ + Update an existing Tool Guide policy with tool_guards. + + Args: + policy_id: ID of the existing Tool Guide policy to update + tool_guards: Dict of tool guards (key: tool_name, value: dict with 'violating_examples', 'compliance_examples', 'policy_code') + + Returns: + Policy ID + + Raises: + ValueError: If policy not found or not a ToolGuide type + + Example: + ```python + await agent.policies.update_tool_guard( + policy_id="tool_guide_abc123", + tool_guards={ + "delete_file": { + "description": "Guard rules for file deletion", + "violating_examples": ["Delete system files"], + "compliance_examples": ["Delete user files with confirmation"], + "policy_code": "" + } + } + ) + ``` + """ + policy_system = await self._ensure_policy_system() + if policy_system is None: + logger.warning("Policy system is disabled - skipping update_tool_guard") + return None + + # Retrieve the existing policy + existing_policy = await policy_system.storage.get_policy(policy_id) + if existing_policy is None: + raise ValueError(f"Policy with ID '{policy_id}' not found") + + # Verify it's a ToolGuide policy + if not isinstance(existing_policy, ToolGuide): + raise ValueError( + f"Policy '{policy_id}' is not a ToolGuide policy (type: {type(existing_policy).__name__})" + ) + + # Merge with existing tool_guards to preserve guards for other tools + tool_guards_obj = dict(existing_policy.tool_guards or {}) + + # Validate that tool_guards keys are in target_tools + target_tools_set = set(existing_policy.target_tools or []) + invalid_tools = set(tool_guards.keys()) - target_tools_set + + if invalid_tools: + raise ValueError( + f"Invalid tool names in tool_guards: {', '.join(sorted(invalid_tools))}. " + f"Must be one of: {', '.join(sorted(target_tools_set))}" + ) + + # Convert and update only the incoming tool_guards + for tool_name, guard_data in tool_guards.items(): + tool_guards_obj[tool_name] = ToolGuard( + violating_examples=guard_data.get("violating_examples", []), + compliance_examples=guard_data.get("compliance_examples", []), + policy_code=guard_data.get("policy_code", ""), + ) + + # Create updated policy with tool_guards + updated_policy = ToolGuide( + id=existing_policy.id, + name=existing_policy.name, + description=existing_policy.description, + triggers=existing_policy.triggers, + target_tools=existing_policy.target_tools, + target_apps=existing_policy.target_apps, + guide_content=existing_policy.guide_content, + tool_guards=tool_guards_obj, + prepend=existing_policy.prepend, + priority=existing_policy.priority, + enabled=existing_policy.enabled, + metadata=existing_policy.metadata, + ) + + # Update in storage + await policy_system.storage.update_policy(updated_policy) + await policy_system.initialize() # Reload policies + + # Save to filesystem if sync is enabled + if self._fs_sync: + try: + self._fs_sync.save_policy_to_file(updated_policy) + logger.debug(f"Saved updated policy '{policy_id}' to filesystem") + except Exception as e: + logger.warning(f"Failed to save updated policy to filesystem: {e}") + + logger.info(f"Updated Tool Guide policy '{policy_id}' with tool_guards") + return policy_id + async def add_tool_approval( self, name: str, @@ -1113,11 +1215,11 @@ async def sync_from_filesystem(self) -> Dict[str, Any]: policy_system = await self._ensure_policy_system() if policy_system is None: logger.warning("Policy system is disabled - skipping sync") - return {"loaded": 0, "removed": 0, "errors": ["Policy system is disabled"]} + return {"loaded": 0, "removed": 0, "errors": ["Policy system is disabled"], "files": []} if not self._fs_sync: logger.warning("Filesystem sync not initialized - skipping") - return {"loaded": 0, "removed": 0, "errors": ["Filesystem sync not initialized"]} + return {"loaded": 0, "removed": 0, "errors": ["Filesystem sync not initialized"], "files": []} try: # Load policies from filesystem @@ -1134,7 +1236,190 @@ async def sync_from_filesystem(self) -> Dict[str, Any]: } except Exception as e: logger.error(f"Failed to sync from filesystem: {e}") - return {"loaded": 0, "removed": 0, "errors": [str(e)]} + return { + "loaded": 0, + "removed": 0, + "errors": [str(e)], + "files": [], + } + + async def generate_tool_guard_examples( + self, + policy_id: str, + target_tool: str + ) -> Tuple[List[str], List[str]]: + """ + Generate violating and compliance examples for a specific tool in a policy. + + This method uses the ToolGuardManager to generate examples that demonstrate + both violations and compliance with the policy guidelines for a specific tool. + + Args: + policy_id: The ID of the policy to generate examples for + target_tool: The specific tool name to generate examples for + + Returns: + Tuple of (violating_examples, compliance_examples) + + Raises: + ValueError: If policy not found, not a ToolGuide, or target_tool not in policy + RuntimeError: If ToolGuardManager initialization fails + + Example: + ```python + agent = CugaAgent(tools=[delete_file]) + + # Add a tool guide policy + policy_id = await agent.policies.add_tool_guide( + name="File Safety", + target_tools=["delete_file"], + content="Never delete system files" + ) + + # Generate examples + violating, compliance = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="delete_file" + ) + + print(f"Violating: {violating}") + print(f"Compliance: {compliance}") + ``` + """ + from cuga.backend.cuga_graph.policy.tool_guard.manager import ToolGuardManager + from cuga.backend.cuga_graph.policy.models import PolicyType + + # Ensure policy system is initialized + policy_system = await self._ensure_policy_system() + if policy_system is None: + raise RuntimeError("Policy system is disabled") + + # Get the policy + policy_data = await self.get(policy_id) + if policy_data is None: + raise ValueError(f"Policy with ID '{policy_id}' not found") + + policy = policy_data.get('policy') + if policy is None: + raise ValueError(f"Could not retrieve policy object for ID '{policy_id}'") + + # Validate policy type + if policy.type != PolicyType.TOOL_GUIDE: + raise ValueError( + f"Policy must be of type 'tool_guide', got '{policy.type}'. " + f"Only tool_guide policies can generate examples." + ) + + # Create and initialize ToolGuardManager + manager = ToolGuardManager(self._agent) + await manager.initialize() + + # Generate examples using the manager + violating_examples, compliance_examples = await manager.generate_examples( + policy=policy, + target_tool=target_tool + ) + + return violating_examples, compliance_examples + + async def generate_tool_guard_code( + self, + policy_id: str, + target_tool: str, + app_name: str = "cuga_app" + ) -> str: + """ + Generate guard code for a specific tool in a policy. + + This method uses the ToolGuardManager to generate executable guard code + that validates tool usage compliance with the policy guidelines. + + Args: + policy_id: The ID of the policy to generate guard code for + target_tool: The specific tool name to generate guard code for + app_name: Application name for the generated code (default: "cuga_app") + + Returns: + String containing the generated guard code + + Raises: + ValueError: If policy not found, not a ToolGuide, target_tool not in policy, + or if the policy doesn't have examples for the target tool + RuntimeError: If ToolGuardManager initialization fails + + Example: + ```python + agent = CugaAgent(tools=[delete_file]) + + # Add a tool guide policy with examples + policy_id = await agent.policies.add_tool_guide( + name="File Safety", + target_tools=["delete_file"], + content="Never delete system files" + ) + + # Generate examples first + violating, compliance = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="delete_file" + ) + + # Update policy with examples + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "delete_file": { + "violating_examples": violating, + "compliance_examples": compliance + } + } + ) + + # Generate guard code + guard_code = await agent.policies.generate_tool_guard_code( + policy_id=policy_id, + target_tool="delete_file" + ) + + print(f"Generated guard code:\n{guard_code}") + ``` + """ + from cuga.backend.cuga_graph.policy.tool_guard.manager import ToolGuardManager + from cuga.backend.cuga_graph.policy.models import PolicyType + + # Ensure policy system is initialized + policy_system = await self._ensure_policy_system() + if policy_system is None: + raise RuntimeError("Policy system is disabled") + + # Get the policy + policy_data = await self.get(policy_id) + if policy_data is None: + raise ValueError(f"Policy with ID '{policy_id}' not found") + + policy = policy_data.get('policy') + if policy is None: + raise ValueError(f"Could not retrieve policy object for ID '{policy_id}'") + + # Validate policy type + if policy.type != PolicyType.TOOL_GUIDE: + raise ValueError( + f"Policy must be of type 'tool_guide', got '{policy.type}'. " + f"Only tool_guide policies can generate guard code." + ) + + # Create and initialize ToolGuardManager + manager = ToolGuardManager(self._agent) + await manager.initialize() + + # Generate guard code using the manager + guard_code = await manager.generate_guard_code( + policy=policy, + target_tool=target_tool, + app_name=app_name + ) + + return guard_code class CugaAgent: diff --git a/src/cuga/sdk_core/debug_crm_finance_policy_e2e.py b/src/cuga/sdk_core/debug_crm_finance_policy_e2e.py new file mode 100644 index 00000000..bf2a3c58 --- /dev/null +++ b/src/cuga/sdk_core/debug_crm_finance_policy_e2e.py @@ -0,0 +1,318 @@ +""" +Debug script for creating and testing CRM finance eligibility tool guard policy. + +This script demonstrates: +1. Creating a CugaAgent with CRM tools from the registry +2. Adding a finance eligibility policy with tool guards +3. Generating examples and guard code for the policy +4. Testing policy enforcement with ToolGuardRuntime + +Prerequisites: +- CRM API server must be running on port 8007 +- Run: cuga start demo_crm (or just the CRM API server) + +Note: This script uses an in-memory policy system for testing. +It does NOT persist policies to the running demo_crm server. +""" + +import asyncio +import os +import tempfile +from cuga import CugaAgent +from cuga.backend.cuga_graph.policy.tool_guard import ToolGuardRuntime +from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import CombinedToolProvider + +# Define policy to create +POLICY_CONFIG = { + "name": "Finance eligibility revenue requirements", + "content": """## Finance Industry Revenue Requirements + +### Policy Rules +- Accounts cannot be created for companies from the Finance industry with annual revenue under $100,000 +- This ensures we only onboard financially stable finance companies +- Companies from other industries have no revenue restrictions + +### Validation Requirements +- Always check the industry field before account creation +- If industry is "Finance", verify annual_revenue >= 100000 +- Reject account creation that violates revenue requirements +- Provide clear error messages when restrictions apply +""", + "description": "Accounts cannot be created for companies from the Finance industry with annual revenue under $100,000.", +} + + +async def create_and_process_policy(agent): + """Create policy and generate examples and guard code.""" + print("\nStep 1: Creating and processing policy...") + print("="*60) + + print(f"\n--- Processing Policy: {POLICY_CONFIG['name']} ---") + + # Create policy + print(f"Creating policy...") + policy_id = await agent.policies.add_tool_guide( + name=POLICY_CONFIG["name"], + content=POLICY_CONFIG["content"], + target_tools=["crm_create_account_accounts_post"], + description=POLICY_CONFIG["description"], + ) + print(f"✅ Created policy with ID: {policy_id}") + + # Generate examples + print(f"Generating examples...") + violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="crm_create_account_accounts_post" + ) + print(f"✅ Generated {len(violating_examples)} violating and {len(compliance_examples)} compliance examples") + + # Update policy with examples + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "crm_create_account_accounts_post": { + "description": f"Guard rules for {POLICY_CONFIG['name']}", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": "" + } + } + ) + + # Generate guard code + print(f"Generating guard code...") + guard_code = await agent.policies.generate_tool_guard_code( + policy_id=policy_id, + target_tool="crm_create_account_accounts_post", + app_name="crm_demo" + ) + print(f"✅ Generated guard code ({len(guard_code)} characters)") + + # Update policy with guard code + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "crm_create_account_accounts_post": { + "description": f"Guard rules for {POLICY_CONFIG['name']}", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": guard_code + } + } + ) + + # Retrieve policy + policy_tool_guide = await agent.policies.get(policy_id) + if policy_tool_guide is None: + raise ValueError(f"Failed to retrieve policy {policy_id}") + + policy = policy_tool_guide["policy"] + + print(f"✅ Policy created and processed successfully") + print("="*60) + + return { + "id": policy_id, + "name": POLICY_CONFIG["name"], + "policy": policy + } + + +async def run_tests(tool_guard_runtime, agent): + """Run test cases to validate policy enforcement.""" + print(f"\n{'='*60}") + print("Step 2: Testing policy enforcement with ToolGuardRuntime") + print(f"{'='*60}") + + print(f"\nRuntime initialized with guards for: {tool_guard_runtime.get_guarded_tools()}") + + test_cases = [ + { + "name": "Test Case 1: Finance with Low Revenue (BLOCKED)", + "args": { + "name": "ACM22 Corporation", + "website": "acm22corporation.com", + "phone": "+1-555-1883", + "address": "94 rue du Gue Jacquet", + "city": "Chatou", + "state": "Île-de-France", + "country": "France", + "region": "Europe", + "annual_revenue": 50000, + "employee_count": 88, + "industry": "Finance" + }, + "expected": "BLOCKED", + "reason": "Finance industry with revenue $50,000 < $100,000" + }, + { + "name": "Test Case 2: Finance with High Revenue (ALLOWED)", + "args": { + "name": "Global Finance Corp", + "website": "globalfinance.com", + "phone": "+1-555-2000", + "address": "123 Wall Street", + "city": "New York", + "state": "NY", + "country": "USA", + "region": "North America", + "annual_revenue": 1500000, + "employee_count": 250, + "industry": "Finance" + }, + "expected": "ALLOWED", + "reason": "Finance industry with revenue $1,500,000 >= $100,000" + }, + { + "name": "Test Case 3: Non-Finance with Low Revenue (ALLOWED)", + "args": { + "name": "Small Law Firm", + "website": "smalllawfirm.com", + "phone": "+1-555-3000", + "address": "456 Main Street", + "city": "Boston", + "state": "MA", + "country": "USA", + "region": "North America", + "annual_revenue": 50000, + "employee_count": 10, + "industry": "Law" + }, + "expected": "ALLOWED", + "reason": "Law industry has no revenue restrictions" + } + ] + + results = [] + for test in test_cases: + print(f"\n--- {test['name']} ---") + print(f"Attempting: crm_create_account_accounts_post(") + for k, v in test['args'].items(): + print(f" {k}={repr(v)},") + print(")") + print(f"Expected: {test['expected']} ({test['reason']})") + + try: + error = await tool_guard_runtime.guard_tool_call( + app_name="crm_demo", + function_name="crm_create_account_accounts_post", + arguments=test["args"] + ) + + actual = "BLOCKED" if error else "ALLOWED" + success = actual == test["expected"] + + if success: + print(f"\n✅ SUCCESS: Tool call was correctly {actual}!") + if error: + print(f"Error message: {error}") + else: + print(f"\n⚠️ WARNING: Tool call was {actual} (expected {test['expected']})") + if error: + print(f"Error message: {error}") + + results.append({"test": test["name"], "success": success, "actual": actual}) + + except Exception as e: + print(f"\n❌ Error during validation: {type(e).__name__}: {e}") + results.append({"test": test["name"], "success": False, "actual": "ERROR"}) + + # Print summary + print(f"\n{'='*60}") + print("Test Summary:") + print(f"{'='*60}") + passed = sum(1 for r in results if r["success"]) + print(f"Passed: {passed}/{len(results)}") + for r in results: + status = "✅" if r["success"] else "❌" + print(f" {status} {r['test']}: {r['actual']}") + print(f"{'='*60}") + + return results + + +async def main(): + """Main workflow for creating and testing CRM finance eligibility tool guard policy.""" + + print("="*60) + print("Initializing CRM Tool Provider...") + print("="*60) + + # Create tool provider with CRM app + tool_provider = CombinedToolProvider(app_names=["crm"]) + await tool_provider.initialize() + + # Verify CRM tools are available + tools = await tool_provider.get_tools(app_name="crm") + crm_tool = next((t for t in tools if t.name == "crm_create_account_accounts_post"), None) + + if not crm_tool: + print("❌ ERROR: CRM tool 'crm_create_account_accounts_post' not found!") + print("Available tools:", [t.name for t in tools]) + print("\nMake sure the CRM API server is running:") + print(" cuga start demo_crm") + print(" OR") + print(" cd src/cuga/demo_tools/crm && python -m crm_api.main") + return + + print(f"✅ Found CRM tool: {crm_tool.name}") + print(f"✅ Total tools available: {len(tools)}") + + # Create agent with CRM tools and in-memory policy system + print("\n" + "="*60) + print("Initializing CugaAgent with in-memory policy system...") + print("="*60) + + # Use a temporary database for this test + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db: + temp_db_path = tmp_db.name + + try: + agent = CugaAgent(tool_provider=tool_provider) + + # Initialize policy system with temporary database + from cuga.backend.cuga_graph.policy.configurable import PolicyConfigurable + + agent._policy_system = PolicyConfigurable() + await agent._policy_system.initialize(policy_db_path=temp_db_path) + + print(f"✅ Using temporary policy database: {temp_db_path}") + print("✅ Policy system initialized") + + # Get policy system + policy_system = agent._policy_system + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + + # Step 1: Create and process policy + policy_data = await create_and_process_policy(agent) + + # Step 2: Initialize runtime and run tests + print(f"\n{'='*60}") + print("Initializing ToolGuardRuntime...") + print(f"{'='*60}") + + tool_guard_runtime = ToolGuardRuntime( + tool_provider=agent.tool_provider, + policy_storage=policy_system.storage + ) + await tool_guard_runtime.initialize() + print("✅ ToolGuardRuntime initialized") + + results = await run_tests(tool_guard_runtime, agent) + + print(f"\n{'='*60}") + print("✅ E2E Test completed successfully!") + print(f"{'='*60}") + + finally: + # Clean up temporary database + if os.path.exists(temp_db_path): + os.unlink(temp_db_path) + print(f"\n🧹 Cleaned up temporary database: {temp_db_path}") + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/src/cuga/sdk_core/debug_tool_guard_e2e.py b/src/cuga/sdk_core/debug_tool_guard_e2e.py new file mode 100644 index 00000000..091475d2 --- /dev/null +++ b/src/cuga/sdk_core/debug_tool_guard_e2e.py @@ -0,0 +1,385 @@ +""" +Debug script for creating tool guard policies with code generation and E2E testing. + +This script demonstrates: +1. Creating a CugaAgent with flight booking tools +2. Adding multiple tool guide policies +3. Generating examples and guard code for each policy +4. Testing policies with ToolGuardRuntime +5. Cleaning up test policies at the end + +Configuration: +- Set DELETE_ALL_POLICIES_AT_START = True to delete all existing policies before running +- Set DELETE_ALL_POLICIES_AT_START = False to preserve existing policies (default) +- Set environment variable CUGA_E2E_ALLOW_DESTRUCTIVE=true to enable destructive cleanup +""" + +import asyncio +import os +from pathlib import Path + +from langchain_core.tools import tool + +from cuga import CugaAgent +from cuga.backend.cuga_graph.policy.models import ToolGuide +from cuga.backend.cuga_graph.policy.tool_guard import ToolGuardRuntime + +# ============================================================================ +# CONFIGURATION +# ============================================================================ +# Default to False for safety - require explicit opt-in for destructive operations +DELETE_ALL_POLICIES_AT_START = os.environ.get("CUGA_E2E_ALLOW_DESTRUCTIVE", "").lower() in ("true", "1", "yes") +# ============================================================================ + +# Define policies to create +POLICIES = [ + { + "name": "Flight Booking Membership Policy", + "content": """## Flight Booking Restrictions by Membership Level + +### Policy Rules +- Customers with "regular" membership cannot book a flight with more than 3 passengers +- Gold and silver members have no passenger restrictions +- This policy ensures fair resource allocation and encourages membership upgrades + +### Validation Requirements +- Always check user membership level before booking +- Reject bookings that violate passenger limits +- Provide clear error messages when restrictions apply +""", + "description": "Membership-based restrictions for flight bookings to ensure fair resource allocation", + }, + { + "name": "Flight ID Format Policy", + "content": """## Flight ID Format Requirements + +### Policy Rules +- Flight ID must start with exactly 2 letters +- Flight ID must have a total of exactly 4 characters (2 letters + 2 digits) +- Example valid flight IDs: FL12, AB99, XY01 +- Example invalid flight IDs: F123 (only 1 letter), FLI2 (3 letters), FL1 (only 3 characters total) + +### Validation Requirements +- Always validate flight ID format before booking +- Reject bookings with invalid flight ID format +- Provide clear error messages when format is incorrect +""", + "description": "Flight ID format validation to ensure proper booking system compatibility", + }, +] + + +@tool +def book_flight(user_id: str, flight_id: str, passengers: int) -> str: + """Book a flight for a user with specified number of passengers""" + return f"Flight {flight_id} booked for user {user_id} with {passengers} passengers" + + +@tool +def get_membership(user_id: str) -> str: + """Get the membership level of a user (gold, silver, or regular)""" + memberships = { + "user123": "gold", + "user456": "silver", + "user789": "regular" + } + return memberships.get(user_id, "regular") + + +async def cleanup_all_policies(agent): + """Clean up all existing policies if configured.""" + print("="*60) + print("Step 0: Cleaning up ALL existing policies") + print("="*60) + + policy_system = await agent.policies._ensure_policy_system() + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + + await policy_system.initialize() + + # Delete from storage + all_policies = await policy_system.storage.list_policies() + print(f"Found {len(all_policies)} total policies in storage") + + for policy in all_policies: + await policy_system.storage.delete_policy(policy.id) + print(f" Deleted from storage: '{policy.name}' (ID: {policy.id})") + + # Delete from filesystem + if agent.policies._fs_sync: + print("\nCleaning up policy files from filesystem...") + cuga_folder = Path(agent.policies._fs_sync.cuga_folder) + if cuga_folder.exists(): + policy_subfolders = ['playbooks', 'output_formatters', 'tool_guides', + 'intent_guards', 'tool_approvals', 'policies'] + + total_deleted = 0 + for subfolder in policy_subfolders: + subfolder_path = cuga_folder / subfolder + if subfolder_path.exists(): + files = list(subfolder_path.glob("*.md")) + list(subfolder_path.glob("*.json")) + for file in files: + file.unlink() + total_deleted += 1 + + if total_deleted > 0: + print(f"✅ Deleted {total_deleted} policy files from filesystem") + + print("✅ All policies successfully deleted") + print("="*60) + + +async def create_and_process_policies(agent, policy_system): + """Create policies and generate examples and guard code for each.""" + print("\nStep 1: Creating and processing policies...") + print("="*60) + + policy_data = [] + + for idx, policy_config in enumerate(POLICIES, 1): + print(f"\n--- Processing Policy {idx}/{len(POLICIES)}: {policy_config['name']} ---") + + # Create policy + print(f"Creating policy...") + policy_id = await agent.policies.add_tool_guide( + name=policy_config["name"], + content=policy_config["content"], + target_tools=["book_flight"], + description=policy_config["description"], + ) + print(f"✅ Created policy with ID: {policy_id}") + + # Generate examples + print(f"Generating examples...") + violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="book_flight" + ) + print(f"✅ Generated {len(violating_examples)} violating and {len(compliance_examples)} compliance examples") + + # Update policy with examples + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "book_flight": { + "description": f"Guard rules for {policy_config['name']}", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": "" + } + } + ) + + # Generate guard code + print(f"Generating guard code...") + guard_code = await agent.policies.generate_tool_guard_code( + policy_id=policy_id, + target_tool="book_flight", + app_name="test_app" + ) + print(f"✅ Generated guard code ({len(guard_code)} characters)") + #print(f"✅ Code:\n{guard_code} ") + + # Update policy with guard code + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "book_flight": { + "description": f"Guard rules for {policy_config['name']}", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": guard_code + } + } + ) + + # Save policy + policy_tool_guide = await agent.policies.get(policy_id) + if policy_tool_guide is None: + raise ValueError(f"Failed to retrieve policy {policy_id}") + + policy = policy_tool_guide["policy"] + + + print(f"✅ Policy saved successfully") + + # Store policy data for later use + policy_data.append({ + "id": policy_id, + "name": policy_config["name"], + "policy": policy + }) + + print("\n" + "="*60) + print(f"✅ All {len(POLICIES)} policies created and processed successfully") + print("="*60) + + return policy_data + + +async def run_tests(tool_guard_runtime): + """Run test cases to validate policy enforcement.""" + print(f"\n{'='*60}") + print("Step 2: Testing policy enforcement with ToolGuardRuntime") + print(f"{'='*60}") + + print(f"\nRuntime initialized with guards for: {tool_guard_runtime.get_guarded_tools()}") + + test_cases = [ + { + "name": "Test Case 1: Too Many Passengers", + "args": {"flight_id": "FL12", "user_id": "user789", "passengers": 8}, + "expected": "BLOCKED", + "reason": "user789 is 'regular' member, 8 > 3 passengers" + }, + { + "name": "Test Case 2: Valid Booking", + "args": {"flight_id": "FL45", "user_id": "user789", "passengers": 2}, + "expected": "ALLOWED", + "reason": "user789 is 'regular' member, 2 <= 3 passengers, valid flight ID" + }, + { + "name": "Test Case 3: Gold Member", + "args": {"flight_id": "AB78", "user_id": "user123", "passengers": 10}, + "expected": "ALLOWED", + "reason": "user123 is 'gold' member, no passenger limit" + }, + { + "name": "Test Case 4: Multiple Violations", + "args": {"flight_id": "F123", "user_id": "user789", "passengers": 8}, + "expected": "BLOCKED", + "reason": "8 > 3 passengers AND flight_id 'F123' has only 1 letter" + }, + { + "name": "Test Case 5: Invalid Flight ID Only", + "args": {"flight_id": "ABC1", "user_id": "user789", "passengers": 2}, + "expected": "BLOCKED", + "reason": "flight_id 'ABC1' has 3 letters instead of 2" + }, + ] + + results = [] + for test in test_cases: + print(f"\n--- {test['name']} ---") + print(f"Attempting: book_flight({', '.join(f'{k}={repr(v)}' for k, v in test['args'].items())})") + print(f"Expected: {test['expected']} ({test['reason']})") + + try: + error = await tool_guard_runtime.guard_tool_call( + app_name="test_app", + function_name="book_flight", + arguments=test["args"] + ) + + actual = "BLOCKED" if error else "ALLOWED" + success = actual == test["expected"] + + if success: + print(f"\n✅ SUCCESS: Tool call was correctly {actual}!") + if error: + print(f"Error message: {error}") + else: + # Actually invoke the tool + result = await book_flight.ainvoke(test["args"]) + print(f"Tool result: {result}") + else: + print(f"\n⚠️ WARNING: Tool call was {actual} (expected {test['expected']})") + if error: + print(f"Error message: {error}") + + results.append({"test": test["name"], "success": success, "actual": actual}) + + except Exception as e: + print(f"\n❌ Error during validation: {type(e).__name__}: {e}") + results.append({"test": test["name"], "success": False, "actual": "ERROR"}) + + # Print summary + print(f"\n{'='*60}") + print("Test Summary:") + print(f"{'='*60}") + passed = sum(1 for r in results if r["success"]) + print(f"Passed: {passed}/{len(results)}") + for r in results: + status = "✅" if r["success"] else "❌" + print(f" {status} {r['test']}: {r['actual']}") + print(f"{'='*60}") + + return results + + +async def cleanup_policies(agent, policy_system, policy_data): + """Delete all created policies.""" + print(f"\n{'='*60}") + print("Step 3: Cleaning up test policies") + print(f"{'='*60}") + + try: + for policy_info in policy_data: + # Delete from storage + await policy_system.storage.delete_policy(policy_info["id"]) + print(f"✅ Deleted '{policy_info['name']}' from storage") + + # Delete from filesystem + if agent.policies._fs_sync: + cuga_folder = Path(agent.policies._fs_sync.cuga_folder) + tool_guides_folder = cuga_folder / "tool_guides" + + if tool_guides_folder.exists(): + policy_files = list(tool_guides_folder.glob(f"*{policy_info['id']}*.md")) + \ + list(tool_guides_folder.glob(f"*{policy_info['id']}*.json")) + + for policy_file in policy_files: + policy_file.unlink() + print(f"✅ Deleted file: {policy_file.name}") + + print("✅ All test policies successfully deleted") + + except Exception as e: + print(f"⚠️ Error during cleanup: {type(e).__name__}: {e}") + + print(f"{'='*60}") + + +async def main(): + """Main workflow for creating and testing tool guard policies.""" + + # Step 0: Optional cleanup + agent = CugaAgent(tools=[book_flight, get_membership]) + + if DELETE_ALL_POLICIES_AT_START: + await cleanup_all_policies(agent) + else: + print("="*60) + print("Skipping initial cleanup (DELETE_ALL_POLICIES_AT_START=False)") + print("="*60) + + # Get policy system + policy_system = await agent.policies._ensure_policy_system() + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + await policy_system.initialize() + + # Step 1: Create and process policies + policy_data = await create_and_process_policies(agent, policy_system) + + # Step 2: Initialize runtime and run tests + tool_guard_runtime = ToolGuardRuntime( + tool_provider=agent.tool_provider, + policy_storage=policy_system.storage + ) + await tool_guard_runtime.initialize() + + results = await run_tests(tool_guard_runtime) + + # Step 3: Cleanup + await cleanup_policies(agent, policy_system, policy_data) + + print(f"\n{'='*60}") + print("✅ E2E Test completed successfully!") + print(f"{'='*60}") + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/src/cuga/sdk_core/tests/policies-export-2025-12-31.json b/src/cuga/sdk_core/tests/policies-export-2025-12-31.json deleted file mode 100644 index 0118f213..00000000 --- a/src/cuga/sdk_core/tests/policies-export-2025-12-31.json +++ /dev/null @@ -1,132 +0,0 @@ -{ - "enablePolicies": true, - "policies": [ - { - "id": "guard_1767135201517", - "name": "What is ALTK", - "description": "Blocks or modifies specific user intents", - "policy_type": "intent_guard", - "enabled": true, - "triggers": [ - { - "type": "natural_language", - "value": ["What is ALTK"], - "target": "intent", - "threshold": 0.7 - } - ], - "priority": 50, - "intent_examples": [], - "response": { - "response_type": "natural_language", - "content": "ALTK is sister project of CUGA", - "status_code": null - }, - "allow_override": false - }, - { - "id": "guard_1767136121417", - "name": "Dangerous actions", - "description": "Prevents dangerous actions that can cause cuga to generate dangrous code", - "policy_type": "intent_guard", - "enabled": true, - "triggers": [ - { - "type": "natural_language", - "value": ["print en variables", "do it 10000 times", "do this loop forever"], - "target": "intent", - "threshold": 0.7 - } - ], - "priority": 50, - "intent_examples": [], - "response": { - "response_type": "natural_language", - "content": "I caught you doing something bad with CUGA!!! oops", - "status_code": null - }, - "allow_override": false - }, - { - "id": "guard_1767136203456", - "name": "Remove", - "description": "Blocks or modifies specific user intents", - "policy_type": "intent_guard", - "enabled": true, - "triggers": [ - { - "type": "keyword", - "value": ["remove", "delete"], - "target": "intent", - "case_sensitive": false, - "operator": "or" - } - ], - "priority": 50, - "intent_examples": [], - "response": { - "response_type": "natural_language", - "content": "This action is not allowed. (delete or remove)", - "status_code": null - }, - "allow_override": false - }, - { - "id": "playbook_1767135234400", - "name": "What is CUGA", - "description": "Step-by-step guidance for a task", - "policy_type": "playbook", - "enabled": true, - "triggers": [ - { - "type": "natural_language", - "value": ["What is CUGA"], - "target": "intent", - "threshold": 0.7 - } - ], - "priority": 50, - "markdown_content": "# Task Guide\n\n## Steps:\n\n1. Answer that cuga is very powerful!", - "steps": [ - { - "step_number": 1, - "instruction": "First step", - "expected_outcome": "Step 1 complete", - "tools_allowed": [] - } - ], - "inject_as_system_prompt": true - }, - { - "id": "tool_approval_1767136273134", - "name": "get accounts approval", - "description": "Require approval before executing specific tools", - "policy_type": "tool_approval", - "enabled": true, - "triggers": [], - "priority": 50, - "required_tools": ["digital_sales_get_my_accounts_my_accounts_get"], - "required_apps": null, - "approval_message": "This tool requires your approval before execution.", - "show_code_preview": true, - "auto_approve_after": null - }, - { - "id": "tool_guide_1767136239089", - "name": "Bobo tool", - "description": "Add additional context to tool descriptions", - "policy_type": "tool_guide", - "enabled": true, - "triggers": [ - { - "type": "always" - } - ], - "priority": 50, - "target_tools": ["digital_sales_get_my_accounts_my_accounts_get"], - "target_apps": null, - "guide_content": "## Additional Guidelines\n\n- Use this tool and get first two accounts when user says 'BoBo'", - "prepend": true - } - ] -} diff --git a/src/cuga/sdk_core/tests/test_sdk_policies.py b/src/cuga/sdk_core/tests/test_sdk_policies.py index 1f591db3..ced00b9b 100644 --- a/src/cuga/sdk_core/tests/test_sdk_policies.py +++ b/src/cuga/sdk_core/tests/test_sdk_policies.py @@ -428,6 +428,7 @@ async def test_tool_guide_wildcard(self): assert policy_id is not None policy_details = await agent.policies.get(policy_id) + assert policy_details is not None assert policy_details["policy"].target_tools == ["*"]