From 388ae17a1f9484c7c1caa3cd2e181180f43a3434 Mon Sep 17 00:00:00 2001 From: Venkata Ramireddy Seelam Date: Wed, 20 May 2026 14:57:29 -0700 Subject: [PATCH] fix(client): normalize WebSocket URI path to prevent double suffix (#737) _get_websocket_uri previously concatenated /task/service unconditionally, producing double-slash or a duplicated path segment when the caller passed a URI with a trailing slash or an already-correct path. Now strips trailing slashes then appends the suffix only when absent, and extracts the suffix string into a module-level constant. --- .../src/rocketride/mixins/connection.py | 529 +++++++++--------- 1 file changed, 271 insertions(+), 258 deletions(-) diff --git a/packages/client-python/src/rocketride/mixins/connection.py b/packages/client-python/src/rocketride/mixins/connection.py index 2d2f15383..dd72e6020 100644 --- a/packages/client-python/src/rocketride/mixins/connection.py +++ b/packages/client-python/src/rocketride/mixins/connection.py @@ -24,25 +24,38 @@ Connection Management for RocketRide Client. This module handles connecting to and disconnecting from RocketRide servers. -It manages the WebSocket connection lifecycle, authentication, and status tracking. +It manages the WebSocket connection lifecycle and provides status checking. The connection system automatically handles: - WebSocket connection establishment -- Authentication with your credential (API key, Zitadel access_token, or rr_ user token) +- Authentication with your API key - Connection status tracking - Automatic reconnection on disconnects (when persist=True) - Graceful disconnection and cleanup Usage: - client = RocketRideClient(uri="http://localhost:8080") - result = await client.connect("your_api_key") - # result is a ConnectResult with full user identity and organizations + # Manual connection management + client = RocketRideClient(auth="your_api_key", uri="https://cloud.rocketride.ai") + await client.connect() + # Check if connected if client.is_connected(): # Do work with the client pass await client.disconnect() + + # Automatic connection management (recommended) + async with RocketRideClient(auth="your_api_key") as client: + # Client automatically connects here + # Do work with connected client + pass + # Client automatically disconnects here + + # Persistent connection with auto-reconnect + client = RocketRideClient(auth="your_api_key", persist=True) + await client.connect() + # Connection will automatically reconnect if dropped (exponential backoff) """ # Design: Physical connect/disconnect live in _internal_connect and _internal_disconnect. @@ -51,12 +64,13 @@ # reconnect-on-disconnect are scheduled via _schedule_reconnect / _attempt_reconnect. import asyncio -import os +import time import urllib.parse from typing import Any, Dict, Optional from ..core import DAPClient, TransportWebSocket, CONST_DEFAULT_WEB_PORT, CONST_DEFAULT_WEB_PROTOCOL from ..core.exceptions import AuthenticationException -from ..types.client import ConnectResult + +_TASK_SERVICE_SUFFIX = "/task/service" class ConnectionMixin(DAPClient): @@ -69,10 +83,15 @@ class ConnectionMixin(DAPClient): Key Features: - Establishes secure WebSocket connections to RocketRide servers - - Single connect(credential) call authenticates and returns ConnectResult + - Authenticates using your API key or access token - Tracks connection status for reliable operations - Automatic reconnection on disconnect (when persist=True) - Provides graceful connection cleanup + - Supports both manual and automatic connection management + + This is automatically included when you use RocketRideClient, so you can + call client.connect() and client.disconnect() directly without needing + to import this mixin. """ def __init__(self, persist: bool = False, max_retry_time: Optional[float] = None, **kwargs): @@ -80,288 +99,224 @@ def __init__(self, persist: bool = False, max_retry_time: Optional[float] = None Initialize connection management. Args: - persist: Enable automatic reconnection on disconnect. - max_retry_time: Deprecated — accepted but ignored. Reconnection - uses linear backoff (0.25s increments, 15s cap) and never gives up. - **kwargs: Additional arguments passed to parent class. + persist: Enable automatic reconnection on disconnect + max_retry_time: Max total time in ms to keep retrying (None = forever) + **kwargs: Additional arguments passed to parent class """ super().__init__(**kwargs) self._persist = persist - # Desired state model — replaces old flag soup - self._desired_state: str = 'detached' # 'detached' | 'attached' | 'authenticated' - self._authenticated: bool = False - self._reconnect_timer: Optional[asyncio.Task] = None - self._current_reconnect_delay: float = 0.25 # seconds; +0.25 per failure, cap 15s + self._max_retry_time = max_retry_time # ms; None = retry forever + self._retry_start_time: Optional[float] = None # when first failure occurred; used to enforce max_retry_time + self._current_reconnect_delay: float = 0.25 # seconds until next retry; doubled each failure, capped at 2.5s + self._manual_disconnect = False # True only after user calls disconnect(); stops on_disconnected from scheduling reconnect + self._reconnect_task: Optional[asyncio.Task] = None # task that sleeps then calls _attempt_connection + self._did_notify_connected = False # True after we called on_connected; gates whether we invoke user on_disconnected async def on_connected(self, connection_info: Optional[str] = None) -> None: - """Handle transport-level connection event (before auth).""" - await super().on_connected(connection_info) - - async def on_disconnected(self, reason: Optional[str] = None, has_error: bool = False) -> None: """ - Handle transport disconnection. + Handle connection established event. - Clears transport and auth state, chains to parent, then consults - ``_desired_state`` to decide whether to reconnect. + Resets manual disconnect flag and delegates to parent. """ - self._transport = None - self._connect_result = None - self._authenticated = False + # We just connected successfully; clear "user asked to disconnect" so future drops can trigger reconnect + self._manual_disconnect = False - await super().on_disconnected(reason, has_error) - - # Reconnect engine: honour _desired_state - if self._desired_state == 'detached': - return - if not self._persist: - self._desired_state = 'detached' - return - if self._reconnect_timer and not self._reconnect_timer.done(): - return # engine already active + # Record that we notified connected so on_disconnected only invokes user callback if we had connected + self._did_notify_connected = True + # Reset backoff so the next reconnect (if any) starts from the initial delay self._current_reconnect_delay = 0.25 - self._schedule_reconnect() + self._retry_start_time = None + await super().on_connected(connection_info) - # ========================================================================= - # INTERNAL HELPERS - # ========================================================================= + async def on_disconnected(self, reason: Optional[str] = None, has_error: bool = False) -> None: + """ + Handle disconnection event. - async def _internal_attach(self, timeout: Optional[float] = None) -> None: - """Create transport if needed and open the WebSocket. No auth.""" + Only invokes the user's on_disconnected if on_connected had previously been called. + Schedules reconnection if persist is enabled and not a manual disconnect. + """ + # Only tell the user we disconnected if we had previously told them we connected + if self._did_notify_connected: + self._did_notify_connected = False + await super().on_disconnected(reason, has_error) + + # Transport called us because the connection closed. If user didn't ask to disconnect + # and we're in persist mode, schedule a reconnect (after backoff delay). + if self._persist and not self._manual_disconnect: + await self._schedule_reconnect() + + # --- Single place for physical connect: create transport if needed, connect, auth, on_connected --- + async def _internal_connect(self, timeout: Optional[float] = None) -> None: + """ + Create transport if needed, connect, send auth, and notify on_connected. + Single place for physical connection. Raises on failure. + """ + # Reuse existing transport if we have one (e.g. retry after failure); otherwise create with current uri/auth if self._transport is None: - self._transport = TransportWebSocket(uri=self._uri) + self._transport = TransportWebSocket(uri=self._uri, auth=self._apikey) self._bind_transport(self._transport) - await DAPClient.connect(self, timeout) - async def _internal_login(self, timeout: Optional[float] = None) -> Dict[str, Any]: - """Send the ``auth`` DAP command over the open transport.""" - auth_args: Dict[str, Any] = {'auth': self._apikey or ''} - if getattr(self, '_client_display_name', None): - auth_args['clientName'] = self._client_display_name - if getattr(self, '_client_display_version', None): - auth_args['clientVersion'] = self._client_display_version - - request = { - 'type': 'request', - 'command': 'auth', - 'seq': self._next_seq(), - 'arguments': auth_args, - } - try: - response = await self.request(request, timeout=timeout) - except Exception: - raise - if not response.get('success', False): - message = response.get('message', 'Authentication failed') - raise AuthenticationException({'message': message}) - - auth_body = response.get('body') or {} - self._connect_result = auth_body # type: ignore[assignment] - self._authenticated = True - - # Store userToken for future reconnects - if auth_body.get('userToken'): - self._apikey = auth_body['userToken'] - - connection_info = self._transport.get_connection_info() if self._transport else None - await super().on_connected(connection_info) - return auth_body + # DAPClient.connect does: transport.connect() (socket), then auth request, then on_connected() + await DAPClient.connect(self, timeout) - async def _internal_logout(self) -> None: - """Send ``deauth`` DAP command to revert to unauthenticated.""" - if not self._authenticated or not self._transport or not self._transport.is_connected(): - return - try: - request = {'type': 'request', 'command': 'deauth', 'seq': self._next_seq(), 'arguments': {}} - await self.request(request) - except Exception: - pass # Best-effort - self._connect_result = None - self._authenticated = False - - async def _internal_disconnect(self) -> None: - """Close and clean up the transport.""" + # --- Single place for physical disconnect: close transport; it will call on_disconnected --- + async def _internal_disconnect(self, reason: Optional[str] = None, has_error: bool = False) -> None: + """ + Close and clean up the transport. Transport invokes on_disconnected when it closes. + Single place for physical disconnect. + """ if self._transport is None: return - await self._transport.disconnect() - def _clear_reconnect_timer(self) -> None: - """Cancel the reconnect task if active.""" - if self._reconnect_timer and not self._reconnect_timer.done(): - self._reconnect_timer.cancel() - self._reconnect_timer = None + # Transport will close the socket and then call our on_disconnected callback + await self._transport.disconnect(reason, has_error) - def _schedule_reconnect(self) -> None: - """Schedule a reconnect attempt driven by ``_desired_state``.""" - self._debug_message(f'Scheduling reconnect in {self._current_reconnect_delay}s') - self._reconnect_timer = asyncio.create_task(self._do_reconnect()) - - async def _do_reconnect(self) -> None: - """Reconnect engine: sleep, re-attach, optionally re-login.""" - await asyncio.sleep(self._current_reconnect_delay) + # --- Persist-mode: one attempt; on failure notify and maybe reschedule with backoff --- + async def _attempt_connection(self, timeout: Optional[float] = None) -> None: + """ + Try _internal_connect; on auth error notify and stop; on other error notify and reschedule with backoff. + Used by persist-mode connect() and by the reconnect task. + """ try: - await self._internal_attach() - if self._desired_state == 'detached': - self._reconnect_timer = None - return - - if self._desired_state == 'authenticated': - await self._internal_login() - if self._desired_state == 'detached': - self._reconnect_timer = None - return - - # Success — reset backoff - self._reconnect_timer = None - self._current_reconnect_delay = 0.25 - self._debug_message('Reconnect successful') + await self._internal_connect(timeout) + # on_connected (invoked by _internal_connect) already resets backoff and retry clock + self._reconnect_task = None # clear completed task reference + self._debug_message('Reconnection successful') except AuthenticationException as e: - # Auth rejected — downgrade to attached, stop retrying auth - if self._desired_state == 'detached': - self._reconnect_timer = None - return - self._desired_state = 'attached' - self._reconnect_timer = None + self._debug_message(f'Reconnection failed (auth): {e}') await self.on_connect_error(e) + + # Auth failures won't fix themselves; don't reschedule + return except Exception as e: - if self._desired_state == 'detached': - self._reconnect_timer = None - return - # Transient failure — linear backoff, cap at 15s - self._current_reconnect_delay = min(self._current_reconnect_delay + 0.25, 15.0) + self._debug_message(f'Reconnection failed: {e}') await self.on_connect_error(e) - self._schedule_reconnect() # replaces timer - # ========================================================================= - # PUBLIC API — TRANSPORT - # ========================================================================= + # Start the retry clock on first failure so we can enforce max_retry_time + if self._retry_start_time is None: + self._retry_start_time = time.monotonic() - async def attach(self, uri: Optional[str] = None, *, timeout: Optional[float] = None) -> None: - """ - Attach to a RocketRide server (open WebSocket, no auth). + # Stop retrying if we've exceeded the total retry window + if self._max_retry_time is not None: + if time.monotonic() - self._retry_start_time >= self._max_retry_time / 1000.0: + return - If ``uri`` is provided and differs from the current URI, detaches - first. If already attached to the same URI, this is a no-op. - """ - if uri: - normalised = self._get_websocket_uri(uri) if hasattr(self, '_get_websocket_uri') else uri - if normalised != self._uri: - if self.is_attached(): - await self.detach() - self._set_uri(normalised) - if self.is_attached(): - if self._desired_state == 'detached': - self._desired_state = 'attached' - return - self._desired_state = 'attached' - await self._internal_attach(timeout) - - async def detach(self) -> None: - """Detach from the server (close WebSocket, cancel reconnection).""" - self._desired_state = 'detached' - self._clear_reconnect_timer() - self._authenticated = False - self._connect_result = None - if self._transport and self._transport.is_connected(): - await self._internal_disconnect() + # Exponential backoff: next attempt will wait longer (cap at 2.5s) + self._current_reconnect_delay = min(self._current_reconnect_delay * 2, 2.5) + await self._schedule_reconnect() + + async def _schedule_reconnect(self) -> None: + """Schedule a reconnection attempt with exponential backoff.""" + # Only one reconnect task at a time; cancel any existing one before scheduling again + if self._reconnect_task and not self._reconnect_task.done(): + self._reconnect_task.cancel() + + # If we've been retrying longer than max_retry_time, give up and notify the user + if self._max_retry_time is not None and self._retry_start_time is not None: + if time.monotonic() - self._retry_start_time >= self._max_retry_time / 1000.0: + await self.on_connect_error(Exception('Max retry time exceeded')) + return - def is_attached(self) -> bool: - """True when the WebSocket transport is connected (regardless of auth).""" - return self._transport is not None and self._transport.is_connected() + # Run _attempt_reconnect after _current_reconnect_delay seconds (backoff) + self._debug_message(f'Scheduling reconnection in {self._current_reconnect_delay}s') + self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) - # ========================================================================= - # PUBLIC API — AUTH - # ========================================================================= + async def _attempt_reconnect(self) -> None: + """Sleep then call _attempt_connection (used by scheduled reconnect).""" + # Wait before retrying so we don't hammer the server (exponential backoff) + await asyncio.sleep(self._current_reconnect_delay) + + # User may have called disconnect() while we were sleeping; only try if still persist and not manual disconnect + if self._persist and not self._manual_disconnect: + self._debug_message('Attempting to reconnect...') + await self._attempt_connection() - async def login( + async def connect( self, - credential: Optional[str] = None, - *, uri: Optional[str] = None, + auth: Optional[str] = None, timeout: Optional[float] = None, - ) -> ConnectResult: + ) -> None: """ - Authenticate over an attached transport. + Connect to the RocketRide server. - If ``uri`` is provided and differs, detaches and re-attaches first. - If ``credential`` is provided and differs from the current credential, - logs out (best-effort) before logging in with the new credential. - If already authenticated with the same credential, this is a no-op. - """ - resolved = credential or self._apikey or os.environ.get('ROCKETRIDE_APIKEY', '') - - # URI change → detach + re-attach - if uri: - normalised = self._get_websocket_uri(uri) if hasattr(self, '_get_websocket_uri') else uri - if normalised != self._uri: - await self.detach() - self._set_uri(normalised) - await self._internal_attach(timeout) - - # Ensure attached - if not self.is_attached(): - await self._internal_attach(timeout) - - # Auth change → logout first (best-effort) - if resolved != self._apikey and self._authenticated: + Must be called before executing pipelines or other operations. + In persist mode, enables automatic reconnection and retries from initial failure + (calls on_connect_error on each failed attempt and keeps retrying). + + Args: + uri: Optional; if provided, updates the server URI before connecting. + auth: Optional; if provided, updates the API key before connecting. + timeout: Optional overall timeout in ms for the connect + auth handshake. + + Examples: + # Manual connection management + await client.connect() try: - await self._internal_logout() - except Exception: + # do work pass - self._set_auth(resolved) - - # Already authenticated with same credential → no-op - if self._authenticated: - self._desired_state = 'authenticated' - return self._connect_result or {} # type: ignore[return-value] + finally: + await client.disconnect() - self._desired_state = 'authenticated' - return await self._internal_login(timeout) + # Automatic connection management (preferred) + async with client: + # connection automatically managed + pass + """ + # Apply optional params so they're used for this connect (and by any new transport we create) + if uri is not None: + self._set_uri(uri) + if auth is not None: + self._set_auth(auth) + + # Fresh connect: we're not in "user asked to disconnect" state, and backoff starts from initial delay + self._manual_disconnect = False + self._current_reconnect_delay = 0.25 + self._retry_start_time = None - async def logout(self) -> None: - """Deauthenticate: sends ``deauth`` to server, clears client auth state.""" - await self._internal_logout() - self._desired_state = 'attached' + # Idempotent connect: if already connected, disconnect first so we reconnect with current (maybe new) params + if self.is_connected(): + await self._internal_disconnect() - def is_authenticated(self) -> bool: - """True when the auth handshake has succeeded on the current connection.""" - return self._authenticated + if self._persist: + # Cancel any pending reconnect from a previous drop; we're doing an explicit connect now + if self._reconnect_task and not self._reconnect_task.done(): + self._reconnect_task.cancel() + self._reconnect_task = None - # ========================================================================= - # COMPAT API — connect() / disconnect() - # ========================================================================= + # First attempt runs here; if it fails, _attempt_connection will schedule the next try + await self._attempt_connection(timeout) + else: + # Non-persist: one shot; no retry scheduling + await self._internal_connect(timeout) - async def connect( - self, - credential: Optional[str] = None, - *, - timeout: Optional[float] = None, - ) -> ConnectResult: + async def disconnect(self) -> None: """ - Connect and authenticate in a single call (backward compatible). + Disconnect from the RocketRide server and stop automatic reconnection. - Wraps ``attach()`` + ``login()``. Sends the credential as the first - DAP message and returns the full ConnectResult on success. + Should be called when finished with the client to clean up resources. + Context managers handle this automatically. """ - self._current_reconnect_delay = 0.25 - await self.attach(timeout=timeout) - return await self.login(credential, timeout=timeout) - - def get_account_info(self) -> Optional[ConnectResult]: - """Return the ConnectResult from the last successful login().""" - return self._connect_result - - async def disconnect(self) -> None: - """Disconnect (backward compatible). Wraps ``logout()`` + ``detach()``.""" - await self.logout() - await self.detach() - - # ========================================================================= - # HELPERS - # ========================================================================= + # Set before we disconnect so that when the transport closes and calls on_disconnected, + # we won't schedule a reconnect (user explicitly asked to disconnect) + self._manual_disconnect = True + + # Stop any scheduled reconnect; user said disconnect + if self._reconnect_task and not self._reconnect_task.done(): + self._reconnect_task.cancel() + self._reconnect_task = None + if self._transport is not None and self.is_connected(): + await self._internal_disconnect() def get_connection_info(self) -> dict: - """Return current connection state and URI.""" + """ + Return current connection state and URI. + + Returns a dict with ``connected`` (bool), ``transport`` (str), + and ``uri`` (str). Useful for debugging or displaying + "Connected to …" in the UI. + """ return { 'connected': self.is_connected(), 'transport': 'WebSocket', @@ -369,40 +324,55 @@ def get_connection_info(self) -> dict: } def get_apikey(self) -> Optional[str]: - """Return the API key in use. For debugging only.""" + """ + Return the API key in use. + + For debugging only; avoid logging or exposing in production. + """ return getattr(self, '_apikey', None) @staticmethod def normalize_uri(uri: str) -> str: - """Normalize a user-provided URI into a fully-formed HTTP/HTTPS URL.""" + """Normalize a user-provided URI into a fully-formed HTTP/HTTPS URL. + + - Bare hostnames (e.g. "localhost", "my-server:5565") get ``http://`` prepended. + - Non-cloud URIs without a port default to 5565. + + Use this when you need a parseable URL from free-form user input before + passing it to the client or doing your own validation. + """ if uri and '://' not in uri: uri = f'{CONST_DEFAULT_WEB_PROTOCOL}{uri}' parsed = urllib.parse.urlparse(uri) if not parsed.port and 'rocketride.ai' not in (parsed.hostname or ''): - hostname = parsed.hostname - if not hostname: - raise ValueError(f"Invalid URI '{uri}': missing hostname") - parsed = parsed._replace(netloc=f'{hostname}:{CONST_DEFAULT_WEB_PORT}') + parsed = parsed._replace(netloc=f'{parsed.hostname}:{CONST_DEFAULT_WEB_PORT}') return parsed.geturl() @staticmethod - def _get_websocket_uri(uri: str, ws_path: str = '/task/service') -> str: + def _get_websocket_uri(uri: str) -> str: """Normalize a user-provided URI into a fully-formed WebSocket address. - Args: - uri: Raw URI (bare host:port, http://, https://, ws://, wss://). - ws_path: WebSocket endpoint path (default: '/task/service'). - Use '/models' for the model server. + Builds on normalize_uri, then converts to ws/wss and appends /task/service + only when the path does not already end with it. + + Fixes #737: previously the suffix was unconditionally concatenated, causing + double-slash when the input had a trailing slash, and path duplication when + /task/service was already present. """ normalized = ConnectionMixin.normalize_uri(uri) parsed = urllib.parse.urlparse(normalized) ws_scheme = 'wss' if parsed.scheme in ('https', 'wss') else 'ws' - normalized_ws_path = f'/{ws_path.lstrip("/")}' - ws_uri = parsed._replace(scheme=ws_scheme, path=normalized_ws_path, params='', query='', fragment='') + + # Normalize the path: strip trailing slashes, then append suffix if absent. + path = parsed.path.rstrip('/') + if not path.endswith(_TASK_SERVICE_SUFFIX): + path = path + _TASK_SERVICE_SUFFIX + + ws_uri = parsed._replace(scheme=ws_scheme, path=path) return ws_uri.geturl() def _set_uri(self, uri: str) -> None: @@ -413,10 +383,53 @@ def _set_auth(self, auth: str) -> None: """Update the authentication credential (internal).""" self._apikey = auth - def set_env(self, env: Dict[str, str]) -> None: - """Update the environment variables used for pipeline substitution.""" - self._env = dict(env) + async def set_connection_params( + self, + uri: Optional[str] = None, + auth: Optional[str] = None, + ) -> None: + """ + Update server URI and/or auth. If currently connected, disconnects and + reconnects with the new params. In persist mode, reconnection is scheduled + only if we were connected (no auto-connect when params are set on a never-connected client). + In non-persist mode, reconnects only if we were connected. + """ + # --- Update params, tear down existing connection/transport, then reconnect (or schedule) only if appropriate --- + if uri is not None: + self._set_uri(uri) + if auth is not None: + self._set_auth(auth) + + # Remember whether we were connected so we know to disconnect and whether to reconnect at the end + was_already_connected = self.is_connected() + + # Prevent on_disconnected (from the disconnect below) from scheduling a reconnect during teardown + self._manual_disconnect = True + if self._reconnect_task and not self._reconnect_task.done(): + self._reconnect_task.cancel() + self._reconnect_task = None + if was_already_connected: + await self._internal_disconnect() + + # Drop the transport so the next connect() builds a new one with the new uri/auth + self._transport = None + if self._persist and was_already_connected: + # Schedule a single reconnect attempt (after backoff); only if we were connected (no auto-connect on param set) + await self._schedule_reconnect() + elif was_already_connected: + # Non-persist: only reconnect if we had been connected (same as connect() semantics) + await self._internal_connect() + + # We're done; clear so future disconnects (e.g. drop) can trigger reconnect if persist + self._manual_disconnect = False async def request(self, request: Dict[str, Any], timeout: Optional[float] = None) -> Dict[str, Any]: - """Send a request to the RocketRide server.""" + """ + Send a request to the RocketRide server. + + Args: + request: The DAP message to send + timeout: Optional per-request timeout in ms. Overrides the default request_timeout. + """ + # Delegate to parent class for actual request processing return await super().request(request, timeout=timeout)