diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..dcd9293 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,74 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + name: Format & Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install linters + run: pip install black flake8 + + - name: black + run: black --check governs_ai/ tests/ + + - name: flake8 + run: flake8 governs_ai/ tests/ --max-line-length=88 --extend-ignore=E203,W503 + + typecheck: + name: Type Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install package with dev extras + run: pip install -e ".[dev]" + + - name: mypy + run: mypy governs_ai/ --ignore-missing-imports + + test: + name: Tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install package with dev extras + run: pip install -e ".[dev]" + + - name: pytest + run: pytest tests/ -v --tb=short + + secret-scan: + name: Secret Scan + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: gitleaks/gitleaks-action@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..40464cf --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,30 @@ +name: Publish to PyPI + +on: + push: + tags: + - "v*" + +jobs: + publish: + name: Build & Publish + runs-on: ubuntu-latest + environment: pypi + permissions: + id-token: write # OIDC trusted publishing + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install build tools + run: pip install build + + - name: Build distribution + run: python -m build + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/README.md b/README.md index 92a059f..43ff07f 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ from governs_ai import GovernsAIClient # Create client with organization context client = GovernsAIClient( api_key="your-api-key", - base_url="http://localhost:3002", + base_url="https://api.governsai.com", org_id="org-456" # Organization context (static) ) diff --git a/docs/README.md b/docs/README.md index 21a7e2f..6816a0d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -38,7 +38,7 @@ from governs_ai import GovernsAIClient # Create client with organization context client = GovernsAIClient( api_key="your-api-key", - base_url="http://localhost:3002", + base_url="https://api.governsai.com", org_id="org-456" # Organization context (static) ) @@ -70,7 +70,7 @@ elif precheck_response.decision == "confirm": ```bash export GOVERNS_API_KEY="your-api-key" -export GOVERNS_BASE_URL="http://localhost:3002" +export GOVERNS_BASE_URL="https://api.governsai.com" export GOVERNS_ORG_ID="org-456" export GOVERNS_TIMEOUT="30000" export GOVERNS_RETRIES="3" @@ -85,7 +85,7 @@ from governs_ai import GovernsAIClient, GovernsAIConfig # Explicit configuration config = GovernsAIConfig( api_key="your-api-key", - base_url="http://localhost:3002", + base_url="https://api.governsai.com", org_id="org-456", timeout=30000, retries=3, @@ -423,7 +423,7 @@ client = GovernsAIClient( ```python GovernsAIClient( api_key: str, - base_url: str = "http://localhost:3002", + base_url: str = "https://api.governsai.com", org_id: str, timeout: int = 30000, retries: int = 3, @@ -770,7 +770,7 @@ import os client = GovernsAIClient( api_key=os.getenv("GOVERNS_API_KEY"), org_id=os.getenv("GOVERNS_ORG_ID"), - base_url=os.getenv("GOVERNS_BASE_URL", "http://localhost:3002") + base_url=os.getenv("GOVERNS_BASE_URL", "https://api.governsai.com") ) ``` diff --git a/governs_ai/__init__.py b/governs_ai/__init__.py index e5ddfa1..ae55283 100644 --- a/governs_ai/__init__.py +++ b/governs_ai/__init__.py @@ -12,6 +12,8 @@ from .clients.budget import BudgetClient from .clients.tool import ToolClient from .clients.analytics import AnalyticsClient +from .clients.context import ContextClient +from .clients.documents import DocumentClient from .models import ( PrecheckRequest, PrecheckResponse, @@ -21,6 +23,25 @@ ConfirmationResponse, HealthStatus, Decision, + SaveContextInput, + SaveContextResponse, + ContextLLMResponse, + ConversationSummary, + ConversationItem, + MemoryRecord, + MemorySearchMetadata, + MemorySearchResponse, + ResolvedUserDetails, + ResolvedUser, + DocumentUploadResponse, + DocumentChunk, + DocumentRecord, + DocumentDetails, + DocumentListPagination, + DocumentListResponse, + DocumentSearchSource, + DocumentSearchResult, + DocumentSearchResponse, ) from .exceptions import ( GovernsAIError, @@ -45,6 +66,8 @@ "BudgetClient", "ToolClient", "AnalyticsClient", + "ContextClient", + "DocumentClient", # Data models "PrecheckRequest", "PrecheckResponse", @@ -54,6 +77,25 @@ "ConfirmationResponse", "HealthStatus", "Decision", + "SaveContextInput", + "SaveContextResponse", + "ContextLLMResponse", + "ConversationSummary", + "ConversationItem", + "MemoryRecord", + "MemorySearchMetadata", + "MemorySearchResponse", + "ResolvedUserDetails", + "ResolvedUser", + "DocumentUploadResponse", + "DocumentChunk", + "DocumentRecord", + "DocumentDetails", + "DocumentListPagination", + "DocumentListResponse", + "DocumentSearchSource", + "DocumentSearchResult", + "DocumentSearchResponse", # Exceptions "GovernsAIError", "PrecheckError", diff --git a/governs_ai/client.py b/governs_ai/client.py index b84afed..d00ba71 100644 --- a/governs_ai/client.py +++ b/governs_ai/client.py @@ -24,6 +24,8 @@ from .clients.budget import BudgetClient from .clients.tool import ToolClient from .clients.analytics import AnalyticsClient +from .clients.context import ContextClient +from .clients.documents import DocumentClient from .exceptions import GovernsAIError @@ -32,7 +34,7 @@ class GovernsAIConfig: """Configuration for GovernsAI client.""" api_key: str - base_url: str = "http://localhost:3002" + base_url: str = "https://api.governsai.com" org_id: str = "" timeout: int = 30000 retries: int = 3 @@ -83,7 +85,7 @@ def __init__( else: self.config = GovernsAIConfig( api_key=api_key or os.getenv("GOVERNS_API_KEY", ""), - base_url=base_url or os.getenv("GOVERNS_BASE_URL", "http://localhost:3002"), + base_url=base_url or os.getenv("GOVERNS_BASE_URL", "https://api.governsai.com"), org_id=org_id or os.getenv("GOVERNS_ORG_ID", ""), timeout=timeout or int(os.getenv("GOVERNS_TIMEOUT", "30000")), retries=retries or int(os.getenv("GOVERNS_RETRIES", "3")), @@ -116,6 +118,8 @@ def __init__( self.budget = BudgetClient(self.http_client, self.logger) self.tools = ToolClient(self.http_client, self.logger) self.analytics = AnalyticsClient(self.http_client, self.logger) + self.context = ContextClient(self.http_client, self.logger) + self.documents = DocumentClient(self.http_client, self.logger) @classmethod def from_env(cls) -> "GovernsAIClient": diff --git a/governs_ai/clients/__init__.py b/governs_ai/clients/__init__.py index 777ab9f..c0c8432 100644 --- a/governs_ai/clients/__init__.py +++ b/governs_ai/clients/__init__.py @@ -9,6 +9,8 @@ from .budget import BudgetClient from .tool import ToolClient from .analytics import AnalyticsClient +from .context import ContextClient +from .documents import DocumentClient __all__ = [ "PrecheckClient", @@ -16,4 +18,6 @@ "BudgetClient", "ToolClient", "AnalyticsClient", + "ContextClient", + "DocumentClient", ] diff --git a/governs_ai/clients/context.py b/governs_ai/clients/context.py new file mode 100644 index 0000000..4b67eba --- /dev/null +++ b/governs_ai/clients/context.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 GovernsAI. All rights reserved. +""" +Context client for memory and retrieval operations. +""" + +from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlencode + +from ..exceptions import GovernsAIError +from ..models.context import ( + ConversationItem, + ConversationSummary, + ContextLLMResponse, + MemorySearchResponse, + ResolvedUser, + ResolvedUserDetails, + SaveContextInput, + SaveContextResponse, +) +from ..models.precheck import PrecheckResponse +from ..utils.http import HTTPClient +from ..utils.logging import GovernsAILogger + + +class ContextClient: + """Client for context memory and retrieval APIs.""" + + def __init__(self, http_client: HTTPClient, logger: GovernsAILogger): + self.http_client = http_client + self.logger = logger + + async def save_context_explicit( + self, + input_data: Union[SaveContextInput, Dict[str, Any]], + ) -> SaveContextResponse: + """Save context explicitly without server-side precheck inference.""" + payload = self._to_save_context_payload(input_data) + try: + response = await self.http_client.post("/api/v1/context", data=payload) + return SaveContextResponse.from_dict(response.data) + except Exception as exc: + self.logger.error(f"Save context failed: {exc}") + raise GovernsAIError(f"Save context failed: {exc}") + + async def store_context( + self, + input_data: Union[SaveContextInput, Dict[str, Any]], + ) -> SaveContextResponse: + """Store context (server may apply precheck logic).""" + return await self.save_context_explicit(input_data) + + async def search_context_llm(self, input_data: Dict[str, Any]) -> ContextLLMResponse: + """Search context using LLM-optimized compressed response format.""" + try: + response = await self.http_client.post("/api/v1/context/search/llm", data=input_data) + return ContextLLMResponse.from_dict(response.data) + except Exception as exc: + self.logger.error(f"Context search failed: {exc}") + raise GovernsAIError(f"Context search failed: {exc}") + + async def search_cross_agent( + self, + query: str, + limit: Optional[int] = None, + threshold: Optional[float] = None, + scope: Optional[str] = None, + ) -> ContextLLMResponse: + """Search across agents using LLM-optimized response format.""" + payload: Dict[str, Any] = {"query": query} + if limit is not None: + payload["limit"] = limit + if threshold is not None: + payload["threshold"] = threshold + if scope: + payload["scope"] = scope + return await self.search_context_llm(payload) + + async def get_or_create_conversation(self, input_data: Dict[str, Any]) -> ConversationSummary: + """Get or create a conversation for a given agent.""" + try: + response = await self.http_client.post("/api/v1/context/conversation", data=input_data) + return ConversationSummary.from_dict(response.data) + except Exception as exc: + self.logger.error(f"Get or create conversation failed: {exc}") + raise GovernsAIError(f"Get or create conversation failed: {exc}") + + async def get_conversation_context( + self, + conversation_id: str, + agent_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> List[ConversationItem]: + """Get context entries for a conversation.""" + params: Dict[str, Any] = {} + if agent_id: + params["agentId"] = agent_id + if limit is not None: + params["limit"] = limit + suffix = f"?{urlencode(params)}" if params else "" + + try: + endpoint = f"/api/v1/context/conversation/{conversation_id}{suffix}" + response = await self.http_client.get(endpoint) + contexts = response.data.get("contexts", []) if isinstance(response.data, dict) else [] + return [ + ConversationItem.from_dict(item) + for item in contexts + if isinstance(item, dict) + ] + except Exception as exc: + self.logger.error(f"Get conversation context failed: {exc}") + raise GovernsAIError(f"Get conversation context failed: {exc}") + + async def get_recent_context( + self, + user_id: Optional[str] = None, + limit: int = 20, + scope: Optional[str] = None, + ) -> ContextLLMResponse: + """Retrieve recent context using the LLM response shape.""" + payload: Dict[str, Any] = { + "query": "recent context", + "limit": limit, + } + if user_id: + payload["userId"] = user_id + if scope: + payload["scope"] = scope + return await self.search_context_llm(payload) + + async def maybe_save_from_precheck( + self, + precheck: Union[PrecheckResponse, Dict[str, Any]], + agent_id: str, + fallback_content: Optional[str] = None, + agent_name: Optional[str] = None, + conversation_id: Optional[str] = None, + correlation_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + scope: Optional[str] = None, + visibility: Optional[str] = None, + ) -> Dict[str, Any]: + """Best-effort helper to persist context from precheck suggestions.""" + payload = precheck.to_dict() if isinstance(precheck, PrecheckResponse) else precheck + if not isinstance(payload, dict): + return {"saved": False} + + intent = payload.get("intent") or {} + has_intent = isinstance(intent, dict) and intent.get("save") is True + + suggested_actions = payload.get("suggestedActions") or [] + save_action: Optional[Dict[str, Any]] = None + if isinstance(suggested_actions, list): + for action in suggested_actions: + if isinstance(action, dict) and action.get("type") == "context.save": + save_action = action + break + + if not has_intent and save_action is None: + return {"saved": False} + + content = None + if isinstance(save_action, dict): + content = save_action.get("content") + if not content: + content = fallback_content + if not content: + content = self._extract_content_from_precheck(payload) + + if not content: + return {"saved": False} + + combined_metadata: Dict[str, Any] = {} + if isinstance(save_action, dict) and isinstance(save_action.get("metadata"), dict): + combined_metadata.update(save_action["metadata"]) + if isinstance(metadata, dict): + combined_metadata.update(metadata) + + save_input = SaveContextInput( + content=content, + content_type="user_message", + agent_id=agent_id, + agent_name=agent_name, + conversation_id=conversation_id, + correlation_id=correlation_id, + metadata=combined_metadata or None, + scope=scope, + visibility=visibility, + ) + + try: + saved = await self.store_context(save_input) + return {"saved": True, "contextId": saved.context_id} + except Exception as exc: + self.logger.error(f"maybe_save_from_precheck failed: {exc}") + return {"saved": False} + + async def store_memory(self, params: Dict[str, Any]) -> SaveContextResponse: + """Store memory for an external user identity.""" + payload: Dict[str, Any] = { + "content": params.get("content"), + "contentType": params.get("contentType", "user_message"), + "agentId": params.get("agentId", "external-app"), + "agentName": params.get("agentName"), + "externalUserId": params.get("externalUserId"), + "externalSource": params.get("externalSource", "default"), + "metadata": params.get("metadata"), + "scope": params.get("scope", "user"), + "visibility": params.get("visibility", "private"), + "email": params.get("email"), + "name": params.get("name"), + } + payload = {key: value for key, value in payload.items() if value is not None} + + try: + response = await self.http_client.post("/api/v1/context", data=payload) + return SaveContextResponse.from_dict(response.data) + except Exception as exc: + self.logger.error(f"Store memory failed: {exc}") + raise GovernsAIError(f"Store memory failed: {exc}") + + async def search_memory(self, params: Dict[str, Any]) -> MemorySearchResponse: + """Search memory for an external user identity.""" + payload: Dict[str, Any] = { + "query": params.get("query"), + "externalUserId": params.get("externalUserId"), + "externalSource": params.get("externalSource", "default"), + "limit": params.get("limit", 10), + "threshold": params.get("threshold", 0.5), + "scope": params.get("scope", "user"), + "agentId": params.get("agentId"), + "contentTypes": params.get("contentTypes"), + } + payload = {key: value for key, value in payload.items() if value is not None} + + try: + response = await self.http_client.post("/api/v1/context/search", data=payload) + return MemorySearchResponse.from_dict(response.data) + except Exception as exc: + self.logger.error(f"Search memory failed: {exc}") + raise GovernsAIError(f"Search memory failed: {exc}") + + async def resolve_user(self, params: Dict[str, Any]) -> ResolvedUser: + """Resolve an external user identifier to a GovernsAI internal user.""" + payload: Dict[str, Any] = { + "externalUserId": params.get("externalUserId"), + "externalSource": params.get("externalSource", "default"), + "email": params.get("email"), + "name": params.get("name"), + } + payload = {key: value for key, value in payload.items() if value is not None} + + try: + response = await self.http_client.post("/api/v1/users/resolve", data=payload) + return ResolvedUser.from_dict(response.data) + except Exception as exc: + self.logger.error(f"Resolve user failed: {exc}") + raise GovernsAIError(f"Resolve user failed: {exc}") + + async def get_user_by_external_id( + self, + external_user_id: str, + external_source: str = "default", + ) -> Optional[ResolvedUserDetails]: + """Look up an external user without auto-creating a user.""" + query = urlencode({"externalUserId": external_user_id, "externalSource": external_source}) + + try: + response = await self.http_client.get(f"/api/v1/users/resolve?{query}") + user_payload = response.data.get("user") if isinstance(response.data, dict) else None + if not isinstance(user_payload, dict): + return None + return ResolvedUserDetails.from_dict(user_payload) + except Exception as exc: + # API returns 404 when no user exists for external id. + if getattr(exc, "status_code", None) == 404: + return None + self.logger.error(f"Get user by external id failed: {exc}") + raise GovernsAIError(f"Get user by external id failed: {exc}") + + def _to_save_context_payload(self, input_data: Union[SaveContextInput, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(input_data, SaveContextInput): + return input_data.to_dict() + + payload = dict(input_data) + if "content_type" in payload and "contentType" not in payload: + payload["contentType"] = payload.pop("content_type") + if "agent_id" in payload and "agentId" not in payload: + payload["agentId"] = payload.pop("agent_id") + if "agent_name" in payload and "agentName" not in payload: + payload["agentName"] = payload.pop("agent_name") + if "conversation_id" in payload and "conversationId" not in payload: + payload["conversationId"] = payload.pop("conversation_id") + if "parent_id" in payload and "parentId" not in payload: + payload["parentId"] = payload.pop("parent_id") + if "correlation_id" in payload and "correlationId" not in payload: + payload["correlationId"] = payload.pop("correlation_id") + if "expires_at" in payload and "expiresAt" not in payload: + payload["expiresAt"] = payload.pop("expires_at") + return payload + + def _extract_content_from_precheck(self, payload: Dict[str, Any]) -> Optional[str]: + content = payload.get("content") + if not isinstance(content, dict): + return None + + messages = content.get("messages") + if not isinstance(messages, list): + return None + + fragments: List[str] = [] + for message in messages: + if isinstance(message, dict) and isinstance(message.get("content"), str): + fragments.append(message["content"]) + + if not fragments: + return None + return "\n".join(fragments) diff --git a/governs_ai/clients/documents.py b/governs_ai/clients/documents.py new file mode 100644 index 0000000..e6c324c --- /dev/null +++ b/governs_ai/clients/documents.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 GovernsAI. All rights reserved. +""" +Document client for OCR + chunking + vector search operations. +""" + +import json +import os +from typing import Any, Dict, Optional, Tuple, Union + +import aiohttp + +from ..exceptions import GovernsAIError +from ..models.documents import ( + DocumentDetails, + DocumentListResponse, + DocumentSearchResponse, + DocumentUploadResponse, +) +from ..utils.http import HTTPClient +from ..utils.logging import GovernsAILogger + + +class DocumentClient: + """Client for document APIs.""" + + def __init__(self, http_client: HTTPClient, logger: GovernsAILogger): + self.http_client = http_client + self.logger = logger + + async def upload_document( + self, + file: Union[bytes, bytearray, memoryview, str, Any], + filename: Optional[str] = None, + content_type: Optional[str] = None, + external_user_id: Optional[str] = None, + external_source: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + scope: Optional[str] = None, + visibility: Optional[str] = None, + email: Optional[str] = None, + name: Optional[str] = None, + processing_mode: Optional[str] = None, + ) -> DocumentUploadResponse: + """Upload a document for OCR and RAG indexing.""" + form_data = aiohttp.FormData() + file_payload, inferred_name = self._normalize_file(file, filename) + + form_data.add_field( + "file", + file_payload, + filename=inferred_name, + content_type=content_type, + ) + form_data.add_field("filename", inferred_name) + + if content_type: + form_data.add_field("contentType", content_type) + if external_user_id: + form_data.add_field("externalUserId", external_user_id) + if external_source: + form_data.add_field("externalSource", external_source) + if metadata is not None: + form_data.add_field("metadata", json.dumps(metadata)) + if scope: + form_data.add_field("scope", scope) + if visibility: + form_data.add_field("visibility", visibility) + if email: + form_data.add_field("email", email) + if name: + form_data.add_field("name", name) + if processing_mode: + form_data.add_field("processingMode", processing_mode) + + try: + response = await self.http_client.post_form_data("/api/v1/documents", form_data) + return DocumentUploadResponse.from_dict(response.data) + except Exception as exc: + self.logger.error(f"Document upload failed: {exc}") + raise GovernsAIError(f"Document upload failed: {exc}") + + async def get_document( + self, + document_id: str, + include_chunks: Optional[bool] = None, + include_content: Optional[bool] = None, + ) -> DocumentDetails: + """Get a document and optionally include chunks/content.""" + params: Dict[str, Any] = {} + if include_chunks is not None: + params["includeChunks"] = include_chunks + if include_content is not None: + params["includeContent"] = include_content + + try: + response = await self.http_client.get(f"/api/v1/documents/{document_id}", params=params) + payload = response.data.get("document", response.data) + return DocumentDetails.from_dict(payload) + except Exception as exc: + self.logger.error(f"Get document failed: {exc}") + raise GovernsAIError(f"Get document failed: {exc}") + + async def list_documents( + self, + user_id: Optional[str] = None, + external_user_id: Optional[str] = None, + external_source: Optional[str] = None, + status: Optional[str] = None, + content_type: Optional[str] = None, + include_archived: Optional[bool] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> DocumentListResponse: + """List documents with optional filters.""" + params = { + "userId": user_id, + "externalUserId": external_user_id, + "externalSource": external_source, + "status": status, + "contentType": content_type, + "includeArchived": include_archived, + "limit": limit, + "offset": offset, + } + params = {key: value for key, value in params.items() if value is not None} + + try: + response = await self.http_client.get("/api/v1/documents", params=params) + return DocumentListResponse.from_dict(response.data) + except Exception as exc: + self.logger.error(f"List documents failed: {exc}") + raise GovernsAIError(f"List documents failed: {exc}") + + async def search_documents(self, params: Dict[str, Any]) -> DocumentSearchResponse: + """Vector-search across document chunks.""" + try: + response = await self.http_client.post("/api/v1/documents/search", data=params) + return DocumentSearchResponse.from_dict(response.data) + except Exception as exc: + self.logger.error(f"Search documents failed: {exc}") + raise GovernsAIError(f"Search documents failed: {exc}") + + async def delete_document(self, document_id: str) -> Dict[str, Any]: + """Delete a document and associated chunks.""" + try: + response = await self.http_client.delete(f"/api/v1/documents/{document_id}") + return response.data + except Exception as exc: + self.logger.error(f"Delete document failed: {exc}") + raise GovernsAIError(f"Delete document failed: {exc}") + + def _normalize_file( + self, + file: Union[bytes, bytearray, memoryview, str, Any], + filename: Optional[str], + ) -> Tuple[bytes, str]: + """Normalize supported file inputs to bytes + filename.""" + if isinstance(file, str): + if not os.path.isfile(file): + raise GovernsAIError(f"Document path does not exist: {file}") + with open(file, "rb") as file_handle: + content = file_handle.read() + inferred_name = filename or os.path.basename(file) + return content, inferred_name + + if isinstance(file, (bytes, bytearray, memoryview)): + return bytes(file), filename or "document" + + if hasattr(file, "read") and callable(file.read): + content = file.read() + if isinstance(content, str): + content = content.encode("utf-8") + if not isinstance(content, (bytes, bytearray, memoryview)): + raise GovernsAIError("File-like object must return bytes from read()") + + inferred_name = filename or getattr(file, "name", None) or "document" + return bytes(content), os.path.basename(str(inferred_name)) + + raise GovernsAIError( + "Unsupported file type. Use bytes, bytearray, memoryview, file path, or file-like object." + ) diff --git a/governs_ai/exceptions/base.py b/governs_ai/exceptions/base.py index 24a5c91..539e562 100644 --- a/governs_ai/exceptions/base.py +++ b/governs_ai/exceptions/base.py @@ -35,9 +35,14 @@ def __str__(self) -> str: class NetworkError(GovernsAIError): """Network-related errors.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__( + self, + message: str, + status_code: Optional[int] = None, + original_error: Optional[Exception] = None, + ): """Initialize network error.""" - super().__init__(message, retryable=True) + super().__init__(message, status_code=status_code, retryable=True) self.original_error = original_error diff --git a/governs_ai/models/__init__.py b/governs_ai/models/__init__.py index 7a04848..174ba48 100644 --- a/governs_ai/models/__init__.py +++ b/governs_ai/models/__init__.py @@ -8,6 +8,29 @@ from .budget import BudgetContext, UsageRecord, BudgetStatus from .confirmation import ConfirmationRequest, ConfirmationResponse from .health import HealthStatus +from .context import ( + SaveContextInput, + SaveContextResponse, + ContextLLMResponse, + ConversationSummary, + ConversationItem, + MemoryRecord, + MemorySearchMetadata, + MemorySearchResponse, + ResolvedUserDetails, + ResolvedUser, +) +from .documents import ( + DocumentUploadResponse, + DocumentChunk, + DocumentRecord, + DocumentDetails, + DocumentListPagination, + DocumentListResponse, + DocumentSearchSource, + DocumentSearchResult, + DocumentSearchResponse, +) from .analytics import ( DecisionAnalytics, ToolCallAnalytics, @@ -26,6 +49,25 @@ "ConfirmationRequest", "ConfirmationResponse", "HealthStatus", + "SaveContextInput", + "SaveContextResponse", + "ContextLLMResponse", + "ConversationSummary", + "ConversationItem", + "MemoryRecord", + "MemorySearchMetadata", + "MemorySearchResponse", + "ResolvedUserDetails", + "ResolvedUser", + "DocumentUploadResponse", + "DocumentChunk", + "DocumentRecord", + "DocumentDetails", + "DocumentListPagination", + "DocumentListResponse", + "DocumentSearchSource", + "DocumentSearchResult", + "DocumentSearchResponse", "DecisionAnalytics", "ToolCallAnalytics", "SpendAnalytics", diff --git a/governs_ai/models/context.py b/governs_ai/models/context.py new file mode 100644 index 0000000..abcd73b --- /dev/null +++ b/governs_ai/models/context.py @@ -0,0 +1,234 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 GovernsAI. All rights reserved. +""" +Context memory data models for the GovernsAI Python SDK. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class SaveContextInput: + """Input payload for creating context records.""" + + content: str + content_type: str + agent_id: str + agent_name: Optional[str] = None + conversation_id: Optional[str] = None + parent_id: Optional[str] = None + correlation_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + scope: Optional[str] = None + visibility: Optional[str] = None + expires_at: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "content": self.content, + "contentType": self.content_type, + "agentId": self.agent_id, + "agentName": self.agent_name, + "conversationId": self.conversation_id, + "parentId": self.parent_id, + "correlationId": self.correlation_id, + "metadata": self.metadata, + "scope": self.scope, + "visibility": self.visibility, + "expiresAt": self.expires_at, + } + return {key: value for key, value in payload.items() if value is not None} + + +@dataclass +class SaveContextResponse: + """Response payload for context writes.""" + + context_id: str + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SaveContextResponse": + return cls(context_id=data.get("contextId", data.get("context_id", ""))) + + +@dataclass +class ContextLLMResponse: + """LLM-optimized context search response.""" + + success: bool + context: str + memory_count: int + high_confidence: int + medium_confidence: int + low_confidence: int + token_estimate: int + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ContextLLMResponse": + return cls( + success=bool(data.get("success", False)), + context=data.get("context", ""), + memory_count=int(data.get("memoryCount", data.get("memory_count", 0))), + high_confidence=int(data.get("highConfidence", data.get("high_confidence", 0))), + medium_confidence=int(data.get("mediumConfidence", data.get("medium_confidence", 0))), + low_confidence=int(data.get("lowConfidence", data.get("low_confidence", 0))), + token_estimate=int(data.get("tokenEstimate", data.get("token_estimate", 0))), + ) + + +@dataclass +class ConversationSummary: + """Conversation metadata summary.""" + + id: str + message_count: int + token_count: int + scope: str + title: Optional[str] = None + last_message_at: Optional[str] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConversationSummary": + return cls( + id=data.get("id", ""), + title=data.get("title"), + message_count=int(data.get("messageCount", data.get("message_count", 0))), + token_count=int(data.get("tokenCount", data.get("token_count", 0))), + last_message_at=data.get("lastMessageAt", data.get("last_message_at")), + scope=data.get("scope", "user"), + ) + + +@dataclass +class ConversationItem: + """Single context item in a conversation.""" + + id: str + content: str + content_type: str + created_at: str + agent_id: Optional[str] = None + parent_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConversationItem": + return cls( + id=data.get("id", ""), + content=data.get("content", ""), + content_type=data.get("contentType", data.get("content_type", "")), + created_at=data.get("createdAt", data.get("created_at", "")), + agent_id=data.get("agentId", data.get("agent_id")), + parent_id=data.get("parentId", data.get("parent_id")), + metadata=data.get("metadata"), + ) + + +@dataclass +class MemoryRecord: + """External memory search result item.""" + + id: str + content: str + content_type: str + created_at: str + summary: Optional[str] = None + agent_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + similarity: Optional[float] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemoryRecord": + similarity = data.get("similarity") + return cls( + id=data.get("id", ""), + content=data.get("content", ""), + content_type=data.get("contentType", data.get("content_type", "")), + created_at=data.get("createdAt", data.get("created_at", "")), + summary=data.get("summary"), + agent_id=data.get("agentId", data.get("agent_id")), + metadata=data.get("metadata"), + similarity=float(similarity) if similarity is not None else None, + ) + + +@dataclass +class MemorySearchMetadata: + """Metadata included with memory search responses.""" + + high_confidence: int = 0 + medium_confidence: int = 0 + low_confidence: int = 0 + token_estimate: int = 0 + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemorySearchMetadata": + return cls( + high_confidence=int(data.get("highConfidence", data.get("high_confidence", 0))), + medium_confidence=int(data.get("mediumConfidence", data.get("medium_confidence", 0))), + low_confidence=int(data.get("lowConfidence", data.get("low_confidence", 0))), + token_estimate=int(data.get("tokenEstimate", data.get("token_estimate", 0))), + ) + + +@dataclass +class MemorySearchResponse: + """Response payload for external memory search.""" + + success: bool + memories: List[MemoryRecord] = field(default_factory=list) + count: int = 0 + metadata: Optional[MemorySearchMetadata] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemorySearchResponse": + raw_memories = data.get("memories", []) + raw_metadata = data.get("metadata") + return cls( + success=bool(data.get("success", False)), + memories=[MemoryRecord.from_dict(item) for item in raw_memories if isinstance(item, dict)], + count=int(data.get("count", len(raw_memories))), + metadata=MemorySearchMetadata.from_dict(raw_metadata) + if isinstance(raw_metadata, dict) + else None, + ) + + +@dataclass +class ResolvedUserDetails: + """User details returned by resolve endpoints.""" + + id: str + email: str + name: Optional[str] + external_id: Optional[str] + external_source: Optional[str] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ResolvedUserDetails": + return cls( + id=data.get("id", ""), + email=data.get("email", ""), + name=data.get("name"), + external_id=data.get("externalId", data.get("external_id")), + external_source=data.get("externalSource", data.get("external_source")), + ) + + +@dataclass +class ResolvedUser: + """Resolved user wrapper for external user mapping.""" + + internal_user_id: str + created: bool + user: ResolvedUserDetails + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ResolvedUser": + user_payload = data.get("user") if isinstance(data.get("user"), dict) else {} + return cls( + internal_user_id=data.get("internalUserId", data.get("internal_user_id", "")), + created=bool(data.get("created", False)), + user=ResolvedUserDetails.from_dict(user_payload), + ) diff --git a/governs_ai/models/documents.py b/governs_ai/models/documents.py new file mode 100644 index 0000000..35185f8 --- /dev/null +++ b/governs_ai/models/documents.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 GovernsAI. All rights reserved. +""" +Document data models for the GovernsAI Python SDK. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class DocumentUploadResponse: + """Response payload for document uploads.""" + + success: bool + document_id: str + status: str + chunk_count: int + file_hash: str + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentUploadResponse": + return cls( + success=bool(data.get("success", False)), + document_id=data.get("documentId", data.get("document_id", "")), + status=data.get("status", "processing"), + chunk_count=int(data.get("chunkCount", data.get("chunk_count", 0))), + file_hash=data.get("fileHash", data.get("file_hash", "")), + ) + + +@dataclass +class DocumentChunk: + """Single chunk generated from OCR/indexing.""" + + id: str + chunk_index: int + content: str + metadata: Dict[str, Any] + created_at: str + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentChunk": + return cls( + id=data.get("id", ""), + chunk_index=int(data.get("chunkIndex", data.get("chunk_index", 0))), + content=data.get("content", ""), + metadata=data.get("metadata", {}), + created_at=data.get("createdAt", data.get("created_at", "")), + ) + + +@dataclass +class DocumentRecord: + """Base document record returned by list/detail endpoints.""" + + id: str + user_id: str + org_id: str + filename: str + content_type: str + file_size: int + file_hash: str + status: str + chunk_count: int + scope: str + visibility: str + is_archived: bool + created_at: str + updated_at: str + external_user_id: Optional[str] = None + external_source: Optional[str] = None + storage_url: Optional[str] = None + error_message: Optional[str] = None + expires_at: Optional[str] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentRecord": + return cls( + id=data.get("id", ""), + user_id=data.get("userId", data.get("user_id", "")), + org_id=data.get("orgId", data.get("org_id", "")), + external_user_id=data.get("externalUserId", data.get("external_user_id")), + external_source=data.get("externalSource", data.get("external_source")), + filename=data.get("filename", ""), + content_type=data.get("contentType", data.get("content_type", "")), + file_size=int(data.get("fileSize", data.get("file_size", 0))), + file_hash=data.get("fileHash", data.get("file_hash", "")), + storage_url=data.get("storageUrl", data.get("storage_url")), + status=data.get("status", "processing"), + error_message=data.get("errorMessage", data.get("error_message")), + chunk_count=int(data.get("chunkCount", data.get("chunk_count", 0))), + scope=data.get("scope", "user"), + visibility=data.get("visibility", "private"), + is_archived=bool(data.get("isArchived", data.get("is_archived", False))), + created_at=data.get("createdAt", data.get("created_at", "")), + updated_at=data.get("updatedAt", data.get("updated_at", "")), + expires_at=data.get("expiresAt", data.get("expires_at")), + ) + + +@dataclass +class DocumentDetails(DocumentRecord): + """Expanded document details including content/chunks.""" + + content: Optional[str] = None + chunks: List[DocumentChunk] = field(default_factory=list) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentDetails": + base = DocumentRecord.from_dict(data) + raw_chunks = data.get("chunks", []) + return cls( + **base.__dict__, + content=data.get("content"), + chunks=[DocumentChunk.from_dict(chunk) for chunk in raw_chunks if isinstance(chunk, dict)], + ) + + +@dataclass +class DocumentListPagination: + """Pagination metadata for document list responses.""" + + total: int + limit: int + offset: int + has_more: bool + total_pages: int + current_page: int + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentListPagination": + return cls( + total=int(data.get("total", 0)), + limit=int(data.get("limit", 0)), + offset=int(data.get("offset", 0)), + has_more=bool(data.get("hasMore", data.get("has_more", False))), + total_pages=int(data.get("totalPages", data.get("total_pages", 0))), + current_page=int(data.get("currentPage", data.get("current_page", 0))), + ) + + +@dataclass +class DocumentListResponse: + """List documents response payload.""" + + success: bool + documents: List[DocumentRecord] + pagination: DocumentListPagination + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentListResponse": + documents = [ + DocumentRecord.from_dict(item) + for item in data.get("documents", []) + if isinstance(item, dict) + ] + pagination_data = data.get("pagination") if isinstance(data.get("pagination"), dict) else {} + return cls( + success=bool(data.get("success", False)), + documents=documents, + pagination=DocumentListPagination.from_dict(pagination_data), + ) + + +@dataclass +class DocumentSearchSource: + """Document metadata included with each search hit.""" + + filename: str + content_type: str + user_id: str + created_at: str + external_user_id: Optional[str] = None + external_source: Optional[str] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentSearchSource": + return cls( + filename=data.get("filename", ""), + content_type=data.get("contentType", data.get("content_type", "")), + user_id=data.get("userId", data.get("user_id", "")), + created_at=data.get("createdAt", data.get("created_at", "")), + external_user_id=data.get("externalUserId", data.get("external_user_id")), + external_source=data.get("externalSource", data.get("external_source")), + ) + + +@dataclass +class DocumentSearchResult: + """Single vector-search result across document chunks.""" + + document_id: str + chunk_id: str + chunk_index: int + content: str + similarity: float + metadata: Dict[str, Any] + document: DocumentSearchSource + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentSearchResult": + document_payload = data.get("document") if isinstance(data.get("document"), dict) else {} + return cls( + document_id=data.get("documentId", data.get("document_id", "")), + chunk_id=data.get("chunkId", data.get("chunk_id", "")), + chunk_index=int(data.get("chunkIndex", data.get("chunk_index", 0))), + content=data.get("content", ""), + similarity=float(data.get("similarity", 0.0)), + metadata=data.get("metadata", {}), + document=DocumentSearchSource.from_dict(document_payload), + ) + + +@dataclass +class DocumentSearchResponse: + """Vector-search response payload for documents.""" + + success: bool + results: List[DocumentSearchResult] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentSearchResponse": + return cls( + success=bool(data.get("success", False)), + results=[ + DocumentSearchResult.from_dict(item) + for item in data.get("results", []) + if isinstance(item, dict) + ], + ) diff --git a/governs_ai/utils/http.py b/governs_ai/utils/http.py index 36b050c..a371b28 100644 --- a/governs_ai/utils/http.py +++ b/governs_ai/utils/http.py @@ -6,7 +6,7 @@ import aiohttp import asyncio -from typing import Dict, Any, Optional, Union +from typing import Dict, Any, Optional from dataclasses import dataclass from ..exceptions.base import NetworkError, AuthenticationError, AuthorizationError, RateLimitError @@ -71,13 +71,18 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" await self.close() - def _get_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + def _get_headers( + self, + additional_headers: Optional[Dict[str, str]] = None, + content_type: Optional[str] = "application/json", + ) -> Dict[str, str]: """Get default headers for requests.""" headers = { "X-Governs-Key": self.api_key, - "Content-Type": "application/json", "User-Agent": "governs-ai-python-sdk/1.0.0", } + if content_type is not None: + headers["Content-Type"] = content_type if additional_headers: headers.update(additional_headers) return headers @@ -99,28 +104,33 @@ def _handle_response_error(self, response: HTTPResponse) -> None: ) elif response.is_client_error: error_msg = response.data.get("message", "Client error") - raise NetworkError(f"Client error: {error_msg}") + raise NetworkError(f"Client error: {error_msg}", status_code=response.status_code) elif response.is_server_error: error_msg = response.data.get("message", "Server error") - raise NetworkError(f"Server error: {error_msg}") + raise NetworkError(f"Server error: {error_msg}", status_code=response.status_code) async def request( self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None, + form_data: Optional[aiohttp.FormData] = None, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ) -> HTTPResponse: """Make HTTP request.""" url = f"{self.base_url}/{endpoint.lstrip('/')}" - request_headers = self._get_headers(headers) + request_headers = self._get_headers( + headers, + content_type=None if form_data is not None else "application/json", + ) try: async with self.session.request( method=method, url=url, - json=data, + json=data if form_data is None else None, + data=form_data, params=params, headers=request_headers, ) as response: @@ -167,6 +177,22 @@ async def post( """Make POST request.""" return await self.request("POST", endpoint, data=data, params=params, headers=headers) + async def post_form_data( + self, + endpoint: str, + form_data: aiohttp.FormData, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> HTTPResponse: + """Make POST request with multipart/form-data payload.""" + return await self.request( + "POST", + endpoint, + form_data=form_data, + params=params, + headers=headers, + ) + async def put( self, endpoint: str, diff --git a/setup.py b/setup.py deleted file mode 100644 index 59272b3..0000000 --- a/setup.py +++ /dev/null @@ -1,45 +0,0 @@ -from setuptools import setup, find_packages - -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() - -setup( - name="governs-ai-sdk", - version="1.0.0", - author="GovernsAI", - author_email="support@governs.ai", - description="Python SDK for GovernsAI - AI governance and compliance platform", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/governs-ai/python-sdk", - packages=find_packages(), - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - python_requires=">=3.8", - install_requires=[ - "requests>=2.25.0", - "pydantic>=1.8.0", - "typing-extensions>=3.10.0", - "aiohttp>=3.8.0", - "asyncio-throttle>=1.0.0", - ], - extras_require={ - "dev": [ - "pytest>=6.0", - "pytest-asyncio>=0.18.0", - "black>=22.0", - "flake8>=4.0", - "mypy>=0.950", - ], - }, -) diff --git a/test_imports.py b/test_imports.py deleted file mode 100644 index c230dc3..0000000 --- a/test_imports.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify all imports work correctly. -""" - -def test_imports(): - """Test all major imports.""" - try: - # Test main client import - from governs_ai import GovernsAIClient, GovernsAIConfig - print("Main client imports successful") - - # Test feature clients - from governs_ai import PrecheckClient, ConfirmationClient, BudgetClient, ToolClient, AnalyticsClient - print("Feature client imports successful") - - # Test data models - from governs_ai import ( - PrecheckRequest, PrecheckResponse, Decision, - BudgetContext, UsageRecord, ConfirmationRequest, ConfirmationResponse, - HealthStatus - ) - print("Data model imports successful") - - # Test exceptions - from governs_ai import ( - GovernsAIError, PrecheckError, ConfirmationError, - BudgetError, ToolError, AnalyticsError - ) - print("Exception imports successful") - - # Test utilities - from governs_ai.utils import with_retry, RetryConfig, HTTPClient, GovernsAILogger - print("Utility imports successful") - - print("\nAll imports successful! The SDK is properly structured.") - return True - - except ImportError as e: - print(f"Import error: {e}") - return False - except Exception as e: - print(f"Unexpected error: {e}") - return False - -if __name__ == "__main__": - test_imports() \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py index e74745f..5460049 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -32,6 +32,11 @@ def client(self, mock_http_client): ) return GovernsAIClient(config=config) + def test_default_base_url_is_production(self): + """Default base URL should target managed API, not localhost.""" + config = GovernsAIConfig(api_key="test-key", org_id="test-org") + assert config.base_url == "https://api.governsai.com" + @pytest.mark.asyncio async def test_test_connection_success(self, client, mock_http_client): """Test successful connection test.""" @@ -146,3 +151,8 @@ def test_get_config(self, client): assert isinstance(config, GovernsAIConfig) assert config.api_key == "test-key" assert config.org_id == "test-org" + + def test_context_and_document_clients_available(self, client): + """Feature parity clients should be initialized on main client.""" + assert client.context is not None + assert client.documents is not None diff --git a/tests/test_context_document_clients.py b/tests/test_context_document_clients.py new file mode 100644 index 0000000..3af4507 --- /dev/null +++ b/tests/test_context_document_clients.py @@ -0,0 +1,140 @@ +""" +Tests for ContextClient and DocumentClient. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from governs_ai.clients.context import ContextClient +from governs_ai.clients.documents import DocumentClient +from governs_ai.models.context import SaveContextInput + + +@pytest.fixture +def mock_http_client(): + client = AsyncMock() + client.post = AsyncMock() + client.get = AsyncMock() + client.delete = AsyncMock() + client.post_form_data = AsyncMock() + return client + + +@pytest.fixture +def logger(): + return MagicMock() + + +@pytest.mark.asyncio +async def test_context_save_context_explicit(mock_http_client, logger): + mock_response = MagicMock() + mock_response.data = {"contextId": "ctx-123"} + mock_http_client.post.return_value = mock_response + + context_client = ContextClient(mock_http_client, logger) + result = await context_client.save_context_explicit( + SaveContextInput( + content="remember this", + content_type="user_message", + agent_id="agent-1", + ) + ) + + assert result.context_id == "ctx-123" + mock_http_client.post.assert_called_once_with( + "/api/v1/context", + data={ + "content": "remember this", + "contentType": "user_message", + "agentId": "agent-1", + }, + ) + + +@pytest.mark.asyncio +async def test_context_search_memory(mock_http_client, logger): + mock_response = MagicMock() + mock_response.data = { + "success": True, + "count": 1, + "memories": [ + { + "id": "mem-1", + "content": "User likes blue widgets", + "contentType": "user_message", + "createdAt": "2026-02-28T00:00:00Z", + } + ], + } + mock_http_client.post.return_value = mock_response + + context_client = ContextClient(mock_http_client, logger) + result = await context_client.search_memory( + { + "query": "preferences", + "externalUserId": "user-1", + "externalSource": "shopify", + } + ) + + assert result.success is True + assert result.count == 1 + assert result.memories[0].id == "mem-1" + mock_http_client.post.assert_called_once() + + +@pytest.mark.asyncio +async def test_document_upload_uses_multipart(mock_http_client, logger): + mock_response = MagicMock() + mock_response.data = { + "success": True, + "documentId": "doc-1", + "status": "processing", + "chunkCount": 0, + "fileHash": "abc123", + } + mock_http_client.post_form_data.return_value = mock_response + + document_client = DocumentClient(mock_http_client, logger) + result = await document_client.upload_document( + file=b"hello world", + filename="hello.txt", + content_type="text/plain", + ) + + assert result.document_id == "doc-1" + mock_http_client.post_form_data.assert_called_once() + call_args = mock_http_client.post_form_data.call_args + assert call_args.args[0] == "/api/v1/documents" + + +@pytest.mark.asyncio +async def test_document_search(mock_http_client, logger): + mock_response = MagicMock() + mock_response.data = { + "success": True, + "results": [ + { + "documentId": "doc-1", + "chunkId": "chunk-1", + "chunkIndex": 0, + "content": "Matching content", + "similarity": 0.92, + "metadata": {}, + "document": { + "filename": "policy.pdf", + "contentType": "application/pdf", + "userId": "user-1", + "createdAt": "2026-02-28T00:00:00Z", + }, + } + ], + } + mock_http_client.post.return_value = mock_response + + document_client = DocumentClient(mock_http_client, logger) + result = await document_client.search_documents({"query": "policy", "limit": 5}) + + assert result.success is True + assert len(result.results) == 1 + assert result.results[0].document_id == "doc-1" diff --git a/tests/test_http_client.py b/tests/test_http_client.py index 1e3394a..ca5ba7a 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -22,3 +22,19 @@ def test_default_auth_header_uses_x_governs_key(): assert headers["X-Governs-Key"] == "test-key" assert "Authorization" not in headers + + +def test_multipart_headers_do_not_set_content_type(): + """Multipart requests should let aiohttp set the boundary header.""" + mock_session = MagicMock() + mock_session.closed = True + + client = HTTPClient( + base_url="http://example.com", + api_key="test-key", + session=mock_session, + ) + + headers = client._get_headers(content_type=None) + + assert "Content-Type" not in headers