diff --git a/ibind/support/logs.py b/ibind/support/logs.py index 11a94226..71ed2f5a 100644 --- a/ibind/support/logs.py +++ b/ibind/support/logs.py @@ -2,6 +2,7 @@ import logging import sys from pathlib import Path +from typing import List from ibind import var @@ -11,6 +12,25 @@ _log_to_file = False +def get_logger_children(main_logger) -> List[logging.Logger]: + """ + Gets child loggers. Added as a support compat for Python version 3.11 and below. + Source: https://github.com/python/cpython/blob/3.12/Lib/logging/__init__.py#L1831 + """ + if hasattr(main_logger, 'getChildren'): + return list(main_logger.getChildren()) + + def _hierlevel(logger): + if logger is logger.manager.root: + return 0 + return 1 + logger.name.count('.') + + d = main_logger.manager.loggerDict + return [item for item in d.values() + if isinstance(item, logging.Logger) and item.parent is main_logger and + _hierlevel(item) == 1 + _hierlevel(item.parent)] + + def project_logger(filepath=None): """ Returns a project-specific logger instance. @@ -152,4 +172,4 @@ def emit(self, record): self.close() self.stream = self._open() - super().emit(record) \ No newline at end of file + super().emit(record) diff --git a/pytest.ini b/pytest.ini index 4a56d57d..eed1abd2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,11 @@ -[tool:pytest] -testpaths = test -pythonpath = . test +[pytest] +pythonpath = . ./test ./ibind +testpaths = + test + test/integration + test/unit addopts = -v --tb=short python_files = test_*.py python_classes = Test* -python_functions = test_* \ No newline at end of file +python_functions = test_* +norecursedirs = .* __pycache__ data .pytest_cache \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index f5fd6ce1..8e14ad3c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,4 @@ ruff>=0.9.4,<0.10.0 bandit>=1.8.2,<2.0.0 pytest>=7.0.0,<9.0.0 pytest-cov>=4.0.0,<6.0.0 - +pytest-mock>=3.0.0,<4.0.0 \ No newline at end of file diff --git a/test/integration/base/test_rest_client_i.py b/test/integration/base/test_rest_client_i.py index 693627af..c7effe2d 100644 --- a/test/integration/base/test_rest_client_i.py +++ b/test/integration/base/test_rest_client_i.py @@ -1,149 +1,221 @@ -import threading -from unittest import TestCase -from unittest.mock import patch, MagicMock import asyncio +import logging +import threading + +import pytest +from unittest.mock import MagicMock from requests import ReadTimeout, Timeout from ibind.client.ibkr_client import IbkrClient from ibind.support.errors import ExternalBrokerError -from ibind.support.logs import project_logger from ibind.base.rest_client import Result, RestClient +from ibind.support.logs import ibind_logs_initialize +from test.test_utils import CaptureLogsContext + + +_URL = 'https://localhost:5000' +_TIMEOUT = 8 +_MAX_RETRIES = 4 +_DEFAULT_PATH = 'test/api/route' + + +@pytest.fixture +def client(): + ibind_logs_initialize(log_to_console=True) + return RestClient( + url=_URL, + timeout=_TIMEOUT, + max_retries=_MAX_RETRIES, + use_session=False, + ) + + +@pytest.fixture +def data(): + return {'Test key': 'Test value'} + + +@pytest.fixture +def response(data): + response = MagicMock() + response.json.return_value = data + return response + + +@pytest.fixture(autouse=True) +def requests_mock(mocker, response): + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + return requests_mock + + +@pytest.fixture +def default_url(): + return f'{_URL}/{_DEFAULT_PATH}' + + +@pytest.fixture +def result(data, default_url): + return Result(data=data, request={'url': default_url}) + + +def test_default_rest_get(client, default_url, result, requests_mock): + # Arrange + # Act + rv = client.get(_DEFAULT_PATH) + + # Assert + assert result == rv + requests_mock.request.assert_called_with('GET', default_url, verify=False, headers={}, timeout=_TIMEOUT) + + +def test_default_rest_post(client, default_url, result, requests_mock): + # Arrange + test_post_kwargs = {'field1': 'value1', 'field2': 'value2'} + test_json = {'json': {**test_post_kwargs}} + + # Act + rv = client.post(_DEFAULT_PATH, params=test_post_kwargs) + + # Assert + assert result.copy(request={'url': default_url, **test_json}) == rv + requests_mock.request.assert_called_with('POST', default_url, verify=False, headers={}, timeout=_TIMEOUT, **test_json) + + +def test_default_rest_delete(client, default_url, result, requests_mock): + # Arrange + # Act + rv = client.delete(_DEFAULT_PATH) + + # Assert + assert result == rv + requests_mock.request.assert_called_with('DELETE', default_url, verify=False, headers={}, timeout=_TIMEOUT) + + +def test_request_retries(client, default_url, requests_mock): + # Arrange + requests_mock.request.side_effect = ReadTimeout() + + # Act + with CaptureLogsContext('ibind.rest_client', level='INFO') as cm, pytest.raises(TimeoutError) as excinfo: + client.get(_DEFAULT_PATH) + + # Assert + for i in range(_MAX_RETRIES): + assert f'RestClient: Timeout for GET {default_url} {{}}, retrying attempt {i + 1}/{_MAX_RETRIES}' in cm.output + + assert f'RestClient: Reached max retries ({_MAX_RETRIES}) for GET {default_url} {{}}' == str(excinfo.value) + + +def test_response_raise_timeout(client, requests_mock): + # Arrange + requests_mock.request.return_value.raise_for_status.side_effect = Timeout() + + # Act + with pytest.raises(ExternalBrokerError) as excinfo: + client.get(_DEFAULT_PATH) + + # Assert + assert f'RestClient: Timeout error ({_TIMEOUT}S)' == str(excinfo.value) + + +def test_response_raise_generic(client, result, requests_mock): + # Arrange + response = requests_mock.request.return_value + response.status_code = 400 + response.reason = 'Test reason' + response.text = 'Test text' + response.raise_for_status.side_effect = ValueError('Test generic error') + + # Act + with pytest.raises(ExternalBrokerError) as excinfo: + client.get(_DEFAULT_PATH) + + # Assert + assert f'RestClient: response error {result.copy(data=None)} :: {response.status_code} :: {response.reason} :: {response.text}' == str(excinfo.value) + + +def _worker_in_thread(results: []): + try: + IbkrClient() + except Exception as e: + results.append(e) + + +def test_in_thread(): + """Run in thread ensuring client still is constructed without an exception.""" + # Arrange + results = [] + t = threading.Thread(target=_worker_in_thread, args=(results,)) + t.daemon = True + + # Act + t.start() + t.join(1) + + # Assert + for result in results: + if isinstance(result, Exception): + raise result + + +def test_without_thread(): + """Run without a thread to ensure it still works as expected.""" + # Arrange + results = [] + + # Act + _worker_in_thread(results) + + # Assert + for result in results: + if isinstance(result, Exception): + raise result + + +async def _async_worker(results: []): + """Async version of the worker function to run in an asyncio event loop.""" + try: + IbkrClient() + except Exception as e: + results.append(e) + + +def _worker_in_async_thread(results: []): + """Runs the async test inside a new thread to check if signal handling breaks.""" + try: + asyncio.run(_async_worker(results)) + except Exception as e: + results.append(e) + + +def test_in_thread_async(): + """Test that IbkrClient() does not break in an asyncio thread.""" + # Arrange + results = [] + t = threading.Thread(target=_worker_in_async_thread, args=(results,)) + t.daemon = True + + # Act + t.start() + t.join(1) + + # Assert + for result in results: + if isinstance(result, Exception): + raise result + + +def test_without_thread_async(): + """Test that IbkrClient() does not break in the main asyncio event loop.""" + # Arrange + results = [] + # Act + asyncio.run(_async_worker(results)) -@patch('ibind.base.rest_client.requests') -class TestRestClientI(TestCase): - def setUp(self): - self.url = 'https://localhost:5000' - self.account_id = 'TEST_ACCOUNT_ID' - self.timeout = 8 - self.max_retries = 4 - self.client = RestClient( - url=self.url, - timeout=self.timeout, - max_retries=self.max_retries, - use_session=False, - ) - - self.data = {'Test key': 'Test value'} - - self.response = MagicMock() - self.response.json.return_value = self.data - self.default_path = 'test/api/route' - self.default_url = f'{self.url}/{self.default_path}' - self.result = Result(data=self.data, request={'url': self.default_url}) - self.maxDiff = 9999 - - def test_default_rest(self, requests_mock): - requests_mock.request.return_value = self.response - - rv = self.client.get(self.default_path) - self.assertEqual(self.result, rv) - requests_mock.request.assert_called_with('GET', self.default_url, verify=False, headers={}, timeout=self.timeout) - - test_post_kwargs = {'field1': 'value1', 'field2': 'value2'} - test_json = {'json': {**test_post_kwargs}} - rv = self.client.post(self.default_path, params=test_post_kwargs) - self.assertEqual(self.result.copy(request={'url': self.default_url, **test_json}), rv) - requests_mock.request.assert_called_with('POST', self.default_url, verify=False, headers={}, timeout=self.timeout, **test_json) - - rv = self.client.delete(self.default_path) - self.assertEqual(self.result, rv) - requests_mock.request.assert_called_with('DELETE', self.default_url, verify=False, headers={}, timeout=self.timeout) - - def test_request_retries(self, requests_mock): - requests_mock.request.side_effect = ReadTimeout() - - with self.assertLogs(project_logger(), level='INFO') as cm, self.assertRaises(TimeoutError) as cm_err: - self.client.get(self.default_path) - - for i, record in enumerate(cm.records): - self.assertEqual(f'RestClient: Timeout for GET {self.default_url} {{}}, retrying attempt {i + 1}/{self.max_retries}', record.msg) - self.assertEqual(f'RestClient: Reached max retries ({self.max_retries}) for GET {self.default_url} {{}}', str(cm_err.exception)) - - def test_response_raise_timeout(self, requests_mock): - requests_mock.request.return_value = self.response - self.response.raise_for_status.side_effect = Timeout() - - with self.assertRaises(ExternalBrokerError) as cm_err: - self.client.get(self.default_path) - - self.assertEqual(f'RestClient: Timeout error ({self.timeout}S)', str(cm_err.exception)) - - def test_response_raise_generic(self, requests_mock): - requests_mock.request.return_value = self.response - self.response.status_code = 400 - self.response.reason = 'Test reason' - self.response.text = 'Test text' - - self.response.raise_for_status.side_effect = ValueError('Test generic error') - - with self.assertRaises(ExternalBrokerError) as cm_err: - self.client.get(self.default_path) - - self.assertEqual( - f'RestClient: response error {self.result.copy(data=None)} :: {self.response.status_code} :: {self.response.reason} :: {self.response.text}', - str(cm_err.exception), - ) - - -class TestRestClientInThread(TestCase): - def _worker(self, results: []): - try: - IbkrClient() - except Exception as e: - results.append(e) - - def test_in_thread(self): - """Run in thread ensuring client still is constructed without an exception.""" - results = [] - t = threading.Thread(target=self._worker, args=(results,)) - t.daemon = True - t.start() - t.join(1) - for result in results: - if isinstance(result, Exception): - raise result - - def test_without_thread(self): - """Run without a thread to ensure it still works as expected.""" - results = [] - self._worker(results) - for result in results: - if isinstance(result, Exception): - raise result - - -class TestRestClientAsync(TestCase): - def _worker(self, results: []): - """Runs the async test inside a new thread to check if signal handling breaks.""" - try: - asyncio.run(self._async_worker(results)) - except Exception as e: - results.append(e) - - async def _async_worker(self, results: []): - """Async version of the worker function to run in an asyncio event loop.""" - try: - IbkrClient() - except Exception as e: - results.append(e) - - def test_in_thread_async(self): - """Test that IbkrClient() does not break in an asyncio thread.""" - results = [] - t = threading.Thread(target=self._worker, args=(results,)) - t.daemon = True - t.start() - t.join(1) - for result in results: - if isinstance(result, Exception): - raise result - - def test_without_thread_async(self): - """Test that IbkrClient() does not break in the main asyncio event loop.""" - results = [] - asyncio.run(self._async_worker(results)) - for result in results: - if isinstance(result, Exception): - raise result + # Assert + for result in results: + if isinstance(result, Exception): + raise result \ No newline at end of file diff --git a/test/integration/base/test_websocket_client_i.py b/test/integration/base/test_websocket_client_i.py index 58529583..b14d5e75 100644 --- a/test/integration/base/test_websocket_client_i.py +++ b/test/integration/base/test_websocket_client_i.py @@ -1,273 +1,396 @@ from threading import Thread from typing import Optional -from unittest import TestCase -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock + +import pytest from ibind.base.ws_client import WsClient from ibind.support.py_utils import tname -from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils import RaiseLogsContext, exact_log - - -class TestWsClient(TestCase): - def setUp(self): - self.url = 'wss://localhost:5000/v1/api/ws' - self.max_reconnect_attempts = 4 - self.max_ping_interval = 38 - self.error_message = 'TEST_ERROR' - - self.ws_client = WsClient( - subscription_processor=None, - url=self.url, - cacert=False, - timeout=0.01, - max_connection_attempts=self.max_reconnect_attempts, - max_ping_interval=self.max_ping_interval, - ) - - self.wsa_mock = create_wsa_mock() - - self.thread_mock = MagicMock(spec=Thread) - self.thread_mock.start.side_effect = lambda: self.ws_client._run_websocket(self.wsa_mock) - - def run_in_test_context(self, fn, expected_errors: list[str] = None): - with patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: init_wsa_mock(self.wsa_mock, *args, **kwargs)), \ - patch('ibind.base.ws_client.Thread', return_value=self.thread_mock) as new_thread_mock, \ - self.assertLogs('ibind', level='DEBUG') as cm, \ - RaiseLogsContext(self, 'ibind', level='ERROR', expected_errors=expected_errors): # fmt: skip - self.new_thread_mock = new_thread_mock - rv = fn() - - return cm, rv - - def start(self): - success = self.ws_client.start() - self.new_thread_mock.assert_called_with(target=self.ws_client._run_websocket, args=(self.wsa_mock,), name='ws_client_thread') - return success - - def _logs_start_success_beginning(self): - return [ - 'WsClient: Starting', - 'WsClient: Trying to connect', - ] - - def _logs_start_success_end(self): - return [ - 'WsClient: Creating new WebSocketApp', - f'WsClient: Thread started ({tname()})', - 'WsClient: Connection open', - f'WsClient: Thread stopped ({tname()})', - ] - - def _logs_failed_attempt(self, attempt): - s = [ - 'WsClient: Creating new WebSocketApp', - 'WsClient: New WebSocketApp connection timeout', - 'WsClient: on_close', - 'WsClient: on_close event while disconnected', - ] - if attempt: - s.append(f'WsClient: Connect reattempt {attempt}/{self.max_reconnect_attempts}') - return s - - def _logs_shutdown_success(self): - return [ - 'WsClient: Shutting down', - 'WsClient: on_close', - 'WsClient: Connection closed', - 'WsClient: Gracefully stopped', - ] - - def _logs_exception_starting(self, error_message, thread_mock): - return [ - 'WsClient: Creating new WebSocketApp', - f'WsClient: Thread started ({tname()})', - f'WsClient: Unexpected error while running WebSocketApp: {error_message}', - 'WsClient: Hard reset, restart=False, self._wsa is None=False', - 'WsClient: Forced restart', - 'WsClient: Reconnecting', - f'WsClient: Thread already running: {thread_mock.name}-{thread_mock.ident}', - f'WsClient: Thread stopped ({tname()})', - 'WsClient: Reconnecting', - 'WsClient: Trying to connect', - ] - - def _logs_check_health_error(self, time_ago): - return [ - f'WsClient: Last WebSocket ping happened {time_ago} seconds ago, exceeding the max ping interval of {self.max_ping_interval}. Restarting.', - 'WsClient: Hard reset, restart=True, self._wsa is None=False', - 'WsClient: Hard reset is closing the WebSocketApp', - ] - - def _logs_hard_restart_error(self): - return [ - 'WsClient: Hard reset close timeout', - f'WsClient: Abandoning current WebSocketApp that cannot be closed: {self.wsa_mock}', - 'WsClient: Forced restart', - 'WsClient: Reconnecting', - 'WsClient: Trying to connect', - ] - - def _verify_started(self): - self.wsa_mock.run_forever.assert_called_with( - sslopt=self.ws_client._sslopt, ping_interval=self.ws_client._ping_interval, ping_timeout=0.95 * self.ws_client._ping_interval - ) - self.wsa_mock._on_open.assert_called_with(self.wsa_mock) - - def _verify_failed_starting(self): - self.wsa_mock.run_forever.assert_not_called() - self.wsa_mock._on_open.assert_not_called() - self.wsa_mock.close.assert_called() - - def test_start_success(self): - cm, success = self.run_in_test_context(self.start) - - self.assertTrue(success, 'Starting should succeed') - self._verify_started() - exact_log(self, cm, self._logs_start_success_beginning() + self._logs_start_success_end()) - - def test_start_success_on_second_attempt(self): - counter = [0] - - # ensure we fail to do anything on the first attempt, and succeed on the second - def delayed_start(): - if counter[0] >= 1: - self.ws_client._run_websocket(self.wsa_mock) - counter[0] += 1 - - self.thread_mock.start.side_effect = delayed_start - - expected_errors = ['WsClient: New WebSocketApp connection timeout'] - - cm, success = self.run_in_test_context(self.start, expected_errors=expected_errors) - - self._verify_started() - - exact_log(self, cm, self._logs_start_success_beginning() + self._logs_failed_attempt(2) + self._logs_start_success_end()) - self.thread_mock.join.assert_called_with(60) - # print("\n".join([r.msg for r in cm.records])) - - def test_start_reattempt_failure(self): - self.thread_mock.start.side_effect = lambda: None - - expected_errors = ['WsClient: New WebSocketApp connection timeout'] - - cm, success = self.run_in_test_context(self.start, expected_errors=expected_errors) - - self.assertFalse(success, 'Starting not succeed') - - self._verify_failed_starting() - - expected_logs = self._logs_start_success_beginning() - for i in range(self.max_reconnect_attempts): - if i < self.max_reconnect_attempts - 1: - expected_logs += self._logs_failed_attempt(i + 2) - else: - expected_logs += self._logs_failed_attempt(None) - expected_logs.append(f'WsClient: Connection failed after {self.max_reconnect_attempts} attempts') - exact_log(self, cm, expected_logs) - - self.assertFalse(self.wsa_mock.keep_running) - - def test_open_exception(self): - old_run_forever = self.wsa_mock.run_forever.side_effect - - def run(): - success = self.start() - self.ws_client.shutdown() - return success - - def run_forever_exception(wsa_mock: MagicMock, sslopt: dict = None, ping_interval: float = 0, ping_timeout: Optional[float] = None): - self.wsa_mock.run_forever.side_effect = old_run_forever - raise RuntimeError(self.error_message) - - self.wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever_exception(self.wsa_mock, *args, **kwargs) - - expected_errors = [f'WsClient: Unexpected error while running WebSocketApp: {self.error_message}'] - - cm, success = self.run_in_test_context(run, expected_errors=expected_errors) - - exact_log( - self, - cm, - self._logs_start_success_beginning() - + self._logs_exception_starting(self.error_message, self.thread_mock) - + self._logs_start_success_end() - + self._logs_shutdown_success(), - ) - - def test_open_and_close(self): - def run(): - success = self.start() - self.ws_client.shutdown() - return success - - cm, success = self.run_in_test_context(run) - - exact_log(self, cm, self._logs_start_success_beginning() + self._logs_start_success_end() + self._logs_shutdown_success()) - - def test_send(self): - def run(): - success = self.start() - self.ws_client.send('test') - self.ws_client.shutdown() - return success - - self.ws_client._on_message = MagicMock() - - cm, success = self.run_in_test_context(run) - - self.ws_client._on_message.assert_called_once_with(self.wsa_mock, 'test') +from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock +from test.test_utils import capture_logs + +_URL = 'wss://localhost:5000/v1/api/ws' +_MAX_RECONNECT_ATTEMPTS = 4 +_MAX_PING_INTERVAL = 38 +_ERROR_MESSAGE = 'TEST_ERROR' + + +# -------------------------------------------------------------------------------------- +# Log expectations +# -------------------------------------------------------------------------------------- + + +def _logs_start_success_beginning(): + return [ + 'WsClient: Starting', + 'WsClient: Trying to connect', + ] + + +def _logs_start_success_end(): + return [ + 'WsClient: Creating new WebSocketApp', + f'WsClient: Thread started ({tname()})', + 'WsClient: Connection open', + f'WsClient: Thread stopped ({tname()})', + ] + + +def _logs_failed_attempt(max_reconnect_attempts: int, attempt: Optional[int]): + logs = [ + 'WsClient: Creating new WebSocketApp', + 'WsClient: New WebSocketApp connection timeout', + 'WsClient: on_close', + 'WsClient: on_close event while disconnected', + ] + if attempt is not None: + logs.append(f'WsClient: Connect reattempt {attempt}/{max_reconnect_attempts}') + return logs - exact_log(self, cm, self._logs_start_success_beginning() + self._logs_start_success_end() + self._logs_shutdown_success()) - def test_send_without_start(self): - def run(): - self.ws_client.send('test') - self.ws_client.shutdown() +def _logs_shutdown_success(): + return [ + 'WsClient: Shutting down', + 'WsClient: on_close', + 'WsClient: Connection closed', + 'WsClient: Gracefully stopped', + ] - self.ws_client._on_message = MagicMock() - expected_errors = ['WsClient: Must be started before sending payloads'] +def _logs_exception_starting(error_message: str, thread_mock: MagicMock): + return [ + 'WsClient: Creating new WebSocketApp', + f'WsClient: Thread started ({tname()})', + f'WsClient: Unexpected error while running WebSocketApp: {error_message}', + 'WsClient: Hard reset, restart=False, self._wsa is None=False', + 'WsClient: Forced restart', + 'WsClient: Reconnecting', + f'WsClient: Thread already running: {thread_mock.name}-{thread_mock.ident}', + f'WsClient: Thread stopped ({tname()})', + 'WsClient: Reconnecting', + 'WsClient: Trying to connect', + ] - cm, success = self.run_in_test_context(run, expected_errors=expected_errors) - exact_log(self, cm, expected_errors) +def _logs_check_health_error(max_ping_interval: int, time_ago: str): + return [ + f'WsClient: Last WebSocket ping happened {time_ago} seconds ago, exceeding the max ping interval of {max_ping_interval}. Restarting.', + 'WsClient: Hard reset, restart=True, self._wsa is None=False', + 'WsClient: Hard reset is closing the WebSocketApp', + ] - def test_check_ping(self): - start_time = [100] - def fake_time(): - start_time[0] += 100 - return start_time[0] +def _logs_hard_restart_error(wsa_mock: MagicMock): + return [ + 'WsClient: Hard reset close timeout', + f'WsClient: Abandoning current WebSocketApp that cannot be closed: {wsa_mock}', + 'WsClient: Forced restart', + 'WsClient: Reconnecting', + 'WsClient: Trying to connect', + ] - def run(): - self.ws_client.start() - self.ws_client.check_ping() - # we simulate that closing the WebSocketApp doesn't work since we have connectivity issues - self.wsa_mock._on_close.side_effect = lambda x, y, z: None - with patch('ibind.base.ws_client.time') as time_mock: - time_mock.time.side_effect = fake_time - self.wsa_mock.last_ping_tm = self.max_ping_interval - self.ws_client.check_ping() - self.assertTrue(self.ws_client.ready()) - self.ws_client.shutdown() - self.ws_client._on_message = MagicMock() +def _verify_started(ws_client: WsClient, wsa_mock: MagicMock): + wsa_mock.run_forever.assert_called_with( + sslopt=ws_client._sslopt, + ping_interval=ws_client._ping_interval, + ping_timeout=0.95 * ws_client._ping_interval, + ) + wsa_mock._on_open.assert_called_with(wsa_mock) + + +def _verify_failed_starting(wsa_mock: MagicMock): + wsa_mock.run_forever.assert_not_called() + wsa_mock._on_open.assert_not_called() + wsa_mock.close.assert_called() + + +# -------------------------------------------------------------------------------------- +# Test setup +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def ws_client(): + return WsClient( + subscription_processor=None, + url=_URL, + cacert=False, + timeout=0.01, + max_connection_attempts=_MAX_RECONNECT_ATTEMPTS, + max_ping_interval=_MAX_PING_INTERVAL, + ) + + +@pytest.fixture +def wsa_mock(): + return create_wsa_mock() + + +@pytest.fixture +def thread_mock(ws_client, wsa_mock): + thread_mock = MagicMock(spec=Thread) + thread_mock.start.side_effect = lambda: ws_client._run_websocket(wsa_mock) + return thread_mock + + +@pytest.fixture +def wsa_ctor_mock(mocker, wsa_mock): + return mocker.patch( + 'ibind.base.ws_client.WebSocketApp', + side_effect=lambda *args, **kwargs: init_wsa_mock(wsa_mock, *args, **kwargs), + ) + + +@pytest.fixture +def thread_ctor_mock(mocker, thread_mock): + return mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) + + +@pytest.fixture +def patched_constructors(wsa_ctor_mock, thread_ctor_mock): + return None + + +# -------------------------------------------------------------------------------------- +# Start / reconnect behavior +# -------------------------------------------------------------------------------------- + +@capture_logs(logger_level='DEBUG') +def test_start_success(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Starts successfully and logs the expected connection sequence.""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ## Act + success = ws_client.start() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + _verify_started(ws_client, wsa_mock) + assert _logs_start_success_beginning() + _logs_start_success_end() == [r.msg for r in cm.records] + + +@capture_logs(logger_level='DEBUG', expected_errors=['WsClient: New WebSocketApp connection timeout']) +def test_start_success_on_second_attempt(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Reconnects and succeeds on the second attempt after a timeout on the first.""" + ## Arrange + cm = kwargs['_cm_ibind'] + counter = [0] + + def delayed_start(): + if counter[0] >= 1: + ws_client._run_websocket(wsa_mock) + counter[0] += 1 + + thread_mock.start.side_effect = delayed_start + + ## Act + success = ws_client.start() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + _verify_started(ws_client, wsa_mock) + assert ( + _logs_start_success_beginning() + + _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, 2) + + _logs_start_success_end() + == [r.msg for r in cm.records] + ) + thread_mock.join.assert_called_with(60) + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + 'WsClient: New WebSocketApp connection timeout', + f'WsClient: Connection failed after {_MAX_RECONNECT_ATTEMPTS} attempts', + ], +) +def test_start_reattempt_failure(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Fails after exhausting reconnect attempts and closes the WebSocketApp.""" + ## Arrange + cm = kwargs['_cm_ibind'] + thread_mock.start.side_effect = lambda: None + + ## Act + success = ws_client.start() + + ## Assert + assert success is False + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + + _verify_failed_starting(wsa_mock) + + expected_logs = _logs_start_success_beginning() + for i in range(_MAX_RECONNECT_ATTEMPTS): + if i < _MAX_RECONNECT_ATTEMPTS - 1: + expected_logs += _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, i + 2) + else: + expected_logs += _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, None) + expected_logs.append(f"WsClient: Connection failed after {_MAX_RECONNECT_ATTEMPTS} attempts") + + assert expected_logs == [r.msg for r in cm.records] + assert wsa_mock.keep_running is False + + +# -------------------------------------------------------------------------------------- +# Error handling +# -------------------------------------------------------------------------------------- + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + f"WsClient: Unexpected error while running WebSocketApp: {_ERROR_MESSAGE}", + 'WsClient: Thread already running:', + ], + partial_match=True, +) +def test_open_exception(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Hard-resets and reconnects when WebSocketApp.run_forever raises an exception.""" + ## Arrange + cm = kwargs['_cm_ibind'] + old_run_forever = wsa_mock.run_forever.side_effect + + def run_forever_exception( + wsa_mock: MagicMock, + sslopt: dict = None, + ping_interval: float = 0, + ping_timeout: Optional[float] = None, + ): + wsa_mock.run_forever.side_effect = old_run_forever + raise RuntimeError(_ERROR_MESSAGE) + + wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever_exception(wsa_mock, *args, **kwargs) + + ## Act + ws_client.start() + ws_client.shutdown() + + ## Assert + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + assert ( + _logs_start_success_beginning() + + _logs_exception_starting(_ERROR_MESSAGE, thread_mock) + + _logs_start_success_end() + + _logs_shutdown_success() + == [r.msg for r in cm.records] + ) + + +# -------------------------------------------------------------------------------------- +# Shutdown +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_open_and_close(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Shuts down cleanly after a successful start.""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ## Act + success = ws_client.start() + ws_client.shutdown() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + assert _logs_start_success_beginning() + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] + + +# -------------------------------------------------------------------------------------- +# Sending payloads +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_send(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Delivers outbound payloads to the on_message callback (mocked echo).""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ws_client._on_message = MagicMock() + + ## Act + success = ws_client.start() + ws_client.send('test') + ws_client.shutdown() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + ws_client._on_message.assert_called_once_with(wsa_mock, 'test') + assert _logs_start_success_beginning() + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] + + +@capture_logs(logger_level='DEBUG', expected_errors=['WsClient: Must be started before sending payloads']) +def test_send_without_start(ws_client, **kwargs): + """Logs an error when trying to send before calling start().""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ws_client._on_message = MagicMock() + + ## Act + ws_client.send('test') + ws_client.shutdown() + + ## Assert + assert ['WsClient: Must be started before sending payloads'] == [r.msg for r in cm.records] + + +# -------------------------------------------------------------------------------------- +# Health checks +# -------------------------------------------------------------------------------------- + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + 'WsClient: Last WebSocket ping happened', + 'WsClient: Hard reset close timeout', + 'WsClient: Abandoning current WebSocketApp that cannot be closed:', + ], + partial_match=True, +) +def test_check_ping(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Triggers a hard reset when the last ping exceeds max_ping_interval.""" + ## Arrange + cm = kwargs['_cm_ibind'] + start_time = [100] + + def fake_time(): + start_time[0] += 100 + return start_time[0] + + ws_client._on_message = MagicMock() + + ## Act + ws_client.start() + ws_client.check_ping() + + # Simulate that closing the WebSocketApp doesn't work since we have connectivity issues + wsa_mock._on_close.side_effect = lambda x, y, z: None - expected_errors = ['WsClient: Must be started before sending payloads', 'WsClient: Hard reset close timeout'] + time_mock = mocker.patch('ibind.base.ws_client.time') + time_mock.time.side_effect = fake_time - cm, success = self.run_in_test_context(run, expected_errors=expected_errors) + wsa_mock.last_ping_tm = _MAX_PING_INTERVAL + ws_client.check_ping() + assert ws_client.ready() is True + ws_client.shutdown() - exact_log( - self, - cm, - self._logs_start_success_beginning() - + self._logs_start_success_end() - + self._logs_check_health_error('162.00') - + - # self._logs_start_success_end() + - self._logs_hard_restart_error() - + self._logs_start_success_end() - + self._logs_shutdown_success(), - ) + ## Assert + assert ( + _logs_start_success_beginning() + + _logs_start_success_end() + + _logs_check_health_error(_MAX_PING_INTERVAL, '162.00') + + _logs_hard_restart_error(wsa_mock) + + _logs_start_success_end() + + _logs_shutdown_success() + == [r.msg for r in cm.records] + ) \ No newline at end of file diff --git a/test/integration/base/websocketapp_mock.py b/test/integration/base/websocketapp_mock.py index 670b1205..5961f7a4 100644 --- a/test/integration/base/websocketapp_mock.py +++ b/test/integration/base/websocketapp_mock.py @@ -50,7 +50,7 @@ def create_wsa_mock(): wsa_mock = MagicMock() wsa_mock.send.side_effect = lambda *args, **kwargs: send(wsa_mock, *args, **kwargs) - wsa_mock.close.side_effect = lambda status=None: close(wsa_mock, status) + wsa_mock.close.side_effect = lambda *args, **kwargs: close(wsa_mock, *args, **kwargs) wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever(wsa_mock, *args, **kwargs) - return wsa_mock + return wsa_mock \ No newline at end of file diff --git a/test/integration/client/test_ibkr_client_i.py b/test/integration/client/test_ibkr_client_i.py index 8042b5d6..b22ae18a 100644 --- a/test/integration/client/test_ibkr_client_i.py +++ b/test/integration/client/test_ibkr_client_i.py @@ -1,7 +1,7 @@ import datetime from pprint import pformat -from unittest import TestCase -from unittest.mock import patch, MagicMock +import pytest +from unittest.mock import MagicMock from requests import ConnectTimeout @@ -9,259 +9,353 @@ from ibind.client.ibkr_client import IbkrClient from ibind.client.ibkr_utils import StockQuery, filter_stocks from ibind.support.errors import ExternalBrokerError -from ibind.support.logs import project_logger +from ibind.support.logs import ibind_logs_initialize from test.integration.client import ibkr_responses -from test_utils import verify_log, SafeAssertLogs, RaiseLogsContext - - -@patch('ibind.base.rest_client.requests') -class TestIbkrClientI(TestCase): - def setUp(self): - self.url = 'https://localhost:5000' - self.account_id = 'TEST_ACCOUNT_ID' - self.timeout = 8 - self.max_retries = 4 - self.client = IbkrClient( - url=self.url, - account_id=self.account_id, - timeout=self.timeout, - max_retries=self.max_retries, - use_session=False, - ) - - self.data = {'Test key': 'Test value'} - - self.response = MagicMock() - self.response.json.return_value = self.data - self.default_path = '/test/api/route' - self.default_url = f'{self.url}/{self.default_path}' - self.result = Result(data=self.data, request={'url': self.default_url}) - self.maxDiff = 9999 - - def test_get_conids(self, requests_mock): - requests_mock.request.return_value = self.response - self.response.json.return_value = ibkr_responses.responses['stocks'] - - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': False}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), - 'HUBS', - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - StockQuery(symbol='INVALID_SYMBOL') - ] # fmt: skip - - with self.assertLogs(project_logger(), level='INFO'): - rv = self.client.stock_conid_by_symbol(queries, default_filtering=False) - - for symbol, conid in rv.data.items(): - self.assertIn(symbol, ibkr_responses.responses['filtered_conids']) - self.assertEqual(conid, ibkr_responses.responses['filtered_conids'][symbol]) - - def test_get_conids_exception(self, requests_mock): - requests_mock.request.return_value = self.response - self.response.json.return_value = ibkr_responses.responses['stocks'] - - symbol = 'AAPL' - query = StockQuery(symbol=symbol, contract_conditions={'isUS': False}, name_match='APPLE') - - instruments = filter_stocks(query, Result(data={symbol: ibkr_responses.responses['stocks'][symbol]}), default_filtering=False).data[symbol] - - with self.assertRaises(RuntimeError) as cm_err: - self.client.stock_conid_by_symbol(query, default_filtering=False) - - self.maxDiff = None - self.assertEqual( - f'Filtering stock "{symbol}" returned 2 instruments and 2 contracts using following query: {query}.\nPlease use filters to ensure that only one instrument and one contract per symbol is selected in order to avoid conid ambiguity.\nBe aware that contracts are filtered as {{"isUS": True}} by default. Set default_filtering=False to prevent this default filtering or specify custom filters. See inline documentation for more details.\nInstruments returned:\n{pformat(instruments)}', - str(cm_err.exception), - ) - - def test_get_live_orders_no_filters(self, requests_mock): - self.client.get = MagicMock(return_value=self.result) - self.client.live_orders() - self.client.get.assert_called_with('iserver/account/orders', params=None) - - def test_get_live_orders_with_valid_filters(self, requests_mock): - self.client.get = MagicMock(return_value=self.result) - filters = ['inactive', 'filled'] - self.client.live_orders(filters=filters) - self.client.get.assert_called_with('iserver/account/orders', params={'filters': 'inactive,filled'}) - - def test_get_live_orders_with_single_filter(self, requests_mock): - self.client.get = MagicMock(return_value=self.result) - self.client.live_orders(filters='submitted') - self.client.get.assert_called_with('iserver/account/orders', params={'filters': 'submitted'}) - - def test_get_live_orders_with_incorrect_filter_type(self, requests_mock): - self.client.get = MagicMock(return_value=self.result) - with self.assertRaises(TypeError): - self.client.live_orders(filters=123) # Non-list, non-string filter - self.client.get.assert_not_called() - - def _marketdata_request(self, method, url, *args, **kwargs): - leaf = url.split('/')[-1] - if leaf == 'stocks': - return MagicMock(json=lambda: ibkr_responses.responses['stocks']) # Mock response for get_conids - elif leaf == 'history': - conid = kwargs['params']['conid'] - return MagicMock(json=lambda: self._history_by_conid[conid]) - - def test_marketdata_history_by_symbols(self, requests_mock): - # Mocking the requests module for external interaction - self._history_by_conid = { +from test.test_utils import CaptureLogsContext + + +_URL = 'https://localhost:5000' +_TIMEOUT = 8 +_MAX_RETRIES = 4 +_DEFAULT_PATH = '/test/api/route' +_ACCOUNT_ID = 'TEST_ACCOUNT_ID' + + +@pytest.fixture +def client(): + ibind_logs_initialize(log_to_console=True) + return IbkrClient( + url=_URL, + account_id=_ACCOUNT_ID, + timeout=_TIMEOUT, + max_retries=_MAX_RETRIES, + use_session=False, + ) + + +@pytest.fixture +def data(): + return {'Test key': 'Test value'} + + +@pytest.fixture +def response(data): + response = MagicMock() + response.json.return_value = data + return response + + +@pytest.fixture(autouse=True) +def requests_mock(mocker, response): + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + return requests_mock + + +@pytest.fixture +def default_url(): + return f'{_URL}/{_DEFAULT_PATH}' + + +@pytest.fixture +def result(data, default_url): + return Result(data=data, request={'url': default_url}) + + +def test_get_conids(client, response): + # Arrange + response.json.return_value = ibkr_responses.responses['stocks'] + + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': False}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), + 'HUBS', + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + StockQuery(symbol='INVALID_SYMBOL') + ] + + # Act + rv = client.stock_conid_by_symbol(queries, default_filtering=False) + + # Assert + for symbol, conid in rv.data.items(): + assert symbol in ibkr_responses.responses['filtered_conids'] + assert conid == ibkr_responses.responses['filtered_conids'][symbol] + + +def test_get_conids_exception(client, response): + # Arrange + response.json.return_value = ibkr_responses.responses['stocks'] + + symbol = 'AAPL' + query = StockQuery(symbol=symbol, contract_conditions={'isUS': False}, name_match='APPLE') + + instruments = filter_stocks(query, Result(data={symbol: ibkr_responses.responses['stocks'][symbol]}), default_filtering=False).data[symbol] + + # Act and Assert + with pytest.raises(RuntimeError) as excinfo: + client.stock_conid_by_symbol(query, default_filtering=False) + + assert str(excinfo.value) == f'Filtering stock "{symbol}" returned 2 instruments and 2 contracts using following query: {query}.' \ + f'\nPlease use filters to ensure that only one instrument and one contract per symbol is selected in order to avoid conid ambiguity.' \ + f'\nBe aware that contracts are filtered as {{"isUS": True}} by default. Set default_filtering=False to prevent this default filtering or specify custom filters. See inline documentation for more details.' \ + f'\nInstruments returned:\n{pformat(instruments)}' + + +def test_get_live_orders_no_filters(client, result): + # Arrange + client.get = MagicMock(return_value=result) + + # Act + client.live_orders() + + # Assert + client.get.assert_called_with('iserver/account/orders', params=None) + + +def test_get_live_orders_with_valid_filters(client, result): + # Arrange + client.get = MagicMock(return_value=result) + filters = ['inactive', 'filled'] + + # Act + client.live_orders(filters=filters) + + # Assert + client.get.assert_called_with('iserver/account/orders', params={'filters': 'inactive,filled'}) + + +def test_get_live_orders_with_single_filter(client, result): + # Arrange + client.get = MagicMock(return_value=result) + + # Act + client.live_orders(filters='submitted') + + # Assert + client.get.assert_called_with('iserver/account/orders', params={'filters': 'submitted'}) + + +def test_get_live_orders_with_incorrect_filter_type(client, result): + # Arrange + client.get = MagicMock(return_value=result) + + # Act and Assert + with pytest.raises(TypeError): + client.live_orders(filters=123) # Non-list, non-string filter + client.get.assert_not_called() + + +def _marketdata_request(method, url, *args, **kwargs): + leaf = url.split('/')[-1] + if leaf == 'stocks': + return MagicMock(json=lambda: ibkr_responses.responses['stocks']) + elif leaf == 'history': + conid = kwargs['params']['conid'] + history_by_conid = { ibkr_responses.responses['filtered_conids'][key]: value for key, value in ibkr_responses.responses['history'].items() } - requests_mock.request.side_effect = self._marketdata_request - - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': False}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), - StockQuery(symbol='HUBS'), - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - ] # fmt: skip - - expected_results = {} - - for query in queries: - data = ibkr_responses.responses['history'][query.symbol]['data'][0] - output = { - 'conid': ibkr_responses.responses['filtered_conids'][query.symbol], - 'symbol': query.symbol, - 'open': data['o'], - 'high': data['h'], - 'low': data['l'], - 'close': data['c'], - 'volume': data['v'], - 'date': datetime.datetime.fromtimestamp(data['t'] / 1000, tz=datetime.timezone.utc), - } - expected_results[query.symbol] = output - - expected_errors = ['Market data for CDN is not live: Delayed', 'Market data for CFC is not live: Delayed'] - - with SafeAssertLogs(self, 'ibind', level='INFO', logger_level='DEBUG', no_logs=False) as cm, \ - RaiseLogsContext(self, 'ibind', level='ERROR', expected_errors=expected_errors): # fmt: skip - results = self.client.marketdata_history_by_symbols(queries) - - verify_log(self, cm, expected_errors) - - # Assertions to verify the correctness of each field in the result - for symbol, expected in expected_results.items(): - result = results[symbol][-1] - self.assertIn(symbol, results) - self.assertAlmostEqual(result['open'], expected['open']) - self.assertAlmostEqual(result['high'], expected['high']) - self.assertAlmostEqual(result['low'], expected['low']) - self.assertAlmostEqual(result['close'], expected['close']) - self.assertAlmostEqual(result['volume'], expected['volume']) - self.assertEqual(result['date'], expected['date']) - - def test_check_health_authenticated_and_connected(self, requests_mock): - response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - self.client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': self.default_url})) - - health_status = self.client.check_health() - self.assertTrue(health_status) - self.client.tickle.assert_called_once() - - def test_check_health_not_authenticated(self, requests_mock): - response_data = {'iserver': {'authStatus': {'authenticated': False, 'competing': False, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - self.client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': self.default_url})) - - health_status = self.client.check_health() - self.assertFalse(health_status) - - def test_check_health_competing_connection(self, requests_mock): - response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': True, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - self.client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': self.default_url})) - - health_status = self.client.check_health() - self.assertFalse(health_status) - - def test_check_health_connection_error(self, requests_mock): - requests_mock.request.side_effect = ConnectTimeout - self.client.tickle = MagicMock(side_effect=ConnectTimeout) - - with self.assertLogs(level='ERROR') as cm: - health_status = self.client.check_health() - self.assertFalse(health_status) - self.assertIn('ConnectTimeout raised when communicating with the Gateway', cm.output[0]) - - def test_check_health_external_broker_error_unauthenticated(self, requests_mock): - requests_mock.request.side_effect = ExternalBrokerError(status_code=401) - self.client.tickle = MagicMock(side_effect=ExternalBrokerError(status_code=401)) - - with self.assertLogs(level='INFO') as cm: - health_status = self.client.check_health() - self.assertFalse(health_status) - self.assertIn('Gateway session is not authenticated.', cm.output[0]) - - def test_check_health_invalid_data(self, requests_mock): - response_data = {} # Invalid data format - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - self.client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': self.default_url})) - - with self.assertRaises(AttributeError) as cm: - self.client.check_health() - self.assertIn('Health check requests returns invalid data', str(cm.exception)) - - def test_marketdata_unsubscribe_success(self, requests_mock): - conids = [12345, 67890] - responses = {12345: MagicMock(status_code=200), 67890: MagicMock(status_code=200)} - requests_mock.request.side_effect = lambda method, url, **kwargs: responses[kwargs['json']['conid']] - self.client.get = MagicMock( - side_effect=lambda url, *args, **kwargs: Result(data={'success': True}, request={'url': url}), __name__='client_get_mock' - ) - - results = self.client.marketdata_unsubscribe(conids) - - for conid, result in results.items(): - self.assertIn(conid, conids) - self.assertIsInstance(result, Result) - self.assertTrue(result.data['success']) - - def test_marketdata_unsubscribe_with_error(self, requests_mock): - conids = [12345, 67890] - responses = { - 12345: MagicMock(status_code=404), # Simulate not found error for one conid - 67890: MagicMock(status_code=200), - } - requests_mock.request.side_effect = lambda method, url, **kwargs: responses[kwargs['json']['conid']] - self.client.get = MagicMock( - side_effect=lambda url, *args, **kwargs: Result(data={'success': True}, request={'url': url}) - if '67890' in url - else ExternalBrokerError(status_code=404), - __name__='client_get_mock', - ) - - results = self.client.marketdata_unsubscribe(conids) - - self.assertIn(12345, results) - self.assertIn(67890, results) - self.assertTrue(results[67890].data['success']) - - def test_marketdata_unsubscribe_raises_exception_on_failure(self, requests_mock): - conids = [12345] - responses = { - 12345: MagicMock(status_code=500), # Simulate server error + return MagicMock(json=lambda: history_by_conid[conid]) + + +def test_marketdata_history_by_symbols(client, requests_mock): + # Arrange + requests_mock.request.side_effect = _marketdata_request + + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': False}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), + StockQuery(symbol='HUBS'), + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + ] + + expected_results = {} + for query in queries: + data = ibkr_responses.responses['history'][query.symbol]['data'][0] + output = { + 'conid': ibkr_responses.responses['filtered_conids'][query.symbol], + 'symbol': query.symbol, + 'open': data['o'], + 'high': data['h'], + 'low': data['l'], + 'close': data['c'], + 'volume': data['v'], + 'date': datetime.datetime.fromtimestamp(data['t'] / 1000, tz=datetime.timezone.utc), } - requests_mock.request.side_effect = lambda method, url, **kwargs: responses[int(url.split('/')[-2])] - self.client.post = MagicMock(side_effect=lambda url, *args, **kwargs: ExternalBrokerError(status_code=500), __name__='client_get_mock') + expected_results[query.symbol] = output + + expected_errors = ['Market data for CDN is not live: Delayed', 'Market data for CFC is not live: Delayed'] + + # Act + with CaptureLogsContext('ibind', level='INFO', logger_level='DEBUG', expected_errors=expected_errors, partial_match=True) as cm: + results = client.marketdata_history_by_symbols(queries) + + # Assert + for msg in expected_errors: + assert msg in cm.output + + for symbol, expected in expected_results.items(): + result = results[symbol][-1] + assert symbol in results + assert result['open'] == pytest.approx(expected['open']) + assert result['high'] == pytest.approx(expected['high']) + assert result['low'] == pytest.approx(expected['low']) + assert result['close'] == pytest.approx(expected['close']) + assert result['volume'] == pytest.approx(expected['volume']) + assert result['date'] == expected['date'] + + +def test_check_health_authenticated_and_connected(client, default_url, requests_mock): + # Arrange + response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is True + client.tickle.assert_called_once() + + +def test_check_health_not_authenticated(client, default_url, requests_mock): + # Arrange + response_data = {'iserver': {'authStatus': {'authenticated': False, 'competing': False, 'connected': True}}} + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is False + + +def test_check_health_competing_connection(client, default_url, requests_mock): + # Arrange + response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': True, 'connected': True}}} + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is False + + +def test_check_health_connection_error(client, requests_mock): + # Arrange + requests_mock.request.side_effect = ConnectTimeout + client.tickle = MagicMock(side_effect=ConnectTimeout) + + # Act + with CaptureLogsContext( + 'ibind.session_mixin', + level='ERROR', + expected_errors=['ConnectTimeout raised when communicating with the Gateway'], + partial_match=True, + ) as cm: + health_status = client.check_health() + + # Assert + assert health_status is False + assert 'ConnectTimeout raised when communicating with the Gateway' in cm.output[0] + + +def test_check_health_external_broker_error_unauthenticated(client, requests_mock): + # Arrange + requests_mock.request.side_effect = ExternalBrokerError(status_code=401) + client.tickle = MagicMock(side_effect=ExternalBrokerError(status_code=401)) + + # Act + with CaptureLogsContext('ibind.session_mixin', level='INFO', expected_errors=['Gateway session is not authenticated.']) as cm: + health_status = client.check_health() + + # Assert + assert health_status is False + assert 'Gateway session is not authenticated.' in cm.output[0] + + +def test_check_health_invalid_data(client, default_url, requests_mock): + # Arrange + response_data = {} # Invalid data format + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act and Assert + with pytest.raises(AttributeError) as excinfo: + client.check_health() + assert 'Health check requests returns invalid data' in str(excinfo.value) + + +def test_marketdata_unsubscribe_success(client, mocker): + # Arrange + conids = [12345, 67890] + + def post_side_effect(url, *args, **kwargs): + conid = kwargs['params']['conid'] + if conid in conids: + return Result(data={'success': True}, request={'url': url}) + raise ExternalBrokerError(status_code=404) + + client.post = MagicMock(side_effect=post_side_effect, __name__='client_post_mock') + + # Act + results = client.marketdata_unsubscribe(conids) + + # Assert + for conid, result in results.items(): + assert int(conid) in conids + assert isinstance(result, Result) + assert result.data['success'] is True + + +def test_marketdata_unsubscribe_with_error(client, mocker): + # Arrange + conids = [12345, 67890] + + def post_side_effect(url, *args, **kwargs): + conid = kwargs['params']['conid'] + if conid == 12345: + raise ExternalBrokerError(status_code=404) + return Result(data={'success': True}, request={'url': url}) + + client.post = MagicMock(side_effect=post_side_effect, __name__='client_post_mock') + + # Act + results = client.marketdata_unsubscribe(conids) + + # Assert + assert 12345 in results + assert 67890 in results + assert results[67890].data['success'] is True + assert isinstance(results[12345], ExternalBrokerError) + + +def test_marketdata_unsubscribe_raises_exception_on_failure(client, mocker): + # Arrange + conids = [12345] + client.post = MagicMock(side_effect=ExternalBrokerError(status_code=500), __name__='client_post_mock') + + # Act + with pytest.raises(ExternalBrokerError) as excinfo: + client.marketdata_unsubscribe(conids) - with self.assertRaises(ExternalBrokerError): - self.client.marketdata_unsubscribe(conids) \ No newline at end of file + # Assert + assert excinfo.value.status_code == 500 \ No newline at end of file diff --git a/test/integration/client/test_ibkr_utils_i.py b/test/integration/client/test_ibkr_utils_i.py index 232e800b..e03569b7 100644 --- a/test/integration/client/test_ibkr_utils_i.py +++ b/test/integration/client/test_ibkr_utils_i.py @@ -1,337 +1,394 @@ from pprint import pformat -from unittest import TestCase -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, call + +import pytest from ibind.base.rest_client import Result -from ibind.client.ibkr_utils import StockQuery, filter_stocks, find_answer, QuestionType, handle_questions, question_type_to_message_id, OrderRequest, parse_order_request -from ibind.support.logs import project_logger +from ibind.client.ibkr_utils import ( + StockQuery, + filter_stocks, + find_answer, + QuestionType, + handle_questions, + question_type_to_message_id, + OrderRequest, + parse_order_request, +) from test.integration.client import ibkr_responses -from test_utils import verify_log - - -class TestIbkrUtilsI(TestCase): - def setUp(self): - self.instruments = ibkr_responses.responses['stocks'] - self.result = Result(data=self.instruments) - self.maxDiff = None - - def test_filter_stocks(self): - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': True}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery(symbol='GOOG', contract_conditions={'isUS': False}, instrument_conditions={'chineseName': 'Alphabet公司'}), - 'HUBS', - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER'), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - StockQuery(symbol='INVALID_SYMBOL') - ] # fmt: skip - with self.assertLogs(project_logger(), level='INFO') as cm: - rv = filter_stocks(queries, Result(data=self.instruments), default_filtering=False) - - verify_log( - self, cm, [f'Error getting stocks. Could not find valid instruments INVALID_SYMBOL in result: {self.result}. Skipping query={queries[-1]}.'] - ) # fmt: skip - - # pprint(rv) - - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': '苹果公司', - 'contracts': [ - {'conid': 38708077, 'exchange': 'MEXI', 'isUS': False}, - {'conid': 273982664, 'exchange': 'EBS', 'isUS': False}, - ], - 'name': 'APPLE INC', - }, - { - 'assetClass': 'STK', - 'chineseName': '苹果公司', - 'contracts': [{'conid': 532640894, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'APPLE INC-CDR', - }, +from test.test_utils import CaptureLogsContext + + +# -------------------------------------------------------------------------------------- +# Stock filtering +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def instruments(): + return ibkr_responses.responses['stocks'] + + +@pytest.fixture +def instruments_result(instruments): + return Result(data=instruments) + + +def test_filter_stocks(instruments, instruments_result): + """Filters instruments for multiple stock queries and logs missing symbols.""" + ## Arrange + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': True}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery( + symbol='GOOG', + contract_conditions={'isUS': False}, + instrument_conditions={'chineseName': 'Alphabet公司'}, + ), + 'HUBS', + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER'), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + StockQuery(symbol='INVALID_SYMBOL'), + ] # fmt: skip + + ## Act + with CaptureLogsContext('ibind', level='INFO', error_level='CRITICAL', attach_stack=False) as cm: + rv = filter_stocks(queries, instruments_result, default_filtering=False) + + ## Assert + expected_error = ( + f'Error getting stocks. Could not find valid instruments INVALID_SYMBOL in result: {instruments_result}. ' + f'Skipping query={queries[-1]}.' + ) + assert expected_error in cm.output + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '苹果公司', + 'contracts': [ + {'conid': 38708077, 'exchange': 'MEXI', 'isUS': False}, + {'conid': 273982664, 'exchange': 'EBS', 'isUS': False}, ], - rv.data['AAPL'], - ) - - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': '西班牙对外银行', - 'contracts': [{'conid': 4815, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'BANCO BILBAO VIZCAYA-SP ADR', - }, + 'name': 'APPLE INC', + }, + { + 'assetClass': 'STK', + 'chineseName': '苹果公司', + 'contracts': [{'conid': 532640894, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'APPLE INC-CDR', + }, + ] == rv.data['AAPL'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '西班牙对外银行', + 'contracts': [{'conid': 4815, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'BANCO BILBAO VIZCAYA-SP ADR', + }, + ] == rv.data['BBVA'] + + assert [] == rv.data['CDN'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': None, + 'contracts': [{'conid': 42001300, 'exchange': 'IBIS', 'isUS': False}], + 'name': 'UET UNITED ELECTRONIC TECHNO', + } + ] == rv.data['CFC'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'Alphabet公司', + 'contracts': [ + {'conid': 210810667, 'exchange': 'MEXI', 'isUS': False}, ], - rv.data['BBVA'], - ) + 'name': 'ALPHABET INC-CL C', + }, + { + 'assetClass': 'STK', + 'chineseName': 'Alphabet公司', + 'contracts': [{'conid': 532638805, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'ALPHABET INC - CDR', + }, + ] == rv.data['GOOG'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'HubSpot公司', + 'contracts': [{'conid': 169544810, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'HUBSPOT INC', + } + ] == rv.data['HUBS'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'Meta平台股份有限公司', + 'contracts': [ + {'conid': 114922621, 'exchange': 'MEXI', 'isUS': False}, + ], + 'name': 'META PLATFORMS INC-CLASS A', + }, + { + 'assetClass': 'STK', + 'chineseName': 'Meta平台股份有限公司', + 'contracts': [{'conid': 530091499, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'META PLATFORMS INC-CDR', + }, + ] == rv.data['META'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '微软公司', + 'contracts': [ + {'conid': 272093, 'exchange': 'NASDAQ', 'isUS': True}, + ], + 'name': 'MICROSOFT CORP', + }, + ] == rv.data['MSFT'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '桑坦德', + 'contracts': [ + {'conid': 38708867, 'exchange': 'MEXI', 'isUS': False}, + {'conid': 385055564, 'exchange': 'WSE', 'isUS': False}, + ], + 'name': 'BANCO SANTANDER SA', + }, + { + 'assetClass': 'STK', + 'chineseName': '桑坦德', + 'contracts': [{'conid': 12442, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'BANCO SANTANDER SA-SPON ADR', + }, + { + 'assetClass': 'STK', + 'chineseName': '桑坦德英国公共有限公司', + 'contracts': [{'conid': 80993135, 'exchange': 'LSE', 'isUS': False}], + 'name': 'SANTANDER UK PLC', + }, + ] == rv.data['SAN'] - self.assertEqual([], rv.data['CDN']) + assert [] == rv.data['SCHW'] - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': None, - 'contracts': [{'conid': 42001300, 'exchange': 'IBIS', 'isUS': False}], - 'name': 'UET UNITED ELECTRONIC TECHNO', - } - ], - rv.data['CFC'], - ) + assert [ + { + 'assetClass': 'STK', + 'chineseName': None, + 'contracts': [{'conid': 589316251, 'exchange': 'NASDAQ', 'isUS': True}], + 'name': 'ATLASSIAN CORP-CL A', + }, + ] == rv.data['TEAM'] - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': 'Alphabet公司', - 'contracts': [ - {'conid': 210810667, 'exchange': 'MEXI', 'isUS': False}, - ], - 'name': 'ALPHABET INC-CL C', - }, - { - 'assetClass': 'STK', - 'chineseName': 'Alphabet公司', - 'contracts': [{'conid': 532638805, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'ALPHABET INC - CDR', - }, - ], - rv.data['GOOG'], - ) - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': 'HubSpot公司', - 'contracts': [{'conid': 169544810, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'HUBSPOT INC', - } - ], - rv.data['HUBS'], - ) +def test_question_type_to_message_id_successful(): + """Maps a QuestionType to its expected IBKR message id.""" + ## Arrange + question_type = QuestionType.PRICE_PERCENTAGE_CONSTRAINT - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': 'Meta平台股份有限公司', - 'contracts': [ - {'conid': 114922621, 'exchange': 'MEXI', 'isUS': False}, - ], - 'name': 'META PLATFORMS INC-CLASS A', - }, - { - 'assetClass': 'STK', - 'chineseName': 'Meta平台股份有限公司', - 'contracts': [{'conid': 530091499, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'META PLATFORMS INC-CDR', - }, - ], - rv.data['META'], - ) + ## Act + message_id = question_type_to_message_id(question_type) - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': '微软公司', - 'contracts': [ - {'conid': 272093, 'exchange': 'NASDAQ', 'isUS': True}, - ], - 'name': 'MICROSOFT CORP', - }, - ], - rv.data['MSFT'], - ) + ## Assert + assert message_id == 'o163' - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': '桑坦德', - 'contracts': [ - {'conid': 38708867, 'exchange': 'MEXI', 'isUS': False}, - {'conid': 385055564, 'exchange': 'WSE', 'isUS': False}, - ], - 'name': 'BANCO SANTANDER SA', - }, - { - 'assetClass': 'STK', - 'chineseName': '桑坦德', - 'contracts': [{'conid': 12442, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'BANCO SANTANDER SA-SPON ADR', - }, - { - 'assetClass': 'STK', - 'chineseName': '桑坦德英国公共有限公司', - 'contracts': [{'conid': 80993135, 'exchange': 'LSE', 'isUS': False}], - 'name': 'SANTANDER UK PLC', - }, - ], - rv.data['SAN'], - ) - self.assertEqual([], rv.data['SCHW']) +# -------------------------------------------------------------------------------------- +# Finding answers +# -------------------------------------------------------------------------------------- - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': None, - 'contracts': [{'conid': 589316251, 'exchange': 'NASDAQ', 'isUS': True}], - 'name': 'ATLASSIAN CORP-CL A', - }, - ], - rv.data['TEAM'], - ) - def test_question_type_to_message_id_successful(self): - question_type = QuestionType.PRICE_PERCENTAGE_CONSTRAINT - message_id = question_type_to_message_id(question_type) - self.assertEqual(message_id, 'o163') +@pytest.fixture +def answers(): + return {QuestionType.PRICE_PERCENTAGE_CONSTRAINT: True} -class TestFindAnswer(TestCase): - def setUp(self): - # Setup Answers dictionary here - self.answers = {QuestionType.PRICE_PERCENTAGE_CONSTRAINT: True} +def test_valid_question(answers): + """Returns True when a known question type is found in the question string.""" + ## Arrange + question = f'Some {QuestionType.PRICE_PERCENTAGE_CONSTRAINT} specific question' - def test_valid_question(self): - question = f'Some {QuestionType.PRICE_PERCENTAGE_CONSTRAINT} specific question' - answer = find_answer(question, self.answers) - self.assertTrue(answer) + ## Act + answer = find_answer(question, answers) - def test_invalid_question(self): - question = 'Nonexistent question type' - with self.assertRaises(ValueError): - find_answer(question, self.answers) + ## Assert + assert answer is True -class TestHandleQuestionsI(TestCase): - def setUp(self): - self.original_result = Result( - data=[{'id': '12345', 'message': ['price exceeds the Percentage constraint of 3%.']}], request={'url': 'test_url'} - ) - self.answers = {QuestionType.PRICE_PERCENTAGE_CONSTRAINT: True} - self.reply_callback = MagicMock() - - @patch('ibind.client.ibkr_utils.QuestionType') - def test_successful_handling(self, question_type_mock): - # Mocking the QuestionType enum - question_type_mock.PRICE_PERCENTAGE_CONSTRAINT.__str__.return_value = 'price exceeds the Percentage constraint of 3%.' - question_type_mock.ADDITIONAL_QUESTION_TYPE.__str__.return_value = 'This is an additional question.' - - self.answers = {question_type_mock.PRICE_PERCENTAGE_CONSTRAINT: True, question_type_mock.ADDITIONAL_QUESTION_TYPE: True} - - # Mock reply_callback to simulate the sequence of question-answer interactions - replies = [ - Result(data=[{'id': '12346', 'message': ['This is an additional question.']}], request={'url': 'another_question_url'}), - Result(data=[{'id': '12347'}], request={'url': 'final_url'}), # No more questions - ] - self.reply_callback.side_effect = replies - - result = handle_questions(self.original_result, self.answers, self.reply_callback) - self.assertEqual(result.request['url'], self.original_result.request['url']) - self.assertEqual(len(self.reply_callback.call_args_list), 2) - # Expected calls to self.reply_callback - expected_calls = [ - call( - self.original_result.data[0]['id'], self.answers[question_type_mock.PRICE_PERCENTAGE_CONSTRAINT] - ), # First call with question ID '12346' and reply True - call( - replies[0].data[0]['id'], self.answers[question_type_mock.ADDITIONAL_QUESTION_TYPE] - ), # Second call with question ID '12347' and reply True - ] - - # Check if the calls to self.reply_callback are as expected - self.assertEqual(expected_calls, self.reply_callback.call_args_list) - - def test_too_many_questions(self): - # Simulate repetitive questions to exceed the question limit - self.reply_callback.side_effect = [self.original_result] * 21 - - with self.assertRaises(RuntimeError) as cm_err: - handle_questions(self.original_result, self.answers, self.reply_callback) - - self.assertIn('Too many questions', str(cm_err.exception)) - - def test_negative_reply(self): - # Set a negative answer - self.answers[QuestionType.PRICE_PERCENTAGE_CONSTRAINT] = False - - with self.assertRaises(RuntimeError) as cm_err: - handle_questions(self.original_result, self.answers, self.reply_callback) - self.assertEqual( - f'A question was not given a positive reply. Question: "{self.original_result.data[0]["message"][0]}". Answers: \n{self.answers}\n. Request: {self.original_result.request}', - str(cm_err.exception), - ) +def test_invalid_question(answers): + """Raises when no answer matches the provided question string.""" + ## Arrange + question = 'Nonexistent question type' + + ## Act & Assert + with pytest.raises(ValueError): + find_answer(question, answers) + + +# -------------------------------------------------------------------------------------- +# Handling interactive questions +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def original_result(): + return Result( + data=[{'id': '12345', 'message': ['price exceeds the Percentage constraint of 3%.']}], + request={'url': 'test_url'}, + ) + + +@pytest.fixture +def reply_callback(): + return MagicMock() + - def test_multiple_orders_returned(self): - # Simulate multiple orders in the data - self.original_result.data = [ - {'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, - {'id': '12346', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, - ] - self.reply_callback.return_value = self.original_result.copy(data=[{}]) +def test_successful_handling(mocker, original_result, reply_callback): + """Replies to a sequence of questions and returns the final result.""" + ## Arrange + question_type_mock = mocker.patch('ibind.client.ibkr_utils.QuestionType') - with self.assertLogs(project_logger(), level='INFO') as cm: - handle_questions(self.original_result, self.answers, self.reply_callback) + question_type_mock.PRICE_PERCENTAGE_CONSTRAINT.__str__.return_value = 'price exceeds the Percentage constraint of 3%.' + question_type_mock.ADDITIONAL_QUESTION_TYPE.__str__.return_value = 'This is an additional question.' - verify_log(self, cm, ['While handling questions multiple orders were returned: ' + pformat(self.original_result.data)]) + answers = {question_type_mock.PRICE_PERCENTAGE_CONSTRAINT: True, question_type_mock.ADDITIONAL_QUESTION_TYPE: True} - def test_multiple_messages_returned(self): - # Simulate a single order with multiple messages - self.original_result.data = [{'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT), 'Message 2']}] - self.reply_callback.return_value = self.original_result.copy(data=[{}]) + replies = [ + Result(data=[{'id': '12346', 'message': ['This is an additional question.']}], request={'url': 'another_question_url'}), + Result(data=[{'id': '12347'}], request={'url': 'final_url'}), + ] + reply_callback.side_effect = replies - with self.assertLogs(project_logger(), level='INFO') as cm: - handle_questions(self.original_result, self.answers, self.reply_callback) + ## Act + result = handle_questions(original_result, answers, reply_callback) - verify_log(self, cm, ['While handling questions multiple messages were returned: ' + pformat(self.original_result.data[0]['message'])]) + ## Assert + assert result.request['url'] == original_result.request['url'] + assert len(reply_callback.call_args_list) == 2 -class TestParseOrderRequestI(TestCase): - def test_parse_both_with_conidex(self): + expected_calls = [ + call(original_result.data[0]['id'], answers[question_type_mock.PRICE_PERCENTAGE_CONSTRAINT]), + call(replies[0].data[0]['id'], answers[question_type_mock.ADDITIONAL_QUESTION_TYPE]), + ] + + assert expected_calls == reply_callback.call_args_list + + +def test_too_many_questions(original_result, answers, reply_callback): + """Raises when the question loop exceeds the maximum number of attempts.""" + ## Arrange + reply_callback.side_effect = [original_result] * 21 + + ## Act & Assert + with pytest.raises(RuntimeError) as cm_err: + handle_questions(original_result, answers, reply_callback) + + assert 'Too many questions' in str(cm_err.value) + + +def test_negative_reply(original_result, answers, reply_callback): + """Raises when a question is answered negatively.""" + ## Arrange + answers[QuestionType.PRICE_PERCENTAGE_CONSTRAINT] = False + + ## Act & Assert + with pytest.raises(RuntimeError) as cm_err: + handle_questions(original_result, answers, reply_callback) + + assert ( + f'A question was not given a positive reply. Question: "{original_result.data[0]["message"][0]}". Answers: \n{answers}\n. Request: {original_result.request}' + == str(cm_err.value) + ) + + +def test_multiple_orders_returned(original_result, answers, reply_callback): + """Logs a message when multiple orders are returned while handling questions.""" + ## Arrange + original_result.data = [ + {'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, + {'id': '12346', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, + ] + reply_callback.return_value = original_result.copy(data=[{}]) + + expected = 'While handling questions multiple orders were returned: ' + pformat(original_result.data) + + ## Act & Assert + with CaptureLogsContext('ibind', level='INFO', expected_errors=[expected], attach_stack=False): + handle_questions(original_result, answers, reply_callback) + + +def test_multiple_messages_returned(original_result, answers, reply_callback): + """Logs a message when multiple messages are returned for a single order.""" + ## Arrange + original_result.data = [{'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT), 'Message 2']}] + reply_callback.return_value = original_result.copy(data=[{}]) + + expected = 'While handling questions multiple messages were returned: ' + pformat(original_result.data[0]['message']) + + ## Act & Assert + with CaptureLogsContext('ibind', level='INFO', expected_errors=[expected], attach_stack=False): + handle_questions(original_result, answers, reply_callback) + + +# -------------------------------------------------------------------------------------- +# Order request parsing +# -------------------------------------------------------------------------------------- + + +def test_parse_both_with_conidex(): + """Parses OrderRequest with conid=None and conidex set into API payload.""" + ## Arrange + order_request = OrderRequest( + conid=None, + side='BUY', + quantity=321, + order_type='MKT', + acct_id='DU1234567', + conidex='33333', + ) + + ## Act + d = parse_order_request(order_request) + + ## Assert + assert { + 'side': 'BUY', + 'quantity': 321, + 'orderType': 'MKT', + 'acctId': 'DU1234567', + 'conidex': '33333', + 'tif': 'GTC', + } == d + + +def test_raise_with_conid_and_conidex(): + """Raises when both conid and conidex are provided.""" + ## Arrange + + ## Act & Assert + with pytest.raises(ValueError) as cm_err: order_request = OrderRequest( - conid=None, + conid=123, side='BUY', quantity=321, order_type='MKT', acct_id='DU1234567', - conidex='33333' # should cause exception + conidex='33333', ) - d = parse_order_request(order_request) - - self.assertEqual({ - 'side': 'BUY', - 'quantity': 321, - 'orderType': 'MKT', - 'acctId': 'DU1234567', - 'conidex': '33333', - 'tif': 'GTC' - }, d) - - def test_raise_with_conid_and_conidex(self): - with self.assertRaises(ValueError) as cm_err: - order_request = OrderRequest( - conid=123, - side='BUY', - quantity=321, - order_type='MKT', - acct_id='DU1234567', - conidex='33333' # should cause exception - ) - - parse_order_request(order_request) - - self.assertEqual("Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`.", str(cm_err.exception)) - + parse_order_request(order_request) + assert "Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`." == str(cm_err.value) \ No newline at end of file diff --git a/test/integration/client/test_ibkr_ws_client_i.py b/test/integration/client/test_ibkr_ws_client_i.py index 8fe6aafb..b3cf9c72 100644 --- a/test/integration/client/test_ibkr_ws_client_i.py +++ b/test/integration/client/test_ibkr_ws_client_i.py @@ -1,415 +1,539 @@ import json -import logging from threading import Thread from typing import Optional -from unittest import TestCase -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, call +import pytest import requests from ibind import Result from ibind.client.ibkr_client import IbkrClient from ibind.client.ibkr_ws_client import IbkrWsClient, IbkrSubscriptionProcessor, IbkrWsKey -from ibind.support.logs import project_logger from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils import RaiseLogsContext, SafeAssertLogs +from test.test_utils import capture_logs + +_URL_WS = 'wss://localhost:5000/v1/api/ws' +_URL_REST = 'https://localhost:5000' +_ACCOUNT_ID = 'TEST_ACCOUNT_ID' +_TIMEOUT_REST = 8 +_MAX_RETRIES_REST = 4 +_MAX_RECONNECT_ATTEMPTS = 4 +_MAX_PING_INTERVAL = 38 +_SUBSCRIPTION_RETRIES = 3 +_CONID = 265598 +_UPDATE_TIME = 5678765456 + + +# -------------------------------------------------------------------------------------- +# Test setup +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def preprocess_ws_client(): + return IbkrWsClient( + url=_URL_WS, + ibkr_client=None, + account_id=None, + subscription_processor_class=lambda: None, + ) + + +@pytest.fixture +def client_mock(): + client = MagicMock( + spec=IbkrClient( + url=_URL_REST, + account_id=_ACCOUNT_ID, + timeout=_TIMEOUT_REST, + max_retries=_MAX_RETRIES_REST, + ) + ) + client.tickle.return_value.data = {'session': 'TEST_COOKIE'} + return client -class TestPreprocessRawMessage(TestCase): - def setUp(self): - self.url = 'wss://localhost:5000/v1/api/ws' +@pytest.fixture +def ws_client(client_mock): + return IbkrWsClient( + url=_URL_WS, + ibkr_client=client_mock, + account_id=_ACCOUNT_ID, + subscription_processor_class=IbkrSubscriptionProcessor, + subscription_retries=_SUBSCRIPTION_RETRIES, + subscription_timeout=0.01, + cacert=False, + timeout=0.01, + max_connection_attempts=_MAX_RECONNECT_ATTEMPTS, + max_ping_interval=_MAX_PING_INTERVAL, + ) - self.ws_client = IbkrWsClient( - url=self.url, - ibkr_client=None, - account_id=None, - subscription_processor_class=lambda: None, - ) - def test_preprocess_with_well_formed_message(self): - raw_message = json.dumps({'topic': 'actABC', 'args': {'key': 'value'}}) - expected_result = ( - {'topic': 'actABC', 'args': {'key': 'value'}}, # message - 'actABC', # topic - {'key': 'value'}, # data - 'a', # subscribed - 'ctABC', # channel - ) - self.assertEqual(self.ws_client._preprocess_raw_message(raw_message), expected_result) - - def test_preprocess_with_unsubscribed_message(self): - raw_message = json.dumps({'message': 'Unsubscribed'}) - expected_result = ({'message': 'Unsubscribed'}, None, None, None, None) - self.assertEqual(self.ws_client._preprocess_raw_message(raw_message), expected_result) - - -class TestIbkrWsClient(TestCase): - # Assuming IbkrWsClient is the class containing preprocess_raw_message - - def setUp(self): - # Assuming similar initialization parameters as in WsClient - self.url = 'wss://localhost:5000/v1/api/ws' - self.max_reconnect_attempts = 4 - self.max_ping_interval = 38 - - self.url_rest = 'https://localhost:5000' - self.account_id = 'TEST_ACCOUNT_ID' - self.timeout = 8 - self.max_retries = 4 - self.subscription_retries = 3 - self.client = MagicMock( - spec=IbkrClient( - url=self.url_rest, - account_id=self.account_id, - timeout=self.timeout, - max_retries=self.max_retries, - ) - ) - self.client.tickle.return_value.data = {'session': 'TEST_COOKIE'} - - self.SubscriptionProcessorClass = IbkrSubscriptionProcessor - - # Initialize the IbkrWsClient - self.ws_client = IbkrWsClient( - url=self.url, - ibkr_client=self.client, - account_id=self.account_id, - subscription_processor_class=self.SubscriptionProcessorClass, - subscription_retries=self.subscription_retries, - subscription_timeout=0.01, - cacert=False, - timeout=0.01, - max_connection_attempts=self.max_reconnect_attempts, - max_ping_interval=self.max_ping_interval, - ) +@pytest.fixture +def wsa_mock(): + return create_wsa_mock() - self.wsa_mock = create_wsa_mock() - self.thread_mock = MagicMock(spec=Thread) - self.thread_mock.start.side_effect = lambda: self.ws_client._run_websocket(self.wsa_mock) - - self.conid = 265598 - self.update_time = 5678765456 - - def run_in_test_context(self, fn, expected_errors: list[str] = None, expect_logs: bool = True): - with patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: init_wsa_mock(self.wsa_mock, *args, **kwargs)), \ - patch('ibind.base.ws_client.Thread', return_value=self.thread_mock) as new_thread_mock, \ - SafeAssertLogs(self, 'ibind', level='DEBUG', logger_level='DEBUG', no_logs=not expect_logs) as cm, \ - RaiseLogsContext(self, 'ibind', level='WARNING', expected_errors=expected_errors): # fmt: skip - ws_client_logger = project_logger('ws_client') - old_level = ws_client_logger.getEffectiveLevel() - ws_client_logger.setLevel(logging.WARNING) - - self.new_thread_mock = new_thread_mock - try: - rv = fn() - except: - raise - finally: - ws_client_logger.setLevel(old_level) - - return cm, rv - - def _send_payload(self, payload: dict, expected_errors: list[str] = None, expect_logs: bool = True): - def run(): - success = self.ws_client.start() - raw_payload = json.dumps(payload) - self.ws_client.send(raw_payload) - self.ws_client.shutdown() - return success - - return self.run_in_test_context(run, expected_errors=expected_errors, expect_logs=expect_logs) - - def _subscribe(self, request: dict, response: Optional[dict], expected_errors: list[str] = None, expect_logs: bool = True): - def run(): - def override_on_message(wsa_mock: MagicMock, message: str): - if response is None: - return - raw_message = json.dumps(response) - wsa_mock.__on_message__(wsa_mock, raw_message) - - self.ws_client.start() - self.wsa_mock._on_message.side_effect = override_on_message - rv = self.ws_client.subscribe( - **{'channel': request.get('channel'), 'data': request.get('data'), 'needs_confirmation': request.get('needs_confirmation')} - ) - self.ws_client.unsubscribe( - **{'channel': request.get('channel'), 'data': request.get('data'), 'needs_confirmation': request.get('confirms_unsubscription')} - ) - self.ws_client.shutdown() - return rv - - return self.run_in_test_context(run, expected_errors=expected_errors, expect_logs=expect_logs) - - def test_on_message_system_heartbeat(self): - hb = 12345678 - cm, success = self._send_payload({'topic': 'system', 'hb': hb}, expect_logs=False) - # print("\n".join([r.msg for r in cm.records])) - self.assertEqual(self.ws_client._last_heartbeat, hb) - - def test_on_message_act_account_mismatch(self): - message_data = {'topic': 'act', 'args': {'accounts': ['OTHER_ACCOUNT_ID']}} - expected_errors = ["IbkrWsClient: Account ID mismatch: expected=TEST_ACCOUNT_ID, received=['OTHER_ACCOUNT_ID']"] - - cm, success = self._send_payload(message_data, expected_errors=expected_errors) - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - - def test_on_message_blt(self): - bulletin_message = {'topic': 'blt', 'args': {'bulletin_key': 'some_info'}} - - with patch.object(self.ws_client, '_handle_bulletin', MagicMock()) as mock_handle_bulletin: - cm, success = self._send_payload(bulletin_message, expect_logs=False) - mock_handle_bulletin.assert_called_once_with(bulletin_message) - - def test_on_message_sts_unauthenticated(self): - message_data = {'topic': 'sts', 'args': {'authenticated': False}} - session_id = 6545676 - - expected_errors = ["IbkrWsClient: Status unauthenticated: {'authenticated': False}", 'IbkrWsClient: Not authenticated, closing WebSocketApp'] - - response_mock = MagicMock(spec=requests.Response) - response_mock.status_code = 200 - response_mock.json.return_value = {'session': session_id, 'data_to_be_ignored': '1234'} - - self.client.tickle.return_value = Result(data=response_mock.json.return_value) - - with patch('ibind.base.rest_client.requests') as requests_mock: - requests_mock.request.return_value = response_mock - cm, success = self._send_payload(message_data, expected_errors=expected_errors) - - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - self.assertFalse(self.ws_client._authenticated) - - def test_on_message_sts_authenticated(self): - message_data = {'topic': 'sts', 'args': {'authenticated': True}} - cm, success = self._send_payload(message_data, expect_logs=False) - - def test_on_message_error(self): - message_data = {'topic': 'error', 'args': {'error_key': 'error_details'}} - expected_errors = [f'IbkrWsClient: Error message: {message_data}'] - - cm, success = self._send_payload(message_data, expected_errors=expected_errors) - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - - def test_on_message_no_topic_handler(self): - message_data = {'topic': 'unrecognized_topic', 'args': {'some_key': 'some_value'}} - expected_errors = [f'IbkrWsClient: Topic "{message_data["topic"]}" unrecognised. Message: {message_data}'] - - cm, success = self._send_payload(message_data, expected_errors=expected_errors) - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - - def test_on_message_handled_without_subscription(self): - message_data = {'topic': 'some_topic', 'args': {'channel': 'XYZ', 'data': 'info'}} - expected_errors = [ - f'IbkrWsClient: Handled a channel "{message_data["topic"][1:]}" message that is missing a subscription. Message: {message_data}' - ] - with patch.object(self.ws_client, '_handle_subscribed_message', return_value=True): - cm, success = self._send_payload(message_data, expected_errors=expected_errors) +@pytest.fixture +def thread_mock(ws_client, wsa_mock): + thread_mock = MagicMock(spec=Thread) + thread_mock.start.side_effect = lambda: ws_client._run_websocket(wsa_mock) + return thread_mock - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - def _logs_subscriptions(self, full_channel, data=None, needs_confirmation_sub: bool = False, needs_confirmation_unsub: bool = True): - return [ - f'IbkrWsClient: Subscribed: s{full_channel}{"" if data is None else f"+{json.dumps(data)}"}{"" if not needs_confirmation_sub else " without confirmation."}', - f'IbkrWsClient: Unsubscribed: u{full_channel}+{json.dumps(data if data is not None else {})}{"" if not needs_confirmation_unsub else " without confirmation."}', - ] +@pytest.fixture +def ws_app_factory(wsa_mock): + # Use a mutable side-effect so individual tests can temporarily override WebSocketApp behavior. + return { + 'fn': lambda *args, **kwargs: init_wsa_mock(wsa_mock, *args, **kwargs), + } - def test_on_message_market_data_channel_handling(self): - queue = self.ws_client.new_queue_accessor(IbkrWsKey.MARKET_DATA) - full_channel = f'{queue.key.channel}+{self.conid}' - request = {'channel': f'{full_channel}', 'data': {'fields': ['55', '71', '84', '86', '88', '85', '87', '7295', '7296', '70']}} - response = { - 'topic': f's{full_channel}', - 'conid': self.conid, - '_updated': self.update_time, - 55: 'AAPL', - 70: '195.34', - 71: '193.67', - 87: '24.2M', - 7295: '194.10', - 84: '195.25', - 86: '195.26', - 88: '3,500', - 85: '500', - 6508: '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', - } - self.assertTrue(queue.empty(), 'Queue should be empty') - - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response) - self.assertTrue(success) - - self.assertEqual(self._logs_subscriptions(full_channel, request['data']), [r.msg for r in cm.records]) - - self.assertEqual( - { - self.conid: { - '_updated': self.update_time, - 'conid': self.conid, - 'topic': f'smd+{self.conid}', - 'ask_price': '195.26', - 'ask_size': '500', - 'bid_price': '195.25', - 'bid_size': '3,500', - 'high': '195.34', - 'low': '193.67', - 'open': '194.10', - 'service_params': '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', - 'symbol': 'AAPL', - 'volume': '24.2M', - } - }, - queue.get(), - ) +@pytest.fixture +def patched_constructors(mocker, thread_mock, ws_app_factory): + mocker.patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: ws_app_factory['fn'](*args, **kwargs)) + mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) + return None + + + +def _send_payload(ws_client, payload: dict): + success = ws_client.start() + ws_client.send(json.dumps(payload)) + ws_client.shutdown() + return success + - def test_on_message_market_history_channel_handling(self): - queue = self.ws_client.new_queue_accessor(IbkrWsKey.MARKET_HISTORY) - server_id = 87567 - full_channel = f'{queue.key.channel}+{self.conid}' - request = { - 'channel': f'{full_channel}', - 'data': {'period': '1min', 'bar': '1min', 'outsideRTH': True, 'source': 'trades', 'format': '%o/%c/%h/%l'}, - 'confirms_unsubscription': False, +def _subscribe(ws_client, wsa_mock, request: dict, response: Optional[dict]): + def override_on_message(wsa_mock: MagicMock, message: str): + if response is None: + return + raw_message = json.dumps(response) + wsa_mock.__on_message__(wsa_mock, raw_message) + + ws_client.start() + wsa_mock._on_message.side_effect = override_on_message + + rv = ws_client.subscribe( + **{ + 'channel': request.get('channel'), + 'data': request.get('data'), + 'needs_confirmation': request.get('needs_confirmation'), } - response = {'topic': f's{full_channel}', 'serverId': server_id, '_updated': self.update_time, 'conid': self.conid, 'foo': 'bar'} + ) + ws_client.unsubscribe( + **{ + 'channel': request.get('channel'), + 'data': request.get('data'), + 'needs_confirmation': request.get('confirms_unsubscription'), + } + ) + ws_client.shutdown() + return rv - self.assertTrue(queue.empty(), 'Queue should be empty') - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response) - self.assertTrue(success) - self.assertEqual(self._logs_subscriptions(full_channel, request['data']), [r.msg for r in cm.records]) +def _logs_subscriptions(full_channel, data=None, needs_confirmation_sub: bool = False, needs_confirmation_unsub: bool = True): + return [ + f'IbkrWsClient: Subscribed: s{full_channel}{"" if data is None else f"+{json.dumps(data)}"}{"" if not needs_confirmation_sub else " without confirmation."}', + f'IbkrWsClient: Unsubscribed: u{full_channel}+{json.dumps(data if data is not None else {})}{"" if not needs_confirmation_unsub else " without confirmation."}', + ] - self.assertEqual(response, queue.get()) - self.assertIn(server_id, self.ws_client.server_ids(IbkrWsKey.MARKET_HISTORY)) - def test_on_message_trade_channel_handling(self): - queue = self.ws_client.new_queue_accessor(IbkrWsKey.TRADES) - full_channel = f'{queue.key.channel}+{self.conid}' - request = {'channel': f'{full_channel}'} - response = {'topic': f's{full_channel}', '_updated': self.update_time, 'conid': self.conid, 'args': [{'foo': 'bar'}]} +# -------------------------------------------------------------------------------------- +# Message preprocessing +# -------------------------------------------------------------------------------------- - self.assertTrue(queue.empty(), 'Queue should be empty') - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response) - self.assertTrue(success) +def test_preprocess_with_well_formed_message(preprocess_ws_client): + """Preprocesses a well-formed raw message into (message, topic, data, subscribed, channel).""" + ## Arrange + raw_message = json.dumps({'topic': 'actABC', 'args': {'key': 'value'}}) + expected_result = ( + {'topic': 'actABC', 'args': {'key': 'value'}}, # message + 'actABC', # topic + {'key': 'value'}, # data + 'a', # subscribed + 'ctABC', # channel + ) - self.assertEqual(self._logs_subscriptions(full_channel), [r.msg for r in cm.records]) - self.assertEqual(response, queue.get()) + ## Act + rv = preprocess_ws_client._preprocess_raw_message(raw_message) - def test_on_message_orders_channel_handling(self): - queue = self.ws_client.new_queue_accessor(IbkrWsKey.ORDERS) + ## Assert + assert rv == expected_result - full_channel = f'{queue.key.channel}+{self.conid}' - request = {'channel': f'{full_channel}'} - response = {'topic': f's{full_channel}', '_updated': self.update_time, 'conid': self.conid, 'args': [{'foo': 'bar'}]} - self.assertTrue(queue.empty(), 'Queue should be empty') +def test_preprocess_with_unsubscribed_message(preprocess_ws_client): + """Returns empty preprocess result for unsubscribed messages.""" + ## Arrange + raw_message = json.dumps({'message': 'Unsubscribed'}) - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response) - self.assertTrue(success) + ## Act + rv = preprocess_ws_client._preprocess_raw_message(raw_message) - self.assertEqual(self._logs_subscriptions(full_channel, None, True, True), [r.msg for r in cm.records]) - self.assertEqual(response, queue.get()) + ## Assert + assert rv == ({'message': 'Unsubscribed'}, None, None, None, None) - def test_subscription_without_confirmation(self): - channel = 'fake' - full_channel = f'{channel}+{self.conid}' - request = {'channel': f'{full_channel}', 'needs_confirmation': False, 'confirms_unsubscription': False} - response = None - expected_errors = [f'IbkrWsClient: Channel subscription timeout: s{full_channel} after {self.subscription_retries} attempts.'] +# -------------------------------------------------------------------------------------- +# On-message handling +# -------------------------------------------------------------------------------------- - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response, expected_errors=expected_errors) - self.assertTrue(success) - self.assertEqual( - [ - f'IbkrWsClient: Subscribed: s{full_channel} without confirmation.', - f'IbkrWsClient: Unsubscribed: u{full_channel}+{{}} without confirmation.', - ], - [r.msg for r in cm.records], - ) +@capture_logs(logger_level='DEBUG') +def test_on_message_system_heartbeat(ws_client, patched_constructors): + """Updates last heartbeat on system heartbeat message.""" + ## Arrange + hb = 12345678 - def test_check_health(self): - start_time = [100] - has_active_connection_counter = [0] - - # control time - def fake_time(): - start_time[0] += 100 - return start_time[0] - - # simulate that we don't have ws connection first - def has_active_connection(): - has_active_connection_counter[0] += 1 - if has_active_connection_counter[0] <= 2: - return False - return True - - # prepare a fake subscription - queue = self.ws_client.new_queue_accessor(IbkrWsKey.TRADES) - full_channel = f'{queue.key.channel}+{self.conid}' - request = {'channel': f'{full_channel}', 'data': {'foo': 'bar'}} - response = {'topic': f's{full_channel}', '_updated': self.update_time, 'conid': self.conid, 'args': [{'foo': 'bar'}]} - - def run(): - # ensures each time WebSocketApp's mock is created, we override its on_message method - def override_init_wsa_mock(wsa_mock: MagicMock, *args, **kwargs): - wsa_mock = init_wsa_mock(wsa_mock, *args, **kwargs) - wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) - return wsa_mock - - self.ws_client.start() - self.ws_client.check_health() - self.wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) - - # create the original subscription - self.ws_client.subscribe(**request) - - # we simulate that closing the WebSocket doesn't work since we have connectivity issues - # self.wsa_mock.on_close.side_effect = lambda x, y, z: None - - # override time.time, ignore check_ping and take control of has_active_connection - with patch('ibind.client.ibkr_ws_client.time') as time_mock, \ - patch.object(self.ws_client, 'check_ping', return_value=True), \ - patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: override_init_wsa_mock(self.wsa_mock, *args, **kwargs)), \ - patch.object(self.ws_client, '_has_active_connection', side_effect=has_active_connection) as has_active_connection_mock: # fmt: skip - time_mock.time.side_effect = fake_time - self.ws_client._last_heartbeat = self.max_ping_interval * 1000 - - # this should try to close the connection, fail to do so, abandon the WebSocketApp's mock, - # then recreate a new mock and recreate the connections - self.ws_client.check_health() - - self.assertTrue(self.ws_client.ready()) - self.assertEqual([call()] * 6, has_active_connection_mock.call_args_list) - self.ws_client.shutdown() - - expected_errors = [ - f'IbkrWsClient: Last IBKR heartbeat happened 162.00 seconds ago, exceeding the max ping interval of {self.max_ping_interval}. Restarting.', - # 'IbkrWsClient: Hard reset close timeout', - # f'IbkrWsClient: Abandoning current WebSocketApp that cannot be closed: {self.wsa_mock}' - ] + ## Act + _send_payload(ws_client, {'topic': 'system', 'hb': hb}) - cm, success = self.run_in_test_context(run, expected_errors=expected_errors) + ## Assert + assert ws_client._last_heartbeat == hb - channel_subscribed_log = f'IbkrWsClient: Subscribed: s{full_channel}+{json.dumps(request["data"])}' +@capture_logs(logger_level='DEBUG', expected_errors = ["IbkrWsClient: Account ID mismatch: expected=TEST_ACCOUNT_ID, received=['OTHER_ACCOUNT_ID']"]) +def test_on_message_act_account_mismatch(ws_client, patched_constructors): + """Logs a warning when account list in act message mismatches expected account.""" + ## Act + _send_payload(ws_client, {'topic': 'act', 'args': {'accounts': ['OTHER_ACCOUNT_ID']}}) - self.assertEqual( - [channel_subscribed_log] - + expected_errors - + [ - f'IbkrWsClient: Invalidated subscription: {full_channel}', - f"IbkrWsClient: Recreating 1/1 subscriptions: {{'{full_channel}': {{'status': False, 'data': {request['data']}, 'needs_confirmation': True, 'subscription_processor': None}}}}", - channel_subscribed_log, - f'IbkrWsClient: Invalidated subscription: {full_channel}', - ], - [r.msg for r in cm.records], - ) + +@capture_logs(logger_level='DEBUG') +def test_on_message_blt(ws_client, patched_constructors, mocker): + """Dispatches bulletin messages to _handle_bulletin.""" + ## Arrange + bulletin_message = {'topic': 'blt', 'args': {'bulletin_key': 'some_info'}} + mock_handle_bulletin = mocker.patch.object(ws_client, '_handle_bulletin', MagicMock()) + + ## Act + _send_payload(ws_client, bulletin_message) + + ## Assert + mock_handle_bulletin.assert_called_once_with(bulletin_message) + +@capture_logs(logger_level='DEBUG', expected_errors=[ + "IbkrWsClient: Status unauthenticated: {'authenticated': False}", + 'IbkrWsClient: Not authenticated, closing WebSocketApp', +]) +def test_on_message_sts_unauthenticated(ws_client, client_mock, patched_constructors, mocker): + """On unauthenticated status, refetches session and closes websocket.""" + ## Arrange + message_data = {'topic': 'sts', 'args': {'authenticated': False}} + session_id = 6545676 + + response_mock = MagicMock(spec=requests.Response) + response_mock.status_code = 200 + response_mock.json.return_value = {'session': session_id, 'data_to_be_ignored': '1234'} + + client_mock.tickle.return_value = Result(data=response_mock.json.return_value) + + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response_mock + + ## Act + _send_payload(ws_client, message_data) + + ## Assert + assert ws_client._authenticated is False + +@capture_logs(logger_level='DEBUG') +def test_on_message_sts_authenticated(ws_client, patched_constructors): + """Accepts authenticated status without logging warnings.""" + ## Act + _send_payload(ws_client, {'topic': 'sts', 'args': {'authenticated': True}}) + + +@capture_logs(logger_level='DEBUG', expected_errors = [f'IbkrWsClient: Error message:'], partial_match=True) +def test_on_message_error(ws_client, patched_constructors): + """Logs error-topic messages as warnings.""" + ## Act + _send_payload(ws_client, {'topic': 'error', 'args': {'error_key': 'error_details'}}) + + + +@capture_logs(logger_level='DEBUG', expected_errors=['unrecognised. Message:'], partial_match=True) +def test_on_message_no_topic_handler(ws_client, patched_constructors): + """Logs a warning when no handler exists for a topic.""" + ## Arrange + message_data = {'topic': 'unrecognized_topic', 'args': {'some_key': 'some_value'}} + + ## Act + _send_payload(ws_client, message_data) + + +@capture_logs(logger_level='DEBUG', expected_errors = [ + 'message that is missing a subscription. Message:' +], partial_match=True) +def test_on_message_handled_without_subscription(ws_client, patched_constructors, mocker): + """Logs a warning if a subscribed message arrives without a known subscription.""" + ## Arrange + mocker.patch.object(ws_client, '_handle_subscribed_message', return_value=True) + + ## Act + _send_payload(ws_client, {'topic': 'some_topic', 'args': {'channel': 'XYZ', 'data': 'info'}}) + + + +# -------------------------------------------------------------------------------------- +# Subscription + channel-specific handling +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_on_message_market_data_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes market data updates into the MARKET_DATA queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.MARKET_DATA) + full_channel = f'{queue.key.channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'data': {'fields': ['55', '71', '84', '86', '88', '85', '87', '7295', '7296', '70']}, + } + response = { + 'topic': f's{full_channel}', + 'conid': _CONID, + '_updated': _UPDATE_TIME, + 55: 'AAPL', + 70: '195.34', + 71: '193.67', + 87: '24.2M', + 7295: '194.10', + 84: '195.25', + 86: '195.26', + 88: '3,500', + 85: '500', + 6508: '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, request['data'])) + assert ( + { + _CONID: { + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'topic': f'smd+{_CONID}', + 'ask_price': '195.26', + 'ask_size': '500', + 'bid_price': '195.25', + 'bid_size': '3,500', + 'high': '195.34', + 'low': '193.67', + 'open': '194.10', + 'service_params': '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', + 'symbol': 'AAPL', + 'volume': '24.2M', + } + } + == queue.get() + ) + + +@capture_logs(logger_level='DEBUG') +def test_on_message_market_history_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes market history updates into the MARKET_HISTORY queue and tracks server IDs.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.MARKET_HISTORY) + server_id = 87567 + full_channel = f'{queue.key.channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'data': {'period': '1min', 'bar': '1min', 'outsideRTH': True, 'source': 'trades', 'format': '%o/%c/%h/%l'}, + 'confirms_unsubscription': False, + } + response = { + 'topic': f's{full_channel}', + 'serverId': server_id, + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'foo': 'bar', + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, request['data'])) + assert response == queue.get() + assert server_id in ws_client.server_ids(IbkrWsKey.MARKET_HISTORY) + + +@capture_logs(logger_level='DEBUG') +def test_on_message_trade_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes trade updates into the TRADES queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.TRADES) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}'} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel)) + assert response == queue.get() + + +@capture_logs(logger_level='DEBUG') +def test_on_message_orders_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes order updates into the ORDERS queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.ORDERS) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}'} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, None, True, True)) + assert response == queue.get() + + +@capture_logs(logger_level='DEBUG') +def test_subscription_without_confirmation(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Subscribes/unsubscribes without confirmation when requested.""" + ## Arrange + cm = kwargs['_cm_ibind'] + channel = 'fake' + full_channel = f'{channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'needs_confirmation': False, + 'confirms_unsubscription': False, + } + response = None + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log([ + f'IbkrWsClient: Subscribed: s{full_channel} without confirmation.', + f'IbkrWsClient: Unsubscribed: u{full_channel}+{{}} without confirmation.', + ]) + + + +# -------------------------------------------------------------------------------------- +# Health checks +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG', expected_errors=[ + f'IbkrWsClient: Last IBKR heartbeat happened 162.00 seconds ago, exceeding the max ping interval of {_MAX_PING_INTERVAL}. Restarting.', +]) +def test_check_health(ws_client, wsa_mock, ws_app_factory, patched_constructors, mocker, **kwargs): + """Restarts and recreates subscriptions when heartbeat exceeds max ping interval.""" + ## Arrange + cm = kwargs['_cm_ibind'] + start_time = [100] + has_active_connection_counter = [0] + + def fake_time(): + start_time[0] += 100 + return start_time[0] + + def has_active_connection(): + has_active_connection_counter[0] += 1 + if has_active_connection_counter[0] <= 2: + return False + return True + + queue = ws_client.new_queue_accessor(IbkrWsKey.TRADES) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}', 'data': {'foo': 'bar'}} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + ## Act + def override_init_wsa_mock(wsa_mock: MagicMock, *args, **kwargs): + wsa_mock = init_wsa_mock(wsa_mock, *args, **kwargs) + wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) + return wsa_mock + + ws_client.start() + ws_client.check_health() + wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) + + ws_client.subscribe(**request) + + # Override time, ignore ping check, and control active-connection health checks. + time_mock = mocker.patch('ibind.client.ibkr_ws_client.time') + time_mock.time.side_effect = fake_time + + mocker.patch.object(ws_client, 'check_ping', return_value=True) + mocker.patch.object(ws_client, '_has_active_connection', side_effect=has_active_connection) + + # Ensure each reconnect creates a WebSocketApp whose on_message pushes our fake response. + ws_app_factory['fn'] = lambda *args, **kwargs: override_init_wsa_mock(wsa_mock, *args, **kwargs) + + ws_client._last_heartbeat = _MAX_PING_INTERVAL * 1000 + ws_client.check_health() + + assert ws_client.ready() is True + assert [call()] * 6 == ws_client._has_active_connection.call_args_list + + ws_client.shutdown() + + + ## Assert + channel_subscribed_log = f'IbkrWsClient: Subscribed: s{full_channel}+{json.dumps(request["data"])}' + cm.partial_log( + [channel_subscribed_log] + + [ + f'IbkrWsClient: Invalidated subscription: {full_channel}', + f"IbkrWsClient: Recreating 1/1 subscriptions: {{'{full_channel}': {{'status': False, 'data': {request['data']}, 'needs_confirmation': True, 'subscription_processor': None}}}}", + channel_subscribed_log, + f'IbkrWsClient: Invalidated subscription: {full_channel}', + ] + ) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index d8523004..e33822c2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,284 +1,476 @@ import functools +import inspect import logging -import sys import traceback -import types -import unittest -from unittest import TestCase -from unittest._log import _CapturingHandler, _AssertLogsContext +from typing import List, Union -from ibind.support.py_utils import make_clean_stack +from ibind.support.logs import get_logger_children +from ibind.support.py_utils import make_clean_stack, OneOrMany, UNDEFINED -def raise_from_context(cm, level='WARNING'): - for record in cm.records: - if record.levelno >= getattr(logging, level): - raise RuntimeError(record.message) +def _accepts_kwargs(func): + """ + Check if a function accepts **kwargs. + Args: + func: A callable to inspect. -def verify_log(test_case: TestCase, cm, expected_messages, comparison: callable = lambda x, y: x == y): - messages = [record.msg for record in cm.records] - missing_expected = expected_messages.copy() - for i, expected_msg in enumerate(expected_messages): - for msg in messages: - if comparison(expected_msg, msg): - missing_expected.remove(expected_msg) - break + Returns: + bool: True if the function accepts **kwargs, False otherwise. + """ + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return False - if missing_expected: - test_case.fail('Expected log(s) not found:\n\t{}'.format('\n\t'.join(missing_expected))) +# --- Logging Utilities --- -def verify_log_simple(test_self, cm, expected_messages): - for i, msg in enumerate(expected_messages): - test_self.assertEqual(msg, cm.records[i].msg) +class LoggingWatcher: + """ + Captures and asserts on log messages during testing. + """ + def __init__(self, logger): + self.logger = logger + self.records = [] + self.output = [] + + def _process_logs(self, expected_messages: OneOrMany[str], comparison: callable = lambda x, y: x == y): + if not isinstance(expected_messages, list): + expected_messages = [expected_messages] + + if not self.output: + return [], expected_messages + + messages = [msg for msg in self.output] + missing_expected = expected_messages.copy() + found = [] + for i, expected_msg in enumerate(expected_messages): + for msg in messages: + if comparison(expected_msg, msg): + found.append(msg) + missing_expected.remove(expected_msg) + break + return found, missing_expected -def exact_log(test_case, cm, expected_messages): - test_case.assertEqual(expected_messages, [record.msg for record in cm.records]) + def exact_log(self, expected_messages: OneOrMany[str]): + """ + Assert that all expected messages appear exactly in the captured logs. + Args: + expected_messages: A single message string or list of message strings to match. -class SafeAssertLogs(_AssertLogsContext): - """ - The self.assertLogs context manager, that sets log level on the handler instead of logger. + Raises: + AssertionError: If any expected message is not found in the captured logs. + """ + found, missing_expected = self._process_logs(expected_messages, lambda x, y: x == y) + if len(missing_expected) > 0: + missing_expected_str = '\n\t'.join(missing_expected) + raise AssertionError(f"Expected exact log(s) not found:\n\t{missing_expected_str}\n\nActual logs:\n{self.format_logs()}\n") - Original docstring: - A context manager used to implement TestCase.assertLogs(). - """ + def partial_log(self, expected_messages: OneOrMany[str]): + """ + Assert that each expected message is a substring of at least one captured log. - def __init__(self, *args, logger_level: str = None, **kwargs): - if sys.version_info < (3, 10, 0) and 'no_logs' in kwargs: - del kwargs['no_logs'] + Args: + expected_messages: A single message string or list of message strings to match as substrings. - super().__init__(*args, **kwargs) - self.logger_level = logger_level + Raises: + AssertionError: If any expected message is not found as a substring in the captured logs. + """ + found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) + if len(missing_expected) > 0: + missing_expected_str = '\n\t'.join(missing_expected) + raise AssertionError(f"Expected partial log(s) not found:\n\t{missing_expected_str}\n\nActual logs:\n{self.format_logs()}\n") - def __enter__(self, include_original_handlers: bool = False): - if isinstance(self.logger_name, logging.Logger): - logger = self.logger = self.logger_name - else: - logger = self.logger = logging.getLogger(self.logger_name) - formatter = logging.Formatter(self.LOGGING_FORMAT) - handler = _CapturingHandler() - handler.setFormatter(formatter) - self.watcher = handler.watcher - self.old_handlers = logger.handlers[:] - self.old_level = logger.level - self.old_propagate = logger.propagate - logger.handlers = [handler] - handler.setLevel(self.level) # this one line is different, originally was `logger.setLevel` - logger.propagate = False - if self.logger_level is not None: - logger.setLevel(getattr(logging, self.logger_level)) + def log_excludes(self, expected_messages: OneOrMany[str]): + """ + Assert that none of the expected messages appear in any captured log. - if include_original_handlers: - logger.handlers += self.old_handlers - logger.propagate = True - return handler.watcher + Args: + expected_messages: A single message string or list of message strings to exclude. + Raises: + AssertionError: If any expected message is found in the captured logs. + """ + found, _ = self._process_logs(expected_messages, lambda x, y: x in y) + if found: + found_str = '\n\t'.join(found) + raise AssertionError(f"Unexpected log(s) found:\n\t{found_str}\n\nCurrent logs:\n{self.format_logs()}\n") -def get_logger_children(main_logger) -> list[logging.Logger]: - """ - Gets child loggers. Added as a support compat for Python version 3.11 and below. - Source: https://github.com/python/cpython/blob/3.12/Lib/logging/__init__.py#L1831 - """ + def format_logs(self): + """ + Return a formatted string of all captured log messages. + + Returns: + str: A formatted string containing all captured logs. + """ + output_str = '\n\t'.join(self.output) + return f"\n{self} captured {len(self.output)} logs:\n[\n\t{output_str}\n]" - def _hierlevel(logger): - if logger is logger.manager.root: - return 0 - return 1 + logger.name.count('.') + def count_occurrences(self, msg: str): + """ + Count occurrences of a message in the captured logs. + + Args: + msg: The message substring to count. + + Returns: + int: The number of logs containing the message substring. + """ + return sum(1 for log in self.output if msg in log) + + def print(self): + """ + Print the formatted logs to stdout. + """ + print(self.format_logs()) - d = main_logger.manager.loggerDict - # exclude PlaceHolders - the last check is to ensure that lower-level - # descendants aren't returned - if there are placeholders, a logger's - # parent field might point to a grandparent or ancestor thereof. - return [ - item - for item in d.values() - if isinstance(item, logging.Logger) and item.parent is main_logger and _hierlevel(item) == 1 + _hierlevel(item.parent) - ] + def __str__(self): + return f'LoggingWatcher({self.logger.name})' -class RaiseLogsContext: +class _CapturingHandler(logging.Handler): + """ + Internal logging handler that captures all logging output. """ - Captures log messages at or above a specified level and raises unexpected ones as exceptions. - This context manager monitors log messages from a specified logger. Any log messages - at or above the given logging level are recorded. If a message is not explicitly - expected, a `RuntimeError` is raised, including the stack trace of the log call. It ensures - loggers are restored to their original state after use. + def __init__(self, logger): + logging.Handler.__init__(self) + self.watcher = LoggingWatcher(logger) - Note: - - When used in conjunction with `self.assertLogs` or `SafeAssertLogs`, ensure this context manager is defined last to properly assert log expectations. + def flush(self): + pass - Args: - test_case (TestCase): The test case instance, typically from `unittest.TestCase`. - logger_name (str | None): The name of the logger to monitor. Defaults to the root logger. - level (str): The logging level threshold (e.g., 'ERROR', 'WARNING'). Logs at or above this level are captured. - expected_errors (list[str] | None): A list of log messages that are expected and should not trigger an exception. - comparison (Callable[[str, str], bool]): A function to compare expected errors with log messages. - Defaults to an exact string match (`lambda x, y: x == y`). - - Example Usage: - >>> with RaiseLogsContext(self, logger_name='my_logger', level='WARNING', expected_errors=['My expected warning']): - ... logging.getLogger('my_logger').warning('My expected warning') # No error - ... logging.getLogger('my_logger').error('Unexpected issue') # Raises RuntimeError + def emit(self, record): + self.watcher.records.append(record) + msg = self.format(record) + self.watcher.output.append(msg) + + +class CaptureLogsContext: + """ + Context manager for capturing and validating log output during tests. """ + LOGGING_FORMAT = "%(message)s" def __init__( self, - test_case: TestCase, - logger_name=None, - level='ERROR', - expected_errors: [str] = None, - comparison: callable = lambda x, y: x == y, + logger: str = 'ibind', + level: str = 'DEBUG', + logger_level: str = None, + error_level: str = 'WARNING', + no_logs: Union[bool, object] = UNDEFINED, + expected_errors: List[str] = None, + partial_match: bool = False, + attach_stack: bool = True, ): - self._test_case = test_case - self._logger_name = logger_name - self._level = level - self._level_no = getattr(logging, level) - if expected_errors is None: - expected_errors = [] - self._expected_errors = expected_errors - self._comparison = comparison - - def monkey_patch_log(self, original_method): - """Wraps a logger method to attach a manually captured stack trace to log records.""" - - def new_method(msg, *args, **kwargs): - # Store the manually captured stack trace in the log record - stack = make_clean_stack() - if 'extra' not in kwargs: - kwargs['extra'] = {} - kwargs['extra']['manual_trace'] = stack - - # Call the original logging method with the modified arguments - return original_method(msg, *args, **kwargs) - - return new_method - - def monkey_patch_loggers(self, loggers): - """Monkey-patches loggers to attach a stack trace to warning and error messages.""" + """ + Initialize a log capture context. + + Args: + logger (str): Logger name to capture. Defaults to 'ibind'. + level (str): Logging level to capture. Defaults to 'DEBUG'. + logger_level (str): Optional logger-specific level override. + error_level (str): Logging level threshold for unexpected logs. Defaults to 'WARNING'. + no_logs (bool): If True, assert no logs are produced. If False, assert logs are produced. + Defaults to UNDEFINED (no assertion). + expected_errors (list): List of expected error messages to match. + partial_match (bool): If True, match expected errors as substrings. Defaults to False. + attach_stack (bool): If True, attach stack traces to logs. Defaults to True. + """ + self._logger = logger + self.level = getattr(logging, level) if isinstance(level, str) else level + self.logger_level = getattr(logging, logger_level) if isinstance(logger_level, str) else logger_level + self.no_logs = no_logs + self.expected_errors = expected_errors or [] + self.partial_match = partial_match + self.comparison = (lambda x, y: x in y) if partial_match else (lambda x, y: x == y) + self.attach_stack = attach_stack + self.error_level = getattr(logging, error_level) if isinstance(error_level, str) else (error_level if error_level is not None else self.level) + if not isinstance(self.expected_errors, list): + self.expected_errors = [self.expected_errors] + + def _monkey_patch_log(self, logger): + original_log = logger._log + + def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): + if extra is None: + extra = {} + extra['manual_trace'] = make_clean_stack()[:-2] + + return original_log(level, msg, args, exc_info, extra, stack_info, stacklevel) + + logger.__old_log_method__ = original_log + logger._log = new_log + + def _monkey_patch_loggers(self, loggers): for logger in loggers: - if self._level_no <= logging.ERROR: - logger.__old_error_method__ = logger.error - logger.error = self.monkey_patch_log(logger.error) - - if self._level_no <= logging.WARNING: - logger.__old_warning_method__ = logger.warning - logger.warning = self.monkey_patch_log(logger.warning) + self._monkey_patch_log(logger) - def restore_loggers(self, loggers): - """Restores the original error and warning logging methods after patching.""" + def _restore_loggers(self, loggers): for logger in loggers: - if self._level_no <= logging.ERROR: - logger.error = logger.__old_error_method__ # Restore the original error method + if hasattr(logger, '__old_log_method__'): + logger._log = logger.__old_log_method__ - if self._level_no <= logging.WARNING: - logger.warning = logger.__old_warning_method__ # Restore the original warning method + def logger_name(self): + """ + Get the logger name. - def __enter__(self): + Returns: + str: The name of the logger. """ - Initializes the logging context by patching loggers and setting up a log watcher. + return self._logger.name if isinstance(self._logger, logging.Logger) else self._logger - This method ensures that logs at the specified level are captured and asserts - that unexpected log messages are raised as errors. + def acquire(self) -> LoggingWatcher: """ + Acquire and configure the logger for capturing. - self._logger = logging.getLogger(self._logger_name) - loggers_to_be_patched = [self._logger] + get_logger_children(self._logger) - self.monkey_patch_loggers(loggers_to_be_patched) # Apply monkey-patching to attach stack traces to logged messages + Returns: + LoggingWatcher: A watcher object for asserting on captured logs. + """ + self.logger = logging.getLogger(self.logger_name()) + self.old_handlers = self.logger.handlers[:] + self.old_level = self.logger.level + self.old_propagate = self.logger.propagate - # Initialize SafeAssertLogs, a helper to capture and assert log records - self._context_manager = SafeAssertLogs(self._test_case, self._logger, level=self._level, no_logs=False) + formatter = logging.Formatter(self.LOGGING_FORMAT, datefmt='%H:%M:%S') + handler = _CapturingHandler(self.logger) + handler.setFormatter(formatter) + self.watcher = handler.watcher + self.logger.handlers = [handler] + handler.setLevel(self.level) + self.logger.propagate = False + if self.logger_level is not None: + self.logger.setLevel(self.logger_level) - # Enter the SafeAssertLogs context, starting log capture and returning the watcher - self._watcher = self._context_manager.__enter__(include_original_handlers=True) - return self._watcher + if self.attach_stack: + loggers_to_patch = [self.logger] + get_logger_children(self.logger) + self._monkey_patch_loggers(loggers_to_patch) + self._loggers_to_patch = loggers_to_patch + else: + self._loggers_to_patch = [] - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Restores original logger methods and verifies captured log messages. + return self.watcher - This method is called when exiting the context manager. It ensures that: - - Monkey-patched loggers are restored to their original state. - - If an exception occurred inside the `with` block, it is propagated normally. - - If no exception occurred, all captured log messages are checked against expected errors. - - Unexpected log messages result in a `RuntimeError`. + def _raise_unexpected_log(self, record): + if hasattr(record, 'manual_trace'): + raise RuntimeError(f'\n{"".join(traceback.format_list(record.manual_trace))}Logger {self.logger} logged an unexpected message:\n{record.msg}') + raise RuntimeError(f'\n...\nFile "{record.pathname}", line {record.lineno} in {record.funcName}\n{record.msg}') + + def _process_exit_logs(self): + records = self.watcher.records + if self.no_logs is not UNDEFINED and self.no_logs: + if records: + self._raise_unexpected_log(records[0]) + return True + + if self.no_logs is not UNDEFINED and not records: + raise AssertionError(f"no logs of level {logging.getLevelName(self.level)} or higher triggered on {self.logger.name}") + + for record in records: + if record.levelno < self.error_level: + continue + if any(self.comparison(expected, record.msg) for expected in self.expected_errors): + continue + self._raise_unexpected_log(record) + + if self.partial_match: + self.watcher.partial_log(self.expected_errors) + else: + self.watcher.exact_log(self.expected_errors) + + def release(self, exc_type=None, exc_val=None, exc_tb=None): """ + Release and restore the logger to its original state. - # Restore original logging methods that were monkey-patched - loggers_to_be_patched = [self._logger] + get_logger_children(self._logger) - self.restore_loggers(loggers_to_be_patched) + Args: + exc_type: Exception type if an exception occurred. + exc_val: Exception value if an exception occurred. + exc_tb: Exception traceback if an exception occurred. - # If an exception occurred inside the 'with' block, return False to let Python re-raise it - if exc_type is not None: - return False + Returns: + bool: True if no exception occurred, False otherwise. + """ + self.logger.handlers = self.old_handlers + self.logger.propagate = self.old_propagate + self.logger.setLevel(self.old_level) + if self._loggers_to_patch: + self._restore_loggers(self._loggers_to_patch) + self._process_exit_logs() + return exc_type is None - # If no logs were captured return True to indicate that no errors were encountered and that the context exited cleanly - if len(self._watcher.records) == 0: - return True + def __enter__(self) -> LoggingWatcher: + return self.acquire() - for record in self._watcher.records: - found = False + def __exit__(self, exc_type, exc_val, exc_tb): + return self.release(exc_type, exc_val, exc_tb) - # Check if the log message matches any of the expected error messages - for expected_error in self._expected_errors: - if self._comparison(expected_error, record.msg): - found = True - break - # If the message is expected, move on to the next record - if found: - continue +def capture_logs(**ctx_kwargs): + """ + Decorator to capture and validate logs in a test function. - # If the log record has a manually stored traceback, raise an error with that traceback - if hasattr(record, 'manual_trace'): - raise RuntimeError( - '\n' + ''.join(traceback.format_list(record.manual_trace)) + f'Logger {self._logger} logged an unexpected message:\n{record.msg}' - ) + Args: + **ctx_kwargs: Keyword arguments passed to CaptureLogsContext. + Common options: logger, level, error_level, expected_errors, partial_match. - # Otherwise, raise an error using the log record's location - raise RuntimeError(f'\n...\nFile "{record.pathname}", line {record.lineno} in {record.funcName}\n{record.msg}') + Returns: + callable: A decorator that wraps a test function to capture logs. + Example: + @capture_logs(logger='myapp', expected_errors=['Error occurred']) + def test_something(): + # test code that logs + pass + """ -def raise_logs(level='ERROR', logger_name=None): - def _wrapper(fn): - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - with RaiseLogsContext(self, level=level, logger_name=logger_name): - return fn(self, *args, **kwargs) + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + capture_log_context = CaptureLogsContext(**ctx_kwargs) + logger_name = f'_cm_{capture_log_context.logger_name()}' + fn_exc = None + log_exc = None + + cm = capture_log_context.acquire() + if _accepts_kwargs(test_func): + kwargs[logger_name] = cm + + try: + rv = test_func(*args, **kwargs) + except Exception as e: + rv = None + fn_exc = e + + try: + capture_log_context.release() + except Exception as e2: + log_exc = e2 + + if fn_exc is not None: + if log_exc is not None: + print('Unexpected log found in test:') + traceback.print_exception(log_exc) + raise fn_exc + elif log_exc is not None: + raise log_exc + + return rv return wrapper - return _wrapper + return decorator -def decorate_methods(decorator, starts_with=''): - class DecorateMethods(type): - """Decorate all methods of the class with the decorator provided""" +# --- Time Mocking Utilities --- - def __new__(cls, name, bases, attrs, **kwargs): - exclude = kwargs.get('exclude', []) +class MockTimeController: + """ + Mock time module for testing time-dependent code. + """ - for attr_name, attr_value in attrs.items(): - if ( - isinstance(attr_value, types.FunctionType) - and attr_name.startswith(starts_with) - and attr_name not in exclude - and not hasattr(attr_value, '__exclude_decorator__') - and not attr_name.startswith('__') - ): - attrs[attr_name] = decorator(attr_value) + def __init__(self, target_module, time_sequence=None, start_time=0.0): + """ + Initialize a mock time controller. - return super(DecorateMethods, cls).__new__(cls, name, bases, attrs) + Args: + target_module (str): Module name to inject the mock time into (eg. 'mymodule.submodule'). + time_sequence (list): Optional sequence of time values to return on successive calls. + If provided, time_sequence takes precedence over start_time. + start_time (float): Initial time value. Defaults to 0.0. Ignored if time_sequence is provided. + """ + self.target_module = target_module + if time_sequence is not None: + self.time_sequence = list(time_sequence) + self.call_index = 0 + else: + self.time_sequence = None + self.current_time = start_time + self.original_time_module = None - return DecorateMethods + def advance_time(self, seconds): + """ + Advance the mock time by the specified number of seconds. + Args: + seconds (float): Number of seconds to advance. -class TestCaseWithRaiseLogs(unittest.TestCase, metaclass=decorate_methods(raise_logs(logger_name='ibind'), starts_with='test')): ... + Raises: + ValueError: If using time_sequence mode. + """ + if self.time_sequence is not None: + raise ValueError("Cannot advance time when using time_sequence.") + self.current_time += seconds + def set_time(self, time_value): + """ + Set the mock time to a specific value. + + Args: + time_value (float): The time value to set. + + Raises: + ValueError: If using time_sequence mode. + """ + if self.time_sequence is not None: + raise ValueError("Cannot set time when using time_sequence.") + self.current_time = time_value + + def mock_time(self): + """ + Get the current mock time value. + + Returns: + float: The current time value. If using time_sequence, returns the next value in the sequence. + """ + if self.time_sequence is not None: + if self.call_index < len(self.time_sequence): + time_value = self.time_sequence[self.call_index] + self.call_index += 1 + return time_value + else: + return self.time_sequence[-1] + else: + return self.current_time + + def __enter__(self): + target_module_obj = __import__(self.target_module, fromlist=['']) + self.original_time_module = target_module_obj.time + + class MockTimeModule: + def __init__(self, original_module, mock_time_func): + self.original_module = original_module + self.time = mock_time_func -def exclude_decorator(fn): - fn.__exclude_decorator__ = True - return fn + def __getattr__(self, name): + return getattr(self.original_module, name) + + target_module_obj.time = MockTimeModule(self.original_time_module, self.mock_time) + self.target_module_obj = target_module_obj + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.target_module_obj.time = self.original_time_module + + +def mock_module_time(target_module, time_sequence=None, start_time=0.0): + """ + Create a mock time controller for a target module. + + Args: + target_module (str): Module name to inject the mock time into. + time_sequence (list): Optional sequence of time values to return on successive calls. + start_time (float): Initial time value. Defaults to 0.0. + + Returns: + MockTimeController: A context manager for mocking time in the target module. + + Example: + with mock_module_time('mymodule', time_sequence=[1.0, 2.0, 3.0]): + # time.time() in mymodule will return 1.0, then 2.0, then 3.0 + pass + """ + return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) \ No newline at end of file diff --git a/test/unit/support/test_py_utils_u.py b/test/unit/support/test_py_utils_u.py index bcd88198..5fb4a250 100644 --- a/test/unit/support/test_py_utils_u.py +++ b/test/unit/support/test_py_utils_u.py @@ -1,113 +1,228 @@ import time -import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock + +import pytest from ibind.support.py_utils import ensure_list_arg, execute_in_parallel, execute_with_key, wait_until -class TestEnsureListArgU(unittest.TestCase): - @ensure_list_arg('arg') - def sample_function(self, arg): - return arg +@ensure_list_arg('arg') +def sample_function(arg): + return arg + + +def test_ensure_list_arg_with_list(): + """Wraps list args without altering the list.""" + # Arrange + input_arg = [1, 2, 3] + + # Act + result = sample_function(input_arg) + + # Assert + assert result == input_arg + + +def test_ensure_list_arg_with_non_list(): + """Wraps a non-list arg into a single-item list.""" + # Arrange + input_arg = 1 + + # Act + result = sample_function(input_arg) + + # Assert + assert result == [input_arg] + + +def test_ensure_list_arg_with_keyword_arg_list(): + """Preserves list input when passed as a keyword arg.""" + # Arrange + input_arg = [1, 2, 3] + + # Act + result = sample_function(arg=input_arg) + + # Assert + assert result == input_arg + - def test_ensure_list_arg_with_list(self): - input_arg = [1, 2, 3] - self.assertEqual(self.sample_function(input_arg), input_arg) +def test_ensure_list_arg_with_keyword_arg_non_list(): + """Wraps a non-list keyword arg into a single-item list.""" + # Arrange + input_arg = 1 - def test_ensure_list_arg_with_non_list(self): - input_arg = 1 - self.assertEqual(self.sample_function(input_arg), [input_arg]) + # Act + result = sample_function(arg=input_arg) - def test_ensure_list_arg_with_keyword_arg_list(self): - input_arg = [1, 2, 3] - self.assertEqual(self.sample_function(arg=input_arg), input_arg) + # Assert + assert result == [input_arg] - def test_ensure_list_arg_with_keyword_arg_non_list(self): - input_arg = 1 - self.assertEqual(self.sample_function(arg=input_arg), [input_arg]) - def test_ensure_list_arg_with_missing_arg(self): - with self.assertRaises(TypeError): - self.sample_function() +def test_ensure_list_arg_with_missing_arg(): + """Raises TypeError when the decorated arg is missing.""" + # Arrange + + # Act / Assert + with pytest.raises(TypeError): + sample_function() -class TestExecuteInParallelU(unittest.TestCase): - def _func(self, v1, v2): +@pytest.fixture +def parallel_setup(): + state = {'delay': 0} + + def _func(v1, v2): if v1 == 1: - time.sleep(self.delay) + time.sleep(state['delay']) return 'result1' elif v2 == 2: return 'result2' else: return 'unknown' - def setUp(self): - self.delay = 0 - self.func = MagicMock(side_effect=self._func) - self.func.__name__ = 'TEST_FUNCTION' - self.requests_dict = {'req1': {'args': [1, 0], 'kwargs': {}}, 'req2': {'args': [0], 'kwargs': {'v2': 2}}} - self.requests_list = [{'args': [1, 0], 'kwargs': {}}, {'args': [0], 'kwargs': {'v2': 2}}] - - def test_execute_in_parallel_with_dict(self): - results = execute_in_parallel(self.func, self.requests_dict) - self.assertEqual(results, {'req1': 'result1', 'req2': 'result2'}) - self.assertEqual(self.func.call_count, 2) - - def test_execute_in_parallel_with_list(self): - self.delay = 0.1 - results = execute_in_parallel(self.func, self.requests_list) - self.assertEqual(results, ['result1', 'result2']) - self.assertEqual(self.func.call_count, 2) - - def test_execute_with_key_success(self): - result = execute_with_key('key', self.func, 1, v2=2) - self.func.assert_called_with(1, v2=2) - self.assertEqual(result, ('key', 'result1')) - - def test_execute_with_key_exception(self): - self.func.side_effect = Exception('error') - result = execute_with_key('key', self.func, 1, v2=2) - self.assertIsInstance(result[1], Exception) - - def test_execute_in_parallel_rate_limiting(self): - start_time = time.time() - - # Simulate a slow function to test rate limiting - def slow_func(): - time.sleep(0.05) - return 'slow_result' - - requests = {i: {'args': [], 'kwargs': {}} for i in range(20)} # 10 requests - max_per_second = 10 # Limit to 5 requests per second - results = execute_in_parallel(slow_func, requests, max_per_second=max_per_second) - - duration = time.time() - start_time - self.assertGreaterEqual(duration, 1.05) # Should take at least 1.1 seconds to complete all requests - self.assertEqual(len(results), 20) - - -class TestWaitUntilU(unittest.TestCase): - def test_wait_until_condition_met(self): - condition = MagicMock(return_value=True) - self.assertTrue(wait_until(condition)) - condition.assert_called() - - def test_wait_until_condition_not_met(self): - condition = MagicMock(return_value=False) - self.assertFalse(wait_until(condition, timeout=0.1)) - condition.assert_called() - - @patch('ibind.support.py_utils._LOGGER.error') - def test_wait_until_timeout_message(self, mock_logger_error): - condition = MagicMock(return_value=False) - timeout_message = 'Condition not met within timeout' - self.assertFalse(wait_until(condition, timeout_message=timeout_message, timeout=0.1)) - mock_logger_error.assert_called_with(timeout_message) - - def test_wait_until_timeout(self): - start_time = time.time() - condition = MagicMock(return_value=False) - timeout = 0.1 - self.assertFalse(wait_until(condition, timeout=timeout)) - duration = time.time() - start_time - self.assertAlmostEqual(duration, timeout, delta=0.02) + func = MagicMock(side_effect=_func) + func.__name__ = 'TEST_FUNCTION' + requests_dict = {'req1': {'args': [1, 0], 'kwargs': {}}, 'req2': {'args': [0], 'kwargs': {'v2': 2}}} + requests_list = [{'args': [1, 0], 'kwargs': {}}, {'args': [0], 'kwargs': {'v2': 2}}] + + return { + 'state': state, + 'func': func, + 'requests_dict': requests_dict, + 'requests_list': requests_list, + } + + +def test_execute_in_parallel_with_dict(parallel_setup): + """Executes requests in parallel when passed a dict of requests.""" + # Arrange + func = parallel_setup['func'] + requests = parallel_setup['requests_dict'] + + # Act + results = execute_in_parallel(func, requests) + + # Assert + assert results == {'req1': 'result1', 'req2': 'result2'} + assert func.call_count == 2 + + +def test_execute_in_parallel_with_list(parallel_setup): + """Executes requests in parallel when passed a list of requests.""" + # Arrange + func = parallel_setup['func'] + requests = parallel_setup['requests_list'] + parallel_setup['state']['delay'] = 0.1 + + # Act + results = execute_in_parallel(func, requests) + + # Assert + assert results == ['result1', 'result2'] + assert func.call_count == 2 + + +def test_execute_with_key_success(parallel_setup): + """Returns (key, result) when the wrapped function succeeds.""" + # Arrange + func = parallel_setup['func'] + + # Act + result = execute_with_key('key', func, 1, v2=2) + + # Assert + func.assert_called_with(1, v2=2) + assert result == ('key', 'result1') + + +def test_execute_with_key_exception(parallel_setup): + """Returns (key, exception) when the wrapped function raises.""" + # Arrange + func = parallel_setup['func'] + func.side_effect = Exception('error') + + # Act + result = execute_with_key('key', func, 1, v2=2) + + # Assert + assert isinstance(result[1], Exception) + + +def test_execute_in_parallel_rate_limiting(): + """Applies max_per_second rate limiting across parallel executions.""" + # Arrange + start_time = time.time() + + # Simulate a slow function to test rate limiting + def slow_func(): + time.sleep(0.05) + return 'slow_result' + + requests = {i: {'args': [], 'kwargs': {}} for i in range(20)} # 10 requests + max_per_second = 10 # Limit to 5 requests per second + + # Act + results = execute_in_parallel(slow_func, requests, max_per_second=max_per_second) + + # Assert + duration = time.time() - start_time + assert duration >= 1.05 # Should take at least 1.1 seconds to complete all requests + assert len(results) == 20 + + +def test_wait_until_condition_met(): + """Returns True immediately when the condition is already met.""" + # Arrange + condition = MagicMock(return_value=True) + + # Act + result = wait_until(condition) + + # Assert + assert result is True + condition.assert_called() + + +def test_wait_until_condition_not_met(): + """Returns False when the condition is not met before timeout.""" + # Arrange + condition = MagicMock(return_value=False) + + # Act + result = wait_until(condition, timeout=0.1) + + # Assert + assert result is False + condition.assert_called() + + +def test_wait_until_timeout_message(mocker): + """Logs the timeout_message when the deadline is reached.""" + # Arrange + mock_logger_error = mocker.patch('ibind.support.py_utils._LOGGER.error') + condition = MagicMock(return_value=False) + timeout_message = 'Condition not met within timeout' + + # Act + result = wait_until(condition, timeout_message=timeout_message, timeout=0.1) + + # Assert + assert result is False + mock_logger_error.assert_called_with(timeout_message) + + +def test_wait_until_timeout(): + """Waits roughly the specified timeout duration before returning False.""" + # Arrange + start_time = time.time() + condition = MagicMock(return_value=False) + timeout = 0.1 + + # Act + result = wait_until(condition, timeout=timeout) + + # Assert + assert result is False + duration = time.time() - start_time + assert duration == pytest.approx(timeout, abs=0.02) \ No newline at end of file