diff --git a/pyproject.toml b/pyproject.toml index 99cc008..ac5f68e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "robosystems-client" -version = "0.1.13" +version = "0.1.14" description = "Python Client for RoboSystems financial graph database API" readme = "README.md" requires-python = ">=3.10" diff --git a/robosystems_client/api/copy/copy_data_to_graph.py b/robosystems_client/api/copy/copy_data_to_graph.py index 204915e..5fb1d4a 100644 --- a/robosystems_client/api/copy/copy_data_to_graph.py +++ b/robosystems_client/api/copy/copy_data_to_graph.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast import httpx @@ -50,11 +50,32 @@ def _get_kwargs( def _parse_response( *, client: Union[AuthenticatedClient, Client], response: httpx.Response -) -> Optional[Union[CopyResponse, HTTPValidationError]]: +) -> Optional[Union[Any, CopyResponse, HTTPValidationError]]: if response.status_code == 200: response_200 = CopyResponse.from_dict(response.json()) return response_200 + if response.status_code == 202: + response_202 = cast(Any, None) + return response_202 + if response.status_code == 400: + response_400 = cast(Any, None) + return response_400 + if response.status_code == 403: + response_403 = cast(Any, None) + return response_403 + if response.status_code == 408: + response_408 = cast(Any, None) + return response_408 + if response.status_code == 429: + response_429 = cast(Any, None) + return response_429 + if response.status_code == 500: + response_500 = cast(Any, None) + return response_500 + if response.status_code == 503: + response_503 = cast(Any, None) + return response_503 if response.status_code == 422: response_422 = HTTPValidationError.from_dict(response.json()) @@ -67,7 +88,7 @@ def _parse_response( def _build_response( *, client: Union[AuthenticatedClient, Client], response: httpx.Response -) -> Response[Union[CopyResponse, HTTPValidationError]]: +) -> Response[Union[Any, CopyResponse, HTTPValidationError]]: return Response( status_code=HTTPStatus(response.status_code), content=response.content, @@ -83,7 +104,7 @@ def sync_detailed( body: Union["DataFrameCopyRequest", "S3CopyRequest", "URLCopyRequest"], authorization: Union[None, Unset, str] = UNSET, auth_token: Union[None, Unset, str] = UNSET, -) -> Response[Union[CopyResponse, HTTPValidationError]]: +) -> Response[Union[Any, CopyResponse, HTTPValidationError]]: """Copy Data to Graph Copy data from external sources into the graph database. @@ -105,10 +126,46 @@ def sync_detailed( - Premium: 100GB max file size, 60 min timeout **Copy Options:** - - `ignore_errors`: Skip duplicate/invalid rows (enables upsert-like behavior) + - `ignore_errors`: Skip duplicate/invalid rows (enables upsert-like behavior). Note: When enabled, + row counts may not be accurately reported - `extended_timeout`: Use extended timeout for large datasets - `validate_schema`: Validate source schema against target table + **Asynchronous Execution with SSE:** + For large data imports, this endpoint returns immediately with an operation ID + and SSE monitoring endpoint. Connect to the returned stream URL for real-time updates: + + ```javascript + const eventSource = new EventSource('/v1/operations/{operation_id}/stream'); + eventSource.onmessage = (event) => { + const data = JSON.parse(event.data); + console.log('Progress:', data.message); + }; + ``` + + **SSE Events Emitted:** + - `operation_started`: Copy operation begins + - `operation_progress`: Progress updates during data transfer + - `operation_completed`: Copy successful with statistics + - `operation_error`: Copy failed with error details + + **SSE Connection Limits:** + - Maximum 5 concurrent SSE connections per user + - Rate limited to 10 new connections per minute + - Automatic circuit breaker for Redis failures + - Graceful degradation if event system unavailable + + **Error Handling:** + - `403 Forbidden`: Attempted copy to shared repository + - `408 Request Timeout`: Operation exceeded timeout limit + - `429 Too Many Requests`: Rate limit exceeded + - `503 Service Unavailable`: Circuit breaker open or service unavailable + - Clients should implement exponential backoff on errors + + **Note:** + Copy operations are FREE - no credit consumption required. + All copy operations are performed asynchronously with progress monitoring. + Args: graph_id (str): Target graph identifier (user graphs only - shared repositories not allowed) @@ -121,7 +178,7 @@ def sync_detailed( httpx.TimeoutException: If the request takes longer than Client.timeout. Returns: - Response[Union[CopyResponse, HTTPValidationError]] + Response[Union[Any, CopyResponse, HTTPValidationError]] """ kwargs = _get_kwargs( @@ -145,7 +202,7 @@ def sync( body: Union["DataFrameCopyRequest", "S3CopyRequest", "URLCopyRequest"], authorization: Union[None, Unset, str] = UNSET, auth_token: Union[None, Unset, str] = UNSET, -) -> Optional[Union[CopyResponse, HTTPValidationError]]: +) -> Optional[Union[Any, CopyResponse, HTTPValidationError]]: """Copy Data to Graph Copy data from external sources into the graph database. @@ -167,10 +224,46 @@ def sync( - Premium: 100GB max file size, 60 min timeout **Copy Options:** - - `ignore_errors`: Skip duplicate/invalid rows (enables upsert-like behavior) + - `ignore_errors`: Skip duplicate/invalid rows (enables upsert-like behavior). Note: When enabled, + row counts may not be accurately reported - `extended_timeout`: Use extended timeout for large datasets - `validate_schema`: Validate source schema against target table + **Asynchronous Execution with SSE:** + For large data imports, this endpoint returns immediately with an operation ID + and SSE monitoring endpoint. Connect to the returned stream URL for real-time updates: + + ```javascript + const eventSource = new EventSource('/v1/operations/{operation_id}/stream'); + eventSource.onmessage = (event) => { + const data = JSON.parse(event.data); + console.log('Progress:', data.message); + }; + ``` + + **SSE Events Emitted:** + - `operation_started`: Copy operation begins + - `operation_progress`: Progress updates during data transfer + - `operation_completed`: Copy successful with statistics + - `operation_error`: Copy failed with error details + + **SSE Connection Limits:** + - Maximum 5 concurrent SSE connections per user + - Rate limited to 10 new connections per minute + - Automatic circuit breaker for Redis failures + - Graceful degradation if event system unavailable + + **Error Handling:** + - `403 Forbidden`: Attempted copy to shared repository + - `408 Request Timeout`: Operation exceeded timeout limit + - `429 Too Many Requests`: Rate limit exceeded + - `503 Service Unavailable`: Circuit breaker open or service unavailable + - Clients should implement exponential backoff on errors + + **Note:** + Copy operations are FREE - no credit consumption required. + All copy operations are performed asynchronously with progress monitoring. + Args: graph_id (str): Target graph identifier (user graphs only - shared repositories not allowed) @@ -183,7 +276,7 @@ def sync( httpx.TimeoutException: If the request takes longer than Client.timeout. Returns: - Union[CopyResponse, HTTPValidationError] + Union[Any, CopyResponse, HTTPValidationError] """ return sync_detailed( @@ -202,7 +295,7 @@ async def asyncio_detailed( body: Union["DataFrameCopyRequest", "S3CopyRequest", "URLCopyRequest"], authorization: Union[None, Unset, str] = UNSET, auth_token: Union[None, Unset, str] = UNSET, -) -> Response[Union[CopyResponse, HTTPValidationError]]: +) -> Response[Union[Any, CopyResponse, HTTPValidationError]]: """Copy Data to Graph Copy data from external sources into the graph database. @@ -224,10 +317,46 @@ async def asyncio_detailed( - Premium: 100GB max file size, 60 min timeout **Copy Options:** - - `ignore_errors`: Skip duplicate/invalid rows (enables upsert-like behavior) + - `ignore_errors`: Skip duplicate/invalid rows (enables upsert-like behavior). Note: When enabled, + row counts may not be accurately reported - `extended_timeout`: Use extended timeout for large datasets - `validate_schema`: Validate source schema against target table + **Asynchronous Execution with SSE:** + For large data imports, this endpoint returns immediately with an operation ID + and SSE monitoring endpoint. Connect to the returned stream URL for real-time updates: + + ```javascript + const eventSource = new EventSource('/v1/operations/{operation_id}/stream'); + eventSource.onmessage = (event) => { + const data = JSON.parse(event.data); + console.log('Progress:', data.message); + }; + ``` + + **SSE Events Emitted:** + - `operation_started`: Copy operation begins + - `operation_progress`: Progress updates during data transfer + - `operation_completed`: Copy successful with statistics + - `operation_error`: Copy failed with error details + + **SSE Connection Limits:** + - Maximum 5 concurrent SSE connections per user + - Rate limited to 10 new connections per minute + - Automatic circuit breaker for Redis failures + - Graceful degradation if event system unavailable + + **Error Handling:** + - `403 Forbidden`: Attempted copy to shared repository + - `408 Request Timeout`: Operation exceeded timeout limit + - `429 Too Many Requests`: Rate limit exceeded + - `503 Service Unavailable`: Circuit breaker open or service unavailable + - Clients should implement exponential backoff on errors + + **Note:** + Copy operations are FREE - no credit consumption required. + All copy operations are performed asynchronously with progress monitoring. + Args: graph_id (str): Target graph identifier (user graphs only - shared repositories not allowed) @@ -240,7 +369,7 @@ async def asyncio_detailed( httpx.TimeoutException: If the request takes longer than Client.timeout. Returns: - Response[Union[CopyResponse, HTTPValidationError]] + Response[Union[Any, CopyResponse, HTTPValidationError]] """ kwargs = _get_kwargs( @@ -262,7 +391,7 @@ async def asyncio( body: Union["DataFrameCopyRequest", "S3CopyRequest", "URLCopyRequest"], authorization: Union[None, Unset, str] = UNSET, auth_token: Union[None, Unset, str] = UNSET, -) -> Optional[Union[CopyResponse, HTTPValidationError]]: +) -> Optional[Union[Any, CopyResponse, HTTPValidationError]]: """Copy Data to Graph Copy data from external sources into the graph database. @@ -284,10 +413,46 @@ async def asyncio( - Premium: 100GB max file size, 60 min timeout **Copy Options:** - - `ignore_errors`: Skip duplicate/invalid rows (enables upsert-like behavior) + - `ignore_errors`: Skip duplicate/invalid rows (enables upsert-like behavior). Note: When enabled, + row counts may not be accurately reported - `extended_timeout`: Use extended timeout for large datasets - `validate_schema`: Validate source schema against target table + **Asynchronous Execution with SSE:** + For large data imports, this endpoint returns immediately with an operation ID + and SSE monitoring endpoint. Connect to the returned stream URL for real-time updates: + + ```javascript + const eventSource = new EventSource('/v1/operations/{operation_id}/stream'); + eventSource.onmessage = (event) => { + const data = JSON.parse(event.data); + console.log('Progress:', data.message); + }; + ``` + + **SSE Events Emitted:** + - `operation_started`: Copy operation begins + - `operation_progress`: Progress updates during data transfer + - `operation_completed`: Copy successful with statistics + - `operation_error`: Copy failed with error details + + **SSE Connection Limits:** + - Maximum 5 concurrent SSE connections per user + - Rate limited to 10 new connections per minute + - Automatic circuit breaker for Redis failures + - Graceful degradation if event system unavailable + + **Error Handling:** + - `403 Forbidden`: Attempted copy to shared repository + - `408 Request Timeout`: Operation exceeded timeout limit + - `429 Too Many Requests`: Rate limit exceeded + - `503 Service Unavailable`: Circuit breaker open or service unavailable + - Clients should implement exponential backoff on errors + + **Note:** + Copy operations are FREE - no credit consumption required. + All copy operations are performed asynchronously with progress monitoring. + Args: graph_id (str): Target graph identifier (user graphs only - shared repositories not allowed) @@ -300,7 +465,7 @@ async def asyncio( httpx.TimeoutException: If the request takes longer than Client.timeout. Returns: - Union[CopyResponse, HTTPValidationError] + Union[Any, CopyResponse, HTTPValidationError] """ return ( diff --git a/robosystems_client/extensions/README.md b/robosystems_client/extensions/README.md index 2253ed3..6cdb6f8 100644 --- a/robosystems_client/extensions/README.md +++ b/robosystems_client/extensions/README.md @@ -11,6 +11,7 @@ The RoboSystems Python Client Extensions provide enhanced functionality for the - **Server-Sent Events (SSE)** streaming with automatic reconnection - **Smart Query Execution** with automatic strategy selection +- **Data Copy Operations** with S3 import and real-time progress tracking - **Operation Monitoring** for long-running operations - **Connection Pooling** and intelligent resource management - **Result Processing** and format conversion utilities @@ -84,6 +85,62 @@ async def main(): asyncio.run(main()) ``` +### Data Copy Operations + +```python +from robosystems_client.extensions import CopyClient, CopyOptions +from robosystems_client.models.s3_copy_request import S3CopyRequest +from robosystems_client.models.s3_copy_request_file_format import S3CopyRequestFileFormat + +# Initialize copy client +copy_client = CopyClient({ + "base_url": "https://api.robosystems.ai", + "api_key": "your-api-key", +}) + +# Create S3 copy request +request = S3CopyRequest( + table_name="companies", + s3_path="s3://my-bucket/data/companies.csv", + s3_access_key_id="AWS_ACCESS_KEY", + s3_secret_access_key="AWS_SECRET_KEY", + s3_region="us-east-1", + file_format=S3CopyRequestFileFormat.CSV, + ignore_errors=False, # Stop on first error +) + +# Set up progress callbacks +def on_progress(message, percent): + if percent: + print(f"Progress: {message} ({percent}%)") + else: + print(f"Progress: {message}") + +def on_warning(warning): + print(f"Warning: {warning}") + +options = CopyOptions( + on_progress=on_progress, + on_warning=on_warning, +) + +# Execute copy with progress monitoring +result = copy_client.copy_from_s3("your_graph_id", request, options) + +# Check results +if result.status == "completed": + print(f"✅ Successfully imported {result.rows_imported:,} rows") + stats = copy_client.calculate_statistics(result) + if stats: + print(f"Throughput: {stats.throughput:.2f} rows/second") +elif result.status == "partial": + print(f"⚠️ Imported {result.rows_imported:,} rows, skipped {result.rows_skipped:,}") +else: + print(f"❌ Copy failed: {result.error}") + +copy_client.close() +``` + ## 🔐 Authentication ### API Key Authentication (Recommended) @@ -141,6 +198,79 @@ dev_ext = create_extensions( ## 🛠 Advanced Features +### Copy Operations with Advanced Features + +```python +from robosystems_client.extensions import CopyClient, CopySourceType + +# Batch copy multiple tables +copy_client = CopyClient({ + "base_url": "https://api.robosystems.ai", + "api_key": "your-api-key", +}) + +copies = [ + { + "request": S3CopyRequest( + table_name="companies", + s3_path="s3://bucket/companies.csv", + s3_access_key_id="KEY", + s3_secret_access_key="SECRET", + file_format=S3CopyRequestFileFormat.CSV, + ), + }, + { + "request": S3CopyRequest( + table_name="transactions", + s3_path="s3://bucket/transactions.parquet", + s3_access_key_id="KEY", + s3_secret_access_key="SECRET", + file_format=S3CopyRequestFileFormat.PARQUET, + ignore_errors=True, # Continue on errors + ), + }, +] + +# Execute batch copy +results = copy_client.batch_copy_from_s3("graph_id", copies) + +for i, result in enumerate(results): + table_name = copies[i]["request"].table_name + print(f"{table_name}: {result.status}") + if result.rows_imported: + print(f" Imported: {result.rows_imported:,} rows") + +# Copy with retry logic for resilient operations +result = copy_client.copy_with_retry( + graph_id="graph_id", + request=S3CopyRequest( + table_name="large_dataset", + s3_path="s3://bucket/large-dataset.csv", + s3_access_key_id="KEY", + s3_secret_access_key="SECRET", + max_file_size_gb=50, + extended_timeout=True, + ), + source_type=CopySourceType.S3, + max_retries=3, + options=CopyOptions( + on_progress=lambda msg, _: print(msg) + ), +) + +# Monitor multiple concurrent copy operations +operation_ids = ["op-123", "op-456", "op-789"] +results = copy_client.monitor_multiple_copies(operation_ids, options) + +for op_id, result in results.items(): + print(f"Operation {op_id}: {result.status}") + if result.status == "completed": + stats = copy_client.calculate_statistics(result) + print(f" Throughput: {stats.throughput:.2f} rows/sec") + +copy_client.close() +``` + ### Query Builder Build complex Cypher queries programmatically: @@ -273,6 +403,87 @@ client.close() ## 📊 Examples +### Data Import with Real-Time Monitoring + +```python +from robosystems_client.extensions import CopyClient, CopyOptions +import time + +def import_financial_data(): + """Import financial data with comprehensive monitoring""" + + copy_client = CopyClient({ + "base_url": "https://api.robosystems.ai", + "api_key": "your-api-key", + }) + + # Track progress history + progress_history = [] + warnings_count = 0 + + def on_progress(message, percent): + timestamp = time.strftime("%H:%M:%S") + progress_history.append({ + "time": timestamp, + "message": message, + "percent": percent, + }) + print(f"[{timestamp}] {message}" + (f" ({percent}%)" if percent else "")) + + def on_warning(warning): + nonlocal warnings_count + warnings_count += 1 + print(f"⚠️ Warning #{warnings_count}: {warning}") + + def on_queue_update(position, wait_time): + print(f"📊 Queue position: {position} (ETA: {wait_time}s)") + + # Configure copy with all callbacks + options = CopyOptions( + on_progress=on_progress, + on_warning=on_warning, + on_queue_update=on_queue_update, + timeout=1800000, # 30 minutes + ) + + # Execute copy operation + start_time = time.time() + + result = copy_client.copy_s3( + graph_id="financial_graph", + table_name="quarterly_reports", + s3_path="s3://financial-data/reports-2024-q1.parquet", + access_key_id="AWS_KEY", + secret_access_key="AWS_SECRET", + file_format="parquet", + ignore_errors=True, # Continue on validation errors + ) + + # Print summary + elapsed = time.time() - start_time + + print("\n" + "="*50) + print("📈 IMPORT SUMMARY") + print("="*50) + print(f"Status: {result.status.upper()}") + print(f"Rows Imported: {result.rows_imported or 0:,}") + print(f"Rows Skipped: {result.rows_skipped or 0:,}") + print(f"Warnings: {warnings_count}") + print(f"Execution Time: {elapsed:.2f} seconds") + + if result.status == "completed": + stats = copy_client.calculate_statistics(result) + if stats: + print(f"Throughput: {stats.throughput:.2f} rows/second") + print(f"Data Processed: {stats.bytes_processed / (1024*1024):.2f} MB") + + copy_client.close() + return result + +# Run the import +result = import_financial_data() +``` + ### Financial Data Analysis ```python diff --git a/robosystems_client/extensions/__init__.py b/robosystems_client/extensions/__init__.py index 88b8141..efe3d03 100644 --- a/robosystems_client/extensions/__init__.py +++ b/robosystems_client/extensions/__init__.py @@ -20,6 +20,14 @@ OperationProgress, OperationResult, ) +from .copy_client import ( + CopyClient, + AsyncCopyClient, + CopySourceType, + CopyOptions, + CopyResult, + CopyStatistics, +) from .extensions import ( RoboSystemsExtensions, RoboSystemsExtensionConfig, @@ -68,6 +76,13 @@ "OperationStatus", "OperationProgress", "OperationResult", + # Copy Client + "CopyClient", + "AsyncCopyClient", + "CopySourceType", + "CopyOptions", + "CopyResult", + "CopyStatistics", # Utilities "QueryBuilder", "ResultProcessor", @@ -106,3 +121,17 @@ def execute_query(graph_id: str, query: str, parameters=None): def stream_query(graph_id: str, query: str, parameters=None, chunk_size=None): """Stream a query using the default extensions instance""" return extensions.query.stream_query(graph_id, query, parameters, chunk_size) + + +def copy_from_s3( + graph_id: str, + table_name: str, + s3_path: str, + access_key_id: str, + secret_access_key: str, + **kwargs, +): + """Copy data from S3 using the default extensions instance""" + return extensions.copy_from_s3( + graph_id, table_name, s3_path, access_key_id, secret_access_key, **kwargs + ) diff --git a/robosystems_client/extensions/copy_client.py b/robosystems_client/extensions/copy_client.py new file mode 100644 index 0000000..95ece87 --- /dev/null +++ b/robosystems_client/extensions/copy_client.py @@ -0,0 +1,469 @@ +"""Enhanced Copy Client with SSE support + +Provides intelligent data copy operations with progress monitoring. +""" + +from dataclasses import dataclass +from typing import Dict, Any, Optional, Callable, Union, List +from enum import Enum +import time +import logging + +from ..api.copy.copy_data_to_graph import sync_detailed as copy_data_to_graph +from ..models.s3_copy_request import S3CopyRequest +from ..models.url_copy_request import URLCopyRequest +from ..models.data_frame_copy_request import DataFrameCopyRequest +from ..models.copy_response import CopyResponse +from ..models.copy_response_status import CopyResponseStatus +from ..models.s3_copy_request_file_format import S3CopyRequestFileFormat +from .sse_client import SSEClient, AsyncSSEClient, SSEConfig, EventType + +logger = logging.getLogger(__name__) + + +class CopySourceType(Enum): + """Types of copy sources""" + + S3 = "s3" + URL = "url" + DATAFRAME = "dataframe" + + +@dataclass +class CopyOptions: + """Options for copy operations""" + + on_progress: Optional[Callable[[str, Optional[float]], None]] = None + on_queue_update: Optional[Callable[[int, int], None]] = None + on_warning: Optional[Callable[[str], None]] = None + timeout: Optional[int] = None + test_mode: Optional[bool] = None + + +@dataclass +class CopyResult: + """Result from copy operation""" + + status: str # 'completed', 'failed', 'partial', 'accepted' + rows_imported: Optional[int] = None + rows_skipped: Optional[int] = None + bytes_processed: Optional[int] = None + execution_time_ms: Optional[float] = None + warnings: Optional[List[str]] = None + error: Optional[str] = None + operation_id: Optional[str] = None + sse_url: Optional[str] = None + message: Optional[str] = None + + +@dataclass +class CopyStatistics: + """Statistics from copy operation""" + + total_rows: int + imported_rows: int + skipped_rows: int + bytes_processed: int + duration: float # seconds + throughput: float # rows per second + + +class CopyClient: + """Enhanced copy client with SSE streaming support""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.base_url = config["base_url"] + self.sse_client: Optional[SSEClient] = None + + # Get client authentication if provided + self.auth_token = config.get("auth_token") + self.api_key = config.get("api_key") + + def copy_from_s3( + self, graph_id: str, request: S3CopyRequest, options: Optional[CopyOptions] = None + ) -> CopyResult: + """Copy data from S3 to graph database""" + return self._execute_copy(graph_id, request, CopySourceType.S3, options) + + def copy_from_url( + self, graph_id: str, request: URLCopyRequest, options: Optional[CopyOptions] = None + ) -> CopyResult: + """Copy data from URL to graph database (when available)""" + return self._execute_copy(graph_id, request, CopySourceType.URL, options) + + def copy_from_dataframe( + self, + graph_id: str, + request: DataFrameCopyRequest, + options: Optional[CopyOptions] = None, + ) -> CopyResult: + """Copy data from DataFrame to graph database (when available)""" + return self._execute_copy(graph_id, request, CopySourceType.DATAFRAME, options) + + def _execute_copy( + self, + graph_id: str, + request: Union[S3CopyRequest, URLCopyRequest, DataFrameCopyRequest], + source_type: CopySourceType, + options: Optional[CopyOptions] = None, + ) -> CopyResult: + """Execute copy operation with automatic SSE monitoring for long-running operations""" + if options is None: + options = CopyOptions() + + start_time = time.time() + + # Import client here to avoid circular imports + from ..client import AuthenticatedClient + + # Create authenticated client + client = AuthenticatedClient( + base_url=self.base_url, + token=self.auth_token, + headers={"X-API-Key": self.api_key} if self.api_key else None, + ) + + try: + # Execute the copy request + response = copy_data_to_graph(graph_id=graph_id, client=client, body=request) + + if response.parsed: + response_data: CopyResponse = response.parsed + + # Check if this is an accepted (async) operation + if ( + response_data.status == CopyResponseStatus.ACCEPTED + and response_data.operation_id + ): + # This is a long-running operation with SSE monitoring + if options.on_progress: + options.on_progress("Copy operation started. Monitoring progress...", None) + + # If SSE URL is provided, use it for monitoring + if response_data.sse_url: + return self._monitor_copy_operation( + response_data.operation_id, options, start_time + ) + + # Otherwise return the accepted response + return CopyResult( + status="accepted", + operation_id=response_data.operation_id, + sse_url=response_data.sse_url, + message=response_data.message, + ) + + # This is a synchronous response - operation completed immediately + return self._build_copy_result(response_data, time.time() - start_time) + else: + return CopyResult( + status="failed", + error="No response data received", + execution_time_ms=(time.time() - start_time) * 1000, + ) + + except Exception as e: + return CopyResult( + status="failed", + error=str(e), + execution_time_ms=(time.time() - start_time) * 1000, + ) + + def _monitor_copy_operation( + self, operation_id: str, options: CopyOptions, start_time: float + ) -> CopyResult: + """Monitor a copy operation using SSE""" + timeout_ms = options.timeout or 3600000 # Default 1 hour for copy operations + timeout_time = time.time() + (timeout_ms / 1000) + + result = CopyResult(status="failed") + warnings: List[str] = [] + + # Set up SSE connection + sse_config = SSEConfig(base_url=self.base_url, timeout=timeout_ms // 1000) + sse_client = SSEClient(sse_config) + + try: + sse_client.connect(operation_id) + + # Set up event handlers + def on_queue_update(data): + if options.on_queue_update: + position = data.get("position", data.get("queue_position", 0)) + wait_time = data.get("estimated_wait_seconds", 0) + options.on_queue_update(position, wait_time) + + def on_progress(data): + if options.on_progress: + message = data.get("message", data.get("status", "Processing...")) + progress_percent = data.get("progress_percent", data.get("progress")) + options.on_progress(message, progress_percent) + + # Check for warnings in progress updates + if "warnings" in data and data["warnings"]: + warnings.extend(data["warnings"]) + if options.on_warning: + for warning in data["warnings"]: + options.on_warning(warning) + + def on_completed(data): + nonlocal result + completion_data = data.get("result", data) + result = CopyResult( + status=completion_data.get("status", "completed"), + rows_imported=completion_data.get("rows_imported"), + rows_skipped=completion_data.get("rows_skipped"), + bytes_processed=completion_data.get("bytes_processed"), + execution_time_ms=(time.time() - start_time) * 1000, + warnings=warnings if warnings else completion_data.get("warnings"), + message=completion_data.get("message"), + ) + + def on_error(data): + nonlocal result + result = CopyResult( + status="failed", + error=data.get("message", data.get("error", "Copy operation failed")), + execution_time_ms=(time.time() - start_time) * 1000, + warnings=warnings if warnings else None, + ) + + def on_cancelled(data): + nonlocal result + result = CopyResult( + status="failed", + error="Copy operation cancelled", + execution_time_ms=(time.time() - start_time) * 1000, + warnings=warnings if warnings else None, + ) + + # Register event handlers + sse_client.on(EventType.QUEUE_UPDATE.value, on_queue_update) + sse_client.on(EventType.OPERATION_PROGRESS.value, on_progress) + sse_client.on(EventType.OPERATION_COMPLETED.value, on_completed) + sse_client.on(EventType.OPERATION_ERROR.value, on_error) + sse_client.on(EventType.OPERATION_CANCELLED.value, on_cancelled) + + # Listen for events until completion or timeout + while time.time() < timeout_time: + sse_client.listen(timeout=1) # Process events for 1 second + + # Check if operation is complete + if result.status in ["completed", "failed", "partial"]: + break + + if time.time() >= timeout_time: + result = CopyResult( + status="failed", + error=f"Copy operation timeout after {timeout_ms}ms", + execution_time_ms=(time.time() - start_time) * 1000, + ) + + finally: + sse_client.close() + + return result + + def _build_copy_result( + self, response_data: CopyResponse, execution_time: float + ) -> CopyResult: + """Build copy result from response data""" + return CopyResult( + status=response_data.status.value, + rows_imported=response_data.rows_imported, + rows_skipped=response_data.rows_skipped, + bytes_processed=response_data.bytes_processed, + execution_time_ms=response_data.execution_time_ms or (execution_time * 1000), + warnings=response_data.warnings, + message=response_data.message, + error=str(response_data.error_details) if response_data.error_details else None, + ) + + def calculate_statistics(self, result: CopyResult) -> Optional[CopyStatistics]: + """Calculate copy statistics from result""" + if result.status == "failed" or not result.rows_imported: + return None + + total_rows = (result.rows_imported or 0) + (result.rows_skipped or 0) + duration = (result.execution_time_ms or 0) / 1000 # Convert to seconds + throughput = (result.rows_imported or 0) / duration if duration > 0 else 0 + + return CopyStatistics( + total_rows=total_rows, + imported_rows=result.rows_imported or 0, + skipped_rows=result.rows_skipped or 0, + bytes_processed=result.bytes_processed or 0, + duration=duration, + throughput=throughput, + ) + + def copy_s3( + self, + graph_id: str, + table_name: str, + s3_path: str, + access_key_id: str, + secret_access_key: str, + region: str = "us-east-1", + file_format: Optional[str] = None, + ignore_errors: bool = False, + ) -> CopyResult: + """Convenience method for simple S3 copy with default options""" + + # Map string format to enum + format_enum = S3CopyRequestFileFormat.PARQUET + if file_format: + format_map = { + "csv": S3CopyRequestFileFormat.CSV, + "parquet": S3CopyRequestFileFormat.PARQUET, + "json": S3CopyRequestFileFormat.JSON, + "delta": S3CopyRequestFileFormat.DELTA, + "iceberg": S3CopyRequestFileFormat.ICEBERG, + } + format_enum = format_map.get(file_format.lower(), S3CopyRequestFileFormat.PARQUET) + + request = S3CopyRequest( + table_name=table_name, + s3_path=s3_path, + s3_access_key_id=access_key_id, + s3_secret_access_key=secret_access_key, + s3_region=region, + file_format=format_enum, + ignore_errors=ignore_errors, + ) + + return self.copy_from_s3(graph_id, request) + + def monitor_multiple_copies( + self, operation_ids: List[str], options: Optional[CopyOptions] = None + ) -> Dict[str, CopyResult]: + """Monitor multiple copy operations concurrently""" + results = {} + for operation_id in operation_ids: + result = self._monitor_copy_operation( + operation_id, options or CopyOptions(), time.time() + ) + results[operation_id] = result + return results + + def batch_copy_from_s3( + self, graph_id: str, copies: List[Dict[str, Any]] + ) -> List[CopyResult]: + """Batch copy multiple tables from S3""" + results = [] + for copy_config in copies: + request = copy_config.get("request") + options = copy_config.get("options") + if request: + result = self.copy_from_s3(graph_id, request, options) + results.append(result) + return results + + def copy_with_retry( + self, + graph_id: str, + request: Union[S3CopyRequest, URLCopyRequest, DataFrameCopyRequest], + source_type: CopySourceType, + max_retries: int = 3, + options: Optional[CopyOptions] = None, + ) -> CopyResult: + """Copy with retry logic for transient failures""" + if options is None: + options = CopyOptions() + + last_error: Optional[Exception] = None + attempt = 0 + + while attempt < max_retries: + attempt += 1 + + try: + result = self._execute_copy(graph_id, request, source_type, options) + + # If successful or partially successful, return + if result.status in ["completed", "partial"]: + return result + + # If failed, check if it's retryable + if result.status == "failed": + is_retryable = self._is_retryable_error(result.error) + if not is_retryable or attempt == max_retries: + return result + + # Wait before retry with exponential backoff + wait_time = min(1000 * (2 ** (attempt - 1)), 30000) / 1000 + if options.on_progress: + options.on_progress( + f"Retrying copy operation (attempt {attempt}/{max_retries}) in {wait_time}s...", + None, + ) + time.sleep(wait_time) + + except Exception as e: + last_error = e + + if attempt == max_retries: + raise last_error + + # Wait before retry + wait_time = min(1000 * (2 ** (attempt - 1)), 30000) / 1000 + if options.on_progress: + options.on_progress( + f"Retrying after error (attempt {attempt}/{max_retries}) in {wait_time}s...", + None, + ) + time.sleep(wait_time) + + raise last_error or Exception("Copy operation failed after all retries") + + def _is_retryable_error(self, error: Optional[str]) -> bool: + """Check if an error is retryable""" + if not error: + return False + + retryable_patterns = [ + "timeout", + "network", + "connection", + "temporary", + "unavailable", + "rate limit", + "throttl", + ] + + lower_error = error.lower() + return any(pattern in lower_error for pattern in retryable_patterns) + + def close(self): + """Cancel any active SSE connections""" + if self.sse_client: + self.sse_client.close() + self.sse_client = None + + +class AsyncCopyClient: + """Async version of CopyClient for async/await usage""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.base_url = config["base_url"] + self.sse_client: Optional[AsyncSSEClient] = None + + # Get client authentication if provided + self.auth_token = config.get("auth_token") + self.api_key = config.get("api_key") + + async def copy_from_s3( + self, graph_id: str, request: S3CopyRequest, options: Optional[CopyOptions] = None + ) -> CopyResult: + """Copy data from S3 to graph database asynchronously""" + # Async implementation would go here + # For now, this is a placeholder + raise NotImplementedError("Async copy client not yet implemented") + + async def close(self): + """Close any active connections""" + if self.sse_client: + await self.sse_client.close() + self.sse_client = None diff --git a/robosystems_client/extensions/extensions.py b/robosystems_client/extensions/extensions.py index 7e31898..bc85250 100644 --- a/robosystems_client/extensions/extensions.py +++ b/robosystems_client/extensions/extensions.py @@ -8,6 +8,7 @@ from .query_client import QueryClient from .operation_client import OperationClient +from .copy_client import CopyClient from .sse_client import SSEClient @@ -39,6 +40,7 @@ def __init__(self, config: RoboSystemsExtensionConfig = None): } # Initialize clients + self.copy = CopyClient(self.config) self.query = QueryClient(self.config) self.operations = OperationClient(self.config) @@ -67,6 +69,7 @@ def create_sse_client(self) -> SSEClient: def close(self): """Clean up all active connections""" + self.copy.close() self.query.close() self.operations.close_all() @@ -93,6 +96,20 @@ def cancel_operation(self, operation_id: str): """Cancel an operation using the operation client""" return self.operations.cancel_operation(operation_id) + def copy_from_s3( + self, + graph_id: str, + table_name: str, + s3_path: str, + access_key_id: str, + secret_access_key: str, + **kwargs, + ): + """Copy data from S3 using the copy client""" + return self.copy.copy_s3( + graph_id, table_name, s3_path, access_key_id, secret_access_key, **kwargs + ) + class AsyncRoboSystemsExtensions: """Async version of the extensions class""" diff --git a/robosystems_client/models/copy_response.py b/robosystems_client/models/copy_response.py index 2929606..dfe19e0 100644 --- a/robosystems_client/models/copy_response.py +++ b/robosystems_client/models/copy_response.py @@ -21,8 +21,10 @@ class CopyResponse: Attributes: status (CopyResponseStatus): Operation status source_type (str): Type of source that was copied from - execution_time_ms (float): Total execution time in milliseconds message (str): Human-readable status message + operation_id (Union[None, Unset, str]): Operation ID for SSE monitoring (for long-running operations) + sse_url (Union[None, Unset, str]): SSE endpoint URL for monitoring operation progress + execution_time_ms (Union[None, Unset, float]): Total execution time in milliseconds (for synchronous operations) rows_imported (Union[None, Unset, int]): Number of rows successfully imported rows_skipped (Union[None, Unset, int]): Number of rows skipped due to errors (when ignore_errors=true) warnings (Union[None, Unset, list[str]]): List of warnings encountered during import @@ -33,8 +35,10 @@ class CopyResponse: status: CopyResponseStatus source_type: str - execution_time_ms: float message: str + operation_id: Union[None, Unset, str] = UNSET + sse_url: Union[None, Unset, str] = UNSET + execution_time_ms: Union[None, Unset, float] = UNSET rows_imported: Union[None, Unset, int] = UNSET rows_skipped: Union[None, Unset, int] = UNSET warnings: Union[None, Unset, list[str]] = UNSET @@ -51,10 +55,26 @@ def to_dict(self) -> dict[str, Any]: source_type = self.source_type - execution_time_ms = self.execution_time_ms - message = self.message + operation_id: Union[None, Unset, str] + if isinstance(self.operation_id, Unset): + operation_id = UNSET + else: + operation_id = self.operation_id + + sse_url: Union[None, Unset, str] + if isinstance(self.sse_url, Unset): + sse_url = UNSET + else: + sse_url = self.sse_url + + execution_time_ms: Union[None, Unset, float] + if isinstance(self.execution_time_ms, Unset): + execution_time_ms = UNSET + else: + execution_time_ms = self.execution_time_ms + rows_imported: Union[None, Unset, int] if isinstance(self.rows_imported, Unset): rows_imported = UNSET @@ -96,10 +116,15 @@ def to_dict(self) -> dict[str, Any]: { "status": status, "source_type": source_type, - "execution_time_ms": execution_time_ms, "message": message, } ) + if operation_id is not UNSET: + field_dict["operation_id"] = operation_id + if sse_url is not UNSET: + field_dict["sse_url"] = sse_url + if execution_time_ms is not UNSET: + field_dict["execution_time_ms"] = execution_time_ms if rows_imported is not UNSET: field_dict["rows_imported"] = rows_imported if rows_skipped is not UNSET: @@ -124,10 +149,35 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: source_type = d.pop("source_type") - execution_time_ms = d.pop("execution_time_ms") - message = d.pop("message") + def _parse_operation_id(data: object) -> Union[None, Unset, str]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, str], data) + + operation_id = _parse_operation_id(d.pop("operation_id", UNSET)) + + def _parse_sse_url(data: object) -> Union[None, Unset, str]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, str], data) + + sse_url = _parse_sse_url(d.pop("sse_url", UNSET)) + + def _parse_execution_time_ms(data: object) -> Union[None, Unset, float]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, float], data) + + execution_time_ms = _parse_execution_time_ms(d.pop("execution_time_ms", UNSET)) + def _parse_rows_imported(data: object) -> Union[None, Unset, int]: if data is None: return data @@ -194,8 +244,10 @@ def _parse_bytes_processed(data: object) -> Union[None, Unset, int]: copy_response = cls( status=status, source_type=source_type, - execution_time_ms=execution_time_ms, message=message, + operation_id=operation_id, + sse_url=sse_url, + execution_time_ms=execution_time_ms, rows_imported=rows_imported, rows_skipped=rows_skipped, warnings=warnings, diff --git a/robosystems_client/models/copy_response_status.py b/robosystems_client/models/copy_response_status.py index 5fc6134..4e15f1c 100644 --- a/robosystems_client/models/copy_response_status.py +++ b/robosystems_client/models/copy_response_status.py @@ -2,6 +2,7 @@ class CopyResponseStatus(str, Enum): + ACCEPTED = "accepted" COMPLETED = "completed" FAILED = "failed" PARTIAL = "partial" diff --git a/robosystems_client/models/s3_copy_request.py b/robosystems_client/models/s3_copy_request.py index f55ed1a..f3f2cf6 100644 --- a/robosystems_client/models/s3_copy_request.py +++ b/robosystems_client/models/s3_copy_request.py @@ -15,28 +15,31 @@ class S3CopyRequest: r"""Request model for S3 copy operations. - Attributes: - table_name (str): Target Kuzu table name - s3_path (str): Full S3 path (s3://bucket/key or s3://bucket/prefix/*.parquet) - s3_access_key_id (str): AWS access key ID for S3 access - s3_secret_access_key (str): AWS secret access key for S3 access - ignore_errors (Union[Unset, bool]): Skip duplicate/invalid rows (enables upsert-like behavior) Default: True. - extended_timeout (Union[Unset, bool]): Use extended timeout for large datasets Default: False. - validate_schema (Union[Unset, bool]): Validate source schema against target table Default: True. - source_type (Union[Literal['s3'], Unset]): Source type identifier Default: 's3'. - s3_session_token (Union[None, Unset, str]): AWS session token (for temporary credentials) - s3_region (Union[None, Unset, str]): S3 region Default: 'us-east-1'. - s3_endpoint (Union[None, Unset, str]): Custom S3 endpoint (for S3-compatible storage) - s3_url_style (Union[None, S3CopyRequestS3UrlStyleType0, Unset]): S3 URL style (vhost or path) - file_format (Union[Unset, S3CopyRequestFileFormat]): File format of the S3 data Default: - S3CopyRequestFileFormat.PARQUET. - csv_delimiter (Union[None, Unset, str]): CSV delimiter Default: ','. - csv_header (Union[None, Unset, bool]): CSV has header row Default: True. - csv_quote (Union[None, Unset, str]): CSV quote character Default: '\\"'. - csv_escape (Union[None, Unset, str]): CSV escape character Default: '\\'. - csv_skip (Union[None, Unset, int]): Number of rows to skip Default: 0. - allow_moved_paths (Union[None, Unset, bool]): Allow moved paths for Iceberg tables Default: False. - max_file_size_gb (Union[None, Unset, int]): Maximum total file size limit in GB Default: 10. + Copies data from S3 buckets into graph database tables using user-provided + AWS credentials. Supports various file formats and bulk loading options. + + Attributes: + table_name (str): Target Kuzu table name + s3_path (str): Full S3 path (s3://bucket/key or s3://bucket/prefix/*.parquet) + s3_access_key_id (str): AWS access key ID for S3 access + s3_secret_access_key (str): AWS secret access key for S3 access + ignore_errors (Union[Unset, bool]): Skip duplicate/invalid rows (enables upsert-like behavior) Default: True. + extended_timeout (Union[Unset, bool]): Use extended timeout for large datasets Default: False. + validate_schema (Union[Unset, bool]): Validate source schema against target table Default: True. + source_type (Union[Literal['s3'], Unset]): Source type identifier Default: 's3'. + s3_session_token (Union[None, Unset, str]): AWS session token (for temporary credentials) + s3_region (Union[None, Unset, str]): S3 region Default: 'us-east-1'. + s3_endpoint (Union[None, Unset, str]): Custom S3 endpoint (for S3-compatible storage) + s3_url_style (Union[None, S3CopyRequestS3UrlStyleType0, Unset]): S3 URL style (vhost or path) + file_format (Union[Unset, S3CopyRequestFileFormat]): File format of the S3 data Default: + S3CopyRequestFileFormat.PARQUET. + csv_delimiter (Union[None, Unset, str]): CSV delimiter Default: ','. + csv_header (Union[None, Unset, bool]): CSV has header row Default: True. + csv_quote (Union[None, Unset, str]): CSV quote character Default: '\\"'. + csv_escape (Union[None, Unset, str]): CSV escape character Default: '\\'. + csv_skip (Union[None, Unset, int]): Number of rows to skip Default: 0. + allow_moved_paths (Union[None, Unset, bool]): Allow moved paths for Iceberg tables Default: False. + max_file_size_gb (Union[None, Unset, int]): Maximum total file size limit in GB Default: 10. """ table_name: str