diff --git a/.gitignore b/.gitignore index 8ee7582a..1048e009 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__ ibind.egg-info build/docs/.generated-files.txt build/docs/mkdocs.yml +build/lib/ .venv .env @@ -11,4 +12,4 @@ build/docs/mkdocs.yml .vscode .DS_Store venv -.coverage \ No newline at end of file +.coverage diff --git a/TESTING.md b/TESTING.md new file mode 100644 index 00000000..6b503fe1 --- /dev/null +++ b/TESTING.md @@ -0,0 +1,88 @@ +# Testing Guide + +This document defines how tests are chosen, written, and evaluated. + +--- + +## Testing philosophy + +- Use commentary with `## Arrange` `## Act` and `## Assert` sections for test structure. +- Tests exist to lock behaviour, not to chase coverage. +- Avoid tests that duplicate what the language/runtime already guarantees. +- Use fixtures for setup/teardown of test state. +- Mock internal dependencies for unit tests. Mock external dependencies for integration tests. +- Prefer integration tests for verifying component boundaries and data flow. +- Capture and assert on logs for error and warning conditions. + + +## Test Type Structure + +Tests are organized into three main categories: + +``` +test/ +├── unit/ # Fast, isolated tests for core logic +├── integration/ # Multi-component tests +└── manual/ # Manual and performance tests +``` + + +## Test types and boundaries + +Use the lightest test type that still provides confidence. + +### Unit tests +Use when: +- logic is isolated and deterministic +- behaviour can be validated without wiring other components + +Guidelines: +- No network, filesystem, threads, or time dependence. +- Mock only at clear boundaries; do not mock internals of the unit under test. +- Data should be small, synthetic, and explicit. +- Prefer clarity over clever parametrisation. + +### Integration tests +Use when: +- correctness depends on interaction between components +- data flow or ordering matters + +Guidelines: +- Mock only outside the test boundary (eg. broker, network). +- Use realistic but minimal fixtures. +- Allow threads/timers only if they are part of the behaviour being tested. +- Failures should clearly indicate which interaction broke. + +### Manual / performance tests +Use when: +- validating full-system flows +- measuring throughput, latency, or concurrency +- interacting with real or near-real external systems + +Guidelines: +- Never run automatically in CI. +- Keep secrets out of test code. +- Prefer recorded or replayable inputs where possible. +- Treat results as diagnostic, not pass/fail gates. + + +## Choosing what to test + +Test: +- decision logic +- state transitions +- boundary conditions +- error and warning paths +- behaviour that has broken before + +Do not test: +- trivial getters/setters +- pure delegation +- obvious library behaviour +- formatting or logging text unless it signals correctness + + +## Running tests + +- Prefer running the smallest relevant subset while iterating. +- Run broader suites when touching core or high-risk code. diff --git a/examples/rest_06_options_chain.py b/examples/rest_06_options_chain.py index 510e922f..3f11d8af 100644 --- a/examples/rest_06_options_chain.py +++ b/examples/rest_06_options_chain.py @@ -136,4 +136,4 @@ response = client.place_order(order_request, answers, account_id).data -print(response) \ No newline at end of file +print(response) diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py new file mode 100644 index 00000000..534ad33b --- /dev/null +++ b/examples/ws_04_ws_v2.py @@ -0,0 +1,117 @@ +""" +WebSocket Intermediate + +In this example we: + +* Demonstrate subscription to multiple channels +* Utilise queue accessors +* Use the 'signal' module to ensure we unsubscribe and shutdown upon the program termination + +Assumes the Gateway is deployed at 'localhost:5000' and the IBIND_ACCOUNT_ID and IBIND_CACERT environment variables have been set. +""" + +import os +import time +from typing import List + +from ibind import events, IbkrWsClientV2, LogSink, QueueSink, CallbackSink, CompositeSink, ibind_logs_initialize +from ibind.subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription, SubscriptionHandle + +ibind_logs_initialize(log_to_file=False, log_level='INFO') + +account_id = os.getenv('IBIND_ACCOUNT_ID', '[YOUR_ACCOUNT_ID]') +cacert = os.getenv('IBIND_CACERT', False) # insert your cacert path here + +# Queue Sink - queue-based event consumer +queue_sink = QueueSink() + +# Callback Sink - callback-based event consumer +callback_sink = CallbackSink() + + +def on_market_data(event: events.MarketData): + print(event) + +def on_market_history(event: events.MarketHistory): + print(event) + +def on_lifecycle(event: events.LifecycleEvent): + print(event) + +callback_sink.on(events.MarketData, on_market_data) +callback_sink.on(events.MarketHistory, on_market_data) +callback_sink.on(events.WsOpen, on_lifecycle) +callback_sink.on(events.WsClose, on_lifecycle) +callback_sink.on(events.WsError, on_lifecycle) +callback_sink.on(events.WsAuthenticated, on_lifecycle) +callback_sink.on(events.WsReady, on_lifecycle) +callback_sink.on(events.WsDegraded, on_lifecycle) + +# Log Sink - useful for debugging +log_sink = LogSink() + +# Composite Sink - allows us to use all above sinks at once +composite_sink = CompositeSink(callback_sink, log_sink) + +# ws_client = IbkrWsClient(cacert=cacert, account_id=account_id) +# ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=LogSink()) +ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=queue_sink) + + +ws_client.start() + +as_sub = AccountSummarySubscription(account_id=account_id) +al_sub = AccountLedgerSubscription(account_id=account_id) +md_sub = MarketDataSubscription(conid='265598', fields=["31", "84", "86"], expiry_seconds=30) +mh_sub = MarketHistorySubscription(conid='265598') +or_sub = OrdersSubscription() +# pl_sub = PriceLadderSubscription(conid='265598', account_id=account_id, exchange='SMART') +pnl_sub = PnlSubscription() +tr_sub = TradesSubscription() +subs = [ + # as_sub, + # al_sub, + md_sub, + # mh_sub, + # or_sub, + # pnl_sub, + # tr_sub +] + +sub_handles: List[SubscriptionHandle] = [] +for sub in subs: + handle = ws_client.subscribe(sub) + handle.wait() + sub_handles.append(handle) + +for handle in sub_handles: + success = handle.wait(timeout=10) + if not success: + print('Subscription not active within 10 seconds') + +try: + while ws_client.is_running(): + for sub in subs: + while not queue_sink.empty(sub.event_type): + ev = queue_sink.get(sub.event_type) + print(ev) + + time.sleep(1) +except KeyboardInterrupt: + print('Interrupt') + +for handle in sub_handles: + unsub_handle = handle.unsubscribe() + success = unsub_handle.wait(timeout=10) + if not success: + print('Subscription not unsubscribed within 10 seconds') + +# unsub_handles = [] +# for sub in subs: +# handle = ws_client.unsubscribe(sub) +# unsub_handles.append(handle) +# +# for handle in unsub_handles: +# handle.wait(10) + +ws_client.shutdown() diff --git a/ibind/__init__.py b/ibind/__init__.py index 6a90b201..7b9f4c68 100644 --- a/ibind/__init__.py +++ b/ibind/__init__.py @@ -10,6 +10,11 @@ from ibind.support.errors import ExternalBrokerError from ibind.support.logs import ibind_logs_initialize from ibind.support.py_utils import execute_in_parallel +from ibind import events, subscriptions +from ibind.ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 +from ibind.ws_v2._ws_events import LogSink, QueueSink, CallbackSink, CompositeSink, NoopSink, EventSink +from ibind.ws_v2.ws_subscriptions import SubscriptionHandle + __all__ = [ 'ibind_logs_initialize', @@ -28,7 +33,17 @@ 'QueueAccessor', 'execute_in_parallel', 'ExternalBrokerError', - 'question_type_to_message_id' + 'question_type_to_message_id', + 'events', + 'subscriptions', + 'IbkrWsClientV2', + 'EventSink', + 'NoopSink', + 'LogSink', + 'QueueSink', + 'CallbackSink', + 'CompositeSink', + 'SubscriptionHandle', ] # patch_dotenv() diff --git a/ibind/base/queue_controller.py b/ibind/base/queue_controller.py index 62b69e0f..5bc217f3 100644 --- a/ibind/base/queue_controller.py +++ b/ibind/base/queue_controller.py @@ -1,10 +1,9 @@ -from enum import Enum from queue import Queue, Empty -from typing import TypeVar, Generic, Any +from typing import TypeVar, Generic, Any, Hashable from ibind.support.py_utils import ensure_list_arg, OneOrMany -T = TypeVar('T', str, Enum) +T = TypeVar('T', bound=Hashable) class QueueAccessor(Generic[T]): # pragma: no cover diff --git a/ibind/base/rest_client.py b/ibind/base/rest_client.py index ec063ec2..d835845a 100644 --- a/ibind/base/rest_client.py +++ b/ibind/base/rest_client.py @@ -49,7 +49,10 @@ def copy(self, data: Optional[Union[list, dict]] = UNDEFINED, request: Optional[ Returns: Result: A new Result instance with the specified modifications. """ - return Result(data=data if data is not UNDEFINED else self.data.copy(), request=request if request is not UNDEFINED else self.request.copy()) + def _copy(value): + return value.copy() if hasattr(value, 'copy') else value + + return Result(data=data if data is not UNDEFINED else _copy(self.data), request=request if request is not UNDEFINED else _copy(self.request)) def pass_result(data: dict, old_result: Result) -> Result: @@ -117,6 +120,7 @@ def __init__( self.use_session = use_session self._auto_recreate_session = auto_recreate_session + self._closed = False if use_session: self.make_session() @@ -139,7 +143,7 @@ def logger(self): self._make_logger() return self._logger - def _get_headers(self, request_method: str, request_url: str): + def _get_headers(self, request_method: str, request_url: str, request_params: dict = None): return {} def get( @@ -224,23 +228,24 @@ def _request( endpoint = endpoint.lstrip('/') url = f'{base_url}{endpoint}' - headers = self._get_headers(request_method=method, request_url=url) - headers = {**headers, **(extra_headers or {})} - # we want to allow default values used by IBKR, so we remove all None parameters kwargs = filter_none(kwargs) - # choose which function should be used to make a reqeust based on use_session field - if self.use_session and self._session is not None: - request_function = self._session.request - else: - request_function = requests.request - - if request_function is None: - _LOGGER.warning(f'{self}: an attempt was made to create a request with no valid session.') + request_params = kwargs.get('params') if method.upper() == 'GET' else None + headers = self._get_headers(request_method=method, request_url=url, request_params=request_params) + headers = {**headers, **(extra_headers or {})} # we repeat the request attempts in case of ReadTimeouts up to max_retries for attempt in range(self._max_retries + 1): + # choose which function should be used to make a request based on the current session + if self.use_session and self._session is not None: + request_function = self._session.request + else: + request_function = requests.request + + if request_function is None: + _LOGGER.warning(f'{self}: an attempt was made to create a request with no valid session.') + if log: self.logger.info(f'{method} {url} {kwargs}{" (attempt: " + str(attempt) + ")" if attempt > 0 else ""}') @@ -308,6 +313,9 @@ def close_session(self): self._session = None def close(self): + if getattr(self, '_closed', False): + return + self._closed = True self.close_session() @@ -328,12 +336,7 @@ def register_shutdown_handler(self): existing_handler_int = signal.getsignal(signal.SIGINT) existing_handler_term = signal.getsignal(signal.SIGTERM) - self._closed = False - def _close_handler(): - if self._closed: - return - self._closed = True self.close() def _signal_handler(signum, frame): @@ -356,4 +359,4 @@ def _signal_handler(signum, frame): atexit.register(_close_handler) def __str__(self): - return f'{self.__class__.__qualname__}' \ No newline at end of file + return f'{self.__class__.__qualname__}' diff --git a/ibind/base/ws_client.py b/ibind/base/ws_client.py index 9da4e4a7..c3d53ade 100644 --- a/ibind/base/ws_client.py +++ b/ibind/base/ws_client.py @@ -1,8 +1,6 @@ import json -import ssl import threading import time -from pathlib import Path from threading import Thread, RLock from typing import Optional, Union, Dict, List @@ -10,7 +8,7 @@ from ibind.base.subscription_controller import SubscriptionController, SubscriptionProcessor from ibind.support.logs import project_logger -from ibind.support.py_utils import exception_to_string, wait_until, tname +from ibind.support.py_utils import exception_to_string, wait_until, tname, make_websocket_sslopt _LOGGER = project_logger(__file__) @@ -85,14 +83,9 @@ def __init__( self._thread = None self._thread_ids = {} self._next_thread_id = 0 + self._last_unanswered_ping_tm = None - if not (cacert is False or Path(cacert).exists()): - raise ValueError(f'{self}: cacert must be a valid Path or False') - - if cacert is None or not cacert: - self._sslopt = {'cert_reqs': ssl.CERT_NONE} - else: - self._sslopt = {'ca_certs': cacert} + self._sslopt = make_websocket_sslopt(cacert) def send(self, payload: str) -> bool: """ @@ -303,6 +296,7 @@ def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover def _handle_on_open(self, wsa: WebSocketApp): _LOGGER.info(f'{self}: Connection open') self._connected = True + self._last_unanswered_ping_tm = None self._on_open(wsa) def _handle_on_error(self, wsa: WebSocketApp, error): # pragma: no cover @@ -463,13 +457,25 @@ def check_ping(self) -> bool: if self._wsa is None: return True - if self._wsa.last_ping_tm == 0: + last_ping_tm = getattr(self._wsa, 'last_ping_tm', 0) + last_pong_tm = getattr(self._wsa, 'last_pong_tm', 0) + + if last_ping_tm == 0: return True - diff = abs(time.time() - self._wsa.last_ping_tm) + if last_pong_tm >= last_ping_tm and last_pong_tm != 0: + self._last_unanswered_ping_tm = None + diff = abs(time.time() - last_pong_tm) + else: + if self._last_unanswered_ping_tm is not None and last_pong_tm >= self._last_unanswered_ping_tm: + self._last_unanswered_ping_tm = None + if self._last_unanswered_ping_tm is None: + self._last_unanswered_ping_tm = last_ping_tm + diff = abs(time.time() - self._last_unanswered_ping_tm) + if diff > self._max_ping_interval: _LOGGER.warning( - f'{self}: Last WebSocket ping happened {diff: .2f} seconds ago, exceeding the max ping interval of {self._max_ping_interval}. Restarting.' + f'{self}: Last WebSocket pong happened {diff: .2f} seconds ago, exceeding the max ping interval of {self._max_ping_interval}. Restarting.' ) self.hard_reset(restart=True) return False diff --git a/ibind/client/ibkr_client.py b/ibind/client/ibkr_client.py index 8f38c40b..a926b9ff 100644 --- a/ibind/client/ibkr_client.py +++ b/ibind/client/ibkr_client.py @@ -134,7 +134,7 @@ def _request(self, method: str, endpoint: str, base_url: str = None, extra_heade raise ExternalBrokerError('IBKR returned 400 Bad Request: no bridge. Try calling `initialize_brokerage_session()` first.') from e raise - def _get_headers(self, request_method: str, request_url: str): + def _get_headers(self, request_method: str, request_url: str, request_params: dict = None): if (not self._use_oauth) or request_url == f'{self.base_url}{self.oauth_config.live_session_token_endpoint}': # No need for extra headers if we don't use oauth or getting live session token return {} @@ -143,7 +143,11 @@ def _get_headers(self, request_method: str, request_url: str): from ibind.oauth.oauth1a import generate_oauth_headers headers = generate_oauth_headers( - oauth_config=self.oauth_config, request_method=request_method, request_url=request_url, live_session_token=self.live_session_token + oauth_config=self.oauth_config, + request_method=request_method, + request_url=request_url, + live_session_token=self.live_session_token, + request_params=request_params, ) return headers @@ -251,6 +255,8 @@ def stop_tickler(self, timeout:float=None): self._tickler.stop(timeout) def close(self): + if getattr(self, '_closed', False): + return if self._use_oauth and self.oauth_config.shutdown_oauth: self.oauth_shutdown() super().close() @@ -320,7 +326,7 @@ def _attempt_health_check(self, method: callable, raise_exceptions: bool = False elif 'An attempt was made to access a socket in a way forbidden by its access permissions' in str(e): _LOGGER.error('Connection to IBKR servers blocked during reauthentication. Check that nothing is blocking connectivity of the application') elif e.status_code == 410 and 'gone' in str(e): - _LOGGER.error(f'OAuth 410 gone: recreate a new live session token, or try a different server, eg. "1.api.ibkr.com", "2.api.ibkr.com", etc.') + _LOGGER.error('OAuth 410 gone: recreate a new live session token, or try a different server, eg. "1.api.ibkr.com", "2.api.ibkr.com", etc.') else: _LOGGER.error(f'Unknown error checking IBKR connection during reauthentication: {exception_to_string(e)}') @@ -330,4 +336,4 @@ def _attempt_health_check(self, method: callable, raise_exceptions: bool = False _LOGGER.error(f'Error reauthenticating OAuth during reauthentication: {exception_to_string(e)}') if raise_exceptions: raise - return False \ No newline at end of file + return False diff --git a/ibind/client/ibkr_client_mixins/marketdata_mixin.py b/ibind/client/ibkr_client_mixins/marketdata_mixin.py index 7522ab30..6c5a1146 100644 --- a/ibind/client/ibkr_client_mixins/marketdata_mixin.py +++ b/ibind/client/ibkr_client_mixins/marketdata_mixin.py @@ -189,7 +189,8 @@ def marketdata_history_by_symbol( start_time (datetime.datetime, optional): Starting date of the request duration. """ - conid = str(self.stock_conid_by_symbol(symbol).data[symbol]) + symbol_key = symbol.symbol if isinstance(symbol, StockQuery) else symbol + conid = str(self.stock_conid_by_symbol(symbol).data[symbol_key]) return self.marketdata_history_by_conid(conid, bar, exchange, period, outside_rth, start_time) def marketdata_history_by_conids( diff --git a/ibind/client/ibkr_client_mixins/session_mixin.py b/ibind/client/ibkr_client_mixins/session_mixin.py index 2b68e733..b87f1f63 100644 --- a/ibind/client/ibkr_client_mixins/session_mixin.py +++ b/ibind/client/ibkr_client_mixins/session_mixin.py @@ -142,4 +142,4 @@ def check_auth_status(self: 'IbkrClient') -> bool: if result.data.get('authenticated', None) is None: raise AttributeError(f'Health check requests returns invalid data: {result}') - return _parse_auth_status(result.data) \ No newline at end of file + return _parse_auth_status(result.data) diff --git a/ibind/client/ibkr_utils.py b/ibind/client/ibkr_utils.py index 21cda07d..45dd0947 100644 --- a/ibind/client/ibkr_utils.py +++ b/ibind/client/ibkr_utils.py @@ -663,13 +663,16 @@ def start(self): This method creates and starts a daemon thread that periodically calls `tickle()` to keep the session alive. """ - if self._thread is not None: + if self._thread is not None and self._thread.is_alive(): _LOGGER.info('Tickler thread already running. Stop the existing thread first by calling Tickler.stop()') - return + return False + + self._thread = None self._stop_event.clear() self._thread = threading.Thread(target=self._worker, daemon=True) self._thread.start() + return True def stop(self, timeout:float=None): """ @@ -682,11 +685,18 @@ def stop(self, timeout:float=None): If None, waits indefinitely. """ if self._thread is None: - return + return True + thread = self._thread self._stop_event.set() # Wake up the sleeping thread immediately - self._thread.join(timeout) + thread.join(timeout) + + if thread.is_alive(): + _LOGGER.error(f'Tickler thread did not stop within timeout={timeout}. Keeping the thread reference to avoid starting a duplicate tickler.') + return False + self._thread = None # Ensure cleanup + return True def cleanup_market_history_responses( @@ -774,4 +784,4 @@ def cleanup_market_history_responses( } ) results[symbol] = records - return results \ No newline at end of file + return results diff --git a/ibind/client/ibkr_ws_client.py b/ibind/client/ibkr_ws_client.py index 30c50cdb..92db3102 100644 --- a/ibind/client/ibkr_ws_client.py +++ b/ibind/client/ibkr_ws_client.py @@ -15,7 +15,7 @@ from ibind.client.ibkr_utils import extract_conid from ibind.support.errors import ExternalBrokerError from ibind.support.logs import project_logger -from ibind.support.py_utils import TimeoutLock, UNDEFINED, wait_until +from ibind.support.py_utils import TimeoutLock, UNDEFINED, wait_until, append_query_params _LOGGER = project_logger(__file__) @@ -278,7 +278,8 @@ def __init__( raise ValueError( 'OAuth access token not found. Please set IBIND_OAUTH1A_ACCESS_TOKEN environment variable or provide it as `access_token` argument.' ) - url += f'?oauth_token={access_token}' + url = append_query_params(url, {'oauth_token': access_token}) + cacert = True if ibkr_client is None: ibkr_client = IbkrClient(account_id=account_id, host=host, port=port, cacert=cacert, use_oauth=use_oauth) @@ -401,12 +402,11 @@ def _handle_authentication_status(self, message, data): _LOGGER.error(f'{self}: Status unauthenticated: {data}') self.set_authenticated(data.get('authenticated')) elif 'competing' in data: - if data.get('competing') is False: - pass - _LOGGER.error(f'{self}: Status competing: {data}') + if data.get('competing') is True: + _LOGGER.error(f'{self}: Status competing: {data}') elif ( # expected status updates that we ignore data == {'message': ''} or - data.get('fail', '') == '' or + data.get('fail') == '' or 'serverName' in data or 'serverVersion' in data or 'username' in data @@ -711,4 +711,4 @@ def ts_changed(): if not wait_until(ts_changed, f'tic timeout, ts={ts}', timeout=5): return None - return self._tic_message \ No newline at end of file + return self._tic_message diff --git a/ibind/events/__init__.py b/ibind/events/__init__.py new file mode 100644 index 00000000..b86fcae3 --- /dev/null +++ b/ibind/events/__init__.py @@ -0,0 +1,33 @@ +from ibind.ws_v2._ws_events import LifecycleEvent, WsOpen, WsAuthenticated, WsDegraded, WsReady, WsClose, WsError, WsEvent +from ibind.ibkr_ws_v2.ibkr_events import GenericIbkrEvent, IbkrError, WaitingForSession, Notification, Bulletin, AccountUpdate, System, AuthenticationStatus, IbkrTopicEvent, AccountSummary, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, ServerId, Unsubscription + + +__all__ = [ + 'LifecycleEvent', + 'WsEvent', + 'WsOpen', + 'WsAuthenticated', + 'WsDegraded', + 'WsReady', + 'WsClose', + 'WsError', + 'GenericIbkrEvent', + 'IbkrError', + 'WaitingForSession', + 'Notification', + 'Bulletin', + 'AccountUpdate', + 'System', + 'AuthenticationStatus', + 'IbkrTopicEvent', + 'AccountSummary', + 'AccountLedger', + 'MarketData', + 'MarketHistory', + 'Orders', + 'PriceLadder', + 'Pnl', + 'Trades', + 'ServerId', + 'Unsubscription', +] diff --git a/ibind/ibkr_ws_v2/__init__.py b/ibind/ibkr_ws_v2/__init__.py new file mode 100644 index 00000000..e3a13dae --- /dev/null +++ b/ibind/ibkr_ws_v2/__init__.py @@ -0,0 +1 @@ +"""IBKR WebSocket V2 client implementation.""" diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py new file mode 100644 index 00000000..500dbf16 --- /dev/null +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -0,0 +1,115 @@ +from typing import ClassVar + +from pydantic import Field + +from ibind.events import WsEvent + + + +class GenericIbkrEvent(WsEvent): + message: dict | None + topic: str | None = None + data: dict | None = None + + +# =================== +# == Unsolicited == +# =================== + +class IbkrError(WsEvent): + message: str + + +class WaitingForSession(WsEvent): + ... + + +class Notification(WsEvent): + message: str + + +class Bulletin(WsEvent): + message: str + + +class AccountUpdate(WsEvent): + data: dict + + +class System(WsEvent): + data: dict + + +class AuthenticationStatus(WsEvent): + data: dict + authenticated: bool | None + competing: bool | None + + +# =================== +# == Topic-based == +# =================== + +class IbkrTopicEvent(WsEvent): + topic: ClassVar[str] + + +class AccountSummary(IbkrTopicEvent): + topic: ClassVar[str] = 'sd' + account_id: str + data: dict + + +class AccountLedger(IbkrTopicEvent): + topic: ClassVar[str] = 'ld' + account_id: str + data: dict + + +class MarketData(IbkrTopicEvent): + topic: ClassVar[str] = 'md' + conid: str + data: dict = Field(default_factory=dict) + + +class MarketHistory(IbkrTopicEvent): + topic: ClassVar[str] = 'mh' + conid: str + data: dict + + +class Orders(IbkrTopicEvent): + topic: ClassVar[str] = 'or' + data: dict + + +class PriceLadder(IbkrTopicEvent): + topic: ClassVar[str] = 'bd' + account_id: str + conid: str + exchange: str + data: dict + + +class Pnl(IbkrTopicEvent): + topic: ClassVar[str] = 'pl' + data: dict + + +class Trades(IbkrTopicEvent): + topic: ClassVar[str] = 'tr' + data: dict + + +# =============== +# == Derived == +# =============== +class ServerId(WsEvent): + target_event_type: type[IbkrTopicEvent] + conid: str + server_id: str + + +class Unsubscription(WsEvent): + target_event_type: type[IbkrTopicEvent] + conid: str | None = None diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py new file mode 100644 index 00000000..e47cdf85 --- /dev/null +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -0,0 +1,250 @@ +import json +from collections import defaultdict +from typing import Dict + +from ibind.client import ibkr_definitions +from ibind.client.ibkr_utils import extract_conid + +from ibind import events +from ibind.events import GenericIbkrEvent, IbkrTopicEvent +from ibind.support.logs import project_logger +from ibind.support.py_utils import UNDEFINED, OneOrMany +from ibind.events import WsEvent + +_LOGGER = project_logger('ibkr_ws_client') + + +def get_ibkr_topic_event(topic: str): + topic_to_event_type = { + 'sd': events.AccountSummary, + 'ld': events.AccountLedger, + 'md': events.MarketData, + 'mh': events.MarketHistory, + 'bd': events.PriceLadder, + 'or': events.Orders, + 'pl': events.Pnl, + 'tr': events.Trades, + } + if topic in topic_to_event_type: + return topic_to_event_type[topic] + raise ValueError(f"No Ibkr event associated with topic '{topic}'") + + +def parse_raw_message(raw_message: str): + message = json.loads(raw_message) + topic = message.get('topic', UNDEFINED) + + if topic is UNDEFINED: + return message, None, None + + data = message.get('args', {}) + + return message, topic, data + + +class IbkrRouter: + def __init__(self, log_raw_messages: bool = False, unwrap_market_data: bool = True): + self._log_raw_messages = log_raw_messages + self._unwrap_market_data = unwrap_market_data + self._server_id_conid_pairs: Dict[type[IbkrTopicEvent], Dict[str, str]] = defaultdict(dict) + + def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: + """ + API will only return fields that were updated. If you are not receiving certain fields in the response - means that they remain unchanged. + """ + if 'conid' not in data: # pragma: no cover + # sometimes the ticker message is just an empty update, we ignore it + return [] + + if not self._unwrap_market_data: + return events.MarketData(conid=data['conid'], data=data) + + unwrapped_data = {} + for key, value in data.items(): + if key in ibkr_definitions.snapshot_by_id: + unwrapped_data[ibkr_definitions.snapshot_by_id[key]] = value + return events.MarketData(conid=str(data['conid']), data=unwrapped_data) + + def _preprocess_market_history_message(self, data: dict) -> OneOrMany[WsEvent]: + mh_server_id_conid_pairs = self._server_id_conid_pairs[events.MarketHistory] + rv = [] + conid = extract_conid(data) + if 'serverId' in data and data['serverId'] not in mh_server_id_conid_pairs: + mh_server_id_conid_pairs[data['serverId']] = str(conid) + rv.append(events.ServerId(conid=str(conid), server_id=data['serverId'], target_event_type=events.MarketHistory)) + + rv.append(events.MarketHistory(conid=str(conid), data=data)) + return rv + + def _preprocess_account_ledger(self, data): + rv = [] + for entry in data['result']: + if 'acctCode' not in entry: + continue + event = events.AccountLedger(data=entry, account_id=entry['acctCode']) + rv.append(event) + return rv + + def _preprocess_account_summary(self, data): + summary = {} + timestamp = data['result'][0]['timestamp'] + for entry in data['result']: + key = entry.pop('key') + entry.pop('timestamp') + + if entry == {}: + continue + + summary[key] = entry + + if summary == {}: + return [] + + if 'AccountCode' not in summary or 'value' not in summary['AccountCode']: + _LOGGER.error(f'{self}: Account code not found in account summary: {summary}') + return [] + + account_id = summary['AccountCode']['value'] + summary['timestamp'] = timestamp + + event = events.AccountSummary(data=summary, account_id=account_id) + return event + + def _handle_subscribed_message(self, topic: str, data: dict) -> OneOrMany[WsEvent] | None: + try: + # ibkr_ws_key = IbkrWsKey.from_topic(topic[1:3]) + event_type = get_ibkr_topic_event(topic[1:3]) + except ValueError: + # ValueError means we don't support this topic + return None + + if event_type == events.AccountSummary: + rv = self._preprocess_account_summary(data) + elif event_type == events.AccountLedger: + rv = self._preprocess_account_ledger(data) + elif event_type == events.MarketData: + rv = self._preprocess_market_data_message(data) + elif event_type == events.MarketHistory: + rv = self._preprocess_market_history_message(data) + elif event_type == events.PriceLadder: + rv = events.PriceLadder(data=data) + elif event_type == events.Orders: + rv = events.Orders(data=data) + elif event_type == events.Pnl: + rv = events.Pnl(data=data) + elif event_type == events.Trades: + rv = events.Trades(data=data) + else: + _LOGGER.error(f'{self}: Unhandled subscribed message: {data}') + rv = None + return rv + + def _handle_account_update(self, message, arguments) -> OneOrMany[WsEvent]: + return events.AccountUpdate(data=arguments) + + def _handle_authentication_status(self, message, arguments) -> OneOrMany[WsEvent]: + if 'authenticated' in arguments or 'competing' in arguments: + return events.AuthenticationStatus(data=arguments, authenticated=arguments.get('authenticated'), competing=arguments.get('competing')) + elif ( # expected status updates that we ignore + arguments == {'message': ''} + or arguments.get('fail') == '' + or 'serverName' in arguments + or 'serverVersion' in arguments + or 'username' in arguments + ): + return [] + + _LOGGER.info(f'{self}: Status message: {arguments}') + return GenericIbkrEvent(message=message, topic='sts', data=arguments) + + def _handle_bulletin(self, message) -> OneOrMany[WsEvent]: # pragma: no cover + return events.Bulletin(message=message) + + def _handle_error(self, message) -> OneOrMany[WsEvent]: + _LOGGER.error(f'{self}: on_message error: {message}') + return events.IbkrError(message=message) + + def _handle_notification(self, data) -> OneOrMany[WsEvent]: # pragma: no cover + rv = [] + for notification in data: + rv.append(events.Notification(message=notification)) + return rv + + def _handle_market_history_unsubscribe(self, data) -> OneOrMany[WsEvent]: + server_id = data['message'].split('Unsubscribed ')[-1] + mh_server_id_conid_pairs = self._server_id_conid_pairs[events.MarketHistory] + if server_id in mh_server_id_conid_pairs: + conid = mh_server_id_conid_pairs[server_id] + _LOGGER.info(f'{self}: Received unsubscribing confirmation for server_id={server_id!r}, conid={conid!r}.') + if conid is not None: + return events.Unsubscription(target_event_type=events.MarketHistory, conid=str(conid)) + + _LOGGER.warning(f'{self}: Unknown conid={conid!r}. Cannot mark the subscription as unsubscribed.') + else: + _LOGGER.warning( + f'{self}: Received unsubscribing confirmation for unknown server_id={server_id!r}. Existing server_ids: {mh_server_id_conid_pairs}' + ) + return [] + + def _handle_message_without_topic(self, message: dict) -> OneOrMany[WsEvent]: + if 'message' in message: + if message['message'] == 'waiting for session': + _LOGGER.info(f'{self}: Waiting for an active IBKR session.') + return events.WaitingForSession() + + if 'Unsubscribed' in message['message']: + return self._handle_market_history_unsubscribe(message) + + elif 'result' in message: + if message['result'] == 'unsubscribed from summary': + return events.Unsubscription(target_event_type=events.AccountSummary) + elif message['result'] == 'unsubscribed from ledger': + return events.Unsubscription(target_event_type=events.AccountLedger) + + _LOGGER.error(f'{self}: Unrecognised message without a topic: {message}') + return GenericIbkrEvent(message=message) + + def route(self, raw_message: str) -> OneOrMany[WsEvent]: + if self._log_raw_messages: + _LOGGER.debug(f'{self}: Raw message: {raw_message}') + message, topic, arguments = parse_raw_message(raw_message) + + if 'error' in message: + rv = self._handle_error(message) + + elif topic is None: + # in general most message should carry a topic, other than for few exceptions + rv = self._handle_message_without_topic(message) + + elif topic == 'tic': + # self._tic_message = message + rv = events.System(data=message) + + elif topic == 'system': + rv = events.System(data=message) + + elif topic == 'act': + rv = self._handle_account_update(message, arguments) + + elif topic == 'blt': + rv = self._handle_bulletin(message) + + elif topic == 'ntf': + rv = self._handle_notification(arguments) + + elif topic == 'sts': + rv = self._handle_authentication_status(message, arguments) + + elif topic == 'error': + rv = self._handle_error(message) + + else: + rv = self._handle_subscribed_message(topic, message) + if rv is None: + _LOGGER.error(f'{self}: topic "{topic}" subscribed but lacking a handler. Message: {message}') + rv = GenericIbkrEvent(message=message, topic=topic, data=arguments) + + return rv + + def __str__(self): + return f'{self.__class__.__qualname__}()' diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py new file mode 100644 index 00000000..f349bfae --- /dev/null +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -0,0 +1,282 @@ +import json +from typing import Tuple, List + +from pydantic import Field + +from ibind import events +from ibind.events import AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary, IbkrTopicEvent +from ibind.support.py_utils import filter_none +from ibind.ws_v2.ws_subscriptions import Subscription, SubscriptionResolver + + +def make_binding_key( + event_type: type[IbkrTopicEvent], + conid: str = None, + account_id=None, + exchange=None +): + if event_type in [events.MarketData, events.MarketHistory]: + return f"{event_type.topic}+{conid}" + elif event_type in [events.AccountLedger, events.AccountSummary]: + return f"{event_type.topic}+{account_id}" + elif event_type in [events.PriceLadder]: + return f"{event_type.topic}+{account_id}+{conid}" + (f"+{exchange}" if exchange is not None else '') + elif event_type in [events.Orders, events.Pnl, events.Trades]: + return event_type.topic + else: + raise ValueError(f'Unsupported event type: {event_type}') + + +class IbkrSubscriptionResolver(SubscriptionResolver): + def __init__(self, account_id): + self._account_id = account_id + + def _resolve_subscribing_event(self, event) -> str: + event_type = type(event) + if event_type in [events.MarketData, events.MarketHistory]: + return make_binding_key(event_type, conid=event.conid) + elif event_type in [events.AccountLedger, events.AccountSummary]: + return make_binding_key(event_type, account_id=event.account_id) + elif event_type in [events.PriceLadder]: + return make_binding_key(event_type, conid=event.conid, account_id=event.account_id, exchange=event.exchange) + elif event_type in [events.Orders, events.Pnl, events.Trades]: + return make_binding_key(event_type) + else: + raise ValueError(f'Unsupported event: {event}') + + def _resolve_unsubscribing_event(self, event) -> str: + return make_binding_key(event.target_event_type, event.conid, self._account_id) + + def resolve_binding_key(self, event) -> Tuple[bool, str] | Tuple[None, None]: + if not (isinstance(event, IbkrTopicEvent) or isinstance(event, Unsubscription)): + return None, None + + if isinstance(event, Unsubscription): + return False, self._resolve_unsubscribing_event(event) + else: + return True, self._resolve_subscribing_event(event) + + +class IbkrSubscription(Subscription): + event_type: type[IbkrTopicEvent] + + @property + def topic(self) -> str: + return self.event_type.topic + + +class AccountSummarySubscription(IbkrSubscription): + event_type: type[IbkrTopicEvent] = AccountSummary + account_id: str + + def subscribe_payload(self) -> str: + return f'ssd+{self.account_id}' + + def unsubscribe_payload(self) -> str: + return f'usd+{self.account_id}' + + @property + def confirms_subscribe(self) -> bool: + return True + + @property + def confirms_unsubscribe(self) -> bool: + return True + + def binding_key(self): + return make_binding_key(self.event_type, account_id=self.account_id) + + +class AccountLedgerSubscription(IbkrSubscription): + event_type: type[IbkrTopicEvent] = AccountLedger + account_id: str + + def subscribe_payload(self) -> str: + return f'sld+{self.account_id}' + + def unsubscribe_payload(self) -> str: + return f'uld+{self.account_id}' + + @property + def confirms_subscribe(self) -> bool: + return True + + @property + def confirms_unsubscribe(self) -> bool: + return True + + def binding_key(self): + return make_binding_key(self.event_type, account_id=self.account_id) + + +class MarketDataSubscription(IbkrSubscription): + event_type: type[IbkrTopicEvent] = MarketData + conid: str + fields: List[str] + + def subscribe_payload(self) -> str: + fields_str = json.dumps({"fields": list(self.fields)}, separators=(',', ':')) + return f'smd+{self.conid}+{fields_str}' + + def unsubscribe_payload(self) -> str: + return f'umd+{self.conid}+{{}}' + + @property + def confirms_subscribe(self) -> bool: + return True + + @property + def confirms_unsubscribe(self) -> bool: + return False + + def binding_key(self): + return make_binding_key(self.event_type, conid=self.conid) + + +class MarketHistorySubscription(IbkrSubscription): + event_type: type[IbkrTopicEvent] = MarketHistory + conid: str + exchange: str = None + period: str = None + bar: str = None + outside_rth: bool = None + source: str = None + format: str = None + server_id: list = Field(default_factory=list) # uses list to allow writing despite frozen model + + def subscribe_payload(self) -> str: + data = { + 'exchange': self.exchange, + 'period': self.period, + 'bar': self.bar, + 'outside_rth': self.outside_rth, + 'source': self.source, + 'format': self.format, + } + data = filter_none(data) + return f'smh+{self.conid}+{json.dumps(data, separators=(",", ":"))}' + + def unsubscribe_payload(self) -> str: + server_id = self.get_server_id() + if server_id is None: + raise RuntimeError(f'{self}: Unsubscribing from market history for conid={self.conid!r} without server_id. MarketHistorySubscription must have server_id set before unsubscribing.') + return f'umh+{server_id}' + + @property + def confirms_subscribe(self) -> bool: + return True + + @property + def confirms_unsubscribe(self) -> bool: + return True + + def set_server_id(self, server_id): + if self.has_server_id(): + raise ValueError('Server ID already set') + self.server_id.append(server_id) + + def has_server_id(self) -> bool: + return len(self.server_id) > 0 + + def get_server_id(self): + return self.server_id[0] + + def binding_key(self): + return make_binding_key(self.event_type, conid=self.conid) + + +class OrdersSubscription(IbkrSubscription): + event_type: type[IbkrTopicEvent] = Orders + filter: str = None + + def subscribe_payload(self) -> str: + filter_str = f'{{"filters": ["{self.filter}"]}}' if self.filter is not None else '{}' + return f'sor+{filter_str}' + + def unsubscribe_payload(self) -> str: + return 'uor+{}' + + @property + def confirms_subscribe(self) -> bool: + return False + + @property + def confirms_unsubscribe(self) -> bool: + return False + + def binding_key(self): + return make_binding_key(self.event_type) + + +class PriceLadderSubscription(IbkrSubscription): + event_type: type[IbkrTopicEvent] = PriceLadder + conid: str + account_id: str + exchange: str + + def subscribe_payload(self) -> str: + return f'sbd+{self.account_id}+{self.conid}+{self.exchange}' + + def unsubscribe_payload(self) -> str: + return f'ubd+{self.account_id}' + + @property + def confirms_subscribe(self) -> bool: + return False + + @property + def confirms_unsubscribe(self) -> bool: + return False + + def binding_key(self): + return make_binding_key(self.event_type, conid=self.conid, account_id=self.account_id, exchange=self.exchange) + + +class PnlSubscription(IbkrSubscription): + event_type: type[IbkrTopicEvent] = Pnl + + def subscribe_payload(self) -> str: + return 'spl' + + def unsubscribe_payload(self) -> str: + return 'upl' + + @property + def confirms_subscribe(self) -> bool: + return True + + @property + def confirms_unsubscribe(self) -> bool: + return False + + def binding_key(self): + return make_binding_key(self.event_type) + + +class TradesSubscription(IbkrSubscription): + event_type: type[IbkrTopicEvent] = Trades + realtime_updates_only: bool | None = None + days: int | None = None + + def subscribe_payload(self) -> str: + extra = {} + if self.realtime_updates_only is not None: + extra['realtime_updates_only'] = self.realtime_updates_only + if self.days is not None: + extra['days'] = self.days + extra_str = json.dumps(extra, separators=(',', ':')) + return f'str+{extra_str}' + + def unsubscribe_payload(self) -> str: + return 'utr' + + @property + def confirms_subscribe(self) -> bool: + return True + + @property + def confirms_unsubscribe(self) -> bool: + return False + + def binding_key(self): + return make_binding_key(self.event_type) diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py new file mode 100644 index 00000000..10d4ff33 --- /dev/null +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -0,0 +1,210 @@ +import json +from collections import defaultdict +from typing import Union, List, Dict, Type + +from ibind import events +from ibind import var +from ibind.client.ibkr_client import IbkrClient +from ibind.events import IbkrTopicEvent +from ibind.ibkr_ws_v2.ibkr_router import IbkrRouter +from ibind.ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription +from ibind.support.logs import project_logger +from ibind.support.py_utils import OneOrMany, ensure_list_arg, append_query_params +from ibind.ws_v2._ws_events import EventSink, CallbackSink, Router, AsyncSink, NoopSink +from ibind.ws_v2.ws_subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle, BindingStatus +from ibind.ws_v2.ws_runtime import WsRuntime, WsState + +_LOGGER = project_logger('ibkr_ws_client') + +_DEFAULT_CYCLE_INTERVAL = 0.25 + + +class IbkrWsClientV2(): + def __init__( + self, + account_id: str = var.IBIND_ACCOUNT_ID, + url: str = var.IBIND_WS_URL, + host: str = '127.0.0.1', + port: str = '5000', + base_route: str = '/v1/api/ws', + ibkr_client: IbkrClient = None, + use_oauth: bool = var.IBIND_USE_OAUTH, + access_token: str = var.IBIND_OAUTH1A_ACCESS_TOKEN, + cacert: Union[str, bool] = var.IBIND_CACERT, + cycle_interval: float = _DEFAULT_CYCLE_INTERVAL, + recreate_subscriptions_on_reconnect: bool = True, + sink: EventSink = None, + router: Router = None, + subscription_resolver: SubscriptionResolver = None, + synchronous_output_events: bool = False, + ): + self._account_id = account_id + + url = var.IBIND_OAUTH1A_WS_URL if url is None and use_oauth else url + + if url is None: + url = f'wss://{host}:{port}{base_route}' + + if use_oauth: + if access_token is None: + raise ValueError( + 'OAuth access token not found. Please set IBIND_OAUTH1A_ACCESS_TOKEN environment variable or provide it as `access_token` argument.' + ) + url = append_query_params(url, {'oauth_token': access_token}) + cacert = True + + if ibkr_client is None: + ibkr_client = IbkrClient(account_id=account_id, host=host, port=port, cacert=cacert, use_oauth=use_oauth) + + self._ibkr_client = ibkr_client + self._use_oauth = use_oauth + self._recreate_subscriptions_on_reconnect = recreate_subscriptions_on_reconnect + + # self._queue_controller = QueueController[IbkrWsKey]() + # self._queue_controller.register_queues(list(IbkrWsKey)) + + if sink is None: + # self._queue_controller.register_queues(['LIFECYCLE', 'IBKR']) + # sink = QueueSink(queue_controller=self._queue_controller) + + # sink = LogSink() + sink = NoopSink() + + self._internal_sink = CallbackSink() + self._register_internal_callbacks() + + if synchronous_output_events: + _LOGGER.info(f'{self}: Output events will be emitted synchronously from the runtime thread') + else: + sink = AsyncSink(sink=sink) + + if router is None: + router = IbkrRouter() + + if subscription_resolver is None: + subscription_resolver = IbkrSubscriptionResolver(account_id) + + self._runtime = WsRuntime( + url=url, + cycle_interval=cycle_interval, + ready_state=WsState.AUTHENTICATED, + cacert=cacert, + sink=sink, + internal_sink=self._internal_sink, + router=router, + subscription_resolver=subscription_resolver, + connection_timeout=5, + get_cookie=self._get_cookie, + get_header=self._get_header, + ) + + self._mh_subscriptions: List[MarketHistorySubscription] = [] + self._conid_server_id_pairs: Dict[type[events.IbkrTopicEvent], Dict[str, str]] = defaultdict(dict) + + def _register_internal_callbacks(self): + self._internal_sink.on(events.AuthenticationStatus, self._on_authentication_status) + self._internal_sink.on(events.WaitingForSession, self._set_unauthenticated) + self._internal_sink.on(events.System, self._on_system) + self._internal_sink.on(events.ServerId, self._on_server_id) + + def _set_unauthenticated(self, _): + self._runtime.set_authenticated(False) + + def _on_authentication_status(self, event: events.AuthenticationStatus): + if event.authenticated is False: + _LOGGER.error(f'{self}: Status unauthenticated: {event}') + elif event.competing is True: + _LOGGER.error(f'{self}: Authentication competing: {event}') + + if event.authenticated is not None: + self._runtime.set_authenticated(event.authenticated) + + def _on_system(self, event: events.System): + if 'hb' in event.data: + self._runtime.set_last_heartbeat(int(event.data['hb']) / 1000) + + def _on_server_id(self, event: events.ServerId): + self._conid_server_id_pairs[event.target_event_type][event.conid] = event.server_id + for subscription in self._mh_subscriptions: + if subscription.event_type == event.target_event_type and subscription.conid == event.conid and not subscription.has_server_id(): + subscription.set_server_id(event.server_id) + + def _get_cookie(self): + # try: + status = self._ibkr_client.tickle() + # except TimeoutError as e: + # if 'Reached max retries' in str(e): + # _LOGGER.warning(f'{self}: Acquiring session cookie timed out, connection to the Gateway may be broken.') + # return None + # raise + # except ExternalBrokerError: + # _LOGGER.warning(f'{self}: Acquiring session cookie failed, connection to the Gateway may be broken.') + # return None + session_id = status.data['session'] + if self._use_oauth: + return f'api={session_id}' + payload = {'session': session_id} + return f'api={json.dumps(payload)}' + + def _get_header(self): + return {'User-Agent': 'ClientPortalGW/1'} if self._use_oauth else None + + def start(self): + self._runtime.start() + + def shutdown(self): + self._runtime.stop() + + def hard_reset(self): + self._runtime.hard_reset() + + def reset_websocket_app(self): + self._runtime.reset_websocket_app() + + def subscribe(self, subscription: Subscription) -> SubscriptionHandle: + if isinstance(subscription, MarketHistorySubscription): + self._mh_subscriptions.append(subscription) + return self._runtime.subscription_controller.subscribe(subscription) + + def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: + if isinstance(subscription, MarketHistorySubscription): + self._handle_mh_unsubscription(subscription) + return self._runtime.subscription_controller.unsubscribe(subscription) + + def get_status(self, binding_key: str) -> BindingStatus: + return self._runtime.subscription_controller.get_status(binding_key) + + def get_server_id(self, event_type: Type[IbkrTopicEvent], conid: str) -> str: + return self._conid_server_id_pairs[event_type][conid] + + def _handle_mh_unsubscription(self, subscription: MarketHistorySubscription): + if subscription.has_server_id(): + return + server_id = self._conid_server_id_pairs.get(subscription.event_type, {}).get(subscription.conid) + if server_id is None: + raise RuntimeError(f'{self}: Unsubscribing from market history for conid={subscription.conid!r} without server_id. Could not find server_id in memory. Ensure at least one MarketHistory event is received before unsubscribing.') + + _LOGGER.warning( + f'{self}: Unsubscribing from market history for conid={subscription.conid!r} without server_id. Setting from memory: {server_id!r}. ' + f'Unsubscribe using the same Subscription instance that was used for subscribing to avoid this warning, ' + f'or set it manually before calling unsubscribe by using ' + f'`subscription.set_server_id(ibkr_ws_client.get_server_id(IbkrWsKey.MARKET_HISTORY, conid))`' + ) + subscription.set_server_id(server_id) + + @ensure_list_arg('subscription_handles') + def wait_all(self, subscription_handles: OneOrMany[SubscriptionHandle], timeout: float | None = None) -> List[SubscriptionHandle]: + failed = [] + for subscription_handle in subscription_handles: + if not (subscription_handle.wait(timeout)): + failed.append(subscription_handle) + return failed + + def is_running(self) -> bool: + return self._runtime.is_running() + + def get_state(self) -> WsState: + return self._runtime.get_state() + + def __str__(self): + return f'{self.__class__.__qualname__}()' diff --git a/ibind/oauth/__init__.py b/ibind/oauth/__init__.py index e0b7d24f..1e151559 100644 --- a/ibind/oauth/__init__.py +++ b/ibind/oauth/__init__.py @@ -55,4 +55,4 @@ def copy(self, **kwargs): if not hasattr(copied, kwarg): raise AttributeError(f'OAuthConfig does not have attribute "{kwarg}"') setattr(copied, kwarg, value) - return copied \ No newline at end of file + return copied diff --git a/ibind/oauth/oauth1a.py b/ibind/oauth/oauth1a.py index 6b849588..82e3f73c 100644 --- a/ibind/oauth/oauth1a.py +++ b/ibind/oauth/oauth1a.py @@ -237,7 +237,6 @@ def generate_oauth_headers( 'Accept-Encoding': 'gzip,deflate', 'Authorization': headers_string, 'Connection': 'keep-alive', - 'Host': 'api.ibkr.com', 'User-Agent': 'ibind', } diff --git a/ibind/subscriptions/__init__.py b/ibind/subscriptions/__init__.py new file mode 100644 index 00000000..164b8f66 --- /dev/null +++ b/ibind/subscriptions/__init__.py @@ -0,0 +1,17 @@ +from ibind.ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription + +from ibind.ws_v2.ws_subscriptions import SubscriptionHandle, BindingStatus, Subscription, SubscriptionResolver + +__all__ = [ + 'Subscription', + 'SubscriptionResolver', + 'SubscriptionHandle', + 'BindingStatus', + 'MarketDataSubscription', + 'OrdersSubscription', + 'AccountLedgerSubscription', + 'AccountSummarySubscription', + 'PnlSubscription', + 'TradesSubscription', + 'MarketHistorySubscription', +] diff --git a/ibind/support/logs.py b/ibind/support/logs.py index 71ed2f5a..9c9349f6 100644 --- a/ibind/support/logs.py +++ b/ibind/support/logs.py @@ -1,5 +1,6 @@ import datetime import logging +import os.path import sys from pathlib import Path from typing import List @@ -44,7 +45,12 @@ def project_logger(filepath=None): Returns: logging.Logger: The project-specific logger instance. """ - return logging.getLogger('ibind' + (f'.{Path(filepath).stem}' if filepath is not None else '')) + logger_name = 'ibind' + if filepath is not None: + child = Path(filepath).stem if os.path.exists(filepath) else str(filepath) + logger_name += f'.{child}' + + return logging.getLogger(logger_name) _LOGGER = project_logger() diff --git a/ibind/support/py_utils.py b/ibind/support/py_utils.py index e5f42958..3e9a0898 100644 --- a/ibind/support/py_utils.py +++ b/ibind/support/py_utils.py @@ -1,6 +1,7 @@ import copy import inspect import os +import ssl import sys import threading import time @@ -10,7 +11,9 @@ from enum import Enum, EnumMeta from functools import wraps from collections.abc import Mapping +from pathlib import Path from typing import List, TypeVar, Union, Dict +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit from ibind.support.logs import project_logger @@ -20,6 +23,9 @@ S = TypeVar('S') OneOrMany = Union[S, List[S]] +def noop(): + return None + _LOGGER = project_logger(__file__) @@ -196,6 +202,27 @@ def filter_none(d): # pragma: no cover return d +def append_query_params(url: str, params: dict) -> str: + parts = urlsplit(url) + query_items = parse_qsl(parts.query, keep_blank_values=True) + query_items.extend((key, value) for key, value in params.items() if value is not None) + query = urlencode(query_items) + return urlunsplit((parts.scheme, parts.netloc, parts.path, query, parts.fragment)) + + +def make_websocket_sslopt(cacert: Union[str, os.PathLike, bool, None]) -> dict: + if cacert in (False, None): + return {'cert_reqs': ssl.CERT_NONE} + + if cacert is True: + return {'cert_reqs': ssl.CERT_REQUIRED} + + if not Path(cacert).exists(): + raise ValueError(f'cacert must be a valid Path, True, False, or None, found: {cacert}') + + return {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': str(cacert)} + + class TimeoutLock: # pragma: no cover """ A lock with a timeout mechanism, extending the standard threading.RLock. @@ -210,17 +237,19 @@ class TimeoutLock: # pragma: no cover def __init__(self, timeout: int): self._lock = threading.RLock() self._timeout = timeout - self._acquired = False def acquire(self, *args, **kwargs): - self._acquired = self._lock.acquire(*args, timeout=self._timeout, **kwargs) + acquired = self._lock.acquire(*args, timeout=self._timeout, **kwargs) + if not acquired: + raise TimeoutError(f'Could not acquire lock within {self._timeout} seconds') + return True def release(self): - if self._acquired: - self._lock.release() + self._lock.release() def __enter__(self): self.acquire() + return self def __exit__(self, type, value, traceback): self.release() @@ -279,8 +308,8 @@ def wait_until(condition: callable, timeout_message: str = None, timeout: float bool: True if the condition becomes True within the timeout period, False otherwise. """ - deadline = time.time() + timeout - while time.time() < deadline: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: if condition(): return True time.sleep(sleep) @@ -351,4 +380,4 @@ def warn_if_late_load(*args, **kwargs): dotenv.load_dotenv = warn_if_late_load except ImportError: - pass # dotenv is not installed, nothing to patch \ No newline at end of file + pass # dotenv is not installed, nothing to patch diff --git a/ibind/var.py b/ibind/var.py index 246710d8..8d062744 100644 --- a/ibind/var.py +++ b/ibind/var.py @@ -102,6 +102,9 @@ def to_bool(value): IBIND_WS_LOG_RAW_MESSAGES = to_bool(os.environ.get('IBIND_WS_LOG_RAW_MESSAGES', False)) """ Whether raw WebSocket messages should be logged. """ +IBIND_WS_SKIP_UTF8_VALIDATION = to_bool(os.environ.get('IBIND_WS_SKIP_UTF8_VALIDATION', True)) +""" Whether to skip UTF-8 validation for WebSocket messages. """ + ##### OAuth common ##### IBIND_USE_OAUTH = to_bool(os.environ.get('IBIND_USE_OAUTH', False)) diff --git a/ibind/ws_v2/__init__.py b/ibind/ws_v2/__init__.py new file mode 100644 index 00000000..0b087100 --- /dev/null +++ b/ibind/ws_v2/__init__.py @@ -0,0 +1 @@ +"""WebSocket V2 runtime primitives.""" diff --git a/ibind/ws_v2/_ws_events.py b/ibind/ws_v2/_ws_events.py new file mode 100644 index 00000000..5957ec93 --- /dev/null +++ b/ibind/ws_v2/_ws_events.py @@ -0,0 +1,439 @@ +import threading +from datetime import datetime +from queue import Queue, Full, Empty +from threading import Thread, Event +from typing import Protocol, Callable, TypeVar, List, Dict, Any + +from pydantic import BaseModel, ConfigDict, Field + +from ibind.base.queue_controller import QueueAccessor +from ibind.support.logs import project_logger +from ibind.support.py_utils import OneOrMany, exception_to_string, tname + +__all__ = [] + +_LOGGER = project_logger('ibkr_ws_client') + + +# ====================== +# == Events Classes == +# ====================== + + +class WsEvent(BaseModel): # pragma: no cover + """ + Base class for all WebSocket events. + + Immutable event model that tracks when it was received. + """ + + model_config = ConfigDict(frozen=True, extra='forbid') + + received_at: datetime = Field(default_factory=datetime.now) + + def __str__(self): + return self._format() + + def __repr__(self): + return self._format() + + def _format(self): + data = self.model_dump() + + # normalize values + for k, v in data.items(): + if isinstance(v, datetime): + data[k] = v.isoformat() + elif isinstance(v, Exception): + data[k] = str(v) + + # move received_at to the end + items = [(k, v) for k, v in data.items() if k != 'received_at'] + if 'received_at' in data: + items.append(('received_at', data['received_at'])) + + fields = ', '.join(f'{k}={v}' if isinstance(v, str) and 'T' in v else f'{k}={repr(v)}' for k, v in items) + + return f'{self.__class__.__name__}({fields})' + + +class LifecycleEvent(WsEvent): + """Base class for WebSocket connection lifecycle events.""" + + pass + + +class WsOpen(LifecycleEvent): + """Emitted when the WebSocket connection is successfully opened.""" + + pass + + +class WsAuthenticated(LifecycleEvent): + """Emitted when the WebSocket connection is authenticated.""" + + pass + + +class WsDegraded(LifecycleEvent): + """Emitted when the WebSocket connection enters a degraded state.""" + + pass + + +class WsReady(LifecycleEvent): + """Emitted when the WebSocket connection is ready for use.""" + + pass + + +class WsClose(LifecycleEvent): + """Emitted when the WebSocket connection is closed.""" + + close_status_code: int | None + close_msg: str | None + + +class WsError(LifecycleEvent): + """Emitted when a WebSocket error occurs.""" + + model_config = ConfigDict(frozen=True, extra='forbid', arbitrary_types_allowed=True) + error: Exception + + +# ============= +# == Sinks == +# ============= + + +class EventSink(Protocol): # pragma: no cover + """Protocol for objects that can receive and process WebSocket events.""" + + def emit(self, event: 'WsEvent') -> None: + pass + + +class LogSink: # pragma: no cover + """Sink that logs events using the project logger.""" + + def emit(self, event: WsEvent) -> None: + _LOGGER.info(event) + + +class NoopSink: # pragma: no cover + """Sink that discards all events without processing.""" + + def emit(self, event: WsEvent) -> None: + pass + + +T = TypeVar('T', bound=WsEvent) + + +class CallbackSink: + """ + Sink that invokes registered callbacks for specific event types. + + Callbacks are registered per event type and invoked when matching events are emitted. + Exceptions from callbacks are logged but do not propagate. + """ + + def __init__(self): + self._callbacks: Dict[type[WsEvent], List[Callable[[WsEvent], None]]] = {} + + def on(self, event_type: type[WsEvent], callback: Callable[[T], None]) -> None: + """ + Register a callback for a specific event type. + + Args: + event_type (type[WsEvent]): The event type to listen for. + callback (Callable): Function to invoke when events of this type are emitted. + """ + self._callbacks.setdefault(event_type, []).append(callback) + + def emit(self, event: WsEvent) -> None: + """ + Emit an event to all registered callbacks for its type. + + Args: + event (WsEvent): The event to emit. + """ + for callback in self._callbacks.get(type(event), []): + try: + callback(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception emitting event to callback {callback.__name__}: {exception_to_string(e)}') + + def __str__(self): # pragma: no cover + return f'{self.__class__.__qualname__}()' + + +class QueueSink: + """ + Sink that stores events in separate queues per event type. + + Maintains a dictionary of queues, one for each event type. Events can be + retrieved synchronously or asynchronously via queue accessors. + """ + + def __init__(self): + self._queues = {} + + def new_queue_accessor(self, event_type: type[WsEvent]) -> QueueAccessor: + """ + Create a queue accessor for a specific event type. + + Args: + event_type (type[WsEvent]): The event type to access. + + Returns: + QueueAccessor: Accessor for the queue associated with this event type. + """ + return QueueAccessor(self._get_queue(event_type), event_type) + + def _get_queue(self, event_type: type[WsEvent]) -> Queue: # pragma: no cover + try: + return self._queues[event_type] + except KeyError: + self._queues[event_type] = Queue() + return self._queues[event_type] + + def get(self, event_type: type[WsEvent], block: bool = False, timeout: float = None) -> Any: + """ + Retrieve an event from the queue for a specific event type. + + Args: + event_type (type[WsEvent]): The event type to retrieve. + block (bool, optional): Whether to block until an event is available. Default: False. + timeout (float, optional): Maximum time to block in seconds. Default: None. + + Returns: + WsEvent | None: The retrieved event, or None if the queue is empty and block=False. + """ + try: + return self._get_queue(event_type).get(block=block, timeout=timeout) + except Empty: + return None + + def empty(self, event_type: type[WsEvent]) -> bool: + """ + Check if the queue for a specific event type is empty. + + Args: + event_type (type[WsEvent]): The event type to check. + + Returns: + bool: True if the queue is empty, False otherwise. + """ + return self._get_queue(event_type).empty() + + def emit(self, event: WsEvent) -> None: + """ + Emit an event by adding it to the queue for its type. + + Args: + event (WsEvent): The event to emit. + """ + queue = self._get_queue(type(event)) + queue.put(event) + + def __str__(self): # pragma: no cover + return f'{self.__class__.__qualname__}()' + + +class CompositeSink: + """ + Sink that forwards events to multiple child sinks. + + Exceptions from individual sinks are logged but do not prevent other sinks + from receiving the event. + """ + + def __init__(self, *sinks: EventSink): + """ + Create a composite sink. + + Args: + *sinks (EventSink): One or more sinks to forward events to. + """ + self._sinks = sinks + + def emit(self, event: WsEvent) -> None: + """ + Emit an event to all registered sinks. + + Args: + event (WsEvent): The event to emit. + """ + for sink in self._sinks: + try: + sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception emitting event to sink: {exception_to_string(e)}') + + def __str__(self): # pragma: no cover + return f'{self.__class__.__qualname__}()' + + +class AsyncSink: + """ + Sink that forwards events to another sink asynchronously via a background thread. + + Events are queued and processed in a separate thread. When the queue is full, + events are dropped according to the drop_oldest policy. + """ + + def __init__( + self, + sink: EventSink, + maxsize: int = 10_000, + drop_oldest: bool = True, + stop_timeout: float = 5, + cycle_interval: float = 0.25, + ): + """ + Create an asynchronous sink. + + Args: + sink (EventSink): The sink to forward events to. + maxsize (int, optional): Maximum queue size. Default: 10,000. + drop_oldest (bool, optional): Whether to drop oldest events when full. + If False, drops newest events. Default: True. + stop_timeout (float, optional): Maximum time to wait for thread to stop in seconds. Default: 5. + cycle_interval (float, optional): Interval between queue processing cycles in seconds. Default: 0.25. + """ + self._sink = sink + self._queue = Queue(maxsize=maxsize) + self._drop_oldest = drop_oldest + self._stop_timeout = stop_timeout + self._cycle_interval = cycle_interval + + self._running = False + self._thread: Thread | None = None + self._wait_event = Event() + + def start(self): + """Start the background thread for processing events.""" + if self._running: + return + if self._thread is not None and self._thread.is_alive(): + _LOGGER.error(f'{self}: Async sink thread is still alive and cannot be restarted yet') + return False + + self._running = True + self._thread = Thread(target=self._cycle, name='async_sink_thread', daemon=True) + self._thread.start() + return True + + def stop(self) -> bool: + """ + Stop the background thread and discard remaining events. + + Returns: + bool: True if the thread stopped successfully, False if it timed out. + + Raises: + RuntimeError: If called from within the async sink thread. + """ + if not self._running: + return True + + if threading.current_thread() == self._thread: + raise RuntimeError(f'{self}: Stopping async sink called from within async sink thread. Ensure it is stopped from a separate thread') + + self._running = False + self._wait_event.set() + + succeeded = True + if self._thread is not None: + self._thread.join(self._stop_timeout) + succeeded = not self._thread.is_alive() + + if succeeded: + self._thread = None + + if self._queue.qsize() > 0: + _LOGGER.warning(f'{self}: Event queue not empty when stopping; discarding {self._queue.qsize()} events') + + return succeeded + + def emit(self, event: WsEvent) -> None: + """ + Queue an event for asynchronous processing. + + Args: + event (WsEvent): The event to emit. + """ + try: + self._queue.put_nowait(event) + self._wait_event.set() + return + except Full: + if not self._drop_oldest: + _LOGGER.warning(f'{self}: Event queue full; dropping newest event: {event}') + return + + try: + dropped = self._queue.get_nowait() + _LOGGER.warning(f'{self}: Event queue full; dropping oldest event: {dropped}') + except Empty: + pass + + try: + self._queue.put_nowait(event) + self._wait_event.set() + except Full: + _LOGGER.warning(f'{self}: Event queue still full; dropping event: {event}') + + def _consume_queue(self): + while True: + try: + event = self._queue.get_nowait() + except Empty: + break + + try: + self._sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception emitting event to sink: {exception_to_string(e)}') + + def _cycle(self): + _LOGGER.debug(f'{self}: AsyncSink thread started ({tname()})') + while self._running: + self._wait_event.clear() + self._wait_event.wait(self._cycle_interval) + self._consume_queue() + + self._consume_queue() + _LOGGER.debug(f'{self}: AsyncSink thread stopped ({tname()})') + + def __str__(self): # pragma: no cover + return f'{self.__class__.__qualname__}({self._queue.qsize()})' + + +# ============== +# == Router == +# ============== + + +class Router(Protocol): # pragma: no cover + """ + Protocol for routing raw WebSocket messages to typed events. + + Implementations parse raw messages and convert them to one or more WsEvent instances. + """ + + def route(self, raw_message) -> OneOrMany[WsEvent]: + """ + Route a raw message to one or more events. + + Args: + raw_message: The raw message to route. + + Returns: + OneOrMany[WsEvent]: One or more events, or None to skip the message. + """ + pass + + def __str__(self): + return f'{self.__class__.__qualname__}()' diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py new file mode 100644 index 00000000..59954e66 --- /dev/null +++ b/ibind/ws_v2/ws_runtime.py @@ -0,0 +1,479 @@ +import json +import threading +import time +from queue import Queue +from threading import Thread, Event +from typing import Union, List, Dict, Callable, Literal + +from ibind import events +from ibind.events import WsEvent +from ibind.support.logs import project_logger +from ibind.support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, noop, make_websocket_sslopt +from ibind.ws_v2._ws_events import EventSink, Router, CallbackSink, AsyncSink +from ibind.ws_v2.ws_subscriptions import SubscriptionController, SubscriptionResolver +from ibind.ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect + +_LOGGER = project_logger('ibkr_ws_client') + +_DEFAULT_TIMEOUT = 5 +_MAX_TRANSPORT_EVENT_RETRIES = 5 +_HEALTH_CHECK_INTERVAL = 10 + + +class WsState(VerboseEnum): + STOPPED = 'STOPPED' + STARTING = 'STARTING' + CONNECTING = 'CONNECTING' + OPEN = 'OPEN' + AUTHENTICATED = 'AUTHENTICATED' + CLOSED = 'CLOSED' + DEGRADED = 'DEGRADED' + RECONNECTING = 'RECONNECTING' + STOPPING = 'STOPPING' + + +def make_sslopt(cacert: Union[str, bool]): + return make_websocket_sslopt(cacert) + + +class WsRuntime(): + def __init__( + self, + url: str, + cycle_interval: float, + sink: EventSink, + internal_sink: CallbackSink, + router: Router, + subscription_resolver: SubscriptionResolver, + ready_state: Literal[WsState.OPEN, WsState.AUTHENTICATED] = WsState.OPEN, + cacert: Union[str, bool] = False, + connection_timeout: float = _DEFAULT_TIMEOUT, + reconnect_timeout: float | None = _DEFAULT_TIMEOUT, + max_ping_interval: float = 20, + get_cookie: Callable = noop, + get_header: Callable = noop, + ): + if ready_state not in [WsState.OPEN, WsState.AUTHENTICATED]: + raise ValueError(f'Invalid ready_state: {ready_state}, must be either {WsState.OPEN} or {WsState.AUTHENTICATED}') + self._url = url + self._cycle_interval = cycle_interval + self._sink = sink + self._internal_sink = internal_sink + self._router = router + self._ready_state = ready_state + self._connection_timeout = connection_timeout + self._reconnect_timeout = reconnect_timeout + self._max_ping_interval = max_ping_interval + + self._state = WsState.STOPPED + self._authenticated = False + self._running = False + self._last_heartbeat = None + self._last_tic = time.time() + self._last_health_check = time.time() + + self._transport_thread: Thread | None = None + self._runtime_thread: Thread | None = None + self._transport_queue = Queue() + self._wait_event = Event() + + self._state_lock = TimeoutLock(60) + + self._sslopt = make_sslopt(cacert) + + self._get_cookie = get_cookie + self._get_header = get_header + + self._transport: WsTransport = self._new_transport() + + self.subscription_controller = SubscriptionController(send_payload=self.send, subscription_resolver=subscription_resolver) + + def _new_transport(self): + return WsTransport( + url=self._url, + event_callback=self._transport_callback, + sslopt=self._sslopt, + get_cookie=self._get_cookie, + get_header=self._get_header, + max_ping_interval=self._max_ping_interval, + connection_timeout=self._connection_timeout, + reconnect_timeout=self._reconnect_timeout, + ) + + def _set_state(self, value): + _LOGGER.debug(f'{self}: {self._state.value} -> {value.value}') + with self._state_lock: + self._state = value + + if self._state == self._ready_state: + self._websocket_ready() + + def get_state(self) -> WsState: + return self._state + + def _websocket_ready(self): + self._emit(events.WsReady()) + self._last_heartbeat = time.time() + _LOGGER.info(f'{self}: Websocket ready, setting last_heartbeat to {self._last_heartbeat}') + + def set_authenticated(self, value: bool): + previous_value = self._authenticated + self._authenticated = value + + if value and self._state == WsState.OPEN: + self._emit(events.WsAuthenticated()) + self._set_state(WsState.AUTHENTICATED) + + if value == False and self._state == self._ready_state: + self.state_degraded() + if value != previous_value: + _LOGGER.info(f'{self}: Connection {"authenticated" if value else "unauthenticated"}') + + def state_degraded(self): + was_already_degraded = self._state == WsState.DEGRADED + self._set_state(WsState.DEGRADED) + self.subscription_controller.invalidate_subscriptions() + + if not was_already_degraded: + self._emit(events.WsDegraded()) + + def get_authenticated(self) -> bool: + return self._authenticated + + def _new_transport_thread(self): + self._transport_thread = Thread(target=self._transport.connect, name='ws_transport_thread') + self._transport_thread.daemon = True + self._transport_thread.start() + + def _new_runtime_thread(self): + self._runtime_thread = Thread(target=self._cycle, name='ws_runtime_thread') + self._runtime_thread.daemon = True + self._runtime_thread.start() + + def _stop_transport_thread(self) -> bool: + try: + self._transport.stop() + if self._transport_thread is None: + return True + + _LOGGER.debug(f'{self}: Joining transport thread') + + self._transport_thread.join(self._connection_timeout) + is_alive = self._transport_thread.is_alive() + if not is_alive: + self._transport_thread = None + return not is_alive + except Exception as e: + _LOGGER.error(f'{self}: Failed to stop transport thread: {e}') + + return False + + def start(self): + if self._state != WsState.STOPPED: + return + + if self._runtime_thread is not None and self._runtime_thread.is_alive(): + _LOGGER.error(f'{self}: Runtime thread must be stopped and joined before starting') + return + + _LOGGER.info(f'{self}: Starting WebSocket runtime') + + self._set_state(WsState.STARTING) + self._running = True + + self._new_runtime_thread() + + if isinstance(self._sink, AsyncSink): + self._sink.start() + + connection_success = wait_until(lambda: self._state == self._ready_state, f'{self}: Starting timeout', timeout=self._connection_timeout) + return connection_success + + def stop(self): + if self._state == WsState.STOPPED: + return + + if threading.current_thread() == self._runtime_thread: + raise RuntimeError(f'{self}: Stopping runtime called from within runtime thread. Ensure it is called from a separate thread') + + _LOGGER.info(f'{self}: Stopping WebSocket runtime') + + # wait until one more pass of the runtime thread has occurred to allow unsubscriptions to complete + wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) + self._wait_event.set() + wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) + + self._set_state(WsState.STOPPING) + transport_thread_stopped = self._stop_transport_thread() + if not transport_thread_stopped: + _LOGGER.error(f'{self}: Failed to stop transport thread; keeping the thread reference to avoid duplicate transport loops') + self._transport.set_degraded(True) + + self._running = False + runtime_thread = self._runtime_thread + if runtime_thread is not None: + runtime_thread.join(self._connection_timeout) + + if runtime_thread is not None and runtime_thread.is_alive(): + _LOGGER.error(f'{self}: Runtime thread failed to stop; keeping the thread reference to avoid duplicate runtime loops') + return + + self._runtime_thread = None + + if isinstance(self._sink, AsyncSink): + self._sink.stop() + + self._set_state(WsState.STOPPED) + + def send(self, payload: str) -> bool: + if self._state != self._ready_state: + _LOGGER.error(f'{self}: State must be {self._ready_state.value} before sending payloads, found {self._state.value}') + return False + + _LOGGER.info(f'{self}: Sending payload: {payload}') + + return self._transport.send(payload) + + def send_json(self, payload: Union[List, Dict]) -> bool: # pragma: no cover + return self.send(json.dumps(payload)) + + def is_running(self) -> bool: + return self._running + + def set_last_heartbeat(self, value: float): + self._last_heartbeat = value + + def hard_reset(self) -> None: + _LOGGER.info(f'{self}: Hard reset') + + if threading.current_thread() in [self._runtime_thread, self._transport_thread]: + raise RuntimeError(f'{self}: Hard reset called from Runtime or Transport thread. Ensure it is called from a separate thread') + + self.stop() + self.start() + + def restart_transport(self): + if threading.current_thread() == self._transport_thread: + raise RuntimeError(f'{self}: Resetting transport thread called from within transport thread. Ensure it is called from a separate thread') + + transport_thread_stopped = self._stop_transport_thread() + if not transport_thread_stopped: + _LOGGER.error(f'{self}: Failed to stop transport thread; restart aborted to avoid duplicate transport loops') + self._transport.set_degraded(True) + return False + + self._transport.set_degraded(True) + self._transport = self._new_transport() + self._new_transport_thread() + return True + + def reset_websocket_app(self): + self._transport.reset_websocket_app() + + def __str__(self): + return f'{self.__class__.__qualname__}({self._state})' + + # ====================== + # == Transport Thread == + # ====================== + + def _transport_callback(self, te: TransportEvent): + self._transport_queue.put(te) + self._wait_event.set() + + # ====================== + # == Runtime Thread == + # ====================== + + def _maintain_transport(self): + # Don't maintain the transport thread if we are stopping + if self._state == WsState.STOPPING: + return + + if self._transport_thread is None or not self._transport_thread.is_alive(): + self._set_state(WsState.CONNECTING) + self._new_transport_thread() + + def _maintain_subscriptions(self): + if self._state != self._ready_state: + return + + self.subscription_controller.reconcile_bindings() + + def check_should_reset(self): + # If WSA is not ready, we don't try to fix health + if not self._transport.is_ready(): + return False + + # If we're not either open or authenticated, we let WSA handle the reconnect first + if self._state not in [WsState.OPEN, WsState.AUTHENTICATED]: + return False + + ping_ok = self._transport.check_ping(self._max_ping_interval) + if not ping_ok: + _LOGGER.warning( + f'{self}: Last WebSocket ping happened {self._transport.get_time_since_last_ping():.2f} seconds ago, ' + f'exceeding the max ping interval of {self._max_ping_interval}.' + ) + # If we have a reconnect timeout, we let WSA handle the reconnect, otherwise let's reset the WSA + return self._reconnect_timeout is None + + heartbeat_ok = True + if self._last_heartbeat is not None: + diff = abs(time.time() - self._last_heartbeat) + if diff > self._max_ping_interval: + _LOGGER.warning( + f'{self}: Last heartbeat happened {diff:.2f} seconds ago, ' + f'exceeding the max ping interval of {self._max_ping_interval}.' + ) + heartbeat_ok = False + + if heartbeat_ok: + return False + + return True + + def health_check(self) -> bool: + if not self.check_should_reset(): + return True + + if not self._running: # return early if runtime got stopped in the meantime + return False + + self.state_degraded() + + _LOGGER.warning(f'{self}: Health check failed, resetting transport websocket') + self.reset_websocket_app() + return False + + def _cycle(self): + _LOGGER.debug(f'{self}: Runtime thread started ({tname()})') + while self._running: + self._maintain_transport() + self._maintain_subscriptions() + + self._process_transport_queue() + + if time.time() - self._last_health_check > _HEALTH_CHECK_INTERVAL: + self._last_health_check = time.time() + self.health_check() + + self._wait_event.clear() + self._wait_event.wait(self._cycle_interval) + + # if not stopped or closed yet, attempt to do one last pass before the thread dies + if self._state not in [WsState.STOPPED, WsState.CLOSED]: + # final pass through the transport queue to flush any remaining events + self._process_transport_queue() + + # final pass through the subscription controller to carry out final unsubscribe events + self.subscription_controller.reconcile_bindings() + + _LOGGER.debug(f'{self}: Runtime thread stopped ({tname()})') + + def _process_transport_queue(self): + retry_events = [] + while not self._transport_queue.empty(): + te = self._transport_queue.get() + try: + self._handle_transport_event(te) + except Exception as e: + _LOGGER.error(f'{self}: Exception processing transport event {te}: {exception_to_string(e)}') + te.add_attempt() + if te.get_attempt() > _MAX_TRANSPORT_EVENT_RETRIES: + _LOGGER.error(f'{self}: Max retries ({_MAX_TRANSPORT_EVENT_RETRIES}) reached for transport event {te}, dropping event.') + continue + retry_events.append(te) + + for event in retry_events: + self._transport_queue.put(event) + + def _handle_transport_event(self, transport_event: TransportEvent): + if isinstance(transport_event, TransportOpened): + self._handle_on_open() + elif isinstance(transport_event, TransportReconnect): + self._handle_on_reconnect() + elif isinstance(transport_event, TransportClosed): + self._handle_on_close(transport_event.close_status_code, transport_event.close_msg) + elif isinstance(transport_event, TransportError): + self._handle_on_error(transport_event.exception) + elif isinstance(transport_event, TransportMessage): + self._handle_on_message(transport_event.message) + else: + _LOGGER.error(f'{self}: Unknown event type: {type(transport_event)}: {transport_event}') + + def _handle_on_message(self, message): # pragma: no cover + events: OneOrMany[WsEvent] = self._router.route(message) + + # Router decided to skip this message + if events is None: + return + + # Handle both lists and individual events + if not isinstance(events, list) and isinstance(events, WsEvent): + events = [events] + + # Propagate events to the sink + for event in events: + try: + self.subscription_controller.observe(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception observing subscription for {event}: {exception_to_string(e)}') + + self._emit(event) + + def _handle_on_open(self): + self._last_heartbeat = None + self._set_state(WsState.OPEN) + _LOGGER.info(f'{self}: Connection open') + if self._state != self._ready_state: + self.set_authenticated(False) + self._emit(events.WsOpen()) + + def _handle_on_reconnect(self): + _LOGGER.info(f'{self}: Connection reopened') + self._last_heartbeat = None + self._set_state(WsState.OPEN) + if self._state != self._ready_state: + self.set_authenticated(False) + self._emit(events.WsOpen()) # we emit Open since reconnect pretty much equivalent + + def _handle_on_error(self, exception: Exception): + _LOGGER.error(f'{self}: Connection error: {exception}') + if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: + self.state_degraded() + self.set_authenticated(False) + self._emit(events.WsError(error=exception)) + + def _handle_on_close(self, close_status_code, close_msg): + self._last_heartbeat = None + + if self._state != WsState.STOPPING: + _LOGGER.info(f'{self}: Connection closed') + self.set_authenticated(False) + self.subscription_controller.invalidate_subscriptions() + else: + _LOGGER.info(f'{self}: Connection gracefully closed') + + self._set_state(WsState.CLOSED) + + if close_status_code is not None or close_msg is not None: # this means an error + try: + msg = close_msg.decode('utf-8') + except AttributeError: + msg = close_msg + + _LOGGER.error(f'{self}: on_close error: {close_status_code} | {msg}') + + self._emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) + + def _emit(self, event: WsEvent): + try: + self._internal_sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: Internal sink exception for {event}: {exception_to_string(e)}') + + try: + self._sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: External sink exception for {event}: {exception_to_string(e)}') diff --git a/ibind/ws_v2/ws_subscriptions.py b/ibind/ws_v2/ws_subscriptions.py new file mode 100644 index 00000000..c1cd469d --- /dev/null +++ b/ibind/ws_v2/ws_subscriptions.py @@ -0,0 +1,550 @@ +import copy +import time +from enum import Enum +from threading import Condition, RLock +from typing import Dict, Optional, Callable, Protocol, Tuple, Literal + +from pydantic import BaseModel, ConfigDict + +from ibind.support.logs import project_logger +from ibind.support.py_utils import exception_to_string +from ibind.events import WsEvent + +_LOGGER = project_logger('ibkr_ws_client') + + +class Subscription(BaseModel): # pragma: no cover + """ + Base class for WebSocket subscriptions. + + Immutable model defining subscription behaviour including payload generation, + confirmation requirements, and expiry settings. Subclasses implement specific + subscription types by overriding abstract methods. + + Attributes: + expiry_seconds (int | None): Time in seconds before subscription expires and + requires renewal. None means no expiry. Default: None. + """ + + model_config = ConfigDict(frozen=True) + expiry_seconds: int | None = None + + @property + def topic(self) -> str: + """Get the subscription topic identifier.""" + raise NotImplementedError + + def subscribe_payload(self) -> str: + """Generate the payload string to send for subscribing.""" + raise NotImplementedError + + def unsubscribe_payload(self) -> str: + """Generate the payload string to send for unsubscribing.""" + raise NotImplementedError + + @property + def confirms_subscribe(self) -> bool: + """Whether the server sends confirmation when subscription succeeds.""" + return True + + @property + def confirms_unsubscribe(self) -> bool: + """Whether the server sends confirmation when unsubscription succeeds.""" + return False + + def binding_key(self): + """Get the unique key identifying this subscription binding.""" + return self.subscribe_payload() + + def __str__(self): + return f'{self.__class__.__qualname__}({self.binding_key()})' + + +class SubscriptionResolver(Protocol): # pragma: no cover + """ + Protocol for resolving subscription binding keys from events. + + Implementations determine which subscription an event belongs to and whether + the event indicates an active or inactive subscription state. + """ + + def resolve_binding_key(self, event: WsEvent) -> Tuple[bool, str]: + """ + Resolve the binding key and active state from an event. + + Args: + event (WsEvent): Event to resolve. + + Returns: + tuple[bool, str]: (is_active, binding_key) where is_active indicates + subscription is active, and binding_key identifies the subscription. + Returns (None, None) if event is not subscription-related. + """ + ... + + +class BindingStatus(Enum): # pragma: no cover + """ + Status of a subscription binding. + + Tracks the lifecycle state of a subscription from initial registration through + activation, failure, or unsubscription. + """ + + NEW = 'NEW' + PENDING = 'PENDING' + ACTIVE = 'ACTIVE' # subscription successful + FAILED = 'FAILED' + DEGRADED = 'DEGRADED' + UNSUBSCRIBED = 'UNSUBSCRIBED' # unsubscription successful + EXPIRED = 'EXPIRED' + + +class Binding(BaseModel): + """ + Internal state tracking for a subscription binding. + + Maintains the desired intent (subscribe or unsubscribe), current status, + and retry state for subscription operations. + + Attributes: + subscription (Subscription): The subscription being tracked. + intent (Literal[BindingStatus.ACTIVE, BindingStatus.UNSUBSCRIBED]): Desired state. + status (BindingStatus): Current state. Default: BindingStatus.NEW. + attempts (int): Number of attempts made. Default: 0. + last_attempt (float): Timestamp of last attempt. Default: 0. + """ + + subscription: Subscription + intent: Literal[BindingStatus.ACTIVE, BindingStatus.UNSUBSCRIBED] + status: BindingStatus = BindingStatus.NEW + attempts: int = 0 + last_attempt: float = 0 + + @property + def done(self) -> bool: + """Whether the binding has reached its intended state.""" + return self.status == self.intent + + def reset(self): # pragma: no cover + """Reset retry state to allow new attempts.""" + self.attempts = 0 + self.last_attempt = 0 + + +class SubscriptionHandle: + """ + Handle for interacting with a subscription. + + Provides methods to query subscription state, wait for completion, and unsubscribe. + Returned by subscribe/unsubscribe operations. + """ + + def __init__(self, controller: 'SubscriptionController', subscription: Subscription): + self._controller = controller + self._subscription = subscription + + @property + def binding_key(self) -> str: + """Get the unique key identifying this subscription.""" + return self._subscription.binding_key() + + @property + def status(self) -> BindingStatus: + """Get the current status of this subscription.""" + return self._controller.get_status(self.binding_key) + + @property + def active(self) -> bool: + """Whether the subscription is currently active.""" + return self.status == BindingStatus.ACTIVE + + @property + def unsubscribed(self) -> bool: + """Whether the subscription has been unsubscribed.""" + return self.status == BindingStatus.UNSUBSCRIBED + + @property + def done(self) -> bool: + """Whether the subscription has reached its intended state.""" + return self._controller.is_done(self.binding_key) + + def wait(self, timeout: float | None = None) -> bool: + """ + Wait for the subscription to reach its intended state. + + Args: + timeout (float | None): Maximum time to wait in seconds, or indefinitely if None. Default: None. + + Returns: + bool: True if subscription reached intended state, False if timed out or failed. + """ + return self._controller.wait_for(self.binding_key, timeout=timeout) + + def unsubscribe(self) -> 'SubscriptionHandle': + """ + Unsubscribe from this subscription. + + Returns: + SubscriptionHandle: This handle for chaining. + """ + self._controller.unsubscribe(self._subscription) + return self + + +class SubscriptionController: + """ + Manages WebSocket subscriptions with automatic retries and state tracking. + + Handles subscription lifecycle including registration, activation, expiry, and unsubscription. + Maintains binding state and coordinates with a resolver to match events to subscriptions. + Thread-safe through internal condition variable. + """ + + def __init__( + self, + send_payload: Callable[[str], bool], + subscription_resolver: SubscriptionResolver, + subscription_retries: int = 5, + subscription_timeout: float = 2, + ): + """ + Create a subscription controller. + + Args: + send_payload (Callable[[str], bool]): Function to send payloads through the WebSocket. + Returns True if sent successfully. + subscription_resolver (SubscriptionResolver): Resolver to match events to subscriptions. + subscription_retries (int, optional): Maximum retry attempts per subscription. Default: 5. + subscription_timeout (float, optional): Seconds to wait between retry attempts. Default: 2. + """ + self._send_payload = send_payload + self._subscription_resolver = subscription_resolver + self._subscription_retries = subscription_retries + self._subscription_timeout = subscription_timeout + + self._bindings: Dict[str, Binding] = {} + self._condition = Condition(RLock()) + + def _send(self, payload) -> bool: + try: + success = self._send_payload(payload) + if not success: + _LOGGER.info(f'{self}: Sending payload unsuccessful: {payload}') + return success + except Exception as e: + _LOGGER.exception(f'{self}: Exception sending payload: {payload}\n{exception_to_string(e)}') + return False + + def observe(self, event: WsEvent): + """ + Process an event to update subscription state. + + Uses the resolver to determine if the event confirms subscription or unsubscription, + then updates the corresponding binding status. + + Args: + event (WsEvent): Event to process. + """ + is_active, binding_key = self._subscription_resolver.resolve_binding_key(event) + + # None means the event is not related to a tracked subscription + if binding_key is None: + return + + with self._condition: + if not self.has_subscription(binding_key): + _LOGGER.warning(f'{self}: Observed a binding_key "{binding_key}" that is missing a subscription. Event: {event}') + return + + if is_active: + self._confirm_subscribed(binding_key) + else: + self._confirm_unsubscribed(binding_key) + + def _make_attempt(self, binding: Binding): + subscription = binding.subscription + if binding.intent == BindingStatus.ACTIVE: + payload = subscription.subscribe_payload() + self._send(payload) + if not subscription.confirms_subscribe: + _LOGGER.info(f'{self}: Subscribed: {payload} without confirmation.') + self._confirm_subscribed(subscription.binding_key()) + + elif binding.intent == BindingStatus.UNSUBSCRIBED: + payload = subscription.unsubscribe_payload() + self._send(payload) + if not subscription.confirms_unsubscribe: + _LOGGER.info(f'{self}: Unsubscribed: {payload} without confirmation.') + self._confirm_unsubscribed(subscription.binding_key()) + + def reconcile_binding(self, binding: Binding): + """ + Reconcile a single binding by checking expiry and retrying if needed. + + Handles expiry checks, retry logic, and failure detection. Called periodically + by the runtime to maintain subscription state. + + Args: + binding (Binding): Binding to reconcile. + """ + now = time.time() + subscription = binding.subscription + + # if either done or failed, return early or check expiration if expiry_seconds is provided + if binding.status in [binding.intent, BindingStatus.FAILED]: + if subscription.expiry_seconds is None: + return + + time_since_last_attempt = now - binding.last_attempt + if time_since_last_attempt < subscription.expiry_seconds: + return + + _LOGGER.info(f'{self}: Subscription expired: {subscription} after {time_since_last_attempt:.1f} seconds') + self._update_status(binding, BindingStatus.EXPIRED) + + # if we've exceeded the number of retries, mark the subscription as failed + if binding.attempts >= self._subscription_retries: + _LOGGER.info(f'{self}: Subscription failed after {self._subscription_retries} attempts: {binding}') + self._update_status(binding, BindingStatus.FAILED) + return + + # wait until timeout has passed since last attempt + if binding.last_attempt + self._subscription_timeout > now: + return + + binding.last_attempt = now + binding.attempts += 1 + self._make_attempt(binding) + + def reconcile_bindings(self): + """Reconcile all bindings by checking expiry and retrying as needed.""" + with self._condition: + for binding in self._bindings.values(): + self.reconcile_binding(binding) + + def subscribe(self, subscription: Subscription) -> SubscriptionHandle: + """ + Register intent to subscribe. + + Creates or updates a binding with ACTIVE intent. The actual subscription + attempt occurs during reconciliation. + + Args: + subscription (Subscription): Subscription to activate. + + Returns: + SubscriptionHandle: Handle for tracking and controlling this subscription. + """ + binding_key = subscription.binding_key() + + with self._condition: + binding = self._bindings.get(binding_key) + + if binding is None: + self._bindings[binding_key] = Binding(subscription=subscription, intent=BindingStatus.ACTIVE) + self._condition.notify_all() + _LOGGER.info(f'{self}: Registered subscription intent: {binding_key}') + + elif binding.intent != BindingStatus.ACTIVE: + binding.intent = BindingStatus.ACTIVE + + # If it had previously completed unsubscribe, it now needs work again. + if binding.status == BindingStatus.UNSUBSCRIBED: + binding.reset() + + self._condition.notify_all() + _LOGGER.info(f'{self}: Updated subscription intent: {binding_key} -> {BindingStatus.ACTIVE.value}') + + return SubscriptionHandle(self, subscription) + + def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: + """ + Register intent to unsubscribe. + + Creates or updates a binding with UNSUBSCRIBED intent. The actual unsubscription + attempt occurs during reconciliation. + + Args: + subscription (Subscription): Subscription to deactivate. + + Returns: + SubscriptionHandle: Handle for tracking this unsubscription. + """ + binding_key = subscription.binding_key() + + with self._condition: + binding = self._bindings.get(binding_key) + + if binding is None: + self._bindings[binding_key] = Binding(subscription=subscription, intent=BindingStatus.UNSUBSCRIBED) + self._condition.notify_all() + _LOGGER.info(f'{self}: Registered unsubscription intent: {binding_key}') + + elif binding.intent != BindingStatus.UNSUBSCRIBED: + binding.intent = BindingStatus.UNSUBSCRIBED + + # If it had previously completed subscribe, it now needs work again. + if binding.status == BindingStatus.ACTIVE: + binding.reset() + + self._condition.notify_all() + _LOGGER.info(f'{self}: Updated subscription intent: {binding_key} -> {BindingStatus.UNSUBSCRIBED.value}') + + return SubscriptionHandle(self, subscription) + + def invalidate_subscriptions(self): + """Mark all subscriptions as degraded, typically after connection loss.""" + with self._condition: + for binding_key, binding in self._bindings.items(): + if binding.status != BindingStatus.DEGRADED: + self._update_status(binding, BindingStatus.DEGRADED) + + def is_subscription_active(self, binding_key: str) -> Optional[bool]: # pragma: no cover + """Check if a subscription is currently active. + + Args: + binding_key (str): Binding key to check. + + Returns: + bool: True if subscription exists and is active, False otherwise. + """ + with self._condition: + if not self.has_subscription(binding_key): + return False + return self._bindings[binding_key].status == BindingStatus.ACTIVE + + def has_active_subscriptions(self) -> bool: + """ + Check if any subscriptions are currently active. + + Returns: + bool: True if any subscriptions are active, False otherwise. + """ + with self._condition: + for subscription in self._bindings: + if self.is_subscription_active(subscription): + return True + return False + + def has_subscription(self, binding_key: str) -> bool: # pragma: no cover + """Check if a subscription exists. + + Args: + binding_key (str): Binding key to check. + + Returns: + bool: True if subscription exists, False otherwise. + """ + return binding_key in self._bindings + + def get_status(self, binding_key: str) -> BindingStatus | None: # pragma: no cover + """Get the status of a subscription. + + Args: + binding_key (str): Binding key to query. + + Returns: + BindingStatus | None: Current status, or None if subscription doesn't exist. + """ + with self._condition: + if not self.has_subscription(binding_key): + return None + return self._bindings[binding_key].status + + def is_done(self, binding_key: str) -> bool | None: # pragma: no cover + """Check if a subscription has reached its intended state. + + Args: + binding_key (str): Binding key to check. + + Returns: + bool | None: True if done, False if not done, None if subscription doesn't exist. + """ + with self._condition: + if not self.has_subscription(binding_key): + return None + return self._bindings[binding_key].done + + def get_active_subscriptions(self): + """ + Get all active subscriptions. + + Returns: + dict[str, Binding]: Deep copies of active bindings keyed by binding_key. + """ + with self._condition: + return { + binding_key: copy.deepcopy(binding) for binding_key, binding in self._bindings.items() if self.is_subscription_active(binding_key) + } + + def _update_status(self, binding: Binding, status: BindingStatus): + _LOGGER.info(f'{self}: Updated subscription status: {binding.subscription.binding_key()} {binding.status.value} -> {status.value}') + binding.status = status + binding.attempts = 0 + self._condition.notify_all() + + def _confirm_subscribed(self, binding_key: str): + if not self.has_subscription(binding_key): + _LOGGER.warning(f'{self}: Unknown subscription {binding_key} - cannot update status to {BindingStatus.ACTIVE.value}') + return + + binding = self._bindings[binding_key] + + if binding.status == BindingStatus.ACTIVE or binding.intent == BindingStatus.UNSUBSCRIBED: + return + + self._update_status(binding, BindingStatus.ACTIVE) + + def _confirm_unsubscribed(self, binding_key: str): + if not self.has_subscription(binding_key): + _LOGGER.warning(f'{self}: Unknown subscription {binding_key} - cannot update status to {BindingStatus.UNSUBSCRIBED.value}') + return + + binding = self._bindings[binding_key] + + if binding.status == BindingStatus.UNSUBSCRIBED or binding.intent == BindingStatus.ACTIVE: + return + + self._update_status(binding, BindingStatus.UNSUBSCRIBED) + + def wait_for(self, binding_key: str, timeout: float | None = None) -> bool: + """ + Wait for a subscription to reach its intended state. + + Blocks until the binding reaches its intent or fails. Uses a condition variable + for efficient waiting. + + Args: + binding_key (str): Binding key to wait for. + timeout (float | None): Maximum time to wait in seconds. None means wait indefinitely. Default: None. + + Returns: + bool: True if binding reached intended state, False if timed out, failed, or binding not found. + """ + deadline = None if timeout is None else time.monotonic() + timeout + + with self._condition: + while True: + if not self.has_subscription(binding_key): + return False + + binding = self._bindings[binding_key] + + if binding.done: + return True + + if binding.status == BindingStatus.FAILED: + return False + + # wait for the remaining time + remaining = None + if timeout is not None: + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + + self._condition.wait(remaining) + + def __str__(self): # pragma: no cover + return f'{self.__class__.__qualname__}()' diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py new file mode 100644 index 00000000..19fa47f5 --- /dev/null +++ b/ibind/ws_v2/ws_transport.py @@ -0,0 +1,410 @@ +import time +from datetime import datetime +from typing import Callable, Any, cast, List, Union, Dict + +from pydantic import BaseModel, ConfigDict, Field +from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION, STATUS_NORMAL + +from ibind import var +from ibind.support.errors import ExternalBrokerError +from ibind.support.logs import project_logger +from ibind.support.py_utils import exception_to_string, tname, wait_until, UNDEFINED, noop + +_LOGGER = project_logger('ibkr_ws_client') + + +class TransportEvent(BaseModel): + """ + Base class for WebSocket transport-level events. + + Tracks when events were received and how many processing attempts have been made. + Uses a list for attempt count to allow mutation despite frozen model. + """ + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + received_at: datetime = Field(default_factory=datetime.now) + attempt: List[int] = Field(default_factory=lambda: [0]) + + def add_attempt(self): + self.attempt[0] += 1 + + def get_attempt(self): + return self.attempt[0] + + def __str__(self): + return f'{self.__class__.__qualname__}()' + + +class TransportOpened(TransportEvent): + """Emitted when the WebSocket connection is successfully opened.""" + + pass + + +class TransportClosed(TransportEvent): + """Emitted when the WebSocket connection is closed.""" + + close_status_code: int | None + close_msg: str | None + + +class TransportError(TransportEvent): + """Emitted when a WebSocket error occurs. Note that currently WebSocketApp only emits this once, due to reconnection logic then skipping this event.""" + + exception: Exception + + +class TransportMessage(TransportEvent): + """Emitted when a message is received from the WebSocket.""" + + message: str + + +class TransportReconnect(TransportEvent): + """Emitted when the WebSocket reconnects after a disconnection.""" + + pass + + +class WsTransport: + """ + Manages low-level WebSocket transport using WebSocketApp. + + Handles connection lifecycle, message sending, cookie validation, and ping monitoring. + Runs in a dedicated thread and communicates via event callbacks. + """ + + def __init__( + self, + url: str, + event_callback: Callable[[TransportEvent], None], + sslopt: Dict[str, Any], + get_cookie: Callable[[], str | None] = noop, + get_header: Callable[[], Dict[str, Any] | None] = noop, + ping_interval: float = 10, + ping_timeout: float = 10, + max_ping_interval: float = 20, + connection_timeout: float = 5, + reconnect_timeout: float = 5, + skip_utf8_validation: bool = var.IBIND_WS_SKIP_UTF8_VALIDATION, + ): + """ + Create a WebSocket transport instance. + + Args: + url (str): WebSocket URL to connect to. + event_callback (Callable): Callback function invoked with TransportEvent instances. + sslopt (dict[str, Any]): SSL options for the WebSocket connection. + get_cookie (Callable, optional): Function to retrieve session cookie. Default: noop. + get_header (Callable, optional): Function to retrieve HTTP headers. Default: noop. + ping_interval (float, optional): Interval in seconds between ping messages. Default: 10. + ping_timeout (float, optional): Timeout in seconds for ping responses. Default: 10. + max_ping_interval (float, optional): Maximum acceptable time since last pong. Default: 20. + connection_timeout (float, optional): Timeout in seconds for connection operations. Default: 5. + reconnect_timeout (float, optional): Timeout in seconds before reconnect attempts. Default: 5. + skip_utf8_validation (bool, optional): Whether to skip UTF-8 validation. Default: True + """ + self._url = url + self._event_callback = event_callback + self._sslopt = sslopt + self._get_cookie = get_cookie + self._get_header = get_header + self._ping_interval = ping_interval + self._ping_timeout = ping_timeout + self._max_ping_interval = max_ping_interval + self._connection_timeout = connection_timeout + self._reconnect_timeout = reconnect_timeout + self._skip_utf8_validation = skip_utf8_validation + + self._running = False + self._wsa: WebSocketApp | None = None + self._degraded = False + self._tname = None + self._last_unanswered_ping_tm = None + + self._session_lacks_authentication = False + + def disconnect(self): + """Gracefully disconnect the WebSocket connection.""" + if self._wsa is None: + _LOGGER.info(f'{self}: WebSocketApp is None, skipping disconnect') + return + self._wsa.close(status=STATUS_NORMAL, timeout=self._connection_timeout) + + def stop(self): + """Stop the transport thread and disconnect the WebSocket.""" + _LOGGER.debug(f'{self}: Stopping transport') + self._running = False + self.disconnect() + + def reset_websocket_app(self) -> bool: + """ + Force close and recreate the WebSocketApp connection. + + Returns: + bool: True if a new WebSocketApp was successfully created, False otherwise. + + Raises: + RuntimeError: If called from within the transport thread. + """ + if tname() == self._tname: + raise RuntimeError(f'{self}: Resetting websocket app called from within transport thread. Ensure it is called from a separate thread') + + if self._wsa is None: + _LOGGER.info(f'{self}: WebSocketApp is None, skipping reset') + return False + + _LOGGER.info(f'{self}: Reset') + + self._wsa.close(status=STATUS_UNEXPECTED_CONDITION, timeout=self._connection_timeout) + + if not wait_until(lambda: self._wsa is None, f'{self}: WebSocket reset close timeout', timeout=self._connection_timeout * 2): + _LOGGER.warning(f'{self}: Abandoning current WebSocketApp that cannot be closed: {self._wsa}') + self._wsa = None + + wait_until(lambda: self._wsa is not None, f'{self}: WebSocket recreation timeout', timeout=self._connection_timeout * 2) + + return self._wsa is not None + + def check_ping(self, max_interval: float = None) -> bool: + """ + Check if the last pong was received within the acceptable interval. + + Args: + max_interval (float, optional): Maximum acceptable seconds since last pong. + Default: self._max_ping_interval. + + Returns: + bool: True if last pong was within the interval or WebSocketApp is not connected, + False if the interval was exceeded. + """ + if self._wsa is None: + return True + + last_ping_tm = getattr(self._wsa, 'last_ping_tm', 0) + last_pong_tm = getattr(self._wsa, 'last_pong_tm', 0) + + if last_ping_tm == 0: + return True + + if max_interval is None: + max_interval = self._max_ping_interval + + now = time.time() + if last_pong_tm >= last_ping_tm and last_pong_tm != 0: + self._last_unanswered_ping_tm = None + return abs(now - last_pong_tm) <= max_interval + + if self._last_unanswered_ping_tm is not None and last_pong_tm >= self._last_unanswered_ping_tm: + self._last_unanswered_ping_tm = None + if self._last_unanswered_ping_tm is None: + self._last_unanswered_ping_tm = last_ping_tm + + return abs(now - self._last_unanswered_ping_tm) <= max_interval + + def get_time_since_last_ping(self) -> float: + """Get seconds elapsed since the latest ping that still needs a fresh pong, or the last pong.""" + last_ping_tm = getattr(self._wsa, 'last_ping_tm', 0) + last_pong_tm = getattr(self._wsa, 'last_pong_tm', 0) + if last_pong_tm >= last_ping_tm and last_pong_tm != 0: + return abs(time.time() - last_pong_tm) + if self._last_unanswered_ping_tm is not None: + return abs(time.time() - self._last_unanswered_ping_tm) + return abs(time.time() - last_ping_tm) + + def fetch_cookie(self) -> Union[str, None]: + """ + Retrieve session cookie using the configured callback. + + Returns: + str | None | UNDEFINED: Cookie value, None if no cookie needed, or UNDEFINED if retrieval failed. + """ + try: + cookie = self._get_cookie() + if self._session_lacks_authentication: + self._session_lacks_authentication = False + return cookie + except Exception as e: + if isinstance(e, TimeoutError): + _LOGGER.info(f'{self}: Timeout retrieving cookie') + return UNDEFINED + if isinstance(e, ExternalBrokerError): + if e.status_code == 401: + if not self._session_lacks_authentication: + self._session_lacks_authentication = True + _LOGGER.info(f'{self}: Failed to retrieve cookie due to lack of authentication. Continuing reattempts silently until authentication is reestablished.') + return UNDEFINED + _LOGGER.error(f'{self}: Failed to retrieve cookie: {exception_to_string(e)}') + return UNDEFINED + + def check_cookie(self) -> bool: + """ + Verify the current cookie matches the stored cookie. + + Returns: + bool: True if cookies match, False if retrieval failed or cookies differ. + """ + cookie = self.fetch_cookie() + if cookie is UNDEFINED: + return False + + if cookie != self._cookie: + _LOGGER.warning(f'{self}: Cookie changed, current: {cookie}, previous: {self._cookie}') + return False + return True + + def set_degraded(self, value): + """Mark the transport as degraded to suppress event callbacks.""" + self._degraded = value + + def is_ready(self) -> bool: + """Check if the WebSocketApp is ready to send messages.""" + return self._wsa is not None and self._wsa.ready and self._wsa.sock is not None and self._wsa.sock.sock is not None + + def send(self, payload: str) -> bool: + """ + Send a message through the WebSocket. + + Args: + payload (str): Message to send. + + Returns: + bool: True if sent successfully, False otherwise. + + Raises: + RuntimeError: If the WebSocketApp is not ready. + """ + if not self.is_ready(): + raise RuntimeError(f'{self}: WebSocketApp socket is not ready') + + try: + self._wsa.send(payload) + except Exception as e: + if 'Connection is already closed' in str(e): + _LOGGER.error(f'{self}: Connection closed while sending payload: {payload}') + else: + _LOGGER.exception(f'{self}: Sending payload failed: {payload}\n{exception_to_string(e)}') + return False + + return True + + def __str__(self): + return f'{self.__class__.__qualname__}({f"degraded:{tname()}" if self._degraded else ""})' + + # ====================== + # == Transport Thread == + # ====================== + + def _wrap_callback(self, f): + def wrapped_f(ws, *args, **kwargs): + try: + f(ws, *args, **kwargs) + except Exception as e: + _LOGGER.exception(f'{self}: Exception executing callback: \n{f} \nwith\n{args=}\n{kwargs=}\n{str(e)}') + + return wrapped_f + + def _on_open(self, wsa: WebSocketApp): + if self._degraded: + return + + if not self.check_cookie(): + self._wsa.close(status=STATUS_UNEXPECTED_CONDITION, timeout=self._connection_timeout) + return + + self._event_callback(TransportOpened()) + + def _on_message(self, wsa: WebSocketApp, message): + if self._degraded: + return + + self._event_callback(TransportMessage(message=message)) + + def _on_close(self, wsa: WebSocketApp, close_status_code, close_msg): + if self._degraded: + return + + self._event_callback(TransportClosed(close_status_code=close_status_code, close_msg=close_msg)) + + def _on_error(self, wsa: WebSocketApp, error): + if self._degraded: + return + + self._event_callback(TransportError(exception=error)) + + def _on_reconnect(self, wsa: WebSocketApp): + if self._degraded: + return + + if not self.check_cookie(): + self._wsa.close(status=STATUS_UNEXPECTED_CONDITION, timeout=self._connection_timeout) + return + + self._event_callback(TransportReconnect()) + + def _new_wsa(self): + """Create a new WebSocketApp instance with current cookie and header.""" + cookie = self.fetch_cookie() + if cookie is UNDEFINED: + return None + + self._cookie = cookie + + try: + self._header = self._get_header() + except Exception as e: + _LOGGER.error(f'{self}: Failed to retrieve header: {exception_to_string(e)}') + return None + + if not self._running: + # Transport got stopped between invocation of this function and creating a WebSocketApp + return None + + wsa = WebSocketApp( + url=self._url, + on_open=self._wrap_callback(self._on_open), + on_message=self._wrap_callback(self._on_message), + on_close=self._wrap_callback(self._on_close), + on_error=self._wrap_callback(self._on_error), + on_reconnect=self._wrap_callback(self._on_reconnect), + cookie=self._cookie, + header=self._header, + ) + _LOGGER.debug(f'{self}: Created new WebSocketApp instance{f", cookie: {cookie}" if cookie is not None else ""}') + + return wsa + + def connect(self): + """Main transport thread loop that maintains the WebSocket connection.""" + _LOGGER.debug(f'{self}: Transport thread started ({tname()})') + + self._tname = tname() + + self._running = True + + while self._running: + if self._wsa is None: + wsa = self._new_wsa() + if wsa is None: + time.sleep(1) + continue + self._wsa = wsa + + try: + self._wsa.run_forever( + ping_interval=self._ping_interval, + ping_timeout=self._ping_interval * 0.95, # the timeout is set to a little sooner than the interval + sslopt=self._sslopt, + reconnect=cast(int, self._reconnect_timeout), # floats are de facto valid, casting only for the linter + skip_utf8_validation=self._skip_utf8_validation, + ) + _LOGGER.debug(f'{self}: WebSocketApp stopped gracefully') + except Exception as e: + if 'url is invalid' in str(e): + _LOGGER.error(f'{self}: URL is invalid: {self._url}') + else: + _LOGGER.exception(f'{self}: Unexpected error while running WebSocketApp: {e}') + finally: + self._wsa = None + + _LOGGER.debug(f'{self}: Transport thread stopped ({tname()})') diff --git a/pyproject.toml b/pyproject.toml index 0b72a5d6..706e8f75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,12 @@ authors = [ ] description = "IBind is a REST and WebSocket client library for Interactive Brokers Client Portal Web API." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" license-files=["LICENSE"] dependencies = [ "requests>=2.31", - "websocket-client>=1.7" + "websocket-client>=1.7", + "pydantic>=2.13" ] classifiers = [ "Development Status :: 4 - Beta", @@ -71,4 +72,4 @@ docstring-code-format = true requires = [ "setuptools" ] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 786c8ffd..57b558d0 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/test/integration/base/test_rest_client_i.py b/test/integration/base/test_rest_client_i.py index c7effe2d..a5c3e29e 100644 --- a/test/integration/base/test_rest_client_i.py +++ b/test/integration/base/test_rest_client_i.py @@ -1,11 +1,11 @@ import asyncio -import logging import threading import pytest from unittest.mock import MagicMock from requests import ReadTimeout, Timeout +from requests import ConnectionError as RequestsConnectionError from ibind.client.ibkr_client import IbkrClient from ibind.support.errors import ExternalBrokerError @@ -136,6 +136,47 @@ def test_response_raise_generic(client, result, requests_mock): assert f'RestClient: response error {result.copy(data=None)} :: {response.status_code} :: {response.reason} :: {response.text}' == str(excinfo.value) +def test_result_copy_with_none_data(): + result = Result() + assert result.copy() == Result(data=None, request={}) + + +def test_request_recreates_session_before_retry(default_url, data, requests_mock): + # Arrange + response = MagicMock() + response.json.return_value = data + + first_session = MagicMock() + first_session.request.side_effect = RequestsConnectionError('connection dropped') + + second_session = MagicMock() + second_session.request.return_value = response + + requests_mock.Session.side_effect = [first_session, second_session] + client = RestClient(url=_URL, timeout=_TIMEOUT, max_retries=1, use_session=True) + + # Act + rv = client.get(_DEFAULT_PATH) + + # Assert + assert rv == Result(data=data, request={'url': default_url}) + first_session.close.assert_called_once() + first_session.request.assert_called_once_with('GET', default_url, verify=False, headers={}, timeout=_TIMEOUT) + second_session.request.assert_called_once_with('GET', default_url, verify=False, headers={}, timeout=_TIMEOUT) + + +def test_close_is_idempotent(client): + # Arrange + client.close_session = MagicMock() + + # Act + client.close() + client.close() + + # Assert + client.close_session.assert_called_once() + + def _worker_in_thread(results: []): try: IbkrClient() @@ -218,4 +259,4 @@ def test_without_thread_async(): # Assert for result in results: if isinstance(result, Exception): - raise result \ No newline at end of file + raise result diff --git a/test/integration/base/test_websocket_client_i.py b/test/integration/base/test_websocket_client_i.py index 40801dd2..df2321b0 100644 --- a/test/integration/base/test_websocket_client_i.py +++ b/test/integration/base/test_websocket_client_i.py @@ -74,7 +74,7 @@ def _logs_exception_starting(error_message: str, thread_mock: MagicMock): def _logs_check_health_error(max_ping_interval: int, time_ago: str): return [ - f'WsClient: Last WebSocket ping happened {time_ago} seconds ago, exceeding the max ping interval of {max_ping_interval}. Restarting.', + f'WsClient: Last WebSocket pong happened {time_ago} seconds ago, exceeding the max ping interval of {max_ping_interval}. Restarting.', 'WsClient: Hard reset, restart=True, self._wsa is None=False', 'WsClient: Hard reset is closing the WebSocketApp', ] @@ -353,7 +353,7 @@ def test_send_without_start(ws_client, **kwargs): @capture_logs( logger_level='DEBUG', expected_errors=[ - 'WsClient: Last WebSocket ping happened', + 'WsClient: Last WebSocket pong happened', 'WsClient: Hard reset close timeout', 'WsClient: Abandoning current WebSocketApp that cannot be closed:', ], @@ -395,4 +395,30 @@ def fake_time(): + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] - ) \ No newline at end of file + ) + + +def test_check_ping_tracks_first_unanswered_ping(ws_client, wsa_mock, patched_constructors, mocker): + """Fails even if outgoing ping timestamps keep updating without a pong.""" + ## Arrange + current_time = [100.0] + mocker.patch('ibind.base.ws_client.time.time', side_effect=lambda: current_time[0]) + ws_client.start() + + ## Act / Assert + wsa_mock.last_ping_tm = 100.0 + wsa_mock.last_pong_tm = 0 + assert ws_client.check_ping() is True + + current_time[0] = 130.0 + wsa_mock.last_ping_tm = 130.0 + assert ws_client.check_ping() is True + + current_time[0] = 139.0 + wsa_mock.last_ping_tm = 139.0 + ws_client.hard_reset = MagicMock() + assert ws_client.check_ping() is False + ws_client.hard_reset.assert_called_once_with(restart=True) + + ## Cleanup + ws_client.shutdown() diff --git a/test/integration/base/websocketapp_mock.py b/test/integration/base/websocketapp_mock.py index af2458c6..fd0de6de 100644 --- a/test/integration/base/websocketapp_mock.py +++ b/test/integration/base/websocketapp_mock.py @@ -27,6 +27,7 @@ def init_wsa_mock( wsa_mock._on_close.side_effect = on_close wsa_mock.last_ping_tm = 0 + wsa_mock.last_pong_tm = 0 wsa_mock.keep_running = False return wsa_mock @@ -53,4 +54,4 @@ def create_wsa_mock(): wsa_mock.close.side_effect = lambda *args, **kwargs: close(wsa_mock, *args, **kwargs) wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever(wsa_mock, *args, **kwargs) - return wsa_mock \ No newline at end of file + return wsa_mock diff --git a/test/integration/client/test_ibkr_client_i.py b/test/integration/client/test_ibkr_client_i.py index b22ae18a..6d9d5c0c 100644 --- a/test/integration/client/test_ibkr_client_i.py +++ b/test/integration/client/test_ibkr_client_i.py @@ -158,7 +158,7 @@ def _marketdata_request(method, url, *args, **kwargs): if leaf == 'stocks': return MagicMock(json=lambda: ibkr_responses.responses['stocks']) elif leaf == 'history': - conid = kwargs['params']['conid'] + conid = int(kwargs['params']['conid']) history_by_conid = { ibkr_responses.responses['filtered_conids'][key]: value for key, value in ibkr_responses.responses['history'].items() } @@ -219,6 +219,18 @@ def test_marketdata_history_by_symbols(client, requests_mock): assert result['date'] == expected['date'] +def test_marketdata_history_by_symbol_accepts_stock_query(client, requests_mock): + # Arrange + requests_mock.request.side_effect = _marketdata_request + query = StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE') + + # Act + result = client.marketdata_history_by_symbol(query, bar='1min') + + # Assert + assert result.data == ibkr_responses.responses['history']['AAPL'] + + def test_check_health_authenticated_and_connected(client, default_url, requests_mock): # Arrange response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} @@ -358,4 +370,4 @@ def test_marketdata_unsubscribe_raises_exception_on_failure(client, mocker): client.marketdata_unsubscribe(conids) # Assert - assert excinfo.value.status_code == 500 \ No newline at end of file + assert excinfo.value.status_code == 500 diff --git a/test/integration/client/test_ibkr_utils_i.py b/test/integration/client/test_ibkr_utils_i.py index e03569b7..748f4f41 100644 --- a/test/integration/client/test_ibkr_utils_i.py +++ b/test/integration/client/test_ibkr_utils_i.py @@ -1,4 +1,5 @@ from pprint import pformat +import threading from unittest.mock import MagicMock, call import pytest @@ -13,6 +14,7 @@ question_type_to_message_id, OrderRequest, parse_order_request, + Tickler, ) from test.integration.client import ibkr_responses from test.test_utils import CaptureLogsContext @@ -391,4 +393,32 @@ def test_raise_with_conid_and_conidex(): parse_order_request(order_request) - assert "Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`." == str(cm_err.value) \ No newline at end of file + assert "Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`." == str(cm_err.value) + + +def test_tickler_keeps_thread_reference_when_stop_times_out(): + ## Arrange + tickle_started = threading.Event() + release_tickle = threading.Event() + client = MagicMock() + + def slow_tickle(): + tickle_started.set() + release_tickle.wait(1) + + client.tickle.side_effect = slow_tickle + tickler = Tickler(client, interval=0) + + ## Act + tickler.start() + assert tickle_started.wait(1) + stopped = tickler.stop(timeout=0.01) + + ## Assert + assert stopped is False + assert tickler._thread is not None + assert tickler._thread.is_alive() + + ## Cleanup + release_tickle.set() + assert tickler.stop(timeout=1) is True diff --git a/test/integration/client/test_ibkr_ws_client_i.py b/test/integration/client/test_ibkr_ws_client_i.py index b3cf9c72..c3aaf399 100644 --- a/test/integration/client/test_ibkr_ws_client_i.py +++ b/test/integration/client/test_ibkr_ws_client_i.py @@ -1,4 +1,5 @@ import json +import ssl from threading import Thread from typing import Optional from unittest.mock import MagicMock, call @@ -94,7 +95,6 @@ def ws_app_factory(wsa_mock): def patched_constructors(mocker, thread_mock, ws_app_factory): mocker.patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: ws_app_factory['fn'](*args, **kwargs)) mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) - return None @@ -133,6 +133,32 @@ def override_on_message(wsa_mock: MagicMock, message: str): return rv +def test_oauth_url_preserves_existing_query_and_forces_tls(client_mock): + # Arrange / Act + client = IbkrWsClient( + url='wss://localhost:5000/v1/api/ws?existing=1', + ibkr_client=client_mock, + use_oauth=True, + access_token='TOKEN VALUE', # noqa: S106 + cacert=False, + ) + + # Assert + assert client._url == 'wss://localhost:5000/v1/api/ws?existing=1&oauth_token=TOKEN+VALUE' + assert client._sslopt == {'cert_reqs': ssl.CERT_REQUIRED} + + +def test_auth_status_competing_false_not_logged_as_error(ws_client, mocker): + # Arrange + logger_error = mocker.patch('ibind.client.ibkr_ws_client._LOGGER.error') + + # Act + ws_client._handle_authentication_status({}, {'competing': False}) + + # Assert + logger_error.assert_not_called() + + def _logs_subscriptions(full_channel, data=None, needs_confirmation_sub: bool = False, needs_confirmation_unsub: bool = True): return [ @@ -246,7 +272,7 @@ def test_on_message_sts_authenticated(ws_client, patched_constructors): _send_payload(ws_client, {'topic': 'sts', 'args': {'authenticated': True}}) -@capture_logs(logger_level='DEBUG', expected_errors = [f'IbkrWsClient: Error message:'], partial_match=True) +@capture_logs(logger_level='DEBUG', expected_errors = ['IbkrWsClient: Error message:'], partial_match=True) def test_on_message_error(ws_client, patched_constructors): """Logs error-topic messages as warnings.""" ## Act @@ -536,4 +562,4 @@ def override_init_wsa_mock(wsa_mock: MagicMock, *args, **kwargs): channel_subscribed_log, f'IbkrWsClient: Invalidated subscription: {full_channel}', ] - ) \ No newline at end of file + ) diff --git a/test/test_utils.py b/test/test_utils.py index e33822c2..ca2a269d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -473,4 +473,4 @@ def mock_module_time(target_module, time_sequence=None, start_time=0.0): # time.time() in mymodule will return 1.0, then 2.0, then 3.0 pass """ - return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) \ No newline at end of file + return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) diff --git a/test/unit/client/test_ibkr_client_u.py b/test/unit/client/test_ibkr_client_u.py index 79a15321..618f508c 100644 --- a/test/unit/client/test_ibkr_client_u.py +++ b/test/unit/client/test_ibkr_client_u.py @@ -1,3 +1,7 @@ +import os +import subprocess +import sys + import pytest from unittest.mock import MagicMock from ibind.client.ibkr_client import IbkrClient @@ -61,4 +65,36 @@ def test_handle_auth_status_not_healthy_oauth_success(client, caplog): ## Assert assert any("IBKR connection is not healthy. Attempting to re-establish OAuth authentication." in r.message for r in caplog.records) client.stop_tickler.assert_called_once_with(15) - client.oauth_init.assert_called_once_with(maintain_oauth=True, init_brokerage_session=True) \ No newline at end of file + client.oauth_init.assert_called_once_with(maintain_oauth=True, init_brokerage_session=True) + + +def test_close_oauth_is_idempotent(client): + ## Arrange + client._use_oauth = True + client.oauth_config.shutdown_oauth = True + client.oauth_shutdown = MagicMock() + + ## Act + client.close() + client.close() + + ## Assert + client.oauth_shutdown.assert_called_once() + + +def test_oauth_config_init_brokerage_session_uses_dedicated_env_var(): + ## Arrange + env = os.environ.copy() + env['IBIND_INIT_OAUTH'] = 'false' + env['IBIND_INIT_BROKERAGE_SESSION'] = 'true' + code = ( + 'from ibind.oauth import OAuthConfig\n' + 'print(OAuthConfig.__dataclass_fields__["init_oauth"].default)\n' + 'print(OAuthConfig.__dataclass_fields__["init_brokerage_session"].default)\n' + ) + + ## Act + completed = subprocess.run([sys.executable, '-c', code], env=env, capture_output=True, text=True, check=True) # noqa: S603 + + ## Assert + assert completed.stdout.splitlines() == ['False', 'True'] diff --git a/test/unit/client/test_oauth1a_u.py b/test/unit/client/test_oauth1a_u.py new file mode 100644 index 00000000..7f812093 --- /dev/null +++ b/test/unit/client/test_oauth1a_u.py @@ -0,0 +1,59 @@ +from urllib.parse import unquote_plus + +import pytest + +pytest.importorskip('Crypto') + +from ibind.client.ibkr_client import IbkrClient +from ibind.oauth.oauth1a import OAuth1aConfig, generate_oauth_headers +from ibind.base.rest_client import Result + + +def test_generate_oauth_headers_omits_static_host_and_includes_request_params(monkeypatch): + ## Arrange + captured = {} + config = OAuth1aConfig(consumer_key='consumer', access_token='access', realm='realm') # noqa: S106 + + monkeypatch.setattr('ibind.oauth.oauth1a.generate_oauth_nonce', lambda: 'nonce') + monkeypatch.setattr('ibind.oauth.oauth1a.generate_request_timestamp', lambda: '123') + + def fake_signature(base_string, live_session_token): + captured['base_string'] = base_string + captured['live_session_token'] = live_session_token + return 'signature' + + monkeypatch.setattr('ibind.oauth.oauth1a.generate_hmac_sha_256_signature', fake_signature) + + ## Act + headers = generate_oauth_headers( + oauth_config=config, + request_method='GET', + request_url='https://1.api.ibkr.com/v1/api/iserver/accounts', + live_session_token='live-token', # noqa: S106 + request_params={'accountId': 'DU123'}, + ) + + ## Assert + assert 'Host' not in headers + assert captured['live_session_token'] == 'live-token' # noqa: S105 + assert 'accountId=DU123' in unquote_plus(captured['base_string']) + + +def test_oauth_get_request_passes_query_params_to_signature(mocker): + ## Arrange + client = IbkrClient(url='https://1.api.ibkr.com/v1/api/', use_oauth=False, use_session=False, auto_register_shutdown=False) + client._use_oauth = True + client.oauth_config = OAuth1aConfig(consumer_key='consumer', access_token='access', realm='realm') # noqa: S106 + client.live_session_token = 'live-token' # noqa: S105 + client._process_response = mocker.MagicMock(return_value=Result(data={'ok': True})) + + generate_headers = mocker.patch('ibind.oauth.oauth1a.generate_oauth_headers', return_value={'Authorization': 'OAuth test'}) + request = mocker.patch('ibind.base.rest_client.requests.request') + request.return_value = mocker.MagicMock() + + ## Act + client.get('iserver/accounts', params={'accountId': 'DU123'}) + + ## Assert + generate_headers.assert_called_once() + assert generate_headers.call_args.kwargs['request_params'] == {'accountId': 'DU123'} diff --git a/test/unit/support/test_py_utils_u.py b/test/unit/support/test_py_utils_u.py index 5fb4a250..685acfda 100644 --- a/test/unit/support/test_py_utils_u.py +++ b/test/unit/support/test_py_utils_u.py @@ -1,9 +1,10 @@ +import ssl import time from unittest.mock import MagicMock import pytest -from ibind.support.py_utils import ensure_list_arg, execute_in_parallel, execute_with_key, wait_until +from ibind.support.py_utils import ensure_list_arg, execute_in_parallel, execute_with_key, wait_until, append_query_params, make_websocket_sslopt @ensure_list_arg('arg') @@ -62,7 +63,7 @@ def test_ensure_list_arg_with_keyword_arg_non_list(): def test_ensure_list_arg_with_missing_arg(): """Raises TypeError when the decorated arg is missing.""" # Arrange - + # Act / Assert with pytest.raises(TypeError): sample_function() @@ -225,4 +226,15 @@ def test_wait_until_timeout(): # Assert assert result is False duration = time.time() - start_time - assert duration == pytest.approx(timeout, abs=0.02) \ No newline at end of file + assert duration == pytest.approx(timeout, abs=0.02) + + +def test_append_query_params_preserves_existing_query(): + url = append_query_params('wss://example.test/ws?existing=1', {'oauth_token': 'abc 123'}) + assert url == 'wss://example.test/ws?existing=1&oauth_token=abc+123' + + +def test_make_websocket_sslopt_supports_bool_modes(): + assert make_websocket_sslopt(False) == {'cert_reqs': ssl.CERT_NONE} + assert make_websocket_sslopt(None) == {'cert_reqs': ssl.CERT_NONE} + assert make_websocket_sslopt(True) == {'cert_reqs': ssl.CERT_REQUIRED} diff --git a/test/unit/test_public_imports_u.py b/test/unit/test_public_imports_u.py new file mode 100644 index 00000000..e848167e --- /dev/null +++ b/test/unit/test_public_imports_u.py @@ -0,0 +1,13 @@ +from setuptools import find_packages + + +def test_public_v2_imports_and_package_discovery(): + import ibind + + packages = find_packages() + + assert 'ibind.ws_v2' in packages + assert 'ibind.ibkr_ws_v2' in packages + assert ibind.IbkrWsClientV2 is not None + assert ibind.events.WsOpen is not None + assert ibind.subscriptions.MarketDataSubscription is not None diff --git a/test/unit/ws_v2/test_ws_events_u.py b/test/unit/ws_v2/test_ws_events_u.py new file mode 100644 index 00000000..2366189f --- /dev/null +++ b/test/unit/ws_v2/test_ws_events_u.py @@ -0,0 +1,496 @@ +import threading +from datetime import datetime +from queue import Empty, Full +from unittest.mock import MagicMock, patch + +import pytest + +from ibind.events import ( + WsOpen, + WsAuthenticated, + WsDegraded, + WsReady, + WsClose, + WsError, +) +from ibind import ( + NoopSink, + CallbackSink, + QueueSink, + CompositeSink, + EventSink, +) +from test.test_utils import capture_logs +from ibind.ws_v2._ws_events import AsyncSink + + +@pytest.fixture +def noop_sink(): + return NoopSink() + + +@pytest.fixture +def callback_sink(): + return CallbackSink() + + +@pytest.fixture +def queue_sink(): + sink = QueueSink() + sink._queues.clear() + return sink + + +@pytest.fixture +def sample_event(): + return WsOpen() + + +class TestWsEvent: + @capture_logs() + def test_immutability(self): + """WsEvent instances are immutable after creation.""" + ## Arrange + event = WsOpen() + + ## Act / Assert + with pytest.raises(Exception): + event.received_at = datetime.now() # NOQA + + @capture_logs() + def test_extra_fields_forbidden(self): + """WsEvent rejects extra fields not in the model.""" + ## Arrange / Act / Assert + with pytest.raises(Exception): + WsOpen(extra_field='value') # NOQA + + +@capture_logs() +def test_lifecycle_events(): + """Lifecycle events can be created with default received_at and optional fields.""" + ## Arrange / Act + ws_open = WsOpen() + ws_authenticated = WsAuthenticated() + ws_degraded = WsDegraded() + ws_ready = WsReady() + ws_close_with_fields = WsClose(close_status_code=1000, close_msg='normal closure') + ws_close_with_none = WsClose(close_status_code=None, close_msg=None) + error = RuntimeError('connection failed') + ws_error = WsError(error=error) + + ## Assert + assert isinstance(ws_open.received_at, datetime) + assert isinstance(ws_authenticated.received_at, datetime) + assert isinstance(ws_degraded.received_at, datetime) + assert isinstance(ws_ready.received_at, datetime) + assert ws_close_with_fields.close_status_code == 1000 + assert ws_close_with_fields.close_msg == 'normal closure' + assert ws_close_with_none.close_status_code is None + assert ws_close_with_none.close_msg is None + assert ws_error.error is error + + +class TestCallbackSink: + @capture_logs() + def test_on_registers_callback(self, callback_sink): + """CallbackSink.on registers a callback for an event type.""" + ## Arrange + callback = MagicMock() + + ## Act + callback_sink.on(WsOpen, callback) + + ## Assert + assert WsOpen in callback_sink._callbacks + assert callback in callback_sink._callbacks[WsOpen] + + @capture_logs() + def test_emit_calls_registered_callback(self, callback_sink, sample_event): + """CallbackSink.emit invokes callbacks registered for the event type.""" + ## Arrange + callback = MagicMock() + callback_sink.on(WsOpen, callback) + + ## Act + callback_sink.emit(sample_event) + + ## Assert + callback.assert_called_once_with(sample_event) + + @capture_logs() + def test_emit_ignores_unregistered_event_types(self, callback_sink): + """CallbackSink.emit does not call callbacks for unregistered event types.""" + ## Arrange + callback = MagicMock() + callback_sink.on(WsOpen, callback) + event = WsClose(close_status_code=1000, close_msg='') + + ## Act + callback_sink.emit(event) + + ## Assert + callback.assert_not_called() + + @capture_logs() + def test_emit_multiple_callbacks(self, callback_sink, sample_event): + """CallbackSink.emit calls all callbacks registered for an event type.""" + ## Arrange + callback1 = MagicMock() + callback2 = MagicMock() + callback_sink.on(WsOpen, callback1) + callback_sink.on(WsOpen, callback2) + + ## Act + callback_sink.emit(sample_event) + + ## Assert + callback1.assert_called_once_with(sample_event) + callback2.assert_called_once_with(sample_event) + + @capture_logs(logger_level='ERROR', expected_errors=['Exception emitting event to callback test_fn'], partial_match=True) + def test_emit_logs_callback_exception(self, callback_sink, sample_event): + """CallbackSink.emit logs exceptions raised by callbacks without propagating.""" + + ## Arrange + def test_fn(event): + raise ValueError('callback error') + + callback_sink.on(WsOpen, test_fn) + + ## Act + callback_sink.emit(sample_event) + + +class TestQueueSink: + @capture_logs() + def test_new_queue_accessor_creates_accessor(self, queue_sink): + """QueueSink.new_queue_accessor returns a QueueAccessor for the event type.""" + ## Act + accessor = queue_sink.new_queue_accessor(WsOpen) + + ## Assert + assert accessor.key == WsOpen + + @capture_logs() + def test_emit_puts_event_in_queue(self, queue_sink, sample_event): + """QueueSink.emit adds the event to the queue for its type.""" + ## Act + queue_sink.emit(sample_event) + + ## Assert + retrieved = queue_sink.get(WsOpen, block=False) + assert retrieved is sample_event + + @capture_logs() + def test_get_returns_none_when_empty(self, queue_sink): + """QueueSink.get returns None when the queue is empty and block=False.""" + ## Act + result = queue_sink.get(WsOpen, block=False) + + ## Assert + assert result is None + + @capture_logs() + def test_empty_returns_true_when_empty(self, queue_sink): + """QueueSink.empty returns True when no events are queued.""" + ## Act + result = queue_sink.empty(WsOpen) + + ## Assert + assert result is True + + @capture_logs() + def test_empty_returns_false_when_not_empty(self, queue_sink): + """QueueSink.empty returns False when events are queued.""" + ## Arrange + queue_sink.emit(WsOpen()) + + ## Act + result = queue_sink.empty(WsOpen) + + ## Assert + assert result is False + + @capture_logs() + def test_separate_queues_per_event_type(self, queue_sink): + """QueueSink maintains separate queues for different event types.""" + ## Arrange + event1 = WsOpen() + event2 = WsClose(close_status_code=1000, close_msg='') + + ## Act + queue_sink.emit(event1) + queue_sink.emit(event2) + + ## Assert + retrieved1 = queue_sink.get(WsOpen, block=False) + retrieved2 = queue_sink.get(WsClose, block=False) + assert isinstance(retrieved1, WsOpen) + assert isinstance(retrieved2, WsClose) + assert queue_sink.get(WsOpen, block=False) is None + + +class TestCompositeSink: + @capture_logs() + def test_emit_calls_all_sinks(self, sample_event): + """CompositeSink.emit forwards the event to all registered sinks.""" + ## Arrange + sink1 = MagicMock() + sink2 = MagicMock() + composite = CompositeSink(sink1, sink2) + + ## Act + composite.emit(sample_event) + + ## Assert + sink1.emit.assert_called_once_with(sample_event) + sink2.emit.assert_called_once_with(sample_event) + + @capture_logs(logger_level='WARNING', expected_errors=['Exception emitting event to sink'], partial_match=True) + def test_emit_logs_sink_exception(self, sample_event): + """CompositeSink.emit logs exceptions from sinks without propagating.""" + ## Arrange + sink1 = MagicMock() + sink1.emit.side_effect = ValueError('sink error') + sink2 = MagicMock() + composite = CompositeSink(sink1, sink2) + + ## Act + composite.emit(sample_event) + + ## Assert + sink2.emit.assert_called_once_with(sample_event) + + +class TestAsyncSink: + @capture_logs() + def test_start_launches_thread(self, noop_sink): + """AsyncSink.start launches a background thread.""" + ## Arrange + sink = AsyncSink(noop_sink) + + ## Act + sink.start() + + ## Assert + assert sink._running is True + assert sink._thread is not None + assert sink._thread.is_alive() + + ## Cleanup + sink.stop() + + @capture_logs() + def test_start_idempotent(self, noop_sink): + """AsyncSink.start does not launch multiple threads if already running.""" + ## Arrange + sink = AsyncSink(noop_sink) + sink.start() + first_thread = sink._thread + + ## Act + sink.start() + + ## Assert + assert sink._thread is first_thread + + ## Cleanup + sink.stop() + + @capture_logs() + def test_stop_terminates_thread(self, noop_sink): + """AsyncSink.stop terminates the background thread.""" + ## Arrange + sink = AsyncSink(noop_sink) + sink.start() + + ## Act + result = sink.stop() + + ## Assert + assert result is True + assert sink._running is False + assert sink._thread is None + + @capture_logs() + def test_stop_idempotent(self, noop_sink): + """AsyncSink.stop returns True when already stopped.""" + ## Arrange + sink = AsyncSink(noop_sink) + + ## Act + result = sink.stop() + + ## Assert + assert result is True + + @capture_logs() + def test_stop_from_same_thread_raises(self, noop_sink): + """AsyncSink.stop raises RuntimeError when called from the sink thread.""" + ## Arrange + sink = AsyncSink(noop_sink) + exception_holder = {'exception': None} + ev = threading.Event() + + def stop_from_thread(): + try: + sink.stop() + except RuntimeError as e: + exception_holder['exception'] = e + ev.set() + + sink._cycle = stop_from_thread + ev.clear() + sink.start() + ev.wait(10) + + ## Assert + assert exception_holder['exception'] is not None + assert 'Stopping async sink called from within async sink thread' in str(exception_holder['exception']) + + ## Cleanup + sink._running = False + + @capture_logs() + def test_emit_queues_event(self, sample_event): + """AsyncSink.emit adds events to the internal queue.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink) + sink.start() + + ## Act + sink.emit(sample_event) + sink._consume_queue() + + ## Assert + inner_sink.emit.assert_called_with(sample_event) + + ## Cleanup + sink.stop() + + @capture_logs(logger_level='WARNING', expected_errors=['dropping newest event'], partial_match=True) + def test_emit_drops_newest_when_full(self): + """AsyncSink.emit drops the newest event when queue is full and drop_oldest=False.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=False) + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + + sink.emit(event1) + sink.emit(event2) + + @capture_logs(logger_level='WARNING', expected_errors=['dropping oldest event'], partial_match=True) + def test_emit_drops_oldest_when_full(self): + """AsyncSink.emit drops the oldest event when queue is full and drop_oldest=True.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=True) + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + sink.emit(event1) + sink.emit(event2) + + @capture_logs() + def test_consume_queue_forwards_events(self): + """AsyncSink forwards queued events to the inner sink.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink) + sink.start() + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + sink.emit(event1) + sink.emit(event2) + sink._consume_queue() + + ## Assert + assert inner_sink.emit.call_count == 2 + inner_sink.emit.assert_any_call(event1) + inner_sink.emit.assert_any_call(event2) + + ## Cleanup + sink.stop() + + @capture_logs(logger_level='ERROR', expected_errors=['sink error'], partial_match=True) + def test_consume_queue_logs_exception(self): + """AsyncSink logs exceptions from the inner sink without stopping.""" + ## Arrange + inner_sink = MagicMock(spec=EventSink) + inner_sink.emit.side_effect = ValueError('sink error') + sink = AsyncSink(inner_sink) + event = WsOpen() + + sink.emit(event) + sink._consume_queue() + + @capture_logs() + def test_cycle_consumes_remaining_events_on_stop(self): + """AsyncSink processes remaining events in queue when stopping.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, cycle_interval=0.5) + sink.start() + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + sink.emit(event1) + sink.emit(event2) + sink.stop() + + ## Assert + assert inner_sink.emit.call_count >= 2 + + @capture_logs(logger_level='WARNING', expected_errors=['Event queue not empty when stopping'], partial_match=True) + def test_stop_warns_when_queue_not_empty(self): + """AsyncSink logs warning when stopping with events still in queue.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=10) + event = WsOpen() + + ## Act + sink._running = True + sink._queue.put(event) + sink.stop() + + def test_emit_handles_empty_exception_when_dropping_oldest(self): + """AsyncSink handles Empty exception when queue becomes empty between full check and get.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=True) + event1 = WsOpen() + + ## Act + with patch.object(sink._queue, 'put_nowait', side_effect=[Full, None]) as mock_put: + with patch.object(sink._queue, 'get_nowait', side_effect=Empty): + sink.emit(event1) + + ## Assert + assert mock_put.call_count == 2 + + @capture_logs( + logger_level='WARNING', + expected_errors=['Event queue full; dropping oldest event', 'Event queue still full; dropping event'], + partial_match=True, + ) + def test_emit_warns_when_queue_still_full_after_drop(self): + """AsyncSink logs warning when queue is still full after dropping oldest event.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=True) + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + with patch.object(sink._queue, 'put_nowait', side_effect=[Full, Full]): + with patch.object(sink._queue, 'get_nowait', return_value=event1): + sink.emit(event2) diff --git a/test/unit/ws_v2/test_ws_subscriptions_u.py b/test/unit/ws_v2/test_ws_subscriptions_u.py new file mode 100644 index 00000000..f5391c1b --- /dev/null +++ b/test/unit/ws_v2/test_ws_subscriptions_u.py @@ -0,0 +1,821 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from ibind.events import WsOpen, WsEvent +from ibind.ws_v2.ws_subscriptions import ( + Subscription, + BindingStatus, + Binding, + SubscriptionHandle, + SubscriptionController, +) +from test.test_utils import capture_logs, mock_module_time + + +class MockEvent(WsEvent): + """Mock event for testing subscription confirmation.""" + + binding_key: str + is_active: bool + + +class MockResolver: + """Mock resolver that extracts binding_key and is_active from MockEvent.""" + + def resolve_binding_key(self, event: WsEvent): + if isinstance(event, MockEvent): + return (event.is_active, event.binding_key) + return (False, None) + + +class MockSubscription(Subscription): + topic_value: str = 'test_topic' + payload_value: str = 'sub_payload' + + @property + def topic(self) -> str: + return self.topic_value + + def subscribe_payload(self) -> str: + return self.payload_value + + def unsubscribe_payload(self) -> str: + return f'unsub_{self.payload_value}' + + +class MockSubscriptionNoConfirm(Subscription): + topic_value: str = 'no_confirm' + payload_value: str = 'no_confirm_payload' + + @property + def topic(self) -> str: + return self.topic_value + + def subscribe_payload(self) -> str: + return self.payload_value + + def unsubscribe_payload(self) -> str: + return f'unsub_{self.payload_value}' + + @property + def confirms_subscribe(self) -> bool: + return False + + @property + def confirms_unsubscribe(self) -> bool: + return True + + +@pytest.fixture +def mock_sub(): + return MockSubscription() + + +@pytest.fixture() +def binding_key(mock_sub): + return mock_sub.binding_key() + + +@pytest.fixture +def test_subscription_with_expiry(): + return MockSubscription(expiry_seconds=10) + + +@pytest.fixture +def test_subscription_no_confirm(): + return MockSubscriptionNoConfirm() + + +@pytest.fixture +def mock_send_payload(): + return MagicMock(return_value=True) + + +@pytest.fixture +def sc(mock_send_payload): + return SubscriptionController( + send_payload=mock_send_payload, + subscription_resolver=MockResolver(), + subscription_retries=3, + subscription_timeout=1.0, + ) + + +class TestBinding: + @capture_logs() + def test_binding_done_when_status_matches_intent(self, mock_sub): + """Binding.done returns True when status matches intent.""" + ## Arrange + binding = Binding(subscription=mock_sub, intent=BindingStatus.ACTIVE) + binding.status = BindingStatus.ACTIVE + + ## Act + result = binding.done + + ## Assert + assert result is True + + @capture_logs() + def test_binding_not_done_when_status_differs(self, mock_sub): + """Binding.done returns False when status does not match intent.""" + ## Arrange + binding = Binding(subscription=mock_sub, intent=BindingStatus.ACTIVE) + binding.status = BindingStatus.PENDING + + ## Act + result = binding.done + + ## Assert + assert result is False + + +class TestSubscriptionHandle: + @capture_logs() + def test_status(self, sc, mock_sub): + """SubscriptionHandle.status returns the current binding status.""" + ## Arrange + sc.subscribe(mock_sub) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.status == BindingStatus.NEW + + @capture_logs() + def test_active_when_status_active(self, sc, mock_sub, binding_key): + """SubscriptionHandle.active returns True when status is ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.active is True + + @capture_logs() + def test_active_when_status_not_active(self, sc, mock_sub): + """SubscriptionHandle.active returns False when status is not ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.active is False + + @capture_logs() + def test_unsubscribed_when_status_unsubscribed(self, sc, mock_sub, binding_key): + """SubscriptionHandle.unsubscribed returns True when status is UNSUBSCRIBED.""" + ## Arrange + sc.unsubscribe(mock_sub) + with sc._condition: + sc._confirm_unsubscribed(binding_key) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.unsubscribed is True + + @capture_logs() + def test_done_delegates_to_controller(self, sc, mock_sub, binding_key): + """SubscriptionHandle.done delegates to controller.is_done.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.done is True + + @capture_logs() + def test_wait_delegates_to_controller(self, sc, mock_sub, binding_key): + """SubscriptionHandle.wait delegates to controller.wait_for.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + handle = SubscriptionHandle(sc, mock_sub) + + ## Act + result = handle.wait(timeout=1.0) + + ## Assert + assert result is True + + @capture_logs() + def test_unsubscribe_delegates_to_controller(self, sc, mock_sub, binding_key): + """SubscriptionHandle.unsubscribe delegates to controller.unsubscribe.""" + ## Arrange + sc.subscribe(mock_sub) + handle = SubscriptionHandle(sc, mock_sub) + + ## Act + result = handle.unsubscribe() + + ## Assert + assert result is handle + assert sc.get_status(binding_key) == BindingStatus.NEW + + +class TestInterface: + @capture_logs() + def test_subscribe_creates_new_binding(self, sc, mock_sub, binding_key): + """SubscriptionController.subscribe creates a new binding with ACTIVE intent.""" + ## Act + handle = sc.subscribe(mock_sub) + + ## Assert + assert isinstance(handle, SubscriptionHandle) + assert sc.has_subscription(binding_key) + binding = sc._bindings[binding_key] + assert binding.intent == BindingStatus.ACTIVE + assert binding.status == BindingStatus.NEW + + @capture_logs() + def test_subscribe_updates_existing_binding_intent(self, sc, mock_sub, binding_key): + """SubscriptionController.subscribe updates intent on existing binding.""" + ## Arrange + sc.unsubscribe(mock_sub) + binding = sc._bindings[binding_key] + assert binding.intent == BindingStatus.UNSUBSCRIBED + + ## Act + sc.subscribe(mock_sub) + + ## Assert + assert binding.intent == BindingStatus.ACTIVE + + @capture_logs() + def test_subscribe_resets_unsubscribed_binding(self, sc, mock_sub, binding_key): + """SubscriptionController.subscribe resets binding when previously UNSUBSCRIBED.""" + ## Arrange + sc.unsubscribe(mock_sub) + with sc._condition: + sc._confirm_unsubscribed(binding_key) + binding = sc._bindings[binding_key] + binding.attempts = 5 + binding.last_attempt = 100.0 + + ## Act + sc.subscribe(mock_sub) + + ## Assert + binding = sc._bindings[binding_key] + assert binding.attempts == 0 + assert binding.last_attempt == 0 + + @capture_logs() + def test_unsubscribe_creates_new_binding(self, sc, mock_sub, binding_key): + """SubscriptionController.unsubscribe creates a new binding with UNSUBSCRIBED intent.""" + ## Act + handle = sc.unsubscribe(mock_sub) + + ## Assert + assert isinstance(handle, SubscriptionHandle) + assert sc.has_subscription(binding_key) + binding = sc._bindings[binding_key] + assert binding.intent == BindingStatus.UNSUBSCRIBED + assert binding.status == BindingStatus.NEW + + @capture_logs() + def test_unsubscribe_updates_existing_binding_intent(self, sc, mock_sub, binding_key): + """SubscriptionController.unsubscribe updates intent on existing binding.""" + ## Arrange + sc.subscribe(mock_sub) + + ## Act + sc.unsubscribe(mock_sub) + + ## Assert + binding = sc._bindings[binding_key] + assert binding.intent == BindingStatus.UNSUBSCRIBED + + @capture_logs() + def test_unsubscribe_resets_active_binding(self, sc, mock_sub, binding_key): + """SubscriptionController.unsubscribe resets binding when previously ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + binding = sc._bindings[binding_key] + binding.attempts = 5 + binding.last_attempt = 100.0 + + ## Act + sc.unsubscribe(mock_sub) + + ## Assert + binding = sc._bindings[binding_key] + assert binding.attempts == 0 + assert binding.last_attempt == 0 + + @capture_logs() + def test_has_active_subscriptions(self, sc, mock_sub, binding_key): + """SubscriptionController.has_active_subscriptions returns True when any subscription is active.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + result = sc.has_active_subscriptions() + + ## Assert + assert result is True + + @capture_logs() + def test_has_active_subscriptions_returns_false_when_none_active(self, sc, mock_sub): + """SubscriptionController.has_active_subscriptions returns False when no subscriptions are active.""" + ## Arrange + sc.subscribe(mock_sub) + + ## Act + result = sc.has_active_subscriptions() + + ## Assert + assert result is False + + @capture_logs() + def test_get_active_subscriptions(self, sc, mock_sub, binding_key): + """SubscriptionController.get_active_subscriptions returns dict of active bindings.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + result = sc.get_active_subscriptions() + + ## Assert + assert binding_key in result + assert result[binding_key].status == BindingStatus.ACTIVE + + @capture_logs() + def test_invalidate_subscriptions(self, sc, mock_sub, binding_key): + """SubscriptionController.invalidate_subscriptions marks all bindings as DEGRADED.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + sc.invalidate_subscriptions() + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.DEGRADED + + +class TestObserve: + @capture_logs() + def test_observe_confirms_subscribed(self, sc, mock_sub, binding_key): + """SubscriptionController.observe confirms subscription when resolver returns active.""" + ## Arrange + sc.subscribe(mock_sub) + event = MockEvent(binding_key=binding_key, is_active=True) + + ## Act + sc.observe(event) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.ACTIVE + + @capture_logs() + def test_observe_confirms_unsubscribed(self, sc, mock_sub, binding_key): + """SubscriptionController.observe confirms unsubscription when resolver returns inactive.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + sc.unsubscribe(mock_sub) + event = MockEvent(binding_key=binding_key, is_active=False) + + ## Act + sc.observe(event) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.UNSUBSCRIBED + + @capture_logs() + def test_observe_ignores_unrelated_events(self, sc, mock_sub, binding_key): + """SubscriptionController.observe ignores events with no binding key.""" + ## Arrange + sc.subscribe(mock_sub) + event = WsOpen() + + ## Act + sc.observe(event) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.NEW, 'observe should not have updated the binding' + + @capture_logs(logger_level='WARNING', expected_errors=['Observed a binding_key'], partial_match=True) + def test_observe_warns_on_missing_subscription(self, sc): + """SubscriptionController.observe logs warning when binding key has no subscription.""" + ## Arrange + event = MockEvent(binding_key='unknown_key', is_active=True) + + ## Act + sc.observe(event) + + +class TestReconcile: + @capture_logs() + def test_reconcile_binding_sends_subscribe_payload(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController.reconcile_binding sends subscribe payload for ACTIVE intent.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + sc.reconcile_binding(binding) + + ## Assert + mock_send_payload.assert_called_once_with('sub_payload') + assert binding.attempts == 1 + assert binding.status == BindingStatus.NEW + + @capture_logs() + def test_reconcile_binding_sends_unsubscribe_payload(self, sc, test_subscription_no_confirm, mock_send_payload): + """SubscriptionController.reconcile_binding sends unsubscribe payload for UNSUBSCRIBED intent.""" + ## Arrange + sc.unsubscribe(test_subscription_no_confirm) + binding = sc._bindings[test_subscription_no_confirm.binding_key()] + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + mock_send_payload.assert_called_once_with('unsub_no_confirm_payload') + assert binding.attempts == 1 + + @capture_logs() + def test_reconcile_binding_auto_confirms_when_no_confirm_subscribe(self, sc, test_subscription_no_confirm, mock_send_payload): + """SubscriptionController.reconcile_binding auto-confirms when confirms_subscribe is False.""" + ## Arrange + sc.subscribe(test_subscription_no_confirm) + binding = sc._bindings[test_subscription_no_confirm.binding_key()] + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.ACTIVE + + @capture_logs() + def test_reconcile_binding_no_auto_confirm_when_confirms_unsubscribe_true(self, sc, test_subscription_no_confirm, mock_send_payload): + """SubscriptionController.reconcile_binding does not auto-confirm when confirms_unsubscribe is True.""" + ## Arrange + sc.unsubscribe(test_subscription_no_confirm) + binding = sc._bindings[test_subscription_no_confirm.binding_key()] + + ## Act + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.NEW + assert binding.attempts == 1 + + @capture_logs() + def test_reconcile_binding_auto_confirms_when_no_confirm_unsubscribe(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController.reconcile_binding auto-confirms when confirms_unsubscribe is False.""" + ## Arrange + sc.unsubscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.UNSUBSCRIBED + + @capture_logs() + def test_reconcile_binding_respects_timeout(self, sc, mock_sub, binding_key): + """SubscriptionController.reconcile_binding waits for timeout before retrying.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1000.5]): + sc.reconcile_binding(binding) + first_attempt = binding.attempts + sc.reconcile_binding(binding) + + ## Assert + assert binding.attempts == first_attempt + + @capture_logs() + def test_reconcile_binding_retries_after_timeout(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController.reconcile_binding retries after timeout expires.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1000.0, 1002.0]): + sc.reconcile_binding(binding) + sc.reconcile_binding(binding) # this call should return without making an attempt + sc.reconcile_binding(binding) + + ## Assert + assert binding.attempts == 2, 'should only attempt twice' + assert mock_send_payload.call_count == 2 + + @capture_logs() + def test_reconcile_binding_marks_failed_after_max_retries(self, sc, mock_sub, binding_key): + """SubscriptionController.reconcile_binding marks binding as FAILED after max retries.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1001.1, 1002.2, 1003.3, 1004.4]): + for i in range(4): + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.FAILED + + @capture_logs() + def test_reconcile_binding_marks_expired_when_no_activity(self, sc, test_subscription_with_expiry): + """SubscriptionController.reconcile_binding marks binding as EXPIRED when expiry time passes.""" + ## Arrange + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1011.0]): + sc.subscribe(test_subscription_with_expiry) + binding = sc._bindings[test_subscription_with_expiry.binding_key()] + with sc._condition: + sc.reconcile_binding(binding) + with sc._condition: + sc._confirm_subscribed(test_subscription_with_expiry.binding_key()) + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.EXPIRED + + @capture_logs() + def test_reconcile_binding_does_not_expire_without_expiry_seconds(self, sc, mock_sub, binding_key): + """SubscriptionController.reconcile_binding does not expire when expiry_seconds is None.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + sc.reconcile_binding(binding) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1002.0]): + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.ACTIVE + + @capture_logs() + def test_reconcile_binding_does_not_expire_before_expiry_time(self, sc, test_subscription_with_expiry): + """SubscriptionController.reconcile_binding does not expire before expiry_seconds elapses.""" + ## Arrange + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1005.0]): + sc.subscribe(test_subscription_with_expiry) + binding = sc._bindings[test_subscription_with_expiry.binding_key()] + with sc._condition: + sc.reconcile_binding(binding) + with sc._condition: + sc._confirm_subscribed(test_subscription_with_expiry.binding_key()) + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.ACTIVE + + @capture_logs() + def test_reconcile_bindings_processes_all_bindings(self, sc): + """SubscriptionController.reconcile_bindings processes all registered bindings.""" + ## Arrange + sub1 = MockSubscription(topic_value='topic1', payload_value='payload1') + sub2 = MockSubscription(topic_value='topic2', payload_value='payload2') + sc.subscribe(sub1) + sc.subscribe(sub2) + + ## Act + sc.reconcile_bindings() + + ## Assert + binding1 = sc._bindings[sub1.binding_key()] + binding2 = sc._bindings[sub2.binding_key()] + assert binding1.attempts == 1 + assert binding2.attempts == 1 + + +class TestSend: + @capture_logs(logger_level='INFO', expected_errors=['Sending payload unsuccessful'], partial_match=True) + def test_send_logs_when_send_fails(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController._send logs when send_payload returns False.""" + ## Arrange + mock_send_payload.return_value = False + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + sc.reconcile_binding(binding) + + @capture_logs(logger_level='ERROR', expected_errors=['Exception sending payload'], partial_match=True) + def test_send_logs_exception(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController._send logs exceptions from send_payload.""" + ## Arrange + mock_send_payload.side_effect = RuntimeError('send error') + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + sc.reconcile_binding(binding) + + +class TestWaitFor: + @capture_logs() + def test_wait_for_returns_true_when_done(self, sc, mock_sub, binding_key): + """SubscriptionController.wait_for returns True when binding is done.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + result = sc.wait_for(binding_key, timeout=1.0) + + ## Assert + assert result is True + + @capture_logs() + def test_wait_for_returns_false_when_failed(self, sc, mock_sub, binding_key): + """SubscriptionController.wait_for returns False when binding is FAILED.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + binding.status = BindingStatus.FAILED + + ## Act + result = sc.wait_for(binding_key, timeout=1.0) + + ## Assert + assert result is False + + @capture_logs() + def test_wait_for_returns_false_when_missing(self, sc): + """SubscriptionController.wait_for returns False when binding does not exist.""" + ## Act + result = sc.wait_for('nonexistent', timeout=0.1) + + ## Assert + assert result is False + + @capture_logs() + def test_wait_for_returns_false_on_timeout(self, sc, mock_sub, binding_key): + """SubscriptionController.wait_for returns False when timeout expires.""" + ## Arrange + sc.subscribe(mock_sub) + + ## Act + result = sc.wait_for(binding_key, timeout=0.001) + + ## Assert + assert result is False + + @capture_logs() + def test_wait_for_waits_and_unblocks_on_notification(self, sc, mock_sub, binding_key): + """SubscriptionController.wait_for waits for notification and returns when status changes.""" + ## Arrange + sc.subscribe(mock_sub) + event = MockEvent(binding_key=binding_key, is_active=True) + + original_wait = sc._condition.wait + wait_call_count = 0 + + def mock_wait(timeout=None): + nonlocal wait_call_count + wait_call_count += 1 + if wait_call_count == 1: + sc.observe(event) + else: + original_wait(timeout) + + ## Act + with patch.object(sc._condition, 'wait', side_effect=mock_wait): + result = sc.wait_for(binding_key, timeout=5.0) + + ## Assert + assert result is True + assert wait_call_count == 1 + assert sc.get_status(binding_key) == BindingStatus.ACTIVE + + +class TestConfirmSubscribed: + @capture_logs() + def test_confirm_subscribed_updates_status(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_subscribed updates status to ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + + ## Act + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.ACTIVE + + @capture_logs() + def test_confirm_subscribed_ignores_when_already_active(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_subscribed does not update when already ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + binding = sc._bindings[binding_key] + binding.attempts = 5 + + ## Act + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Assert + assert binding.attempts == 5 + + @capture_logs() + def test_confirm_subscribed_ignores_when_intent_unsubscribed(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_subscribed does not update when intent is UNSUBSCRIBED.""" + ## Arrange + sc.subscribe(mock_sub) + sc.unsubscribe(mock_sub) + binding = sc._bindings[binding_key] + original_status = binding.status + + ## Act + sc._confirm_subscribed(binding_key) + + ## Assert + assert binding.status == original_status + + @capture_logs(logger_level='WARNING', expected_errors=['Unknown subscription'], partial_match=True) + def test_confirm_subscribed_warns_when_missing(self, sc): + """SubscriptionController._confirm_subscribed logs warning when binding does not exist.""" + ## Act + sc._confirm_subscribed('nonexistent') + + +class TestConfirmUnsubscribed: + @capture_logs() + def test_confirm_unsubscribed_updates_status(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_unsubscribed updates status to UNSUBSCRIBED.""" + ## Arrange + sc.unsubscribe(mock_sub) + + ## Act + with sc._condition: + sc._confirm_unsubscribed(binding_key) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.UNSUBSCRIBED + + @capture_logs() + def test_confirm_unsubscribed_ignores_when_already_unsubscribed(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_unsubscribed does not update when already UNSUBSCRIBED.""" + ## Arrange + sc.unsubscribe(mock_sub) + with sc._condition: + sc._confirm_unsubscribed(binding_key) + binding = sc._bindings[binding_key] + binding.attempts = 5 + + ## Act + with sc._condition: + sc._confirm_unsubscribed(binding_key) + + ## Assert + assert binding.attempts == 5 + + @capture_logs() + def test_confirm_unsubscribed_ignores_when_intent_active(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_unsubscribed does not update when intent is ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + original_status = binding.status + + ## Act + sc._confirm_unsubscribed(binding_key) + + ## Assert + assert binding.status == original_status + + @capture_logs(logger_level='WARNING', expected_errors=['Unknown subscription'], partial_match=True) + def test_confirm_unsubscribed_warns_when_missing(self, sc): + """SubscriptionController._confirm_unsubscribed logs warning when binding does not exist.""" + ## Act + sc._confirm_unsubscribed('nonexistent') diff --git a/test/unit/ws_v2/test_ws_v2_regressions_u.py b/test/unit/ws_v2/test_ws_v2_regressions_u.py new file mode 100644 index 00000000..e87842d5 --- /dev/null +++ b/test/unit/ws_v2/test_ws_v2_regressions_u.py @@ -0,0 +1,121 @@ +import ssl +from unittest.mock import MagicMock + +from ibind import events +from ibind.ibkr_ws_v2.ibkr_router import IbkrRouter +from ibind.ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 +from ibind.ws_v2._ws_events import CallbackSink, QueueSink +from ibind.ws_v2.ws_runtime import WsState +from ibind.ws_v2.ws_transport import WsTransport + + +def test_ws_state_values_are_strings(): + assert WsState.OPEN.value == 'OPEN' + assert str(WsState.AUTHENTICATED) == 'AUTHENTICATED' + + +def test_callback_sink_instances_do_not_share_callbacks(): + ## Arrange + first_callback = MagicMock() + second_callback = MagicMock() + first = CallbackSink() + second = CallbackSink() + + ## Act + first.on(events.WsOpen, first_callback) + second.emit(events.WsOpen()) + + ## Assert + first_callback.assert_not_called() + second_callback.assert_not_called() + + +def test_queue_sink_instances_do_not_share_queues(): + ## Arrange + first = QueueSink() + second = QueueSink() + + ## Act + first.emit(events.WsOpen()) + + ## Assert + assert second.get(events.WsOpen) is None + + +def test_transport_liveness_fails_when_pong_predates_ping(mocker): + ## Arrange + transport = WsTransport(url='wss://example.test/ws', event_callback=lambda _: None, sslopt={}) + transport._wsa = MagicMock(last_ping_tm=100, last_pong_tm=90) + mocker.patch('ibind.ws_v2.ws_transport.time.time', return_value=160) + + ## Act / Assert + assert transport.check_ping(max_interval=50) is False + + +def test_transport_liveness_tracks_first_unanswered_ping(mocker): + ## Arrange + current_time = [100.0] + transport = WsTransport(url='wss://example.test/ws', event_callback=lambda _: None, sslopt={}) + transport._wsa = MagicMock(last_ping_tm=100.0, last_pong_tm=0) + mocker.patch('ibind.ws_v2.ws_transport.time.time', side_effect=lambda: current_time[0]) + + ## Act / Assert + assert transport.check_ping(max_interval=50) is True + current_time[0] = 130.0 + transport._wsa.last_ping_tm = 130.0 + assert transport.check_ping(max_interval=50) is True + current_time[0] = 151.0 + transport._wsa.last_ping_tm = 151.0 + assert transport.check_ping(max_interval=50) is False + + +def test_transport_liveness_allows_fresh_pong_after_ping(mocker): + ## Arrange + transport = WsTransport(url='wss://example.test/ws', event_callback=lambda _: None, sslopt={}) + transport._wsa = MagicMock(last_ping_tm=100, last_pong_tm=120) + mocker.patch('ibind.ws_v2.ws_transport.time.time', return_value=140) + + ## Act / Assert + assert transport.check_ping(max_interval=50) is True + + +def test_ibkr_ws_client_v2_oauth_url_preserves_existing_query_and_forces_tls(): + ## Arrange / Act + client = IbkrWsClientV2( + url='wss://localhost:5000/v1/api/ws?existing=1', + ibkr_client=MagicMock(), + use_oauth=True, + access_token='TOKEN VALUE', # noqa: S106 + cacert=False, + ) + + ## Assert + assert client._runtime._url == 'wss://localhost:5000/v1/api/ws?existing=1&oauth_token=TOKEN+VALUE' + assert client._runtime._sslopt == {'cert_reqs': ssl.CERT_REQUIRED} + + +def test_router_unknown_auth_status_is_not_silently_ignored(): + ## Arrange + router = IbkrRouter() + + ## Act + event = router.route('{"topic":"sts","args":{"unexpected":true}}') + + ## Assert + assert isinstance(event, events.GenericIbkrEvent) + assert event.topic == 'sts' + assert event.data == {'unexpected': True} + + +def test_client_v2_competing_false_does_not_change_authentication(mocker): + ## Arrange + client = IbkrWsClientV2(ibkr_client=MagicMock(), use_oauth=False) + client._runtime.set_authenticated = MagicMock() + logger_error = mocker.patch('ibind.ibkr_ws_v2.ibkr_ws_client_v2._LOGGER.error') + + ## Act + client._on_authentication_status(events.AuthenticationStatus(data={'competing': False}, authenticated=None, competing=False)) + + ## Assert + logger_error.assert_not_called() + client._runtime.set_authenticated.assert_not_called()