From ec8f145d05f2a44b32b56f4b6eba6b20bd19e8ee Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 29 Apr 2026 18:16:58 +0200 Subject: [PATCH 01/32] feat: added version 2 of IbkrWsClient (still WIP) - refactoring the threading/lifecycle model, making (un)subscribing actions idempotent and introducing Pydantic models at input and output --- examples/ws_04_ws_v2.py | 87 ++++++ ibind/ibkr_ws_v2/ibkr_events.py | 152 ++++++++++ ibind/ibkr_ws_v2/ibkr_router.py | 266 +++++++++++++++++ ibind/ibkr_ws_v2/ibkr_subscriptions.py | 228 +++++++++++++++ ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 155 ++++++++++ ibind/ws_v2/events.py | 143 ++++++++++ ibind/ws_v2/subscription_controller.py | 235 +++++++++++++++ ibind/ws_v2/ws_runtime.py | 380 +++++++++++++++++++++++++ ibind/ws_v2/ws_transport.py | 164 +++++++++++ 9 files changed, 1810 insertions(+) create mode 100644 examples/ws_04_ws_v2.py create mode 100644 ibind/ibkr_ws_v2/ibkr_events.py create mode 100644 ibind/ibkr_ws_v2/ibkr_router.py create mode 100644 ibind/ibkr_ws_v2/ibkr_subscriptions.py create mode 100644 ibind/ibkr_ws_v2/ibkr_ws_client_v2.py create mode 100644 ibind/ws_v2/events.py create mode 100644 ibind/ws_v2/subscription_controller.py create mode 100644 ibind/ws_v2/ws_runtime.py create mode 100644 ibind/ws_v2/ws_transport.py diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py new file mode 100644 index 00000000..927eba50 --- /dev/null +++ b/examples/ws_04_ws_v2.py @@ -0,0 +1,87 @@ +""" +WebSocket Intermediate + +In this example we: + +* Demonstrate subscription to multiple channels +* Utilise queue accessors +* Use the 'signal' module to ensure we unsubscribe and shutdown upon the program termination + +Assumes the Gateway is deployed at 'localhost:5000' and the IBIND_ACCOUNT_ID and IBIND_CACERT environment variables have been set. +""" + +import os +import signal +import time + +from ibind import IbkrWsKey, IbkrWsClient, ibind_logs_initialize +from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription +from ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 + +ibind_logs_initialize(log_to_file=False, log_level='DEBUG') + +account_id = os.getenv('IBIND_ACCOUNT_ID', '[YOUR_ACCOUNT_ID]') +cacert = os.getenv('IBIND_CACERT', False) # insert your cacert path here + +# ws_client = IbkrWsClient(cacert=cacert, account_id=account_id) +ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id) + +# def stop(_, _1): +# print('exit') +# ws_client.shutdown() +# print('done') +# return False +# +# signal.signal(signal.SIGINT, stop) +# signal.signal(signal.SIGTERM, stop) + +ws_client.start() + +md_sub = MarketDataSubscription(conid='265598') +or_sub = OrdersSubscription() +al_sub = AccountLedgerSubscription(account_id=account_id) + +ws_client.subscribe(md_sub) +ws_client.subscribe(or_sub) +ws_client.subscribe(al_sub) + +try: + while ws_client.is_running(): + time.sleep(1) +except KeyboardInterrupt: + print('Interrupt') + +ws_client.unsubscribe(md_sub) +ws_client.unsubscribe(or_sub) +ws_client.unsubscribe(al_sub) +ws_client.shutdown() + +# requests = [ +# {'channel': 'md+265598', 'data': {'fields': ['55', '71', '84', '86', '88', '85', '87', '7295', '7296', '70']}}, +# {'channel': 'or'}, +# {'channel': 'tr'}, +# {'channel': f'sd+{account_id}'}, +# {'channel': f'ld+{account_id}'}, +# {'channel': 'pl'}, +# ] +# +# +# + +# +# for request in requests: +# while not ws_client.subscribe(**request): +# time.sleep(1) +# +# while ws_client.running: +# try: +# for qa in queue_accessors: +# while not qa.empty(): +# print(str(qa), qa.get()) +# +# time.sleep(1) +# except KeyboardInterrupt: +# print('KeyboardInterrupt') +# break +# +# stop(None, None) \ No newline at end of file diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py new file mode 100644 index 00000000..0c4b6d26 --- /dev/null +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -0,0 +1,152 @@ +from enum import Enum +from typing import Any + +from ws_v2.events import WsEvent + + +class IbkrWsKey(Enum): + # generic + UNCLASSIFIED = 'UNCLASSIFIED' + GENERIC = 'GENERIC' + UNSUBSCRIPTION = 'UNSUBSCRIPTION' + + # unsolicited + ACCOUNT_UPDATE = 'ACCOUNT_UPDATE' + AUTHENTICATION_STATUS = 'AUTHENTICATION_STATUS' + BULLETIN = 'BULLETIN' + ERROR = 'ERROR' + SYSTEM = 'SYSTEM' + NOTIFICATION = 'NOTIFICATION' + + # subscription-based + ACCOUNT_SUMMARY = 'ACCOUNT_SUMMARY' + ACCOUNT_LEDGER = 'ACCOUNT_LEDGER' + MARKET_DATA = 'MARKET_DATA' + MARKET_HISTORY = 'MARKET_HISTORY' + PRICE_LADDER = 'PRICE_LADDER' + ORDERS = 'ORDERS' + PNL = 'PNL' + TRADES = 'TRADES' + + # internal + CLIENT_INTERNAL = 'CLIENT_INTERNAL' + + @classmethod + def from_channel(cls, channel): + channel_to_key = { + 'sd': IbkrWsKey.ACCOUNT_SUMMARY, + 'ld': IbkrWsKey.ACCOUNT_LEDGER, + 'md': IbkrWsKey.MARKET_DATA, + 'mh': IbkrWsKey.MARKET_HISTORY, + 'bd': IbkrWsKey.PRICE_LADDER, + 'or': IbkrWsKey.ORDERS, + 'pl': IbkrWsKey.PNL, + 'tr': IbkrWsKey.TRADES, + } + if channel in channel_to_key: + return channel_to_key[channel] + raise ValueError(f"No enum member associated with channel '{channel}'") + + + +class ParsedIbkrMessage(WsEvent): + key: str = IbkrWsKey.UNCLASSIFIED + message: dict | None + topic: str | None = None + data: dict | None = None + subscribed: str | None = None + channel: str | None = None + + +# =================== +# == Unsolicited == +# =================== + +class IbkrError(WsEvent): + key: IbkrWsKey = IbkrWsKey.ERROR + message: str + + +class WaitingForSession(WsEvent): + key: IbkrWsKey = IbkrWsKey.GENERIC + + +class Notification(WsEvent): + key: IbkrWsKey = IbkrWsKey.NOTIFICATION + message: str + + +class Bulletin(WsEvent): + key: IbkrWsKey = IbkrWsKey.BULLETIN + message: str + + +class AccountUpdate(WsEvent): + key: IbkrWsKey = IbkrWsKey.ACCOUNT_UPDATE + data: dict + + +class System(WsEvent): + key: IbkrWsKey = IbkrWsKey.SYSTEM + data: dict + + +class AuthenticationStatus(WsEvent): + key: IbkrWsKey = IbkrWsKey.AUTHENTICATION_STATUS + data: dict + authenticated: bool | None + competing: bool | None + + +# ========================== +# == Subscription-based == +# ========================== + +class Unsubscription(WsEvent): + key: IbkrWsKey = IbkrWsKey.UNSUBSCRIPTION + target_key: IbkrWsKey + conid: int | None = None + + +class AccountSummary(WsEvent): + key: IbkrWsKey = IbkrWsKey.ACCOUNT_SUMMARY + data: dict + + +class AccountLedger(WsEvent): + key: IbkrWsKey = IbkrWsKey.ACCOUNT_LEDGER + account_id: str + data: dict + + +class MarketData(WsEvent): + key: IbkrWsKey = IbkrWsKey.MARKET_DATA + conid: str + data: dict = {} + fields: dict[str, Any] = {} + + +class MarketHistory(WsEvent): + key: IbkrWsKey = IbkrWsKey.MARKET_HISTORY + conid: str + data: dict + + +class Orders(WsEvent): + key: IbkrWsKey = IbkrWsKey.ORDERS + data: dict + + +class PriceLadder(WsEvent): + key: IbkrWsKey = IbkrWsKey.PRICE_LADDER + data: dict + + +class Pnl(WsEvent): + key: IbkrWsKey = IbkrWsKey.PNL + data: dict + + +class Trades(WsEvent): + key: IbkrWsKey = IbkrWsKey.TRADES + data: dict \ No newline at end of file diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py new file mode 100644 index 00000000..36fe7b64 --- /dev/null +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -0,0 +1,266 @@ +import json +from collections import defaultdict +from typing import Dict + +from client import ibkr_definitions +from client.ibkr_utils import extract_conid +from ibkr_ws_v2 import ibkr_events +from ibkr_ws_v2.ibkr_events import ParsedIbkrMessage, IbkrWsKey +from support.logs import project_logger +from support.py_utils import UNDEFINED, OneOrMany +from ws_v2.events import WsEvent + +_LOGGER = project_logger(__file__) + + +def parse_raw_message(raw_message: str): + message = json.loads(raw_message) + # print(message) + topic = message.get('topic', UNDEFINED) + + if topic is UNDEFINED: + return message, None, None, None, None + + data = message.get('args', {}) + + # subscribed is the indicator of whether it was a subscription or unsubscription, defined by the first letter + # channel is the actual channel we received the information about + subscribed, channel = topic[0], topic[1:] + + return message, topic, data, subscribed, channel + + +class IbkrRouter(): + def __init__( + self, + log_raw_messages: bool = False, + unwrap_market_data: bool = True + ): + self._log_raw_messages = log_raw_messages + self._unwrap_market_data = unwrap_market_data + self._server_id_conid_pairs: Dict[IbkrWsKey, Dict[str, int]] = defaultdict(dict) + + def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: + """ + API will only return fields that were updated. If you are not receiving certain fields in the response - means that they remain unchanged. + """ + if 'conid' not in data: # pragma: no cover + # sometimes the ticker message is just an empty update, we ignore it + return [] + + if not self._unwrap_market_data: + return ibkr_events.MarketData(conid=data['conid'], data=data) + # return {data['conid']: data} + + # result = {'conid': data['conid'], '_updated': data['_updated'], 'topic': data['topic']} + fields = {} + for key, value in data.items(): + if key in ibkr_definitions.snapshot_by_id: + # result[ibkr_definitions.snapshot_by_id[key]] = value + fields[ibkr_definitions.snapshot_by_id[key]] = value + return ibkr_events.MarketData(conid=str(data['conid']), fields=fields) + # return {data['conid']: result} + + def _preprocess_market_history_message(self, data: dict) -> OneOrMany[WsEvent]: + mh_server_id_conid_pairs = self._server_id_conid_pairs[IbkrWsKey.MARKET_HISTORY] + if 'serverId' in data and data['serverId'] not in mh_server_id_conid_pairs: + mh_server_id_conid_pairs[data['serverId']] = extract_conid(data) + + return ibkr_events.MarketHistory(conid=str(data['conid']), data=data) + + def _preprocess_account_leger(self, data): + events = [] + for entry in data['result']: + if 'acctCode' not in entry: + continue + event = ibkr_events.AccountLedger(data=entry, account_id=entry['acctCode']) + events.append(event) + return events + + def _handle_subscribed_message(self, channel: str, data: dict) -> OneOrMany[WsEvent] | None: + try: + ibkr_ws_key = IbkrWsKey.from_channel(channel[:2]) + except ValueError: + # ValueError means we don't support this channel + return None + + if ibkr_ws_key == IbkrWsKey.ACCOUNT_SUMMARY: + return ibkr_events.AccountSummary(data=data) + elif ibkr_ws_key == IbkrWsKey.ACCOUNT_LEDGER: + return self._preprocess_account_leger(data) + elif ibkr_ws_key == IbkrWsKey.MARKET_DATA: + return self._preprocess_market_data_message(data) + elif ibkr_ws_key == IbkrWsKey.MARKET_HISTORY: + return self._preprocess_market_history_message(data) + elif ibkr_ws_key == IbkrWsKey.PRICE_LADDER: + return ibkr_events.PriceLadder(data=data) + elif ibkr_ws_key == IbkrWsKey.ORDERS: + return ibkr_events.Orders(data=data) + elif ibkr_ws_key == IbkrWsKey.PNL: + return ibkr_events.Pnl(data=data) + elif ibkr_ws_key == IbkrWsKey.TRADES: + return ibkr_events.Trades(data=data) + else: + _LOGGER.error(f'{self}: Unhandled subscribed message: {data}') + return None + + def _handle_account_update(self, message, arguments) -> OneOrMany[WsEvent]: + # if 'accounts' in data and self._account_id not in data['accounts']: + # _LOGGER.error(f'{self}: Account ID mismatch: expected={self._account_id}, received={data["accounts"]}') + # if 'acctProps' in data: # expected account update that we ignore + # return [] + + _LOGGER.info(f'{self}: Account update: {arguments}') + return ibkr_events.AccountUpdate(data=arguments) + + def _handle_authentication_status(self, message, arguments) -> OneOrMany[WsEvent]: + # if 'authenticated' in arguments: + # if arguments.get('authenticated') is False: + # _LOGGER.error(f'{self}: Status unauthenticated: {arguments}') + # + # # TODO: this needs to be handled in IbkrWsClient or WsRuntime + # # self.set_authenticated(data.get('authenticated')) + # elif 'competing' in arguments: + # if arguments.get('competing') is False: + # pass + # _LOGGER.error(f'{self}: Authentication competing: {arguments}') + + if 'authenticated' in arguments: + _LOGGER.info(f'{self}: Authentication status: {arguments}') + return ibkr_events.AuthenticationStatus(data=arguments, authenticated=arguments.get('authenticated'), competing=arguments.get('competing')) + elif ( # expected status updates that we ignore + arguments == {'message': ''} or + arguments.get('fail', '') == '' or + 'serverName' in arguments or + 'serverVersion' in arguments or + 'username' in arguments + ): + _LOGGER.info(f'{self}: Authentication silenced: {arguments}') + pass + + return [] + + def _handle_bulletin(self, message) -> OneOrMany[WsEvent]: # pragma: no cover + return ibkr_events.Bulletin(message=message) + + def _handle_error(self, message) -> OneOrMany[WsEvent]: + _LOGGER.error(f'{self}: on_message error: {message}') + return ibkr_events.IbkrError(message=message) + + def _handle_notification(self, data) -> OneOrMany[WsEvent]: # pragma: no cover + events = [] + for notification in data: + _LOGGER.info(f'{self}: IBKR notification: {notification}') + events.append(ibkr_events.Notification(message=notification)) + return events + + def _handle_market_history_unsubscribe(self, data) -> OneOrMany[WsEvent]: + server_id = data['message'].split('Unsubscribed ')[-1] + mh_server_id_conid_pairs = self._server_id_conid_pairs[IbkrWsKey.MARKET_HISTORY] + if server_id in mh_server_id_conid_pairs: + conid = mh_server_id_conid_pairs[server_id] + _LOGGER.info(f'{self}: Received unsubscribing confirmation for server_id={server_id!r}/conid={conid!r}.') + if conid is not None: + return ibkr_events.Unsubscription(target_key=IbkrWsKey.MARKET_HISTORY, conid=conid) + # self.modify_subscription(f'mh+{conid}', status=False) + + _LOGGER.warning(f'{self}: Unknown conid={conid!r}. Cannot mark the subscription as unsubscribed.') + else: + _LOGGER.warning( + f'{self}: Received unsubscribing confirmation for unknown server_id={server_id!r}. Existing server_ids: {mh_server_id_conid_pairs}' + ) + return [] + + def _handle_message_without_topic(self, message: dict) -> OneOrMany[WsEvent]: + if 'message' in message: + if message['message'] == 'waiting for session': + _LOGGER.info(f'{self}: Waiting for an active IBKR session.') + return ibkr_events.WaitingForSession() + + if 'Unsubscribed' in message['message']: + return self._handle_market_history_unsubscribe(message) + + elif 'result' in message: + if message['result'] == 'unsubscribed from summary': + return ibkr_events.Unsubscription(target_key=IbkrWsKey.ACCOUNT_SUMMARY) + # return self.modify_subscription(f'sd+{self._account_id}', status=False) + elif message['result'] == 'unsubscribed from ledger': + return ibkr_events.Unsubscription(target_key=IbkrWsKey.ACCOUNT_LEDGER) + # return self.modify_subscription(f'ld+{self._account_id}', status=False) + + _LOGGER.error(f'{self}: Unrecognised message without a topic: {message}') + return ParsedIbkrMessage(message=message) + + def _preprocess_raw_message(self, raw_message: str): + message = json.loads(raw_message) + # print(message) + topic = message.get('topic', UNDEFINED) + + if topic is UNDEFINED: + return message, None, None, None, None + + data = message.get('args', {}) + + # subscribed is the indicator of whether it was a subscription or unsubscription, defined by the first letter + # channel is the actual channel we received the information about + subscribed, channel = topic[0], topic[1:] + + return message, topic, data, subscribed, channel + + def route(self, raw_message: str) -> OneOrMany[WsEvent]: + if self._log_raw_messages: + _LOGGER.debug(f'{self}: Raw message: {raw_message}') + message, topic, arguments, subscribed, channel = parse_raw_message(raw_message) + + if 'error' in message: + return self._handle_error(message) + + elif topic is None: + # in general most message should carry a topic, other than for few exceptions + return self._handle_message_without_topic(message) + + elif topic == 'tic': + self._tic_message = message + + elif topic == 'system': + if 'hb' in message: + self._last_heartbeat = message['hb'] + return ibkr_events.System(data=message) + + elif topic == 'act': + return self._handle_account_update(message, arguments) + + elif topic == 'blt': + return self._handle_bulletin(message) + + elif topic == 'ntf': + return self._handle_notification(arguments) + + elif topic == 'sts': + return self._handle_authentication_status(message, arguments) + + elif topic == 'error': + return self._handle_error(message) + # _LOGGER.error(f'{self}: Error message: {message}') + + # elif self.has_subscription(channel): + # if not self.is_subscription_active(channel): + # self.modify_subscription(channel, status=True) + else: + events = self._handle_subscribed_message(channel, message) + if events is None: + _LOGGER.error(f'{self}: Channel "{channel}" subscribed but lacking a handler. Message: {message}') + events = ParsedIbkrMessage(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) + return events + # _LOGGER.warning(f'{self}: Handled a channel "{channel}" message that is missing a subscription. Message: {message}') + + _LOGGER.error(f'{self}: Topic "{topic}" unrecognised. Message: {message}') + return ParsedIbkrMessage(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) + + # def route(self, raw_message) -> List[WsEvent]: + # _LOGGER.debug(f'{self}: Routing message: {raw_message}') + # message, topic, data, subscribed, channel = parse_raw_message(raw_message) + # return [ParsedIbkrMessage(message=message, topic=topic, data=data, subscribed=subscribed, channel=channel)] + + def __str__(self): + return f'{self.__class__.__qualname__}()' \ No newline at end of file diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py new file mode 100644 index 00000000..a7f4abca --- /dev/null +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -0,0 +1,228 @@ +import json +from typing import Literal, Any + +from ibkr_ws_v2.ibkr_events import IbkrWsKey, AccountSummary, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades +from ws_v2.events import WsEvent +from ws_v2.subscription_controller import Subscription + + +def ibkr_payload(op: Literal['s', 'u'], topic: str, target: str | None = None, data: dict[str, Any] | None = None) -> str: + payload = f"{op}{topic}" + if target is not None: + payload += f"+{target}" + if data is not None: + payload += f"+{json.dumps(data, separators=(',', ':'))}" + return payload + + +class AccountSummarySubscription(Subscription): + @property + def key(self) -> IbkrWsKey: + return IbkrWsKey.ACCOUNT_SUMMARY + + @property + def topic(self) -> str: + return '' + + def subscribe_payload(self) -> str: + ... + + def unsubscribe_payload(self) -> str: + ... + + def confirms_subscribe(self) -> bool: + return True + + def confirms_unsubscribe(self) -> bool: + return True + + +class AccountLedgerSubscription(Subscription): + account_id: str + @property + def key(self) -> IbkrWsKey: + return IbkrWsKey.ACCOUNT_LEDGER + + @property + def topic(self) -> str: + return 'ld' + + def subscribe_payload(self) -> str: + return ibkr_payload("s", "ld", self.account_id) + + + def unsubscribe_payload(self) -> str: + return ibkr_payload("u", "ld", self.account_id) + + + def confirms_subscribe(self) -> bool: + return True + + def confirms_unsubscribe(self) -> bool: + return True + + def make_hash(self): + return self.topic + "+" + self.account_id + +class MarketDataSubscription(Subscription): + conid: str + fields: tuple[str, ...] = ("31", "84", "86") + + @property + def key(self) -> IbkrWsKey: + return IbkrWsKey.MARKET_DATA + + @property + def topic(self) -> str: + return "md" + + def subscribe_payload(self) -> str: + return ibkr_payload("s", "md", self.conid, {"fields": list(self.fields)}) + + def unsubscribe_payload(self) -> str: + return ibkr_payload("u", "md", self.conid, {}) + + @property + def confirms_subscribe(self) -> bool: + return True + + @property + def confirms_unsubscribe(self) -> bool: + return False + + def make_hash(self): + return self.topic + "+" + self.conid + + +class MarketHistorySubscription(Subscription): + conid: str + + @property + def key(self) -> IbkrWsKey: + return IbkrWsKey.MARKET_HISTORY + + @property + def topic(self) -> str: + return '' + + def subscribe_payload(self) -> str: + ... + + def unsubscribe_payload(self) -> str: + ... + + def confirms_subscribe(self) -> bool: + return True + + def confirms_unsubscribe(self) -> bool: + return True + + +class OrdersSubscription(Subscription): + @property + def key(self) -> IbkrWsKey: + return IbkrWsKey.ORDERS + + @property + def topic(self) -> str: + return "or" + + def subscribe_payload(self) -> str: + return ibkr_payload("s", "or") + + def unsubscribe_payload(self) -> str: + return ibkr_payload("u", "or", data={}) + + @property + def confirms_subscribe(self) -> bool: + return False + + @property + def confirms_unsubscribe(self) -> bool: + return False + + +class PriceLadderSubscription(Subscription): + @property + def key(self) -> IbkrWsKey: + return IbkrWsKey.PRICE_LADDER + + @property + def topic(self) -> str: + return '' + + def subscribe_payload(self) -> str: + ... + + def unsubscribe_payload(self) -> str: + ... + + def confirms_subscribe(self) -> bool: + return False + + def confirms_unsubscribe(self) -> bool: + return False + + +class PnlSubscription(Subscription): + @property + def key(self) -> IbkrWsKey: + return IbkrWsKey.PNL + + @property + def topic(self) -> str: + return '' + + def subscribe_payload(self) -> str: + ... + + def unsubscribe_payload(self) -> str: + ... + + def confirms_subscribe(self) -> bool: + return True + + def confirms_unsubscribe(self) -> bool: + return False + + +class TradesSubscription(Subscription): + @property + def key(self) -> IbkrWsKey: + return IbkrWsKey.TRADES + + @property + def topic(self) -> str: + return '' + + def subscribe_payload(self) -> str: + ... + + def unsubscribe_payload(self) -> str: + ... + + def confirms_subscribe(self) -> bool: + return True + + def confirms_unsubscribe(self) -> bool: + return False + +def event_to_subscription(event:WsEvent): + if isinstance(event, AccountSummary): + return AccountSummarySubscription() + elif isinstance(event, AccountLedger): + return AccountLedgerSubscription(account_id=event.account_id) + elif isinstance(event, MarketData): + return MarketDataSubscription(conid=event.conid) + elif isinstance(event, MarketHistory): + return MarketHistorySubscription(conid=event.conid) + elif isinstance(event, Orders): + return OrdersSubscription() + elif isinstance(event, PriceLadder): + return PriceLadderSubscription() + elif isinstance(event, Pnl): + return PnlSubscription() + elif isinstance(event, Trades): + return TradesSubscription() + else: + raise ValueError(f'Unsupported event: {event}') \ No newline at end of file diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py new file mode 100644 index 00000000..ce8c527b --- /dev/null +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -0,0 +1,155 @@ +import json +from typing import Union + +import var +from ibind import ExternalBrokerError, IbkrClient +from ibkr_ws_v2 import ibkr_events +from ibkr_ws_v2.ibkr_router import IbkrRouter +from ibkr_ws_v2.ibkr_subscriptions import event_to_subscription +from support.logs import project_logger +from ws_v2 import events +from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router +from ws_v2.subscription_controller import Subscription +from ws_v2.ws_runtime import WsRuntime, WsState + +_LOGGER = project_logger(__file__) + +_DEFAULT_CYCLE_INTERVAL = 0.25 + + +class IbkrWsClientV2(): + def __init__( + self, + account_id: str = var.IBIND_ACCOUNT_ID, + url: str = var.IBIND_WS_URL, + host: str = '127.0.0.1', + port: str = '5000', + base_route: str = '/v1/api/ws', + ibkr_client: IbkrClient = None, + use_oauth: bool = var.IBIND_USE_OAUTH, + access_token: str = var.IBIND_OAUTH1A_ACCESS_TOKEN, + cacert: Union[str, bool] = var.IBIND_CACERT, + cycle_interval: float = _DEFAULT_CYCLE_INTERVAL, + recreate_subscriptions_on_reconnect: bool = True, + sink: EventSink = None, + router: Router = None, + ): + self._account_id = account_id + + url = var.IBIND_OAUTH1A_WS_URL if url is None and use_oauth else url + + if url is None: + url = f'wss://{host}:{port}{base_route}' + + if use_oauth: + if access_token is None: + raise ValueError( + 'OAuth access token not found. Please set IBIND_OAUTH1A_ACCESS_TOKEN environment variable or provide it as `access_token` argument.' + ) + url += f'?oauth_token={access_token}' + + if ibkr_client is None: + ibkr_client = IbkrClient(account_id=account_id, host=host, port=port, cacert=cacert, use_oauth=use_oauth) + + self._ibkr_client = ibkr_client + self._use_oauth = use_oauth + self._recreate_subscriptions_on_reconnect = recreate_subscriptions_on_reconnect + + if sink is None: + # self._queue_controller = QueueController[IbkrWsKey]() + # self._queue_controller.register_queues(['CLIENT_INTERNAL', 'IBKR']) + # sink = QueueSink(queue_controller=self._queue_controller) + + sink = LogSink() + + self._internal_sink = CallbackSink() + self._register_internal_callbacks() + sink = CompositeSink(self._internal_sink, sink) + + if router is None: + router = IbkrRouter() + + self._runtime = WsRuntime( + url=url, + cycle_interval=cycle_interval, + ready_state=WsState.AUTHENTICATED, + cacert=cacert, + sink=sink, + router=router, + get_cookie=self._get_cookie, + get_header=self._get_header, + ) + + # self._subscription_controller = SubscriptionController( + # send_payload=self._runtime.send, + # is_running=self._runtime.is_running, + # ) + + def _register_internal_callbacks(self): + self._internal_sink.on(ibkr_events.AuthenticationStatus, self._on_authentication_status) + + self._internal_sink.on(ibkr_events.Unsubscription, self._on_unsubscription_confirmation) + + self._internal_sink.on(ibkr_events.AccountSummary, self._on_subscription_confirmation) + self._internal_sink.on(ibkr_events.AccountLedger, self._on_subscription_confirmation) + self._internal_sink.on(ibkr_events.MarketData, self._on_subscription_confirmation) + self._internal_sink.on(ibkr_events.MarketHistory, self._on_subscription_confirmation) + self._internal_sink.on(ibkr_events.Pnl, self._on_subscription_confirmation) + self._internal_sink.on(ibkr_events.Trades, self._on_subscription_confirmation) + + self._internal_sink.on(ibkr_events.WaitingForSession, self._set_unauthenticated) + + def _set_unauthenticated(self, _): + self._runtime.set_authenticated(False) + + def _on_authentication_status(self, event: ibkr_events.AuthenticationStatus): + if event.authenticated is False: + _LOGGER.error(f'{self}: Status unauthenticated: {event}') + elif event.competing is True: + _LOGGER.error(f'{self}: Authentication competing: {event}') + + self._runtime.set_authenticated(event.authenticated) + + def _on_subscription_confirmation(self, event: events.WsEvent): + subscription = event_to_subscription(event) + self._runtime.subscription_controller.set_subscription_active(subscription) + + def _on_unsubscription_confirmation(self, event: events.WsEvent): + subscription = event_to_subscription(event) + self._runtime.subscription_controller.set_subscription_unsubscribed(subscription) + + def _get_cookie(self): + try: + status = self._ibkr_client.tickle() + except ExternalBrokerError: + _LOGGER.warning('Acquiring session cookie failed, connection to the Gateway may be broken.') + return None + session_id = status.data['session'] + if self._use_oauth: + return f'api={session_id}' + payload = {'session': session_id} + return f'api={json.dumps(payload)}' + + def _get_header(self): + return {'User-Agent': 'ClientPortalGW/1'} if self._use_oauth else None + + def start(self): + self._runtime.start() + + def shutdown(self): + self._runtime.stop() + + def hard_reset(self): + self._runtime.hard_reset() + + def subscribe(self, subscription: Subscription) -> bool: + return self._runtime.subscription_controller.subscribe(subscription) + + def unsubscribe(self, subscription: Subscription) -> bool: + return self._runtime.subscription_controller.unsubscribe(subscription) + + def is_running(self) -> bool: + return self._runtime.is_running() + + def __str__(self): + return f'{self.__class__.__qualname__}()' \ No newline at end of file diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py new file mode 100644 index 00000000..fc4f135c --- /dev/null +++ b/ibind/ws_v2/events.py @@ -0,0 +1,143 @@ +from collections import defaultdict +from datetime import datetime +from typing import Hashable, Protocol, Callable + +from pydantic import BaseModel, ConfigDict, Field + +from base.queue_controller import QueueController +from support.logs import project_logger +from support.py_utils import OneOrMany + +_LOGGER = project_logger(__file__) + + +# ====================== +# == Events Classes == +# ====================== + +class WsEvent(BaseModel): + model_config = ConfigDict(frozen=True, extra="forbid") + + received_at: datetime = Field(default_factory=datetime.now) + key: Hashable + + def __str__(self): + return self._format() + + def __repr__(self): + return self._format() + + def _format(self): + data = self.model_dump() + + # remove key (already logged elsewhere) + data.pop("key", None) + + # normalize values + for k, v in data.items(): + if isinstance(v, datetime): + data[k] = v.isoformat() + elif isinstance(v, Exception): + data[k] = str(v) + + # move received_at to the end + items = [(k, v) for k, v in data.items() if k != "received_at"] + if "received_at" in data: + items.append(("received_at", data["received_at"])) + + fields = ", ".join( + f"{k}={v}" if isinstance(v, str) and "T" in v else f"{k}={repr(v)}" + for k, v in items + ) + + return f"{self.__class__.__name__}({fields})" + + +class ClientInternalEvent(WsEvent): + key: str = 'CLIENT_INTERNAL' + + +class WsOpen(ClientInternalEvent): + ... + + +class WsAuthenticated(ClientInternalEvent): + ... + + +class WsReady(ClientInternalEvent): + ... + + +class WsReconnect(ClientInternalEvent): + ... + + +class WsClose(ClientInternalEvent): + close_status_code: int | None + close_msg: str | None + + +class WsError(ClientInternalEvent): + model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) + error: Exception + + +class WsCritical(ClientInternalEvent): + model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) + exception: Exception + + +# ============= +# == Sinks == +# ============= + +class EventSink(Protocol): + def emit(self, event: "WsEvent") -> None: + ... + + +class LogSink: + def emit(self, event: WsEvent) -> None: + _LOGGER.debug(f'{event.key}: {str(event)}') + + +class CallbackSink: + def __init__(self): + self._callbacks: dict[type[WsEvent], list[Callable[[WsEvent], None]]] = defaultdict(list) + + def on(self, event_type: type[WsEvent], callback: Callable[[WsEvent], None]) -> None: + self._callbacks[event_type].append(callback) + + def emit(self, event: WsEvent) -> None: + for callback in self._callbacks[type(event)]: + callback(event) + + +class QueueSink: + def __init__(self, queue_controller: QueueController): + self._queue_controller = queue_controller + + def emit(self, event: WsEvent) -> None: + self._queue_controller.put_to_queue(event.key, event) + + +class CompositeSink: + def __init__(self, *sinks: EventSink): + self._sinks = sinks + + def emit(self, event: WsEvent) -> None: + for sink in self._sinks: + sink.emit(event) + + +# ============== +# == Router == +# ============== + +class Router(Protocol): + def route(self, raw_message) -> OneOrMany[WsEvent]: + ... + + def __str__(self): + return f'{self.__class__.__qualname__}()' \ No newline at end of file diff --git a/ibind/ws_v2/subscription_controller.py b/ibind/ws_v2/subscription_controller.py new file mode 100644 index 00000000..00254d15 --- /dev/null +++ b/ibind/ws_v2/subscription_controller.py @@ -0,0 +1,235 @@ +import copy +import time +from enum import Enum +from typing import Dict, Optional, Callable + +from pydantic import BaseModel, ConfigDict + +from ibind.support.logs import project_logger +from ibind.support.py_utils import TimeoutLock, exception_to_string + +_LOGGER = project_logger(__file__) + + +class Subscription(BaseModel): + model_config = ConfigDict(frozen=True) + + @property + def key(self) -> str: + raise NotImplementedError + + @property + def topic(self) -> str: + raise NotImplementedError + + def subscribe_payload(self) -> str: + raise NotImplementedError + + def unsubscribe_payload(self) -> str: + raise NotImplementedError + + @property + def confirms_subscribe(self) -> bool: + return True + + @property + def confirms_unsubscribe(self) -> bool: + return False + + def make_hash(self): + return self.subscribe_payload() + + def __hash__(self): + if hasattr(self, '_hash'): + return self._hash + _hash = hash(self.make_hash()) + setattr(self, '_hash', _hash) + return _hash + + def __str__(self): + return f'{self.__class__.__qualname__}({self.make_hash()})' + + +class BindingStatus(Enum): + NEW = "NEW" + PENDING = "PENDING" + ACTIVE = "ACTIVE" + FAILED = "FAILED" + DEGRADED = "DEGRADED" + UNSUBSCRIBED = "UNSUBSCRIBED" + RECONNECTING = "RECONNECTING" + + +class Binding(BaseModel): + subscription: Subscription + intent: BindingStatus + status: BindingStatus = BindingStatus.NEW + attempts: int = 0 + last_attempt: float = 0 + + +class SubscriptionController: + """ + Mixin which manages subscriptions to different channels using the WsClient. + + This class handles the logic for subscribing and unsubscribing to various channels. It maintains a + record of active subscriptions and provides methods to modify them. The class relies on a + SubscriptionProcessor to create subscription and unsubscription payloads. + + Constructor Parameters: + subscription_processor (SubscriptionProcessor): The processor to create subscription payloads. + subscription_retries (int, optional): The number of retries for subscription requests. Defaults to 5. + subscription_timeout (float, optional): The timeout in seconds for subscription requests. Defaults to 2. + """ + + def __init__( + self, + send_payload: Callable[[str], bool], + subscription_retries: int = 5, + subscription_timeout: float = 2, + ): + self._send_payload = send_payload + self._subscription_retries = subscription_retries + self._subscription_timeout = subscription_timeout + + # self._subscriptions: Dict[str, dict] = {} + self._bindings: Dict[Subscription, Binding] = {} + self._operational_lock = TimeoutLock(60) + + def _send(self, payload) -> bool: + try: + success = self._send_payload(payload) + if not success: + _LOGGER.info(f'{self}: Sending payload unsuccessful: {payload}') + return success + except Exception as e: + _LOGGER.exception(f'{self}: Exception sending payload: {payload}\n{exception_to_string(e)}') + return False + + def parse_binding(self, subscription: Subscription, binding: Binding): + if binding.status == binding.intent: + return + + if binding.last_attempt + self._subscription_timeout > time.time(): + return + + binding.last_attempt = time.time() + + if binding.attempts >= self._subscription_retries: + _LOGGER.info(f'{self}: Subscription failed after {self._subscription_retries} attempts: {subscription}') + binding.status = BindingStatus.FAILED + binding.attempts = 0 + return + + if binding.intent == BindingStatus.ACTIVE: + payload = binding.subscription.subscribe_payload() + self._send(payload) + if not subscription.confirms_subscribe: + _LOGGER.info(f'{self}: Subscribed: {payload} without confirmation.') + self.set_subscription_active(subscription) + elif binding.intent == BindingStatus.UNSUBSCRIBED: + payload = binding.subscription.unsubscribe_payload() + self._send(payload) + if not subscription.confirms_unsubscribe: + _LOGGER.info(f'{self}: Unsubscribed: {payload} without confirmation.') + self.set_subscription_unsubscribed(subscription) + + def parse_bindings(self): + for subscription, binding in self._bindings.items(): + self.parse_binding(subscription, binding) + + def subscribe(self, subscription: Subscription) -> bool: + with self._operational_lock: + if self.is_subscription_active(subscription): # do nothing if subscription is present and active + return True + + # store a new binding + if not self.has_subscription(subscription): + self._bindings[subscription] = Binding(subscription=subscription, intent=BindingStatus.ACTIVE) + + def unsubscribe(self, subscription: Subscription) -> bool: + with self._operational_lock: + if not self.has_subscription(subscription): + binding = Binding(subscription=subscription, intent=BindingStatus.UNSUBSCRIBED) + self._bindings[subscription] = binding + else: + self._bindings[subscription].intent = BindingStatus.UNSUBSCRIBED + + # def invalidate_subscriptions(self): + # for channel in self._subscriptions: + # if self._subscriptions[channel].get('status', False): + # self._subscriptions[channel]['status'] = False + # _LOGGER.info(f'{self}: Invalidated subscription: {channel}') + + def invalidate_subscriptions(self): + for subscription, binding in self._bindings.items(): + if binding.status == BindingStatus.ACTIVE: + binding.status = BindingStatus.DEGRADED + _LOGGER.info(f'{self}: Invalidated subscription: {subscription}') + + # def is_subscription_active(self, channel: str) -> Optional[bool]: # pragma: no cover + # return self._subscriptions.get(channel, {}).get('status', None) + + def is_subscription_active(self, subscription: Subscription) -> Optional[bool]: # pragma: no cover + if not self.has_subscription(subscription): + return False + return self._bindings.get(subscription).status == BindingStatus.ACTIVE + + # def has_active_subscriptions(self) -> bool: # pragma: no cover + # for channel in self._subscriptions: + # if self.is_subscription_active(channel): + # return True + # return False + + def has_active_subscriptions(self) -> bool: # pragma: no cover + for subscription in self._bindings: + if self.is_subscription_active(subscription): + return True + return False + + # def has_subscription(self, channel: str) -> bool: # pragma: no cover + # return channel in self._subscriptions + + def has_subscription(self, subscription: Subscription) -> bool: # pragma: no cover + return subscription in self._bindings + + # def get_active_subscriptions(self): + # return {channel: copy.deepcopy(subscription) for channel, subscription in self._subscriptions.items() if self.is_subscription_active(channel)} + + def get_active_subscriptions(self): + return { + subscription: copy.deepcopy(binding) + for subscription, binding in self._bindings.items() + if self.is_subscription_active(subscription) + } + + def set_subscription_active(self, subscription: Subscription): + if not self.has_subscription(subscription): + _LOGGER.warning(f'{self}: Unknown subscription {subscription} - cannot update status to {BindingStatus.ACTIVE.value}') + return + + binding = self._bindings[subscription] + + if binding.status == BindingStatus.ACTIVE or binding.intent == BindingStatus.UNSUBSCRIBED: + return + + binding.status = BindingStatus.ACTIVE + binding.attempts = 0 + _LOGGER.info(f'{self}: Updated subscription status: {subscription} -> {BindingStatus.ACTIVE.value}') + + def set_subscription_unsubscribed(self, subscription: Subscription): + if not self.has_subscription(subscription): + _LOGGER.warning(f'{self}: Unknown subscription {subscription} - cannot update status to {BindingStatus.UNSUBSCRIBED.value}') + return + + binding = self._bindings[subscription] + + if binding.status == BindingStatus.UNSUBSCRIBED or binding.intent == BindingStatus.ACTIVE: + return + + binding.status = BindingStatus.UNSUBSCRIBED + binding.attempts = 0 + _LOGGER.info(f'{self}: Updated subscription status: {subscription} -> {BindingStatus.UNSUBSCRIBED.value}') + + def __str__(self): + return f'{self.__class__.__qualname__}()' \ No newline at end of file diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py new file mode 100644 index 00000000..1276200c --- /dev/null +++ b/ibind/ws_v2/ws_runtime.py @@ -0,0 +1,380 @@ +import json +import ssl +import threading +from enum import Enum +from pathlib import Path +from queue import Queue +from threading import Thread, Event +from typing import Union, List, Dict, Callable, Literal + +from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION + +from support.logs import project_logger +from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string +from ws_v2 import events +from ws_v2.events import WsOpen, WsEvent, EventSink, Router +from ws_v2.subscription_controller import SubscriptionController +from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportCritical, TransportReconnect + +_LOGGER = project_logger(__file__) + +_NOOP = lambda: None + +_DEFAULT_TIMEOUT = 5 + + +class WsState(VerboseEnum): + STOPPED = 'STOPPED', + STARTING = 'STARTING', + CONNECTING = 'CONNECTING', + OPEN = 'OPEN', + AUTHENTICATED = 'AUTHENTICATED', + CLOSED = 'CLOSED', + DEGRADED = 'DEGRADED', + RECONNECTING = 'RECONNECTING', + STOPPING = 'STOPPING', + + +class WsRuntime(): + def __init__( + self, + url: str, + cycle_interval: float, + sink:EventSink, + router: Router, + ready_state: Literal[WsState.OPEN, WsState.AUTHENTICATED] = WsState.OPEN, + cacert: Union[str, bool] = False, + connection_timeout: float = _DEFAULT_TIMEOUT, + restart_on_close: bool = True, + restart_on_critical: bool = True, + get_cookie: Callable = _NOOP, + get_header: Callable = _NOOP + ): + self._url = url + self._cycle_interval = cycle_interval + self._sink = sink + self._router = router + self._ready_state = ready_state + self._connection_timeout = connection_timeout + self._restart_on_close = restart_on_close + self._restart_on_critical = restart_on_critical + + self._state = WsState.STOPPED + self._authenticated = False + + self._transport_thread = None + self._runtime_thread = None + self._transport_queue = Queue() + self._wait_event = Event() + + if not (cacert is False or Path(cacert).exists()): + raise ValueError(f'{self}: cacert must be a valid Path or False') + + if cacert is None or not cacert: + sslopt = {'cert_reqs': ssl.CERT_NONE} + else: + sslopt = {'ca_certs': cacert} + + + self._transport = WsTransport( + url=url, + event_callback=self._transport_callback, + sslopt=sslopt, + get_cookie=get_cookie, + get_header=get_header, + ) + + self.subscription_controller = SubscriptionController(send_payload=self.send) + + @property + def state(self): + _LOGGER.debug(f'{self}: State: {self._state.value}') + return self._state + + @state.setter + def state(self, value): + _LOGGER.debug(f'{self}: {self._state.value} -> {value.value}') + self._state = value + if self._state == self._ready_state: + self._sink.emit(events.WsReady()) + + def set_authenticated(self, value:bool): + if value != self._authenticated: + _LOGGER.debug(f'{self}: Authenticated: {value}') + self._authenticated = value + + if value and self._state == WsState.OPEN: + self._sink.emit(events.WsAuthenticated()) + self.state = WsState.AUTHENTICATED + + if value == False: + self.subscription_controller.invalidate_subscriptions() + + def get_authenticated(self) -> bool: + return self._authenticated + + + def _new_transport_thread(self): + self._transport_thread = Thread(target=self._transport.connect, name='ws_transport_thread') + self._transport_thread.daemon = True + self._transport_thread.start() + + def _new_runtime_thread(self): + self._runtime_thread = Thread(target=self._cycle, name='ws_runtime_thread') + self._runtime_thread.daemon = True + self._runtime_thread.start() + + def start(self): + if self.state != WsState.STOPPED: + return + + if self._runtime_thread is not None and self._runtime_thread.is_alive(): + _LOGGER.error(f'{self}: Runtime thread is not stopped') + return + + self.state = WsState.STARTING + self._running = True + + self._new_runtime_thread() + + connection_success = wait_until(lambda: self._state == self._ready_state, f'{self}: Starting timeout', timeout=self._connection_timeout) + return connection_success + + def stop(self): + if self.state == WsState.STOPPED: + return + + # wait until one more pass of the runtime thread has occurred to allow unsubscriptions to complete + wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) + self._wait_event.set() + wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) + + # TODO: decide which thread should stop first - transport or runtime + self.state = WsState.STOPPING + try: + self._transport.disconnect() + self._transport_thread.join(self._connection_timeout) + except Exception as e: + _LOGGER.error(f'{self}: Failed to disconnect: {e}') + # TODO: decide what to do if transport disconnect fails + + self._running = False + self._runtime_thread.join(self._connection_timeout) + + self.state = WsState.STOPPED + + def send(self, payload: str) -> bool: + if self._state != self._ready_state: + _LOGGER.error(f'{self}: State must be {self._ready_state.value} before sending payloads, found {self._state.value}') + return False + + _LOGGER.debug(f'{self}: Sending payload: {payload}') + + return self._transport.send(payload) + + def send_json(self, payload: Union[List, Dict]) -> bool: # pragma: no cover + return self.send(json.dumps(payload)) + + def is_running(self) -> bool: + return self._running + + def __str__(self): + return f'{self.__class__.__qualname__}({self._state})' + + # ====================== + # == Transport Thread == + # ====================== + + def _transport_callback(self, te: TransportEvent): + # _LOGGER.debug(f'{self}: {te}') + self._transport_queue.put(te) + self._wait_event.set() + + # ====================== + # == Runtime Thread == + # ====================== + + def _maintain_transport(self): + # Don't maintain the transport thread if we are stopping + if self._state == WsState.STOPPING: + return + + if self._transport_thread is None or not self._transport_thread.is_alive(): + _LOGGER.debug(f'{self}: Starting new transport thread') + self.state = WsState.CONNECTING + self._new_transport_thread() + + def _maintain_subscriptions(self): + if self._state != self._ready_state: + return + + self.subscription_controller.parse_bindings() + + def _cycle(self): + _LOGGER.debug(f'{self}: Runtime thread started ({tname()})') + while self._running: + self._maintain_transport() + self._maintain_subscriptions() + + self.process_transport_queue() + + self._wait_event.clear() + self._wait_event.wait(self._cycle_interval) + + # final pass through the router queue to flush any remaining events + self.process_transport_queue() + # final pass through the subscription controller to carry out final unsubscribe events + self.subscription_controller.parse_bindings() + _LOGGER.debug(f'{self}: Runtime thread stopped ({tname()})') + + def process_transport_queue(self): + while not self._transport_queue.empty(): + te = self._transport_queue.get() + try: + self._handle_transport_event(te) + except Exception as e: + _LOGGER.error(f'{self}: Exception processing transport event: {exception_to_string(e)} for {te}') + + def _handle_transport_event(self, te: TransportEvent): + if isinstance(te, TransportOpened): + self._handle_on_open(te.wsa) + elif isinstance(te, TransportClosed): + self._handle_on_close(te.wsa, te.close_status_code, te.close_msg) + elif isinstance(te, TransportError): + self._handle_on_error(te.wsa, te.error) + elif isinstance(te, TransportMessage): + self._handle_on_message(te.wsa, te.message) + elif isinstance(te, TransportCritical): + self._handle_on_critical(te.wsa, te.exception) + elif isinstance(te, TransportReconnect): + self._handle_on_reconnect(te.wsa) + else: + _LOGGER.error(f'{self}: Unknown event type: {type(te)}: {te}') + + def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover + events = self._router.route(message) + + # Router decided to skip this message + if events is None: + return + + # Handle lists and individual events + if not isinstance(events, list) and isinstance(events, WsEvent): + events = [events] + + # Propagate events to the sink + for event in events: + try: + self._sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception propagating event: {exception_to_string(e)} for {event}') + + def _handle_on_open(self, wsa: WebSocketApp): + _LOGGER.info(f'{self}: Connection open') + self.state = WsState.OPEN ## connected = True + self._sink.emit(events.WsOpen()) + + def _handle_on_error(self, wsa: WebSocketApp, exception:Exception): # pragma: no cover + _LOGGER.error(f'{self}: on_error: {exception}') + if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: + self.state = WsState.DEGRADED + self._sink.emit(events.WsError(error=exception)) + + def _handle_on_reconnect(self, wsa: WebSocketApp): # pragma: no cover + _LOGGER.error(f'{self}: on_reconnect') + self.set_authenticated(False) + self.state = WsState.OPEN + self._sink.emit(events.WsReconnect()) + + def _handle_on_critical(self, wsa: WebSocketApp, exception): # pragma: no cover + self._sink.emit(events.WsCritical(exception=exception)) + if self._restart_on_critical: + # TODO: following comment is not true - no restarting in on_close takes place + # if restart_on_close is set, restarting will happen in on_close callback + self.hard_reset(restart=not self._restart_on_close) + + def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): + _LOGGER.info(f'{self}: on_close') + self.subscription_controller.invalidate_subscriptions() + self._sink.emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) + # if we're not connected we shouldn't need to do anything + if self.state not in [self._ready_state, WsState.OPEN, WsState.STOPPING]: ## not self._connected: + _LOGGER.info(f'{self}: Unexpected on_close event while not open') + return + + if close_status_code is not None or close_msg is not None: # this means an error + try: + msg = close_msg.decode('utf-8') + except AttributeError: + msg = close_msg + + _LOGGER.error(f'{self}: on_close error: {close_status_code} | {msg}') + + else: # otherwise it's a close success confirmation + _LOGGER.info(f'{self}: Connection closed') + + if self.state == WsState.STOPPING: + _LOGGER.info(f'{self}: Gracefully closed') + + self.state = WsState.CLOSED ## self._connected = False + + # if not self._running: # if close happened due to shutting down, acknowledge and return + # _LOGGER.info(f'{self}: Gracefully closed') + # return + + + def hard_reset(self, restart: bool = False) -> None: + """ + Performs a hard reset of the WebSocket connection. + + This method forcefully closes the current WebSocketApp connection and optionally restarts it. It is + used to handle scenarios where the connection is unresponsive or encounters a critical error. + + This method cannot be called from the transport thread. + + Parameters: + restart (bool, optional): Specifies whether to restart the WebSocketApp connection after resetting. + Defaults to False. + + Note: + - Closes the current WebSocketApp connection, if any, and clears related resources. + - If the WebSocketApp is unresponsive or cannot be closed, it will be abandoned and the connection will be reset. + - If 'restart' is True, the method attempts to re-establish a new WebSocketApp connection after resetting. + """ + _LOGGER.info(f'{self}: Hard reset, {restart=}, {self._wsa is None=}') + + # we want the websocket closed before reconnecting + if self._wsa is not None: + if not self._connected: + # this means that we get a bad error before we could even get a connection confirmation + # which shouldn't really happen, but if it does the original WebSocketApp is bad + # so let's drop it anyway. + self._wsa = None + restart = True # since we've abandoned the WebSocketApp, let's ensure we restart + else: + _LOGGER.info(f'{self}: Hard reset is closing the WebSocketApp') + # check if current thread is the same as _transport_thread + if threading.current_thread() == self._transport_thread: + raise RuntimeError(f'{self}: Hard reset called from transport thread. Ensure it is started from a separate thread') + + self._wsa.close(status=STATUS_UNEXPECTED_CONDITION) + + # ensure the websocket is closed and abandoned + if not wait_until(lambda: self._wsa is None, f'{self}: Hard reset close timeout', timeout=self._timeout): + _LOGGER.warning(f'{self}: Abandoning current WebSocketApp that cannot be closed: {self._wsa}') + self._wsa = None + restart = True # since we've abandoned the WebSocketApp, let's ensure we restart + + # in some cases, closing the websocket will cause the restart elsewhere, therefore only closing it is enough + if restart: + _LOGGER.info(f'{self}: Forced restart') + self._reconnect() + + def _reconnect(self): + with self._reconnect_lock: + if self.state not in [WsState.OPEN, self._ready_state]: ## not self._has_active_connection(): + _LOGGER.info(f'{self}: Reconnecting') + self._try_connecting() + + if self._has_active_connection(): + self._on_reconnect() \ No newline at end of file diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py new file mode 100644 index 00000000..a5da5362 --- /dev/null +++ b/ibind/ws_v2/ws_transport.py @@ -0,0 +1,164 @@ +from datetime import datetime +from typing import Callable, Any + +from pydantic import BaseModel, ConfigDict, Field +from websocket import WebSocketApp + +from support.logs import project_logger +from support.py_utils import exception_to_string, tname + +_LOGGER = project_logger(__file__) + +_NOOP = lambda: None + + +class TransportEvent(BaseModel): + model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) + + received_at: datetime = Field(default_factory=datetime.now) + wsa: WebSocketApp + + def __str__(self): + return f'{self.__class__.__qualname__}()' + + +class TransportOpened(TransportEvent): + ... + + +class TransportClosed(TransportEvent): + close_status_code: int | None + close_msg: str | None + + +class TransportError(TransportEvent): + model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) + error: Exception + + +class TransportMessage(TransportEvent): + message: str + +class TransportReconnect(TransportEvent): + ... + +class TransportCritical(TransportEvent): + model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) + exception: Exception + + +class WsTransport(): + + def __init__( + self, + url: str, + event_callback: Callable, + sslopt: dict[str, Any], + get_cookie: Callable = _NOOP, + get_header: Callable = _NOOP, + ping_interval: float = 10, + ping_timeout: float = 10, + ): + self._url = url + self._event_callback = event_callback + self._get_cookie = get_cookie + self._get_header = get_header + self._ping_interval = ping_interval + self._ping_timeout = ping_timeout + self._sslopt = sslopt + + self._running = False + self._wsa: WebSocketApp | None = None + + def _wrap_callback(self, f): + def wrapped_f(ws, *args, **kwargs): + try: + f(ws, *args, **kwargs) + except Exception as e: + _LOGGER.exception(f'{self}: Exception executing callback: \n{f} \nwith\n{args=}\n{kwargs=}\n{str(e)}') + + return wrapped_f + + def _on_open(self, wsa: WebSocketApp): + self._event_callback(TransportOpened(wsa=wsa)) + + def _on_message(self, wsa: WebSocketApp, message): + self._event_callback(TransportMessage(wsa=wsa, message=message)) + + def _on_close(self, wsa: WebSocketApp, close_status_code, close_msg): + self._event_callback(TransportClosed(wsa=wsa, close_status_code=close_status_code, close_msg=close_msg)) + + def _on_error(self, wsa: WebSocketApp, error): + self._event_callback(TransportError(wsa=wsa, error=error)) + + def _on_reconnect(self, wsa: WebSocketApp): + self._event_callback(TransportReconnect(wsa=wsa)) + + def new_wsa(self): + try: + cookie = self._get_cookie() + except Exception as e: + _LOGGER.error(f'{self}: Failed to retrieve cookie: {exception_to_string(e)}') + cookie = None + + try: + header = self._get_header() + except Exception as e: + _LOGGER.error(f'{self}: Failed to retrieve header: {exception_to_string(e)}') + header = None + + wsa = WebSocketApp( + url=self._url, + on_open=self._wrap_callback(self._on_open), + on_message=self._wrap_callback(self._on_message), + on_close=self._wrap_callback(self._on_close), + on_error=self._wrap_callback(self._on_error), + on_reconnect=self._wrap_callback(self._on_reconnect), + cookie=cookie, + header=header, + ) + + self._wsa = wsa + + def send(self, payload: str) -> bool: + if not self._wsa.ready: + raise RuntimeError(f'{self}: WSA socket is not ready') + + try: + self._wsa.send(payload) + except Exception as e: + if 'Connection is already closed' in str(e): + _LOGGER.error(f'{self}: Connection closed while sending payload: {payload}') + else: + _LOGGER.exception(f'{self}: Sending payload failed: {payload}\n{exception_to_string(e)}') + return False + + return True + + def connect(self): + _LOGGER.debug(f'{self}: Transport thread started ({tname()})') + + if self._wsa is None: + self.new_wsa() + + try: + # the timeout is set to a little sooner than the interval + self._wsa.run_forever(ping_interval=self._ping_interval, ping_timeout=self._ping_interval * 0.95, sslopt=self._sslopt, reconnect=3) + + except ValueError as e: + if 'url is invalid' in str(e): + _LOGGER.error(f'{self}: URL is invalid: {self._url}') + except Exception as e: + _LOGGER.exception(f'{self}: Unexpected error while running WebSocketApp: {e}') + self._event_callback(TransportCritical(wsa=self._wsa, exception=e)) + + _LOGGER.debug(f'{self}: Transport thread stopped ({tname()})') + + # if self._restart_on_close and self._running: + # self._reconnect() + + def disconnect(self): + self._wsa.close() + + def __str__(self): + return f'{self.__class__.__qualname__}()' \ No newline at end of file From ed194563520e77568355d39fb9b5a962cc5de796 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 29 Apr 2026 18:17:33 +0200 Subject: [PATCH 02/32] chore: updated to requests>=2.33 and added pydantic>=2.13 --- requirements.txt | Bin 76 -> 108 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 786c8ffd058f76f80446bcee43eff24e7ad9b298..008e657ce57c637ec72ec984a62926c9573234e3 100644 GIT binary patch delta 42 xcmeatnIOYxJW*C(ftP`cp@5;1A%!84A&;SiA(J7Q!H&U}!H7YR!H~h20RXkm2S@+_ delta 9 Qcmd1tnIOYxI8jy~01dJNbN~PV From d263c0fc4343ef3dfbb446e3c7463c98c20acddb Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 29 Apr 2026 18:18:02 +0200 Subject: [PATCH 03/32] refactor(queue_controller): generic T is bound to Hashable instead of expecting only str and Enum --- ibind/base/queue_controller.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ibind/base/queue_controller.py b/ibind/base/queue_controller.py index 62b69e0f..29d0e35e 100644 --- a/ibind/base/queue_controller.py +++ b/ibind/base/queue_controller.py @@ -1,10 +1,9 @@ -from enum import Enum from queue import Queue, Empty -from typing import TypeVar, Generic, Any +from typing import TypeVar, Generic, Any, Hashable from ibind.support.py_utils import ensure_list_arg, OneOrMany -T = TypeVar('T', str, Enum) +T = TypeVar('T', bound=Hashable) class QueueAccessor(Generic[T]): # pragma: no cover @@ -147,4 +146,4 @@ def put_to_queue(self, key: T, data): AttributeError: If no queue exists for the given key. """ queue = self.get_queue(key) - queue.put(data) + queue.put(data) \ No newline at end of file From 3d8fbe84903ed5c7035c43950b5ea3227386a359 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 29 Apr 2026 20:27:15 +0200 Subject: [PATCH 04/32] refactor(ws_v2): added binding_key (to Subscriptions and resolution) and changed subscription_controller._bindings key from Subscription to new the binding_key. - added SubscriptionResolver which allow SubscriptionController to automatically detect binding_keys that need confirmation on (un)subscriptions - finished implementing ibkr_subscriptions --- examples/ws_04_ws_v2.py | 34 ++-- ibind/ibkr_ws_v2/ibkr_events.py | 19 ++ ibind/ibkr_ws_v2/ibkr_router.py | 30 ++- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 256 ++++++++++++++----------- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 32 +--- ibind/ws_v2/subscription_controller.py | 165 ++++++++-------- ibind/ws_v2/ws_runtime.py | 40 ++-- 7 files changed, 334 insertions(+), 242 deletions(-) diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py index 927eba50..a35be319 100644 --- a/examples/ws_04_ws_v2.py +++ b/examples/ws_04_ws_v2.py @@ -11,11 +11,10 @@ """ import os -import signal import time -from ibind import IbkrWsKey, IbkrWsClient, ibind_logs_initialize -from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription +from ibind import ibind_logs_initialize +from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PriceLadderSubscription, PnlSubscription, TradesSubscription from ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 ibind_logs_initialize(log_to_file=False, log_level='DEBUG') @@ -37,13 +36,24 @@ ws_client.start() -md_sub = MarketDataSubscription(conid='265598') -or_sub = OrdersSubscription() +as_sub = AccountSummarySubscription(account_id=account_id) al_sub = AccountLedgerSubscription(account_id=account_id) - -ws_client.subscribe(md_sub) -ws_client.subscribe(or_sub) -ws_client.subscribe(al_sub) +md_sub = MarketDataSubscription(conid='265598', fields=("31", "84", "86")) +or_sub = OrdersSubscription() +# pl_sub = PriceLadderSubscription(conid='265598', account_id=account_id, exchange='SMART') +pnl_sub = PnlSubscription() +tr_sub = TradesSubscription() +subs = [ + as_sub, + al_sub, + md_sub, + or_sub, + pnl_sub, + tr_sub +] + +for sub in subs: + ws_client.subscribe(sub) try: while ws_client.is_running(): @@ -51,9 +61,9 @@ except KeyboardInterrupt: print('Interrupt') -ws_client.unsubscribe(md_sub) -ws_client.unsubscribe(or_sub) -ws_client.unsubscribe(al_sub) +for sub in subs: + ws_client.unsubscribe(sub) +# time.sleep(5) ws_client.shutdown() # requests = [ diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py index 0c4b6d26..15e09863 100644 --- a/ibind/ibkr_ws_v2/ibkr_events.py +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -48,6 +48,24 @@ def from_channel(cls, channel): raise ValueError(f"No enum member associated with channel '{channel}'") + @property + def channel(self): + """ + Gets the solicited channel string associated with the enum member. + + Returns: + str: The channel string corresponding to the enum member. + """ + return { + IbkrWsKey.ACCOUNT_SUMMARY: 'sd', + IbkrWsKey.ACCOUNT_LEDGER: 'ld', + IbkrWsKey.MARKET_DATA: 'md', + IbkrWsKey.MARKET_HISTORY: 'mh', + IbkrWsKey.PRICE_LADDER: 'bd', + IbkrWsKey.ORDERS: 'or', + IbkrWsKey.PNL: 'pl', + IbkrWsKey.TRADES: 'tr', + }[self] class ParsedIbkrMessage(WsEvent): key: str = IbkrWsKey.UNCLASSIFIED @@ -110,6 +128,7 @@ class Unsubscription(WsEvent): class AccountSummary(WsEvent): key: IbkrWsKey = IbkrWsKey.ACCOUNT_SUMMARY + account_id: str data: dict diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index 36fe7b64..05b0b1c2 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -68,7 +68,7 @@ def _preprocess_market_history_message(self, data: dict) -> OneOrMany[WsEvent]: return ibkr_events.MarketHistory(conid=str(data['conid']), data=data) - def _preprocess_account_leger(self, data): + def _preprocess_account_ledger(self, data): events = [] for entry in data['result']: if 'acctCode' not in entry: @@ -77,6 +77,30 @@ def _preprocess_account_leger(self, data): events.append(event) return events + def _preprocess_account_summary(self, data): + summary = {} + timestamp = data['result'][0]['timestamp'] + for entry in data['result']: + key = entry.pop('key') + entry.pop('timestamp') + + if entry == {}: + continue + + summary[key] = entry + + if summary == {}: + return [] + + if 'AccountCode' not in summary or 'value' not in summary['AccountCode']: + _LOGGER.error(f'{self}: Account code not found in account summary: {summary}') + return [] + account_id = summary['AccountCode']['value'] + summary['timestamp'] = timestamp + + event = ibkr_events.AccountSummary(data=summary, account_id=account_id) + return event + def _handle_subscribed_message(self, channel: str, data: dict) -> OneOrMany[WsEvent] | None: try: ibkr_ws_key = IbkrWsKey.from_channel(channel[:2]) @@ -85,9 +109,9 @@ def _handle_subscribed_message(self, channel: str, data: dict) -> OneOrMany[WsEv return None if ibkr_ws_key == IbkrWsKey.ACCOUNT_SUMMARY: - return ibkr_events.AccountSummary(data=data) + return self._preprocess_account_summary(data) elif ibkr_ws_key == IbkrWsKey.ACCOUNT_LEDGER: - return self._preprocess_account_leger(data) + return self._preprocess_account_ledger(data) elif ibkr_ws_key == IbkrWsKey.MARKET_DATA: return self._preprocess_market_data_message(data) elif ibkr_ws_key == IbkrWsKey.MARKET_HISTORY: diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index a7f4abca..2e6df665 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -1,86 +1,132 @@ import json -from typing import Literal, Any - -from ibkr_ws_v2.ibkr_events import IbkrWsKey, AccountSummary, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades -from ws_v2.events import WsEvent -from ws_v2.subscription_controller import Subscription - - -def ibkr_payload(op: Literal['s', 'u'], topic: str, target: str | None = None, data: dict[str, Any] | None = None) -> str: - payload = f"{op}{topic}" - if target is not None: - payload += f"+{target}" - if data is not None: - payload += f"+{json.dumps(data, separators=(',', ':'))}" - return payload - - -class AccountSummarySubscription(Subscription): - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.ACCOUNT_SUMMARY +from typing import Tuple + +from ibkr_ws_v2.ibkr_events import IbkrWsKey, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary +from ws_v2.subscription_controller import Subscription, SubscriptionResolver + + +def make_binding_key( + key: IbkrWsKey, + conid: str = None, + account_id=None, + exchange=None +): + if key in [IbkrWsKey.MARKET_DATA, IbkrWsKey.MARKET_HISTORY]: + return f"{key.channel}+{conid}" + elif key in [IbkrWsKey.ACCOUNT_LEDGER, IbkrWsKey.ACCOUNT_SUMMARY]: + return f"{key.channel}+{account_id}" + elif key in [IbkrWsKey.PRICE_LADDER]: + return f"{key.channel}+{account_id}+{conid}" + (f"+{exchange}" if exchange is not None else '') + elif key in [IbkrWsKey.ORDERS, IbkrWsKey.PNL, IbkrWsKey.TRADES]: + return key.channel + else: + raise ValueError(f'Unsupported key: {key}') + + +class IbkrSubscriptionResolver(SubscriptionResolver): + _register = [ + MarketData, + AccountSummary, + AccountLedger, + MarketHistory, + Orders, + PriceLadder, + Pnl, + Trades, + Unsubscription + ] + + def __init__(self, account_id): + self._account_id = account_id + + def _resolve_subscribing_event(self, event) -> str: + if event.key in [IbkrWsKey.MARKET_DATA, IbkrWsKey.MARKET_HISTORY]: + return make_binding_key(event.key, conid=event.conid) + elif event.key in [IbkrWsKey.ACCOUNT_LEDGER, IbkrWsKey.ACCOUNT_SUMMARY]: + return make_binding_key(event.key, account_id=event.account_id) + elif event.key in [IbkrWsKey.PRICE_LADDER]: + return make_binding_key(event.key, conid=event.conid, account_id=event.account_id, exchange=event.exchange) + elif event.key in [IbkrWsKey.ORDERS, IbkrWsKey.PNL, IbkrWsKey.TRADES]: + return make_binding_key(event.key) + else: + raise ValueError(f'Unsupported event: {event}') + + def _resolve_unsubscribing_event(self, event) -> str: + return make_binding_key(event.target_key, event.conid, self._account_id) + + def resolve_binding_key(self, event) -> Tuple[bool, str] | Tuple[None, None]: + if type(event) not in self._register: + return None, None + + if isinstance(event, Unsubscription): + return False, self._resolve_unsubscribing_event(event) + else: + return True, self._resolve_subscribing_event(event) + + +class IbkrSubscription(Subscription): + key: IbkrWsKey @property def topic(self) -> str: - return '' + return self.key.channel + + +class AccountSummarySubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.ACCOUNT_SUMMARY + account_id: str def subscribe_payload(self) -> str: - ... + return f'ssd+{self.account_id}' def unsubscribe_payload(self) -> str: - ... + return f'usd+{self.account_id}' + @property def confirms_subscribe(self) -> bool: return True + @property def confirms_unsubscribe(self) -> bool: return True + def binding_key(self): + return make_binding_key(self.key, account_id=self.account_id) -class AccountLedgerSubscription(Subscription): - account_id: str - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.ACCOUNT_LEDGER - @property - def topic(self) -> str: - return 'ld' +class AccountLedgerSubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.ACCOUNT_LEDGER + account_id: str def subscribe_payload(self) -> str: - return ibkr_payload("s", "ld", self.account_id) - + return f'sld+{self.account_id}' def unsubscribe_payload(self) -> str: - return ibkr_payload("u", "ld", self.account_id) - + return f'uld+{self.account_id}' + @property def confirms_subscribe(self) -> bool: return True + @property def confirms_unsubscribe(self) -> bool: return True - def make_hash(self): - return self.topic + "+" + self.account_id + def binding_key(self): + return make_binding_key(self.key, account_id=self.account_id) -class MarketDataSubscription(Subscription): - conid: str - fields: tuple[str, ...] = ("31", "84", "86") - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.MARKET_DATA - - @property - def topic(self) -> str: - return "md" +class MarketDataSubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.MARKET_DATA + conid: str + fields: tuple[str, ...] def subscribe_payload(self) -> str: - return ibkr_payload("s", "md", self.conid, {"fields": list(self.fields)}) + fields_str = json.dumps({"fields": list(self.fields)}, separators=(',', ':')) + return f'smd+{self.conid}+{fields_str}' def unsubscribe_payload(self) -> str: - return ibkr_payload("u", "md", self.conid, {}) + return f'umd+{self.conid}+{{}}' @property def confirms_subscribe(self) -> bool: @@ -90,48 +136,45 @@ def confirms_subscribe(self) -> bool: def confirms_unsubscribe(self) -> bool: return False - def make_hash(self): - return self.topic + "+" + self.conid + def binding_key(self): + return make_binding_key(self.key, conid=self.conid) -class MarketHistorySubscription(Subscription): +class MarketHistorySubscription(IbkrSubscription): conid: str @property def key(self) -> IbkrWsKey: return IbkrWsKey.MARKET_HISTORY - @property - def topic(self) -> str: - return '' - def subscribe_payload(self) -> str: ... def unsubscribe_payload(self) -> str: ... + @property def confirms_subscribe(self) -> bool: return True + @property def confirms_unsubscribe(self) -> bool: return True + def binding_key(self): + return make_binding_key(self.key, conid=self.conid) -class OrdersSubscription(Subscription): - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.ORDERS - @property - def topic(self) -> str: - return "or" +class OrdersSubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.ORDERS + filter: str = None def subscribe_payload(self) -> str: - return ibkr_payload("s", "or") + filter_str = f'{{"filters": ["{self.filter}"]}}' if self.filter is not None else '{}' + return f'sor+{filter_str}' def unsubscribe_payload(self) -> str: - return ibkr_payload("u", "or", data={}) + return 'uor+{}' @property def confirms_subscribe(self) -> bool: @@ -141,88 +184,79 @@ def confirms_subscribe(self) -> bool: def confirms_unsubscribe(self) -> bool: return False + def binding_key(self): + return make_binding_key(self.key) -class PriceLadderSubscription(Subscription): - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.PRICE_LADDER - @property - def topic(self) -> str: - return '' +class PriceLadderSubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.PRICE_LADDER + conid: str + account_id: str + exchange: str def subscribe_payload(self) -> str: - ... + return f'sbd+{self.account_id}+{self.conid}+{self.exchange}' def unsubscribe_payload(self) -> str: - ... + return f'ubd+{self.account_id}' + @property def confirms_subscribe(self) -> bool: return False + @property def confirms_unsubscribe(self) -> bool: return False + def binding_key(self): + return make_binding_key(self.key, conid=self.conid, account_id=self.account_id, exchange=self.exchange) -class PnlSubscription(Subscription): - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.PNL - @property - def topic(self) -> str: - return '' +class PnlSubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.PNL def subscribe_payload(self) -> str: - ... + return 'spl' def unsubscribe_payload(self) -> str: - ... + return 'upl' + @property def confirms_subscribe(self) -> bool: return True + @property def confirms_unsubscribe(self) -> bool: return False + def binding_key(self): + return make_binding_key(self.key) -class TradesSubscription(Subscription): - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.TRADES - @property - def topic(self) -> str: - return '' +class TradesSubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.TRADES + realtime_updates_only: bool = False + days: int = 1 def subscribe_payload(self) -> str: - ... + extra = {} + if self.realtime_updates_only: + extra['realtime_updates_only'] = self.realtime_updates_only + if self.days: + extra['days'] = self.days + extra_str = json.dumps(extra, separators=(',', ':')) + return f'str+{extra_str}' def unsubscribe_payload(self) -> str: - ... + return 'utr' + @property def confirms_subscribe(self) -> bool: return True + @property def confirms_unsubscribe(self) -> bool: return False -def event_to_subscription(event:WsEvent): - if isinstance(event, AccountSummary): - return AccountSummarySubscription() - elif isinstance(event, AccountLedger): - return AccountLedgerSubscription(account_id=event.account_id) - elif isinstance(event, MarketData): - return MarketDataSubscription(conid=event.conid) - elif isinstance(event, MarketHistory): - return MarketHistorySubscription(conid=event.conid) - elif isinstance(event, Orders): - return OrdersSubscription() - elif isinstance(event, PriceLadder): - return PriceLadderSubscription() - elif isinstance(event, Pnl): - return PnlSubscription() - elif isinstance(event, Trades): - return TradesSubscription() - else: - raise ValueError(f'Unsupported event: {event}') \ No newline at end of file + def binding_key(self): + return make_binding_key(self.key) \ No newline at end of file diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index ce8c527b..8f3e985e 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -5,11 +5,10 @@ from ibind import ExternalBrokerError, IbkrClient from ibkr_ws_v2 import ibkr_events from ibkr_ws_v2.ibkr_router import IbkrRouter -from ibkr_ws_v2.ibkr_subscriptions import event_to_subscription +from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver from support.logs import project_logger -from ws_v2 import events from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router -from ws_v2.subscription_controller import Subscription +from ws_v2.subscription_controller import Subscription, SubscriptionResolver from ws_v2.ws_runtime import WsRuntime, WsState _LOGGER = project_logger(__file__) @@ -33,6 +32,7 @@ def __init__( recreate_subscriptions_on_reconnect: bool = True, sink: EventSink = None, router: Router = None, + subscription_resolver: SubscriptionResolver = None, ): self._account_id = account_id @@ -69,6 +69,9 @@ def __init__( if router is None: router = IbkrRouter() + if subscription_resolver is None: + subscription_resolver = IbkrSubscriptionResolver(account_id) + self._runtime = WsRuntime( url=url, cycle_interval=cycle_interval, @@ -76,27 +79,13 @@ def __init__( cacert=cacert, sink=sink, router=router, + subscription_resolver=subscription_resolver, get_cookie=self._get_cookie, get_header=self._get_header, ) - # self._subscription_controller = SubscriptionController( - # send_payload=self._runtime.send, - # is_running=self._runtime.is_running, - # ) - def _register_internal_callbacks(self): self._internal_sink.on(ibkr_events.AuthenticationStatus, self._on_authentication_status) - - self._internal_sink.on(ibkr_events.Unsubscription, self._on_unsubscription_confirmation) - - self._internal_sink.on(ibkr_events.AccountSummary, self._on_subscription_confirmation) - self._internal_sink.on(ibkr_events.AccountLedger, self._on_subscription_confirmation) - self._internal_sink.on(ibkr_events.MarketData, self._on_subscription_confirmation) - self._internal_sink.on(ibkr_events.MarketHistory, self._on_subscription_confirmation) - self._internal_sink.on(ibkr_events.Pnl, self._on_subscription_confirmation) - self._internal_sink.on(ibkr_events.Trades, self._on_subscription_confirmation) - self._internal_sink.on(ibkr_events.WaitingForSession, self._set_unauthenticated) def _set_unauthenticated(self, _): @@ -110,13 +99,6 @@ def _on_authentication_status(self, event: ibkr_events.AuthenticationStatus): self._runtime.set_authenticated(event.authenticated) - def _on_subscription_confirmation(self, event: events.WsEvent): - subscription = event_to_subscription(event) - self._runtime.subscription_controller.set_subscription_active(subscription) - - def _on_unsubscription_confirmation(self, event: events.WsEvent): - subscription = event_to_subscription(event) - self._runtime.subscription_controller.set_subscription_unsubscribed(subscription) def _get_cookie(self): try: diff --git a/ibind/ws_v2/subscription_controller.py b/ibind/ws_v2/subscription_controller.py index 00254d15..5bb52ae3 100644 --- a/ibind/ws_v2/subscription_controller.py +++ b/ibind/ws_v2/subscription_controller.py @@ -1,22 +1,20 @@ import copy import time from enum import Enum -from typing import Dict, Optional, Callable +from typing import Dict, Optional, Callable, Protocol, Tuple, Hashable, Literal from pydantic import BaseModel, ConfigDict from ibind.support.logs import project_logger from ibind.support.py_utils import TimeoutLock, exception_to_string +from ws_v2.events import WsEvent _LOGGER = project_logger(__file__) class Subscription(BaseModel): model_config = ConfigDict(frozen=True) - - @property - def key(self) -> str: - raise NotImplementedError + key: Hashable @property def topic(self) -> str: @@ -36,18 +34,23 @@ def confirms_subscribe(self) -> bool: def confirms_unsubscribe(self) -> bool: return False - def make_hash(self): + def binding_key(self): return self.subscribe_payload() def __hash__(self): if hasattr(self, '_hash'): return self._hash - _hash = hash(self.make_hash()) + _hash = hash(self.binding_key()) setattr(self, '_hash', _hash) return _hash def __str__(self): - return f'{self.__class__.__qualname__}({self.make_hash()})' + return f'{self.__class__.__qualname__}({self.binding_key()})' + + +class SubscriptionResolver(Protocol): + def resolve_binding_key(self, event) -> Tuple[bool, str]: + ... class BindingStatus(Enum): @@ -62,7 +65,7 @@ class BindingStatus(Enum): class Binding(BaseModel): subscription: Subscription - intent: BindingStatus + intent: Literal[BindingStatus.ACTIVE, BindingStatus.UNSUBSCRIBED] status: BindingStatus = BindingStatus.NEW attempts: int = 0 last_attempt: float = 0 @@ -85,15 +88,16 @@ class SubscriptionController: def __init__( self, send_payload: Callable[[str], bool], + subscription_resolver: SubscriptionResolver, subscription_retries: int = 5, subscription_timeout: float = 2, ): self._send_payload = send_payload + self._subscription_resolver = subscription_resolver self._subscription_retries = subscription_retries self._subscription_timeout = subscription_timeout - # self._subscriptions: Dict[str, dict] = {} - self._bindings: Dict[Subscription, Binding] = {} + self._bindings: Dict[str, Binding] = {} self._operational_lock = TimeoutLock(60) def _send(self, payload) -> bool: @@ -106,130 +110,141 @@ def _send(self, payload) -> bool: _LOGGER.exception(f'{self}: Exception sending payload: {payload}\n{exception_to_string(e)}') return False - def parse_binding(self, subscription: Subscription, binding: Binding): - if binding.status == binding.intent: + def observe(self, event: WsEvent): + is_active, binding_key = self._subscription_resolver.resolve_binding_key(event) + + # None means the event is not related to a tracked subscription + if binding_key is None: return + with self._operational_lock: + if not self.has_subscription(binding_key): + _LOGGER.warning(f'{self}: Observed a binding_key "{binding_key}" that is missing a subscription. Event: {event}') + return + + if is_active: + self._confirm_subscribed(binding_key) + else: + self._confirm_unsubscribed(binding_key) + + def parse_binding(self, binding: Binding): + # wait until timeout has passed since last attempt if binding.last_attempt + self._subscription_timeout > time.time(): return - binding.last_attempt = time.time() + # if we've exceeded the number of retries, mark the subscription as failed if binding.attempts >= self._subscription_retries: - _LOGGER.info(f'{self}: Subscription failed after {self._subscription_retries} attempts: {subscription}') + _LOGGER.info(f'{self}: Subscription failed after {self._subscription_retries} attempts: {binding}') binding.status = BindingStatus.FAILED binding.attempts = 0 return + subscription = binding.subscription + if binding.intent == BindingStatus.ACTIVE: - payload = binding.subscription.subscribe_payload() + payload = subscription.subscribe_payload() self._send(payload) if not subscription.confirms_subscribe: _LOGGER.info(f'{self}: Subscribed: {payload} without confirmation.') - self.set_subscription_active(subscription) + self._confirm_subscribed(subscription.binding_key()) + elif binding.intent == BindingStatus.UNSUBSCRIBED: - payload = binding.subscription.unsubscribe_payload() + payload = subscription.unsubscribe_payload() self._send(payload) if not subscription.confirms_unsubscribe: _LOGGER.info(f'{self}: Unsubscribed: {payload} without confirmation.') - self.set_subscription_unsubscribed(subscription) + self._confirm_unsubscribed(subscription.binding_key()) def parse_bindings(self): - for subscription, binding in self._bindings.items(): - self.parse_binding(subscription, binding) + with self._operational_lock: + for subscription, binding in self._bindings.items(): + if binding.status == binding.intent: + continue + + self.parse_binding(binding) def subscribe(self, subscription: Subscription) -> bool: with self._operational_lock: - if self.is_subscription_active(subscription): # do nothing if subscription is present and active + if self.is_subscription_active(subscription.binding_key()): # do nothing if subscription is present and active return True # store a new binding - if not self.has_subscription(subscription): - self._bindings[subscription] = Binding(subscription=subscription, intent=BindingStatus.ACTIVE) + if self.has_subscription(subscription.binding_key()): + return + + self._bindings[subscription.binding_key()] = Binding(subscription=subscription, intent=BindingStatus.ACTIVE) + _LOGGER.info(f'{self}: Registered subscription intent: {subscription.binding_key()}') def unsubscribe(self, subscription: Subscription) -> bool: with self._operational_lock: - if not self.has_subscription(subscription): - binding = Binding(subscription=subscription, intent=BindingStatus.UNSUBSCRIBED) - self._bindings[subscription] = binding + if self.has_subscription(subscription.binding_key()): + binding = self._bindings[subscription.binding_key()] + if binding.status == BindingStatus.UNSUBSCRIBED: + return + self._bindings[subscription.binding_key()].intent = BindingStatus.UNSUBSCRIBED else: - self._bindings[subscription].intent = BindingStatus.UNSUBSCRIBED - - # def invalidate_subscriptions(self): - # for channel in self._subscriptions: - # if self._subscriptions[channel].get('status', False): - # self._subscriptions[channel]['status'] = False - # _LOGGER.info(f'{self}: Invalidated subscription: {channel}') + binding = Binding(subscription=subscription, intent=BindingStatus.UNSUBSCRIBED) + self._bindings[subscription.binding_key()] = binding + _LOGGER.info(f'{self}: Registered unsubscription intent: {subscription.binding_key()}') def invalidate_subscriptions(self): - for subscription, binding in self._bindings.items(): + for binding_key, binding in self._bindings.items(): if binding.status == BindingStatus.ACTIVE: binding.status = BindingStatus.DEGRADED - _LOGGER.info(f'{self}: Invalidated subscription: {subscription}') - - # def is_subscription_active(self, channel: str) -> Optional[bool]: # pragma: no cover - # return self._subscriptions.get(channel, {}).get('status', None) + _LOGGER.info(f'{self}: Invalidated subscription: {binding}') - def is_subscription_active(self, subscription: Subscription) -> Optional[bool]: # pragma: no cover - if not self.has_subscription(subscription): + def is_subscription_active(self, binding_key: str) -> Optional[bool]: # pragma: no cover + if not self.has_subscription(binding_key): return False - return self._bindings.get(subscription).status == BindingStatus.ACTIVE - - # def has_active_subscriptions(self) -> bool: # pragma: no cover - # for channel in self._subscriptions: - # if self.is_subscription_active(channel): - # return True - # return False + return self._bindings[binding_key].status == BindingStatus.ACTIVE def has_active_subscriptions(self) -> bool: # pragma: no cover - for subscription in self._bindings: - if self.is_subscription_active(subscription): - return True + with self._operational_lock: + for subscription in self._bindings: + if self.is_subscription_active(subscription): + return True return False - # def has_subscription(self, channel: str) -> bool: # pragma: no cover - # return channel in self._subscriptions - - def has_subscription(self, subscription: Subscription) -> bool: # pragma: no cover - return subscription in self._bindings - - # def get_active_subscriptions(self): - # return {channel: copy.deepcopy(subscription) for channel, subscription in self._subscriptions.items() if self.is_subscription_active(channel)} + def has_subscription(self, binding_key: str) -> bool: # pragma: no cover + with self._operational_lock: + return binding_key in self._bindings def get_active_subscriptions(self): - return { - subscription: copy.deepcopy(binding) - for subscription, binding in self._bindings.items() - if self.is_subscription_active(subscription) - } - - def set_subscription_active(self, subscription: Subscription): - if not self.has_subscription(subscription): - _LOGGER.warning(f'{self}: Unknown subscription {subscription} - cannot update status to {BindingStatus.ACTIVE.value}') + with self._operational_lock: + return { + binding_key: copy.deepcopy(binding) + for binding_key, binding in self._bindings.items() + if self.is_subscription_active(binding_key) + } + + def _confirm_subscribed(self, binding_key: str): + if not self.has_subscription(binding_key): + _LOGGER.warning(f'{self}: Unknown subscription {binding_key} - cannot update status to {BindingStatus.ACTIVE.value}') return - binding = self._bindings[subscription] + binding = self._bindings[binding_key] if binding.status == BindingStatus.ACTIVE or binding.intent == BindingStatus.UNSUBSCRIBED: return binding.status = BindingStatus.ACTIVE binding.attempts = 0 - _LOGGER.info(f'{self}: Updated subscription status: {subscription} -> {BindingStatus.ACTIVE.value}') + _LOGGER.info(f'{self}: Updated subscription status: {binding_key} -> {BindingStatus.ACTIVE.value}') - def set_subscription_unsubscribed(self, subscription: Subscription): - if not self.has_subscription(subscription): - _LOGGER.warning(f'{self}: Unknown subscription {subscription} - cannot update status to {BindingStatus.UNSUBSCRIBED.value}') + def _confirm_unsubscribed(self, binding_key: str): + if not self.has_subscription(binding_key): + _LOGGER.warning(f'{self}: Unknown subscription {binding_key} - cannot update status to {BindingStatus.UNSUBSCRIBED.value}') return - binding = self._bindings[subscription] + binding = self._bindings[binding_key] if binding.status == BindingStatus.UNSUBSCRIBED or binding.intent == BindingStatus.ACTIVE: return binding.status = BindingStatus.UNSUBSCRIBED binding.attempts = 0 - _LOGGER.info(f'{self}: Updated subscription status: {subscription} -> {BindingStatus.UNSUBSCRIBED.value}') + _LOGGER.info(f'{self}: Updated subscription status: {binding_key} -> {BindingStatus.UNSUBSCRIBED.value}') def __str__(self): return f'{self.__class__.__qualname__}()' \ No newline at end of file diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index 1276200c..359b8504 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -1,7 +1,6 @@ import json import ssl import threading -from enum import Enum from pathlib import Path from queue import Queue from threading import Thread, Event @@ -10,10 +9,10 @@ from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION from support.logs import project_logger -from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string +from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock from ws_v2 import events -from ws_v2.events import WsOpen, WsEvent, EventSink, Router -from ws_v2.subscription_controller import SubscriptionController +from ws_v2.events import WsEvent, EventSink, Router +from ws_v2.subscription_controller import SubscriptionController, SubscriptionResolver from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportCritical, TransportReconnect _LOGGER = project_logger(__file__) @@ -40,8 +39,9 @@ def __init__( self, url: str, cycle_interval: float, - sink:EventSink, + sink: EventSink, router: Router, + subscription_resolver: SubscriptionResolver, ready_state: Literal[WsState.OPEN, WsState.AUTHENTICATED] = WsState.OPEN, cacert: Union[str, bool] = False, connection_timeout: float = _DEFAULT_TIMEOUT, @@ -54,6 +54,7 @@ def __init__( self._cycle_interval = cycle_interval self._sink = sink self._router = router + self._subscription_resolver = subscription_resolver self._ready_state = ready_state self._connection_timeout = connection_timeout self._restart_on_close = restart_on_close @@ -67,6 +68,8 @@ def __init__( self._transport_queue = Queue() self._wait_event = Event() + self._state_lock = TimeoutLock(60) + if not (cacert is False or Path(cacert).exists()): raise ValueError(f'{self}: cacert must be a valid Path or False') @@ -75,7 +78,6 @@ def __init__( else: sslopt = {'ca_certs': cacert} - self._transport = WsTransport( url=url, event_callback=self._transport_callback, @@ -84,21 +86,24 @@ def __init__( get_header=get_header, ) - self.subscription_controller = SubscriptionController(send_payload=self.send) + self.subscription_controller = SubscriptionController(send_payload=self.send, subscription_resolver=self._subscription_resolver) @property def state(self): _LOGGER.debug(f'{self}: State: {self._state.value}') - return self._state + with self._state_lock: + return self._state @state.setter def state(self, value): _LOGGER.debug(f'{self}: {self._state.value} -> {value.value}') - self._state = value + with self._state_lock: + self._state = value + if self._state == self._ready_state: self._sink.emit(events.WsReady()) - def set_authenticated(self, value:bool): + def set_authenticated(self, value: bool): if value != self._authenticated: _LOGGER.debug(f'{self}: Authenticated: {value}') self._authenticated = value @@ -113,7 +118,6 @@ def set_authenticated(self, value:bool): def get_authenticated(self) -> bool: return self._authenticated - def _new_transport_thread(self): self._transport_thread = Thread(target=self._transport.connect, name='ws_transport_thread') self._transport_thread.daemon = True @@ -264,6 +268,11 @@ def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover # Propagate events to the sink for event in events: + try: + self.subscription_controller.observe(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception observing subscription: {exception_to_string(e)} for {event}') + try: self._sink.emit(event) except Exception as e: @@ -274,7 +283,7 @@ def _handle_on_open(self, wsa: WebSocketApp): self.state = WsState.OPEN ## connected = True self._sink.emit(events.WsOpen()) - def _handle_on_error(self, wsa: WebSocketApp, exception:Exception): # pragma: no cover + def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): # pragma: no cover _LOGGER.error(f'{self}: on_error: {exception}') if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: self.state = WsState.DEGRADED @@ -298,7 +307,7 @@ def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): self.subscription_controller.invalidate_subscriptions() self._sink.emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) # if we're not connected we shouldn't need to do anything - if self.state not in [self._ready_state, WsState.OPEN, WsState.STOPPING]: ## not self._connected: + if self.state not in [self._ready_state, WsState.OPEN, WsState.STOPPING]: ## not self._connected: _LOGGER.info(f'{self}: Unexpected on_close event while not open') return @@ -316,13 +325,12 @@ def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): if self.state == WsState.STOPPING: _LOGGER.info(f'{self}: Gracefully closed') - self.state = WsState.CLOSED ## self._connected = False + self.state = WsState.CLOSED ## self._connected = False # if not self._running: # if close happened due to shutting down, acknowledge and return # _LOGGER.info(f'{self}: Gracefully closed') # return - def hard_reset(self, restart: bool = False) -> None: """ Performs a hard reset of the WebSocket connection. @@ -372,7 +380,7 @@ def hard_reset(self, restart: bool = False) -> None: def _reconnect(self): with self._reconnect_lock: - if self.state not in [WsState.OPEN, self._ready_state]: ## not self._has_active_connection(): + if self.state not in [WsState.OPEN, self._ready_state]: ## not self._has_active_connection(): _LOGGER.info(f'{self}: Reconnecting') self._try_connecting() From 6c424770d8f99ab23915482e2edf153875cb781c Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 29 Apr 2026 20:34:25 +0200 Subject: [PATCH 05/32] fix(ws_v2): small fixes --- examples/ws_02_intermediate.py | 2 +- ibind/ibkr_ws_v2/ibkr_events.py | 7 ++++++- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 5 +---- ibind/ws_v2/subscription_controller.py | 2 ++ 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/ws_02_intermediate.py b/examples/ws_02_intermediate.py index dc1ba8a0..ca498c2f 100644 --- a/examples/ws_02_intermediate.py +++ b/examples/ws_02_intermediate.py @@ -68,4 +68,4 @@ def stop(_, _1): print('KeyboardInterrupt') break -stop(None, None) +stop(None, None) \ No newline at end of file diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py index 15e09863..919b3a24 100644 --- a/ibind/ibkr_ws_v2/ibkr_events.py +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -1,6 +1,8 @@ from enum import Enum from typing import Any +from pydantic import Field + from ws_v2.events import WsEvent @@ -142,7 +144,7 @@ class MarketData(WsEvent): key: IbkrWsKey = IbkrWsKey.MARKET_DATA conid: str data: dict = {} - fields: dict[str, Any] = {} + fields: dict[str, Any] = Field(default_factory=dict) class MarketHistory(WsEvent): @@ -158,6 +160,9 @@ class Orders(WsEvent): class PriceLadder(WsEvent): key: IbkrWsKey = IbkrWsKey.PRICE_LADDER + account_id: str + conid: str + exchange: str data: dict diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index 2e6df665..d35d99f7 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -141,12 +141,9 @@ def binding_key(self): class MarketHistorySubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.MARKET_HISTORY conid: str - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.MARKET_HISTORY - def subscribe_payload(self) -> str: ... diff --git a/ibind/ws_v2/subscription_controller.py b/ibind/ws_v2/subscription_controller.py index 5bb52ae3..b147f34a 100644 --- a/ibind/ws_v2/subscription_controller.py +++ b/ibind/ws_v2/subscription_controller.py @@ -140,6 +140,8 @@ def parse_binding(self, binding: Binding): binding.attempts = 0 return + binding.attempts += 1 + subscription = binding.subscription if binding.intent == BindingStatus.ACTIVE: From 772d9857f8c37d1afd988d1f5b06564f4a185865 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 29 Apr 2026 20:34:25 +0200 Subject: [PATCH 06/32] fix(ws_v2): small fixes --- examples/ws_02_intermediate.py | 2 +- ibind/ibkr_ws_v2/ibkr_events.py | 13 ++++++++---- ibind/ibkr_ws_v2/ibkr_router.py | 8 +++---- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 13 +++++------- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 5 +++-- ibind/ws_v2/events.py | 4 ++++ ibind/ws_v2/subscription_controller.py | 4 +++- ibind/ws_v2/ws_runtime.py | 29 ++++++++++++++------------ ibind/ws_v2/ws_transport.py | 2 ++ 9 files changed, 47 insertions(+), 33 deletions(-) diff --git a/examples/ws_02_intermediate.py b/examples/ws_02_intermediate.py index dc1ba8a0..ca498c2f 100644 --- a/examples/ws_02_intermediate.py +++ b/examples/ws_02_intermediate.py @@ -68,4 +68,4 @@ def stop(_, _1): print('KeyboardInterrupt') break -stop(None, None) +stop(None, None) \ No newline at end of file diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py index 15e09863..aa92069f 100644 --- a/ibind/ibkr_ws_v2/ibkr_events.py +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -1,6 +1,8 @@ from enum import Enum from typing import Any +from pydantic import Field + from ws_v2.events import WsEvent @@ -47,7 +49,6 @@ def from_channel(cls, channel): return channel_to_key[channel] raise ValueError(f"No enum member associated with channel '{channel}'") - @property def channel(self): """ @@ -67,7 +68,8 @@ def channel(self): IbkrWsKey.TRADES: 'tr', }[self] -class ParsedIbkrMessage(WsEvent): + +class GenericIbkrEvent(WsEvent): key: str = IbkrWsKey.UNCLASSIFIED message: dict | None topic: str | None = None @@ -141,8 +143,8 @@ class AccountLedger(WsEvent): class MarketData(WsEvent): key: IbkrWsKey = IbkrWsKey.MARKET_DATA conid: str - data: dict = {} - fields: dict[str, Any] = {} + data: dict = Field(default_factory=dict) + fields: dict[str, Any] = Field(default_factory=dict) class MarketHistory(WsEvent): @@ -158,6 +160,9 @@ class Orders(WsEvent): class PriceLadder(WsEvent): key: IbkrWsKey = IbkrWsKey.PRICE_LADDER + account_id: str + conid: str + exchange: str data: dict diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index 05b0b1c2..c445ced8 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -5,7 +5,7 @@ from client import ibkr_definitions from client.ibkr_utils import extract_conid from ibkr_ws_v2 import ibkr_events -from ibkr_ws_v2.ibkr_events import ParsedIbkrMessage, IbkrWsKey +from ibkr_ws_v2.ibkr_events import GenericIbkrEvent, IbkrWsKey from support.logs import project_logger from support.py_utils import UNDEFINED, OneOrMany from ws_v2.events import WsEvent @@ -213,7 +213,7 @@ def _handle_message_without_topic(self, message: dict) -> OneOrMany[WsEvent]: # return self.modify_subscription(f'ld+{self._account_id}', status=False) _LOGGER.error(f'{self}: Unrecognised message without a topic: {message}') - return ParsedIbkrMessage(message=message) + return GenericIbkrEvent(message=message) def _preprocess_raw_message(self, raw_message: str): message = json.loads(raw_message) @@ -274,12 +274,12 @@ def route(self, raw_message: str) -> OneOrMany[WsEvent]: events = self._handle_subscribed_message(channel, message) if events is None: _LOGGER.error(f'{self}: Channel "{channel}" subscribed but lacking a handler. Message: {message}') - events = ParsedIbkrMessage(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) + events = GenericIbkrEvent(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) return events # _LOGGER.warning(f'{self}: Handled a channel "{channel}" message that is missing a subscription. Message: {message}') _LOGGER.error(f'{self}: Topic "{topic}" unrecognised. Message: {message}') - return ParsedIbkrMessage(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) + return GenericIbkrEvent(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) # def route(self, raw_message) -> List[WsEvent]: # _LOGGER.debug(f'{self}: Routing message: {raw_message}') diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index 2e6df665..a4d8b377 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -141,12 +141,9 @@ def binding_key(self): class MarketHistorySubscription(IbkrSubscription): + key: IbkrWsKey = IbkrWsKey.MARKET_HISTORY conid: str - @property - def key(self) -> IbkrWsKey: - return IbkrWsKey.MARKET_HISTORY - def subscribe_payload(self) -> str: ... @@ -235,14 +232,14 @@ def binding_key(self): class TradesSubscription(IbkrSubscription): key: IbkrWsKey = IbkrWsKey.TRADES - realtime_updates_only: bool = False - days: int = 1 + realtime_updates_only: bool | None = None + days: int | None = None def subscribe_payload(self) -> str: extra = {} - if self.realtime_updates_only: + if self.realtime_updates_only is not None: extra['realtime_updates_only'] = self.realtime_updates_only - if self.days: + if self.days is not None: extra['days'] = self.days extra_str = json.dumps(extra, separators=(',', ':')) return f'str+{extra_str}' diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 8f3e985e..4f4731e9 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -7,7 +7,7 @@ from ibkr_ws_v2.ibkr_router import IbkrRouter from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver from support.logs import project_logger -from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router +from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router, NoopSink from ws_v2.subscription_controller import Subscription, SubscriptionResolver from ws_v2.ws_runtime import WsRuntime, WsState @@ -60,7 +60,8 @@ def __init__( # self._queue_controller.register_queues(['CLIENT_INTERNAL', 'IBKR']) # sink = QueueSink(queue_controller=self._queue_controller) - sink = LogSink() + # sink = LogSink() + sink = NoopSink() self._internal_sink = CallbackSink() self._register_internal_callbacks() diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py index fc4f135c..802bd915 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/events.py @@ -101,6 +101,10 @@ class LogSink: def emit(self, event: WsEvent) -> None: _LOGGER.debug(f'{event.key}: {str(event)}') +class NoopSink: + def emit(self, event: WsEvent) -> None: + pass + class CallbackSink: def __init__(self): diff --git a/ibind/ws_v2/subscription_controller.py b/ibind/ws_v2/subscription_controller.py index 5bb52ae3..c5079011 100644 --- a/ibind/ws_v2/subscription_controller.py +++ b/ibind/ws_v2/subscription_controller.py @@ -140,6 +140,8 @@ def parse_binding(self, binding: Binding): binding.attempts = 0 return + binding.attempts += 1 + subscription = binding.subscription if binding.intent == BindingStatus.ACTIVE: @@ -158,7 +160,7 @@ def parse_binding(self, binding: Binding): def parse_bindings(self): with self._operational_lock: - for subscription, binding in self._bindings.items(): + for binding in self._bindings.values(): if binding.status == binding.intent: continue diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index 359b8504..8c5a290c 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -9,7 +9,7 @@ from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION from support.logs import project_logger -from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock +from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany from ws_v2 import events from ws_v2.events import WsEvent, EventSink, Router from ws_v2.subscription_controller import SubscriptionController, SubscriptionResolver @@ -62,6 +62,7 @@ def __init__( self._state = WsState.STOPPED self._authenticated = False + self._running = False self._transport_thread = None self._runtime_thread = None @@ -90,13 +91,13 @@ def __init__( @property def state(self): - _LOGGER.debug(f'{self}: State: {self._state.value}') + _LOGGER.info(f'{self}: State: {self._state.value}') with self._state_lock: return self._state @state.setter def state(self, value): - _LOGGER.debug(f'{self}: {self._state.value} -> {value.value}') + _LOGGER.info(f'{self}: {self._state.value} -> {value.value}') with self._state_lock: self._state = value @@ -105,7 +106,7 @@ def state(self, value): def set_authenticated(self, value: bool): if value != self._authenticated: - _LOGGER.debug(f'{self}: Authenticated: {value}') + _LOGGER.info(f'{self}: Authenticated: {value}') self._authenticated = value if value and self._state == WsState.OPEN: @@ -157,13 +158,15 @@ def stop(self): self.state = WsState.STOPPING try: self._transport.disconnect() - self._transport_thread.join(self._connection_timeout) + if self._transport_thread is not None: + self._transport_thread.join(self._connection_timeout) except Exception as e: _LOGGER.error(f'{self}: Failed to disconnect: {e}') # TODO: decide what to do if transport disconnect fails self._running = False - self._runtime_thread.join(self._connection_timeout) + if self._runtime_thread is not None: + self._runtime_thread.join(self._connection_timeout) self.state = WsState.STOPPED @@ -204,7 +207,7 @@ def _maintain_transport(self): return if self._transport_thread is None or not self._transport_thread.is_alive(): - _LOGGER.debug(f'{self}: Starting new transport thread') + _LOGGER.info(f'{self}: Starting new transport thread') self.state = WsState.CONNECTING self._new_transport_thread() @@ -215,23 +218,23 @@ def _maintain_subscriptions(self): self.subscription_controller.parse_bindings() def _cycle(self): - _LOGGER.debug(f'{self}: Runtime thread started ({tname()})') + _LOGGER.info(f'{self}: Runtime thread started ({tname()})') while self._running: self._maintain_transport() self._maintain_subscriptions() - self.process_transport_queue() + self._process_transport_queue() self._wait_event.clear() self._wait_event.wait(self._cycle_interval) # final pass through the router queue to flush any remaining events - self.process_transport_queue() + self._process_transport_queue() # final pass through the subscription controller to carry out final unsubscribe events self.subscription_controller.parse_bindings() - _LOGGER.debug(f'{self}: Runtime thread stopped ({tname()})') + _LOGGER.info(f'{self}: Runtime thread stopped ({tname()})') - def process_transport_queue(self): + def _process_transport_queue(self): while not self._transport_queue.empty(): te = self._transport_queue.get() try: @@ -256,7 +259,7 @@ def _handle_transport_event(self, te: TransportEvent): _LOGGER.error(f'{self}: Unknown event type: {type(te)}: {te}') def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover - events = self._router.route(message) + events: OneOrMany[WsEvent] = self._router.route(message) # Router decided to skip this message if events is None: diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index a5da5362..8e7d90d3 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -151,6 +151,8 @@ def connect(self): except Exception as e: _LOGGER.exception(f'{self}: Unexpected error while running WebSocketApp: {e}') self._event_callback(TransportCritical(wsa=self._wsa, exception=e)) + finally: + self._wsa = None _LOGGER.debug(f'{self}: Transport thread stopped ({tname()})') From 7d250937be7aafb353401cd35cc2d019504b064e Mon Sep 17 00:00:00 2001 From: voyz Date: Thu, 30 Apr 2026 12:45:11 +0200 Subject: [PATCH 07/32] feat(ws_v2): added subscription handlers --- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 2 +- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 6 +- ...ription_controller.py => subscriptions.py} | 172 ++++++++++++++---- ibind/ws_v2/ws_runtime.py | 6 +- 4 files changed, 145 insertions(+), 41 deletions(-) rename ibind/ws_v2/{subscription_controller.py => subscriptions.py} (60%) diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index a4d8b377..bcb16814 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -2,7 +2,7 @@ from typing import Tuple from ibkr_ws_v2.ibkr_events import IbkrWsKey, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary -from ws_v2.subscription_controller import Subscription, SubscriptionResolver +from ws_v2.subscriptions import Subscription, SubscriptionResolver def make_binding_key( diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 4f4731e9..5eabe098 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -8,7 +8,7 @@ from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver from support.logs import project_logger from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router, NoopSink -from ws_v2.subscription_controller import Subscription, SubscriptionResolver +from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle from ws_v2.ws_runtime import WsRuntime, WsState _LOGGER = project_logger(__file__) @@ -125,10 +125,10 @@ def shutdown(self): def hard_reset(self): self._runtime.hard_reset() - def subscribe(self, subscription: Subscription) -> bool: + def subscribe(self, subscription: Subscription) -> SubscriptionHandle: return self._runtime.subscription_controller.subscribe(subscription) - def unsubscribe(self, subscription: Subscription) -> bool: + def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: return self._runtime.subscription_controller.unsubscribe(subscription) def is_running(self) -> bool: diff --git a/ibind/ws_v2/subscription_controller.py b/ibind/ws_v2/subscriptions.py similarity index 60% rename from ibind/ws_v2/subscription_controller.py rename to ibind/ws_v2/subscriptions.py index c5079011..a6944ff7 100644 --- a/ibind/ws_v2/subscription_controller.py +++ b/ibind/ws_v2/subscriptions.py @@ -1,12 +1,13 @@ import copy import time from enum import Enum +from threading import Condition, RLock from typing import Dict, Optional, Callable, Protocol, Tuple, Hashable, Literal from pydantic import BaseModel, ConfigDict from ibind.support.logs import project_logger -from ibind.support.py_utils import TimeoutLock, exception_to_string +from ibind.support.py_utils import exception_to_string from ws_v2.events import WsEvent _LOGGER = project_logger(__file__) @@ -70,6 +71,48 @@ class Binding(BaseModel): attempts: int = 0 last_attempt: float = 0 + @property + def done(self) -> bool: + return self.status == self.intent + + def reset(self): + self.status = BindingStatus.NEW + self.attempts = 0 + self.last_attempt = 0 + + +class SubscriptionHandle: + def __init__(self, controller: "SubscriptionController", subscription: Subscription): + self._controller = controller + self._subscription = subscription + + @property + def binding_key(self) -> str: + return self._subscription.binding_key() + + @property + def status(self) -> BindingStatus: + return self._controller.get_status(self.binding_key) + + @property + def active(self) -> bool: + return self.status == BindingStatus.ACTIVE + + @property + def unsubscribed(self) -> bool: + return self.status == BindingStatus.UNSUBSCRIBED + + @property + def done(self) -> bool: + return self._controller.is_done(self.binding_key) + + def wait(self, timeout: float | None = None) -> bool: + return self._controller.wait_for(self.binding_key, timeout=timeout) + + def unsubscribe(self) -> "SubscriptionHandle": + self._controller.unsubscribe(self._subscription) + return self + class SubscriptionController: """ @@ -98,7 +141,7 @@ def __init__( self._subscription_timeout = subscription_timeout self._bindings: Dict[str, Binding] = {} - self._operational_lock = TimeoutLock(60) + self._condition = Condition(RLock()) def _send(self, payload) -> bool: try: @@ -117,7 +160,7 @@ def observe(self, event: WsEvent): if binding_key is None: return - with self._operational_lock: + with self._condition: if not self.has_subscription(binding_key): _LOGGER.warning(f'{self}: Observed a binding_key "{binding_key}" that is missing a subscription. Event: {event}') return @@ -127,17 +170,18 @@ def observe(self, event: WsEvent): else: self._confirm_unsubscribed(binding_key) - def parse_binding(self, binding: Binding): + def reconcile_binding(self, binding: Binding): # wait until timeout has passed since last attempt - if binding.last_attempt + self._subscription_timeout > time.time(): + if binding.last_attempt + self._subscription_timeout > time.monotonic(): return - binding.last_attempt = time.time() + binding.last_attempt = time.monotonic() # if we've exceeded the number of retries, mark the subscription as failed if binding.attempts >= self._subscription_retries: _LOGGER.info(f'{self}: Subscription failed after {self._subscription_retries} attempts: {binding}') binding.status = BindingStatus.FAILED binding.attempts = 0 + self._condition.notify_all() return binding.attempts += 1 @@ -158,37 +202,59 @@ def parse_binding(self, binding: Binding): _LOGGER.info(f'{self}: Unsubscribed: {payload} without confirmation.') self._confirm_unsubscribed(subscription.binding_key()) - def parse_bindings(self): - with self._operational_lock: + def reconcile_bindings(self): + with self._condition: for binding in self._bindings.values(): if binding.status == binding.intent: continue - self.parse_binding(binding) + self.reconcile_binding(binding) - def subscribe(self, subscription: Subscription) -> bool: - with self._operational_lock: - if self.is_subscription_active(subscription.binding_key()): # do nothing if subscription is present and active - return True + def subscribe(self, subscription: Subscription) -> SubscriptionHandle: + binding_key = subscription.binding_key() - # store a new binding - if self.has_subscription(subscription.binding_key()): - return + with self._condition: + binding = self._bindings.get(binding_key) + + if binding is None: + self._bindings[binding_key] = Binding(subscription=subscription, intent=BindingStatus.ACTIVE) + self._condition.notify_all() + _LOGGER.info(f'{self}: Registered subscription intent: {binding_key}') - self._bindings[subscription.binding_key()] = Binding(subscription=subscription, intent=BindingStatus.ACTIVE) - _LOGGER.info(f'{self}: Registered subscription intent: {subscription.binding_key()}') + elif binding.intent != BindingStatus.ACTIVE: + binding.intent = BindingStatus.ACTIVE - def unsubscribe(self, subscription: Subscription) -> bool: - with self._operational_lock: - if self.has_subscription(subscription.binding_key()): - binding = self._bindings[subscription.binding_key()] + # If it had previously completed unsubscribe, it now needs work again. if binding.status == BindingStatus.UNSUBSCRIBED: - return - self._bindings[subscription.binding_key()].intent = BindingStatus.UNSUBSCRIBED - else: - binding = Binding(subscription=subscription, intent=BindingStatus.UNSUBSCRIBED) - self._bindings[subscription.binding_key()] = binding - _LOGGER.info(f'{self}: Registered unsubscription intent: {subscription.binding_key()}') + binding.reset() + + self._condition.notify_all() + _LOGGER.info(f'{self}: Updated subscription intent: {binding_key} -> {BindingStatus.ACTIVE.value}') + + return SubscriptionHandle(self, subscription) + + def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: + binding_key = subscription.binding_key() + + with self._condition: + binding = self._bindings.get(binding_key) + + if binding is None: + self._bindings[binding_key] = Binding(subscription=subscription, intent=BindingStatus.UNSUBSCRIBED) + self._condition.notify_all() + _LOGGER.info(f'{self}: Registered unsubscription intent: {binding_key}') + + elif binding.intent != BindingStatus.UNSUBSCRIBED: + binding.intent = BindingStatus.UNSUBSCRIBED + + # If it had previously completed subscribe, it now needs work again. + if binding.status == BindingStatus.ACTIVE: + binding.reset() + + self._condition.notify_all() + _LOGGER.info(f'{self}: Updated subscription intent: {binding_key} -> {BindingStatus.UNSUBSCRIBED.value}') + + return SubscriptionHandle(self, subscription) def invalidate_subscriptions(self): for binding_key, binding in self._bindings.items(): @@ -196,24 +262,36 @@ def invalidate_subscriptions(self): binding.status = BindingStatus.DEGRADED _LOGGER.info(f'{self}: Invalidated subscription: {binding}') - def is_subscription_active(self, binding_key: str) -> Optional[bool]: # pragma: no cover + def is_subscription_active(self, binding_key: str) -> Optional[bool]: if not self.has_subscription(binding_key): return False return self._bindings[binding_key].status == BindingStatus.ACTIVE - def has_active_subscriptions(self) -> bool: # pragma: no cover - with self._operational_lock: + def has_active_subscriptions(self) -> bool: + with self._condition: for subscription in self._bindings: if self.is_subscription_active(subscription): return True return False - def has_subscription(self, binding_key: str) -> bool: # pragma: no cover - with self._operational_lock: + def has_subscription(self, binding_key: str) -> bool: + with self._condition: return binding_key in self._bindings + def get_status(self, binding_key: str) -> BindingStatus | None: + with self._condition: + if not self.has_subscription(binding_key): + return None + return self._bindings[binding_key].status + + def is_done(self, binding_key: str) -> bool | None: + with self._condition: + if not self.has_subscription(binding_key): + return None + return self._bindings[binding_key].done + def get_active_subscriptions(self): - with self._operational_lock: + with self._condition: return { binding_key: copy.deepcopy(binding) for binding_key, binding in self._bindings.items() @@ -233,6 +311,7 @@ def _confirm_subscribed(self, binding_key: str): binding.status = BindingStatus.ACTIVE binding.attempts = 0 _LOGGER.info(f'{self}: Updated subscription status: {binding_key} -> {BindingStatus.ACTIVE.value}') + self._condition.notify_all() def _confirm_unsubscribed(self, binding_key: str): if not self.has_subscription(binding_key): @@ -247,6 +326,31 @@ def _confirm_unsubscribed(self, binding_key: str): binding.status = BindingStatus.UNSUBSCRIBED binding.attempts = 0 _LOGGER.info(f'{self}: Updated subscription status: {binding_key} -> {BindingStatus.UNSUBSCRIBED.value}') + self._condition.notify_all() + + def wait_for(self, binding_key: str, timeout: float | None = None) -> bool: + deadline = None if timeout is None else time.monotonic() + timeout + + with self._condition: + while True: + if not self.has_subscription(binding_key): + return False + + binding = self._bindings[binding_key] + + if binding.done: + return True + + if binding.status == BindingStatus.FAILED: + return False + + remaining = None + if timeout is not None: + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + + self._condition.wait(remaining) def __str__(self): return f'{self.__class__.__qualname__}()' \ No newline at end of file diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index 8c5a290c..88e1e3e5 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -12,7 +12,7 @@ from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany from ws_v2 import events from ws_v2.events import WsEvent, EventSink, Router -from ws_v2.subscription_controller import SubscriptionController, SubscriptionResolver +from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportCritical, TransportReconnect _LOGGER = project_logger(__file__) @@ -215,7 +215,7 @@ def _maintain_subscriptions(self): if self._state != self._ready_state: return - self.subscription_controller.parse_bindings() + self.subscription_controller.reconcile_bindings() def _cycle(self): _LOGGER.info(f'{self}: Runtime thread started ({tname()})') @@ -231,7 +231,7 @@ def _cycle(self): # final pass through the router queue to flush any remaining events self._process_transport_queue() # final pass through the subscription controller to carry out final unsubscribe events - self.subscription_controller.parse_bindings() + self.subscription_controller.reconcile_bindings() _LOGGER.info(f'{self}: Runtime thread stopped ({tname()})') def _process_transport_queue(self): From f481122ce2a0ed4e4dbbf3c892a892b39557d01b Mon Sep 17 00:00:00 2001 From: voyz Date: Fri, 1 May 2026 10:02:05 +0200 Subject: [PATCH 08/32] feat(ws_v2): added health checks handling and resets --- examples/ws_04_ws_v2.py | 27 ++- ibind/ibkr_ws_v2/ibkr_router.py | 59 +----- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 29 ++- ibind/ws_v2/events.py | 15 +- ibind/ws_v2/subscriptions.py | 4 +- ibind/ws_v2/ws_runtime.py | 290 +++++++++++++++++--------- ibind/ws_v2/ws_transport.py | 247 +++++++++++++++++----- 7 files changed, 438 insertions(+), 233 deletions(-) diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py index a35be319..9cc8fbe9 100644 --- a/examples/ws_04_ws_v2.py +++ b/examples/ws_04_ws_v2.py @@ -12,10 +12,12 @@ import os import time +from typing import List from ibind import ibind_logs_initialize from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PriceLadderSubscription, PnlSubscription, TradesSubscription from ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 +from ws_v2.subscriptions import SubscriptionHandle ibind_logs_initialize(log_to_file=False, log_level='DEBUG') @@ -52,8 +54,16 @@ tr_sub ] +sub_handles: List[SubscriptionHandle] = [] for sub in subs: - ws_client.subscribe(sub) + handle = ws_client.subscribe(sub) + handle.wait() + sub_handles.append(handle) + +for handle in sub_handles: + success = handle.wait(timeout=10) + if not success: + print('Subscription not active within 10 seconds') try: while ws_client.is_running(): @@ -61,8 +71,19 @@ except KeyboardInterrupt: print('Interrupt') -for sub in subs: - ws_client.unsubscribe(sub) +for handle in sub_handles: + unsub_handle = handle.unsubscribe() + success = unsub_handle.wait(timeout=10) + if not success: + print('Subscription not unsubscribed within 10 seconds') + +# unsub_handles = [] +# for sub in subs: +# handle = ws_client.unsubscribe(sub) +# unsub_handles.append(handle) +# +# for handle in unsub_handles: +# handle.wait() # time.sleep(5) ws_client.shutdown() diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index c445ced8..d4c6220a 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -15,7 +15,6 @@ def parse_raw_message(raw_message: str): message = json.loads(raw_message) - # print(message) topic = message.get('topic', UNDEFINED) if topic is UNDEFINED: @@ -50,16 +49,13 @@ def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: if not self._unwrap_market_data: return ibkr_events.MarketData(conid=data['conid'], data=data) - # return {data['conid']: data} - # result = {'conid': data['conid'], '_updated': data['_updated'], 'topic': data['topic']} fields = {} for key, value in data.items(): if key in ibkr_definitions.snapshot_by_id: # result[ibkr_definitions.snapshot_by_id[key]] = value fields[ibkr_definitions.snapshot_by_id[key]] = value return ibkr_events.MarketData(conid=str(data['conid']), fields=fields) - # return {data['conid']: result} def _preprocess_market_history_message(self, data: dict) -> OneOrMany[WsEvent]: mh_server_id_conid_pairs = self._server_id_conid_pairs[IbkrWsKey.MARKET_HISTORY] @@ -129,27 +125,11 @@ def _handle_subscribed_message(self, channel: str, data: dict) -> OneOrMany[WsEv return None def _handle_account_update(self, message, arguments) -> OneOrMany[WsEvent]: - # if 'accounts' in data and self._account_id not in data['accounts']: - # _LOGGER.error(f'{self}: Account ID mismatch: expected={self._account_id}, received={data["accounts"]}') - # if 'acctProps' in data: # expected account update that we ignore - # return [] - _LOGGER.info(f'{self}: Account update: {arguments}') return ibkr_events.AccountUpdate(data=arguments) def _handle_authentication_status(self, message, arguments) -> OneOrMany[WsEvent]: - # if 'authenticated' in arguments: - # if arguments.get('authenticated') is False: - # _LOGGER.error(f'{self}: Status unauthenticated: {arguments}') - # - # # TODO: this needs to be handled in IbkrWsClient or WsRuntime - # # self.set_authenticated(data.get('authenticated')) - # elif 'competing' in arguments: - # if arguments.get('competing') is False: - # pass - # _LOGGER.error(f'{self}: Authentication competing: {arguments}') - - if 'authenticated' in arguments: + if 'authenticated' in arguments or 'competing' in arguments: _LOGGER.info(f'{self}: Authentication status: {arguments}') return ibkr_events.AuthenticationStatus(data=arguments, authenticated=arguments.get('authenticated'), competing=arguments.get('competing')) elif ( # expected status updates that we ignore @@ -159,7 +139,7 @@ def _handle_authentication_status(self, message, arguments) -> OneOrMany[WsEvent 'serverVersion' in arguments or 'username' in arguments ): - _LOGGER.info(f'{self}: Authentication silenced: {arguments}') + # _LOGGER.info(f'{self}: Authentication silenced: {arguments}') pass return [] @@ -186,7 +166,6 @@ def _handle_market_history_unsubscribe(self, data) -> OneOrMany[WsEvent]: _LOGGER.info(f'{self}: Received unsubscribing confirmation for server_id={server_id!r}/conid={conid!r}.') if conid is not None: return ibkr_events.Unsubscription(target_key=IbkrWsKey.MARKET_HISTORY, conid=conid) - # self.modify_subscription(f'mh+{conid}', status=False) _LOGGER.warning(f'{self}: Unknown conid={conid!r}. Cannot mark the subscription as unsubscribed.') else: @@ -207,29 +186,12 @@ def _handle_message_without_topic(self, message: dict) -> OneOrMany[WsEvent]: elif 'result' in message: if message['result'] == 'unsubscribed from summary': return ibkr_events.Unsubscription(target_key=IbkrWsKey.ACCOUNT_SUMMARY) - # return self.modify_subscription(f'sd+{self._account_id}', status=False) elif message['result'] == 'unsubscribed from ledger': return ibkr_events.Unsubscription(target_key=IbkrWsKey.ACCOUNT_LEDGER) - # return self.modify_subscription(f'ld+{self._account_id}', status=False) _LOGGER.error(f'{self}: Unrecognised message without a topic: {message}') return GenericIbkrEvent(message=message) - def _preprocess_raw_message(self, raw_message: str): - message = json.loads(raw_message) - # print(message) - topic = message.get('topic', UNDEFINED) - - if topic is UNDEFINED: - return message, None, None, None, None - - data = message.get('args', {}) - - # subscribed is the indicator of whether it was a subscription or unsubscription, defined by the first letter - # channel is the actual channel we received the information about - subscribed, channel = topic[0], topic[1:] - - return message, topic, data, subscribed, channel def route(self, raw_message: str) -> OneOrMany[WsEvent]: if self._log_raw_messages: @@ -244,11 +206,10 @@ def route(self, raw_message: str) -> OneOrMany[WsEvent]: return self._handle_message_without_topic(message) elif topic == 'tic': - self._tic_message = message + # self._tic_message = message + return ibkr_events.System(data=message) elif topic == 'system': - if 'hb' in message: - self._last_heartbeat = message['hb'] return ibkr_events.System(data=message) elif topic == 'act': @@ -265,26 +226,14 @@ def route(self, raw_message: str) -> OneOrMany[WsEvent]: elif topic == 'error': return self._handle_error(message) - # _LOGGER.error(f'{self}: Error message: {message}') - # elif self.has_subscription(channel): - # if not self.is_subscription_active(channel): - # self.modify_subscription(channel, status=True) else: events = self._handle_subscribed_message(channel, message) if events is None: _LOGGER.error(f'{self}: Channel "{channel}" subscribed but lacking a handler. Message: {message}') events = GenericIbkrEvent(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) return events - # _LOGGER.warning(f'{self}: Handled a channel "{channel}" message that is missing a subscription. Message: {message}') - - _LOGGER.error(f'{self}: Topic "{topic}" unrecognised. Message: {message}') - return GenericIbkrEvent(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) - # def route(self, raw_message) -> List[WsEvent]: - # _LOGGER.debug(f'{self}: Routing message: {raw_message}') - # message, topic, data, subscribed, channel = parse_raw_message(raw_message) - # return [ParsedIbkrMessage(message=message, topic=topic, data=data, subscribed=subscribed, channel=channel)] def __str__(self): return f'{self.__class__.__qualname__}()' \ No newline at end of file diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 5eabe098..15961794 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -7,6 +7,7 @@ from ibkr_ws_v2.ibkr_router import IbkrRouter from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver from support.logs import project_logger +from ws_v2 import events from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router, NoopSink from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle from ws_v2.ws_runtime import WsRuntime, WsState @@ -60,8 +61,8 @@ def __init__( # self._queue_controller.register_queues(['CLIENT_INTERNAL', 'IBKR']) # sink = QueueSink(queue_controller=self._queue_controller) - # sink = LogSink() - sink = NoopSink() + sink = LogSink() + # sink = NoopSink() self._internal_sink = CallbackSink() self._register_internal_callbacks() @@ -88,6 +89,12 @@ def __init__( def _register_internal_callbacks(self): self._internal_sink.on(ibkr_events.AuthenticationStatus, self._on_authentication_status) self._internal_sink.on(ibkr_events.WaitingForSession, self._set_unauthenticated) + self._internal_sink.on(ibkr_events.System, self._on_system) + # self._internal_sink.on(events.WsReconnect, self._on_open) + # self._internal_sink.on(events.WsOpen, self._on_open) + + def _on_open(self, event: events.WsOpen): + _LOGGER.info(f'{self}: WSA opened, cookie: {self._get_cookie()}') def _set_unauthenticated(self, _): self._runtime.set_authenticated(False) @@ -100,13 +107,21 @@ def _on_authentication_status(self, event: ibkr_events.AuthenticationStatus): self._runtime.set_authenticated(event.authenticated) + def _on_system(self, event: ibkr_events.System): + if 'hb' in event.data: + self._runtime.set_last_heartbeat(int(event.data['hb']) / 1000) def _get_cookie(self): - try: - status = self._ibkr_client.tickle() - except ExternalBrokerError: - _LOGGER.warning('Acquiring session cookie failed, connection to the Gateway may be broken.') - return None + # try: + status = self._ibkr_client.tickle() + # except TimeoutError as e: + # if 'Reached max retries' in str(e): + # _LOGGER.warning(f'{self}: Acquiring session cookie timed out, connection to the Gateway may be broken.') + # return None + # raise + # except ExternalBrokerError: + # _LOGGER.warning(f'{self}: Acquiring session cookie failed, connection to the Gateway may be broken.') + # return None session_id = status.data['session'] if self._use_oauth: return f'api={session_id}' diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py index 802bd915..a17c520b 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/events.py @@ -1,6 +1,6 @@ from collections import defaultdict from datetime import datetime -from typing import Hashable, Protocol, Callable +from typing import Hashable, Protocol, Callable, TypeVar, List, Dict from pydantic import BaseModel, ConfigDict, Field @@ -83,11 +83,6 @@ class WsError(ClientInternalEvent): error: Exception -class WsCritical(ClientInternalEvent): - model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) - exception: Exception - - # ============= # == Sinks == # ============= @@ -101,16 +96,20 @@ class LogSink: def emit(self, event: WsEvent) -> None: _LOGGER.debug(f'{event.key}: {str(event)}') + class NoopSink: def emit(self, event: WsEvent) -> None: pass +T = TypeVar("T", bound=WsEvent) + + class CallbackSink: def __init__(self): - self._callbacks: dict[type[WsEvent], list[Callable[[WsEvent], None]]] = defaultdict(list) + self._callbacks: Dict[type[WsEvent], List[Callable[[WsEvent], None]]] = defaultdict(list) - def on(self, event_type: type[WsEvent], callback: Callable[[WsEvent], None]) -> None: + def on(self, event_type: type[WsEvent], callback: Callable[[T], None]) -> None: self._callbacks[event_type].append(callback) def emit(self, event: WsEvent) -> None: diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/subscriptions.py index a6944ff7..afda7dbc 100644 --- a/ibind/ws_v2/subscriptions.py +++ b/ibind/ws_v2/subscriptions.py @@ -172,9 +172,9 @@ def observe(self, event: WsEvent): def reconcile_binding(self, binding: Binding): # wait until timeout has passed since last attempt - if binding.last_attempt + self._subscription_timeout > time.monotonic(): + if binding.last_attempt + self._subscription_timeout > time.time(): return - binding.last_attempt = time.monotonic() + binding.last_attempt = time.time() # if we've exceeded the number of retries, mark the subscription as failed if binding.attempts >= self._subscription_retries: diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index 88e1e3e5..a4a59fe5 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -1,19 +1,20 @@ import json import ssl import threading +import time from pathlib import Path from queue import Queue from threading import Thread, Event from typing import Union, List, Dict, Callable, Literal -from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION +from websocket import WebSocketApp from support.logs import project_logger from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany from ws_v2 import events from ws_v2.events import WsEvent, EventSink, Router from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver -from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportCritical, TransportReconnect +from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect _LOGGER = project_logger(__file__) @@ -48,7 +49,8 @@ def __init__( restart_on_close: bool = True, restart_on_critical: bool = True, get_cookie: Callable = _NOOP, - get_header: Callable = _NOOP + get_header: Callable = _NOOP, + max_ping_interval: float = 20, ): self._url = url self._cycle_interval = cycle_interval @@ -59,13 +61,17 @@ def __init__( self._connection_timeout = connection_timeout self._restart_on_close = restart_on_close self._restart_on_critical = restart_on_critical + self._max_ping_interval = max_ping_interval self._state = WsState.STOPPED self._authenticated = False self._running = False + self._last_heartbeat = None + self._last_tic = time.time() + self._last_health_check = time.time() - self._transport_thread = None - self._runtime_thread = None + self._transport_thread: Thread | None = None + self._runtime_thread: Thread | None = None self._transport_queue = Queue() self._wait_event = Event() @@ -79,13 +85,23 @@ def __init__( else: sslopt = {'ca_certs': cacert} - self._transport = WsTransport( - url=url, - event_callback=self._transport_callback, - sslopt=sslopt, - get_cookie=get_cookie, - get_header=get_header, - ) + self._get_cookie = get_cookie + self._get_header = get_header + + self._transport: WsTransport = None + + def _new_transport(): + self._transport = WsTransport( + url=url, + event_callback=self._transport_callback, + sslopt=sslopt, + get_cookie=get_cookie, + get_header=get_header, + max_ping_interval=self._max_ping_interval + ) + + self._new_transport = _new_transport + self._new_transport() self.subscription_controller = SubscriptionController(send_payload=self.send, subscription_resolver=self._subscription_resolver) @@ -129,12 +145,30 @@ def _new_runtime_thread(self): self._runtime_thread.daemon = True self._runtime_thread.start() + def _stop_transport_thread(self) -> bool: + try: + self._transport.stop() + if self._transport_thread is None: + return True + + _LOGGER.debug(f'{self}: Joining transport thread') + + self._transport_thread.join(self._connection_timeout) + is_alive = self._transport_thread.is_alive() + self._transport_thread = None + return not is_alive + except Exception as e: + _LOGGER.error(f'{self}: Failed to stop transport thread: {e}') + # TODO: decide what to do if transport disconnect fails + + return False + def start(self): if self.state != WsState.STOPPED: return if self._runtime_thread is not None and self._runtime_thread.is_alive(): - _LOGGER.error(f'{self}: Runtime thread is not stopped') + _LOGGER.error(f'{self}: Runtime thread must be stopped and joined before starting') return self.state = WsState.STARTING @@ -156,18 +190,17 @@ def stop(self): # TODO: decide which thread should stop first - transport or runtime self.state = WsState.STOPPING - try: - self._transport.disconnect() - if self._transport_thread is not None: - self._transport_thread.join(self._connection_timeout) - except Exception as e: - _LOGGER.error(f'{self}: Failed to disconnect: {e}') - # TODO: decide what to do if transport disconnect fails + self._stop_transport_thread() self._running = False if self._runtime_thread is not None: self._runtime_thread.join(self._connection_timeout) + if self._runtime_thread.is_alive(): + _LOGGER.error(f'{self}: Runtime thread failed to stop, abandoning...') + + self._runtime_thread = None + self.state = WsState.STOPPED def send(self, payload: str) -> bool: @@ -185,6 +218,34 @@ def send_json(self, payload: Union[List, Dict]) -> bool: # pragma: no cover def is_running(self) -> bool: return self._running + def set_last_heartbeat(self, value: float): + self._last_heartbeat = value + + def hard_reset(self) -> None: + _LOGGER.info(f'{self}: Hard reset') + + if threading.current_thread() in [self._runtime_thread, self._transport_thread]: + raise RuntimeError(f'{self}: Hard reset called from Runtime or Transport thread. Ensure it is called from a separate thread') + + self.stop() + self.start() + + def restart_transport(self): + if threading.current_thread() == self._transport_thread: + raise RuntimeError(f'{self}: Resetting transport thread called from within transport thread. Ensure it is called from a separate thread') + + success = self._stop_transport_thread() + if not success: + _LOGGER.error(f'{self}: Failed to stop transport thread, abandoning...') + self._transport_thread = None + + self._transport.set_degraded(True) + self._new_transport() + self._new_transport_thread() + + def reset_transport_websocket(self): + self._transport.reset() + def __str__(self): return f'{self.__class__.__qualname__}({self._state})' @@ -217,6 +278,73 @@ def _maintain_subscriptions(self): self.subscription_controller.reconcile_bindings() + def check_should_restart(self): + # If WSA is not ready, we don't try to fix health + if not self._transport.is_ready(): + return False + + # If we're not either open or authenticated, we let WSA handle the reconnect first + if self._state not in [WsState.OPEN, WsState.AUTHENTICATED]: + return False + + ping_ok = self._transport.check_ping(self._max_ping_interval) + if not ping_ok: + _LOGGER.warning( + f'{self}: Last WebSocket ping happened {self._transport.get_time_since_last_ping():.2f} seconds ago, ' + f'exceeding the max ping interval of {self._max_ping_interval}.' + ) + return False + + # cookie_ok = self._transport.check_cookie() + # if cookie_ok is not None and not cookie_ok: + # _LOGGER.warning(f'{self}: Cookie check failed') + # return True + + heartbeat_ok = True + if self._last_heartbeat is not None: + diff = abs(time.time() - self._last_heartbeat) + if diff > self._max_ping_interval: + _LOGGER.warning( + f'{self}: Last heartbeat happened {diff:.2f} seconds ago, ' + f'exceeding the max ping interval of {self._max_ping_interval}.' + ) + heartbeat_ok = False + + if heartbeat_ok: + return False + + return True + + def health_check(self) -> bool: + if not self.check_should_restart(): + return True + + _LOGGER.warning(f'{self}: Health check failed, resetting transport websocket') + self.state = WsState.DEGRADED + self.set_authenticated(False) + + if not self._running: # return early if runtime got stopped in the meantime + return False + + self.reset_transport_websocket() + + # if wait_until(lambda: self._state == self._ready_state, timeout=self._connection_timeout): + # _LOGGER.info(f'Health recovered by resetting transport WebSocket') + # return True + + # if not self._running: # return early if runtime got stopped in the meantime + # return False + # + # _LOGGER.warning(f'{self}: Resetting transport websocket failed, restarting transport') + # self.restart_transport() + # + # if wait_until(lambda: self._state == self._ready_state, timeout=self._connection_timeout): + # _LOGGER.info(f'Health recovered by resetting transport thread') + # return True + # + # _LOGGER.error(f'{self}: Resetting transport websocket failed') + return False + def _cycle(self): _LOGGER.info(f'{self}: Runtime thread started ({tname()})') while self._running: @@ -225,22 +353,40 @@ def _cycle(self): self._process_transport_queue() + if time.time() - self._last_health_check > 10: + self._last_health_check = time.time() + self.health_check() + + # if time.time() - self._last_tic > 5: + # if self._transport.is_ready(): + # _LOGGER.debug(f'{self}: Sending tic') + # self._transport.send('tic') + # self._last_tic = time.time() + self._wait_event.clear() self._wait_event.wait(self._cycle_interval) - # final pass through the router queue to flush any remaining events - self._process_transport_queue() - # final pass through the subscription controller to carry out final unsubscribe events - self.subscription_controller.reconcile_bindings() + # if not stopped or closed yet, attempt to do one last pass before the thread dies + if self.state not in [WsState.STOPPED, WsState.CLOSED]: + # final pass through the transport queue to flush any remaining events + self._process_transport_queue() + + # final pass through the subscription controller to carry out final unsubscribe events + self.subscription_controller.reconcile_bindings() + _LOGGER.info(f'{self}: Runtime thread stopped ({tname()})') def _process_transport_queue(self): + retry_events = [] while not self._transport_queue.empty(): te = self._transport_queue.get() try: self._handle_transport_event(te) except Exception as e: _LOGGER.error(f'{self}: Exception processing transport event: {exception_to_string(e)} for {te}') + retry_events.append(te) + for event in retry_events: + self._transport_queue.put(event) def _handle_transport_event(self, te: TransportEvent): if isinstance(te, TransportOpened): @@ -248,11 +394,9 @@ def _handle_transport_event(self, te: TransportEvent): elif isinstance(te, TransportClosed): self._handle_on_close(te.wsa, te.close_status_code, te.close_msg) elif isinstance(te, TransportError): - self._handle_on_error(te.wsa, te.error) + self._handle_on_error(te.wsa, te.exception) elif isinstance(te, TransportMessage): self._handle_on_message(te.wsa, te.message) - elif isinstance(te, TransportCritical): - self._handle_on_critical(te.wsa, te.exception) elif isinstance(te, TransportReconnect): self._handle_on_reconnect(te.wsa) else: @@ -283,6 +427,8 @@ def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover def _handle_on_open(self, wsa: WebSocketApp): _LOGGER.info(f'{self}: Connection open') + # self._last_heartbeat = time.time() + self._last_heartbeat = None self.state = WsState.OPEN ## connected = True self._sink.emit(events.WsOpen()) @@ -290,29 +436,29 @@ def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): # pragma: _LOGGER.error(f'{self}: on_error: {exception}') if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: self.state = WsState.DEGRADED + self.set_authenticated(False) self._sink.emit(events.WsError(error=exception)) def _handle_on_reconnect(self, wsa: WebSocketApp): # pragma: no cover _LOGGER.error(f'{self}: on_reconnect') + # self._router.reset_last_heartbeat() + # self._last_heartbeat = time.time() + self._last_heartbeat = None self.set_authenticated(False) self.state = WsState.OPEN self._sink.emit(events.WsReconnect()) - def _handle_on_critical(self, wsa: WebSocketApp, exception): # pragma: no cover - self._sink.emit(events.WsCritical(exception=exception)) - if self._restart_on_critical: - # TODO: following comment is not true - no restarting in on_close takes place - # if restart_on_close is set, restarting will happen in on_close callback - self.hard_reset(restart=not self._restart_on_close) - def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): _LOGGER.info(f'{self}: on_close') + # self._router.reset_last_heartbeat() + self._last_heartbeat = None self.subscription_controller.invalidate_subscriptions() self._sink.emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) + # if we're not connected we shouldn't need to do anything - if self.state not in [self._ready_state, WsState.OPEN, WsState.STOPPING]: ## not self._connected: - _LOGGER.info(f'{self}: Unexpected on_close event while not open') - return + # if self.state not in [self._ready_state, WsState.OPEN, WsState.STOPPING]: ## not self._connected: + # _LOGGER.info(f'{self}: Unexpected on_close event while not open') + # return if close_status_code is not None or close_msg is not None: # this means an error try: @@ -323,69 +469,9 @@ def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): _LOGGER.error(f'{self}: on_close error: {close_status_code} | {msg}') else: # otherwise it's a close success confirmation - _LOGGER.info(f'{self}: Connection closed') - - if self.state == WsState.STOPPING: - _LOGGER.info(f'{self}: Gracefully closed') - - self.state = WsState.CLOSED ## self._connected = False - - # if not self._running: # if close happened due to shutting down, acknowledge and return - # _LOGGER.info(f'{self}: Gracefully closed') - # return - - def hard_reset(self, restart: bool = False) -> None: - """ - Performs a hard reset of the WebSocket connection. - - This method forcefully closes the current WebSocketApp connection and optionally restarts it. It is - used to handle scenarios where the connection is unresponsive or encounters a critical error. - - This method cannot be called from the transport thread. - - Parameters: - restart (bool, optional): Specifies whether to restart the WebSocketApp connection after resetting. - Defaults to False. - - Note: - - Closes the current WebSocketApp connection, if any, and clears related resources. - - If the WebSocketApp is unresponsive or cannot be closed, it will be abandoned and the connection will be reset. - - If 'restart' is True, the method attempts to re-establish a new WebSocketApp connection after resetting. - """ - _LOGGER.info(f'{self}: Hard reset, {restart=}, {self._wsa is None=}') - - # we want the websocket closed before reconnecting - if self._wsa is not None: - if not self._connected: - # this means that we get a bad error before we could even get a connection confirmation - # which shouldn't really happen, but if it does the original WebSocketApp is bad - # so let's drop it anyway. - self._wsa = None - restart = True # since we've abandoned the WebSocketApp, let's ensure we restart + if self.state == WsState.STOPPING: + _LOGGER.info(f'{self}: Gracefully closed') else: - _LOGGER.info(f'{self}: Hard reset is closing the WebSocketApp') - # check if current thread is the same as _transport_thread - if threading.current_thread() == self._transport_thread: - raise RuntimeError(f'{self}: Hard reset called from transport thread. Ensure it is started from a separate thread') - - self._wsa.close(status=STATUS_UNEXPECTED_CONDITION) - - # ensure the websocket is closed and abandoned - if not wait_until(lambda: self._wsa is None, f'{self}: Hard reset close timeout', timeout=self._timeout): - _LOGGER.warning(f'{self}: Abandoning current WebSocketApp that cannot be closed: {self._wsa}') - self._wsa = None - restart = True # since we've abandoned the WebSocketApp, let's ensure we restart - - # in some cases, closing the websocket will cause the restart elsewhere, therefore only closing it is enough - if restart: - _LOGGER.info(f'{self}: Forced restart') - self._reconnect() - - def _reconnect(self): - with self._reconnect_lock: - if self.state not in [WsState.OPEN, self._ready_state]: ## not self._has_active_connection(): - _LOGGER.info(f'{self}: Reconnecting') - self._try_connecting() - - if self._has_active_connection(): - self._on_reconnect() \ No newline at end of file + _LOGGER.info(f'{self}: Connection closed') + + self.state = WsState.CLOSED ## self._connected = False \ No newline at end of file diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index 8e7d90d3..e8d5030d 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -1,11 +1,13 @@ +import time from datetime import datetime -from typing import Callable, Any +from typing import Callable, Any, cast from pydantic import BaseModel, ConfigDict, Field -from websocket import WebSocketApp +from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION, STATUS_NORMAL +from ibind import ExternalBrokerError from support.logs import project_logger -from support.py_utils import exception_to_string, tname +from support.py_utils import exception_to_string, tname, wait_until, UNDEFINED _LOGGER = project_logger(__file__) @@ -33,19 +35,16 @@ class TransportClosed(TransportEvent): class TransportError(TransportEvent): model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) - error: Exception + exception: Exception class TransportMessage(TransportEvent): message: str + class TransportReconnect(TransportEvent): ... -class TransportCritical(TransportEvent): - model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) - exception: Exception - class WsTransport(): @@ -58,6 +57,8 @@ def __init__( get_header: Callable = _NOOP, ping_interval: float = 10, ping_timeout: float = 10, + max_ping_interval: float = 20, + connection_timeout: float = 5, ): self._url = url self._event_callback = event_callback @@ -65,10 +66,128 @@ def __init__( self._get_header = get_header self._ping_interval = ping_interval self._ping_timeout = ping_timeout + self._max_ping_interval = max_ping_interval + self._connection_timeout = connection_timeout self._sslopt = sslopt self._running = False self._wsa: WebSocketApp | None = None + self._degraded = False + self._tname = None + + def disconnect(self): + if self._wsa is None: + _LOGGER.info(f'{self}: WSA is None, skipping disconnect') + return + self._wsa.close(status=STATUS_NORMAL, timeout=self._connection_timeout) + + def stop(self): + _LOGGER.info(f'{self}: Stopping') + self._running = False + self.disconnect() + + def reset(self) -> bool: + if tname() == self._tname: + raise RuntimeError(f'{self}: Resetting transport thread called from within transport thread. Ensure it is called from a separate thread') + + if self._wsa is None: + _LOGGER.info(f'{self}: WSA is None, skipping reset') + return False + + _LOGGER.info(f'{self}: Reset') + + self._wsa.close(status=STATUS_UNEXPECTED_CONDITION, timeout=self._connection_timeout) + + if not wait_until(lambda: self._wsa is None, f'{self}: WebSocket reset close timeout', timeout=self._connection_timeout * 2): + _LOGGER.warning(f'{self}: Abandoning current WebSocketApp that cannot be closed: {self._wsa}') + self._wsa = None + + wait_until(lambda: self._wsa is not None, f'{self}: WebSocket recreation timeout', timeout=self._connection_timeout * 2) + + return self._wsa is not None + + def check_ping(self, max_interval: float = None) -> bool: + """ + Checks the last ping response time of the WebSocketApp connection. + + Verifies whether the last ping response from the WebSocketApp was within the acceptable time interval + defined by 'max_ping_interval' parameter. If the last ping response exceeds this interval, a hard reset of the connection is triggered. + + Returns: + bool: True if the last ping was within the acceptable interval or if the WebSocketApp is not connected, + False if the ping interval was exceeded and a hard reset was initiated. + + Note: + - A ping interval exceeding 'max_ping_interval' indicates potential issues with the WebsocketApp connection. + """ + if self._wsa is None: + return True + + if self._wsa.last_pong_tm == 0: + return True + + if max_interval is None: + max_interval = self._max_ping_interval + + return self.get_time_since_last_ping() <= max_interval + + def get_time_since_last_ping(self) -> float: + return abs(time.time() - self._wsa.last_pong_tm) + + def fetch_cookie(self): + """ + Using UNDEFINED since _get_cookie could in fact return a None, and they mean different things + """ + try: + return self._get_cookie() + except Exception as e: + if isinstance(e, TimeoutError): + _LOGGER.info(f'{self}: Timeout retrieving cookie') + return UNDEFINED + if isinstance(e, ExternalBrokerError): + if e.status_code == 401: + _LOGGER.info(f'{self}: Failed to retrieve cookie due to lack of authentication') + return UNDEFINED + _LOGGER.error(f'{self}: Failed to retrieve cookie: {exception_to_string(e)}') + return UNDEFINED + + def check_cookie(self) -> bool: + cookie = self.fetch_cookie() + if cookie is UNDEFINED: + return False + + if cookie != self._cookie: + _LOGGER.warning(f'{self}: Cookie changed, current: {cookie}, previous: {self._cookie}') + return False + return True + + def set_degraded(self, value): + self._degraded = value + + def is_ready(self) -> bool: + return self._wsa is not None and self._wsa.ready and self._wsa.sock is not None and self._wsa.sock.sock is not None + + def send(self, payload: str) -> bool: + if not self.is_ready(): + raise RuntimeError(f'{self}: WSA socket is not ready') + + try: + self._wsa.send(payload) + except Exception as e: + if 'Connection is already closed' in str(e): + _LOGGER.error(f'{self}: Connection closed while sending payload: {payload}') + else: + _LOGGER.exception(f'{self}: Sending payload failed: {payload}\n{exception_to_string(e)}') + return False + + return True + + def __str__(self): + return f'{self.__class__.__qualname__}({f"degraded:{tname()}" if self._degraded else ""})' + + # ====================== + # == Transport Thread == + # ====================== def _wrap_callback(self, f): def wrapped_f(ws, *args, **kwargs): @@ -80,32 +199,54 @@ def wrapped_f(ws, *args, **kwargs): return wrapped_f def _on_open(self, wsa: WebSocketApp): + if self._degraded: + return + if not self.check_cookie(): + self._wsa.close(status=STATUS_UNEXPECTED_CONDITION, timeout=self._connection_timeout) + return self._event_callback(TransportOpened(wsa=wsa)) def _on_message(self, wsa: WebSocketApp, message): + if self._degraded: + return self._event_callback(TransportMessage(wsa=wsa, message=message)) def _on_close(self, wsa: WebSocketApp, close_status_code, close_msg): + if self._degraded: + return self._event_callback(TransportClosed(wsa=wsa, close_status_code=close_status_code, close_msg=close_msg)) def _on_error(self, wsa: WebSocketApp, error): - self._event_callback(TransportError(wsa=wsa, error=error)) + if self._degraded: + return + self._event_callback(TransportError(wsa=wsa, exception=error)) def _on_reconnect(self, wsa: WebSocketApp): + if self._degraded: + return + if not self.check_cookie(): + self._wsa.close(status=STATUS_UNEXPECTED_CONDITION, timeout=self._connection_timeout) + return self._event_callback(TransportReconnect(wsa=wsa)) def new_wsa(self): - try: - cookie = self._get_cookie() - except Exception as e: - _LOGGER.error(f'{self}: Failed to retrieve cookie: {exception_to_string(e)}') - cookie = None + cookie = self.fetch_cookie() + if cookie is UNDEFINED: + return None + self._cookie = cookie + if cookie is not None: + _LOGGER.info(f'{self}: Current cookie: {cookie}') try: - header = self._get_header() + self._header = self._get_header() except Exception as e: _LOGGER.error(f'{self}: Failed to retrieve header: {exception_to_string(e)}') header = None + return None + + if not self._running: + # Transport got stopped between invocation of new_wsa and creating one + return None wsa = WebSocketApp( url=self._url, @@ -114,53 +255,47 @@ def new_wsa(self): on_close=self._wrap_callback(self._on_close), on_error=self._wrap_callback(self._on_error), on_reconnect=self._wrap_callback(self._on_reconnect), - cookie=cookie, - header=header, + cookie=self._cookie, + header=self._header, ) - self._wsa = wsa - - def send(self, payload: str) -> bool: - if not self._wsa.ready: - raise RuntimeError(f'{self}: WSA socket is not ready') - - try: - self._wsa.send(payload) - except Exception as e: - if 'Connection is already closed' in str(e): - _LOGGER.error(f'{self}: Connection closed while sending payload: {payload}') - else: - _LOGGER.exception(f'{self}: Sending payload failed: {payload}\n{exception_to_string(e)}') - return False - - return True + return wsa def connect(self): - _LOGGER.debug(f'{self}: Transport thread started ({tname()})') + _LOGGER.info(f'{self}: Transport thread started ({tname()})') - if self._wsa is None: - self.new_wsa() - - try: - # the timeout is set to a little sooner than the interval - self._wsa.run_forever(ping_interval=self._ping_interval, ping_timeout=self._ping_interval * 0.95, sslopt=self._sslopt, reconnect=3) - - except ValueError as e: - if 'url is invalid' in str(e): - _LOGGER.error(f'{self}: URL is invalid: {self._url}') - except Exception as e: - _LOGGER.exception(f'{self}: Unexpected error while running WebSocketApp: {e}') - self._event_callback(TransportCritical(wsa=self._wsa, exception=e)) - finally: - self._wsa = None + self._tname = tname() - _LOGGER.debug(f'{self}: Transport thread stopped ({tname()})') + self._running = True - # if self._restart_on_close and self._running: - # self._reconnect() + while self._running: + # status, reason = probe_ws_reachability(self._url, sslopt=self._sslopt, timeout=3) + # _LOGGER.debug(f'{self}: Probe result: {status}, {reason}') + # if status != ReachabilityStatus.OK: + # time.sleep(5) + # continue - def disconnect(self): - self._wsa.close() + if self._wsa is None: + wsa = self.new_wsa() + if wsa is None: + time.sleep(1) + continue + self._wsa = wsa - def __str__(self): - return f'{self.__class__.__qualname__}()' \ No newline at end of file + try: + self._wsa.run_forever( + ping_interval=self._ping_interval, + ping_timeout=self._ping_interval * 0.95, # the timeout is set to a little sooner than the interval + sslopt=self._sslopt, + reconnect=cast(int, self._connection_timeout) # floats are accepted, hence casting only for linter + ) + _LOGGER.info(f'{self}: WSA run_forever stopped gracefully') + except Exception as e: + if 'url is invalid' in str(e): + _LOGGER.error(f'{self}: URL is invalid: {self._url}') + else: + _LOGGER.exception(f'{self}: Unexpected error while running WebSocketApp: {e}') + finally: + self._wsa = None + + _LOGGER.info(f'{self}: Transport thread stopped ({tname()})') \ No newline at end of file From 40e8ea505e03d52efbc11d76a7be018f55a68320 Mon Sep 17 00:00:00 2001 From: voyz Date: Fri, 1 May 2026 10:02:34 +0200 Subject: [PATCH 09/32] chore: updated wait_until's time usage from time.time to time.monotonic --- ibind/support/py_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibind/support/py_utils.py b/ibind/support/py_utils.py index e5f42958..ecfa90e0 100644 --- a/ibind/support/py_utils.py +++ b/ibind/support/py_utils.py @@ -279,8 +279,8 @@ def wait_until(condition: callable, timeout_message: str = None, timeout: float bool: True if the condition becomes True within the timeout period, False otherwise. """ - deadline = time.time() + timeout - while time.time() < deadline: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: if condition(): return True time.sleep(sleep) From 87dbcfff10de2b5cc366e1e2106b39c37c0dbd6c Mon Sep 17 00:00:00 2001 From: voyz Date: Fri, 1 May 2026 11:50:01 +0200 Subject: [PATCH 10/32] feat(ws_v2): added WsDegraded event, removed WsReconnect and unified marking as degraded across WsRuntime --- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 12 +-- ibind/support/py_utils.py | 5 +- ibind/ws_v2/events.py | 6 +- ibind/ws_v2/subscriptions.py | 4 +- ibind/ws_v2/ws_runtime.py | 110 ++++++++++++++------------ ibind/ws_v2/ws_transport.py | 13 ++- 6 files changed, 75 insertions(+), 75 deletions(-) diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 15961794..eb0d3e0f 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -2,13 +2,12 @@ from typing import Union import var -from ibind import ExternalBrokerError, IbkrClient +from ibind import IbkrClient from ibkr_ws_v2 import ibkr_events from ibkr_ws_v2.ibkr_router import IbkrRouter from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver from support.logs import project_logger -from ws_v2 import events -from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router, NoopSink +from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle from ws_v2.ws_runtime import WsRuntime, WsState @@ -90,11 +89,6 @@ def _register_internal_callbacks(self): self._internal_sink.on(ibkr_events.AuthenticationStatus, self._on_authentication_status) self._internal_sink.on(ibkr_events.WaitingForSession, self._set_unauthenticated) self._internal_sink.on(ibkr_events.System, self._on_system) - # self._internal_sink.on(events.WsReconnect, self._on_open) - # self._internal_sink.on(events.WsOpen, self._on_open) - - def _on_open(self, event: events.WsOpen): - _LOGGER.info(f'{self}: WSA opened, cookie: {self._get_cookie()}') def _set_unauthenticated(self, _): self._runtime.set_authenticated(False) @@ -150,4 +144,4 @@ def is_running(self) -> bool: return self._runtime.is_running() def __str__(self): - return f'{self.__class__.__qualname__}()' \ No newline at end of file + return f'{self.__class__.__qualname__}()' diff --git a/ibind/support/py_utils.py b/ibind/support/py_utils.py index ecfa90e0..988e4a42 100644 --- a/ibind/support/py_utils.py +++ b/ibind/support/py_utils.py @@ -20,6 +20,9 @@ S = TypeVar('S') OneOrMany = Union[S, List[S]] +def NOOP(): + return None + _LOGGER = project_logger(__file__) @@ -351,4 +354,4 @@ def warn_if_late_load(*args, **kwargs): dotenv.load_dotenv = warn_if_late_load except ImportError: - pass # dotenv is not installed, nothing to patch \ No newline at end of file + pass # dotenv is not installed, nothing to patch diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py index a17c520b..2ba9e8e8 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/events.py @@ -65,11 +65,11 @@ class WsAuthenticated(ClientInternalEvent): ... -class WsReady(ClientInternalEvent): +class WsDegraded(ClientInternalEvent): ... -class WsReconnect(ClientInternalEvent): +class WsReady(ClientInternalEvent): ... @@ -143,4 +143,4 @@ def route(self, raw_message) -> OneOrMany[WsEvent]: ... def __str__(self): - return f'{self.__class__.__qualname__}()' \ No newline at end of file + return f'{self.__class__.__qualname__}()' diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/subscriptions.py index afda7dbc..5596941e 100644 --- a/ibind/ws_v2/subscriptions.py +++ b/ibind/ws_v2/subscriptions.py @@ -260,7 +260,7 @@ def invalidate_subscriptions(self): for binding_key, binding in self._bindings.items(): if binding.status == BindingStatus.ACTIVE: binding.status = BindingStatus.DEGRADED - _LOGGER.info(f'{self}: Invalidated subscription: {binding}') + _LOGGER.info(f'{self}: Invalidated: {binding}') def is_subscription_active(self, binding_key: str) -> Optional[bool]: if not self.has_subscription(binding_key): @@ -353,4 +353,4 @@ def wait_for(self, binding_key: str, timeout: float | None = None) -> bool: self._condition.wait(remaining) def __str__(self): - return f'{self.__class__.__qualname__}()' \ No newline at end of file + return f'{self.__class__.__qualname__}()' diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index a4a59fe5..a76ab891 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -10,7 +10,7 @@ from websocket import WebSocketApp from support.logs import project_logger -from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany +from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, NOOP from ws_v2 import events from ws_v2.events import WsEvent, EventSink, Router from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver @@ -18,8 +18,6 @@ _LOGGER = project_logger(__file__) -_NOOP = lambda: None - _DEFAULT_TIMEOUT = 5 @@ -35,6 +33,16 @@ class WsState(VerboseEnum): STOPPING = 'STOPPING', +def make_sslopt(cacert: Union[str, bool]): + if not (cacert is False or Path(cacert).exists()): + raise ValueError(f'Cacert must be a valid Path or False, found: {cacert}') + + if cacert is None or not cacert: + return {'cert_reqs': ssl.CERT_NONE} + else: + return {'ca_certs': cacert} + + class WsRuntime(): def __init__( self, @@ -46,21 +54,18 @@ def __init__( ready_state: Literal[WsState.OPEN, WsState.AUTHENTICATED] = WsState.OPEN, cacert: Union[str, bool] = False, connection_timeout: float = _DEFAULT_TIMEOUT, - restart_on_close: bool = True, - restart_on_critical: bool = True, - get_cookie: Callable = _NOOP, - get_header: Callable = _NOOP, max_ping_interval: float = 20, + get_cookie: Callable = NOOP, + get_header: Callable = NOOP, ): + if ready_state not in [WsState.OPEN, WsState.AUTHENTICATED]: + raise ValueError(f'Invalid ready_state: {ready_state}, must be either {WsState.OPEN} or {WsState.AUTHENTICATED}') self._url = url self._cycle_interval = cycle_interval self._sink = sink self._router = router - self._subscription_resolver = subscription_resolver self._ready_state = ready_state self._connection_timeout = connection_timeout - self._restart_on_close = restart_on_close - self._restart_on_critical = restart_on_critical self._max_ping_interval = max_ping_interval self._state = WsState.STOPPED @@ -77,33 +82,26 @@ def __init__( self._state_lock = TimeoutLock(60) - if not (cacert is False or Path(cacert).exists()): - raise ValueError(f'{self}: cacert must be a valid Path or False') - - if cacert is None or not cacert: - sslopt = {'cert_reqs': ssl.CERT_NONE} - else: - sslopt = {'ca_certs': cacert} + self._sslopt = make_sslopt(cacert) self._get_cookie = get_cookie self._get_header = get_header - self._transport: WsTransport = None - - def _new_transport(): - self._transport = WsTransport( - url=url, - event_callback=self._transport_callback, - sslopt=sslopt, - get_cookie=get_cookie, - get_header=get_header, - max_ping_interval=self._max_ping_interval - ) + self._transport: WsTransport | None = None - self._new_transport = _new_transport self._new_transport() - self.subscription_controller = SubscriptionController(send_payload=self.send, subscription_resolver=self._subscription_resolver) + self.subscription_controller = SubscriptionController(send_payload=self.send, subscription_resolver=subscription_resolver) + + def _new_transport(self): + self._transport = WsTransport( + url=self._url, + event_callback=self._transport_callback, + sslopt=self._sslopt, + get_cookie=self._get_cookie, + get_header=self._get_header, + max_ping_interval=self._max_ping_interval + ) @property def state(self): @@ -129,8 +127,16 @@ def set_authenticated(self, value: bool): self._sink.emit(events.WsAuthenticated()) self.state = WsState.AUTHENTICATED - if value == False: - self.subscription_controller.invalidate_subscriptions() + if value == False and self._state == self._ready_state: + self.state_degraded() + + def state_degraded(self): + was_already_degraded = self._state == WsState.DEGRADED + self.state = WsState.DEGRADED + self.subscription_controller.invalidate_subscriptions() + + if not was_already_degraded: + self._sink.emit(events.WsDegraded()) def get_authenticated(self) -> bool: return self._authenticated @@ -320,8 +326,7 @@ def health_check(self) -> bool: return True _LOGGER.warning(f'{self}: Health check failed, resetting transport websocket') - self.state = WsState.DEGRADED - self.set_authenticated(False) + self.state_degraded() if not self._running: # return early if runtime got stopped in the meantime return False @@ -418,42 +423,41 @@ def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover try: self.subscription_controller.observe(event) except Exception as e: - _LOGGER.error(f'{self}: Exception observing subscription: {exception_to_string(e)} for {event}') + _LOGGER.error(f'{self}: Exception observing subscription for {event}: {exception_to_string(e)}') try: self._sink.emit(event) except Exception as e: - _LOGGER.error(f'{self}: Exception propagating event: {exception_to_string(e)} for {event}') + _LOGGER.error(f'{self}: Exception propagating event {event}: {exception_to_string(e)}') def _handle_on_open(self, wsa: WebSocketApp): _LOGGER.info(f'{self}: Connection open') # self._last_heartbeat = time.time() self._last_heartbeat = None - self.state = WsState.OPEN ## connected = True + self.state = WsState.OPEN + if self._state != self._ready_state: + self.set_authenticated(False) self._sink.emit(events.WsOpen()) - def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): # pragma: no cover + def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): _LOGGER.error(f'{self}: on_error: {exception}') if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: - self.state = WsState.DEGRADED + self.state_degraded() self.set_authenticated(False) self._sink.emit(events.WsError(error=exception)) - def _handle_on_reconnect(self, wsa: WebSocketApp): # pragma: no cover + def _handle_on_reconnect(self, wsa: WebSocketApp): _LOGGER.error(f'{self}: on_reconnect') - # self._router.reset_last_heartbeat() # self._last_heartbeat = time.time() self._last_heartbeat = None - self.set_authenticated(False) self.state = WsState.OPEN - self._sink.emit(events.WsReconnect()) + if self._state != self._ready_state: + self.set_authenticated(False) + self._sink.emit(events.WsOpen()) def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): _LOGGER.info(f'{self}: on_close') - # self._router.reset_last_heartbeat() self._last_heartbeat = None - self.subscription_controller.invalidate_subscriptions() - self._sink.emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) # if we're not connected we shouldn't need to do anything # if self.state not in [self._ready_state, WsState.OPEN, WsState.STOPPING]: ## not self._connected: @@ -468,10 +472,12 @@ def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): _LOGGER.error(f'{self}: on_close error: {close_status_code} | {msg}') - else: # otherwise it's a close success confirmation - if self.state == WsState.STOPPING: - _LOGGER.info(f'{self}: Gracefully closed') - else: - _LOGGER.info(f'{self}: Connection closed') + elif self.state == WsState.STOPPING: + _LOGGER.info(f'{self}: Gracefully closed') + else: + _LOGGER.info(f'{self}: Connection closed') - self.state = WsState.CLOSED ## self._connected = False \ No newline at end of file + self.set_authenticated(False) + self.subscription_controller.invalidate_subscriptions() + self.state = WsState.CLOSED + self._sink.emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index e8d5030d..8961fdec 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -7,12 +7,10 @@ from ibind import ExternalBrokerError from support.logs import project_logger -from support.py_utils import exception_to_string, tname, wait_until, UNDEFINED +from support.py_utils import exception_to_string, tname, wait_until, UNDEFINED, NOOP _LOGGER = project_logger(__file__) -_NOOP = lambda: None - class TransportEvent(BaseModel): model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) @@ -53,8 +51,8 @@ def __init__( url: str, event_callback: Callable, sslopt: dict[str, Any], - get_cookie: Callable = _NOOP, - get_header: Callable = _NOOP, + get_cookie: Callable = NOOP, + get_header: Callable = NOOP, ping_interval: float = 10, ping_timeout: float = 10, max_ping_interval: float = 20, @@ -241,7 +239,6 @@ def new_wsa(self): self._header = self._get_header() except Exception as e: _LOGGER.error(f'{self}: Failed to retrieve header: {exception_to_string(e)}') - header = None return None if not self._running: @@ -287,7 +284,7 @@ def connect(self): ping_interval=self._ping_interval, ping_timeout=self._ping_interval * 0.95, # the timeout is set to a little sooner than the interval sslopt=self._sslopt, - reconnect=cast(int, self._connection_timeout) # floats are accepted, hence casting only for linter + reconnect=cast(int, self._connection_timeout) # floats are de facto valid, casting only for the linter ) _LOGGER.info(f'{self}: WSA run_forever stopped gracefully') except Exception as e: @@ -298,4 +295,4 @@ def connect(self): finally: self._wsa = None - _LOGGER.info(f'{self}: Transport thread stopped ({tname()})') \ No newline at end of file + _LOGGER.info(f'{self}: Transport thread stopped ({tname()})') From d693b109834eadeabf1083efee8a55489e54a3cf Mon Sep 17 00:00:00 2001 From: voyz Date: Sat, 2 May 2026 12:52:23 +0200 Subject: [PATCH 11/32] feat(ws_v2): added automated handling of MarketHistory unsubscriptions --- examples/ws_04_ws_v2.py | 14 +++++---- ibind/ibkr_ws_v2/ibkr_events.py | 12 ++++++-- ibind/ibkr_ws_v2/ibkr_router.py | 21 ++++++++------ ibind/ibkr_ws_v2/ibkr_subscriptions.py | 39 ++++++++++++++++++++++++-- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 39 ++++++++++++++++++++++++-- 5 files changed, 102 insertions(+), 23 deletions(-) diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py index 9cc8fbe9..d27147b7 100644 --- a/examples/ws_04_ws_v2.py +++ b/examples/ws_04_ws_v2.py @@ -41,17 +41,19 @@ as_sub = AccountSummarySubscription(account_id=account_id) al_sub = AccountLedgerSubscription(account_id=account_id) md_sub = MarketDataSubscription(conid='265598', fields=("31", "84", "86")) +mh_sub = MarketHistorySubscription(conid='265598') or_sub = OrdersSubscription() # pl_sub = PriceLadderSubscription(conid='265598', account_id=account_id, exchange='SMART') pnl_sub = PnlSubscription() tr_sub = TradesSubscription() subs = [ - as_sub, - al_sub, - md_sub, - or_sub, - pnl_sub, - tr_sub + # as_sub, + # al_sub, + # md_sub, + mh_sub, + # or_sub, + # pnl_sub, + # tr_sub ] sub_handles: List[SubscriptionHandle] = [] diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py index aa92069f..5c486baa 100644 --- a/ibind/ibkr_ws_v2/ibkr_events.py +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -11,6 +11,7 @@ class IbkrWsKey(Enum): UNCLASSIFIED = 'UNCLASSIFIED' GENERIC = 'GENERIC' UNSUBSCRIPTION = 'UNSUBSCRIPTION' + SERVER_ID = 'SERVER_ID' # unsolicited ACCOUNT_UPDATE = 'ACCOUNT_UPDATE' @@ -125,7 +126,7 @@ class AuthenticationStatus(WsEvent): class Unsubscription(WsEvent): key: IbkrWsKey = IbkrWsKey.UNSUBSCRIPTION target_key: IbkrWsKey - conid: int | None = None + conid: str | None = None class AccountSummary(WsEvent): @@ -153,6 +154,13 @@ class MarketHistory(WsEvent): data: dict +class ServerId(WsEvent): + key: IbkrWsKey = IbkrWsKey.SERVER_ID + conid: str + server_id: str + target_key: IbkrWsKey + + class Orders(WsEvent): key: IbkrWsKey = IbkrWsKey.ORDERS data: dict @@ -173,4 +181,4 @@ class Pnl(WsEvent): class Trades(WsEvent): key: IbkrWsKey = IbkrWsKey.TRADES - data: dict \ No newline at end of file + data: dict diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index d4c6220a..f12dd1d9 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -37,7 +37,7 @@ def __init__( ): self._log_raw_messages = log_raw_messages self._unwrap_market_data = unwrap_market_data - self._server_id_conid_pairs: Dict[IbkrWsKey, Dict[str, int]] = defaultdict(dict) + self._server_id_conid_pairs: Dict[IbkrWsKey, Dict[str, str]] = defaultdict(dict) def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: """ @@ -59,10 +59,14 @@ def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: def _preprocess_market_history_message(self, data: dict) -> OneOrMany[WsEvent]: mh_server_id_conid_pairs = self._server_id_conid_pairs[IbkrWsKey.MARKET_HISTORY] + events = [] + conid = extract_conid(data) if 'serverId' in data and data['serverId'] not in mh_server_id_conid_pairs: - mh_server_id_conid_pairs[data['serverId']] = extract_conid(data) + mh_server_id_conid_pairs[data['serverId']] = str(conid) + events.append(ibkr_events.ServerId(conid=str(conid), server_id=data['serverId'], target_key=IbkrWsKey.MARKET_HISTORY)) - return ibkr_events.MarketHistory(conid=str(data['conid']), data=data) + events.append(ibkr_events.MarketHistory(conid=str(conid), data=data)) + return events def _preprocess_account_ledger(self, data): events = [] @@ -91,6 +95,7 @@ def _preprocess_account_summary(self, data): if 'AccountCode' not in summary or 'value' not in summary['AccountCode']: _LOGGER.error(f'{self}: Account code not found in account summary: {summary}') return [] + account_id = summary['AccountCode']['value'] summary['timestamp'] = timestamp @@ -163,15 +168,13 @@ def _handle_market_history_unsubscribe(self, data) -> OneOrMany[WsEvent]: mh_server_id_conid_pairs = self._server_id_conid_pairs[IbkrWsKey.MARKET_HISTORY] if server_id in mh_server_id_conid_pairs: conid = mh_server_id_conid_pairs[server_id] - _LOGGER.info(f'{self}: Received unsubscribing confirmation for server_id={server_id!r}/conid={conid!r}.') + _LOGGER.info(f'{self}: Received unsubscribing confirmation for server_id={server_id!r}, conid={conid!r}.') if conid is not None: - return ibkr_events.Unsubscription(target_key=IbkrWsKey.MARKET_HISTORY, conid=conid) + return ibkr_events.Unsubscription(target_key=IbkrWsKey.MARKET_HISTORY, conid=str(conid)) _LOGGER.warning(f'{self}: Unknown conid={conid!r}. Cannot mark the subscription as unsubscribed.') else: - _LOGGER.warning( - f'{self}: Received unsubscribing confirmation for unknown server_id={server_id!r}. Existing server_ids: {mh_server_id_conid_pairs}' - ) + _LOGGER.warning(f'{self}: Received unsubscribing confirmation for unknown server_id={server_id!r}. Existing server_ids: {mh_server_id_conid_pairs}') return [] def _handle_message_without_topic(self, message: dict) -> OneOrMany[WsEvent]: @@ -236,4 +239,4 @@ def route(self, raw_message: str) -> OneOrMany[WsEvent]: def __str__(self): - return f'{self.__class__.__qualname__}()' \ No newline at end of file + return f'{self.__class__.__qualname__}()' diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index bcb16814..14a4246d 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -1,7 +1,10 @@ import json from typing import Tuple +from pydantic import Field + from ibkr_ws_v2.ibkr_events import IbkrWsKey, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary +from support.py_utils import filter_none from ws_v2.subscriptions import Subscription, SubscriptionResolver @@ -143,12 +146,31 @@ def binding_key(self): class MarketHistorySubscription(IbkrSubscription): key: IbkrWsKey = IbkrWsKey.MARKET_HISTORY conid: str + exchange: str = None + period: str = None + bar: str = None + outside_rth: bool = None + source: str = None + format: str = None + server_id: list = Field(default_factory=list) # uses list to allow writing despite frozen model def subscribe_payload(self) -> str: - ... + data = { + 'exchange': self.exchange, + 'period': self.period, + 'bar': self.bar, + 'outside_rth': self.outside_rth, + 'source': self.source, + 'format': self.format, + } + data = filter_none(data) + return f'smh+{self.conid}+{json.dumps(data, separators=(',', ':'))}' def unsubscribe_payload(self) -> str: - ... + server_id = self.get_server_id() + if server_id is None: + raise RuntimeError(f'{self}: Unsubscribing from market history for conid={self.conid!r} without server_id. MarketHistorySubscription must have server_id set before unsubscribing.') + return f'umh+{server_id}' @property def confirms_subscribe(self) -> bool: @@ -158,6 +180,17 @@ def confirms_subscribe(self) -> bool: def confirms_unsubscribe(self) -> bool: return True + def set_server_id(self, server_id): + if self.has_server_id(): + raise ValueError('Server ID already set') + self.server_id.append(server_id) + + def has_server_id(self) -> bool: + return len(self.server_id) > 0 + + def get_server_id(self): + return self.server_id[0] + def binding_key(self): return make_binding_key(self.key, conid=self.conid) @@ -256,4 +289,4 @@ def confirms_unsubscribe(self) -> bool: return False def binding_key(self): - return make_binding_key(self.key) \ No newline at end of file + return make_binding_key(self.key) diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index eb0d3e0f..8b099ddd 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -1,11 +1,12 @@ import json -from typing import Union +from collections import defaultdict +from typing import Union, List, Dict import var -from ibind import IbkrClient +from ibind import IbkrClient, IbkrWsKey from ibkr_ws_v2 import ibkr_events from ibkr_ws_v2.ibkr_router import IbkrRouter -from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver +from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription from support.logs import project_logger from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle @@ -85,10 +86,14 @@ def __init__( get_header=self._get_header, ) + self._mh_subscriptions: List[MarketHistorySubscription] = [] + self._conid_server_id_pairs: Dict[IbkrWsKey, Dict[str, str]] = defaultdict(dict) + def _register_internal_callbacks(self): self._internal_sink.on(ibkr_events.AuthenticationStatus, self._on_authentication_status) self._internal_sink.on(ibkr_events.WaitingForSession, self._set_unauthenticated) self._internal_sink.on(ibkr_events.System, self._on_system) + self._internal_sink.on(ibkr_events.ServerId, self._on_server_id) def _set_unauthenticated(self, _): self._runtime.set_authenticated(False) @@ -105,6 +110,12 @@ def _on_system(self, event: ibkr_events.System): if 'hb' in event.data: self._runtime.set_last_heartbeat(int(event.data['hb']) / 1000) + def _on_server_id(self, event: ibkr_events.ServerId): + self._conid_server_id_pairs[event.target_key][event.conid] = event.server_id + for subscription in self._mh_subscriptions: + if subscription.key == event.target_key and subscription.conid == event.conid and not subscription.has_server_id(): + subscription.set_server_id(event.server_id) + def _get_cookie(self): # try: status = self._ibkr_client.tickle() @@ -135,11 +146,33 @@ def hard_reset(self): self._runtime.hard_reset() def subscribe(self, subscription: Subscription) -> SubscriptionHandle: + if isinstance(subscription, MarketHistorySubscription): + self._mh_subscriptions.append(subscription) return self._runtime.subscription_controller.subscribe(subscription) def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: + if isinstance(subscription, MarketHistorySubscription): + self._handle_mh_unsubscription(subscription) return self._runtime.subscription_controller.unsubscribe(subscription) + def get_server_id(self, key: IbkrWsKey, conid: str) -> str: + return self._conid_server_id_pairs[key][conid] + + def _handle_mh_unsubscription(self, subscription: MarketHistorySubscription): + if subscription.has_server_id(): + return + server_id = self._conid_server_id_pairs.get(subscription.key, {}).get(subscription.conid) + if server_id is None: + raise RuntimeError(f'{self}: Unsubscribing from market history for conid={subscription.conid!r} without server_id. Could not find server_id in memory. Ensure at least one MarketHistory event is received before unsubscribing.') + + _LOGGER.warning( + f'{self}: Unsubscribing from market history for conid={subscription.conid!r} without server_id. Setting from memory: {server_id!r}. ' + f'Unsubscribe using the same Subscription instance that was used for subscribing to avoid this warning, ' + f'or set it manually before calling unsubscribe by using ' + f'`subscription.set_server_id(ibkr_ws_client.get_server_id(IbkrWsKey.MARKET_HISTORY, conid))`' + ) + subscription.set_server_id(server_id) + def is_running(self) -> bool: return self._runtime.is_running() From 8aeb3e7779fb3336d48e54e0f19294fafb36a926 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 12:41:43 +0200 Subject: [PATCH 12/32] feat(ws_v2): added AsyncQueue and subscription expires --- examples/ws_04_ws_v2.py | 4 +- ibind/ibkr_ws_v2/ibkr_router.py | 2 +- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 17 +++- ibind/ws_v2/events.py | 126 +++++++++++++++++++++++++- ibind/ws_v2/subscriptions.py | 42 +++++---- ibind/ws_v2/ws_runtime.py | 58 ++++++++---- ibind/ws_v2/ws_transport.py | 2 +- 7 files changed, 202 insertions(+), 49 deletions(-) diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py index d27147b7..161f859a 100644 --- a/examples/ws_04_ws_v2.py +++ b/examples/ws_04_ws_v2.py @@ -15,7 +15,7 @@ from typing import List from ibind import ibind_logs_initialize -from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PriceLadderSubscription, PnlSubscription, TradesSubscription +from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PriceLadderSubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription from ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 from ws_v2.subscriptions import SubscriptionHandle @@ -40,7 +40,7 @@ as_sub = AccountSummarySubscription(account_id=account_id) al_sub = AccountLedgerSubscription(account_id=account_id) -md_sub = MarketDataSubscription(conid='265598', fields=("31", "84", "86")) +md_sub = MarketDataSubscription(conid='265598', fields=("31", "84", "86"), expiry_seconds=30) mh_sub = MarketHistorySubscription(conid='265598') or_sub = OrdersSubscription() # pl_sub = PriceLadderSubscription(conid='265598', account_id=account_id, exchange='SMART') diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index f12dd1d9..1ce99fb8 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -10,7 +10,7 @@ from support.py_utils import UNDEFINED, OneOrMany from ws_v2.events import WsEvent -_LOGGER = project_logger(__file__) +_LOGGER = project_logger('websocket') def parse_raw_message(raw_message: str): diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 8b099ddd..4ea96ef3 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -3,16 +3,17 @@ from typing import Union, List, Dict import var +from base.queue_controller import QueueController from ibind import IbkrClient, IbkrWsKey from ibkr_ws_v2 import ibkr_events from ibkr_ws_v2.ibkr_router import IbkrRouter from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription from support.logs import project_logger -from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router +from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router, AsyncSink from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle from ws_v2.ws_runtime import WsRuntime, WsState -_LOGGER = project_logger(__file__) +_LOGGER = project_logger('websocket') _DEFAULT_CYCLE_INTERVAL = 0.25 @@ -34,6 +35,7 @@ def __init__( sink: EventSink = None, router: Router = None, subscription_resolver: SubscriptionResolver = None, + synchronous_output_events: bool = False, ): self._account_id = account_id @@ -56,8 +58,10 @@ def __init__( self._use_oauth = use_oauth self._recreate_subscriptions_on_reconnect = recreate_subscriptions_on_reconnect + self._queue_controller = QueueController[IbkrWsKey]() + self._queue_controller.register_queues(list(IbkrWsKey)) + if sink is None: - # self._queue_controller = QueueController[IbkrWsKey]() # self._queue_controller.register_queues(['CLIENT_INTERNAL', 'IBKR']) # sink = QueueSink(queue_controller=self._queue_controller) @@ -66,7 +70,11 @@ def __init__( self._internal_sink = CallbackSink() self._register_internal_callbacks() - sink = CompositeSink(self._internal_sink, sink) + + if synchronous_output_events: + _LOGGER.info(f'{self}: Output events will be emitted synchronously from the runtime thread') + else: + sink = AsyncSink(sink=sink) if router is None: router = IbkrRouter() @@ -80,6 +88,7 @@ def __init__( ready_state=WsState.AUTHENTICATED, cacert=cacert, sink=sink, + internal_sink=self._internal_sink, router=router, subscription_resolver=subscription_resolver, get_cookie=self._get_cookie, diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py index 2ba9e8e8..21879cc3 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/events.py @@ -1,14 +1,17 @@ +import threading from collections import defaultdict from datetime import datetime +from queue import Queue, Full, Empty +from threading import Thread, Event from typing import Hashable, Protocol, Callable, TypeVar, List, Dict from pydantic import BaseModel, ConfigDict, Field from base.queue_controller import QueueController from support.logs import project_logger -from support.py_utils import OneOrMany +from support.py_utils import OneOrMany, exception_to_string -_LOGGER = project_logger(__file__) +_LOGGER = project_logger('websocket') # ====================== @@ -89,7 +92,13 @@ class WsError(ClientInternalEvent): class EventSink(Protocol): def emit(self, event: "WsEvent") -> None: - ... + pass + + def start(self): + pass + + def stop(self): + pass class LogSink: @@ -114,7 +123,13 @@ def on(self, event_type: type[WsEvent], callback: Callable[[T], None]) -> None: def emit(self, event: WsEvent) -> None: for callback in self._callbacks[type(event)]: - callback(event) + try: + callback(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception emitting event to callback: {exception_to_string(e)}') + + def __str__(self): + return f'{self.__class__.__qualname__}()' class QueueSink: @@ -131,7 +146,108 @@ def __init__(self, *sinks: EventSink): def emit(self, event: WsEvent) -> None: for sink in self._sinks: - sink.emit(event) + try: + sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception emitting event to sink: {exception_to_string(e)}') + + def __str__(self): + return f'{self.__class__.__qualname__}()' + + +class AsyncSink: + def __init__( + self, + sink: EventSink, + maxsize: int = 10_000, + drop_oldest: bool = True, + stop_timeout: float = 5, + cycle_interval: float = 0.25, + ): + self._sink = sink + self._queue = Queue(maxsize=maxsize) + self._drop_oldest = drop_oldest + self._stop_timeout = stop_timeout + self._cycle_interval = cycle_interval + + self._running = False + self._thread: Thread | None = None + self._wait_event = Event() + + def start(self): + if self._running: + return + + self._running = True + self._thread = Thread(target=self._cycle, name="ws_sink_thread", daemon=True) + self._thread.start() + + def stop(self) -> bool: + if not self._running: + return True + + if threading.current_thread() == self._thread: + raise RuntimeError(f'{self}: Stopping async sink called from within async sink thread. Ensure it is stopped from a separate thread') + + self._running = False + self._wait_event.set() + + succeeded = True + if self._thread is not None: + self._thread.join(self._stop_timeout) + succeeded = not self._thread.is_alive() + + self._thread = None + + if self._queue.qsize() > 0: + _LOGGER.warning(f'{self}: Event queue not empty when stopping; discarding {self._queue.qsize()} events') + + return succeeded + + def emit(self, event: WsEvent) -> None: + try: + self._queue.put_nowait(event) + self._wait_event.set() + return + except Full: + if not self._drop_oldest: + _LOGGER.warning(f'{self}: Event queue full; dropping newest event: {event}') + return + + try: + dropped = self._queue.get_nowait() + _LOGGER.warning(f'{self}: Event queue full; dropping oldest event: {dropped}') + except Empty: + pass + + try: + self._queue.put_nowait(event) + self._wait_event.set() + except Full: + _LOGGER.warning(f'{self}: Event queue still full; dropping event: {event}') + + def _consume_queue(self): + while True: + try: + event = self._queue.get_nowait() + except Empty: + break + + try: + self._sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: Exception emitting event to sink: {exception_to_string(e)}') + + def _cycle(self): + while self._running: + self._wait_event.clear() + self._wait_event.wait(self._cycle_interval) + self._consume_queue() + + self._consume_queue() + + def __str__(self): + return f'{self.__class__.__qualname__}({self._queue.qsize()})' # ============== diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/subscriptions.py index 5596941e..455219a8 100644 --- a/ibind/ws_v2/subscriptions.py +++ b/ibind/ws_v2/subscriptions.py @@ -10,12 +10,13 @@ from ibind.support.py_utils import exception_to_string from ws_v2.events import WsEvent -_LOGGER = project_logger(__file__) +_LOGGER = project_logger('websocket') class Subscription(BaseModel): model_config = ConfigDict(frozen=True) key: Hashable + expiry_seconds: int | None = None @property def topic(self) -> str: @@ -61,7 +62,7 @@ class BindingStatus(Enum): FAILED = "FAILED" DEGRADED = "DEGRADED" UNSUBSCRIBED = "UNSUBSCRIBED" - RECONNECTING = "RECONNECTING" + EXPIRED = "EXPIRED" class Binding(BaseModel): @@ -171,10 +172,21 @@ def observe(self, event: WsEvent): self._confirm_unsubscribed(binding_key) def reconcile_binding(self, binding: Binding): + now = time.time() + + + if binding.status == binding.intent: + time_since_last_attempt = now - binding.last_attempt + if binding.subscription.expiry_seconds is None or time_since_last_attempt < binding.subscription.expiry_seconds: + return + + _LOGGER.info(f'{self}: Subscription expired: {binding.subscription} after {time_since_last_attempt:.1f} seconds') + self._update_status(binding, BindingStatus.EXPIRED) + # wait until timeout has passed since last attempt - if binding.last_attempt + self._subscription_timeout > time.time(): + if binding.last_attempt + self._subscription_timeout > now: return - binding.last_attempt = time.time() + binding.last_attempt = now # if we've exceeded the number of retries, mark the subscription as failed if binding.attempts >= self._subscription_retries: @@ -205,9 +217,6 @@ def reconcile_binding(self, binding: Binding): def reconcile_bindings(self): with self._condition: for binding in self._bindings.values(): - if binding.status == binding.intent: - continue - self.reconcile_binding(binding) def subscribe(self, subscription: Subscription) -> SubscriptionHandle: @@ -260,7 +269,8 @@ def invalidate_subscriptions(self): for binding_key, binding in self._bindings.items(): if binding.status == BindingStatus.ACTIVE: binding.status = BindingStatus.DEGRADED - _LOGGER.info(f'{self}: Invalidated: {binding}') + self._update_status(binding, BindingStatus.DEGRADED) + # _LOGGER.info(f'{self}: Invalidated: {binding}') def is_subscription_active(self, binding_key: str) -> Optional[bool]: if not self.has_subscription(binding_key): @@ -298,6 +308,12 @@ def get_active_subscriptions(self): if self.is_subscription_active(binding_key) } + def _update_status(self, binding: Binding, status: BindingStatus): + binding.status = status + binding.attempts = 0 + _LOGGER.info(f'{self}: Updated subscription status: {binding.subscription.binding_key()} -> {status.value}') + self._condition.notify_all() + def _confirm_subscribed(self, binding_key: str): if not self.has_subscription(binding_key): _LOGGER.warning(f'{self}: Unknown subscription {binding_key} - cannot update status to {BindingStatus.ACTIVE.value}') @@ -308,10 +324,7 @@ def _confirm_subscribed(self, binding_key: str): if binding.status == BindingStatus.ACTIVE or binding.intent == BindingStatus.UNSUBSCRIBED: return - binding.status = BindingStatus.ACTIVE - binding.attempts = 0 - _LOGGER.info(f'{self}: Updated subscription status: {binding_key} -> {BindingStatus.ACTIVE.value}') - self._condition.notify_all() + self._update_status(binding, BindingStatus.ACTIVE) def _confirm_unsubscribed(self, binding_key: str): if not self.has_subscription(binding_key): @@ -323,10 +336,7 @@ def _confirm_unsubscribed(self, binding_key: str): if binding.status == BindingStatus.UNSUBSCRIBED or binding.intent == BindingStatus.ACTIVE: return - binding.status = BindingStatus.UNSUBSCRIBED - binding.attempts = 0 - _LOGGER.info(f'{self}: Updated subscription status: {binding_key} -> {BindingStatus.UNSUBSCRIBED.value}') - self._condition.notify_all() + self._update_status(binding, BindingStatus.UNSUBSCRIBED) def wait_for(self, binding_key: str, timeout: float | None = None) -> bool: deadline = None if timeout is None else time.monotonic() + timeout diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index a76ab891..8e194afd 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -12,11 +12,11 @@ from support.logs import project_logger from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, NOOP from ws_v2 import events -from ws_v2.events import WsEvent, EventSink, Router +from ws_v2.events import WsEvent, EventSink, Router, CallbackSink from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect -_LOGGER = project_logger(__file__) +_LOGGER = project_logger('websocket') _DEFAULT_TIMEOUT = 5 @@ -49,6 +49,7 @@ def __init__( url: str, cycle_interval: float, sink: EventSink, + internal_sink: CallbackSink, router: Router, subscription_resolver: SubscriptionResolver, ready_state: Literal[WsState.OPEN, WsState.AUTHENTICATED] = WsState.OPEN, @@ -63,6 +64,7 @@ def __init__( self._url = url self._cycle_interval = cycle_interval self._sink = sink + self._internal_sink = internal_sink self._router = router self._ready_state = ready_state self._connection_timeout = connection_timeout @@ -116,7 +118,7 @@ def state(self, value): self._state = value if self._state == self._ready_state: - self._sink.emit(events.WsReady()) + self._emit(events.WsReady()) def set_authenticated(self, value: bool): if value != self._authenticated: @@ -124,7 +126,7 @@ def set_authenticated(self, value: bool): self._authenticated = value if value and self._state == WsState.OPEN: - self._sink.emit(events.WsAuthenticated()) + self._emit(events.WsAuthenticated()) self.state = WsState.AUTHENTICATED if value == False and self._state == self._ready_state: @@ -136,7 +138,7 @@ def state_degraded(self): self.subscription_controller.invalidate_subscriptions() if not was_already_degraded: - self._sink.emit(events.WsDegraded()) + self._emit(events.WsDegraded()) def get_authenticated(self) -> bool: return self._authenticated @@ -181,6 +183,7 @@ def start(self): self._running = True self._new_runtime_thread() + self._sink.start() connection_success = wait_until(lambda: self._state == self._ready_state, f'{self}: Starting timeout', timeout=self._connection_timeout) return connection_success @@ -189,6 +192,9 @@ def stop(self): if self.state == WsState.STOPPED: return + if threading.current_thread() == self._runtime_thread: + raise RuntimeError(f'{self}: Stopping runtime called from within runtime thread. Ensure it is called from a separate thread') + # wait until one more pass of the runtime thread has occurred to allow unsubscriptions to complete wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) self._wait_event.set() @@ -198,6 +204,7 @@ def stop(self): self.state = WsState.STOPPING self._stop_transport_thread() + self._running = False if self._runtime_thread is not None: self._runtime_thread.join(self._connection_timeout) @@ -207,6 +214,8 @@ def stop(self): self._runtime_thread = None + self._sink.stop() + self.state = WsState.STOPPED def send(self, payload: str) -> bool: @@ -425,10 +434,7 @@ def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover except Exception as e: _LOGGER.error(f'{self}: Exception observing subscription for {event}: {exception_to_string(e)}') - try: - self._sink.emit(event) - except Exception as e: - _LOGGER.error(f'{self}: Exception propagating event {event}: {exception_to_string(e)}') + self._emit(event) def _handle_on_open(self, wsa: WebSocketApp): _LOGGER.info(f'{self}: Connection open') @@ -437,23 +443,24 @@ def _handle_on_open(self, wsa: WebSocketApp): self.state = WsState.OPEN if self._state != self._ready_state: self.set_authenticated(False) - self._sink.emit(events.WsOpen()) - - def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): - _LOGGER.error(f'{self}: on_error: {exception}') - if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: - self.state_degraded() - self.set_authenticated(False) - self._sink.emit(events.WsError(error=exception)) + self._emit(events.WsOpen()) def _handle_on_reconnect(self, wsa: WebSocketApp): - _LOGGER.error(f'{self}: on_reconnect') + _LOGGER.info(f'{self}: on_reconnect') # self._last_heartbeat = time.time() self._last_heartbeat = None self.state = WsState.OPEN if self._state != self._ready_state: self.set_authenticated(False) - self._sink.emit(events.WsOpen()) + self._emit(events.WsOpen()) # we emit Open since reconnect pretty much equivalent + + def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): + _LOGGER.error(f'{self}: on_error: {exception}') + if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: + self.state_degraded() + self.set_authenticated(False) + self._emit(events.WsError(error=exception)) + def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): _LOGGER.info(f'{self}: on_close') @@ -480,4 +487,15 @@ def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): self.set_authenticated(False) self.subscription_controller.invalidate_subscriptions() self.state = WsState.CLOSED - self._sink.emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) + self._emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) + + def _emit(self, event:WsEvent): + try: + self._internal_sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: Internal sink exception for {event}: {exception_to_string(e)}') + + try: + self._sink.emit(event) + except Exception as e: + _LOGGER.error(f'{self}: External sink exception for {event}: {exception_to_string(e)}') diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index 8961fdec..5783dbba 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -9,7 +9,7 @@ from support.logs import project_logger from support.py_utils import exception_to_string, tname, wait_until, UNDEFINED, NOOP -_LOGGER = project_logger(__file__) +_LOGGER = project_logger('websocket') class TransportEvent(BaseModel): From db5caa65f5be605ff182e4bc71cbef9f83adaf54 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 12:42:08 +0200 Subject: [PATCH 13/32] refactor(logs): project_logger now accepts both filepath and a normal string --- ibind/support/logs.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ibind/support/logs.py b/ibind/support/logs.py index 71ed2f5a..edde1a96 100644 --- a/ibind/support/logs.py +++ b/ibind/support/logs.py @@ -44,7 +44,12 @@ def project_logger(filepath=None): Returns: logging.Logger: The project-specific logger instance. """ - return logging.getLogger('ibind' + (f'.{Path(filepath).stem}' if filepath is not None else '')) + logger_name = 'ibind' + if filepath is not None: + child = Path(filepath).stem if isinstance(filepath, Path) else str(filepath) + logger_name += f'.{child}' + + return logging.getLogger(logger_name) _LOGGER = project_logger() From 68db65e590e502b5707f3918b11dbad46a659014 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 15:33:28 +0200 Subject: [PATCH 14/32] fix(ws_v2): small fixes --- examples/ws_04_ws_v2.py | 8 +-- ibind/ibkr_ws_v2/ibkr_router.py | 6 +-- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 3 ++ ibind/ws_v2/events.py | 6 ++- ibind/ws_v2/subscriptions.py | 77 ++++++++++++++------------- ibind/ws_v2/ws_runtime.py | 45 +++++++++------- ibind/ws_v2/ws_transport.py | 4 +- 7 files changed, 81 insertions(+), 68 deletions(-) diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py index 161f859a..ead13f55 100644 --- a/examples/ws_04_ws_v2.py +++ b/examples/ws_04_ws_v2.py @@ -49,8 +49,8 @@ subs = [ # as_sub, # al_sub, - # md_sub, - mh_sub, + md_sub, + # mh_sub, # or_sub, # pnl_sub, # tr_sub @@ -85,7 +85,7 @@ # unsub_handles.append(handle) # # for handle in unsub_handles: -# handle.wait() +# handle.wait(10) # time.sleep(5) ws_client.shutdown() @@ -117,4 +117,4 @@ # print('KeyboardInterrupt') # break # -# stop(None, None) \ No newline at end of file +# stop(None, None) diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index 1ce99fb8..d11f29b8 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -130,12 +130,12 @@ def _handle_subscribed_message(self, channel: str, data: dict) -> OneOrMany[WsEv return None def _handle_account_update(self, message, arguments) -> OneOrMany[WsEvent]: - _LOGGER.info(f'{self}: Account update: {arguments}') + # _LOGGER.info(f'{self}: Account update: {arguments}') return ibkr_events.AccountUpdate(data=arguments) def _handle_authentication_status(self, message, arguments) -> OneOrMany[WsEvent]: if 'authenticated' in arguments or 'competing' in arguments: - _LOGGER.info(f'{self}: Authentication status: {arguments}') + # _LOGGER.info(f'{self}: Authentication status: {arguments}') return ibkr_events.AuthenticationStatus(data=arguments, authenticated=arguments.get('authenticated'), competing=arguments.get('competing')) elif ( # expected status updates that we ignore arguments == {'message': ''} or @@ -159,7 +159,7 @@ def _handle_error(self, message) -> OneOrMany[WsEvent]: def _handle_notification(self, data) -> OneOrMany[WsEvent]: # pragma: no cover events = [] for notification in data: - _LOGGER.info(f'{self}: IBKR notification: {notification}') + # _LOGGER.info(f'{self}: IBKR notification: {notification}') events.append(ibkr_events.Notification(message=notification)) return events diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 4ea96ef3..125baed0 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -154,6 +154,9 @@ def shutdown(self): def hard_reset(self): self._runtime.hard_reset() + def reset_websocket_app(self): + self._runtime.reset_websocket_app() + def subscribe(self, subscription: Subscription) -> SubscriptionHandle: if isinstance(subscription, MarketHistorySubscription): self._mh_subscriptions.append(subscription) diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py index 21879cc3..4886c6ea 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/events.py @@ -9,7 +9,7 @@ from base.queue_controller import QueueController from support.logs import project_logger -from support.py_utils import OneOrMany, exception_to_string +from support.py_utils import OneOrMany, exception_to_string, tname _LOGGER = project_logger('websocket') @@ -179,7 +179,7 @@ def start(self): return self._running = True - self._thread = Thread(target=self._cycle, name="ws_sink_thread", daemon=True) + self._thread = Thread(target=self._cycle, name="async_sink_thread", daemon=True) self._thread.start() def stop(self) -> bool: @@ -239,12 +239,14 @@ def _consume_queue(self): _LOGGER.error(f'{self}: Exception emitting event to sink: {exception_to_string(e)}') def _cycle(self): + _LOGGER.info(f'{self}: AsyncSink thread started ({tname()})') while self._running: self._wait_event.clear() self._wait_event.wait(self._cycle_interval) self._consume_queue() self._consume_queue() + _LOGGER.info(f'{self}: AsyncSink thread stopped ({tname()})') def __str__(self): return f'{self.__class__.__qualname__}({self._queue.qsize()})' diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/subscriptions.py index 455219a8..f4d48d55 100644 --- a/ibind/ws_v2/subscriptions.py +++ b/ibind/ws_v2/subscriptions.py @@ -77,7 +77,6 @@ def done(self) -> bool: return self.status == self.intent def reset(self): - self.status = BindingStatus.NEW self.attempts = 0 self.last_attempt = 0 @@ -171,48 +170,51 @@ def observe(self, event: WsEvent): else: self._confirm_unsubscribed(binding_key) + def _make_attempt(self, binding:Binding): + subscription = binding.subscription + if binding.intent == BindingStatus.ACTIVE: + payload = subscription.subscribe_payload() + self._send(payload) + if not subscription.confirms_subscribe: + _LOGGER.info(f'{self}: Subscribed: {payload} without confirmation.') + self._confirm_subscribed(subscription.binding_key()) + + elif binding.intent == BindingStatus.UNSUBSCRIBED: + payload = subscription.unsubscribe_payload() + self._send(payload) + if not subscription.confirms_unsubscribe: + _LOGGER.info(f'{self}: Unsubscribed: {payload} without confirmation.') + self._confirm_unsubscribed(subscription.binding_key()) + def reconcile_binding(self, binding: Binding): now = time.time() + subscription = binding.subscription + if binding.status == binding.intent or binding.status == BindingStatus.FAILED: + if subscription.expiry_seconds is None: + return - if binding.status == binding.intent: time_since_last_attempt = now - binding.last_attempt - if binding.subscription.expiry_seconds is None or time_since_last_attempt < binding.subscription.expiry_seconds: + if time_since_last_attempt < subscription.expiry_seconds: return - _LOGGER.info(f'{self}: Subscription expired: {binding.subscription} after {time_since_last_attempt:.1f} seconds') + _LOGGER.info(f'{self}: Subscription expired: {subscription} after {time_since_last_attempt:.1f} seconds') self._update_status(binding, BindingStatus.EXPIRED) - # wait until timeout has passed since last attempt - if binding.last_attempt + self._subscription_timeout > now: - return - binding.last_attempt = now - # if we've exceeded the number of retries, mark the subscription as failed if binding.attempts >= self._subscription_retries: _LOGGER.info(f'{self}: Subscription failed after {self._subscription_retries} attempts: {binding}') - binding.status = BindingStatus.FAILED - binding.attempts = 0 - self._condition.notify_all() + self._update_status(binding, BindingStatus.FAILED) return - binding.attempts += 1 - - subscription = binding.subscription + # wait until timeout has passed since last attempt + if binding.last_attempt + self._subscription_timeout > now: + return - if binding.intent == BindingStatus.ACTIVE: - payload = subscription.subscribe_payload() - self._send(payload) - if not subscription.confirms_subscribe: - _LOGGER.info(f'{self}: Subscribed: {payload} without confirmation.') - self._confirm_subscribed(subscription.binding_key()) + binding.last_attempt = now + binding.attempts += 1 + self._make_attempt(binding) - elif binding.intent == BindingStatus.UNSUBSCRIBED: - payload = subscription.unsubscribe_payload() - self._send(payload) - if not subscription.confirms_unsubscribe: - _LOGGER.info(f'{self}: Unsubscribed: {payload} without confirmation.') - self._confirm_unsubscribed(subscription.binding_key()) def reconcile_bindings(self): with self._condition: @@ -266,16 +268,16 @@ def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: return SubscriptionHandle(self, subscription) def invalidate_subscriptions(self): - for binding_key, binding in self._bindings.items(): - if binding.status == BindingStatus.ACTIVE: - binding.status = BindingStatus.DEGRADED - self._update_status(binding, BindingStatus.DEGRADED) - # _LOGGER.info(f'{self}: Invalidated: {binding}') + with self._condition: + for binding_key, binding in self._bindings.items(): + if binding.status != BindingStatus.DEGRADED: + self._update_status(binding, BindingStatus.DEGRADED) def is_subscription_active(self, binding_key: str) -> Optional[bool]: - if not self.has_subscription(binding_key): - return False - return self._bindings[binding_key].status == BindingStatus.ACTIVE + with self._condition: + if not self.has_subscription(binding_key): + return False + return self._bindings[binding_key].status == BindingStatus.ACTIVE def has_active_subscriptions(self) -> bool: with self._condition: @@ -285,8 +287,7 @@ def has_active_subscriptions(self) -> bool: return False def has_subscription(self, binding_key: str) -> bool: - with self._condition: - return binding_key in self._bindings + return binding_key in self._bindings def get_status(self, binding_key: str) -> BindingStatus | None: with self._condition: @@ -309,9 +310,9 @@ def get_active_subscriptions(self): } def _update_status(self, binding: Binding, status: BindingStatus): + _LOGGER.info(f'{self}: Updated subscription status: {binding.subscription.binding_key()} {binding.status.value} -> {status.value}') binding.status = status binding.attempts = 0 - _LOGGER.info(f'{self}: Updated subscription status: {binding.subscription.binding_key()} -> {status.value}') self._condition.notify_all() def _confirm_subscribed(self, binding_key: str): diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index 8e194afd..b672a394 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -118,7 +118,12 @@ def state(self, value): self._state = value if self._state == self._ready_state: - self._emit(events.WsReady()) + self._websocket_ready() + + def _websocket_ready(self): + self._emit(events.WsReady()) + self._last_heartbeat = time.time() + _LOGGER.info(f'{self}: Websocket ready, setting last_heartbeat to {self._last_heartbeat}') def set_authenticated(self, value: bool): if value != self._authenticated: @@ -172,13 +177,15 @@ def _stop_transport_thread(self) -> bool: return False def start(self): - if self.state != WsState.STOPPED: + if self._state != WsState.STOPPED: return if self._runtime_thread is not None and self._runtime_thread.is_alive(): _LOGGER.error(f'{self}: Runtime thread must be stopped and joined before starting') return + _LOGGER.info(f'{self}: Starting runtime') + self.state = WsState.STARTING self._running = True @@ -189,12 +196,14 @@ def start(self): return connection_success def stop(self): - if self.state == WsState.STOPPED: + if self._state == WsState.STOPPED: return if threading.current_thread() == self._runtime_thread: raise RuntimeError(f'{self}: Stopping runtime called from within runtime thread. Ensure it is called from a separate thread') + _LOGGER.info(f'{self}: Stopping runtime') + # wait until one more pass of the runtime thread has occurred to allow unsubscriptions to complete wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) self._wait_event.set() @@ -204,7 +213,6 @@ def stop(self): self.state = WsState.STOPPING self._stop_transport_thread() - self._running = False if self._runtime_thread is not None: self._runtime_thread.join(self._connection_timeout) @@ -258,8 +266,8 @@ def restart_transport(self): self._new_transport() self._new_transport_thread() - def reset_transport_websocket(self): - self._transport.reset() + def reset_websocket_app(self): + self._transport.reset_websocket_app() def __str__(self): return f'{self.__class__.__qualname__}({self._state})' @@ -283,7 +291,6 @@ def _maintain_transport(self): return if self._transport_thread is None or not self._transport_thread.is_alive(): - _LOGGER.info(f'{self}: Starting new transport thread') self.state = WsState.CONNECTING self._new_transport_thread() @@ -334,13 +341,13 @@ def health_check(self) -> bool: if not self.check_should_restart(): return True - _LOGGER.warning(f'{self}: Health check failed, resetting transport websocket') - self.state_degraded() - if not self._running: # return early if runtime got stopped in the meantime return False - self.reset_transport_websocket() + self.state_degraded() + + _LOGGER.warning(f'{self}: Health check failed, resetting transport websocket') + self.reset_websocket_app() # if wait_until(lambda: self._state == self._ready_state, timeout=self._connection_timeout): # _LOGGER.info(f'Health recovered by resetting transport WebSocket') @@ -446,7 +453,7 @@ def _handle_on_open(self, wsa: WebSocketApp): self._emit(events.WsOpen()) def _handle_on_reconnect(self, wsa: WebSocketApp): - _LOGGER.info(f'{self}: on_reconnect') + _LOGGER.info(f'{self}: Connection reopened') # self._last_heartbeat = time.time() self._last_heartbeat = None self.state = WsState.OPEN @@ -455,7 +462,7 @@ def _handle_on_reconnect(self, wsa: WebSocketApp): self._emit(events.WsOpen()) # we emit Open since reconnect pretty much equivalent def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): - _LOGGER.error(f'{self}: on_error: {exception}') + _LOGGER.error(f'{self}: Connection error: {exception}') if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: self.state_degraded() self.set_authenticated(False) @@ -463,7 +470,7 @@ def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): - _LOGGER.info(f'{self}: on_close') + _LOGGER.info(f'{self}: Connection closed') self._last_heartbeat = None # if we're not connected we shouldn't need to do anything @@ -479,13 +486,13 @@ def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): _LOGGER.error(f'{self}: on_close error: {close_status_code} | {msg}') - elif self.state == WsState.STOPPING: - _LOGGER.info(f'{self}: Gracefully closed') + + if self.state != WsState.STOPPING: + self.set_authenticated(False) + self.subscription_controller.invalidate_subscriptions() else: - _LOGGER.info(f'{self}: Connection closed') + _LOGGER.info(f'{self}: Gracefully closed') - self.set_authenticated(False) - self.subscription_controller.invalidate_subscriptions() self.state = WsState.CLOSED self._emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index 5783dbba..e7f51248 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -84,9 +84,9 @@ def stop(self): self._running = False self.disconnect() - def reset(self) -> bool: + def reset_websocket_app(self) -> bool: if tname() == self._tname: - raise RuntimeError(f'{self}: Resetting transport thread called from within transport thread. Ensure it is called from a separate thread') + raise RuntimeError(f'{self}: Resetting websocket app called from within transport thread. Ensure it is called from a separate thread') if self._wsa is None: _LOGGER.info(f'{self}: WSA is None, skipping reset') From 93c93fa57c3d0328f6e6ccf9bf7429e1416ebe63 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 18:33:19 +0200 Subject: [PATCH 15/32] feat(ws_v2): added skip_utf8_validation, added wait_for(handles), renamed 'channel' to 'topic', implemented QueueSink as replacement of QueueController/Accessor chore(ws_v2): cleaned up ws_runtime and ws_transport --- examples/ws_04_ws_v2.py | 16 ++- ibind/ibkr_ws_v2/ibkr_events.py | 26 ++--- ibind/ibkr_ws_v2/ibkr_router.py | 29 ++--- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 10 +- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 31 +++-- ibind/var.py | 3 + ibind/ws_v2/events.py | 48 +++++--- ibind/ws_v2/subscriptions.py | 6 +- ibind/ws_v2/ws_runtime.py | 155 ++++++++++--------------- ibind/ws_v2/ws_transport.py | 40 ++++--- 10 files changed, 194 insertions(+), 170 deletions(-) diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py index ead13f55..b5c997e4 100644 --- a/examples/ws_04_ws_v2.py +++ b/examples/ws_04_ws_v2.py @@ -15,8 +15,10 @@ from typing import List from ibind import ibind_logs_initialize -from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PriceLadderSubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription +from ibkr_ws_v2.ibkr_events import IbkrWsKey +from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription from ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 +from ws_v2.events import LogSink, QueueSink from ws_v2.subscriptions import SubscriptionHandle ibind_logs_initialize(log_to_file=False, log_level='DEBUG') @@ -24,8 +26,11 @@ account_id = os.getenv('IBIND_ACCOUNT_ID', '[YOUR_ACCOUNT_ID]') cacert = os.getenv('IBIND_CACERT', False) # insert your cacert path here +queue_sink = QueueSink(list(IbkrWsKey)) + # ws_client = IbkrWsClient(cacert=cacert, account_id=account_id) -ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id) +# ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=LogSink()) +ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=queue_sink) # def stop(_, _1): # print('exit') @@ -56,6 +61,8 @@ # tr_sub ] + + sub_handles: List[SubscriptionHandle] = [] for sub in subs: handle = ws_client.subscribe(sub) @@ -69,6 +76,11 @@ try: while ws_client.is_running(): + for sub in subs: + while not queue_sink.empty(sub.key): + ev = queue_sink.get(sub.key) + print(ev) + time.sleep(1) except KeyboardInterrupt: print('Interrupt') diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py index 5c486baa..9dc931c9 100644 --- a/ibind/ibkr_ws_v2/ibkr_events.py +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Any from pydantic import Field @@ -8,7 +7,6 @@ class IbkrWsKey(Enum): # generic - UNCLASSIFIED = 'UNCLASSIFIED' GENERIC = 'GENERIC' UNSUBSCRIPTION = 'UNSUBSCRIPTION' SERVER_ID = 'SERVER_ID' @@ -35,8 +33,8 @@ class IbkrWsKey(Enum): CLIENT_INTERNAL = 'CLIENT_INTERNAL' @classmethod - def from_channel(cls, channel): - channel_to_key = { + def from_topic(cls, topic): + topic_to_key = { 'sd': IbkrWsKey.ACCOUNT_SUMMARY, 'ld': IbkrWsKey.ACCOUNT_LEDGER, 'md': IbkrWsKey.MARKET_DATA, @@ -46,17 +44,17 @@ def from_channel(cls, channel): 'pl': IbkrWsKey.PNL, 'tr': IbkrWsKey.TRADES, } - if channel in channel_to_key: - return channel_to_key[channel] - raise ValueError(f"No enum member associated with channel '{channel}'") + if topic in topic_to_key: + return topic_to_key[topic] + raise ValueError(f"No enum member associated with topic '{topic}'") @property - def channel(self): + def topic(self): """ - Gets the solicited channel string associated with the enum member. + Gets the solicited topic string associated with the enum member. Returns: - str: The channel string corresponding to the enum member. + str: The topic string corresponding to the enum member. """ return { IbkrWsKey.ACCOUNT_SUMMARY: 'sd', @@ -69,14 +67,15 @@ def channel(self): IbkrWsKey.TRADES: 'tr', }[self] + def __str__(self): + return self.value + class GenericIbkrEvent(WsEvent): - key: str = IbkrWsKey.UNCLASSIFIED + key: str = IbkrWsKey.GENERIC message: dict | None topic: str | None = None data: dict | None = None - subscribed: str | None = None - channel: str | None = None # =================== @@ -145,7 +144,6 @@ class MarketData(WsEvent): key: IbkrWsKey = IbkrWsKey.MARKET_DATA conid: str data: dict = Field(default_factory=dict) - fields: dict[str, Any] = Field(default_factory=dict) class MarketHistory(WsEvent): diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index d11f29b8..bd7d3f44 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -10,7 +10,7 @@ from support.py_utils import UNDEFINED, OneOrMany from ws_v2.events import WsEvent -_LOGGER = project_logger('websocket') +_LOGGER = project_logger('ibkr_ws_client') def parse_raw_message(raw_message: str): @@ -22,11 +22,7 @@ def parse_raw_message(raw_message: str): data = message.get('args', {}) - # subscribed is the indicator of whether it was a subscription or unsubscription, defined by the first letter - # channel is the actual channel we received the information about - subscribed, channel = topic[0], topic[1:] - - return message, topic, data, subscribed, channel + return message, topic, data class IbkrRouter(): @@ -50,12 +46,11 @@ def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: if not self._unwrap_market_data: return ibkr_events.MarketData(conid=data['conid'], data=data) - fields = {} + unwrapped_data = {} for key, value in data.items(): if key in ibkr_definitions.snapshot_by_id: - # result[ibkr_definitions.snapshot_by_id[key]] = value - fields[ibkr_definitions.snapshot_by_id[key]] = value - return ibkr_events.MarketData(conid=str(data['conid']), fields=fields) + unwrapped_data[ibkr_definitions.snapshot_by_id[key]] = value + return ibkr_events.MarketData(conid=str(data['conid']), data=unwrapped_data) def _preprocess_market_history_message(self, data: dict) -> OneOrMany[WsEvent]: mh_server_id_conid_pairs = self._server_id_conid_pairs[IbkrWsKey.MARKET_HISTORY] @@ -102,11 +97,11 @@ def _preprocess_account_summary(self, data): event = ibkr_events.AccountSummary(data=summary, account_id=account_id) return event - def _handle_subscribed_message(self, channel: str, data: dict) -> OneOrMany[WsEvent] | None: + def _handle_subscribed_message(self, topic: str, data: dict) -> OneOrMany[WsEvent] | None: try: - ibkr_ws_key = IbkrWsKey.from_channel(channel[:2]) + ibkr_ws_key = IbkrWsKey.from_topic(topic[1:3]) except ValueError: - # ValueError means we don't support this channel + # ValueError means we don't support this topic return None if ibkr_ws_key == IbkrWsKey.ACCOUNT_SUMMARY: @@ -199,7 +194,7 @@ def _handle_message_without_topic(self, message: dict) -> OneOrMany[WsEvent]: def route(self, raw_message: str) -> OneOrMany[WsEvent]: if self._log_raw_messages: _LOGGER.debug(f'{self}: Raw message: {raw_message}') - message, topic, arguments, subscribed, channel = parse_raw_message(raw_message) + message, topic, arguments = parse_raw_message(raw_message) if 'error' in message: return self._handle_error(message) @@ -231,10 +226,10 @@ def route(self, raw_message: str) -> OneOrMany[WsEvent]: return self._handle_error(message) else: - events = self._handle_subscribed_message(channel, message) + events = self._handle_subscribed_message(topic, message) if events is None: - _LOGGER.error(f'{self}: Channel "{channel}" subscribed but lacking a handler. Message: {message}') - events = GenericIbkrEvent(message=message, topic=topic, data=arguments, subscribed=subscribed, channel=channel) + _LOGGER.error(f'{self}: topic "{topic}" subscribed but lacking a handler. Message: {message}') + events = GenericIbkrEvent(message=message, topic=topic, data=arguments) return events diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index 14a4246d..a6322b9a 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -15,13 +15,13 @@ def make_binding_key( exchange=None ): if key in [IbkrWsKey.MARKET_DATA, IbkrWsKey.MARKET_HISTORY]: - return f"{key.channel}+{conid}" + return f"{key.topic}+{conid}" elif key in [IbkrWsKey.ACCOUNT_LEDGER, IbkrWsKey.ACCOUNT_SUMMARY]: - return f"{key.channel}+{account_id}" + return f"{key.topic}+{account_id}" elif key in [IbkrWsKey.PRICE_LADDER]: - return f"{key.channel}+{account_id}+{conid}" + (f"+{exchange}" if exchange is not None else '') + return f"{key.topic}+{account_id}+{conid}" + (f"+{exchange}" if exchange is not None else '') elif key in [IbkrWsKey.ORDERS, IbkrWsKey.PNL, IbkrWsKey.TRADES]: - return key.channel + return key.topic else: raise ValueError(f'Unsupported key: {key}') @@ -72,7 +72,7 @@ class IbkrSubscription(Subscription): @property def topic(self) -> str: - return self.key.channel + return self.key.topic class AccountSummarySubscription(IbkrSubscription): diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 125baed0..1db5f031 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -9,11 +9,12 @@ from ibkr_ws_v2.ibkr_router import IbkrRouter from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription from support.logs import project_logger -from ws_v2.events import EventSink, LogSink, CallbackSink, CompositeSink, Router, AsyncSink -from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle +from support.py_utils import OneOrMany, ensure_list_arg +from ws_v2.events import EventSink, LogSink, CallbackSink, Router, AsyncSink, NoopSink +from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle, BindingStatus from ws_v2.ws_runtime import WsRuntime, WsState -_LOGGER = project_logger('websocket') +_LOGGER = project_logger('ibkr_ws_client') _DEFAULT_CYCLE_INTERVAL = 0.25 @@ -58,15 +59,15 @@ def __init__( self._use_oauth = use_oauth self._recreate_subscriptions_on_reconnect = recreate_subscriptions_on_reconnect - self._queue_controller = QueueController[IbkrWsKey]() - self._queue_controller.register_queues(list(IbkrWsKey)) + # self._queue_controller = QueueController[IbkrWsKey]() + # self._queue_controller.register_queues(list(IbkrWsKey)) if sink is None: # self._queue_controller.register_queues(['CLIENT_INTERNAL', 'IBKR']) # sink = QueueSink(queue_controller=self._queue_controller) - sink = LogSink() - # sink = NoopSink() + # sink = LogSink() + sink = NoopSink() self._internal_sink = CallbackSink() self._register_internal_callbacks() @@ -91,6 +92,7 @@ def __init__( internal_sink=self._internal_sink, router=router, subscription_resolver=subscription_resolver, + connection_timeout=5, get_cookie=self._get_cookie, get_header=self._get_header, ) @@ -167,6 +169,9 @@ def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: self._handle_mh_unsubscription(subscription) return self._runtime.subscription_controller.unsubscribe(subscription) + def get_status(self, binding_key: str) -> BindingStatus: + return self._runtime.subscription_controller.get_status(binding_key) + def get_server_id(self, key: IbkrWsKey, conid: str) -> str: return self._conid_server_id_pairs[key][conid] @@ -185,8 +190,20 @@ def _handle_mh_unsubscription(self, subscription: MarketHistorySubscription): ) subscription.set_server_id(server_id) + @ensure_list_arg('subscription_handles') + def wait_all(self, subscription_handles: OneOrMany[SubscriptionHandle], timeout: float | None = None) -> List[SubscriptionHandle]: + failed = [] + for subscription_handle in subscription_handles: + if not (subscription_handle.wait(timeout)): + failed.append(subscription_handle) + return failed + def is_running(self) -> bool: return self._runtime.is_running() + + def get_state(self) -> WsState: + return self._runtime.get_state() + def __str__(self): return f'{self.__class__.__qualname__}()' diff --git a/ibind/var.py b/ibind/var.py index 246710d8..8d062744 100644 --- a/ibind/var.py +++ b/ibind/var.py @@ -102,6 +102,9 @@ def to_bool(value): IBIND_WS_LOG_RAW_MESSAGES = to_bool(os.environ.get('IBIND_WS_LOG_RAW_MESSAGES', False)) """ Whether raw WebSocket messages should be logged. """ +IBIND_WS_SKIP_UTF8_VALIDATION = to_bool(os.environ.get('IBIND_WS_SKIP_UTF8_VALIDATION', True)) +""" Whether to skip UTF-8 validation for WebSocket messages. """ + ##### OAuth common ##### IBIND_USE_OAUTH = to_bool(os.environ.get('IBIND_USE_OAUTH', False)) diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py index 4886c6ea..3823aa51 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/events.py @@ -3,15 +3,15 @@ from datetime import datetime from queue import Queue, Full, Empty from threading import Thread, Event -from typing import Hashable, Protocol, Callable, TypeVar, List, Dict +from typing import Hashable, Protocol, Callable, TypeVar, List, Dict, Any from pydantic import BaseModel, ConfigDict, Field -from base.queue_controller import QueueController +from base.queue_controller import QueueAccessor from support.logs import project_logger -from support.py_utils import OneOrMany, exception_to_string, tname +from support.py_utils import OneOrMany, exception_to_string, tname, ensure_list_arg -_LOGGER = project_logger('websocket') +_LOGGER = project_logger('ibkr_ws_client') # ====================== @@ -94,12 +94,6 @@ class EventSink(Protocol): def emit(self, event: "WsEvent") -> None: pass - def start(self): - pass - - def stop(self): - pass - class LogSink: def emit(self, event: WsEvent) -> None: @@ -133,11 +127,39 @@ def __str__(self): class QueueSink: - def __init__(self, queue_controller: QueueController): - self._queue_controller = queue_controller + def __init__(self, event_types: List[Hashable]): + self._queues = { + 'CLIENT_INTERNAL': Queue() + } + self.register_queues(event_types) + + @ensure_list_arg('keys') + def register_queues(self, keys: OneOrMany[Hashable]): + for key in keys: + if key not in self._queues: + self._queues[str(key)] = Queue() + + def new_queue_accessor(self, key: Hashable) -> QueueAccessor: + return QueueAccessor(self._get_queue(key), key) + + def _get_queue(self, key: Hashable) -> Queue: # pragma: no cover + try: + return self._queues[str(key)] + except KeyError: + raise AttributeError(f'Invalid queue key: "{key}", expected: {list(self._queues.keys())}') + + def get(self, key: Hashable, block: bool = False, timeout=None) -> Any: + try: + return self._get_queue(key).get(block=block, timeout=timeout) + except Empty: + return None + + def empty(self, key: Hashable) -> bool: + return self._get_queue(key).empty() def emit(self, event: WsEvent) -> None: - self._queue_controller.put_to_queue(event.key, event) + queue = self._get_queue(event.key) + queue.put(event) class CompositeSink: diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/subscriptions.py index f4d48d55..593bca72 100644 --- a/ibind/ws_v2/subscriptions.py +++ b/ibind/ws_v2/subscriptions.py @@ -10,7 +10,7 @@ from ibind.support.py_utils import exception_to_string from ws_v2.events import WsEvent -_LOGGER = project_logger('websocket') +_LOGGER = project_logger('ibkr_ws_client') class Subscription(BaseModel): @@ -116,9 +116,9 @@ def unsubscribe(self) -> "SubscriptionHandle": class SubscriptionController: """ - Mixin which manages subscriptions to different channels using the WsClient. + Mixin which manages subscriptions to different topics using the WsClient. - This class handles the logic for subscribing and unsubscribing to various channels. It maintains a + This class handles the logic for subscribing and unsubscribing to various topics. It maintains a record of active subscriptions and provides methods to modify them. The class relies on a SubscriptionProcessor to create subscription and unsubscription payloads. diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index b672a394..4efcec2a 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -7,18 +7,18 @@ from threading import Thread, Event from typing import Union, List, Dict, Callable, Literal -from websocket import WebSocketApp - from support.logs import project_logger from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, NOOP from ws_v2 import events -from ws_v2.events import WsEvent, EventSink, Router, CallbackSink +from ws_v2.events import WsEvent, EventSink, Router, CallbackSink, AsyncSink from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect -_LOGGER = project_logger('websocket') +_LOGGER = project_logger('ibkr_ws_client') _DEFAULT_TIMEOUT = 5 +_MAX_TRANSPORT_EVENT_RETRIES = 5 +_HEALTH_CHECK_INTERVAL = 10 class WsState(VerboseEnum): @@ -55,6 +55,7 @@ def __init__( ready_state: Literal[WsState.OPEN, WsState.AUTHENTICATED] = WsState.OPEN, cacert: Union[str, bool] = False, connection_timeout: float = _DEFAULT_TIMEOUT, + reconnect_timeout: float | None = _DEFAULT_TIMEOUT, max_ping_interval: float = 20, get_cookie: Callable = NOOP, get_header: Callable = NOOP, @@ -68,6 +69,7 @@ def __init__( self._router = router self._ready_state = ready_state self._connection_timeout = connection_timeout + self._reconnect_timeout = reconnect_timeout self._max_ping_interval = max_ping_interval self._state = WsState.STOPPED @@ -89,30 +91,23 @@ def __init__( self._get_cookie = get_cookie self._get_header = get_header - self._transport: WsTransport | None = None - - self._new_transport() + self._transport: WsTransport = self._new_transport() self.subscription_controller = SubscriptionController(send_payload=self.send, subscription_resolver=subscription_resolver) def _new_transport(self): - self._transport = WsTransport( + return WsTransport( url=self._url, event_callback=self._transport_callback, sslopt=self._sslopt, get_cookie=self._get_cookie, get_header=self._get_header, - max_ping_interval=self._max_ping_interval + max_ping_interval=self._max_ping_interval, + connection_timeout=self._connection_timeout, + reconnect_timeout=self._reconnect_timeout, ) - @property - def state(self): - _LOGGER.info(f'{self}: State: {self._state.value}') - with self._state_lock: - return self._state - - @state.setter - def state(self, value): + def _set_state(self, value): _LOGGER.info(f'{self}: {self._state.value} -> {value.value}') with self._state_lock: self._state = value @@ -120,6 +115,9 @@ def state(self, value): if self._state == self._ready_state: self._websocket_ready() + def get_state(self) -> WsState: + return self._state + def _websocket_ready(self): self._emit(events.WsReady()) self._last_heartbeat = time.time() @@ -132,14 +130,14 @@ def set_authenticated(self, value: bool): if value and self._state == WsState.OPEN: self._emit(events.WsAuthenticated()) - self.state = WsState.AUTHENTICATED + self._set_state(WsState.AUTHENTICATED) if value == False and self._state == self._ready_state: self.state_degraded() def state_degraded(self): was_already_degraded = self._state == WsState.DEGRADED - self.state = WsState.DEGRADED + self._set_state(WsState.DEGRADED) self.subscription_controller.invalidate_subscriptions() if not was_already_degraded: @@ -186,11 +184,13 @@ def start(self): _LOGGER.info(f'{self}: Starting runtime') - self.state = WsState.STARTING + self._set_state(WsState.STARTING) self._running = True self._new_runtime_thread() - self._sink.start() + + if isinstance(self._sink, AsyncSink): + self._sink.start() connection_success = wait_until(lambda: self._state == self._ready_state, f'{self}: Starting timeout', timeout=self._connection_timeout) return connection_success @@ -210,7 +210,7 @@ def stop(self): wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) # TODO: decide which thread should stop first - transport or runtime - self.state = WsState.STOPPING + self._set_state(WsState.STOPPING) self._stop_transport_thread() self._running = False @@ -222,9 +222,10 @@ def stop(self): self._runtime_thread = None - self._sink.stop() + if isinstance(self._sink, AsyncSink): + self._sink.stop() - self.state = WsState.STOPPED + self._set_state(WsState.STOPPED) def send(self, payload: str) -> bool: if self._state != self._ready_state: @@ -263,7 +264,7 @@ def restart_transport(self): self._transport_thread = None self._transport.set_degraded(True) - self._new_transport() + self._transport = self._new_transport() self._new_transport_thread() def reset_websocket_app(self): @@ -291,7 +292,7 @@ def _maintain_transport(self): return if self._transport_thread is None or not self._transport_thread.is_alive(): - self.state = WsState.CONNECTING + self._set_state(WsState.CONNECTING) self._new_transport_thread() def _maintain_subscriptions(self): @@ -300,7 +301,7 @@ def _maintain_subscriptions(self): self.subscription_controller.reconcile_bindings() - def check_should_restart(self): + def check_should_reset(self): # If WSA is not ready, we don't try to fix health if not self._transport.is_ready(): return False @@ -315,12 +316,8 @@ def check_should_restart(self): f'{self}: Last WebSocket ping happened {self._transport.get_time_since_last_ping():.2f} seconds ago, ' f'exceeding the max ping interval of {self._max_ping_interval}.' ) - return False - - # cookie_ok = self._transport.check_cookie() - # if cookie_ok is not None and not cookie_ok: - # _LOGGER.warning(f'{self}: Cookie check failed') - # return True + # If we have a reconnect timeout, we let WSA handle the reconnect, otherwise let's reset the WSA + return self._reconnect_timeout is None heartbeat_ok = True if self._last_heartbeat is not None: @@ -338,7 +335,7 @@ def check_should_restart(self): return True def health_check(self) -> bool: - if not self.check_should_restart(): + if not self.check_should_reset(): return True if not self._running: # return early if runtime got stopped in the meantime @@ -348,22 +345,6 @@ def health_check(self) -> bool: _LOGGER.warning(f'{self}: Health check failed, resetting transport websocket') self.reset_websocket_app() - - # if wait_until(lambda: self._state == self._ready_state, timeout=self._connection_timeout): - # _LOGGER.info(f'Health recovered by resetting transport WebSocket') - # return True - - # if not self._running: # return early if runtime got stopped in the meantime - # return False - # - # _LOGGER.warning(f'{self}: Resetting transport websocket failed, restarting transport') - # self.restart_transport() - # - # if wait_until(lambda: self._state == self._ready_state, timeout=self._connection_timeout): - # _LOGGER.info(f'Health recovered by resetting transport thread') - # return True - # - # _LOGGER.error(f'{self}: Resetting transport websocket failed') return False def _cycle(self): @@ -374,21 +355,15 @@ def _cycle(self): self._process_transport_queue() - if time.time() - self._last_health_check > 10: + if time.time() - self._last_health_check > _HEALTH_CHECK_INTERVAL: self._last_health_check = time.time() self.health_check() - # if time.time() - self._last_tic > 5: - # if self._transport.is_ready(): - # _LOGGER.debug(f'{self}: Sending tic') - # self._transport.send('tic') - # self._last_tic = time.time() - self._wait_event.clear() self._wait_event.wait(self._cycle_interval) # if not stopped or closed yet, attempt to do one last pass before the thread dies - if self.state not in [WsState.STOPPED, WsState.CLOSED]: + if self._state not in [WsState.STOPPED, WsState.CLOSED]: # final pass through the transport queue to flush any remaining events self._process_transport_queue() @@ -404,33 +379,38 @@ def _process_transport_queue(self): try: self._handle_transport_event(te) except Exception as e: - _LOGGER.error(f'{self}: Exception processing transport event: {exception_to_string(e)} for {te}') + _LOGGER.error(f'{self}: Exception processing transport event {te}: {exception_to_string(e)}') + te.attempt += 1 + if te.attempt > _MAX_TRANSPORT_EVENT_RETRIES: + _LOGGER.error(f'{self}: Max retries ({_MAX_TRANSPORT_EVENT_RETRIES}) reached for transport event {te}, dropping event.') + continue retry_events.append(te) + for event in retry_events: self._transport_queue.put(event) - def _handle_transport_event(self, te: TransportEvent): - if isinstance(te, TransportOpened): - self._handle_on_open(te.wsa) - elif isinstance(te, TransportClosed): - self._handle_on_close(te.wsa, te.close_status_code, te.close_msg) - elif isinstance(te, TransportError): - self._handle_on_error(te.wsa, te.exception) - elif isinstance(te, TransportMessage): - self._handle_on_message(te.wsa, te.message) - elif isinstance(te, TransportReconnect): - self._handle_on_reconnect(te.wsa) + def _handle_transport_event(self, transport_event: TransportEvent): + if isinstance(transport_event, TransportOpened): + self._handle_on_open() + elif isinstance(transport_event, TransportReconnect): + self._handle_on_reconnect() + elif isinstance(transport_event, TransportClosed): + self._handle_on_close(transport_event.close_status_code, transport_event.close_msg) + elif isinstance(transport_event, TransportError): + self._handle_on_error(transport_event.exception) + elif isinstance(transport_event, TransportMessage): + self._handle_on_message(transport_event.message) else: - _LOGGER.error(f'{self}: Unknown event type: {type(te)}: {te}') + _LOGGER.error(f'{self}: Unknown event type: {type(transport_event)}: {transport_event}') - def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover + def _handle_on_message(self, message): # pragma: no cover events: OneOrMany[WsEvent] = self._router.route(message) # Router decided to skip this message if events is None: return - # Handle lists and individual events + # Handle both lists and individual events if not isinstance(events, list) and isinstance(events, WsEvent): events = [events] @@ -443,41 +423,33 @@ def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover self._emit(event) - def _handle_on_open(self, wsa: WebSocketApp): + def _handle_on_open(self): _LOGGER.info(f'{self}: Connection open') - # self._last_heartbeat = time.time() self._last_heartbeat = None - self.state = WsState.OPEN + self._set_state(WsState.OPEN) if self._state != self._ready_state: self.set_authenticated(False) self._emit(events.WsOpen()) - def _handle_on_reconnect(self, wsa: WebSocketApp): + def _handle_on_reconnect(self): _LOGGER.info(f'{self}: Connection reopened') - # self._last_heartbeat = time.time() self._last_heartbeat = None - self.state = WsState.OPEN + self._set_state(WsState.OPEN) if self._state != self._ready_state: self.set_authenticated(False) - self._emit(events.WsOpen()) # we emit Open since reconnect pretty much equivalent + self._emit(events.WsOpen()) # we emit Open since reconnect pretty much equivalent - def _handle_on_error(self, wsa: WebSocketApp, exception: Exception): + def _handle_on_error(self, exception: Exception): _LOGGER.error(f'{self}: Connection error: {exception}') if str(exception) in ['Connection to remote host was lost.', 'No connection could be made because the target machine actively refused it']: self.state_degraded() self.set_authenticated(False) self._emit(events.WsError(error=exception)) - - def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): + def _handle_on_close(self, close_status_code, close_msg): _LOGGER.info(f'{self}: Connection closed') self._last_heartbeat = None - # if we're not connected we shouldn't need to do anything - # if self.state not in [self._ready_state, WsState.OPEN, WsState.STOPPING]: ## not self._connected: - # _LOGGER.info(f'{self}: Unexpected on_close event while not open') - # return - if close_status_code is not None or close_msg is not None: # this means an error try: msg = close_msg.decode('utf-8') @@ -486,17 +458,16 @@ def _handle_on_close(self, wsa: WebSocketApp, close_status_code, close_msg): _LOGGER.error(f'{self}: on_close error: {close_status_code} | {msg}') - - if self.state != WsState.STOPPING: + if self._state != WsState.STOPPING: self.set_authenticated(False) self.subscription_controller.invalidate_subscriptions() else: _LOGGER.info(f'{self}: Gracefully closed') - self.state = WsState.CLOSED + self._set_state(WsState.CLOSED) self._emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) - def _emit(self, event:WsEvent): + def _emit(self, event: WsEvent): try: self._internal_sink.emit(event) except Exception as e: diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index e7f51248..86dc699a 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -5,18 +5,18 @@ from pydantic import BaseModel, ConfigDict, Field from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION, STATUS_NORMAL +import var from ibind import ExternalBrokerError from support.logs import project_logger from support.py_utils import exception_to_string, tname, wait_until, UNDEFINED, NOOP -_LOGGER = project_logger('websocket') +_LOGGER = project_logger('ibkr_ws_client') class TransportEvent(BaseModel): model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) - received_at: datetime = Field(default_factory=datetime.now) - wsa: WebSocketApp + attempt: int = 0 def __str__(self): return f'{self.__class__.__qualname__}()' @@ -32,7 +32,6 @@ class TransportClosed(TransportEvent): class TransportError(TransportEvent): - model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) exception: Exception @@ -57,16 +56,20 @@ def __init__( ping_timeout: float = 10, max_ping_interval: float = 20, connection_timeout: float = 5, + reconnect_timeout: float = 5, + skip_utf8_validation: bool = var.IBIND_WS_SKIP_UTF8_VALIDATION, ): self._url = url self._event_callback = event_callback + self._sslopt = sslopt self._get_cookie = get_cookie self._get_header = get_header self._ping_interval = ping_interval self._ping_timeout = ping_timeout self._max_ping_interval = max_ping_interval self._connection_timeout = connection_timeout - self._sslopt = sslopt + self._reconnect_timeout = reconnect_timeout + self._skip_utf8_validation = skip_utf8_validation self._running = False self._wsa: WebSocketApp | None = None @@ -199,38 +202,46 @@ def wrapped_f(ws, *args, **kwargs): def _on_open(self, wsa: WebSocketApp): if self._degraded: return + if not self.check_cookie(): self._wsa.close(status=STATUS_UNEXPECTED_CONDITION, timeout=self._connection_timeout) return - self._event_callback(TransportOpened(wsa=wsa)) + + self._event_callback(TransportOpened()) def _on_message(self, wsa: WebSocketApp, message): if self._degraded: return - self._event_callback(TransportMessage(wsa=wsa, message=message)) + + self._event_callback(TransportMessage(message=message)) def _on_close(self, wsa: WebSocketApp, close_status_code, close_msg): if self._degraded: return - self._event_callback(TransportClosed(wsa=wsa, close_status_code=close_status_code, close_msg=close_msg)) + + self._event_callback(TransportClosed(close_status_code=close_status_code, close_msg=close_msg)) def _on_error(self, wsa: WebSocketApp, error): if self._degraded: return - self._event_callback(TransportError(wsa=wsa, exception=error)) + + self._event_callback(TransportError(exception=error)) def _on_reconnect(self, wsa: WebSocketApp): if self._degraded: return + if not self.check_cookie(): self._wsa.close(status=STATUS_UNEXPECTED_CONDITION, timeout=self._connection_timeout) return - self._event_callback(TransportReconnect(wsa=wsa)) + + self._event_callback(TransportReconnect()) def new_wsa(self): cookie = self.fetch_cookie() if cookie is UNDEFINED: return None + self._cookie = cookie if cookie is not None: _LOGGER.info(f'{self}: Current cookie: {cookie}') @@ -266,12 +277,6 @@ def connect(self): self._running = True while self._running: - # status, reason = probe_ws_reachability(self._url, sslopt=self._sslopt, timeout=3) - # _LOGGER.debug(f'{self}: Probe result: {status}, {reason}') - # if status != ReachabilityStatus.OK: - # time.sleep(5) - # continue - if self._wsa is None: wsa = self.new_wsa() if wsa is None: @@ -284,7 +289,8 @@ def connect(self): ping_interval=self._ping_interval, ping_timeout=self._ping_interval * 0.95, # the timeout is set to a little sooner than the interval sslopt=self._sslopt, - reconnect=cast(int, self._connection_timeout) # floats are de facto valid, casting only for the linter + reconnect=cast(int, self._reconnect_timeout), # floats are de facto valid, casting only for the linter + skip_utf8_validation=self._skip_utf8_validation ) _LOGGER.info(f'{self}: WSA run_forever stopped gracefully') except Exception as e: From da690bf1ea845bf6f3a330276b1d61d7003b4590 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 20:17:25 +0200 Subject: [PATCH 16/32] feat(ws_v2): deprecated IbkrWsKey - WsEvent types are used instead fix(ws_v2): fixed TransportEvent attempts refactor(ws_v2): renamed ClientInternalEvents to LifecycleEvents --- ibind/ibkr_ws_v2/ibkr_events.py | 221 +++++++++++++------------ ibind/ibkr_ws_v2/ibkr_router.py | 39 +++-- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 94 +++++------ ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 21 ++- ibind/ws_v2/events.py | 61 +++---- ibind/ws_v2/subscriptions.py | 15 +- ibind/ws_v2/ws_runtime.py | 4 +- ibind/ws_v2/ws_transport.py | 12 +- 8 files changed, 229 insertions(+), 238 deletions(-) diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py index 9dc931c9..138a5b43 100644 --- a/ibind/ibkr_ws_v2/ibkr_events.py +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -1,78 +1,78 @@ -from enum import Enum +from typing import ClassVar from pydantic import Field from ws_v2.events import WsEvent -class IbkrWsKey(Enum): - # generic - GENERIC = 'GENERIC' - UNSUBSCRIPTION = 'UNSUBSCRIPTION' - SERVER_ID = 'SERVER_ID' - - # unsolicited - ACCOUNT_UPDATE = 'ACCOUNT_UPDATE' - AUTHENTICATION_STATUS = 'AUTHENTICATION_STATUS' - BULLETIN = 'BULLETIN' - ERROR = 'ERROR' - SYSTEM = 'SYSTEM' - NOTIFICATION = 'NOTIFICATION' - - # subscription-based - ACCOUNT_SUMMARY = 'ACCOUNT_SUMMARY' - ACCOUNT_LEDGER = 'ACCOUNT_LEDGER' - MARKET_DATA = 'MARKET_DATA' - MARKET_HISTORY = 'MARKET_HISTORY' - PRICE_LADDER = 'PRICE_LADDER' - ORDERS = 'ORDERS' - PNL = 'PNL' - TRADES = 'TRADES' - - # internal - CLIENT_INTERNAL = 'CLIENT_INTERNAL' - - @classmethod - def from_topic(cls, topic): - topic_to_key = { - 'sd': IbkrWsKey.ACCOUNT_SUMMARY, - 'ld': IbkrWsKey.ACCOUNT_LEDGER, - 'md': IbkrWsKey.MARKET_DATA, - 'mh': IbkrWsKey.MARKET_HISTORY, - 'bd': IbkrWsKey.PRICE_LADDER, - 'or': IbkrWsKey.ORDERS, - 'pl': IbkrWsKey.PNL, - 'tr': IbkrWsKey.TRADES, - } - if topic in topic_to_key: - return topic_to_key[topic] - raise ValueError(f"No enum member associated with topic '{topic}'") - - @property - def topic(self): - """ - Gets the solicited topic string associated with the enum member. - - Returns: - str: The topic string corresponding to the enum member. - """ - return { - IbkrWsKey.ACCOUNT_SUMMARY: 'sd', - IbkrWsKey.ACCOUNT_LEDGER: 'ld', - IbkrWsKey.MARKET_DATA: 'md', - IbkrWsKey.MARKET_HISTORY: 'mh', - IbkrWsKey.PRICE_LADDER: 'bd', - IbkrWsKey.ORDERS: 'or', - IbkrWsKey.PNL: 'pl', - IbkrWsKey.TRADES: 'tr', - }[self] - - def __str__(self): - return self.value +# class IbkrWsKey(Enum): +# # generic +# GENERIC = 'GENERIC' +# UNSUBSCRIPTION = 'UNSUBSCRIPTION' +# SERVER_ID = 'SERVER_ID' +# WAITING_FOR_SESSION = 'WAITING_FOR_SESSION' +# +# # unsolicited +# ACCOUNT_UPDATE = 'ACCOUNT_UPDATE' +# AUTHENTICATION_STATUS = 'AUTHENTICATION_STATUS' +# BULLETIN = 'BULLETIN' +# ERROR = 'ERROR' +# SYSTEM = 'SYSTEM' +# NOTIFICATION = 'NOTIFICATION' +# +# # subscription-based +# ACCOUNT_SUMMARY = 'ACCOUNT_SUMMARY' +# ACCOUNT_LEDGER = 'ACCOUNT_LEDGER' +# MARKET_DATA = 'MARKET_DATA' +# MARKET_HISTORY = 'MARKET_HISTORY' +# PRICE_LADDER = 'PRICE_LADDER' +# ORDERS = 'ORDERS' +# PNL = 'PNL' +# TRADES = 'TRADES' +# +# # internal +# LIFECYCLE = 'LIFECYCLE' +# +# @classmethod +# def from_topic(cls, topic): +# topic_to_key = { +# 'sd': IbkrWsKey.ACCOUNT_SUMMARY, +# 'ld': IbkrWsKey.ACCOUNT_LEDGER, +# 'md': IbkrWsKey.MARKET_DATA, +# 'mh': IbkrWsKey.MARKET_HISTORY, +# 'bd': IbkrWsKey.PRICE_LADDER, +# 'or': IbkrWsKey.ORDERS, +# 'pl': IbkrWsKey.PNL, +# 'tr': IbkrWsKey.TRADES, +# } +# if topic in topic_to_key: +# return topic_to_key[topic] +# raise ValueError(f"No enum member associated with topic '{topic}'") +# +# @property +# def topic(self): +# """ +# Gets the solicited topic string associated with the enum member. +# +# Returns: +# str: The topic string corresponding to the enum member. +# """ +# return { +# IbkrWsKey.ACCOUNT_SUMMARY: 'sd', +# IbkrWsKey.ACCOUNT_LEDGER: 'ld', +# IbkrWsKey.MARKET_DATA: 'md', +# IbkrWsKey.MARKET_HISTORY: 'mh', +# IbkrWsKey.PRICE_LADDER: 'bd', +# IbkrWsKey.ORDERS: 'or', +# IbkrWsKey.PNL: 'pl', +# IbkrWsKey.TRADES: 'tr', +# }[self] +# +# def __str__(self): +# return self.value class GenericIbkrEvent(WsEvent): - key: str = IbkrWsKey.GENERIC message: dict | None topic: str | None = None data: dict | None = None @@ -83,100 +83,115 @@ class GenericIbkrEvent(WsEvent): # =================== class IbkrError(WsEvent): - key: IbkrWsKey = IbkrWsKey.ERROR message: str class WaitingForSession(WsEvent): - key: IbkrWsKey = IbkrWsKey.GENERIC + ... class Notification(WsEvent): - key: IbkrWsKey = IbkrWsKey.NOTIFICATION message: str class Bulletin(WsEvent): - key: IbkrWsKey = IbkrWsKey.BULLETIN message: str class AccountUpdate(WsEvent): - key: IbkrWsKey = IbkrWsKey.ACCOUNT_UPDATE data: dict class System(WsEvent): - key: IbkrWsKey = IbkrWsKey.SYSTEM data: dict class AuthenticationStatus(WsEvent): - key: IbkrWsKey = IbkrWsKey.AUTHENTICATION_STATUS data: dict authenticated: bool | None competing: bool | None -# ========================== -# == Subscription-based == -# ========================== +# =================== +# == Topic-based == +# =================== -class Unsubscription(WsEvent): - key: IbkrWsKey = IbkrWsKey.UNSUBSCRIPTION - target_key: IbkrWsKey - conid: str | None = None +class IbkrTopicEvent(WsEvent): + topic: ClassVar[str] -class AccountSummary(WsEvent): - key: IbkrWsKey = IbkrWsKey.ACCOUNT_SUMMARY +class AccountSummary(IbkrTopicEvent): + topic: ClassVar[str] = 'sd' account_id: str data: dict -class AccountLedger(WsEvent): - key: IbkrWsKey = IbkrWsKey.ACCOUNT_LEDGER +class AccountLedger(IbkrTopicEvent): + topic: ClassVar[str] = 'ld' account_id: str data: dict -class MarketData(WsEvent): - key: IbkrWsKey = IbkrWsKey.MARKET_DATA +class MarketData(IbkrTopicEvent): + topic: ClassVar[str] = 'md' conid: str data: dict = Field(default_factory=dict) -class MarketHistory(WsEvent): - key: IbkrWsKey = IbkrWsKey.MARKET_HISTORY +class MarketHistory(IbkrTopicEvent): + topic: ClassVar[str] = 'mh' conid: str data: dict -class ServerId(WsEvent): - key: IbkrWsKey = IbkrWsKey.SERVER_ID - conid: str - server_id: str - target_key: IbkrWsKey - - -class Orders(WsEvent): - key: IbkrWsKey = IbkrWsKey.ORDERS +class Orders(IbkrTopicEvent): + topic: ClassVar[str] = 'or' data: dict -class PriceLadder(WsEvent): - key: IbkrWsKey = IbkrWsKey.PRICE_LADDER +class PriceLadder(IbkrTopicEvent): + topic: ClassVar[str] = 'bd' account_id: str conid: str exchange: str data: dict -class Pnl(WsEvent): - key: IbkrWsKey = IbkrWsKey.PNL +class Pnl(IbkrTopicEvent): + topic: ClassVar[str] = 'pl' data: dict -class Trades(WsEvent): - key: IbkrWsKey = IbkrWsKey.TRADES +class Trades(IbkrTopicEvent): + topic: ClassVar[str] = 'tr' data: dict + + +# =============== +# == Derived == +# =============== +class ServerId(WsEvent): + target_event_type: type['IbkrTopicEvent'] + conid: str + server_id: str + + +class Unsubscription(WsEvent): + target_event_type: type['IbkrTopicEvent'] + conid: str | None = None + + +def get_ibkr_topic_event(topic: str): + topic_to_event_type = { + 'sd': AccountSummary, + 'ld': AccountLedger, + 'md': MarketData, + 'mh': MarketHistory, + 'bd': PriceLadder, + 'or': Orders, + 'pl': Pnl, + 'tr': Trades, + } + if topic in topic_to_event_type: + return topic_to_event_type[topic] + raise ValueError(f"No Ibkr event associated with topic '{topic}'") diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index bd7d3f44..2bb83995 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -5,7 +5,7 @@ from client import ibkr_definitions from client.ibkr_utils import extract_conid from ibkr_ws_v2 import ibkr_events -from ibkr_ws_v2.ibkr_events import GenericIbkrEvent, IbkrWsKey +from ibkr_ws_v2.ibkr_events import GenericIbkrEvent, get_ibkr_topic_event, IbkrTopicEvent from support.logs import project_logger from support.py_utils import UNDEFINED, OneOrMany from ws_v2.events import WsEvent @@ -18,7 +18,7 @@ def parse_raw_message(raw_message: str): topic = message.get('topic', UNDEFINED) if topic is UNDEFINED: - return message, None, None, None, None + return message, None, None data = message.get('args', {}) @@ -33,7 +33,7 @@ def __init__( ): self._log_raw_messages = log_raw_messages self._unwrap_market_data = unwrap_market_data - self._server_id_conid_pairs: Dict[IbkrWsKey, Dict[str, str]] = defaultdict(dict) + self._server_id_conid_pairs: Dict[type[IbkrTopicEvent], Dict[str, str]] = defaultdict(dict) def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: """ @@ -53,12 +53,12 @@ def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: return ibkr_events.MarketData(conid=str(data['conid']), data=unwrapped_data) def _preprocess_market_history_message(self, data: dict) -> OneOrMany[WsEvent]: - mh_server_id_conid_pairs = self._server_id_conid_pairs[IbkrWsKey.MARKET_HISTORY] + mh_server_id_conid_pairs = self._server_id_conid_pairs[ibkr_events.MarketHistory] events = [] conid = extract_conid(data) if 'serverId' in data and data['serverId'] not in mh_server_id_conid_pairs: mh_server_id_conid_pairs[data['serverId']] = str(conid) - events.append(ibkr_events.ServerId(conid=str(conid), server_id=data['serverId'], target_key=IbkrWsKey.MARKET_HISTORY)) + events.append(ibkr_events.ServerId(conid=str(conid), server_id=data['serverId'], target_event_type=ibkr_events.MarketHistory)) events.append(ibkr_events.MarketHistory(conid=str(conid), data=data)) return events @@ -99,26 +99,27 @@ def _preprocess_account_summary(self, data): def _handle_subscribed_message(self, topic: str, data: dict) -> OneOrMany[WsEvent] | None: try: - ibkr_ws_key = IbkrWsKey.from_topic(topic[1:3]) + # ibkr_ws_key = IbkrWsKey.from_topic(topic[1:3]) + event_type = get_ibkr_topic_event(topic[1:3]) except ValueError: # ValueError means we don't support this topic return None - if ibkr_ws_key == IbkrWsKey.ACCOUNT_SUMMARY: + if event_type == ibkr_events.AccountSummary: return self._preprocess_account_summary(data) - elif ibkr_ws_key == IbkrWsKey.ACCOUNT_LEDGER: + elif event_type == ibkr_events.AccountLedger: return self._preprocess_account_ledger(data) - elif ibkr_ws_key == IbkrWsKey.MARKET_DATA: + elif event_type == ibkr_events.MarketData: return self._preprocess_market_data_message(data) - elif ibkr_ws_key == IbkrWsKey.MARKET_HISTORY: + elif event_type == ibkr_events.MarketHistory: return self._preprocess_market_history_message(data) - elif ibkr_ws_key == IbkrWsKey.PRICE_LADDER: + elif event_type == ibkr_events.PriceLadder: return ibkr_events.PriceLadder(data=data) - elif ibkr_ws_key == IbkrWsKey.ORDERS: + elif event_type == ibkr_events.Orders: return ibkr_events.Orders(data=data) - elif ibkr_ws_key == IbkrWsKey.PNL: + elif event_type == ibkr_events.Pnl: return ibkr_events.Pnl(data=data) - elif ibkr_ws_key == IbkrWsKey.TRADES: + elif event_type == ibkr_events.Trades: return ibkr_events.Trades(data=data) else: _LOGGER.error(f'{self}: Unhandled subscribed message: {data}') @@ -160,12 +161,12 @@ def _handle_notification(self, data) -> OneOrMany[WsEvent]: # pragma: no cover def _handle_market_history_unsubscribe(self, data) -> OneOrMany[WsEvent]: server_id = data['message'].split('Unsubscribed ')[-1] - mh_server_id_conid_pairs = self._server_id_conid_pairs[IbkrWsKey.MARKET_HISTORY] + mh_server_id_conid_pairs = self._server_id_conid_pairs[ibkr_events.MarketHistory] if server_id in mh_server_id_conid_pairs: conid = mh_server_id_conid_pairs[server_id] _LOGGER.info(f'{self}: Received unsubscribing confirmation for server_id={server_id!r}, conid={conid!r}.') if conid is not None: - return ibkr_events.Unsubscription(target_key=IbkrWsKey.MARKET_HISTORY, conid=str(conid)) + return ibkr_events.Unsubscription(target_event_type=ibkr_events.MarketHistory, conid=str(conid)) _LOGGER.warning(f'{self}: Unknown conid={conid!r}. Cannot mark the subscription as unsubscribed.') else: @@ -183,14 +184,13 @@ def _handle_message_without_topic(self, message: dict) -> OneOrMany[WsEvent]: elif 'result' in message: if message['result'] == 'unsubscribed from summary': - return ibkr_events.Unsubscription(target_key=IbkrWsKey.ACCOUNT_SUMMARY) + return ibkr_events.Unsubscription(target_event_type=ibkr_events.AccountSummary) elif message['result'] == 'unsubscribed from ledger': - return ibkr_events.Unsubscription(target_key=IbkrWsKey.ACCOUNT_LEDGER) + return ibkr_events.Unsubscription(target_event_type=ibkr_events.AccountLedger) _LOGGER.error(f'{self}: Unrecognised message without a topic: {message}') return GenericIbkrEvent(message=message) - def route(self, raw_message: str) -> OneOrMany[WsEvent]: if self._log_raw_messages: _LOGGER.debug(f'{self}: Raw message: {raw_message}') @@ -232,6 +232,5 @@ def route(self, raw_message: str) -> OneOrMany[WsEvent]: events = GenericIbkrEvent(message=message, topic=topic, data=arguments) return events - def __str__(self): return f'{self.__class__.__qualname__}()' diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index a6322b9a..01fadaf3 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -3,62 +3,52 @@ from pydantic import Field -from ibkr_ws_v2.ibkr_events import IbkrWsKey, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary +from ibkr_ws_v2 import ibkr_events +from ibkr_ws_v2.ibkr_events import AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary, IbkrTopicEvent from support.py_utils import filter_none from ws_v2.subscriptions import Subscription, SubscriptionResolver def make_binding_key( - key: IbkrWsKey, + event_type: type[IbkrTopicEvent], conid: str = None, account_id=None, exchange=None ): - if key in [IbkrWsKey.MARKET_DATA, IbkrWsKey.MARKET_HISTORY]: - return f"{key.topic}+{conid}" - elif key in [IbkrWsKey.ACCOUNT_LEDGER, IbkrWsKey.ACCOUNT_SUMMARY]: - return f"{key.topic}+{account_id}" - elif key in [IbkrWsKey.PRICE_LADDER]: - return f"{key.topic}+{account_id}+{conid}" + (f"+{exchange}" if exchange is not None else '') - elif key in [IbkrWsKey.ORDERS, IbkrWsKey.PNL, IbkrWsKey.TRADES]: - return key.topic + if event_type in [ibkr_events.MarketData, ibkr_events.MarketHistory]: + return f"{event_type.topic}+{conid}" + elif event_type in [ibkr_events.AccountLedger, ibkr_events.AccountSummary]: + return f"{event_type.topic}+{account_id}" + elif event_type in [ibkr_events.PriceLadder]: + return f"{event_type.topic}+{account_id}+{conid}" + (f"+{exchange}" if exchange is not None else '') + elif event_type in [ibkr_events.Orders, ibkr_events.Pnl, ibkr_events.Trades]: + return event_type.topic else: - raise ValueError(f'Unsupported key: {key}') + raise ValueError(f'Unsupported event type: {event_type}') class IbkrSubscriptionResolver(SubscriptionResolver): - _register = [ - MarketData, - AccountSummary, - AccountLedger, - MarketHistory, - Orders, - PriceLadder, - Pnl, - Trades, - Unsubscription - ] - def __init__(self, account_id): self._account_id = account_id def _resolve_subscribing_event(self, event) -> str: - if event.key in [IbkrWsKey.MARKET_DATA, IbkrWsKey.MARKET_HISTORY]: - return make_binding_key(event.key, conid=event.conid) - elif event.key in [IbkrWsKey.ACCOUNT_LEDGER, IbkrWsKey.ACCOUNT_SUMMARY]: - return make_binding_key(event.key, account_id=event.account_id) - elif event.key in [IbkrWsKey.PRICE_LADDER]: - return make_binding_key(event.key, conid=event.conid, account_id=event.account_id, exchange=event.exchange) - elif event.key in [IbkrWsKey.ORDERS, IbkrWsKey.PNL, IbkrWsKey.TRADES]: - return make_binding_key(event.key) + event_type = type(event) + if event_type in [ibkr_events.MarketData, ibkr_events.MarketHistory]: + return make_binding_key(event_type, conid=event.conid) + elif event_type in [ibkr_events.AccountLedger, ibkr_events.AccountSummary]: + return make_binding_key(event_type, account_id=event.account_id) + elif event_type in [ibkr_events.PriceLadder]: + return make_binding_key(event_type, conid=event.conid, account_id=event.account_id, exchange=event.exchange) + elif event_type in [ibkr_events.Orders, ibkr_events.Pnl, ibkr_events.Trades]: + return make_binding_key(event_type) else: raise ValueError(f'Unsupported event: {event}') def _resolve_unsubscribing_event(self, event) -> str: - return make_binding_key(event.target_key, event.conid, self._account_id) + return make_binding_key(event.target_event_type, event.conid, self._account_id) def resolve_binding_key(self, event) -> Tuple[bool, str] | Tuple[None, None]: - if type(event) not in self._register: + if not (isinstance(event, IbkrTopicEvent) or isinstance(event, Unsubscription)): return None, None if isinstance(event, Unsubscription): @@ -68,15 +58,15 @@ def resolve_binding_key(self, event) -> Tuple[bool, str] | Tuple[None, None]: class IbkrSubscription(Subscription): - key: IbkrWsKey + event_type: type[IbkrTopicEvent] @property def topic(self) -> str: - return self.key.topic + return self.event_type.topic class AccountSummarySubscription(IbkrSubscription): - key: IbkrWsKey = IbkrWsKey.ACCOUNT_SUMMARY + event_type: type[IbkrTopicEvent] = AccountSummary account_id: str def subscribe_payload(self) -> str: @@ -94,11 +84,11 @@ def confirms_unsubscribe(self) -> bool: return True def binding_key(self): - return make_binding_key(self.key, account_id=self.account_id) + return make_binding_key(self.event_type, account_id=self.account_id) class AccountLedgerSubscription(IbkrSubscription): - key: IbkrWsKey = IbkrWsKey.ACCOUNT_LEDGER + event_type: type[IbkrTopicEvent] = AccountLedger account_id: str def subscribe_payload(self) -> str: @@ -116,11 +106,11 @@ def confirms_unsubscribe(self) -> bool: return True def binding_key(self): - return make_binding_key(self.key, account_id=self.account_id) + return make_binding_key(self.event_type, account_id=self.account_id) class MarketDataSubscription(IbkrSubscription): - key: IbkrWsKey = IbkrWsKey.MARKET_DATA + event_type: type[IbkrTopicEvent] = MarketData conid: str fields: tuple[str, ...] @@ -140,11 +130,11 @@ def confirms_unsubscribe(self) -> bool: return False def binding_key(self): - return make_binding_key(self.key, conid=self.conid) + return make_binding_key(self.event_type, conid=self.conid) class MarketHistorySubscription(IbkrSubscription): - key: IbkrWsKey = IbkrWsKey.MARKET_HISTORY + event_type: type[IbkrTopicEvent] = MarketHistory conid: str exchange: str = None period: str = None @@ -152,7 +142,7 @@ class MarketHistorySubscription(IbkrSubscription): outside_rth: bool = None source: str = None format: str = None - server_id: list = Field(default_factory=list) # uses list to allow writing despite frozen model + server_id: list = Field(default_factory=list) # uses list to allow writing despite frozen model def subscribe_payload(self) -> str: data = { @@ -192,11 +182,11 @@ def get_server_id(self): return self.server_id[0] def binding_key(self): - return make_binding_key(self.key, conid=self.conid) + return make_binding_key(self.event_type, conid=self.conid) class OrdersSubscription(IbkrSubscription): - key: IbkrWsKey = IbkrWsKey.ORDERS + event_type: type[IbkrTopicEvent] = Orders filter: str = None def subscribe_payload(self) -> str: @@ -215,11 +205,11 @@ def confirms_unsubscribe(self) -> bool: return False def binding_key(self): - return make_binding_key(self.key) + return make_binding_key(self.event_type) class PriceLadderSubscription(IbkrSubscription): - key: IbkrWsKey = IbkrWsKey.PRICE_LADDER + event_type: type[IbkrTopicEvent] = PriceLadder conid: str account_id: str exchange: str @@ -239,11 +229,11 @@ def confirms_unsubscribe(self) -> bool: return False def binding_key(self): - return make_binding_key(self.key, conid=self.conid, account_id=self.account_id, exchange=self.exchange) + return make_binding_key(self.event_type, conid=self.conid, account_id=self.account_id, exchange=self.exchange) class PnlSubscription(IbkrSubscription): - key: IbkrWsKey = IbkrWsKey.PNL + event_type: type[IbkrTopicEvent] = Pnl def subscribe_payload(self) -> str: return 'spl' @@ -260,11 +250,11 @@ def confirms_unsubscribe(self) -> bool: return False def binding_key(self): - return make_binding_key(self.key) + return make_binding_key(self.event_type) class TradesSubscription(IbkrSubscription): - key: IbkrWsKey = IbkrWsKey.TRADES + event_type: type[IbkrTopicEvent] = Trades realtime_updates_only: bool | None = None days: int | None = None @@ -289,4 +279,4 @@ def confirms_unsubscribe(self) -> bool: return False def binding_key(self): - return make_binding_key(self.key) + return make_binding_key(self.event_type) diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 1db5f031..6ccd8a00 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -3,14 +3,14 @@ from typing import Union, List, Dict import var -from base.queue_controller import QueueController -from ibind import IbkrClient, IbkrWsKey +from ibind import IbkrClient from ibkr_ws_v2 import ibkr_events +from ibkr_ws_v2.ibkr_events import IbkrTopicEvent from ibkr_ws_v2.ibkr_router import IbkrRouter from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription from support.logs import project_logger from support.py_utils import OneOrMany, ensure_list_arg -from ws_v2.events import EventSink, LogSink, CallbackSink, Router, AsyncSink, NoopSink +from ws_v2.events import EventSink, CallbackSink, Router, AsyncSink, NoopSink from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle, BindingStatus from ws_v2.ws_runtime import WsRuntime, WsState @@ -63,7 +63,7 @@ def __init__( # self._queue_controller.register_queues(list(IbkrWsKey)) if sink is None: - # self._queue_controller.register_queues(['CLIENT_INTERNAL', 'IBKR']) + # self._queue_controller.register_queues(['LIFECYCLE', 'IBKR']) # sink = QueueSink(queue_controller=self._queue_controller) # sink = LogSink() @@ -98,7 +98,7 @@ def __init__( ) self._mh_subscriptions: List[MarketHistorySubscription] = [] - self._conid_server_id_pairs: Dict[IbkrWsKey, Dict[str, str]] = defaultdict(dict) + self._conid_server_id_pairs: Dict[type[ibkr_events.IbkrTopicEvent], Dict[str, str]] = defaultdict(dict) def _register_internal_callbacks(self): self._internal_sink.on(ibkr_events.AuthenticationStatus, self._on_authentication_status) @@ -122,9 +122,9 @@ def _on_system(self, event: ibkr_events.System): self._runtime.set_last_heartbeat(int(event.data['hb']) / 1000) def _on_server_id(self, event: ibkr_events.ServerId): - self._conid_server_id_pairs[event.target_key][event.conid] = event.server_id + self._conid_server_id_pairs[event.target_event_type][event.conid] = event.server_id for subscription in self._mh_subscriptions: - if subscription.key == event.target_key and subscription.conid == event.conid and not subscription.has_server_id(): + if subscription.event_type == event.target_event_type and subscription.conid == event.conid and not subscription.has_server_id(): subscription.set_server_id(event.server_id) def _get_cookie(self): @@ -172,13 +172,13 @@ def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: def get_status(self, binding_key: str) -> BindingStatus: return self._runtime.subscription_controller.get_status(binding_key) - def get_server_id(self, key: IbkrWsKey, conid: str) -> str: - return self._conid_server_id_pairs[key][conid] + def get_server_id(self, event_type: type[IbkrTopicEvent], conid: str) -> str: + return self._conid_server_id_pairs[event_type][conid] def _handle_mh_unsubscription(self, subscription: MarketHistorySubscription): if subscription.has_server_id(): return - server_id = self._conid_server_id_pairs.get(subscription.key, {}).get(subscription.conid) + server_id = self._conid_server_id_pairs.get(subscription.event_type, {}).get(subscription.conid) if server_id is None: raise RuntimeError(f'{self}: Unsubscribing from market history for conid={subscription.conid!r} without server_id. Could not find server_id in memory. Ensure at least one MarketHistory event is received before unsubscribing.') @@ -201,7 +201,6 @@ def wait_all(self, subscription_handles: OneOrMany[SubscriptionHandle], timeout: def is_running(self) -> bool: return self._runtime.is_running() - def get_state(self) -> WsState: return self._runtime.get_state() diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py index 3823aa51..8b0dcbd0 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/events.py @@ -3,13 +3,13 @@ from datetime import datetime from queue import Queue, Full, Empty from threading import Thread, Event -from typing import Hashable, Protocol, Callable, TypeVar, List, Dict, Any +from typing import Protocol, Callable, TypeVar, List, Dict, Any from pydantic import BaseModel, ConfigDict, Field from base.queue_controller import QueueAccessor from support.logs import project_logger -from support.py_utils import OneOrMany, exception_to_string, tname, ensure_list_arg +from support.py_utils import OneOrMany, exception_to_string, tname _LOGGER = project_logger('ibkr_ws_client') @@ -22,7 +22,6 @@ class WsEvent(BaseModel): model_config = ConfigDict(frozen=True, extra="forbid") received_at: datetime = Field(default_factory=datetime.now) - key: Hashable def __str__(self): return self._format() @@ -56,32 +55,32 @@ def _format(self): return f"{self.__class__.__name__}({fields})" -class ClientInternalEvent(WsEvent): - key: str = 'CLIENT_INTERNAL' +class LifecycleEvent(WsEvent): + ... -class WsOpen(ClientInternalEvent): +class WsOpen(LifecycleEvent): ... -class WsAuthenticated(ClientInternalEvent): +class WsAuthenticated(LifecycleEvent): ... -class WsDegraded(ClientInternalEvent): +class WsDegraded(LifecycleEvent): ... -class WsReady(ClientInternalEvent): +class WsReady(LifecycleEvent): ... -class WsClose(ClientInternalEvent): +class WsClose(LifecycleEvent): close_status_code: int | None close_msg: str | None -class WsError(ClientInternalEvent): +class WsError(LifecycleEvent): model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) error: Exception @@ -97,7 +96,7 @@ def emit(self, event: "WsEvent") -> None: class LogSink: def emit(self, event: WsEvent) -> None: - _LOGGER.debug(f'{event.key}: {str(event)}') + _LOGGER.debug(event) class NoopSink: @@ -127,38 +126,30 @@ def __str__(self): class QueueSink: - def __init__(self, event_types: List[Hashable]): - self._queues = { - 'CLIENT_INTERNAL': Queue() - } - self.register_queues(event_types) - - @ensure_list_arg('keys') - def register_queues(self, keys: OneOrMany[Hashable]): - for key in keys: - if key not in self._queues: - self._queues[str(key)] = Queue() - - def new_queue_accessor(self, key: Hashable) -> QueueAccessor: - return QueueAccessor(self._get_queue(key), key) - - def _get_queue(self, key: Hashable) -> Queue: # pragma: no cover + def __init__(self): + self._queues = {} + + def new_queue_accessor(self, event_type: type[WsEvent]) -> QueueAccessor: + return QueueAccessor(self._get_queue(event_type), event_type) + + def _get_queue(self, event_type: type[WsEvent]) -> Queue: # pragma: no cover try: - return self._queues[str(key)] + return self._queues[event_type] except KeyError: - raise AttributeError(f'Invalid queue key: "{key}", expected: {list(self._queues.keys())}') + self._queues[event_type] = Queue() + return self._queues[event_type] - def get(self, key: Hashable, block: bool = False, timeout=None) -> Any: + def get(self, event_type: type[WsEvent], block: bool = False, timeout=None) -> Any: try: - return self._get_queue(key).get(block=block, timeout=timeout) + return self._get_queue(event_type).get(block=block, timeout=timeout) except Empty: return None - def empty(self, key: Hashable) -> bool: - return self._get_queue(key).empty() + def empty(self, event_type: type[WsEvent]) -> bool: + return self._get_queue(event_type).empty() def emit(self, event: WsEvent) -> None: - queue = self._get_queue(event.key) + queue = self._get_queue(type(event)) queue.put(event) diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/subscriptions.py index 593bca72..af80ac4d 100644 --- a/ibind/ws_v2/subscriptions.py +++ b/ibind/ws_v2/subscriptions.py @@ -2,7 +2,7 @@ import time from enum import Enum from threading import Condition, RLock -from typing import Dict, Optional, Callable, Protocol, Tuple, Hashable, Literal +from typing import Dict, Optional, Callable, Protocol, Tuple, Literal from pydantic import BaseModel, ConfigDict @@ -15,7 +15,6 @@ class Subscription(BaseModel): model_config = ConfigDict(frozen=True) - key: Hashable expiry_seconds: int | None = None @property @@ -39,19 +38,12 @@ def confirms_unsubscribe(self) -> bool: def binding_key(self): return self.subscribe_payload() - def __hash__(self): - if hasattr(self, '_hash'): - return self._hash - _hash = hash(self.binding_key()) - setattr(self, '_hash', _hash) - return _hash - def __str__(self): return f'{self.__class__.__qualname__}({self.binding_key()})' class SubscriptionResolver(Protocol): - def resolve_binding_key(self, event) -> Tuple[bool, str]: + def resolve_binding_key(self, event: WsEvent) -> Tuple[bool, str]: ... @@ -170,7 +162,7 @@ def observe(self, event: WsEvent): else: self._confirm_unsubscribed(binding_key) - def _make_attempt(self, binding:Binding): + def _make_attempt(self, binding: Binding): subscription = binding.subscription if binding.intent == BindingStatus.ACTIVE: payload = subscription.subscribe_payload() @@ -215,7 +207,6 @@ def reconcile_binding(self, binding: Binding): binding.attempts += 1 self._make_attempt(binding) - def reconcile_bindings(self): with self._condition: for binding in self._bindings.values(): diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index 4efcec2a..0485364f 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -380,8 +380,8 @@ def _process_transport_queue(self): self._handle_transport_event(te) except Exception as e: _LOGGER.error(f'{self}: Exception processing transport event {te}: {exception_to_string(e)}') - te.attempt += 1 - if te.attempt > _MAX_TRANSPORT_EVENT_RETRIES: + te.add_attempt() + if te.get_attempt() > _MAX_TRANSPORT_EVENT_RETRIES: _LOGGER.error(f'{self}: Max retries ({_MAX_TRANSPORT_EVENT_RETRIES}) reached for transport event {te}, dropping event.') continue retry_events.append(te) diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index 86dc699a..edebe978 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -1,6 +1,6 @@ import time from datetime import datetime -from typing import Callable, Any, cast +from typing import Callable, Any, cast, List from pydantic import BaseModel, ConfigDict, Field from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION, STATUS_NORMAL @@ -14,9 +14,15 @@ class TransportEvent(BaseModel): - model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) received_at: datetime = Field(default_factory=datetime.now) - attempt: int = 0 + attempt: List[int] = Field(default_factory=lambda: [0]) + + def add_attempt(self): + self.attempt[0] += 1 + + def get_attempt(self): + return self.attempt[0] def __str__(self): return f'{self.__class__.__qualname__}()' From 90b9b22f31afec04bdb52d93e8765606585ca933 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 20:39:49 +0200 Subject: [PATCH 17/32] chore(ws_v2): updated API with ws_v2 --- examples/ws_04_ws_v2.py | 87 ++++++++++++++------------------- ibind/__init__.py | 15 +++++- ibind/events/__init__.py | 32 ++++++++++++ ibind/subscriptions/__init__.py | 17 +++++++ 4 files changed, 99 insertions(+), 52 deletions(-) create mode 100644 ibind/events/__init__.py create mode 100644 ibind/subscriptions/__init__.py diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py index b5c997e4..69a280c9 100644 --- a/examples/ws_04_ws_v2.py +++ b/examples/ws_04_ws_v2.py @@ -14,32 +14,49 @@ import time from typing import List -from ibind import ibind_logs_initialize -from ibkr_ws_v2.ibkr_events import IbkrWsKey -from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription -from ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 -from ws_v2.events import LogSink, QueueSink -from ws_v2.subscriptions import SubscriptionHandle +from ibind import events, IbkrWsClientV2, LogSink, QueueSink, CallbackSink, CompositeSink, ibind_logs_initialize +from ibind.subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription, SubscriptionHandle ibind_logs_initialize(log_to_file=False, log_level='DEBUG') account_id = os.getenv('IBIND_ACCOUNT_ID', '[YOUR_ACCOUNT_ID]') cacert = os.getenv('IBIND_CACERT', False) # insert your cacert path here -queue_sink = QueueSink(list(IbkrWsKey)) +# Queue Sink - queue-based event consumer +queue_sink = QueueSink() + +# Callback Sink - callback-based event consumer +callback_sink = CallbackSink() + + +def on_market_data(event: events.MarketData): + print(event) + +def on_market_history(event: events.MarketHistory): + print(event) + +def on_lifecycle(event: events.LifecycleEvent): + print(event) + +callback_sink.on(events.MarketData, on_market_data) +callback_sink.on(events.MarketHistory, on_market_data) +callback_sink.on(events.WsOpen, on_lifecycle) +callback_sink.on(events.WsClose, on_lifecycle) +callback_sink.on(events.WsError, on_lifecycle) +callback_sink.on(events.WsAuthenticated, on_lifecycle) +callback_sink.on(events.WsReady, on_lifecycle) +callback_sink.on(events.WsDegraded, on_lifecycle) + +# Log Sink - useful for debugging +log_sink = LogSink() + +# Composite Sink - allows us to use all above sinks at once +composite_sink = CompositeSink(callback_sink, log_sink) # ws_client = IbkrWsClient(cacert=cacert, account_id=account_id) # ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=LogSink()) -ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=queue_sink) +ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=composite_sink) -# def stop(_, _1): -# print('exit') -# ws_client.shutdown() -# print('done') -# return False -# -# signal.signal(signal.SIGINT, stop) -# signal.signal(signal.SIGTERM, stop) ws_client.start() @@ -61,8 +78,6 @@ # tr_sub ] - - sub_handles: List[SubscriptionHandle] = [] for sub in subs: handle = ws_client.subscribe(sub) @@ -77,8 +92,8 @@ try: while ws_client.is_running(): for sub in subs: - while not queue_sink.empty(sub.key): - ev = queue_sink.get(sub.key) + while not queue_sink.empty(sub.event_type): + ev = queue_sink.get(sub.event_type) print(ev) time.sleep(1) @@ -98,35 +113,5 @@ # # for handle in unsub_handles: # handle.wait(10) -# time.sleep(5) -ws_client.shutdown() - -# requests = [ -# {'channel': 'md+265598', 'data': {'fields': ['55', '71', '84', '86', '88', '85', '87', '7295', '7296', '70']}}, -# {'channel': 'or'}, -# {'channel': 'tr'}, -# {'channel': f'sd+{account_id}'}, -# {'channel': f'ld+{account_id}'}, -# {'channel': 'pl'}, -# ] -# -# -# -# -# for request in requests: -# while not ws_client.subscribe(**request): -# time.sleep(1) -# -# while ws_client.running: -# try: -# for qa in queue_accessors: -# while not qa.empty(): -# print(str(qa), qa.get()) -# -# time.sleep(1) -# except KeyboardInterrupt: -# print('KeyboardInterrupt') -# break -# -# stop(None, None) +ws_client.shutdown() diff --git a/ibind/__init__.py b/ibind/__init__.py index 6a90b201..de7b0797 100644 --- a/ibind/__init__.py +++ b/ibind/__init__.py @@ -10,6 +10,11 @@ from ibind.support.errors import ExternalBrokerError from ibind.support.logs import ibind_logs_initialize from ibind.support.py_utils import execute_in_parallel +from ibind import events, subscriptions +from ibind.ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 +from ibind.ws_v2.events import LogSink, QueueSink, CallbackSink, CompositeSink +from ibind.ws_v2.subscriptions import SubscriptionHandle + __all__ = [ 'ibind_logs_initialize', @@ -28,7 +33,15 @@ 'QueueAccessor', 'execute_in_parallel', 'ExternalBrokerError', - 'question_type_to_message_id' + 'question_type_to_message_id', + 'events', + 'subscriptions', + 'IbkrWsClientV2', + 'LogSink', + 'QueueSink', + 'CallbackSink', + 'CompositeSink', + 'SubscriptionHandle', ] # patch_dotenv() diff --git a/ibind/events/__init__.py b/ibind/events/__init__.py new file mode 100644 index 00000000..b82f22ce --- /dev/null +++ b/ibind/events/__init__.py @@ -0,0 +1,32 @@ +from ibind.ws_v2.events import LifecycleEvent, WsOpen, WsAuthenticated, WsDegraded, WsReady, WsClose, WsError +from ibind.ibkr_ws_v2.ibkr_events import GenericIbkrEvent, IbkrError, WaitingForSession, Notification, Bulletin, AccountUpdate, System, AuthenticationStatus, IbkrTopicEvent, AccountSummary, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, ServerId, Unsubscription + + +__all__ = [ + 'LifecycleEvent', + 'WsOpen', + 'WsAuthenticated', + 'WsDegraded', + 'WsReady', + 'WsClose', + 'WsError', + 'GenericIbkrEvent', + 'IbkrError', + 'WaitingForSession', + 'Notification', + 'Bulletin', + 'AccountUpdate', + 'System', + 'AuthenticationStatus', + 'IbkrTopicEvent', + 'AccountSummary', + 'AccountLedger', + 'MarketData', + 'MarketHistory', + 'Orders', + 'PriceLadder', + 'Pnl', + 'Trades', + 'ServerId', + 'Unsubscription', +] diff --git a/ibind/subscriptions/__init__.py b/ibind/subscriptions/__init__.py new file mode 100644 index 00000000..74f7a744 --- /dev/null +++ b/ibind/subscriptions/__init__.py @@ -0,0 +1,17 @@ +from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription + +from ws_v2.subscriptions import SubscriptionHandle, BindingStatus, Subscription, SubscriptionResolver + +__all__ = [ + 'Subscription', + 'SubscriptionResolver', + 'SubscriptionHandle', + 'BindingStatus', + 'MarketDataSubscription', + 'OrdersSubscription', + 'AccountLedgerSubscription', + 'AccountSummarySubscription', + 'PnlSubscription', + 'TradesSubscription', + 'MarketHistorySubscription', +] From a68135dcc7a8ad3467fb33524500fbf22ef1d1c6 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 20:48:27 +0200 Subject: [PATCH 18/32] chore(ws_v2): small reformatting --- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 2 +- ibind/support/py_utils.py | 2 +- ibind/ws_v2/subscriptions.py | 2 +- ibind/ws_v2/ws_runtime.py | 6 +++--- ibind/ws_v2/ws_transport.py | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index 01fadaf3..c77534ef 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -154,7 +154,7 @@ def subscribe_payload(self) -> str: 'format': self.format, } data = filter_none(data) - return f'smh+{self.conid}+{json.dumps(data, separators=(',', ':'))}' + return f'smh+{self.conid}+{json.dumps(data, separators=(",", ":"))}' def unsubscribe_payload(self) -> str: server_id = self.get_server_id() diff --git a/ibind/support/py_utils.py b/ibind/support/py_utils.py index 988e4a42..d35e1609 100644 --- a/ibind/support/py_utils.py +++ b/ibind/support/py_utils.py @@ -20,7 +20,7 @@ S = TypeVar('S') OneOrMany = Union[S, List[S]] -def NOOP(): +def noop(): return None _LOGGER = project_logger(__file__) diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/subscriptions.py index af80ac4d..a0d0b60f 100644 --- a/ibind/ws_v2/subscriptions.py +++ b/ibind/ws_v2/subscriptions.py @@ -182,7 +182,7 @@ def reconcile_binding(self, binding: Binding): now = time.time() subscription = binding.subscription - if binding.status == binding.intent or binding.status == BindingStatus.FAILED: + if binding.status in [binding.intent, BindingStatus.FAILED]: if subscription.expiry_seconds is None: return diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index 0485364f..eb9c9fe2 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -8,7 +8,7 @@ from typing import Union, List, Dict, Callable, Literal from support.logs import project_logger -from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, NOOP +from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, noop from ws_v2 import events from ws_v2.events import WsEvent, EventSink, Router, CallbackSink, AsyncSink from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver @@ -57,8 +57,8 @@ def __init__( connection_timeout: float = _DEFAULT_TIMEOUT, reconnect_timeout: float | None = _DEFAULT_TIMEOUT, max_ping_interval: float = 20, - get_cookie: Callable = NOOP, - get_header: Callable = NOOP, + get_cookie: Callable = noop, + get_header: Callable = noop, ): if ready_state not in [WsState.OPEN, WsState.AUTHENTICATED]: raise ValueError(f'Invalid ready_state: {ready_state}, must be either {WsState.OPEN} or {WsState.AUTHENTICATED}') diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index edebe978..e67cc918 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -8,7 +8,7 @@ import var from ibind import ExternalBrokerError from support.logs import project_logger -from support.py_utils import exception_to_string, tname, wait_until, UNDEFINED, NOOP +from support.py_utils import exception_to_string, tname, wait_until, UNDEFINED, noop _LOGGER = project_logger('ibkr_ws_client') @@ -56,8 +56,8 @@ def __init__( url: str, event_callback: Callable, sslopt: dict[str, Any], - get_cookie: Callable = NOOP, - get_header: Callable = NOOP, + get_cookie: Callable = noop, + get_header: Callable = noop, ping_interval: float = 10, ping_timeout: float = 10, max_ping_interval: float = 20, From 8749c2bf82952aaffe0160d12ab97615c0fcac69 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 20:49:15 +0200 Subject: [PATCH 19/32] requirements: updated websocket-client to >=1.9 (from >=1.7) --- requirements.txt | Bin 108 -> 112 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 008e657ce57c637ec72ec984a62926c9573234e3..57b558d066544a449983e2f1f868bac544e8c6b3 100644 GIT binary patch delta 14 Vcmc~PnBc`|Ini5@g_nVg0RSIX0|@{C delta 9 QcmXTOnc&4}KG9nd01!|D>;M1& From 77931ff1805a314a5cd21b421269056ec50bc31c Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 20:55:26 +0200 Subject: [PATCH 20/32] fix(logs): fixed project_logger incorrectly testing filepath --- ibind/support/logs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibind/support/logs.py b/ibind/support/logs.py index edde1a96..d6d18bf0 100644 --- a/ibind/support/logs.py +++ b/ibind/support/logs.py @@ -46,7 +46,7 @@ def project_logger(filepath=None): """ logger_name = 'ibind' if filepath is not None: - child = Path(filepath).stem if isinstance(filepath, Path) else str(filepath) + child = Path(filepath).stem if os.path.exists(filepath) else str(filepath) logger_name += f'.{child}' return logging.getLogger(logger_name) From d33ddc0b90a900b79e2bf96d7b14690d13ac2e01 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 20:55:58 +0200 Subject: [PATCH 21/32] fix(logs): fixed project_logger incorrectly testing filepath 2 --- ibind/support/logs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ibind/support/logs.py b/ibind/support/logs.py index d6d18bf0..9c9349f6 100644 --- a/ibind/support/logs.py +++ b/ibind/support/logs.py @@ -1,5 +1,6 @@ import datetime import logging +import os.path import sys from pathlib import Path from typing import List From ee6a6013b0d696aa3b776f183389fc1276016fff Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 3 May 2026 20:56:21 +0200 Subject: [PATCH 22/32] chore: small reformats --- ibind/base/queue_controller.py | 2 +- ibind/base/rest_client.py | 2 +- ibind/client/ibkr_client.py | 4 ++-- ibind/client/ibkr_client_mixins/session_mixin.py | 2 +- ibind/client/ibkr_utils.py | 2 +- ibind/client/ibkr_ws_client.py | 2 +- ibind/oauth/__init__.py | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ibind/base/queue_controller.py b/ibind/base/queue_controller.py index 29d0e35e..5bc217f3 100644 --- a/ibind/base/queue_controller.py +++ b/ibind/base/queue_controller.py @@ -146,4 +146,4 @@ def put_to_queue(self, key: T, data): AttributeError: If no queue exists for the given key. """ queue = self.get_queue(key) - queue.put(data) \ No newline at end of file + queue.put(data) diff --git a/ibind/base/rest_client.py b/ibind/base/rest_client.py index ec063ec2..a524a1b0 100644 --- a/ibind/base/rest_client.py +++ b/ibind/base/rest_client.py @@ -356,4 +356,4 @@ def _signal_handler(signum, frame): atexit.register(_close_handler) def __str__(self): - return f'{self.__class__.__qualname__}' \ No newline at end of file + return f'{self.__class__.__qualname__}' diff --git a/ibind/client/ibkr_client.py b/ibind/client/ibkr_client.py index 8f38c40b..15be8e36 100644 --- a/ibind/client/ibkr_client.py +++ b/ibind/client/ibkr_client.py @@ -320,7 +320,7 @@ def _attempt_health_check(self, method: callable, raise_exceptions: bool = False elif 'An attempt was made to access a socket in a way forbidden by its access permissions' in str(e): _LOGGER.error('Connection to IBKR servers blocked during reauthentication. Check that nothing is blocking connectivity of the application') elif e.status_code == 410 and 'gone' in str(e): - _LOGGER.error(f'OAuth 410 gone: recreate a new live session token, or try a different server, eg. "1.api.ibkr.com", "2.api.ibkr.com", etc.') + _LOGGER.error('OAuth 410 gone: recreate a new live session token, or try a different server, eg. "1.api.ibkr.com", "2.api.ibkr.com", etc.') else: _LOGGER.error(f'Unknown error checking IBKR connection during reauthentication: {exception_to_string(e)}') @@ -330,4 +330,4 @@ def _attempt_health_check(self, method: callable, raise_exceptions: bool = False _LOGGER.error(f'Error reauthenticating OAuth during reauthentication: {exception_to_string(e)}') if raise_exceptions: raise - return False \ No newline at end of file + return False diff --git a/ibind/client/ibkr_client_mixins/session_mixin.py b/ibind/client/ibkr_client_mixins/session_mixin.py index 2b68e733..b87f1f63 100644 --- a/ibind/client/ibkr_client_mixins/session_mixin.py +++ b/ibind/client/ibkr_client_mixins/session_mixin.py @@ -142,4 +142,4 @@ def check_auth_status(self: 'IbkrClient') -> bool: if result.data.get('authenticated', None) is None: raise AttributeError(f'Health check requests returns invalid data: {result}') - return _parse_auth_status(result.data) \ No newline at end of file + return _parse_auth_status(result.data) diff --git a/ibind/client/ibkr_utils.py b/ibind/client/ibkr_utils.py index 21cda07d..bacbd4d9 100644 --- a/ibind/client/ibkr_utils.py +++ b/ibind/client/ibkr_utils.py @@ -774,4 +774,4 @@ def cleanup_market_history_responses( } ) results[symbol] = records - return results \ No newline at end of file + return results diff --git a/ibind/client/ibkr_ws_client.py b/ibind/client/ibkr_ws_client.py index 30c50cdb..74918a95 100644 --- a/ibind/client/ibkr_ws_client.py +++ b/ibind/client/ibkr_ws_client.py @@ -711,4 +711,4 @@ def ts_changed(): if not wait_until(ts_changed, f'tic timeout, ts={ts}', timeout=5): return None - return self._tic_message \ No newline at end of file + return self._tic_message diff --git a/ibind/oauth/__init__.py b/ibind/oauth/__init__.py index e0b7d24f..1e151559 100644 --- a/ibind/oauth/__init__.py +++ b/ibind/oauth/__init__.py @@ -55,4 +55,4 @@ def copy(self, **kwargs): if not hasattr(copied, kwarg): raise AttributeError(f'OAuthConfig does not have attribute "{kwarg}"') setattr(copied, kwarg, value) - return copied \ No newline at end of file + return copied From 68ecb8fe57c379000977a38e223296dc38d7f4c4 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 13:46:28 +0200 Subject: [PATCH 23/32] refactor(ws_v2): normalize event imports through `ibind.events`, re-export `WsEvent`, and relocate IBKR topic-to-event resolution into the router for clearer ownership fix(ws_v2): harden runtime shutdown and transport state improve runtime stop/close flow to set closed state consistently, separate graceful vs unexpected disconnects, and mark transport degraded when thread shutdown fails reduce websocket lifecycle log noise by moving thread start/stop logs to debug while keeping key connection, auth, and send events visible --- examples/ws_04_ws_v2.py | 6 +- ibind/events/__init__.py | 3 +- ibind/ibkr_ws_v2/ibkr_events.py | 90 +------------- ibind/ibkr_ws_v2/ibkr_router.py | 164 ++++++++++++++----------- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 24 ++-- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 24 ++-- ibind/ws_v2/events.py | 32 ++--- ibind/ws_v2/subscriptions.py | 2 +- ibind/ws_v2/ws_runtime.py | 54 ++++---- ibind/ws_v2/ws_transport.py | 146 +++++++++++++++++----- 10 files changed, 278 insertions(+), 267 deletions(-) diff --git a/examples/ws_04_ws_v2.py b/examples/ws_04_ws_v2.py index 69a280c9..534ad33b 100644 --- a/examples/ws_04_ws_v2.py +++ b/examples/ws_04_ws_v2.py @@ -17,7 +17,7 @@ from ibind import events, IbkrWsClientV2, LogSink, QueueSink, CallbackSink, CompositeSink, ibind_logs_initialize from ibind.subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription, SubscriptionHandle -ibind_logs_initialize(log_to_file=False, log_level='DEBUG') +ibind_logs_initialize(log_to_file=False, log_level='INFO') account_id = os.getenv('IBIND_ACCOUNT_ID', '[YOUR_ACCOUNT_ID]') cacert = os.getenv('IBIND_CACERT', False) # insert your cacert path here @@ -55,14 +55,14 @@ def on_lifecycle(event: events.LifecycleEvent): # ws_client = IbkrWsClient(cacert=cacert, account_id=account_id) # ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=LogSink()) -ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=composite_sink) +ws_client = IbkrWsClientV2(cacert=cacert, account_id=account_id, sink=queue_sink) ws_client.start() as_sub = AccountSummarySubscription(account_id=account_id) al_sub = AccountLedgerSubscription(account_id=account_id) -md_sub = MarketDataSubscription(conid='265598', fields=("31", "84", "86"), expiry_seconds=30) +md_sub = MarketDataSubscription(conid='265598', fields=["31", "84", "86"], expiry_seconds=30) mh_sub = MarketHistorySubscription(conid='265598') or_sub = OrdersSubscription() # pl_sub = PriceLadderSubscription(conid='265598', account_id=account_id, exchange='SMART') diff --git a/ibind/events/__init__.py b/ibind/events/__init__.py index b82f22ce..b62de9f1 100644 --- a/ibind/events/__init__.py +++ b/ibind/events/__init__.py @@ -1,9 +1,10 @@ -from ibind.ws_v2.events import LifecycleEvent, WsOpen, WsAuthenticated, WsDegraded, WsReady, WsClose, WsError +from ibind.ws_v2.events import LifecycleEvent, WsOpen, WsAuthenticated, WsDegraded, WsReady, WsClose, WsError, WsEvent from ibind.ibkr_ws_v2.ibkr_events import GenericIbkrEvent, IbkrError, WaitingForSession, Notification, Bulletin, AccountUpdate, System, AuthenticationStatus, IbkrTopicEvent, AccountSummary, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, ServerId, Unsubscription __all__ = [ 'LifecycleEvent', + 'WsEvent', 'WsOpen', 'WsAuthenticated', 'WsDegraded', diff --git a/ibind/ibkr_ws_v2/ibkr_events.py b/ibind/ibkr_ws_v2/ibkr_events.py index 138a5b43..500dbf16 100644 --- a/ibind/ibkr_ws_v2/ibkr_events.py +++ b/ibind/ibkr_ws_v2/ibkr_events.py @@ -2,74 +2,8 @@ from pydantic import Field -from ws_v2.events import WsEvent - - -# class IbkrWsKey(Enum): -# # generic -# GENERIC = 'GENERIC' -# UNSUBSCRIPTION = 'UNSUBSCRIPTION' -# SERVER_ID = 'SERVER_ID' -# WAITING_FOR_SESSION = 'WAITING_FOR_SESSION' -# -# # unsolicited -# ACCOUNT_UPDATE = 'ACCOUNT_UPDATE' -# AUTHENTICATION_STATUS = 'AUTHENTICATION_STATUS' -# BULLETIN = 'BULLETIN' -# ERROR = 'ERROR' -# SYSTEM = 'SYSTEM' -# NOTIFICATION = 'NOTIFICATION' -# -# # subscription-based -# ACCOUNT_SUMMARY = 'ACCOUNT_SUMMARY' -# ACCOUNT_LEDGER = 'ACCOUNT_LEDGER' -# MARKET_DATA = 'MARKET_DATA' -# MARKET_HISTORY = 'MARKET_HISTORY' -# PRICE_LADDER = 'PRICE_LADDER' -# ORDERS = 'ORDERS' -# PNL = 'PNL' -# TRADES = 'TRADES' -# -# # internal -# LIFECYCLE = 'LIFECYCLE' -# -# @classmethod -# def from_topic(cls, topic): -# topic_to_key = { -# 'sd': IbkrWsKey.ACCOUNT_SUMMARY, -# 'ld': IbkrWsKey.ACCOUNT_LEDGER, -# 'md': IbkrWsKey.MARKET_DATA, -# 'mh': IbkrWsKey.MARKET_HISTORY, -# 'bd': IbkrWsKey.PRICE_LADDER, -# 'or': IbkrWsKey.ORDERS, -# 'pl': IbkrWsKey.PNL, -# 'tr': IbkrWsKey.TRADES, -# } -# if topic in topic_to_key: -# return topic_to_key[topic] -# raise ValueError(f"No enum member associated with topic '{topic}'") -# -# @property -# def topic(self): -# """ -# Gets the solicited topic string associated with the enum member. -# -# Returns: -# str: The topic string corresponding to the enum member. -# """ -# return { -# IbkrWsKey.ACCOUNT_SUMMARY: 'sd', -# IbkrWsKey.ACCOUNT_LEDGER: 'ld', -# IbkrWsKey.MARKET_DATA: 'md', -# IbkrWsKey.MARKET_HISTORY: 'mh', -# IbkrWsKey.PRICE_LADDER: 'bd', -# IbkrWsKey.ORDERS: 'or', -# IbkrWsKey.PNL: 'pl', -# IbkrWsKey.TRADES: 'tr', -# }[self] -# -# def __str__(self): -# return self.value +from ibind.events import WsEvent + class GenericIbkrEvent(WsEvent): @@ -171,27 +105,11 @@ class Trades(IbkrTopicEvent): # == Derived == # =============== class ServerId(WsEvent): - target_event_type: type['IbkrTopicEvent'] + target_event_type: type[IbkrTopicEvent] conid: str server_id: str class Unsubscription(WsEvent): - target_event_type: type['IbkrTopicEvent'] + target_event_type: type[IbkrTopicEvent] conid: str | None = None - - -def get_ibkr_topic_event(topic: str): - topic_to_event_type = { - 'sd': AccountSummary, - 'ld': AccountLedger, - 'md': MarketData, - 'mh': MarketHistory, - 'bd': PriceLadder, - 'or': Orders, - 'pl': Pnl, - 'tr': Trades, - } - if topic in topic_to_event_type: - return topic_to_event_type[topic] - raise ValueError(f"No Ibkr event associated with topic '{topic}'") diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index 2bb83995..c6b527ca 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -4,15 +4,33 @@ from client import ibkr_definitions from client.ibkr_utils import extract_conid -from ibkr_ws_v2 import ibkr_events -from ibkr_ws_v2.ibkr_events import GenericIbkrEvent, get_ibkr_topic_event, IbkrTopicEvent -from support.logs import project_logger -from support.py_utils import UNDEFINED, OneOrMany -from ws_v2.events import WsEvent + +# from ibkr_ws_v2 import ibkr_events +from ibind import events +from ibind.events import GenericIbkrEvent, IbkrTopicEvent +from ibind.support.logs import project_logger +from ibind.support.py_utils import UNDEFINED, OneOrMany +from ibind.events import WsEvent _LOGGER = project_logger('ibkr_ws_client') +def get_ibkr_topic_event(topic: str): + topic_to_event_type = { + 'sd': events.AccountSummary, + 'ld': events.AccountLedger, + 'md': events.MarketData, + 'mh': events.MarketHistory, + 'bd': events.PriceLadder, + 'or': events.Orders, + 'pl': events.Pnl, + 'tr': events.Trades, + } + if topic in topic_to_event_type: + return topic_to_event_type[topic] + raise ValueError(f"No Ibkr event associated with topic '{topic}'") + + def parse_raw_message(raw_message: str): message = json.loads(raw_message) topic = message.get('topic', UNDEFINED) @@ -25,12 +43,8 @@ def parse_raw_message(raw_message: str): return message, topic, data -class IbkrRouter(): - def __init__( - self, - log_raw_messages: bool = False, - unwrap_market_data: bool = True - ): +class IbkrRouter: + def __init__(self, log_raw_messages: bool = False, unwrap_market_data: bool = True): self._log_raw_messages = log_raw_messages self._unwrap_market_data = unwrap_market_data self._server_id_conid_pairs: Dict[type[IbkrTopicEvent], Dict[str, str]] = defaultdict(dict) @@ -44,33 +58,33 @@ def _preprocess_market_data_message(self, data: dict) -> OneOrMany[WsEvent]: return [] if not self._unwrap_market_data: - return ibkr_events.MarketData(conid=data['conid'], data=data) + return events.MarketData(conid=data['conid'], data=data) unwrapped_data = {} for key, value in data.items(): if key in ibkr_definitions.snapshot_by_id: unwrapped_data[ibkr_definitions.snapshot_by_id[key]] = value - return ibkr_events.MarketData(conid=str(data['conid']), data=unwrapped_data) + return events.MarketData(conid=str(data['conid']), data=unwrapped_data) def _preprocess_market_history_message(self, data: dict) -> OneOrMany[WsEvent]: - mh_server_id_conid_pairs = self._server_id_conid_pairs[ibkr_events.MarketHistory] - events = [] + mh_server_id_conid_pairs = self._server_id_conid_pairs[events.MarketHistory] + rv = [] conid = extract_conid(data) if 'serverId' in data and data['serverId'] not in mh_server_id_conid_pairs: mh_server_id_conid_pairs[data['serverId']] = str(conid) - events.append(ibkr_events.ServerId(conid=str(conid), server_id=data['serverId'], target_event_type=ibkr_events.MarketHistory)) + rv.append(events.ServerId(conid=str(conid), server_id=data['serverId'], target_event_type=events.MarketHistory)) - events.append(ibkr_events.MarketHistory(conid=str(conid), data=data)) - return events + rv.append(events.MarketHistory(conid=str(conid), data=data)) + return rv def _preprocess_account_ledger(self, data): - events = [] + rv = [] for entry in data['result']: if 'acctCode' not in entry: continue - event = ibkr_events.AccountLedger(data=entry, account_id=entry['acctCode']) - events.append(event) - return events + event = events.AccountLedger(data=entry, account_id=entry['acctCode']) + rv.append(event) + return rv def _preprocess_account_summary(self, data): summary = {} @@ -94,7 +108,7 @@ def _preprocess_account_summary(self, data): account_id = summary['AccountCode']['value'] summary['timestamp'] = timestamp - event = ibkr_events.AccountSummary(data=summary, account_id=account_id) + event = events.AccountSummary(data=summary, account_id=account_id) return event def _handle_subscribed_message(self, topic: str, data: dict) -> OneOrMany[WsEvent] | None: @@ -105,88 +119,87 @@ def _handle_subscribed_message(self, topic: str, data: dict) -> OneOrMany[WsEven # ValueError means we don't support this topic return None - if event_type == ibkr_events.AccountSummary: - return self._preprocess_account_summary(data) - elif event_type == ibkr_events.AccountLedger: - return self._preprocess_account_ledger(data) - elif event_type == ibkr_events.MarketData: - return self._preprocess_market_data_message(data) - elif event_type == ibkr_events.MarketHistory: - return self._preprocess_market_history_message(data) - elif event_type == ibkr_events.PriceLadder: - return ibkr_events.PriceLadder(data=data) - elif event_type == ibkr_events.Orders: - return ibkr_events.Orders(data=data) - elif event_type == ibkr_events.Pnl: - return ibkr_events.Pnl(data=data) - elif event_type == ibkr_events.Trades: - return ibkr_events.Trades(data=data) + if event_type == events.AccountSummary: + rv = self._preprocess_account_summary(data) + elif event_type == events.AccountLedger: + rv = self._preprocess_account_ledger(data) + elif event_type == events.MarketData: + rv = self._preprocess_market_data_message(data) + elif event_type == events.MarketHistory: + rv = self._preprocess_market_history_message(data) + elif event_type == events.PriceLadder: + rv = events.PriceLadder(data=data) + elif event_type == events.Orders: + rv = events.Orders(data=data) + elif event_type == events.Pnl: + rv = events.Pnl(data=data) + elif event_type == events.Trades: + rv = events.Trades(data=data) else: _LOGGER.error(f'{self}: Unhandled subscribed message: {data}') - return None + rv = None + return rv def _handle_account_update(self, message, arguments) -> OneOrMany[WsEvent]: - # _LOGGER.info(f'{self}: Account update: {arguments}') - return ibkr_events.AccountUpdate(data=arguments) + return events.AccountUpdate(data=arguments) def _handle_authentication_status(self, message, arguments) -> OneOrMany[WsEvent]: if 'authenticated' in arguments or 'competing' in arguments: - # _LOGGER.info(f'{self}: Authentication status: {arguments}') - return ibkr_events.AuthenticationStatus(data=arguments, authenticated=arguments.get('authenticated'), competing=arguments.get('competing')) + return events.AuthenticationStatus(data=arguments, authenticated=arguments.get('authenticated'), competing=arguments.get('competing')) elif ( # expected status updates that we ignore - arguments == {'message': ''} or - arguments.get('fail', '') == '' or - 'serverName' in arguments or - 'serverVersion' in arguments or - 'username' in arguments + arguments == {'message': ''} + or arguments.get('fail', '') == '' + or 'serverName' in arguments + or 'serverVersion' in arguments + or 'username' in arguments ): - # _LOGGER.info(f'{self}: Authentication silenced: {arguments}') pass return [] def _handle_bulletin(self, message) -> OneOrMany[WsEvent]: # pragma: no cover - return ibkr_events.Bulletin(message=message) + return events.Bulletin(message=message) def _handle_error(self, message) -> OneOrMany[WsEvent]: _LOGGER.error(f'{self}: on_message error: {message}') - return ibkr_events.IbkrError(message=message) + return events.IbkrError(message=message) def _handle_notification(self, data) -> OneOrMany[WsEvent]: # pragma: no cover - events = [] + rv = [] for notification in data: - # _LOGGER.info(f'{self}: IBKR notification: {notification}') - events.append(ibkr_events.Notification(message=notification)) - return events + rv.append(events.Notification(message=notification)) + return rv def _handle_market_history_unsubscribe(self, data) -> OneOrMany[WsEvent]: server_id = data['message'].split('Unsubscribed ')[-1] - mh_server_id_conid_pairs = self._server_id_conid_pairs[ibkr_events.MarketHistory] + mh_server_id_conid_pairs = self._server_id_conid_pairs[events.MarketHistory] if server_id in mh_server_id_conid_pairs: conid = mh_server_id_conid_pairs[server_id] _LOGGER.info(f'{self}: Received unsubscribing confirmation for server_id={server_id!r}, conid={conid!r}.') if conid is not None: - return ibkr_events.Unsubscription(target_event_type=ibkr_events.MarketHistory, conid=str(conid)) + return events.Unsubscription(target_event_type=events.MarketHistory, conid=str(conid)) _LOGGER.warning(f'{self}: Unknown conid={conid!r}. Cannot mark the subscription as unsubscribed.') else: - _LOGGER.warning(f'{self}: Received unsubscribing confirmation for unknown server_id={server_id!r}. Existing server_ids: {mh_server_id_conid_pairs}') + _LOGGER.warning( + f'{self}: Received unsubscribing confirmation for unknown server_id={server_id!r}. Existing server_ids: {mh_server_id_conid_pairs}' + ) return [] def _handle_message_without_topic(self, message: dict) -> OneOrMany[WsEvent]: if 'message' in message: if message['message'] == 'waiting for session': _LOGGER.info(f'{self}: Waiting for an active IBKR session.') - return ibkr_events.WaitingForSession() + return events.WaitingForSession() if 'Unsubscribed' in message['message']: return self._handle_market_history_unsubscribe(message) elif 'result' in message: if message['result'] == 'unsubscribed from summary': - return ibkr_events.Unsubscription(target_event_type=ibkr_events.AccountSummary) + return events.Unsubscription(target_event_type=events.AccountSummary) elif message['result'] == 'unsubscribed from ledger': - return ibkr_events.Unsubscription(target_event_type=ibkr_events.AccountLedger) + return events.Unsubscription(target_event_type=events.AccountLedger) _LOGGER.error(f'{self}: Unrecognised message without a topic: {message}') return GenericIbkrEvent(message=message) @@ -197,40 +210,41 @@ def route(self, raw_message: str) -> OneOrMany[WsEvent]: message, topic, arguments = parse_raw_message(raw_message) if 'error' in message: - return self._handle_error(message) + rv = self._handle_error(message) elif topic is None: # in general most message should carry a topic, other than for few exceptions - return self._handle_message_without_topic(message) + rv = self._handle_message_without_topic(message) elif topic == 'tic': # self._tic_message = message - return ibkr_events.System(data=message) + rv = events.System(data=message) elif topic == 'system': - return ibkr_events.System(data=message) + rv = events.System(data=message) elif topic == 'act': - return self._handle_account_update(message, arguments) + rv = self._handle_account_update(message, arguments) elif topic == 'blt': - return self._handle_bulletin(message) + rv = self._handle_bulletin(message) elif topic == 'ntf': - return self._handle_notification(arguments) + rv = self._handle_notification(arguments) elif topic == 'sts': - return self._handle_authentication_status(message, arguments) + rv = self._handle_authentication_status(message, arguments) elif topic == 'error': - return self._handle_error(message) + rv = self._handle_error(message) else: - events = self._handle_subscribed_message(topic, message) - if events is None: + rv = self._handle_subscribed_message(topic, message) + if rv is None: _LOGGER.error(f'{self}: topic "{topic}" subscribed but lacking a handler. Message: {message}') - events = GenericIbkrEvent(message=message, topic=topic, data=arguments) - return events + rv = GenericIbkrEvent(message=message, topic=topic, data=arguments) + + return rv def __str__(self): return f'{self.__class__.__qualname__}()' diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index c77534ef..05f60bb6 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -1,10 +1,10 @@ import json -from typing import Tuple +from typing import Tuple, List from pydantic import Field -from ibkr_ws_v2 import ibkr_events -from ibkr_ws_v2.ibkr_events import AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary, IbkrTopicEvent +from ibind import events +from ibind.events import AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary, IbkrTopicEvent from support.py_utils import filter_none from ws_v2.subscriptions import Subscription, SubscriptionResolver @@ -15,13 +15,13 @@ def make_binding_key( account_id=None, exchange=None ): - if event_type in [ibkr_events.MarketData, ibkr_events.MarketHistory]: + if event_type in [events.MarketData, events.MarketHistory]: return f"{event_type.topic}+{conid}" - elif event_type in [ibkr_events.AccountLedger, ibkr_events.AccountSummary]: + elif event_type in [events.AccountLedger, events.AccountSummary]: return f"{event_type.topic}+{account_id}" - elif event_type in [ibkr_events.PriceLadder]: + elif event_type in [events.PriceLadder]: return f"{event_type.topic}+{account_id}+{conid}" + (f"+{exchange}" if exchange is not None else '') - elif event_type in [ibkr_events.Orders, ibkr_events.Pnl, ibkr_events.Trades]: + elif event_type in [events.Orders, events.Pnl, events.Trades]: return event_type.topic else: raise ValueError(f'Unsupported event type: {event_type}') @@ -33,13 +33,13 @@ def __init__(self, account_id): def _resolve_subscribing_event(self, event) -> str: event_type = type(event) - if event_type in [ibkr_events.MarketData, ibkr_events.MarketHistory]: + if event_type in [events.MarketData, events.MarketHistory]: return make_binding_key(event_type, conid=event.conid) - elif event_type in [ibkr_events.AccountLedger, ibkr_events.AccountSummary]: + elif event_type in [events.AccountLedger, events.AccountSummary]: return make_binding_key(event_type, account_id=event.account_id) - elif event_type in [ibkr_events.PriceLadder]: + elif event_type in [events.PriceLadder]: return make_binding_key(event_type, conid=event.conid, account_id=event.account_id, exchange=event.exchange) - elif event_type in [ibkr_events.Orders, ibkr_events.Pnl, ibkr_events.Trades]: + elif event_type in [events.Orders, events.Pnl, events.Trades]: return make_binding_key(event_type) else: raise ValueError(f'Unsupported event: {event}') @@ -112,7 +112,7 @@ def binding_key(self): class MarketDataSubscription(IbkrSubscription): event_type: type[IbkrTopicEvent] = MarketData conid: str - fields: tuple[str, ...] + fields: List[str] def subscribe_payload(self) -> str: fields_str = json.dumps({"fields": list(self.fields)}, separators=(',', ':')) diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 6ccd8a00..4230f88c 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -1,11 +1,11 @@ import json from collections import defaultdict -from typing import Union, List, Dict +from typing import Union, List, Dict, Type +from ibind import events import var from ibind import IbkrClient -from ibkr_ws_v2 import ibkr_events -from ibkr_ws_v2.ibkr_events import IbkrTopicEvent +from ibind.events import IbkrTopicEvent from ibkr_ws_v2.ibkr_router import IbkrRouter from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription from support.logs import project_logger @@ -98,18 +98,18 @@ def __init__( ) self._mh_subscriptions: List[MarketHistorySubscription] = [] - self._conid_server_id_pairs: Dict[type[ibkr_events.IbkrTopicEvent], Dict[str, str]] = defaultdict(dict) + self._conid_server_id_pairs: Dict[type[events.IbkrTopicEvent], Dict[str, str]] = defaultdict(dict) def _register_internal_callbacks(self): - self._internal_sink.on(ibkr_events.AuthenticationStatus, self._on_authentication_status) - self._internal_sink.on(ibkr_events.WaitingForSession, self._set_unauthenticated) - self._internal_sink.on(ibkr_events.System, self._on_system) - self._internal_sink.on(ibkr_events.ServerId, self._on_server_id) + self._internal_sink.on(events.AuthenticationStatus, self._on_authentication_status) + self._internal_sink.on(events.WaitingForSession, self._set_unauthenticated) + self._internal_sink.on(events.System, self._on_system) + self._internal_sink.on(events.ServerId, self._on_server_id) def _set_unauthenticated(self, _): self._runtime.set_authenticated(False) - def _on_authentication_status(self, event: ibkr_events.AuthenticationStatus): + def _on_authentication_status(self, event: events.AuthenticationStatus): if event.authenticated is False: _LOGGER.error(f'{self}: Status unauthenticated: {event}') elif event.competing is True: @@ -117,11 +117,11 @@ def _on_authentication_status(self, event: ibkr_events.AuthenticationStatus): self._runtime.set_authenticated(event.authenticated) - def _on_system(self, event: ibkr_events.System): + def _on_system(self, event: events.System): if 'hb' in event.data: self._runtime.set_last_heartbeat(int(event.data['hb']) / 1000) - def _on_server_id(self, event: ibkr_events.ServerId): + def _on_server_id(self, event: events.ServerId): self._conid_server_id_pairs[event.target_event_type][event.conid] = event.server_id for subscription in self._mh_subscriptions: if subscription.event_type == event.target_event_type and subscription.conid == event.conid and not subscription.has_server_id(): @@ -172,7 +172,7 @@ def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: def get_status(self, binding_key: str) -> BindingStatus: return self._runtime.subscription_controller.get_status(binding_key) - def get_server_id(self, event_type: type[IbkrTopicEvent], conid: str) -> str: + def get_server_id(self, event_type: Type[IbkrTopicEvent], conid: str) -> str: return self._conid_server_id_pairs[event_type][conid] def _handle_mh_unsubscription(self, subscription: MarketHistorySubscription): diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/events.py index 8b0dcbd0..bb2825f0 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/events.py @@ -1,5 +1,4 @@ import threading -from collections import defaultdict from datetime import datetime from queue import Queue, Full, Empty from threading import Thread, Event @@ -32,9 +31,6 @@ def __repr__(self): def _format(self): data = self.model_dump() - # remove key (already logged elsewhere) - data.pop("key", None) - # normalize values for k, v in data.items(): if isinstance(v, datetime): @@ -56,23 +52,23 @@ def _format(self): class LifecycleEvent(WsEvent): - ... + pass class WsOpen(LifecycleEvent): - ... + pass class WsAuthenticated(LifecycleEvent): - ... + pass class WsDegraded(LifecycleEvent): - ... + pass class WsReady(LifecycleEvent): - ... + pass class WsClose(LifecycleEvent): @@ -96,7 +92,7 @@ def emit(self, event: "WsEvent") -> None: class LogSink: def emit(self, event: WsEvent) -> None: - _LOGGER.debug(event) + _LOGGER.info(event) class NoopSink: @@ -108,14 +104,13 @@ def emit(self, event: WsEvent) -> None: class CallbackSink: - def __init__(self): - self._callbacks: Dict[type[WsEvent], List[Callable[[WsEvent], None]]] = defaultdict(list) + _callbacks: Dict[type[WsEvent], List[Callable[[WsEvent], None]]] = {} def on(self, event_type: type[WsEvent], callback: Callable[[T], None]) -> None: - self._callbacks[event_type].append(callback) + self._callbacks.setdefault(event_type, []).append(callback) def emit(self, event: WsEvent) -> None: - for callback in self._callbacks[type(event)]: + for callback in self._callbacks.get(type(event), []): try: callback(event) except Exception as e: @@ -126,8 +121,7 @@ def __str__(self): class QueueSink: - def __init__(self): - self._queues = {} + _queues = {} def new_queue_accessor(self, event_type: type[WsEvent]) -> QueueAccessor: return QueueAccessor(self._get_queue(event_type), event_type) @@ -252,14 +246,14 @@ def _consume_queue(self): _LOGGER.error(f'{self}: Exception emitting event to sink: {exception_to_string(e)}') def _cycle(self): - _LOGGER.info(f'{self}: AsyncSink thread started ({tname()})') + _LOGGER.debug(f'{self}: AsyncSink thread started ({tname()})') while self._running: self._wait_event.clear() self._wait_event.wait(self._cycle_interval) self._consume_queue() self._consume_queue() - _LOGGER.info(f'{self}: AsyncSink thread stopped ({tname()})') + _LOGGER.debug(f'{self}: AsyncSink thread stopped ({tname()})') def __str__(self): return f'{self.__class__.__qualname__}({self._queue.qsize()})' @@ -271,7 +265,7 @@ def __str__(self): class Router(Protocol): def route(self, raw_message) -> OneOrMany[WsEvent]: - ... + pass def __str__(self): return f'{self.__class__.__qualname__}()' diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/subscriptions.py index a0d0b60f..c1604300 100644 --- a/ibind/ws_v2/subscriptions.py +++ b/ibind/ws_v2/subscriptions.py @@ -8,7 +8,7 @@ from ibind.support.logs import project_logger from ibind.support.py_utils import exception_to_string -from ws_v2.events import WsEvent +from ibind.events import WsEvent _LOGGER = project_logger('ibkr_ws_client') diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index eb9c9fe2..d54615e8 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -9,8 +9,9 @@ from support.logs import project_logger from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, noop -from ws_v2 import events -from ws_v2.events import WsEvent, EventSink, Router, CallbackSink, AsyncSink +from ibind import events +from ibind.events import WsEvent +from ws_v2.events import EventSink, Router, CallbackSink, AsyncSink from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect @@ -108,7 +109,7 @@ def _new_transport(self): ) def _set_state(self, value): - _LOGGER.info(f'{self}: {self._state.value} -> {value.value}') + _LOGGER.debug(f'{self}: {self._state.value} -> {value.value}') with self._state_lock: self._state = value @@ -124,8 +125,7 @@ def _websocket_ready(self): _LOGGER.info(f'{self}: Websocket ready, setting last_heartbeat to {self._last_heartbeat}') def set_authenticated(self, value: bool): - if value != self._authenticated: - _LOGGER.info(f'{self}: Authenticated: {value}') + previous_value = self._authenticated self._authenticated = value if value and self._state == WsState.OPEN: @@ -134,6 +134,8 @@ def set_authenticated(self, value: bool): if value == False and self._state == self._ready_state: self.state_degraded() + if value != previous_value: + _LOGGER.info(f'{self}: Connection {"authenticated" if value else "unauthenticated"}') def state_degraded(self): was_already_degraded = self._state == WsState.DEGRADED @@ -170,7 +172,6 @@ def _stop_transport_thread(self) -> bool: return not is_alive except Exception as e: _LOGGER.error(f'{self}: Failed to stop transport thread: {e}') - # TODO: decide what to do if transport disconnect fails return False @@ -182,7 +183,7 @@ def start(self): _LOGGER.error(f'{self}: Runtime thread must be stopped and joined before starting') return - _LOGGER.info(f'{self}: Starting runtime') + _LOGGER.info(f'{self}: Starting WebSocket runtime') self._set_state(WsState.STARTING) self._running = True @@ -202,16 +203,19 @@ def stop(self): if threading.current_thread() == self._runtime_thread: raise RuntimeError(f'{self}: Stopping runtime called from within runtime thread. Ensure it is called from a separate thread') - _LOGGER.info(f'{self}: Stopping runtime') + _LOGGER.info(f'{self}: Stopping WebSocket runtime') # wait until one more pass of the runtime thread has occurred to allow unsubscriptions to complete wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) self._wait_event.set() wait_until(lambda: not self._wait_event.is_set(), timeout=self._connection_timeout) - # TODO: decide which thread should stop first - transport or runtime self._set_state(WsState.STOPPING) - self._stop_transport_thread() + transport_thread_stopped = self._stop_transport_thread() + if not transport_thread_stopped: + _LOGGER.error(f'{self}: Failed to stop transport thread, abandoning...') + self._transport_thread = None + self._transport.set_degraded(True) self._running = False if self._runtime_thread is not None: @@ -232,7 +236,7 @@ def send(self, payload: str) -> bool: _LOGGER.error(f'{self}: State must be {self._ready_state.value} before sending payloads, found {self._state.value}') return False - _LOGGER.debug(f'{self}: Sending payload: {payload}') + _LOGGER.info(f'{self}: Sending payload: {payload}') return self._transport.send(payload) @@ -258,8 +262,8 @@ def restart_transport(self): if threading.current_thread() == self._transport_thread: raise RuntimeError(f'{self}: Resetting transport thread called from within transport thread. Ensure it is called from a separate thread') - success = self._stop_transport_thread() - if not success: + transport_thread_stopped = self._stop_transport_thread() + if not transport_thread_stopped: _LOGGER.error(f'{self}: Failed to stop transport thread, abandoning...') self._transport_thread = None @@ -278,7 +282,6 @@ def __str__(self): # ====================== def _transport_callback(self, te: TransportEvent): - # _LOGGER.debug(f'{self}: {te}') self._transport_queue.put(te) self._wait_event.set() @@ -348,7 +351,7 @@ def health_check(self) -> bool: return False def _cycle(self): - _LOGGER.info(f'{self}: Runtime thread started ({tname()})') + _LOGGER.debug(f'{self}: Runtime thread started ({tname()})') while self._running: self._maintain_transport() self._maintain_subscriptions() @@ -370,7 +373,7 @@ def _cycle(self): # final pass through the subscription controller to carry out final unsubscribe events self.subscription_controller.reconcile_bindings() - _LOGGER.info(f'{self}: Runtime thread stopped ({tname()})') + _LOGGER.debug(f'{self}: Runtime thread stopped ({tname()})') def _process_transport_queue(self): retry_events = [] @@ -424,9 +427,9 @@ def _handle_on_message(self, message): # pragma: no cover self._emit(event) def _handle_on_open(self): - _LOGGER.info(f'{self}: Connection open') self._last_heartbeat = None self._set_state(WsState.OPEN) + _LOGGER.info(f'{self}: Connection open') if self._state != self._ready_state: self.set_authenticated(False) self._emit(events.WsOpen()) @@ -447,9 +450,17 @@ def _handle_on_error(self, exception: Exception): self._emit(events.WsError(error=exception)) def _handle_on_close(self, close_status_code, close_msg): - _LOGGER.info(f'{self}: Connection closed') self._last_heartbeat = None + if self._state != WsState.STOPPING: + _LOGGER.info(f'{self}: Connection closed') + self.set_authenticated(False) + self.subscription_controller.invalidate_subscriptions() + else: + _LOGGER.info(f'{self}: Connection gracefully closed') + + self._set_state(WsState.CLOSED) + if close_status_code is not None or close_msg is not None: # this means an error try: msg = close_msg.decode('utf-8') @@ -458,13 +469,6 @@ def _handle_on_close(self, close_status_code, close_msg): _LOGGER.error(f'{self}: on_close error: {close_status_code} | {msg}') - if self._state != WsState.STOPPING: - self.set_authenticated(False) - self.subscription_controller.invalidate_subscriptions() - else: - _LOGGER.info(f'{self}: Gracefully closed') - - self._set_state(WsState.CLOSED) self._emit(events.WsClose(close_status_code=close_status_code, close_msg=close_msg)) def _emit(self, event: WsEvent): diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index e67cc918..4098d5c8 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -1,6 +1,6 @@ import time from datetime import datetime -from typing import Callable, Any, cast, List +from typing import Callable, Any, cast, List, Union, Dict from pydantic import BaseModel, ConfigDict, Field from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION, STATUS_NORMAL @@ -14,6 +14,13 @@ class TransportEvent(BaseModel): + """ + Base class for WebSocket transport-level events. + + Tracks when events were received and how many processing attempts have been made. + Uses a list for attempt count to allow mutation despite frozen model. + """ + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) received_at: datetime = Field(default_factory=datetime.now) attempt: List[int] = Field(default_factory=lambda: [0]) @@ -29,35 +36,51 @@ def __str__(self): class TransportOpened(TransportEvent): - ... + """Emitted when the WebSocket connection is successfully opened.""" + + pass class TransportClosed(TransportEvent): + """Emitted when the WebSocket connection is closed.""" + close_status_code: int | None close_msg: str | None class TransportError(TransportEvent): + """Emitted when a WebSocket error occurs. Note that currently WebSocketApp only emits this once, due to reconnection logic then skipping this event.""" + exception: Exception class TransportMessage(TransportEvent): + """Emitted when a message is received from the WebSocket.""" + message: str class TransportReconnect(TransportEvent): - ... + """Emitted when the WebSocket reconnects after a disconnection.""" + + pass + +class WsTransport: + """ + Manages low-level WebSocket transport using WebSocketApp. -class WsTransport(): + Handles connection lifecycle, message sending, cookie validation, and ping monitoring. + Runs in a dedicated thread and communicates via event callbacks. + """ def __init__( self, url: str, - event_callback: Callable, - sslopt: dict[str, Any], - get_cookie: Callable = noop, - get_header: Callable = noop, + event_callback: Callable[[TransportEvent], None], + sslopt: Dict[str, Any], + get_cookie: Callable[[], str | None] = noop, + get_header: Callable[[], Dict[str, Any] | None] = noop, ping_interval: float = 10, ping_timeout: float = 10, max_ping_interval: float = 20, @@ -65,6 +88,22 @@ def __init__( reconnect_timeout: float = 5, skip_utf8_validation: bool = var.IBIND_WS_SKIP_UTF8_VALIDATION, ): + """ + Create a WebSocket transport instance. + + Args: + url (str): WebSocket URL to connect to. + event_callback (Callable): Callback function invoked with TransportEvent instances. + sslopt (dict[str, Any]): SSL options for the WebSocket connection. + get_cookie (Callable, optional): Function to retrieve session cookie. Default: noop. + get_header (Callable, optional): Function to retrieve HTTP headers. Default: noop. + ping_interval (float, optional): Interval in seconds between ping messages. Default: 10. + ping_timeout (float, optional): Timeout in seconds for ping responses. Default: 10. + max_ping_interval (float, optional): Maximum acceptable time since last pong. Default: 20. + connection_timeout (float, optional): Timeout in seconds for connection operations. Default: 5. + reconnect_timeout (float, optional): Timeout in seconds before reconnect attempts. Default: 5. + skip_utf8_validation (bool, optional): Whether to skip UTF-8 validation. Default: True + """ self._url = url self._event_callback = event_callback self._sslopt = sslopt @@ -82,23 +121,36 @@ def __init__( self._degraded = False self._tname = None + self._session_lacks_authentication = False + def disconnect(self): + """Gracefully disconnect the WebSocket connection.""" if self._wsa is None: - _LOGGER.info(f'{self}: WSA is None, skipping disconnect') + _LOGGER.info(f'{self}: WebSocketApp is None, skipping disconnect') return self._wsa.close(status=STATUS_NORMAL, timeout=self._connection_timeout) def stop(self): - _LOGGER.info(f'{self}: Stopping') + """Stop the transport thread and disconnect the WebSocket.""" + _LOGGER.debug(f'{self}: Stopping transport') self._running = False self.disconnect() def reset_websocket_app(self) -> bool: + """ + Force close and recreate the WebSocketApp connection. + + Returns: + bool: True if a new WebSocketApp was successfully created, False otherwise. + + Raises: + RuntimeError: If called from within the transport thread. + """ if tname() == self._tname: raise RuntimeError(f'{self}: Resetting websocket app called from within transport thread. Ensure it is called from a separate thread') if self._wsa is None: - _LOGGER.info(f'{self}: WSA is None, skipping reset') + _LOGGER.info(f'{self}: WebSocketApp is None, skipping reset') return False _LOGGER.info(f'{self}: Reset') @@ -115,17 +167,15 @@ def reset_websocket_app(self) -> bool: def check_ping(self, max_interval: float = None) -> bool: """ - Checks the last ping response time of the WebSocketApp connection. + Check if the last pong was received within the acceptable interval. - Verifies whether the last ping response from the WebSocketApp was within the acceptable time interval - defined by 'max_ping_interval' parameter. If the last ping response exceeds this interval, a hard reset of the connection is triggered. + Args: + max_interval (float, optional): Maximum acceptable seconds since last pong. + Default: self._max_ping_interval. Returns: - bool: True if the last ping was within the acceptable interval or if the WebSocketApp is not connected, - False if the ping interval was exceeded and a hard reset was initiated. - - Note: - - A ping interval exceeding 'max_ping_interval' indicates potential issues with the WebsocketApp connection. + bool: True if last pong was within the interval or WebSocketApp is not connected, + False if the interval was exceeded. """ if self._wsa is None: return True @@ -139,26 +189,41 @@ def check_ping(self, max_interval: float = None) -> bool: return self.get_time_since_last_ping() <= max_interval def get_time_since_last_ping(self) -> float: + """Get seconds elapsed since the last pong was received.""" return abs(time.time() - self._wsa.last_pong_tm) - def fetch_cookie(self): + def fetch_cookie(self) -> Union[str, None, UNDEFINED]: """ - Using UNDEFINED since _get_cookie could in fact return a None, and they mean different things + Retrieve session cookie using the configured callback. + + Returns: + str | None | UNDEFINED: Cookie value, None if no cookie needed, or UNDEFINED if retrieval failed. """ try: - return self._get_cookie() + cookie = self._get_cookie() + if self._session_lacks_authentication: + self._session_lacks_authentication = False + return cookie except Exception as e: if isinstance(e, TimeoutError): _LOGGER.info(f'{self}: Timeout retrieving cookie') return UNDEFINED if isinstance(e, ExternalBrokerError): if e.status_code == 401: - _LOGGER.info(f'{self}: Failed to retrieve cookie due to lack of authentication') + if not self._session_lacks_authentication: + self._session_lacks_authentication = True + _LOGGER.info(f'{self}: Failed to retrieve cookie due to lack of authentication. Continuing reattempts silently until authentication is reestablished.') return UNDEFINED _LOGGER.error(f'{self}: Failed to retrieve cookie: {exception_to_string(e)}') return UNDEFINED def check_cookie(self) -> bool: + """ + Verify the current cookie matches the stored cookie. + + Returns: + bool: True if cookies match, False if retrieval failed or cookies differ. + """ cookie = self.fetch_cookie() if cookie is UNDEFINED: return False @@ -169,14 +234,28 @@ def check_cookie(self) -> bool: return True def set_degraded(self, value): + """Mark the transport as degraded to suppress event callbacks.""" self._degraded = value def is_ready(self) -> bool: + """Check if the WebSocketApp is ready to send messages.""" return self._wsa is not None and self._wsa.ready and self._wsa.sock is not None and self._wsa.sock.sock is not None def send(self, payload: str) -> bool: + """ + Send a message through the WebSocket. + + Args: + payload (str): Message to send. + + Returns: + bool: True if sent successfully, False otherwise. + + Raises: + RuntimeError: If the WebSocketApp is not ready. + """ if not self.is_ready(): - raise RuntimeError(f'{self}: WSA socket is not ready') + raise RuntimeError(f'{self}: WebSocketApp socket is not ready') try: self._wsa.send(payload) @@ -243,14 +322,13 @@ def _on_reconnect(self, wsa: WebSocketApp): self._event_callback(TransportReconnect()) - def new_wsa(self): + def _new_wsa(self): + """Create a new WebSocketApp instance with current cookie and header.""" cookie = self.fetch_cookie() if cookie is UNDEFINED: return None self._cookie = cookie - if cookie is not None: - _LOGGER.info(f'{self}: Current cookie: {cookie}') try: self._header = self._get_header() @@ -259,7 +337,7 @@ def new_wsa(self): return None if not self._running: - # Transport got stopped between invocation of new_wsa and creating one + # Transport got stopped between invocation of this function and creating a WebSocketApp return None wsa = WebSocketApp( @@ -272,11 +350,13 @@ def new_wsa(self): cookie=self._cookie, header=self._header, ) + _LOGGER.debug(f'{self}: Created new WebSocketApp instance{f", cookie: {cookie}" if cookie is not None else ""}') return wsa def connect(self): - _LOGGER.info(f'{self}: Transport thread started ({tname()})') + """Main transport thread loop that maintains the WebSocket connection.""" + _LOGGER.debug(f'{self}: Transport thread started ({tname()})') self._tname = tname() @@ -284,7 +364,7 @@ def connect(self): while self._running: if self._wsa is None: - wsa = self.new_wsa() + wsa = self._new_wsa() if wsa is None: time.sleep(1) continue @@ -296,9 +376,9 @@ def connect(self): ping_timeout=self._ping_interval * 0.95, # the timeout is set to a little sooner than the interval sslopt=self._sslopt, reconnect=cast(int, self._reconnect_timeout), # floats are de facto valid, casting only for the linter - skip_utf8_validation=self._skip_utf8_validation + skip_utf8_validation=self._skip_utf8_validation, ) - _LOGGER.info(f'{self}: WSA run_forever stopped gracefully') + _LOGGER.debug(f'{self}: WebSocketApp stopped gracefully') except Exception as e: if 'url is invalid' in str(e): _LOGGER.error(f'{self}: URL is invalid: {self._url}') @@ -307,4 +387,4 @@ def connect(self): finally: self._wsa = None - _LOGGER.info(f'{self}: Transport thread stopped ({tname()})') + _LOGGER.debug(f'{self}: Transport thread stopped ({tname()})') From bffdbe55ca31102b118be89972c65844579d93f6 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 13:52:39 +0200 Subject: [PATCH 24/32] fix(ws_transport): fixed return type of fetch_cookie for Python <=3.10 --- ibind/ws_v2/ws_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index 4098d5c8..704ced61 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -192,7 +192,7 @@ def get_time_since_last_ping(self) -> float: """Get seconds elapsed since the last pong was received.""" return abs(time.time() - self._wsa.last_pong_tm) - def fetch_cookie(self) -> Union[str, None, UNDEFINED]: + def fetch_cookie(self) -> Union[str, None]: """ Retrieve session cookie using the configured callback. From 4f88a354fbc8ee90dd61dcec284de2b48b50f405 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 15:06:28 +0200 Subject: [PATCH 25/32] refactor(ws_v2): ws_v2.events was renamed to ws_v2._ws_events to indicate privacy, expecting imports to point at ibind.events. Also added docstrings --- ibind/__init__.py | 4 +- ibind/events/__init__.py | 2 +- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 2 +- ibind/ws_v2/{events.py => _ws_events.py} | 207 ++++++++++++++++++++--- ibind/ws_v2/ws_runtime.py | 2 +- 5 files changed, 190 insertions(+), 27 deletions(-) rename ibind/ws_v2/{events.py => _ws_events.py} (51%) diff --git a/ibind/__init__.py b/ibind/__init__.py index de7b0797..65b8f627 100644 --- a/ibind/__init__.py +++ b/ibind/__init__.py @@ -12,8 +12,8 @@ from ibind.support.py_utils import execute_in_parallel from ibind import events, subscriptions from ibind.ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 -from ibind.ws_v2.events import LogSink, QueueSink, CallbackSink, CompositeSink from ibind.ws_v2.subscriptions import SubscriptionHandle +from ibind.ws_v2._ws_events import LogSink, QueueSink, CallbackSink, CompositeSink, NoopSink, EventSink __all__ = [ @@ -37,6 +37,8 @@ 'events', 'subscriptions', 'IbkrWsClientV2', + 'EventSink', + 'NoopSink', 'LogSink', 'QueueSink', 'CallbackSink', diff --git a/ibind/events/__init__.py b/ibind/events/__init__.py index b62de9f1..b86fcae3 100644 --- a/ibind/events/__init__.py +++ b/ibind/events/__init__.py @@ -1,4 +1,4 @@ -from ibind.ws_v2.events import LifecycleEvent, WsOpen, WsAuthenticated, WsDegraded, WsReady, WsClose, WsError, WsEvent +from ibind.ws_v2._ws_events import LifecycleEvent, WsOpen, WsAuthenticated, WsDegraded, WsReady, WsClose, WsError, WsEvent from ibind.ibkr_ws_v2.ibkr_events import GenericIbkrEvent, IbkrError, WaitingForSession, Notification, Bulletin, AccountUpdate, System, AuthenticationStatus, IbkrTopicEvent, AccountSummary, AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, ServerId, Unsubscription diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index 4230f88c..b1d75643 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -10,8 +10,8 @@ from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription from support.logs import project_logger from support.py_utils import OneOrMany, ensure_list_arg -from ws_v2.events import EventSink, CallbackSink, Router, AsyncSink, NoopSink from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle, BindingStatus +from ws_v2._ws_events import EventSink, CallbackSink, Router, AsyncSink, NoopSink from ws_v2.ws_runtime import WsRuntime, WsState _LOGGER = project_logger('ibkr_ws_client') diff --git a/ibind/ws_v2/events.py b/ibind/ws_v2/_ws_events.py similarity index 51% rename from ibind/ws_v2/events.py rename to ibind/ws_v2/_ws_events.py index bb2825f0..9890c93b 100644 --- a/ibind/ws_v2/events.py +++ b/ibind/ws_v2/_ws_events.py @@ -10,6 +10,8 @@ from support.logs import project_logger from support.py_utils import OneOrMany, exception_to_string, tname +__all__ = [] + _LOGGER = project_logger('ibkr_ws_client') @@ -17,8 +19,15 @@ # == Events Classes == # ====================== -class WsEvent(BaseModel): - model_config = ConfigDict(frozen=True, extra="forbid") + +class WsEvent(BaseModel): # pragma: no cover + """ + Base class for all WebSocket events. + + Immutable event model that tracks when it was received. + """ + + model_config = ConfigDict(frozen=True, extra='forbid') received_at: datetime = Field(default_factory=datetime.now) @@ -39,45 +48,56 @@ def _format(self): data[k] = str(v) # move received_at to the end - items = [(k, v) for k, v in data.items() if k != "received_at"] - if "received_at" in data: - items.append(("received_at", data["received_at"])) + items = [(k, v) for k, v in data.items() if k != 'received_at'] + if 'received_at' in data: + items.append(('received_at', data['received_at'])) - fields = ", ".join( - f"{k}={v}" if isinstance(v, str) and "T" in v else f"{k}={repr(v)}" - for k, v in items - ) + fields = ', '.join(f'{k}={v}' if isinstance(v, str) and 'T' in v else f'{k}={repr(v)}' for k, v in items) - return f"{self.__class__.__name__}({fields})" + return f'{self.__class__.__name__}({fields})' class LifecycleEvent(WsEvent): + """Base class for WebSocket connection lifecycle events.""" + pass class WsOpen(LifecycleEvent): + """Emitted when the WebSocket connection is successfully opened.""" + pass class WsAuthenticated(LifecycleEvent): + """Emitted when the WebSocket connection is authenticated.""" + pass class WsDegraded(LifecycleEvent): + """Emitted when the WebSocket connection enters a degraded state.""" + pass class WsReady(LifecycleEvent): + """Emitted when the WebSocket connection is ready for use.""" + pass class WsClose(LifecycleEvent): + """Emitted when the WebSocket connection is closed.""" + close_status_code: int | None close_msg: str | None class WsError(LifecycleEvent): - model_config = ConfigDict(frozen=True, extra="forbid", arbitrary_types_allowed=True) + """Emitted when a WebSocket error occurs.""" + + model_config = ConfigDict(frozen=True, extra='forbid', arbitrary_types_allowed=True) error: Exception @@ -85,45 +105,88 @@ class WsError(LifecycleEvent): # == Sinks == # ============= -class EventSink(Protocol): - def emit(self, event: "WsEvent") -> None: + +class EventSink(Protocol): # pragma: no cover + """Protocol for objects that can receive and process WebSocket events.""" + + def emit(self, event: 'WsEvent') -> None: pass -class LogSink: +class LogSink: # pragma: no cover + """Sink that logs events using the project logger.""" + def emit(self, event: WsEvent) -> None: _LOGGER.info(event) -class NoopSink: +class NoopSink: # pragma: no cover + """Sink that discards all events without processing.""" + def emit(self, event: WsEvent) -> None: pass -T = TypeVar("T", bound=WsEvent) +T = TypeVar('T', bound=WsEvent) class CallbackSink: + """ + Sink that invokes registered callbacks for specific event types. + + Callbacks are registered per event type and invoked when matching events are emitted. + Exceptions from callbacks are logged but do not propagate. + """ + _callbacks: Dict[type[WsEvent], List[Callable[[WsEvent], None]]] = {} def on(self, event_type: type[WsEvent], callback: Callable[[T], None]) -> None: + """ + Register a callback for a specific event type. + + Args: + event_type (type[WsEvent]): The event type to listen for. + callback (Callable): Function to invoke when events of this type are emitted. + """ self._callbacks.setdefault(event_type, []).append(callback) def emit(self, event: WsEvent) -> None: + """ + Emit an event to all registered callbacks for its type. + + Args: + event (WsEvent): The event to emit. + """ for callback in self._callbacks.get(type(event), []): try: callback(event) except Exception as e: - _LOGGER.error(f'{self}: Exception emitting event to callback: {exception_to_string(e)}') + _LOGGER.error(f'{self}: Exception emitting event to callback {callback.__name__}: {exception_to_string(e)}') - def __str__(self): + def __str__(self): # pragma: no cover return f'{self.__class__.__qualname__}()' class QueueSink: + """ + Sink that stores events in separate queues per event type. + + Maintains a dictionary of queues, one for each event type. Events can be + retrieved synchronously or asynchronously via queue accessors. + """ + _queues = {} def new_queue_accessor(self, event_type: type[WsEvent]) -> QueueAccessor: + """ + Create a queue accessor for a specific event type. + + Args: + event_type (type[WsEvent]): The event type to access. + + Returns: + QueueAccessor: Accessor for the queue associated with this event type. + """ return QueueAccessor(self._get_queue(event_type), event_type) def _get_queue(self, event_type: type[WsEvent]) -> Queue: # pragma: no cover @@ -133,36 +196,91 @@ def _get_queue(self, event_type: type[WsEvent]) -> Queue: # pragma: no cover self._queues[event_type] = Queue() return self._queues[event_type] - def get(self, event_type: type[WsEvent], block: bool = False, timeout=None) -> Any: + def get(self, event_type: type[WsEvent], block: bool = False, timeout: float = None) -> Any: + """ + Retrieve an event from the queue for a specific event type. + + Args: + event_type (type[WsEvent]): The event type to retrieve. + block (bool, optional): Whether to block until an event is available. Default: False. + timeout (float, optional): Maximum time to block in seconds. Default: None. + + Returns: + WsEvent | None: The retrieved event, or None if the queue is empty and block=False. + """ try: return self._get_queue(event_type).get(block=block, timeout=timeout) except Empty: return None def empty(self, event_type: type[WsEvent]) -> bool: + """ + Check if the queue for a specific event type is empty. + + Args: + event_type (type[WsEvent]): The event type to check. + + Returns: + bool: True if the queue is empty, False otherwise. + """ return self._get_queue(event_type).empty() def emit(self, event: WsEvent) -> None: + """ + Emit an event by adding it to the queue for its type. + + Args: + event (WsEvent): The event to emit. + """ queue = self._get_queue(type(event)) queue.put(event) + def __str__(self): # pragma: no cover + return f'{self.__class__.__qualname__}()' + class CompositeSink: + """ + Sink that forwards events to multiple child sinks. + + Exceptions from individual sinks are logged but do not prevent other sinks + from receiving the event. + """ + def __init__(self, *sinks: EventSink): + """ + Create a composite sink. + + Args: + *sinks (EventSink): One or more sinks to forward events to. + """ self._sinks = sinks def emit(self, event: WsEvent) -> None: + """ + Emit an event to all registered sinks. + + Args: + event (WsEvent): The event to emit. + """ for sink in self._sinks: try: sink.emit(event) except Exception as e: _LOGGER.error(f'{self}: Exception emitting event to sink: {exception_to_string(e)}') - def __str__(self): + def __str__(self): # pragma: no cover return f'{self.__class__.__qualname__}()' class AsyncSink: + """ + Sink that forwards events to another sink asynchronously via a background thread. + + Events are queued and processed in a separate thread. When the queue is full, + events are dropped according to the drop_oldest policy. + """ + def __init__( self, sink: EventSink, @@ -171,6 +289,17 @@ def __init__( stop_timeout: float = 5, cycle_interval: float = 0.25, ): + """ + Create an asynchronous sink. + + Args: + sink (EventSink): The sink to forward events to. + maxsize (int, optional): Maximum queue size. Default: 10,000. + drop_oldest (bool, optional): Whether to drop oldest events when full. + If False, drops newest events. Default: True. + stop_timeout (float, optional): Maximum time to wait for thread to stop in seconds. Default: 5. + cycle_interval (float, optional): Interval between queue processing cycles in seconds. Default: 0.25. + """ self._sink = sink self._queue = Queue(maxsize=maxsize) self._drop_oldest = drop_oldest @@ -182,14 +311,24 @@ def __init__( self._wait_event = Event() def start(self): + """Start the background thread for processing events.""" if self._running: return self._running = True - self._thread = Thread(target=self._cycle, name="async_sink_thread", daemon=True) + self._thread = Thread(target=self._cycle, name='async_sink_thread', daemon=True) self._thread.start() def stop(self) -> bool: + """ + Stop the background thread and discard remaining events. + + Returns: + bool: True if the thread stopped successfully, False if it timed out. + + Raises: + RuntimeError: If called from within the async sink thread. + """ if not self._running: return True @@ -212,6 +351,12 @@ def stop(self) -> bool: return succeeded def emit(self, event: WsEvent) -> None: + """ + Queue an event for asynchronous processing. + + Args: + event (WsEvent): The event to emit. + """ try: self._queue.put_nowait(event) self._wait_event.set() @@ -255,7 +400,7 @@ def _cycle(self): self._consume_queue() _LOGGER.debug(f'{self}: AsyncSink thread stopped ({tname()})') - def __str__(self): + def __str__(self): # pragma: no cover return f'{self.__class__.__qualname__}({self._queue.qsize()})' @@ -263,8 +408,24 @@ def __str__(self): # == Router == # ============== -class Router(Protocol): + +class Router(Protocol): # pragma: no cover + """ + Protocol for routing raw WebSocket messages to typed events. + + Implementations parse raw messages and convert them to one or more WsEvent instances. + """ + def route(self, raw_message) -> OneOrMany[WsEvent]: + """ + Route a raw message to one or more events. + + Args: + raw_message: The raw message to route. + + Returns: + OneOrMany[WsEvent]: One or more events, or None to skip the message. + """ pass def __str__(self): diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index d54615e8..c17ae75b 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -11,8 +11,8 @@ from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, noop from ibind import events from ibind.events import WsEvent -from ws_v2.events import EventSink, Router, CallbackSink, AsyncSink from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver +from ws_v2._ws_events import EventSink, Router, CallbackSink, AsyncSink from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect _LOGGER = project_logger('ibkr_ws_client') From d864d689b3e656819fd659166829a7e77639b00d Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 15:08:32 +0200 Subject: [PATCH 26/32] refactor(ws_v2): ws_v2.subscriptions was renamed to ws_v2.ws_subscriptions for consistency --- ibind/__init__.py | 2 +- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 2 +- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 2 +- ibind/subscriptions/__init__.py | 2 +- ibind/ws_v2/ws_runtime.py | 2 +- ibind/ws_v2/{subscriptions.py => ws_subscriptions.py} | 0 6 files changed, 5 insertions(+), 5 deletions(-) rename ibind/ws_v2/{subscriptions.py => ws_subscriptions.py} (100%) diff --git a/ibind/__init__.py b/ibind/__init__.py index 65b8f627..7b9f4c68 100644 --- a/ibind/__init__.py +++ b/ibind/__init__.py @@ -12,8 +12,8 @@ from ibind.support.py_utils import execute_in_parallel from ibind import events, subscriptions from ibind.ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 -from ibind.ws_v2.subscriptions import SubscriptionHandle from ibind.ws_v2._ws_events import LogSink, QueueSink, CallbackSink, CompositeSink, NoopSink, EventSink +from ibind.ws_v2.ws_subscriptions import SubscriptionHandle __all__ = [ diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index 05f60bb6..72f16ece 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -6,7 +6,7 @@ from ibind import events from ibind.events import AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary, IbkrTopicEvent from support.py_utils import filter_none -from ws_v2.subscriptions import Subscription, SubscriptionResolver +from ws_v2.ws_subscriptions import Subscription, SubscriptionResolver def make_binding_key( diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index b1d75643..b05f9c7d 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -10,8 +10,8 @@ from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription from support.logs import project_logger from support.py_utils import OneOrMany, ensure_list_arg -from ws_v2.subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle, BindingStatus from ws_v2._ws_events import EventSink, CallbackSink, Router, AsyncSink, NoopSink +from ws_v2.ws_subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle, BindingStatus from ws_v2.ws_runtime import WsRuntime, WsState _LOGGER = project_logger('ibkr_ws_client') diff --git a/ibind/subscriptions/__init__.py b/ibind/subscriptions/__init__.py index 74f7a744..4160e9a3 100644 --- a/ibind/subscriptions/__init__.py +++ b/ibind/subscriptions/__init__.py @@ -1,6 +1,6 @@ from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription -from ws_v2.subscriptions import SubscriptionHandle, BindingStatus, Subscription, SubscriptionResolver +from ws_v2.ws_subscriptions import SubscriptionHandle, BindingStatus, Subscription, SubscriptionResolver __all__ = [ 'Subscription', diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index c17ae75b..86c7468d 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -11,8 +11,8 @@ from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, noop from ibind import events from ibind.events import WsEvent -from ws_v2.subscriptions import SubscriptionController, SubscriptionResolver from ws_v2._ws_events import EventSink, Router, CallbackSink, AsyncSink +from ws_v2.ws_subscriptions import SubscriptionController, SubscriptionResolver from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect _LOGGER = project_logger('ibkr_ws_client') diff --git a/ibind/ws_v2/subscriptions.py b/ibind/ws_v2/ws_subscriptions.py similarity index 100% rename from ibind/ws_v2/subscriptions.py rename to ibind/ws_v2/ws_subscriptions.py From 295f2530f20bb3dc6a296679c6ff780e23d87805 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 15:08:58 +0200 Subject: [PATCH 27/32] docs: add TESTING.md --- TESTING.md | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 TESTING.md diff --git a/TESTING.md b/TESTING.md new file mode 100644 index 00000000..6b503fe1 --- /dev/null +++ b/TESTING.md @@ -0,0 +1,88 @@ +# Testing Guide + +This document defines how tests are chosen, written, and evaluated. + +--- + +## Testing philosophy + +- Use commentary with `## Arrange` `## Act` and `## Assert` sections for test structure. +- Tests exist to lock behaviour, not to chase coverage. +- Avoid tests that duplicate what the language/runtime already guarantees. +- Use fixtures for setup/teardown of test state. +- Mock internal dependencies for unit tests. Mock external dependencies for integration tests. +- Prefer integration tests for verifying component boundaries and data flow. +- Capture and assert on logs for error and warning conditions. + + +## Test Type Structure + +Tests are organized into three main categories: + +``` +test/ +├── unit/ # Fast, isolated tests for core logic +├── integration/ # Multi-component tests +└── manual/ # Manual and performance tests +``` + + +## Test types and boundaries + +Use the lightest test type that still provides confidence. + +### Unit tests +Use when: +- logic is isolated and deterministic +- behaviour can be validated without wiring other components + +Guidelines: +- No network, filesystem, threads, or time dependence. +- Mock only at clear boundaries; do not mock internals of the unit under test. +- Data should be small, synthetic, and explicit. +- Prefer clarity over clever parametrisation. + +### Integration tests +Use when: +- correctness depends on interaction between components +- data flow or ordering matters + +Guidelines: +- Mock only outside the test boundary (eg. broker, network). +- Use realistic but minimal fixtures. +- Allow threads/timers only if they are part of the behaviour being tested. +- Failures should clearly indicate which interaction broke. + +### Manual / performance tests +Use when: +- validating full-system flows +- measuring throughput, latency, or concurrency +- interacting with real or near-real external systems + +Guidelines: +- Never run automatically in CI. +- Keep secrets out of test code. +- Prefer recorded or replayable inputs where possible. +- Treat results as diagnostic, not pass/fail gates. + + +## Choosing what to test + +Test: +- decision logic +- state transitions +- boundary conditions +- error and warning paths +- behaviour that has broken before + +Do not test: +- trivial getters/setters +- pure delegation +- obvious library behaviour +- formatting or logging text unless it signals correctness + + +## Running tests + +- Prefer running the smallest relevant subset while iterating. +- Run broader suites when touching core or high-risk code. From e6120760e46552c57260f9159436978c288b768a Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 15:09:18 +0200 Subject: [PATCH 28/32] test: add test_ws_events.py --- test/unit/ws_v2/test_ws_events_u.py | 497 ++++++++++++++++++++++++++++ 1 file changed, 497 insertions(+) create mode 100644 test/unit/ws_v2/test_ws_events_u.py diff --git a/test/unit/ws_v2/test_ws_events_u.py b/test/unit/ws_v2/test_ws_events_u.py new file mode 100644 index 00000000..30244c42 --- /dev/null +++ b/test/unit/ws_v2/test_ws_events_u.py @@ -0,0 +1,497 @@ +import threading +import time +from datetime import datetime +from queue import Empty, Full +from unittest.mock import MagicMock, patch + +import pytest + +from ibind.events import ( + WsOpen, + WsAuthenticated, + WsDegraded, + WsReady, + WsClose, + WsError, +) +from ibind import ( + NoopSink, + CallbackSink, + QueueSink, + CompositeSink, + EventSink, +) +from test.test_utils import capture_logs +from ws_v2._ws_events import AsyncSink + + +@pytest.fixture +def noop_sink(): + return NoopSink() + + +@pytest.fixture +def callback_sink(): + return CallbackSink() + + +@pytest.fixture +def queue_sink(): + sink = QueueSink() + sink._queues.clear() + return sink + + +@pytest.fixture +def sample_event(): + return WsOpen() + + +class TestWsEvent: + @capture_logs() + def test_immutability(self): + """WsEvent instances are immutable after creation.""" + ## Arrange + event = WsOpen() + + ## Act / Assert + with pytest.raises(Exception): + event.received_at = datetime.now() # NOQA + + @capture_logs() + def test_extra_fields_forbidden(self): + """WsEvent rejects extra fields not in the model.""" + ## Arrange / Act / Assert + with pytest.raises(Exception): + WsOpen(extra_field='value') # NOQA + + +@capture_logs() +def test_lifecycle_events(): + """Lifecycle events can be created with default received_at and optional fields.""" + ## Arrange / Act + ws_open = WsOpen() + ws_authenticated = WsAuthenticated() + ws_degraded = WsDegraded() + ws_ready = WsReady() + ws_close_with_fields = WsClose(close_status_code=1000, close_msg='normal closure') + ws_close_with_none = WsClose(close_status_code=None, close_msg=None) + error = RuntimeError('connection failed') + ws_error = WsError(error=error) + + ## Assert + assert isinstance(ws_open.received_at, datetime) + assert isinstance(ws_authenticated.received_at, datetime) + assert isinstance(ws_degraded.received_at, datetime) + assert isinstance(ws_ready.received_at, datetime) + assert ws_close_with_fields.close_status_code == 1000 + assert ws_close_with_fields.close_msg == 'normal closure' + assert ws_close_with_none.close_status_code is None + assert ws_close_with_none.close_msg is None + assert ws_error.error is error + + +class TestCallbackSink: + @capture_logs() + def test_on_registers_callback(self, callback_sink): + """CallbackSink.on registers a callback for an event type.""" + ## Arrange + callback = MagicMock() + + ## Act + callback_sink.on(WsOpen, callback) + + ## Assert + assert WsOpen in callback_sink._callbacks + assert callback in callback_sink._callbacks[WsOpen] + + @capture_logs() + def test_emit_calls_registered_callback(self, callback_sink, sample_event): + """CallbackSink.emit invokes callbacks registered for the event type.""" + ## Arrange + callback = MagicMock() + callback_sink.on(WsOpen, callback) + + ## Act + callback_sink.emit(sample_event) + + ## Assert + callback.assert_called_once_with(sample_event) + + @capture_logs() + def test_emit_ignores_unregistered_event_types(self, callback_sink): + """CallbackSink.emit does not call callbacks for unregistered event types.""" + ## Arrange + callback = MagicMock() + callback_sink.on(WsOpen, callback) + event = WsClose(close_status_code=1000, close_msg='') + + ## Act + callback_sink.emit(event) + + ## Assert + callback.assert_not_called() + + @capture_logs() + def test_emit_multiple_callbacks(self, callback_sink, sample_event): + """CallbackSink.emit calls all callbacks registered for an event type.""" + ## Arrange + callback1 = MagicMock() + callback2 = MagicMock() + callback_sink.on(WsOpen, callback1) + callback_sink.on(WsOpen, callback2) + + ## Act + callback_sink.emit(sample_event) + + ## Assert + callback1.assert_called_once_with(sample_event) + callback2.assert_called_once_with(sample_event) + + @capture_logs(logger_level='ERROR', expected_errors=['Exception emitting event to callback test_fn'], partial_match=True) + def test_emit_logs_callback_exception(self, callback_sink, sample_event): + """CallbackSink.emit logs exceptions raised by callbacks without propagating.""" + ## Arrange + def test_fn(event): + raise ValueError('callback error') + callback_sink.on(WsOpen, test_fn) + + ## Act + callback_sink.emit(sample_event) + + +class TestQueueSink: + @capture_logs() + def test_new_queue_accessor_creates_accessor(self, queue_sink): + """QueueSink.new_queue_accessor returns a QueueAccessor for the event type.""" + ## Act + accessor = queue_sink.new_queue_accessor(WsOpen) + + ## Assert + assert accessor.key == WsOpen + + @capture_logs() + def test_emit_puts_event_in_queue(self, queue_sink, sample_event): + """QueueSink.emit adds the event to the queue for its type.""" + ## Act + queue_sink.emit(sample_event) + + ## Assert + retrieved = queue_sink.get(WsOpen, block=False) + assert retrieved is sample_event + + @capture_logs() + def test_get_returns_none_when_empty(self, queue_sink): + """QueueSink.get returns None when the queue is empty and block=False.""" + ## Act + result = queue_sink.get(WsOpen, block=False) + + ## Assert + assert result is None + + + @capture_logs() + def test_empty_returns_true_when_empty(self, queue_sink): + """QueueSink.empty returns True when no events are queued.""" + ## Act + result = queue_sink.empty(WsOpen) + + ## Assert + assert result is True + + @capture_logs() + def test_empty_returns_false_when_not_empty(self, queue_sink): + """QueueSink.empty returns False when events are queued.""" + ## Arrange + queue_sink.emit(WsOpen()) + + ## Act + result = queue_sink.empty(WsOpen) + + ## Assert + assert result is False + + @capture_logs() + def test_separate_queues_per_event_type(self, queue_sink): + """QueueSink maintains separate queues for different event types.""" + ## Arrange + event1 = WsOpen() + event2 = WsClose(close_status_code=1000, close_msg='') + + ## Act + queue_sink.emit(event1) + queue_sink.emit(event2) + + ## Assert + retrieved1 = queue_sink.get(WsOpen, block=False) + retrieved2 = queue_sink.get(WsClose, block=False) + assert isinstance(retrieved1, WsOpen) + assert isinstance(retrieved2, WsClose) + assert queue_sink.get(WsOpen, block=False) is None + + +class TestCompositeSink: + @capture_logs() + def test_emit_calls_all_sinks(self, sample_event): + """CompositeSink.emit forwards the event to all registered sinks.""" + ## Arrange + sink1 = MagicMock() + sink2 = MagicMock() + composite = CompositeSink(sink1, sink2) + + ## Act + composite.emit(sample_event) + + ## Assert + sink1.emit.assert_called_once_with(sample_event) + sink2.emit.assert_called_once_with(sample_event) + + @capture_logs(logger_level='WARNING', expected_errors=['Exception emitting event to sink'], partial_match=True) + def test_emit_logs_sink_exception(self, sample_event): + """CompositeSink.emit logs exceptions from sinks without propagating.""" + ## Arrange + sink1 = MagicMock() + sink1.emit.side_effect = ValueError('sink error') + sink2 = MagicMock() + composite = CompositeSink(sink1, sink2) + + ## Act + composite.emit(sample_event) + + ## Assert + sink2.emit.assert_called_once_with(sample_event) + + +class TestAsyncSink: + @capture_logs() + def test_start_launches_thread(self, noop_sink): + """AsyncSink.start launches a background thread.""" + ## Arrange + sink = AsyncSink(noop_sink) + + ## Act + sink.start() + + ## Assert + assert sink._running is True + assert sink._thread is not None + assert sink._thread.is_alive() + + ## Cleanup + sink.stop() + + @capture_logs() + def test_start_idempotent(self, noop_sink): + """AsyncSink.start does not launch multiple threads if already running.""" + ## Arrange + sink = AsyncSink(noop_sink) + sink.start() + first_thread = sink._thread + + ## Act + sink.start() + + ## Assert + assert sink._thread is first_thread + + ## Cleanup + sink.stop() + + @capture_logs() + def test_stop_terminates_thread(self, noop_sink): + """AsyncSink.stop terminates the background thread.""" + ## Arrange + sink = AsyncSink(noop_sink) + sink.start() + + ## Act + result = sink.stop() + + ## Assert + assert result is True + assert sink._running is False + assert sink._thread is None + + @capture_logs() + def test_stop_idempotent(self, noop_sink): + """AsyncSink.stop returns True when already stopped.""" + ## Arrange + sink = AsyncSink(noop_sink) + + ## Act + result = sink.stop() + + ## Assert + assert result is True + + @capture_logs() + def test_stop_from_same_thread_raises(self, noop_sink): + """AsyncSink.stop raises RuntimeError when called from the sink thread.""" + ## Arrange + sink = AsyncSink(noop_sink) + exception_holder = {'exception': None} + ev = threading.Event() + + def stop_from_thread(): + try: + sink.stop() + except RuntimeError as e: + exception_holder['exception'] = e + ev.set() + + sink._cycle = stop_from_thread + ev.clear() + sink.start() + ev.wait(10) + + ## Assert + assert exception_holder['exception'] is not None + assert 'Stopping async sink called from within async sink thread' in str(exception_holder['exception']) + + ## Cleanup + sink._running = False + + @capture_logs() + def test_emit_queues_event(self, sample_event): + """AsyncSink.emit adds events to the internal queue.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink) + sink.start() + + ## Act + sink.emit(sample_event) + sink._consume_queue() + + ## Assert + inner_sink.emit.assert_called_with(sample_event) + + ## Cleanup + sink.stop() + + @capture_logs(logger_level='WARNING', expected_errors=['dropping newest event'], partial_match=True) + def test_emit_drops_newest_when_full(self): + """AsyncSink.emit drops the newest event when queue is full and drop_oldest=False.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=False) + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + + sink.emit(event1) + sink.emit(event2) + + @capture_logs(logger_level='WARNING', expected_errors=['dropping oldest event'], partial_match=True) + def test_emit_drops_oldest_when_full(self): + """AsyncSink.emit drops the oldest event when queue is full and drop_oldest=True.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=True) + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + sink.emit(event1) + sink.emit(event2) + + @capture_logs() + def test_consume_queue_forwards_events(self): + """AsyncSink forwards queued events to the inner sink.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink) + sink.start() + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + sink.emit(event1) + sink.emit(event2) + sink._consume_queue() + + ## Assert + assert inner_sink.emit.call_count == 2 + inner_sink.emit.assert_any_call(event1) + inner_sink.emit.assert_any_call(event2) + + ## Cleanup + sink.stop() + + @capture_logs(logger_level='ERROR', expected_errors=['sink error'], partial_match=True) + def test_consume_queue_logs_exception(self): + """AsyncSink logs exceptions from the inner sink without stopping.""" + ## Arrange + inner_sink = MagicMock(spec=EventSink) + inner_sink.emit.side_effect = ValueError('sink error') + sink = AsyncSink(inner_sink) + event = WsOpen() + + sink.emit(event) + sink._consume_queue() + + @capture_logs() + def test_cycle_consumes_remaining_events_on_stop(self): + """AsyncSink processes remaining events in queue when stopping.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, cycle_interval=0.5) + sink.start() + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + sink.emit(event1) + sink.emit(event2) + sink.stop() + + ## Assert + assert inner_sink.emit.call_count >= 2 + + @capture_logs(logger_level='WARNING', expected_errors=['Event queue not empty when stopping'], partial_match=True) + def test_stop_warns_when_queue_not_empty(self): + """AsyncSink logs warning when stopping with events still in queue.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=10) + event = WsOpen() + + ## Act + sink._running = True + sink._queue.put(event) + sink.stop() + + def test_emit_handles_empty_exception_when_dropping_oldest(self): + """AsyncSink handles Empty exception when queue becomes empty between full check and get.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=True) + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + with patch.object(sink._queue, 'put_nowait', side_effect=[Full, None]) as mock_put: + with patch.object(sink._queue, 'get_nowait', side_effect=Empty): + sink.emit(event1) + + ## Assert + assert mock_put.call_count == 2 + + @capture_logs( + logger_level='WARNING', + expected_errors=['Event queue full; dropping oldest event', 'Event queue still full; dropping event'], + partial_match=True, + ) + def test_emit_warns_when_queue_still_full_after_drop(self): + """AsyncSink logs warning when queue is still full after dropping oldest event.""" + ## Arrange + inner_sink = MagicMock() + sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=True) + event1 = WsOpen() + event2 = WsAuthenticated() + + ## Act + with patch.object(sink._queue, 'put_nowait', side_effect=[Full, Full]) as mock_put: + with patch.object(sink._queue, 'get_nowait', return_value=event1): + sink.emit(event2) From a32e7da1266c9efe40d03570e6c0ca8888bb2042 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 18:35:16 +0200 Subject: [PATCH 29/32] docs: added docstrings to ws_subscriptions.py --- ibind/ws_v2/ws_subscriptions.py | 252 ++++++++++++++++++++++++++++---- 1 file changed, 222 insertions(+), 30 deletions(-) diff --git a/ibind/ws_v2/ws_subscriptions.py b/ibind/ws_v2/ws_subscriptions.py index c1604300..c1cd469d 100644 --- a/ibind/ws_v2/ws_subscriptions.py +++ b/ibind/ws_v2/ws_subscriptions.py @@ -13,51 +13,108 @@ _LOGGER = project_logger('ibkr_ws_client') -class Subscription(BaseModel): +class Subscription(BaseModel): # pragma: no cover + """ + Base class for WebSocket subscriptions. + + Immutable model defining subscription behaviour including payload generation, + confirmation requirements, and expiry settings. Subclasses implement specific + subscription types by overriding abstract methods. + + Attributes: + expiry_seconds (int | None): Time in seconds before subscription expires and + requires renewal. None means no expiry. Default: None. + """ + model_config = ConfigDict(frozen=True) expiry_seconds: int | None = None @property def topic(self) -> str: + """Get the subscription topic identifier.""" raise NotImplementedError def subscribe_payload(self) -> str: + """Generate the payload string to send for subscribing.""" raise NotImplementedError def unsubscribe_payload(self) -> str: + """Generate the payload string to send for unsubscribing.""" raise NotImplementedError @property def confirms_subscribe(self) -> bool: + """Whether the server sends confirmation when subscription succeeds.""" return True @property def confirms_unsubscribe(self) -> bool: + """Whether the server sends confirmation when unsubscription succeeds.""" return False def binding_key(self): + """Get the unique key identifying this subscription binding.""" return self.subscribe_payload() def __str__(self): return f'{self.__class__.__qualname__}({self.binding_key()})' -class SubscriptionResolver(Protocol): +class SubscriptionResolver(Protocol): # pragma: no cover + """ + Protocol for resolving subscription binding keys from events. + + Implementations determine which subscription an event belongs to and whether + the event indicates an active or inactive subscription state. + """ + def resolve_binding_key(self, event: WsEvent) -> Tuple[bool, str]: + """ + Resolve the binding key and active state from an event. + + Args: + event (WsEvent): Event to resolve. + + Returns: + tuple[bool, str]: (is_active, binding_key) where is_active indicates + subscription is active, and binding_key identifies the subscription. + Returns (None, None) if event is not subscription-related. + """ ... -class BindingStatus(Enum): - NEW = "NEW" - PENDING = "PENDING" - ACTIVE = "ACTIVE" - FAILED = "FAILED" - DEGRADED = "DEGRADED" - UNSUBSCRIBED = "UNSUBSCRIBED" - EXPIRED = "EXPIRED" +class BindingStatus(Enum): # pragma: no cover + """ + Status of a subscription binding. + + Tracks the lifecycle state of a subscription from initial registration through + activation, failure, or unsubscription. + """ + + NEW = 'NEW' + PENDING = 'PENDING' + ACTIVE = 'ACTIVE' # subscription successful + FAILED = 'FAILED' + DEGRADED = 'DEGRADED' + UNSUBSCRIBED = 'UNSUBSCRIBED' # unsubscription successful + EXPIRED = 'EXPIRED' class Binding(BaseModel): + """ + Internal state tracking for a subscription binding. + + Maintains the desired intent (subscribe or unsubscribe), current status, + and retry state for subscription operations. + + Attributes: + subscription (Subscription): The subscription being tracked. + intent (Literal[BindingStatus.ACTIVE, BindingStatus.UNSUBSCRIBED]): Desired state. + status (BindingStatus): Current state. Default: BindingStatus.NEW. + attempts (int): Number of attempts made. Default: 0. + last_attempt (float): Timestamp of last attempt. Default: 0. + """ + subscription: Subscription intent: Literal[BindingStatus.ACTIVE, BindingStatus.UNSUBSCRIBED] status: BindingStatus = BindingStatus.NEW @@ -66,58 +123,82 @@ class Binding(BaseModel): @property def done(self) -> bool: + """Whether the binding has reached its intended state.""" return self.status == self.intent - def reset(self): + def reset(self): # pragma: no cover + """Reset retry state to allow new attempts.""" self.attempts = 0 self.last_attempt = 0 class SubscriptionHandle: - def __init__(self, controller: "SubscriptionController", subscription: Subscription): + """ + Handle for interacting with a subscription. + + Provides methods to query subscription state, wait for completion, and unsubscribe. + Returned by subscribe/unsubscribe operations. + """ + + def __init__(self, controller: 'SubscriptionController', subscription: Subscription): self._controller = controller self._subscription = subscription @property def binding_key(self) -> str: + """Get the unique key identifying this subscription.""" return self._subscription.binding_key() @property def status(self) -> BindingStatus: + """Get the current status of this subscription.""" return self._controller.get_status(self.binding_key) @property def active(self) -> bool: + """Whether the subscription is currently active.""" return self.status == BindingStatus.ACTIVE @property def unsubscribed(self) -> bool: + """Whether the subscription has been unsubscribed.""" return self.status == BindingStatus.UNSUBSCRIBED @property def done(self) -> bool: + """Whether the subscription has reached its intended state.""" return self._controller.is_done(self.binding_key) def wait(self, timeout: float | None = None) -> bool: + """ + Wait for the subscription to reach its intended state. + + Args: + timeout (float | None): Maximum time to wait in seconds, or indefinitely if None. Default: None. + + Returns: + bool: True if subscription reached intended state, False if timed out or failed. + """ return self._controller.wait_for(self.binding_key, timeout=timeout) - def unsubscribe(self) -> "SubscriptionHandle": + def unsubscribe(self) -> 'SubscriptionHandle': + """ + Unsubscribe from this subscription. + + Returns: + SubscriptionHandle: This handle for chaining. + """ self._controller.unsubscribe(self._subscription) return self class SubscriptionController: """ - Mixin which manages subscriptions to different topics using the WsClient. - - This class handles the logic for subscribing and unsubscribing to various topics. It maintains a - record of active subscriptions and provides methods to modify them. The class relies on a - SubscriptionProcessor to create subscription and unsubscription payloads. + Manages WebSocket subscriptions with automatic retries and state tracking. - Constructor Parameters: - subscription_processor (SubscriptionProcessor): The processor to create subscription payloads. - subscription_retries (int, optional): The number of retries for subscription requests. Defaults to 5. - subscription_timeout (float, optional): The timeout in seconds for subscription requests. Defaults to 2. + Handles subscription lifecycle including registration, activation, expiry, and unsubscription. + Maintains binding state and coordinates with a resolver to match events to subscriptions. + Thread-safe through internal condition variable. """ def __init__( @@ -127,6 +208,16 @@ def __init__( subscription_retries: int = 5, subscription_timeout: float = 2, ): + """ + Create a subscription controller. + + Args: + send_payload (Callable[[str], bool]): Function to send payloads through the WebSocket. + Returns True if sent successfully. + subscription_resolver (SubscriptionResolver): Resolver to match events to subscriptions. + subscription_retries (int, optional): Maximum retry attempts per subscription. Default: 5. + subscription_timeout (float, optional): Seconds to wait between retry attempts. Default: 2. + """ self._send_payload = send_payload self._subscription_resolver = subscription_resolver self._subscription_retries = subscription_retries @@ -146,6 +237,15 @@ def _send(self, payload) -> bool: return False def observe(self, event: WsEvent): + """ + Process an event to update subscription state. + + Uses the resolver to determine if the event confirms subscription or unsubscription, + then updates the corresponding binding status. + + Args: + event (WsEvent): Event to process. + """ is_active, binding_key = self._subscription_resolver.resolve_binding_key(event) # None means the event is not related to a tracked subscription @@ -179,9 +279,19 @@ def _make_attempt(self, binding: Binding): self._confirm_unsubscribed(subscription.binding_key()) def reconcile_binding(self, binding: Binding): + """ + Reconcile a single binding by checking expiry and retrying if needed. + + Handles expiry checks, retry logic, and failure detection. Called periodically + by the runtime to maintain subscription state. + + Args: + binding (Binding): Binding to reconcile. + """ now = time.time() subscription = binding.subscription + # if either done or failed, return early or check expiration if expiry_seconds is provided if binding.status in [binding.intent, BindingStatus.FAILED]: if subscription.expiry_seconds is None: return @@ -208,11 +318,24 @@ def reconcile_binding(self, binding: Binding): self._make_attempt(binding) def reconcile_bindings(self): + """Reconcile all bindings by checking expiry and retrying as needed.""" with self._condition: for binding in self._bindings.values(): self.reconcile_binding(binding) def subscribe(self, subscription: Subscription) -> SubscriptionHandle: + """ + Register intent to subscribe. + + Creates or updates a binding with ACTIVE intent. The actual subscription + attempt occurs during reconciliation. + + Args: + subscription (Subscription): Subscription to activate. + + Returns: + SubscriptionHandle: Handle for tracking and controlling this subscription. + """ binding_key = subscription.binding_key() with self._condition: @@ -236,6 +359,18 @@ def subscribe(self, subscription: Subscription) -> SubscriptionHandle: return SubscriptionHandle(self, subscription) def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: + """ + Register intent to unsubscribe. + + Creates or updates a binding with UNSUBSCRIBED intent. The actual unsubscription + attempt occurs during reconciliation. + + Args: + subscription (Subscription): Subscription to deactivate. + + Returns: + SubscriptionHandle: Handle for tracking this unsubscription. + """ binding_key = subscription.binding_key() with self._condition: @@ -259,45 +394,88 @@ def unsubscribe(self, subscription: Subscription) -> SubscriptionHandle: return SubscriptionHandle(self, subscription) def invalidate_subscriptions(self): + """Mark all subscriptions as degraded, typically after connection loss.""" with self._condition: for binding_key, binding in self._bindings.items(): if binding.status != BindingStatus.DEGRADED: self._update_status(binding, BindingStatus.DEGRADED) - def is_subscription_active(self, binding_key: str) -> Optional[bool]: + def is_subscription_active(self, binding_key: str) -> Optional[bool]: # pragma: no cover + """Check if a subscription is currently active. + + Args: + binding_key (str): Binding key to check. + + Returns: + bool: True if subscription exists and is active, False otherwise. + """ with self._condition: if not self.has_subscription(binding_key): return False return self._bindings[binding_key].status == BindingStatus.ACTIVE def has_active_subscriptions(self) -> bool: + """ + Check if any subscriptions are currently active. + + Returns: + bool: True if any subscriptions are active, False otherwise. + """ with self._condition: for subscription in self._bindings: if self.is_subscription_active(subscription): return True return False - def has_subscription(self, binding_key: str) -> bool: + def has_subscription(self, binding_key: str) -> bool: # pragma: no cover + """Check if a subscription exists. + + Args: + binding_key (str): Binding key to check. + + Returns: + bool: True if subscription exists, False otherwise. + """ return binding_key in self._bindings - def get_status(self, binding_key: str) -> BindingStatus | None: + def get_status(self, binding_key: str) -> BindingStatus | None: # pragma: no cover + """Get the status of a subscription. + + Args: + binding_key (str): Binding key to query. + + Returns: + BindingStatus | None: Current status, or None if subscription doesn't exist. + """ with self._condition: if not self.has_subscription(binding_key): return None return self._bindings[binding_key].status - def is_done(self, binding_key: str) -> bool | None: + def is_done(self, binding_key: str) -> bool | None: # pragma: no cover + """Check if a subscription has reached its intended state. + + Args: + binding_key (str): Binding key to check. + + Returns: + bool | None: True if done, False if not done, None if subscription doesn't exist. + """ with self._condition: if not self.has_subscription(binding_key): return None return self._bindings[binding_key].done def get_active_subscriptions(self): + """ + Get all active subscriptions. + + Returns: + dict[str, Binding]: Deep copies of active bindings keyed by binding_key. + """ with self._condition: return { - binding_key: copy.deepcopy(binding) - for binding_key, binding in self._bindings.items() - if self.is_subscription_active(binding_key) + binding_key: copy.deepcopy(binding) for binding_key, binding in self._bindings.items() if self.is_subscription_active(binding_key) } def _update_status(self, binding: Binding, status: BindingStatus): @@ -331,6 +509,19 @@ def _confirm_unsubscribed(self, binding_key: str): self._update_status(binding, BindingStatus.UNSUBSCRIBED) def wait_for(self, binding_key: str, timeout: float | None = None) -> bool: + """ + Wait for a subscription to reach its intended state. + + Blocks until the binding reaches its intent or fails. Uses a condition variable + for efficient waiting. + + Args: + binding_key (str): Binding key to wait for. + timeout (float | None): Maximum time to wait in seconds. None means wait indefinitely. Default: None. + + Returns: + bool: True if binding reached intended state, False if timed out, failed, or binding not found. + """ deadline = None if timeout is None else time.monotonic() + timeout with self._condition: @@ -346,6 +537,7 @@ def wait_for(self, binding_key: str, timeout: float | None = None) -> bool: if binding.status == BindingStatus.FAILED: return False + # wait for the remaining time remaining = None if timeout is not None: remaining = deadline - time.monotonic() @@ -354,5 +546,5 @@ def wait_for(self, binding_key: str, timeout: float | None = None) -> bool: self._condition.wait(remaining) - def __str__(self): + def __str__(self): # pragma: no cover return f'{self.__class__.__qualname__}()' From 530438a6057a1b5987b2f2817c029d1c05b14786 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 18:36:14 +0200 Subject: [PATCH 30/32] test: fixed ruff checks --- test/unit/ws_v2/test_ws_events_u.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/unit/ws_v2/test_ws_events_u.py b/test/unit/ws_v2/test_ws_events_u.py index 30244c42..7e3f166c 100644 --- a/test/unit/ws_v2/test_ws_events_u.py +++ b/test/unit/ws_v2/test_ws_events_u.py @@ -1,5 +1,4 @@ import threading -import time from datetime import datetime from queue import Empty, Full from unittest.mock import MagicMock, patch @@ -151,9 +150,11 @@ def test_emit_multiple_callbacks(self, callback_sink, sample_event): @capture_logs(logger_level='ERROR', expected_errors=['Exception emitting event to callback test_fn'], partial_match=True) def test_emit_logs_callback_exception(self, callback_sink, sample_event): """CallbackSink.emit logs exceptions raised by callbacks without propagating.""" + ## Arrange def test_fn(event): raise ValueError('callback error') + callback_sink.on(WsOpen, test_fn) ## Act @@ -189,7 +190,6 @@ def test_get_returns_none_when_empty(self, queue_sink): ## Assert assert result is None - @capture_logs() def test_empty_returns_true_when_empty(self, queue_sink): """QueueSink.empty returns True when no events are queued.""" @@ -468,7 +468,6 @@ def test_emit_handles_empty_exception_when_dropping_oldest(self): inner_sink = MagicMock() sink = AsyncSink(inner_sink, maxsize=1, drop_oldest=True) event1 = WsOpen() - event2 = WsAuthenticated() ## Act with patch.object(sink._queue, 'put_nowait', side_effect=[Full, None]) as mock_put: @@ -492,6 +491,6 @@ def test_emit_warns_when_queue_still_full_after_drop(self): event2 = WsAuthenticated() ## Act - with patch.object(sink._queue, 'put_nowait', side_effect=[Full, Full]) as mock_put: + with patch.object(sink._queue, 'put_nowait', side_effect=[Full, Full]): with patch.object(sink._queue, 'get_nowait', return_value=event1): sink.emit(event2) From fcedb06c7da282d41a8779934f2843e3bf871667 Mon Sep 17 00:00:00 2001 From: voyz Date: Sun, 10 May 2026 18:36:27 +0200 Subject: [PATCH 31/32] test: add test_ws_subscriptions_u.py --- test/unit/ws_v2/test_ws_subscriptions_u.py | 821 +++++++++++++++++++++ 1 file changed, 821 insertions(+) create mode 100644 test/unit/ws_v2/test_ws_subscriptions_u.py diff --git a/test/unit/ws_v2/test_ws_subscriptions_u.py b/test/unit/ws_v2/test_ws_subscriptions_u.py new file mode 100644 index 00000000..f5391c1b --- /dev/null +++ b/test/unit/ws_v2/test_ws_subscriptions_u.py @@ -0,0 +1,821 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from ibind.events import WsOpen, WsEvent +from ibind.ws_v2.ws_subscriptions import ( + Subscription, + BindingStatus, + Binding, + SubscriptionHandle, + SubscriptionController, +) +from test.test_utils import capture_logs, mock_module_time + + +class MockEvent(WsEvent): + """Mock event for testing subscription confirmation.""" + + binding_key: str + is_active: bool + + +class MockResolver: + """Mock resolver that extracts binding_key and is_active from MockEvent.""" + + def resolve_binding_key(self, event: WsEvent): + if isinstance(event, MockEvent): + return (event.is_active, event.binding_key) + return (False, None) + + +class MockSubscription(Subscription): + topic_value: str = 'test_topic' + payload_value: str = 'sub_payload' + + @property + def topic(self) -> str: + return self.topic_value + + def subscribe_payload(self) -> str: + return self.payload_value + + def unsubscribe_payload(self) -> str: + return f'unsub_{self.payload_value}' + + +class MockSubscriptionNoConfirm(Subscription): + topic_value: str = 'no_confirm' + payload_value: str = 'no_confirm_payload' + + @property + def topic(self) -> str: + return self.topic_value + + def subscribe_payload(self) -> str: + return self.payload_value + + def unsubscribe_payload(self) -> str: + return f'unsub_{self.payload_value}' + + @property + def confirms_subscribe(self) -> bool: + return False + + @property + def confirms_unsubscribe(self) -> bool: + return True + + +@pytest.fixture +def mock_sub(): + return MockSubscription() + + +@pytest.fixture() +def binding_key(mock_sub): + return mock_sub.binding_key() + + +@pytest.fixture +def test_subscription_with_expiry(): + return MockSubscription(expiry_seconds=10) + + +@pytest.fixture +def test_subscription_no_confirm(): + return MockSubscriptionNoConfirm() + + +@pytest.fixture +def mock_send_payload(): + return MagicMock(return_value=True) + + +@pytest.fixture +def sc(mock_send_payload): + return SubscriptionController( + send_payload=mock_send_payload, + subscription_resolver=MockResolver(), + subscription_retries=3, + subscription_timeout=1.0, + ) + + +class TestBinding: + @capture_logs() + def test_binding_done_when_status_matches_intent(self, mock_sub): + """Binding.done returns True when status matches intent.""" + ## Arrange + binding = Binding(subscription=mock_sub, intent=BindingStatus.ACTIVE) + binding.status = BindingStatus.ACTIVE + + ## Act + result = binding.done + + ## Assert + assert result is True + + @capture_logs() + def test_binding_not_done_when_status_differs(self, mock_sub): + """Binding.done returns False when status does not match intent.""" + ## Arrange + binding = Binding(subscription=mock_sub, intent=BindingStatus.ACTIVE) + binding.status = BindingStatus.PENDING + + ## Act + result = binding.done + + ## Assert + assert result is False + + +class TestSubscriptionHandle: + @capture_logs() + def test_status(self, sc, mock_sub): + """SubscriptionHandle.status returns the current binding status.""" + ## Arrange + sc.subscribe(mock_sub) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.status == BindingStatus.NEW + + @capture_logs() + def test_active_when_status_active(self, sc, mock_sub, binding_key): + """SubscriptionHandle.active returns True when status is ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.active is True + + @capture_logs() + def test_active_when_status_not_active(self, sc, mock_sub): + """SubscriptionHandle.active returns False when status is not ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.active is False + + @capture_logs() + def test_unsubscribed_when_status_unsubscribed(self, sc, mock_sub, binding_key): + """SubscriptionHandle.unsubscribed returns True when status is UNSUBSCRIBED.""" + ## Arrange + sc.unsubscribe(mock_sub) + with sc._condition: + sc._confirm_unsubscribed(binding_key) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.unsubscribed is True + + @capture_logs() + def test_done_delegates_to_controller(self, sc, mock_sub, binding_key): + """SubscriptionHandle.done delegates to controller.is_done.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + handle = SubscriptionHandle(sc, mock_sub) + + ## Assert + assert handle.done is True + + @capture_logs() + def test_wait_delegates_to_controller(self, sc, mock_sub, binding_key): + """SubscriptionHandle.wait delegates to controller.wait_for.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + handle = SubscriptionHandle(sc, mock_sub) + + ## Act + result = handle.wait(timeout=1.0) + + ## Assert + assert result is True + + @capture_logs() + def test_unsubscribe_delegates_to_controller(self, sc, mock_sub, binding_key): + """SubscriptionHandle.unsubscribe delegates to controller.unsubscribe.""" + ## Arrange + sc.subscribe(mock_sub) + handle = SubscriptionHandle(sc, mock_sub) + + ## Act + result = handle.unsubscribe() + + ## Assert + assert result is handle + assert sc.get_status(binding_key) == BindingStatus.NEW + + +class TestInterface: + @capture_logs() + def test_subscribe_creates_new_binding(self, sc, mock_sub, binding_key): + """SubscriptionController.subscribe creates a new binding with ACTIVE intent.""" + ## Act + handle = sc.subscribe(mock_sub) + + ## Assert + assert isinstance(handle, SubscriptionHandle) + assert sc.has_subscription(binding_key) + binding = sc._bindings[binding_key] + assert binding.intent == BindingStatus.ACTIVE + assert binding.status == BindingStatus.NEW + + @capture_logs() + def test_subscribe_updates_existing_binding_intent(self, sc, mock_sub, binding_key): + """SubscriptionController.subscribe updates intent on existing binding.""" + ## Arrange + sc.unsubscribe(mock_sub) + binding = sc._bindings[binding_key] + assert binding.intent == BindingStatus.UNSUBSCRIBED + + ## Act + sc.subscribe(mock_sub) + + ## Assert + assert binding.intent == BindingStatus.ACTIVE + + @capture_logs() + def test_subscribe_resets_unsubscribed_binding(self, sc, mock_sub, binding_key): + """SubscriptionController.subscribe resets binding when previously UNSUBSCRIBED.""" + ## Arrange + sc.unsubscribe(mock_sub) + with sc._condition: + sc._confirm_unsubscribed(binding_key) + binding = sc._bindings[binding_key] + binding.attempts = 5 + binding.last_attempt = 100.0 + + ## Act + sc.subscribe(mock_sub) + + ## Assert + binding = sc._bindings[binding_key] + assert binding.attempts == 0 + assert binding.last_attempt == 0 + + @capture_logs() + def test_unsubscribe_creates_new_binding(self, sc, mock_sub, binding_key): + """SubscriptionController.unsubscribe creates a new binding with UNSUBSCRIBED intent.""" + ## Act + handle = sc.unsubscribe(mock_sub) + + ## Assert + assert isinstance(handle, SubscriptionHandle) + assert sc.has_subscription(binding_key) + binding = sc._bindings[binding_key] + assert binding.intent == BindingStatus.UNSUBSCRIBED + assert binding.status == BindingStatus.NEW + + @capture_logs() + def test_unsubscribe_updates_existing_binding_intent(self, sc, mock_sub, binding_key): + """SubscriptionController.unsubscribe updates intent on existing binding.""" + ## Arrange + sc.subscribe(mock_sub) + + ## Act + sc.unsubscribe(mock_sub) + + ## Assert + binding = sc._bindings[binding_key] + assert binding.intent == BindingStatus.UNSUBSCRIBED + + @capture_logs() + def test_unsubscribe_resets_active_binding(self, sc, mock_sub, binding_key): + """SubscriptionController.unsubscribe resets binding when previously ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + binding = sc._bindings[binding_key] + binding.attempts = 5 + binding.last_attempt = 100.0 + + ## Act + sc.unsubscribe(mock_sub) + + ## Assert + binding = sc._bindings[binding_key] + assert binding.attempts == 0 + assert binding.last_attempt == 0 + + @capture_logs() + def test_has_active_subscriptions(self, sc, mock_sub, binding_key): + """SubscriptionController.has_active_subscriptions returns True when any subscription is active.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + result = sc.has_active_subscriptions() + + ## Assert + assert result is True + + @capture_logs() + def test_has_active_subscriptions_returns_false_when_none_active(self, sc, mock_sub): + """SubscriptionController.has_active_subscriptions returns False when no subscriptions are active.""" + ## Arrange + sc.subscribe(mock_sub) + + ## Act + result = sc.has_active_subscriptions() + + ## Assert + assert result is False + + @capture_logs() + def test_get_active_subscriptions(self, sc, mock_sub, binding_key): + """SubscriptionController.get_active_subscriptions returns dict of active bindings.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + result = sc.get_active_subscriptions() + + ## Assert + assert binding_key in result + assert result[binding_key].status == BindingStatus.ACTIVE + + @capture_logs() + def test_invalidate_subscriptions(self, sc, mock_sub, binding_key): + """SubscriptionController.invalidate_subscriptions marks all bindings as DEGRADED.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + sc.invalidate_subscriptions() + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.DEGRADED + + +class TestObserve: + @capture_logs() + def test_observe_confirms_subscribed(self, sc, mock_sub, binding_key): + """SubscriptionController.observe confirms subscription when resolver returns active.""" + ## Arrange + sc.subscribe(mock_sub) + event = MockEvent(binding_key=binding_key, is_active=True) + + ## Act + sc.observe(event) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.ACTIVE + + @capture_logs() + def test_observe_confirms_unsubscribed(self, sc, mock_sub, binding_key): + """SubscriptionController.observe confirms unsubscription when resolver returns inactive.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + sc.unsubscribe(mock_sub) + event = MockEvent(binding_key=binding_key, is_active=False) + + ## Act + sc.observe(event) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.UNSUBSCRIBED + + @capture_logs() + def test_observe_ignores_unrelated_events(self, sc, mock_sub, binding_key): + """SubscriptionController.observe ignores events with no binding key.""" + ## Arrange + sc.subscribe(mock_sub) + event = WsOpen() + + ## Act + sc.observe(event) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.NEW, 'observe should not have updated the binding' + + @capture_logs(logger_level='WARNING', expected_errors=['Observed a binding_key'], partial_match=True) + def test_observe_warns_on_missing_subscription(self, sc): + """SubscriptionController.observe logs warning when binding key has no subscription.""" + ## Arrange + event = MockEvent(binding_key='unknown_key', is_active=True) + + ## Act + sc.observe(event) + + +class TestReconcile: + @capture_logs() + def test_reconcile_binding_sends_subscribe_payload(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController.reconcile_binding sends subscribe payload for ACTIVE intent.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + sc.reconcile_binding(binding) + + ## Assert + mock_send_payload.assert_called_once_with('sub_payload') + assert binding.attempts == 1 + assert binding.status == BindingStatus.NEW + + @capture_logs() + def test_reconcile_binding_sends_unsubscribe_payload(self, sc, test_subscription_no_confirm, mock_send_payload): + """SubscriptionController.reconcile_binding sends unsubscribe payload for UNSUBSCRIBED intent.""" + ## Arrange + sc.unsubscribe(test_subscription_no_confirm) + binding = sc._bindings[test_subscription_no_confirm.binding_key()] + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + mock_send_payload.assert_called_once_with('unsub_no_confirm_payload') + assert binding.attempts == 1 + + @capture_logs() + def test_reconcile_binding_auto_confirms_when_no_confirm_subscribe(self, sc, test_subscription_no_confirm, mock_send_payload): + """SubscriptionController.reconcile_binding auto-confirms when confirms_subscribe is False.""" + ## Arrange + sc.subscribe(test_subscription_no_confirm) + binding = sc._bindings[test_subscription_no_confirm.binding_key()] + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.ACTIVE + + @capture_logs() + def test_reconcile_binding_no_auto_confirm_when_confirms_unsubscribe_true(self, sc, test_subscription_no_confirm, mock_send_payload): + """SubscriptionController.reconcile_binding does not auto-confirm when confirms_unsubscribe is True.""" + ## Arrange + sc.unsubscribe(test_subscription_no_confirm) + binding = sc._bindings[test_subscription_no_confirm.binding_key()] + + ## Act + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.NEW + assert binding.attempts == 1 + + @capture_logs() + def test_reconcile_binding_auto_confirms_when_no_confirm_unsubscribe(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController.reconcile_binding auto-confirms when confirms_unsubscribe is False.""" + ## Arrange + sc.unsubscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.UNSUBSCRIBED + + @capture_logs() + def test_reconcile_binding_respects_timeout(self, sc, mock_sub, binding_key): + """SubscriptionController.reconcile_binding waits for timeout before retrying.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1000.5]): + sc.reconcile_binding(binding) + first_attempt = binding.attempts + sc.reconcile_binding(binding) + + ## Assert + assert binding.attempts == first_attempt + + @capture_logs() + def test_reconcile_binding_retries_after_timeout(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController.reconcile_binding retries after timeout expires.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1000.0, 1002.0]): + sc.reconcile_binding(binding) + sc.reconcile_binding(binding) # this call should return without making an attempt + sc.reconcile_binding(binding) + + ## Assert + assert binding.attempts == 2, 'should only attempt twice' + assert mock_send_payload.call_count == 2 + + @capture_logs() + def test_reconcile_binding_marks_failed_after_max_retries(self, sc, mock_sub, binding_key): + """SubscriptionController.reconcile_binding marks binding as FAILED after max retries.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1001.1, 1002.2, 1003.3, 1004.4]): + for i in range(4): + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.FAILED + + @capture_logs() + def test_reconcile_binding_marks_expired_when_no_activity(self, sc, test_subscription_with_expiry): + """SubscriptionController.reconcile_binding marks binding as EXPIRED when expiry time passes.""" + ## Arrange + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1011.0]): + sc.subscribe(test_subscription_with_expiry) + binding = sc._bindings[test_subscription_with_expiry.binding_key()] + with sc._condition: + sc.reconcile_binding(binding) + with sc._condition: + sc._confirm_subscribed(test_subscription_with_expiry.binding_key()) + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.EXPIRED + + @capture_logs() + def test_reconcile_binding_does_not_expire_without_expiry_seconds(self, sc, mock_sub, binding_key): + """SubscriptionController.reconcile_binding does not expire when expiry_seconds is None.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + sc.reconcile_binding(binding) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1002.0]): + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.ACTIVE + + @capture_logs() + def test_reconcile_binding_does_not_expire_before_expiry_time(self, sc, test_subscription_with_expiry): + """SubscriptionController.reconcile_binding does not expire before expiry_seconds elapses.""" + ## Arrange + with mock_module_time('ibind.ws_v2.ws_subscriptions', time_sequence=[1000.0, 1005.0]): + sc.subscribe(test_subscription_with_expiry) + binding = sc._bindings[test_subscription_with_expiry.binding_key()] + with sc._condition: + sc.reconcile_binding(binding) + with sc._condition: + sc._confirm_subscribed(test_subscription_with_expiry.binding_key()) + + ## Act + with sc._condition: + sc.reconcile_binding(binding) + + ## Assert + assert binding.status == BindingStatus.ACTIVE + + @capture_logs() + def test_reconcile_bindings_processes_all_bindings(self, sc): + """SubscriptionController.reconcile_bindings processes all registered bindings.""" + ## Arrange + sub1 = MockSubscription(topic_value='topic1', payload_value='payload1') + sub2 = MockSubscription(topic_value='topic2', payload_value='payload2') + sc.subscribe(sub1) + sc.subscribe(sub2) + + ## Act + sc.reconcile_bindings() + + ## Assert + binding1 = sc._bindings[sub1.binding_key()] + binding2 = sc._bindings[sub2.binding_key()] + assert binding1.attempts == 1 + assert binding2.attempts == 1 + + +class TestSend: + @capture_logs(logger_level='INFO', expected_errors=['Sending payload unsuccessful'], partial_match=True) + def test_send_logs_when_send_fails(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController._send logs when send_payload returns False.""" + ## Arrange + mock_send_payload.return_value = False + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + sc.reconcile_binding(binding) + + @capture_logs(logger_level='ERROR', expected_errors=['Exception sending payload'], partial_match=True) + def test_send_logs_exception(self, sc, mock_sub, mock_send_payload, binding_key): + """SubscriptionController._send logs exceptions from send_payload.""" + ## Arrange + mock_send_payload.side_effect = RuntimeError('send error') + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + + ## Act + sc.reconcile_binding(binding) + + +class TestWaitFor: + @capture_logs() + def test_wait_for_returns_true_when_done(self, sc, mock_sub, binding_key): + """SubscriptionController.wait_for returns True when binding is done.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Act + result = sc.wait_for(binding_key, timeout=1.0) + + ## Assert + assert result is True + + @capture_logs() + def test_wait_for_returns_false_when_failed(self, sc, mock_sub, binding_key): + """SubscriptionController.wait_for returns False when binding is FAILED.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + binding.status = BindingStatus.FAILED + + ## Act + result = sc.wait_for(binding_key, timeout=1.0) + + ## Assert + assert result is False + + @capture_logs() + def test_wait_for_returns_false_when_missing(self, sc): + """SubscriptionController.wait_for returns False when binding does not exist.""" + ## Act + result = sc.wait_for('nonexistent', timeout=0.1) + + ## Assert + assert result is False + + @capture_logs() + def test_wait_for_returns_false_on_timeout(self, sc, mock_sub, binding_key): + """SubscriptionController.wait_for returns False when timeout expires.""" + ## Arrange + sc.subscribe(mock_sub) + + ## Act + result = sc.wait_for(binding_key, timeout=0.001) + + ## Assert + assert result is False + + @capture_logs() + def test_wait_for_waits_and_unblocks_on_notification(self, sc, mock_sub, binding_key): + """SubscriptionController.wait_for waits for notification and returns when status changes.""" + ## Arrange + sc.subscribe(mock_sub) + event = MockEvent(binding_key=binding_key, is_active=True) + + original_wait = sc._condition.wait + wait_call_count = 0 + + def mock_wait(timeout=None): + nonlocal wait_call_count + wait_call_count += 1 + if wait_call_count == 1: + sc.observe(event) + else: + original_wait(timeout) + + ## Act + with patch.object(sc._condition, 'wait', side_effect=mock_wait): + result = sc.wait_for(binding_key, timeout=5.0) + + ## Assert + assert result is True + assert wait_call_count == 1 + assert sc.get_status(binding_key) == BindingStatus.ACTIVE + + +class TestConfirmSubscribed: + @capture_logs() + def test_confirm_subscribed_updates_status(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_subscribed updates status to ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + + ## Act + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.ACTIVE + + @capture_logs() + def test_confirm_subscribed_ignores_when_already_active(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_subscribed does not update when already ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + with sc._condition: + sc._confirm_subscribed(binding_key) + binding = sc._bindings[binding_key] + binding.attempts = 5 + + ## Act + with sc._condition: + sc._confirm_subscribed(binding_key) + + ## Assert + assert binding.attempts == 5 + + @capture_logs() + def test_confirm_subscribed_ignores_when_intent_unsubscribed(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_subscribed does not update when intent is UNSUBSCRIBED.""" + ## Arrange + sc.subscribe(mock_sub) + sc.unsubscribe(mock_sub) + binding = sc._bindings[binding_key] + original_status = binding.status + + ## Act + sc._confirm_subscribed(binding_key) + + ## Assert + assert binding.status == original_status + + @capture_logs(logger_level='WARNING', expected_errors=['Unknown subscription'], partial_match=True) + def test_confirm_subscribed_warns_when_missing(self, sc): + """SubscriptionController._confirm_subscribed logs warning when binding does not exist.""" + ## Act + sc._confirm_subscribed('nonexistent') + + +class TestConfirmUnsubscribed: + @capture_logs() + def test_confirm_unsubscribed_updates_status(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_unsubscribed updates status to UNSUBSCRIBED.""" + ## Arrange + sc.unsubscribe(mock_sub) + + ## Act + with sc._condition: + sc._confirm_unsubscribed(binding_key) + + ## Assert + assert sc.get_status(binding_key) == BindingStatus.UNSUBSCRIBED + + @capture_logs() + def test_confirm_unsubscribed_ignores_when_already_unsubscribed(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_unsubscribed does not update when already UNSUBSCRIBED.""" + ## Arrange + sc.unsubscribe(mock_sub) + with sc._condition: + sc._confirm_unsubscribed(binding_key) + binding = sc._bindings[binding_key] + binding.attempts = 5 + + ## Act + with sc._condition: + sc._confirm_unsubscribed(binding_key) + + ## Assert + assert binding.attempts == 5 + + @capture_logs() + def test_confirm_unsubscribed_ignores_when_intent_active(self, sc, mock_sub, binding_key): + """SubscriptionController._confirm_unsubscribed does not update when intent is ACTIVE.""" + ## Arrange + sc.subscribe(mock_sub) + binding = sc._bindings[binding_key] + original_status = binding.status + + ## Act + sc._confirm_unsubscribed(binding_key) + + ## Assert + assert binding.status == original_status + + @capture_logs(logger_level='WARNING', expected_errors=['Unknown subscription'], partial_match=True) + def test_confirm_unsubscribed_warns_when_missing(self, sc): + """SubscriptionController._confirm_unsubscribed logs warning when binding does not exist.""" + ## Act + sc._confirm_unsubscribed('nonexistent') From 1e49a0c8246ab0c451a899aaf895ecffe18286af Mon Sep 17 00:00:00 2001 From: Shreyas Zanpure Date: Wed, 13 May 2026 11:54:50 +0100 Subject: [PATCH 32/32] Fix OAuth and websocket v2 regressions --- .gitignore | 3 +- examples/rest_06_options_chain.py | 2 +- examples/ws_02_intermediate.py | 2 +- ibind/base/rest_client.py | 39 +++--- ibind/base/ws_client.py | 32 +++-- ibind/client/ibkr_client.py | 10 +- .../ibkr_client_mixins/marketdata_mixin.py | 3 +- ibind/client/ibkr_utils.py | 18 ++- ibind/client/ibkr_ws_client.py | 12 +- ibind/ibkr_ws_v2/__init__.py | 1 + ibind/ibkr_ws_v2/ibkr_router.py | 12 +- ibind/ibkr_ws_v2/ibkr_subscriptions.py | 4 +- ibind/ibkr_ws_v2/ibkr_ws_client_v2.py | 24 ++-- ibind/oauth/oauth1a.py | 1 - ibind/subscriptions/__init__.py | 4 +- ibind/support/py_utils.py | 34 ++++- ibind/ws_v2/__init__.py | 1 + ibind/ws_v2/_ws_events.py | 19 ++- ibind/ws_v2/ws_runtime.py | 60 ++++----- ibind/ws_v2/ws_transport.py | 36 ++++-- pyproject.toml | 7 +- test/integration/base/test_rest_client_i.py | 45 ++++++- .../base/test_websocket_client_i.py | 32 ++++- test/integration/base/websocketapp_mock.py | 3 +- test/integration/client/test_ibkr_client_i.py | 16 ++- test/integration/client/test_ibkr_utils_i.py | 32 ++++- .../client/test_ibkr_ws_client_i.py | 32 ++++- test/test_utils.py | 2 +- test/unit/client/test_ibkr_client_u.py | 38 +++++- test/unit/client/test_oauth1a_u.py | 59 +++++++++ test/unit/support/test_py_utils_u.py | 18 ++- test/unit/test_public_imports_u.py | 13 ++ test/unit/ws_v2/test_ws_events_u.py | 2 +- test/unit/ws_v2/test_ws_v2_regressions_u.py | 121 ++++++++++++++++++ 34 files changed, 597 insertions(+), 140 deletions(-) create mode 100644 ibind/ibkr_ws_v2/__init__.py create mode 100644 ibind/ws_v2/__init__.py create mode 100644 test/unit/client/test_oauth1a_u.py create mode 100644 test/unit/test_public_imports_u.py create mode 100644 test/unit/ws_v2/test_ws_v2_regressions_u.py diff --git a/.gitignore b/.gitignore index 8ee7582a..1048e009 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__ ibind.egg-info build/docs/.generated-files.txt build/docs/mkdocs.yml +build/lib/ .venv .env @@ -11,4 +12,4 @@ build/docs/mkdocs.yml .vscode .DS_Store venv -.coverage \ No newline at end of file +.coverage diff --git a/examples/rest_06_options_chain.py b/examples/rest_06_options_chain.py index 510e922f..3f11d8af 100644 --- a/examples/rest_06_options_chain.py +++ b/examples/rest_06_options_chain.py @@ -136,4 +136,4 @@ response = client.place_order(order_request, answers, account_id).data -print(response) \ No newline at end of file +print(response) diff --git a/examples/ws_02_intermediate.py b/examples/ws_02_intermediate.py index ca498c2f..dc1ba8a0 100644 --- a/examples/ws_02_intermediate.py +++ b/examples/ws_02_intermediate.py @@ -68,4 +68,4 @@ def stop(_, _1): print('KeyboardInterrupt') break -stop(None, None) \ No newline at end of file +stop(None, None) diff --git a/ibind/base/rest_client.py b/ibind/base/rest_client.py index a524a1b0..d835845a 100644 --- a/ibind/base/rest_client.py +++ b/ibind/base/rest_client.py @@ -49,7 +49,10 @@ def copy(self, data: Optional[Union[list, dict]] = UNDEFINED, request: Optional[ Returns: Result: A new Result instance with the specified modifications. """ - return Result(data=data if data is not UNDEFINED else self.data.copy(), request=request if request is not UNDEFINED else self.request.copy()) + def _copy(value): + return value.copy() if hasattr(value, 'copy') else value + + return Result(data=data if data is not UNDEFINED else _copy(self.data), request=request if request is not UNDEFINED else _copy(self.request)) def pass_result(data: dict, old_result: Result) -> Result: @@ -117,6 +120,7 @@ def __init__( self.use_session = use_session self._auto_recreate_session = auto_recreate_session + self._closed = False if use_session: self.make_session() @@ -139,7 +143,7 @@ def logger(self): self._make_logger() return self._logger - def _get_headers(self, request_method: str, request_url: str): + def _get_headers(self, request_method: str, request_url: str, request_params: dict = None): return {} def get( @@ -224,23 +228,24 @@ def _request( endpoint = endpoint.lstrip('/') url = f'{base_url}{endpoint}' - headers = self._get_headers(request_method=method, request_url=url) - headers = {**headers, **(extra_headers or {})} - # we want to allow default values used by IBKR, so we remove all None parameters kwargs = filter_none(kwargs) - # choose which function should be used to make a reqeust based on use_session field - if self.use_session and self._session is not None: - request_function = self._session.request - else: - request_function = requests.request - - if request_function is None: - _LOGGER.warning(f'{self}: an attempt was made to create a request with no valid session.') + request_params = kwargs.get('params') if method.upper() == 'GET' else None + headers = self._get_headers(request_method=method, request_url=url, request_params=request_params) + headers = {**headers, **(extra_headers or {})} # we repeat the request attempts in case of ReadTimeouts up to max_retries for attempt in range(self._max_retries + 1): + # choose which function should be used to make a request based on the current session + if self.use_session and self._session is not None: + request_function = self._session.request + else: + request_function = requests.request + + if request_function is None: + _LOGGER.warning(f'{self}: an attempt was made to create a request with no valid session.') + if log: self.logger.info(f'{method} {url} {kwargs}{" (attempt: " + str(attempt) + ")" if attempt > 0 else ""}') @@ -308,6 +313,9 @@ def close_session(self): self._session = None def close(self): + if getattr(self, '_closed', False): + return + self._closed = True self.close_session() @@ -328,12 +336,7 @@ def register_shutdown_handler(self): existing_handler_int = signal.getsignal(signal.SIGINT) existing_handler_term = signal.getsignal(signal.SIGTERM) - self._closed = False - def _close_handler(): - if self._closed: - return - self._closed = True self.close() def _signal_handler(signum, frame): diff --git a/ibind/base/ws_client.py b/ibind/base/ws_client.py index 9da4e4a7..c3d53ade 100644 --- a/ibind/base/ws_client.py +++ b/ibind/base/ws_client.py @@ -1,8 +1,6 @@ import json -import ssl import threading import time -from pathlib import Path from threading import Thread, RLock from typing import Optional, Union, Dict, List @@ -10,7 +8,7 @@ from ibind.base.subscription_controller import SubscriptionController, SubscriptionProcessor from ibind.support.logs import project_logger -from ibind.support.py_utils import exception_to_string, wait_until, tname +from ibind.support.py_utils import exception_to_string, wait_until, tname, make_websocket_sslopt _LOGGER = project_logger(__file__) @@ -85,14 +83,9 @@ def __init__( self._thread = None self._thread_ids = {} self._next_thread_id = 0 + self._last_unanswered_ping_tm = None - if not (cacert is False or Path(cacert).exists()): - raise ValueError(f'{self}: cacert must be a valid Path or False') - - if cacert is None or not cacert: - self._sslopt = {'cert_reqs': ssl.CERT_NONE} - else: - self._sslopt = {'ca_certs': cacert} + self._sslopt = make_websocket_sslopt(cacert) def send(self, payload: str) -> bool: """ @@ -303,6 +296,7 @@ def _handle_on_message(self, wsa: WebSocketApp, message): # pragma: no cover def _handle_on_open(self, wsa: WebSocketApp): _LOGGER.info(f'{self}: Connection open') self._connected = True + self._last_unanswered_ping_tm = None self._on_open(wsa) def _handle_on_error(self, wsa: WebSocketApp, error): # pragma: no cover @@ -463,13 +457,25 @@ def check_ping(self) -> bool: if self._wsa is None: return True - if self._wsa.last_ping_tm == 0: + last_ping_tm = getattr(self._wsa, 'last_ping_tm', 0) + last_pong_tm = getattr(self._wsa, 'last_pong_tm', 0) + + if last_ping_tm == 0: return True - diff = abs(time.time() - self._wsa.last_ping_tm) + if last_pong_tm >= last_ping_tm and last_pong_tm != 0: + self._last_unanswered_ping_tm = None + diff = abs(time.time() - last_pong_tm) + else: + if self._last_unanswered_ping_tm is not None and last_pong_tm >= self._last_unanswered_ping_tm: + self._last_unanswered_ping_tm = None + if self._last_unanswered_ping_tm is None: + self._last_unanswered_ping_tm = last_ping_tm + diff = abs(time.time() - self._last_unanswered_ping_tm) + if diff > self._max_ping_interval: _LOGGER.warning( - f'{self}: Last WebSocket ping happened {diff: .2f} seconds ago, exceeding the max ping interval of {self._max_ping_interval}. Restarting.' + f'{self}: Last WebSocket pong happened {diff: .2f} seconds ago, exceeding the max ping interval of {self._max_ping_interval}. Restarting.' ) self.hard_reset(restart=True) return False diff --git a/ibind/client/ibkr_client.py b/ibind/client/ibkr_client.py index 15be8e36..a926b9ff 100644 --- a/ibind/client/ibkr_client.py +++ b/ibind/client/ibkr_client.py @@ -134,7 +134,7 @@ def _request(self, method: str, endpoint: str, base_url: str = None, extra_heade raise ExternalBrokerError('IBKR returned 400 Bad Request: no bridge. Try calling `initialize_brokerage_session()` first.') from e raise - def _get_headers(self, request_method: str, request_url: str): + def _get_headers(self, request_method: str, request_url: str, request_params: dict = None): if (not self._use_oauth) or request_url == f'{self.base_url}{self.oauth_config.live_session_token_endpoint}': # No need for extra headers if we don't use oauth or getting live session token return {} @@ -143,7 +143,11 @@ def _get_headers(self, request_method: str, request_url: str): from ibind.oauth.oauth1a import generate_oauth_headers headers = generate_oauth_headers( - oauth_config=self.oauth_config, request_method=request_method, request_url=request_url, live_session_token=self.live_session_token + oauth_config=self.oauth_config, + request_method=request_method, + request_url=request_url, + live_session_token=self.live_session_token, + request_params=request_params, ) return headers @@ -251,6 +255,8 @@ def stop_tickler(self, timeout:float=None): self._tickler.stop(timeout) def close(self): + if getattr(self, '_closed', False): + return if self._use_oauth and self.oauth_config.shutdown_oauth: self.oauth_shutdown() super().close() diff --git a/ibind/client/ibkr_client_mixins/marketdata_mixin.py b/ibind/client/ibkr_client_mixins/marketdata_mixin.py index 7522ab30..6c5a1146 100644 --- a/ibind/client/ibkr_client_mixins/marketdata_mixin.py +++ b/ibind/client/ibkr_client_mixins/marketdata_mixin.py @@ -189,7 +189,8 @@ def marketdata_history_by_symbol( start_time (datetime.datetime, optional): Starting date of the request duration. """ - conid = str(self.stock_conid_by_symbol(symbol).data[symbol]) + symbol_key = symbol.symbol if isinstance(symbol, StockQuery) else symbol + conid = str(self.stock_conid_by_symbol(symbol).data[symbol_key]) return self.marketdata_history_by_conid(conid, bar, exchange, period, outside_rth, start_time) def marketdata_history_by_conids( diff --git a/ibind/client/ibkr_utils.py b/ibind/client/ibkr_utils.py index bacbd4d9..45dd0947 100644 --- a/ibind/client/ibkr_utils.py +++ b/ibind/client/ibkr_utils.py @@ -663,13 +663,16 @@ def start(self): This method creates and starts a daemon thread that periodically calls `tickle()` to keep the session alive. """ - if self._thread is not None: + if self._thread is not None and self._thread.is_alive(): _LOGGER.info('Tickler thread already running. Stop the existing thread first by calling Tickler.stop()') - return + return False + + self._thread = None self._stop_event.clear() self._thread = threading.Thread(target=self._worker, daemon=True) self._thread.start() + return True def stop(self, timeout:float=None): """ @@ -682,11 +685,18 @@ def stop(self, timeout:float=None): If None, waits indefinitely. """ if self._thread is None: - return + return True + thread = self._thread self._stop_event.set() # Wake up the sleeping thread immediately - self._thread.join(timeout) + thread.join(timeout) + + if thread.is_alive(): + _LOGGER.error(f'Tickler thread did not stop within timeout={timeout}. Keeping the thread reference to avoid starting a duplicate tickler.') + return False + self._thread = None # Ensure cleanup + return True def cleanup_market_history_responses( diff --git a/ibind/client/ibkr_ws_client.py b/ibind/client/ibkr_ws_client.py index 74918a95..92db3102 100644 --- a/ibind/client/ibkr_ws_client.py +++ b/ibind/client/ibkr_ws_client.py @@ -15,7 +15,7 @@ from ibind.client.ibkr_utils import extract_conid from ibind.support.errors import ExternalBrokerError from ibind.support.logs import project_logger -from ibind.support.py_utils import TimeoutLock, UNDEFINED, wait_until +from ibind.support.py_utils import TimeoutLock, UNDEFINED, wait_until, append_query_params _LOGGER = project_logger(__file__) @@ -278,7 +278,8 @@ def __init__( raise ValueError( 'OAuth access token not found. Please set IBIND_OAUTH1A_ACCESS_TOKEN environment variable or provide it as `access_token` argument.' ) - url += f'?oauth_token={access_token}' + url = append_query_params(url, {'oauth_token': access_token}) + cacert = True if ibkr_client is None: ibkr_client = IbkrClient(account_id=account_id, host=host, port=port, cacert=cacert, use_oauth=use_oauth) @@ -401,12 +402,11 @@ def _handle_authentication_status(self, message, data): _LOGGER.error(f'{self}: Status unauthenticated: {data}') self.set_authenticated(data.get('authenticated')) elif 'competing' in data: - if data.get('competing') is False: - pass - _LOGGER.error(f'{self}: Status competing: {data}') + if data.get('competing') is True: + _LOGGER.error(f'{self}: Status competing: {data}') elif ( # expected status updates that we ignore data == {'message': ''} or - data.get('fail', '') == '' or + data.get('fail') == '' or 'serverName' in data or 'serverVersion' in data or 'username' in data diff --git a/ibind/ibkr_ws_v2/__init__.py b/ibind/ibkr_ws_v2/__init__.py new file mode 100644 index 00000000..e3a13dae --- /dev/null +++ b/ibind/ibkr_ws_v2/__init__.py @@ -0,0 +1 @@ +"""IBKR WebSocket V2 client implementation.""" diff --git a/ibind/ibkr_ws_v2/ibkr_router.py b/ibind/ibkr_ws_v2/ibkr_router.py index c6b527ca..e47cdf85 100644 --- a/ibind/ibkr_ws_v2/ibkr_router.py +++ b/ibind/ibkr_ws_v2/ibkr_router.py @@ -2,10 +2,9 @@ from collections import defaultdict from typing import Dict -from client import ibkr_definitions -from client.ibkr_utils import extract_conid +from ibind.client import ibkr_definitions +from ibind.client.ibkr_utils import extract_conid -# from ibkr_ws_v2 import ibkr_events from ibind import events from ibind.events import GenericIbkrEvent, IbkrTopicEvent from ibind.support.logs import project_logger @@ -148,14 +147,15 @@ def _handle_authentication_status(self, message, arguments) -> OneOrMany[WsEvent return events.AuthenticationStatus(data=arguments, authenticated=arguments.get('authenticated'), competing=arguments.get('competing')) elif ( # expected status updates that we ignore arguments == {'message': ''} - or arguments.get('fail', '') == '' + or arguments.get('fail') == '' or 'serverName' in arguments or 'serverVersion' in arguments or 'username' in arguments ): - pass + return [] - return [] + _LOGGER.info(f'{self}: Status message: {arguments}') + return GenericIbkrEvent(message=message, topic='sts', data=arguments) def _handle_bulletin(self, message) -> OneOrMany[WsEvent]: # pragma: no cover return events.Bulletin(message=message) diff --git a/ibind/ibkr_ws_v2/ibkr_subscriptions.py b/ibind/ibkr_ws_v2/ibkr_subscriptions.py index 72f16ece..f349bfae 100644 --- a/ibind/ibkr_ws_v2/ibkr_subscriptions.py +++ b/ibind/ibkr_ws_v2/ibkr_subscriptions.py @@ -5,8 +5,8 @@ from ibind import events from ibind.events import AccountLedger, MarketData, MarketHistory, Orders, PriceLadder, Pnl, Trades, Unsubscription, AccountSummary, IbkrTopicEvent -from support.py_utils import filter_none -from ws_v2.ws_subscriptions import Subscription, SubscriptionResolver +from ibind.support.py_utils import filter_none +from ibind.ws_v2.ws_subscriptions import Subscription, SubscriptionResolver def make_binding_key( diff --git a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py index b05f9c7d..10d4ff33 100644 --- a/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py +++ b/ibind/ibkr_ws_v2/ibkr_ws_client_v2.py @@ -3,16 +3,16 @@ from typing import Union, List, Dict, Type from ibind import events -import var -from ibind import IbkrClient +from ibind import var +from ibind.client.ibkr_client import IbkrClient from ibind.events import IbkrTopicEvent -from ibkr_ws_v2.ibkr_router import IbkrRouter -from ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription -from support.logs import project_logger -from support.py_utils import OneOrMany, ensure_list_arg -from ws_v2._ws_events import EventSink, CallbackSink, Router, AsyncSink, NoopSink -from ws_v2.ws_subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle, BindingStatus -from ws_v2.ws_runtime import WsRuntime, WsState +from ibind.ibkr_ws_v2.ibkr_router import IbkrRouter +from ibind.ibkr_ws_v2.ibkr_subscriptions import IbkrSubscriptionResolver, MarketHistorySubscription +from ibind.support.logs import project_logger +from ibind.support.py_utils import OneOrMany, ensure_list_arg, append_query_params +from ibind.ws_v2._ws_events import EventSink, CallbackSink, Router, AsyncSink, NoopSink +from ibind.ws_v2.ws_subscriptions import Subscription, SubscriptionResolver, SubscriptionHandle, BindingStatus +from ibind.ws_v2.ws_runtime import WsRuntime, WsState _LOGGER = project_logger('ibkr_ws_client') @@ -50,7 +50,8 @@ def __init__( raise ValueError( 'OAuth access token not found. Please set IBIND_OAUTH1A_ACCESS_TOKEN environment variable or provide it as `access_token` argument.' ) - url += f'?oauth_token={access_token}' + url = append_query_params(url, {'oauth_token': access_token}) + cacert = True if ibkr_client is None: ibkr_client = IbkrClient(account_id=account_id, host=host, port=port, cacert=cacert, use_oauth=use_oauth) @@ -115,7 +116,8 @@ def _on_authentication_status(self, event: events.AuthenticationStatus): elif event.competing is True: _LOGGER.error(f'{self}: Authentication competing: {event}') - self._runtime.set_authenticated(event.authenticated) + if event.authenticated is not None: + self._runtime.set_authenticated(event.authenticated) def _on_system(self, event: events.System): if 'hb' in event.data: diff --git a/ibind/oauth/oauth1a.py b/ibind/oauth/oauth1a.py index 6b849588..82e3f73c 100644 --- a/ibind/oauth/oauth1a.py +++ b/ibind/oauth/oauth1a.py @@ -237,7 +237,6 @@ def generate_oauth_headers( 'Accept-Encoding': 'gzip,deflate', 'Authorization': headers_string, 'Connection': 'keep-alive', - 'Host': 'api.ibkr.com', 'User-Agent': 'ibind', } diff --git a/ibind/subscriptions/__init__.py b/ibind/subscriptions/__init__.py index 4160e9a3..164b8f66 100644 --- a/ibind/subscriptions/__init__.py +++ b/ibind/subscriptions/__init__.py @@ -1,6 +1,6 @@ -from ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription +from ibind.ibkr_ws_v2.ibkr_subscriptions import MarketDataSubscription, OrdersSubscription, AccountLedgerSubscription, AccountSummarySubscription, PnlSubscription, TradesSubscription, MarketHistorySubscription -from ws_v2.ws_subscriptions import SubscriptionHandle, BindingStatus, Subscription, SubscriptionResolver +from ibind.ws_v2.ws_subscriptions import SubscriptionHandle, BindingStatus, Subscription, SubscriptionResolver __all__ = [ 'Subscription', diff --git a/ibind/support/py_utils.py b/ibind/support/py_utils.py index d35e1609..3e9a0898 100644 --- a/ibind/support/py_utils.py +++ b/ibind/support/py_utils.py @@ -1,6 +1,7 @@ import copy import inspect import os +import ssl import sys import threading import time @@ -10,7 +11,9 @@ from enum import Enum, EnumMeta from functools import wraps from collections.abc import Mapping +from pathlib import Path from typing import List, TypeVar, Union, Dict +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit from ibind.support.logs import project_logger @@ -199,6 +202,27 @@ def filter_none(d): # pragma: no cover return d +def append_query_params(url: str, params: dict) -> str: + parts = urlsplit(url) + query_items = parse_qsl(parts.query, keep_blank_values=True) + query_items.extend((key, value) for key, value in params.items() if value is not None) + query = urlencode(query_items) + return urlunsplit((parts.scheme, parts.netloc, parts.path, query, parts.fragment)) + + +def make_websocket_sslopt(cacert: Union[str, os.PathLike, bool, None]) -> dict: + if cacert in (False, None): + return {'cert_reqs': ssl.CERT_NONE} + + if cacert is True: + return {'cert_reqs': ssl.CERT_REQUIRED} + + if not Path(cacert).exists(): + raise ValueError(f'cacert must be a valid Path, True, False, or None, found: {cacert}') + + return {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': str(cacert)} + + class TimeoutLock: # pragma: no cover """ A lock with a timeout mechanism, extending the standard threading.RLock. @@ -213,17 +237,19 @@ class TimeoutLock: # pragma: no cover def __init__(self, timeout: int): self._lock = threading.RLock() self._timeout = timeout - self._acquired = False def acquire(self, *args, **kwargs): - self._acquired = self._lock.acquire(*args, timeout=self._timeout, **kwargs) + acquired = self._lock.acquire(*args, timeout=self._timeout, **kwargs) + if not acquired: + raise TimeoutError(f'Could not acquire lock within {self._timeout} seconds') + return True def release(self): - if self._acquired: - self._lock.release() + self._lock.release() def __enter__(self): self.acquire() + return self def __exit__(self, type, value, traceback): self.release() diff --git a/ibind/ws_v2/__init__.py b/ibind/ws_v2/__init__.py new file mode 100644 index 00000000..0b087100 --- /dev/null +++ b/ibind/ws_v2/__init__.py @@ -0,0 +1 @@ +"""WebSocket V2 runtime primitives.""" diff --git a/ibind/ws_v2/_ws_events.py b/ibind/ws_v2/_ws_events.py index 9890c93b..5957ec93 100644 --- a/ibind/ws_v2/_ws_events.py +++ b/ibind/ws_v2/_ws_events.py @@ -6,9 +6,9 @@ from pydantic import BaseModel, ConfigDict, Field -from base.queue_controller import QueueAccessor -from support.logs import project_logger -from support.py_utils import OneOrMany, exception_to_string, tname +from ibind.base.queue_controller import QueueAccessor +from ibind.support.logs import project_logger +from ibind.support.py_utils import OneOrMany, exception_to_string, tname __all__ = [] @@ -138,7 +138,8 @@ class CallbackSink: Exceptions from callbacks are logged but do not propagate. """ - _callbacks: Dict[type[WsEvent], List[Callable[[WsEvent], None]]] = {} + def __init__(self): + self._callbacks: Dict[type[WsEvent], List[Callable[[WsEvent], None]]] = {} def on(self, event_type: type[WsEvent], callback: Callable[[T], None]) -> None: """ @@ -175,7 +176,8 @@ class QueueSink: retrieved synchronously or asynchronously via queue accessors. """ - _queues = {} + def __init__(self): + self._queues = {} def new_queue_accessor(self, event_type: type[WsEvent]) -> QueueAccessor: """ @@ -314,10 +316,14 @@ def start(self): """Start the background thread for processing events.""" if self._running: return + if self._thread is not None and self._thread.is_alive(): + _LOGGER.error(f'{self}: Async sink thread is still alive and cannot be restarted yet') + return False self._running = True self._thread = Thread(target=self._cycle, name='async_sink_thread', daemon=True) self._thread.start() + return True def stop(self) -> bool: """ @@ -343,7 +349,8 @@ def stop(self) -> bool: self._thread.join(self._stop_timeout) succeeded = not self._thread.is_alive() - self._thread = None + if succeeded: + self._thread = None if self._queue.qsize() > 0: _LOGGER.warning(f'{self}: Event queue not empty when stopping; discarding {self._queue.qsize()} events') diff --git a/ibind/ws_v2/ws_runtime.py b/ibind/ws_v2/ws_runtime.py index 86c7468d..59954e66 100644 --- a/ibind/ws_v2/ws_runtime.py +++ b/ibind/ws_v2/ws_runtime.py @@ -1,19 +1,17 @@ import json -import ssl import threading import time -from pathlib import Path from queue import Queue from threading import Thread, Event from typing import Union, List, Dict, Callable, Literal -from support.logs import project_logger -from support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, noop from ibind import events from ibind.events import WsEvent -from ws_v2._ws_events import EventSink, Router, CallbackSink, AsyncSink -from ws_v2.ws_subscriptions import SubscriptionController, SubscriptionResolver -from ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect +from ibind.support.logs import project_logger +from ibind.support.py_utils import wait_until, tname, VerboseEnum, exception_to_string, TimeoutLock, OneOrMany, noop, make_websocket_sslopt +from ibind.ws_v2._ws_events import EventSink, Router, CallbackSink, AsyncSink +from ibind.ws_v2.ws_subscriptions import SubscriptionController, SubscriptionResolver +from ibind.ws_v2.ws_transport import WsTransport, TransportEvent, TransportOpened, TransportClosed, TransportError, TransportMessage, TransportReconnect _LOGGER = project_logger('ibkr_ws_client') @@ -23,25 +21,19 @@ class WsState(VerboseEnum): - STOPPED = 'STOPPED', - STARTING = 'STARTING', - CONNECTING = 'CONNECTING', - OPEN = 'OPEN', - AUTHENTICATED = 'AUTHENTICATED', - CLOSED = 'CLOSED', - DEGRADED = 'DEGRADED', - RECONNECTING = 'RECONNECTING', - STOPPING = 'STOPPING', + STOPPED = 'STOPPED' + STARTING = 'STARTING' + CONNECTING = 'CONNECTING' + OPEN = 'OPEN' + AUTHENTICATED = 'AUTHENTICATED' + CLOSED = 'CLOSED' + DEGRADED = 'DEGRADED' + RECONNECTING = 'RECONNECTING' + STOPPING = 'STOPPING' def make_sslopt(cacert: Union[str, bool]): - if not (cacert is False or Path(cacert).exists()): - raise ValueError(f'Cacert must be a valid Path or False, found: {cacert}') - - if cacert is None or not cacert: - return {'cert_reqs': ssl.CERT_NONE} - else: - return {'ca_certs': cacert} + return make_websocket_sslopt(cacert) class WsRuntime(): @@ -168,7 +160,8 @@ def _stop_transport_thread(self) -> bool: self._transport_thread.join(self._connection_timeout) is_alive = self._transport_thread.is_alive() - self._transport_thread = None + if not is_alive: + self._transport_thread = None return not is_alive except Exception as e: _LOGGER.error(f'{self}: Failed to stop transport thread: {e}') @@ -213,16 +206,17 @@ def stop(self): self._set_state(WsState.STOPPING) transport_thread_stopped = self._stop_transport_thread() if not transport_thread_stopped: - _LOGGER.error(f'{self}: Failed to stop transport thread, abandoning...') - self._transport_thread = None + _LOGGER.error(f'{self}: Failed to stop transport thread; keeping the thread reference to avoid duplicate transport loops') self._transport.set_degraded(True) self._running = False - if self._runtime_thread is not None: - self._runtime_thread.join(self._connection_timeout) + runtime_thread = self._runtime_thread + if runtime_thread is not None: + runtime_thread.join(self._connection_timeout) - if self._runtime_thread.is_alive(): - _LOGGER.error(f'{self}: Runtime thread failed to stop, abandoning...') + if runtime_thread is not None and runtime_thread.is_alive(): + _LOGGER.error(f'{self}: Runtime thread failed to stop; keeping the thread reference to avoid duplicate runtime loops') + return self._runtime_thread = None @@ -264,12 +258,14 @@ def restart_transport(self): transport_thread_stopped = self._stop_transport_thread() if not transport_thread_stopped: - _LOGGER.error(f'{self}: Failed to stop transport thread, abandoning...') - self._transport_thread = None + _LOGGER.error(f'{self}: Failed to stop transport thread; restart aborted to avoid duplicate transport loops') + self._transport.set_degraded(True) + return False self._transport.set_degraded(True) self._transport = self._new_transport() self._new_transport_thread() + return True def reset_websocket_app(self): self._transport.reset_websocket_app() diff --git a/ibind/ws_v2/ws_transport.py b/ibind/ws_v2/ws_transport.py index 704ced61..19fa47f5 100644 --- a/ibind/ws_v2/ws_transport.py +++ b/ibind/ws_v2/ws_transport.py @@ -5,10 +5,10 @@ from pydantic import BaseModel, ConfigDict, Field from websocket import WebSocketApp, STATUS_UNEXPECTED_CONDITION, STATUS_NORMAL -import var -from ibind import ExternalBrokerError -from support.logs import project_logger -from support.py_utils import exception_to_string, tname, wait_until, UNDEFINED, noop +from ibind import var +from ibind.support.errors import ExternalBrokerError +from ibind.support.logs import project_logger +from ibind.support.py_utils import exception_to_string, tname, wait_until, UNDEFINED, noop _LOGGER = project_logger('ibkr_ws_client') @@ -120,6 +120,7 @@ def __init__( self._wsa: WebSocketApp | None = None self._degraded = False self._tname = None + self._last_unanswered_ping_tm = None self._session_lacks_authentication = False @@ -180,17 +181,36 @@ def check_ping(self, max_interval: float = None) -> bool: if self._wsa is None: return True - if self._wsa.last_pong_tm == 0: + last_ping_tm = getattr(self._wsa, 'last_ping_tm', 0) + last_pong_tm = getattr(self._wsa, 'last_pong_tm', 0) + + if last_ping_tm == 0: return True if max_interval is None: max_interval = self._max_ping_interval - return self.get_time_since_last_ping() <= max_interval + now = time.time() + if last_pong_tm >= last_ping_tm and last_pong_tm != 0: + self._last_unanswered_ping_tm = None + return abs(now - last_pong_tm) <= max_interval + + if self._last_unanswered_ping_tm is not None and last_pong_tm >= self._last_unanswered_ping_tm: + self._last_unanswered_ping_tm = None + if self._last_unanswered_ping_tm is None: + self._last_unanswered_ping_tm = last_ping_tm + + return abs(now - self._last_unanswered_ping_tm) <= max_interval def get_time_since_last_ping(self) -> float: - """Get seconds elapsed since the last pong was received.""" - return abs(time.time() - self._wsa.last_pong_tm) + """Get seconds elapsed since the latest ping that still needs a fresh pong, or the last pong.""" + last_ping_tm = getattr(self._wsa, 'last_ping_tm', 0) + last_pong_tm = getattr(self._wsa, 'last_pong_tm', 0) + if last_pong_tm >= last_ping_tm and last_pong_tm != 0: + return abs(time.time() - last_pong_tm) + if self._last_unanswered_ping_tm is not None: + return abs(time.time() - self._last_unanswered_ping_tm) + return abs(time.time() - last_ping_tm) def fetch_cookie(self) -> Union[str, None]: """ diff --git a/pyproject.toml b/pyproject.toml index 0b72a5d6..706e8f75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,12 @@ authors = [ ] description = "IBind is a REST and WebSocket client library for Interactive Brokers Client Portal Web API." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" license-files=["LICENSE"] dependencies = [ "requests>=2.31", - "websocket-client>=1.7" + "websocket-client>=1.7", + "pydantic>=2.13" ] classifiers = [ "Development Status :: 4 - Beta", @@ -71,4 +72,4 @@ docstring-code-format = true requires = [ "setuptools" ] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" diff --git a/test/integration/base/test_rest_client_i.py b/test/integration/base/test_rest_client_i.py index c7effe2d..a5c3e29e 100644 --- a/test/integration/base/test_rest_client_i.py +++ b/test/integration/base/test_rest_client_i.py @@ -1,11 +1,11 @@ import asyncio -import logging import threading import pytest from unittest.mock import MagicMock from requests import ReadTimeout, Timeout +from requests import ConnectionError as RequestsConnectionError from ibind.client.ibkr_client import IbkrClient from ibind.support.errors import ExternalBrokerError @@ -136,6 +136,47 @@ def test_response_raise_generic(client, result, requests_mock): assert f'RestClient: response error {result.copy(data=None)} :: {response.status_code} :: {response.reason} :: {response.text}' == str(excinfo.value) +def test_result_copy_with_none_data(): + result = Result() + assert result.copy() == Result(data=None, request={}) + + +def test_request_recreates_session_before_retry(default_url, data, requests_mock): + # Arrange + response = MagicMock() + response.json.return_value = data + + first_session = MagicMock() + first_session.request.side_effect = RequestsConnectionError('connection dropped') + + second_session = MagicMock() + second_session.request.return_value = response + + requests_mock.Session.side_effect = [first_session, second_session] + client = RestClient(url=_URL, timeout=_TIMEOUT, max_retries=1, use_session=True) + + # Act + rv = client.get(_DEFAULT_PATH) + + # Assert + assert rv == Result(data=data, request={'url': default_url}) + first_session.close.assert_called_once() + first_session.request.assert_called_once_with('GET', default_url, verify=False, headers={}, timeout=_TIMEOUT) + second_session.request.assert_called_once_with('GET', default_url, verify=False, headers={}, timeout=_TIMEOUT) + + +def test_close_is_idempotent(client): + # Arrange + client.close_session = MagicMock() + + # Act + client.close() + client.close() + + # Assert + client.close_session.assert_called_once() + + def _worker_in_thread(results: []): try: IbkrClient() @@ -218,4 +259,4 @@ def test_without_thread_async(): # Assert for result in results: if isinstance(result, Exception): - raise result \ No newline at end of file + raise result diff --git a/test/integration/base/test_websocket_client_i.py b/test/integration/base/test_websocket_client_i.py index 40801dd2..df2321b0 100644 --- a/test/integration/base/test_websocket_client_i.py +++ b/test/integration/base/test_websocket_client_i.py @@ -74,7 +74,7 @@ def _logs_exception_starting(error_message: str, thread_mock: MagicMock): def _logs_check_health_error(max_ping_interval: int, time_ago: str): return [ - f'WsClient: Last WebSocket ping happened {time_ago} seconds ago, exceeding the max ping interval of {max_ping_interval}. Restarting.', + f'WsClient: Last WebSocket pong happened {time_ago} seconds ago, exceeding the max ping interval of {max_ping_interval}. Restarting.', 'WsClient: Hard reset, restart=True, self._wsa is None=False', 'WsClient: Hard reset is closing the WebSocketApp', ] @@ -353,7 +353,7 @@ def test_send_without_start(ws_client, **kwargs): @capture_logs( logger_level='DEBUG', expected_errors=[ - 'WsClient: Last WebSocket ping happened', + 'WsClient: Last WebSocket pong happened', 'WsClient: Hard reset close timeout', 'WsClient: Abandoning current WebSocketApp that cannot be closed:', ], @@ -395,4 +395,30 @@ def fake_time(): + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] - ) \ No newline at end of file + ) + + +def test_check_ping_tracks_first_unanswered_ping(ws_client, wsa_mock, patched_constructors, mocker): + """Fails even if outgoing ping timestamps keep updating without a pong.""" + ## Arrange + current_time = [100.0] + mocker.patch('ibind.base.ws_client.time.time', side_effect=lambda: current_time[0]) + ws_client.start() + + ## Act / Assert + wsa_mock.last_ping_tm = 100.0 + wsa_mock.last_pong_tm = 0 + assert ws_client.check_ping() is True + + current_time[0] = 130.0 + wsa_mock.last_ping_tm = 130.0 + assert ws_client.check_ping() is True + + current_time[0] = 139.0 + wsa_mock.last_ping_tm = 139.0 + ws_client.hard_reset = MagicMock() + assert ws_client.check_ping() is False + ws_client.hard_reset.assert_called_once_with(restart=True) + + ## Cleanup + ws_client.shutdown() diff --git a/test/integration/base/websocketapp_mock.py b/test/integration/base/websocketapp_mock.py index af2458c6..fd0de6de 100644 --- a/test/integration/base/websocketapp_mock.py +++ b/test/integration/base/websocketapp_mock.py @@ -27,6 +27,7 @@ def init_wsa_mock( wsa_mock._on_close.side_effect = on_close wsa_mock.last_ping_tm = 0 + wsa_mock.last_pong_tm = 0 wsa_mock.keep_running = False return wsa_mock @@ -53,4 +54,4 @@ def create_wsa_mock(): wsa_mock.close.side_effect = lambda *args, **kwargs: close(wsa_mock, *args, **kwargs) wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever(wsa_mock, *args, **kwargs) - return wsa_mock \ No newline at end of file + return wsa_mock diff --git a/test/integration/client/test_ibkr_client_i.py b/test/integration/client/test_ibkr_client_i.py index b22ae18a..6d9d5c0c 100644 --- a/test/integration/client/test_ibkr_client_i.py +++ b/test/integration/client/test_ibkr_client_i.py @@ -158,7 +158,7 @@ def _marketdata_request(method, url, *args, **kwargs): if leaf == 'stocks': return MagicMock(json=lambda: ibkr_responses.responses['stocks']) elif leaf == 'history': - conid = kwargs['params']['conid'] + conid = int(kwargs['params']['conid']) history_by_conid = { ibkr_responses.responses['filtered_conids'][key]: value for key, value in ibkr_responses.responses['history'].items() } @@ -219,6 +219,18 @@ def test_marketdata_history_by_symbols(client, requests_mock): assert result['date'] == expected['date'] +def test_marketdata_history_by_symbol_accepts_stock_query(client, requests_mock): + # Arrange + requests_mock.request.side_effect = _marketdata_request + query = StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE') + + # Act + result = client.marketdata_history_by_symbol(query, bar='1min') + + # Assert + assert result.data == ibkr_responses.responses['history']['AAPL'] + + def test_check_health_authenticated_and_connected(client, default_url, requests_mock): # Arrange response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} @@ -358,4 +370,4 @@ def test_marketdata_unsubscribe_raises_exception_on_failure(client, mocker): client.marketdata_unsubscribe(conids) # Assert - assert excinfo.value.status_code == 500 \ No newline at end of file + assert excinfo.value.status_code == 500 diff --git a/test/integration/client/test_ibkr_utils_i.py b/test/integration/client/test_ibkr_utils_i.py index e03569b7..748f4f41 100644 --- a/test/integration/client/test_ibkr_utils_i.py +++ b/test/integration/client/test_ibkr_utils_i.py @@ -1,4 +1,5 @@ from pprint import pformat +import threading from unittest.mock import MagicMock, call import pytest @@ -13,6 +14,7 @@ question_type_to_message_id, OrderRequest, parse_order_request, + Tickler, ) from test.integration.client import ibkr_responses from test.test_utils import CaptureLogsContext @@ -391,4 +393,32 @@ def test_raise_with_conid_and_conidex(): parse_order_request(order_request) - assert "Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`." == str(cm_err.value) \ No newline at end of file + assert "Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`." == str(cm_err.value) + + +def test_tickler_keeps_thread_reference_when_stop_times_out(): + ## Arrange + tickle_started = threading.Event() + release_tickle = threading.Event() + client = MagicMock() + + def slow_tickle(): + tickle_started.set() + release_tickle.wait(1) + + client.tickle.side_effect = slow_tickle + tickler = Tickler(client, interval=0) + + ## Act + tickler.start() + assert tickle_started.wait(1) + stopped = tickler.stop(timeout=0.01) + + ## Assert + assert stopped is False + assert tickler._thread is not None + assert tickler._thread.is_alive() + + ## Cleanup + release_tickle.set() + assert tickler.stop(timeout=1) is True diff --git a/test/integration/client/test_ibkr_ws_client_i.py b/test/integration/client/test_ibkr_ws_client_i.py index b3cf9c72..c3aaf399 100644 --- a/test/integration/client/test_ibkr_ws_client_i.py +++ b/test/integration/client/test_ibkr_ws_client_i.py @@ -1,4 +1,5 @@ import json +import ssl from threading import Thread from typing import Optional from unittest.mock import MagicMock, call @@ -94,7 +95,6 @@ def ws_app_factory(wsa_mock): def patched_constructors(mocker, thread_mock, ws_app_factory): mocker.patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: ws_app_factory['fn'](*args, **kwargs)) mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) - return None @@ -133,6 +133,32 @@ def override_on_message(wsa_mock: MagicMock, message: str): return rv +def test_oauth_url_preserves_existing_query_and_forces_tls(client_mock): + # Arrange / Act + client = IbkrWsClient( + url='wss://localhost:5000/v1/api/ws?existing=1', + ibkr_client=client_mock, + use_oauth=True, + access_token='TOKEN VALUE', # noqa: S106 + cacert=False, + ) + + # Assert + assert client._url == 'wss://localhost:5000/v1/api/ws?existing=1&oauth_token=TOKEN+VALUE' + assert client._sslopt == {'cert_reqs': ssl.CERT_REQUIRED} + + +def test_auth_status_competing_false_not_logged_as_error(ws_client, mocker): + # Arrange + logger_error = mocker.patch('ibind.client.ibkr_ws_client._LOGGER.error') + + # Act + ws_client._handle_authentication_status({}, {'competing': False}) + + # Assert + logger_error.assert_not_called() + + def _logs_subscriptions(full_channel, data=None, needs_confirmation_sub: bool = False, needs_confirmation_unsub: bool = True): return [ @@ -246,7 +272,7 @@ def test_on_message_sts_authenticated(ws_client, patched_constructors): _send_payload(ws_client, {'topic': 'sts', 'args': {'authenticated': True}}) -@capture_logs(logger_level='DEBUG', expected_errors = [f'IbkrWsClient: Error message:'], partial_match=True) +@capture_logs(logger_level='DEBUG', expected_errors = ['IbkrWsClient: Error message:'], partial_match=True) def test_on_message_error(ws_client, patched_constructors): """Logs error-topic messages as warnings.""" ## Act @@ -536,4 +562,4 @@ def override_init_wsa_mock(wsa_mock: MagicMock, *args, **kwargs): channel_subscribed_log, f'IbkrWsClient: Invalidated subscription: {full_channel}', ] - ) \ No newline at end of file + ) diff --git a/test/test_utils.py b/test/test_utils.py index e33822c2..ca2a269d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -473,4 +473,4 @@ def mock_module_time(target_module, time_sequence=None, start_time=0.0): # time.time() in mymodule will return 1.0, then 2.0, then 3.0 pass """ - return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) \ No newline at end of file + return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) diff --git a/test/unit/client/test_ibkr_client_u.py b/test/unit/client/test_ibkr_client_u.py index 79a15321..618f508c 100644 --- a/test/unit/client/test_ibkr_client_u.py +++ b/test/unit/client/test_ibkr_client_u.py @@ -1,3 +1,7 @@ +import os +import subprocess +import sys + import pytest from unittest.mock import MagicMock from ibind.client.ibkr_client import IbkrClient @@ -61,4 +65,36 @@ def test_handle_auth_status_not_healthy_oauth_success(client, caplog): ## Assert assert any("IBKR connection is not healthy. Attempting to re-establish OAuth authentication." in r.message for r in caplog.records) client.stop_tickler.assert_called_once_with(15) - client.oauth_init.assert_called_once_with(maintain_oauth=True, init_brokerage_session=True) \ No newline at end of file + client.oauth_init.assert_called_once_with(maintain_oauth=True, init_brokerage_session=True) + + +def test_close_oauth_is_idempotent(client): + ## Arrange + client._use_oauth = True + client.oauth_config.shutdown_oauth = True + client.oauth_shutdown = MagicMock() + + ## Act + client.close() + client.close() + + ## Assert + client.oauth_shutdown.assert_called_once() + + +def test_oauth_config_init_brokerage_session_uses_dedicated_env_var(): + ## Arrange + env = os.environ.copy() + env['IBIND_INIT_OAUTH'] = 'false' + env['IBIND_INIT_BROKERAGE_SESSION'] = 'true' + code = ( + 'from ibind.oauth import OAuthConfig\n' + 'print(OAuthConfig.__dataclass_fields__["init_oauth"].default)\n' + 'print(OAuthConfig.__dataclass_fields__["init_brokerage_session"].default)\n' + ) + + ## Act + completed = subprocess.run([sys.executable, '-c', code], env=env, capture_output=True, text=True, check=True) # noqa: S603 + + ## Assert + assert completed.stdout.splitlines() == ['False', 'True'] diff --git a/test/unit/client/test_oauth1a_u.py b/test/unit/client/test_oauth1a_u.py new file mode 100644 index 00000000..7f812093 --- /dev/null +++ b/test/unit/client/test_oauth1a_u.py @@ -0,0 +1,59 @@ +from urllib.parse import unquote_plus + +import pytest + +pytest.importorskip('Crypto') + +from ibind.client.ibkr_client import IbkrClient +from ibind.oauth.oauth1a import OAuth1aConfig, generate_oauth_headers +from ibind.base.rest_client import Result + + +def test_generate_oauth_headers_omits_static_host_and_includes_request_params(monkeypatch): + ## Arrange + captured = {} + config = OAuth1aConfig(consumer_key='consumer', access_token='access', realm='realm') # noqa: S106 + + monkeypatch.setattr('ibind.oauth.oauth1a.generate_oauth_nonce', lambda: 'nonce') + monkeypatch.setattr('ibind.oauth.oauth1a.generate_request_timestamp', lambda: '123') + + def fake_signature(base_string, live_session_token): + captured['base_string'] = base_string + captured['live_session_token'] = live_session_token + return 'signature' + + monkeypatch.setattr('ibind.oauth.oauth1a.generate_hmac_sha_256_signature', fake_signature) + + ## Act + headers = generate_oauth_headers( + oauth_config=config, + request_method='GET', + request_url='https://1.api.ibkr.com/v1/api/iserver/accounts', + live_session_token='live-token', # noqa: S106 + request_params={'accountId': 'DU123'}, + ) + + ## Assert + assert 'Host' not in headers + assert captured['live_session_token'] == 'live-token' # noqa: S105 + assert 'accountId=DU123' in unquote_plus(captured['base_string']) + + +def test_oauth_get_request_passes_query_params_to_signature(mocker): + ## Arrange + client = IbkrClient(url='https://1.api.ibkr.com/v1/api/', use_oauth=False, use_session=False, auto_register_shutdown=False) + client._use_oauth = True + client.oauth_config = OAuth1aConfig(consumer_key='consumer', access_token='access', realm='realm') # noqa: S106 + client.live_session_token = 'live-token' # noqa: S105 + client._process_response = mocker.MagicMock(return_value=Result(data={'ok': True})) + + generate_headers = mocker.patch('ibind.oauth.oauth1a.generate_oauth_headers', return_value={'Authorization': 'OAuth test'}) + request = mocker.patch('ibind.base.rest_client.requests.request') + request.return_value = mocker.MagicMock() + + ## Act + client.get('iserver/accounts', params={'accountId': 'DU123'}) + + ## Assert + generate_headers.assert_called_once() + assert generate_headers.call_args.kwargs['request_params'] == {'accountId': 'DU123'} diff --git a/test/unit/support/test_py_utils_u.py b/test/unit/support/test_py_utils_u.py index 5fb4a250..685acfda 100644 --- a/test/unit/support/test_py_utils_u.py +++ b/test/unit/support/test_py_utils_u.py @@ -1,9 +1,10 @@ +import ssl import time from unittest.mock import MagicMock import pytest -from ibind.support.py_utils import ensure_list_arg, execute_in_parallel, execute_with_key, wait_until +from ibind.support.py_utils import ensure_list_arg, execute_in_parallel, execute_with_key, wait_until, append_query_params, make_websocket_sslopt @ensure_list_arg('arg') @@ -62,7 +63,7 @@ def test_ensure_list_arg_with_keyword_arg_non_list(): def test_ensure_list_arg_with_missing_arg(): """Raises TypeError when the decorated arg is missing.""" # Arrange - + # Act / Assert with pytest.raises(TypeError): sample_function() @@ -225,4 +226,15 @@ def test_wait_until_timeout(): # Assert assert result is False duration = time.time() - start_time - assert duration == pytest.approx(timeout, abs=0.02) \ No newline at end of file + assert duration == pytest.approx(timeout, abs=0.02) + + +def test_append_query_params_preserves_existing_query(): + url = append_query_params('wss://example.test/ws?existing=1', {'oauth_token': 'abc 123'}) + assert url == 'wss://example.test/ws?existing=1&oauth_token=abc+123' + + +def test_make_websocket_sslopt_supports_bool_modes(): + assert make_websocket_sslopt(False) == {'cert_reqs': ssl.CERT_NONE} + assert make_websocket_sslopt(None) == {'cert_reqs': ssl.CERT_NONE} + assert make_websocket_sslopt(True) == {'cert_reqs': ssl.CERT_REQUIRED} diff --git a/test/unit/test_public_imports_u.py b/test/unit/test_public_imports_u.py new file mode 100644 index 00000000..e848167e --- /dev/null +++ b/test/unit/test_public_imports_u.py @@ -0,0 +1,13 @@ +from setuptools import find_packages + + +def test_public_v2_imports_and_package_discovery(): + import ibind + + packages = find_packages() + + assert 'ibind.ws_v2' in packages + assert 'ibind.ibkr_ws_v2' in packages + assert ibind.IbkrWsClientV2 is not None + assert ibind.events.WsOpen is not None + assert ibind.subscriptions.MarketDataSubscription is not None diff --git a/test/unit/ws_v2/test_ws_events_u.py b/test/unit/ws_v2/test_ws_events_u.py index 7e3f166c..2366189f 100644 --- a/test/unit/ws_v2/test_ws_events_u.py +++ b/test/unit/ws_v2/test_ws_events_u.py @@ -21,7 +21,7 @@ EventSink, ) from test.test_utils import capture_logs -from ws_v2._ws_events import AsyncSink +from ibind.ws_v2._ws_events import AsyncSink @pytest.fixture diff --git a/test/unit/ws_v2/test_ws_v2_regressions_u.py b/test/unit/ws_v2/test_ws_v2_regressions_u.py new file mode 100644 index 00000000..e87842d5 --- /dev/null +++ b/test/unit/ws_v2/test_ws_v2_regressions_u.py @@ -0,0 +1,121 @@ +import ssl +from unittest.mock import MagicMock + +from ibind import events +from ibind.ibkr_ws_v2.ibkr_router import IbkrRouter +from ibind.ibkr_ws_v2.ibkr_ws_client_v2 import IbkrWsClientV2 +from ibind.ws_v2._ws_events import CallbackSink, QueueSink +from ibind.ws_v2.ws_runtime import WsState +from ibind.ws_v2.ws_transport import WsTransport + + +def test_ws_state_values_are_strings(): + assert WsState.OPEN.value == 'OPEN' + assert str(WsState.AUTHENTICATED) == 'AUTHENTICATED' + + +def test_callback_sink_instances_do_not_share_callbacks(): + ## Arrange + first_callback = MagicMock() + second_callback = MagicMock() + first = CallbackSink() + second = CallbackSink() + + ## Act + first.on(events.WsOpen, first_callback) + second.emit(events.WsOpen()) + + ## Assert + first_callback.assert_not_called() + second_callback.assert_not_called() + + +def test_queue_sink_instances_do_not_share_queues(): + ## Arrange + first = QueueSink() + second = QueueSink() + + ## Act + first.emit(events.WsOpen()) + + ## Assert + assert second.get(events.WsOpen) is None + + +def test_transport_liveness_fails_when_pong_predates_ping(mocker): + ## Arrange + transport = WsTransport(url='wss://example.test/ws', event_callback=lambda _: None, sslopt={}) + transport._wsa = MagicMock(last_ping_tm=100, last_pong_tm=90) + mocker.patch('ibind.ws_v2.ws_transport.time.time', return_value=160) + + ## Act / Assert + assert transport.check_ping(max_interval=50) is False + + +def test_transport_liveness_tracks_first_unanswered_ping(mocker): + ## Arrange + current_time = [100.0] + transport = WsTransport(url='wss://example.test/ws', event_callback=lambda _: None, sslopt={}) + transport._wsa = MagicMock(last_ping_tm=100.0, last_pong_tm=0) + mocker.patch('ibind.ws_v2.ws_transport.time.time', side_effect=lambda: current_time[0]) + + ## Act / Assert + assert transport.check_ping(max_interval=50) is True + current_time[0] = 130.0 + transport._wsa.last_ping_tm = 130.0 + assert transport.check_ping(max_interval=50) is True + current_time[0] = 151.0 + transport._wsa.last_ping_tm = 151.0 + assert transport.check_ping(max_interval=50) is False + + +def test_transport_liveness_allows_fresh_pong_after_ping(mocker): + ## Arrange + transport = WsTransport(url='wss://example.test/ws', event_callback=lambda _: None, sslopt={}) + transport._wsa = MagicMock(last_ping_tm=100, last_pong_tm=120) + mocker.patch('ibind.ws_v2.ws_transport.time.time', return_value=140) + + ## Act / Assert + assert transport.check_ping(max_interval=50) is True + + +def test_ibkr_ws_client_v2_oauth_url_preserves_existing_query_and_forces_tls(): + ## Arrange / Act + client = IbkrWsClientV2( + url='wss://localhost:5000/v1/api/ws?existing=1', + ibkr_client=MagicMock(), + use_oauth=True, + access_token='TOKEN VALUE', # noqa: S106 + cacert=False, + ) + + ## Assert + assert client._runtime._url == 'wss://localhost:5000/v1/api/ws?existing=1&oauth_token=TOKEN+VALUE' + assert client._runtime._sslopt == {'cert_reqs': ssl.CERT_REQUIRED} + + +def test_router_unknown_auth_status_is_not_silently_ignored(): + ## Arrange + router = IbkrRouter() + + ## Act + event = router.route('{"topic":"sts","args":{"unexpected":true}}') + + ## Assert + assert isinstance(event, events.GenericIbkrEvent) + assert event.topic == 'sts' + assert event.data == {'unexpected': True} + + +def test_client_v2_competing_false_does_not_change_authentication(mocker): + ## Arrange + client = IbkrWsClientV2(ibkr_client=MagicMock(), use_oauth=False) + client._runtime.set_authenticated = MagicMock() + logger_error = mocker.patch('ibind.ibkr_ws_v2.ibkr_ws_client_v2._LOGGER.error') + + ## Act + client._on_authentication_status(events.AuthenticationStatus(data={'competing': False}, authenticated=None, competing=False)) + + ## Assert + logger_error.assert_not_called() + client._runtime.set_authenticated.assert_not_called()