From 1a7606357ea0ad709e68521603377f3f5a88b0d1 Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 13 Apr 2026 14:55:24 +0300 Subject: [PATCH 01/24] add fields --- .../backend/cuga_graph/policy/__init__.py | 8 + .../cuga_graph/policy/filesystem_sync.py | 11 +- .../cuga_graph/policy/folder_loader.py | 25 ++- src/cuga/backend/cuga_graph/policy/models.py | 18 ++ .../policy/tests/test_filesystem_sync.py | 52 ++++++ src/cuga/backend/cuga_graph/policy/utils.py | 28 +++- src/cuga/backend/server/main.py | 1 + src/cuga/sdk.py | 91 +++++++++++ src/cuga/sdk_core/tests/debug_tool_guide.py | 154 ++++++++++++++++++ 9 files changed, 373 insertions(+), 15 deletions(-) create mode 100644 src/cuga/sdk_core/tests/debug_tool_guide.py diff --git a/src/cuga/backend/cuga_graph/policy/__init__.py b/src/cuga/backend/cuga_graph/policy/__init__.py index 3dace391..50a224ab 100644 --- a/src/cuga/backend/cuga_graph/policy/__init__.py +++ b/src/cuga/backend/cuga_graph/policy/__init__.py @@ -8,6 +8,10 @@ PolicyType, Playbook, IntentGuard, + ToolGuide, + ToolGuard, + ToolApproval, + OutputFormatter, CustomPolicy, PlaybookStep, IntentGuardResponse, @@ -31,6 +35,10 @@ "PolicyType", "Playbook", "IntentGuard", + "ToolGuide", + "ToolGuard", + "ToolApproval", + "OutputFormatter", "CustomPolicy", "PlaybookStep", "IntentGuardResponse", diff --git a/src/cuga/backend/cuga_graph/policy/filesystem_sync.py b/src/cuga/backend/cuga_graph/policy/filesystem_sync.py index a290b29f..fb2801a3 100644 --- a/src/cuga/backend/cuga_graph/policy/filesystem_sync.py +++ b/src/cuga/backend/cuga_graph/policy/filesystem_sync.py @@ -91,7 +91,7 @@ def _policy_to_markdown(self, policy: Policy) -> str: # Build frontmatter # Get policy type - try 'type' first, then 'policy_type' policy_type = getattr(policy, 'type', None) or getattr(policy, 'policy_type', None) - policy_type_value = policy_type.value if hasattr(policy_type, 'value') else str(policy_type) + policy_type_value = policy_type.value if policy_type is not None and hasattr(policy_type, 'value') else str(policy_type) frontmatter = { 'id': policy.id, @@ -103,10 +103,11 @@ def _policy_to_markdown(self, policy: Policy) -> str: } # Add triggers if present - if hasattr(policy, 'triggers') and policy.triggers: + policy_triggers = getattr(policy, 'triggers', None) + if policy_triggers: triggers_config = {} - for trigger in policy.triggers: + for trigger in policy_triggers: if isinstance(trigger, KeywordTrigger): triggers_config['keywords'] = trigger.value triggers_config['target'] = trigger.target @@ -133,6 +134,10 @@ def _policy_to_markdown(self, policy: Policy) -> str: frontmatter['target_tools'] = policy.target_tools if policy.target_apps: frontmatter['target_apps'] = policy.target_apps + if policy.tool_guards: + frontmatter['tool_guards'] = { + tool_name: guard.model_dump() for tool_name, guard in policy.tool_guards.items() + } frontmatter['prepend'] = policy.prepend content = policy.guide_content or "" elif isinstance(policy, IntentGuard): diff --git a/src/cuga/backend/cuga_graph/policy/folder_loader.py b/src/cuga/backend/cuga_graph/policy/folder_loader.py index a88b1287..636d57bc 100644 --- a/src/cuga/backend/cuga_graph/policy/folder_loader.py +++ b/src/cuga/backend/cuga_graph/policy/folder_loader.py @@ -7,7 +7,7 @@ import os from pathlib import Path -from typing import Dict, Any, List +from typing import Dict, Any, List, cast import yaml from loguru import logger @@ -15,12 +15,14 @@ Playbook, OutputFormatter, ToolGuide, + ToolGuard, IntentGuard, ToolApproval, KeywordTrigger, NaturalLanguageTrigger, AlwaysTrigger, IntentGuardResponse, + Trigger, ) @@ -123,7 +125,7 @@ def create_playbook_from_markdown( raise ValueError(f"Playbook in {file_path} missing 'name' in frontmatter") triggers_config = frontmatter.get('triggers', {}) - triggers = create_triggers_from_metadata(triggers_config) + triggers = cast(List[Trigger], create_triggers_from_metadata(triggers_config)) if not triggers: raise ValueError(f"Playbook {name} must have at least one trigger") @@ -160,10 +162,10 @@ def create_output_formatter_from_markdown( raise ValueError(f"OutputFormatter in {file_path} missing 'name' in frontmatter") triggers_config = frontmatter.get('triggers', {}) - triggers = create_triggers_from_metadata(triggers_config) + triggers = cast(List[Trigger], create_triggers_from_metadata(triggers_config)) if not triggers: - triggers = [AlwaysTrigger()] + triggers = cast(List[Trigger], [AlwaysTrigger()]) format_type = frontmatter.get('format_type', 'markdown') if format_type not in ['markdown', 'json_schema', 'direct']: @@ -205,10 +207,17 @@ def create_tool_guide_from_markdown( raise ValueError(f"ToolGuide {name} must specify 'target_tools'") triggers_config = frontmatter.get('triggers', {}) - triggers = create_triggers_from_metadata(triggers_config) + triggers = cast(List[Trigger], create_triggers_from_metadata(triggers_config)) if not triggers: - triggers = [AlwaysTrigger()] + triggers = cast(List[Trigger], [AlwaysTrigger()]) + + raw_tool_guards = frontmatter.get('tool_guards') + tool_guards = ( + {tool_name: ToolGuard(**guard_config) for tool_name, guard_config in raw_tool_guards.items()} + if isinstance(raw_tool_guards, dict) + else None + ) return ToolGuide( id=frontmatter.get('id', f"tool_guide_{Path(file_path).stem}"), @@ -218,6 +227,7 @@ def create_tool_guide_from_markdown( target_tools=target_tools, target_apps=frontmatter.get('target_apps'), guide_content=content, + tool_guards=tool_guards, prepend=frontmatter.get('prepend', False), priority=frontmatter.get('priority', 50), enabled=frontmatter.get('enabled', True), @@ -244,7 +254,7 @@ def create_intent_guard_from_markdown( raise ValueError(f"IntentGuard in {file_path} missing 'name' in frontmatter") triggers_config = frontmatter.get('triggers', {}) - triggers = create_triggers_from_metadata(triggers_config) + triggers = cast(List[Trigger], create_triggers_from_metadata(triggers_config)) if not triggers: raise ValueError(f"IntentGuard {name} must have at least one trigger") @@ -261,6 +271,7 @@ def create_intent_guard_from_markdown( response=IntentGuardResponse( response_type=response_type, content=content, + status_code=frontmatter.get('status_code'), ), allow_override=frontmatter.get('allow_override', False), priority=frontmatter.get('priority', 50), diff --git a/src/cuga/backend/cuga_graph/policy/models.py b/src/cuga/backend/cuga_graph/policy/models.py index 3c8dc655..538edbd5 100644 --- a/src/cuga/backend/cuga_graph/policy/models.py +++ b/src/cuga/backend/cuga_graph/policy/models.py @@ -165,6 +165,21 @@ class IntentGuardResponse(BaseModel): status_code: Optional[int] = Field(None, description="HTTP status code if applicable") +class ToolGuard(BaseModel): + """Guard configuration for a specific tool with compliance rules.""" + + description: str = Field(..., description="Description of the guard rules for this tool") + violating_examples: List[str] = Field( + default_factory=list, description="Examples of violating usage patterns" + ) + compliance_examples: List[str] = Field( + default_factory=list, description="Examples of compliant usage patterns" + ) + policy_code: str = Field( + default="", description="Python code that validates tool usage compliance" + ) + + class IntentGuard(BaseModel): """Guard that intercepts intents and provides custom responses.""" @@ -214,6 +229,9 @@ class ToolGuide(BaseModel): None, description="List of app names to enrich tools for (optional)" ) guide_content: str = Field(..., description="Markdown content to append to tool descriptions") + tool_guards: Optional[Dict[str, ToolGuard]] = Field( + None, description="Optional guard configurations per tool (key: tool_name, value: ToolGuard)" + ) prepend: bool = Field(False, description="Whether to prepend content instead of appending") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") priority: int = Field(0, description="Priority when multiple guides match (higher = more important)") diff --git a/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py b/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py index 828fc76e..be8c223b 100644 --- a/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py +++ b/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py @@ -24,6 +24,7 @@ IntentGuard, Playbook, ToolGuide, + ToolGuard, ToolApproval, OutputFormatter, KeywordTrigger, @@ -164,6 +165,40 @@ async def test_save_policy_to_filesystem( for expected in expected_content: assert expected in content + @pytest.mark.asyncio + async def test_save_tool_guide_with_tool_guards_to_filesystem(self, temp_cuga_folder): + """Test saving tool_guards in ToolGuide frontmatter.""" + fs_sync = PolicyFilesystemSync(cuga_folder=temp_cuga_folder) + policy = ToolGuide( + id="test_guide_with_guards", + name="Guide With Guards", + description="Test tool guide with per-tool guard configuration", + triggers=[AlwaysTrigger()], + target_tools=["test_tool"], + target_apps=None, + guide_content="## Guidelines\n- Be careful", + tool_guards={ + "test_tool": ToolGuard( + description="Only allow safe usage", + violating_examples=["Deleting all records without confirmation"], + compliance_examples=["Delete one record after explicit confirmation"], + policy_code="def validate(call):\n return True", + ) + }, + prepend=False, + priority=0, + enabled=True, + ) + + file_path = fs_sync.save_policy_to_file(policy) + + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + assert "tool_guards:" in content + assert "test_tool:" in content + assert "Only allow safe usage" in content + @pytest.mark.asyncio async def test_delete_policy_file(self, temp_cuga_folder): """Test deleting a policy file""" @@ -376,7 +411,19 @@ async def test_auto_load_multiple_policy_types(self, temp_cuga_folder): description="Test", triggers=[AlwaysTrigger()], target_tools=["*"], + target_apps=None, guide_content="## Test", + tool_guards={ + "test_tool": ToolGuard( + description="Guard loaded from filesystem", + violating_examples=["bad_example()"], + compliance_examples=["good_example()"], + policy_code="def validate(call):\n return True", + ) + }, + prepend=False, + priority=0, + enabled=True, ) fs_sync.save_policy_to_file(guard) @@ -399,6 +446,11 @@ async def test_auto_load_multiple_policy_types(self, temp_cuga_folder): assert "playbook" in policy_types assert "tool_guide" in policy_types + loaded_guide = next(p for p in policies if p["id"] == "guide_1") + assert loaded_guide["policy"].tool_guards is not None + assert "test_tool" in loaded_guide["policy"].tool_guards + assert loaded_guide["policy"].tool_guards["test_tool"].description == "Guard loaded from filesystem" + @pytest.mark.asyncio async def test_auto_load_disabled(self, temp_cuga_folder): """Test that auto-load can be disabled""" diff --git a/src/cuga/backend/cuga_graph/policy/utils.py b/src/cuga/backend/cuga_graph/policy/utils.py index 2abf20db..e336180e 100644 --- a/src/cuga/backend/cuga_graph/policy/utils.py +++ b/src/cuga/backend/cuga_graph/policy/utils.py @@ -14,6 +14,7 @@ Policy, PolicyType, ToolGuide, + ToolGuard, ToolApproval, ) from cuga.backend.cuga_graph.policy.storage import PolicyStorage @@ -215,6 +216,19 @@ async def apply_policies_data_to_storage( enabled=policy_data.get("enabled", True), ) elif policy_type == "tool_guide": + raw_tool_guards = policy_data.get("tool_guards") + tool_guards = ( + { + tool_name: ( + guard_config + if isinstance(guard_config, ToolGuard) + else ToolGuard(**guard_config) + ) + for tool_name, guard_config in raw_tool_guards.items() + } + if isinstance(raw_tool_guards, dict) + else None + ) policy = ToolGuide( id=policy_data["id"], name=policy_data["name"], @@ -223,6 +237,7 @@ async def apply_policies_data_to_storage( target_tools=policy_data.get("target_tools", []), target_apps=policy_data.get("target_apps"), guide_content=policy_data.get("guide_content", ""), + tool_guards=tool_guards, prepend=policy_data.get("prepend", False), priority=policy_data.get("priority", 50), enabled=policy_data.get("enabled", True), @@ -423,8 +438,8 @@ async def restore_policies(storage: PolicyStorage, backup_dir: str) -> int: for policy_type in PolicyType: backup_file = backup_path / f"policies_{policy_type.value}.json" if backup_file.exists(): - count = await load_policies_from_json(str(backup_file), storage) - total_count += count + result = await load_policies_from_json(str(backup_file), storage) + total_count += result.get("count", 0) logger.info(f"Restored {total_count} policies from {backup_dir}") return total_count @@ -453,7 +468,8 @@ def validate_policy(policy: Policy) -> tuple[bool, List[str]]: errors.append("Policy name is required") if not policy.description: errors.append("Policy description is required") - if not policy.triggers: + policy_triggers = getattr(policy, "triggers", None) + if policy_triggers is not None and not policy_triggers: errors.append("At least one trigger is required") # Type-specific validation @@ -522,16 +538,18 @@ def format_policy_summary(policy: Policy) -> str: Returns: Formatted summary string """ + policy_triggers = getattr(policy, "triggers", None) or [] + lines = [ f"Policy: {policy.name} ({policy.id})", f"Type: {policy.type}", f"Description: {policy.description}", f"Priority: {policy.priority}", f"Enabled: {'Yes' if policy.enabled else 'No'}", - f"Triggers: {len(policy.triggers)}", + f"Triggers: {len(policy_triggers)}", ] - for i, trigger in enumerate(policy.triggers): + for i, trigger in enumerate(policy_triggers): value = getattr(trigger, 'value', 'N/A') if isinstance(value, list): value_str = ', '.join(value) if value else '[]' diff --git a/src/cuga/backend/server/main.py b/src/cuga/backend/server/main.py index b81e8f73..e736debd 100644 --- a/src/cuga/backend/server/main.py +++ b/src/cuga/backend/server/main.py @@ -2365,6 +2365,7 @@ async def get_policies_config( frontend_policy["target_tools"] = policy_dict.get("target_tools", []) frontend_policy["target_apps"] = policy_dict.get("target_apps") frontend_policy["guide_content"] = policy_dict.get("guide_content", "") + frontend_policy["tool_guards"] = policy_dict.get("tool_guards") frontend_policy["prepend"] = policy_dict.get("prepend", False) elif policy_dict["type"] == "tool_approval": frontend_policy["required_tools"] = policy_dict.get("required_tools", []) diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index 53d56a4f..6a76be26 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -98,6 +98,7 @@ def delete_database(table: str) -> str: IntentGuard, Playbook, ToolGuide, + ToolGuard, ToolApproval, OutputFormatter, KeywordTrigger, @@ -523,6 +524,96 @@ async def add_tool_guide( logger.info(f"Added Tool Guide policy: {policy.id}") return policy.id + async def update_tool_guard( + self, + policy_id: str, + tool_guards: Dict[str, Dict[str, Any]], + ) -> str: + """ + Update an existing Tool Guide policy with tool_guards. + + Args: + policy_id: ID of the existing Tool Guide policy to update + tool_guards: Dict of tool guards (key: tool_name, value: dict with 'description', 'violating_examples', 'compliance_examples', 'policy_code') + + Returns: + Policy ID + + Raises: + ValueError: If policy not found or not a ToolGuide type + + Example: + ```python + await agent.policies.update_tool_guard( + policy_id="tool_guide_abc123", + tool_guards={ + "delete_file": { + "description": "Guard rules for file deletion", + "violating_examples": ["Delete system files"], + "compliance_examples": ["Delete user files with confirmation"], + "policy_code": "" + } + } + ) + ``` + """ + policy_system = await self._ensure_policy_system() + if policy_system is None: + logger.warning("Policy system is disabled - skipping update_tool_guard") + return None + + # Retrieve the existing policy + existing_policy = await policy_system.storage.get_policy(policy_id) + if existing_policy is None: + raise ValueError(f"Policy with ID '{policy_id}' not found") + + # Verify it's a ToolGuide policy + if not isinstance(existing_policy, ToolGuide): + raise ValueError( + f"Policy '{policy_id}' is not a ToolGuide policy (type: {type(existing_policy).__name__})" + ) + + # Convert tool_guards dict to ToolGuard objects + tool_guards_obj = {} + for tool_name, guard_data in tool_guards.items(): + tool_guards_obj[tool_name] = ToolGuard( + description=guard_data.get("description", ""), + violating_examples=guard_data.get("violating_examples", []), + compliance_examples=guard_data.get("compliance_examples", []), + policy_code=guard_data.get("policy_code", ""), + ) + + # Create updated policy with tool_guards + updated_policy = ToolGuide( + id=existing_policy.id, + name=existing_policy.name, + description=existing_policy.description, + triggers=existing_policy.triggers, + target_tools=existing_policy.target_tools, + target_apps=existing_policy.target_apps, + guide_content=existing_policy.guide_content, + tool_guards=tool_guards_obj, + prepend=existing_policy.prepend, + priority=existing_policy.priority, + enabled=existing_policy.enabled, + metadata=existing_policy.metadata, + ) + + # Update in storage + await policy_system.storage.update_policy(updated_policy) + await policy_system.initialize() # Reload policies + + # Save to filesystem if sync is enabled + if self._fs_sync: + try: + self._fs_sync.save_policy_to_file(updated_policy) + logger.debug(f"Saved updated policy '{policy_id}' to filesystem") + except Exception as e: + logger.warning(f"Failed to save updated policy to filesystem: {e}") + + logger.info(f"Updated Tool Guide policy '{policy_id}' with tool_guards") + return policy_id + async def add_tool_approval( self, name: str, diff --git a/src/cuga/sdk_core/tests/debug_tool_guide.py b/src/cuga/sdk_core/tests/debug_tool_guide.py new file mode 100644 index 00000000..c3e1f8ef --- /dev/null +++ b/src/cuga/sdk_core/tests/debug_tool_guide.py @@ -0,0 +1,154 @@ +""" +Debug script for creating, storing, retrieving, and exporting a tool guide policy. + +This script demonstrates: +1. Creating a CugaAgent with a tool +2. Adding a new tool guide policy +3. Saving the policy +4. Retrieving the policy from storage +5. Exporting the policy data to a JSON file +""" + +import asyncio +import json +from pathlib import Path +from datetime import datetime + +from langchain_core.tools import tool + +from cuga import CugaAgent +from cuga.backend.cuga_graph.policy.models import ToolGuide + + +@tool +def delete_file(file_path: str) -> str: + """Delete a file from the system""" + return f"File {file_path} has been deleted" + + +async def main(): + """Main workflow for creating and managing a tool guide policy.""" + + # Step 1: Create a CugaAgent with the delete_file tool + print("Step 1: Creating CugaAgent with delete_file tool...") + agent = CugaAgent(tools=[delete_file]) + + # Step 2: Add a new tool guide policy + print("\nStep 2: Adding new tool guide policy...") + policy_id = await agent.policies.add_tool_guide( + name="File Deletion Safety Policy", + content="## File Deletion Guidelines\n- Never delete system files\n- Always verify file path before deletion\n- Require confirmation for critical files\n- Log all deletion operations", + target_tools=["delete_file"], + description="Safety guidelines for file deletion operations to prevent accidental data loss", + ) + print(f"Created policy with ID: {policy_id}") + + # Step 2b: Update the policy with tool_guards + print("\nStep 2b: Updating policy with tool_guards...") + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "delete_file": { + "description": "Guard rules for safe file deletion to prevent accidental data loss", + "violating_examples": [ + "Deleting system files like /etc/passwd or C:\\Windows\\System32", + "Deleting files without user confirmation", + "Deleting files in protected directories", + "Bulk deletion without verification" + ], + "compliance_examples": [ + "Verify file path is in user's home directory before deletion", + "Request explicit user confirmation before deleting", + "Check file is not a system file or in protected directory", + "Log all deletion operations with timestamp and user" + ], + "policy_code": "" + } + } + ) + print(f"Updated policy {policy_id} with tool_guards") + + # Step 3: Save the policy + print("\nStep 3: Saving policy...") + policy_tool_guide = await agent.policies.get(policy_id) + if policy_tool_guide is None: + raise ValueError(f"Failed to retrieve policy {policy_id}") + + policy = policy_tool_guide["policy"] + + # Get policy system and update + policy_system = await agent.policies._ensure_policy_system() + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + + #await policy_system.storage.update_policy(policy) + await policy_system.initialize() + + # Save to file system if available + if agent.policies._fs_sync: + agent.policies._fs_sync.save_policy_to_file(policy) + + print("Policy saved successfully!") + + # Step 4: Retrieve the policy from storage + print("\nStep 4: Retrieving policy from storage...") + retrieved_policy = await policy_system.storage.get_policy(policy_id) + if retrieved_policy is None: + raise ValueError(f"Failed to retrieve updated policy {policy_id}") + + if not isinstance(retrieved_policy, ToolGuide): + raise TypeError(f"Expected ToolGuide, got {type(retrieved_policy).__name__}") + + print("Policy retrieved successfully!") + + # Step 5: Export to JSON file + print("\nStep 5: Exporting policy data to JSON file...") + + # Convert tool_guards to serializable format + tool_guards_data = None + if retrieved_policy.tool_guards: + tool_guards_data = {} + for tool_name, guard in retrieved_policy.tool_guards.items(): + tool_guards_data[tool_name] = { + "description": guard.description, + "violating_examples": guard.violating_examples, + "compliance_examples": guard.compliance_examples, + "policy_code": guard.policy_code, + } + + output_data = { + "export_timestamp": datetime.utcnow().isoformat(), + "policy_id": policy_id, + "policy_name": retrieved_policy.name, + "policy_description": retrieved_policy.description, + "guide_content": retrieved_policy.guide_content, + "target_tools": retrieved_policy.target_tools, + "target_apps": retrieved_policy.target_apps, + "tool_guards": tool_guards_data, + "prepend": retrieved_policy.prepend, + "priority": retrieved_policy.priority, + "enabled": retrieved_policy.enabled, + "policy_type": type(retrieved_policy).__name__, + "metadata": retrieved_policy.metadata, + } + + # Save to JSON file in current directory + output_path = Path("tool_guide_policy_export.json") + output_path.write_text(json.dumps(output_data, indent=2), encoding="utf-8") + + print(f"\n{'='*60}") + print("SUCCESS! Policy workflow completed:") + print(f"{'='*60}") + print(f"Policy ID: {policy_id}") + print(f"Policy Name: {retrieved_policy.name}") + print(f"Target Tools: {', '.join(retrieved_policy.target_tools)}") + print(f"\nExported to: {output_path.absolute()}") + print(f"{'='*60}") + print("\nExported JSON content:") + print(json.dumps(output_data, indent=2)) + + +if __name__ == "__main__": + asyncio.run(main()) + +# Made with Bob \ No newline at end of file From 855e1d1d69ddffdffa4a4a1406c01daa3aa1474e Mon Sep 17 00:00:00 2001 From: naamaz Date: Tue, 14 Apr 2026 12:47:01 +0300 Subject: [PATCH 02/24] cuga tools 2 oas tool info --- pyproject.toml | 1 + .../cuga_graph/policy/tool_guard/__init__.py | 12 + .../policy/tool_guard/cuga_to_tool_info.py | 483 +++++++++++++ .../policy/tool_guard/example_usage.py | 672 ++++++++++++++++++ 4 files changed, 1168 insertions(+) create mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py create mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/cuga_to_tool_info.py create mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py diff --git a/pyproject.toml b/pyproject.toml index b18cec1f..67af385f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "langchain-text-splitters>=0.3,<0.4", "beautifulsoup4>=4.12", "psycopg[binary]>=3.2", + "toolguard>=0.2.16", ] [project.optional-dependencies] diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py new file mode 100644 index 00000000..f0a950ab --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py @@ -0,0 +1,12 @@ +""" +Tool Guard integration for CUGA. + +This module provides integration between CUGA's tool system and Toolguard's +policy enforcement framework. +""" + +from .cuga_to_tool_info import cuga_tools_to_oas + +__all__ = ["cuga_tools_to_oas"] + +# Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_to_tool_info.py b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_to_tool_info.py new file mode 100644 index 00000000..87bbd56a --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_to_tool_info.py @@ -0,0 +1,483 @@ +""" +Convert CUGA tools to Toolguard's OpenAPI Specification format. + +This module provides functionality to convert tools from CUGA's ToolProviderInterface +or CombinedToolProvider into Toolguard's OpenAPI format, which can then be converted +to ToolInfo objects for policy enforcement. + +Supports complex return types including: +- Pydantic models (automatically converted to JSON schema) +- List[PydanticModel] (arrays with item schemas) +- Dict[str, Any] (nested object structures) +- Simple types (str, int, bool, etc.) +""" + +from typing import Any, Dict, List, Union, get_origin, get_args +import inspect + +from langchain_core.tools import StructuredTool +from loguru import logger +from pydantic import BaseModel, Field + +from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import CombinedToolProvider +from cuga.backend.cuga_graph.nodes.cuga_lite.tool_provider_interface import ToolProviderInterface + +try: + from toolguard.buildtime.utils.open_api import ( + Components, + Info, + MediaType, + OpenAPI, + Operation, + Parameter, + ParameterIn, + PathItem, + RequestBody, + Response, + Server, + ) + from toolguard.buildtime.utils.jschema import JSchema, JSONSchemaTypes +except ImportError as e: + raise ImportError( + "toolguard is required for this module. Install it with: pip install toolguard" + ) from e + + +def _extract_json_schema_from_pydantic(args_schema) -> Dict[str, Any]: + """ + Extract JSON schema from a Pydantic model. + + Args: + args_schema: Pydantic model class + + Returns: + JSON schema dictionary + """ + if args_schema is None: + return {"type": "object", "properties": {}, "required": []} + + try: + # Get the JSON schema from the Pydantic model + schema = args_schema.model_json_schema() + return schema + except Exception as e: + logger.warning(f"Failed to extract JSON schema from args_schema: {e}") + return {"type": "object", "properties": {}, "required": []} + + +def _extract_return_type_schema(tool: StructuredTool) -> Dict[str, Any]: + """ + Extract JSON schema from a tool's return type annotation. + + Supports: + - Pydantic models (BaseModel subclasses) + - List[PydanticModel] + - Dict[str, Any] + - Simple types (str, int, bool, float) + + Args: + tool: StructuredTool instance + + Returns: + JSON schema dictionary representing the return type + """ + try: + # Get the function from the tool + if not hasattr(tool, "func"): + logger.debug(f"Tool {tool.name} has no func attribute") + return {"type": "object", "description": "Tool execution result"} + + func = tool.func + if func is None: + logger.debug(f"Tool {tool.name} func is None") + return {"type": "object", "description": "Tool execution result"} + + # Get the return type annotation + sig = inspect.signature(func) + return_annotation = sig.return_annotation + + if return_annotation is inspect.Signature.empty: + logger.debug(f"Tool {tool.name} has no return type annotation") + return {"type": "object", "description": "Tool execution result"} + + # Handle string annotations (forward references) + if isinstance(return_annotation, str): + logger.debug(f"Tool {tool.name} has string return annotation: {return_annotation}") + return {"type": "object", "description": f"Returns {return_annotation}"} + + # Check if it's a Pydantic model + if inspect.isclass(return_annotation) and issubclass(return_annotation, BaseModel): + logger.debug(f"Tool {tool.name} returns Pydantic model: {return_annotation.__name__}") + schema = return_annotation.model_json_schema() + # Remove $defs if present and inline them + if "$defs" in schema: + defs = schema.pop("$defs") + # Resolve any $ref references + schema = _resolve_schema_refs(schema, defs) + return schema + + # Check if it's a List type + origin = get_origin(return_annotation) + if origin is list or origin is List: + args = get_args(return_annotation) + if args: + item_type = args[0] + # Check if list item is a Pydantic model + if inspect.isclass(item_type) and issubclass(item_type, BaseModel): + logger.debug(f"Tool {tool.name} returns List[{item_type.__name__}]") + item_schema = item_type.model_json_schema() + # Remove $defs and inline + if "$defs" in item_schema: + defs = item_schema.pop("$defs") + item_schema = _resolve_schema_refs(item_schema, defs) + return { + "type": "array", + "items": item_schema, + "description": f"List of {item_type.__name__} objects" + } + else: + # Simple list type + return { + "type": "array", + "items": {"type": _python_type_to_json_type(item_type)}, + "description": f"List of {item_type.__name__ if hasattr(item_type, '__name__') else str(item_type)}" + } + return {"type": "array", "description": "List of items"} + + # Check if it's a Dict type + if origin is dict or origin is Dict: + logger.debug(f"Tool {tool.name} returns Dict") + return { + "type": "object", + "description": "Dictionary with dynamic keys", + "additionalProperties": True + } + + # Handle simple types + simple_type = _python_type_to_json_type(return_annotation) + if simple_type != "object": + logger.debug(f"Tool {tool.name} returns simple type: {simple_type}") + return {"type": simple_type, "description": f"Returns {simple_type}"} + + # Default fallback + logger.debug(f"Tool {tool.name} has unhandled return type: {return_annotation}") + return {"type": "object", "description": "Tool execution result"} + + except Exception as e: + logger.warning(f"Failed to extract return type schema for tool {tool.name}: {e}") + return {"type": "object", "description": "Tool execution result"} + + +def _python_type_to_json_type(python_type) -> str: + """Convert Python type to JSON schema type string.""" + type_mapping = { + str: "string", + int: "integer", + float: "number", + bool: "boolean", + list: "array", + dict: "object", + } + return type_mapping.get(python_type, "object") + + +def _resolve_schema_refs(schema: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: + """ + Resolve $ref references in a schema by inlining definitions. + + Args: + schema: Schema that may contain $ref + defs: Definitions to resolve references from + + Returns: + Schema with references resolved + """ + if isinstance(schema, dict): + if "$ref" in schema: + # Extract the reference name + ref = schema["$ref"] + if ref.startswith("#/$defs/"): + def_name = ref.split("/")[-1] + if def_name in defs: + # Return the resolved definition (recursively resolve it too) + return _resolve_schema_refs(defs[def_name], defs) + return schema + else: + # Recursively resolve in nested structures + return {k: _resolve_schema_refs(v, defs) for k, v in schema.items()} + elif isinstance(schema, list): + return [_resolve_schema_refs(item, defs) for item in schema] + else: + return schema + + +def _dict_schema_to_jschema(schema_dict: Dict[str, Any]) -> JSchema: + """ + Convert a dictionary JSON schema to a JSchema object. + + Args: + schema_dict: Dictionary containing JSON schema + + Returns: + JSchema object + """ + # Extract basic properties + schema_type = schema_dict.get("type", "object") + description = schema_dict.get("description", "") + + # Map string type to JSONSchemaTypes enum + try: + json_type = JSONSchemaTypes(schema_type) + except ValueError: + json_type = JSONSchemaTypes.object + + # Build JSchema kwargs + jschema_kwargs = { + "type": json_type, + "description": description, + } + + # Handle properties for object types + if "properties" in schema_dict: + properties = {} + for prop_name, prop_schema in schema_dict["properties"].items(): + # Recursively convert nested schemas + if isinstance(prop_schema, dict): + properties[prop_name] = _dict_schema_to_jschema(prop_schema) + else: + properties[prop_name] = prop_schema + jschema_kwargs["properties"] = properties + + # Handle required fields + if "required" in schema_dict: + jschema_kwargs["required"] = schema_dict["required"] + + # Handle array items + if "items" in schema_dict: + items_schema = schema_dict["items"] + if isinstance(items_schema, dict): + jschema_kwargs["items"] = _dict_schema_to_jschema(items_schema) + else: + jschema_kwargs["items"] = items_schema + + # Handle additionalProperties for dynamic objects + if "additionalProperties" in schema_dict: + jschema_kwargs["additionalProperties"] = schema_dict["additionalProperties"] + + return JSchema(**jschema_kwargs) + + +def _convert_tool_to_operation(tool: StructuredTool, app_name: str) -> tuple[str, Operation]: + """ + Convert a single CUGA StructuredTool to an OpenAPI Operation. + + Args: + tool: LangChain StructuredTool instance + app_name: Name of the application/service providing the tool + + Returns: + Tuple of (operation_id, Operation) + """ + # Extract tool metadata + tool_name = tool.name + description = tool.description or "" + + # Get operation_id from tool metadata if available + operation_id = getattr(tool.func, "_operation_id", None) if hasattr(tool, "func") else None + if not operation_id: + operation_id = f"{app_name}_{tool_name}" + + # Extract parameters from the tool's args_schema + parameters: List[Parameter] = [] + request_body: RequestBody | None = None + + if tool.args_schema: + schema = _extract_json_schema_from_pydantic(tool.args_schema) + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + # Convert properties to OpenAPI parameters or request body + # For simplicity, we'll use query parameters for simple types + # and request body for complex objects + body_properties = {} + + for param_name, param_schema in properties.items(): + param_type = param_schema.get("type", "string") + param_desc = param_schema.get("description", "") + is_required = param_name in required_fields + + # Simple types go as query parameters + if param_type in ["string", "integer", "number", "boolean"]: + # Map string type to JSONSchemaTypes enum + json_type = JSONSchemaTypes(param_type) + parameters.append( + Parameter( + name=param_name, + description=param_desc, + **{"in": ParameterIn.query}, # Use dict unpacking for 'in' keyword + required=is_required, + schema=JSchema( + type=json_type, + description=param_desc, + ), + ) + ) + else: + # Complex types go in request body + body_properties[param_name] = param_schema + + # Create request body if there are complex parameters + if body_properties: + request_body = RequestBody( + description="Request body parameters", + required=True, + content={ + "application/json": MediaType( + schema=JSchema( + type=JSONSchemaTypes.object, + properties=body_properties, + required=[k for k in body_properties.keys() if k in required_fields], + ) + ) + }, + ) + + # Create response schema by extracting return type information + logger.debug(f"Extracting return type schema for tool: {tool_name}") + return_schema = _extract_return_type_schema(tool) + + # Convert the return schema dict to JSchema + # We need to handle the schema conversion properly + response_jschema = _dict_schema_to_jschema(return_schema) + + # Create response with the extracted schema + responses: Dict[str, Union[Response, Any]] = { + "200": Response( + description="Successful response", + content={ + "application/json": MediaType( + schema=response_jschema + ) + }, + ) + } + + # Create the operation + # Cast parameters list to match expected type + params_list: List[Union[Parameter, Any]] = parameters if parameters else [] + operation = Operation( + operationId=operation_id, + summary=tool_name, + description=description, + parameters=params_list if params_list else None, + requestBody=request_body, + responses=responses, + tags=[app_name], + ) + + return operation_id, operation + + +async def cuga_tools_to_oas( + tool_provider: Union[ToolProviderInterface, CombinedToolProvider], + title: str = "CUGA Tools API", + version: str = "1.0.0", + description: str = "OpenAPI specification for CUGA tools", +) -> OpenAPI: + """ + Convert all tools from a CUGA tool provider to Toolguard's OpenAPI format. + + This function extracts all tools from the provided tool provider and converts + them into a complete OpenAPI specification that can be used with Toolguard + for policy enforcement. + + Args: + tool_provider: CUGA tool provider (ToolProviderInterface or CombinedToolProvider) + title: Title for the OpenAPI specification + version: Version of the API + description: Description of the API + + Returns: + OpenAPI object with all tools converted to operations + + Example: + ```python + from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import CombinedToolProvider + from cuga.backend.cuga_graph.policy.tool_guard import cuga_tools_to_oas + + # Create tool provider + provider = CombinedToolProvider() + await provider.initialize() + + # Convert to OpenAPI + oas = await cuga_tools_to_oas(provider) + + # Save to file + oas.save("tools.yaml") + + # Or convert to ToolInfo for Toolguard + from toolguard.buildtime.gen_spec.oas_to_toolinfo import openapi_to_toolinfos + tool_infos = openapi_to_toolinfos(oas) + ``` + """ + # Initialize the tool provider if needed + if not getattr(tool_provider, "initialized", False): + await tool_provider.initialize() + + # Get all apps and their tools + apps = await tool_provider.get_apps() + + # Create OpenAPI structure + paths: Dict[str, Union[PathItem, Any]] = {} + all_tags = [] + + # Process each app + for app in apps: + app_name = app.name + all_tags.append({"name": app_name, "description": app.description or f"Tools from {app_name}"}) + + # Get tools for this app + tools = await tool_provider.get_tools(app_name) + + logger.info(f"Converting {len(tools)} tools from app '{app_name}' to OpenAPI") + + # Convert each tool to an operation + for tool in tools: + try: + operation_id, operation = _convert_tool_to_operation(tool, app_name) + + # Create a path for this tool + # Use the tool name as the path + path = f"/{app_name}/{tool.name}" + + # Create or update PathItem + if path not in paths: + paths[path] = PathItem() + + # Add operation as POST (most tools are actions) + paths[path].post = operation + + logger.debug(f"Converted tool '{tool.name}' to operation '{operation_id}'") + + except Exception as e: + logger.error(f"Failed to convert tool '{tool.name}' from app '{app_name}': {e}") + continue + + # Create the OpenAPI object + openapi = OpenAPI( + openapi="3.1.0", + info=Info( + title=title, + version=version, + description=description, + ), + servers=[Server(url="/")], + paths=paths, + tags=all_tags, + ) + + logger.info(f"Successfully converted {len(paths)} tool paths to OpenAPI specification") + + return openapi + +# Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py b/src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py new file mode 100644 index 00000000..9faa9aa8 --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py @@ -0,0 +1,672 @@ +""" +Example usage of the CUGA to Toolguard integration. + +This file demonstrates how to convert CUGA tools to Toolguard's OpenAPI format +and then to ToolInfo objects for policy enforcement. + +Includes examples of: +- Simple return types (str, int) +- Complex return types (Pydantic models, nested dicts, lists of objects) +- How OpenAPI specs reflect these complex structures +""" + +import asyncio +from pathlib import Path +from typing import List, Dict, Any +from datetime import datetime + +from loguru import logger +from langchain_core.tools import tool +from pydantic import BaseModel, Field + +from cuga.backend.cuga_graph.nodes.cuga_lite.direct_langchain_tools_provider import DirectLangChainToolsProvider +from cuga.backend.cuga_graph.policy.tool_guard import cuga_tools_to_oas + + +# ============================================================================ +# Complex Object Models (Pydantic) +# ============================================================================ + +class ContactInfo(BaseModel): + """Contact information model""" + id: int = Field(description="Unique contact identifier") + name: str = Field(description="Full name of the contact") + email: str = Field(description="Email address") + phone: str | None = Field(None, description="Optional phone number") + company: str | None = Field(None, description="Company name") + tags: List[str] = Field(default_factory=list, description="Contact tags") + + +class EmailMetadata(BaseModel): + """Email metadata model""" + sent_at: str = Field(description="Timestamp when email was sent") + message_id: str = Field(description="Unique message identifier") + status: str = Field(description="Delivery status") + + +class EmailResult(BaseModel): + """Email sending result with metadata""" + success: bool = Field(description="Whether email was sent successfully") + recipient: str = Field(description="Email recipient") + subject: str = Field(description="Email subject") + metadata: EmailMetadata = Field(description="Email metadata") + + +class WeatherData(BaseModel): + """Detailed weather information""" + location: str = Field(description="Location name") + temperature: float = Field(description="Current temperature") + units: str = Field(description="Temperature units (celsius/fahrenheit)") + conditions: str = Field(description="Weather conditions") + humidity: int = Field(description="Humidity percentage") + wind_speed: float = Field(description="Wind speed") + forecast: List[Dict[str, Any]] = Field(description="3-day forecast") + + +class SearchResult(BaseModel): + """Search result with pagination""" + query: str = Field(description="Search query used") + total_results: int = Field(description="Total number of results") + page: int = Field(description="Current page number") + results: List[ContactInfo] = Field(description="List of matching contacts") + has_more: bool = Field(description="Whether more results are available") + + +# ============================================================================ +# Simple Tools (String Return Types) +# ============================================================================ + +@tool +def calculate(expression: str) -> str: + """ + Evaluate a mathematical expression. + + Args: + expression: Mathematical expression to evaluate (e.g., "2 + 2", "10 * 5") + + Returns: + Result of the calculation + """ + try: + # Simple eval for demo - in production use a safe math parser + result = eval(expression, {"__builtins__": {}}, {}) + return f"Result: {result}" + except Exception as e: + return f"Error: {str(e)}" + + +# ============================================================================ +# Complex Tools (Pydantic Model Return Types) +# ============================================================================ + +@tool +def search_contacts_advanced(query: str, limit: int = 10, page: int = 1) -> SearchResult: + """ + Search for contacts in the CRM system with advanced filtering. + + This tool returns a complex SearchResult object containing: + - Query metadata (query string, pagination info) + - List of ContactInfo objects with full contact details + - Pagination information + + Args: + query: Search query string to find contacts + limit: Maximum number of results per page (default: 10) + page: Page number for pagination (default: 1) + + Returns: + SearchResult object with matching contacts and metadata + """ + # Simulate search results + contacts = [ + ContactInfo( + id=i, + name=f"Contact {i}", + email=f"contact{i}@example.com", + phone=f"+1-555-000{i}", + company=f"Company {i % 3}", + tags=["customer", "active"] if i % 2 == 0 else ["lead"] + ) + for i in range(1, min(limit + 1, 6)) + ] + + return SearchResult( + query=query, + total_results=50, + page=page, + results=contacts, + has_more=True + ) + + +@tool +def create_contact_advanced( + name: str, + email: str, + phone: str | None = None, + company: str | None = None, + tags: List[str] | None = None +) -> ContactInfo: + """ + Create a new contact in the CRM system. + + This tool returns a ContactInfo Pydantic model with all contact details. + The OpenAPI spec will automatically reflect the structure of ContactInfo. + + Args: + name: Full name of the contact + email: Email address of the contact + phone: Optional phone number + company: Optional company name + tags: Optional list of tags for categorization + + Returns: + ContactInfo object with the created contact details + """ + return ContactInfo( + id=12345, + name=name, + email=email, + phone=phone, + company=company, + tags=tags or [] + ) + + +@tool +def send_email_advanced(to: str, subject: str, body: str) -> EmailResult: + """ + Send an email to a recipient with detailed result tracking. + + This tool returns an EmailResult object containing: + - Success status + - Recipient information + - Email metadata (timestamp, message ID, delivery status) + + Args: + to: Email address of the recipient + subject: Subject line of the email + body: Body content of the email + + Returns: + EmailResult object with sending status and metadata + """ + return EmailResult( + success=True, + recipient=to, + subject=subject, + metadata=EmailMetadata( + sent_at=datetime.now().isoformat(), + message_id=f"msg_{hash(to + subject)}", + status="delivered" + ) + ) + + +@tool +def get_weather_detailed(location: str, units: str = "celsius") -> WeatherData: + """ + Get detailed weather information for a location. + + This tool returns a WeatherData object with comprehensive weather information: + - Current conditions (temperature, humidity, wind) + - 3-day forecast + - Location details + + The OpenAPI spec will show the full nested structure including the forecast array. + + Args: + location: City name or location to get weather for + units: Temperature units - 'celsius' or 'fahrenheit' (default: celsius) + + Returns: + WeatherData object with current conditions and forecast + """ + return WeatherData( + location=location, + temperature=22.5 if units == "celsius" else 72.5, + units=units, + conditions="Partly Cloudy", + humidity=65, + wind_speed=12.5, + forecast=[ + {"day": "Tomorrow", "high": 24, "low": 18, "conditions": "Sunny"}, + {"day": "Day 2", "high": 23, "low": 17, "conditions": "Cloudy"}, + {"day": "Day 3", "high": 21, "low": 16, "conditions": "Rainy"} + ] + ) + + +@tool +def get_contact_list(limit: int = 10) -> List[ContactInfo]: + """ + Get a list of all contacts. + + This tool returns a List of ContactInfo objects. + The OpenAPI spec will represent this as an array of ContactInfo schemas. + + Args: + limit: Maximum number of contacts to return + + Returns: + List of ContactInfo objects + """ + return [ + ContactInfo( + id=i, + name=f"Contact {i}", + email=f"contact{i}@example.com", + phone=f"+1-555-000{i}", + company=f"Company {i % 3}", + tags=["customer"] if i % 2 == 0 else ["lead"] + ) + for i in range(1, min(limit + 1, 11)) + ] + + +@tool +def get_contact_stats() -> Dict[str, Any]: + """ + Get statistics about contacts in the system. + + This tool returns a nested dictionary structure. + The OpenAPI spec will represent this as an object with nested properties. + + Returns: + Dictionary with contact statistics including counts by category and recent activity + """ + return { + "total_contacts": 1250, + "by_category": { + "customers": 800, + "leads": 350, + "partners": 100 + }, + "recent_activity": { + "added_today": 15, + "updated_today": 42, + "emails_sent_today": 127 + }, + "top_companies": [ + {"name": "Acme Corp", "contact_count": 45}, + {"name": "TechStart Inc", "contact_count": 38}, + {"name": "Global Solutions", "contact_count": 32} + ] + } + + +async def example_convert_tools_to_oas(): + """ + Example: Convert CUGA tools to OpenAPI specification. + + This example demonstrates how simple and complex return types are represented + in the generated OpenAPI specification. + """ + logger.info("Creating tool provider with sample tools...") + + # Create a direct tool provider with sample tools (mix of simple and complex) + sample_tools = [ + calculate, # Simple: returns str + search_contacts_advanced, # Complex: returns SearchResult (Pydantic model) + create_contact_advanced, # Complex: returns ContactInfo (Pydantic model) + send_email_advanced, # Complex: returns EmailResult with nested EmailMetadata + get_weather_detailed, # Complex: returns WeatherData with forecast array + get_contact_list, # Complex: returns List[ContactInfo] + get_contact_stats, # Complex: returns nested Dict + ] + + provider = DirectLangChainToolsProvider(tools=sample_tools, app_name="demo_app") + + # Initialize the provider + await provider.initialize() + + logger.info("Converting tools to OpenAPI specification...") + logger.info("Note: Complex return types (Pydantic models, nested dicts, lists) will be") + logger.info(" fully represented in the OpenAPI schema with their complete structure.") + + # Convert all tools to OpenAPI format + oas = await cuga_tools_to_oas( + tool_provider=provider, + title="CUGA Tools API - Complex Objects Demo", + version="1.0.0", + description="OpenAPI specification demonstrating complex return types including Pydantic models, nested objects, and lists", + ) + + # Save to file + output_path = Path("cuga_tools_openapi.yaml") + oas.save(output_path) + logger.info(f"OpenAPI specification saved to {output_path}") + logger.info("Check the file to see how complex objects are represented in the schema!") + + # Also save as JSON + json_path = Path("cuga_tools_openapi.json") + oas.save(json_path) + logger.info(f"OpenAPI specification saved to {json_path}") + + return oas + + +async def example_convert_to_toolinfo(): + """ + Example: Convert CUGA tools to Toolguard ToolInfo objects. + + This demonstrates how complex return types are preserved through the + OpenAPI -> ToolInfo conversion process. + """ + from toolguard.buildtime.gen_spec.oas_to_toolinfo import openapi_to_toolinfos + + logger.info("Creating tool provider with complex return type tools...") + + # Create a direct tool provider with tools that return complex objects + sample_tools = [ + calculate, + search_contacts_advanced, + create_contact_advanced, + send_email_advanced, + get_weather_detailed, + get_contact_list, + get_contact_stats, + ] + + provider = DirectLangChainToolsProvider(tools=sample_tools, app_name="demo_app") + await provider.initialize() + + logger.info("Converting tools to OpenAPI specification...") + + # Convert to OpenAPI + oas = await cuga_tools_to_oas(provider) + + logger.info("Converting OpenAPI to ToolInfo objects...") + + # Convert OpenAPI to ToolInfo objects (for Toolguard policy enforcement) + tool_infos = openapi_to_toolinfos(oas) + + logger.info(f"Converted {len(tool_infos)} tools to ToolInfo objects") + logger.info("Complex return types are preserved in the ToolInfo schema!") + + # Print some information about the tools + for tool_info in tool_infos: + logger.info(f"\nTool: {tool_info.name}") + logger.info(f" Summary: {tool_info.summary}") + logger.info(f" Parameters: {list(tool_info.parameters.keys())}") + # Note: Response schema information is also preserved in the ToolInfo + + return tool_infos + + +async def example_with_multiple_apps(): + """ + Example: Convert tools from multiple app groups with complex return types. + + This shows how different apps can have tools with different complexity levels. + """ + logger.info("Creating tool provider with multiple app groups...") + + # Create CRM tools (complex return types) + crm_tools = [search_contacts_advanced, create_contact_advanced, get_contact_list] + crm_provider = DirectLangChainToolsProvider(tools=crm_tools, app_name="crm") + await crm_provider.initialize() + + # Create email tools (complex return types) + email_tools = [send_email_advanced] + email_provider = DirectLangChainToolsProvider(tools=email_tools, app_name="email") + await email_provider.initialize() + + # Create weather tools (complex return types) + weather_tools = [get_weather_detailed] + weather_provider = DirectLangChainToolsProvider(tools=weather_tools, app_name="weather") + await weather_provider.initialize() + + logger.info("Converting CRM tools to OpenAPI specification...") + + # Convert CRM tools to OpenAPI + crm_oas = await cuga_tools_to_oas( + tool_provider=crm_provider, + title="CRM Tools API", + version="1.0.0", + description="OpenAPI specification for CRM tools with complex ContactInfo and SearchResult models", + ) + + logger.info("Converting Email tools to OpenAPI specification...") + + # Convert Email tools to OpenAPI + email_oas = await cuga_tools_to_oas( + tool_provider=email_provider, + title="Email Tools API", + version="1.0.0", + description="OpenAPI specification for Email tools with EmailResult and nested EmailMetadata", + ) + + logger.info("Converting Weather tools to OpenAPI specification...") + + # Convert Weather tools to OpenAPI + weather_oas = await cuga_tools_to_oas( + tool_provider=weather_provider, + title="Weather Tools API", + version="1.0.0", + description="OpenAPI specification for Weather tools with WeatherData including forecast arrays", + ) + + # Save to files + crm_path = Path("crm_tools_openapi.yaml") + crm_oas.save(crm_path) + logger.info(f"CRM OpenAPI specification saved to {crm_path}") + + email_path = Path("email_tools_openapi.yaml") + email_oas.save(email_path) + logger.info(f"Email OpenAPI specification saved to {email_path}") + + weather_path = Path("weather_tools_openapi.yaml") + weather_oas.save(weather_path) + logger.info(f"Weather OpenAPI specification saved to {weather_path}") + + logger.info("\nAll OpenAPI specs include full schema definitions for complex return types!") + + return crm_oas, email_oas, weather_oas + + +async def example_complex_objects_showcase(): + """ + Example: Showcase how different complex return types are represented in OpenAPI. + + This example demonstrates: + 1. Pydantic models -> Full JSON schema with all fields + 2. List[PydanticModel] -> Array schema with item definitions + 3. Nested objects -> Hierarchical schema structure + 4. Dict[str, Any] -> Object schema with nested properties + """ + logger.info("=" * 80) + logger.info("COMPLEX OBJECT RETURN TYPES SHOWCASE") + logger.info("=" * 80) + + # Create provider with one tool of each type + showcase_tools = [ + create_contact_advanced, # Returns: ContactInfo (Pydantic model) + get_contact_list, # Returns: List[ContactInfo] + send_email_advanced, # Returns: EmailResult with nested EmailMetadata + get_weather_detailed, # Returns: WeatherData with forecast array + get_contact_stats, # Returns: Dict[str, Any] with nested structure + ] + + provider = DirectLangChainToolsProvider(tools=showcase_tools, app_name="showcase") + await provider.initialize() + + logger.info("\nConverting tools with complex return types to OpenAPI...") + + oas = await cuga_tools_to_oas( + tool_provider=provider, + title="Complex Return Types Showcase", + version="1.0.0", + description="Demonstrates how various complex return types are represented in OpenAPI specs", + ) + + # Save the spec + output_path = Path("complex_objects_showcase.yaml") + oas.save(output_path) + + logger.info(f"\n✅ OpenAPI spec saved to: {output_path}") + logger.info("\nWhat to look for in the generated spec:") + logger.info(" 1. 'components/schemas' section - Contains all Pydantic model definitions") + logger.info(" 2. Each tool's response schema - References the component schemas") + logger.info(" 3. Nested objects - Fully expanded with all properties") + logger.info(" 4. Arrays - Properly typed with 'items' schema") + logger.info(" 5. Optional fields - Marked appropriately in the schema") + + logger.info("\n" + "=" * 80) + logger.info("EXAMPLE OUTPUT STRUCTURES:") + logger.info("=" * 80) + + # Show example outputs + logger.info("\n1. ContactInfo (Pydantic Model):") + logger.info(" {") + logger.info(' "id": 12345,') + logger.info(' "name": "John Doe",') + logger.info(' "email": "john@example.com",') + logger.info(' "phone": "+1-555-0001",') + logger.info(' "company": "Acme Corp",') + logger.info(' "tags": ["customer", "vip"]') + logger.info(" }") + + logger.info("\n2. List[ContactInfo]:") + logger.info(" [") + logger.info(" { ContactInfo object 1 },") + logger.info(" { ContactInfo object 2 },") + logger.info(" ...") + logger.info(" ]") + + logger.info("\n3. EmailResult with nested EmailMetadata:") + logger.info(" {") + logger.info(' "success": true,') + logger.info(' "recipient": "user@example.com",') + logger.info(' "subject": "Hello",') + logger.info(' "metadata": {') + logger.info(' "sent_at": "2024-01-01T12:00:00",') + logger.info(' "message_id": "msg_12345",') + logger.info(' "status": "delivered"') + logger.info(" }") + logger.info(" }") + + logger.info("\n4. WeatherData with forecast array:") + logger.info(" {") + logger.info(' "location": "New York",') + logger.info(' "temperature": 22.5,') + logger.info(' "forecast": [') + logger.info(' {"day": "Tomorrow", "high": 24, "low": 18},') + logger.info(" ...") + logger.info(" ]") + logger.info(" }") + + logger.info("\n5. Nested Dict[str, Any]:") + logger.info(" {") + logger.info(' "total_contacts": 1250,') + logger.info(' "by_category": {') + logger.info(' "customers": 800,') + logger.info(' "leads": 350') + logger.info(" },") + logger.info(' "top_companies": [') + logger.info(' {"name": "Acme", "contact_count": 45}') + logger.info(" ]") + logger.info(" }") + + logger.info("\n" + "=" * 80) + logger.info("✅ All these structures are fully represented in the OpenAPI spec!") + logger.info("=" * 80) + + return oas + + +async def main(): + """ + Run all examples demonstrating complex object return types. + """ + logger.info("\n" + "🚀 " * 20) + logger.info("CUGA TO TOOLGUARD - COMPLEX OBJECT EXAMPLES") + logger.info("🚀 " * 20 + "\n") + + logger.info("=" * 80) + logger.info("Example 1: Complex Objects Showcase") + logger.info("=" * 80) + try: + await example_complex_objects_showcase() + except Exception as e: + logger.error(f"Example 1 failed: {e}") + import traceback + traceback.print_exc() + + logger.info("\n" + "=" * 80) + logger.info("Example 2: Convert all tools to OpenAPI") + logger.info("=" * 80) + try: + await example_convert_tools_to_oas() + except Exception as e: + logger.error(f"Example 2 failed: {e}") + import traceback + traceback.print_exc() + + logger.info("\n" + "=" * 80) + logger.info("Example 3: Convert to ToolInfo objects") + logger.info("=" * 80) + try: + await example_convert_to_toolinfo() + except Exception as e: + logger.error(f"Example 3 failed: {e}") + import traceback + traceback.print_exc() + + logger.info("\n" + "=" * 80) + logger.info("Example 4: Convert multiple app groups") + logger.info("=" * 80) + try: + await example_with_multiple_apps() + except Exception as e: + logger.error(f"Example 4 failed: {e}") + import traceback + traceback.print_exc() + + logger.info("\n" + "🎉 " * 20) + logger.info("ALL EXAMPLES COMPLETED!") + logger.info("🎉 " * 20 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) + + +# ============================================================================ +# SUMMARY: Complex Object Return Types in OpenAPI Specs +# ============================================================================ +# +# This file demonstrates that YES, OpenAPI specs WILL reflect complex object +# return types! Here's what's included: +# +# 1. PYDANTIC MODELS (ContactInfo, EmailResult, WeatherData, etc.) +# - Automatically converted to JSON Schema +# - All fields with proper types (int, str, bool, etc.) +# - Nested models preserved (EmailResult contains EmailMetadata) +# - Optional fields marked appropriately +# +# 2. LIST OF COMPLEX OBJECTS (List[ContactInfo], List[ProductModel]) +# - Represented as arrays in OpenAPI +# - Item schema fully defined with all properties +# +# 3. NESTED DICTIONARIES (get_contact_stats returns Dict[str, Any]) +# - Hierarchical structure preserved +# - Nested objects and arrays maintained +# +# 4. MIXED COMPLEXITY +# - WeatherData includes both simple fields AND a forecast array +# - SearchResult includes metadata AND a list of ContactInfo objects +# +# The generated OpenAPI specification will include: +# - Full schema definitions in 'components/schemas' +# - Response schemas for each tool referencing these components +# - Complete type information for all fields +# - Proper handling of optional/required fields +# +# Run this file to see the generated OpenAPI specs: +# python src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py +# +# Then check the generated YAML/JSON files to see how complex objects are +# represented in the OpenAPI specification! +# +# Made with Bob From 8979bd623832c30911942f3b423aa48765b1aeb5 Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 20 Apr 2026 10:47:15 +0300 Subject: [PATCH 03/24] create examples --- src/cuga/backend/cuga_graph/policy/models.py | 2 +- .../cuga_graph/policy/tool_guard/__init__.py | 3 +- .../policy/tool_guard/cuga_llm_adapter.py | 200 ++++++ .../policy/tool_guard/cuga_to_tool_info.py | 483 ------------- .../tool_guard/example_cuga_llm_usage.py | 210 ++++++ .../policy/tool_guard/example_usage.py | 672 ------------------ .../cuga_graph/policy/tool_guard/manager.py | 190 +++++ src/cuga/sdk_core/tests/debug_tool_guide.py | 14 +- src/cuga/sdk_core/tests/test_sdk_policies.py | 1 + 9 files changed, 617 insertions(+), 1158 deletions(-) create mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py delete mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/cuga_to_tool_info.py create mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/example_cuga_llm_usage.py delete mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py create mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/manager.py diff --git a/src/cuga/backend/cuga_graph/policy/models.py b/src/cuga/backend/cuga_graph/policy/models.py index 538edbd5..1ba47448 100644 --- a/src/cuga/backend/cuga_graph/policy/models.py +++ b/src/cuga/backend/cuga_graph/policy/models.py @@ -230,7 +230,7 @@ class ToolGuide(BaseModel): ) guide_content: str = Field(..., description="Markdown content to append to tool descriptions") tool_guards: Optional[Dict[str, ToolGuard]] = Field( - None, description="Optional guard configurations per tool (key: tool_name, value: ToolGuard)" + default=None, description="Optional guard configurations per tool (key: tool_name, value: ToolGuard)" ) prepend: bool = Field(False, description="Whether to prepend content instead of appending") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py index f0a950ab..4b886b9a 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py @@ -6,7 +6,8 @@ """ from .cuga_to_tool_info import cuga_tools_to_oas +from .manager import ToolGuardManager -__all__ = ["cuga_tools_to_oas"] +__all__ = ["cuga_tools_to_oas", "ToolGuardManager"] # Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py new file mode 100644 index 00000000..0b257f5a --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py @@ -0,0 +1,200 @@ +""" +Adapter to use CUGA's LLM (BaseChatModel) with ToolGuard's I_TG_LLM interface. + +This module provides a bridge between CUGA's LangChain-based LLM system and +ToolGuard's example generation system, allowing any CUGA-configured model +(OpenAI, Groq, Azure, etc.) to be used with ToolGuard. +""" + +from typing import Dict, List +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage +from loguru import logger +from toolguard.buildtime.llm.llm_base import LanguageModelBase + + +class CugaLLMAdapter(LanguageModelBase): + """ + Adapter that wraps CUGA's BaseChatModel to work with ToolGuard's I_TG_LLM interface. + + This allows using any LangChain-compatible model (OpenAI, Groq, Azure, etc.) + with ToolGuard's example generation system. The adapter handles message format + conversion between ToolGuard's dict-based format and LangChain's message objects. + + Inherits from LanguageModelBase which provides: + - chat_json(): Automatic JSON extraction with retry logic + - extract_json_from_string(): JSON parsing utilities + + Example: + ```python + from cuga.sdk import CugaAgent + from cuga.backend.cuga_graph.policy.tool_guard.cuga_llm_adapter import CugaLLMAdapter + + # Create agent with any model + agent = CugaAgent(tools=[my_tool]) + + # Wrap the model for ToolGuard + llm_adapter = CugaLLMAdapter(agent._model) + + # Use with ToolGuard + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Generate examples"} + ] + response = await llm_adapter.generate(messages) + + # Or use chat_json for structured output + json_response = await llm_adapter.chat_json(messages) + ``` + + Attributes: + model: The underlying LangChain BaseChatModel instance + """ + + def __init__(self, model: BaseChatModel): + """ + Initialize the adapter with a CUGA/LangChain model. + + Args: + model: BaseChatModel instance (e.g., ChatOpenAI, ChatGroq, AzureChatOpenAI, etc.) + This is typically obtained from agent._model in CugaAgent. + + Raises: + TypeError: If model is not a BaseChatModel instance + """ + if not isinstance(model, BaseChatModel): + raise TypeError( + f"Expected BaseChatModel instance, got {type(model).__name__}. " + "Please pass agent._model from a CugaAgent instance." + ) + + self.model = model + logger.info(f"Initialized CugaLLMAdapter with model: {type(model).__name__}") + + def _convert_to_langchain_messages(self, messages: List[Dict]) -> List[BaseMessage]: + """ + Convert ToolGuard message format to LangChain message objects. + + ToolGuard uses a simple dict format: + [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}] + + LangChain uses typed message objects: + [HumanMessage(content="Hello"), AIMessage(content="Hi")] + + Args: + messages: List of dicts with 'role' and 'content' keys. + Supported roles: 'system', 'user', 'human', 'assistant', 'ai' + + Returns: + List of LangChain message objects (SystemMessage, HumanMessage, AIMessage) + + Raises: + ValueError: If a message is missing 'role' or 'content' keys + """ + langchain_messages = [] + + for i, msg in enumerate(messages): + if not isinstance(msg, dict): + raise ValueError( + f"Message at index {i} must be a dict, got {type(msg).__name__}" + ) + + if "role" not in msg: + raise ValueError(f"Message at index {i} missing 'role' key: {msg}") + + if "content" not in msg: + raise ValueError(f"Message at index {i} missing 'content' key: {msg}") + + role = msg["role"].lower() + content = msg["content"] + + # Convert to appropriate LangChain message type + if role == "system": + langchain_messages.append(SystemMessage(content=content)) + elif role in ("assistant", "ai"): + langchain_messages.append(AIMessage(content=content)) + elif role in ("user", "human"): + langchain_messages.append(HumanMessage(content=content)) + else: + logger.warning( + f"Unknown role '{role}' at index {i}, treating as user message" + ) + langchain_messages.append(HumanMessage(content=content)) + + return langchain_messages + + async def generate(self, messages: List[Dict]) -> str: + """ + Generate a text response from the model. + + This is the core method required by I_TG_LLM interface. It converts + ToolGuard's message format to LangChain format, invokes the model, + and returns the generated text. + + Args: + messages: List of message dicts in format: + [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Generate examples"}, + {"role": "assistant", "content": "Here are examples..."} + ] + + Returns: + Generated text response as string + + Raises: + ValueError: If messages format is invalid + Exception: If model invocation fails (propagated from LangChain) + + Example: + ```python + adapter = CugaLLMAdapter(model) + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Say hello"} + ] + response = await adapter.generate(messages) + print(response) # "Hello! How can I help you today?" + ``` + """ + try: + # Convert to LangChain format + langchain_messages = self._convert_to_langchain_messages(messages) + + logger.debug( + f"Invoking {type(self.model).__name__} with {len(langchain_messages)} messages" + ) + + # Invoke the model using LangChain's async interface + response = await self.model.ainvoke(langchain_messages) + + # Extract content from response + # LangChain models return AIMessage or similar with .content attribute + if hasattr(response, 'content'): + content = response.content + # Handle case where content might be a list (e.g., multimodal responses) + if isinstance(content, list): + # Join list items into a single string + content = "\n".join(str(item) for item in content) + elif not isinstance(content, str): + content = str(content) + else: + content = str(response) + + logger.debug(f"Generated response length: {len(content)} characters") + + return content + + except ValueError as e: + # Re-raise validation errors + logger.error(f"Message format error: {e}") + raise + except Exception as e: + # Log and re-raise model invocation errors + logger.error(f"Error generating response with {type(self.model).__name__}: {e}") + raise RuntimeError( + f"Failed to generate response: {e}. " + "Check that your model is properly configured and has valid credentials." + ) from e + +# Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_to_tool_info.py b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_to_tool_info.py deleted file mode 100644 index 87bbd56a..00000000 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_to_tool_info.py +++ /dev/null @@ -1,483 +0,0 @@ -""" -Convert CUGA tools to Toolguard's OpenAPI Specification format. - -This module provides functionality to convert tools from CUGA's ToolProviderInterface -or CombinedToolProvider into Toolguard's OpenAPI format, which can then be converted -to ToolInfo objects for policy enforcement. - -Supports complex return types including: -- Pydantic models (automatically converted to JSON schema) -- List[PydanticModel] (arrays with item schemas) -- Dict[str, Any] (nested object structures) -- Simple types (str, int, bool, etc.) -""" - -from typing import Any, Dict, List, Union, get_origin, get_args -import inspect - -from langchain_core.tools import StructuredTool -from loguru import logger -from pydantic import BaseModel, Field - -from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import CombinedToolProvider -from cuga.backend.cuga_graph.nodes.cuga_lite.tool_provider_interface import ToolProviderInterface - -try: - from toolguard.buildtime.utils.open_api import ( - Components, - Info, - MediaType, - OpenAPI, - Operation, - Parameter, - ParameterIn, - PathItem, - RequestBody, - Response, - Server, - ) - from toolguard.buildtime.utils.jschema import JSchema, JSONSchemaTypes -except ImportError as e: - raise ImportError( - "toolguard is required for this module. Install it with: pip install toolguard" - ) from e - - -def _extract_json_schema_from_pydantic(args_schema) -> Dict[str, Any]: - """ - Extract JSON schema from a Pydantic model. - - Args: - args_schema: Pydantic model class - - Returns: - JSON schema dictionary - """ - if args_schema is None: - return {"type": "object", "properties": {}, "required": []} - - try: - # Get the JSON schema from the Pydantic model - schema = args_schema.model_json_schema() - return schema - except Exception as e: - logger.warning(f"Failed to extract JSON schema from args_schema: {e}") - return {"type": "object", "properties": {}, "required": []} - - -def _extract_return_type_schema(tool: StructuredTool) -> Dict[str, Any]: - """ - Extract JSON schema from a tool's return type annotation. - - Supports: - - Pydantic models (BaseModel subclasses) - - List[PydanticModel] - - Dict[str, Any] - - Simple types (str, int, bool, float) - - Args: - tool: StructuredTool instance - - Returns: - JSON schema dictionary representing the return type - """ - try: - # Get the function from the tool - if not hasattr(tool, "func"): - logger.debug(f"Tool {tool.name} has no func attribute") - return {"type": "object", "description": "Tool execution result"} - - func = tool.func - if func is None: - logger.debug(f"Tool {tool.name} func is None") - return {"type": "object", "description": "Tool execution result"} - - # Get the return type annotation - sig = inspect.signature(func) - return_annotation = sig.return_annotation - - if return_annotation is inspect.Signature.empty: - logger.debug(f"Tool {tool.name} has no return type annotation") - return {"type": "object", "description": "Tool execution result"} - - # Handle string annotations (forward references) - if isinstance(return_annotation, str): - logger.debug(f"Tool {tool.name} has string return annotation: {return_annotation}") - return {"type": "object", "description": f"Returns {return_annotation}"} - - # Check if it's a Pydantic model - if inspect.isclass(return_annotation) and issubclass(return_annotation, BaseModel): - logger.debug(f"Tool {tool.name} returns Pydantic model: {return_annotation.__name__}") - schema = return_annotation.model_json_schema() - # Remove $defs if present and inline them - if "$defs" in schema: - defs = schema.pop("$defs") - # Resolve any $ref references - schema = _resolve_schema_refs(schema, defs) - return schema - - # Check if it's a List type - origin = get_origin(return_annotation) - if origin is list or origin is List: - args = get_args(return_annotation) - if args: - item_type = args[0] - # Check if list item is a Pydantic model - if inspect.isclass(item_type) and issubclass(item_type, BaseModel): - logger.debug(f"Tool {tool.name} returns List[{item_type.__name__}]") - item_schema = item_type.model_json_schema() - # Remove $defs and inline - if "$defs" in item_schema: - defs = item_schema.pop("$defs") - item_schema = _resolve_schema_refs(item_schema, defs) - return { - "type": "array", - "items": item_schema, - "description": f"List of {item_type.__name__} objects" - } - else: - # Simple list type - return { - "type": "array", - "items": {"type": _python_type_to_json_type(item_type)}, - "description": f"List of {item_type.__name__ if hasattr(item_type, '__name__') else str(item_type)}" - } - return {"type": "array", "description": "List of items"} - - # Check if it's a Dict type - if origin is dict or origin is Dict: - logger.debug(f"Tool {tool.name} returns Dict") - return { - "type": "object", - "description": "Dictionary with dynamic keys", - "additionalProperties": True - } - - # Handle simple types - simple_type = _python_type_to_json_type(return_annotation) - if simple_type != "object": - logger.debug(f"Tool {tool.name} returns simple type: {simple_type}") - return {"type": simple_type, "description": f"Returns {simple_type}"} - - # Default fallback - logger.debug(f"Tool {tool.name} has unhandled return type: {return_annotation}") - return {"type": "object", "description": "Tool execution result"} - - except Exception as e: - logger.warning(f"Failed to extract return type schema for tool {tool.name}: {e}") - return {"type": "object", "description": "Tool execution result"} - - -def _python_type_to_json_type(python_type) -> str: - """Convert Python type to JSON schema type string.""" - type_mapping = { - str: "string", - int: "integer", - float: "number", - bool: "boolean", - list: "array", - dict: "object", - } - return type_mapping.get(python_type, "object") - - -def _resolve_schema_refs(schema: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: - """ - Resolve $ref references in a schema by inlining definitions. - - Args: - schema: Schema that may contain $ref - defs: Definitions to resolve references from - - Returns: - Schema with references resolved - """ - if isinstance(schema, dict): - if "$ref" in schema: - # Extract the reference name - ref = schema["$ref"] - if ref.startswith("#/$defs/"): - def_name = ref.split("/")[-1] - if def_name in defs: - # Return the resolved definition (recursively resolve it too) - return _resolve_schema_refs(defs[def_name], defs) - return schema - else: - # Recursively resolve in nested structures - return {k: _resolve_schema_refs(v, defs) for k, v in schema.items()} - elif isinstance(schema, list): - return [_resolve_schema_refs(item, defs) for item in schema] - else: - return schema - - -def _dict_schema_to_jschema(schema_dict: Dict[str, Any]) -> JSchema: - """ - Convert a dictionary JSON schema to a JSchema object. - - Args: - schema_dict: Dictionary containing JSON schema - - Returns: - JSchema object - """ - # Extract basic properties - schema_type = schema_dict.get("type", "object") - description = schema_dict.get("description", "") - - # Map string type to JSONSchemaTypes enum - try: - json_type = JSONSchemaTypes(schema_type) - except ValueError: - json_type = JSONSchemaTypes.object - - # Build JSchema kwargs - jschema_kwargs = { - "type": json_type, - "description": description, - } - - # Handle properties for object types - if "properties" in schema_dict: - properties = {} - for prop_name, prop_schema in schema_dict["properties"].items(): - # Recursively convert nested schemas - if isinstance(prop_schema, dict): - properties[prop_name] = _dict_schema_to_jschema(prop_schema) - else: - properties[prop_name] = prop_schema - jschema_kwargs["properties"] = properties - - # Handle required fields - if "required" in schema_dict: - jschema_kwargs["required"] = schema_dict["required"] - - # Handle array items - if "items" in schema_dict: - items_schema = schema_dict["items"] - if isinstance(items_schema, dict): - jschema_kwargs["items"] = _dict_schema_to_jschema(items_schema) - else: - jschema_kwargs["items"] = items_schema - - # Handle additionalProperties for dynamic objects - if "additionalProperties" in schema_dict: - jschema_kwargs["additionalProperties"] = schema_dict["additionalProperties"] - - return JSchema(**jschema_kwargs) - - -def _convert_tool_to_operation(tool: StructuredTool, app_name: str) -> tuple[str, Operation]: - """ - Convert a single CUGA StructuredTool to an OpenAPI Operation. - - Args: - tool: LangChain StructuredTool instance - app_name: Name of the application/service providing the tool - - Returns: - Tuple of (operation_id, Operation) - """ - # Extract tool metadata - tool_name = tool.name - description = tool.description or "" - - # Get operation_id from tool metadata if available - operation_id = getattr(tool.func, "_operation_id", None) if hasattr(tool, "func") else None - if not operation_id: - operation_id = f"{app_name}_{tool_name}" - - # Extract parameters from the tool's args_schema - parameters: List[Parameter] = [] - request_body: RequestBody | None = None - - if tool.args_schema: - schema = _extract_json_schema_from_pydantic(tool.args_schema) - properties = schema.get("properties", {}) - required_fields = schema.get("required", []) - - # Convert properties to OpenAPI parameters or request body - # For simplicity, we'll use query parameters for simple types - # and request body for complex objects - body_properties = {} - - for param_name, param_schema in properties.items(): - param_type = param_schema.get("type", "string") - param_desc = param_schema.get("description", "") - is_required = param_name in required_fields - - # Simple types go as query parameters - if param_type in ["string", "integer", "number", "boolean"]: - # Map string type to JSONSchemaTypes enum - json_type = JSONSchemaTypes(param_type) - parameters.append( - Parameter( - name=param_name, - description=param_desc, - **{"in": ParameterIn.query}, # Use dict unpacking for 'in' keyword - required=is_required, - schema=JSchema( - type=json_type, - description=param_desc, - ), - ) - ) - else: - # Complex types go in request body - body_properties[param_name] = param_schema - - # Create request body if there are complex parameters - if body_properties: - request_body = RequestBody( - description="Request body parameters", - required=True, - content={ - "application/json": MediaType( - schema=JSchema( - type=JSONSchemaTypes.object, - properties=body_properties, - required=[k for k in body_properties.keys() if k in required_fields], - ) - ) - }, - ) - - # Create response schema by extracting return type information - logger.debug(f"Extracting return type schema for tool: {tool_name}") - return_schema = _extract_return_type_schema(tool) - - # Convert the return schema dict to JSchema - # We need to handle the schema conversion properly - response_jschema = _dict_schema_to_jschema(return_schema) - - # Create response with the extracted schema - responses: Dict[str, Union[Response, Any]] = { - "200": Response( - description="Successful response", - content={ - "application/json": MediaType( - schema=response_jschema - ) - }, - ) - } - - # Create the operation - # Cast parameters list to match expected type - params_list: List[Union[Parameter, Any]] = parameters if parameters else [] - operation = Operation( - operationId=operation_id, - summary=tool_name, - description=description, - parameters=params_list if params_list else None, - requestBody=request_body, - responses=responses, - tags=[app_name], - ) - - return operation_id, operation - - -async def cuga_tools_to_oas( - tool_provider: Union[ToolProviderInterface, CombinedToolProvider], - title: str = "CUGA Tools API", - version: str = "1.0.0", - description: str = "OpenAPI specification for CUGA tools", -) -> OpenAPI: - """ - Convert all tools from a CUGA tool provider to Toolguard's OpenAPI format. - - This function extracts all tools from the provided tool provider and converts - them into a complete OpenAPI specification that can be used with Toolguard - for policy enforcement. - - Args: - tool_provider: CUGA tool provider (ToolProviderInterface or CombinedToolProvider) - title: Title for the OpenAPI specification - version: Version of the API - description: Description of the API - - Returns: - OpenAPI object with all tools converted to operations - - Example: - ```python - from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import CombinedToolProvider - from cuga.backend.cuga_graph.policy.tool_guard import cuga_tools_to_oas - - # Create tool provider - provider = CombinedToolProvider() - await provider.initialize() - - # Convert to OpenAPI - oas = await cuga_tools_to_oas(provider) - - # Save to file - oas.save("tools.yaml") - - # Or convert to ToolInfo for Toolguard - from toolguard.buildtime.gen_spec.oas_to_toolinfo import openapi_to_toolinfos - tool_infos = openapi_to_toolinfos(oas) - ``` - """ - # Initialize the tool provider if needed - if not getattr(tool_provider, "initialized", False): - await tool_provider.initialize() - - # Get all apps and their tools - apps = await tool_provider.get_apps() - - # Create OpenAPI structure - paths: Dict[str, Union[PathItem, Any]] = {} - all_tags = [] - - # Process each app - for app in apps: - app_name = app.name - all_tags.append({"name": app_name, "description": app.description or f"Tools from {app_name}"}) - - # Get tools for this app - tools = await tool_provider.get_tools(app_name) - - logger.info(f"Converting {len(tools)} tools from app '{app_name}' to OpenAPI") - - # Convert each tool to an operation - for tool in tools: - try: - operation_id, operation = _convert_tool_to_operation(tool, app_name) - - # Create a path for this tool - # Use the tool name as the path - path = f"/{app_name}/{tool.name}" - - # Create or update PathItem - if path not in paths: - paths[path] = PathItem() - - # Add operation as POST (most tools are actions) - paths[path].post = operation - - logger.debug(f"Converted tool '{tool.name}' to operation '{operation_id}'") - - except Exception as e: - logger.error(f"Failed to convert tool '{tool.name}' from app '{app_name}': {e}") - continue - - # Create the OpenAPI object - openapi = OpenAPI( - openapi="3.1.0", - info=Info( - title=title, - version=version, - description=description, - ), - servers=[Server(url="/")], - paths=paths, - tags=all_tags, - ) - - logger.info(f"Successfully converted {len(paths)} tool paths to OpenAPI specification") - - return openapi - -# Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/example_cuga_llm_usage.py b/src/cuga/backend/cuga_graph/policy/tool_guard/example_cuga_llm_usage.py new file mode 100644 index 00000000..cb0266a4 --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/example_cuga_llm_usage.py @@ -0,0 +1,210 @@ +""" +Example usage of CugaLLMAdapter with ToolGuard. + +This demonstrates how to use CUGA's LLM system with ToolGuard's +example generation capabilities. +""" + +import asyncio +from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import CombinedToolProvider +from cuga.sdk import CugaAgent +from cuga.backend.cuga_graph.policy.tool_guard.cuga_llm_adapter import CugaLLMAdapter +from langchain_core.tools import tool + + +@tool +def delete_file(path: str) -> str: + """Delete a file from the filesystem""" + return f"Deleted file: {path}" + + +@tool +def read_file(path: str) -> str: + """Read contents of a file""" + return f"Contents of {path}" + + +async def example_basic_usage(): + """Basic example: Using CugaLLMAdapter directly""" + print("\n=== Example 1: Basic CugaLLMAdapter Usage ===\n") + + # Create a CUGA agent with tools + agent = CugaAgent(tools=[delete_file, read_file]) + + # Wrap the agent's model with the adapter + assert agent._model is not None, "Agent model should be initialized" + llm_adapter = CugaLLMAdapter(agent._model) + + # Use the adapter with ToolGuard's message format + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Say hello in a friendly way."} + ] + + response = await llm_adapter.generate(messages) + print(f"Response: {response}\n") + + +async def example_json_output(): + """Example: Using chat_json for structured output""" + print("\n=== Example 2: JSON Output with chat_json ===\n") + + agent = CugaAgent(tools=[delete_file]) + assert agent._model is not None, "Agent model should be initialized" + llm_adapter = CugaLLMAdapter(agent._model) + + # Request JSON output + messages = [ + { + "role": "system", + "content": "You are a helpful assistant that responds in JSON format." + }, + { + "role": "user", + "content": ( + "Generate a JSON object with two fields: 'greeting' and 'language'. " + "Wrap your response in ```json``` code blocks." + ) + } + ] + + # chat_json automatically extracts JSON from the response + json_response = await llm_adapter.chat_json(messages) + print(f"JSON Response: {json_response}\n") + + +async def example_with_toolguard_manager(): + """Example: Using ToolGuardManager with CugaLLMAdapter""" + print("\n=== Example 3: ToolGuardManager Integration ===\n") + + from cuga.backend.cuga_graph.policy.tool_guard.manager import ToolGuardManager + + # Create agent with tools + agent = CugaAgent( + tools=[delete_file, read_file], + cuga_folder="/Users/naamazwerdling/Documents/OASB/policy_validation/cuga_demo/output" + ) + + # Create ToolGuard manager (automatically uses CugaLLMAdapter) + manager = ToolGuardManager(agent) + await manager.initialize() + + tool_count = len(manager.langchain_tools) if manager.langchain_tools else 0 + print(f"✅ ToolGuardManager initialized with {tool_count} tools") + print(f" LLM: {type(manager.llm).__name__}") + print(f" Underlying model: {type(agent._model).__name__}\n") + + # Add a tool guide policy + await agent.policies.add_tool_guide( + name="File Safety Guidelines", + target_tools=["delete_file"], + description="Ensure safe file deletion practices", + content=""" +# File Safety Guidelines + +## Rules +- Never delete system files (e.g., /etc/*, /sys/*, /boot/*) +- Never delete files without user confirmation +- Always validate file paths before deletion + +## Examples +- ❌ BAD: Delete /etc/passwd +- ✅ GOOD: Delete /home/user/temp.txt (with confirmation) + """.strip() + ) + + print("✅ Added tool guide policy for delete_file\n") + + # Get the policy + policies = await agent.policies.list() + if policies: + policy_info = policies[0] + print(f"Policy: {policy_info['name']} (ID: {policy_info['id']})\n") + + # Generate examples using ToolGuard with CUGA's LLM + print("🔄 Generating violating and compliance examples...") + policy_full = await agent.policies.get(policy_info['id']) + + if policy_full and policy_full.get('policy'): + try: + violating, compliance = await manager.generate_examples( + policy=policy_full['policy'], + target_tool="delete_file" + ) + + print("\n📋 Generated Examples:\n") + + print("❌ Violating Examples (actions that break the policy):") + for i, example in enumerate(violating, 1): + print(f" {i}. {example}") + + print("\n✅ Compliance Examples (actions that follow the policy):") + for i, example in enumerate(compliance, 1): + print(f" {i}. {example}") + + print() + + except Exception as e: + print(f"⚠️ Error generating examples: {e}") + import traceback + print(traceback.format_exc()) + print(" This might happen if the policy format is incompatible or LLM fails") + print() + else: + print("⚠️ Could not retrieve full policy details\n") + + +async def example_error_handling(): + """Example: Error handling with CugaLLMAdapter""" + print("\n=== Example 4: Error Handling ===\n") + + agent = CugaAgent(tools=[delete_file]) + assert agent._model is not None, "Agent model should be initialized" + llm_adapter = CugaLLMAdapter(agent._model) + + # Test with invalid message format + try: + invalid_messages = [ + {"content": "Missing role key"} # Missing 'role' + ] + await llm_adapter.generate(invalid_messages) + except ValueError as e: + print(f"✅ Caught expected error: {e}\n") + + # Test with valid messages + try: + valid_messages = [ + {"role": "user", "content": "Hello!"} + ] + response = await llm_adapter.generate(valid_messages) + print(f"✅ Valid request succeeded: {response[:50]}...\n") + except Exception as e: + print(f"❌ Unexpected error: {e}\n") + + +async def main(): + """Run all examples""" + print("=" * 60) + print("CugaLLMAdapter Usage Examples") + print("=" * 60) + + try: + #await example_basic_usage() + #await example_json_output() + await example_with_toolguard_manager() + #await example_error_handling() + + print("=" * 60) + print("✅ All examples completed successfully!") + print("=" * 60) + + except Exception as e: + print(f"\n❌ Error running examples: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(main()) + +# Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py b/src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py deleted file mode 100644 index 9faa9aa8..00000000 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py +++ /dev/null @@ -1,672 +0,0 @@ -""" -Example usage of the CUGA to Toolguard integration. - -This file demonstrates how to convert CUGA tools to Toolguard's OpenAPI format -and then to ToolInfo objects for policy enforcement. - -Includes examples of: -- Simple return types (str, int) -- Complex return types (Pydantic models, nested dicts, lists of objects) -- How OpenAPI specs reflect these complex structures -""" - -import asyncio -from pathlib import Path -from typing import List, Dict, Any -from datetime import datetime - -from loguru import logger -from langchain_core.tools import tool -from pydantic import BaseModel, Field - -from cuga.backend.cuga_graph.nodes.cuga_lite.direct_langchain_tools_provider import DirectLangChainToolsProvider -from cuga.backend.cuga_graph.policy.tool_guard import cuga_tools_to_oas - - -# ============================================================================ -# Complex Object Models (Pydantic) -# ============================================================================ - -class ContactInfo(BaseModel): - """Contact information model""" - id: int = Field(description="Unique contact identifier") - name: str = Field(description="Full name of the contact") - email: str = Field(description="Email address") - phone: str | None = Field(None, description="Optional phone number") - company: str | None = Field(None, description="Company name") - tags: List[str] = Field(default_factory=list, description="Contact tags") - - -class EmailMetadata(BaseModel): - """Email metadata model""" - sent_at: str = Field(description="Timestamp when email was sent") - message_id: str = Field(description="Unique message identifier") - status: str = Field(description="Delivery status") - - -class EmailResult(BaseModel): - """Email sending result with metadata""" - success: bool = Field(description="Whether email was sent successfully") - recipient: str = Field(description="Email recipient") - subject: str = Field(description="Email subject") - metadata: EmailMetadata = Field(description="Email metadata") - - -class WeatherData(BaseModel): - """Detailed weather information""" - location: str = Field(description="Location name") - temperature: float = Field(description="Current temperature") - units: str = Field(description="Temperature units (celsius/fahrenheit)") - conditions: str = Field(description="Weather conditions") - humidity: int = Field(description="Humidity percentage") - wind_speed: float = Field(description="Wind speed") - forecast: List[Dict[str, Any]] = Field(description="3-day forecast") - - -class SearchResult(BaseModel): - """Search result with pagination""" - query: str = Field(description="Search query used") - total_results: int = Field(description="Total number of results") - page: int = Field(description="Current page number") - results: List[ContactInfo] = Field(description="List of matching contacts") - has_more: bool = Field(description="Whether more results are available") - - -# ============================================================================ -# Simple Tools (String Return Types) -# ============================================================================ - -@tool -def calculate(expression: str) -> str: - """ - Evaluate a mathematical expression. - - Args: - expression: Mathematical expression to evaluate (e.g., "2 + 2", "10 * 5") - - Returns: - Result of the calculation - """ - try: - # Simple eval for demo - in production use a safe math parser - result = eval(expression, {"__builtins__": {}}, {}) - return f"Result: {result}" - except Exception as e: - return f"Error: {str(e)}" - - -# ============================================================================ -# Complex Tools (Pydantic Model Return Types) -# ============================================================================ - -@tool -def search_contacts_advanced(query: str, limit: int = 10, page: int = 1) -> SearchResult: - """ - Search for contacts in the CRM system with advanced filtering. - - This tool returns a complex SearchResult object containing: - - Query metadata (query string, pagination info) - - List of ContactInfo objects with full contact details - - Pagination information - - Args: - query: Search query string to find contacts - limit: Maximum number of results per page (default: 10) - page: Page number for pagination (default: 1) - - Returns: - SearchResult object with matching contacts and metadata - """ - # Simulate search results - contacts = [ - ContactInfo( - id=i, - name=f"Contact {i}", - email=f"contact{i}@example.com", - phone=f"+1-555-000{i}", - company=f"Company {i % 3}", - tags=["customer", "active"] if i % 2 == 0 else ["lead"] - ) - for i in range(1, min(limit + 1, 6)) - ] - - return SearchResult( - query=query, - total_results=50, - page=page, - results=contacts, - has_more=True - ) - - -@tool -def create_contact_advanced( - name: str, - email: str, - phone: str | None = None, - company: str | None = None, - tags: List[str] | None = None -) -> ContactInfo: - """ - Create a new contact in the CRM system. - - This tool returns a ContactInfo Pydantic model with all contact details. - The OpenAPI spec will automatically reflect the structure of ContactInfo. - - Args: - name: Full name of the contact - email: Email address of the contact - phone: Optional phone number - company: Optional company name - tags: Optional list of tags for categorization - - Returns: - ContactInfo object with the created contact details - """ - return ContactInfo( - id=12345, - name=name, - email=email, - phone=phone, - company=company, - tags=tags or [] - ) - - -@tool -def send_email_advanced(to: str, subject: str, body: str) -> EmailResult: - """ - Send an email to a recipient with detailed result tracking. - - This tool returns an EmailResult object containing: - - Success status - - Recipient information - - Email metadata (timestamp, message ID, delivery status) - - Args: - to: Email address of the recipient - subject: Subject line of the email - body: Body content of the email - - Returns: - EmailResult object with sending status and metadata - """ - return EmailResult( - success=True, - recipient=to, - subject=subject, - metadata=EmailMetadata( - sent_at=datetime.now().isoformat(), - message_id=f"msg_{hash(to + subject)}", - status="delivered" - ) - ) - - -@tool -def get_weather_detailed(location: str, units: str = "celsius") -> WeatherData: - """ - Get detailed weather information for a location. - - This tool returns a WeatherData object with comprehensive weather information: - - Current conditions (temperature, humidity, wind) - - 3-day forecast - - Location details - - The OpenAPI spec will show the full nested structure including the forecast array. - - Args: - location: City name or location to get weather for - units: Temperature units - 'celsius' or 'fahrenheit' (default: celsius) - - Returns: - WeatherData object with current conditions and forecast - """ - return WeatherData( - location=location, - temperature=22.5 if units == "celsius" else 72.5, - units=units, - conditions="Partly Cloudy", - humidity=65, - wind_speed=12.5, - forecast=[ - {"day": "Tomorrow", "high": 24, "low": 18, "conditions": "Sunny"}, - {"day": "Day 2", "high": 23, "low": 17, "conditions": "Cloudy"}, - {"day": "Day 3", "high": 21, "low": 16, "conditions": "Rainy"} - ] - ) - - -@tool -def get_contact_list(limit: int = 10) -> List[ContactInfo]: - """ - Get a list of all contacts. - - This tool returns a List of ContactInfo objects. - The OpenAPI spec will represent this as an array of ContactInfo schemas. - - Args: - limit: Maximum number of contacts to return - - Returns: - List of ContactInfo objects - """ - return [ - ContactInfo( - id=i, - name=f"Contact {i}", - email=f"contact{i}@example.com", - phone=f"+1-555-000{i}", - company=f"Company {i % 3}", - tags=["customer"] if i % 2 == 0 else ["lead"] - ) - for i in range(1, min(limit + 1, 11)) - ] - - -@tool -def get_contact_stats() -> Dict[str, Any]: - """ - Get statistics about contacts in the system. - - This tool returns a nested dictionary structure. - The OpenAPI spec will represent this as an object with nested properties. - - Returns: - Dictionary with contact statistics including counts by category and recent activity - """ - return { - "total_contacts": 1250, - "by_category": { - "customers": 800, - "leads": 350, - "partners": 100 - }, - "recent_activity": { - "added_today": 15, - "updated_today": 42, - "emails_sent_today": 127 - }, - "top_companies": [ - {"name": "Acme Corp", "contact_count": 45}, - {"name": "TechStart Inc", "contact_count": 38}, - {"name": "Global Solutions", "contact_count": 32} - ] - } - - -async def example_convert_tools_to_oas(): - """ - Example: Convert CUGA tools to OpenAPI specification. - - This example demonstrates how simple and complex return types are represented - in the generated OpenAPI specification. - """ - logger.info("Creating tool provider with sample tools...") - - # Create a direct tool provider with sample tools (mix of simple and complex) - sample_tools = [ - calculate, # Simple: returns str - search_contacts_advanced, # Complex: returns SearchResult (Pydantic model) - create_contact_advanced, # Complex: returns ContactInfo (Pydantic model) - send_email_advanced, # Complex: returns EmailResult with nested EmailMetadata - get_weather_detailed, # Complex: returns WeatherData with forecast array - get_contact_list, # Complex: returns List[ContactInfo] - get_contact_stats, # Complex: returns nested Dict - ] - - provider = DirectLangChainToolsProvider(tools=sample_tools, app_name="demo_app") - - # Initialize the provider - await provider.initialize() - - logger.info("Converting tools to OpenAPI specification...") - logger.info("Note: Complex return types (Pydantic models, nested dicts, lists) will be") - logger.info(" fully represented in the OpenAPI schema with their complete structure.") - - # Convert all tools to OpenAPI format - oas = await cuga_tools_to_oas( - tool_provider=provider, - title="CUGA Tools API - Complex Objects Demo", - version="1.0.0", - description="OpenAPI specification demonstrating complex return types including Pydantic models, nested objects, and lists", - ) - - # Save to file - output_path = Path("cuga_tools_openapi.yaml") - oas.save(output_path) - logger.info(f"OpenAPI specification saved to {output_path}") - logger.info("Check the file to see how complex objects are represented in the schema!") - - # Also save as JSON - json_path = Path("cuga_tools_openapi.json") - oas.save(json_path) - logger.info(f"OpenAPI specification saved to {json_path}") - - return oas - - -async def example_convert_to_toolinfo(): - """ - Example: Convert CUGA tools to Toolguard ToolInfo objects. - - This demonstrates how complex return types are preserved through the - OpenAPI -> ToolInfo conversion process. - """ - from toolguard.buildtime.gen_spec.oas_to_toolinfo import openapi_to_toolinfos - - logger.info("Creating tool provider with complex return type tools...") - - # Create a direct tool provider with tools that return complex objects - sample_tools = [ - calculate, - search_contacts_advanced, - create_contact_advanced, - send_email_advanced, - get_weather_detailed, - get_contact_list, - get_contact_stats, - ] - - provider = DirectLangChainToolsProvider(tools=sample_tools, app_name="demo_app") - await provider.initialize() - - logger.info("Converting tools to OpenAPI specification...") - - # Convert to OpenAPI - oas = await cuga_tools_to_oas(provider) - - logger.info("Converting OpenAPI to ToolInfo objects...") - - # Convert OpenAPI to ToolInfo objects (for Toolguard policy enforcement) - tool_infos = openapi_to_toolinfos(oas) - - logger.info(f"Converted {len(tool_infos)} tools to ToolInfo objects") - logger.info("Complex return types are preserved in the ToolInfo schema!") - - # Print some information about the tools - for tool_info in tool_infos: - logger.info(f"\nTool: {tool_info.name}") - logger.info(f" Summary: {tool_info.summary}") - logger.info(f" Parameters: {list(tool_info.parameters.keys())}") - # Note: Response schema information is also preserved in the ToolInfo - - return tool_infos - - -async def example_with_multiple_apps(): - """ - Example: Convert tools from multiple app groups with complex return types. - - This shows how different apps can have tools with different complexity levels. - """ - logger.info("Creating tool provider with multiple app groups...") - - # Create CRM tools (complex return types) - crm_tools = [search_contacts_advanced, create_contact_advanced, get_contact_list] - crm_provider = DirectLangChainToolsProvider(tools=crm_tools, app_name="crm") - await crm_provider.initialize() - - # Create email tools (complex return types) - email_tools = [send_email_advanced] - email_provider = DirectLangChainToolsProvider(tools=email_tools, app_name="email") - await email_provider.initialize() - - # Create weather tools (complex return types) - weather_tools = [get_weather_detailed] - weather_provider = DirectLangChainToolsProvider(tools=weather_tools, app_name="weather") - await weather_provider.initialize() - - logger.info("Converting CRM tools to OpenAPI specification...") - - # Convert CRM tools to OpenAPI - crm_oas = await cuga_tools_to_oas( - tool_provider=crm_provider, - title="CRM Tools API", - version="1.0.0", - description="OpenAPI specification for CRM tools with complex ContactInfo and SearchResult models", - ) - - logger.info("Converting Email tools to OpenAPI specification...") - - # Convert Email tools to OpenAPI - email_oas = await cuga_tools_to_oas( - tool_provider=email_provider, - title="Email Tools API", - version="1.0.0", - description="OpenAPI specification for Email tools with EmailResult and nested EmailMetadata", - ) - - logger.info("Converting Weather tools to OpenAPI specification...") - - # Convert Weather tools to OpenAPI - weather_oas = await cuga_tools_to_oas( - tool_provider=weather_provider, - title="Weather Tools API", - version="1.0.0", - description="OpenAPI specification for Weather tools with WeatherData including forecast arrays", - ) - - # Save to files - crm_path = Path("crm_tools_openapi.yaml") - crm_oas.save(crm_path) - logger.info(f"CRM OpenAPI specification saved to {crm_path}") - - email_path = Path("email_tools_openapi.yaml") - email_oas.save(email_path) - logger.info(f"Email OpenAPI specification saved to {email_path}") - - weather_path = Path("weather_tools_openapi.yaml") - weather_oas.save(weather_path) - logger.info(f"Weather OpenAPI specification saved to {weather_path}") - - logger.info("\nAll OpenAPI specs include full schema definitions for complex return types!") - - return crm_oas, email_oas, weather_oas - - -async def example_complex_objects_showcase(): - """ - Example: Showcase how different complex return types are represented in OpenAPI. - - This example demonstrates: - 1. Pydantic models -> Full JSON schema with all fields - 2. List[PydanticModel] -> Array schema with item definitions - 3. Nested objects -> Hierarchical schema structure - 4. Dict[str, Any] -> Object schema with nested properties - """ - logger.info("=" * 80) - logger.info("COMPLEX OBJECT RETURN TYPES SHOWCASE") - logger.info("=" * 80) - - # Create provider with one tool of each type - showcase_tools = [ - create_contact_advanced, # Returns: ContactInfo (Pydantic model) - get_contact_list, # Returns: List[ContactInfo] - send_email_advanced, # Returns: EmailResult with nested EmailMetadata - get_weather_detailed, # Returns: WeatherData with forecast array - get_contact_stats, # Returns: Dict[str, Any] with nested structure - ] - - provider = DirectLangChainToolsProvider(tools=showcase_tools, app_name="showcase") - await provider.initialize() - - logger.info("\nConverting tools with complex return types to OpenAPI...") - - oas = await cuga_tools_to_oas( - tool_provider=provider, - title="Complex Return Types Showcase", - version="1.0.0", - description="Demonstrates how various complex return types are represented in OpenAPI specs", - ) - - # Save the spec - output_path = Path("complex_objects_showcase.yaml") - oas.save(output_path) - - logger.info(f"\n✅ OpenAPI spec saved to: {output_path}") - logger.info("\nWhat to look for in the generated spec:") - logger.info(" 1. 'components/schemas' section - Contains all Pydantic model definitions") - logger.info(" 2. Each tool's response schema - References the component schemas") - logger.info(" 3. Nested objects - Fully expanded with all properties") - logger.info(" 4. Arrays - Properly typed with 'items' schema") - logger.info(" 5. Optional fields - Marked appropriately in the schema") - - logger.info("\n" + "=" * 80) - logger.info("EXAMPLE OUTPUT STRUCTURES:") - logger.info("=" * 80) - - # Show example outputs - logger.info("\n1. ContactInfo (Pydantic Model):") - logger.info(" {") - logger.info(' "id": 12345,') - logger.info(' "name": "John Doe",') - logger.info(' "email": "john@example.com",') - logger.info(' "phone": "+1-555-0001",') - logger.info(' "company": "Acme Corp",') - logger.info(' "tags": ["customer", "vip"]') - logger.info(" }") - - logger.info("\n2. List[ContactInfo]:") - logger.info(" [") - logger.info(" { ContactInfo object 1 },") - logger.info(" { ContactInfo object 2 },") - logger.info(" ...") - logger.info(" ]") - - logger.info("\n3. EmailResult with nested EmailMetadata:") - logger.info(" {") - logger.info(' "success": true,') - logger.info(' "recipient": "user@example.com",') - logger.info(' "subject": "Hello",') - logger.info(' "metadata": {') - logger.info(' "sent_at": "2024-01-01T12:00:00",') - logger.info(' "message_id": "msg_12345",') - logger.info(' "status": "delivered"') - logger.info(" }") - logger.info(" }") - - logger.info("\n4. WeatherData with forecast array:") - logger.info(" {") - logger.info(' "location": "New York",') - logger.info(' "temperature": 22.5,') - logger.info(' "forecast": [') - logger.info(' {"day": "Tomorrow", "high": 24, "low": 18},') - logger.info(" ...") - logger.info(" ]") - logger.info(" }") - - logger.info("\n5. Nested Dict[str, Any]:") - logger.info(" {") - logger.info(' "total_contacts": 1250,') - logger.info(' "by_category": {') - logger.info(' "customers": 800,') - logger.info(' "leads": 350') - logger.info(" },") - logger.info(' "top_companies": [') - logger.info(' {"name": "Acme", "contact_count": 45}') - logger.info(" ]") - logger.info(" }") - - logger.info("\n" + "=" * 80) - logger.info("✅ All these structures are fully represented in the OpenAPI spec!") - logger.info("=" * 80) - - return oas - - -async def main(): - """ - Run all examples demonstrating complex object return types. - """ - logger.info("\n" + "🚀 " * 20) - logger.info("CUGA TO TOOLGUARD - COMPLEX OBJECT EXAMPLES") - logger.info("🚀 " * 20 + "\n") - - logger.info("=" * 80) - logger.info("Example 1: Complex Objects Showcase") - logger.info("=" * 80) - try: - await example_complex_objects_showcase() - except Exception as e: - logger.error(f"Example 1 failed: {e}") - import traceback - traceback.print_exc() - - logger.info("\n" + "=" * 80) - logger.info("Example 2: Convert all tools to OpenAPI") - logger.info("=" * 80) - try: - await example_convert_tools_to_oas() - except Exception as e: - logger.error(f"Example 2 failed: {e}") - import traceback - traceback.print_exc() - - logger.info("\n" + "=" * 80) - logger.info("Example 3: Convert to ToolInfo objects") - logger.info("=" * 80) - try: - await example_convert_to_toolinfo() - except Exception as e: - logger.error(f"Example 3 failed: {e}") - import traceback - traceback.print_exc() - - logger.info("\n" + "=" * 80) - logger.info("Example 4: Convert multiple app groups") - logger.info("=" * 80) - try: - await example_with_multiple_apps() - except Exception as e: - logger.error(f"Example 4 failed: {e}") - import traceback - traceback.print_exc() - - logger.info("\n" + "🎉 " * 20) - logger.info("ALL EXAMPLES COMPLETED!") - logger.info("🎉 " * 20 + "\n") - - -if __name__ == "__main__": - asyncio.run(main()) - - -# ============================================================================ -# SUMMARY: Complex Object Return Types in OpenAPI Specs -# ============================================================================ -# -# This file demonstrates that YES, OpenAPI specs WILL reflect complex object -# return types! Here's what's included: -# -# 1. PYDANTIC MODELS (ContactInfo, EmailResult, WeatherData, etc.) -# - Automatically converted to JSON Schema -# - All fields with proper types (int, str, bool, etc.) -# - Nested models preserved (EmailResult contains EmailMetadata) -# - Optional fields marked appropriately -# -# 2. LIST OF COMPLEX OBJECTS (List[ContactInfo], List[ProductModel]) -# - Represented as arrays in OpenAPI -# - Item schema fully defined with all properties -# -# 3. NESTED DICTIONARIES (get_contact_stats returns Dict[str, Any]) -# - Hierarchical structure preserved -# - Nested objects and arrays maintained -# -# 4. MIXED COMPLEXITY -# - WeatherData includes both simple fields AND a forecast array -# - SearchResult includes metadata AND a list of ContactInfo objects -# -# The generated OpenAPI specification will include: -# - Full schema definitions in 'components/schemas' -# - Response schemas for each tool referencing these components -# - Complete type information for all fields -# - Proper handling of optional/required fields -# -# Run this file to see the generated OpenAPI specs: -# python src/cuga/backend/cuga_graph/policy/tool_guard/example_usage.py -# -# Then check the generated YAML/JSON files to see how complex objects are -# represented in the OpenAPI specification! -# -# Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py new file mode 100644 index 00000000..c1183d10 --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py @@ -0,0 +1,190 @@ +""" +Manager for generating tool guard examples using the toolguard library. + +This module provides integration between CUGA's policy system and toolguard's +example generation capabilities to create violating and compliance examples +for tool usage policies. +""" + +from typing import List, Tuple, Dict, Any +from pathlib import Path +from loguru import logger + +from cuga.backend.cuga_graph.policy.models import ToolGuide, PolicyType +from cuga.backend.cuga_graph.policy.tool_guard.cuga_llm_adapter import CugaLLMAdapter + +from cuga.sdk import CugaAgent +from toolguard.buildtime.buildtime import ToolGuardSpec, generate_guard_examples +from toolguard.extra.langchain_to_oas import langchain_tools_to_openapi +from toolguard.runtime.data_types import ToolGuardSpecItem + + + +class ToolGuardManager: + + def __init__( + self, + agent: CugaAgent, + ): + + self.langchain_tools: List = [] # Store LangChain tools + self.tools_dict: Dict[str, Any] = {} # Store OpenAPI dict for ToolGuard + self._initialized = False + + self.tool_provider = agent.tool_provider + + # Wrap CUGA's LLM with adapter for ToolGuard compatibility + if agent._model is None: + raise ValueError( + "Agent model is not initialized. Ensure the CugaAgent has a valid model " + "before creating ToolGuardManager." + ) + + self.llm = CugaLLMAdapter(agent._model) + logger.info(f"Initialized ToolGuardManager with {type(agent._model).__name__} via CugaLLMAdapter") + + # Use Path to properly handle path concatenation + self.work_dir = str(Path(agent.cuga_folder) / "toolguard") + + + async def initialize( + self, + ) -> None: + + logger.info("Initializing ToolGuardManager...") + + # Get all LangChain tools from the tool provider + self.langchain_tools = await self.tool_provider.get_all_tools() + + # Convert LangChain tools to OpenAPI dict using ToolGuard's utility + self.tools_dict = langchain_tools_to_openapi(self.langchain_tools) + + self._initialized = True + logger.info(f"✅ ToolGuardManager initialized with {len(self.langchain_tools)} tools") + + async def generate_examples( + self, + policy: ToolGuide, + target_tool: str + ) -> Tuple[List[str], List[str]]: + """ + Generate violating and compliance examples for a specific tool in a ToolGuide policy. + + Args: + policy: ToolGuide policy to generate examples for + target_tool: Specific tool name to generate examples for + + Returns: + Tuple of (violating_examples, compliance_examples) + + Raises: + RuntimeError: If manager not initialized + ValueError: If policy is not a ToolGuide or target_tool not in policy.target_tools + """ + # Ensure manager is initialized + if not self._initialized or not self.langchain_tools: + raise RuntimeError( + "ToolGuardManager not initialized. Call initialize() first with a tool provider." + ) + + # Validate policy type + if policy.type != PolicyType.TOOL_GUIDE: + raise ValueError( + f"Policy must be of type 'tool_guide', got '{policy.type}'. " + f"Only tool_guide policies can generate examples." + ) + + # Validate that target_tool is in policy.target_tools or policy has wildcard + if "*" not in policy.target_tools and target_tool not in policy.target_tools: + raise ValueError( + f"Tool '{target_tool}' is not in policy.target_tools. " + f"Policy targets: {policy.target_tools}" + ) + + # Verify tool exists in our LangChain tools + tool_names = [tool.name for tool in self.langchain_tools] + if target_tool not in tool_names: + raise ValueError( + f"Tool '{target_tool}' not found in available tools. " + f"Available tools: {tool_names}" + ) + + logger.info(f"Generating examples for tool '{target_tool}'...") + + # Create ToolGuardSpecItem with policy information + # Concatenate description and guide_content for the description field + description = policy.description + if policy.guide_content: + description = f"{description}\n\n{policy.guide_content}" + + spec_item = ToolGuardSpecItem( + name=policy.name, + description=description, + references=[policy.guide_content] if policy.guide_content else [] + ) + + # Create ToolGuardSpec with the spec item + spec = ToolGuardSpec( + tool_name=target_tool, + policy_items=[spec_item] + ) + + # Generate examples using toolguard + try: + updated_specs = await generate_guard_examples( + tools=self.tools_dict, # Pass the OpenAPI dict + tool_specs=[spec], + llm=self.llm, # type: ignore + work_dir=self.work_dir + ) + + # Extract examples from the updated spec + if updated_specs and len(updated_specs) > 0: + updated_spec = updated_specs[0] + if updated_spec.policy_items and len(updated_spec.policy_items) > 0: + policy_item = updated_spec.policy_items[0] + + violating_examples = policy_item.violation_examples + compliance_examples = policy_item.compliance_examples + + logger.info( + f"✅ Generated {len(violating_examples)} violating and {len(compliance_examples)} " + f"compliance examples for tool '{target_tool}'" + ) + + return violating_examples, compliance_examples + else: + logger.warning(f"No policy items in updated spec for tool '{target_tool}'") + return [], [] + else: + logger.warning(f"No results returned for tool '{target_tool}'") + return [], [] + + except Exception as e: + logger.error( + f"❌ Failed to generate examples for tool '{target_tool}': {e}" + ) + raise + + @property + def is_initialized(self) -> bool: + """Check if the manager has been initialized.""" + return self._initialized + + def get_available_tools(self) -> List[str]: + """ + Get list of available tool names. + + Returns: + List of tool names that can be used in policies + + Raises: + RuntimeError: If manager not initialized + """ + if not self._initialized or not self.langchain_tools: + raise RuntimeError("ToolGuardManager not initialized. Call initialize() first.") + + return [tool.name for tool in self.langchain_tools] + + +# Made with Bob \ No newline at end of file diff --git a/src/cuga/sdk_core/tests/debug_tool_guide.py b/src/cuga/sdk_core/tests/debug_tool_guide.py index c3e1f8ef..d84150ed 100644 --- a/src/cuga/sdk_core/tests/debug_tool_guide.py +++ b/src/cuga/sdk_core/tests/debug_tool_guide.py @@ -42,7 +42,19 @@ async def main(): description="Safety guidelines for file deletion operations to prevent accidental data loss", ) print(f"Created policy with ID: {policy_id}") - + policies = await agent.policies.list() + print(policies[0]["name"]) + print(policies[0]["type"]) + + policy_details = agent.policies.get(policy_id) + for policy in policies: + print(f"ID: {policy['id']}") + print(f"Name: {policy['name']}") + print(f"Type: {policy['type']}") + print(f"Enabled: {policy['enabled']}") + print(f"Priority: {policy['priority']}") + print("Policy updated with tool_guards") + # Step 2b: Update the policy with tool_guards print("\nStep 2b: Updating policy with tool_guards...") await agent.policies.update_tool_guard( diff --git a/src/cuga/sdk_core/tests/test_sdk_policies.py b/src/cuga/sdk_core/tests/test_sdk_policies.py index 1f591db3..ced00b9b 100644 --- a/src/cuga/sdk_core/tests/test_sdk_policies.py +++ b/src/cuga/sdk_core/tests/test_sdk_policies.py @@ -428,6 +428,7 @@ async def test_tool_guide_wildcard(self): assert policy_id is not None policy_details = await agent.policies.get(policy_id) + assert policy_details is not None assert policy_details["policy"].target_tools == ["*"] From b478b3e14d0f514ced3cfe256e5da480a214c1aa Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 20 Apr 2026 11:16:13 +0300 Subject: [PATCH 04/24] examples in sdk --- .../cuga_graph/policy/tool_guard/__init__.py | 3 +- src/cuga/sdk.py | 82 ++++++++++++++++++- src/cuga/sdk_core/tests/debug_tool_guide.py | 35 ++++---- 3 files changed, 101 insertions(+), 19 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py index 4b886b9a..73312cf8 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py @@ -5,9 +5,8 @@ policy enforcement framework. """ -from .cuga_to_tool_info import cuga_tools_to_oas from .manager import ToolGuardManager -__all__ = ["cuga_tools_to_oas", "ToolGuardManager"] +__all__ = ["ToolGuardManager"] # Made with Bob diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index 6a76be26..1c746024 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -68,7 +68,7 @@ def delete_database(table: str) -> str: ``` """ -from typing import List, Optional, Dict, Any, Union, TYPE_CHECKING +from typing import List, Optional, Dict, Any, Union, TYPE_CHECKING, Tuple import uuid from loguru import logger from pydantic import BaseModel, Field @@ -1213,7 +1213,85 @@ async def sync_from_filesystem(self) -> Dict[str, Any]: } except Exception as e: logger.error(f"Failed to sync from filesystem: {e}") - return {"loaded": 0, "removed": 0, "errors": [str(e)]} + + async def generate_tool_guard_examples( + self, + policy_id: str, + target_tool: str + ) -> Tuple[List[str], List[str]]: + """ + Generate violating and compliance examples for a specific tool in a policy. + + This method uses the ToolGuardManager to generate examples that demonstrate + both violations and compliance with the policy guidelines for a specific tool. + + Args: + policy_id: The ID of the policy to generate examples for + target_tool: The specific tool name to generate examples for + + Returns: + Tuple of (violating_examples, compliance_examples) + + Raises: + ValueError: If policy not found, not a ToolGuide, or target_tool not in policy + RuntimeError: If ToolGuardManager initialization fails + + Example: + ```python + agent = CugaAgent(tools=[delete_file]) + + # Add a tool guide policy + policy_id = await agent.policies.add_tool_guide( + name="File Safety", + target_tools=["delete_file"], + content="Never delete system files" + ) + + # Generate examples + violating, compliance = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="delete_file" + ) + + print(f"Violating: {violating}") + print(f"Compliance: {compliance}") + ``` + """ + from cuga.backend.cuga_graph.policy.tool_guard.manager import ToolGuardManager + from cuga.backend.cuga_graph.policy.models import PolicyType + + # Ensure policy system is initialized + policy_system = await self._ensure_policy_system() + if policy_system is None: + raise RuntimeError("Policy system is disabled") + + # Get the policy + policy_data = await self.get(policy_id) + if policy_data is None: + raise ValueError(f"Policy with ID '{policy_id}' not found") + + policy = policy_data.get('policy') + if policy is None: + raise ValueError(f"Could not retrieve policy object for ID '{policy_id}'") + + # Validate policy type + if policy.type != PolicyType.TOOL_GUIDE: + raise ValueError( + f"Policy must be of type 'tool_guide', got '{policy.type}'. " + f"Only tool_guide policies can generate examples." + ) + + # Create and initialize ToolGuardManager + manager = ToolGuardManager(self._agent) + await manager.initialize() + + # Generate examples using the manager + violating_examples, compliance_examples = await manager.generate_examples( + policy=policy, + target_tool=target_tool + ) + + return violating_examples, compliance_examples class CugaAgent: diff --git a/src/cuga/sdk_core/tests/debug_tool_guide.py b/src/cuga/sdk_core/tests/debug_tool_guide.py index d84150ed..7629ec20 100644 --- a/src/cuga/sdk_core/tests/debug_tool_guide.py +++ b/src/cuga/sdk_core/tests/debug_tool_guide.py @@ -55,30 +55,35 @@ async def main(): print(f"Priority: {policy['priority']}") print("Policy updated with tool_guards") - # Step 2b: Update the policy with tool_guards - print("\nStep 2b: Updating policy with tool_guards...") + # Step 2b: Generate examples using the new SDK function + print("\nStep 2b: Generating tool guard examples using SDK...") + violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="delete_file" + ) + + print(f"Generated {len(violating_examples)} violating examples and {len(compliance_examples)} compliance examples") + print("\nViolating examples:") + for i, example in enumerate(violating_examples, 1): + print(f" {i}. {example}") + print("\nCompliance examples:") + for i, example in enumerate(compliance_examples, 1): + print(f" {i}. {example}") + + # Step 2c: Update the policy with generated tool_guards + print("\nStep 2c: Updating policy with generated tool_guards...") await agent.policies.update_tool_guard( policy_id=policy_id, tool_guards={ "delete_file": { "description": "Guard rules for safe file deletion to prevent accidental data loss", - "violating_examples": [ - "Deleting system files like /etc/passwd or C:\\Windows\\System32", - "Deleting files without user confirmation", - "Deleting files in protected directories", - "Bulk deletion without verification" - ], - "compliance_examples": [ - "Verify file path is in user's home directory before deletion", - "Request explicit user confirmation before deleting", - "Check file is not a system file or in protected directory", - "Log all deletion operations with timestamp and user" - ], + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, "policy_code": "" } } ) - print(f"Updated policy {policy_id} with tool_guards") + print(f"Updated policy {policy_id} with generated tool_guards") # Step 3: Save the policy print("\nStep 3: Saving policy...") From c432d21b78db21197b67bd5a057ce1b87a37660f Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 20 Apr 2026 14:38:05 +0300 Subject: [PATCH 05/24] buildtime --- .../cuga_graph/policy/tool_guard/manager.py | 354 +++++++++++++----- src/cuga/sdk.py | 97 +++++ src/cuga/sdk_core/tests/debug_tool_guide.py | 171 --------- .../tests/policies-export-2025-12-31.json | 132 ------- 4 files changed, 356 insertions(+), 398 deletions(-) delete mode 100644 src/cuga/sdk_core/tests/debug_tool_guide.py delete mode 100644 src/cuga/sdk_core/tests/policies-export-2025-12-31.json diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py index c1183d10..1a8c21e7 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py @@ -6,7 +6,9 @@ for tool usage policies. """ -from typing import List, Tuple, Dict, Any +import shutil +from contextlib import contextmanager +from typing import List, Tuple, Dict, Any, Iterator, Optional from pathlib import Path from loguru import logger @@ -14,176 +16,338 @@ from cuga.backend.cuga_graph.policy.tool_guard.cuga_llm_adapter import CugaLLMAdapter from cuga.sdk import CugaAgent -from toolguard.buildtime.buildtime import ToolGuardSpec, generate_guard_examples +from toolguard.buildtime.buildtime import ToolGuardSpec, generate_guard_examples, generate_guards_code from toolguard.extra.langchain_to_oas import langchain_tools_to_openapi -from toolguard.runtime.data_types import ToolGuardSpecItem - +from toolguard.runtime.data_types import ToolGuardSpecItem, ToolGuardsCodeGenerationResult class ToolGuardManager: - + def __init__( self, agent: CugaAgent, ): - - self.langchain_tools: List = [] # Store LangChain tools + self.langchain_tools: List[Any] = [] # Store LangChain tools self.tools_dict: Dict[str, Any] = {} # Store OpenAPI dict for ToolGuard self._initialized = False - + self.tool_provider = agent.tool_provider - + # Wrap CUGA's LLM with adapter for ToolGuard compatibility if agent._model is None: raise ValueError( "Agent model is not initialized. Ensure the CugaAgent has a valid model " "before creating ToolGuardManager." ) - + self.llm = CugaLLMAdapter(agent._model) logger.info(f"Initialized ToolGuardManager with {type(agent._model).__name__} via CugaLLMAdapter") - - # Use Path to properly handle path concatenation - self.work_dir = str(Path(agent.cuga_folder) / "toolguard") - - + # Use Path to properly handle path concatenation and ensure directory exists + work_dir_path = Path(agent.cuga_folder) / "toolguard" + work_dir_path.mkdir(parents=True, exist_ok=True) + self.work_dir = str(work_dir_path) + logger.debug(f"ToolGuard working directory: {self.work_dir}") + async def initialize( self, ) -> None: - logger.info("Initializing ToolGuardManager...") - + # Get all LangChain tools from the tool provider self.langchain_tools = await self.tool_provider.get_all_tools() - + # Convert LangChain tools to OpenAPI dict using ToolGuard's utility self.tools_dict = langchain_tools_to_openapi(self.langchain_tools) - + self._initialized = True logger.info(f"✅ ToolGuardManager initialized with {len(self.langchain_tools)} tools") - - async def generate_examples( - self, - policy: ToolGuide, - target_tool: str - ) -> Tuple[List[str], List[str]]: - """ - Generate violating and compliance examples for a specific tool in a ToolGuide policy. - - Args: - policy: ToolGuide policy to generate examples for - target_tool: Specific tool name to generate examples for - - Returns: - Tuple of (violating_examples, compliance_examples) - - Raises: - RuntimeError: If manager not initialized - ValueError: If policy is not a ToolGuide or target_tool not in policy.target_tools - """ - # Ensure manager is initialized + + def _ensure_initialized(self) -> None: + """Ensure manager is initialized, raise RuntimeError if not.""" if not self._initialized or not self.langchain_tools: raise RuntimeError( "ToolGuardManager not initialized. Call initialize() first with a tool provider." ) - - # Validate policy type + + def _validate_policy_and_tool(self, policy: ToolGuide, target_tool: str) -> None: + """Validate policy type and target tool.""" if policy.type != PolicyType.TOOL_GUIDE: raise ValueError( f"Policy must be of type 'tool_guide', got '{policy.type}'. " f"Only tool_guide policies can generate examples." ) - - # Validate that target_tool is in policy.target_tools or policy has wildcard + if "*" not in policy.target_tools and target_tool not in policy.target_tools: raise ValueError( f"Tool '{target_tool}' is not in policy.target_tools. " f"Policy targets: {policy.target_tools}" ) - - # Verify tool exists in our LangChain tools + tool_names = [tool.name for tool in self.langchain_tools] if target_tool not in tool_names: raise ValueError( f"Tool '{target_tool}' not found in available tools. " f"Available tools: {tool_names}" ) - - logger.info(f"Generating examples for tool '{target_tool}'...") - - # Create ToolGuardSpecItem with policy information - # Concatenate description and guide_content for the description field + + def _build_description(self, policy: ToolGuide) -> str: + """Build description from policy, concatenating guide_content if present.""" description = policy.description if policy.guide_content: description = f"{description}\n\n{policy.guide_content}" - - spec_item = ToolGuardSpecItem( - name=policy.name, - description=description, - references=[policy.guide_content] if policy.guide_content else [] - ) - + return description + + def _create_spec_item( + self, + policy: ToolGuide, + violating_examples: Optional[List[str]] = None, + compliance_examples: Optional[List[str]] = None + ) -> ToolGuardSpecItem: + """Create ToolGuardSpecItem from policy.""" + description = self._build_description(policy) + + kwargs: Dict[str, Any] = { + "name": policy.name, + "description": description, + "references": [policy.guide_content] if policy.guide_content else [] + } + + if violating_examples is not None: + kwargs["violation_examples"] = violating_examples + if compliance_examples is not None: + kwargs["compliance_examples"] = compliance_examples + + return ToolGuardSpecItem(**kwargs) + + @contextmanager + def _temp_directory(self) -> Iterator[Path]: + """Context manager for temporary toolguard directory.""" + tmp_dir = Path(self.work_dir) / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + try: + yield tmp_dir + finally: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + logger.debug(f"Cleaned up temporary directory: {tmp_dir}") + + def _save_domain_files(self, result: ToolGuardsCodeGenerationResult) -> None: + """Save RuntimeDomain files to domain directory.""" + work_dir_path = Path(self.work_dir) + domain_dir = work_dir_path / "domain" + domain_dir.mkdir(parents=True, exist_ok=True) + + for attr_name in ["app_types", "app_api", "app_api_impl"]: + domain_file = getattr(result.domain, attr_name) + domain_file.save(domain_dir) + logger.info(f"Saved {attr_name} to {domain_dir / domain_file.file_name}") + + async def generate_examples( + self, + policy: ToolGuide, + target_tool: str + ) -> Tuple[List[str], List[str]]: + """ + Generate violating and compliance examples for a specific tool in a ToolGuide policy. + + Args: + policy: ToolGuide policy to generate examples for + target_tool: Specific tool name to generate examples for + + Returns: + Tuple of (violating_examples, compliance_examples) + + Raises: + RuntimeError: If manager not initialized + ValueError: If policy is not a ToolGuide or target_tool not in policy.target_tools + """ + self._ensure_initialized() + self._validate_policy_and_tool(policy, target_tool) + + logger.info(f"Generating examples for tool '{target_tool}'...") + + # Create ToolGuardSpecItem with policy information + spec_item = self._create_spec_item(policy) + # Create ToolGuardSpec with the spec item spec = ToolGuardSpec( tool_name=target_tool, policy_items=[spec_item] ) - + # Generate examples using toolguard - try: - updated_specs = await generate_guard_examples( - tools=self.tools_dict, # Pass the OpenAPI dict - tool_specs=[spec], - llm=self.llm, # type: ignore - work_dir=self.work_dir + with self._temp_directory() as tmp_dir: + try: + updated_specs = await generate_guard_examples( + tools=self.tools_dict, # Pass the OpenAPI dict + tool_specs=[spec], + llm=self.llm, # type: ignore + work_dir=str(tmp_dir) + ) + + # Extract examples from the updated spec + if updated_specs: + updated_spec = updated_specs[0] + if updated_spec.policy_items: + policy_item = updated_spec.policy_items[0] + + violating_examples = policy_item.violation_examples + compliance_examples = policy_item.compliance_examples + + logger.info( + f"✅ Generated {len(violating_examples)} violating and {len(compliance_examples)} " + f"compliance examples for tool '{target_tool}'" + ) + + return violating_examples, compliance_examples + else: + logger.warning(f"No policy items in updated spec for tool '{target_tool}'") + return [], [] + else: + logger.warning(f"No results returned for tool '{target_tool}'") + return [], [] + + except Exception as e: + logger.error( + f"❌ Failed to generate examples for tool '{target_tool}': {e}" + ) + raise RuntimeError(f"Failed to generate examples for tool '{target_tool}'") from e + + async def generate_guard_code( + self, + policy: ToolGuide, + target_tool: str, + app_name: str = "cuga_app" + ) -> str: + """ + Generate guard code for a specific tool in a ToolGuide policy. + + This method creates a ToolGuardSpec from the policy, validates it has examples, + calls toolguard's generate_guards_code, saves the RuntimeDomain to a file, + and returns the generated guard code content. + + Args: + policy: ToolGuide policy to generate guard code for + target_tool: Specific tool name to generate guard code for + app_name: Application name for the generated code (default: "cuga_app") + + Returns: + String containing the generated guard code + + Raises: + RuntimeError: If manager not initialized + ValueError: If policy is not a ToolGuide, target_tool not in policy.target_tools, + or if the policy doesn't have examples for the target tool + """ + self._ensure_initialized() + self._validate_policy_and_tool(policy, target_tool) + + logger.info(f"Generating guard code for tool '{target_tool}'...") + + # Check if policy has tool_guards for this specific tool + tool_guard = None + if policy.tool_guards and target_tool in policy.tool_guards: + tool_guard = policy.tool_guards[target_tool] + + # Validate that we have examples (either from tool_guards or need to generate them first) + if tool_guard: + violating_examples = tool_guard.violating_examples + compliance_examples = tool_guard.compliance_examples + else: + violating_examples = [] + compliance_examples = [] + + # Ensure we have examples + if not violating_examples and not compliance_examples: + raise ValueError( + f"Policy for tool '{target_tool}' must have examples before generating guard code. " + f"Call generate_examples() first to create examples, or provide them in the policy's tool_guards." ) - - # Extract examples from the updated spec - if updated_specs and len(updated_specs) > 0: - updated_spec = updated_specs[0] - if updated_spec.policy_items and len(updated_spec.policy_items) > 0: - policy_item = updated_spec.policy_items[0] + + # Create ToolGuardSpecItem with policy information and examples + spec_item = self._create_spec_item( + policy, + violating_examples=violating_examples, + compliance_examples=compliance_examples + ) + + # Create ToolGuardSpec with the spec item + spec = ToolGuardSpec( + tool_name=target_tool, + policy_items=[spec_item] + ) + + # Generate guard code using toolguard + with self._temp_directory() as tmp_dir: + try: + result: ToolGuardsCodeGenerationResult = await generate_guards_code( + tools=self.tools_dict, # Pass the OpenAPI dict + tool_specs=[spec], + work_dir=str(tmp_dir), + llm=self.llm, # type: ignore + app_name=app_name + ) + + # Save RuntimeDomain files directly under toolguard directory (not in tmp) + self._save_domain_files(result) + + # Extract guard code from the result + if target_tool in result.tools: + tool_result = result.tools[target_tool] + + # Get the item guard file content (should be only one) + if not tool_result.item_guard_files: + raise ValueError( + f"No item guard files generated for tool '{target_tool}'" + ) - violating_examples = policy_item.violation_examples - compliance_examples = policy_item.compliance_examples + if len(tool_result.item_guard_files) > 1: + logger.warning( + f"Multiple item guard files found for tool '{target_tool}', using the first one" + ) + item_guard_file = tool_result.item_guard_files[0] + if item_guard_file is None: + raise ValueError( + f"Item guard file is None for tool '{target_tool}'" + ) + + guard_code = item_guard_file.content + logger.info( - f"✅ Generated {len(violating_examples)} violating and {len(compliance_examples)} " - f"compliance examples for tool '{target_tool}'" + f"✅ Generated guard code for tool '{target_tool}' " + f"(guard function: {tool_result.guard_fn_name})" ) - - return violating_examples, compliance_examples + + return guard_code else: - logger.warning(f"No policy items in updated spec for tool '{target_tool}'") - return [], [] - else: - logger.warning(f"No results returned for tool '{target_tool}'") - return [], [] - - except Exception as e: - logger.error( - f"❌ Failed to generate examples for tool '{target_tool}': {e}" - ) - raise - + raise ValueError( + f"Tool '{target_tool}' not found in generation results. " + f"Available tools: {list(result.tools.keys())}" + ) + + except Exception as e: + logger.error( + f"❌ Failed to generate guard code for tool '{target_tool}': {e}" + ) + raise RuntimeError(f"Failed to generate guard code for tool '{target_tool}'") from e + @property def is_initialized(self) -> bool: """Check if the manager has been initialized.""" return self._initialized - + def get_available_tools(self) -> List[str]: """ Get list of available tool names. - + Returns: List of tool names that can be used in policies - + Raises: RuntimeError: If manager not initialized """ - if not self._initialized or not self.langchain_tools: - raise RuntimeError("ToolGuardManager not initialized. Call initialize() first.") - + self._ensure_initialized() return [tool.name for tool in self.langchain_tools] diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index 1c746024..f0b8176a 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -1292,6 +1292,103 @@ async def generate_tool_guard_examples( ) return violating_examples, compliance_examples + + async def generate_tool_guard_code( + self, + policy_id: str, + target_tool: str, + app_name: str = "cuga_app" + ) -> str: + """ + Generate guard code for a specific tool in a policy. + + This method uses the ToolGuardManager to generate executable guard code + that validates tool usage compliance with the policy guidelines. + + Args: + policy_id: The ID of the policy to generate guard code for + target_tool: The specific tool name to generate guard code for + app_name: Application name for the generated code (default: "cuga_app") + + Returns: + String containing the generated guard code + + Raises: + ValueError: If policy not found, not a ToolGuide, target_tool not in policy, + or if the policy doesn't have examples for the target tool + RuntimeError: If ToolGuardManager initialization fails + + Example: + ```python + agent = CugaAgent(tools=[delete_file]) + + # Add a tool guide policy with examples + policy_id = await agent.policies.add_tool_guide( + name="File Safety", + target_tools=["delete_file"], + content="Never delete system files" + ) + + # Generate examples first + violating, compliance = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="delete_file" + ) + + # Update policy with examples + await agent.policies.update_tool_guard( + policy_id=policy_id, + target_tool="delete_file", + violating_examples=violating, + compliance_examples=compliance + ) + + # Generate guard code + guard_code = await agent.policies.generate_tool_guard_code( + policy_id=policy_id, + target_tool="delete_file" + ) + + print(f"Generated guard code:\n{guard_code}") + ``` + """ + from cuga.backend.cuga_graph.policy.tool_guard.manager import ToolGuardManager + from cuga.backend.cuga_graph.policy.models import PolicyType + + # Ensure policy system is initialized + policy_system = await self._ensure_policy_system() + if policy_system is None: + raise RuntimeError("Policy system is disabled") + + # Get the policy + policy_data = await self.get(policy_id) + if policy_data is None: + raise ValueError(f"Policy with ID '{policy_id}' not found") + + policy = policy_data.get('policy') + if policy is None: + raise ValueError(f"Could not retrieve policy object for ID '{policy_id}'") + + # Validate policy type + if policy.type != PolicyType.TOOL_GUIDE: + raise ValueError( + f"Policy must be of type 'tool_guide', got '{policy.type}'. " + f"Only tool_guide policies can generate guard code." + ) + + # Create and initialize ToolGuardManager + manager = ToolGuardManager(self._agent) + await manager.initialize() + + # Generate guard code using the manager + guard_code = await manager.generate_guard_code( + policy=policy, + target_tool=target_tool, + app_name=app_name + ) + + return guard_code + return violating_examples, compliance_examples class CugaAgent: diff --git a/src/cuga/sdk_core/tests/debug_tool_guide.py b/src/cuga/sdk_core/tests/debug_tool_guide.py deleted file mode 100644 index 7629ec20..00000000 --- a/src/cuga/sdk_core/tests/debug_tool_guide.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -Debug script for creating, storing, retrieving, and exporting a tool guide policy. - -This script demonstrates: -1. Creating a CugaAgent with a tool -2. Adding a new tool guide policy -3. Saving the policy -4. Retrieving the policy from storage -5. Exporting the policy data to a JSON file -""" - -import asyncio -import json -from pathlib import Path -from datetime import datetime - -from langchain_core.tools import tool - -from cuga import CugaAgent -from cuga.backend.cuga_graph.policy.models import ToolGuide - - -@tool -def delete_file(file_path: str) -> str: - """Delete a file from the system""" - return f"File {file_path} has been deleted" - - -async def main(): - """Main workflow for creating and managing a tool guide policy.""" - - # Step 1: Create a CugaAgent with the delete_file tool - print("Step 1: Creating CugaAgent with delete_file tool...") - agent = CugaAgent(tools=[delete_file]) - - # Step 2: Add a new tool guide policy - print("\nStep 2: Adding new tool guide policy...") - policy_id = await agent.policies.add_tool_guide( - name="File Deletion Safety Policy", - content="## File Deletion Guidelines\n- Never delete system files\n- Always verify file path before deletion\n- Require confirmation for critical files\n- Log all deletion operations", - target_tools=["delete_file"], - description="Safety guidelines for file deletion operations to prevent accidental data loss", - ) - print(f"Created policy with ID: {policy_id}") - policies = await agent.policies.list() - print(policies[0]["name"]) - print(policies[0]["type"]) - - policy_details = agent.policies.get(policy_id) - for policy in policies: - print(f"ID: {policy['id']}") - print(f"Name: {policy['name']}") - print(f"Type: {policy['type']}") - print(f"Enabled: {policy['enabled']}") - print(f"Priority: {policy['priority']}") - print("Policy updated with tool_guards") - - # Step 2b: Generate examples using the new SDK function - print("\nStep 2b: Generating tool guard examples using SDK...") - violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( - policy_id=policy_id, - target_tool="delete_file" - ) - - print(f"Generated {len(violating_examples)} violating examples and {len(compliance_examples)} compliance examples") - print("\nViolating examples:") - for i, example in enumerate(violating_examples, 1): - print(f" {i}. {example}") - print("\nCompliance examples:") - for i, example in enumerate(compliance_examples, 1): - print(f" {i}. {example}") - - # Step 2c: Update the policy with generated tool_guards - print("\nStep 2c: Updating policy with generated tool_guards...") - await agent.policies.update_tool_guard( - policy_id=policy_id, - tool_guards={ - "delete_file": { - "description": "Guard rules for safe file deletion to prevent accidental data loss", - "violating_examples": violating_examples, - "compliance_examples": compliance_examples, - "policy_code": "" - } - } - ) - print(f"Updated policy {policy_id} with generated tool_guards") - - # Step 3: Save the policy - print("\nStep 3: Saving policy...") - policy_tool_guide = await agent.policies.get(policy_id) - if policy_tool_guide is None: - raise ValueError(f"Failed to retrieve policy {policy_id}") - - policy = policy_tool_guide["policy"] - - # Get policy system and update - policy_system = await agent.policies._ensure_policy_system() - if policy_system is None or policy_system.storage is None: - raise ValueError("Policy system storage is not available") - - #await policy_system.storage.update_policy(policy) - await policy_system.initialize() - - # Save to file system if available - if agent.policies._fs_sync: - agent.policies._fs_sync.save_policy_to_file(policy) - - print("Policy saved successfully!") - - # Step 4: Retrieve the policy from storage - print("\nStep 4: Retrieving policy from storage...") - retrieved_policy = await policy_system.storage.get_policy(policy_id) - if retrieved_policy is None: - raise ValueError(f"Failed to retrieve updated policy {policy_id}") - - if not isinstance(retrieved_policy, ToolGuide): - raise TypeError(f"Expected ToolGuide, got {type(retrieved_policy).__name__}") - - print("Policy retrieved successfully!") - - # Step 5: Export to JSON file - print("\nStep 5: Exporting policy data to JSON file...") - - # Convert tool_guards to serializable format - tool_guards_data = None - if retrieved_policy.tool_guards: - tool_guards_data = {} - for tool_name, guard in retrieved_policy.tool_guards.items(): - tool_guards_data[tool_name] = { - "description": guard.description, - "violating_examples": guard.violating_examples, - "compliance_examples": guard.compliance_examples, - "policy_code": guard.policy_code, - } - - output_data = { - "export_timestamp": datetime.utcnow().isoformat(), - "policy_id": policy_id, - "policy_name": retrieved_policy.name, - "policy_description": retrieved_policy.description, - "guide_content": retrieved_policy.guide_content, - "target_tools": retrieved_policy.target_tools, - "target_apps": retrieved_policy.target_apps, - "tool_guards": tool_guards_data, - "prepend": retrieved_policy.prepend, - "priority": retrieved_policy.priority, - "enabled": retrieved_policy.enabled, - "policy_type": type(retrieved_policy).__name__, - "metadata": retrieved_policy.metadata, - } - - # Save to JSON file in current directory - output_path = Path("tool_guide_policy_export.json") - output_path.write_text(json.dumps(output_data, indent=2), encoding="utf-8") - - print(f"\n{'='*60}") - print("SUCCESS! Policy workflow completed:") - print(f"{'='*60}") - print(f"Policy ID: {policy_id}") - print(f"Policy Name: {retrieved_policy.name}") - print(f"Target Tools: {', '.join(retrieved_policy.target_tools)}") - print(f"\nExported to: {output_path.absolute()}") - print(f"{'='*60}") - print("\nExported JSON content:") - print(json.dumps(output_data, indent=2)) - - -if __name__ == "__main__": - asyncio.run(main()) - -# Made with Bob \ No newline at end of file diff --git a/src/cuga/sdk_core/tests/policies-export-2025-12-31.json b/src/cuga/sdk_core/tests/policies-export-2025-12-31.json deleted file mode 100644 index 0118f213..00000000 --- a/src/cuga/sdk_core/tests/policies-export-2025-12-31.json +++ /dev/null @@ -1,132 +0,0 @@ -{ - "enablePolicies": true, - "policies": [ - { - "id": "guard_1767135201517", - "name": "What is ALTK", - "description": "Blocks or modifies specific user intents", - "policy_type": "intent_guard", - "enabled": true, - "triggers": [ - { - "type": "natural_language", - "value": ["What is ALTK"], - "target": "intent", - "threshold": 0.7 - } - ], - "priority": 50, - "intent_examples": [], - "response": { - "response_type": "natural_language", - "content": "ALTK is sister project of CUGA", - "status_code": null - }, - "allow_override": false - }, - { - "id": "guard_1767136121417", - "name": "Dangerous actions", - "description": "Prevents dangerous actions that can cause cuga to generate dangrous code", - "policy_type": "intent_guard", - "enabled": true, - "triggers": [ - { - "type": "natural_language", - "value": ["print en variables", "do it 10000 times", "do this loop forever"], - "target": "intent", - "threshold": 0.7 - } - ], - "priority": 50, - "intent_examples": [], - "response": { - "response_type": "natural_language", - "content": "I caught you doing something bad with CUGA!!! oops", - "status_code": null - }, - "allow_override": false - }, - { - "id": "guard_1767136203456", - "name": "Remove", - "description": "Blocks or modifies specific user intents", - "policy_type": "intent_guard", - "enabled": true, - "triggers": [ - { - "type": "keyword", - "value": ["remove", "delete"], - "target": "intent", - "case_sensitive": false, - "operator": "or" - } - ], - "priority": 50, - "intent_examples": [], - "response": { - "response_type": "natural_language", - "content": "This action is not allowed. (delete or remove)", - "status_code": null - }, - "allow_override": false - }, - { - "id": "playbook_1767135234400", - "name": "What is CUGA", - "description": "Step-by-step guidance for a task", - "policy_type": "playbook", - "enabled": true, - "triggers": [ - { - "type": "natural_language", - "value": ["What is CUGA"], - "target": "intent", - "threshold": 0.7 - } - ], - "priority": 50, - "markdown_content": "# Task Guide\n\n## Steps:\n\n1. Answer that cuga is very powerful!", - "steps": [ - { - "step_number": 1, - "instruction": "First step", - "expected_outcome": "Step 1 complete", - "tools_allowed": [] - } - ], - "inject_as_system_prompt": true - }, - { - "id": "tool_approval_1767136273134", - "name": "get accounts approval", - "description": "Require approval before executing specific tools", - "policy_type": "tool_approval", - "enabled": true, - "triggers": [], - "priority": 50, - "required_tools": ["digital_sales_get_my_accounts_my_accounts_get"], - "required_apps": null, - "approval_message": "This tool requires your approval before execution.", - "show_code_preview": true, - "auto_approve_after": null - }, - { - "id": "tool_guide_1767136239089", - "name": "Bobo tool", - "description": "Add additional context to tool descriptions", - "policy_type": "tool_guide", - "enabled": true, - "triggers": [ - { - "type": "always" - } - ], - "priority": 50, - "target_tools": ["digital_sales_get_my_accounts_my_accounts_get"], - "target_apps": null, - "guide_content": "## Additional Guidelines\n\n- Use this tool and get first two accounts when user says 'BoBo'", - "prepend": true - } - ] -} From 84bed81cb2a6c7e13403ca2c00b77ad794cfe4d6 Mon Sep 17 00:00:00 2001 From: naamaz Date: Thu, 23 Apr 2026 11:33:28 +0300 Subject: [PATCH 06/24] tg runtime --- .../cuga_graph/policy/tool_guard/__init__.py | 3 +- .../policy/tool_guard/tool_guard_runtime.py | 256 ++++++++++++++++++ .../policy/tool_guard/tool_invoker.py | 120 ++++++++ src/cuga/sdk_core/debug_tool_guard.py | 220 +++++++++++++++ 4 files changed, 598 insertions(+), 1 deletion(-) create mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py create mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py create mode 100644 src/cuga/sdk_core/debug_tool_guard.py diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py index 73312cf8..90e4e9ce 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py @@ -6,7 +6,8 @@ """ from .manager import ToolGuardManager +from .tool_guard_runtime import ToolGuardRuntime -__all__ = ["ToolGuardManager"] +__all__ = ["ToolGuardManager", "ToolGuardRuntime"] # Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py new file mode 100644 index 00000000..beb098bd --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -0,0 +1,256 @@ +""" +Runtime execution of tool guards for policy enforcement. + +This module provides runtime validation of tool calls against registered +ToolGuide policies with policy_code. +""" + +from typing import Dict, Any, List, Optional +from loguru import logger + +from cuga.backend.cuga_graph.policy.tool_guard.tool_invoker import ToolGuardInvoker +from cuga.backend.cuga_graph.policy.models import ToolGuide, PolicyType +from cuga.backend.cuga_graph.policy.storage import PolicyStorage + + +class ToolGuardRuntime: + """ + Runtime system for executing tool guards during tool invocation. + + This class: + 1. Initializes a ToolGuardInvoker for tool execution + 2. Loads all ToolGuide policies with policy_code + 3. Creates a mapping: tool_name -> List[ToolGuide with code] + 4. Executes guard validation when tools are called + """ + + def __init__(self, tool_provider, policy_storage: PolicyStorage): + """ + Initialize the ToolGuardRuntime. + + Args: + tool_provider: CUGA's tool provider instance + policy_storage: PolicyStorage instance to load policies from + """ + self.tool_provider = tool_provider + self.policy_storage = policy_storage + + # Create ToolGuardInvoker instance + self.invoker = ToolGuardInvoker(tool_provider) + + # Mapping: tool_name -> List[ToolGuide policies with code] + self.tool_to_guards: Dict[str, List[ToolGuide]] = {} + + self._initialized = False + logger.debug("Created ToolGuardRuntime instance") + + async def initialize(self) -> None: + """ + Initialize the runtime by loading all ToolGuide policies with code. + + This method: + 1. Fetches all ToolGuide policies from storage + 2. Filters for policies that have tool_guards with policy_code + 3. Builds the tool_to_guards mapping + """ + logger.info("Initializing ToolGuardRuntime...") + + # Load all ToolGuide policies + policies = await self.policy_storage.list_policies( + policy_type=PolicyType.TOOL_GUIDE, + enabled_only=True + ) + + logger.debug(f"Found {len(policies)} ToolGuide policies") + + # Build mapping: tool_name -> List[ToolGuide with code] + for policy in policies: + # Type guard: ensure we're working with ToolGuide + if not isinstance(policy, ToolGuide): + logger.warning( + f"Expected ToolGuide but got {type(policy).__name__}, skipping" + ) + continue + + if not policy.tool_guards: + logger.debug(f"Policy '{policy.name}' has no tool_guards, skipping") + continue + + # Iterate through each tool's guard configuration + for tool_name, tool_guard in policy.tool_guards.items(): + # Only include if policy_code exists + if tool_guard.policy_code: + if tool_name not in self.tool_to_guards: + self.tool_to_guards[tool_name] = [] + + self.tool_to_guards[tool_name].append(policy) + logger.debug( + f"Registered guard for tool '{tool_name}' " + f"from policy '{policy.name}'" + ) + else: + logger.debug( + f"Tool guard for '{tool_name}' in policy '{policy.name}' " + f"has no policy_code, skipping" + ) + + self._initialized = True + logger.info( + f"✅ ToolGuardRuntime initialized with guards for " + f"{len(self.tool_to_guards)} tools" + ) + + # Log summary of registered guards + for tool_name, guards in self.tool_to_guards.items(): + logger.debug( + f" - Tool '{tool_name}': {len(guards)} guard(s) " + f"({', '.join(g.name for g in guards)})" + ) + + async def guard_tool_call( + self, + app_name: str, + function_name: str, + arguments: Dict[str, Any] + ) -> Optional[str]: + """ + Validate a tool call against registered guards. + + This method executes all registered guards for the specified tool. + If any guard returns an error message, validation fails and that + error message is returned. If all guards pass (return None), + validation succeeds and None is returned. + + Args: + app_name: Name of the application calling the tool + function_name: Name of the tool/function being called + arguments: Arguments being passed to the tool + + Returns: + Error message string if validation fails, None if validation passes + """ + if not self._initialized: + logger.warning("ToolGuardRuntime not initialized, skipping validation") + return None + + # Check if this tool has any guards + if function_name not in self.tool_to_guards: + logger.debug(f"No guards registered for tool '{function_name}'") + return None + + # Get all guards for this tool + guards = self.tool_to_guards[function_name] + + logger.debug( + f"Validating tool call '{function_name}' against " + f"{len(guards)} guard(s)" + ) + + # Execute each guard's policy_code + for policy in guards: + # Ensure policy has tool_guards and the function_name exists + if not policy.tool_guards or function_name not in policy.tool_guards: + logger.warning( + f"Policy '{policy.name}' missing tool_guard for '{function_name}', skipping" + ) + continue + + tool_guard = policy.tool_guards[function_name] + + try: + logger.debug( + f"Executing guard '{policy.name}' for tool '{function_name}'" + ) + + # Create execution context with necessary variables + exec_globals = { + 'app_name': app_name, + 'function_name': function_name, + 'arguments': arguments, + 'invoker': self.invoker, + } + + # Execute the policy code + # The policy_code should define a 'validate' function + exec(tool_guard.policy_code, exec_globals) + + # Check if validate function was defined + if 'validate' not in exec_globals: + logger.warning( + f"Policy code for '{policy.name}' does not define " + f"a 'validate' function, skipping" + ) + continue + + validate_fn = exec_globals['validate'] + + # Call the validate function + # Support both sync and async validate functions + import inspect + if inspect.iscoroutinefunction(validate_fn): + error = await validate_fn( + app_name=app_name, + function_name=function_name, + arguments=arguments + ) + else: + error = validate_fn( + app_name=app_name, + function_name=function_name, + arguments=arguments + ) + + # If validation failed, return the error message + if error: + logger.warning( + f"Tool guard '{policy.name}' blocked call to " + f"'{function_name}': {error}" + ) + return error + + logger.debug( + f"Guard '{policy.name}' passed for tool '{function_name}'" + ) + + except Exception as e: + logger.error( + f"Error executing guard '{policy.name}' for tool " + f"'{function_name}': {e}", + exc_info=True + ) + # Continue to next guard instead of failing completely + # This ensures one broken guard doesn't break all validation + continue + + # All guards passed + logger.debug(f"Tool call '{function_name}' passed all guards") + return None + + @property + def is_initialized(self) -> bool: + """Check if the runtime has been initialized.""" + return self._initialized + + def get_guarded_tools(self) -> List[str]: + """ + Get list of tool names that have guards registered. + + Returns: + List of tool names with active guards + """ + return list(self.tool_to_guards.keys()) + + def get_guards_for_tool(self, tool_name: str) -> List[ToolGuide]: + """ + Get all guards registered for a specific tool. + + Args: + tool_name: Name of the tool + + Returns: + List of ToolGuide policies with guards for this tool + """ + return self.tool_to_guards.get(tool_name, []) + + +# Made with Bob \ No newline at end of file diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py new file mode 100644 index 00000000..662944b5 --- /dev/null +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py @@ -0,0 +1,120 @@ +""" +ToolGuard invoker for CUGA's tool provider. + +This module provides integration between toolguard's runtime validation +and CUGA's tool provider system. +""" + +from typing import Any, Dict, Optional, Type, TypeVar +from loguru import logger + +from toolguard.runtime.data_types import IToolInvoker + +T = TypeVar('T') + + +class ToolGuardInvoker(IToolInvoker): + """ + Tool invoker that uses CUGA's tool provider for executing tools + during toolguard validation. + + This class bridges toolguard's runtime validation with CUGA's + tool execution system, allowing guards to invoke tools for + validation purposes. + + Similar to LangchainToolInvoker and MCPToolInvoker from the toolguard + library, but adapted to work with CUGA's tool provider. + """ + + def __init__(self, tool_provider): + """ + Initialize the ToolGuardInvoker. + + Args: + tool_provider: CUGA's tool provider instance that manages + and executes tools + """ + self.tool_provider = tool_provider + self._tools_cache: Optional[Dict[str, Any]] = None + logger.debug("Initialized ToolGuardInvoker with CUGA tool provider") + + async def _get_tools(self) -> Dict[str, Any]: + """ + Get all available tools from the tool provider. + + Returns: + Dictionary mapping tool names to tool instances + """ + if self._tools_cache is None: + tools_list = await self.tool_provider.get_all_tools() + self._tools_cache = {tool.name: tool for tool in tools_list} + logger.debug(f"Cached {len(self._tools_cache)} tools") + return self._tools_cache + + async def invoke( + self, + toolname: str, + arguments: Dict[str, Any], + return_type: Type[T] + ) -> T: + """ + Invoke a tool by name with the given arguments. + + This method is called by toolguard during guard validation + to execute tools and verify their behavior. + + Args: + toolname: Name of the tool to invoke + arguments: Dictionary of arguments to pass to the tool + return_type: Expected return type for the tool invocation + + Returns: + The result of the tool invocation, cast to the expected type + + Raises: + ValueError: If the tool is not found + RuntimeError: If tool invocation fails + """ + try: + logger.debug(f"Invoking tool '{toolname}' with args: {arguments}") + + # Get the tool from the provider + tools = await self._get_tools() + + if toolname not in tools: + available_tools = list(tools.keys()) + raise ValueError( + f"Tool '{toolname}' not found. " + f"Available tools: {available_tools}" + ) + + tool = tools[toolname] + + # Invoke the tool using LangChain's invoke method + # LangChain tools typically accept a single input or dict of inputs + result = await tool.ainvoke(arguments) + + logger.debug(f"Tool '{toolname}' invocation successful") + return result + + except ValueError: + # Re-raise ValueError as-is (tool not found) + raise + except Exception as e: + logger.error(f"Failed to invoke tool '{toolname}': {e}") + raise RuntimeError( + f"Tool invocation failed for '{toolname}': {str(e)}" + ) from e + + def clear_cache(self) -> None: + """ + Clear the cached tools. + + Call this method if tools are added or removed from the + tool provider after initialization. + """ + self._tools_cache = None + logger.debug("Cleared tools cache") + + +# Made with Bob \ No newline at end of file diff --git a/src/cuga/sdk_core/debug_tool_guard.py b/src/cuga/sdk_core/debug_tool_guard.py new file mode 100644 index 00000000..79df65ba --- /dev/null +++ b/src/cuga/sdk_core/debug_tool_guard.py @@ -0,0 +1,220 @@ +""" +Debug script for creating a tool guard policy with code generation. + +This script demonstrates: +1. Creating a CugaAgent with flight booking tools +2. Adding a tool guide policy for membership-based restrictions +3. Generating examples for the policy +4. Generating guard code for the policy +5. Updating the policy with examples and code +6. Exporting the policy data to a JSON file +""" + +import asyncio +import json +from pathlib import Path +from datetime import datetime + +from langchain_core.tools import tool + +from cuga import CugaAgent +from cuga.backend.cuga_graph.policy.models import ToolGuide + + +@tool +def book_flight(user_id: str, flight_id: str, passengers: int) -> str: + """Book a flight for a user with specified number of passengers""" + return f"Flight {flight_id} booked for user {user_id} with {passengers} passengers" + + +@tool +def get_membership(user_id: str) -> str: + """Get the membership level of a user (gold, silver, or regular)""" + # Simulate membership lookup + memberships = { + "user123": "gold", + "user456": "silver", + "user789": "regular" + } + return memberships.get(user_id, "regular") + + +async def main(): + """Main workflow for creating and managing a tool guard policy with code generation.""" + + # Step 1: Create a CugaAgent with the flight booking tools + print("Step 1: Creating CugaAgent with flight booking tools...") + agent = CugaAgent(tools=[book_flight, get_membership]) + + # Step 2: Add a new tool guide policy + print("\nStep 2: Adding new tool guide policy...") + policy_id = await agent.policies.add_tool_guide( + name="Flight Booking Membership Policy", + content="""## Flight Booking Restrictions by Membership Level + +### Policy Rules +- Customers with "regular" membership cannot book a flight with more than 3 passengers +- Gold and silver members have no passenger restrictions +- This policy ensures fair resource allocation and encourages membership upgrades + +### Validation Requirements +- Always check user membership level before booking +- Reject bookings that violate passenger limits +- Provide clear error messages when restrictions apply +""", + target_tools=["book_flight"], + description="Membership-based restrictions for flight bookings to ensure fair resource allocation", + ) + print(f"Created policy with ID: {policy_id}") + + # Step 3: Generate examples using the SDK function + print("\nStep 3: Generating tool guard examples...") + violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="book_flight" + ) + + print(f"Generated {len(violating_examples)} violating examples and {len(compliance_examples)} compliance examples") + print("\nViolating examples:") + for i, example in enumerate(violating_examples, 1): + print(f" {i}. {example}") + print("\nCompliance examples:") + for i, example in enumerate(compliance_examples, 1): + print(f" {i}. {example}") + + # Step 4: Update the policy with generated examples + print("\nStep 4: Updating policy with generated examples...") + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "book_flight": { + "description": "Guard rules for flight booking based on membership level", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": "" # Will be populated in next step + } + } + ) + print(f"Updated policy {policy_id} with generated examples") + + # Step 5: Generate guard code + print("\nStep 5: Generating guard code...") + guard_code = await agent.policies.generate_tool_guard_code( + policy_id=policy_id, + target_tool="book_flight", + app_name="flight_booking_app" + ) + + print(f"Generated guard code ({len(guard_code)} characters)") + print("\nGuard code preview (first 500 characters):") + print(guard_code[:500]) + print("...") + + # Step 6: Update the policy with generated code + print("\nStep 6: Updating policy with generated guard code...") + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "book_flight": { + "description": "Guard rules for flight booking based on membership level", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": guard_code + } + } + ) + print(f"Updated policy {policy_id} with generated guard code") + + # Step 7: Save the policy + print("\nStep 7: Saving policy...") + policy_tool_guide = await agent.policies.get(policy_id) + if policy_tool_guide is None: + raise ValueError(f"Failed to retrieve policy {policy_id}") + + policy = policy_tool_guide["policy"] + + # Get policy system and update + policy_system = await agent.policies._ensure_policy_system() + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + + await policy_system.initialize() + + # Save to file system if available + if agent.policies._fs_sync: + agent.policies._fs_sync.save_policy_to_file(policy) + + print("Policy saved successfully!") + + # Step 8: Retrieve the policy from storage + print("\nStep 8: Retrieving policy from storage...") + retrieved_policy = await policy_system.storage.get_policy(policy_id) + if retrieved_policy is None: + raise ValueError(f"Failed to retrieve updated policy {policy_id}") + + if not isinstance(retrieved_policy, ToolGuide): + raise TypeError(f"Expected ToolGuide, got {type(retrieved_policy).__name__}") + + print("Policy retrieved successfully!") + + # Step 9: Export to JSON file + print("\nStep 9: Exporting policy data to JSON file...") + + # Convert tool_guards to serializable format + tool_guards_data = None + if retrieved_policy.tool_guards: + tool_guards_data = {} + for tool_name, guard in retrieved_policy.tool_guards.items(): + tool_guards_data[tool_name] = { + "description": guard.description, + "violating_examples": guard.violating_examples, + "compliance_examples": guard.compliance_examples, + "policy_code": guard.policy_code, + } + + output_data = { + "export_timestamp": datetime.utcnow().isoformat(), + "policy_id": policy_id, + "policy_name": retrieved_policy.name, + "policy_description": retrieved_policy.description, + "guide_content": retrieved_policy.guide_content, + "target_tools": retrieved_policy.target_tools, + "target_apps": retrieved_policy.target_apps, + "tool_guards": tool_guards_data, + "prepend": retrieved_policy.prepend, + "priority": retrieved_policy.priority, + "enabled": retrieved_policy.enabled, + "policy_type": type(retrieved_policy).__name__, + "metadata": retrieved_policy.metadata, + } + + # Save to JSON file in current directory + output_path = Path("flight_booking_policy_export.json") + output_path.write_text(json.dumps(output_data, indent=2), encoding="utf-8") + + print(f"\n{'='*60}") + print("SUCCESS! Policy workflow with code generation completed:") + print(f"{'='*60}") + print(f"Policy ID: {policy_id}") + print(f"Policy Name: {retrieved_policy.name}") + print(f"Target Tools: {', '.join(retrieved_policy.target_tools)}") + print(f"Generated Examples: {len(violating_examples)} violating, {len(compliance_examples)} compliance") + print(f"Generated Code: {len(guard_code)} characters") + print(f"\nExported to: {output_path.absolute()}") + print(f"{'='*60}") + + # Print summary of tool_guards + if retrieved_policy.tool_guards: + print("\nTool Guards Summary:") + for tool_name, guard in retrieved_policy.tool_guards.items(): + print(f"\n Tool: {tool_name}") + print(f" Description: {guard.description}") + print(f" Violating Examples: {len(guard.violating_examples)}") + print(f" Compliance Examples: {len(guard.compliance_examples)}") + print(f" Policy Code Length: {len(guard.policy_code)} characters") + + +if __name__ == "__main__": + asyncio.run(main()) + +# Made with Bob \ No newline at end of file From f1c6b89efafae1866763c737b16c0f252236e267 Mon Sep 17 00:00:00 2001 From: naamaz Date: Sun, 26 Apr 2026 14:37:17 +0300 Subject: [PATCH 07/24] runtime 1 --- .../policy/tool_guard/tool_guard_runtime.py | 152 +++++- .../registry/registry/api_registry.py | 79 ++- .../registry/registry/api_registry_server.py | 42 +- src/cuga/sdk_core/debug_tool_guard_e2e.py | 487 ++++++++++++++++++ 4 files changed, 738 insertions(+), 22 deletions(-) create mode 100644 src/cuga/sdk_core/debug_tool_guard_e2e.py diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index beb098bd..f57ccd53 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -8,6 +8,8 @@ from typing import Dict, Any, List, Optional from loguru import logger +from toolguard.runtime.data_types import PolicyViolationException + from cuga.backend.cuga_graph.policy.tool_guard.tool_invoker import ToolGuardInvoker from cuga.backend.cuga_graph.policy.models import ToolGuide, PolicyType from cuga.backend.cuga_graph.policy.storage import PolicyStorage @@ -162,42 +164,146 @@ async def guard_tool_call( f"Executing guard '{policy.name}' for tool '{function_name}'" ) - # Create execution context with necessary variables + def rule(name: str): + """Decorator for policy rules - just returns the function.""" + def decorator(func): + return func + return decorator + + def assert_any_condition_met(*conditions): + """Mock function for assert_any_condition_met.""" + pass + + # Create a simple args object from the arguments dict + class Args: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + args_obj = Args(**arguments) + + # Create a dynamic API wrapper that provides method-like access to tools + class DynamicAPIWrapper: + """Wrapper that allows calling tools as methods on the api object.""" + + def __init__(self, invoker): + self._invoker = invoker + + def __getattr__(self, tool_name: str): + """ + Return a callable that invokes the tool with the given name. + This allows code like: api.get_membership(args) + """ + async def tool_caller(args_obj): + # Convert args object to dict + if hasattr(args_obj, '__dict__'): + args_dict = args_obj.__dict__ + else: + args_dict = args_obj + + # Invoke the tool through the invoker + result = await self._invoker.invoke( + toolname=tool_name, + arguments=args_dict, + return_type=str # Default to str, could be improved + ) + return result + + return tool_caller + + api_wrapper = DynamicAPIWrapper(self.invoker) + + # Create execution context with necessary variables and mocks exec_globals = { 'app_name': app_name, 'function_name': function_name, 'arguments': arguments, 'invoker': self.invoker, + # Mock imports for generated code + 'PolicyViolationException': PolicyViolationException, + 'rule': rule, + 'assert_any_condition_met': assert_any_condition_met, + 'typing': __import__('typing'), + 'Any': Any, + 'Dict': Dict, + 'List': List, + 'Optional': Optional, } # Execute the policy code - # The policy_code should define a 'validate' function + # The policy_code may define either: + # 1. A 'validate' function (legacy format) + # 2. A decorated function with @rule (generated format) exec(tool_guard.policy_code, exec_globals) - # Check if validate function was defined - if 'validate' not in exec_globals: + # Find the validation function + validate_fn = None + + # First, check for 'validate' function (legacy format) + if 'validate' in exec_globals: + validate_fn = exec_globals['validate'] + logger.debug(f"Found 'validate' function in policy code") + else: + # Look for any function that starts with 'guard_' (generated format) + for name, obj in exec_globals.items(): + if name.startswith('guard_') and callable(obj): + validate_fn = obj + logger.debug(f"Found guard function '{name}' in policy code") + break + + if validate_fn is None: logger.warning( f"Policy code for '{policy.name}' does not define " - f"a 'validate' function, skipping" + f"a 'validate' or 'guard_*' function, skipping" ) continue - validate_fn = exec_globals['validate'] - # Call the validate function # Support both sync and async validate functions import inspect - if inspect.iscoroutinefunction(validate_fn): - error = await validate_fn( - app_name=app_name, - function_name=function_name, - arguments=arguments - ) - else: - error = validate_fn( - app_name=app_name, - function_name=function_name, - arguments=arguments + + # Determine the function signature to call it correctly + sig = inspect.signature(validate_fn) + params = list(sig.parameters.keys()) + + error: Optional[str] = None + + # Call the validate function and catch PolicyViolationException + try: + if inspect.iscoroutinefunction(validate_fn): + # For generated code format: func(api, args) + if len(params) >= 2 and params[0] in ['api', 'invoker']: + await validate_fn(api_wrapper, args_obj) + # For legacy format: func(app_name, function_name, arguments) + elif 'app_name' in params or 'function_name' in params: + result = await validate_fn( + app_name=app_name, + function_name=function_name, + arguments=arguments + ) + error = str(result) if result else None + else: + # Try calling with api_wrapper and args + await validate_fn(api_wrapper, args_obj) + else: + # Sync function + if len(params) >= 2 and params[0] in ['api', 'invoker']: + validate_fn(api_wrapper, args_obj) + elif 'app_name' in params or 'function_name' in params: + result = validate_fn( + app_name=app_name, + function_name=function_name, + arguments=arguments + ) + error = str(result) if result else None + else: + validate_fn(api_wrapper, args_obj) + + except PolicyViolationException as e: + # Policy violation - this is expected behavior + error = str(e) + logger.debug( + f"Guard '{policy.name}' caught policy violation: {error}" ) # If validation failed, return the error message @@ -212,6 +318,16 @@ async def guard_tool_call( f"Guard '{policy.name}' passed for tool '{function_name}'" ) + except PolicyViolationException as e: + # Policy violation at the outer level - should have been caught above + # but handle it here as well for safety + error = str(e) + logger.warning( + f"Tool guard '{policy.name}' blocked call to " + f"'{function_name}': {error}" + ) + return error + except Exception as e: logger.error( f"Error executing guard '{policy.name}' for tool " diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry.py b/src/cuga/backend/tools_env/registry/registry/api_registry.py index 76bb874d..9018a69e 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry.py @@ -29,12 +29,17 @@ class ApiRegistry: interacting with the mcp manager """ - def __init__(self, client: MCPManager): + def __init__(self, client: MCPManager, policy_storage=None): logger.info("ApiRegistry: Initializing.") self.mcp_client = client self.auth_manager = None self.tavily_client = None self._init_tavily_if_enabled() + + # ToolGuardRuntime support for policy-based tool validation + self.policy_storage = policy_storage + self.tool_guard_runtime = None + self._tool_guard_initialized = False def _init_tavily_if_enabled(self): """Initialize Tavily client if web search is enabled.""" @@ -60,6 +65,49 @@ async def start_servers(self): """Start servers and load tools""" await self.mcp_client.load_tools() logger.info("ApiRegistry: Servers started successfully.") + + # Initialize ToolGuardRuntime after tools are loaded + await self._initialize_tool_guard_runtime() + + async def _initialize_tool_guard_runtime(self): + """ + Initialize ToolGuardRuntime after tools are loaded. + + This is called after load_tools() to ensure both tools and policies are available. + """ + if self._tool_guard_initialized: + return + + self._tool_guard_initialized = True + + if self.policy_storage is None: + logger.debug("ToolGuardRuntime: No policy storage provided, skipping initialization") + return + + try: + from cuga.backend.cuga_graph.policy.tool_guard.tool_guard_runtime import ToolGuardRuntime + + # Note: tool_provider is only needed if guards invoke tools via the invoker + # For most guards that just validate arguments, it's not used + # We pass self.mcp_client which can be wrapped if needed + self.tool_guard_runtime = ToolGuardRuntime( + tool_provider=self.mcp_client, + policy_storage=self.policy_storage + ) + + await self.tool_guard_runtime.initialize() + + guarded_tools = self.tool_guard_runtime.get_guarded_tools() + logger.info( + f"✅ ToolGuardRuntime initialized with guards for {len(guarded_tools)} tools" + ) + if guarded_tools: + logger.debug(f" Guarded tools: {', '.join(guarded_tools)}") + + except Exception as e: + logger.warning(f"Failed to initialize ToolGuardRuntime: {e}") + logger.debug(f"ToolGuardRuntime initialization error details:", exc_info=True) + self.tool_guard_runtime = None async def show_applications(self) -> List[AppDefinition]: """Lists application names and their descriptions.""" @@ -179,6 +227,35 @@ async def call_function( self, app_name: str, function_name: str, arguments: Dict[str, Any], auth_config=None ) -> Dict[str, Any]: """Calls a function via the mcp_client.""" + + # Validate tool call against ToolGuard policies + if self.tool_guard_runtime and self.tool_guard_runtime.is_initialized: + try: + error_message = await self.tool_guard_runtime.guard_tool_call( + app_name=app_name, + function_name=function_name, + arguments=arguments + ) + + if error_message: + # Guard validation failed - return error without executing tool + logger.warning( + f"🛡️ Tool guard blocked call to '{function_name}': {error_message}" + ) + return { + "status": "exception", + "status_code": 403, + "message": f"Tool guard policy violation: {error_message}", + "error_type": "ToolGuardViolation", + "function_name": function_name, + } + else: + logger.debug(f"✅ Tool guard validation passed for '{function_name}'") + except Exception as e: + # Log guard execution error but don't block the tool call + # This ensures a broken guard doesn't break the entire system + logger.error(f"Error executing tool guard for '{function_name}': {e}", exc_info=True) + if app_name == "web" and function_name == "search_web" and self._is_web_search_enabled(): args = arguments.get('params', arguments) if isinstance(arguments, dict) else arguments query = args.get('query') if isinstance(args, dict) else str(args) diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py index e639b3a1..caa8be2a 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py @@ -140,7 +140,19 @@ async def _get_or_create_registry( logger.debug(f"Knowledge MCP transport override skipped: {e}") manager = MCPManager(config=services) - reg = ApiRegistry(client=manager) + + # Get policy storage if policy system is enabled + policy_storage = None + if settings.policy.enabled: + try: + from cuga.backend.cuga_graph.policy.storage import PolicyStorage + policy_storage = PolicyStorage() + await policy_storage.connect() + logger.info("✅ Connected policy storage for ToolGuardRuntime") + except Exception as e: + logger.warning(f"Failed to connect policy storage: {e}") + + reg = ApiRegistry(client=manager, policy_storage=policy_storage) await reg.start_servers() agent_registries[agent_id] = (manager, reg) @@ -166,7 +178,19 @@ async def lifespan(app: FastAPI): print(f"Using configuration file: {config_file}") services = load_service_configs(str(config_file)) mcp_manager = MCPManager(config=services) - registry = ApiRegistry(client=mcp_manager) + + # Get policy storage if policy system is enabled + policy_storage = None + if settings.policy.enabled: + try: + from cuga.backend.cuga_graph.policy.storage import PolicyStorage + policy_storage = PolicyStorage() + await policy_storage.connect() + logger.info("✅ Connected policy storage for ToolGuardRuntime") + except Exception as e: + logger.warning(f"Failed to connect policy storage: {e}") + + registry = ApiRegistry(client=mcp_manager, policy_storage=policy_storage) await registry.start_servers() yield @@ -516,7 +540,19 @@ async def reload_config( logger.info(f"Reloading from file: {config_path}") services = load_service_configs(config_path) new_manager = MCPManager(config=services) - new_registry = ApiRegistry(client=new_manager) + + # Get policy storage if policy system is enabled + policy_storage = None + if settings.policy.enabled: + try: + from cuga.backend.cuga_graph.policy.storage import PolicyStorage + policy_storage = PolicyStorage() + await policy_storage.connect() + logger.info("✅ Connected policy storage for ToolGuardRuntime") + except Exception as e: + logger.warning(f"Failed to connect policy storage: {e}") + + new_registry = ApiRegistry(client=new_manager, policy_storage=policy_storage) await new_registry.start_servers() if 'mcp_manager' in globals(): await mcp_manager.shutdown() diff --git a/src/cuga/sdk_core/debug_tool_guard_e2e.py b/src/cuga/sdk_core/debug_tool_guard_e2e.py new file mode 100644 index 00000000..43f5ed24 --- /dev/null +++ b/src/cuga/sdk_core/debug_tool_guard_e2e.py @@ -0,0 +1,487 @@ +""" +Debug script for creating a tool guard policy with code generation and E2E testing. + +This script demonstrates: +1. Creating a CugaAgent with flight booking tools +2. Adding a tool guide policy for membership-based restrictions +3. Generating examples for the policy +4. Generating guard code for the policy +5. Updating the policy with examples and code +6. Exporting the policy data to a JSON file +7. Testing the policy by invoking the tool with a violating call using ToolGuardRuntime +8. Cleaning up the test policy at the end + +Configuration: +- Set DELETE_ALL_POLICIES_AT_START = True to delete all existing policies before running +- Set DELETE_ALL_POLICIES_AT_START = False to preserve existing policies (default) +""" + +import asyncio +import json +import shutil +from pathlib import Path +from datetime import datetime + +from langchain_core.tools import tool + +from cuga import CugaAgent +from cuga.backend.cuga_graph.policy.models import ToolGuide +from cuga.backend.cuga_graph.policy.tool_guard import ToolGuardRuntime + +# ============================================================================ +# CONFIGURATION +# ============================================================================ +# Set to True to delete ALL existing policies at startup (clean slate) +# Set to False to preserve existing policies (only delete test policy at end) +DELETE_ALL_POLICIES_AT_START = True +# ============================================================================ + + +@tool +def book_flight(user_id: str, flight_id: str, passengers: int) -> str: + """Book a flight for a user with specified number of passengers""" + return f"Flight {flight_id} booked for user {user_id} with {passengers} passengers" + + +@tool +def get_membership(user_id: str) -> str: + """Get the membership level of a user (gold, silver, or regular)""" + # Simulate membership lookup + memberships = { + "user123": "gold", + "user456": "silver", + "user789": "regular" + } + return memberships.get(user_id, "regular") + + +async def main(): + """Main workflow for creating and managing a tool guard policy with code generation and E2E testing.""" + + # Step 0 (Optional): Clean up ALL existing policies if configured + if DELETE_ALL_POLICIES_AT_START: + print("="*60) + print("Step 0: Cleaning up ALL existing policies (DELETE_ALL_POLICIES_AT_START=True)") + print("="*60) + + agent = CugaAgent(tools=[book_flight, get_membership]) + + # Get policy system + policy_system = await agent.policies._ensure_policy_system() + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + + await policy_system.initialize() + + # List all policies + all_policies = await policy_system.storage.list_policies() + print(f"Found {len(all_policies)} total policies in storage") + + # Delete ALL policies from storage + deleted_count = 0 + for policy in all_policies: + await policy_system.storage.delete_policy(policy.id) + print(f" Deleted from storage: '{policy.name}' (ID: {policy.id})") + deleted_count += 1 + + # Also delete policy files from filesystem if filesystem sync is enabled + if agent.policies._fs_sync: + print("\nCleaning up policy files from filesystem...") + + # Get the .cuga folder path + cuga_folder = Path(agent.policies._fs_sync.cuga_folder) + if cuga_folder.exists(): + print(f"Found .cuga folder at: {cuga_folder}") + + # Delete all policy subfolders + policy_subfolders = [ + 'playbooks', + 'output_formatters', + 'tool_guides', + 'intent_guards', + 'tool_approvals', + 'policies' + ] + + total_deleted = 0 + for subfolder in policy_subfolders: + subfolder_path = cuga_folder / subfolder + if subfolder_path.exists(): + # Count files before deletion + files = list(subfolder_path.glob("*.md")) + list(subfolder_path.glob("*.json")) + if files: + print(f" Deleting {len(files)} files from {subfolder}/") + for file in files: + file.unlink() + total_deleted += 1 + + if total_deleted > 0: + print(f"✅ Deleted {total_deleted} policy files from filesystem") + else: + print("No policy files found in filesystem") + else: + print(f".cuga folder does not exist: {cuga_folder}") + + if deleted_count > 0: + print(f"\n✅ Cleaned up {deleted_count} policy/policies from storage") + else: + print("\nNo existing policies to clean up from storage") + + # Verify cleanup + remaining_policies = await policy_system.storage.list_policies() + if remaining_policies: + print(f"⚠️ Warning: {len(remaining_policies)} policies still remain after cleanup") + for p in remaining_policies: + print(f" - {p.name} (ID: {p.id})") + else: + print("✅ All policies successfully deleted from storage") + + print("="*60) + else: + print("="*60) + print("Skipping initial cleanup (DELETE_ALL_POLICIES_AT_START=False)") + print("Existing policies will be preserved") + print("="*60) + + # Step 1: Create a CugaAgent with the flight booking tools + print("\nStep 1: Creating CugaAgent with flight booking tools...") + agent = CugaAgent(tools=[book_flight, get_membership]) + + # Get policy system for later use + policy_system = await agent.policies._ensure_policy_system() + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + + await policy_system.initialize() + + # Step 2: Add a new tool guide policy + print("\nStep 2: Adding new tool guide policy...") + policy_id = await agent.policies.add_tool_guide( + name="Flight Booking Membership Policy", + content="""## Flight Booking Restrictions by Membership Level + +### Policy Rules +- Customers with "regular" membership cannot book a flight with more than 3 passengers +- Gold and silver members have no passenger restrictions +- This policy ensures fair resource allocation and encourages membership upgrades + +### Validation Requirements +- Always check user membership level before booking +- Reject bookings that violate passenger limits +- Provide clear error messages when restrictions apply +""", + target_tools=["book_flight"], + description="Membership-based restrictions for flight bookings to ensure fair resource allocation", + ) + print(f"Created policy with ID: {policy_id}") + + # Step 3: Generate examples using the SDK function + print("\nStep 3: Generating tool guard examples...") + violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="book_flight" + ) + + print(f"Generated {len(violating_examples)} violating examples and {len(compliance_examples)} compliance examples") + print("\nViolating examples:") + for i, example in enumerate(violating_examples, 1): + print(f" {i}. {example}") + print("\nCompliance examples:") + for i, example in enumerate(compliance_examples, 1): + print(f" {i}. {example}") + + # Step 4: Update the policy with generated examples + print("\nStep 4: Updating policy with generated examples...") + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "book_flight": { + "description": "Guard rules for flight booking based on membership level", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": "" # Will be populated in next step + } + } + ) + print(f"Updated policy {policy_id} with generated examples") + + # Step 5: Generate guard code + print("\nStep 5: Generating guard code...") + guard_code = await agent.policies.generate_tool_guard_code( + policy_id=policy_id, + target_tool="book_flight", + app_name="test_app" # Use simple app name for code generation + ) + + print(f"Generated guard code ({len(guard_code)} characters)") + print("\nGuard code preview (first 500 characters):") + print(guard_code[:500]) + print("...") + + # Step 6: Update the policy with generated code + print("\nStep 6: Updating policy with generated guard code...") + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "book_flight": { + "description": "Guard rules for flight booking based on membership level", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": guard_code + } + } + ) + print(f"Updated policy {policy_id} with generated guard code") + + # Step 7: Save the policy + print("\nStep 7: Saving policy...") + policy_tool_guide = await agent.policies.get(policy_id) + if policy_tool_guide is None: + raise ValueError(f"Failed to retrieve policy {policy_id}") + + policy = policy_tool_guide["policy"] + + # Get policy system and update + policy_system = await agent.policies._ensure_policy_system() + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + + await policy_system.initialize() + + # Save to file system if available + if agent.policies._fs_sync: + agent.policies._fs_sync.save_policy_to_file(policy) + + print("Policy saved successfully!") + + # Step 8: Retrieve the policy from storage + print("\nStep 8: Retrieving policy from storage...") + retrieved_policy = await policy_system.storage.get_policy(policy_id) + if retrieved_policy is None: + raise ValueError(f"Failed to retrieve updated policy {policy_id}") + + if not isinstance(retrieved_policy, ToolGuide): + raise TypeError(f"Expected ToolGuide, got {type(retrieved_policy).__name__}") + + print("Policy retrieved successfully!") + + # Step 9: Export to JSON file + print("\nStep 9: Exporting policy data to JSON file...") + + # Convert tool_guards to serializable format + tool_guards_data = None + if retrieved_policy.tool_guards: + tool_guards_data = {} + for tool_name, guard in retrieved_policy.tool_guards.items(): + tool_guards_data[tool_name] = { + "description": guard.description, + "violating_examples": guard.violating_examples, + "compliance_examples": guard.compliance_examples, + "policy_code": guard.policy_code, + } + + output_data = { + "export_timestamp": datetime.utcnow().isoformat(), + "policy_id": policy_id, + "policy_name": retrieved_policy.name, + "policy_description": retrieved_policy.description, + "guide_content": retrieved_policy.guide_content, + "target_tools": retrieved_policy.target_tools, + "target_apps": retrieved_policy.target_apps, + "tool_guards": tool_guards_data, + "prepend": retrieved_policy.prepend, + "priority": retrieved_policy.priority, + "enabled": retrieved_policy.enabled, + "policy_type": type(retrieved_policy).__name__, + "metadata": retrieved_policy.metadata, + } + + # Save to JSON file in current directory + output_path = Path("flight_booking_policy_export_e2e.json") + output_path.write_text(json.dumps(output_data, indent=2), encoding="utf-8") + + print(f"\n{'='*60}") + print("SUCCESS! Policy workflow with code generation completed:") + print(f"{'='*60}") + print(f"Policy ID: {policy_id}") + print(f"Policy Name: {retrieved_policy.name}") + print(f"Target Tools: {', '.join(retrieved_policy.target_tools)}") + print(f"Generated Examples: {len(violating_examples)} violating, {len(compliance_examples)} compliance") + print(f"Generated Code: {len(guard_code)} characters") + print(f"\nExported to: {output_path.absolute()}") + print(f"{'='*60}") + + # Print summary of tool_guards + if retrieved_policy.tool_guards: + print("\nTool Guards Summary:") + for tool_name, guard in retrieved_policy.tool_guards.items(): + print(f"\n Tool: {tool_name}") + print(f" Description: {guard.description}") + print(f" Violating Examples: {len(guard.violating_examples)}") + print(f" Compliance Examples: {len(guard.compliance_examples)}") + print(f" Policy Code Length: {len(guard.policy_code)} characters") + + # Step 10: Test the policy by invoking the tool with ToolGuardRuntime + print(f"\n{'='*60}") + print("Step 10: Testing policy enforcement with ToolGuardRuntime") + print(f"{'='*60}") + + # Initialize ToolGuardRuntime + print("\nInitializing ToolGuardRuntime...") + tool_guard_runtime = ToolGuardRuntime( + tool_provider=agent.tool_provider, + policy_storage=policy_system.storage + ) + await tool_guard_runtime.initialize() + + print(f"Runtime initialized with guards for: {tool_guard_runtime.get_guarded_tools()}") + + # Test case 1: Violating call (should be blocked) + print("\n--- Test Case 1: Violating Call ---") + print("Attempting: book_flight(flight_id='fl123', user_id='user789', passengers=8)") + print("Expected: BLOCKED (user789 is 'regular' member, 8 > 3 passengers)") + + try: + error = await tool_guard_runtime.guard_tool_call( + app_name="test_app", # Match the app_name used in code generation + function_name="book_flight", + arguments={ + "flight_id": "fl123", + "user_id": "user789", + "passengers": 8 + } + ) + + if error: + print(f"\n✅ SUCCESS: Tool call was correctly BLOCKED!") + print(f"Error message: {error}") + else: + print("\n⚠️ WARNING: Tool call was ALLOWED (expected to be blocked)") + + except Exception as e: + print(f"\n❌ Error during validation: {type(e).__name__}: {e}") + + # Test case 2: Compliant call (should be allowed) + print("\n--- Test Case 2: Compliant Call ---") + print("Attempting: book_flight(flight_id='fl456', user_id='user789', passengers=2)") + print("Expected: ALLOWED (user789 is 'regular' member, 2 <= 3 passengers)") + + try: + error = await tool_guard_runtime.guard_tool_call( + app_name="test_app", # Match the app_name used in code generation + function_name="book_flight", + arguments={ + "flight_id": "fl456", + "user_id": "user789", + "passengers": 2 + } + ) + + if error: + print(f"\n⚠️ WARNING: Tool call was BLOCKED (expected to be allowed)") + print(f"Error message: {error}") + else: + print("\n✅ SUCCESS: Tool call was correctly ALLOWED!") + + # Actually invoke the tool to show it works + print("\nExecuting the actual tool call...") + result = await book_flight.ainvoke({ + "flight_id": "fl456", + "user_id": "user789", + "passengers": 2 + }) + print(f"Tool result: {result}") + + except Exception as e: + print(f"\n❌ Error during validation: {type(e).__name__}: {e}") + + # Test case 3: Gold member with many passengers (should be allowed) + print("\n--- Test Case 3: Gold Member Call ---") + print("Attempting: book_flight(flight_id='fl789', user_id='user123', passengers=10)") + print("Expected: ALLOWED (user123 is 'gold' member, no passenger limit)") + + try: + error = await tool_guard_runtime.guard_tool_call( + app_name="test_app", # Match the app_name used in code generation + function_name="book_flight", + arguments={ + "flight_id": "fl789", + "user_id": "user123", + "passengers": 10 + } + ) + + if error: + print(f"\n⚠️ WARNING: Tool call was BLOCKED (expected to be allowed)") + print(f"Error message: {error}") + else: + print("\n✅ SUCCESS: Tool call was correctly ALLOWED!") + + # Actually invoke the tool to show it works + print("\nExecuting the actual tool call...") + result = await book_flight.ainvoke({ + "flight_id": "fl789", + "user_id": "user123", + "passengers": 10 + }) + print(f"Tool result: {result}") + + except Exception as e: + print(f"\n❌ Error during validation: {type(e).__name__}: {e}") + + print(f"\n{'='*60}") + print("E2E Test Complete!") + print(f"{'='*60}") + print("\nSummary:") + print(" - Policy created and saved successfully") + print(" - Examples and guard code generated") + print(" - ToolGuardRuntime initialized and tested") + print(" - All test cases executed") + + # Step 11: Clean up - Delete the policy we created + print(f"\n{'='*60}") + print("Step 11: Cleaning up - Deleting test policy") + print(f"{'='*60}") + + try: + # Delete from storage + await policy_system.storage.delete_policy(policy_id) + print(f"✅ Deleted policy from storage: {policy_id}") + + # Delete from filesystem if filesystem sync is enabled + if agent.policies._fs_sync: + # Get the .cuga folder path + cuga_folder = Path(agent.policies._fs_sync.cuga_folder) + tool_guides_folder = cuga_folder / "tool_guides" + + if tool_guides_folder.exists(): + # Find and delete the policy file + policy_files = list(tool_guides_folder.glob(f"*{policy_id}*.md")) + \ + list(tool_guides_folder.glob(f"*{policy_id}*.json")) + + for policy_file in policy_files: + policy_file.unlink() + print(f"✅ Deleted policy file: {policy_file.name}") + + if not policy_files: + print("⚠️ No policy files found to delete from filesystem") + + # Verify deletion + deleted_policy = await policy_system.storage.get_policy(policy_id) + if deleted_policy is None: + print("✅ Policy successfully deleted and verified") + else: + print("⚠️ Warning: Policy still exists in storage after deletion") + + except Exception as e: + print(f"⚠️ Error during cleanup: {type(e).__name__}: {e}") + + print(f"\n{'='*60}") + print("✅ Test completed and cleaned up successfully!") + print(f"{'='*60}") + + +if __name__ == "__main__": + asyncio.run(main()) + +# Made with Bob \ No newline at end of file From 707a1a178481e909348f7a74bcf7accee65673ec Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 27 Apr 2026 10:24:28 +0300 Subject: [PATCH 08/24] runtime --- .../policy/tool_guard/tool_guard_runtime.py | 527 ++++++++++-------- 1 file changed, 298 insertions(+), 229 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index f57ccd53..82425001 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -5,10 +5,20 @@ ToolGuide policies with policy_code. """ +from pathlib import Path from typing import Dict, Any, List, Optional from loguru import logger -from toolguard.runtime.data_types import PolicyViolationException +from toolguard.runtime.data_types import ( + FileTwin, + PolicyViolationException, + RuntimeDomain, + ToolGuardCodeResult, + ToolGuardSpec, + ToolGuardSpecItem, + ToolGuardsCodeGenerationResult, +) +from toolguard.runtime.runtime import load_toolguards_from_memory from cuga.backend.cuga_graph.policy.tool_guard.tool_invoker import ToolGuardInvoker from cuga.backend.cuga_graph.policy.models import ToolGuide, PolicyType @@ -18,53 +28,64 @@ class ToolGuardRuntime: """ Runtime system for executing tool guards during tool invocation. - + This class: 1. Initializes a ToolGuardInvoker for tool execution 2. Loads all ToolGuide policies with policy_code 3. Creates a mapping: tool_name -> List[ToolGuide with code] - 4. Executes guard validation when tools are called + 4. Prebuilds umbrella guard modules per tool + 5. Executes guard validation through toolguard runtime """ - + def __init__(self, tool_provider, policy_storage: PolicyStorage): """ Initialize the ToolGuardRuntime. - + Args: tool_provider: CUGA's tool provider instance policy_storage: PolicyStorage instance to load policies from """ self.tool_provider = tool_provider self.policy_storage = policy_storage - + # Create ToolGuardInvoker instance self.invoker = ToolGuardInvoker(tool_provider) - + # Mapping: tool_name -> List[ToolGuide policies with code] self.tool_to_guards: Dict[str, List[ToolGuide]] = {} - + + # ToolGuard in-memory runtime built from umbrella guard modules + self._runtime = None + self._runtime_domain: Optional[RuntimeDomain] = None + self._initialized = False logger.debug("Created ToolGuardRuntime instance") - + async def initialize(self) -> None: """ Initialize the runtime by loading all ToolGuide policies with code. - + This method: 1. Fetches all ToolGuide policies from storage 2. Filters for policies that have tool_guards with policy_code 3. Builds the tool_to_guards mapping + 4. Creates umbrella guard functions per tool + 5. Builds an in-memory ToolGuard runtime """ logger.info("Initializing ToolGuardRuntime...") - + + self.tool_to_guards = {} + self._runtime = None + self._runtime_domain = None + # Load all ToolGuide policies policies = await self.policy_storage.list_policies( policy_type=PolicyType.TOOL_GUIDE, enabled_only=True ) - + logger.debug(f"Found {len(policies)} ToolGuide policies") - + # Build mapping: tool_name -> List[ToolGuide with code] for policy in policies: # Type guard: ensure we're working with ToolGuide @@ -73,18 +94,18 @@ async def initialize(self) -> None: f"Expected ToolGuide but got {type(policy).__name__}, skipping" ) continue - + if not policy.tool_guards: logger.debug(f"Policy '{policy.name}' has no tool_guards, skipping") continue - + # Iterate through each tool's guard configuration for tool_name, tool_guard in policy.tool_guards.items(): # Only include if policy_code exists if tool_guard.policy_code: if tool_name not in self.tool_to_guards: self.tool_to_guards[tool_name] = [] - + self.tool_to_guards[tool_name].append(policy) logger.debug( f"Registered guard for tool '{tool_name}' " @@ -95,20 +116,236 @@ async def initialize(self) -> None: f"Tool guard for '{tool_name}' in policy '{policy.name}' " f"has no policy_code, skipping" ) - + + if self.tool_to_guards: + self._runtime_domain = self._load_runtime_domain() + self._runtime = self._build_runtime() + self._initialized = True logger.info( f"✅ ToolGuardRuntime initialized with guards for " f"{len(self.tool_to_guards)} tools" ) - + # Log summary of registered guards for tool_name, guards in self.tool_to_guards.items(): logger.debug( f" - Tool '{tool_name}': {len(guards)} guard(s) " f"({', '.join(g.name for g in guards)})" ) - + + def _build_runtime(self): + """Build an in-memory ToolGuard runtime from registered guard policies.""" + if self._runtime_domain is None: + raise RuntimeError("ToolGuard runtime domain is not loaded") + + file_twins: List[FileTwin] = [ + self._runtime_domain.app_types, + self._runtime_domain.app_api, + self._runtime_domain.app_api_impl, + ] + tools: Dict[str, ToolGuardCodeResult] = {} + + for tool_name, guards in self.tool_to_guards.items(): + module_name = self._module_name_for_tool(tool_name) + guard_fn_name = self._guard_function_name_for_tool(tool_name) + guard_module_path = Path(*module_name.split(".")).with_suffix(".py") + + module_content = self._build_tool_guard_module( + tool_name=tool_name, + guards=guards, + guard_fn_name=guard_fn_name, + ) + + guard_file = FileTwin( + file_name=guard_module_path, + content=module_content, + ) + file_twins.append(guard_file) + + tools[tool_name] = ToolGuardCodeResult( + tool=ToolGuardSpec( + tool_name=tool_name, + policy_items=[ + ToolGuardSpecItem( + name=policy.name, + description=f"Runtime guard from policy '{policy.name}'", + ) + for policy in guards + ], + ), + guard_fn_name=guard_fn_name, + guard_file=guard_file, + item_guard_files=[], + test_files=[], + ) + + result = ToolGuardsCodeGenerationResult( + out_dir=Path("."), + domain=self._runtime_domain, + tools=tools, + ) + + runtime = load_toolguards_from_memory(result) + runtime.__enter__() + return runtime + + def _load_runtime_domain(self) -> RuntimeDomain: + """Load RuntimeDomain files saved by ToolGuardManager.""" + domain_dir = Path.cwd() / ".cuga" / "toolguard" / "domain" + + if not domain_dir.exists(): + raise RuntimeError( + f"ToolGuard domain directory not found: {domain_dir}. " + "Generate tool guard code first so ToolGuardManager saves the domain files." + ) + + app_dirs = sorted( + [path for path in domain_dir.iterdir() if path.is_dir()], + key=lambda path: path.stat().st_mtime, + reverse=True, + ) + if not app_dirs: + raise RuntimeError( + f"No ToolGuard app directories found under {domain_dir}" + ) + + selected_domain = None + for app_dir in app_dirs: + app_name = app_dir.name + app_types_rel = Path(app_name) / f"{app_name}_types.py" + app_api_rel = Path(app_name) / f"i_{app_name}.py" + app_api_impl_rel = Path(app_name) / f"{app_name}_impl.py" + + candidate_paths = [ + domain_dir / app_types_rel, + domain_dir / app_api_rel, + domain_dir / app_api_impl_rel, + ] + if all(path.exists() for path in candidate_paths): + selected_domain = (app_name, app_types_rel, app_api_rel, app_api_impl_rel) + break + + if selected_domain is None: + raise RuntimeError( + f"No complete ToolGuard domain found under {domain_dir}" + ) + + app_name, app_types_rel, app_api_rel, app_api_impl_rel = selected_domain + + api_content = FileTwin.load_from(domain_dir, app_api_rel).content + api_impl_content = FileTwin.load_from(domain_dir, app_api_impl_rel).content + + app_api_class_name = f"I{''.join(part.capitalize() for part in app_name.split('_'))}" + app_api_impl_class_name = ''.join(part.capitalize() for part in app_name.split('_')) + + for line in api_content.splitlines(): + stripped = line.strip() + if stripped.startswith("class "): + app_api_class_name = stripped.split()[1].split("(")[0].rstrip(":") + break + + for line in api_impl_content.splitlines(): + stripped = line.strip() + if stripped.startswith("class "): + app_api_impl_class_name = stripped.split()[1].split("(")[0].rstrip(":") + break + + return RuntimeDomain( + app_name=app_name, + app_types=FileTwin.load_from(domain_dir, app_types_rel), + app_api_class_name=app_api_class_name, + app_api=FileTwin.load_from(domain_dir, app_api_rel), + app_api_size=0, + app_api_impl_class_name=app_api_impl_class_name, + app_api_impl=FileTwin.load_from(domain_dir, app_api_impl_rel), + ) + + def _build_tool_guard_module( + self, + tool_name: str, + guards: List[ToolGuide], + guard_fn_name: str, + ) -> str: + """Create a module containing a single umbrella guard function for one tool.""" + guard_blocks: List[str] = [] + guard_calls: List[str] = [] + + for index, policy in enumerate(guards): + tool_guard = policy.tool_guards.get(tool_name) if policy.tool_guards else None + if not tool_guard or not tool_guard.policy_code: + logger.warning( + f"Policy '{policy.name}' missing tool_guard for '{tool_name}', skipping" + ) + continue + + compiled_name = f"_compiled_guard_{index}" + validate_alias = f"_guard_validate_{index}" + + guard_blocks.append( + f"# Policy: {policy.name}\n" + f"{tool_guard.policy_code}\n" + f"{compiled_name} = locals().get('validate')\n" + f"if {compiled_name} is None:\n" + f" for _name, _obj in list(locals().items()):\n" + f" if _name.startswith('guard_') and callable(_obj):\n" + f" {compiled_name} = _obj\n" + f" break\n" + f"if {compiled_name} is None:\n" + f" raise ValueError(\n" + f" \"Policy code for '{policy.name}' does not define a 'validate' or 'guard_*' function\"\n" + f" )\n" + f"{validate_alias} = {compiled_name}\n" + ) + + guard_calls.extend( + [ + " try:", + f" await {validate_alias}(api=api, args=args)", + " except PolicyViolationException as e:", + f" violations.append(\"[{policy.name}] \" + str(e))", + ] + ) + + if not guard_calls: + guard_calls.append(" return None") + else: + guard_calls = [ + " violations = []", + *guard_calls, + " if violations:", + " raise PolicyViolationException(\"\\n\".join(violations))", + ] + + return ( + "from toolguard.runtime.data_types import (\n" + " PolicyViolationException,\n" + " assert_any_condition_met,\n" + ")\n" + "from toolguard.runtime.rules import rule\n\n" + f"{''.join(guard_blocks)}\n" + f"async def {guard_fn_name}(api, args):\n" + f"{chr(10).join(guard_calls)}\n" + ) + + def _module_name_for_tool(self, tool_name: str) -> str: + """Convert a tool name to a valid python module name.""" + normalized = "".join( + ch if ch.isalnum() else "_" for ch in tool_name.lower() + ).strip("_") + if not normalized: + normalized = "tool" + return f"cuga_toolguard_runtime.generated.guard_{normalized}" + + def _guard_function_name_for_tool(self, tool_name: str) -> str: + """Convert a tool name to a valid umbrella guard function name.""" + normalized = "".join( + ch if ch.isalnum() else "_" for ch in tool_name.lower() + ).strip("_") + if not normalized: + normalized = "tool" + return f"guard_{normalized}" + async def guard_tool_call( self, app_name: str, @@ -117,252 +354,84 @@ async def guard_tool_call( ) -> Optional[str]: """ Validate a tool call against registered guards. - - This method executes all registered guards for the specified tool. - If any guard returns an error message, validation fails and that - error message is returned. If all guards pass (return None), - validation succeeds and None is returned. - + + This method delegates validation to the ToolGuard runtime using a + prebuilt umbrella guard function for the requested tool. + Args: app_name: Name of the application calling the tool function_name: Name of the tool/function being called arguments: Arguments being passed to the tool - + Returns: Error message string if validation fails, None if validation passes """ if not self._initialized: logger.warning("ToolGuardRuntime not initialized, skipping validation") return None - + # Check if this tool has any guards if function_name not in self.tool_to_guards: logger.debug(f"No guards registered for tool '{function_name}'") return None - - # Get all guards for this tool + + if self._runtime is None: + logger.warning( + f"ToolGuard runtime unavailable for guarded tool '{function_name}', " + "skipping validation" + ) + return None + guards = self.tool_to_guards[function_name] - logger.debug( f"Validating tool call '{function_name}' against " - f"{len(guards)} guard(s)" + f"{len(guards)} guard(s) using umbrella runtime" ) - - # Execute each guard's policy_code - for policy in guards: - # Ensure policy has tool_guards and the function_name exists - if not policy.tool_guards or function_name not in policy.tool_guards: - logger.warning( - f"Policy '{policy.name}' missing tool_guard for '{function_name}', skipping" - ) - continue - - tool_guard = policy.tool_guards[function_name] - - try: - logger.debug( - f"Executing guard '{policy.name}' for tool '{function_name}'" - ) - - def rule(name: str): - """Decorator for policy rules - just returns the function.""" - def decorator(func): - return func - return decorator - - def assert_any_condition_met(*conditions): - """Mock function for assert_any_condition_met.""" - pass - - # Create a simple args object from the arguments dict - class Args: - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - args_obj = Args(**arguments) - - # Create a dynamic API wrapper that provides method-like access to tools - class DynamicAPIWrapper: - """Wrapper that allows calling tools as methods on the api object.""" - - def __init__(self, invoker): - self._invoker = invoker - - def __getattr__(self, tool_name: str): - """ - Return a callable that invokes the tool with the given name. - This allows code like: api.get_membership(args) - """ - async def tool_caller(args_obj): - # Convert args object to dict - if hasattr(args_obj, '__dict__'): - args_dict = args_obj.__dict__ - else: - args_dict = args_obj - - # Invoke the tool through the invoker - result = await self._invoker.invoke( - toolname=tool_name, - arguments=args_dict, - return_type=str # Default to str, could be improved - ) - return result - - return tool_caller - - api_wrapper = DynamicAPIWrapper(self.invoker) - - # Create execution context with necessary variables and mocks - exec_globals = { - 'app_name': app_name, - 'function_name': function_name, - 'arguments': arguments, - 'invoker': self.invoker, - # Mock imports for generated code - 'PolicyViolationException': PolicyViolationException, - 'rule': rule, - 'assert_any_condition_met': assert_any_condition_met, - 'typing': __import__('typing'), - 'Any': Any, - 'Dict': Dict, - 'List': List, - 'Optional': Optional, - } - - # Execute the policy code - # The policy_code may define either: - # 1. A 'validate' function (legacy format) - # 2. A decorated function with @rule (generated format) - exec(tool_guard.policy_code, exec_globals) - - # Find the validation function - validate_fn = None - - # First, check for 'validate' function (legacy format) - if 'validate' in exec_globals: - validate_fn = exec_globals['validate'] - logger.debug(f"Found 'validate' function in policy code") - else: - # Look for any function that starts with 'guard_' (generated format) - for name, obj in exec_globals.items(): - if name.startswith('guard_') and callable(obj): - validate_fn = obj - logger.debug(f"Found guard function '{name}' in policy code") - break - - if validate_fn is None: - logger.warning( - f"Policy code for '{policy.name}' does not define " - f"a 'validate' or 'guard_*' function, skipping" - ) - continue - - # Call the validate function - # Support both sync and async validate functions - import inspect - - # Determine the function signature to call it correctly - sig = inspect.signature(validate_fn) - params = list(sig.parameters.keys()) - - error: Optional[str] = None - - # Call the validate function and catch PolicyViolationException - try: - if inspect.iscoroutinefunction(validate_fn): - # For generated code format: func(api, args) - if len(params) >= 2 and params[0] in ['api', 'invoker']: - await validate_fn(api_wrapper, args_obj) - # For legacy format: func(app_name, function_name, arguments) - elif 'app_name' in params or 'function_name' in params: - result = await validate_fn( - app_name=app_name, - function_name=function_name, - arguments=arguments - ) - error = str(result) if result else None - else: - # Try calling with api_wrapper and args - await validate_fn(api_wrapper, args_obj) - else: - # Sync function - if len(params) >= 2 and params[0] in ['api', 'invoker']: - validate_fn(api_wrapper, args_obj) - elif 'app_name' in params or 'function_name' in params: - result = validate_fn( - app_name=app_name, - function_name=function_name, - arguments=arguments - ) - error = str(result) if result else None - else: - validate_fn(api_wrapper, args_obj) - - except PolicyViolationException as e: - # Policy violation - this is expected behavior - error = str(e) - logger.debug( - f"Guard '{policy.name}' caught policy violation: {error}" - ) - - # If validation failed, return the error message - if error: - logger.warning( - f"Tool guard '{policy.name}' blocked call to " - f"'{function_name}': {error}" - ) - return error - - logger.debug( - f"Guard '{policy.name}' passed for tool '{function_name}'" - ) - - except PolicyViolationException as e: - # Policy violation at the outer level - should have been caught above - # but handle it here as well for safety - error = str(e) - logger.warning( - f"Tool guard '{policy.name}' blocked call to " - f"'{function_name}': {error}" - ) - return error - - except Exception as e: - logger.error( - f"Error executing guard '{policy.name}' for tool " - f"'{function_name}': {e}", - exc_info=True - ) - # Continue to next guard instead of failing completely - # This ensures one broken guard doesn't break all validation - continue - - # All guards passed + + try: + args_obj = type("ToolGuardArgs", (), arguments)() + await self._runtime.guard_toolcall( + tool_name=function_name, + args=arguments | {"args": args_obj}, + delegate=self.invoker, + ) + except PolicyViolationException as e: + error = str(e) + logger.warning( + f"Tool guard blocked call to '{function_name}': {error}" + ) + return error + except Exception as e: + logger.error( + f"Error executing umbrella guard for tool '{function_name}': {e}", + exc_info=True + ) + return None + logger.debug(f"Tool call '{function_name}' passed all guards") return None - + @property def is_initialized(self) -> bool: """Check if the runtime has been initialized.""" return self._initialized - + def get_guarded_tools(self) -> List[str]: """ Get list of tool names that have guards registered. - + Returns: List of tool names with active guards """ return list(self.tool_to_guards.keys()) - + def get_guards_for_tool(self, tool_name: str) -> List[ToolGuide]: """ Get all guards registered for a specific tool. - + Args: tool_name: Name of the tool - + Returns: List of ToolGuide policies with guards for this tool """ From 38c5209f89a2d2c257b2bd9e917b10ab58eccb8b Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 27 Apr 2026 12:00:10 +0300 Subject: [PATCH 09/24] test runtime --- .../policy/tool_guard/tool_guard_runtime.py | 35 +- src/cuga/sdk_core/debug_tool_guard.py | 220 ------ src/cuga/sdk_core/debug_tool_guard_e2e.py | 695 ++++++++---------- 3 files changed, 319 insertions(+), 631 deletions(-) delete mode 100644 src/cuga/sdk_core/debug_tool_guard.py diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index 82425001..9892ce03 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -282,20 +282,27 @@ def _build_tool_guard_module( compiled_name = f"_compiled_guard_{index}" validate_alias = f"_guard_validate_{index}" + # Extract the guard function name from the policy code + # The generated code has @rule decorator followed by async def guard_xxx + guard_func_name = None + for line in tool_guard.policy_code.split('\n'): + line = line.strip() + if line.startswith('async def guard_'): + # Extract function name: "async def guard_xxx(..." -> "guard_xxx" + guard_func_name = line.split('(')[0].replace('async def ', '').strip() + break + + if not guard_func_name: + logger.warning( + f"Could not find guard function in policy code for '{policy.name}', skipping" + ) + continue + guard_blocks.append( f"# Policy: {policy.name}\n" f"{tool_guard.policy_code}\n" - f"{compiled_name} = locals().get('validate')\n" - f"if {compiled_name} is None:\n" - f" for _name, _obj in list(locals().items()):\n" - f" if _name.startswith('guard_') and callable(_obj):\n" - f" {compiled_name} = _obj\n" - f" break\n" - f"if {compiled_name} is None:\n" - f" raise ValueError(\n" - f" \"Policy code for '{policy.name}' does not define a 'validate' or 'guard_*' function\"\n" - f" )\n" - f"{validate_alias} = {compiled_name}\n" + f"# Assign the specific guard function for this policy\n" + f"{validate_alias} = {guard_func_name}\n" ) guard_calls.extend( @@ -303,7 +310,11 @@ def _build_tool_guard_module( " try:", f" await {validate_alias}(api=api, args=args)", " except PolicyViolationException as e:", - f" violations.append(\"[{policy.name}] \" + str(e))", + f" error_msg = str(e)", + f" # Check if error already contains policy name to avoid duplication", + f" if not error_msg.startswith('[{policy.name}]'):", + f" error_msg = f\"[{policy.name}] {{error_msg}}\"", + f" violations.append(error_msg)", ] ) diff --git a/src/cuga/sdk_core/debug_tool_guard.py b/src/cuga/sdk_core/debug_tool_guard.py deleted file mode 100644 index 79df65ba..00000000 --- a/src/cuga/sdk_core/debug_tool_guard.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Debug script for creating a tool guard policy with code generation. - -This script demonstrates: -1. Creating a CugaAgent with flight booking tools -2. Adding a tool guide policy for membership-based restrictions -3. Generating examples for the policy -4. Generating guard code for the policy -5. Updating the policy with examples and code -6. Exporting the policy data to a JSON file -""" - -import asyncio -import json -from pathlib import Path -from datetime import datetime - -from langchain_core.tools import tool - -from cuga import CugaAgent -from cuga.backend.cuga_graph.policy.models import ToolGuide - - -@tool -def book_flight(user_id: str, flight_id: str, passengers: int) -> str: - """Book a flight for a user with specified number of passengers""" - return f"Flight {flight_id} booked for user {user_id} with {passengers} passengers" - - -@tool -def get_membership(user_id: str) -> str: - """Get the membership level of a user (gold, silver, or regular)""" - # Simulate membership lookup - memberships = { - "user123": "gold", - "user456": "silver", - "user789": "regular" - } - return memberships.get(user_id, "regular") - - -async def main(): - """Main workflow for creating and managing a tool guard policy with code generation.""" - - # Step 1: Create a CugaAgent with the flight booking tools - print("Step 1: Creating CugaAgent with flight booking tools...") - agent = CugaAgent(tools=[book_flight, get_membership]) - - # Step 2: Add a new tool guide policy - print("\nStep 2: Adding new tool guide policy...") - policy_id = await agent.policies.add_tool_guide( - name="Flight Booking Membership Policy", - content="""## Flight Booking Restrictions by Membership Level - -### Policy Rules -- Customers with "regular" membership cannot book a flight with more than 3 passengers -- Gold and silver members have no passenger restrictions -- This policy ensures fair resource allocation and encourages membership upgrades - -### Validation Requirements -- Always check user membership level before booking -- Reject bookings that violate passenger limits -- Provide clear error messages when restrictions apply -""", - target_tools=["book_flight"], - description="Membership-based restrictions for flight bookings to ensure fair resource allocation", - ) - print(f"Created policy with ID: {policy_id}") - - # Step 3: Generate examples using the SDK function - print("\nStep 3: Generating tool guard examples...") - violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( - policy_id=policy_id, - target_tool="book_flight" - ) - - print(f"Generated {len(violating_examples)} violating examples and {len(compliance_examples)} compliance examples") - print("\nViolating examples:") - for i, example in enumerate(violating_examples, 1): - print(f" {i}. {example}") - print("\nCompliance examples:") - for i, example in enumerate(compliance_examples, 1): - print(f" {i}. {example}") - - # Step 4: Update the policy with generated examples - print("\nStep 4: Updating policy with generated examples...") - await agent.policies.update_tool_guard( - policy_id=policy_id, - tool_guards={ - "book_flight": { - "description": "Guard rules for flight booking based on membership level", - "violating_examples": violating_examples, - "compliance_examples": compliance_examples, - "policy_code": "" # Will be populated in next step - } - } - ) - print(f"Updated policy {policy_id} with generated examples") - - # Step 5: Generate guard code - print("\nStep 5: Generating guard code...") - guard_code = await agent.policies.generate_tool_guard_code( - policy_id=policy_id, - target_tool="book_flight", - app_name="flight_booking_app" - ) - - print(f"Generated guard code ({len(guard_code)} characters)") - print("\nGuard code preview (first 500 characters):") - print(guard_code[:500]) - print("...") - - # Step 6: Update the policy with generated code - print("\nStep 6: Updating policy with generated guard code...") - await agent.policies.update_tool_guard( - policy_id=policy_id, - tool_guards={ - "book_flight": { - "description": "Guard rules for flight booking based on membership level", - "violating_examples": violating_examples, - "compliance_examples": compliance_examples, - "policy_code": guard_code - } - } - ) - print(f"Updated policy {policy_id} with generated guard code") - - # Step 7: Save the policy - print("\nStep 7: Saving policy...") - policy_tool_guide = await agent.policies.get(policy_id) - if policy_tool_guide is None: - raise ValueError(f"Failed to retrieve policy {policy_id}") - - policy = policy_tool_guide["policy"] - - # Get policy system and update - policy_system = await agent.policies._ensure_policy_system() - if policy_system is None or policy_system.storage is None: - raise ValueError("Policy system storage is not available") - - await policy_system.initialize() - - # Save to file system if available - if agent.policies._fs_sync: - agent.policies._fs_sync.save_policy_to_file(policy) - - print("Policy saved successfully!") - - # Step 8: Retrieve the policy from storage - print("\nStep 8: Retrieving policy from storage...") - retrieved_policy = await policy_system.storage.get_policy(policy_id) - if retrieved_policy is None: - raise ValueError(f"Failed to retrieve updated policy {policy_id}") - - if not isinstance(retrieved_policy, ToolGuide): - raise TypeError(f"Expected ToolGuide, got {type(retrieved_policy).__name__}") - - print("Policy retrieved successfully!") - - # Step 9: Export to JSON file - print("\nStep 9: Exporting policy data to JSON file...") - - # Convert tool_guards to serializable format - tool_guards_data = None - if retrieved_policy.tool_guards: - tool_guards_data = {} - for tool_name, guard in retrieved_policy.tool_guards.items(): - tool_guards_data[tool_name] = { - "description": guard.description, - "violating_examples": guard.violating_examples, - "compliance_examples": guard.compliance_examples, - "policy_code": guard.policy_code, - } - - output_data = { - "export_timestamp": datetime.utcnow().isoformat(), - "policy_id": policy_id, - "policy_name": retrieved_policy.name, - "policy_description": retrieved_policy.description, - "guide_content": retrieved_policy.guide_content, - "target_tools": retrieved_policy.target_tools, - "target_apps": retrieved_policy.target_apps, - "tool_guards": tool_guards_data, - "prepend": retrieved_policy.prepend, - "priority": retrieved_policy.priority, - "enabled": retrieved_policy.enabled, - "policy_type": type(retrieved_policy).__name__, - "metadata": retrieved_policy.metadata, - } - - # Save to JSON file in current directory - output_path = Path("flight_booking_policy_export.json") - output_path.write_text(json.dumps(output_data, indent=2), encoding="utf-8") - - print(f"\n{'='*60}") - print("SUCCESS! Policy workflow with code generation completed:") - print(f"{'='*60}") - print(f"Policy ID: {policy_id}") - print(f"Policy Name: {retrieved_policy.name}") - print(f"Target Tools: {', '.join(retrieved_policy.target_tools)}") - print(f"Generated Examples: {len(violating_examples)} violating, {len(compliance_examples)} compliance") - print(f"Generated Code: {len(guard_code)} characters") - print(f"\nExported to: {output_path.absolute()}") - print(f"{'='*60}") - - # Print summary of tool_guards - if retrieved_policy.tool_guards: - print("\nTool Guards Summary:") - for tool_name, guard in retrieved_policy.tool_guards.items(): - print(f"\n Tool: {tool_name}") - print(f" Description: {guard.description}") - print(f" Violating Examples: {len(guard.violating_examples)}") - print(f" Compliance Examples: {len(guard.compliance_examples)}") - print(f" Policy Code Length: {len(guard.policy_code)} characters") - - -if __name__ == "__main__": - asyncio.run(main()) - -# Made with Bob \ No newline at end of file diff --git a/src/cuga/sdk_core/debug_tool_guard_e2e.py b/src/cuga/sdk_core/debug_tool_guard_e2e.py index 43f5ed24..eafb93c2 100644 --- a/src/cuga/sdk_core/debug_tool_guard_e2e.py +++ b/src/cuga/sdk_core/debug_tool_guard_e2e.py @@ -1,15 +1,12 @@ """ -Debug script for creating a tool guard policy with code generation and E2E testing. +Debug script for creating tool guard policies with code generation and E2E testing. This script demonstrates: 1. Creating a CugaAgent with flight booking tools -2. Adding a tool guide policy for membership-based restrictions -3. Generating examples for the policy -4. Generating guard code for the policy -5. Updating the policy with examples and code -6. Exporting the policy data to a JSON file -7. Testing the policy by invoking the tool with a violating call using ToolGuardRuntime -8. Cleaning up the test policy at the end +2. Adding multiple tool guide policies +3. Generating examples and guard code for each policy +4. Testing policies with ToolGuardRuntime +5. Cleaning up test policies at the end Configuration: - Set DELETE_ALL_POLICIES_AT_START = True to delete all existing policies before running @@ -17,10 +14,7 @@ """ import asyncio -import json -import shutil from pathlib import Path -from datetime import datetime from langchain_core.tools import tool @@ -31,11 +25,46 @@ # ============================================================================ # CONFIGURATION # ============================================================================ -# Set to True to delete ALL existing policies at startup (clean slate) -# Set to False to preserve existing policies (only delete test policy at end) DELETE_ALL_POLICIES_AT_START = True # ============================================================================ +# Define policies to create +POLICIES = [ + { + "name": "Flight Booking Membership Policy", + "content": """## Flight Booking Restrictions by Membership Level + +### Policy Rules +- Customers with "regular" membership cannot book a flight with more than 3 passengers +- Gold and silver members have no passenger restrictions +- This policy ensures fair resource allocation and encourages membership upgrades + +### Validation Requirements +- Always check user membership level before booking +- Reject bookings that violate passenger limits +- Provide clear error messages when restrictions apply +""", + "description": "Membership-based restrictions for flight bookings to ensure fair resource allocation", + }, + { + "name": "Flight ID Format Policy", + "content": """## Flight ID Format Requirements + +### Policy Rules +- Flight ID must start with exactly 2 letters +- Flight ID must have a total of exactly 4 characters (2 letters + 2 digits) +- Example valid flight IDs: FL12, AB99, XY01 +- Example invalid flight IDs: F123 (only 1 letter), FLI2 (3 letters), FL1 (only 3 characters total) + +### Validation Requirements +- Always validate flight ID format before booking +- Reject bookings with invalid flight ID format +- Provide clear error messages when format is incorrect +""", + "description": "Flight ID format validation to ensure proper booking system compatibility", + }, +] + @tool def book_flight(user_id: str, flight_id: str, passengers: int) -> str: @@ -46,7 +75,6 @@ def book_flight(user_id: str, flight_id: str, passengers: int) -> str: @tool def get_membership(user_id: str) -> str: """Get the membership level of a user (gold, silver, or regular)""" - # Simulate membership lookup memberships = { "user123": "gold", "user456": "silver", @@ -55,429 +83,298 @@ def get_membership(user_id: str) -> str: return memberships.get(user_id, "regular") -async def main(): - """Main workflow for creating and managing a tool guard policy with code generation and E2E testing.""" - - # Step 0 (Optional): Clean up ALL existing policies if configured - if DELETE_ALL_POLICIES_AT_START: - print("="*60) - print("Step 0: Cleaning up ALL existing policies (DELETE_ALL_POLICIES_AT_START=True)") - print("="*60) - - agent = CugaAgent(tools=[book_flight, get_membership]) - - # Get policy system - policy_system = await agent.policies._ensure_policy_system() - if policy_system is None or policy_system.storage is None: - raise ValueError("Policy system storage is not available") - - await policy_system.initialize() - - # List all policies - all_policies = await policy_system.storage.list_policies() - print(f"Found {len(all_policies)} total policies in storage") - - # Delete ALL policies from storage - deleted_count = 0 - for policy in all_policies: - await policy_system.storage.delete_policy(policy.id) - print(f" Deleted from storage: '{policy.name}' (ID: {policy.id})") - deleted_count += 1 - - # Also delete policy files from filesystem if filesystem sync is enabled - if agent.policies._fs_sync: - print("\nCleaning up policy files from filesystem...") - - # Get the .cuga folder path - cuga_folder = Path(agent.policies._fs_sync.cuga_folder) - if cuga_folder.exists(): - print(f"Found .cuga folder at: {cuga_folder}") - - # Delete all policy subfolders - policy_subfolders = [ - 'playbooks', - 'output_formatters', - 'tool_guides', - 'intent_guards', - 'tool_approvals', - 'policies' - ] - - total_deleted = 0 - for subfolder in policy_subfolders: - subfolder_path = cuga_folder / subfolder - if subfolder_path.exists(): - # Count files before deletion - files = list(subfolder_path.glob("*.md")) + list(subfolder_path.glob("*.json")) - if files: - print(f" Deleting {len(files)} files from {subfolder}/") - for file in files: - file.unlink() - total_deleted += 1 - - if total_deleted > 0: - print(f"✅ Deleted {total_deleted} policy files from filesystem") - else: - print("No policy files found in filesystem") - else: - print(f".cuga folder does not exist: {cuga_folder}") - - if deleted_count > 0: - print(f"\n✅ Cleaned up {deleted_count} policy/policies from storage") - else: - print("\nNo existing policies to clean up from storage") - - # Verify cleanup - remaining_policies = await policy_system.storage.list_policies() - if remaining_policies: - print(f"⚠️ Warning: {len(remaining_policies)} policies still remain after cleanup") - for p in remaining_policies: - print(f" - {p.name} (ID: {p.id})") - else: - print("✅ All policies successfully deleted from storage") - - print("="*60) - else: - print("="*60) - print("Skipping initial cleanup (DELETE_ALL_POLICIES_AT_START=False)") - print("Existing policies will be preserved") - print("="*60) +async def cleanup_all_policies(agent): + """Clean up all existing policies if configured.""" + print("="*60) + print("Step 0: Cleaning up ALL existing policies") + print("="*60) - # Step 1: Create a CugaAgent with the flight booking tools - print("\nStep 1: Creating CugaAgent with flight booking tools...") - agent = CugaAgent(tools=[book_flight, get_membership]) - - # Get policy system for later use policy_system = await agent.policies._ensure_policy_system() if policy_system is None or policy_system.storage is None: raise ValueError("Policy system storage is not available") await policy_system.initialize() - # Step 2: Add a new tool guide policy - print("\nStep 2: Adding new tool guide policy...") - policy_id = await agent.policies.add_tool_guide( - name="Flight Booking Membership Policy", - content="""## Flight Booking Restrictions by Membership Level - -### Policy Rules -- Customers with "regular" membership cannot book a flight with more than 3 passengers -- Gold and silver members have no passenger restrictions -- This policy ensures fair resource allocation and encourages membership upgrades - -### Validation Requirements -- Always check user membership level before booking -- Reject bookings that violate passenger limits -- Provide clear error messages when restrictions apply -""", - target_tools=["book_flight"], - description="Membership-based restrictions for flight bookings to ensure fair resource allocation", - ) - print(f"Created policy with ID: {policy_id}") - - # Step 3: Generate examples using the SDK function - print("\nStep 3: Generating tool guard examples...") - violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( - policy_id=policy_id, - target_tool="book_flight" - ) - - print(f"Generated {len(violating_examples)} violating examples and {len(compliance_examples)} compliance examples") - print("\nViolating examples:") - for i, example in enumerate(violating_examples, 1): - print(f" {i}. {example}") - print("\nCompliance examples:") - for i, example in enumerate(compliance_examples, 1): - print(f" {i}. {example}") - - # Step 4: Update the policy with generated examples - print("\nStep 4: Updating policy with generated examples...") - await agent.policies.update_tool_guard( - policy_id=policy_id, - tool_guards={ - "book_flight": { - "description": "Guard rules for flight booking based on membership level", - "violating_examples": violating_examples, - "compliance_examples": compliance_examples, - "policy_code": "" # Will be populated in next step - } - } - ) - print(f"Updated policy {policy_id} with generated examples") - - # Step 5: Generate guard code - print("\nStep 5: Generating guard code...") - guard_code = await agent.policies.generate_tool_guard_code( - policy_id=policy_id, - target_tool="book_flight", - app_name="test_app" # Use simple app name for code generation - ) + # Delete from storage + all_policies = await policy_system.storage.list_policies() + print(f"Found {len(all_policies)} total policies in storage") - print(f"Generated guard code ({len(guard_code)} characters)") - print("\nGuard code preview (first 500 characters):") - print(guard_code[:500]) - print("...") + for policy in all_policies: + await policy_system.storage.delete_policy(policy.id) + print(f" Deleted from storage: '{policy.name}' (ID: {policy.id})") - # Step 6: Update the policy with generated code - print("\nStep 6: Updating policy with generated guard code...") - await agent.policies.update_tool_guard( - policy_id=policy_id, - tool_guards={ - "book_flight": { - "description": "Guard rules for flight booking based on membership level", - "violating_examples": violating_examples, - "compliance_examples": compliance_examples, - "policy_code": guard_code - } - } - ) - print(f"Updated policy {policy_id} with generated guard code") - - # Step 7: Save the policy - print("\nStep 7: Saving policy...") - policy_tool_guide = await agent.policies.get(policy_id) - if policy_tool_guide is None: - raise ValueError(f"Failed to retrieve policy {policy_id}") - - policy = policy_tool_guide["policy"] - - # Get policy system and update - policy_system = await agent.policies._ensure_policy_system() - if policy_system is None or policy_system.storage is None: - raise ValueError("Policy system storage is not available") - - await policy_system.initialize() - - # Save to file system if available + # Delete from filesystem if agent.policies._fs_sync: - agent.policies._fs_sync.save_policy_to_file(policy) - - print("Policy saved successfully!") - - # Step 8: Retrieve the policy from storage - print("\nStep 8: Retrieving policy from storage...") - retrieved_policy = await policy_system.storage.get_policy(policy_id) - if retrieved_policy is None: - raise ValueError(f"Failed to retrieve updated policy {policy_id}") - - if not isinstance(retrieved_policy, ToolGuide): - raise TypeError(f"Expected ToolGuide, got {type(retrieved_policy).__name__}") - - print("Policy retrieved successfully!") - - # Step 9: Export to JSON file - print("\nStep 9: Exporting policy data to JSON file...") - - # Convert tool_guards to serializable format - tool_guards_data = None - if retrieved_policy.tool_guards: - tool_guards_data = {} - for tool_name, guard in retrieved_policy.tool_guards.items(): - tool_guards_data[tool_name] = { - "description": guard.description, - "violating_examples": guard.violating_examples, - "compliance_examples": guard.compliance_examples, - "policy_code": guard.policy_code, - } - - output_data = { - "export_timestamp": datetime.utcnow().isoformat(), - "policy_id": policy_id, - "policy_name": retrieved_policy.name, - "policy_description": retrieved_policy.description, - "guide_content": retrieved_policy.guide_content, - "target_tools": retrieved_policy.target_tools, - "target_apps": retrieved_policy.target_apps, - "tool_guards": tool_guards_data, - "prepend": retrieved_policy.prepend, - "priority": retrieved_policy.priority, - "enabled": retrieved_policy.enabled, - "policy_type": type(retrieved_policy).__name__, - "metadata": retrieved_policy.metadata, - } - - # Save to JSON file in current directory - output_path = Path("flight_booking_policy_export_e2e.json") - output_path.write_text(json.dumps(output_data, indent=2), encoding="utf-8") - - print(f"\n{'='*60}") - print("SUCCESS! Policy workflow with code generation completed:") - print(f"{'='*60}") - print(f"Policy ID: {policy_id}") - print(f"Policy Name: {retrieved_policy.name}") - print(f"Target Tools: {', '.join(retrieved_policy.target_tools)}") - print(f"Generated Examples: {len(violating_examples)} violating, {len(compliance_examples)} compliance") - print(f"Generated Code: {len(guard_code)} characters") - print(f"\nExported to: {output_path.absolute()}") - print(f"{'='*60}") - - # Print summary of tool_guards - if retrieved_policy.tool_guards: - print("\nTool Guards Summary:") - for tool_name, guard in retrieved_policy.tool_guards.items(): - print(f"\n Tool: {tool_name}") - print(f" Description: {guard.description}") - print(f" Violating Examples: {len(guard.violating_examples)}") - print(f" Compliance Examples: {len(guard.compliance_examples)}") - print(f" Policy Code Length: {len(guard.policy_code)} characters") - - # Step 10: Test the policy by invoking the tool with ToolGuardRuntime - print(f"\n{'='*60}") - print("Step 10: Testing policy enforcement with ToolGuardRuntime") - print(f"{'='*60}") - - # Initialize ToolGuardRuntime - print("\nInitializing ToolGuardRuntime...") - tool_guard_runtime = ToolGuardRuntime( - tool_provider=agent.tool_provider, - policy_storage=policy_system.storage - ) - await tool_guard_runtime.initialize() + print("\nCleaning up policy files from filesystem...") + cuga_folder = Path(agent.policies._fs_sync.cuga_folder) + if cuga_folder.exists(): + policy_subfolders = ['playbooks', 'output_formatters', 'tool_guides', + 'intent_guards', 'tool_approvals', 'policies'] + + total_deleted = 0 + for subfolder in policy_subfolders: + subfolder_path = cuga_folder / subfolder + if subfolder_path.exists(): + files = list(subfolder_path.glob("*.md")) + list(subfolder_path.glob("*.json")) + for file in files: + file.unlink() + total_deleted += 1 + + if total_deleted > 0: + print(f"✅ Deleted {total_deleted} policy files from filesystem") - print(f"Runtime initialized with guards for: {tool_guard_runtime.get_guarded_tools()}") + print("✅ All policies successfully deleted") + print("="*60) + + +async def create_and_process_policies(agent, policy_system): + """Create policies and generate examples and guard code for each.""" + print("\nStep 1: Creating and processing policies...") + print("="*60) - # Test case 1: Violating call (should be blocked) - print("\n--- Test Case 1: Violating Call ---") - print("Attempting: book_flight(flight_id='fl123', user_id='user789', passengers=8)") - print("Expected: BLOCKED (user789 is 'regular' member, 8 > 3 passengers)") + policy_data = [] - try: - error = await tool_guard_runtime.guard_tool_call( - app_name="test_app", # Match the app_name used in code generation - function_name="book_flight", - arguments={ - "flight_id": "fl123", - "user_id": "user789", - "passengers": 8 - } + for idx, policy_config in enumerate(POLICIES, 1): + print(f"\n--- Processing Policy {idx}/{len(POLICIES)}: {policy_config['name']} ---") + + # Create policy + print(f"Creating policy...") + policy_id = await agent.policies.add_tool_guide( + name=policy_config["name"], + content=policy_config["content"], + target_tools=["book_flight"], + description=policy_config["description"], ) + print(f"✅ Created policy with ID: {policy_id}") - if error: - print(f"\n✅ SUCCESS: Tool call was correctly BLOCKED!") - print(f"Error message: {error}") - else: - print("\n⚠️ WARNING: Tool call was ALLOWED (expected to be blocked)") - - except Exception as e: - print(f"\n❌ Error during validation: {type(e).__name__}: {e}") - - # Test case 2: Compliant call (should be allowed) - print("\n--- Test Case 2: Compliant Call ---") - print("Attempting: book_flight(flight_id='fl456', user_id='user789', passengers=2)") - print("Expected: ALLOWED (user789 is 'regular' member, 2 <= 3 passengers)") - - try: - error = await tool_guard_runtime.guard_tool_call( - app_name="test_app", # Match the app_name used in code generation - function_name="book_flight", - arguments={ - "flight_id": "fl456", - "user_id": "user789", - "passengers": 2 + # Generate examples + print(f"Generating examples...") + violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="book_flight" + ) + print(f"✅ Generated {len(violating_examples)} violating and {len(compliance_examples)} compliance examples") + + # Update policy with examples + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "book_flight": { + "description": f"Guard rules for {policy_config['name']}", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": "" + } } ) - if error: - print(f"\n⚠️ WARNING: Tool call was BLOCKED (expected to be allowed)") - print(f"Error message: {error}") - else: - print("\n✅ SUCCESS: Tool call was correctly ALLOWED!") - - # Actually invoke the tool to show it works - print("\nExecuting the actual tool call...") - result = await book_flight.ainvoke({ - "flight_id": "fl456", - "user_id": "user789", - "passengers": 2 - }) - print(f"Tool result: {result}") - - except Exception as e: - print(f"\n❌ Error during validation: {type(e).__name__}: {e}") - - # Test case 3: Gold member with many passengers (should be allowed) - print("\n--- Test Case 3: Gold Member Call ---") - print("Attempting: book_flight(flight_id='fl789', user_id='user123', passengers=10)") - print("Expected: ALLOWED (user123 is 'gold' member, no passenger limit)") - - try: - error = await tool_guard_runtime.guard_tool_call( - app_name="test_app", # Match the app_name used in code generation - function_name="book_flight", - arguments={ - "flight_id": "fl789", - "user_id": "user123", - "passengers": 10 + # Generate guard code + print(f"Generating guard code...") + guard_code = await agent.policies.generate_tool_guard_code( + policy_id=policy_id, + target_tool="book_flight", + app_name="test_app" + ) + print(f"✅ Generated guard code ({len(guard_code)} characters)") + print(f"✅ Code:\n{guard_code} ") + + # Update policy with guard code + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "book_flight": { + "description": f"Guard rules for {policy_config['name']}", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": guard_code + } } ) - if error: - print(f"\n⚠️ WARNING: Tool call was BLOCKED (expected to be allowed)") - print(f"Error message: {error}") - else: - print("\n✅ SUCCESS: Tool call was correctly ALLOWED!") + # Save policy + policy_tool_guide = await agent.policies.get(policy_id) + if policy_tool_guide is None: + raise ValueError(f"Failed to retrieve policy {policy_id}") + + policy = policy_tool_guide["policy"] + if agent.policies._fs_sync: + agent.policies._fs_sync.save_policy_to_file(policy) + + print(f"✅ Policy saved successfully") + + # Store policy data for later use + policy_data.append({ + "id": policy_id, + "name": policy_config["name"], + "policy": policy + }) + + print("\n" + "="*60) + print(f"✅ All {len(POLICIES)} policies created and processed successfully") + print("="*60) + + return policy_data + + +async def run_tests(tool_guard_runtime): + """Run test cases to validate policy enforcement.""" + print(f"\n{'='*60}") + print("Step 2: Testing policy enforcement with ToolGuardRuntime") + print(f"{'='*60}") + + print(f"\nRuntime initialized with guards for: {tool_guard_runtime.get_guarded_tools()}") + + test_cases = [ + { + "name": "Test Case 1: Too Many Passengers", + "args": {"flight_id": "FL12", "user_id": "user789", "passengers": 8}, + "expected": "BLOCKED", + "reason": "user789 is 'regular' member, 8 > 3 passengers" + }, + { + "name": "Test Case 2: Valid Booking", + "args": {"flight_id": "FL45", "user_id": "user789", "passengers": 2}, + "expected": "ALLOWED", + "reason": "user789 is 'regular' member, 2 <= 3 passengers, valid flight ID" + }, + { + "name": "Test Case 3: Gold Member", + "args": {"flight_id": "AB78", "user_id": "user123", "passengers": 10}, + "expected": "ALLOWED", + "reason": "user123 is 'gold' member, no passenger limit" + }, + { + "name": "Test Case 4: Multiple Violations", + "args": {"flight_id": "F123", "user_id": "user789", "passengers": 8}, + "expected": "BLOCKED", + "reason": "8 > 3 passengers AND flight_id 'F123' has only 1 letter" + }, + { + "name": "Test Case 5: Invalid Flight ID Only", + "args": {"flight_id": "ABC1", "user_id": "user789", "passengers": 2}, + "expected": "BLOCKED", + "reason": "flight_id 'ABC1' has 3 letters instead of 2" + }, + ] + + results = [] + for test in test_cases: + print(f"\n--- {test['name']} ---") + print(f"Attempting: book_flight({', '.join(f'{k}={repr(v)}' for k, v in test['args'].items())})") + print(f"Expected: {test['expected']} ({test['reason']})") + + try: + error = await tool_guard_runtime.guard_tool_call( + app_name="test_app", + function_name="book_flight", + arguments=test["args"] + ) - # Actually invoke the tool to show it works - print("\nExecuting the actual tool call...") - result = await book_flight.ainvoke({ - "flight_id": "fl789", - "user_id": "user123", - "passengers": 10 - }) - print(f"Tool result: {result}") + actual = "BLOCKED" if error else "ALLOWED" + success = actual == test["expected"] - except Exception as e: - print(f"\n❌ Error during validation: {type(e).__name__}: {e}") + if success: + print(f"\n✅ SUCCESS: Tool call was correctly {actual}!") + if error: + print(f"Error message: {error}") + else: + # Actually invoke the tool + result = await book_flight.ainvoke(test["args"]) + print(f"Tool result: {result}") + else: + print(f"\n⚠️ WARNING: Tool call was {actual} (expected {test['expected']})") + if error: + print(f"Error message: {error}") + + results.append({"test": test["name"], "success": success, "actual": actual}) + + except Exception as e: + print(f"\n❌ Error during validation: {type(e).__name__}: {e}") + results.append({"test": test["name"], "success": False, "actual": "ERROR"}) + # Print summary print(f"\n{'='*60}") - print("E2E Test Complete!") + print("Test Summary:") + print(f"{'='*60}") + passed = sum(1 for r in results if r["success"]) + print(f"Passed: {passed}/{len(results)}") + for r in results: + status = "✅" if r["success"] else "❌" + print(f" {status} {r['test']}: {r['actual']}") print(f"{'='*60}") - print("\nSummary:") - print(" - Policy created and saved successfully") - print(" - Examples and guard code generated") - print(" - ToolGuardRuntime initialized and tested") - print(" - All test cases executed") - # Step 11: Clean up - Delete the policy we created + return results + + +async def cleanup_policies(agent, policy_system, policy_data): + """Delete all created policies.""" print(f"\n{'='*60}") - print("Step 11: Cleaning up - Deleting test policy") + print("Step 3: Cleaning up test policies") print(f"{'='*60}") try: - # Delete from storage - await policy_system.storage.delete_policy(policy_id) - print(f"✅ Deleted policy from storage: {policy_id}") - - # Delete from filesystem if filesystem sync is enabled - if agent.policies._fs_sync: - # Get the .cuga folder path - cuga_folder = Path(agent.policies._fs_sync.cuga_folder) - tool_guides_folder = cuga_folder / "tool_guides" + for policy_info in policy_data: + # Delete from storage + await policy_system.storage.delete_policy(policy_info["id"]) + print(f"✅ Deleted '{policy_info['name']}' from storage") - if tool_guides_folder.exists(): - # Find and delete the policy file - policy_files = list(tool_guides_folder.glob(f"*{policy_id}*.md")) + \ - list(tool_guides_folder.glob(f"*{policy_id}*.json")) - - for policy_file in policy_files: - policy_file.unlink() - print(f"✅ Deleted policy file: {policy_file.name}") + # Delete from filesystem + if agent.policies._fs_sync: + cuga_folder = Path(agent.policies._fs_sync.cuga_folder) + tool_guides_folder = cuga_folder / "tool_guides" - if not policy_files: - print("⚠️ No policy files found to delete from filesystem") + if tool_guides_folder.exists(): + policy_files = list(tool_guides_folder.glob(f"*{policy_info['id']}*.md")) + \ + list(tool_guides_folder.glob(f"*{policy_info['id']}*.json")) + + for policy_file in policy_files: + policy_file.unlink() + print(f"✅ Deleted file: {policy_file.name}") + + print("✅ All test policies successfully deleted") - # Verify deletion - deleted_policy = await policy_system.storage.get_policy(policy_id) - if deleted_policy is None: - print("✅ Policy successfully deleted and verified") - else: - print("⚠️ Warning: Policy still exists in storage after deletion") - except Exception as e: print(f"⚠️ Error during cleanup: {type(e).__name__}: {e}") + print(f"{'='*60}") + + +async def main(): + """Main workflow for creating and testing tool guard policies.""" + + # Step 0: Optional cleanup + agent = CugaAgent(tools=[book_flight, get_membership]) + + if DELETE_ALL_POLICIES_AT_START: + await cleanup_all_policies(agent) + else: + print("="*60) + print("Skipping initial cleanup (DELETE_ALL_POLICIES_AT_START=False)") + print("="*60) + + # Get policy system + policy_system = await agent.policies._ensure_policy_system() + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + await policy_system.initialize() + + # Step 1: Create and process policies + policy_data = await create_and_process_policies(agent, policy_system) + + # Step 2: Initialize runtime and run tests + tool_guard_runtime = ToolGuardRuntime( + tool_provider=agent.tool_provider, + policy_storage=policy_system.storage + ) + await tool_guard_runtime.initialize() + + results = await run_tests(tool_guard_runtime) + + # Step 3: Cleanup + await cleanup_policies(agent, policy_system, policy_data) + print(f"\n{'='*60}") - print("✅ Test completed and cleaned up successfully!") + print("✅ E2E Test completed successfully!") print(f"{'='*60}") From 9f0118475b21c05da207adb16cb449968ac56746 Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 27 Apr 2026 12:12:58 +0300 Subject: [PATCH 10/24] cleanup --- .../tool_guard/example_cuga_llm_usage.py | 210 --------- .../policy/tool_guard/tool_guard_runtime.py | 413 ++++++++++++------ src/cuga/sdk_core/debug_tool_guard_e2e.py | 2 +- 3 files changed, 285 insertions(+), 340 deletions(-) delete mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/example_cuga_llm_usage.py diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/example_cuga_llm_usage.py b/src/cuga/backend/cuga_graph/policy/tool_guard/example_cuga_llm_usage.py deleted file mode 100644 index cb0266a4..00000000 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/example_cuga_llm_usage.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -Example usage of CugaLLMAdapter with ToolGuard. - -This demonstrates how to use CUGA's LLM system with ToolGuard's -example generation capabilities. -""" - -import asyncio -from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import CombinedToolProvider -from cuga.sdk import CugaAgent -from cuga.backend.cuga_graph.policy.tool_guard.cuga_llm_adapter import CugaLLMAdapter -from langchain_core.tools import tool - - -@tool -def delete_file(path: str) -> str: - """Delete a file from the filesystem""" - return f"Deleted file: {path}" - - -@tool -def read_file(path: str) -> str: - """Read contents of a file""" - return f"Contents of {path}" - - -async def example_basic_usage(): - """Basic example: Using CugaLLMAdapter directly""" - print("\n=== Example 1: Basic CugaLLMAdapter Usage ===\n") - - # Create a CUGA agent with tools - agent = CugaAgent(tools=[delete_file, read_file]) - - # Wrap the agent's model with the adapter - assert agent._model is not None, "Agent model should be initialized" - llm_adapter = CugaLLMAdapter(agent._model) - - # Use the adapter with ToolGuard's message format - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Say hello in a friendly way."} - ] - - response = await llm_adapter.generate(messages) - print(f"Response: {response}\n") - - -async def example_json_output(): - """Example: Using chat_json for structured output""" - print("\n=== Example 2: JSON Output with chat_json ===\n") - - agent = CugaAgent(tools=[delete_file]) - assert agent._model is not None, "Agent model should be initialized" - llm_adapter = CugaLLMAdapter(agent._model) - - # Request JSON output - messages = [ - { - "role": "system", - "content": "You are a helpful assistant that responds in JSON format." - }, - { - "role": "user", - "content": ( - "Generate a JSON object with two fields: 'greeting' and 'language'. " - "Wrap your response in ```json``` code blocks." - ) - } - ] - - # chat_json automatically extracts JSON from the response - json_response = await llm_adapter.chat_json(messages) - print(f"JSON Response: {json_response}\n") - - -async def example_with_toolguard_manager(): - """Example: Using ToolGuardManager with CugaLLMAdapter""" - print("\n=== Example 3: ToolGuardManager Integration ===\n") - - from cuga.backend.cuga_graph.policy.tool_guard.manager import ToolGuardManager - - # Create agent with tools - agent = CugaAgent( - tools=[delete_file, read_file], - cuga_folder="/Users/naamazwerdling/Documents/OASB/policy_validation/cuga_demo/output" - ) - - # Create ToolGuard manager (automatically uses CugaLLMAdapter) - manager = ToolGuardManager(agent) - await manager.initialize() - - tool_count = len(manager.langchain_tools) if manager.langchain_tools else 0 - print(f"✅ ToolGuardManager initialized with {tool_count} tools") - print(f" LLM: {type(manager.llm).__name__}") - print(f" Underlying model: {type(agent._model).__name__}\n") - - # Add a tool guide policy - await agent.policies.add_tool_guide( - name="File Safety Guidelines", - target_tools=["delete_file"], - description="Ensure safe file deletion practices", - content=""" -# File Safety Guidelines - -## Rules -- Never delete system files (e.g., /etc/*, /sys/*, /boot/*) -- Never delete files without user confirmation -- Always validate file paths before deletion - -## Examples -- ❌ BAD: Delete /etc/passwd -- ✅ GOOD: Delete /home/user/temp.txt (with confirmation) - """.strip() - ) - - print("✅ Added tool guide policy for delete_file\n") - - # Get the policy - policies = await agent.policies.list() - if policies: - policy_info = policies[0] - print(f"Policy: {policy_info['name']} (ID: {policy_info['id']})\n") - - # Generate examples using ToolGuard with CUGA's LLM - print("🔄 Generating violating and compliance examples...") - policy_full = await agent.policies.get(policy_info['id']) - - if policy_full and policy_full.get('policy'): - try: - violating, compliance = await manager.generate_examples( - policy=policy_full['policy'], - target_tool="delete_file" - ) - - print("\n📋 Generated Examples:\n") - - print("❌ Violating Examples (actions that break the policy):") - for i, example in enumerate(violating, 1): - print(f" {i}. {example}") - - print("\n✅ Compliance Examples (actions that follow the policy):") - for i, example in enumerate(compliance, 1): - print(f" {i}. {example}") - - print() - - except Exception as e: - print(f"⚠️ Error generating examples: {e}") - import traceback - print(traceback.format_exc()) - print(" This might happen if the policy format is incompatible or LLM fails") - print() - else: - print("⚠️ Could not retrieve full policy details\n") - - -async def example_error_handling(): - """Example: Error handling with CugaLLMAdapter""" - print("\n=== Example 4: Error Handling ===\n") - - agent = CugaAgent(tools=[delete_file]) - assert agent._model is not None, "Agent model should be initialized" - llm_adapter = CugaLLMAdapter(agent._model) - - # Test with invalid message format - try: - invalid_messages = [ - {"content": "Missing role key"} # Missing 'role' - ] - await llm_adapter.generate(invalid_messages) - except ValueError as e: - print(f"✅ Caught expected error: {e}\n") - - # Test with valid messages - try: - valid_messages = [ - {"role": "user", "content": "Hello!"} - ] - response = await llm_adapter.generate(valid_messages) - print(f"✅ Valid request succeeded: {response[:50]}...\n") - except Exception as e: - print(f"❌ Unexpected error: {e}\n") - - -async def main(): - """Run all examples""" - print("=" * 60) - print("CugaLLMAdapter Usage Examples") - print("=" * 60) - - try: - #await example_basic_usage() - #await example_json_output() - await example_with_toolguard_manager() - #await example_error_handling() - - print("=" * 60) - print("✅ All examples completed successfully!") - print("=" * 60) - - except Exception as e: - print(f"\n❌ Error running examples: {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - asyncio.run(main()) - -# Made with Bob diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index 9892ce03..7fc9d3cf 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -6,7 +6,8 @@ """ from pathlib import Path -from typing import Dict, Any, List, Optional +from typing import Any, Dict, List, Optional, Sequence, Tuple + from loguru import logger from toolguard.runtime.data_types import ( @@ -20,9 +21,9 @@ ) from toolguard.runtime.runtime import load_toolguards_from_memory -from cuga.backend.cuga_graph.policy.tool_guard.tool_invoker import ToolGuardInvoker -from cuga.backend.cuga_graph.policy.models import ToolGuide, PolicyType +from cuga.backend.cuga_graph.policy.models import Policy, PolicyType, ToolGuide from cuga.backend.cuga_graph.policy.storage import PolicyStorage +from cuga.backend.cuga_graph.policy.tool_guard.tool_invoker import ToolGuardInvoker class ToolGuardRuntime: @@ -37,7 +38,7 @@ class ToolGuardRuntime: 5. Executes guard validation through toolguard runtime """ - def __init__(self, tool_provider, policy_storage: PolicyStorage): + def __init__(self, tool_provider, policy_storage: PolicyStorage) -> None: """ Initialize the ToolGuardRuntime. @@ -47,17 +48,10 @@ def __init__(self, tool_provider, policy_storage: PolicyStorage): """ self.tool_provider = tool_provider self.policy_storage = policy_storage - - # Create ToolGuardInvoker instance self.invoker = ToolGuardInvoker(tool_provider) - - # Mapping: tool_name -> List[ToolGuide policies with code] self.tool_to_guards: Dict[str, List[ToolGuide]] = {} - - # ToolGuard in-memory runtime built from umbrella guard modules self._runtime = None self._runtime_domain: Optional[RuntimeDomain] = None - self._initialized = False logger.debug("Created ToolGuardRuntime instance") @@ -73,61 +67,76 @@ async def initialize(self) -> None: 5. Builds an in-memory ToolGuard runtime """ logger.info("Initializing ToolGuardRuntime...") + self._reset_state() + + policies = await self.policy_storage.list_policies( + policy_type=PolicyType.TOOL_GUIDE, enabled_only=True + ) + logger.debug(f"Found {len(policies)} ToolGuide policies") + # Filter to ensure we only have ToolGuide instances + tool_guide_policies = [p for p in policies if isinstance(p, ToolGuide)] + self._build_tool_to_guards_mapping(tool_guide_policies) + + if self.tool_to_guards: + self._runtime_domain = self._load_runtime_domain() + self._runtime = self._build_runtime() + + self._initialized = True + self._log_initialization_summary() + + def _reset_state(self) -> None: + """Reset internal state for reinitialization.""" self.tool_to_guards = {} self._runtime = None self._runtime_domain = None - # Load all ToolGuide policies - policies = await self.policy_storage.list_policies( - policy_type=PolicyType.TOOL_GUIDE, - enabled_only=True - ) - - logger.debug(f"Found {len(policies)} ToolGuide policies") + def _build_tool_to_guards_mapping(self, policies: Sequence[ToolGuide]) -> None: + """ + Build mapping from tool names to their guard policies. - # Build mapping: tool_name -> List[ToolGuide with code] + Args: + policies: Sequence of ToolGuide policies to process + """ for policy in policies: - # Type guard: ensure we're working with ToolGuide - if not isinstance(policy, ToolGuide): - logger.warning( - f"Expected ToolGuide but got {type(policy).__name__}, skipping" - ) - continue - if not policy.tool_guards: logger.debug(f"Policy '{policy.name}' has no tool_guards, skipping") continue - # Iterate through each tool's guard configuration - for tool_name, tool_guard in policy.tool_guards.items(): - # Only include if policy_code exists - if tool_guard.policy_code: - if tool_name not in self.tool_to_guards: - self.tool_to_guards[tool_name] = [] - - self.tool_to_guards[tool_name].append(policy) - logger.debug( - f"Registered guard for tool '{tool_name}' " - f"from policy '{policy.name}'" - ) - else: - logger.debug( - f"Tool guard for '{tool_name}' in policy '{policy.name}' " - f"has no policy_code, skipping" - ) + self._register_policy_guards(policy) - if self.tool_to_guards: - self._runtime_domain = self._load_runtime_domain() - self._runtime = self._build_runtime() + def _register_policy_guards(self, policy: ToolGuide) -> None: + """ + Register guards from a policy for all its tools. - self._initialized = True + Args: + policy: ToolGuide policy to register + """ + if not policy.tool_guards: + return + + for tool_name, tool_guard in policy.tool_guards.items(): + if not tool_guard.policy_code: + logger.debug( + f"Tool guard for '{tool_name}' in policy '{policy.name}' " + f"has no policy_code, skipping" + ) + continue + + if tool_name not in self.tool_to_guards: + self.tool_to_guards[tool_name] = [] + + self.tool_to_guards[tool_name].append(policy) + logger.debug( + f"Registered guard for tool '{tool_name}' from policy '{policy.name}'" + ) + + def _log_initialization_summary(self) -> None: + """Log summary of initialization results.""" logger.info( f"✅ ToolGuardRuntime initialized with guards for " f"{len(self.tool_to_guards)} tools" ) - - # Log summary of registered guards for tool_name, guards in self.tool_to_guards.items(): logger.debug( f" - Tool '{tool_name}': {len(guards)} guard(s) " @@ -191,15 +200,57 @@ def _build_runtime(self): return runtime def _load_runtime_domain(self) -> RuntimeDomain: - """Load RuntimeDomain files saved by ToolGuardManager.""" + """ + Load RuntimeDomain files saved by ToolGuardManager. + + Returns: + RuntimeDomain with loaded domain files + + Raises: + RuntimeError: If domain directory or files are not found + """ domain_dir = Path.cwd() / ".cuga" / "toolguard" / "domain" + self._validate_domain_directory(domain_dir) + app_dirs = self._get_sorted_app_directories(domain_dir) + selected_domain = self._find_complete_domain(domain_dir, app_dirs) + + if selected_domain is None: + raise RuntimeError( + f"No complete ToolGuard domain found under {domain_dir}" + ) + + return self._create_runtime_domain(domain_dir, selected_domain) + + def _validate_domain_directory(self, domain_dir: Path) -> None: + """ + Validate that the domain directory exists. + + Args: + domain_dir: Path to domain directory + + Raises: + RuntimeError: If domain directory doesn't exist + """ if not domain_dir.exists(): raise RuntimeError( f"ToolGuard domain directory not found: {domain_dir}. " "Generate tool guard code first so ToolGuardManager saves the domain files." ) + def _get_sorted_app_directories(self, domain_dir: Path) -> List[Path]: + """ + Get app directories sorted by modification time (newest first). + + Args: + domain_dir: Path to domain directory + + Returns: + List of app directory paths + + Raises: + RuntimeError: If no app directories found + """ app_dirs = sorted( [path for path in domain_dir.iterdir() if path.is_dir()], key=lambda path: path.stat().st_mtime, @@ -209,8 +260,21 @@ def _load_runtime_domain(self) -> RuntimeDomain: raise RuntimeError( f"No ToolGuard app directories found under {domain_dir}" ) + return app_dirs + + def _find_complete_domain( + self, domain_dir: Path, app_dirs: List[Path] + ) -> Optional[Tuple[str, Path, Path, Path]]: + """ + Find the first complete domain with all required files. - selected_domain = None + Args: + domain_dir: Path to domain directory + app_dirs: List of app directories to search + + Returns: + Tuple of (app_name, types_path, api_path, impl_path) or None + """ for app_dir in app_dirs: app_name = app_dir.name app_types_rel = Path(app_name) / f"{app_name}_types.py" @@ -223,33 +287,34 @@ def _load_runtime_domain(self) -> RuntimeDomain: domain_dir / app_api_impl_rel, ] if all(path.exists() for path in candidate_paths): - selected_domain = (app_name, app_types_rel, app_api_rel, app_api_impl_rel) - break + return (app_name, app_types_rel, app_api_rel, app_api_impl_rel) - if selected_domain is None: - raise RuntimeError( - f"No complete ToolGuard domain found under {domain_dir}" - ) + return None + + def _create_runtime_domain( + self, domain_dir: Path, selected_domain: Tuple[str, Path, Path, Path] + ) -> RuntimeDomain: + """ + Create RuntimeDomain from selected domain files. + + Args: + domain_dir: Path to domain directory + selected_domain: Tuple of (app_name, types_path, api_path, impl_path) + Returns: + RuntimeDomain instance + """ app_name, app_types_rel, app_api_rel, app_api_impl_rel = selected_domain api_content = FileTwin.load_from(domain_dir, app_api_rel).content api_impl_content = FileTwin.load_from(domain_dir, app_api_impl_rel).content - app_api_class_name = f"I{''.join(part.capitalize() for part in app_name.split('_'))}" - app_api_impl_class_name = ''.join(part.capitalize() for part in app_name.split('_')) - - for line in api_content.splitlines(): - stripped = line.strip() - if stripped.startswith("class "): - app_api_class_name = stripped.split()[1].split("(")[0].rstrip(":") - break - - for line in api_impl_content.splitlines(): - stripped = line.strip() - if stripped.startswith("class "): - app_api_impl_class_name = stripped.split()[1].split("(")[0].rstrip(":") - break + app_api_class_name = self._extract_class_name( + api_content, f"I{''.join(part.capitalize() for part in app_name.split('_'))}" + ) + app_api_impl_class_name = self._extract_class_name( + api_impl_content, ''.join(part.capitalize() for part in app_name.split('_')) + ) return RuntimeDomain( app_name=app_name, @@ -261,65 +326,135 @@ def _load_runtime_domain(self) -> RuntimeDomain: app_api_impl=FileTwin.load_from(domain_dir, app_api_impl_rel), ) + def _extract_class_name(self, content: str, default: str) -> str: + """ + Extract class name from Python source code. + + Args: + content: Python source code + default: Default class name if not found + + Returns: + Extracted or default class name + """ + for line in content.splitlines(): + stripped = line.strip() + if stripped.startswith("class "): + return stripped.split()[1].split("(")[0].rstrip(":") + return default + def _build_tool_guard_module( self, tool_name: str, guards: List[ToolGuide], guard_fn_name: str, ) -> str: - """Create a module containing a single umbrella guard function for one tool.""" + """ + Create a module containing a single umbrella guard function for one tool. + + Args: + tool_name: Name of the tool + guards: List of ToolGuide policies for this tool + guard_fn_name: Name for the umbrella guard function + + Returns: + Generated Python module content as string + """ guard_blocks: List[str] = [] guard_calls: List[str] = [] for index, policy in enumerate(guards): - tool_guard = policy.tool_guards.get(tool_name) if policy.tool_guards else None - if not tool_guard or not tool_guard.policy_code: - logger.warning( - f"Policy '{policy.name}' missing tool_guard for '{tool_name}', skipping" - ) - continue + self._process_policy_guard( + policy, tool_name, index, guard_blocks, guard_calls + ) - compiled_name = f"_compiled_guard_{index}" - validate_alias = f"_guard_validate_{index}" - - # Extract the guard function name from the policy code - # The generated code has @rule decorator followed by async def guard_xxx - guard_func_name = None - for line in tool_guard.policy_code.split('\n'): - line = line.strip() - if line.startswith('async def guard_'): - # Extract function name: "async def guard_xxx(..." -> "guard_xxx" - guard_func_name = line.split('(')[0].replace('async def ', '').strip() - break - - if not guard_func_name: - logger.warning( - f"Could not find guard function in policy code for '{policy.name}', skipping" - ) - continue - - guard_blocks.append( - f"# Policy: {policy.name}\n" - f"{tool_guard.policy_code}\n" - f"# Assign the specific guard function for this policy\n" - f"{validate_alias} = {guard_func_name}\n" + return self._generate_module_content(guard_fn_name, guard_blocks, guard_calls) + + def _process_policy_guard( + self, + policy: ToolGuide, + tool_name: str, + index: int, + guard_blocks: List[str], + guard_calls: List[str], + ) -> None: + """ + Process a single policy guard and add to blocks and calls. + + Args: + policy: ToolGuide policy to process + tool_name: Name of the tool + index: Index of this guard + guard_blocks: List to append guard code blocks to + guard_calls: List to append guard call statements to + """ + tool_guard = policy.tool_guards.get(tool_name) if policy.tool_guards else None + if not tool_guard or not tool_guard.policy_code: + logger.warning( + f"Policy '{policy.name}' missing tool_guard for '{tool_name}', skipping" ) + return - guard_calls.extend( - [ - " try:", - f" await {validate_alias}(api=api, args=args)", - " except PolicyViolationException as e:", - f" error_msg = str(e)", - f" # Check if error already contains policy name to avoid duplication", - f" if not error_msg.startswith('[{policy.name}]'):", - f" error_msg = f\"[{policy.name}] {{error_msg}}\"", - f" violations.append(error_msg)", - ] + guard_func_name = self._extract_guard_function_name(tool_guard.policy_code) + if not guard_func_name: + logger.warning( + f"Could not find guard function in policy code for '{policy.name}', skipping" ) + return + + validate_alias = f"_guard_validate_{index}" + + guard_blocks.append( + f"# Policy: {policy.name}\n" + f"{tool_guard.policy_code}\n" + f"# Assign the specific guard function for this policy\n" + f"{validate_alias} = {guard_func_name}\n" + ) + guard_calls.extend([ + " try:", + f" await {validate_alias}(api=api, args=args)", + " except PolicyViolationException as e:", + " error_msg = str(e)", + f" # Check if error already contains policy name to avoid duplication", + f" if not error_msg.startswith('[{policy.name}]'):", + f" error_msg = f\"[{policy.name}] {{error_msg}}\"", + " violations.append(error_msg)", + ]) + + def _extract_guard_function_name(self, policy_code: str) -> Optional[str]: + """ + Extract guard function name from policy code. + + Args: + policy_code: Generated policy code + + Returns: + Guard function name or None if not found + """ + for line in policy_code.split('\n'): + line = line.strip() + if line.startswith('async def guard_'): + # Extract function name: "async def guard_xxx(..." -> "guard_xxx" + return line.split('(')[0].replace('async def ', '').strip() + return None + + def _generate_module_content( + self, guard_fn_name: str, guard_blocks: List[str], guard_calls: List[str] + ) -> str: + """ + Generate the complete module content. + + Args: + guard_fn_name: Name for the umbrella guard function + guard_blocks: List of guard code blocks + guard_calls: List of guard call statements + + Returns: + Complete module content as string + """ if not guard_calls: - guard_calls.append(" return None") + guard_calls = [" return None"] else: guard_calls = [ " violations = []", @@ -340,22 +475,45 @@ def _build_tool_guard_module( ) def _module_name_for_tool(self, tool_name: str) -> str: - """Convert a tool name to a valid python module name.""" - normalized = "".join( - ch if ch.isalnum() else "_" for ch in tool_name.lower() - ).strip("_") - if not normalized: - normalized = "tool" + """ + Convert a tool name to a valid python module name. + + Args: + tool_name: Name of the tool + + Returns: + Valid Python module name + """ + normalized = self._normalize_name(tool_name) return f"cuga_toolguard_runtime.generated.guard_{normalized}" def _guard_function_name_for_tool(self, tool_name: str) -> str: - """Convert a tool name to a valid umbrella guard function name.""" + """ + Convert a tool name to a valid umbrella guard function name. + + Args: + tool_name: Name of the tool + + Returns: + Valid Python function name + """ + normalized = self._normalize_name(tool_name) + return f"guard_{normalized}" + + def _normalize_name(self, name: str) -> str: + """ + Normalize a name to be a valid Python identifier. + + Args: + name: Name to normalize + + Returns: + Normalized name safe for use as Python identifier + """ normalized = "".join( - ch if ch.isalnum() else "_" for ch in tool_name.lower() + ch if ch.isalnum() else "_" for ch in name.lower() ).strip("_") - if not normalized: - normalized = "tool" - return f"guard_{normalized}" + return normalized if normalized else "tool" async def guard_tool_call( self, @@ -446,7 +604,4 @@ def get_guards_for_tool(self, tool_name: str) -> List[ToolGuide]: Returns: List of ToolGuide policies with guards for this tool """ - return self.tool_to_guards.get(tool_name, []) - - -# Made with Bob \ No newline at end of file + return self.tool_to_guards.get(tool_name, []) \ No newline at end of file diff --git a/src/cuga/sdk_core/debug_tool_guard_e2e.py b/src/cuga/sdk_core/debug_tool_guard_e2e.py index eafb93c2..d66e2d92 100644 --- a/src/cuga/sdk_core/debug_tool_guard_e2e.py +++ b/src/cuga/sdk_core/debug_tool_guard_e2e.py @@ -176,7 +176,7 @@ async def create_and_process_policies(agent, policy_system): app_name="test_app" ) print(f"✅ Generated guard code ({len(guard_code)} characters)") - print(f"✅ Code:\n{guard_code} ") + #print(f"✅ Code:\n{guard_code} ") # Update policy with guard code await agent.policies.update_tool_guard( From 704f3ca7cfe9e23c21f4d5472bc951c878c83994 Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 27 Apr 2026 12:32:51 +0300 Subject: [PATCH 11/24] cleanup --- src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py | 2 +- .../backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py | 2 +- src/cuga/backend/cuga_graph/policy/tool_guard/manager.py | 1 - src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py | 2 -- src/cuga/sdk_core/debug_tool_guard_e2e.py | 4 +--- 5 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py index 90e4e9ce..841989f8 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/__init__.py @@ -10,4 +10,4 @@ __all__ = ["ToolGuardManager", "ToolGuardRuntime"] -# Made with Bob + diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py index 0b257f5a..ae83d64a 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py @@ -197,4 +197,4 @@ async def generate(self, messages: List[Dict]) -> str: "Check that your model is properly configured and has valid credentials." ) from e -# Made with Bob + diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py index 1a8c21e7..58b5ebb4 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py @@ -351,4 +351,3 @@ def get_available_tools(self) -> List[str]: return [tool.name for tool in self.langchain_tools] -# Made with Bob \ No newline at end of file diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py index 662944b5..444a5248 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py @@ -116,5 +116,3 @@ def clear_cache(self) -> None: self._tools_cache = None logger.debug("Cleared tools cache") - -# Made with Bob \ No newline at end of file diff --git a/src/cuga/sdk_core/debug_tool_guard_e2e.py b/src/cuga/sdk_core/debug_tool_guard_e2e.py index d66e2d92..cec02b61 100644 --- a/src/cuga/sdk_core/debug_tool_guard_e2e.py +++ b/src/cuga/sdk_core/debug_tool_guard_e2e.py @@ -197,8 +197,7 @@ async def create_and_process_policies(agent, policy_system): raise ValueError(f"Failed to retrieve policy {policy_id}") policy = policy_tool_guide["policy"] - if agent.policies._fs_sync: - agent.policies._fs_sync.save_policy_to_file(policy) + print(f"✅ Policy saved successfully") @@ -381,4 +380,3 @@ async def main(): if __name__ == "__main__": asyncio.run(main()) -# Made with Bob \ No newline at end of file From 0c4106d3c81e70aafec0a5b63973d73c348e8937 Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 27 Apr 2026 14:37:48 +0300 Subject: [PATCH 12/24] fixes --- src/cuga/backend/cuga_graph/policy/models.py | 8 +++- .../policy/tests/test_filesystem_sync.py | 3 +- .../policy/tool_guard/cuga_llm_adapter.py | 7 ++- .../cuga_graph/policy/tool_guard/manager.py | 10 ++++- .../policy/tool_guard/tool_guard_runtime.py | 36 ++++++++++++--- .../policy/tool_guard/tool_invoker.py | 7 ++- src/cuga/backend/server/main.py | 8 +++- .../registry/registry/api_registry.py | 15 ++++--- .../registry/registry/api_registry_server.py | 45 ++++++++++++++++--- src/cuga/sdk.py | 12 ++++- src/cuga/sdk_core/debug_tool_guard_e2e.py | 5 ++- 11 files changed, 126 insertions(+), 30 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/models.py b/src/cuga/backend/cuga_graph/policy/models.py index 1ba47448..596ba061 100644 --- a/src/cuga/backend/cuga_graph/policy/models.py +++ b/src/cuga/backend/cuga_graph/policy/models.py @@ -176,7 +176,13 @@ class ToolGuard(BaseModel): default_factory=list, description="Examples of compliant usage patterns" ) policy_code: str = Field( - default="", description="Python code that validates tool usage compliance" + default="", + description=( + "Python code that validates tool usage compliance. " + "WARNING: This code is executed unsandboxed at runtime. " + "Only trusted administrators with manage access should be allowed to modify policy code. " + "Malicious code can compromise system security." + ) ) diff --git a/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py b/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py index be8c223b..211dd45f 100644 --- a/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py +++ b/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py @@ -446,7 +446,8 @@ async def test_auto_load_multiple_policy_types(self, temp_cuga_folder): assert "playbook" in policy_types assert "tool_guide" in policy_types - loaded_guide = next(p for p in policies if p["id"] == "guide_1") + loaded_guide = await agent.policies.get("guide_1") + assert loaded_guide is not None assert loaded_guide["policy"].tool_guards is not None assert "test_tool" in loaded_guide["policy"].tool_guards assert loaded_guide["policy"].tool_guards["test_tool"].description == "Guard loaded from filesystem" diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py index ae83d64a..10711eba 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py @@ -105,7 +105,12 @@ def _convert_to_langchain_messages(self, messages: List[Dict]) -> List[BaseMessa if "content" not in msg: raise ValueError(f"Message at index {i} missing 'content' key: {msg}") - role = msg["role"].lower() + raw_role = msg["role"] + if not isinstance(raw_role, str): + raise ValueError( + f"Message at index {i} has non-string 'role': {type(raw_role).__name__}" + ) + role = raw_role.lower() content = msg["content"] # Convert to appropriate LangChain message type diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py index 58b5ebb4..c04813f8 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py @@ -121,8 +121,14 @@ def _create_spec_item( @contextmanager def _temp_directory(self) -> Iterator[Path]: - """Context manager for temporary toolguard directory.""" - tmp_dir = Path(self.work_dir) / "tmp" + """Context manager for temporary toolguard directory. + + Creates a unique temporary subdirectory per invocation to avoid + race conditions when generate_examples() or generate_guard_code() + are called concurrently. + """ + import uuid + tmp_dir = Path(self.work_dir) / f"tmp_{uuid.uuid4().hex}" tmp_dir.mkdir(parents=True, exist_ok=True) try: yield tmp_dir diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index 7fc9d3cf..593d7056 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -87,6 +87,11 @@ async def initialize(self) -> None: def _reset_state(self) -> None: """Reset internal state for reinitialization.""" + if self._runtime is not None: + try: + self._runtime.__exit__(None, None, None) + except Exception: + logger.exception("Error while exiting previous ToolGuard runtime") self.tool_to_guards = {} self._runtime = None self._runtime_domain = None @@ -411,14 +416,19 @@ def _process_policy_guard( f"{validate_alias} = {guard_func_name}\n" ) + # Sanitize policy name for safe embedding in generated Python code + policy_name_literal = repr(policy.name) + guard_calls.extend([ " try:", f" await {validate_alias}(api=api, args=args)", " except PolicyViolationException as e:", " error_msg = str(e)", - f" # Check if error already contains policy name to avoid duplication", - f" if not error_msg.startswith('[{policy.name}]'):", - f" error_msg = f\"[{policy.name}] {{error_msg}}\"", + " # Check if error already contains policy name to avoid duplication", + f" _policy_name = {policy_name_literal}", + " _prefix = f\"[{_policy_name}]\"", + " if not error_msg.startswith(_prefix):", + " error_msg = f\"{_prefix} {error_msg}\"", " violations.append(error_msg)", ]) @@ -575,7 +585,12 @@ async def guard_tool_call( f"Error executing umbrella guard for tool '{function_name}': {e}", exc_info=True ) - return None + # Fail closed: treat internal guard errors as a violation so a buggy + # or malformed guard cannot silently bypass policy enforcement. + return ( + f"Internal guard error for '{function_name}': {e}. " + "Tool call blocked as a safety precaution." + ) logger.debug(f"Tool call '{function_name}' passed all guards") return None @@ -604,4 +619,15 @@ def get_guards_for_tool(self, tool_name: str) -> List[ToolGuide]: Returns: List of ToolGuide policies with guards for this tool """ - return self.tool_to_guards.get(tool_name, []) \ No newline at end of file + return self.tool_to_guards.get(tool_name, []) + + async def shutdown(self) -> None: + """Release in-memory ToolGuard runtime resources.""" + if self._runtime is not None: + try: + self._runtime.__exit__(None, None, None) + except Exception: + logger.exception("Error while shutting down ToolGuard runtime") + self._runtime = None + self._initialized = False + logger.debug("ToolGuardRuntime shutdown complete") \ No newline at end of file diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py index 444a5248..1fe7f5cc 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py @@ -76,7 +76,12 @@ async def invoke( RuntimeError: If tool invocation fails """ try: - logger.debug(f"Invoking tool '{toolname}' with args: {arguments}") + # Redact sensitive arguments before logging + arg_summary = { + k: f"<{type(v).__name__}>" if v is not None else None + for k, v in (arguments.items() if isinstance(arguments, dict) else {}) + } + logger.debug(f"Invoking tool '{toolname}' with arg keys: {list(arg_summary.keys())}") # Get the tool from the provider tools = await self._get_tools() diff --git a/src/cuga/backend/server/main.py b/src/cuga/backend/server/main.py index e736debd..7c8e0697 100644 --- a/src/cuga/backend/server/main.py +++ b/src/cuga/backend/server/main.py @@ -2395,9 +2395,13 @@ async def get_policies_config( @app.post("/api/config/policies") async def save_policies_config( request: Request, - current_user: Optional[UserInfo] = Depends(require_auth), + current_user: Optional[UserInfo] = Depends(require_manage_access), ): - """Endpoint to save policies configuration. Use draft collection when X-Use-Draft header is set.""" + """Endpoint to save policies configuration. Use draft collection when X-Use-Draft header is set. + + Security: Requires manage access role. Policy code is executed unsandboxed at runtime, + so only trusted administrators should be allowed to modify policies. + """ if not settings.policy.enabled: return JSONResponse( {"status": "error", "message": "Policy system is disabled in settings"}, diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry.py b/src/cuga/backend/tools_env/registry/registry/api_registry.py index 9018a69e..71627e1b 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry.py @@ -106,7 +106,7 @@ async def _initialize_tool_guard_runtime(self): except Exception as e: logger.warning(f"Failed to initialize ToolGuardRuntime: {e}") - logger.debug(f"ToolGuardRuntime initialization error details:", exc_info=True) + logger.debug("ToolGuardRuntime initialization error details:", exc_info=True) self.tool_guard_runtime = None async def show_applications(self) -> List[AppDefinition]: @@ -228,13 +228,16 @@ async def call_function( ) -> Dict[str, Any]: """Calls a function via the mcp_client.""" + # Normalize argument shape so guards and the tool see the same payload + unwrapped_args = arguments.get('params', arguments) if isinstance(arguments, dict) and 'params' in arguments else arguments + # Validate tool call against ToolGuard policies if self.tool_guard_runtime and self.tool_guard_runtime.is_initialized: try: error_message = await self.tool_guard_runtime.guard_tool_call( app_name=app_name, function_name=function_name, - arguments=arguments + arguments=unwrapped_args if isinstance(unwrapped_args, dict) else {} ) if error_message: @@ -257,8 +260,7 @@ async def call_function( logger.error(f"Error executing tool guard for '{function_name}': {e}", exc_info=True) if app_name == "web" and function_name == "search_web" and self._is_web_search_enabled(): - args = arguments.get('params', arguments) if isinstance(arguments, dict) else arguments - query = args.get('query') if isinstance(args, dict) else str(args) + query = unwrapped_args.get('query') if isinstance(unwrapped_args, dict) else str(unwrapped_args) if not query: return { "status": "exception", @@ -361,13 +363,12 @@ async def call_function( f"ApiRegistry: call_function(function_name='{function_name}', arguments={arguments}, headers={headers}) called." ) try: - # Delegate the call to the client - args = arguments['params'] if 'params' in arguments else arguments + # Delegate the call to the client (use already-unwrapped args) if self.auth_manager: headers["_tokens"] = json.dumps(self.auth_manager.get_stored_tokens()) result = await self.mcp_client.call_tool( tool_name=function_name, - args=args, + args=unwrapped_args, headers=headers, ) logger.debug("Response:", result) diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py index caa8be2a..55cac7e3 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py @@ -144,13 +144,18 @@ async def _get_or_create_registry( # Get policy storage if policy system is enabled policy_storage = None if settings.policy.enabled: + from cuga.backend.cuga_graph.policy.storage import PolicyStorage + policy_storage = PolicyStorage() try: - from cuga.backend.cuga_graph.policy.storage import PolicyStorage - policy_storage = PolicyStorage() await policy_storage.connect() logger.info("✅ Connected policy storage for ToolGuardRuntime") except Exception as e: - logger.warning(f"Failed to connect policy storage: {e}") + logger.error(f"Failed to connect policy storage: {e}") + # Fail closed: if policy enforcement is enabled but storage fails, + # don't boot the registry unguarded + raise RuntimeError( + f"Policy system is enabled but PolicyStorage.connect() failed: {e}" + ) from e reg = ApiRegistry(client=manager, policy_storage=policy_storage) await reg.start_servers() @@ -199,6 +204,12 @@ async def lifespan(app: FastAPI): for agent_id, (mgr, reg) in agent_registries.items(): logger.info(f"Cleaning up registry for agent: {agent_id}") await mgr.shutdown() + # Disconnect policy storage if present + if reg.policy_storage is not None: + try: + await reg.policy_storage.disconnect() + except Exception as e: + logger.warning(f"Error disconnecting policy storage for agent {agent_id}: {e}") if not database_mode and 'mcp_manager' in globals(): await mcp_manager.shutdown() @@ -497,8 +508,13 @@ async def reload_config( logger.info(f"Reloading from database for agent: {agent_id}") # Clear cache for this agent if agent_id in agent_registries: - old_mgr, _ = agent_registries.pop(agent_id) + old_mgr, old_reg = agent_registries.pop(agent_id) await old_mgr.shutdown() + if old_reg.policy_storage is not None: + try: + await old_reg.policy_storage.disconnect() + except Exception as e: + logger.warning(f"Error disconnecting policy storage: {e}") # Recreate registry for this agent with retry on empty await _get_or_create_registry(agent_id, retry_on_empty=True) # If this is the default agent, update global registry @@ -523,8 +539,13 @@ async def reload_config( # Reload all agents (clear cache) logger.info("Reloading all agents from database") agent_ids = list(agent_registries.keys()) - for old_mgr, _ in agent_registries.values(): + for old_mgr, old_reg in agent_registries.values(): await old_mgr.shutdown() + if old_reg.policy_storage is not None: + try: + await old_reg.policy_storage.disconnect() + except Exception as e: + logger.warning(f"Error disconnecting policy storage: {e}") agent_registries.clear() # Recreate default agent mcp_manager, registry = await _get_or_create_registry(default_agent_id) @@ -580,16 +601,26 @@ async def clear_agent_cache( try: if agent_id: if agent_id in agent_registries: - mgr, _ = agent_registries.pop(agent_id) + mgr, reg = agent_registries.pop(agent_id) await mgr.shutdown() + if reg.policy_storage is not None: + try: + await reg.policy_storage.disconnect() + except Exception as e: + logger.warning(f"Error disconnecting policy storage: {e}") logger.info(f"Cleared cache for agent: {agent_id}") return {"status": "ok", "message": f"Cache cleared for agent: {agent_id}"} else: return {"status": "ok", "message": f"No cache found for agent: {agent_id}"} else: count = len(agent_registries) - for mgr, _ in agent_registries.values(): + for mgr, reg in agent_registries.values(): await mgr.shutdown() + if reg.policy_storage is not None: + try: + await reg.policy_storage.disconnect() + except Exception as e: + logger.warning(f"Error disconnecting policy storage: {e}") agent_registries.clear() logger.info(f"Cleared cache for {count} agents") return {"status": "ok", "message": f"Cache cleared for {count} agents"} diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index f0b8176a..5606eba8 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -573,8 +573,10 @@ async def update_tool_guard( f"Policy '{policy_id}' is not a ToolGuide policy (type: {type(existing_policy).__name__})" ) - # Convert tool_guards dict to ToolGuard objects - tool_guards_obj = {} + # Merge with existing tool_guards to preserve guards for other tools + tool_guards_obj = dict(existing_policy.tool_guards or {}) + + # Convert and update only the incoming tool_guards for tool_name, guard_data in tool_guards.items(): tool_guards_obj[tool_name] = ToolGuard( description=guard_data.get("description", ""), @@ -1213,6 +1215,12 @@ async def sync_from_filesystem(self) -> Dict[str, Any]: } except Exception as e: logger.error(f"Failed to sync from filesystem: {e}") + return { + "loaded": 0, + "removed": 0, + "errors": [str(e)], + "files": [], + } async def generate_tool_guard_examples( self, diff --git a/src/cuga/sdk_core/debug_tool_guard_e2e.py b/src/cuga/sdk_core/debug_tool_guard_e2e.py index cec02b61..091475d2 100644 --- a/src/cuga/sdk_core/debug_tool_guard_e2e.py +++ b/src/cuga/sdk_core/debug_tool_guard_e2e.py @@ -11,9 +11,11 @@ Configuration: - Set DELETE_ALL_POLICIES_AT_START = True to delete all existing policies before running - Set DELETE_ALL_POLICIES_AT_START = False to preserve existing policies (default) +- Set environment variable CUGA_E2E_ALLOW_DESTRUCTIVE=true to enable destructive cleanup """ import asyncio +import os from pathlib import Path from langchain_core.tools import tool @@ -25,7 +27,8 @@ # ============================================================================ # CONFIGURATION # ============================================================================ -DELETE_ALL_POLICIES_AT_START = True +# Default to False for safety - require explicit opt-in for destructive operations +DELETE_ALL_POLICIES_AT_START = os.environ.get("CUGA_E2E_ALLOW_DESTRUCTIVE", "").lower() in ("true", "1", "yes") # ============================================================================ # Define policies to create From 07778facfee9b48cf299855bb7840e4b927e4e4f Mon Sep 17 00:00:00 2001 From: naamaz Date: Tue, 28 Apr 2026 12:41:11 +0300 Subject: [PATCH 13/24] fixes --- .../policy/tool_guard/tool_guard_runtime.py | 4 ++-- .../tools_env/registry/registry/api_registry.py | 14 +++++++++++--- .../registry/registry/api_registry_server.py | 14 ++++++++++++-- src/cuga/sdk.py | 1 - 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index 593d7056..297d38fc 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple - +from types import SimpleNamespace from loguru import logger from toolguard.runtime.data_types import ( @@ -568,7 +568,7 @@ async def guard_tool_call( ) try: - args_obj = type("ToolGuardArgs", (), arguments)() + args_obj = SimpleNamespace(**arguments) await self._runtime.guard_toolcall( tool_name=function_name, args=arguments | {"args": args_obj}, diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry.py b/src/cuga/backend/tools_env/registry/registry/api_registry.py index 71627e1b..a8762d55 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry.py @@ -255,9 +255,17 @@ async def call_function( else: logger.debug(f"✅ Tool guard validation passed for '{function_name}'") except Exception as e: - # Log guard execution error but don't block the tool call - # This ensures a broken guard doesn't break the entire system - logger.error(f"Error executing tool guard for '{function_name}': {e}", exc_info=True) + # Fail-closed: treat exceptions from guard_tool_call as policy violations + # to honor ToolGuardRuntime's fail-closed contract + error_msg = f"Exception during tool guard execution for '{function_name}': {e}" + logger.error(error_msg, exc_info=True) + return { + "status": "exception", + "status_code": 403, + "message": f"Tool guard policy violation: {error_msg}", + "error_type": "ToolGuardViolation", + "function_name": function_name, + } if app_name == "web" and function_name == "search_web" and self._is_web_search_enabled(): query = unwrapped_args.get('query') if isinstance(unwrapped_args, dict) else str(unwrapped_args) diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py index 55cac7e3..89911497 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py @@ -193,7 +193,12 @@ async def lifespan(app: FastAPI): await policy_storage.connect() logger.info("✅ Connected policy storage for ToolGuardRuntime") except Exception as e: - logger.warning(f"Failed to connect policy storage: {e}") + logger.error(f"Failed to connect policy storage: {e}") + # Fail closed: if policy enforcement is enabled but storage fails, + # don't boot the registry unguarded + raise RuntimeError( + f"Policy system is enabled but PolicyStorage.connect() failed: {e}" + ) from e registry = ApiRegistry(client=mcp_manager, policy_storage=policy_storage) await registry.start_servers() @@ -571,7 +576,12 @@ async def reload_config( await policy_storage.connect() logger.info("✅ Connected policy storage for ToolGuardRuntime") except Exception as e: - logger.warning(f"Failed to connect policy storage: {e}") + logger.error(f"Failed to connect policy storage: {e}") + # Fail closed: if policy enforcement is enabled but storage fails, + # don't reload the registry unguarded + raise RuntimeError( + f"Policy system is enabled but PolicyStorage.connect() failed: {e}" + ) from e new_registry = ApiRegistry(client=new_manager, policy_storage=policy_storage) await new_registry.start_servers() diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index 65f1e1d7..ec5735a1 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -1408,7 +1408,6 @@ async def generate_tool_guard_code( ) return guard_code - return violating_examples, compliance_examples class CugaAgent: From abebaaea32176b80803726bdb2bb590863fed307 Mon Sep 17 00:00:00 2001 From: naamaz Date: Tue, 28 Apr 2026 14:33:01 +0300 Subject: [PATCH 14/24] fixes --- .../policy/tool_guard/tool_guard_runtime.py | 16 ++++-- .../registry/registry/api_registry.py | 15 ++++-- .../registry/registry/api_registry_server.py | 52 +++++++++++++++++-- src/cuga/sdk.py | 23 ++++++-- 4 files changed, 88 insertions(+), 18 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index 297d38fc..6e1ca6a3 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -512,18 +512,28 @@ def _guard_function_name_for_tool(self, tool_name: str) -> str: def _normalize_name(self, name: str) -> str: """ - Normalize a name to be a valid Python identifier. + Normalize a name to be a valid Python identifier with disambiguation. Args: name: Name to normalize Returns: - Normalized name safe for use as Python identifier + Normalized name safe for use as Python identifier with hash suffix """ + import hashlib + + # Create readable normalized portion normalized = "".join( ch if ch.isalnum() else "_" for ch in name.lower() ).strip("_") - return normalized if normalized else "tool" + + # Use "tool" as base if normalization results in empty string + base = normalized if normalized else "tool" + + # Add short hash suffix for disambiguation + name_hash = hashlib.sha256(name.encode()).hexdigest()[:8] + + return f"{base}_{name_hash}" async def guard_tool_call( self, diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry.py b/src/cuga/backend/tools_env/registry/registry/api_registry.py index a8762d55..af8093e4 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry.py @@ -105,9 +105,12 @@ async def _initialize_tool_guard_runtime(self): logger.debug(f" Guarded tools: {', '.join(guarded_tools)}") except Exception as e: - logger.warning(f"Failed to initialize ToolGuardRuntime: {e}") - logger.debug("ToolGuardRuntime initialization error details:", exc_info=True) - self.tool_guard_runtime = None + logger.error(f"Failed to initialize ToolGuardRuntime: {e}", exc_info=True) + # Fail closed: if policy enforcement is enabled but ToolGuardRuntime fails, + # don't allow the service to start without policy validation + raise RuntimeError( + f"ToolGuardRuntime failed to initialize. Policy enforcement cannot be bypassed. Error: {e}" + ) from e async def show_applications(self) -> List[AppDefinition]: """Lists application names and their descriptions.""" @@ -228,8 +231,10 @@ async def call_function( ) -> Dict[str, Any]: """Calls a function via the mcp_client.""" - # Normalize argument shape so guards and the tool see the same payload - unwrapped_args = arguments.get('params', arguments) if isinstance(arguments, dict) and 'params' in arguments else arguments + # Use arguments as-is - do not unwrap 'params' unconditionally + # Transport-layer wrappers should be normalized at the request boundary, + # not here where it affects both guards and tools + unwrapped_args = arguments # Validate tool call against ToolGuard policies if self.tool_guard_runtime and self.tool_guard_runtime.is_initialized: diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py index 89911497..009fb18d 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py @@ -158,7 +158,16 @@ async def _get_or_create_registry( ) from e reg = ApiRegistry(client=manager, policy_storage=policy_storage) - await reg.start_servers() + try: + await reg.start_servers() + except Exception: + # Clean up policy_storage if start_servers fails + if policy_storage is not None: + try: + await policy_storage.disconnect() + except Exception as cleanup_error: + logger.warning(f"Error disconnecting policy storage during cleanup: {cleanup_error}") + raise agent_registries[agent_id] = (manager, reg) return manager, reg @@ -201,7 +210,16 @@ async def lifespan(app: FastAPI): ) from e registry = ApiRegistry(client=mcp_manager, policy_storage=policy_storage) - await registry.start_servers() + try: + await registry.start_servers() + except Exception: + # Clean up policy_storage if start_servers fails + if policy_storage is not None: + try: + await policy_storage.disconnect() + except Exception as cleanup_error: + logger.warning(f"Error disconnecting policy storage during cleanup: {cleanup_error}") + raise yield @@ -215,8 +233,15 @@ async def lifespan(app: FastAPI): await reg.policy_storage.disconnect() except Exception as e: logger.warning(f"Error disconnecting policy storage for agent {agent_id}: {e}") - if not database_mode and 'mcp_manager' in globals(): - await mcp_manager.shutdown() + if not database_mode: + # In YAML mode, also clean up the global registry's policy storage + if 'registry' in globals() and registry is not None and registry.policy_storage is not None: + try: + await registry.policy_storage.disconnect() + except Exception as e: + logger.warning(f"Error disconnecting global registry policy storage: {e}") + if 'mcp_manager' in globals(): + await mcp_manager.shutdown() # --- FastAPI Server Setup --- @@ -584,7 +609,24 @@ async def reload_config( ) from e new_registry = ApiRegistry(client=new_manager, policy_storage=policy_storage) - await new_registry.start_servers() + try: + await new_registry.start_servers() + except Exception: + # Clean up new policy_storage if start_servers fails + if policy_storage is not None: + try: + await policy_storage.disconnect() + except Exception as cleanup_error: + logger.warning(f"Error disconnecting policy storage during cleanup: {cleanup_error}") + raise + + # Clean up old registry's policy storage before replacing + if 'registry' in globals() and registry is not None and registry.policy_storage is not None: + try: + await registry.policy_storage.disconnect() + except Exception as e: + logger.warning(f"Error disconnecting old registry policy storage: {e}") + if 'mcp_manager' in globals(): await mcp_manager.shutdown() mcp_manager = new_manager diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index ec5735a1..c4a01a43 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -588,6 +588,16 @@ async def update_tool_guard( # Merge with existing tool_guards to preserve guards for other tools tool_guards_obj = dict(existing_policy.tool_guards or {}) + # Validate that tool_guards keys are in target_tools + target_tools_set = set(existing_policy.target_tools or []) + invalid_tools = set(tool_guards.keys()) - target_tools_set + + if invalid_tools: + raise ValueError( + f"Invalid tool names in tool_guards: {', '.join(sorted(invalid_tools))}. " + f"Must be one of: {', '.join(sorted(target_tools_set))}" + ) + # Convert and update only the incoming tool_guards for tool_name, guard_data in tool_guards.items(): tool_guards_obj[tool_name] = ToolGuard( @@ -1206,11 +1216,11 @@ async def sync_from_filesystem(self) -> Dict[str, Any]: policy_system = await self._ensure_policy_system() if policy_system is None: logger.warning("Policy system is disabled - skipping sync") - return {"loaded": 0, "removed": 0, "errors": ["Policy system is disabled"]} + return {"loaded": 0, "removed": 0, "errors": ["Policy system is disabled"], "files": []} if not self._fs_sync: logger.warning("Filesystem sync not initialized - skipping") - return {"loaded": 0, "removed": 0, "errors": ["Filesystem sync not initialized"]} + return {"loaded": 0, "removed": 0, "errors": ["Filesystem sync not initialized"], "files": []} try: # Load policies from filesystem @@ -1358,9 +1368,12 @@ async def generate_tool_guard_code( # Update policy with examples await agent.policies.update_tool_guard( policy_id=policy_id, - target_tool="delete_file", - violating_examples=violating, - compliance_examples=compliance + tool_guards={ + "delete_file": { + "violating_examples": violating, + "compliance_examples": compliance + } + } ) # Generate guard code From aaa7b21f9a826075c5a5fc1f93f1fbdba409b772 Mon Sep 17 00:00:00 2001 From: naamaz Date: Tue, 28 Apr 2026 15:04:35 +0300 Subject: [PATCH 15/24] fixes --- .../policy/tool_guard/tool_guard_runtime.py | 213 ++++++++++++++---- .../registry/registry/api_registry_server.py | 18 ++ 2 files changed, 185 insertions(+), 46 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index 6e1ca6a3..76756291 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -50,8 +50,9 @@ def __init__(self, tool_provider, policy_storage: PolicyStorage) -> None: self.policy_storage = policy_storage self.invoker = ToolGuardInvoker(tool_provider) self.tool_to_guards: Dict[str, List[ToolGuide]] = {} - self._runtime = None - self._runtime_domain: Optional[RuntimeDomain] = None + # Per-app runtime mapping to avoid cross-app collisions + self._runtimes_by_app: Dict[str, Any] = {} + self._runtime_domains_by_app: Dict[str, RuntimeDomain] = {} self._initialized = False logger.debug("Created ToolGuardRuntime instance") @@ -63,8 +64,7 @@ async def initialize(self) -> None: 1. Fetches all ToolGuide policies from storage 2. Filters for policies that have tool_guards with policy_code 3. Builds the tool_to_guards mapping - 4. Creates umbrella guard functions per tool - 5. Builds an in-memory ToolGuard runtime + 4. Per-app runtimes will be lazily loaded on first use """ logger.info("Initializing ToolGuardRuntime...") self._reset_state() @@ -78,23 +78,21 @@ async def initialize(self) -> None: tool_guide_policies = [p for p in policies if isinstance(p, ToolGuide)] self._build_tool_to_guards_mapping(tool_guide_policies) - if self.tool_to_guards: - self._runtime_domain = self._load_runtime_domain() - self._runtime = self._build_runtime() - self._initialized = True self._log_initialization_summary() def _reset_state(self) -> None: """Reset internal state for reinitialization.""" - if self._runtime is not None: - try: - self._runtime.__exit__(None, None, None) - except Exception: - logger.exception("Error while exiting previous ToolGuard runtime") + # Clean up all per-app runtimes + for app_name, runtime in self._runtimes_by_app.items(): + if runtime is not None: + try: + runtime.__exit__(None, None, None) + except Exception: + logger.exception(f"Error while exiting ToolGuard runtime for app '{app_name}'") self.tool_to_guards = {} - self._runtime = None - self._runtime_domain = None + self._runtimes_by_app = {} + self._runtime_domains_by_app = {} def _build_tool_to_guards_mapping(self, policies: Sequence[ToolGuide]) -> None: """ @@ -128,12 +126,23 @@ def _register_policy_guards(self, policy: ToolGuide) -> None: ) continue + # Validate that policy_code contains at least one async def guard_* function + guard_func_name = self._extract_guard_function_name(tool_guard.policy_code) + if not guard_func_name: + logger.error( + f"Tool guard for '{tool_name}' in policy '{policy.name}' " + f"has policy_code but no valid 'async def guard_*' function found. " + f"Skipping registration to prevent marking tool as guarded without enforcement." + ) + continue + if tool_name not in self.tool_to_guards: self.tool_to_guards[tool_name] = [] self.tool_to_guards[tool_name].append(policy) logger.debug( - f"Registered guard for tool '{tool_name}' from policy '{policy.name}'" + f"Registered guard for tool '{tool_name}' from policy '{policy.name}' " + f"with guard function '{guard_func_name}'" ) def _log_initialization_summary(self) -> None: @@ -148,19 +157,42 @@ def _log_initialization_summary(self) -> None: f"({', '.join(g.name for g in guards)})" ) - def _build_runtime(self): - """Build an in-memory ToolGuard runtime from registered guard policies.""" - if self._runtime_domain is None: - raise RuntimeError("ToolGuard runtime domain is not loaded") + def _build_runtime(self, app_name: str): + """ + Build an in-memory ToolGuard runtime from registered guard policies for a specific app. + + Args: + app_name: Name of the application to build runtime for + + Returns: + Runtime instance for the specified app + """ + runtime_domain = self._runtime_domains_by_app.get(app_name) + if runtime_domain is None: + raise RuntimeError(f"ToolGuard runtime domain not loaded for app '{app_name}'") file_twins: List[FileTwin] = [ - self._runtime_domain.app_types, - self._runtime_domain.app_api, - self._runtime_domain.app_api_impl, + runtime_domain.app_types, + runtime_domain.app_api, + runtime_domain.app_api_impl, ] tools: Dict[str, ToolGuardCodeResult] = {} - for tool_name, guards in self.tool_to_guards.items(): + for tool_name, all_guards in self.tool_to_guards.items(): + # Filter guards to only those applicable to this app + guards = [ + guard for guard in all_guards + if guard.target_apps is None or not guard.target_apps or app_name in guard.target_apps + ] + + # Skip this tool if no guards apply to this app + if not guards: + logger.debug( + f"Skipping tool '{tool_name}' for app '{app_name}' - " + f"no applicable guards (out of {len(all_guards)} total)" + ) + continue + module_name = self._module_name_for_tool(tool_name) guard_fn_name = self._guard_function_name_for_tool(tool_name) guard_module_path = Path(*module_name.split(".")).with_suffix(".py") @@ -196,7 +228,7 @@ def _build_runtime(self): result = ToolGuardsCodeGenerationResult( out_dir=Path("."), - domain=self._runtime_domain, + domain=runtime_domain, tools=tools, ) @@ -204,12 +236,15 @@ def _build_runtime(self): runtime.__enter__() return runtime - def _load_runtime_domain(self) -> RuntimeDomain: + def _load_runtime_domain(self, app_name: str) -> RuntimeDomain: """ - Load RuntimeDomain files saved by ToolGuardManager. + Load RuntimeDomain files saved by ToolGuardManager for a specific app. + + Args: + app_name: Name of the application to load domain for Returns: - RuntimeDomain with loaded domain files + RuntimeDomain with loaded domain files for the specified app Raises: RuntimeError: If domain directory or files are not found @@ -217,12 +252,18 @@ def _load_runtime_domain(self) -> RuntimeDomain: domain_dir = Path.cwd() / ".cuga" / "toolguard" / "domain" self._validate_domain_directory(domain_dir) - app_dirs = self._get_sorted_app_directories(domain_dir) - selected_domain = self._find_complete_domain(domain_dir, app_dirs) + # Look for the specific app's domain + app_dir = domain_dir / app_name + if not app_dir.exists(): + raise RuntimeError( + f"ToolGuard domain directory not found for app '{app_name}': {app_dir}" + ) + + selected_domain = self._find_complete_domain_for_app(domain_dir, app_name) if selected_domain is None: raise RuntimeError( - f"No complete ToolGuard domain found under {domain_dir}" + f"No complete ToolGuard domain found for app '{app_name}' under {domain_dir}" ) return self._create_runtime_domain(domain_dir, selected_domain) @@ -296,6 +337,33 @@ def _find_complete_domain( return None + def _find_complete_domain_for_app( + self, domain_dir: Path, app_name: str + ) -> Optional[Tuple[str, Path, Path, Path]]: + """ + Find complete domain files for a specific app. + + Args: + domain_dir: Path to domain directory + app_name: Name of the app to find domain for + + Returns: + Tuple of (app_name, types_path, api_path, impl_path) or None + """ + app_types_rel = Path(app_name) / f"{app_name}_types.py" + app_api_rel = Path(app_name) / f"i_{app_name}.py" + app_api_impl_rel = Path(app_name) / f"{app_name}_impl.py" + + candidate_paths = [ + domain_dir / app_types_rel, + domain_dir / app_api_rel, + domain_dir / app_api_impl_rel, + ] + if all(path.exists() for path in candidate_paths): + return (app_name, app_types_rel, app_api_rel, app_api_impl_rel) + + return None + def _create_runtime_domain( self, domain_dir: Path, selected_domain: Tuple[str, Path, Path, Path] ) -> RuntimeDomain: @@ -535,6 +603,41 @@ def _normalize_name(self, name: str) -> str: return f"{base}_{name_hash}" + async def _get_or_create_runtime_for_app(self, app_name: str): + """ + Get or lazily create a runtime for the specified app. + + Args: + app_name: Name of the application + + Returns: + Runtime instance for the app, or None if it cannot be created + """ + # Return cached runtime if available + if app_name in self._runtimes_by_app: + return self._runtimes_by_app[app_name] + + # Try to load and build runtime for this app + try: + logger.info(f"Loading runtime domain for app '{app_name}'...") + runtime_domain = self._load_runtime_domain(app_name) + self._runtime_domains_by_app[app_name] = runtime_domain + + logger.info(f"Building runtime for app '{app_name}'...") + runtime = self._build_runtime(app_name) + self._runtimes_by_app[app_name] = runtime + + logger.info(f"✅ Runtime initialized for app '{app_name}'") + return runtime + except Exception as e: + logger.error( + f"Failed to initialize runtime for app '{app_name}': {e}", + exc_info=True + ) + # Cache None to avoid repeated failed attempts + self._runtimes_by_app[app_name] = None + return None + async def guard_tool_call( self, app_name: str, @@ -564,22 +667,37 @@ async def guard_tool_call( logger.debug(f"No guards registered for tool '{function_name}'") return None - if self._runtime is None: + # Filter guards to only those applicable to this app + all_guards = self.tool_to_guards[function_name] + guards = [ + guard for guard in all_guards + if guard.target_apps is None or not guard.target_apps or app_name in guard.target_apps + ] + + if not guards: + logger.debug( + f"No guards applicable for tool '{function_name}' on app '{app_name}' " + f"(found {len(all_guards)} guard(s) but none match this app)" + ) + return None + + # Get or create app-specific runtime + runtime = await self._get_or_create_runtime_for_app(app_name) + if runtime is None: logger.warning( - f"ToolGuard runtime unavailable for guarded tool '{function_name}', " + f"ToolGuard runtime unavailable for app '{app_name}' and tool '{function_name}', " "skipping validation" ) return None - guards = self.tool_to_guards[function_name] logger.debug( - f"Validating tool call '{function_name}' against " - f"{len(guards)} guard(s) using umbrella runtime" + f"Validating tool call '{function_name}' for app '{app_name}' against " + f"{len(guards)} applicable guard(s) (out of {len(all_guards)} total) using umbrella runtime" ) try: args_obj = SimpleNamespace(**arguments) - await self._runtime.guard_toolcall( + await runtime.guard_toolcall( tool_name=function_name, args=arguments | {"args": args_obj}, delegate=self.invoker, @@ -587,12 +705,12 @@ async def guard_tool_call( except PolicyViolationException as e: error = str(e) logger.warning( - f"Tool guard blocked call to '{function_name}': {error}" + f"Tool guard blocked call to '{function_name}' for app '{app_name}': {error}" ) return error except Exception as e: logger.error( - f"Error executing umbrella guard for tool '{function_name}': {e}", + f"Error executing umbrella guard for tool '{function_name}' in app '{app_name}': {e}", exc_info=True ) # Fail closed: treat internal guard errors as a violation so a buggy @@ -602,7 +720,7 @@ async def guard_tool_call( "Tool call blocked as a safety precaution." ) - logger.debug(f"Tool call '{function_name}' passed all guards") + logger.debug(f"Tool call '{function_name}' for app '{app_name}' passed all guards") return None @property @@ -633,11 +751,14 @@ def get_guards_for_tool(self, tool_name: str) -> List[ToolGuide]: async def shutdown(self) -> None: """Release in-memory ToolGuard runtime resources.""" - if self._runtime is not None: - try: - self._runtime.__exit__(None, None, None) - except Exception: - logger.exception("Error while shutting down ToolGuard runtime") - self._runtime = None + # Clean up all per-app runtimes + for app_name, runtime in self._runtimes_by_app.items(): + if runtime is not None: + try: + runtime.__exit__(None, None, None) + except Exception: + logger.exception(f"Error while shutting down ToolGuard runtime for app '{app_name}'") + self._runtimes_by_app = {} + self._runtime_domains_by_app = {} self._initialized = False logger.debug("ToolGuardRuntime shutdown complete") \ No newline at end of file diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py index 009fb18d..2a754389 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py @@ -161,6 +161,12 @@ async def _get_or_create_registry( try: await reg.start_servers() except Exception: + # Clean up manager if start_servers fails + try: + await manager.shutdown() + except Exception as cleanup_error: + logger.warning(f"Error shutting down manager during cleanup: {cleanup_error}") + # Clean up policy_storage if start_servers fails if policy_storage is not None: try: @@ -213,6 +219,12 @@ async def lifespan(app: FastAPI): try: await registry.start_servers() except Exception: + # Clean up mcp_manager if start_servers fails + try: + await mcp_manager.shutdown() + except Exception as cleanup_error: + logger.warning(f"Error shutting down mcp_manager during cleanup: {cleanup_error}") + # Clean up policy_storage if start_servers fails if policy_storage is not None: try: @@ -612,6 +624,12 @@ async def reload_config( try: await new_registry.start_servers() except Exception: + # Clean up new_manager if start_servers fails + try: + await new_manager.shutdown() + except Exception as cleanup_error: + logger.warning(f"Error shutting down new_manager during cleanup: {cleanup_error}") + # Clean up new policy_storage if start_servers fails if policy_storage is not None: try: From 27ed4518baab2cc737334ee69170b2c6adbbb6c0 Mon Sep 17 00:00:00 2001 From: DAVID BOAZ Date: Thu, 30 Apr 2026 11:02:19 +0300 Subject: [PATCH 16/24] cuga-toolguard dev --- pyproject.toml | 2 +- .../policy/tool_guard/cuga_llm_adapter.py | 205 ------------------ .../cuga_graph/policy/tool_guard/manager.py | 40 ++-- .../policy/tool_guard/tool_invoker.py | 2 +- 4 files changed, 20 insertions(+), 229 deletions(-) delete mode 100644 src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py diff --git a/pyproject.toml b/pyproject.toml index 9febd520..6f1369b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ "langchain-text-splitters>=1.0,<2.0", "beautifulsoup4>=4.12", "psycopg[binary]>=3.2", - "toolguard>=0.2.16", + "toolguard>=0.2.17", "cuga-oak-health; python_version>='3.12'", "aiosmtpd", ] diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py b/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py deleted file mode 100644 index 10711eba..00000000 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/cuga_llm_adapter.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Adapter to use CUGA's LLM (BaseChatModel) with ToolGuard's I_TG_LLM interface. - -This module provides a bridge between CUGA's LangChain-based LLM system and -ToolGuard's example generation system, allowing any CUGA-configured model -(OpenAI, Groq, Azure, etc.) to be used with ToolGuard. -""" - -from typing import Dict, List -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage -from loguru import logger -from toolguard.buildtime.llm.llm_base import LanguageModelBase - - -class CugaLLMAdapter(LanguageModelBase): - """ - Adapter that wraps CUGA's BaseChatModel to work with ToolGuard's I_TG_LLM interface. - - This allows using any LangChain-compatible model (OpenAI, Groq, Azure, etc.) - with ToolGuard's example generation system. The adapter handles message format - conversion between ToolGuard's dict-based format and LangChain's message objects. - - Inherits from LanguageModelBase which provides: - - chat_json(): Automatic JSON extraction with retry logic - - extract_json_from_string(): JSON parsing utilities - - Example: - ```python - from cuga.sdk import CugaAgent - from cuga.backend.cuga_graph.policy.tool_guard.cuga_llm_adapter import CugaLLMAdapter - - # Create agent with any model - agent = CugaAgent(tools=[my_tool]) - - # Wrap the model for ToolGuard - llm_adapter = CugaLLMAdapter(agent._model) - - # Use with ToolGuard - messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Generate examples"} - ] - response = await llm_adapter.generate(messages) - - # Or use chat_json for structured output - json_response = await llm_adapter.chat_json(messages) - ``` - - Attributes: - model: The underlying LangChain BaseChatModel instance - """ - - def __init__(self, model: BaseChatModel): - """ - Initialize the adapter with a CUGA/LangChain model. - - Args: - model: BaseChatModel instance (e.g., ChatOpenAI, ChatGroq, AzureChatOpenAI, etc.) - This is typically obtained from agent._model in CugaAgent. - - Raises: - TypeError: If model is not a BaseChatModel instance - """ - if not isinstance(model, BaseChatModel): - raise TypeError( - f"Expected BaseChatModel instance, got {type(model).__name__}. " - "Please pass agent._model from a CugaAgent instance." - ) - - self.model = model - logger.info(f"Initialized CugaLLMAdapter with model: {type(model).__name__}") - - def _convert_to_langchain_messages(self, messages: List[Dict]) -> List[BaseMessage]: - """ - Convert ToolGuard message format to LangChain message objects. - - ToolGuard uses a simple dict format: - [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}] - - LangChain uses typed message objects: - [HumanMessage(content="Hello"), AIMessage(content="Hi")] - - Args: - messages: List of dicts with 'role' and 'content' keys. - Supported roles: 'system', 'user', 'human', 'assistant', 'ai' - - Returns: - List of LangChain message objects (SystemMessage, HumanMessage, AIMessage) - - Raises: - ValueError: If a message is missing 'role' or 'content' keys - """ - langchain_messages = [] - - for i, msg in enumerate(messages): - if not isinstance(msg, dict): - raise ValueError( - f"Message at index {i} must be a dict, got {type(msg).__name__}" - ) - - if "role" not in msg: - raise ValueError(f"Message at index {i} missing 'role' key: {msg}") - - if "content" not in msg: - raise ValueError(f"Message at index {i} missing 'content' key: {msg}") - - raw_role = msg["role"] - if not isinstance(raw_role, str): - raise ValueError( - f"Message at index {i} has non-string 'role': {type(raw_role).__name__}" - ) - role = raw_role.lower() - content = msg["content"] - - # Convert to appropriate LangChain message type - if role == "system": - langchain_messages.append(SystemMessage(content=content)) - elif role in ("assistant", "ai"): - langchain_messages.append(AIMessage(content=content)) - elif role in ("user", "human"): - langchain_messages.append(HumanMessage(content=content)) - else: - logger.warning( - f"Unknown role '{role}' at index {i}, treating as user message" - ) - langchain_messages.append(HumanMessage(content=content)) - - return langchain_messages - - async def generate(self, messages: List[Dict]) -> str: - """ - Generate a text response from the model. - - This is the core method required by I_TG_LLM interface. It converts - ToolGuard's message format to LangChain format, invokes the model, - and returns the generated text. - - Args: - messages: List of message dicts in format: - [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Generate examples"}, - {"role": "assistant", "content": "Here are examples..."} - ] - - Returns: - Generated text response as string - - Raises: - ValueError: If messages format is invalid - Exception: If model invocation fails (propagated from LangChain) - - Example: - ```python - adapter = CugaLLMAdapter(model) - messages = [ - {"role": "system", "content": "You are helpful"}, - {"role": "user", "content": "Say hello"} - ] - response = await adapter.generate(messages) - print(response) # "Hello! How can I help you today?" - ``` - """ - try: - # Convert to LangChain format - langchain_messages = self._convert_to_langchain_messages(messages) - - logger.debug( - f"Invoking {type(self.model).__name__} with {len(langchain_messages)} messages" - ) - - # Invoke the model using LangChain's async interface - response = await self.model.ainvoke(langchain_messages) - - # Extract content from response - # LangChain models return AIMessage or similar with .content attribute - if hasattr(response, 'content'): - content = response.content - # Handle case where content might be a list (e.g., multimodal responses) - if isinstance(content, list): - # Join list items into a single string - content = "\n".join(str(item) for item in content) - elif not isinstance(content, str): - content = str(content) - else: - content = str(response) - - logger.debug(f"Generated response length: {len(content)} characters") - - return content - - except ValueError as e: - # Re-raise validation errors - logger.error(f"Message format error: {e}") - raise - except Exception as e: - # Log and re-raise model invocation errors - logger.error(f"Error generating response with {type(self.model).__name__}: {e}") - raise RuntimeError( - f"Failed to generate response: {e}. " - "Check that your model is properly configured and has valid credentials." - ) from e - - diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py index c04813f8..0bb13176 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py @@ -8,17 +8,25 @@ import shutil from contextlib import contextmanager -from typing import List, Tuple, Dict, Any, Iterator, Optional from pathlib import Path -from loguru import logger +from typing import Any, Dict, Iterator, List, Optional, Tuple -from cuga.backend.cuga_graph.policy.models import ToolGuide, PolicyType -from cuga.backend.cuga_graph.policy.tool_guard.cuga_llm_adapter import CugaLLMAdapter +from langchain_core.tools import StructuredTool +from loguru import logger +from toolguard.buildtime.buildtime import ( + ToolGuardSpec, + generate_guard_examples, + generate_guards_code, +) +from toolguard.buildtime.llm import LangchainModelWrapper +from toolguard.extra.langchain_to_oas import langchain_tools_to_openapi +from toolguard.runtime.data_types import ( + ToolGuardsCodeGenerationResult, + ToolGuardSpecItem, +) +from cuga.backend.cuga_graph.policy.models import PolicyType, ToolGuide from cuga.sdk import CugaAgent -from toolguard.buildtime.buildtime import ToolGuardSpec, generate_guard_examples, generate_guards_code -from toolguard.extra.langchain_to_oas import langchain_tools_to_openapi -from toolguard.runtime.data_types import ToolGuardSpecItem, ToolGuardsCodeGenerationResult class ToolGuardManager: @@ -27,7 +35,7 @@ def __init__( self, agent: CugaAgent, ): - self.langchain_tools: List[Any] = [] # Store LangChain tools + self.langchain_tools: List[StructuredTool] = [] # Store LangChain tools self.tools_dict: Dict[str, Any] = {} # Store OpenAPI dict for ToolGuard self._initialized = False @@ -40,8 +48,8 @@ def __init__( "before creating ToolGuardManager." ) - self.llm = CugaLLMAdapter(agent._model) - logger.info(f"Initialized ToolGuardManager with {type(agent._model).__name__} via CugaLLMAdapter") + self.llm = LangchainModelWrapper(agent._model) + logger.info(f"Initialized ToolGuardManager with {type(agent._model).__name__} via LangchainModelWrapper") # Use Path to properly handle path concatenation and ensure directory exists work_dir_path = Path(agent.cuga_folder) / "toolguard" work_dir_path.mkdir(parents=True, exist_ok=True) @@ -343,17 +351,5 @@ def is_initialized(self) -> bool: """Check if the manager has been initialized.""" return self._initialized - def get_available_tools(self) -> List[str]: - """ - Get list of available tool names. - - Returns: - List of tool names that can be used in policies - - Raises: - RuntimeError: If manager not initialized - """ - self._ensure_initialized() - return [tool.name for tool in self.langchain_tools] diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py index 1fe7f5cc..49d85ecf 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py @@ -8,7 +8,7 @@ from typing import Any, Dict, Optional, Type, TypeVar from loguru import logger -from toolguard.runtime.data_types import IToolInvoker +from toolguard.runtime import IToolInvoker T = TypeVar('T') From e8df696ed00a4cddee6fbfe40611102d3384f59f Mon Sep 17 00:00:00 2001 From: naamaz Date: Thu, 30 Apr 2026 11:11:17 +0300 Subject: [PATCH 17/24] fix1 --- .../cuga_graph/policy/tool_guard/tool_guard_runtime.py | 4 ++-- src/cuga/sdk.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index 593d7056..3fed2c76 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple from loguru import logger - +from types import SimpleNamespace from toolguard.runtime.data_types import ( FileTwin, PolicyViolationException, @@ -568,7 +568,7 @@ async def guard_tool_call( ) try: - args_obj = type("ToolGuardArgs", (), arguments)() + args_obj = SimpleNamespace(**arguments) await self._runtime.guard_toolcall( tool_name=function_name, args=arguments | {"args": args_obj}, diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index 65f1e1d7..ec5735a1 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -1408,7 +1408,6 @@ async def generate_tool_guard_code( ) return guard_code - return violating_examples, compliance_examples class CugaAgent: From 0d93fa50c59047b9cbbed65776a6ba786a444208 Mon Sep 17 00:00:00 2001 From: naamaz Date: Thu, 30 Apr 2026 11:43:17 +0300 Subject: [PATCH 18/24] fix2 --- .../cuga_graph/policy/tool_guard/manager.py | 59 ++++++++++++++++++- .../policy/tool_guard/tool_invoker.py | 21 ++++++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py index c04813f8..b7e5c343 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py @@ -6,6 +6,8 @@ for tool usage policies. """ +import asyncio +import re import shutil from contextlib import contextmanager from typing import List, Tuple, Dict, Any, Iterator, Optional @@ -31,6 +33,20 @@ def __init__( self.tools_dict: Dict[str, Any] = {} # Store OpenAPI dict for ToolGuard self._initialized = False + # Validate tool_provider upfront + if agent.tool_provider is None: + raise ValueError( + "Agent tool_provider is not initialized. Ensure the CugaAgent has a valid " + "tool_provider before creating ToolGuardManager." + ) + + # Validate cuga_folder upfront + if not agent.cuga_folder: + raise ValueError( + "Agent cuga_folder is not set. Ensure the CugaAgent has a valid " + "cuga_folder path before creating ToolGuardManager." + ) + self.tool_provider = agent.tool_provider # Wrap CUGA's LLM with adapter for ToolGuard compatibility @@ -90,6 +106,39 @@ def _validate_policy_and_tool(self, policy: ToolGuide, target_tool: str) -> None f"Available tools: {tool_names}" ) + def _validate_app_name(self, app_name: str) -> str: + """ + Validate and sanitize app_name to prevent path traversal attacks. + + Args: + app_name: Application name to validate + + Returns: + The validated app_name + + Raises: + ValueError: If app_name contains unsafe characters or patterns + """ + # Check for path separators and traversal segments + if '/' in app_name or '\\' in app_name: + raise ValueError( + f"Invalid app_name '{app_name}': path separators ('/', '\\') are not allowed" + ) + + if '..' in app_name: + raise ValueError( + f"Invalid app_name '{app_name}': path traversal segments ('..') are not allowed" + ) + + # Validate against safe whitelist: alphanumeric, underscore, and hyphen only + if not re.match(r'^[A-Za-z0-9_-]+$', app_name): + raise ValueError( + f"Invalid app_name '{app_name}': only alphanumeric characters, " + f"underscores, and hyphens are allowed (pattern: /^[A-Za-z0-9_-]+$/)" + ) + + return app_name + def _build_description(self, policy: ToolGuide) -> str: """Build description from policy, concatenating guide_content if present.""" description = policy.description @@ -213,6 +262,8 @@ async def generate_examples( logger.warning(f"No results returned for tool '{target_tool}'") return [], [] + except asyncio.CancelledError: + raise except Exception as e: logger.error( f"❌ Failed to generate examples for tool '{target_tool}': {e}" @@ -243,10 +294,14 @@ async def generate_guard_code( Raises: RuntimeError: If manager not initialized ValueError: If policy is not a ToolGuide, target_tool not in policy.target_tools, - or if the policy doesn't have examples for the target tool + if the policy doesn't have examples for the target tool, + or if app_name contains unsafe characters """ self._ensure_initialized() self._validate_policy_and_tool(policy, target_tool) + + # Validate app_name to prevent path traversal attacks + app_name = self._validate_app_name(app_name) logger.info(f"Generating guard code for tool '{target_tool}'...") @@ -332,6 +387,8 @@ async def generate_guard_code( f"Available tools: {list(result.tools.keys())}" ) + except asyncio.CancelledError: + raise except Exception as e: logger.error( f"❌ Failed to generate guard code for tool '{target_tool}': {e}" diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py index 1fe7f5cc..3ae7182a 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py @@ -5,6 +5,7 @@ and CUGA's tool provider system. """ +import asyncio from typing import Any, Dict, Optional, Type, TypeVar from loguru import logger @@ -44,10 +45,25 @@ async def _get_tools(self) -> Dict[str, Any]: Returns: Dictionary mapping tool names to tool instances + + Raises: + ValueError: If duplicate tool names are detected """ if self._tools_cache is None: tools_list = await self.tool_provider.get_all_tools() - self._tools_cache = {tool.name: tool for tool in tools_list} + + # Check for duplicate tool names before building cache + tools_map: Dict[str, Any] = {} + for tool in tools_list: + if tool.name in tools_map: + raise ValueError( + f"Duplicate tool name detected: '{tool.name}'. " + f"Tool names must be unique across all providers to ensure " + f"correct routing of guards to tools." + ) + tools_map[tool.name] = tool + + self._tools_cache = tools_map logger.debug(f"Cached {len(self._tools_cache)} tools") return self._tools_cache @@ -106,6 +122,9 @@ async def invoke( # Re-raise ValueError as-is (tool not found) raise except Exception as e: + # Re-raise CancelledError immediately to preserve cancellation + if isinstance(e, asyncio.CancelledError): + raise logger.error(f"Failed to invoke tool '{toolname}': {e}") raise RuntimeError( f"Tool invocation failed for '{toolname}': {str(e)}" From c94ab7ce5ed7d815c44b5c36f61692399cf60e36 Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 4 May 2026 11:16:28 +0300 Subject: [PATCH 19/24] fix --- .../backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index 31758e43..cb2f2e91 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -697,7 +697,7 @@ async def guard_tool_call( try: args_obj = SimpleNamespace(**arguments) - await self._runtime.guard_toolcall( + await runtime.guard_toolcall( tool_name=function_name, args=arguments | {"args": args_obj}, delegate=self.invoker, From 4e84c602a399cff8e7acccccf3fb5ccb89b386a3 Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 4 May 2026 11:31:27 +0300 Subject: [PATCH 20/24] fix --- src/cuga/backend/cuga_graph/policy/models.py | 4 ++-- .../cuga_graph/policy/tool_guard/manager.py | 20 ++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/models.py b/src/cuga/backend/cuga_graph/policy/models.py index 596ba061..ad672154 100644 --- a/src/cuga/backend/cuga_graph/policy/models.py +++ b/src/cuga/backend/cuga_graph/policy/models.py @@ -179,9 +179,9 @@ class ToolGuard(BaseModel): default="", description=( "Python code that validates tool usage compliance. " - "WARNING: This code is executed unsandboxed at runtime. " + "WARNING: This code is executed in a sandboxed environment using the toolguard library. " "Only trusted administrators with manage access should be allowed to modify policy code. " - "Malicious code can compromise system security." + "While sandboxed, policy code should still be reviewed for correctness and performance." ) ) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py index 2482f271..1d022ea7 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/manager.py @@ -40,6 +40,7 @@ def __init__( self.langchain_tools: List[StructuredTool] = [] # Store LangChain tools self.tools_dict: Dict[str, Any] = {} # Store OpenAPI dict for ToolGuard self._initialized = False + self._init_lock = asyncio.Lock() # Validate tool_provider upfront if agent.tool_provider is None: @@ -75,16 +76,21 @@ def __init__( async def initialize( self, ) -> None: - logger.info("Initializing ToolGuardManager...") + async with self._init_lock: + if self._initialized: + logger.debug("ToolGuardManager already initialized, skipping") + return + + logger.info("Initializing ToolGuardManager...") - # Get all LangChain tools from the tool provider - self.langchain_tools = await self.tool_provider.get_all_tools() + # Get all LangChain tools from the tool provider + self.langchain_tools = await self.tool_provider.get_all_tools() - # Convert LangChain tools to OpenAPI dict using ToolGuard's utility - self.tools_dict = langchain_tools_to_openapi(self.langchain_tools) + # Convert LangChain tools to OpenAPI dict using ToolGuard's utility + self.tools_dict = langchain_tools_to_openapi(self.langchain_tools) - self._initialized = True - logger.info(f"✅ ToolGuardManager initialized with {len(self.langchain_tools)} tools") + self._initialized = True + logger.info(f"✅ ToolGuardManager initialized with {len(self.langchain_tools)} tools") def _ensure_initialized(self) -> None: """Ensure manager is initialized, raise RuntimeError if not.""" From 330ed8ee1254faaab37b123e8b72058431576da7 Mon Sep 17 00:00:00 2001 From: naamaz Date: Mon, 4 May 2026 12:05:48 +0300 Subject: [PATCH 21/24] fixes --- .../policy/tool_guard/tool_guard_runtime.py | 117 +++++++++++--- .../registry/registry/api_registry.py | 34 +++- .../registry/registry/api_registry_server.py | 153 ++++++------------ 3 files changed, 172 insertions(+), 132 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py index cb2f2e91..e96699e3 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_guard_runtime.py @@ -31,22 +31,30 @@ class ToolGuardRuntime: Runtime system for executing tool guards during tool invocation. This class: - 1. Initializes a ToolGuardInvoker for tool execution - 2. Loads all ToolGuide policies with policy_code - 3. Creates a mapping: tool_name -> List[ToolGuide with code] - 4. Prebuilds umbrella guard modules per tool - 5. Executes guard validation through toolguard runtime + 1. Manages policy storage lifecycle (connect/disconnect) + 2. Initializes a ToolGuardInvoker for tool execution + 3. Loads all ToolGuide policies with policy_code + 4. Creates a mapping: tool_name -> List[ToolGuide with code] + 5. Prebuilds umbrella guard modules per tool + 6. Executes guard validation through toolguard runtime """ - def __init__(self, tool_provider, policy_storage: PolicyStorage) -> None: + def __init__( + self, + tool_provider, + enable_policies: bool = False, + policy_storage: Optional[PolicyStorage] = None + ) -> None: """ Initialize the ToolGuardRuntime. Args: tool_provider: CUGA's tool provider instance - policy_storage: PolicyStorage instance to load policies from + enable_policies: Whether to enable policy enforcement + policy_storage: Optional PolicyStorage instance (will be created if None and enable_policies=True) """ self.tool_provider = tool_provider + self.enable_policies = enable_policies self.policy_storage = policy_storage self.invoker = ToolGuardInvoker(tool_provider) self.tool_to_guards: Dict[str, List[ToolGuide]] = {} @@ -54,32 +62,89 @@ def __init__(self, tool_provider, policy_storage: PolicyStorage) -> None: self._runtimes_by_app: Dict[str, Any] = {} self._runtime_domains_by_app: Dict[str, RuntimeDomain] = {} self._initialized = False - logger.debug("Created ToolGuardRuntime instance") + self._policy_storage_owned = False # Track if we created the storage + logger.debug(f"Created ToolGuardRuntime instance (enable_policies={enable_policies})") async def initialize(self) -> None: """ - Initialize the runtime by loading all ToolGuide policies with code. + Initialize the runtime by connecting to policy storage and loading policies. This method: - 1. Fetches all ToolGuide policies from storage - 2. Filters for policies that have tool_guards with policy_code - 3. Builds the tool_to_guards mapping - 4. Per-app runtimes will be lazily loaded on first use + 1. Connects to policy storage if policies are enabled + 2. Fetches all ToolGuide policies from storage + 3. Filters for policies that have tool_guards with policy_code + 4. Builds the tool_to_guards mapping + 5. Per-app runtimes will be lazily loaded on first use + + Raises: + RuntimeError: If policy system is enabled but storage connection fails (fail-closed) """ logger.info("Initializing ToolGuardRuntime...") self._reset_state() - policies = await self.policy_storage.list_policies( - policy_type=PolicyType.TOOL_GUIDE, enabled_only=True - ) - logger.debug(f"Found {len(policies)} ToolGuide policies") + # Connect to policy storage if policies are enabled + if self.enable_policies: + if self.policy_storage is None: + # Create policy storage if not provided + from cuga.backend.cuga_graph.policy.storage import PolicyStorage + self.policy_storage = PolicyStorage() + self._policy_storage_owned = True + logger.debug("Created PolicyStorage instance") + + # Validate policy_storage has required interface + self._validate_policy_storage() + + try: + await self.policy_storage.connect() + logger.info("✅ Connected policy storage for ToolGuardRuntime") + except Exception as e: + logger.error(f"Failed to connect policy storage: {e}") + # Fail closed: if policy enforcement is enabled but storage fails, + # don't allow the service to start without policy validation + raise RuntimeError( + f"Policy system is enabled but PolicyStorage.connect() failed: {e}" + ) from e + + # Load policies if storage is available + if self.policy_storage is not None: + policies = await self.policy_storage.list_policies( + policy_type=PolicyType.TOOL_GUIDE, enabled_only=True + ) + logger.debug(f"Found {len(policies)} ToolGuide policies") - # Filter to ensure we only have ToolGuide instances - tool_guide_policies = [p for p in policies if isinstance(p, ToolGuide)] - self._build_tool_to_guards_mapping(tool_guide_policies) + # Filter to ensure we only have ToolGuide instances + tool_guide_policies = [p for p in policies if isinstance(p, ToolGuide)] + self._build_tool_to_guards_mapping(tool_guide_policies) + else: + logger.debug("No policy storage available, skipping policy loading") self._initialized = True self._log_initialization_summary() + + def _validate_policy_storage(self) -> None: + """ + Validate that policy_storage has the required interface. + + Raises: + ValueError: If policy_storage doesn't implement required methods + """ + if self.policy_storage is None: + return + + required_methods = ['connect', 'disconnect', 'list_policies', 'get_policy'] + missing_methods = [] + + for method in required_methods: + if not hasattr(self.policy_storage, method): + missing_methods.append(method) + + if missing_methods: + raise ValueError( + f"policy_storage must implement the following methods: {', '.join(missing_methods)}. " + f"Provided object type: {type(self.policy_storage).__name__}" + ) + + logger.debug("✅ Policy storage interface validation passed") def _reset_state(self) -> None: """Reset internal state for reinitialization.""" @@ -750,7 +815,7 @@ def get_guards_for_tool(self, tool_name: str) -> List[ToolGuide]: return self.tool_to_guards.get(tool_name, []) async def shutdown(self) -> None: - """Release in-memory ToolGuard runtime resources.""" + """Release in-memory ToolGuard runtime resources and disconnect policy storage.""" # Clean up all per-app runtimes for app_name, runtime in self._runtimes_by_app.items(): if runtime is not None: @@ -760,5 +825,15 @@ async def shutdown(self) -> None: logger.exception(f"Error while shutting down ToolGuard runtime for app '{app_name}'") self._runtimes_by_app = {} self._runtime_domains_by_app = {} + + # Disconnect policy storage if we own it + if self.policy_storage is not None and self._policy_storage_owned: + try: + await self.policy_storage.disconnect() + logger.debug("Disconnected policy storage") + except Exception as e: + logger.warning(f"Error disconnecting policy storage during shutdown: {e}") + self.policy_storage = None + self._initialized = False logger.debug("ToolGuardRuntime shutdown complete") \ No newline at end of file diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry.py b/src/cuga/backend/tools_env/registry/registry/api_registry.py index af8093e4..90c19d3f 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry.py @@ -29,7 +29,14 @@ class ApiRegistry: interacting with the mcp manager """ - def __init__(self, client: MCPManager, policy_storage=None): + def __init__(self, client: MCPManager, enable_policies: bool = False): + """ + Initialize ApiRegistry. + + Args: + client: MCPManager instance for tool management + enable_policies: Whether to enable policy-based tool validation + """ logger.info("ApiRegistry: Initializing.") self.mcp_client = client self.auth_manager = None @@ -37,9 +44,10 @@ def __init__(self, client: MCPManager, policy_storage=None): self._init_tavily_if_enabled() # ToolGuardRuntime support for policy-based tool validation - self.policy_storage = policy_storage + # ToolGuardRuntime now manages its own policy storage lifecycle self.tool_guard_runtime = None self._tool_guard_initialized = False + self._enable_policies = enable_policies def _init_tavily_if_enabled(self): """Initialize Tavily client if web search is enabled.""" @@ -74,25 +82,25 @@ async def _initialize_tool_guard_runtime(self): Initialize ToolGuardRuntime after tools are loaded. This is called after load_tools() to ensure both tools and policies are available. + ToolGuardRuntime now manages its own policy storage lifecycle. """ if self._tool_guard_initialized: return self._tool_guard_initialized = True - if self.policy_storage is None: - logger.debug("ToolGuardRuntime: No policy storage provided, skipping initialization") + if not self._enable_policies: + logger.debug("ToolGuardRuntime: Policy enforcement disabled, skipping initialization") return try: from cuga.backend.cuga_graph.policy.tool_guard.tool_guard_runtime import ToolGuardRuntime - # Note: tool_provider is only needed if guards invoke tools via the invoker - # For most guards that just validate arguments, it's not used - # We pass self.mcp_client which can be wrapped if needed + # ToolGuardRuntime now manages policy storage internally + # We just pass enable_policies flag and it handles the rest self.tool_guard_runtime = ToolGuardRuntime( tool_provider=self.mcp_client, - policy_storage=self.policy_storage + enable_policies=self._enable_policies ) await self.tool_guard_runtime.initialize() @@ -111,6 +119,16 @@ async def _initialize_tool_guard_runtime(self): raise RuntimeError( f"ToolGuardRuntime failed to initialize. Policy enforcement cannot be bypassed. Error: {e}" ) from e + + async def cleanup(self): + """Cleanup resources including ToolGuardRuntime.""" + if self.tool_guard_runtime is not None: + try: + await self.tool_guard_runtime.shutdown() + logger.debug("ToolGuardRuntime shutdown complete") + except Exception as e: + logger.warning(f"Error shutting down ToolGuardRuntime: {e}") + self.tool_guard_runtime = None async def show_applications(self) -> List[AppDefinition]: """Lists application names and their descriptions.""" diff --git a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py index 2a754389..3830712a 100644 --- a/src/cuga/backend/tools_env/registry/registry/api_registry_server.py +++ b/src/cuga/backend/tools_env/registry/registry/api_registry_server.py @@ -141,23 +141,8 @@ async def _get_or_create_registry( manager = MCPManager(config=services) - # Get policy storage if policy system is enabled - policy_storage = None - if settings.policy.enabled: - from cuga.backend.cuga_graph.policy.storage import PolicyStorage - policy_storage = PolicyStorage() - try: - await policy_storage.connect() - logger.info("✅ Connected policy storage for ToolGuardRuntime") - except Exception as e: - logger.error(f"Failed to connect policy storage: {e}") - # Fail closed: if policy enforcement is enabled but storage fails, - # don't boot the registry unguarded - raise RuntimeError( - f"Policy system is enabled but PolicyStorage.connect() failed: {e}" - ) from e - - reg = ApiRegistry(client=manager, policy_storage=policy_storage) + # ApiRegistry now manages policy storage lifecycle through ToolGuardRuntime + reg = ApiRegistry(client=manager, enable_policies=settings.policy.enabled) try: await reg.start_servers() except Exception: @@ -167,12 +152,11 @@ async def _get_or_create_registry( except Exception as cleanup_error: logger.warning(f"Error shutting down manager during cleanup: {cleanup_error}") - # Clean up policy_storage if start_servers fails - if policy_storage is not None: - try: - await policy_storage.disconnect() - except Exception as cleanup_error: - logger.warning(f"Error disconnecting policy storage during cleanup: {cleanup_error}") + # Clean up registry resources (including ToolGuardRuntime/policy storage) + try: + await reg.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error cleaning up registry during cleanup: {cleanup_error}") raise agent_registries[agent_id] = (manager, reg) @@ -199,23 +183,8 @@ async def lifespan(app: FastAPI): services = load_service_configs(str(config_file)) mcp_manager = MCPManager(config=services) - # Get policy storage if policy system is enabled - policy_storage = None - if settings.policy.enabled: - try: - from cuga.backend.cuga_graph.policy.storage import PolicyStorage - policy_storage = PolicyStorage() - await policy_storage.connect() - logger.info("✅ Connected policy storage for ToolGuardRuntime") - except Exception as e: - logger.error(f"Failed to connect policy storage: {e}") - # Fail closed: if policy enforcement is enabled but storage fails, - # don't boot the registry unguarded - raise RuntimeError( - f"Policy system is enabled but PolicyStorage.connect() failed: {e}" - ) from e - - registry = ApiRegistry(client=mcp_manager, policy_storage=policy_storage) + # ApiRegistry now manages policy storage lifecycle through ToolGuardRuntime + registry = ApiRegistry(client=mcp_manager, enable_policies=settings.policy.enabled) try: await registry.start_servers() except Exception: @@ -225,12 +194,11 @@ async def lifespan(app: FastAPI): except Exception as cleanup_error: logger.warning(f"Error shutting down mcp_manager during cleanup: {cleanup_error}") - # Clean up policy_storage if start_servers fails - if policy_storage is not None: - try: - await policy_storage.disconnect() - except Exception as cleanup_error: - logger.warning(f"Error disconnecting policy storage during cleanup: {cleanup_error}") + # Clean up registry resources (including ToolGuardRuntime/policy storage) + try: + await registry.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error cleaning up registry during cleanup: {cleanup_error}") raise yield @@ -239,19 +207,18 @@ async def lifespan(app: FastAPI): for agent_id, (mgr, reg) in agent_registries.items(): logger.info(f"Cleaning up registry for agent: {agent_id}") await mgr.shutdown() - # Disconnect policy storage if present - if reg.policy_storage is not None: - try: - await reg.policy_storage.disconnect() - except Exception as e: - logger.warning(f"Error disconnecting policy storage for agent {agent_id}: {e}") + # Cleanup registry resources (including ToolGuardRuntime/policy storage) + try: + await reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up registry for agent {agent_id}: {e}") if not database_mode: - # In YAML mode, also clean up the global registry's policy storage - if 'registry' in globals() and registry is not None and registry.policy_storage is not None: + # In YAML mode, also clean up the global registry + if 'registry' in globals() and registry is not None: try: - await registry.policy_storage.disconnect() + await registry.cleanup() except Exception as e: - logger.warning(f"Error disconnecting global registry policy storage: {e}") + logger.warning(f"Error cleaning up global registry: {e}") if 'mcp_manager' in globals(): await mcp_manager.shutdown() @@ -552,11 +519,10 @@ async def reload_config( if agent_id in agent_registries: old_mgr, old_reg = agent_registries.pop(agent_id) await old_mgr.shutdown() - if old_reg.policy_storage is not None: - try: - await old_reg.policy_storage.disconnect() - except Exception as e: - logger.warning(f"Error disconnecting policy storage: {e}") + try: + await old_reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up old registry: {e}") # Recreate registry for this agent with retry on empty await _get_or_create_registry(agent_id, retry_on_empty=True) # If this is the default agent, update global registry @@ -583,11 +549,10 @@ async def reload_config( agent_ids = list(agent_registries.keys()) for old_mgr, old_reg in agent_registries.values(): await old_mgr.shutdown() - if old_reg.policy_storage is not None: - try: - await old_reg.policy_storage.disconnect() - except Exception as e: - logger.warning(f"Error disconnecting policy storage: {e}") + try: + await old_reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up old registry: {e}") agent_registries.clear() # Recreate default agent mcp_manager, registry = await _get_or_create_registry(default_agent_id) @@ -604,23 +569,8 @@ async def reload_config( services = load_service_configs(config_path) new_manager = MCPManager(config=services) - # Get policy storage if policy system is enabled - policy_storage = None - if settings.policy.enabled: - try: - from cuga.backend.cuga_graph.policy.storage import PolicyStorage - policy_storage = PolicyStorage() - await policy_storage.connect() - logger.info("✅ Connected policy storage for ToolGuardRuntime") - except Exception as e: - logger.error(f"Failed to connect policy storage: {e}") - # Fail closed: if policy enforcement is enabled but storage fails, - # don't reload the registry unguarded - raise RuntimeError( - f"Policy system is enabled but PolicyStorage.connect() failed: {e}" - ) from e - - new_registry = ApiRegistry(client=new_manager, policy_storage=policy_storage) + # ApiRegistry now manages policy storage lifecycle through ToolGuardRuntime + new_registry = ApiRegistry(client=new_manager, enable_policies=settings.policy.enabled) try: await new_registry.start_servers() except Exception: @@ -630,20 +580,19 @@ async def reload_config( except Exception as cleanup_error: logger.warning(f"Error shutting down new_manager during cleanup: {cleanup_error}") - # Clean up new policy_storage if start_servers fails - if policy_storage is not None: - try: - await policy_storage.disconnect() - except Exception as cleanup_error: - logger.warning(f"Error disconnecting policy storage during cleanup: {cleanup_error}") + # Clean up new registry resources + try: + await new_registry.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error cleaning up new registry during cleanup: {cleanup_error}") raise - # Clean up old registry's policy storage before replacing - if 'registry' in globals() and registry is not None and registry.policy_storage is not None: + # Clean up old registry before replacing + if 'registry' in globals() and registry is not None: try: - await registry.policy_storage.disconnect() + await registry.cleanup() except Exception as e: - logger.warning(f"Error disconnecting old registry policy storage: {e}") + logger.warning(f"Error cleaning up old registry: {e}") if 'mcp_manager' in globals(): await mcp_manager.shutdown() @@ -673,11 +622,10 @@ async def clear_agent_cache( if agent_id in agent_registries: mgr, reg = agent_registries.pop(agent_id) await mgr.shutdown() - if reg.policy_storage is not None: - try: - await reg.policy_storage.disconnect() - except Exception as e: - logger.warning(f"Error disconnecting policy storage: {e}") + try: + await reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up registry: {e}") logger.info(f"Cleared cache for agent: {agent_id}") return {"status": "ok", "message": f"Cache cleared for agent: {agent_id}"} else: @@ -686,11 +634,10 @@ async def clear_agent_cache( count = len(agent_registries) for mgr, reg in agent_registries.values(): await mgr.shutdown() - if reg.policy_storage is not None: - try: - await reg.policy_storage.disconnect() - except Exception as e: - logger.warning(f"Error disconnecting policy storage: {e}") + try: + await reg.cleanup() + except Exception as e: + logger.warning(f"Error cleaning up registry: {e}") agent_registries.clear() logger.info(f"Cleared cache for {count} agents") return {"status": "ok", "message": f"Cache cleared for {count} agents"} From af1b3224956871fc63ff802c1815cedb4212701f Mon Sep 17 00:00:00 2001 From: naamaz Date: Thu, 7 May 2026 11:59:12 +0300 Subject: [PATCH 22/24] demo sdk example --- .../sdk_core/debug_crm_finance_policy_e2e.py | 318 ++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 src/cuga/sdk_core/debug_crm_finance_policy_e2e.py diff --git a/src/cuga/sdk_core/debug_crm_finance_policy_e2e.py b/src/cuga/sdk_core/debug_crm_finance_policy_e2e.py new file mode 100644 index 00000000..bf2a3c58 --- /dev/null +++ b/src/cuga/sdk_core/debug_crm_finance_policy_e2e.py @@ -0,0 +1,318 @@ +""" +Debug script for creating and testing CRM finance eligibility tool guard policy. + +This script demonstrates: +1. Creating a CugaAgent with CRM tools from the registry +2. Adding a finance eligibility policy with tool guards +3. Generating examples and guard code for the policy +4. Testing policy enforcement with ToolGuardRuntime + +Prerequisites: +- CRM API server must be running on port 8007 +- Run: cuga start demo_crm (or just the CRM API server) + +Note: This script uses an in-memory policy system for testing. +It does NOT persist policies to the running demo_crm server. +""" + +import asyncio +import os +import tempfile +from cuga import CugaAgent +from cuga.backend.cuga_graph.policy.tool_guard import ToolGuardRuntime +from cuga.backend.cuga_graph.nodes.cuga_lite.combined_tool_provider import CombinedToolProvider + +# Define policy to create +POLICY_CONFIG = { + "name": "Finance eligibility revenue requirements", + "content": """## Finance Industry Revenue Requirements + +### Policy Rules +- Accounts cannot be created for companies from the Finance industry with annual revenue under $100,000 +- This ensures we only onboard financially stable finance companies +- Companies from other industries have no revenue restrictions + +### Validation Requirements +- Always check the industry field before account creation +- If industry is "Finance", verify annual_revenue >= 100000 +- Reject account creation that violates revenue requirements +- Provide clear error messages when restrictions apply +""", + "description": "Accounts cannot be created for companies from the Finance industry with annual revenue under $100,000.", +} + + +async def create_and_process_policy(agent): + """Create policy and generate examples and guard code.""" + print("\nStep 1: Creating and processing policy...") + print("="*60) + + print(f"\n--- Processing Policy: {POLICY_CONFIG['name']} ---") + + # Create policy + print(f"Creating policy...") + policy_id = await agent.policies.add_tool_guide( + name=POLICY_CONFIG["name"], + content=POLICY_CONFIG["content"], + target_tools=["crm_create_account_accounts_post"], + description=POLICY_CONFIG["description"], + ) + print(f"✅ Created policy with ID: {policy_id}") + + # Generate examples + print(f"Generating examples...") + violating_examples, compliance_examples = await agent.policies.generate_tool_guard_examples( + policy_id=policy_id, + target_tool="crm_create_account_accounts_post" + ) + print(f"✅ Generated {len(violating_examples)} violating and {len(compliance_examples)} compliance examples") + + # Update policy with examples + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "crm_create_account_accounts_post": { + "description": f"Guard rules for {POLICY_CONFIG['name']}", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": "" + } + } + ) + + # Generate guard code + print(f"Generating guard code...") + guard_code = await agent.policies.generate_tool_guard_code( + policy_id=policy_id, + target_tool="crm_create_account_accounts_post", + app_name="crm_demo" + ) + print(f"✅ Generated guard code ({len(guard_code)} characters)") + + # Update policy with guard code + await agent.policies.update_tool_guard( + policy_id=policy_id, + tool_guards={ + "crm_create_account_accounts_post": { + "description": f"Guard rules for {POLICY_CONFIG['name']}", + "violating_examples": violating_examples, + "compliance_examples": compliance_examples, + "policy_code": guard_code + } + } + ) + + # Retrieve policy + policy_tool_guide = await agent.policies.get(policy_id) + if policy_tool_guide is None: + raise ValueError(f"Failed to retrieve policy {policy_id}") + + policy = policy_tool_guide["policy"] + + print(f"✅ Policy created and processed successfully") + print("="*60) + + return { + "id": policy_id, + "name": POLICY_CONFIG["name"], + "policy": policy + } + + +async def run_tests(tool_guard_runtime, agent): + """Run test cases to validate policy enforcement.""" + print(f"\n{'='*60}") + print("Step 2: Testing policy enforcement with ToolGuardRuntime") + print(f"{'='*60}") + + print(f"\nRuntime initialized with guards for: {tool_guard_runtime.get_guarded_tools()}") + + test_cases = [ + { + "name": "Test Case 1: Finance with Low Revenue (BLOCKED)", + "args": { + "name": "ACM22 Corporation", + "website": "acm22corporation.com", + "phone": "+1-555-1883", + "address": "94 rue du Gue Jacquet", + "city": "Chatou", + "state": "Île-de-France", + "country": "France", + "region": "Europe", + "annual_revenue": 50000, + "employee_count": 88, + "industry": "Finance" + }, + "expected": "BLOCKED", + "reason": "Finance industry with revenue $50,000 < $100,000" + }, + { + "name": "Test Case 2: Finance with High Revenue (ALLOWED)", + "args": { + "name": "Global Finance Corp", + "website": "globalfinance.com", + "phone": "+1-555-2000", + "address": "123 Wall Street", + "city": "New York", + "state": "NY", + "country": "USA", + "region": "North America", + "annual_revenue": 1500000, + "employee_count": 250, + "industry": "Finance" + }, + "expected": "ALLOWED", + "reason": "Finance industry with revenue $1,500,000 >= $100,000" + }, + { + "name": "Test Case 3: Non-Finance with Low Revenue (ALLOWED)", + "args": { + "name": "Small Law Firm", + "website": "smalllawfirm.com", + "phone": "+1-555-3000", + "address": "456 Main Street", + "city": "Boston", + "state": "MA", + "country": "USA", + "region": "North America", + "annual_revenue": 50000, + "employee_count": 10, + "industry": "Law" + }, + "expected": "ALLOWED", + "reason": "Law industry has no revenue restrictions" + } + ] + + results = [] + for test in test_cases: + print(f"\n--- {test['name']} ---") + print(f"Attempting: crm_create_account_accounts_post(") + for k, v in test['args'].items(): + print(f" {k}={repr(v)},") + print(")") + print(f"Expected: {test['expected']} ({test['reason']})") + + try: + error = await tool_guard_runtime.guard_tool_call( + app_name="crm_demo", + function_name="crm_create_account_accounts_post", + arguments=test["args"] + ) + + actual = "BLOCKED" if error else "ALLOWED" + success = actual == test["expected"] + + if success: + print(f"\n✅ SUCCESS: Tool call was correctly {actual}!") + if error: + print(f"Error message: {error}") + else: + print(f"\n⚠️ WARNING: Tool call was {actual} (expected {test['expected']})") + if error: + print(f"Error message: {error}") + + results.append({"test": test["name"], "success": success, "actual": actual}) + + except Exception as e: + print(f"\n❌ Error during validation: {type(e).__name__}: {e}") + results.append({"test": test["name"], "success": False, "actual": "ERROR"}) + + # Print summary + print(f"\n{'='*60}") + print("Test Summary:") + print(f"{'='*60}") + passed = sum(1 for r in results if r["success"]) + print(f"Passed: {passed}/{len(results)}") + for r in results: + status = "✅" if r["success"] else "❌" + print(f" {status} {r['test']}: {r['actual']}") + print(f"{'='*60}") + + return results + + +async def main(): + """Main workflow for creating and testing CRM finance eligibility tool guard policy.""" + + print("="*60) + print("Initializing CRM Tool Provider...") + print("="*60) + + # Create tool provider with CRM app + tool_provider = CombinedToolProvider(app_names=["crm"]) + await tool_provider.initialize() + + # Verify CRM tools are available + tools = await tool_provider.get_tools(app_name="crm") + crm_tool = next((t for t in tools if t.name == "crm_create_account_accounts_post"), None) + + if not crm_tool: + print("❌ ERROR: CRM tool 'crm_create_account_accounts_post' not found!") + print("Available tools:", [t.name for t in tools]) + print("\nMake sure the CRM API server is running:") + print(" cuga start demo_crm") + print(" OR") + print(" cd src/cuga/demo_tools/crm && python -m crm_api.main") + return + + print(f"✅ Found CRM tool: {crm_tool.name}") + print(f"✅ Total tools available: {len(tools)}") + + # Create agent with CRM tools and in-memory policy system + print("\n" + "="*60) + print("Initializing CugaAgent with in-memory policy system...") + print("="*60) + + # Use a temporary database for this test + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db: + temp_db_path = tmp_db.name + + try: + agent = CugaAgent(tool_provider=tool_provider) + + # Initialize policy system with temporary database + from cuga.backend.cuga_graph.policy.configurable import PolicyConfigurable + + agent._policy_system = PolicyConfigurable() + await agent._policy_system.initialize(policy_db_path=temp_db_path) + + print(f"✅ Using temporary policy database: {temp_db_path}") + print("✅ Policy system initialized") + + # Get policy system + policy_system = agent._policy_system + if policy_system is None or policy_system.storage is None: + raise ValueError("Policy system storage is not available") + + # Step 1: Create and process policy + policy_data = await create_and_process_policy(agent) + + # Step 2: Initialize runtime and run tests + print(f"\n{'='*60}") + print("Initializing ToolGuardRuntime...") + print(f"{'='*60}") + + tool_guard_runtime = ToolGuardRuntime( + tool_provider=agent.tool_provider, + policy_storage=policy_system.storage + ) + await tool_guard_runtime.initialize() + print("✅ ToolGuardRuntime initialized") + + results = await run_tests(tool_guard_runtime, agent) + + print(f"\n{'='*60}") + print("✅ E2E Test completed successfully!") + print(f"{'='*60}") + + finally: + # Clean up temporary database + if os.path.exists(temp_db_path): + os.unlink(temp_db_path) + print(f"\n🧹 Cleaned up temporary database: {temp_db_path}") + + +if __name__ == "__main__": + asyncio.run(main()) + From 7770eddc5ce63006a26d210c8bed18f29e42ab34 Mon Sep 17 00:00:00 2001 From: naamaz Date: Thu, 14 May 2026 10:20:11 +0300 Subject: [PATCH 23/24] remove tg description --- src/cuga/backend/cuga_graph/policy/models.py | 1 - .../backend/cuga_graph/policy/tests/test_filesystem_sync.py | 4 +--- src/cuga/sdk.py | 3 +-- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/models.py b/src/cuga/backend/cuga_graph/policy/models.py index ad672154..77b2be16 100644 --- a/src/cuga/backend/cuga_graph/policy/models.py +++ b/src/cuga/backend/cuga_graph/policy/models.py @@ -168,7 +168,6 @@ class IntentGuardResponse(BaseModel): class ToolGuard(BaseModel): """Guard configuration for a specific tool with compliance rules.""" - description: str = Field(..., description="Description of the guard rules for this tool") violating_examples: List[str] = Field( default_factory=list, description="Examples of violating usage patterns" ) diff --git a/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py b/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py index 211dd45f..f0fc27ec 100644 --- a/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py +++ b/src/cuga/backend/cuga_graph/policy/tests/test_filesystem_sync.py @@ -179,7 +179,6 @@ async def test_save_tool_guide_with_tool_guards_to_filesystem(self, temp_cuga_fo guide_content="## Guidelines\n- Be careful", tool_guards={ "test_tool": ToolGuard( - description="Only allow safe usage", violating_examples=["Deleting all records without confirmation"], compliance_examples=["Delete one record after explicit confirmation"], policy_code="def validate(call):\n return True", @@ -415,7 +414,6 @@ async def test_auto_load_multiple_policy_types(self, temp_cuga_folder): guide_content="## Test", tool_guards={ "test_tool": ToolGuard( - description="Guard loaded from filesystem", violating_examples=["bad_example()"], compliance_examples=["good_example()"], policy_code="def validate(call):\n return True", @@ -450,7 +448,7 @@ async def test_auto_load_multiple_policy_types(self, temp_cuga_folder): assert loaded_guide is not None assert loaded_guide["policy"].tool_guards is not None assert "test_tool" in loaded_guide["policy"].tool_guards - assert loaded_guide["policy"].tool_guards["test_tool"].description == "Guard loaded from filesystem" + # ToolGuard no longer has description field - it's derived from ToolGuide @pytest.mark.asyncio async def test_auto_load_disabled(self, temp_cuga_folder): diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index 5ba8949c..1887712d 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -546,7 +546,7 @@ async def update_tool_guard( Args: policy_id: ID of the existing Tool Guide policy to update - tool_guards: Dict of tool guards (key: tool_name, value: dict with 'description', 'violating_examples', 'compliance_examples', 'policy_code') + tool_guards: Dict of tool guards (key: tool_name, value: dict with 'violating_examples', 'compliance_examples', 'policy_code') Returns: Policy ID @@ -601,7 +601,6 @@ async def update_tool_guard( # Convert and update only the incoming tool_guards for tool_name, guard_data in tool_guards.items(): tool_guards_obj[tool_name] = ToolGuard( - description=guard_data.get("description", ""), violating_examples=guard_data.get("violating_examples", []), compliance_examples=guard_data.get("compliance_examples", []), policy_code=guard_data.get("policy_code", ""), From d053e193aa2e963aad2496a2fa987eaf2d02121c Mon Sep 17 00:00:00 2001 From: naamaz Date: Thu, 14 May 2026 10:32:24 +0300 Subject: [PATCH 24/24] fixes --- src/cuga/backend/cuga_graph/policy/models.py | 2 +- .../backend/cuga_graph/policy/tool_guard/tool_invoker.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/cuga/backend/cuga_graph/policy/models.py b/src/cuga/backend/cuga_graph/policy/models.py index 77b2be16..33066dcf 100644 --- a/src/cuga/backend/cuga_graph/policy/models.py +++ b/src/cuga/backend/cuga_graph/policy/models.py @@ -178,7 +178,7 @@ class ToolGuard(BaseModel): default="", description=( "Python code that validates tool usage compliance. " - "WARNING: This code is executed in a sandboxed environment using the toolguard library. " + "This code is executed in a sandboxed environment using the toolguard library. " "Only trusted administrators with manage access should be allowed to modify policy code. " "While sandboxed, policy code should still be reviewed for correctness and performance." ) diff --git a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py index ba73fd77..6118a839 100644 --- a/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py +++ b/src/cuga/backend/cuga_graph/policy/tool_guard/tool_invoker.py @@ -121,10 +121,9 @@ async def invoke( except ValueError: # Re-raise ValueError as-is (tool not found) raise + except asyncio.CancelledError: + raise except Exception as e: - # Re-raise CancelledError immediately to preserve cancellation - if isinstance(e, asyncio.CancelledError): - raise logger.error(f"Failed to invoke tool '{toolname}': {e}") raise RuntimeError( f"Tool invocation failed for '{toolname}': {str(e)}"