From 118fffa7a174d52780617a3c2292bb0f915f531d Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 10:17:35 +0000 Subject: [PATCH 01/31] Feat: Create unittest to pytest migration plan This change creates a detailed migration plan to convert all existing unittest-based tests to pytest. The plan is structured into chunks to allow for parallel work by multiple AI agents. The plan is located in `test/migration_plan.md` and includes: - A method for identifying all unittest files. - General guidelines for the migration. - A file-by-file breakdown with specific migration steps and potential challenges. --- test/migration_plan.md | 72 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 test/migration_plan.md diff --git a/test/migration_plan.md b/test/migration_plan.md new file mode 100644 index 00000000..d4a73d3c --- /dev/null +++ b/test/migration_plan.md @@ -0,0 +1,72 @@ +# Unittest to Pytest Migration Plan + +This document outlines the roadmap for migrating our existing `unittest`-based tests to `pytest`. The goal of this migration is to modernize our testing suite, improve readability, and take advantage of `pytest`'s powerful features, such as fixtures and improved assertions. + +This plan is designed to be executed by multiple AI agents in parallel, with each agent working on a separate file. + +## Identifying `unittest` Files + +To ensure a complete migration, a systematic search was performed across the `test/` directory to identify all files using the `unittest` framework. This was accomplished by running the following command: + +```bash +grep -r "import unittest" test/ +``` + +This command recursively searches for the string `"import unittest"` in all files within the `test/` directory. The output of this command is the definitive list of files that need to be migrated. + +## General Guidelines + +When migrating tests, please adhere to the following principles: + +- **Test Classes:** Convert `unittest.TestCase` subclasses into plain test functions. If a class structure is still beneficial for grouping related tests, you can use a class without inheriting from `unittest.TestCase`. +- **`setUp` and `tearDown`:** Replace `setUp` and `tearDown` methods with `pytest` fixtures. This is the preferred way to manage test setup and teardown in `pytest`. +- **Assertions:** Convert all `self.assert...` methods to plain `assert` statements. For example, `self.assertEqual(a, b)` becomes `assert a == b`. `pytest` provides detailed output for failing assertions. +- **Exception Handling:** Replace `with self.assertRaises(...)` with `with pytest.raises(...)`. +- **Logging:** Use the built-in `caplog` fixture to test log messages. This is the standard `pytest` way to handle logging. +- **Arrange, Act, Assert:** Structure your tests using the Arrange, Act, Assert pattern to improve readability and maintainability. +- **Parametrization:** Use `@pytest.mark.parametrize` to run the same test with different inputs. This is a powerful feature for reducing code duplication. + +## Migration Chunks + +The following files need to be migrated. Each file can be worked on by a separate AI agent. + +### `test/test_utils.py` + +**Current Structure:** +- This file contains several helper functions and classes for testing, including `SafeAssertLogs`, `RaiseLogsContext`, `TestCaseWithRaiseLogs`, and a decorator `raise_logs`. +- These utilities are tightly coupled with the `unittest` framework. + +**Migration Steps:** +1. **Refactor `SafeAssertLogs`:** This class can be replaced with the built-in `caplog` fixture in `pytest`. The tests that use this class will need to be updated to use `caplog`. +2. **Refactor `RaiseLogsContext`:** This context manager can also be replaced with the `caplog` fixture. The logic for raising an exception on unexpected log messages will need to be implemented within the tests themselves. +3. **Refactor `raise_logs` decorator:** This decorator should be removed. The tests that use it will need to be updated to use `caplog` and the logic for checking for unexpected log messages. +4. **Refactor `TestCaseWithRaiseLogs`:** This class should be removed. The tests that inherit from it will need to be converted to plain test functions. +5. **Remove `unittest` imports:** Once all the `unittest`-dependent code has been refactored, the `import unittest` statement can be removed. + +**Potential Challenges:** +- The logic in `RaiseLogsContext` for raising exceptions on unexpected log messages is complex. This will need to be carefully replicated in the tests that use this context manager. + +### `test/unit/support/test_py_utils_u.py` + +**Current Structure:** +- This file contains several `unittest.TestCase` subclasses with `setUp` methods and `self.assert...` statements. +- It tests the functions in `ibind/support/py_utils.py`. + +**Migration Steps:** +1. **Convert `TestEnsureListArgU`:** + - Convert the class to a set of test functions. + - Use `assert` for all assertions. + - Use `@pytest.mark.parametrize` to reduce code duplication for the different test cases. +2. **Convert `TestExecuteInParallelU`:** + - Convert the class to a set of test functions. + - Use a `pytest` fixture to replace the `setUp` method. + - Use `assert` for all assertions. + - Use `@pytest.mark.parametrize` where appropriate. +3. **Convert `TestWaitUntilU`:** + - Convert the class to a set of test functions. + - Use `assert` for all assertions. + - Use the `caplog` fixture to test the log message for the timeout. +4. **Remove `unittest` imports:** Once all the `unittest`-dependent code has been refactored, the `import unittest` statement can be removed. + +**Potential Challenges:** +- There are no significant challenges expected for this file. The migration should be straightforward. From b3a3fed4dbc2b78bf1771b4a89d3ff0a5b126f42 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 10:26:43 +0000 Subject: [PATCH 02/31] Feat: Create unittest to pytest migration plan This change creates a detailed migration plan to convert all existing unittest-based tests to pytest. The plan is structured into chunks to allow for parallel work. The plan is located in `test/migration_plan.md` and includes: - A method for identifying all unittest files. - General guidelines for the migration. - A file-by-file breakdown with specific migration steps and potential challenges. --- test/migration_plan.md | 58 ++++++++++-------------------------------- 1 file changed, 13 insertions(+), 45 deletions(-) diff --git a/test/migration_plan.md b/test/migration_plan.md index d4a73d3c..2191536f 100644 --- a/test/migration_plan.md +++ b/test/migration_plan.md @@ -2,17 +2,15 @@ This document outlines the roadmap for migrating our existing `unittest`-based tests to `pytest`. The goal of this migration is to modernize our testing suite, improve readability, and take advantage of `pytest`'s powerful features, such as fixtures and improved assertions. -This plan is designed to be executed by multiple AI agents in parallel, with each agent working on a separate file. - ## Identifying `unittest` Files To ensure a complete migration, a systematic search was performed across the `test/` directory to identify all files using the `unittest` framework. This was accomplished by running the following command: ```bash -grep -r "import unittest" test/ +grep -r -E "import unittest|from unittest" test/ ``` -This command recursively searches for the string `"import unittest"` in all files within the `test/` directory. The output of this command is the definitive list of files that need to be migrated. +This command recursively searches for `unittest` imports in all files within the `test/` directory. The output of this command is the definitive list of files that need to be migrated. ## General Guidelines @@ -22,51 +20,21 @@ When migrating tests, please adhere to the following principles: - **`setUp` and `tearDown`:** Replace `setUp` and `tearDown` methods with `pytest` fixtures. This is the preferred way to manage test setup and teardown in `pytest`. - **Assertions:** Convert all `self.assert...` methods to plain `assert` statements. For example, `self.assertEqual(a, b)` becomes `assert a == b`. `pytest` provides detailed output for failing assertions. - **Exception Handling:** Replace `with self.assertRaises(...)` with `with pytest.raises(...)`. -- **Logging:** Use the built-in `caplog` fixture to test log messages. This is the standard `pytest` way to handle logging. +- **Logging:** The `test_utils.py` file will be updated manually to provide a `capture_logs` fixture. This fixture will replace the `SafeAssertLogs`, `RaiseLogsContext`, `TestCaseWithRaiseLogs`, and `raise_logs` decorator. Use the `capture_logs` fixture to test log messages. - **Arrange, Act, Assert:** Structure your tests using the Arrange, Act, Assert pattern to improve readability and maintainability. - **Parametrization:** Use `@pytest.mark.parametrize` to run the same test with different inputs. This is a powerful feature for reducing code duplication. ## Migration Chunks -The following files need to be migrated. Each file can be worked on by a separate AI agent. - -### `test/test_utils.py` - -**Current Structure:** -- This file contains several helper functions and classes for testing, including `SafeAssertLogs`, `RaiseLogsContext`, `TestCaseWithRaiseLogs`, and a decorator `raise_logs`. -- These utilities are tightly coupled with the `unittest` framework. - -**Migration Steps:** -1. **Refactor `SafeAssertLogs`:** This class can be replaced with the built-in `caplog` fixture in `pytest`. The tests that use this class will need to be updated to use `caplog`. -2. **Refactor `RaiseLogsContext`:** This context manager can also be replaced with the `caplog` fixture. The logic for raising an exception on unexpected log messages will need to be implemented within the tests themselves. -3. **Refactor `raise_logs` decorator:** This decorator should be removed. The tests that use it will need to be updated to use `caplog` and the logic for checking for unexpected log messages. -4. **Refactor `TestCaseWithRaiseLogs`:** This class should be removed. The tests that inherit from it will need to be converted to plain test functions. -5. **Remove `unittest` imports:** Once all the `unittest`-dependent code has been refactored, the `import unittest` statement can be removed. - -**Potential Challenges:** -- The logic in `RaiseLogsContext` for raising exceptions on unexpected log messages is complex. This will need to be carefully replicated in the tests that use this context manager. - -### `test/unit/support/test_py_utils_u.py` - -**Current Structure:** -- This file contains several `unittest.TestCase` subclasses with `setUp` methods and `self.assert...` statements. -- It tests the functions in `ibind/support/py_utils.py`. +The following files need to be migrated. Each file can be worked on independently. -**Migration Steps:** -1. **Convert `TestEnsureListArgU`:** - - Convert the class to a set of test functions. - - Use `assert` for all assertions. - - Use `@pytest.mark.parametrize` to reduce code duplication for the different test cases. -2. **Convert `TestExecuteInParallelU`:** - - Convert the class to a set of test functions. - - Use a `pytest` fixture to replace the `setUp` method. - - Use `assert` for all assertions. - - Use `@pytest.mark.parametrize` where appropriate. -3. **Convert `TestWaitUntilU`:** - - Convert the class to a set of test functions. - - Use `assert` for all assertions. - - Use the `caplog` fixture to test the log message for the timeout. -4. **Remove `unittest` imports:** Once all the `unittest`-dependent code has been refactored, the `import unittest` statement can be removed. +1. [ ] `test/e2e/xtest_ibkr_client_e.py` +2. [ ] `test/integration/base/test_rest_client_i.py` +3. [ ] `test/integration/base/test_websocket_client_i.py` +4. [ ] `test/integration/base/websocketapp_mock.py` +5. [ ] `test/integration/client/test_ibkr_client_i.py` +6. [ ] `test/integration/client/test_ibkr_utils_i.py` +7. [ ] `test/integration/client/test_ibkr_ws_client_i.py` +8. [ ] `test/unit/support/test_py_utils_u.py` -**Potential Challenges:** -- There are no significant challenges expected for this file. The migration should be straightforward. +**Note:** `test/test_utils.py` will be updated manually by a human to provide a `capture_logs` fixture. This fixture will be used in the migrated tests. From 19b2af85596c2774e291aebb24047d066437a99c Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 10:42:37 +0000 Subject: [PATCH 03/31] Feat: Create detailed unittest to pytest migration plan This change creates a comprehensive, file-by-file migration plan to convert all existing unittest-based tests to pytest. The plan is structured into individually actionable chunks. The plan is located in `test/migration_plan.md` and includes: - General guidelines for the migration. - A detailed analysis of each unittest file. - Specific migration steps and potential challenges for each file. --- test/migration_plan.md | 98 +++++++++++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/test/migration_plan.md b/test/migration_plan.md index 2191536f..76f3ad6d 100644 --- a/test/migration_plan.md +++ b/test/migration_plan.md @@ -2,16 +2,6 @@ This document outlines the roadmap for migrating our existing `unittest`-based tests to `pytest`. The goal of this migration is to modernize our testing suite, improve readability, and take advantage of `pytest`'s powerful features, such as fixtures and improved assertions. -## Identifying `unittest` Files - -To ensure a complete migration, a systematic search was performed across the `test/` directory to identify all files using the `unittest` framework. This was accomplished by running the following command: - -```bash -grep -r -E "import unittest|from unittest" test/ -``` - -This command recursively searches for `unittest` imports in all files within the `test/` directory. The output of this command is the definitive list of files that need to be migrated. - ## General Guidelines When migrating tests, please adhere to the following principles: @@ -20,7 +10,7 @@ When migrating tests, please adhere to the following principles: - **`setUp` and `tearDown`:** Replace `setUp` and `tearDown` methods with `pytest` fixtures. This is the preferred way to manage test setup and teardown in `pytest`. - **Assertions:** Convert all `self.assert...` methods to plain `assert` statements. For example, `self.assertEqual(a, b)` becomes `assert a == b`. `pytest` provides detailed output for failing assertions. - **Exception Handling:** Replace `with self.assertRaises(...)` with `with pytest.raises(...)`. -- **Logging:** The `test_utils.py` file will be updated manually to provide a `capture_logs` fixture. This fixture will replace the `SafeAssertLogs`, `RaiseLogsContext`, `TestCaseWithRaiseLogs`, and `raise_logs` decorator. Use the `capture_logs` fixture to test log messages. +- **Logging:** The `test_utils.py` file will be updated manually to provide a `capture_logs` fixture. This fixture will replace the `SafeAssertLogs`, `RaiseLogsContext`, `TestCaseWithRaiseLogs`, and `raise_logs` decorator. Use the `capture_logs` fixture to test log messages. The built-in `caplog` fixture can also be used for simple cases. - **Arrange, Act, Assert:** Structure your tests using the Arrange, Act, Assert pattern to improve readability and maintainability. - **Parametrization:** Use `@pytest.mark.parametrize` to run the same test with different inputs. This is a powerful feature for reducing code duplication. @@ -28,13 +18,81 @@ When migrating tests, please adhere to the following principles: The following files need to be migrated. Each file can be worked on independently. -1. [ ] `test/e2e/xtest_ibkr_client_e.py` -2. [ ] `test/integration/base/test_rest_client_i.py` -3. [ ] `test/integration/base/test_websocket_client_i.py` -4. [ ] `test/integration/base/websocketapp_mock.py` -5. [ ] `test/integration/client/test_ibkr_client_i.py` -6. [ ] `test/integration/client/test_ibkr_utils_i.py` -7. [ ] `test/integration/client/test_ibkr_ws_client_i.py` -8. [ ] `test/unit/support/test_py_utils_u.py` +**Note:** `test/test_utils.py` will be updated manually to provide a `capture_logs` fixture. This fixture will be used in the migrated tests. + +--- + +### 1. [ ] `test/integration/base/test_rest_client_i.py` + +- **Current Structure:** Contains three `unittest.TestCase` subclasses: `TestRestClientI`, `TestRestClientInThread`, and `TestRestClientAsync`. It uses a class-level `@patch` decorator, `setUp` methods, and various `self.assert...` methods, including `self.assertLogs` and `self.assertRaises`. +- **Migration Steps:** + 1. Convert the `TestRestClientI` class into a series of test functions. + 2. Replace the `setUp` method's logic with a `pytest` fixture that provides a configured `RestClient` instance. + 3. Convert all `self.assertEqual` and `self.assertRaises` calls to plain `assert` statements and `with pytest.raises(...)`. + 4. Replace `with self.assertLogs(...)` with the `caplog` fixture for log capture and assertion. + 5. Refactor the class-level `@patch('ibind.base.rest_client.requests')` to use the `mocker` fixture from `pytest-mock` within each test function that needs it. + 6. Convert the `TestRestClientInThread` and `TestRestClientAsync` classes to simple test functions; their internal logic does not require a class structure. +- **Potential Challenges:** The class-level patching needs to be carefully applied to each test function that relies on it, likely using the `mocker.patch` method. + +--- + +### 2. [ ] `test/integration/base/test_websocket_client_i.py` + +- **Current Structure:** Contains a single `TestWsClient(TestCase)` class with a complex `setUp` method. It heavily relies on a custom `run_in_test_context` helper method that sets up multiple patches and log handlers (`self.assertLogs`, `RaiseLogsContext`). +- **Migration Steps:** + 1. Convert the `TestWsClient` class into a series of test functions. + 2. The logic within the `setUp` method should be moved into one or more `pytest` fixtures. + 3. The `run_in_test_context` helper method must be refactored. Its functionality (patching, log capturing) should be moved into a dedicated fixture. + 4. Replace `self.assertLogs` and the custom `RaiseLogsContext` with the new `capture_logs` fixture. + 5. Convert all `self.assertTrue` and `self.assertFalse` calls to plain `assert` statements. +- **Potential Challenges:** The `run_in_test_context` method is complex. Migrating its logic into a `pytest` fixture that correctly manages setup and teardown of patches will be the most challenging part of this file's migration. + +--- + +### 3. [ ] `test/integration/client/test_ibkr_client_i.py` + +- **Current Structure:** Consists of a single `TestIbkrClientI(TestCase)` class that uses a class-level `@patch`, a `setUp` method, and a wide variety of `self.assert...` methods. +- **Migration Steps:** + 1. Convert the `TestIbkrClientI` class into a series of test functions. + 2. Move the `setUp` logic into a `pytest` fixture. + 3. Replace all `self.assert...` calls (e.g., `assertEqual`, `assertIn`, `assertRaises`, `assertAlmostEqual`, `assertTrue`) with plain `assert` statements and `pytest.raises`. + 4. Replace `with self.assertLogs(...)` and `RaiseLogsContext` with the `capture_logs` fixture or `caplog`. + 5. Handle the class-level patch using the `mocker` fixture in each relevant test function. +- **Potential Challenges:** The `test_marketdata_history_by_symbols` test has a complex mock side effect (`_marketdata_request`). This logic should be extracted into a helper function or a fixture to maintain readability. + +--- + +### 4. [ ] `test/integration/client/test_ibkr_utils_i.py` + +- **Current Structure:** Contains four `TestCase` subclasses: `TestIbkrUtilsI`, `TestFindAnswer`, `TestHandleQuestionsI`, and `TestParseOrderRequestI`. These classes use `setUp` methods and various assertions. +- **Migration Steps:** + 1. Convert all four classes into separate sets of test functions. The class names can be used as prefixes for the function names to maintain grouping (e.g., `test_ibkr_utils_filter_stocks`). + 2. Move `setUp` logic into fixtures where applicable. + 3. Convert all `self.assert...` calls to plain `assert` and `pytest.raises`. + 4. Replace `with self.assertLogs(...)` with the `caplog` fixture. +- **Potential Challenges:** This file appears to be a straightforward migration with no significant challenges. + +--- + +### 5. [ ] `test/integration/client/test_ibkr_ws_client_i.py` + +- **Current Structure:** Contains two `TestCase` subclasses: `TestPreprocessRawMessage` and `TestIbkrWsClient`. The `TestIbkrWsClient` class is complex, with a detailed `setUp` method and a `run_in_test_context` helper method similar to the one in `test_websocket_client_i.py`. +- **Migration Steps:** + 1. Convert both `TestCase` subclasses into sets of test functions. + 2. Move the extensive `setUp` logic from `TestIbkrWsClient` into `pytest` fixtures. + 3. Refactor the `run_in_test_context` helper method into a dedicated fixture that handles patching and log capturing. + 4. Replace `SafeAssertLogs` and `RaiseLogsContext` with the new `capture_logs` fixture. + 5. Convert all `self.assert...` calls to plain `assert` statements. +- **Potential Challenges:** Similar to `test_websocket_client_i.py`, the primary challenge is refactoring the `run_in_test_context` method into a robust and readable `pytest` fixture. + +--- + +### 6. [ ] `test/unit/support/test_py_utils_u.py` -**Note:** `test/test_utils.py` will be updated manually by a human to provide a `capture_logs` fixture. This fixture will be used in the migrated tests. +- **Current Structure:** Contains three `TestCase` subclasses: `TestEnsureListArgU`, `TestExecuteInParallelU`, and `TestWaitUntilU`. It uses `setUp`, a variety of `self.assert...` methods, and `with self.assertRaises`. +- **Migration Steps:** + 1. Convert all three classes into separate sets of test functions. + 2. Move the `setUp` method from `TestExecuteInParallelU` into a fixture. + 3. Convert all `self.assert...` methods and `with self.assertRaises` to plain `assert` statements and `with pytest.raises(...)`. + 4. The `@patch` decorator in `test_wait_until_timeout_message` can be replaced with the `mocker` fixture. +- **Potential Challenges:** This file should be a straightforward migration. From 9d37a0104b238482e9ef6a275a035e392b414a91 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 10:56:17 +0000 Subject: [PATCH 04/31] feat: Add pytest test utils and unittest migration plan Adds a new pytest-friendly test utility file at `test/test_utils_new.py`. This file includes the `capture_logs` context manager and decorator for advanced log testing. Creates a detailed migration plan in `test/migration_plan.md` to guide the conversion of all existing unittest files to pytest. The plan instructs developers to create new test files for side-by-side comparison and to use the new utilities. --- test/migration_plan.md | 86 ++++--- test/test_utils_new.py | 498 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 536 insertions(+), 48 deletions(-) create mode 100644 test/test_utils_new.py diff --git a/test/migration_plan.md b/test/migration_plan.md index 76f3ad6d..90f99bdb 100644 --- a/test/migration_plan.md +++ b/test/migration_plan.md @@ -1,98 +1,88 @@ # Unittest to Pytest Migration Plan -This document outlines the roadmap for migrating our existing `unittest`-based tests to `pytest`. The goal of this migration is to modernize our testing suite, improve readability, and take advantage of `pytest`'s powerful features, such as fixtures and improved assertions. +This document outlines the roadmap for migrating our existing `unittest`-based tests to `pytest`. The goal is to modernize our testing suite, improve readability, and take advantage of `pytest`'s powerful features. ## General Guidelines When migrating tests, please adhere to the following principles: +- **New Test Files:** To compare test coverage before and after the migration, create a new test file for the migrated tests. For example, `test/integration/base/test_rest_client_i.py` should be migrated to `test/integration/base/test_rest_client_i_new.py`. - **Test Classes:** Convert `unittest.TestCase` subclasses into plain test functions. If a class structure is still beneficial for grouping related tests, you can use a class without inheriting from `unittest.TestCase`. -- **`setUp` and `tearDown`:** Replace `setUp` and `tearDown` methods with `pytest` fixtures. This is the preferred way to manage test setup and teardown in `pytest`. -- **Assertions:** Convert all `self.assert...` methods to plain `assert` statements. For example, `self.assertEqual(a, b)` becomes `assert a == b`. `pytest` provides detailed output for failing assertions. +- **`setUp` and `tearDown`:** Replace `setUp` and `tearDown` methods with `pytest` fixtures. +- **Assertions:** Convert all `self.assert...` methods to plain `assert` statements. For example, `self.assertEqual(a, b)` becomes `assert a == b`. - **Exception Handling:** Replace `with self.assertRaises(...)` with `with pytest.raises(...)`. -- **Logging:** The `test_utils.py` file will be updated manually to provide a `capture_logs` fixture. This fixture will replace the `SafeAssertLogs`, `RaiseLogsContext`, `TestCaseWithRaiseLogs`, and `raise_logs` decorator. Use the `capture_logs` fixture to test log messages. The built-in `caplog` fixture can also be used for simple cases. -- **Arrange, Act, Assert:** Structure your tests using the Arrange, Act, Assert pattern to improve readability and maintainability. -- **Parametrization:** Use `@pytest.mark.parametrize` to run the same test with different inputs. This is a powerful feature for reducing code duplication. +- **Logging:** Use the new `capture_logs` utility from `test_utils_new.py`. It can be used as a context manager (`with capture_logs(...) as cm:`) or as a decorator (`@capture_logs(...)`). This replaces all previous `unittest`-based logging helpers. The returned watcher object has methods like `exact_log`, `partial_log`, and `log_excludes` for assertions. +- **Arrange, Act, Assert:** Structure your tests using the Arrange, Act, Assert pattern. +- **Parametrization:** Use `@pytest.mark.parametrize` to run the same test with different inputs. ## Migration Chunks The following files need to be migrated. Each file can be worked on independently. -**Note:** `test/test_utils.py` will be updated manually to provide a `capture_logs` fixture. This fixture will be used in the migrated tests. - --- ### 1. [ ] `test/integration/base/test_rest_client_i.py` -- **Current Structure:** Contains three `unittest.TestCase` subclasses: `TestRestClientI`, `TestRestClientInThread`, and `TestRestClientAsync`. It uses a class-level `@patch` decorator, `setUp` methods, and various `self.assert...` methods, including `self.assertLogs` and `self.assertRaises`. - **Migration Steps:** - 1. Convert the `TestRestClientI` class into a series of test functions. - 2. Replace the `setUp` method's logic with a `pytest` fixture that provides a configured `RestClient` instance. - 3. Convert all `self.assertEqual` and `self.assertRaises` calls to plain `assert` statements and `with pytest.raises(...)`. - 4. Replace `with self.assertLogs(...)` with the `caplog` fixture for log capture and assertion. - 5. Refactor the class-level `@patch('ibind.base.rest_client.requests')` to use the `mocker` fixture from `pytest-mock` within each test function that needs it. - 6. Convert the `TestRestClientInThread` and `TestRestClientAsync` classes to simple test functions; their internal logic does not require a class structure. -- **Potential Challenges:** The class-level patching needs to be carefully applied to each test function that relies on it, likely using the `mocker.patch` method. + 1. Create a new file: `test/integration/base/test_rest_client_i_new.py`. + 2. In the new file, convert all `TestCase` subclasses into simple test functions. + 3. Replace the `setUp` method's logic with a `pytest` fixture. + 4. Convert all `self.assert...` calls and `with self.assertRaises` to `assert` and `with pytest.raises(...)`. + 5. Replace `with self.assertLogs(...)` with the `capture_logs` context manager from `test_utils_new.py`. + 6. Refactor the class-level patch to use the `mocker` fixture within each test function. --- ### 2. [ ] `test/integration/base/test_websocket_client_i.py` -- **Current Structure:** Contains a single `TestWsClient(TestCase)` class with a complex `setUp` method. It heavily relies on a custom `run_in_test_context` helper method that sets up multiple patches and log handlers (`self.assertLogs`, `RaiseLogsContext`). - **Migration Steps:** - 1. Convert the `TestWsClient` class into a series of test functions. - 2. The logic within the `setUp` method should be moved into one or more `pytest` fixtures. - 3. The `run_in_test_context` helper method must be refactored. Its functionality (patching, log capturing) should be moved into a dedicated fixture. - 4. Replace `self.assertLogs` and the custom `RaiseLogsContext` with the new `capture_logs` fixture. - 5. Convert all `self.assertTrue` and `self.assertFalse` calls to plain `assert` statements. -- **Potential Challenges:** The `run_in_test_context` method is complex. Migrating its logic into a `pytest` fixture that correctly manages setup and teardown of patches will be the most challenging part of this file's migration. + 1. Create a new file: `test/integration/base/test_websocket_client_i_new.py`. + 2. In the new file, convert the `TestWsClient` class into a series of test functions. + 3. Move the `setUp` logic into one or more `pytest` fixtures. + 4. Eliminate the complex `run_in_test_context` helper. Use the `mocker` fixture for patching and decorate tests with `@capture_logs(...)` from `test_utils_new.py` for logging. + 5. Convert all `self.assert...` calls to plain `assert` statements. --- ### 3. [ ] `test/integration/client/test_ibkr_client_i.py` -- **Current Structure:** Consists of a single `TestIbkrClientI(TestCase)` class that uses a class-level `@patch`, a `setUp` method, and a wide variety of `self.assert...` methods. - **Migration Steps:** - 1. Convert the `TestIbkrClientI` class into a series of test functions. - 2. Move the `setUp` logic into a `pytest` fixture. - 3. Replace all `self.assert...` calls (e.g., `assertEqual`, `assertIn`, `assertRaises`, `assertAlmostEqual`, `assertTrue`) with plain `assert` statements and `pytest.raises`. - 4. Replace `with self.assertLogs(...)` and `RaiseLogsContext` with the `capture_logs` fixture or `caplog`. - 5. Handle the class-level patch using the `mocker` fixture in each relevant test function. -- **Potential Challenges:** The `test_marketdata_history_by_symbols` test has a complex mock side effect (`_marketdata_request`). This logic should be extracted into a helper function or a fixture to maintain readability. + 1. Create a new file: `test/integration/client/test_ibkr_client_i_new.py`. + 2. In the new file, convert the class into a series of test functions. + 3. Move the `setUp` logic into a `pytest` fixture. + 4. Replace all `self.assert...` calls with plain `assert` statements and `pytest.raises`. + 5. Replace the `SafeAssertLogs` and `RaiseLogsContext` with the `capture_logs` utility from `test_utils_new.py`. + 6. Handle the class-level patch using the `mocker` fixture. --- ### 4. [ ] `test/integration/client/test_ibkr_utils_i.py` -- **Current Structure:** Contains four `TestCase` subclasses: `TestIbkrUtilsI`, `TestFindAnswer`, `TestHandleQuestionsI`, and `TestParseOrderRequestI`. These classes use `setUp` methods and various assertions. - **Migration Steps:** - 1. Convert all four classes into separate sets of test functions. The class names can be used as prefixes for the function names to maintain grouping (e.g., `test_ibkr_utils_filter_stocks`). - 2. Move `setUp` logic into fixtures where applicable. - 3. Convert all `self.assert...` calls to plain `assert` and `pytest.raises`. - 4. Replace `with self.assertLogs(...)` with the `caplog` fixture. -- **Potential Challenges:** This file appears to be a straightforward migration with no significant challenges. + 1. Create a new file: `test/integration/client/test_ibkr_utils_i_new.py`. + 2. In the new file, convert all four classes into separate sets of test functions. + 3. Move `setUp` logic into fixtures where applicable. + 4. Convert all `self.assert...` calls to plain `assert` and `pytest.raises`. + 5. Replace `with self.assertLogs(...)` with the `capture_logs` context manager from `test_utils_new.py`. --- ### 5. [ ] `test/integration/client/test_ibkr_ws_client_i.py` -- **Current Structure:** Contains two `TestCase` subclasses: `TestPreprocessRawMessage` and `TestIbkrWsClient`. The `TestIbkrWsClient` class is complex, with a detailed `setUp` method and a `run_in_test_context` helper method similar to the one in `test_websocket_client_i.py`. - **Migration Steps:** - 1. Convert both `TestCase` subclasses into sets of test functions. - 2. Move the extensive `setUp` logic from `TestIbkrWsClient` into `pytest` fixtures. - 3. Refactor the `run_in_test_context` helper method into a dedicated fixture that handles patching and log capturing. - 4. Replace `SafeAssertLogs` and `RaiseLogsContext` with the new `capture_logs` fixture. + 1. Create a new file: `test/integration/client/test_ibkr_ws_client_i_new.py`. + 2. In the new file, convert both `TestCase` subclasses into sets of test functions. + 3. Move the extensive `setUp` logic into `pytest` fixtures. + 4. Eliminate the `run_in_test_context` helper. Use the `mocker` fixture for patching and `@capture_logs(...)` from `test_utils_new.py` for logging. 5. Convert all `self.assert...` calls to plain `assert` statements. -- **Potential Challenges:** Similar to `test_websocket_client_i.py`, the primary challenge is refactoring the `run_in_test_context` method into a robust and readable `pytest` fixture. --- ### 6. [ ] `test/unit/support/test_py_utils_u.py` -- **Current Structure:** Contains three `TestCase` subclasses: `TestEnsureListArgU`, `TestExecuteInParallelU`, and `TestWaitUntilU`. It uses `setUp`, a variety of `self.assert...` methods, and `with self.assertRaises`. - **Migration Steps:** - 1. Convert all three classes into separate sets of test functions. - 2. Move the `setUp` method from `TestExecuteInParallelU` into a fixture. - 3. Convert all `self.assert...` methods and `with self.assertRaises` to plain `assert` statements and `with pytest.raises(...)`. - 4. The `@patch` decorator in `test_wait_until_timeout_message` can be replaced with the `mocker` fixture. -- **Potential Challenges:** This file should be a straightforward migration. + 1. Create a new file: `test/unit/support/test_py_utils_u_new.py`. + 2. In the new file, convert all three classes into separate sets of test functions. + 3. Move the `setUp` method into a fixture. + 4. Convert all `self.assert...` methods and `with self.assertRaises` to plain `assert` statements and `with pytest.raises(...)`. + 5. Replace the `@patch` decorator with the `mocker` fixture. diff --git a/test/test_utils_new.py b/test/test_utils_new.py new file mode 100644 index 00000000..6de4830d --- /dev/null +++ b/test/test_utils_new.py @@ -0,0 +1,498 @@ +import functools +import importlib +import inspect +import logging +import os +import traceback +from pathlib import Path + +from support.slog import get_logger_children, PrettyFormatter +from utils.context_utils import make_clean_stack, accepts_kwargs +from utils.py_utils import UNDEFINED, OneOrMany + +_NAME_TO_LEVEL = logging.getLevelNamesMapping() + + + + +class LoggingWatcher: + """ + Helper class for capturing and asserting logs during testing. + + Attributes: + logger: The logger instance being watched. + records: List to store log records. + output: List to store log output messages. + + """ + + def __init__(self, logger): + """ + Initialize the LoggingWatcher. + + Args: + logger: The logger instance to watch. + """ + self.logger = logger + self.records = [] + self.output = [] + + def _process_logs(self, expected_messages: OneOrMany[str], comparison: callable = lambda x, y: x == y): + """ + Assert that all expected messages appear in the captured logs, using the given comparison function. + + Args: + expected_messages (OneOrMany[str]): Message(s) expected in the logs. + comparison (callable): Function to compare expected and actual messages (default: exact match). + + Raises: + AssertionError: If any expected message is not found in the logs according to the comparison. + """ + + if not isinstance(expected_messages, list): + expected_messages = [expected_messages] + + if not self.output: + return [], expected_messages + + messages = [msg for msg in self.output] + missing_expected = expected_messages.copy() + found = [] + for i, expected_msg in enumerate(expected_messages): + for msg in messages: + if comparison(expected_msg, msg): + found.append(msg) + missing_expected.remove(expected_msg) + break + + return found, missing_expected + + def exact_log(self, expected_messages: OneOrMany[str]): + """ + Assert that all expected messages appear in the captured logs. + + Args: + expected_messages (OneOrMany[str]): Message(s) expected in the logs. + """ + found, missing_expected = self._process_logs(expected_messages, lambda x, y: x == y) + + if len(missing_expected) > 0: + raise AssertionError("Expected exact log(s) not found:\n\t{}\n\nActual logs:\n{}\n".format('\n\t'.join(missing_expected), self.format_logs())) + + def partial_log(self, expected_messages: OneOrMany[str]): + """ + Assert that each expected message is a substring of at least one captured log message. + + Args: + expected_messages (OneOrMany[str]): Message(s) expected to be partially present in the logs. + """ + found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) + + if len(missing_expected) > 0: + raise AssertionError("Expected partial log(s) not found:\n\t{}\n\nActual logs:\n{}\n".format('\n\t'.join(missing_expected), self.format_logs())) + + def log_excludes(self, expected_messages: OneOrMany[str]): + """ + Assert that none of the expected messages appear in any captured log message. + + Args: + expected_messages (OneOrMany[str]): Message(s) that must not be present in the logs. + """ + found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) + if found: + raise AssertionError("Unexpected log(s) found:\n\t{}\n\nCurrent logs:\n{}\n".format('\n\t'.join(found), self.format_logs())) + + def format_logs(self): + """ + Return a formatted string of all captured log messages. + + Returns: + str: Formatted log output. + """ + return f"\n{self} captured {len(self.output)} logs:\n[\n\t{'\n\t'.join(self.output)}\n]" + + def count_occurrences(self, msg: str): + """ + Count the number of occurrences of a message in the captured logs. + + Args: + msg (str): Message to count occurrences of. + + Returns: + int: Number of occurrences of the message. + """ + return sum(1 for log in self.output if msg in log) + + def print(self): + """ + Print the formatted logs. + """ + print(self.format_logs()) + + def __str__(self): + return f'LoggingWatcher({self.logger.name})' + + +class _CapturingHandler(logging.Handler): + """ + A logging handler capturing all (raw and formatted) logging output. + """ + + def __init__(self, logger): + logging.Handler.__init__(self) + self.watcher = LoggingWatcher(logger) + + def flush(self): + pass + + def emit(self, record): + self.watcher.records.append(record) + msg = self.format(record) + self.watcher.output.append(msg) + + +class CaptureLogsContext: + """ + Flexible context manager for log assertion and raising on unexpected logs. + + - If no_logs is True: asserts that no logs are emitted at or above the specified level. + - If no_logs is False: asserts that logs are emitted, and all logs must match expected_errors (if provided), otherwise raises. + """ + LOGGING_FORMAT = "%(message)s" + + def __init__( + self, + logger='slog', + level='DEBUG', + logger_level: str = None, + error_level='WARNING', + no_logs=UNDEFINED, + expected_errors=None, + partial_match=False, + attach_stack=True, + ): + self._logger = logger + self.level = getattr(logging, level) if isinstance(level, str) else level + self.logger_level = getattr(logging, logger_level) if isinstance(logger_level, str) else logger_level + self.no_logs = no_logs + self.expected_errors = expected_errors or [] + self.partial_match = partial_match + self.comparison = (lambda x, y: x in y) if partial_match else (lambda x, y: x == y) + self.attach_stack = attach_stack + + # for warning/error logs we specify the minimum level separate from the main logger + self.error_level = getattr(logging, error_level) if isinstance(error_level, str) else (error_level if error_level is not None else self.level) + + if not isinstance(self.expected_errors, list): + self.expected_errors = [self.expected_errors] + + def _monkey_patch_log(self, logger): + original_log = logger._log + + def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): + # Attach cleaned stack trace + if extra is None: + extra = {} + extra['manual_trace'] = make_clean_stack(extra_filters=[os.path.join('support', 'slog.py')])[:-2] + return original_log(level, msg, args, exc_info, extra, stack_info, stacklevel) + + logger.__old_log_method__ = original_log + logger._log = new_log + + def _monkey_patch_loggers(self, loggers): + for logger in loggers: + self._monkey_patch_log(logger) + + def _restore_loggers(self, loggers): + for logger in loggers: + if hasattr(logger, '__old_log_method__'): + logger._log = logger.__old_log_method__ + + def logger_name(self): + if isinstance(self._logger, logging.Logger): + return self._logger.name + else: + return self._logger + + def acquire(self) -> LoggingWatcher: + if isinstance(self._logger, logging.Logger): + self.logger = self._logger + else: + self.logger = logging.getLogger(self._logger) + self.old_handlers = self.logger.handlers[:] + self.old_level = self.logger.level + self.old_propagate = self.logger.propagate + + formatter = PrettyFormatter(self.LOGGING_FORMAT, datefmt='%H:%M:%S', use_tags=False, print_ctx=False) + handler = _CapturingHandler(self.logger) + handler.setFormatter(formatter) + self.watcher = handler.watcher + self.logger.handlers = [handler] + handler.setLevel(self.level) + self.logger.propagate = False + if self.logger_level is not None: + self.logger.setLevel(self.logger_level) + + # Monkey-patch for stack traces + if self.attach_stack: + loggers_to_patch = [self.logger] + get_logger_children(self.logger) + self._monkey_patch_loggers(loggers_to_patch) + self._loggers_to_patch = loggers_to_patch + else: + self._loggers_to_patch = [] + + return self.watcher + + + + def _raise_unexpected_log(self, record): + if hasattr(record, 'manual_trace'): + raise RuntimeError( + '\n' + ''.join(traceback.format_list(record.manual_trace)) + + f'Logger {self.logger} logged an unexpected message:\n{record.msg}' + ) + + # Fallback to at least log the line at which the log was created + raise RuntimeError( + f'\n...\nFile "{record.pathname}", line {record.lineno} in {record.funcName}\n{record.msg}' + ) + + def _process_exit_logs(self): + records = self.watcher.records + + # 1. If no_logs: fail if any logs found + if self.no_logs is not UNDEFINED and self.no_logs: + if records: + self._raise_unexpected_log(records[0]) + return True + + # 2. If logs are expected: fail if no logs found + if self.no_logs is not UNDEFINED and not records: + raise AssertionError( + f"no logs of level {logging.getLevelName(self.level)} or higher triggered on {self.logger.name}" + ) + + # 3. Check all logs against expected_errors, but only for logs at or above error_level + for record in records: + if record.levelno < self.error_level: + continue + + # find and skip expected errors + found = any(self.comparison(expected, record.msg) for expected in self.expected_errors) + if found: + continue + + # raise any unexpected logs + self._raise_unexpected_log(record) + + if self.partial_match: + self.watcher.partial_log(self.expected_errors) + else: + self.watcher.exact_log(self.expected_errors) + + def release(self, exc_type=None, exc_val=None, exc_tb=None): + self.logger.handlers = self.old_handlers + self.logger.propagate = self.old_propagate + self.logger.setLevel(self.old_level) + + if self._loggers_to_patch: + self._restore_loggers(self._loggers_to_patch) + + self._process_exit_logs() + + if exc_type is not None: + # raise exc_type(exc_val) + return False # propagate exceptions + + return True # suppress exceptions if no error + + + def __enter__(self) -> LoggingWatcher: + return self.acquire() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self.release(exc_type, exc_val, exc_tb) + + +def capture_logs(**ctx_kwargs): + """ + Wrapper around CaptureLogsContext to make it easier to use as a decorator for the whole test function. + """ + + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + capture_log_context = CaptureLogsContext(**ctx_kwargs) + logger_name = f'_cm_{capture_log_context.logger_name()}' + fn_exc = None + log_exc = None + # try: + # with capture_log_context as cm: + # for key, val in kwargs.items(): + # # Dict containing LoggingWatcher(s) + # if inspect.isgenerator(val): + # val = next(val) + # kwargs[key] = val + # + # if isinstance(val, dict) and logger_name in val and isinstance(val[logger_name], LoggingWatcher): + # capture_log_context.watcher.output.extend(val[logger_name].output) + # capture_log_context.watcher.records.extend(val[logger_name].records) + # del val[logger_name] + + # Pass the context manager to the test if kwargs are accepted + cm = capture_log_context.acquire() + if accepts_kwargs(test_func): + kwargs[logger_name] = cm + + try: + rv = test_func(*args, **kwargs) + except Exception as e: + rv = None + fn_exc = e + + try: + capture_log_context.release() + except Exception as e2: + log_exc = e2 + + if fn_exc is not None: + if log_exc is not None: + print(f'Unexpected log found in test:') + traceback.print_exception(log_exc) + raise fn_exc + elif log_exc is not None: + raise log_exc + + return rv + # except Exception as e: + # raise + # fn_exc = e + # if fn_exc is not None: + # raise RuntimeError() from fn_exc + + + return wrapper + + return decorator + + +def make_data_dir(): + current_frame = inspect.currentframe() + current_frame = current_frame.f_back + return Path(current_frame.f_code.co_filename).parent / 'data' + + +class MockTimeController: + """ + A utility class to control time.time() calls within specific modules for testing. + + This allows tests to manually control the passage of time in specific modules, + preventing tests from hanging on timeout conditions while not affecting other modules. + """ + + def __init__(self, target_module, time_sequence=None, start_time=0.0): + """ + Initialize the mock time controller. + + Args: + target_module (str): Module path to patch (eg., 'utils.py_utils') + time_sequence (list): List of time values to return on successive calls to time.time() + start_time (float): Starting time value if time_sequence is not provided + """ + self.target_module = target_module + if time_sequence is not None: + self.time_sequence = list(time_sequence) # Make a copy + self.call_index = 0 + else: + self.time_sequence = None + self.current_time = start_time + self.original_time_module = None + + def advance_time(self, seconds): + """Advance the mocked time by the specified number of seconds.""" + if self.time_sequence is not None: + raise ValueError("Cannot advance time when using time_sequence. Use time_sequence parameter instead.") + self.current_time += seconds + + def set_time(self, time_value): + """Set the mocked time to a specific value.""" + if self.time_sequence is not None: + raise ValueError("Cannot set time when using time_sequence. Use time_sequence parameter instead.") + self.current_time = time_value + + def mock_time(self): + """Return the current mocked time.""" + if self.time_sequence is not None: + # Return values from sequence, cycling back to the last value if we run out + if self.call_index < len(self.time_sequence): + time_value = self.time_sequence[self.call_index] + self.call_index += 1 + return time_value + else: + # Return the last value repeatedly if we've exhausted the sequence + return self.time_sequence[-1] + else: + return self.current_time + + def __enter__(self): + """Context manager entry - patch time module reference in target module only.""" + # Dynamically import the target module + target_module_obj = __import__(self.target_module, fromlist=['']) + + # Store the original time module reference from target module + self.original_time_module = target_module_obj.time + + # Create a mock time module that only replaces the time() function + class MockTimeModule: + def __init__(self, original_module, mock_time_func): + self.original_module = original_module + self.time = mock_time_func + + def __getattr__(self, name): + # Delegate all other attributes to the original time module + return getattr(self.original_module, name) + + # Replace the time module reference in target module with our mock + target_module_obj.time = MockTimeModule(self.original_time_module, self.mock_time) + self.target_module_obj = target_module_obj # Store reference for cleanup + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - restore original time module reference.""" + self.target_module_obj.time = self.original_time_module + +def mock_module_time(target_module, time_sequence=None, start_time=0.0): + """ + Context manager to mock time.time() calls within any specified module. + + Usage: + # Mock time in a specific module + with mock_module_time('some.module', [0.0, 1.0, 2.0]) as time_controller: + # Code that uses time.time() in 'some.module' will get mocked values + pass + + # Mock time in multiple modules (use multiple context managers) + with mock_module_time('module1', [0.0, 1.0]), \ + mock_module_time('module2', [0.0, 2.0]): + # Both modules will have their time mocked independently + pass + + Args: + target_module (str): Module path to patch (eg., 'utils.py_utils', 'some.other.module') + time_sequence (list): List of time values to return on successive calls to time.time() + start_time (float): Initial time value to start with (ignored if time_sequence is provided) + + Returns: + MockTimeController: Controller object to manipulate time + """ + return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) + +def import_all_modules(): + os.environ['DOTENV_PATH'] = 'UNDEFINED' # disable loading .env files + engine_dir = Path(__file__).parent.parent / 'engine' + for py_path in engine_dir.rglob('*.py'): + if py_path.name == '__init__.py' or '__pycache__' in py_path.parts: + continue + rel_path = py_path.relative_to(engine_dir.parent) + module_name = '.'.join(rel_path.with_suffix('').parts) + importlib.import_module(module_name) From 84bfb03dfc04ffb0ab7adeac58d45d23b68b1126 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 10:58:35 +0000 Subject: [PATCH 05/31] feat: Add pytest test utils and unittest migration plan Adds a new pytest-friendly test utility file at `test/test_utils_new.py`. This file includes the `capture_logs` context manager and decorator for advanced log testing. Creates a detailed migration plan in `test/migration_plan.md` to guide the conversion of all existing unittest files to pytest. The plan instructs developers to create new test files for side-by-side comparison and to use the new utilities. From e2e89bad26c9a0ac29afb27a2f0275ac11ad9784 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 10:59:11 +0000 Subject: [PATCH 06/31] feat: Add pytest test utils and unittest migration plan Adds a new pytest-friendly test utility file at `test/test_utils_new.py`. This file includes the `capture_logs` context manager and decorator for advanced log testing. Creates a detailed migration plan in `test/migration_plan.md` to guide the conversion of all existing unittest files to pytest. The plan instructs developers to create new test files for side-by-side comparison and to use the new utilities. From c897643f1641ccd569adfa5ee1854112192dce94 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 11:08:02 +0000 Subject: [PATCH 07/31] feat: Add pytest test utils and unittest migration plan Adds a new pytest-friendly test utility file at `test/test_utils_new.py`. This file includes the `capture_logs` context manager and decorator for advanced log testing, along with other test helpers. Creates a detailed migration plan in `test/migration_plan.md` to guide the conversion of all existing unittest files to pytest. The plan instructs developers to create new test files for side-by-side comparison and to use the new utilities. --- test/test_utils_new.py | 292 +++++++++-------------------------------- 1 file changed, 59 insertions(+), 233 deletions(-) diff --git a/test/test_utils_new.py b/test/test_utils_new.py index 6de4830d..af71225a 100644 --- a/test/test_utils_new.py +++ b/test/test_utils_new.py @@ -1,54 +1,59 @@ import functools -import importlib import inspect import logging import os import traceback from pathlib import Path +from typing import List, TypeVar -from support.slog import get_logger_children, PrettyFormatter -from utils.context_utils import make_clean_stack, accepts_kwargs -from utils.py_utils import UNDEFINED, OneOrMany +from ibind.support.py_utils import make_clean_stack _NAME_TO_LEVEL = logging.getLevelNamesMapping() +# --- New Functions and Types --- +def accepts_kwargs(func): + """Returns True if func accepts **kwargs, else False.""" + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return False +UNDEFINED = object() -class LoggingWatcher: +S = TypeVar('S') +OneOrMany = S | List[S] + +def get_logger_children(main_logger) -> List[logging.Logger]: """ - Helper class for capturing and asserting logs during testing. + Gets child loggers. Added as a support compat for Python version 3.11 and below. + Source: https://github.com/python/cpython/blob/3.12/Lib/logging/__init__.py#L1831 + """ + if hasattr(main_logger, 'getChildren'): + return list(main_logger.getChildren()) - Attributes: - logger: The logger instance being watched. - records: List to store log records. - output: List to store log output messages. + def _hierlevel(logger): + if logger is logger.manager.root: + return 0 + return 1 + logger.name.count('.') - """ + d = main_logger.manager.loggerDict + return [item for item in d.values() + if isinstance(item, logging.Logger) and item.parent is main_logger and + _hierlevel(item) == 1 + _hierlevel(item.parent)] - def __init__(self, logger): - """ - Initialize the LoggingWatcher. +# --- Logging Utilities --- + +class LoggingWatcher: + """Helper class for capturing and asserting logs during testing.""" - Args: - logger: The logger instance to watch. - """ + def __init__(self, logger): self.logger = logger self.records = [] self.output = [] def _process_logs(self, expected_messages: OneOrMany[str], comparison: callable = lambda x, y: x == y): - """ - Assert that all expected messages appear in the captured logs, using the given comparison function. - - Args: - expected_messages (OneOrMany[str]): Message(s) expected in the logs. - comparison (callable): Function to compare expected and actual messages (default: exact match). - - Raises: - AssertionError: If any expected message is not found in the logs according to the comparison. - """ - if not isinstance(expected_messages, list): expected_messages = [expected_messages] @@ -64,80 +69,43 @@ def _process_logs(self, expected_messages: OneOrMany[str], comparison: callable found.append(msg) missing_expected.remove(expected_msg) break - return found, missing_expected def exact_log(self, expected_messages: OneOrMany[str]): - """ - Assert that all expected messages appear in the captured logs. - - Args: - expected_messages (OneOrMany[str]): Message(s) expected in the logs. - """ + """Assert that all expected messages appear in the captured logs.""" found, missing_expected = self._process_logs(expected_messages, lambda x, y: x == y) - if len(missing_expected) > 0: - raise AssertionError("Expected exact log(s) not found:\n\t{}\n\nActual logs:\n{}\n".format('\n\t'.join(missing_expected), self.format_logs())) + raise AssertionError(f"Expected exact log(s) not found:\n\t{'\n\t'.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") def partial_log(self, expected_messages: OneOrMany[str]): - """ - Assert that each expected message is a substring of at least one captured log message. - - Args: - expected_messages (OneOrMany[str]): Message(s) expected to be partially present in the logs. - """ + """Assert that each expected message is a substring of at least one captured log message.""" found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) - if len(missing_expected) > 0: - raise AssertionError("Expected partial log(s) not found:\n\t{}\n\nActual logs:\n{}\n".format('\n\t'.join(missing_expected), self.format_logs())) + raise AssertionError(f"Expected partial log(s) not found:\n\t{'\n\t'.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") def log_excludes(self, expected_messages: OneOrMany[str]): - """ - Assert that none of the expected messages appear in any captured log message. - - Args: - expected_messages (OneOrMany[str]): Message(s) that must not be present in the logs. - """ - found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) + """Assert that none of the expected messages appear in any captured log message.""" + found, _ = self._process_logs(expected_messages, lambda x, y: x in y) if found: - raise AssertionError("Unexpected log(s) found:\n\t{}\n\nCurrent logs:\n{}\n".format('\n\t'.join(found), self.format_logs())) + raise AssertionError(f"Unexpected log(s) found:\n\t{'\n\t'.join(found)}\n\nCurrent logs:\n{self.format_logs()}\n") def format_logs(self): - """ - Return a formatted string of all captured log messages. - - Returns: - str: Formatted log output. - """ + """Return a formatted string of all captured log messages.""" return f"\n{self} captured {len(self.output)} logs:\n[\n\t{'\n\t'.join(self.output)}\n]" def count_occurrences(self, msg: str): - """ - Count the number of occurrences of a message in the captured logs. - - Args: - msg (str): Message to count occurrences of. - - Returns: - int: Number of occurrences of the message. - """ + """Count the number of occurrences of a message in the captured logs.""" return sum(1 for log in self.output if msg in log) def print(self): - """ - Print the formatted logs. - """ + """Print the formatted logs.""" print(self.format_logs()) def __str__(self): return f'LoggingWatcher({self.logger.name})' - class _CapturingHandler(logging.Handler): - """ - A logging handler capturing all (raw and formatted) logging output. - """ - + """A logging handler capturing all (raw and formatted) logging output.""" def __init__(self, logger): logging.Handler.__init__(self) self.watcher = LoggingWatcher(logger) @@ -150,14 +118,7 @@ def emit(self, record): msg = self.format(record) self.watcher.output.append(msg) - class CaptureLogsContext: - """ - Flexible context manager for log assertion and raising on unexpected logs. - - - If no_logs is True: asserts that no logs are emitted at or above the specified level. - - If no_logs is False: asserts that logs are emitted, and all logs must match expected_errors (if provided), otherwise raises. - """ LOGGING_FORMAT = "%(message)s" def __init__( @@ -179,18 +140,13 @@ def __init__( self.partial_match = partial_match self.comparison = (lambda x, y: x in y) if partial_match else (lambda x, y: x == y) self.attach_stack = attach_stack - - # for warning/error logs we specify the minimum level separate from the main logger self.error_level = getattr(logging, error_level) if isinstance(error_level, str) else (error_level if error_level is not None else self.level) - if not isinstance(self.expected_errors, list): self.expected_errors = [self.expected_errors] def _monkey_patch_log(self, logger): original_log = logger._log - def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): - # Attach cleaned stack trace if extra is None: extra = {} extra['manual_trace'] = make_clean_stack(extra_filters=[os.path.join('support', 'slog.py')])[:-2] @@ -209,21 +165,15 @@ def _restore_loggers(self, loggers): logger._log = logger.__old_log_method__ def logger_name(self): - if isinstance(self._logger, logging.Logger): - return self._logger.name - else: - return self._logger + return self._logger.name if isinstance(self._logger, logging.Logger) else self._logger def acquire(self) -> LoggingWatcher: - if isinstance(self._logger, logging.Logger): - self.logger = self._logger - else: - self.logger = logging.getLogger(self._logger) + self.logger = logging.getLogger(self.logger_name()) self.old_handlers = self.logger.handlers[:] self.old_level = self.logger.level self.old_propagate = self.logger.propagate - formatter = PrettyFormatter(self.LOGGING_FORMAT, datefmt='%H:%M:%S', use_tags=False, print_ctx=False) + formatter = logging.Formatter(self.LOGGING_FORMAT, datefmt='%H:%M:%S') handler = _CapturingHandler(self.logger) handler.setFormatter(formatter) self.watcher = handler.watcher @@ -233,7 +183,6 @@ def acquire(self) -> LoggingWatcher: if self.logger_level is not None: self.logger.setLevel(self.logger_level) - # Monkey-patch for stack traces if self.attach_stack: loggers_to_patch = [self.logger] + get_logger_children(self.logger) self._monkey_patch_loggers(loggers_to_patch) @@ -243,46 +192,26 @@ def acquire(self) -> LoggingWatcher: return self.watcher - - def _raise_unexpected_log(self, record): if hasattr(record, 'manual_trace'): - raise RuntimeError( - '\n' + ''.join(traceback.format_list(record.manual_trace)) - + f'Logger {self.logger} logged an unexpected message:\n{record.msg}' - ) - - # Fallback to at least log the line at which the log was created - raise RuntimeError( - f'\n...\nFile "{record.pathname}", line {record.lineno} in {record.funcName}\n{record.msg}' - ) + raise RuntimeError(f'\n{"".join(traceback.format_list(record.manual_trace))}Logger {self.logger} logged an unexpected message:\n{record.msg}') + raise RuntimeError(f'\n...\nFile "{record.pathname}", line {record.lineno} in {record.funcName}\n{record.msg}') def _process_exit_logs(self): records = self.watcher.records - - # 1. If no_logs: fail if any logs found if self.no_logs is not UNDEFINED and self.no_logs: if records: self._raise_unexpected_log(records[0]) return True - # 2. If logs are expected: fail if no logs found if self.no_logs is not UNDEFINED and not records: - raise AssertionError( - f"no logs of level {logging.getLevelName(self.level)} or higher triggered on {self.logger.name}" - ) + raise AssertionError(f"no logs of level {logging.getLevelName(self.level)} or higher triggered on {self.logger.name}") - # 3. Check all logs against expected_errors, but only for logs at or above error_level for record in records: if record.levelno < self.error_level: continue - - # find and skip expected errors - found = any(self.comparison(expected, record.msg) for expected in self.expected_errors) - if found: + if any(self.comparison(expected, record.msg) for expected in self.expected_errors): continue - - # raise any unexpected logs self._raise_unexpected_log(record) if self.partial_match: @@ -294,18 +223,10 @@ def release(self, exc_type=None, exc_val=None, exc_tb=None): self.logger.handlers = self.old_handlers self.logger.propagate = self.old_propagate self.logger.setLevel(self.old_level) - if self._loggers_to_patch: self._restore_loggers(self._loggers_to_patch) - self._process_exit_logs() - - if exc_type is not None: - # raise exc_type(exc_val) - return False # propagate exceptions - - return True # suppress exceptions if no error - + return exc_type is None def __enter__(self) -> LoggingWatcher: return self.acquire() @@ -313,12 +234,7 @@ def __enter__(self) -> LoggingWatcher: def __exit__(self, exc_type, exc_val, exc_tb): return self.release(exc_type, exc_val, exc_tb) - def capture_logs(**ctx_kwargs): - """ - Wrapper around CaptureLogsContext to make it easier to use as a decorator for the whole test function. - """ - def decorator(test_func): @functools.wraps(test_func) def wrapper(*args, **kwargs): @@ -326,20 +242,7 @@ def wrapper(*args, **kwargs): logger_name = f'_cm_{capture_log_context.logger_name()}' fn_exc = None log_exc = None - # try: - # with capture_log_context as cm: - # for key, val in kwargs.items(): - # # Dict containing LoggingWatcher(s) - # if inspect.isgenerator(val): - # val = next(val) - # kwargs[key] = val - # - # if isinstance(val, dict) and logger_name in val and isinstance(val[logger_name], LoggingWatcher): - # capture_log_context.watcher.output.extend(val[logger_name].output) - # capture_log_context.watcher.records.extend(val[logger_name].records) - # del val[logger_name] - - # Pass the context manager to the test if kwargs are accepted + cm = capture_log_context.acquire() if accepts_kwargs(test_func): kwargs[logger_name] = cm @@ -357,51 +260,23 @@ def wrapper(*args, **kwargs): if fn_exc is not None: if log_exc is not None: - print(f'Unexpected log found in test:') + print('Unexpected log found in test:') traceback.print_exception(log_exc) raise fn_exc elif log_exc is not None: raise log_exc return rv - # except Exception as e: - # raise - # fn_exc = e - # if fn_exc is not None: - # raise RuntimeError() from fn_exc - - return wrapper - return decorator - -def make_data_dir(): - current_frame = inspect.currentframe() - current_frame = current_frame.f_back - return Path(current_frame.f_code.co_filename).parent / 'data' - +# --- Time Mocking Utilities --- class MockTimeController: - """ - A utility class to control time.time() calls within specific modules for testing. - - This allows tests to manually control the passage of time in specific modules, - preventing tests from hanging on timeout conditions while not affecting other modules. - """ - def __init__(self, target_module, time_sequence=None, start_time=0.0): - """ - Initialize the mock time controller. - - Args: - target_module (str): Module path to patch (eg., 'utils.py_utils') - time_sequence (list): List of time values to return on successive calls to time.time() - start_time (float): Starting time value if time_sequence is not provided - """ self.target_module = target_module if time_sequence is not None: - self.time_sequence = list(time_sequence) # Make a copy + self.time_sequence = list(time_sequence) self.call_index = 0 else: self.time_sequence = None @@ -409,90 +284,41 @@ def __init__(self, target_module, time_sequence=None, start_time=0.0): self.original_time_module = None def advance_time(self, seconds): - """Advance the mocked time by the specified number of seconds.""" if self.time_sequence is not None: - raise ValueError("Cannot advance time when using time_sequence. Use time_sequence parameter instead.") + raise ValueError("Cannot advance time when using time_sequence.") self.current_time += seconds def set_time(self, time_value): - """Set the mocked time to a specific value.""" if self.time_sequence is not None: - raise ValueError("Cannot set time when using time_sequence. Use time_sequence parameter instead.") + raise ValueError("Cannot set time when using time_sequence.") self.current_time = time_value def mock_time(self): - """Return the current mocked time.""" if self.time_sequence is not None: - # Return values from sequence, cycling back to the last value if we run out if self.call_index < len(self.time_sequence): time_value = self.time_sequence[self.call_index] self.call_index += 1 return time_value else: - # Return the last value repeatedly if we've exhausted the sequence return self.time_sequence[-1] else: return self.current_time def __enter__(self): - """Context manager entry - patch time module reference in target module only.""" - # Dynamically import the target module target_module_obj = __import__(self.target_module, fromlist=['']) - - # Store the original time module reference from target module self.original_time_module = target_module_obj.time - - # Create a mock time module that only replaces the time() function class MockTimeModule: def __init__(self, original_module, mock_time_func): self.original_module = original_module self.time = mock_time_func - def __getattr__(self, name): - # Delegate all other attributes to the original time module return getattr(self.original_module, name) - - # Replace the time module reference in target module with our mock target_module_obj.time = MockTimeModule(self.original_time_module, self.mock_time) - self.target_module_obj = target_module_obj # Store reference for cleanup + self.target_module_obj = target_module_obj return self def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit - restore original time module reference.""" self.target_module_obj.time = self.original_time_module def mock_module_time(target_module, time_sequence=None, start_time=0.0): - """ - Context manager to mock time.time() calls within any specified module. - - Usage: - # Mock time in a specific module - with mock_module_time('some.module', [0.0, 1.0, 2.0]) as time_controller: - # Code that uses time.time() in 'some.module' will get mocked values - pass - - # Mock time in multiple modules (use multiple context managers) - with mock_module_time('module1', [0.0, 1.0]), \ - mock_module_time('module2', [0.0, 2.0]): - # Both modules will have their time mocked independently - pass - - Args: - target_module (str): Module path to patch (eg., 'utils.py_utils', 'some.other.module') - time_sequence (list): List of time values to return on successive calls to time.time() - start_time (float): Initial time value to start with (ignored if time_sequence is provided) - - Returns: - MockTimeController: Controller object to manipulate time - """ return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) - -def import_all_modules(): - os.environ['DOTENV_PATH'] = 'UNDEFINED' # disable loading .env files - engine_dir = Path(__file__).parent.parent / 'engine' - for py_path in engine_dir.rglob('*.py'): - if py_path.name == '__init__.py' or '__pycache__' in py_path.parts: - continue - rel_path = py_path.relative_to(engine_dir.parent) - module_name = '.'.join(rel_path.with_suffix('').parts) - importlib.import_module(module_name) From 45e550a9b146e479ac6fd6a95ba0015624b29e6d Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 11:21:37 +0000 Subject: [PATCH 08/31] feat: Add pytest test utils and unittest migration plan Adds a new pytest-friendly test utility file at `test/test_utils_new.py`. This file includes the `capture_logs` context manager and decorator for advanced log testing, along with other test helpers. Moves the `get_logger_children` function from the test utilities to `ibind/support/logs.py` to make it a general-purpose utility. Creates a detailed migration plan in `test/migration_plan.md` to guide the conversion of all existing unittest files to pytest. The plan instructs developers to create new test files for side-by-side comparison and to use the new utilities. --- ibind/support/logs.py | 22 +++++++++++++++++++++- test/test_utils_new.py | 26 ++------------------------ 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/ibind/support/logs.py b/ibind/support/logs.py index 11a94226..71ed2f5a 100644 --- a/ibind/support/logs.py +++ b/ibind/support/logs.py @@ -2,6 +2,7 @@ import logging import sys from pathlib import Path +from typing import List from ibind import var @@ -11,6 +12,25 @@ _log_to_file = False +def get_logger_children(main_logger) -> List[logging.Logger]: + """ + Gets child loggers. Added as a support compat for Python version 3.11 and below. + Source: https://github.com/python/cpython/blob/3.12/Lib/logging/__init__.py#L1831 + """ + if hasattr(main_logger, 'getChildren'): + return list(main_logger.getChildren()) + + def _hierlevel(logger): + if logger is logger.manager.root: + return 0 + return 1 + logger.name.count('.') + + d = main_logger.manager.loggerDict + return [item for item in d.values() + if isinstance(item, logging.Logger) and item.parent is main_logger and + _hierlevel(item) == 1 + _hierlevel(item.parent)] + + def project_logger(filepath=None): """ Returns a project-specific logger instance. @@ -152,4 +172,4 @@ def emit(self, record): self.close() self.stream = self._open() - super().emit(record) \ No newline at end of file + super().emit(record) diff --git a/test/test_utils_new.py b/test/test_utils_new.py index af71225a..76eba94a 100644 --- a/test/test_utils_new.py +++ b/test/test_utils_new.py @@ -6,7 +6,8 @@ from pathlib import Path from typing import List, TypeVar -from ibind.support.py_utils import make_clean_stack +from ibind.support.logs import get_logger_children +from ibind.support.py_utils import make_clean_stack, OneOrMany, UNDEFINED _NAME_TO_LEVEL = logging.getLevelNamesMapping() @@ -20,29 +21,6 @@ def accepts_kwargs(func): return True return False -UNDEFINED = object() - -S = TypeVar('S') -OneOrMany = S | List[S] - -def get_logger_children(main_logger) -> List[logging.Logger]: - """ - Gets child loggers. Added as a support compat for Python version 3.11 and below. - Source: https://github.com/python/cpython/blob/3.12/Lib/logging/__init__.py#L1831 - """ - if hasattr(main_logger, 'getChildren'): - return list(main_logger.getChildren()) - - def _hierlevel(logger): - if logger is logger.manager.root: - return 0 - return 1 + logger.name.count('.') - - d = main_logger.manager.loggerDict - return [item for item in d.values() - if isinstance(item, logging.Logger) and item.parent is main_logger and - _hierlevel(item) == 1 + _hierlevel(item.parent)] - # --- Logging Utilities --- class LoggingWatcher: From 81dc7d734bbb363575fed8ff40a82a0fca82d46d Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 15:53:26 +0000 Subject: [PATCH 09/31] feat: Migrate rest_client_i.py to pytest Migrated the unittest-based tests in test/integration/base/test_rest_client_i.py to pytest-style tests in a new file, test/integration/base/test_rest_client_i_new.py. This migration includes: - Converting unittest.TestCase to pytest test functions. - Replacing setUp with pytest fixtures. - Replacing unittest assertions with pytest assertions. - Using the mocker fixture for patching. - Using the CaptureLogsContext for logging assertions. Also includes a defensive fix to test/test_utils_new.py to handle a TypeError when calling the make_clean_stack function. --- .../base/test_rest_client_i_new.py | 161 ++++++++++++++++++ test/test_utils_new.py | 7 +- 2 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 test/integration/base/test_rest_client_i_new.py diff --git a/test/integration/base/test_rest_client_i_new.py b/test/integration/base/test_rest_client_i_new.py new file mode 100644 index 00000000..1fed210b --- /dev/null +++ b/test/integration/base/test_rest_client_i_new.py @@ -0,0 +1,161 @@ +import asyncio +import logging +import threading + +import pytest +from unittest.mock import MagicMock + +from requests import ReadTimeout, Timeout + +from ibind.client.ibkr_client import IbkrClient +from ibind.support.errors import ExternalBrokerError +from ibind.base.rest_client import Result, RestClient +from ibind.support.logs import ibind_logs_initialize +from test.test_utils_new import CaptureLogsContext + + +@pytest.fixture +def client_fixture(): + ibind_logs_initialize(log_to_console=True) + url = 'https://localhost:5000' + timeout = 8 + max_retries = 4 + client = RestClient( + url=url, + timeout=timeout, + max_retries=max_retries, + use_session=False, + ) + data = {'Test key': 'Test value'} + response = MagicMock() + response.json.return_value = data + default_path = 'test/api/route' + default_url = f'{url}/{default_path}' + result = Result(data=data, request={'url': default_url}) + return client, response, default_path, default_url, result, timeout, max_retries + + +def test_default_rest(client_fixture, mocker): + client, response, default_path, default_url, result, timeout, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + + rv = client.get(default_path) + assert result == rv + requests_mock.request.assert_called_with('GET', default_url, verify=False, headers={}, timeout=timeout) + + test_post_kwargs = {'field1': 'value1', 'field2': 'value2'} + test_json = {'json': {**test_post_kwargs}} + rv = client.post(default_path, params=test_post_kwargs) + assert result.copy(request={'url': default_url, **test_json}) == rv + requests_mock.request.assert_called_with('POST', default_url, verify=False, headers={}, timeout=timeout, **test_json) + + rv = client.delete(default_path) + assert result == rv + requests_mock.request.assert_called_with('DELETE', default_url, verify=False, headers={}, timeout=timeout) + + +def test_request_retries(client_fixture, mocker): + client, _, default_path, default_url, _, _, max_retries = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.side_effect = ReadTimeout() + + with CaptureLogsContext('ibind.rest_client', level='INFO') as cm, pytest.raises(TimeoutError) as excinfo: + client.get(default_path) + + for i in range(max_retries): + assert f'RestClient: Timeout for GET {default_url} {{}}, retrying attempt {i + 1}/{max_retries}' in cm.output + + assert f'RestClient: Reached max retries ({max_retries}) for GET {default_url} {{}}' == str(excinfo.value) + + +def test_response_raise_timeout(client_fixture, mocker): + client, response, default_path, _, _, timeout, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + response.raise_for_status.side_effect = Timeout() + + with pytest.raises(ExternalBrokerError) as excinfo: + client.get(default_path) + + assert f'RestClient: Timeout error ({timeout}S)' == str(excinfo.value) + + +def test_response_raise_generic(client_fixture, mocker): + client, response, default_path, _, result, _, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + response.status_code = 400 + response.reason = 'Test reason' + response.text = 'Test text' + response.raise_for_status.side_effect = ValueError('Test generic error') + + with pytest.raises(ExternalBrokerError) as excinfo: + client.get(default_path) + + assert f'RestClient: response error {result.copy(data=None)} :: {response.status_code} :: {response.reason} :: {response.text}' == str(excinfo.value) + + +def _worker_in_thread(results: []): + try: + IbkrClient() + except Exception as e: + results.append(e) + + +def test_in_thread(): + """Run in thread ensuring client still is constructed without an exception.""" + results = [] + t = threading.Thread(target=_worker_in_thread, args=(results,)) + t.daemon = True + t.start() + t.join(1) + for result in results: + if isinstance(result, Exception): + raise result + + +def test_without_thread(): + """Run without a thread to ensure it still works as expected.""" + results = [] + _worker_in_thread(results) + for result in results: + if isinstance(result, Exception): + raise result + + +async def _async_worker(results: []): + """Async version of the worker function to run in an asyncio event loop.""" + try: + IbkrClient() + except Exception as e: + results.append(e) + + +def _worker_in_async_thread(results: []): + """Runs the async test inside a new thread to check if signal handling breaks.""" + try: + asyncio.run(_async_worker(results)) + except Exception as e: + results.append(e) + + +def test_in_thread_async(): + """Test that IbkrClient() does not break in an asyncio thread.""" + results = [] + t = threading.Thread(target=_worker_in_async_thread, args=(results,)) + t.daemon = True + t.start() + t.join(1) + for result in results: + if isinstance(result, Exception): + raise result + + +def test_without_thread_async(): + """Test that IbkrClient() does not break in the main asyncio event loop.""" + results = [] + asyncio.run(_async_worker(results)) + for result in results: + if isinstance(result, Exception): + raise result diff --git a/test/test_utils_new.py b/test/test_utils_new.py index 76eba94a..3c0fed41 100644 --- a/test/test_utils_new.py +++ b/test/test_utils_new.py @@ -127,7 +127,12 @@ def _monkey_patch_log(self, logger): def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): if extra is None: extra = {} - extra['manual_trace'] = make_clean_stack(extra_filters=[os.path.join('support', 'slog.py')])[:-2] + # Check if make_clean_stack accepts extra_filters + if 'extra_filters' in inspect.signature(make_clean_stack).parameters: + extra['manual_trace'] = make_clean_stack(extra_filters=[os.path.join('support', 'slog.py')])[:-2] + else: + extra['manual_trace'] = make_clean_stack()[:-2] + return original_log(level, msg, args, exc_info, extra, stack_info, stacklevel) logger.__old_log_method__ = original_log From 71f72c2b8c103d6bbfee9f1155686e1512b63f87 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 16:11:53 +0000 Subject: [PATCH 10/31] feat: Migrate rest_client_i.py to pytest and improve test readability Migrated the unittest-based tests in test/integration/base/test_rest_client_i.py to pytest-style tests in a new file, test/integration/base/test_rest_client_i_new.py. This migration includes: - Converting unittest.TestCase to pytest test functions. - Replacing setUp with pytest fixtures. - Replacing unittest assertions with pytest assertions. - Using the mocker fixture for patching. - Using the CaptureLogsContext for logging assertions. Added ## Arrange, ## Act, and ## Assert comments to all tests in the new file to improve readability. Split the `test_default_rest` function into three separate tests for better test isolation. Includes a defensive fix to test/test_utils_new.py to handle a TypeError when calling the make_clean_stack function. --- .../base/test_rest_client_i_new.py | 55 ++++++++++++++++++- test/test_utils_new.py | 3 +- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/test/integration/base/test_rest_client_i_new.py b/test/integration/base/test_rest_client_i_new.py index 1fed210b..c6243da0 100644 --- a/test/integration/base/test_rest_client_i_new.py +++ b/test/integration/base/test_rest_client_i_new.py @@ -35,34 +35,61 @@ def client_fixture(): return client, response, default_path, default_url, result, timeout, max_retries -def test_default_rest(client_fixture, mocker): +def test_default_rest_get(client_fixture, mocker): + # Arrange client, response, default_path, default_url, result, timeout, _ = client_fixture requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = response + # Act rv = client.get(default_path) + + # Assert assert result == rv requests_mock.request.assert_called_with('GET', default_url, verify=False, headers={}, timeout=timeout) + +def test_default_rest_post(client_fixture, mocker): + # Arrange + client, response, default_path, default_url, result, timeout, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response test_post_kwargs = {'field1': 'value1', 'field2': 'value2'} test_json = {'json': {**test_post_kwargs}} + + # Act rv = client.post(default_path, params=test_post_kwargs) + + # Assert assert result.copy(request={'url': default_url, **test_json}) == rv requests_mock.request.assert_called_with('POST', default_url, verify=False, headers={}, timeout=timeout, **test_json) + +def test_default_rest_delete(client_fixture, mocker): + # Arrange + client, response, default_path, default_url, result, timeout, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + + # Act rv = client.delete(default_path) + + # Assert assert result == rv requests_mock.request.assert_called_with('DELETE', default_url, verify=False, headers={}, timeout=timeout) def test_request_retries(client_fixture, mocker): + # Arrange client, _, default_path, default_url, _, _, max_retries = client_fixture requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.side_effect = ReadTimeout() + # Act with CaptureLogsContext('ibind.rest_client', level='INFO') as cm, pytest.raises(TimeoutError) as excinfo: client.get(default_path) + # Assert for i in range(max_retries): assert f'RestClient: Timeout for GET {default_url} {{}}, retrying attempt {i + 1}/{max_retries}' in cm.output @@ -70,18 +97,22 @@ def test_request_retries(client_fixture, mocker): def test_response_raise_timeout(client_fixture, mocker): + # Arrange client, response, default_path, _, _, timeout, _ = client_fixture requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = response response.raise_for_status.side_effect = Timeout() + # Act with pytest.raises(ExternalBrokerError) as excinfo: client.get(default_path) + # Assert assert f'RestClient: Timeout error ({timeout}S)' == str(excinfo.value) def test_response_raise_generic(client_fixture, mocker): + # Arrange client, response, default_path, _, result, _, _ = client_fixture requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = response @@ -90,9 +121,11 @@ def test_response_raise_generic(client_fixture, mocker): response.text = 'Test text' response.raise_for_status.side_effect = ValueError('Test generic error') + # Act with pytest.raises(ExternalBrokerError) as excinfo: client.get(default_path) + # Assert assert f'RestClient: response error {result.copy(data=None)} :: {response.status_code} :: {response.reason} :: {response.text}' == str(excinfo.value) @@ -105,11 +138,16 @@ def _worker_in_thread(results: []): def test_in_thread(): """Run in thread ensuring client still is constructed without an exception.""" + # Arrange results = [] t = threading.Thread(target=_worker_in_thread, args=(results,)) t.daemon = True + + # Act t.start() t.join(1) + + # Assert for result in results: if isinstance(result, Exception): raise result @@ -117,8 +155,13 @@ def test_in_thread(): def test_without_thread(): """Run without a thread to ensure it still works as expected.""" + # Arrange results = [] + + # Act _worker_in_thread(results) + + # Assert for result in results: if isinstance(result, Exception): raise result @@ -142,11 +185,16 @@ def _worker_in_async_thread(results: []): def test_in_thread_async(): """Test that IbkrClient() does not break in an asyncio thread.""" + # Arrange results = [] t = threading.Thread(target=_worker_in_async_thread, args=(results,)) t.daemon = True + + # Act t.start() t.join(1) + + # Assert for result in results: if isinstance(result, Exception): raise result @@ -154,8 +202,13 @@ def test_in_thread_async(): def test_without_thread_async(): """Test that IbkrClient() does not break in the main asyncio event loop.""" + # Arrange results = [] + + # Act asyncio.run(_async_worker(results)) + + # Assert for result in results: if isinstance(result, Exception): raise result diff --git a/test/test_utils_new.py b/test/test_utils_new.py index 3c0fed41..6ec520c3 100644 --- a/test/test_utils_new.py +++ b/test/test_utils_new.py @@ -127,7 +127,8 @@ def _monkey_patch_log(self, logger): def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): if extra is None: extra = {} - # Check if make_clean_stack accepts extra_filters + # Check if make_clean_stack accepts extra_filters. This is necessary + # because the signature of make_clean_stack is unstable. if 'extra_filters' in inspect.signature(make_clean_stack).parameters: extra['manual_trace'] = make_clean_stack(extra_filters=[os.path.join('support', 'slog.py')])[:-2] else: From d47822bcd3753c2935803aef30d0e92e919117ec Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 21 Dec 2025 16:17:19 +0000 Subject: [PATCH 11/31] feat: Migrate rest_client_i.py to pytest and improve test readability Migrated the unittest-based tests in test/integration/base/test_rest_client_i.py to pytest-style tests in a new file, test/integration/base/test_rest_client_i_new.py. This migration includes: - Converting unittest.TestCase to pytest test functions. - Replacing setUp with pytest fixtures. - Replacing unittest assertions with pytest assertions. - Using the mocker fixture for patching. - Using the CaptureLogsContext for logging assertions. Improved test readability by: - Adding ## Arrange, ## Act, and ## Assert comments to all tests. - Splitting the `test_default_rest` function into three separate tests for better test isolation. Also includes a defensive fix to test/test_utils_new.py to handle a TypeError when calling the make_clean_stack function. --- coverage_new.txt | 44 ++++++++++++++++++++++++++++++++++++++++++++ coverage_old.txt | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 coverage_new.txt create mode 100644 coverage_old.txt diff --git a/coverage_new.txt b/coverage_new.txt new file mode 100644 index 00000000..1272343d --- /dev/null +++ b/coverage_new.txt @@ -0,0 +1,44 @@ +============================= test session starts ============================== +platform linux -- Python 3.12.12, pytest-8.4.2, pluggy-1.6.0 +rootdir: /app +configfile: pytest.ini +plugins: mock-3.15.1, cov-5.0.0 +collected 10 items + +test/integration/base/test_rest_client_i_new.py .......... [100%] + +---------- coverage: platform linux, python 3.12.12-final-0 ---------- +Name Stmts Miss Cover +------------------------------------------------------------------------- +ibind/__init__.py 13 0 100% +ibind/base/__init__.py 0 0 100% +ibind/base/queue_controller.py 18 7 61% +ibind/base/rest_client.py 152 36 76% +ibind/base/subscription_controller.py 125 104 17% +ibind/base/ws_client.py 217 184 15% +ibind/client/__init__.py 0 0 100% +ibind/client/ibkr_client.py 119 73 39% +ibind/client/ibkr_client_mixins/__init__.py 0 0 100% +ibind/client/ibkr_client_mixins/accounts_mixin.py 4 0 100% +ibind/client/ibkr_client_mixins/contract_mixin.py 25 14 44% +ibind/client/ibkr_client_mixins/marketdata_mixin.py 61 45 26% +ibind/client/ibkr_client_mixins/order_mixin.py 22 12 45% +ibind/client/ibkr_client_mixins/portfolio_mixin.py 5 0 100% +ibind/client/ibkr_client_mixins/scanner_mixin.py 5 0 100% +ibind/client/ibkr_client_mixins/session_mixin.py 39 29 26% +ibind/client/ibkr_client_mixins/watchlist_mixin.py 4 0 100% +ibind/client/ibkr_definitions.py 6 1 83% +ibind/client/ibkr_utils.py 226 131 42% +ibind/client/ibkr_ws_client.py 238 177 26% +ibind/oauth/__init__.py 26 26 0% +ibind/oauth/oauth1a.py 164 164 0% +ibind/support/__init__.py 0 0 100% +ibind/support/errors.py 4 0 100% +ibind/support/logs.py 82 13 84% +ibind/support/py_utils.py 87 63 28% +ibind/var.py 88 4 95% +------------------------------------------------------------------------- +TOTAL 1730 1083 37% + + +============================== 10 passed in 1.23s ============================== diff --git a/coverage_old.txt b/coverage_old.txt new file mode 100644 index 00000000..01bbe7f6 --- /dev/null +++ b/coverage_old.txt @@ -0,0 +1,44 @@ +============================= test session starts ============================== +platform linux -- Python 3.12.12, pytest-8.4.2, pluggy-1.6.0 +rootdir: /app +configfile: pytest.ini +plugins: mock-3.15.1, cov-5.0.0 +collected 8 items + +test/integration/base/test_rest_client_i.py ........ [100%] + +---------- coverage: platform linux, python 3.12.12-final-0 ---------- +Name Stmts Miss Cover +------------------------------------------------------------------------- +ibind/__init__.py 13 0 100% +ibind/base/__init__.py 0 0 100% +ibind/base/queue_controller.py 18 7 61% +ibind/base/rest_client.py 152 36 76% +ibind/base/subscription_controller.py 125 104 17% +ibind/base/ws_client.py 217 184 15% +ibind/client/__init__.py 0 0 100% +ibind/client/ibkr_client.py 119 73 39% +ibind/client/ibkr_client_mixins/__init__.py 0 0 100% +ibind/client/ibkr_client_mixins/accounts_mixin.py 4 0 100% +ibind/client/ibkr_client_mixins/contract_mixin.py 25 14 44% +ibind/client/ibkr_client_mixins/marketdata_mixin.py 61 45 26% +ibind/client/ibkr_client_mixins/order_mixin.py 22 12 45% +ibind/client/ibkr_client_mixins/portfolio_mixin.py 5 0 100% +ibind/client/ibkr_client_mixins/scanner_mixin.py 5 0 100% +ibind/client/ibkr_client_mixins/session_mixin.py 39 29 26% +ibind/client/ibkr_client_mixins/watchlist_mixin.py 4 0 100% +ibind/client/ibkr_definitions.py 6 1 83% +ibind/client/ibkr_utils.py 226 131 42% +ibind/client/ibkr_ws_client.py 238 177 26% +ibind/oauth/__init__.py 26 26 0% +ibind/oauth/oauth1a.py 164 164 0% +ibind/support/__init__.py 0 0 100% +ibind/support/errors.py 4 0 100% +ibind/support/logs.py 82 57 30% +ibind/support/py_utils.py 87 63 28% +ibind/var.py 88 4 95% +------------------------------------------------------------------------- +TOTAL 1730 1127 35% + + +============================== 8 passed in 1.26s =============================== From c4136177ac5f6142120c627970ea1aeefed4b4db Mon Sep 17 00:00:00 2001 From: voyz Date: Tue, 23 Dec 2025 09:54:15 +0100 Subject: [PATCH 12/31] test: migrated test_ibkr_client_i.py to pytest --- .../client/test_ibkr_client_i_new.py | 364 ++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 test/integration/client/test_ibkr_client_i_new.py diff --git a/test/integration/client/test_ibkr_client_i_new.py b/test/integration/client/test_ibkr_client_i_new.py new file mode 100644 index 00000000..a4fccd0a --- /dev/null +++ b/test/integration/client/test_ibkr_client_i_new.py @@ -0,0 +1,364 @@ +import datetime +from pprint import pformat +import pytest +from unittest.mock import MagicMock + +from requests import ConnectTimeout + +from ibind.base.rest_client import Result +from ibind.client.ibkr_client import IbkrClient +from ibind.client.ibkr_utils import StockQuery, filter_stocks +from ibind.support.errors import ExternalBrokerError +from ibind.support.logs import ibind_logs_initialize +from test.integration.client import ibkr_responses +from test.test_utils_new import CaptureLogsContext + + +@pytest.fixture +def client_fixture(mocker): + ibind_logs_initialize(log_to_console=True) + mocker.patch('ibind.base.rest_client.requests') + url = 'https://localhost:5000' + account_id = 'TEST_ACCOUNT_ID' + timeout = 8 + max_retries = 4 + client = IbkrClient( + url=url, + account_id=account_id, + timeout=timeout, + max_retries=max_retries, + use_session=False, + ) + data = {'Test key': 'Test value'} + response = MagicMock() + response.json.return_value = data + default_path = '/test/api/route' + default_url = f'{url}/{default_path}' + result = Result(data=data, request={'url': default_url}) + return client, response, default_path, default_url, result, timeout, max_retries + + +def test_get_conids(client_fixture, mocker): + # Arrange + client, response, _, _, _, _, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + response.json.return_value = ibkr_responses.responses['stocks'] + + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': False}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), + 'HUBS', + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + StockQuery(symbol='INVALID_SYMBOL') + ] + + # Act + rv = client.stock_conid_by_symbol(queries, default_filtering=False) + + # Assert + for symbol, conid in rv.data.items(): + assert symbol in ibkr_responses.responses['filtered_conids'] + assert conid == ibkr_responses.responses['filtered_conids'][symbol] + + +def test_get_conids_exception(client_fixture, mocker): + # Arrange + client, response, _, _, _, _, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + response.json.return_value = ibkr_responses.responses['stocks'] + + symbol = 'AAPL' + query = StockQuery(symbol=symbol, contract_conditions={'isUS': False}, name_match='APPLE') + + instruments = filter_stocks(query, Result(data={symbol: ibkr_responses.responses['stocks'][symbol]}), default_filtering=False).data[symbol] + + # Act and Assert + with pytest.raises(RuntimeError) as excinfo: + client.stock_conid_by_symbol(query, default_filtering=False) + + assert str(excinfo.value) == f'Filtering stock "{symbol}" returned 2 instruments and 2 contracts using following query: {query}.' \ + f'\nPlease use filters to ensure that only one instrument and one contract per symbol is selected in order to avoid conid ambiguity.' \ + f'\nBe aware that contracts are filtered as {{"isUS": True}} by default. Set default_filtering=False to prevent this default filtering or specify custom filters. See inline documentation for more details.' \ + f'\nInstruments returned:\n{pformat(instruments)}' + + +def test_get_live_orders_no_filters(client_fixture): + # Arrange + client, _, _, _, result, _, _ = client_fixture + client.get = MagicMock(return_value=result) + + # Act + client.live_orders() + + # Assert + client.get.assert_called_with('iserver/account/orders', params=None) + + +def test_get_live_orders_with_valid_filters(client_fixture): + # Arrange + client, _, _, _, result, _, _ = client_fixture + client.get = MagicMock(return_value=result) + filters = ['inactive', 'filled'] + + # Act + client.live_orders(filters=filters) + + # Assert + client.get.assert_called_with('iserver/account/orders', params={'filters': 'inactive,filled'}) + + +def test_get_live_orders_with_single_filter(client_fixture): + # Arrange + client, _, _, _, result, _, _ = client_fixture + client.get = MagicMock(return_value=result) + + # Act + client.live_orders(filters='submitted') + + # Assert + client.get.assert_called_with('iserver/account/orders', params={'filters': 'submitted'}) + + +def test_get_live_orders_with_incorrect_filter_type(client_fixture): + # Arrange + client, _, _, _, result, _, _ = client_fixture + client.get = MagicMock(return_value=result) + + # Act and Assert + with pytest.raises(TypeError): + client.live_orders(filters=123) # Non-list, non-string filter + client.get.assert_not_called() + + +def _marketdata_request(method, url, *args, **kwargs): + leaf = url.split('/')[-1] + if leaf == 'stocks': + return MagicMock(json=lambda: ibkr_responses.responses['stocks']) + elif leaf == 'history': + conid = kwargs['params']['conid'] + history_by_conid = { + ibkr_responses.responses['filtered_conids'][key]: value for key, value in ibkr_responses.responses['history'].items() + } + return MagicMock(json=lambda: history_by_conid[conid]) + + +def test_marketdata_history_by_symbols(client_fixture, mocker): + # Arrange + client, _, _, _, _, _, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.side_effect = _marketdata_request + + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': False}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), + StockQuery(symbol='HUBS'), + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + ] + + expected_results = {} + for query in queries: + data = ibkr_responses.responses['history'][query.symbol]['data'][0] + output = { + 'conid': ibkr_responses.responses['filtered_conids'][query.symbol], + 'symbol': query.symbol, + 'open': data['o'], + 'high': data['h'], + 'low': data['l'], + 'close': data['c'], + 'volume': data['v'], + 'date': datetime.datetime.fromtimestamp(data['t'] / 1000, tz=datetime.timezone.utc), + } + expected_results[query.symbol] = output + + expected_errors = ['Market data for CDN is not live: Delayed', 'Market data for CFC is not live: Delayed'] + + # Act + with CaptureLogsContext('ibind', level='INFO', logger_level='DEBUG', expected_errors=expected_errors, partial_match=True) as cm: + results = client.marketdata_history_by_symbols(queries) + + # Assert + for msg in expected_errors: + assert msg in cm.output + + for symbol, expected in expected_results.items(): + result = results[symbol][-1] + assert symbol in results + assert result['open'] == pytest.approx(expected['open']) + assert result['high'] == pytest.approx(expected['high']) + assert result['low'] == pytest.approx(expected['low']) + assert result['close'] == pytest.approx(expected['close']) + assert result['volume'] == pytest.approx(expected['volume']) + assert result['date'] == expected['date'] + + +def test_check_health_authenticated_and_connected(client_fixture, mocker): + # Arrange + client, _, _, default_url, _, _, _ = client_fixture + response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is True + client.tickle.assert_called_once() + + +def test_check_health_not_authenticated(client_fixture, mocker): + # Arrange + client, _, _, default_url, _, _, _ = client_fixture + response_data = {'iserver': {'authStatus': {'authenticated': False, 'competing': False, 'connected': True}}} + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is False + + +def test_check_health_competing_connection(client_fixture, mocker): + # Arrange + client, _, _, default_url, _, _, _ = client_fixture + response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': True, 'connected': True}}} + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is False + + +def test_check_health_connection_error(client_fixture, mocker): + # Arrange + client, _, _, _, _, _, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.side_effect = ConnectTimeout + client.tickle = MagicMock(side_effect=ConnectTimeout) + + # Act + with CaptureLogsContext( + 'ibind.session_mixin', + level='ERROR', + expected_errors=['ConnectTimeout raised when communicating with the Gateway'], + partial_match=True, + ) as cm: + health_status = client.check_health() + + # Assert + assert health_status is False + assert 'ConnectTimeout raised when communicating with the Gateway' in cm.output[0] + + +def test_check_health_external_broker_error_unauthenticated(client_fixture, mocker): + # Arrange + client, _, _, _, _, _, _ = client_fixture + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.side_effect = ExternalBrokerError(status_code=401) + client.tickle = MagicMock(side_effect=ExternalBrokerError(status_code=401)) + + # Act + with CaptureLogsContext('ibind.session_mixin', level='INFO', expected_errors=['Gateway session is not authenticated.']) as cm: + health_status = client.check_health() + + # Assert + assert health_status is False + assert 'Gateway session is not authenticated.' in cm.output[0] + + +def test_check_health_invalid_data(client_fixture, mocker): + # Arrange + client, _, _, default_url, _, _, _ = client_fixture + response_data = {} # Invalid data format + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act and Assert + with pytest.raises(AttributeError) as excinfo: + client.check_health() + assert 'Health check requests returns invalid data' in str(excinfo.value) + + +def test_marketdata_unsubscribe_success(client_fixture, mocker): + # Arrange + client, _, _, _, _, _, _ = client_fixture + conids = [12345, 67890] + + def post_side_effect(url, *args, **kwargs): + conid = kwargs['params']['conid'] + if conid in conids: + return Result(data={'success': True}, request={'url': url}) + raise ExternalBrokerError(status_code=404) + + client.post = MagicMock(side_effect=post_side_effect, __name__='client_post_mock') + + # Act + results = client.marketdata_unsubscribe(conids) + + # Assert + for conid, result in results.items(): + assert int(conid) in conids + assert isinstance(result, Result) + assert result.data['success'] is True + + +def test_marketdata_unsubscribe_with_error(client_fixture, mocker): + # Arrange + client, _, _, _, _, _, _ = client_fixture + conids = [12345, 67890] + + def post_side_effect(url, *args, **kwargs): + conid = kwargs['params']['conid'] + if conid == 12345: + raise ExternalBrokerError(status_code=404) + return Result(data={'success': True}, request={'url': url}) + + client.post = MagicMock(side_effect=post_side_effect, __name__='client_post_mock') + + # Act + results = client.marketdata_unsubscribe(conids) + + # Assert + assert 12345 in results + assert 67890 in results + assert results[67890].data['success'] is True + assert isinstance(results[12345], ExternalBrokerError) + + +def test_marketdata_unsubscribe_raises_exception_on_failure(client_fixture, mocker): + # Arrange + client, _, _, _, _, _, _ = client_fixture + conids = [12345] + client.post = MagicMock(side_effect=ExternalBrokerError(status_code=500), __name__='client_post_mock') + + # Act + with pytest.raises(ExternalBrokerError) as excinfo: + client.marketdata_unsubscribe(conids) + + # Assert + assert excinfo.value.status_code == 500 \ No newline at end of file From 2aa9ca4478677bc2743664e061edc738d27cc450 Mon Sep 17 00:00:00 2001 From: voyz Date: Tue, 23 Dec 2025 09:54:29 +0100 Subject: [PATCH 13/31] chore: added pytest-mock to requirements-dev.txt --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index f5fd6ce1..8e14ad3c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,4 @@ ruff>=0.9.4,<0.10.0 bandit>=1.8.2,<2.0.0 pytest>=7.0.0,<9.0.0 pytest-cov>=4.0.0,<6.0.0 - +pytest-mock>=3.0.0,<4.0.0 \ No newline at end of file From 0c8b1ea5e47c518101a4578f16a98d4ecd7ea74a Mon Sep 17 00:00:00 2001 From: voyz Date: Tue, 23 Dec 2025 10:18:37 +0100 Subject: [PATCH 14/31] tests: refactor pytest fixtures and autouse requests mocks Split tuple fixtures into constants + granular fixtures; add autouse requests_mock with default return_value to reduce boilerplate in migrated tests. --- .../base/test_rest_client_i_new.py | 109 +++++++++-------- .../client/test_ibkr_client_i_new.py | 113 +++++++++--------- 2 files changed, 113 insertions(+), 109 deletions(-) diff --git a/test/integration/base/test_rest_client_i_new.py b/test/integration/base/test_rest_client_i_new.py index c6243da0..1feda7b8 100644 --- a/test/integration/base/test_rest_client_i_new.py +++ b/test/integration/base/test_rest_client_i_new.py @@ -14,108 +14,115 @@ from test.test_utils_new import CaptureLogsContext +_URL = 'https://localhost:5000' +_TIMEOUT = 8 +_MAX_RETRIES = 4 +_DEFAULT_PATH = 'test/api/route' + + @pytest.fixture -def client_fixture(): +def client(): ibind_logs_initialize(log_to_console=True) - url = 'https://localhost:5000' - timeout = 8 - max_retries = 4 - client = RestClient( - url=url, - timeout=timeout, - max_retries=max_retries, + return RestClient( + url=_URL, + timeout=_TIMEOUT, + max_retries=_MAX_RETRIES, use_session=False, ) - data = {'Test key': 'Test value'} + + +@pytest.fixture +def data(): + return {'Test key': 'Test value'} + + +@pytest.fixture +def response(data): response = MagicMock() response.json.return_value = data - default_path = 'test/api/route' - default_url = f'{url}/{default_path}' - result = Result(data=data, request={'url': default_url}) - return client, response, default_path, default_url, result, timeout, max_retries + return response -def test_default_rest_get(client_fixture, mocker): - # Arrange - client, response, default_path, default_url, result, timeout, _ = client_fixture +@pytest.fixture(autouse=True) +def requests_mock(mocker, response): requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = response + return requests_mock + + +@pytest.fixture +def default_url(): + return f'{_URL}/{_DEFAULT_PATH}' + + +@pytest.fixture +def result(data, default_url): + return Result(data=data, request={'url': default_url}) + +def test_default_rest_get(client, default_url, result, requests_mock): + # Arrange # Act - rv = client.get(default_path) + rv = client.get(_DEFAULT_PATH) # Assert assert result == rv - requests_mock.request.assert_called_with('GET', default_url, verify=False, headers={}, timeout=timeout) + requests_mock.request.assert_called_with('GET', default_url, verify=False, headers={}, timeout=_TIMEOUT) -def test_default_rest_post(client_fixture, mocker): +def test_default_rest_post(client, default_url, result, requests_mock): # Arrange - client, response, default_path, default_url, result, timeout, _ = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response test_post_kwargs = {'field1': 'value1', 'field2': 'value2'} test_json = {'json': {**test_post_kwargs}} # Act - rv = client.post(default_path, params=test_post_kwargs) + rv = client.post(_DEFAULT_PATH, params=test_post_kwargs) # Assert assert result.copy(request={'url': default_url, **test_json}) == rv - requests_mock.request.assert_called_with('POST', default_url, verify=False, headers={}, timeout=timeout, **test_json) + requests_mock.request.assert_called_with('POST', default_url, verify=False, headers={}, timeout=_TIMEOUT, **test_json) -def test_default_rest_delete(client_fixture, mocker): +def test_default_rest_delete(client, default_url, result, requests_mock): # Arrange - client, response, default_path, default_url, result, timeout, _ = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response - # Act - rv = client.delete(default_path) + rv = client.delete(_DEFAULT_PATH) # Assert assert result == rv - requests_mock.request.assert_called_with('DELETE', default_url, verify=False, headers={}, timeout=timeout) + requests_mock.request.assert_called_with('DELETE', default_url, verify=False, headers={}, timeout=_TIMEOUT) -def test_request_retries(client_fixture, mocker): +def test_request_retries(client, default_url, requests_mock): # Arrange - client, _, default_path, default_url, _, _, max_retries = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.side_effect = ReadTimeout() # Act with CaptureLogsContext('ibind.rest_client', level='INFO') as cm, pytest.raises(TimeoutError) as excinfo: - client.get(default_path) + client.get(_DEFAULT_PATH) # Assert - for i in range(max_retries): - assert f'RestClient: Timeout for GET {default_url} {{}}, retrying attempt {i + 1}/{max_retries}' in cm.output + for i in range(_MAX_RETRIES): + assert f'RestClient: Timeout for GET {default_url} {{}}, retrying attempt {i + 1}/{_MAX_RETRIES}' in cm.output - assert f'RestClient: Reached max retries ({max_retries}) for GET {default_url} {{}}' == str(excinfo.value) + assert f'RestClient: Reached max retries ({_MAX_RETRIES}) for GET {default_url} {{}}' == str(excinfo.value) -def test_response_raise_timeout(client_fixture, mocker): +def test_response_raise_timeout(client, requests_mock): # Arrange - client, response, default_path, _, _, timeout, _ = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response - response.raise_for_status.side_effect = Timeout() + requests_mock.request.return_value.raise_for_status.side_effect = Timeout() # Act with pytest.raises(ExternalBrokerError) as excinfo: - client.get(default_path) + client.get(_DEFAULT_PATH) # Assert - assert f'RestClient: Timeout error ({timeout}S)' == str(excinfo.value) + assert f'RestClient: Timeout error ({_TIMEOUT}S)' == str(excinfo.value) -def test_response_raise_generic(client_fixture, mocker): +def test_response_raise_generic(client, result, requests_mock): # Arrange - client, response, default_path, _, result, _, _ = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response + response = requests_mock.request.return_value response.status_code = 400 response.reason = 'Test reason' response.text = 'Test text' @@ -123,7 +130,7 @@ def test_response_raise_generic(client_fixture, mocker): # Act with pytest.raises(ExternalBrokerError) as excinfo: - client.get(default_path) + client.get(_DEFAULT_PATH) # Assert assert f'RestClient: response error {result.copy(data=None)} :: {response.status_code} :: {response.reason} :: {response.text}' == str(excinfo.value) @@ -211,4 +218,4 @@ def test_without_thread_async(): # Assert for result in results: if isinstance(result, Exception): - raise result + raise result \ No newline at end of file diff --git a/test/integration/client/test_ibkr_client_i_new.py b/test/integration/client/test_ibkr_client_i_new.py index a4fccd0a..f95f35c1 100644 --- a/test/integration/client/test_ibkr_client_i_new.py +++ b/test/integration/client/test_ibkr_client_i_new.py @@ -14,35 +14,56 @@ from test.test_utils_new import CaptureLogsContext +_URL = 'https://localhost:5000' +_TIMEOUT = 8 +_MAX_RETRIES = 4 +_DEFAULT_PATH = '/test/api/route' +_ACCOUNT_ID = 'TEST_ACCOUNT_ID' + + @pytest.fixture -def client_fixture(mocker): +def client(): ibind_logs_initialize(log_to_console=True) - mocker.patch('ibind.base.rest_client.requests') - url = 'https://localhost:5000' - account_id = 'TEST_ACCOUNT_ID' - timeout = 8 - max_retries = 4 - client = IbkrClient( - url=url, - account_id=account_id, - timeout=timeout, - max_retries=max_retries, + return IbkrClient( + url=_URL, + account_id=_ACCOUNT_ID, + timeout=_TIMEOUT, + max_retries=_MAX_RETRIES, use_session=False, ) - data = {'Test key': 'Test value'} + + +@pytest.fixture +def data(): + return {'Test key': 'Test value'} + + +@pytest.fixture +def response(data): response = MagicMock() response.json.return_value = data - default_path = '/test/api/route' - default_url = f'{url}/{default_path}' - result = Result(data=data, request={'url': default_url}) - return client, response, default_path, default_url, result, timeout, max_retries + return response -def test_get_conids(client_fixture, mocker): - # Arrange - client, response, _, _, _, _, _ = client_fixture +@pytest.fixture(autouse=True) +def requests_mock(mocker, response): requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = response + return requests_mock + + +@pytest.fixture +def default_url(): + return f'{_URL}/{_DEFAULT_PATH}' + + +@pytest.fixture +def result(data, default_url): + return Result(data=data, request={'url': default_url}) + + +def test_get_conids(client, response): + # Arrange response.json.return_value = ibkr_responses.responses['stocks'] queries = [ @@ -69,11 +90,8 @@ def test_get_conids(client_fixture, mocker): assert conid == ibkr_responses.responses['filtered_conids'][symbol] -def test_get_conids_exception(client_fixture, mocker): +def test_get_conids_exception(client, response): # Arrange - client, response, _, _, _, _, _ = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response response.json.return_value = ibkr_responses.responses['stocks'] symbol = 'AAPL' @@ -91,9 +109,8 @@ def test_get_conids_exception(client_fixture, mocker): f'\nInstruments returned:\n{pformat(instruments)}' -def test_get_live_orders_no_filters(client_fixture): +def test_get_live_orders_no_filters(client, result): # Arrange - client, _, _, _, result, _, _ = client_fixture client.get = MagicMock(return_value=result) # Act @@ -103,9 +120,8 @@ def test_get_live_orders_no_filters(client_fixture): client.get.assert_called_with('iserver/account/orders', params=None) -def test_get_live_orders_with_valid_filters(client_fixture): +def test_get_live_orders_with_valid_filters(client, result): # Arrange - client, _, _, _, result, _, _ = client_fixture client.get = MagicMock(return_value=result) filters = ['inactive', 'filled'] @@ -116,9 +132,8 @@ def test_get_live_orders_with_valid_filters(client_fixture): client.get.assert_called_with('iserver/account/orders', params={'filters': 'inactive,filled'}) -def test_get_live_orders_with_single_filter(client_fixture): +def test_get_live_orders_with_single_filter(client, result): # Arrange - client, _, _, _, result, _, _ = client_fixture client.get = MagicMock(return_value=result) # Act @@ -128,9 +143,8 @@ def test_get_live_orders_with_single_filter(client_fixture): client.get.assert_called_with('iserver/account/orders', params={'filters': 'submitted'}) -def test_get_live_orders_with_incorrect_filter_type(client_fixture): +def test_get_live_orders_with_incorrect_filter_type(client, result): # Arrange - client, _, _, _, result, _, _ = client_fixture client.get = MagicMock(return_value=result) # Act and Assert @@ -151,10 +165,8 @@ def _marketdata_request(method, url, *args, **kwargs): return MagicMock(json=lambda: history_by_conid[conid]) -def test_marketdata_history_by_symbols(client_fixture, mocker): +def test_marketdata_history_by_symbols(client, requests_mock): # Arrange - client, _, _, _, _, _, _ = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.side_effect = _marketdata_request queries = [ @@ -207,11 +219,9 @@ def test_marketdata_history_by_symbols(client_fixture, mocker): assert result['date'] == expected['date'] -def test_check_health_authenticated_and_connected(client_fixture, mocker): +def test_check_health_authenticated_and_connected(client, default_url, requests_mock): # Arrange - client, _, _, default_url, _, _, _ = client_fixture response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} - requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = MagicMock(json=lambda: response_data) client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) @@ -223,11 +233,9 @@ def test_check_health_authenticated_and_connected(client_fixture, mocker): client.tickle.assert_called_once() -def test_check_health_not_authenticated(client_fixture, mocker): +def test_check_health_not_authenticated(client, default_url, requests_mock): # Arrange - client, _, _, default_url, _, _, _ = client_fixture response_data = {'iserver': {'authStatus': {'authenticated': False, 'competing': False, 'connected': True}}} - requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = MagicMock(json=lambda: response_data) client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) @@ -238,11 +246,9 @@ def test_check_health_not_authenticated(client_fixture, mocker): assert health_status is False -def test_check_health_competing_connection(client_fixture, mocker): +def test_check_health_competing_connection(client, default_url, requests_mock): # Arrange - client, _, _, default_url, _, _, _ = client_fixture response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': True, 'connected': True}}} - requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = MagicMock(json=lambda: response_data) client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) @@ -253,10 +259,8 @@ def test_check_health_competing_connection(client_fixture, mocker): assert health_status is False -def test_check_health_connection_error(client_fixture, mocker): +def test_check_health_connection_error(client, requests_mock): # Arrange - client, _, _, _, _, _, _ = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.side_effect = ConnectTimeout client.tickle = MagicMock(side_effect=ConnectTimeout) @@ -274,10 +278,8 @@ def test_check_health_connection_error(client_fixture, mocker): assert 'ConnectTimeout raised when communicating with the Gateway' in cm.output[0] -def test_check_health_external_broker_error_unauthenticated(client_fixture, mocker): +def test_check_health_external_broker_error_unauthenticated(client, requests_mock): # Arrange - client, _, _, _, _, _, _ = client_fixture - requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.side_effect = ExternalBrokerError(status_code=401) client.tickle = MagicMock(side_effect=ExternalBrokerError(status_code=401)) @@ -290,11 +292,9 @@ def test_check_health_external_broker_error_unauthenticated(client_fixture, mock assert 'Gateway session is not authenticated.' in cm.output[0] -def test_check_health_invalid_data(client_fixture, mocker): +def test_check_health_invalid_data(client, default_url, requests_mock): # Arrange - client, _, _, default_url, _, _, _ = client_fixture response_data = {} # Invalid data format - requests_mock = mocker.patch('ibind.base.rest_client.requests') requests_mock.request.return_value = MagicMock(json=lambda: response_data) client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) @@ -304,9 +304,8 @@ def test_check_health_invalid_data(client_fixture, mocker): assert 'Health check requests returns invalid data' in str(excinfo.value) -def test_marketdata_unsubscribe_success(client_fixture, mocker): +def test_marketdata_unsubscribe_success(client, mocker): # Arrange - client, _, _, _, _, _, _ = client_fixture conids = [12345, 67890] def post_side_effect(url, *args, **kwargs): @@ -327,9 +326,8 @@ def post_side_effect(url, *args, **kwargs): assert result.data['success'] is True -def test_marketdata_unsubscribe_with_error(client_fixture, mocker): +def test_marketdata_unsubscribe_with_error(client, mocker): # Arrange - client, _, _, _, _, _, _ = client_fixture conids = [12345, 67890] def post_side_effect(url, *args, **kwargs): @@ -350,9 +348,8 @@ def post_side_effect(url, *args, **kwargs): assert isinstance(results[12345], ExternalBrokerError) -def test_marketdata_unsubscribe_raises_exception_on_failure(client_fixture, mocker): +def test_marketdata_unsubscribe_raises_exception_on_failure(client, mocker): # Arrange - client, _, _, _, _, _, _ = client_fixture conids = [12345] client.post = MagicMock(side_effect=ExternalBrokerError(status_code=500), __name__='client_post_mock') From 1c8fa0ae1d7567d5ae4e8aa574a3e7a4bf035f7a Mon Sep 17 00:00:00 2001 From: voyz Date: Tue, 23 Dec 2025 10:56:43 +0100 Subject: [PATCH 15/31] test: updated the migration plan --- test/migration_plan.md | 75 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/test/migration_plan.md b/test/migration_plan.md index 90f99bdb..0f61570e 100644 --- a/test/migration_plan.md +++ b/test/migration_plan.md @@ -7,33 +7,90 @@ This document outlines the roadmap for migrating our existing `unittest`-based t When migrating tests, please adhere to the following principles: - **New Test Files:** To compare test coverage before and after the migration, create a new test file for the migrated tests. For example, `test/integration/base/test_rest_client_i.py` should be migrated to `test/integration/base/test_rest_client_i_new.py`. +- **Post-migration check:** run the old and new test files *separately* with `--cov= --cov-report=term-missing` and confirm the covered/missing lines are identical (or document any differences). - **Test Classes:** Convert `unittest.TestCase` subclasses into plain test functions. If a class structure is still beneficial for grouping related tests, you can use a class without inheriting from `unittest.TestCase`. - **`setUp` and `tearDown`:** Replace `setUp` and `tearDown` methods with `pytest` fixtures. - **Assertions:** Convert all `self.assert...` methods to plain `assert` statements. For example, `self.assertEqual(a, b)` becomes `assert a == b`. - **Exception Handling:** Replace `with self.assertRaises(...)` with `with pytest.raises(...)`. - **Logging:** Use the new `capture_logs` utility from `test_utils_new.py`. It can be used as a context manager (`with capture_logs(...) as cm:`) or as a decorator (`@capture_logs(...)`). This replaces all previous `unittest`-based logging helpers. The returned watcher object has methods like `exact_log`, `partial_log`, and `log_excludes` for assertions. -- **Arrange, Act, Assert:** Structure your tests using the Arrange, Act, Assert pattern. +- **Arrange, Act, Assert:** Structure your tests using the ##Arrange, ##Act, ##Assert pattern. - **Parametrization:** Use `@pytest.mark.parametrize` to run the same test with different inputs. +## Additional Rules (learned from first few migrations) + +The following rules help avoid common migration pitfalls and reduce boilerplate. See: + +- `test/integration/base/test_rest_client_i_new.py` +- `test/integration/client/test_ibkr_client_i_new.py` + +### Fixtures and constants + +- **Prefer module constants for stable configuration** + - Put stable values such as `_URL`, `_TIMEOUT`, `_DEFAULT_PATH`, `_MAX_RETRIES` at module scope. + - Keep fixtures focused on objects with lifecycle/state (clients, mocks, results). + +- **Avoid “mega fixtures” that return tuples** + - If a `setUp` method created many objects, migrate it into multiple fixtures. + +### Patching (replacing class-level @patch) + +- **Use an autouse `requests_mock` fixture for common patching** + - When the original unittest test patched a whole `TestCase` class (e.g. `@patch('...requests')`), replicate it with a single `@pytest.fixture(autouse=True)`. + + Example pattern: + + ```python + @pytest.fixture(autouse=True) + def requests_mock(mocker, response): + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + return requests_mock + ``` + + Tests can still override behavior locally: + + - `requests_mock.request.side_effect = ReadTimeout()` + - `requests_mock.request.return_value = MagicMock(...)` + +### Preserve unittest semantics + +- **Float comparisons** + - `self.assertAlmostEqual(...)` should migrate to `pytest.approx(...)`. + +- **Logging expectations** + - Do not assert *more* than the unittest test asserted. + - If unittest checked a substring (e.g. `assertIn`), migrate to `partial_match=True` or explicit substring checks. + +- **Exceptions vs return values** + - Verify whether the production code *raises* or *returns* exceptions. + - A common pitfall is migrating a test to “return exception in results” when the implementation actually raises (or ignores) specific errors. + +- **Key types / coercions** + - Be careful with dict keys and parameter conversions. + - If production code casts IDs (e.g. `int(conid)`), results may be keyed by `int` even if the input looked like a string. + +- **Naming parity** + - Keep test names close to the original unittest names to make 1:1 mapping and review easier. + ## Migration Chunks The following files need to be migrated. Each file can be worked on independently. --- -### 1. [ ] `test/integration/base/test_rest_client_i.py` +### 1. [✔] `test/integration/base/test_rest_client_i.py` - **Migration Steps:** 1. Create a new file: `test/integration/base/test_rest_client_i_new.py`. 2. In the new file, convert all `TestCase` subclasses into simple test functions. - 3. Replace the `setUp` method's logic with a `pytest` fixture. + 3. Replace the `setUp` method's logic with granular fixtures and module constants (avoid tuple-returning fixtures). 4. Convert all `self.assert...` calls and `with self.assertRaises` to `assert` and `with pytest.raises(...)`. 5. Replace `with self.assertLogs(...)` with the `capture_logs` context manager from `test_utils_new.py`. - 6. Refactor the class-level patch to use the `mocker` fixture within each test function. + 6. Refactor the class-level patch into an autouse fixture (e.g. `requests_mock`) so tests don't repeat patch boilerplate. --- -### 2. [ ] `test/integration/base/test_websocket_client_i.py` +### 2. [] `test/integration/base/test_websocket_client_i.py` - **Migration Steps:** 1. Create a new file: `test/integration/base/test_websocket_client_i_new.py`. @@ -44,15 +101,15 @@ The following files need to be migrated. Each file can be worked on independentl --- -### 3. [ ] `test/integration/client/test_ibkr_client_i.py` +### 3. [✔] `test/integration/client/test_ibkr_client_i.py` - **Migration Steps:** 1. Create a new file: `test/integration/client/test_ibkr_client_i_new.py`. 2. In the new file, convert the class into a series of test functions. - 3. Move the `setUp` logic into a `pytest` fixture. + 3. Move the `setUp` logic into granular fixtures and module constants (avoid tuple-returning fixtures). 4. Replace all `self.assert...` calls with plain `assert` statements and `pytest.raises`. 5. Replace the `SafeAssertLogs` and `RaiseLogsContext` with the `capture_logs` utility from `test_utils_new.py`. - 6. Handle the class-level patch using the `mocker` fixture. + 6. Handle the class-level patch using an autouse fixture (e.g. `requests_mock`) so tests don't repeat patch boilerplate. --- @@ -85,4 +142,4 @@ The following files need to be migrated. Each file can be worked on independentl 2. In the new file, convert all three classes into separate sets of test functions. 3. Move the `setUp` method into a fixture. 4. Convert all `self.assert...` methods and `with self.assertRaises` to plain `assert` statements and `with pytest.raises(...)`. - 5. Replace the `@patch` decorator with the `mocker` fixture. + 5. Replace the `@patch` decorator with the `mocker` fixture. \ No newline at end of file From 9c77b81d3c5ccfea425f97c09d1e3efcf60fc88f Mon Sep 17 00:00:00 2001 From: voyz Date: Tue, 23 Dec 2025 11:27:51 +0100 Subject: [PATCH 16/31] test: migrated test_py_utils_u.py to pytest --- test/unit/support/test_py_utils_u_new.py | 228 +++++++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 test/unit/support/test_py_utils_u_new.py diff --git a/test/unit/support/test_py_utils_u_new.py b/test/unit/support/test_py_utils_u_new.py new file mode 100644 index 00000000..5fb4a250 --- /dev/null +++ b/test/unit/support/test_py_utils_u_new.py @@ -0,0 +1,228 @@ +import time +from unittest.mock import MagicMock + +import pytest + +from ibind.support.py_utils import ensure_list_arg, execute_in_parallel, execute_with_key, wait_until + + +@ensure_list_arg('arg') +def sample_function(arg): + return arg + + +def test_ensure_list_arg_with_list(): + """Wraps list args without altering the list.""" + # Arrange + input_arg = [1, 2, 3] + + # Act + result = sample_function(input_arg) + + # Assert + assert result == input_arg + + +def test_ensure_list_arg_with_non_list(): + """Wraps a non-list arg into a single-item list.""" + # Arrange + input_arg = 1 + + # Act + result = sample_function(input_arg) + + # Assert + assert result == [input_arg] + + +def test_ensure_list_arg_with_keyword_arg_list(): + """Preserves list input when passed as a keyword arg.""" + # Arrange + input_arg = [1, 2, 3] + + # Act + result = sample_function(arg=input_arg) + + # Assert + assert result == input_arg + + +def test_ensure_list_arg_with_keyword_arg_non_list(): + """Wraps a non-list keyword arg into a single-item list.""" + # Arrange + input_arg = 1 + + # Act + result = sample_function(arg=input_arg) + + # Assert + assert result == [input_arg] + + +def test_ensure_list_arg_with_missing_arg(): + """Raises TypeError when the decorated arg is missing.""" + # Arrange + + # Act / Assert + with pytest.raises(TypeError): + sample_function() + + +@pytest.fixture +def parallel_setup(): + state = {'delay': 0} + + def _func(v1, v2): + if v1 == 1: + time.sleep(state['delay']) + return 'result1' + elif v2 == 2: + return 'result2' + else: + return 'unknown' + + func = MagicMock(side_effect=_func) + func.__name__ = 'TEST_FUNCTION' + requests_dict = {'req1': {'args': [1, 0], 'kwargs': {}}, 'req2': {'args': [0], 'kwargs': {'v2': 2}}} + requests_list = [{'args': [1, 0], 'kwargs': {}}, {'args': [0], 'kwargs': {'v2': 2}}] + + return { + 'state': state, + 'func': func, + 'requests_dict': requests_dict, + 'requests_list': requests_list, + } + + +def test_execute_in_parallel_with_dict(parallel_setup): + """Executes requests in parallel when passed a dict of requests.""" + # Arrange + func = parallel_setup['func'] + requests = parallel_setup['requests_dict'] + + # Act + results = execute_in_parallel(func, requests) + + # Assert + assert results == {'req1': 'result1', 'req2': 'result2'} + assert func.call_count == 2 + + +def test_execute_in_parallel_with_list(parallel_setup): + """Executes requests in parallel when passed a list of requests.""" + # Arrange + func = parallel_setup['func'] + requests = parallel_setup['requests_list'] + parallel_setup['state']['delay'] = 0.1 + + # Act + results = execute_in_parallel(func, requests) + + # Assert + assert results == ['result1', 'result2'] + assert func.call_count == 2 + + +def test_execute_with_key_success(parallel_setup): + """Returns (key, result) when the wrapped function succeeds.""" + # Arrange + func = parallel_setup['func'] + + # Act + result = execute_with_key('key', func, 1, v2=2) + + # Assert + func.assert_called_with(1, v2=2) + assert result == ('key', 'result1') + + +def test_execute_with_key_exception(parallel_setup): + """Returns (key, exception) when the wrapped function raises.""" + # Arrange + func = parallel_setup['func'] + func.side_effect = Exception('error') + + # Act + result = execute_with_key('key', func, 1, v2=2) + + # Assert + assert isinstance(result[1], Exception) + + +def test_execute_in_parallel_rate_limiting(): + """Applies max_per_second rate limiting across parallel executions.""" + # Arrange + start_time = time.time() + + # Simulate a slow function to test rate limiting + def slow_func(): + time.sleep(0.05) + return 'slow_result' + + requests = {i: {'args': [], 'kwargs': {}} for i in range(20)} # 10 requests + max_per_second = 10 # Limit to 5 requests per second + + # Act + results = execute_in_parallel(slow_func, requests, max_per_second=max_per_second) + + # Assert + duration = time.time() - start_time + assert duration >= 1.05 # Should take at least 1.1 seconds to complete all requests + assert len(results) == 20 + + +def test_wait_until_condition_met(): + """Returns True immediately when the condition is already met.""" + # Arrange + condition = MagicMock(return_value=True) + + # Act + result = wait_until(condition) + + # Assert + assert result is True + condition.assert_called() + + +def test_wait_until_condition_not_met(): + """Returns False when the condition is not met before timeout.""" + # Arrange + condition = MagicMock(return_value=False) + + # Act + result = wait_until(condition, timeout=0.1) + + # Assert + assert result is False + condition.assert_called() + + +def test_wait_until_timeout_message(mocker): + """Logs the timeout_message when the deadline is reached.""" + # Arrange + mock_logger_error = mocker.patch('ibind.support.py_utils._LOGGER.error') + condition = MagicMock(return_value=False) + timeout_message = 'Condition not met within timeout' + + # Act + result = wait_until(condition, timeout_message=timeout_message, timeout=0.1) + + # Assert + assert result is False + mock_logger_error.assert_called_with(timeout_message) + + +def test_wait_until_timeout(): + """Waits roughly the specified timeout duration before returning False.""" + # Arrange + start_time = time.time() + condition = MagicMock(return_value=False) + timeout = 0.1 + + # Act + result = wait_until(condition, timeout=timeout) + + # Assert + assert result is False + duration = time.time() - start_time + assert duration == pytest.approx(timeout, abs=0.02) \ No newline at end of file From 95a6c3e8d07a610c6a9b40c97e4bbc002cf4033f Mon Sep 17 00:00:00 2001 From: voyz Date: Tue, 23 Dec 2025 12:05:51 +0100 Subject: [PATCH 17/31] test: migrated test_ibkr_utils_i.py to pytest --- .../client/test_ibkr_utils_i_new.py | 394 ++++++++++++++++++ 1 file changed, 394 insertions(+) create mode 100644 test/integration/client/test_ibkr_utils_i_new.py diff --git a/test/integration/client/test_ibkr_utils_i_new.py b/test/integration/client/test_ibkr_utils_i_new.py new file mode 100644 index 00000000..2eb92395 --- /dev/null +++ b/test/integration/client/test_ibkr_utils_i_new.py @@ -0,0 +1,394 @@ +from pprint import pformat +from unittest.mock import MagicMock, call + +import pytest + +from ibind.base.rest_client import Result +from ibind.client.ibkr_utils import ( + StockQuery, + filter_stocks, + find_answer, + QuestionType, + handle_questions, + question_type_to_message_id, + OrderRequest, + parse_order_request, +) +from test.integration.client import ibkr_responses +from test.test_utils_new import CaptureLogsContext + + +# -------------------------------------------------------------------------------------- +# Stock filtering +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def instruments(): + return ibkr_responses.responses['stocks'] + + +@pytest.fixture +def instruments_result(instruments): + return Result(data=instruments) + + +def test_filter_stocks(instruments, instruments_result): + """Filters instruments for multiple stock queries and logs missing symbols.""" + ## Arrange + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': True}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery( + symbol='GOOG', + contract_conditions={'isUS': False}, + instrument_conditions={'chineseName': 'Alphabet公司'}, + ), + 'HUBS', + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER'), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + StockQuery(symbol='INVALID_SYMBOL'), + ] # fmt: skip + + ## Act + with CaptureLogsContext('ibind', level='INFO', error_level='CRITICAL', attach_stack=False) as cm: + rv = filter_stocks(queries, instruments_result, default_filtering=False) + + ## Assert + expected_error = ( + f'Error getting stocks. Could not find valid instruments INVALID_SYMBOL in result: {instruments_result}. ' + f'Skipping query={queries[-1]}.' + ) + assert expected_error in cm.output + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '苹果公司', + 'contracts': [ + {'conid': 38708077, 'exchange': 'MEXI', 'isUS': False}, + {'conid': 273982664, 'exchange': 'EBS', 'isUS': False}, + ], + 'name': 'APPLE INC', + }, + { + 'assetClass': 'STK', + 'chineseName': '苹果公司', + 'contracts': [{'conid': 532640894, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'APPLE INC-CDR', + }, + ] == rv.data['AAPL'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '西班牙对外银行', + 'contracts': [{'conid': 4815, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'BANCO BILBAO VIZCAYA-SP ADR', + }, + ] == rv.data['BBVA'] + + assert [] == rv.data['CDN'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': None, + 'contracts': [{'conid': 42001300, 'exchange': 'IBIS', 'isUS': False}], + 'name': 'UET UNITED ELECTRONIC TECHNO', + } + ] == rv.data['CFC'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'Alphabet公司', + 'contracts': [ + {'conid': 210810667, 'exchange': 'MEXI', 'isUS': False}, + ], + 'name': 'ALPHABET INC-CL C', + }, + { + 'assetClass': 'STK', + 'chineseName': 'Alphabet公司', + 'contracts': [{'conid': 532638805, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'ALPHABET INC - CDR', + }, + ] == rv.data['GOOG'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'HubSpot公司', + 'contracts': [{'conid': 169544810, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'HUBSPOT INC', + } + ] == rv.data['HUBS'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'Meta平台股份有限公司', + 'contracts': [ + {'conid': 114922621, 'exchange': 'MEXI', 'isUS': False}, + ], + 'name': 'META PLATFORMS INC-CLASS A', + }, + { + 'assetClass': 'STK', + 'chineseName': 'Meta平台股份有限公司', + 'contracts': [{'conid': 530091499, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'META PLATFORMS INC-CDR', + }, + ] == rv.data['META'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '微软公司', + 'contracts': [ + {'conid': 272093, 'exchange': 'NASDAQ', 'isUS': True}, + ], + 'name': 'MICROSOFT CORP', + }, + ] == rv.data['MSFT'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '桑坦德', + 'contracts': [ + {'conid': 38708867, 'exchange': 'MEXI', 'isUS': False}, + {'conid': 385055564, 'exchange': 'WSE', 'isUS': False}, + ], + 'name': 'BANCO SANTANDER SA', + }, + { + 'assetClass': 'STK', + 'chineseName': '桑坦德', + 'contracts': [{'conid': 12442, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'BANCO SANTANDER SA-SPON ADR', + }, + { + 'assetClass': 'STK', + 'chineseName': '桑坦德英国公共有限公司', + 'contracts': [{'conid': 80993135, 'exchange': 'LSE', 'isUS': False}], + 'name': 'SANTANDER UK PLC', + }, + ] == rv.data['SAN'] + + assert [] == rv.data['SCHW'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': None, + 'contracts': [{'conid': 589316251, 'exchange': 'NASDAQ', 'isUS': True}], + 'name': 'ATLASSIAN CORP-CL A', + }, + ] == rv.data['TEAM'] + + +def test_question_type_to_message_id_successful(): + """Maps a QuestionType to its expected IBKR message id.""" + ## Arrange + question_type = QuestionType.PRICE_PERCENTAGE_CONSTRAINT + + ## Act + message_id = question_type_to_message_id(question_type) + + ## Assert + assert message_id == 'o163' + + +# -------------------------------------------------------------------------------------- +# Finding answers +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def answers(): + return {QuestionType.PRICE_PERCENTAGE_CONSTRAINT: True} + + +def test_valid_question(answers): + """Returns True when a known question type is found in the question string.""" + ## Arrange + question = f'Some {QuestionType.PRICE_PERCENTAGE_CONSTRAINT} specific question' + + ## Act + answer = find_answer(question, answers) + + ## Assert + assert answer is True + + +def test_invalid_question(answers): + """Raises when no answer matches the provided question string.""" + ## Arrange + question = 'Nonexistent question type' + + ## Act & Assert + with pytest.raises(ValueError): + find_answer(question, answers) + + +# -------------------------------------------------------------------------------------- +# Handling interactive questions +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def original_result(): + return Result( + data=[{'id': '12345', 'message': ['price exceeds the Percentage constraint of 3%.']}], + request={'url': 'test_url'}, + ) + + +@pytest.fixture +def reply_callback(): + return MagicMock() + + +def test_successful_handling(mocker, original_result, reply_callback): + """Replies to a sequence of questions and returns the final result.""" + ## Arrange + question_type_mock = mocker.patch('ibind.client.ibkr_utils.QuestionType') + + question_type_mock.PRICE_PERCENTAGE_CONSTRAINT.__str__.return_value = 'price exceeds the Percentage constraint of 3%.' + question_type_mock.ADDITIONAL_QUESTION_TYPE.__str__.return_value = 'This is an additional question.' + + answers = {question_type_mock.PRICE_PERCENTAGE_CONSTRAINT: True, question_type_mock.ADDITIONAL_QUESTION_TYPE: True} + + replies = [ + Result(data=[{'id': '12346', 'message': ['This is an additional question.']}], request={'url': 'another_question_url'}), + Result(data=[{'id': '12347'}], request={'url': 'final_url'}), + ] + reply_callback.side_effect = replies + + ## Act + result = handle_questions(original_result, answers, reply_callback) + + ## Assert + assert result.request['url'] == original_result.request['url'] + assert len(reply_callback.call_args_list) == 2 + + expected_calls = [ + call(original_result.data[0]['id'], answers[question_type_mock.PRICE_PERCENTAGE_CONSTRAINT]), + call(replies[0].data[0]['id'], answers[question_type_mock.ADDITIONAL_QUESTION_TYPE]), + ] + + assert expected_calls == reply_callback.call_args_list + + +def test_too_many_questions(original_result, answers, reply_callback): + """Raises when the question loop exceeds the maximum number of attempts.""" + ## Arrange + reply_callback.side_effect = [original_result] * 21 + + ## Act & Assert + with pytest.raises(RuntimeError) as cm_err: + handle_questions(original_result, answers, reply_callback) + + assert 'Too many questions' in str(cm_err.value) + + +def test_negative_reply(original_result, answers, reply_callback): + """Raises when a question is answered negatively.""" + ## Arrange + answers[QuestionType.PRICE_PERCENTAGE_CONSTRAINT] = False + + ## Act & Assert + with pytest.raises(RuntimeError) as cm_err: + handle_questions(original_result, answers, reply_callback) + + assert ( + f'A question was not given a positive reply. Question: "{original_result.data[0]["message"][0]}". Answers: \n{answers}\n. Request: {original_result.request}' + == str(cm_err.value) + ) + + +def test_multiple_orders_returned(original_result, answers, reply_callback): + """Logs a message when multiple orders are returned while handling questions.""" + ## Arrange + original_result.data = [ + {'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, + {'id': '12346', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, + ] + reply_callback.return_value = original_result.copy(data=[{}]) + + expected = 'While handling questions multiple orders were returned: ' + pformat(original_result.data) + + ## Act & Assert + with CaptureLogsContext('ibind', level='INFO', expected_errors=[expected], attach_stack=False): + handle_questions(original_result, answers, reply_callback) + + +def test_multiple_messages_returned(original_result, answers, reply_callback): + """Logs a message when multiple messages are returned for a single order.""" + ## Arrange + original_result.data = [{'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT), 'Message 2']}] + reply_callback.return_value = original_result.copy(data=[{}]) + + expected = 'While handling questions multiple messages were returned: ' + pformat(original_result.data[0]['message']) + + ## Act & Assert + with CaptureLogsContext('ibind', level='INFO', expected_errors=[expected], attach_stack=False): + handle_questions(original_result, answers, reply_callback) + + +# -------------------------------------------------------------------------------------- +# Order request parsing +# -------------------------------------------------------------------------------------- + + +def test_parse_both_with_conidex(): + """Parses OrderRequest with conid=None and conidex set into API payload.""" + ## Arrange + order_request = OrderRequest( + conid=None, + side='BUY', + quantity=321, + order_type='MKT', + acct_id='DU1234567', + conidex='33333', + ) + + ## Act + d = parse_order_request(order_request) + + ## Assert + assert { + 'side': 'BUY', + 'quantity': 321, + 'orderType': 'MKT', + 'acctId': 'DU1234567', + 'conidex': '33333', + 'tif': 'GTC', + } == d + + +def test_raise_with_conid_and_conidex(): + """Raises when both conid and conidex are provided.""" + ## Arrange + + ## Act & Assert + with pytest.raises(ValueError) as cm_err: + order_request = OrderRequest( + conid=123, + side='BUY', + quantity=321, + order_type='MKT', + acct_id='DU1234567', + conidex='33333', + ) + + parse_order_request(order_request) + + assert "Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`." == str(cm_err.value) \ No newline at end of file From 9be11047310f956252961e84cafb73f06391e064 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:00:49 +0100 Subject: [PATCH 18/31] test: migrated test_ibkr_ws_client_i.py and test_websocket_client_i.py to pytest --- .../base/test_websocket_client_i_new.py | 396 +++++++++++++ test/integration/base/websocketapp_mock.py | 4 +- .../client/test_ibkr_ws_client_i_new.py | 539 ++++++++++++++++++ test/test_utils_new.py | 4 +- 4 files changed, 939 insertions(+), 4 deletions(-) create mode 100644 test/integration/base/test_websocket_client_i_new.py create mode 100644 test/integration/client/test_ibkr_ws_client_i_new.py diff --git a/test/integration/base/test_websocket_client_i_new.py b/test/integration/base/test_websocket_client_i_new.py new file mode 100644 index 00000000..3a9b2fac --- /dev/null +++ b/test/integration/base/test_websocket_client_i_new.py @@ -0,0 +1,396 @@ +from threading import Thread +from typing import Optional +from unittest.mock import MagicMock + +import pytest + +from ibind.base.ws_client import WsClient +from ibind.support.py_utils import tname +from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock +from test.test_utils_new import capture_logs + +_URL = 'wss://localhost:5000/v1/api/ws' +_MAX_RECONNECT_ATTEMPTS = 4 +_MAX_PING_INTERVAL = 38 +_ERROR_MESSAGE = 'TEST_ERROR' + + +# -------------------------------------------------------------------------------------- +# Log expectations +# -------------------------------------------------------------------------------------- + + +def _logs_start_success_beginning(): + return [ + 'WsClient: Starting', + 'WsClient: Trying to connect', + ] + + +def _logs_start_success_end(): + return [ + 'WsClient: Creating new WebSocketApp', + f'WsClient: Thread started ({tname()})', + 'WsClient: Connection open', + f'WsClient: Thread stopped ({tname()})', + ] + + +def _logs_failed_attempt(max_reconnect_attempts: int, attempt: Optional[int]): + logs = [ + 'WsClient: Creating new WebSocketApp', + 'WsClient: New WebSocketApp connection timeout', + 'WsClient: on_close', + 'WsClient: on_close event while disconnected', + ] + if attempt is not None: + logs.append(f'WsClient: Connect reattempt {attempt}/{max_reconnect_attempts}') + return logs + + +def _logs_shutdown_success(): + return [ + 'WsClient: Shutting down', + 'WsClient: on_close', + 'WsClient: Connection closed', + 'WsClient: Gracefully stopped', + ] + + +def _logs_exception_starting(error_message: str, thread_mock: MagicMock): + return [ + 'WsClient: Creating new WebSocketApp', + f'WsClient: Thread started ({tname()})', + f'WsClient: Unexpected error while running WebSocketApp: {error_message}', + 'WsClient: Hard reset, restart=False, self._wsa is None=False', + 'WsClient: Forced restart', + 'WsClient: Reconnecting', + f'WsClient: Thread already running: {thread_mock.name}-{thread_mock.ident}', + f'WsClient: Thread stopped ({tname()})', + 'WsClient: Reconnecting', + 'WsClient: Trying to connect', + ] + + +def _logs_check_health_error(max_ping_interval: int, time_ago: str): + return [ + f'WsClient: Last WebSocket ping happened {time_ago} seconds ago, exceeding the max ping interval of {max_ping_interval}. Restarting.', + 'WsClient: Hard reset, restart=True, self._wsa is None=False', + 'WsClient: Hard reset is closing the WebSocketApp', + ] + + +def _logs_hard_restart_error(wsa_mock: MagicMock): + return [ + 'WsClient: Hard reset close timeout', + f'WsClient: Abandoning current WebSocketApp that cannot be closed: {wsa_mock}', + 'WsClient: Forced restart', + 'WsClient: Reconnecting', + 'WsClient: Trying to connect', + ] + + +def _verify_started(ws_client: WsClient, wsa_mock: MagicMock): + wsa_mock.run_forever.assert_called_with( + sslopt=ws_client._sslopt, + ping_interval=ws_client._ping_interval, + ping_timeout=0.95 * ws_client._ping_interval, + ) + wsa_mock._on_open.assert_called_with(wsa_mock) + + +def _verify_failed_starting(wsa_mock: MagicMock): + wsa_mock.run_forever.assert_not_called() + wsa_mock._on_open.assert_not_called() + wsa_mock.close.assert_called() + + +# -------------------------------------------------------------------------------------- +# Test setup +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def ws_client(): + return WsClient( + subscription_processor=None, + url=_URL, + cacert=False, + timeout=0.01, + max_connection_attempts=_MAX_RECONNECT_ATTEMPTS, + max_ping_interval=_MAX_PING_INTERVAL, + ) + + +@pytest.fixture +def wsa_mock(): + return create_wsa_mock() + + +@pytest.fixture +def thread_mock(ws_client, wsa_mock): + thread_mock = MagicMock(spec=Thread) + thread_mock.start.side_effect = lambda: ws_client._run_websocket(wsa_mock) + return thread_mock + + +@pytest.fixture +def wsa_ctor_mock(mocker, wsa_mock): + return mocker.patch( + 'ibind.base.ws_client.WebSocketApp', + side_effect=lambda *args, **kwargs: init_wsa_mock(wsa_mock, *args, **kwargs), + ) + + +@pytest.fixture +def thread_ctor_mock(mocker, thread_mock): + return mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) + + +@pytest.fixture +def patched_constructors(wsa_ctor_mock, thread_ctor_mock): + return None + + +# -------------------------------------------------------------------------------------- +# Start / reconnect behavior +# -------------------------------------------------------------------------------------- + +@capture_logs(logger_level='DEBUG') +def test_start_success(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Starts successfully and logs the expected connection sequence.""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ## Act + success = ws_client.start() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + _verify_started(ws_client, wsa_mock) + assert _logs_start_success_beginning() + _logs_start_success_end() == [r.msg for r in cm.records] + + +@capture_logs(logger_level='DEBUG', expected_errors=['WsClient: New WebSocketApp connection timeout']) +def test_start_success_on_second_attempt(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Reconnects and succeeds on the second attempt after a timeout on the first.""" + ## Arrange + cm = kwargs['_cm_ibind'] + counter = [0] + + def delayed_start(): + if counter[0] >= 1: + ws_client._run_websocket(wsa_mock) + counter[0] += 1 + + thread_mock.start.side_effect = delayed_start + + ## Act + success = ws_client.start() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + _verify_started(ws_client, wsa_mock) + assert ( + _logs_start_success_beginning() + + _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, 2) + + _logs_start_success_end() + == [r.msg for r in cm.records] + ) + thread_mock.join.assert_called_with(60) + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + 'WsClient: New WebSocketApp connection timeout', + f'WsClient: Connection failed after {_MAX_RECONNECT_ATTEMPTS} attempts', + ], +) +def test_start_reattempt_failure(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Fails after exhausting reconnect attempts and closes the WebSocketApp.""" + ## Arrange + cm = kwargs['_cm_ibind'] + thread_mock.start.side_effect = lambda: None + + ## Act + success = ws_client.start() + + ## Assert + assert success is False + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + + _verify_failed_starting(wsa_mock) + + expected_logs = _logs_start_success_beginning() + for i in range(_MAX_RECONNECT_ATTEMPTS): + if i < _MAX_RECONNECT_ATTEMPTS - 1: + expected_logs += _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, i + 2) + else: + expected_logs += _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, None) + expected_logs.append(f"WsClient: Connection failed after {_MAX_RECONNECT_ATTEMPTS} attempts") + + assert expected_logs == [r.msg for r in cm.records] + assert wsa_mock.keep_running is False + + +# -------------------------------------------------------------------------------------- +# Error handling +# -------------------------------------------------------------------------------------- + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + f"WsClient: Unexpected error while running WebSocketApp: {_ERROR_MESSAGE}", + 'WsClient: Thread already running:', + ], + partial_match=True, +) +def test_open_exception(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Hard-resets and reconnects when WebSocketApp.run_forever raises an exception.""" + ## Arrange + cm = kwargs['_cm_ibind'] + old_run_forever = wsa_mock.run_forever.side_effect + + def run_forever_exception( + wsa_mock: MagicMock, + sslopt: dict = None, + ping_interval: float = 0, + ping_timeout: Optional[float] = None, + ): + wsa_mock.run_forever.side_effect = old_run_forever + raise RuntimeError(_ERROR_MESSAGE) + + wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever_exception(wsa_mock, *args, **kwargs) + + ## Act + ws_client.start() + ws_client.shutdown() + + ## Assert + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + assert ( + _logs_start_success_beginning() + + _logs_exception_starting(_ERROR_MESSAGE, thread_mock) + + _logs_start_success_end() + + _logs_shutdown_success() + == [r.msg for r in cm.records] + ) + + +# -------------------------------------------------------------------------------------- +# Shutdown +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_open_and_close(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Shuts down cleanly after a successful start.""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ## Act + success = ws_client.start() + ws_client.shutdown() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + assert _logs_start_success_beginning() + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] + + +# -------------------------------------------------------------------------------------- +# Sending payloads +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_send(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Delivers outbound payloads to the on_message callback (mocked echo).""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ws_client._on_message = MagicMock() + + ## Act + success = ws_client.start() + ws_client.send('test') + ws_client.shutdown() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + ws_client._on_message.assert_called_once_with(wsa_mock, 'test') + assert _logs_start_success_beginning() + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] + + +@capture_logs(logger_level='DEBUG', expected_errors=['WsClient: Must be started before sending payloads']) +def test_send_without_start(ws_client, **kwargs): + """Logs an error when trying to send before calling start().""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ws_client._on_message = MagicMock() + + ## Act + ws_client.send('test') + ws_client.shutdown() + + ## Assert + assert ['WsClient: Must be started before sending payloads'] == [r.msg for r in cm.records] + + +# -------------------------------------------------------------------------------------- +# Health checks +# -------------------------------------------------------------------------------------- + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + 'WsClient: Last WebSocket ping happened', + 'WsClient: Hard reset close timeout', + 'WsClient: Abandoning current WebSocketApp that cannot be closed:', + ], + partial_match=True, +) +def test_check_ping(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Triggers a hard reset when the last ping exceeds max_ping_interval.""" + ## Arrange + cm = kwargs['_cm_ibind'] + start_time = [100] + + def fake_time(): + start_time[0] += 100 + return start_time[0] + + ws_client._on_message = MagicMock() + + ## Act + ws_client.start() + ws_client.check_ping() + + # Simulate that closing the WebSocketApp doesn't work since we have connectivity issues + wsa_mock._on_close.side_effect = lambda x, y, z: None + + time_mock = mocker.patch('ibind.base.ws_client.time') + time_mock.time.side_effect = fake_time + + wsa_mock.last_ping_tm = _MAX_PING_INTERVAL + ws_client.check_ping() + assert ws_client.ready() is True + ws_client.shutdown() + + ## Assert + assert ( + _logs_start_success_beginning() + + _logs_start_success_end() + + _logs_check_health_error(_MAX_PING_INTERVAL, '162.00') + + _logs_hard_restart_error(wsa_mock) + + _logs_start_success_end() + + _logs_shutdown_success() + == [r.msg for r in cm.records] + ) \ No newline at end of file diff --git a/test/integration/base/websocketapp_mock.py b/test/integration/base/websocketapp_mock.py index 670b1205..5961f7a4 100644 --- a/test/integration/base/websocketapp_mock.py +++ b/test/integration/base/websocketapp_mock.py @@ -50,7 +50,7 @@ def create_wsa_mock(): wsa_mock = MagicMock() wsa_mock.send.side_effect = lambda *args, **kwargs: send(wsa_mock, *args, **kwargs) - wsa_mock.close.side_effect = lambda status=None: close(wsa_mock, status) + wsa_mock.close.side_effect = lambda *args, **kwargs: close(wsa_mock, *args, **kwargs) wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever(wsa_mock, *args, **kwargs) - return wsa_mock + return wsa_mock \ No newline at end of file diff --git a/test/integration/client/test_ibkr_ws_client_i_new.py b/test/integration/client/test_ibkr_ws_client_i_new.py new file mode 100644 index 00000000..84df04fa --- /dev/null +++ b/test/integration/client/test_ibkr_ws_client_i_new.py @@ -0,0 +1,539 @@ +import json +from threading import Thread +from typing import Optional +from unittest.mock import MagicMock, call + +import pytest +import requests + +from ibind import Result +from ibind.client.ibkr_client import IbkrClient +from ibind.client.ibkr_ws_client import IbkrWsClient, IbkrSubscriptionProcessor, IbkrWsKey +from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock +from test.test_utils_new import capture_logs + +_URL_WS = 'wss://localhost:5000/v1/api/ws' +_URL_REST = 'https://localhost:5000' +_ACCOUNT_ID = 'TEST_ACCOUNT_ID' +_TIMEOUT_REST = 8 +_MAX_RETRIES_REST = 4 +_MAX_RECONNECT_ATTEMPTS = 4 +_MAX_PING_INTERVAL = 38 +_SUBSCRIPTION_RETRIES = 3 +_CONID = 265598 +_UPDATE_TIME = 5678765456 + + +# -------------------------------------------------------------------------------------- +# Test setup +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def preprocess_ws_client(): + return IbkrWsClient( + url=_URL_WS, + ibkr_client=None, + account_id=None, + subscription_processor_class=lambda: None, + ) + + +@pytest.fixture +def client_mock(): + client = MagicMock( + spec=IbkrClient( + url=_URL_REST, + account_id=_ACCOUNT_ID, + timeout=_TIMEOUT_REST, + max_retries=_MAX_RETRIES_REST, + ) + ) + client.tickle.return_value.data = {'session': 'TEST_COOKIE'} + return client + + +@pytest.fixture +def ws_client(client_mock): + return IbkrWsClient( + url=_URL_WS, + ibkr_client=client_mock, + account_id=_ACCOUNT_ID, + subscription_processor_class=IbkrSubscriptionProcessor, + subscription_retries=_SUBSCRIPTION_RETRIES, + subscription_timeout=0.01, + cacert=False, + timeout=0.01, + max_connection_attempts=_MAX_RECONNECT_ATTEMPTS, + max_ping_interval=_MAX_PING_INTERVAL, + ) + + + +@pytest.fixture +def wsa_mock(): + return create_wsa_mock() + + +@pytest.fixture +def thread_mock(ws_client, wsa_mock): + thread_mock = MagicMock(spec=Thread) + thread_mock.start.side_effect = lambda: ws_client._run_websocket(wsa_mock) + return thread_mock + + +@pytest.fixture +def ws_app_factory(wsa_mock): + # Use a mutable side-effect so individual tests can temporarily override WebSocketApp behavior. + return { + 'fn': lambda *args, **kwargs: init_wsa_mock(wsa_mock, *args, **kwargs), + } + + +@pytest.fixture +def patched_constructors(mocker, thread_mock, ws_app_factory): + mocker.patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: ws_app_factory['fn'](*args, **kwargs)) + mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) + return None + + + +def _send_payload(ws_client, payload: dict): + success = ws_client.start() + ws_client.send(json.dumps(payload)) + ws_client.shutdown() + return success + + +def _subscribe(ws_client, wsa_mock, request: dict, response: Optional[dict]): + def override_on_message(wsa_mock: MagicMock, message: str): + if response is None: + return + raw_message = json.dumps(response) + wsa_mock.__on_message__(wsa_mock, raw_message) + + ws_client.start() + wsa_mock._on_message.side_effect = override_on_message + + rv = ws_client.subscribe( + **{ + 'channel': request.get('channel'), + 'data': request.get('data'), + 'needs_confirmation': request.get('needs_confirmation'), + } + ) + ws_client.unsubscribe( + **{ + 'channel': request.get('channel'), + 'data': request.get('data'), + 'needs_confirmation': request.get('confirms_unsubscription'), + } + ) + ws_client.shutdown() + return rv + + + +def _logs_subscriptions(full_channel, data=None, needs_confirmation_sub: bool = False, needs_confirmation_unsub: bool = True): + return [ + f'IbkrWsClient: Subscribed: s{full_channel}{"" if data is None else f"+{json.dumps(data)}"}{"" if not needs_confirmation_sub else " without confirmation."}', + f'IbkrWsClient: Unsubscribed: u{full_channel}+{json.dumps(data if data is not None else {})}{"" if not needs_confirmation_unsub else " without confirmation."}', + ] + + +# -------------------------------------------------------------------------------------- +# Message preprocessing +# -------------------------------------------------------------------------------------- + + +def test_preprocess_with_well_formed_message(preprocess_ws_client): + """Preprocesses a well-formed raw message into (message, topic, data, subscribed, channel).""" + ## Arrange + raw_message = json.dumps({'topic': 'actABC', 'args': {'key': 'value'}}) + expected_result = ( + {'topic': 'actABC', 'args': {'key': 'value'}}, # message + 'actABC', # topic + {'key': 'value'}, # data + 'a', # subscribed + 'ctABC', # channel + ) + + ## Act + rv = preprocess_ws_client._preprocess_raw_message(raw_message) + + ## Assert + assert rv == expected_result + + +def test_preprocess_with_unsubscribed_message(preprocess_ws_client): + """Returns empty preprocess result for unsubscribed messages.""" + ## Arrange + raw_message = json.dumps({'message': 'Unsubscribed'}) + + ## Act + rv = preprocess_ws_client._preprocess_raw_message(raw_message) + + ## Assert + assert rv == ({'message': 'Unsubscribed'}, None, None, None, None) + + +# -------------------------------------------------------------------------------------- +# On-message handling +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_on_message_system_heartbeat(ws_client, patched_constructors): + """Updates last heartbeat on system heartbeat message.""" + ## Arrange + hb = 12345678 + + ## Act + _send_payload(ws_client, {'topic': 'system', 'hb': hb}) + + ## Assert + assert ws_client._last_heartbeat == hb + +@capture_logs(logger_level='DEBUG', expected_errors = ["IbkrWsClient: Account ID mismatch: expected=TEST_ACCOUNT_ID, received=['OTHER_ACCOUNT_ID']"]) +def test_on_message_act_account_mismatch(ws_client, patched_constructors): + """Logs a warning when account list in act message mismatches expected account.""" + ## Act + _send_payload(ws_client, {'topic': 'act', 'args': {'accounts': ['OTHER_ACCOUNT_ID']}}) + + +@capture_logs(logger_level='DEBUG') +def test_on_message_blt(ws_client, patched_constructors, mocker): + """Dispatches bulletin messages to _handle_bulletin.""" + ## Arrange + bulletin_message = {'topic': 'blt', 'args': {'bulletin_key': 'some_info'}} + mock_handle_bulletin = mocker.patch.object(ws_client, '_handle_bulletin', MagicMock()) + + ## Act + _send_payload(ws_client, bulletin_message) + + ## Assert + mock_handle_bulletin.assert_called_once_with(bulletin_message) + +@capture_logs(logger_level='DEBUG', expected_errors=[ + "IbkrWsClient: Status unauthenticated: {'authenticated': False}", + 'IbkrWsClient: Not authenticated, closing WebSocketApp', +]) +def test_on_message_sts_unauthenticated(ws_client, client_mock, patched_constructors, mocker): + """On unauthenticated status, refetches session and closes websocket.""" + ## Arrange + message_data = {'topic': 'sts', 'args': {'authenticated': False}} + session_id = 6545676 + + response_mock = MagicMock(spec=requests.Response) + response_mock.status_code = 200 + response_mock.json.return_value = {'session': session_id, 'data_to_be_ignored': '1234'} + + client_mock.tickle.return_value = Result(data=response_mock.json.return_value) + + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response_mock + + ## Act + _send_payload(ws_client, message_data) + + ## Assert + assert ws_client._authenticated is False + +@capture_logs(logger_level='DEBUG') +def test_on_message_sts_authenticated(ws_client, patched_constructors): + """Accepts authenticated status without logging warnings.""" + ## Act + _send_payload(ws_client, {'topic': 'sts', 'args': {'authenticated': True}}) + + +@capture_logs(logger_level='DEBUG', expected_errors = [f'IbkrWsClient: Error message:'], partial_match=True) +def test_on_message_error(ws_client, patched_constructors): + """Logs error-topic messages as warnings.""" + ## Act + _send_payload(ws_client, {'topic': 'error', 'args': {'error_key': 'error_details'}}) + + + +@capture_logs(logger_level='DEBUG', expected_errors=['unrecognised. Message:'], partial_match=True) +def test_on_message_no_topic_handler(ws_client, patched_constructors): + """Logs a warning when no handler exists for a topic.""" + ## Arrange + message_data = {'topic': 'unrecognized_topic', 'args': {'some_key': 'some_value'}} + + ## Act + _send_payload(ws_client, message_data) + + +@capture_logs(logger_level='DEBUG', expected_errors = [ + 'message that is missing a subscription. Message:' +], partial_match=True) +def test_on_message_handled_without_subscription(ws_client, patched_constructors, mocker): + """Logs a warning if a subscribed message arrives without a known subscription.""" + ## Arrange + mocker.patch.object(ws_client, '_handle_subscribed_message', return_value=True) + + ## Act + _send_payload(ws_client, {'topic': 'some_topic', 'args': {'channel': 'XYZ', 'data': 'info'}}) + + + +# -------------------------------------------------------------------------------------- +# Subscription + channel-specific handling +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_on_message_market_data_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes market data updates into the MARKET_DATA queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.MARKET_DATA) + full_channel = f'{queue.key.channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'data': {'fields': ['55', '71', '84', '86', '88', '85', '87', '7295', '7296', '70']}, + } + response = { + 'topic': f's{full_channel}', + 'conid': _CONID, + '_updated': _UPDATE_TIME, + 55: 'AAPL', + 70: '195.34', + 71: '193.67', + 87: '24.2M', + 7295: '194.10', + 84: '195.25', + 86: '195.26', + 88: '3,500', + 85: '500', + 6508: '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, request['data'])) + assert ( + { + _CONID: { + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'topic': f'smd+{_CONID}', + 'ask_price': '195.26', + 'ask_size': '500', + 'bid_price': '195.25', + 'bid_size': '3,500', + 'high': '195.34', + 'low': '193.67', + 'open': '194.10', + 'service_params': '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', + 'symbol': 'AAPL', + 'volume': '24.2M', + } + } + == queue.get() + ) + + +@capture_logs(logger_level='DEBUG') +def test_on_message_market_history_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes market history updates into the MARKET_HISTORY queue and tracks server IDs.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.MARKET_HISTORY) + server_id = 87567 + full_channel = f'{queue.key.channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'data': {'period': '1min', 'bar': '1min', 'outsideRTH': True, 'source': 'trades', 'format': '%o/%c/%h/%l'}, + 'confirms_unsubscription': False, + } + response = { + 'topic': f's{full_channel}', + 'serverId': server_id, + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'foo': 'bar', + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, request['data'])) + assert response == queue.get() + assert server_id in ws_client.server_ids(IbkrWsKey.MARKET_HISTORY) + + +@capture_logs(logger_level='DEBUG') +def test_on_message_trade_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes trade updates into the TRADES queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.TRADES) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}'} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel)) + assert response == queue.get() + + +@capture_logs(logger_level='DEBUG') +def test_on_message_orders_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes order updates into the ORDERS queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.ORDERS) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}'} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, None, True, True)) + assert response == queue.get() + + +@capture_logs(logger_level='DEBUG') +def test_subscription_without_confirmation(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Subscribes/unsubscribes without confirmation when requested.""" + ## Arrange + cm = kwargs['_cm_ibind'] + channel = 'fake' + full_channel = f'{channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'needs_confirmation': False, + 'confirms_unsubscription': False, + } + response = None + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log([ + f'IbkrWsClient: Subscribed: s{full_channel} without confirmation.', + f'IbkrWsClient: Unsubscribed: u{full_channel}+{{}} without confirmation.', + ]) + + + +# -------------------------------------------------------------------------------------- +# Health checks +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG', expected_errors=[ + f'IbkrWsClient: Last IBKR heartbeat happened 162.00 seconds ago, exceeding the max ping interval of {_MAX_PING_INTERVAL}. Restarting.', +]) +def test_check_health(ws_client, wsa_mock, ws_app_factory, patched_constructors, mocker, **kwargs): + """Restarts and recreates subscriptions when heartbeat exceeds max ping interval.""" + ## Arrange + cm = kwargs['_cm_ibind'] + start_time = [100] + has_active_connection_counter = [0] + + def fake_time(): + start_time[0] += 100 + return start_time[0] + + def has_active_connection(): + has_active_connection_counter[0] += 1 + if has_active_connection_counter[0] <= 2: + return False + return True + + queue = ws_client.new_queue_accessor(IbkrWsKey.TRADES) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}', 'data': {'foo': 'bar'}} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + ## Act + def override_init_wsa_mock(wsa_mock: MagicMock, *args, **kwargs): + wsa_mock = init_wsa_mock(wsa_mock, *args, **kwargs) + wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) + return wsa_mock + + ws_client.start() + ws_client.check_health() + wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) + + ws_client.subscribe(**request) + + # Override time, ignore ping check, and control active-connection health checks. + time_mock = mocker.patch('ibind.client.ibkr_ws_client.time') + time_mock.time.side_effect = fake_time + + mocker.patch.object(ws_client, 'check_ping', return_value=True) + mocker.patch.object(ws_client, '_has_active_connection', side_effect=has_active_connection) + + # Ensure each reconnect creates a WebSocketApp whose on_message pushes our fake response. + ws_app_factory['fn'] = lambda *args, **kwargs: override_init_wsa_mock(wsa_mock, *args, **kwargs) + + ws_client._last_heartbeat = _MAX_PING_INTERVAL * 1000 + ws_client.check_health() + + assert ws_client.ready() is True + assert [call()] * 6 == ws_client._has_active_connection.call_args_list + + ws_client.shutdown() + + + ## Assert + channel_subscribed_log = f'IbkrWsClient: Subscribed: s{full_channel}+{json.dumps(request["data"])}' + cm.partial_log( + [channel_subscribed_log] + + [ + f'IbkrWsClient: Invalidated subscription: {full_channel}', + f"IbkrWsClient: Recreating 1/1 subscriptions: {{'{full_channel}': {{'status': False, 'data': {request['data']}, 'needs_confirmation': True, 'subscription_processor': None}}}}", + channel_subscribed_log, + f'IbkrWsClient: Invalidated subscription: {full_channel}', + ] + ) \ No newline at end of file diff --git a/test/test_utils_new.py b/test/test_utils_new.py index 6ec520c3..c14e5487 100644 --- a/test/test_utils_new.py +++ b/test/test_utils_new.py @@ -101,7 +101,7 @@ class CaptureLogsContext: def __init__( self, - logger='slog', + logger='ibind', level='DEBUG', logger_level: str = None, error_level='WARNING', @@ -305,4 +305,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.target_module_obj.time = self.original_time_module def mock_module_time(target_module, time_sequence=None, start_time=0.0): - return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) + return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) \ No newline at end of file From b8b99b44e61b87fbdeaa0b86c06d79885515e1eb Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:30:29 +0100 Subject: [PATCH 19/31] chore: temporarily added coverage old and new reports --- cov_new.txt | 32 ++++++++++++++++++++++++++++++++ cov_old.txt | 32 ++++++++++++++++++++++++++++++++ test/migration_plan.md | 8 ++++---- 3 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 cov_new.txt create mode 100644 cov_old.txt diff --git a/cov_new.txt b/cov_new.txt new file mode 100644 index 00000000..b69eaaed --- /dev/null +++ b/cov_new.txt @@ -0,0 +1,32 @@ +---------- coverage: platform win32, python 3.13.11-final-0 ---------- +Name Stmts Miss Cover Missing +----------------------------------------------------------------------------------- +ibind\__init__.py 13 0 100% +ibind\base\__init__.py 0 0 100% +ibind\base\queue_controller.py 18 0 100% +ibind\base\rest_client.py 152 35 77% 103, 110, 235, 240, 252, 265-275, 280-284, 296-297, 306-308, 311, 334-337, 340-346, 355 +ibind\base\subscription_controller.py 125 33 74% 65, 67-69, 74-75, 86, 89, 92, 100-101, 154, 173-191, 211, 283, 286, 289-292, 326, 334, 358 +ibind\base\ws_client.py 217 38 82% 64, 90, 95, 118-120, 124-129, 151, 155-156, 168-169, 190, 196-198, 202-204, 238, 247, 254-256, 317-322, 363, 439-440, 457, 470 +ibind\client\__init__.py 0 0 100% +ibind\client\ibkr_client.py 119 69 42% 87-91, 116-119, 132-135, 143-149, 163-164, 195-218, 234-237, 250-251, 254-256, 265-267, 270-272, 286, 289-333 +ibind\client\ibkr_client_mixins\__init__.py 0 0 100% +ibind\client\ibkr_client_mixins\accounts_mixin.py 4 0 100% +ibind\client\ibkr_client_mixins\contract_mixin.py 25 0 100% +ibind\client\ibkr_client_mixins\marketdata_mixin.py 61 28 54% 54-85, 226, 236-241 +ibind\client\ibkr_client_mixins\order_mixin.py 22 12 45% 102-110, 175-183 +ibind\client\ibkr_client_mixins\portfolio_mixin.py 5 0 100% +ibind\client\ibkr_client_mixins\scanner_mixin.py 5 0 100% +ibind\client\ibkr_client_mixins\session_mixin.py 39 13 67% 105, 130-145 +ibind\client\ibkr_client_mixins\watchlist_mixin.py 4 0 100% +ibind\client\ibkr_definitions.py 6 0 100% +ibind\client\ibkr_utils.py 226 42 81% 222, 310-313, 316, 440-441, 446, 603-606, 617-620, 639-642, 645-657, 666-672, 684-689 +ibind\client\ibkr_ws_client.py 238 65 73% 274, 277-281, 321, 326-328, 331, 351, 384, 390-394, 403-416, 422-423, 431-442, 447-460, 480, 484, 488, 491, 504, 517, 535, 538, 551, 702-714 +ibind\oauth\__init__.py 26 26 0% 1-58 +ibind\oauth\oauth1a.py 164 164 0% 1-466 +ibind\support\__init__.py 0 0 100% +ibind\support\errors.py 4 0 100% +ibind\support\logs.py 82 13 84% 23-29, 94, 96, 136, 143, 164, 172-173 +ibind\support\py_utils.py 87 25 71% 143, 153-156, 169, 308, 324-332, 336-354 +ibind\var.py 88 4 95% 24-25, 36-37 +----------------------------------------------------------------------------------- +TOTAL 1730 567 67% \ No newline at end of file diff --git a/cov_old.txt b/cov_old.txt new file mode 100644 index 00000000..40630161 --- /dev/null +++ b/cov_old.txt @@ -0,0 +1,32 @@ +---------- coverage: platform win32, python 3.13.11-final-0 ---------- +Name Stmts Miss Cover Missing +----------------------------------------------------------------------------------- +ibind\__init__.py 13 0 100% +ibind\base\__init__.py 0 0 100% +ibind\base\queue_controller.py 18 0 100% +ibind\base\rest_client.py 152 35 77% 103, 110, 235, 240, 252, 265-275, 280-284, 296-297, 306-308, 311, 334-337, 340-346, 355 +ibind\base\subscription_controller.py 125 33 74% 65, 67-69, 74-75, 86, 89, 92, 100-101, 154, 173-191, 211, 283, 286, 289-292, 326, 334, 358 +ibind\base\ws_client.py 217 38 82% 64, 90, 95, 118-120, 124-129, 151, 155-156, 168-169, 190, 196-198, 202-204, 238, 247, 254-256, 317-322, 363, 439-440, 457, 470 +ibind\client\__init__.py 0 0 100% +ibind\client\ibkr_client.py 119 69 42% 87-91, 116-119, 132-135, 143-149, 163-164, 195-218, 234-237, 250-251, 254-256, 265-267, 270-272, 286, 289-333 +ibind\client\ibkr_client_mixins\__init__.py 0 0 100% +ibind\client\ibkr_client_mixins\accounts_mixin.py 4 0 100% +ibind\client\ibkr_client_mixins\contract_mixin.py 25 0 100% +ibind\client\ibkr_client_mixins\marketdata_mixin.py 61 29 52% 54-85, 226, 236-241, 303 +ibind\client\ibkr_client_mixins\order_mixin.py 22 12 45% 102-110, 175-183 +ibind\client\ibkr_client_mixins\portfolio_mixin.py 5 0 100% +ibind\client\ibkr_client_mixins\scanner_mixin.py 5 0 100% +ibind\client\ibkr_client_mixins\session_mixin.py 39 13 67% 105, 130-145 +ibind\client\ibkr_client_mixins\watchlist_mixin.py 4 0 100% +ibind\client\ibkr_definitions.py 6 0 100% +ibind\client\ibkr_utils.py 226 42 81% 222, 310-313, 316, 440-441, 446, 603-606, 617-620, 639-642, 645-657, 666-672, 684-689 +ibind\client\ibkr_ws_client.py 238 65 73% 274, 277-281, 321, 326-328, 331, 351, 384, 390-394, 403-416, 422-423, 431-442, 447-460, 480, 484, 488, 491, 504, 517, 535, 538, 551, 702-714 +ibind\oauth\__init__.py 26 26 0% 1-58 +ibind\oauth\oauth1a.py 164 164 0% 1-466 +ibind\support\__init__.py 0 0 100% +ibind\support\errors.py 4 0 100% +ibind\support\logs.py 82 57 30% 20-29, 75-96, 121-141, 150-153, 156-157, 160, 163-168, 171-175 +ibind\support\py_utils.py 87 25 71% 143, 153-156, 169, 308, 324-332, 336-354 +ibind\var.py 88 4 95% 24-25, 36-37 +----------------------------------------------------------------------------------- +TOTAL 1730 612 65% \ No newline at end of file diff --git a/test/migration_plan.md b/test/migration_plan.md index 0f61570e..f5e0a7f1 100644 --- a/test/migration_plan.md +++ b/test/migration_plan.md @@ -90,7 +90,7 @@ The following files need to be migrated. Each file can be worked on independentl --- -### 2. [] `test/integration/base/test_websocket_client_i.py` +### 2. [✔] `test/integration/base/test_websocket_client_i.py` - **Migration Steps:** 1. Create a new file: `test/integration/base/test_websocket_client_i_new.py`. @@ -113,7 +113,7 @@ The following files need to be migrated. Each file can be worked on independentl --- -### 4. [ ] `test/integration/client/test_ibkr_utils_i.py` +### 4. [✔] `test/integration/client/test_ibkr_utils_i.py` - **Migration Steps:** 1. Create a new file: `test/integration/client/test_ibkr_utils_i_new.py`. @@ -124,7 +124,7 @@ The following files need to be migrated. Each file can be worked on independentl --- -### 5. [ ] `test/integration/client/test_ibkr_ws_client_i.py` +### 5. [✔] `test/integration/client/test_ibkr_ws_client_i.py` - **Migration Steps:** 1. Create a new file: `test/integration/client/test_ibkr_ws_client_i_new.py`. @@ -135,7 +135,7 @@ The following files need to be migrated. Each file can be worked on independentl --- -### 6. [ ] `test/unit/support/test_py_utils_u.py` +### 6. [✔] `test/unit/support/test_py_utils_u.py` - **Migration Steps:** 1. Create a new file: `test/unit/support/test_py_utils_u_new.py`. From 9b618b2fd0dbbbebfefebbdad70a0cc6023b8024 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:36:52 +0100 Subject: [PATCH 20/31] test: updated pytest.ini --- pytest.ini | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytest.ini b/pytest.ini index 4a56d57d..eed1abd2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,11 @@ -[tool:pytest] -testpaths = test -pythonpath = . test +[pytest] +pythonpath = . ./test ./ibind +testpaths = + test + test/integration + test/unit addopts = -v --tb=short python_files = test_*.py python_classes = Test* -python_functions = test_* \ No newline at end of file +python_functions = test_* +norecursedirs = .* __pycache__ data .pytest_cache \ No newline at end of file From 8eb4c496c150d1f110a4422bd51b9722357dda30 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:37:04 +0100 Subject: [PATCH 21/31] test: small import fixes to new test files --- test/integration/base/test_websocket_client_i.py | 4 ++-- test/integration/base/test_websocket_client_i_new.py | 4 ++-- test/integration/client/test_ibkr_client_i_new.py | 4 ++-- test/integration/client/test_ibkr_ws_client_i_new.py | 4 ++-- test/test_utils_new.py | 7 +------ 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/test/integration/base/test_websocket_client_i.py b/test/integration/base/test_websocket_client_i.py index 58529583..dc632cd3 100644 --- a/test/integration/base/test_websocket_client_i.py +++ b/test/integration/base/test_websocket_client_i.py @@ -5,7 +5,7 @@ from ibind.base.ws_client import WsClient from ibind.support.py_utils import tname -from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock +from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock from test_utils import RaiseLogsContext, exact_log @@ -270,4 +270,4 @@ def run(): self._logs_hard_restart_error() + self._logs_start_success_end() + self._logs_shutdown_success(), - ) + ) \ No newline at end of file diff --git a/test/integration/base/test_websocket_client_i_new.py b/test/integration/base/test_websocket_client_i_new.py index 3a9b2fac..aa3d1221 100644 --- a/test/integration/base/test_websocket_client_i_new.py +++ b/test/integration/base/test_websocket_client_i_new.py @@ -6,8 +6,8 @@ from ibind.base.ws_client import WsClient from ibind.support.py_utils import tname -from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test.test_utils_new import capture_logs +from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock +from test_utils_new import capture_logs _URL = 'wss://localhost:5000/v1/api/ws' _MAX_RECONNECT_ATTEMPTS = 4 diff --git a/test/integration/client/test_ibkr_client_i_new.py b/test/integration/client/test_ibkr_client_i_new.py index f95f35c1..b29c1a0f 100644 --- a/test/integration/client/test_ibkr_client_i_new.py +++ b/test/integration/client/test_ibkr_client_i_new.py @@ -10,8 +10,8 @@ from ibind.client.ibkr_utils import StockQuery, filter_stocks from ibind.support.errors import ExternalBrokerError from ibind.support.logs import ibind_logs_initialize -from test.integration.client import ibkr_responses -from test.test_utils_new import CaptureLogsContext +from integration.client import ibkr_responses +from test_utils_new import CaptureLogsContext _URL = 'https://localhost:5000' diff --git a/test/integration/client/test_ibkr_ws_client_i_new.py b/test/integration/client/test_ibkr_ws_client_i_new.py index 84df04fa..60da7ab7 100644 --- a/test/integration/client/test_ibkr_ws_client_i_new.py +++ b/test/integration/client/test_ibkr_ws_client_i_new.py @@ -9,8 +9,8 @@ from ibind import Result from ibind.client.ibkr_client import IbkrClient from ibind.client.ibkr_ws_client import IbkrWsClient, IbkrSubscriptionProcessor, IbkrWsKey -from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test.test_utils_new import capture_logs +from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock +from test_utils_new import capture_logs _URL_WS = 'wss://localhost:5000/v1/api/ws' _URL_REST = 'https://localhost:5000' diff --git a/test/test_utils_new.py b/test/test_utils_new.py index c14e5487..8fb7a233 100644 --- a/test/test_utils_new.py +++ b/test/test_utils_new.py @@ -127,12 +127,7 @@ def _monkey_patch_log(self, logger): def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): if extra is None: extra = {} - # Check if make_clean_stack accepts extra_filters. This is necessary - # because the signature of make_clean_stack is unstable. - if 'extra_filters' in inspect.signature(make_clean_stack).parameters: - extra['manual_trace'] = make_clean_stack(extra_filters=[os.path.join('support', 'slog.py')])[:-2] - else: - extra['manual_trace'] = make_clean_stack()[:-2] + extra['manual_trace'] = make_clean_stack()[:-2] return original_log(level, msg, args, exc_info, extra, stack_info, stacklevel) From 222e6d4bb83a541c370a4abbfa4a494a699172b0 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:38:22 +0100 Subject: [PATCH 22/31] test: replaced unittest test files with the new pytest files --- test/integration/base/test_rest_client_i.py | 352 ++++--- .../base/test_rest_client_i_new.py | 221 ----- .../base/test_websocket_client_i.py | 631 ++++++++----- .../base/test_websocket_client_i_new.py | 396 -------- test/integration/client/test_ibkr_client_i.py | 602 +++++++----- .../client/test_ibkr_client_i_new.py | 361 -------- test/integration/client/test_ibkr_utils_i.py | 657 +++++++------ .../client/test_ibkr_utils_i_new.py | 394 -------- .../client/test_ibkr_ws_client_i.py | 868 ++++++++++-------- .../client/test_ibkr_ws_client_i_new.py | 539 ----------- test/unit/support/test_py_utils_u.py | 303 ++++-- test/unit/support/test_py_utils_u_new.py | 228 ----- 12 files changed, 1999 insertions(+), 3553 deletions(-) delete mode 100644 test/integration/base/test_rest_client_i_new.py delete mode 100644 test/integration/base/test_websocket_client_i_new.py delete mode 100644 test/integration/client/test_ibkr_client_i_new.py delete mode 100644 test/integration/client/test_ibkr_utils_i_new.py delete mode 100644 test/integration/client/test_ibkr_ws_client_i_new.py delete mode 100644 test/unit/support/test_py_utils_u_new.py diff --git a/test/integration/base/test_rest_client_i.py b/test/integration/base/test_rest_client_i.py index 693627af..1feda7b8 100644 --- a/test/integration/base/test_rest_client_i.py +++ b/test/integration/base/test_rest_client_i.py @@ -1,149 +1,221 @@ -import threading -from unittest import TestCase -from unittest.mock import patch, MagicMock import asyncio +import logging +import threading + +import pytest +from unittest.mock import MagicMock from requests import ReadTimeout, Timeout from ibind.client.ibkr_client import IbkrClient from ibind.support.errors import ExternalBrokerError -from ibind.support.logs import project_logger from ibind.base.rest_client import Result, RestClient +from ibind.support.logs import ibind_logs_initialize +from test.test_utils_new import CaptureLogsContext + + +_URL = 'https://localhost:5000' +_TIMEOUT = 8 +_MAX_RETRIES = 4 +_DEFAULT_PATH = 'test/api/route' + + +@pytest.fixture +def client(): + ibind_logs_initialize(log_to_console=True) + return RestClient( + url=_URL, + timeout=_TIMEOUT, + max_retries=_MAX_RETRIES, + use_session=False, + ) + + +@pytest.fixture +def data(): + return {'Test key': 'Test value'} + + +@pytest.fixture +def response(data): + response = MagicMock() + response.json.return_value = data + return response + + +@pytest.fixture(autouse=True) +def requests_mock(mocker, response): + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + return requests_mock + + +@pytest.fixture +def default_url(): + return f'{_URL}/{_DEFAULT_PATH}' + + +@pytest.fixture +def result(data, default_url): + return Result(data=data, request={'url': default_url}) + + +def test_default_rest_get(client, default_url, result, requests_mock): + # Arrange + # Act + rv = client.get(_DEFAULT_PATH) + + # Assert + assert result == rv + requests_mock.request.assert_called_with('GET', default_url, verify=False, headers={}, timeout=_TIMEOUT) + + +def test_default_rest_post(client, default_url, result, requests_mock): + # Arrange + test_post_kwargs = {'field1': 'value1', 'field2': 'value2'} + test_json = {'json': {**test_post_kwargs}} + + # Act + rv = client.post(_DEFAULT_PATH, params=test_post_kwargs) + + # Assert + assert result.copy(request={'url': default_url, **test_json}) == rv + requests_mock.request.assert_called_with('POST', default_url, verify=False, headers={}, timeout=_TIMEOUT, **test_json) + + +def test_default_rest_delete(client, default_url, result, requests_mock): + # Arrange + # Act + rv = client.delete(_DEFAULT_PATH) + + # Assert + assert result == rv + requests_mock.request.assert_called_with('DELETE', default_url, verify=False, headers={}, timeout=_TIMEOUT) + + +def test_request_retries(client, default_url, requests_mock): + # Arrange + requests_mock.request.side_effect = ReadTimeout() + + # Act + with CaptureLogsContext('ibind.rest_client', level='INFO') as cm, pytest.raises(TimeoutError) as excinfo: + client.get(_DEFAULT_PATH) + + # Assert + for i in range(_MAX_RETRIES): + assert f'RestClient: Timeout for GET {default_url} {{}}, retrying attempt {i + 1}/{_MAX_RETRIES}' in cm.output + + assert f'RestClient: Reached max retries ({_MAX_RETRIES}) for GET {default_url} {{}}' == str(excinfo.value) + + +def test_response_raise_timeout(client, requests_mock): + # Arrange + requests_mock.request.return_value.raise_for_status.side_effect = Timeout() + + # Act + with pytest.raises(ExternalBrokerError) as excinfo: + client.get(_DEFAULT_PATH) + + # Assert + assert f'RestClient: Timeout error ({_TIMEOUT}S)' == str(excinfo.value) + + +def test_response_raise_generic(client, result, requests_mock): + # Arrange + response = requests_mock.request.return_value + response.status_code = 400 + response.reason = 'Test reason' + response.text = 'Test text' + response.raise_for_status.side_effect = ValueError('Test generic error') + + # Act + with pytest.raises(ExternalBrokerError) as excinfo: + client.get(_DEFAULT_PATH) + + # Assert + assert f'RestClient: response error {result.copy(data=None)} :: {response.status_code} :: {response.reason} :: {response.text}' == str(excinfo.value) + + +def _worker_in_thread(results: []): + try: + IbkrClient() + except Exception as e: + results.append(e) + + +def test_in_thread(): + """Run in thread ensuring client still is constructed without an exception.""" + # Arrange + results = [] + t = threading.Thread(target=_worker_in_thread, args=(results,)) + t.daemon = True + + # Act + t.start() + t.join(1) + + # Assert + for result in results: + if isinstance(result, Exception): + raise result + + +def test_without_thread(): + """Run without a thread to ensure it still works as expected.""" + # Arrange + results = [] + + # Act + _worker_in_thread(results) + + # Assert + for result in results: + if isinstance(result, Exception): + raise result + + +async def _async_worker(results: []): + """Async version of the worker function to run in an asyncio event loop.""" + try: + IbkrClient() + except Exception as e: + results.append(e) + + +def _worker_in_async_thread(results: []): + """Runs the async test inside a new thread to check if signal handling breaks.""" + try: + asyncio.run(_async_worker(results)) + except Exception as e: + results.append(e) + + +def test_in_thread_async(): + """Test that IbkrClient() does not break in an asyncio thread.""" + # Arrange + results = [] + t = threading.Thread(target=_worker_in_async_thread, args=(results,)) + t.daemon = True + + # Act + t.start() + t.join(1) + + # Assert + for result in results: + if isinstance(result, Exception): + raise result + + +def test_without_thread_async(): + """Test that IbkrClient() does not break in the main asyncio event loop.""" + # Arrange + results = [] + # Act + asyncio.run(_async_worker(results)) -@patch('ibind.base.rest_client.requests') -class TestRestClientI(TestCase): - def setUp(self): - self.url = 'https://localhost:5000' - self.account_id = 'TEST_ACCOUNT_ID' - self.timeout = 8 - self.max_retries = 4 - self.client = RestClient( - url=self.url, - timeout=self.timeout, - max_retries=self.max_retries, - use_session=False, - ) - - self.data = {'Test key': 'Test value'} - - self.response = MagicMock() - self.response.json.return_value = self.data - self.default_path = 'test/api/route' - self.default_url = f'{self.url}/{self.default_path}' - self.result = Result(data=self.data, request={'url': self.default_url}) - self.maxDiff = 9999 - - def test_default_rest(self, requests_mock): - requests_mock.request.return_value = self.response - - rv = self.client.get(self.default_path) - self.assertEqual(self.result, rv) - requests_mock.request.assert_called_with('GET', self.default_url, verify=False, headers={}, timeout=self.timeout) - - test_post_kwargs = {'field1': 'value1', 'field2': 'value2'} - test_json = {'json': {**test_post_kwargs}} - rv = self.client.post(self.default_path, params=test_post_kwargs) - self.assertEqual(self.result.copy(request={'url': self.default_url, **test_json}), rv) - requests_mock.request.assert_called_with('POST', self.default_url, verify=False, headers={}, timeout=self.timeout, **test_json) - - rv = self.client.delete(self.default_path) - self.assertEqual(self.result, rv) - requests_mock.request.assert_called_with('DELETE', self.default_url, verify=False, headers={}, timeout=self.timeout) - - def test_request_retries(self, requests_mock): - requests_mock.request.side_effect = ReadTimeout() - - with self.assertLogs(project_logger(), level='INFO') as cm, self.assertRaises(TimeoutError) as cm_err: - self.client.get(self.default_path) - - for i, record in enumerate(cm.records): - self.assertEqual(f'RestClient: Timeout for GET {self.default_url} {{}}, retrying attempt {i + 1}/{self.max_retries}', record.msg) - self.assertEqual(f'RestClient: Reached max retries ({self.max_retries}) for GET {self.default_url} {{}}', str(cm_err.exception)) - - def test_response_raise_timeout(self, requests_mock): - requests_mock.request.return_value = self.response - self.response.raise_for_status.side_effect = Timeout() - - with self.assertRaises(ExternalBrokerError) as cm_err: - self.client.get(self.default_path) - - self.assertEqual(f'RestClient: Timeout error ({self.timeout}S)', str(cm_err.exception)) - - def test_response_raise_generic(self, requests_mock): - requests_mock.request.return_value = self.response - self.response.status_code = 400 - self.response.reason = 'Test reason' - self.response.text = 'Test text' - - self.response.raise_for_status.side_effect = ValueError('Test generic error') - - with self.assertRaises(ExternalBrokerError) as cm_err: - self.client.get(self.default_path) - - self.assertEqual( - f'RestClient: response error {self.result.copy(data=None)} :: {self.response.status_code} :: {self.response.reason} :: {self.response.text}', - str(cm_err.exception), - ) - - -class TestRestClientInThread(TestCase): - def _worker(self, results: []): - try: - IbkrClient() - except Exception as e: - results.append(e) - - def test_in_thread(self): - """Run in thread ensuring client still is constructed without an exception.""" - results = [] - t = threading.Thread(target=self._worker, args=(results,)) - t.daemon = True - t.start() - t.join(1) - for result in results: - if isinstance(result, Exception): - raise result - - def test_without_thread(self): - """Run without a thread to ensure it still works as expected.""" - results = [] - self._worker(results) - for result in results: - if isinstance(result, Exception): - raise result - - -class TestRestClientAsync(TestCase): - def _worker(self, results: []): - """Runs the async test inside a new thread to check if signal handling breaks.""" - try: - asyncio.run(self._async_worker(results)) - except Exception as e: - results.append(e) - - async def _async_worker(self, results: []): - """Async version of the worker function to run in an asyncio event loop.""" - try: - IbkrClient() - except Exception as e: - results.append(e) - - def test_in_thread_async(self): - """Test that IbkrClient() does not break in an asyncio thread.""" - results = [] - t = threading.Thread(target=self._worker, args=(results,)) - t.daemon = True - t.start() - t.join(1) - for result in results: - if isinstance(result, Exception): - raise result - - def test_without_thread_async(self): - """Test that IbkrClient() does not break in the main asyncio event loop.""" - results = [] - asyncio.run(self._async_worker(results)) - for result in results: - if isinstance(result, Exception): - raise result + # Assert + for result in results: + if isinstance(result, Exception): + raise result \ No newline at end of file diff --git a/test/integration/base/test_rest_client_i_new.py b/test/integration/base/test_rest_client_i_new.py deleted file mode 100644 index 1feda7b8..00000000 --- a/test/integration/base/test_rest_client_i_new.py +++ /dev/null @@ -1,221 +0,0 @@ -import asyncio -import logging -import threading - -import pytest -from unittest.mock import MagicMock - -from requests import ReadTimeout, Timeout - -from ibind.client.ibkr_client import IbkrClient -from ibind.support.errors import ExternalBrokerError -from ibind.base.rest_client import Result, RestClient -from ibind.support.logs import ibind_logs_initialize -from test.test_utils_new import CaptureLogsContext - - -_URL = 'https://localhost:5000' -_TIMEOUT = 8 -_MAX_RETRIES = 4 -_DEFAULT_PATH = 'test/api/route' - - -@pytest.fixture -def client(): - ibind_logs_initialize(log_to_console=True) - return RestClient( - url=_URL, - timeout=_TIMEOUT, - max_retries=_MAX_RETRIES, - use_session=False, - ) - - -@pytest.fixture -def data(): - return {'Test key': 'Test value'} - - -@pytest.fixture -def response(data): - response = MagicMock() - response.json.return_value = data - return response - - -@pytest.fixture(autouse=True) -def requests_mock(mocker, response): - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response - return requests_mock - - -@pytest.fixture -def default_url(): - return f'{_URL}/{_DEFAULT_PATH}' - - -@pytest.fixture -def result(data, default_url): - return Result(data=data, request={'url': default_url}) - - -def test_default_rest_get(client, default_url, result, requests_mock): - # Arrange - # Act - rv = client.get(_DEFAULT_PATH) - - # Assert - assert result == rv - requests_mock.request.assert_called_with('GET', default_url, verify=False, headers={}, timeout=_TIMEOUT) - - -def test_default_rest_post(client, default_url, result, requests_mock): - # Arrange - test_post_kwargs = {'field1': 'value1', 'field2': 'value2'} - test_json = {'json': {**test_post_kwargs}} - - # Act - rv = client.post(_DEFAULT_PATH, params=test_post_kwargs) - - # Assert - assert result.copy(request={'url': default_url, **test_json}) == rv - requests_mock.request.assert_called_with('POST', default_url, verify=False, headers={}, timeout=_TIMEOUT, **test_json) - - -def test_default_rest_delete(client, default_url, result, requests_mock): - # Arrange - # Act - rv = client.delete(_DEFAULT_PATH) - - # Assert - assert result == rv - requests_mock.request.assert_called_with('DELETE', default_url, verify=False, headers={}, timeout=_TIMEOUT) - - -def test_request_retries(client, default_url, requests_mock): - # Arrange - requests_mock.request.side_effect = ReadTimeout() - - # Act - with CaptureLogsContext('ibind.rest_client', level='INFO') as cm, pytest.raises(TimeoutError) as excinfo: - client.get(_DEFAULT_PATH) - - # Assert - for i in range(_MAX_RETRIES): - assert f'RestClient: Timeout for GET {default_url} {{}}, retrying attempt {i + 1}/{_MAX_RETRIES}' in cm.output - - assert f'RestClient: Reached max retries ({_MAX_RETRIES}) for GET {default_url} {{}}' == str(excinfo.value) - - -def test_response_raise_timeout(client, requests_mock): - # Arrange - requests_mock.request.return_value.raise_for_status.side_effect = Timeout() - - # Act - with pytest.raises(ExternalBrokerError) as excinfo: - client.get(_DEFAULT_PATH) - - # Assert - assert f'RestClient: Timeout error ({_TIMEOUT}S)' == str(excinfo.value) - - -def test_response_raise_generic(client, result, requests_mock): - # Arrange - response = requests_mock.request.return_value - response.status_code = 400 - response.reason = 'Test reason' - response.text = 'Test text' - response.raise_for_status.side_effect = ValueError('Test generic error') - - # Act - with pytest.raises(ExternalBrokerError) as excinfo: - client.get(_DEFAULT_PATH) - - # Assert - assert f'RestClient: response error {result.copy(data=None)} :: {response.status_code} :: {response.reason} :: {response.text}' == str(excinfo.value) - - -def _worker_in_thread(results: []): - try: - IbkrClient() - except Exception as e: - results.append(e) - - -def test_in_thread(): - """Run in thread ensuring client still is constructed without an exception.""" - # Arrange - results = [] - t = threading.Thread(target=_worker_in_thread, args=(results,)) - t.daemon = True - - # Act - t.start() - t.join(1) - - # Assert - for result in results: - if isinstance(result, Exception): - raise result - - -def test_without_thread(): - """Run without a thread to ensure it still works as expected.""" - # Arrange - results = [] - - # Act - _worker_in_thread(results) - - # Assert - for result in results: - if isinstance(result, Exception): - raise result - - -async def _async_worker(results: []): - """Async version of the worker function to run in an asyncio event loop.""" - try: - IbkrClient() - except Exception as e: - results.append(e) - - -def _worker_in_async_thread(results: []): - """Runs the async test inside a new thread to check if signal handling breaks.""" - try: - asyncio.run(_async_worker(results)) - except Exception as e: - results.append(e) - - -def test_in_thread_async(): - """Test that IbkrClient() does not break in an asyncio thread.""" - # Arrange - results = [] - t = threading.Thread(target=_worker_in_async_thread, args=(results,)) - t.daemon = True - - # Act - t.start() - t.join(1) - - # Assert - for result in results: - if isinstance(result, Exception): - raise result - - -def test_without_thread_async(): - """Test that IbkrClient() does not break in the main asyncio event loop.""" - # Arrange - results = [] - - # Act - asyncio.run(_async_worker(results)) - - # Assert - for result in results: - if isinstance(result, Exception): - raise result \ No newline at end of file diff --git a/test/integration/base/test_websocket_client_i.py b/test/integration/base/test_websocket_client_i.py index dc632cd3..aa3d1221 100644 --- a/test/integration/base/test_websocket_client_i.py +++ b/test/integration/base/test_websocket_client_i.py @@ -1,273 +1,396 @@ from threading import Thread from typing import Optional -from unittest import TestCase -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock + +import pytest from ibind.base.ws_client import WsClient from ibind.support.py_utils import tname from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils import RaiseLogsContext, exact_log - - -class TestWsClient(TestCase): - def setUp(self): - self.url = 'wss://localhost:5000/v1/api/ws' - self.max_reconnect_attempts = 4 - self.max_ping_interval = 38 - self.error_message = 'TEST_ERROR' - - self.ws_client = WsClient( - subscription_processor=None, - url=self.url, - cacert=False, - timeout=0.01, - max_connection_attempts=self.max_reconnect_attempts, - max_ping_interval=self.max_ping_interval, - ) - - self.wsa_mock = create_wsa_mock() - - self.thread_mock = MagicMock(spec=Thread) - self.thread_mock.start.side_effect = lambda: self.ws_client._run_websocket(self.wsa_mock) - - def run_in_test_context(self, fn, expected_errors: list[str] = None): - with patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: init_wsa_mock(self.wsa_mock, *args, **kwargs)), \ - patch('ibind.base.ws_client.Thread', return_value=self.thread_mock) as new_thread_mock, \ - self.assertLogs('ibind', level='DEBUG') as cm, \ - RaiseLogsContext(self, 'ibind', level='ERROR', expected_errors=expected_errors): # fmt: skip - self.new_thread_mock = new_thread_mock - rv = fn() - - return cm, rv - - def start(self): - success = self.ws_client.start() - self.new_thread_mock.assert_called_with(target=self.ws_client._run_websocket, args=(self.wsa_mock,), name='ws_client_thread') - return success - - def _logs_start_success_beginning(self): - return [ - 'WsClient: Starting', - 'WsClient: Trying to connect', - ] - - def _logs_start_success_end(self): - return [ - 'WsClient: Creating new WebSocketApp', - f'WsClient: Thread started ({tname()})', - 'WsClient: Connection open', - f'WsClient: Thread stopped ({tname()})', - ] - - def _logs_failed_attempt(self, attempt): - s = [ - 'WsClient: Creating new WebSocketApp', - 'WsClient: New WebSocketApp connection timeout', - 'WsClient: on_close', - 'WsClient: on_close event while disconnected', - ] - if attempt: - s.append(f'WsClient: Connect reattempt {attempt}/{self.max_reconnect_attempts}') - return s - - def _logs_shutdown_success(self): - return [ - 'WsClient: Shutting down', - 'WsClient: on_close', - 'WsClient: Connection closed', - 'WsClient: Gracefully stopped', - ] - - def _logs_exception_starting(self, error_message, thread_mock): - return [ - 'WsClient: Creating new WebSocketApp', - f'WsClient: Thread started ({tname()})', - f'WsClient: Unexpected error while running WebSocketApp: {error_message}', - 'WsClient: Hard reset, restart=False, self._wsa is None=False', - 'WsClient: Forced restart', - 'WsClient: Reconnecting', - f'WsClient: Thread already running: {thread_mock.name}-{thread_mock.ident}', - f'WsClient: Thread stopped ({tname()})', - 'WsClient: Reconnecting', - 'WsClient: Trying to connect', - ] - - def _logs_check_health_error(self, time_ago): - return [ - f'WsClient: Last WebSocket ping happened {time_ago} seconds ago, exceeding the max ping interval of {self.max_ping_interval}. Restarting.', - 'WsClient: Hard reset, restart=True, self._wsa is None=False', - 'WsClient: Hard reset is closing the WebSocketApp', - ] - - def _logs_hard_restart_error(self): - return [ - 'WsClient: Hard reset close timeout', - f'WsClient: Abandoning current WebSocketApp that cannot be closed: {self.wsa_mock}', - 'WsClient: Forced restart', - 'WsClient: Reconnecting', - 'WsClient: Trying to connect', - ] - - def _verify_started(self): - self.wsa_mock.run_forever.assert_called_with( - sslopt=self.ws_client._sslopt, ping_interval=self.ws_client._ping_interval, ping_timeout=0.95 * self.ws_client._ping_interval - ) - self.wsa_mock._on_open.assert_called_with(self.wsa_mock) - - def _verify_failed_starting(self): - self.wsa_mock.run_forever.assert_not_called() - self.wsa_mock._on_open.assert_not_called() - self.wsa_mock.close.assert_called() - - def test_start_success(self): - cm, success = self.run_in_test_context(self.start) - - self.assertTrue(success, 'Starting should succeed') - self._verify_started() - exact_log(self, cm, self._logs_start_success_beginning() + self._logs_start_success_end()) - - def test_start_success_on_second_attempt(self): - counter = [0] - - # ensure we fail to do anything on the first attempt, and succeed on the second - def delayed_start(): - if counter[0] >= 1: - self.ws_client._run_websocket(self.wsa_mock) - counter[0] += 1 - - self.thread_mock.start.side_effect = delayed_start - - expected_errors = ['WsClient: New WebSocketApp connection timeout'] - - cm, success = self.run_in_test_context(self.start, expected_errors=expected_errors) - - self._verify_started() - - exact_log(self, cm, self._logs_start_success_beginning() + self._logs_failed_attempt(2) + self._logs_start_success_end()) - self.thread_mock.join.assert_called_with(60) - # print("\n".join([r.msg for r in cm.records])) - - def test_start_reattempt_failure(self): - self.thread_mock.start.side_effect = lambda: None - - expected_errors = ['WsClient: New WebSocketApp connection timeout'] - - cm, success = self.run_in_test_context(self.start, expected_errors=expected_errors) - - self.assertFalse(success, 'Starting not succeed') - - self._verify_failed_starting() - - expected_logs = self._logs_start_success_beginning() - for i in range(self.max_reconnect_attempts): - if i < self.max_reconnect_attempts - 1: - expected_logs += self._logs_failed_attempt(i + 2) - else: - expected_logs += self._logs_failed_attempt(None) - expected_logs.append(f'WsClient: Connection failed after {self.max_reconnect_attempts} attempts') - exact_log(self, cm, expected_logs) - - self.assertFalse(self.wsa_mock.keep_running) - - def test_open_exception(self): - old_run_forever = self.wsa_mock.run_forever.side_effect - - def run(): - success = self.start() - self.ws_client.shutdown() - return success - - def run_forever_exception(wsa_mock: MagicMock, sslopt: dict = None, ping_interval: float = 0, ping_timeout: Optional[float] = None): - self.wsa_mock.run_forever.side_effect = old_run_forever - raise RuntimeError(self.error_message) - - self.wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever_exception(self.wsa_mock, *args, **kwargs) - - expected_errors = [f'WsClient: Unexpected error while running WebSocketApp: {self.error_message}'] - - cm, success = self.run_in_test_context(run, expected_errors=expected_errors) - - exact_log( - self, - cm, - self._logs_start_success_beginning() - + self._logs_exception_starting(self.error_message, self.thread_mock) - + self._logs_start_success_end() - + self._logs_shutdown_success(), - ) - - def test_open_and_close(self): - def run(): - success = self.start() - self.ws_client.shutdown() - return success - - cm, success = self.run_in_test_context(run) - - exact_log(self, cm, self._logs_start_success_beginning() + self._logs_start_success_end() + self._logs_shutdown_success()) - - def test_send(self): - def run(): - success = self.start() - self.ws_client.send('test') - self.ws_client.shutdown() - return success - - self.ws_client._on_message = MagicMock() - - cm, success = self.run_in_test_context(run) - - self.ws_client._on_message.assert_called_once_with(self.wsa_mock, 'test') +from test_utils_new import capture_logs + +_URL = 'wss://localhost:5000/v1/api/ws' +_MAX_RECONNECT_ATTEMPTS = 4 +_MAX_PING_INTERVAL = 38 +_ERROR_MESSAGE = 'TEST_ERROR' + + +# -------------------------------------------------------------------------------------- +# Log expectations +# -------------------------------------------------------------------------------------- + + +def _logs_start_success_beginning(): + return [ + 'WsClient: Starting', + 'WsClient: Trying to connect', + ] + + +def _logs_start_success_end(): + return [ + 'WsClient: Creating new WebSocketApp', + f'WsClient: Thread started ({tname()})', + 'WsClient: Connection open', + f'WsClient: Thread stopped ({tname()})', + ] + + +def _logs_failed_attempt(max_reconnect_attempts: int, attempt: Optional[int]): + logs = [ + 'WsClient: Creating new WebSocketApp', + 'WsClient: New WebSocketApp connection timeout', + 'WsClient: on_close', + 'WsClient: on_close event while disconnected', + ] + if attempt is not None: + logs.append(f'WsClient: Connect reattempt {attempt}/{max_reconnect_attempts}') + return logs - exact_log(self, cm, self._logs_start_success_beginning() + self._logs_start_success_end() + self._logs_shutdown_success()) - def test_send_without_start(self): - def run(): - self.ws_client.send('test') - self.ws_client.shutdown() +def _logs_shutdown_success(): + return [ + 'WsClient: Shutting down', + 'WsClient: on_close', + 'WsClient: Connection closed', + 'WsClient: Gracefully stopped', + ] - self.ws_client._on_message = MagicMock() - expected_errors = ['WsClient: Must be started before sending payloads'] +def _logs_exception_starting(error_message: str, thread_mock: MagicMock): + return [ + 'WsClient: Creating new WebSocketApp', + f'WsClient: Thread started ({tname()})', + f'WsClient: Unexpected error while running WebSocketApp: {error_message}', + 'WsClient: Hard reset, restart=False, self._wsa is None=False', + 'WsClient: Forced restart', + 'WsClient: Reconnecting', + f'WsClient: Thread already running: {thread_mock.name}-{thread_mock.ident}', + f'WsClient: Thread stopped ({tname()})', + 'WsClient: Reconnecting', + 'WsClient: Trying to connect', + ] - cm, success = self.run_in_test_context(run, expected_errors=expected_errors) - exact_log(self, cm, expected_errors) +def _logs_check_health_error(max_ping_interval: int, time_ago: str): + return [ + f'WsClient: Last WebSocket ping happened {time_ago} seconds ago, exceeding the max ping interval of {max_ping_interval}. Restarting.', + 'WsClient: Hard reset, restart=True, self._wsa is None=False', + 'WsClient: Hard reset is closing the WebSocketApp', + ] - def test_check_ping(self): - start_time = [100] - def fake_time(): - start_time[0] += 100 - return start_time[0] +def _logs_hard_restart_error(wsa_mock: MagicMock): + return [ + 'WsClient: Hard reset close timeout', + f'WsClient: Abandoning current WebSocketApp that cannot be closed: {wsa_mock}', + 'WsClient: Forced restart', + 'WsClient: Reconnecting', + 'WsClient: Trying to connect', + ] - def run(): - self.ws_client.start() - self.ws_client.check_ping() - # we simulate that closing the WebSocketApp doesn't work since we have connectivity issues - self.wsa_mock._on_close.side_effect = lambda x, y, z: None - with patch('ibind.base.ws_client.time') as time_mock: - time_mock.time.side_effect = fake_time - self.wsa_mock.last_ping_tm = self.max_ping_interval - self.ws_client.check_ping() - self.assertTrue(self.ws_client.ready()) - self.ws_client.shutdown() - self.ws_client._on_message = MagicMock() +def _verify_started(ws_client: WsClient, wsa_mock: MagicMock): + wsa_mock.run_forever.assert_called_with( + sslopt=ws_client._sslopt, + ping_interval=ws_client._ping_interval, + ping_timeout=0.95 * ws_client._ping_interval, + ) + wsa_mock._on_open.assert_called_with(wsa_mock) + + +def _verify_failed_starting(wsa_mock: MagicMock): + wsa_mock.run_forever.assert_not_called() + wsa_mock._on_open.assert_not_called() + wsa_mock.close.assert_called() + + +# -------------------------------------------------------------------------------------- +# Test setup +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def ws_client(): + return WsClient( + subscription_processor=None, + url=_URL, + cacert=False, + timeout=0.01, + max_connection_attempts=_MAX_RECONNECT_ATTEMPTS, + max_ping_interval=_MAX_PING_INTERVAL, + ) + + +@pytest.fixture +def wsa_mock(): + return create_wsa_mock() + + +@pytest.fixture +def thread_mock(ws_client, wsa_mock): + thread_mock = MagicMock(spec=Thread) + thread_mock.start.side_effect = lambda: ws_client._run_websocket(wsa_mock) + return thread_mock + + +@pytest.fixture +def wsa_ctor_mock(mocker, wsa_mock): + return mocker.patch( + 'ibind.base.ws_client.WebSocketApp', + side_effect=lambda *args, **kwargs: init_wsa_mock(wsa_mock, *args, **kwargs), + ) + + +@pytest.fixture +def thread_ctor_mock(mocker, thread_mock): + return mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) + + +@pytest.fixture +def patched_constructors(wsa_ctor_mock, thread_ctor_mock): + return None + + +# -------------------------------------------------------------------------------------- +# Start / reconnect behavior +# -------------------------------------------------------------------------------------- + +@capture_logs(logger_level='DEBUG') +def test_start_success(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Starts successfully and logs the expected connection sequence.""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ## Act + success = ws_client.start() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + _verify_started(ws_client, wsa_mock) + assert _logs_start_success_beginning() + _logs_start_success_end() == [r.msg for r in cm.records] + + +@capture_logs(logger_level='DEBUG', expected_errors=['WsClient: New WebSocketApp connection timeout']) +def test_start_success_on_second_attempt(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Reconnects and succeeds on the second attempt after a timeout on the first.""" + ## Arrange + cm = kwargs['_cm_ibind'] + counter = [0] + + def delayed_start(): + if counter[0] >= 1: + ws_client._run_websocket(wsa_mock) + counter[0] += 1 + + thread_mock.start.side_effect = delayed_start + + ## Act + success = ws_client.start() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + _verify_started(ws_client, wsa_mock) + assert ( + _logs_start_success_beginning() + + _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, 2) + + _logs_start_success_end() + == [r.msg for r in cm.records] + ) + thread_mock.join.assert_called_with(60) + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + 'WsClient: New WebSocketApp connection timeout', + f'WsClient: Connection failed after {_MAX_RECONNECT_ATTEMPTS} attempts', + ], +) +def test_start_reattempt_failure(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Fails after exhausting reconnect attempts and closes the WebSocketApp.""" + ## Arrange + cm = kwargs['_cm_ibind'] + thread_mock.start.side_effect = lambda: None + + ## Act + success = ws_client.start() + + ## Assert + assert success is False + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + + _verify_failed_starting(wsa_mock) + + expected_logs = _logs_start_success_beginning() + for i in range(_MAX_RECONNECT_ATTEMPTS): + if i < _MAX_RECONNECT_ATTEMPTS - 1: + expected_logs += _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, i + 2) + else: + expected_logs += _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, None) + expected_logs.append(f"WsClient: Connection failed after {_MAX_RECONNECT_ATTEMPTS} attempts") + + assert expected_logs == [r.msg for r in cm.records] + assert wsa_mock.keep_running is False + + +# -------------------------------------------------------------------------------------- +# Error handling +# -------------------------------------------------------------------------------------- + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + f"WsClient: Unexpected error while running WebSocketApp: {_ERROR_MESSAGE}", + 'WsClient: Thread already running:', + ], + partial_match=True, +) +def test_open_exception(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Hard-resets and reconnects when WebSocketApp.run_forever raises an exception.""" + ## Arrange + cm = kwargs['_cm_ibind'] + old_run_forever = wsa_mock.run_forever.side_effect + + def run_forever_exception( + wsa_mock: MagicMock, + sslopt: dict = None, + ping_interval: float = 0, + ping_timeout: Optional[float] = None, + ): + wsa_mock.run_forever.side_effect = old_run_forever + raise RuntimeError(_ERROR_MESSAGE) + + wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever_exception(wsa_mock, *args, **kwargs) + + ## Act + ws_client.start() + ws_client.shutdown() + + ## Assert + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + assert ( + _logs_start_success_beginning() + + _logs_exception_starting(_ERROR_MESSAGE, thread_mock) + + _logs_start_success_end() + + _logs_shutdown_success() + == [r.msg for r in cm.records] + ) + + +# -------------------------------------------------------------------------------------- +# Shutdown +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_open_and_close(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Shuts down cleanly after a successful start.""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ## Act + success = ws_client.start() + ws_client.shutdown() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + assert _logs_start_success_beginning() + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] + + +# -------------------------------------------------------------------------------------- +# Sending payloads +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_send(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): + """Delivers outbound payloads to the on_message callback (mocked echo).""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ws_client._on_message = MagicMock() + + ## Act + success = ws_client.start() + ws_client.send('test') + ws_client.shutdown() + + ## Assert + assert success is True + thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') + ws_client._on_message.assert_called_once_with(wsa_mock, 'test') + assert _logs_start_success_beginning() + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] + + +@capture_logs(logger_level='DEBUG', expected_errors=['WsClient: Must be started before sending payloads']) +def test_send_without_start(ws_client, **kwargs): + """Logs an error when trying to send before calling start().""" + ## Arrange + cm = kwargs['_cm_ibind'] + + ws_client._on_message = MagicMock() + + ## Act + ws_client.send('test') + ws_client.shutdown() + + ## Assert + assert ['WsClient: Must be started before sending payloads'] == [r.msg for r in cm.records] + + +# -------------------------------------------------------------------------------------- +# Health checks +# -------------------------------------------------------------------------------------- + + +@capture_logs( + logger_level='DEBUG', + expected_errors=[ + 'WsClient: Last WebSocket ping happened', + 'WsClient: Hard reset close timeout', + 'WsClient: Abandoning current WebSocketApp that cannot be closed:', + ], + partial_match=True, +) +def test_check_ping(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Triggers a hard reset when the last ping exceeds max_ping_interval.""" + ## Arrange + cm = kwargs['_cm_ibind'] + start_time = [100] + + def fake_time(): + start_time[0] += 100 + return start_time[0] + + ws_client._on_message = MagicMock() + + ## Act + ws_client.start() + ws_client.check_ping() + + # Simulate that closing the WebSocketApp doesn't work since we have connectivity issues + wsa_mock._on_close.side_effect = lambda x, y, z: None - expected_errors = ['WsClient: Must be started before sending payloads', 'WsClient: Hard reset close timeout'] + time_mock = mocker.patch('ibind.base.ws_client.time') + time_mock.time.side_effect = fake_time - cm, success = self.run_in_test_context(run, expected_errors=expected_errors) + wsa_mock.last_ping_tm = _MAX_PING_INTERVAL + ws_client.check_ping() + assert ws_client.ready() is True + ws_client.shutdown() - exact_log( - self, - cm, - self._logs_start_success_beginning() - + self._logs_start_success_end() - + self._logs_check_health_error('162.00') - + - # self._logs_start_success_end() + - self._logs_hard_restart_error() - + self._logs_start_success_end() - + self._logs_shutdown_success(), - ) \ No newline at end of file + ## Assert + assert ( + _logs_start_success_beginning() + + _logs_start_success_end() + + _logs_check_health_error(_MAX_PING_INTERVAL, '162.00') + + _logs_hard_restart_error(wsa_mock) + + _logs_start_success_end() + + _logs_shutdown_success() + == [r.msg for r in cm.records] + ) \ No newline at end of file diff --git a/test/integration/base/test_websocket_client_i_new.py b/test/integration/base/test_websocket_client_i_new.py deleted file mode 100644 index aa3d1221..00000000 --- a/test/integration/base/test_websocket_client_i_new.py +++ /dev/null @@ -1,396 +0,0 @@ -from threading import Thread -from typing import Optional -from unittest.mock import MagicMock - -import pytest - -from ibind.base.ws_client import WsClient -from ibind.support.py_utils import tname -from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils_new import capture_logs - -_URL = 'wss://localhost:5000/v1/api/ws' -_MAX_RECONNECT_ATTEMPTS = 4 -_MAX_PING_INTERVAL = 38 -_ERROR_MESSAGE = 'TEST_ERROR' - - -# -------------------------------------------------------------------------------------- -# Log expectations -# -------------------------------------------------------------------------------------- - - -def _logs_start_success_beginning(): - return [ - 'WsClient: Starting', - 'WsClient: Trying to connect', - ] - - -def _logs_start_success_end(): - return [ - 'WsClient: Creating new WebSocketApp', - f'WsClient: Thread started ({tname()})', - 'WsClient: Connection open', - f'WsClient: Thread stopped ({tname()})', - ] - - -def _logs_failed_attempt(max_reconnect_attempts: int, attempt: Optional[int]): - logs = [ - 'WsClient: Creating new WebSocketApp', - 'WsClient: New WebSocketApp connection timeout', - 'WsClient: on_close', - 'WsClient: on_close event while disconnected', - ] - if attempt is not None: - logs.append(f'WsClient: Connect reattempt {attempt}/{max_reconnect_attempts}') - return logs - - -def _logs_shutdown_success(): - return [ - 'WsClient: Shutting down', - 'WsClient: on_close', - 'WsClient: Connection closed', - 'WsClient: Gracefully stopped', - ] - - -def _logs_exception_starting(error_message: str, thread_mock: MagicMock): - return [ - 'WsClient: Creating new WebSocketApp', - f'WsClient: Thread started ({tname()})', - f'WsClient: Unexpected error while running WebSocketApp: {error_message}', - 'WsClient: Hard reset, restart=False, self._wsa is None=False', - 'WsClient: Forced restart', - 'WsClient: Reconnecting', - f'WsClient: Thread already running: {thread_mock.name}-{thread_mock.ident}', - f'WsClient: Thread stopped ({tname()})', - 'WsClient: Reconnecting', - 'WsClient: Trying to connect', - ] - - -def _logs_check_health_error(max_ping_interval: int, time_ago: str): - return [ - f'WsClient: Last WebSocket ping happened {time_ago} seconds ago, exceeding the max ping interval of {max_ping_interval}. Restarting.', - 'WsClient: Hard reset, restart=True, self._wsa is None=False', - 'WsClient: Hard reset is closing the WebSocketApp', - ] - - -def _logs_hard_restart_error(wsa_mock: MagicMock): - return [ - 'WsClient: Hard reset close timeout', - f'WsClient: Abandoning current WebSocketApp that cannot be closed: {wsa_mock}', - 'WsClient: Forced restart', - 'WsClient: Reconnecting', - 'WsClient: Trying to connect', - ] - - -def _verify_started(ws_client: WsClient, wsa_mock: MagicMock): - wsa_mock.run_forever.assert_called_with( - sslopt=ws_client._sslopt, - ping_interval=ws_client._ping_interval, - ping_timeout=0.95 * ws_client._ping_interval, - ) - wsa_mock._on_open.assert_called_with(wsa_mock) - - -def _verify_failed_starting(wsa_mock: MagicMock): - wsa_mock.run_forever.assert_not_called() - wsa_mock._on_open.assert_not_called() - wsa_mock.close.assert_called() - - -# -------------------------------------------------------------------------------------- -# Test setup -# -------------------------------------------------------------------------------------- - - -@pytest.fixture -def ws_client(): - return WsClient( - subscription_processor=None, - url=_URL, - cacert=False, - timeout=0.01, - max_connection_attempts=_MAX_RECONNECT_ATTEMPTS, - max_ping_interval=_MAX_PING_INTERVAL, - ) - - -@pytest.fixture -def wsa_mock(): - return create_wsa_mock() - - -@pytest.fixture -def thread_mock(ws_client, wsa_mock): - thread_mock = MagicMock(spec=Thread) - thread_mock.start.side_effect = lambda: ws_client._run_websocket(wsa_mock) - return thread_mock - - -@pytest.fixture -def wsa_ctor_mock(mocker, wsa_mock): - return mocker.patch( - 'ibind.base.ws_client.WebSocketApp', - side_effect=lambda *args, **kwargs: init_wsa_mock(wsa_mock, *args, **kwargs), - ) - - -@pytest.fixture -def thread_ctor_mock(mocker, thread_mock): - return mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) - - -@pytest.fixture -def patched_constructors(wsa_ctor_mock, thread_ctor_mock): - return None - - -# -------------------------------------------------------------------------------------- -# Start / reconnect behavior -# -------------------------------------------------------------------------------------- - -@capture_logs(logger_level='DEBUG') -def test_start_success(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): - """Starts successfully and logs the expected connection sequence.""" - ## Arrange - cm = kwargs['_cm_ibind'] - - ## Act - success = ws_client.start() - - ## Assert - assert success is True - thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') - _verify_started(ws_client, wsa_mock) - assert _logs_start_success_beginning() + _logs_start_success_end() == [r.msg for r in cm.records] - - -@capture_logs(logger_level='DEBUG', expected_errors=['WsClient: New WebSocketApp connection timeout']) -def test_start_success_on_second_attempt(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): - """Reconnects and succeeds on the second attempt after a timeout on the first.""" - ## Arrange - cm = kwargs['_cm_ibind'] - counter = [0] - - def delayed_start(): - if counter[0] >= 1: - ws_client._run_websocket(wsa_mock) - counter[0] += 1 - - thread_mock.start.side_effect = delayed_start - - ## Act - success = ws_client.start() - - ## Assert - assert success is True - thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') - _verify_started(ws_client, wsa_mock) - assert ( - _logs_start_success_beginning() - + _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, 2) - + _logs_start_success_end() - == [r.msg for r in cm.records] - ) - thread_mock.join.assert_called_with(60) - - -@capture_logs( - logger_level='DEBUG', - expected_errors=[ - 'WsClient: New WebSocketApp connection timeout', - f'WsClient: Connection failed after {_MAX_RECONNECT_ATTEMPTS} attempts', - ], -) -def test_start_reattempt_failure(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): - """Fails after exhausting reconnect attempts and closes the WebSocketApp.""" - ## Arrange - cm = kwargs['_cm_ibind'] - thread_mock.start.side_effect = lambda: None - - ## Act - success = ws_client.start() - - ## Assert - assert success is False - thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') - - _verify_failed_starting(wsa_mock) - - expected_logs = _logs_start_success_beginning() - for i in range(_MAX_RECONNECT_ATTEMPTS): - if i < _MAX_RECONNECT_ATTEMPTS - 1: - expected_logs += _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, i + 2) - else: - expected_logs += _logs_failed_attempt(_MAX_RECONNECT_ATTEMPTS, None) - expected_logs.append(f"WsClient: Connection failed after {_MAX_RECONNECT_ATTEMPTS} attempts") - - assert expected_logs == [r.msg for r in cm.records] - assert wsa_mock.keep_running is False - - -# -------------------------------------------------------------------------------------- -# Error handling -# -------------------------------------------------------------------------------------- - - -@capture_logs( - logger_level='DEBUG', - expected_errors=[ - f"WsClient: Unexpected error while running WebSocketApp: {_ERROR_MESSAGE}", - 'WsClient: Thread already running:', - ], - partial_match=True, -) -def test_open_exception(ws_client, wsa_mock, thread_mock, thread_ctor_mock, patched_constructors, **kwargs): - """Hard-resets and reconnects when WebSocketApp.run_forever raises an exception.""" - ## Arrange - cm = kwargs['_cm_ibind'] - old_run_forever = wsa_mock.run_forever.side_effect - - def run_forever_exception( - wsa_mock: MagicMock, - sslopt: dict = None, - ping_interval: float = 0, - ping_timeout: Optional[float] = None, - ): - wsa_mock.run_forever.side_effect = old_run_forever - raise RuntimeError(_ERROR_MESSAGE) - - wsa_mock.run_forever.side_effect = lambda *args, **kwargs: run_forever_exception(wsa_mock, *args, **kwargs) - - ## Act - ws_client.start() - ws_client.shutdown() - - ## Assert - thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') - assert ( - _logs_start_success_beginning() - + _logs_exception_starting(_ERROR_MESSAGE, thread_mock) - + _logs_start_success_end() - + _logs_shutdown_success() - == [r.msg for r in cm.records] - ) - - -# -------------------------------------------------------------------------------------- -# Shutdown -# -------------------------------------------------------------------------------------- - - -@capture_logs(logger_level='DEBUG') -def test_open_and_close(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): - """Shuts down cleanly after a successful start.""" - ## Arrange - cm = kwargs['_cm_ibind'] - - ## Act - success = ws_client.start() - ws_client.shutdown() - - ## Assert - assert success is True - thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') - assert _logs_start_success_beginning() + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] - - -# -------------------------------------------------------------------------------------- -# Sending payloads -# -------------------------------------------------------------------------------------- - - -@capture_logs(logger_level='DEBUG') -def test_send(ws_client, wsa_mock, thread_ctor_mock, patched_constructors, **kwargs): - """Delivers outbound payloads to the on_message callback (mocked echo).""" - ## Arrange - cm = kwargs['_cm_ibind'] - - ws_client._on_message = MagicMock() - - ## Act - success = ws_client.start() - ws_client.send('test') - ws_client.shutdown() - - ## Assert - assert success is True - thread_ctor_mock.assert_called_with(target=ws_client._run_websocket, args=(wsa_mock,), name='ws_client_thread') - ws_client._on_message.assert_called_once_with(wsa_mock, 'test') - assert _logs_start_success_beginning() + _logs_start_success_end() + _logs_shutdown_success() == [r.msg for r in cm.records] - - -@capture_logs(logger_level='DEBUG', expected_errors=['WsClient: Must be started before sending payloads']) -def test_send_without_start(ws_client, **kwargs): - """Logs an error when trying to send before calling start().""" - ## Arrange - cm = kwargs['_cm_ibind'] - - ws_client._on_message = MagicMock() - - ## Act - ws_client.send('test') - ws_client.shutdown() - - ## Assert - assert ['WsClient: Must be started before sending payloads'] == [r.msg for r in cm.records] - - -# -------------------------------------------------------------------------------------- -# Health checks -# -------------------------------------------------------------------------------------- - - -@capture_logs( - logger_level='DEBUG', - expected_errors=[ - 'WsClient: Last WebSocket ping happened', - 'WsClient: Hard reset close timeout', - 'WsClient: Abandoning current WebSocketApp that cannot be closed:', - ], - partial_match=True, -) -def test_check_ping(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): - """Triggers a hard reset when the last ping exceeds max_ping_interval.""" - ## Arrange - cm = kwargs['_cm_ibind'] - start_time = [100] - - def fake_time(): - start_time[0] += 100 - return start_time[0] - - ws_client._on_message = MagicMock() - - ## Act - ws_client.start() - ws_client.check_ping() - - # Simulate that closing the WebSocketApp doesn't work since we have connectivity issues - wsa_mock._on_close.side_effect = lambda x, y, z: None - - time_mock = mocker.patch('ibind.base.ws_client.time') - time_mock.time.side_effect = fake_time - - wsa_mock.last_ping_tm = _MAX_PING_INTERVAL - ws_client.check_ping() - assert ws_client.ready() is True - ws_client.shutdown() - - ## Assert - assert ( - _logs_start_success_beginning() - + _logs_start_success_end() - + _logs_check_health_error(_MAX_PING_INTERVAL, '162.00') - + _logs_hard_restart_error(wsa_mock) - + _logs_start_success_end() - + _logs_shutdown_success() - == [r.msg for r in cm.records] - ) \ No newline at end of file diff --git a/test/integration/client/test_ibkr_client_i.py b/test/integration/client/test_ibkr_client_i.py index 8042b5d6..b29c1a0f 100644 --- a/test/integration/client/test_ibkr_client_i.py +++ b/test/integration/client/test_ibkr_client_i.py @@ -1,7 +1,7 @@ import datetime from pprint import pformat -from unittest import TestCase -from unittest.mock import patch, MagicMock +import pytest +from unittest.mock import MagicMock from requests import ConnectTimeout @@ -9,259 +9,353 @@ from ibind.client.ibkr_client import IbkrClient from ibind.client.ibkr_utils import StockQuery, filter_stocks from ibind.support.errors import ExternalBrokerError -from ibind.support.logs import project_logger -from test.integration.client import ibkr_responses -from test_utils import verify_log, SafeAssertLogs, RaiseLogsContext - - -@patch('ibind.base.rest_client.requests') -class TestIbkrClientI(TestCase): - def setUp(self): - self.url = 'https://localhost:5000' - self.account_id = 'TEST_ACCOUNT_ID' - self.timeout = 8 - self.max_retries = 4 - self.client = IbkrClient( - url=self.url, - account_id=self.account_id, - timeout=self.timeout, - max_retries=self.max_retries, - use_session=False, - ) - - self.data = {'Test key': 'Test value'} - - self.response = MagicMock() - self.response.json.return_value = self.data - self.default_path = '/test/api/route' - self.default_url = f'{self.url}/{self.default_path}' - self.result = Result(data=self.data, request={'url': self.default_url}) - self.maxDiff = 9999 - - def test_get_conids(self, requests_mock): - requests_mock.request.return_value = self.response - self.response.json.return_value = ibkr_responses.responses['stocks'] - - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': False}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), - 'HUBS', - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - StockQuery(symbol='INVALID_SYMBOL') - ] # fmt: skip - - with self.assertLogs(project_logger(), level='INFO'): - rv = self.client.stock_conid_by_symbol(queries, default_filtering=False) - - for symbol, conid in rv.data.items(): - self.assertIn(symbol, ibkr_responses.responses['filtered_conids']) - self.assertEqual(conid, ibkr_responses.responses['filtered_conids'][symbol]) - - def test_get_conids_exception(self, requests_mock): - requests_mock.request.return_value = self.response - self.response.json.return_value = ibkr_responses.responses['stocks'] - - symbol = 'AAPL' - query = StockQuery(symbol=symbol, contract_conditions={'isUS': False}, name_match='APPLE') - - instruments = filter_stocks(query, Result(data={symbol: ibkr_responses.responses['stocks'][symbol]}), default_filtering=False).data[symbol] - - with self.assertRaises(RuntimeError) as cm_err: - self.client.stock_conid_by_symbol(query, default_filtering=False) - - self.maxDiff = None - self.assertEqual( - f'Filtering stock "{symbol}" returned 2 instruments and 2 contracts using following query: {query}.\nPlease use filters to ensure that only one instrument and one contract per symbol is selected in order to avoid conid ambiguity.\nBe aware that contracts are filtered as {{"isUS": True}} by default. Set default_filtering=False to prevent this default filtering or specify custom filters. See inline documentation for more details.\nInstruments returned:\n{pformat(instruments)}', - str(cm_err.exception), - ) - - def test_get_live_orders_no_filters(self, requests_mock): - self.client.get = MagicMock(return_value=self.result) - self.client.live_orders() - self.client.get.assert_called_with('iserver/account/orders', params=None) - - def test_get_live_orders_with_valid_filters(self, requests_mock): - self.client.get = MagicMock(return_value=self.result) - filters = ['inactive', 'filled'] - self.client.live_orders(filters=filters) - self.client.get.assert_called_with('iserver/account/orders', params={'filters': 'inactive,filled'}) - - def test_get_live_orders_with_single_filter(self, requests_mock): - self.client.get = MagicMock(return_value=self.result) - self.client.live_orders(filters='submitted') - self.client.get.assert_called_with('iserver/account/orders', params={'filters': 'submitted'}) - - def test_get_live_orders_with_incorrect_filter_type(self, requests_mock): - self.client.get = MagicMock(return_value=self.result) - with self.assertRaises(TypeError): - self.client.live_orders(filters=123) # Non-list, non-string filter - self.client.get.assert_not_called() - - def _marketdata_request(self, method, url, *args, **kwargs): - leaf = url.split('/')[-1] - if leaf == 'stocks': - return MagicMock(json=lambda: ibkr_responses.responses['stocks']) # Mock response for get_conids - elif leaf == 'history': - conid = kwargs['params']['conid'] - return MagicMock(json=lambda: self._history_by_conid[conid]) - - def test_marketdata_history_by_symbols(self, requests_mock): - # Mocking the requests module for external interaction - self._history_by_conid = { +from ibind.support.logs import ibind_logs_initialize +from integration.client import ibkr_responses +from test_utils_new import CaptureLogsContext + + +_URL = 'https://localhost:5000' +_TIMEOUT = 8 +_MAX_RETRIES = 4 +_DEFAULT_PATH = '/test/api/route' +_ACCOUNT_ID = 'TEST_ACCOUNT_ID' + + +@pytest.fixture +def client(): + ibind_logs_initialize(log_to_console=True) + return IbkrClient( + url=_URL, + account_id=_ACCOUNT_ID, + timeout=_TIMEOUT, + max_retries=_MAX_RETRIES, + use_session=False, + ) + + +@pytest.fixture +def data(): + return {'Test key': 'Test value'} + + +@pytest.fixture +def response(data): + response = MagicMock() + response.json.return_value = data + return response + + +@pytest.fixture(autouse=True) +def requests_mock(mocker, response): + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response + return requests_mock + + +@pytest.fixture +def default_url(): + return f'{_URL}/{_DEFAULT_PATH}' + + +@pytest.fixture +def result(data, default_url): + return Result(data=data, request={'url': default_url}) + + +def test_get_conids(client, response): + # Arrange + response.json.return_value = ibkr_responses.responses['stocks'] + + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': False}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), + 'HUBS', + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + StockQuery(symbol='INVALID_SYMBOL') + ] + + # Act + rv = client.stock_conid_by_symbol(queries, default_filtering=False) + + # Assert + for symbol, conid in rv.data.items(): + assert symbol in ibkr_responses.responses['filtered_conids'] + assert conid == ibkr_responses.responses['filtered_conids'][symbol] + + +def test_get_conids_exception(client, response): + # Arrange + response.json.return_value = ibkr_responses.responses['stocks'] + + symbol = 'AAPL' + query = StockQuery(symbol=symbol, contract_conditions={'isUS': False}, name_match='APPLE') + + instruments = filter_stocks(query, Result(data={symbol: ibkr_responses.responses['stocks'][symbol]}), default_filtering=False).data[symbol] + + # Act and Assert + with pytest.raises(RuntimeError) as excinfo: + client.stock_conid_by_symbol(query, default_filtering=False) + + assert str(excinfo.value) == f'Filtering stock "{symbol}" returned 2 instruments and 2 contracts using following query: {query}.' \ + f'\nPlease use filters to ensure that only one instrument and one contract per symbol is selected in order to avoid conid ambiguity.' \ + f'\nBe aware that contracts are filtered as {{"isUS": True}} by default. Set default_filtering=False to prevent this default filtering or specify custom filters. See inline documentation for more details.' \ + f'\nInstruments returned:\n{pformat(instruments)}' + + +def test_get_live_orders_no_filters(client, result): + # Arrange + client.get = MagicMock(return_value=result) + + # Act + client.live_orders() + + # Assert + client.get.assert_called_with('iserver/account/orders', params=None) + + +def test_get_live_orders_with_valid_filters(client, result): + # Arrange + client.get = MagicMock(return_value=result) + filters = ['inactive', 'filled'] + + # Act + client.live_orders(filters=filters) + + # Assert + client.get.assert_called_with('iserver/account/orders', params={'filters': 'inactive,filled'}) + + +def test_get_live_orders_with_single_filter(client, result): + # Arrange + client.get = MagicMock(return_value=result) + + # Act + client.live_orders(filters='submitted') + + # Assert + client.get.assert_called_with('iserver/account/orders', params={'filters': 'submitted'}) + + +def test_get_live_orders_with_incorrect_filter_type(client, result): + # Arrange + client.get = MagicMock(return_value=result) + + # Act and Assert + with pytest.raises(TypeError): + client.live_orders(filters=123) # Non-list, non-string filter + client.get.assert_not_called() + + +def _marketdata_request(method, url, *args, **kwargs): + leaf = url.split('/')[-1] + if leaf == 'stocks': + return MagicMock(json=lambda: ibkr_responses.responses['stocks']) + elif leaf == 'history': + conid = kwargs['params']['conid'] + history_by_conid = { ibkr_responses.responses['filtered_conids'][key]: value for key, value in ibkr_responses.responses['history'].items() } - requests_mock.request.side_effect = self._marketdata_request - - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': False}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), - StockQuery(symbol='HUBS'), - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - ] # fmt: skip - - expected_results = {} - - for query in queries: - data = ibkr_responses.responses['history'][query.symbol]['data'][0] - output = { - 'conid': ibkr_responses.responses['filtered_conids'][query.symbol], - 'symbol': query.symbol, - 'open': data['o'], - 'high': data['h'], - 'low': data['l'], - 'close': data['c'], - 'volume': data['v'], - 'date': datetime.datetime.fromtimestamp(data['t'] / 1000, tz=datetime.timezone.utc), - } - expected_results[query.symbol] = output - - expected_errors = ['Market data for CDN is not live: Delayed', 'Market data for CFC is not live: Delayed'] - - with SafeAssertLogs(self, 'ibind', level='INFO', logger_level='DEBUG', no_logs=False) as cm, \ - RaiseLogsContext(self, 'ibind', level='ERROR', expected_errors=expected_errors): # fmt: skip - results = self.client.marketdata_history_by_symbols(queries) - - verify_log(self, cm, expected_errors) - - # Assertions to verify the correctness of each field in the result - for symbol, expected in expected_results.items(): - result = results[symbol][-1] - self.assertIn(symbol, results) - self.assertAlmostEqual(result['open'], expected['open']) - self.assertAlmostEqual(result['high'], expected['high']) - self.assertAlmostEqual(result['low'], expected['low']) - self.assertAlmostEqual(result['close'], expected['close']) - self.assertAlmostEqual(result['volume'], expected['volume']) - self.assertEqual(result['date'], expected['date']) - - def test_check_health_authenticated_and_connected(self, requests_mock): - response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - self.client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': self.default_url})) - - health_status = self.client.check_health() - self.assertTrue(health_status) - self.client.tickle.assert_called_once() - - def test_check_health_not_authenticated(self, requests_mock): - response_data = {'iserver': {'authStatus': {'authenticated': False, 'competing': False, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - self.client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': self.default_url})) - - health_status = self.client.check_health() - self.assertFalse(health_status) - - def test_check_health_competing_connection(self, requests_mock): - response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': True, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - self.client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': self.default_url})) - - health_status = self.client.check_health() - self.assertFalse(health_status) - - def test_check_health_connection_error(self, requests_mock): - requests_mock.request.side_effect = ConnectTimeout - self.client.tickle = MagicMock(side_effect=ConnectTimeout) - - with self.assertLogs(level='ERROR') as cm: - health_status = self.client.check_health() - self.assertFalse(health_status) - self.assertIn('ConnectTimeout raised when communicating with the Gateway', cm.output[0]) - - def test_check_health_external_broker_error_unauthenticated(self, requests_mock): - requests_mock.request.side_effect = ExternalBrokerError(status_code=401) - self.client.tickle = MagicMock(side_effect=ExternalBrokerError(status_code=401)) - - with self.assertLogs(level='INFO') as cm: - health_status = self.client.check_health() - self.assertFalse(health_status) - self.assertIn('Gateway session is not authenticated.', cm.output[0]) - - def test_check_health_invalid_data(self, requests_mock): - response_data = {} # Invalid data format - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - self.client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': self.default_url})) - - with self.assertRaises(AttributeError) as cm: - self.client.check_health() - self.assertIn('Health check requests returns invalid data', str(cm.exception)) - - def test_marketdata_unsubscribe_success(self, requests_mock): - conids = [12345, 67890] - responses = {12345: MagicMock(status_code=200), 67890: MagicMock(status_code=200)} - requests_mock.request.side_effect = lambda method, url, **kwargs: responses[kwargs['json']['conid']] - self.client.get = MagicMock( - side_effect=lambda url, *args, **kwargs: Result(data={'success': True}, request={'url': url}), __name__='client_get_mock' - ) - - results = self.client.marketdata_unsubscribe(conids) - - for conid, result in results.items(): - self.assertIn(conid, conids) - self.assertIsInstance(result, Result) - self.assertTrue(result.data['success']) - - def test_marketdata_unsubscribe_with_error(self, requests_mock): - conids = [12345, 67890] - responses = { - 12345: MagicMock(status_code=404), # Simulate not found error for one conid - 67890: MagicMock(status_code=200), - } - requests_mock.request.side_effect = lambda method, url, **kwargs: responses[kwargs['json']['conid']] - self.client.get = MagicMock( - side_effect=lambda url, *args, **kwargs: Result(data={'success': True}, request={'url': url}) - if '67890' in url - else ExternalBrokerError(status_code=404), - __name__='client_get_mock', - ) - - results = self.client.marketdata_unsubscribe(conids) - - self.assertIn(12345, results) - self.assertIn(67890, results) - self.assertTrue(results[67890].data['success']) - - def test_marketdata_unsubscribe_raises_exception_on_failure(self, requests_mock): - conids = [12345] - responses = { - 12345: MagicMock(status_code=500), # Simulate server error + return MagicMock(json=lambda: history_by_conid[conid]) + + +def test_marketdata_history_by_symbols(client, requests_mock): + # Arrange + requests_mock.request.side_effect = _marketdata_request + + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': False}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), + StockQuery(symbol='HUBS'), + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + ] + + expected_results = {} + for query in queries: + data = ibkr_responses.responses['history'][query.symbol]['data'][0] + output = { + 'conid': ibkr_responses.responses['filtered_conids'][query.symbol], + 'symbol': query.symbol, + 'open': data['o'], + 'high': data['h'], + 'low': data['l'], + 'close': data['c'], + 'volume': data['v'], + 'date': datetime.datetime.fromtimestamp(data['t'] / 1000, tz=datetime.timezone.utc), } - requests_mock.request.side_effect = lambda method, url, **kwargs: responses[int(url.split('/')[-2])] - self.client.post = MagicMock(side_effect=lambda url, *args, **kwargs: ExternalBrokerError(status_code=500), __name__='client_get_mock') + expected_results[query.symbol] = output + + expected_errors = ['Market data for CDN is not live: Delayed', 'Market data for CFC is not live: Delayed'] + + # Act + with CaptureLogsContext('ibind', level='INFO', logger_level='DEBUG', expected_errors=expected_errors, partial_match=True) as cm: + results = client.marketdata_history_by_symbols(queries) + + # Assert + for msg in expected_errors: + assert msg in cm.output + + for symbol, expected in expected_results.items(): + result = results[symbol][-1] + assert symbol in results + assert result['open'] == pytest.approx(expected['open']) + assert result['high'] == pytest.approx(expected['high']) + assert result['low'] == pytest.approx(expected['low']) + assert result['close'] == pytest.approx(expected['close']) + assert result['volume'] == pytest.approx(expected['volume']) + assert result['date'] == expected['date'] + + +def test_check_health_authenticated_and_connected(client, default_url, requests_mock): + # Arrange + response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is True + client.tickle.assert_called_once() + + +def test_check_health_not_authenticated(client, default_url, requests_mock): + # Arrange + response_data = {'iserver': {'authStatus': {'authenticated': False, 'competing': False, 'connected': True}}} + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is False + + +def test_check_health_competing_connection(client, default_url, requests_mock): + # Arrange + response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': True, 'connected': True}}} + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act + health_status = client.check_health() + + # Assert + assert health_status is False + + +def test_check_health_connection_error(client, requests_mock): + # Arrange + requests_mock.request.side_effect = ConnectTimeout + client.tickle = MagicMock(side_effect=ConnectTimeout) + + # Act + with CaptureLogsContext( + 'ibind.session_mixin', + level='ERROR', + expected_errors=['ConnectTimeout raised when communicating with the Gateway'], + partial_match=True, + ) as cm: + health_status = client.check_health() + + # Assert + assert health_status is False + assert 'ConnectTimeout raised when communicating with the Gateway' in cm.output[0] + + +def test_check_health_external_broker_error_unauthenticated(client, requests_mock): + # Arrange + requests_mock.request.side_effect = ExternalBrokerError(status_code=401) + client.tickle = MagicMock(side_effect=ExternalBrokerError(status_code=401)) + + # Act + with CaptureLogsContext('ibind.session_mixin', level='INFO', expected_errors=['Gateway session is not authenticated.']) as cm: + health_status = client.check_health() + + # Assert + assert health_status is False + assert 'Gateway session is not authenticated.' in cm.output[0] + + +def test_check_health_invalid_data(client, default_url, requests_mock): + # Arrange + response_data = {} # Invalid data format + requests_mock.request.return_value = MagicMock(json=lambda: response_data) + client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) + + # Act and Assert + with pytest.raises(AttributeError) as excinfo: + client.check_health() + assert 'Health check requests returns invalid data' in str(excinfo.value) + + +def test_marketdata_unsubscribe_success(client, mocker): + # Arrange + conids = [12345, 67890] + + def post_side_effect(url, *args, **kwargs): + conid = kwargs['params']['conid'] + if conid in conids: + return Result(data={'success': True}, request={'url': url}) + raise ExternalBrokerError(status_code=404) + + client.post = MagicMock(side_effect=post_side_effect, __name__='client_post_mock') + + # Act + results = client.marketdata_unsubscribe(conids) + + # Assert + for conid, result in results.items(): + assert int(conid) in conids + assert isinstance(result, Result) + assert result.data['success'] is True + + +def test_marketdata_unsubscribe_with_error(client, mocker): + # Arrange + conids = [12345, 67890] + + def post_side_effect(url, *args, **kwargs): + conid = kwargs['params']['conid'] + if conid == 12345: + raise ExternalBrokerError(status_code=404) + return Result(data={'success': True}, request={'url': url}) + + client.post = MagicMock(side_effect=post_side_effect, __name__='client_post_mock') + + # Act + results = client.marketdata_unsubscribe(conids) + + # Assert + assert 12345 in results + assert 67890 in results + assert results[67890].data['success'] is True + assert isinstance(results[12345], ExternalBrokerError) + + +def test_marketdata_unsubscribe_raises_exception_on_failure(client, mocker): + # Arrange + conids = [12345] + client.post = MagicMock(side_effect=ExternalBrokerError(status_code=500), __name__='client_post_mock') + + # Act + with pytest.raises(ExternalBrokerError) as excinfo: + client.marketdata_unsubscribe(conids) - with self.assertRaises(ExternalBrokerError): - self.client.marketdata_unsubscribe(conids) \ No newline at end of file + # Assert + assert excinfo.value.status_code == 500 \ No newline at end of file diff --git a/test/integration/client/test_ibkr_client_i_new.py b/test/integration/client/test_ibkr_client_i_new.py deleted file mode 100644 index b29c1a0f..00000000 --- a/test/integration/client/test_ibkr_client_i_new.py +++ /dev/null @@ -1,361 +0,0 @@ -import datetime -from pprint import pformat -import pytest -from unittest.mock import MagicMock - -from requests import ConnectTimeout - -from ibind.base.rest_client import Result -from ibind.client.ibkr_client import IbkrClient -from ibind.client.ibkr_utils import StockQuery, filter_stocks -from ibind.support.errors import ExternalBrokerError -from ibind.support.logs import ibind_logs_initialize -from integration.client import ibkr_responses -from test_utils_new import CaptureLogsContext - - -_URL = 'https://localhost:5000' -_TIMEOUT = 8 -_MAX_RETRIES = 4 -_DEFAULT_PATH = '/test/api/route' -_ACCOUNT_ID = 'TEST_ACCOUNT_ID' - - -@pytest.fixture -def client(): - ibind_logs_initialize(log_to_console=True) - return IbkrClient( - url=_URL, - account_id=_ACCOUNT_ID, - timeout=_TIMEOUT, - max_retries=_MAX_RETRIES, - use_session=False, - ) - - -@pytest.fixture -def data(): - return {'Test key': 'Test value'} - - -@pytest.fixture -def response(data): - response = MagicMock() - response.json.return_value = data - return response - - -@pytest.fixture(autouse=True) -def requests_mock(mocker, response): - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response - return requests_mock - - -@pytest.fixture -def default_url(): - return f'{_URL}/{_DEFAULT_PATH}' - - -@pytest.fixture -def result(data, default_url): - return Result(data=data, request={'url': default_url}) - - -def test_get_conids(client, response): - # Arrange - response.json.return_value = ibkr_responses.responses['stocks'] - - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': False}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), - 'HUBS', - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - StockQuery(symbol='INVALID_SYMBOL') - ] - - # Act - rv = client.stock_conid_by_symbol(queries, default_filtering=False) - - # Assert - for symbol, conid in rv.data.items(): - assert symbol in ibkr_responses.responses['filtered_conids'] - assert conid == ibkr_responses.responses['filtered_conids'][symbol] - - -def test_get_conids_exception(client, response): - # Arrange - response.json.return_value = ibkr_responses.responses['stocks'] - - symbol = 'AAPL' - query = StockQuery(symbol=symbol, contract_conditions={'isUS': False}, name_match='APPLE') - - instruments = filter_stocks(query, Result(data={symbol: ibkr_responses.responses['stocks'][symbol]}), default_filtering=False).data[symbol] - - # Act and Assert - with pytest.raises(RuntimeError) as excinfo: - client.stock_conid_by_symbol(query, default_filtering=False) - - assert str(excinfo.value) == f'Filtering stock "{symbol}" returned 2 instruments and 2 contracts using following query: {query}.' \ - f'\nPlease use filters to ensure that only one instrument and one contract per symbol is selected in order to avoid conid ambiguity.' \ - f'\nBe aware that contracts are filtered as {{"isUS": True}} by default. Set default_filtering=False to prevent this default filtering or specify custom filters. See inline documentation for more details.' \ - f'\nInstruments returned:\n{pformat(instruments)}' - - -def test_get_live_orders_no_filters(client, result): - # Arrange - client.get = MagicMock(return_value=result) - - # Act - client.live_orders() - - # Assert - client.get.assert_called_with('iserver/account/orders', params=None) - - -def test_get_live_orders_with_valid_filters(client, result): - # Arrange - client.get = MagicMock(return_value=result) - filters = ['inactive', 'filled'] - - # Act - client.live_orders(filters=filters) - - # Assert - client.get.assert_called_with('iserver/account/orders', params={'filters': 'inactive,filled'}) - - -def test_get_live_orders_with_single_filter(client, result): - # Arrange - client.get = MagicMock(return_value=result) - - # Act - client.live_orders(filters='submitted') - - # Assert - client.get.assert_called_with('iserver/account/orders', params={'filters': 'submitted'}) - - -def test_get_live_orders_with_incorrect_filter_type(client, result): - # Arrange - client.get = MagicMock(return_value=result) - - # Act and Assert - with pytest.raises(TypeError): - client.live_orders(filters=123) # Non-list, non-string filter - client.get.assert_not_called() - - -def _marketdata_request(method, url, *args, **kwargs): - leaf = url.split('/')[-1] - if leaf == 'stocks': - return MagicMock(json=lambda: ibkr_responses.responses['stocks']) - elif leaf == 'history': - conid = kwargs['params']['conid'] - history_by_conid = { - ibkr_responses.responses['filtered_conids'][key]: value for key, value in ibkr_responses.responses['history'].items() - } - return MagicMock(json=lambda: history_by_conid[conid]) - - -def test_marketdata_history_by_symbols(client, requests_mock): - # Arrange - requests_mock.request.side_effect = _marketdata_request - - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False, 'exchange': 'AEQLIT'}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': False}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery(symbol='GOOG', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={'chineseName': 'Alphabet公司'}), - StockQuery(symbol='HUBS'), - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False, 'exchange': 'MEXI'}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER', contract_conditions={'isUS': True}), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - ] - - expected_results = {} - for query in queries: - data = ibkr_responses.responses['history'][query.symbol]['data'][0] - output = { - 'conid': ibkr_responses.responses['filtered_conids'][query.symbol], - 'symbol': query.symbol, - 'open': data['o'], - 'high': data['h'], - 'low': data['l'], - 'close': data['c'], - 'volume': data['v'], - 'date': datetime.datetime.fromtimestamp(data['t'] / 1000, tz=datetime.timezone.utc), - } - expected_results[query.symbol] = output - - expected_errors = ['Market data for CDN is not live: Delayed', 'Market data for CFC is not live: Delayed'] - - # Act - with CaptureLogsContext('ibind', level='INFO', logger_level='DEBUG', expected_errors=expected_errors, partial_match=True) as cm: - results = client.marketdata_history_by_symbols(queries) - - # Assert - for msg in expected_errors: - assert msg in cm.output - - for symbol, expected in expected_results.items(): - result = results[symbol][-1] - assert symbol in results - assert result['open'] == pytest.approx(expected['open']) - assert result['high'] == pytest.approx(expected['high']) - assert result['low'] == pytest.approx(expected['low']) - assert result['close'] == pytest.approx(expected['close']) - assert result['volume'] == pytest.approx(expected['volume']) - assert result['date'] == expected['date'] - - -def test_check_health_authenticated_and_connected(client, default_url, requests_mock): - # Arrange - response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': False, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) - - # Act - health_status = client.check_health() - - # Assert - assert health_status is True - client.tickle.assert_called_once() - - -def test_check_health_not_authenticated(client, default_url, requests_mock): - # Arrange - response_data = {'iserver': {'authStatus': {'authenticated': False, 'competing': False, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) - - # Act - health_status = client.check_health() - - # Assert - assert health_status is False - - -def test_check_health_competing_connection(client, default_url, requests_mock): - # Arrange - response_data = {'iserver': {'authStatus': {'authenticated': True, 'competing': True, 'connected': True}}} - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) - - # Act - health_status = client.check_health() - - # Assert - assert health_status is False - - -def test_check_health_connection_error(client, requests_mock): - # Arrange - requests_mock.request.side_effect = ConnectTimeout - client.tickle = MagicMock(side_effect=ConnectTimeout) - - # Act - with CaptureLogsContext( - 'ibind.session_mixin', - level='ERROR', - expected_errors=['ConnectTimeout raised when communicating with the Gateway'], - partial_match=True, - ) as cm: - health_status = client.check_health() - - # Assert - assert health_status is False - assert 'ConnectTimeout raised when communicating with the Gateway' in cm.output[0] - - -def test_check_health_external_broker_error_unauthenticated(client, requests_mock): - # Arrange - requests_mock.request.side_effect = ExternalBrokerError(status_code=401) - client.tickle = MagicMock(side_effect=ExternalBrokerError(status_code=401)) - - # Act - with CaptureLogsContext('ibind.session_mixin', level='INFO', expected_errors=['Gateway session is not authenticated.']) as cm: - health_status = client.check_health() - - # Assert - assert health_status is False - assert 'Gateway session is not authenticated.' in cm.output[0] - - -def test_check_health_invalid_data(client, default_url, requests_mock): - # Arrange - response_data = {} # Invalid data format - requests_mock.request.return_value = MagicMock(json=lambda: response_data) - client.tickle = MagicMock(return_value=Result(data=response_data, request={'url': default_url})) - - # Act and Assert - with pytest.raises(AttributeError) as excinfo: - client.check_health() - assert 'Health check requests returns invalid data' in str(excinfo.value) - - -def test_marketdata_unsubscribe_success(client, mocker): - # Arrange - conids = [12345, 67890] - - def post_side_effect(url, *args, **kwargs): - conid = kwargs['params']['conid'] - if conid in conids: - return Result(data={'success': True}, request={'url': url}) - raise ExternalBrokerError(status_code=404) - - client.post = MagicMock(side_effect=post_side_effect, __name__='client_post_mock') - - # Act - results = client.marketdata_unsubscribe(conids) - - # Assert - for conid, result in results.items(): - assert int(conid) in conids - assert isinstance(result, Result) - assert result.data['success'] is True - - -def test_marketdata_unsubscribe_with_error(client, mocker): - # Arrange - conids = [12345, 67890] - - def post_side_effect(url, *args, **kwargs): - conid = kwargs['params']['conid'] - if conid == 12345: - raise ExternalBrokerError(status_code=404) - return Result(data={'success': True}, request={'url': url}) - - client.post = MagicMock(side_effect=post_side_effect, __name__='client_post_mock') - - # Act - results = client.marketdata_unsubscribe(conids) - - # Assert - assert 12345 in results - assert 67890 in results - assert results[67890].data['success'] is True - assert isinstance(results[12345], ExternalBrokerError) - - -def test_marketdata_unsubscribe_raises_exception_on_failure(client, mocker): - # Arrange - conids = [12345] - client.post = MagicMock(side_effect=ExternalBrokerError(status_code=500), __name__='client_post_mock') - - # Act - with pytest.raises(ExternalBrokerError) as excinfo: - client.marketdata_unsubscribe(conids) - - # Assert - assert excinfo.value.status_code == 500 \ No newline at end of file diff --git a/test/integration/client/test_ibkr_utils_i.py b/test/integration/client/test_ibkr_utils_i.py index 232e800b..2eb92395 100644 --- a/test/integration/client/test_ibkr_utils_i.py +++ b/test/integration/client/test_ibkr_utils_i.py @@ -1,337 +1,394 @@ from pprint import pformat -from unittest import TestCase -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, call + +import pytest from ibind.base.rest_client import Result -from ibind.client.ibkr_utils import StockQuery, filter_stocks, find_answer, QuestionType, handle_questions, question_type_to_message_id, OrderRequest, parse_order_request -from ibind.support.logs import project_logger +from ibind.client.ibkr_utils import ( + StockQuery, + filter_stocks, + find_answer, + QuestionType, + handle_questions, + question_type_to_message_id, + OrderRequest, + parse_order_request, +) from test.integration.client import ibkr_responses -from test_utils import verify_log - - -class TestIbkrUtilsI(TestCase): - def setUp(self): - self.instruments = ibkr_responses.responses['stocks'] - self.result = Result(data=self.instruments) - self.maxDiff = None - - def test_filter_stocks(self): - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': True}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery(symbol='GOOG', contract_conditions={'isUS': False}, instrument_conditions={'chineseName': 'Alphabet公司'}), - 'HUBS', - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER'), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - StockQuery(symbol='INVALID_SYMBOL') - ] # fmt: skip - with self.assertLogs(project_logger(), level='INFO') as cm: - rv = filter_stocks(queries, Result(data=self.instruments), default_filtering=False) - - verify_log( - self, cm, [f'Error getting stocks. Could not find valid instruments INVALID_SYMBOL in result: {self.result}. Skipping query={queries[-1]}.'] - ) # fmt: skip - - # pprint(rv) - - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': '苹果公司', - 'contracts': [ - {'conid': 38708077, 'exchange': 'MEXI', 'isUS': False}, - {'conid': 273982664, 'exchange': 'EBS', 'isUS': False}, - ], - 'name': 'APPLE INC', - }, - { - 'assetClass': 'STK', - 'chineseName': '苹果公司', - 'contracts': [{'conid': 532640894, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'APPLE INC-CDR', - }, +from test.test_utils_new import CaptureLogsContext + + +# -------------------------------------------------------------------------------------- +# Stock filtering +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def instruments(): + return ibkr_responses.responses['stocks'] + + +@pytest.fixture +def instruments_result(instruments): + return Result(data=instruments) + + +def test_filter_stocks(instruments, instruments_result): + """Filters instruments for multiple stock queries and logs missing symbols.""" + ## Arrange + queries = [ + StockQuery(symbol='AAPL', contract_conditions={'isUS': False}, name_match='APPLE'), + StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), + StockQuery(symbol='CDN', contract_conditions={'isUS': True}), + StockQuery(symbol='CFC', contract_conditions={}), + StockQuery( + symbol='GOOG', + contract_conditions={'isUS': False}, + instrument_conditions={'chineseName': 'Alphabet公司'}, + ), + 'HUBS', + StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False}, instrument_conditions={}), + StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='SAN', name_match='SANTANDER'), + StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NASDAQ'}), + StockQuery(symbol='TEAM', name_match='ATLASSIAN'), + StockQuery(symbol='INVALID_SYMBOL'), + ] # fmt: skip + + ## Act + with CaptureLogsContext('ibind', level='INFO', error_level='CRITICAL', attach_stack=False) as cm: + rv = filter_stocks(queries, instruments_result, default_filtering=False) + + ## Assert + expected_error = ( + f'Error getting stocks. Could not find valid instruments INVALID_SYMBOL in result: {instruments_result}. ' + f'Skipping query={queries[-1]}.' + ) + assert expected_error in cm.output + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '苹果公司', + 'contracts': [ + {'conid': 38708077, 'exchange': 'MEXI', 'isUS': False}, + {'conid': 273982664, 'exchange': 'EBS', 'isUS': False}, ], - rv.data['AAPL'], - ) - - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': '西班牙对外银行', - 'contracts': [{'conid': 4815, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'BANCO BILBAO VIZCAYA-SP ADR', - }, + 'name': 'APPLE INC', + }, + { + 'assetClass': 'STK', + 'chineseName': '苹果公司', + 'contracts': [{'conid': 532640894, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'APPLE INC-CDR', + }, + ] == rv.data['AAPL'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '西班牙对外银行', + 'contracts': [{'conid': 4815, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'BANCO BILBAO VIZCAYA-SP ADR', + }, + ] == rv.data['BBVA'] + + assert [] == rv.data['CDN'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': None, + 'contracts': [{'conid': 42001300, 'exchange': 'IBIS', 'isUS': False}], + 'name': 'UET UNITED ELECTRONIC TECHNO', + } + ] == rv.data['CFC'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'Alphabet公司', + 'contracts': [ + {'conid': 210810667, 'exchange': 'MEXI', 'isUS': False}, ], - rv.data['BBVA'], - ) + 'name': 'ALPHABET INC-CL C', + }, + { + 'assetClass': 'STK', + 'chineseName': 'Alphabet公司', + 'contracts': [{'conid': 532638805, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'ALPHABET INC - CDR', + }, + ] == rv.data['GOOG'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'HubSpot公司', + 'contracts': [{'conid': 169544810, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'HUBSPOT INC', + } + ] == rv.data['HUBS'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': 'Meta平台股份有限公司', + 'contracts': [ + {'conid': 114922621, 'exchange': 'MEXI', 'isUS': False}, + ], + 'name': 'META PLATFORMS INC-CLASS A', + }, + { + 'assetClass': 'STK', + 'chineseName': 'Meta平台股份有限公司', + 'contracts': [{'conid': 530091499, 'exchange': 'AEQLIT', 'isUS': False}], + 'name': 'META PLATFORMS INC-CDR', + }, + ] == rv.data['META'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '微软公司', + 'contracts': [ + {'conid': 272093, 'exchange': 'NASDAQ', 'isUS': True}, + ], + 'name': 'MICROSOFT CORP', + }, + ] == rv.data['MSFT'] + + assert [ + { + 'assetClass': 'STK', + 'chineseName': '桑坦德', + 'contracts': [ + {'conid': 38708867, 'exchange': 'MEXI', 'isUS': False}, + {'conid': 385055564, 'exchange': 'WSE', 'isUS': False}, + ], + 'name': 'BANCO SANTANDER SA', + }, + { + 'assetClass': 'STK', + 'chineseName': '桑坦德', + 'contracts': [{'conid': 12442, 'exchange': 'NYSE', 'isUS': True}], + 'name': 'BANCO SANTANDER SA-SPON ADR', + }, + { + 'assetClass': 'STK', + 'chineseName': '桑坦德英国公共有限公司', + 'contracts': [{'conid': 80993135, 'exchange': 'LSE', 'isUS': False}], + 'name': 'SANTANDER UK PLC', + }, + ] == rv.data['SAN'] - self.assertEqual([], rv.data['CDN']) + assert [] == rv.data['SCHW'] - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': None, - 'contracts': [{'conid': 42001300, 'exchange': 'IBIS', 'isUS': False}], - 'name': 'UET UNITED ELECTRONIC TECHNO', - } - ], - rv.data['CFC'], - ) + assert [ + { + 'assetClass': 'STK', + 'chineseName': None, + 'contracts': [{'conid': 589316251, 'exchange': 'NASDAQ', 'isUS': True}], + 'name': 'ATLASSIAN CORP-CL A', + }, + ] == rv.data['TEAM'] - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': 'Alphabet公司', - 'contracts': [ - {'conid': 210810667, 'exchange': 'MEXI', 'isUS': False}, - ], - 'name': 'ALPHABET INC-CL C', - }, - { - 'assetClass': 'STK', - 'chineseName': 'Alphabet公司', - 'contracts': [{'conid': 532638805, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'ALPHABET INC - CDR', - }, - ], - rv.data['GOOG'], - ) - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': 'HubSpot公司', - 'contracts': [{'conid': 169544810, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'HUBSPOT INC', - } - ], - rv.data['HUBS'], - ) +def test_question_type_to_message_id_successful(): + """Maps a QuestionType to its expected IBKR message id.""" + ## Arrange + question_type = QuestionType.PRICE_PERCENTAGE_CONSTRAINT - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': 'Meta平台股份有限公司', - 'contracts': [ - {'conid': 114922621, 'exchange': 'MEXI', 'isUS': False}, - ], - 'name': 'META PLATFORMS INC-CLASS A', - }, - { - 'assetClass': 'STK', - 'chineseName': 'Meta平台股份有限公司', - 'contracts': [{'conid': 530091499, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'META PLATFORMS INC-CDR', - }, - ], - rv.data['META'], - ) + ## Act + message_id = question_type_to_message_id(question_type) - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': '微软公司', - 'contracts': [ - {'conid': 272093, 'exchange': 'NASDAQ', 'isUS': True}, - ], - 'name': 'MICROSOFT CORP', - }, - ], - rv.data['MSFT'], - ) + ## Assert + assert message_id == 'o163' - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': '桑坦德', - 'contracts': [ - {'conid': 38708867, 'exchange': 'MEXI', 'isUS': False}, - {'conid': 385055564, 'exchange': 'WSE', 'isUS': False}, - ], - 'name': 'BANCO SANTANDER SA', - }, - { - 'assetClass': 'STK', - 'chineseName': '桑坦德', - 'contracts': [{'conid': 12442, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'BANCO SANTANDER SA-SPON ADR', - }, - { - 'assetClass': 'STK', - 'chineseName': '桑坦德英国公共有限公司', - 'contracts': [{'conid': 80993135, 'exchange': 'LSE', 'isUS': False}], - 'name': 'SANTANDER UK PLC', - }, - ], - rv.data['SAN'], - ) - self.assertEqual([], rv.data['SCHW']) +# -------------------------------------------------------------------------------------- +# Finding answers +# -------------------------------------------------------------------------------------- - self.assertEqual( - [ - { - 'assetClass': 'STK', - 'chineseName': None, - 'contracts': [{'conid': 589316251, 'exchange': 'NASDAQ', 'isUS': True}], - 'name': 'ATLASSIAN CORP-CL A', - }, - ], - rv.data['TEAM'], - ) - def test_question_type_to_message_id_successful(self): - question_type = QuestionType.PRICE_PERCENTAGE_CONSTRAINT - message_id = question_type_to_message_id(question_type) - self.assertEqual(message_id, 'o163') +@pytest.fixture +def answers(): + return {QuestionType.PRICE_PERCENTAGE_CONSTRAINT: True} -class TestFindAnswer(TestCase): - def setUp(self): - # Setup Answers dictionary here - self.answers = {QuestionType.PRICE_PERCENTAGE_CONSTRAINT: True} +def test_valid_question(answers): + """Returns True when a known question type is found in the question string.""" + ## Arrange + question = f'Some {QuestionType.PRICE_PERCENTAGE_CONSTRAINT} specific question' - def test_valid_question(self): - question = f'Some {QuestionType.PRICE_PERCENTAGE_CONSTRAINT} specific question' - answer = find_answer(question, self.answers) - self.assertTrue(answer) + ## Act + answer = find_answer(question, answers) - def test_invalid_question(self): - question = 'Nonexistent question type' - with self.assertRaises(ValueError): - find_answer(question, self.answers) + ## Assert + assert answer is True -class TestHandleQuestionsI(TestCase): - def setUp(self): - self.original_result = Result( - data=[{'id': '12345', 'message': ['price exceeds the Percentage constraint of 3%.']}], request={'url': 'test_url'} - ) - self.answers = {QuestionType.PRICE_PERCENTAGE_CONSTRAINT: True} - self.reply_callback = MagicMock() - - @patch('ibind.client.ibkr_utils.QuestionType') - def test_successful_handling(self, question_type_mock): - # Mocking the QuestionType enum - question_type_mock.PRICE_PERCENTAGE_CONSTRAINT.__str__.return_value = 'price exceeds the Percentage constraint of 3%.' - question_type_mock.ADDITIONAL_QUESTION_TYPE.__str__.return_value = 'This is an additional question.' - - self.answers = {question_type_mock.PRICE_PERCENTAGE_CONSTRAINT: True, question_type_mock.ADDITIONAL_QUESTION_TYPE: True} - - # Mock reply_callback to simulate the sequence of question-answer interactions - replies = [ - Result(data=[{'id': '12346', 'message': ['This is an additional question.']}], request={'url': 'another_question_url'}), - Result(data=[{'id': '12347'}], request={'url': 'final_url'}), # No more questions - ] - self.reply_callback.side_effect = replies - - result = handle_questions(self.original_result, self.answers, self.reply_callback) - self.assertEqual(result.request['url'], self.original_result.request['url']) - self.assertEqual(len(self.reply_callback.call_args_list), 2) - # Expected calls to self.reply_callback - expected_calls = [ - call( - self.original_result.data[0]['id'], self.answers[question_type_mock.PRICE_PERCENTAGE_CONSTRAINT] - ), # First call with question ID '12346' and reply True - call( - replies[0].data[0]['id'], self.answers[question_type_mock.ADDITIONAL_QUESTION_TYPE] - ), # Second call with question ID '12347' and reply True - ] - - # Check if the calls to self.reply_callback are as expected - self.assertEqual(expected_calls, self.reply_callback.call_args_list) - - def test_too_many_questions(self): - # Simulate repetitive questions to exceed the question limit - self.reply_callback.side_effect = [self.original_result] * 21 - - with self.assertRaises(RuntimeError) as cm_err: - handle_questions(self.original_result, self.answers, self.reply_callback) - - self.assertIn('Too many questions', str(cm_err.exception)) - - def test_negative_reply(self): - # Set a negative answer - self.answers[QuestionType.PRICE_PERCENTAGE_CONSTRAINT] = False - - with self.assertRaises(RuntimeError) as cm_err: - handle_questions(self.original_result, self.answers, self.reply_callback) - self.assertEqual( - f'A question was not given a positive reply. Question: "{self.original_result.data[0]["message"][0]}". Answers: \n{self.answers}\n. Request: {self.original_result.request}', - str(cm_err.exception), - ) +def test_invalid_question(answers): + """Raises when no answer matches the provided question string.""" + ## Arrange + question = 'Nonexistent question type' + + ## Act & Assert + with pytest.raises(ValueError): + find_answer(question, answers) + + +# -------------------------------------------------------------------------------------- +# Handling interactive questions +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def original_result(): + return Result( + data=[{'id': '12345', 'message': ['price exceeds the Percentage constraint of 3%.']}], + request={'url': 'test_url'}, + ) + + +@pytest.fixture +def reply_callback(): + return MagicMock() + - def test_multiple_orders_returned(self): - # Simulate multiple orders in the data - self.original_result.data = [ - {'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, - {'id': '12346', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, - ] - self.reply_callback.return_value = self.original_result.copy(data=[{}]) +def test_successful_handling(mocker, original_result, reply_callback): + """Replies to a sequence of questions and returns the final result.""" + ## Arrange + question_type_mock = mocker.patch('ibind.client.ibkr_utils.QuestionType') - with self.assertLogs(project_logger(), level='INFO') as cm: - handle_questions(self.original_result, self.answers, self.reply_callback) + question_type_mock.PRICE_PERCENTAGE_CONSTRAINT.__str__.return_value = 'price exceeds the Percentage constraint of 3%.' + question_type_mock.ADDITIONAL_QUESTION_TYPE.__str__.return_value = 'This is an additional question.' - verify_log(self, cm, ['While handling questions multiple orders were returned: ' + pformat(self.original_result.data)]) + answers = {question_type_mock.PRICE_PERCENTAGE_CONSTRAINT: True, question_type_mock.ADDITIONAL_QUESTION_TYPE: True} - def test_multiple_messages_returned(self): - # Simulate a single order with multiple messages - self.original_result.data = [{'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT), 'Message 2']}] - self.reply_callback.return_value = self.original_result.copy(data=[{}]) + replies = [ + Result(data=[{'id': '12346', 'message': ['This is an additional question.']}], request={'url': 'another_question_url'}), + Result(data=[{'id': '12347'}], request={'url': 'final_url'}), + ] + reply_callback.side_effect = replies - with self.assertLogs(project_logger(), level='INFO') as cm: - handle_questions(self.original_result, self.answers, self.reply_callback) + ## Act + result = handle_questions(original_result, answers, reply_callback) - verify_log(self, cm, ['While handling questions multiple messages were returned: ' + pformat(self.original_result.data[0]['message'])]) + ## Assert + assert result.request['url'] == original_result.request['url'] + assert len(reply_callback.call_args_list) == 2 -class TestParseOrderRequestI(TestCase): - def test_parse_both_with_conidex(self): + expected_calls = [ + call(original_result.data[0]['id'], answers[question_type_mock.PRICE_PERCENTAGE_CONSTRAINT]), + call(replies[0].data[0]['id'], answers[question_type_mock.ADDITIONAL_QUESTION_TYPE]), + ] + + assert expected_calls == reply_callback.call_args_list + + +def test_too_many_questions(original_result, answers, reply_callback): + """Raises when the question loop exceeds the maximum number of attempts.""" + ## Arrange + reply_callback.side_effect = [original_result] * 21 + + ## Act & Assert + with pytest.raises(RuntimeError) as cm_err: + handle_questions(original_result, answers, reply_callback) + + assert 'Too many questions' in str(cm_err.value) + + +def test_negative_reply(original_result, answers, reply_callback): + """Raises when a question is answered negatively.""" + ## Arrange + answers[QuestionType.PRICE_PERCENTAGE_CONSTRAINT] = False + + ## Act & Assert + with pytest.raises(RuntimeError) as cm_err: + handle_questions(original_result, answers, reply_callback) + + assert ( + f'A question was not given a positive reply. Question: "{original_result.data[0]["message"][0]}". Answers: \n{answers}\n. Request: {original_result.request}' + == str(cm_err.value) + ) + + +def test_multiple_orders_returned(original_result, answers, reply_callback): + """Logs a message when multiple orders are returned while handling questions.""" + ## Arrange + original_result.data = [ + {'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, + {'id': '12346', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, + ] + reply_callback.return_value = original_result.copy(data=[{}]) + + expected = 'While handling questions multiple orders were returned: ' + pformat(original_result.data) + + ## Act & Assert + with CaptureLogsContext('ibind', level='INFO', expected_errors=[expected], attach_stack=False): + handle_questions(original_result, answers, reply_callback) + + +def test_multiple_messages_returned(original_result, answers, reply_callback): + """Logs a message when multiple messages are returned for a single order.""" + ## Arrange + original_result.data = [{'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT), 'Message 2']}] + reply_callback.return_value = original_result.copy(data=[{}]) + + expected = 'While handling questions multiple messages were returned: ' + pformat(original_result.data[0]['message']) + + ## Act & Assert + with CaptureLogsContext('ibind', level='INFO', expected_errors=[expected], attach_stack=False): + handle_questions(original_result, answers, reply_callback) + + +# -------------------------------------------------------------------------------------- +# Order request parsing +# -------------------------------------------------------------------------------------- + + +def test_parse_both_with_conidex(): + """Parses OrderRequest with conid=None and conidex set into API payload.""" + ## Arrange + order_request = OrderRequest( + conid=None, + side='BUY', + quantity=321, + order_type='MKT', + acct_id='DU1234567', + conidex='33333', + ) + + ## Act + d = parse_order_request(order_request) + + ## Assert + assert { + 'side': 'BUY', + 'quantity': 321, + 'orderType': 'MKT', + 'acctId': 'DU1234567', + 'conidex': '33333', + 'tif': 'GTC', + } == d + + +def test_raise_with_conid_and_conidex(): + """Raises when both conid and conidex are provided.""" + ## Arrange + + ## Act & Assert + with pytest.raises(ValueError) as cm_err: order_request = OrderRequest( - conid=None, + conid=123, side='BUY', quantity=321, order_type='MKT', acct_id='DU1234567', - conidex='33333' # should cause exception + conidex='33333', ) - d = parse_order_request(order_request) - - self.assertEqual({ - 'side': 'BUY', - 'quantity': 321, - 'orderType': 'MKT', - 'acctId': 'DU1234567', - 'conidex': '33333', - 'tif': 'GTC' - }, d) - - def test_raise_with_conid_and_conidex(self): - with self.assertRaises(ValueError) as cm_err: - order_request = OrderRequest( - conid=123, - side='BUY', - quantity=321, - order_type='MKT', - acct_id='DU1234567', - conidex='33333' # should cause exception - ) - - parse_order_request(order_request) - - self.assertEqual("Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`.", str(cm_err.exception)) - + parse_order_request(order_request) + assert "Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`." == str(cm_err.value) \ No newline at end of file diff --git a/test/integration/client/test_ibkr_utils_i_new.py b/test/integration/client/test_ibkr_utils_i_new.py deleted file mode 100644 index 2eb92395..00000000 --- a/test/integration/client/test_ibkr_utils_i_new.py +++ /dev/null @@ -1,394 +0,0 @@ -from pprint import pformat -from unittest.mock import MagicMock, call - -import pytest - -from ibind.base.rest_client import Result -from ibind.client.ibkr_utils import ( - StockQuery, - filter_stocks, - find_answer, - QuestionType, - handle_questions, - question_type_to_message_id, - OrderRequest, - parse_order_request, -) -from test.integration.client import ibkr_responses -from test.test_utils_new import CaptureLogsContext - - -# -------------------------------------------------------------------------------------- -# Stock filtering -# -------------------------------------------------------------------------------------- - - -@pytest.fixture -def instruments(): - return ibkr_responses.responses['stocks'] - - -@pytest.fixture -def instruments_result(instruments): - return Result(data=instruments) - - -def test_filter_stocks(instruments, instruments_result): - """Filters instruments for multiple stock queries and logs missing symbols.""" - ## Arrange - queries = [ - StockQuery(symbol='AAPL', contract_conditions={'isUS': False}, name_match='APPLE'), - StockQuery(symbol='BBVA', contract_conditions={'exchange': 'NYSE'}), - StockQuery(symbol='CDN', contract_conditions={'isUS': True}), - StockQuery(symbol='CFC', contract_conditions={}), - StockQuery( - symbol='GOOG', - contract_conditions={'isUS': False}, - instrument_conditions={'chineseName': 'Alphabet公司'}, - ), - 'HUBS', - StockQuery(symbol='META', name_match='meta ', contract_conditions={'isUS': False}, instrument_conditions={}), - StockQuery(symbol='MSFT', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='SAN', name_match='SANTANDER'), - StockQuery(symbol='SCHW', contract_conditions={'exchange': 'NASDAQ'}), - StockQuery(symbol='TEAM', name_match='ATLASSIAN'), - StockQuery(symbol='INVALID_SYMBOL'), - ] # fmt: skip - - ## Act - with CaptureLogsContext('ibind', level='INFO', error_level='CRITICAL', attach_stack=False) as cm: - rv = filter_stocks(queries, instruments_result, default_filtering=False) - - ## Assert - expected_error = ( - f'Error getting stocks. Could not find valid instruments INVALID_SYMBOL in result: {instruments_result}. ' - f'Skipping query={queries[-1]}.' - ) - assert expected_error in cm.output - - assert [ - { - 'assetClass': 'STK', - 'chineseName': '苹果公司', - 'contracts': [ - {'conid': 38708077, 'exchange': 'MEXI', 'isUS': False}, - {'conid': 273982664, 'exchange': 'EBS', 'isUS': False}, - ], - 'name': 'APPLE INC', - }, - { - 'assetClass': 'STK', - 'chineseName': '苹果公司', - 'contracts': [{'conid': 532640894, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'APPLE INC-CDR', - }, - ] == rv.data['AAPL'] - - assert [ - { - 'assetClass': 'STK', - 'chineseName': '西班牙对外银行', - 'contracts': [{'conid': 4815, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'BANCO BILBAO VIZCAYA-SP ADR', - }, - ] == rv.data['BBVA'] - - assert [] == rv.data['CDN'] - - assert [ - { - 'assetClass': 'STK', - 'chineseName': None, - 'contracts': [{'conid': 42001300, 'exchange': 'IBIS', 'isUS': False}], - 'name': 'UET UNITED ELECTRONIC TECHNO', - } - ] == rv.data['CFC'] - - assert [ - { - 'assetClass': 'STK', - 'chineseName': 'Alphabet公司', - 'contracts': [ - {'conid': 210810667, 'exchange': 'MEXI', 'isUS': False}, - ], - 'name': 'ALPHABET INC-CL C', - }, - { - 'assetClass': 'STK', - 'chineseName': 'Alphabet公司', - 'contracts': [{'conid': 532638805, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'ALPHABET INC - CDR', - }, - ] == rv.data['GOOG'] - - assert [ - { - 'assetClass': 'STK', - 'chineseName': 'HubSpot公司', - 'contracts': [{'conid': 169544810, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'HUBSPOT INC', - } - ] == rv.data['HUBS'] - - assert [ - { - 'assetClass': 'STK', - 'chineseName': 'Meta平台股份有限公司', - 'contracts': [ - {'conid': 114922621, 'exchange': 'MEXI', 'isUS': False}, - ], - 'name': 'META PLATFORMS INC-CLASS A', - }, - { - 'assetClass': 'STK', - 'chineseName': 'Meta平台股份有限公司', - 'contracts': [{'conid': 530091499, 'exchange': 'AEQLIT', 'isUS': False}], - 'name': 'META PLATFORMS INC-CDR', - }, - ] == rv.data['META'] - - assert [ - { - 'assetClass': 'STK', - 'chineseName': '微软公司', - 'contracts': [ - {'conid': 272093, 'exchange': 'NASDAQ', 'isUS': True}, - ], - 'name': 'MICROSOFT CORP', - }, - ] == rv.data['MSFT'] - - assert [ - { - 'assetClass': 'STK', - 'chineseName': '桑坦德', - 'contracts': [ - {'conid': 38708867, 'exchange': 'MEXI', 'isUS': False}, - {'conid': 385055564, 'exchange': 'WSE', 'isUS': False}, - ], - 'name': 'BANCO SANTANDER SA', - }, - { - 'assetClass': 'STK', - 'chineseName': '桑坦德', - 'contracts': [{'conid': 12442, 'exchange': 'NYSE', 'isUS': True}], - 'name': 'BANCO SANTANDER SA-SPON ADR', - }, - { - 'assetClass': 'STK', - 'chineseName': '桑坦德英国公共有限公司', - 'contracts': [{'conid': 80993135, 'exchange': 'LSE', 'isUS': False}], - 'name': 'SANTANDER UK PLC', - }, - ] == rv.data['SAN'] - - assert [] == rv.data['SCHW'] - - assert [ - { - 'assetClass': 'STK', - 'chineseName': None, - 'contracts': [{'conid': 589316251, 'exchange': 'NASDAQ', 'isUS': True}], - 'name': 'ATLASSIAN CORP-CL A', - }, - ] == rv.data['TEAM'] - - -def test_question_type_to_message_id_successful(): - """Maps a QuestionType to its expected IBKR message id.""" - ## Arrange - question_type = QuestionType.PRICE_PERCENTAGE_CONSTRAINT - - ## Act - message_id = question_type_to_message_id(question_type) - - ## Assert - assert message_id == 'o163' - - -# -------------------------------------------------------------------------------------- -# Finding answers -# -------------------------------------------------------------------------------------- - - -@pytest.fixture -def answers(): - return {QuestionType.PRICE_PERCENTAGE_CONSTRAINT: True} - - -def test_valid_question(answers): - """Returns True when a known question type is found in the question string.""" - ## Arrange - question = f'Some {QuestionType.PRICE_PERCENTAGE_CONSTRAINT} specific question' - - ## Act - answer = find_answer(question, answers) - - ## Assert - assert answer is True - - -def test_invalid_question(answers): - """Raises when no answer matches the provided question string.""" - ## Arrange - question = 'Nonexistent question type' - - ## Act & Assert - with pytest.raises(ValueError): - find_answer(question, answers) - - -# -------------------------------------------------------------------------------------- -# Handling interactive questions -# -------------------------------------------------------------------------------------- - - -@pytest.fixture -def original_result(): - return Result( - data=[{'id': '12345', 'message': ['price exceeds the Percentage constraint of 3%.']}], - request={'url': 'test_url'}, - ) - - -@pytest.fixture -def reply_callback(): - return MagicMock() - - -def test_successful_handling(mocker, original_result, reply_callback): - """Replies to a sequence of questions and returns the final result.""" - ## Arrange - question_type_mock = mocker.patch('ibind.client.ibkr_utils.QuestionType') - - question_type_mock.PRICE_PERCENTAGE_CONSTRAINT.__str__.return_value = 'price exceeds the Percentage constraint of 3%.' - question_type_mock.ADDITIONAL_QUESTION_TYPE.__str__.return_value = 'This is an additional question.' - - answers = {question_type_mock.PRICE_PERCENTAGE_CONSTRAINT: True, question_type_mock.ADDITIONAL_QUESTION_TYPE: True} - - replies = [ - Result(data=[{'id': '12346', 'message': ['This is an additional question.']}], request={'url': 'another_question_url'}), - Result(data=[{'id': '12347'}], request={'url': 'final_url'}), - ] - reply_callback.side_effect = replies - - ## Act - result = handle_questions(original_result, answers, reply_callback) - - ## Assert - assert result.request['url'] == original_result.request['url'] - assert len(reply_callback.call_args_list) == 2 - - expected_calls = [ - call(original_result.data[0]['id'], answers[question_type_mock.PRICE_PERCENTAGE_CONSTRAINT]), - call(replies[0].data[0]['id'], answers[question_type_mock.ADDITIONAL_QUESTION_TYPE]), - ] - - assert expected_calls == reply_callback.call_args_list - - -def test_too_many_questions(original_result, answers, reply_callback): - """Raises when the question loop exceeds the maximum number of attempts.""" - ## Arrange - reply_callback.side_effect = [original_result] * 21 - - ## Act & Assert - with pytest.raises(RuntimeError) as cm_err: - handle_questions(original_result, answers, reply_callback) - - assert 'Too many questions' in str(cm_err.value) - - -def test_negative_reply(original_result, answers, reply_callback): - """Raises when a question is answered negatively.""" - ## Arrange - answers[QuestionType.PRICE_PERCENTAGE_CONSTRAINT] = False - - ## Act & Assert - with pytest.raises(RuntimeError) as cm_err: - handle_questions(original_result, answers, reply_callback) - - assert ( - f'A question was not given a positive reply. Question: "{original_result.data[0]["message"][0]}". Answers: \n{answers}\n. Request: {original_result.request}' - == str(cm_err.value) - ) - - -def test_multiple_orders_returned(original_result, answers, reply_callback): - """Logs a message when multiple orders are returned while handling questions.""" - ## Arrange - original_result.data = [ - {'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, - {'id': '12346', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT)]}, - ] - reply_callback.return_value = original_result.copy(data=[{}]) - - expected = 'While handling questions multiple orders were returned: ' + pformat(original_result.data) - - ## Act & Assert - with CaptureLogsContext('ibind', level='INFO', expected_errors=[expected], attach_stack=False): - handle_questions(original_result, answers, reply_callback) - - -def test_multiple_messages_returned(original_result, answers, reply_callback): - """Logs a message when multiple messages are returned for a single order.""" - ## Arrange - original_result.data = [{'id': '12345', 'message': [str(QuestionType.PRICE_PERCENTAGE_CONSTRAINT), 'Message 2']}] - reply_callback.return_value = original_result.copy(data=[{}]) - - expected = 'While handling questions multiple messages were returned: ' + pformat(original_result.data[0]['message']) - - ## Act & Assert - with CaptureLogsContext('ibind', level='INFO', expected_errors=[expected], attach_stack=False): - handle_questions(original_result, answers, reply_callback) - - -# -------------------------------------------------------------------------------------- -# Order request parsing -# -------------------------------------------------------------------------------------- - - -def test_parse_both_with_conidex(): - """Parses OrderRequest with conid=None and conidex set into API payload.""" - ## Arrange - order_request = OrderRequest( - conid=None, - side='BUY', - quantity=321, - order_type='MKT', - acct_id='DU1234567', - conidex='33333', - ) - - ## Act - d = parse_order_request(order_request) - - ## Assert - assert { - 'side': 'BUY', - 'quantity': 321, - 'orderType': 'MKT', - 'acctId': 'DU1234567', - 'conidex': '33333', - 'tif': 'GTC', - } == d - - -def test_raise_with_conid_and_conidex(): - """Raises when both conid and conidex are provided.""" - ## Arrange - - ## Act & Assert - with pytest.raises(ValueError) as cm_err: - order_request = OrderRequest( - conid=123, - side='BUY', - quantity=321, - order_type='MKT', - acct_id='DU1234567', - conidex='33333', - ) - - parse_order_request(order_request) - - assert "Both 'conidex' and 'conid' are provided. When using 'conidex', specify `conid=None`." == str(cm_err.value) \ No newline at end of file diff --git a/test/integration/client/test_ibkr_ws_client_i.py b/test/integration/client/test_ibkr_ws_client_i.py index 8fe6aafb..60da7ab7 100644 --- a/test/integration/client/test_ibkr_ws_client_i.py +++ b/test/integration/client/test_ibkr_ws_client_i.py @@ -1,415 +1,539 @@ import json -import logging from threading import Thread from typing import Optional -from unittest import TestCase -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, call +import pytest import requests from ibind import Result from ibind.client.ibkr_client import IbkrClient from ibind.client.ibkr_ws_client import IbkrWsClient, IbkrSubscriptionProcessor, IbkrWsKey -from ibind.support.logs import project_logger -from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils import RaiseLogsContext, SafeAssertLogs +from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock +from test_utils_new import capture_logs + +_URL_WS = 'wss://localhost:5000/v1/api/ws' +_URL_REST = 'https://localhost:5000' +_ACCOUNT_ID = 'TEST_ACCOUNT_ID' +_TIMEOUT_REST = 8 +_MAX_RETRIES_REST = 4 +_MAX_RECONNECT_ATTEMPTS = 4 +_MAX_PING_INTERVAL = 38 +_SUBSCRIPTION_RETRIES = 3 +_CONID = 265598 +_UPDATE_TIME = 5678765456 + + +# -------------------------------------------------------------------------------------- +# Test setup +# -------------------------------------------------------------------------------------- + + +@pytest.fixture +def preprocess_ws_client(): + return IbkrWsClient( + url=_URL_WS, + ibkr_client=None, + account_id=None, + subscription_processor_class=lambda: None, + ) + + +@pytest.fixture +def client_mock(): + client = MagicMock( + spec=IbkrClient( + url=_URL_REST, + account_id=_ACCOUNT_ID, + timeout=_TIMEOUT_REST, + max_retries=_MAX_RETRIES_REST, + ) + ) + client.tickle.return_value.data = {'session': 'TEST_COOKIE'} + return client -class TestPreprocessRawMessage(TestCase): - def setUp(self): - self.url = 'wss://localhost:5000/v1/api/ws' +@pytest.fixture +def ws_client(client_mock): + return IbkrWsClient( + url=_URL_WS, + ibkr_client=client_mock, + account_id=_ACCOUNT_ID, + subscription_processor_class=IbkrSubscriptionProcessor, + subscription_retries=_SUBSCRIPTION_RETRIES, + subscription_timeout=0.01, + cacert=False, + timeout=0.01, + max_connection_attempts=_MAX_RECONNECT_ATTEMPTS, + max_ping_interval=_MAX_PING_INTERVAL, + ) - self.ws_client = IbkrWsClient( - url=self.url, - ibkr_client=None, - account_id=None, - subscription_processor_class=lambda: None, - ) - def test_preprocess_with_well_formed_message(self): - raw_message = json.dumps({'topic': 'actABC', 'args': {'key': 'value'}}) - expected_result = ( - {'topic': 'actABC', 'args': {'key': 'value'}}, # message - 'actABC', # topic - {'key': 'value'}, # data - 'a', # subscribed - 'ctABC', # channel - ) - self.assertEqual(self.ws_client._preprocess_raw_message(raw_message), expected_result) - - def test_preprocess_with_unsubscribed_message(self): - raw_message = json.dumps({'message': 'Unsubscribed'}) - expected_result = ({'message': 'Unsubscribed'}, None, None, None, None) - self.assertEqual(self.ws_client._preprocess_raw_message(raw_message), expected_result) - - -class TestIbkrWsClient(TestCase): - # Assuming IbkrWsClient is the class containing preprocess_raw_message - - def setUp(self): - # Assuming similar initialization parameters as in WsClient - self.url = 'wss://localhost:5000/v1/api/ws' - self.max_reconnect_attempts = 4 - self.max_ping_interval = 38 - - self.url_rest = 'https://localhost:5000' - self.account_id = 'TEST_ACCOUNT_ID' - self.timeout = 8 - self.max_retries = 4 - self.subscription_retries = 3 - self.client = MagicMock( - spec=IbkrClient( - url=self.url_rest, - account_id=self.account_id, - timeout=self.timeout, - max_retries=self.max_retries, - ) - ) - self.client.tickle.return_value.data = {'session': 'TEST_COOKIE'} - - self.SubscriptionProcessorClass = IbkrSubscriptionProcessor - - # Initialize the IbkrWsClient - self.ws_client = IbkrWsClient( - url=self.url, - ibkr_client=self.client, - account_id=self.account_id, - subscription_processor_class=self.SubscriptionProcessorClass, - subscription_retries=self.subscription_retries, - subscription_timeout=0.01, - cacert=False, - timeout=0.01, - max_connection_attempts=self.max_reconnect_attempts, - max_ping_interval=self.max_ping_interval, - ) +@pytest.fixture +def wsa_mock(): + return create_wsa_mock() - self.wsa_mock = create_wsa_mock() - self.thread_mock = MagicMock(spec=Thread) - self.thread_mock.start.side_effect = lambda: self.ws_client._run_websocket(self.wsa_mock) - - self.conid = 265598 - self.update_time = 5678765456 - - def run_in_test_context(self, fn, expected_errors: list[str] = None, expect_logs: bool = True): - with patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: init_wsa_mock(self.wsa_mock, *args, **kwargs)), \ - patch('ibind.base.ws_client.Thread', return_value=self.thread_mock) as new_thread_mock, \ - SafeAssertLogs(self, 'ibind', level='DEBUG', logger_level='DEBUG', no_logs=not expect_logs) as cm, \ - RaiseLogsContext(self, 'ibind', level='WARNING', expected_errors=expected_errors): # fmt: skip - ws_client_logger = project_logger('ws_client') - old_level = ws_client_logger.getEffectiveLevel() - ws_client_logger.setLevel(logging.WARNING) - - self.new_thread_mock = new_thread_mock - try: - rv = fn() - except: - raise - finally: - ws_client_logger.setLevel(old_level) - - return cm, rv - - def _send_payload(self, payload: dict, expected_errors: list[str] = None, expect_logs: bool = True): - def run(): - success = self.ws_client.start() - raw_payload = json.dumps(payload) - self.ws_client.send(raw_payload) - self.ws_client.shutdown() - return success - - return self.run_in_test_context(run, expected_errors=expected_errors, expect_logs=expect_logs) - - def _subscribe(self, request: dict, response: Optional[dict], expected_errors: list[str] = None, expect_logs: bool = True): - def run(): - def override_on_message(wsa_mock: MagicMock, message: str): - if response is None: - return - raw_message = json.dumps(response) - wsa_mock.__on_message__(wsa_mock, raw_message) - - self.ws_client.start() - self.wsa_mock._on_message.side_effect = override_on_message - rv = self.ws_client.subscribe( - **{'channel': request.get('channel'), 'data': request.get('data'), 'needs_confirmation': request.get('needs_confirmation')} - ) - self.ws_client.unsubscribe( - **{'channel': request.get('channel'), 'data': request.get('data'), 'needs_confirmation': request.get('confirms_unsubscription')} - ) - self.ws_client.shutdown() - return rv - - return self.run_in_test_context(run, expected_errors=expected_errors, expect_logs=expect_logs) - - def test_on_message_system_heartbeat(self): - hb = 12345678 - cm, success = self._send_payload({'topic': 'system', 'hb': hb}, expect_logs=False) - # print("\n".join([r.msg for r in cm.records])) - self.assertEqual(self.ws_client._last_heartbeat, hb) - - def test_on_message_act_account_mismatch(self): - message_data = {'topic': 'act', 'args': {'accounts': ['OTHER_ACCOUNT_ID']}} - expected_errors = ["IbkrWsClient: Account ID mismatch: expected=TEST_ACCOUNT_ID, received=['OTHER_ACCOUNT_ID']"] - - cm, success = self._send_payload(message_data, expected_errors=expected_errors) - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - - def test_on_message_blt(self): - bulletin_message = {'topic': 'blt', 'args': {'bulletin_key': 'some_info'}} - - with patch.object(self.ws_client, '_handle_bulletin', MagicMock()) as mock_handle_bulletin: - cm, success = self._send_payload(bulletin_message, expect_logs=False) - mock_handle_bulletin.assert_called_once_with(bulletin_message) - - def test_on_message_sts_unauthenticated(self): - message_data = {'topic': 'sts', 'args': {'authenticated': False}} - session_id = 6545676 - - expected_errors = ["IbkrWsClient: Status unauthenticated: {'authenticated': False}", 'IbkrWsClient: Not authenticated, closing WebSocketApp'] - - response_mock = MagicMock(spec=requests.Response) - response_mock.status_code = 200 - response_mock.json.return_value = {'session': session_id, 'data_to_be_ignored': '1234'} - - self.client.tickle.return_value = Result(data=response_mock.json.return_value) - - with patch('ibind.base.rest_client.requests') as requests_mock: - requests_mock.request.return_value = response_mock - cm, success = self._send_payload(message_data, expected_errors=expected_errors) - - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - self.assertFalse(self.ws_client._authenticated) - - def test_on_message_sts_authenticated(self): - message_data = {'topic': 'sts', 'args': {'authenticated': True}} - cm, success = self._send_payload(message_data, expect_logs=False) - - def test_on_message_error(self): - message_data = {'topic': 'error', 'args': {'error_key': 'error_details'}} - expected_errors = [f'IbkrWsClient: Error message: {message_data}'] - - cm, success = self._send_payload(message_data, expected_errors=expected_errors) - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - - def test_on_message_no_topic_handler(self): - message_data = {'topic': 'unrecognized_topic', 'args': {'some_key': 'some_value'}} - expected_errors = [f'IbkrWsClient: Topic "{message_data["topic"]}" unrecognised. Message: {message_data}'] - - cm, success = self._send_payload(message_data, expected_errors=expected_errors) - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - - def test_on_message_handled_without_subscription(self): - message_data = {'topic': 'some_topic', 'args': {'channel': 'XYZ', 'data': 'info'}} - expected_errors = [ - f'IbkrWsClient: Handled a channel "{message_data["topic"][1:]}" message that is missing a subscription. Message: {message_data}' - ] - with patch.object(self.ws_client, '_handle_subscribed_message', return_value=True): - cm, success = self._send_payload(message_data, expected_errors=expected_errors) +@pytest.fixture +def thread_mock(ws_client, wsa_mock): + thread_mock = MagicMock(spec=Thread) + thread_mock.start.side_effect = lambda: ws_client._run_websocket(wsa_mock) + return thread_mock - self.assertEqual(expected_errors, [r.msg for r in cm.records]) - def _logs_subscriptions(self, full_channel, data=None, needs_confirmation_sub: bool = False, needs_confirmation_unsub: bool = True): - return [ - f'IbkrWsClient: Subscribed: s{full_channel}{"" if data is None else f"+{json.dumps(data)}"}{"" if not needs_confirmation_sub else " without confirmation."}', - f'IbkrWsClient: Unsubscribed: u{full_channel}+{json.dumps(data if data is not None else {})}{"" if not needs_confirmation_unsub else " without confirmation."}', - ] +@pytest.fixture +def ws_app_factory(wsa_mock): + # Use a mutable side-effect so individual tests can temporarily override WebSocketApp behavior. + return { + 'fn': lambda *args, **kwargs: init_wsa_mock(wsa_mock, *args, **kwargs), + } - def test_on_message_market_data_channel_handling(self): - queue = self.ws_client.new_queue_accessor(IbkrWsKey.MARKET_DATA) - full_channel = f'{queue.key.channel}+{self.conid}' - request = {'channel': f'{full_channel}', 'data': {'fields': ['55', '71', '84', '86', '88', '85', '87', '7295', '7296', '70']}} - response = { - 'topic': f's{full_channel}', - 'conid': self.conid, - '_updated': self.update_time, - 55: 'AAPL', - 70: '195.34', - 71: '193.67', - 87: '24.2M', - 7295: '194.10', - 84: '195.25', - 86: '195.26', - 88: '3,500', - 85: '500', - 6508: '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', - } - self.assertTrue(queue.empty(), 'Queue should be empty') - - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response) - self.assertTrue(success) - - self.assertEqual(self._logs_subscriptions(full_channel, request['data']), [r.msg for r in cm.records]) - - self.assertEqual( - { - self.conid: { - '_updated': self.update_time, - 'conid': self.conid, - 'topic': f'smd+{self.conid}', - 'ask_price': '195.26', - 'ask_size': '500', - 'bid_price': '195.25', - 'bid_size': '3,500', - 'high': '195.34', - 'low': '193.67', - 'open': '194.10', - 'service_params': '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', - 'symbol': 'AAPL', - 'volume': '24.2M', - } - }, - queue.get(), - ) +@pytest.fixture +def patched_constructors(mocker, thread_mock, ws_app_factory): + mocker.patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: ws_app_factory['fn'](*args, **kwargs)) + mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) + return None + + + +def _send_payload(ws_client, payload: dict): + success = ws_client.start() + ws_client.send(json.dumps(payload)) + ws_client.shutdown() + return success + - def test_on_message_market_history_channel_handling(self): - queue = self.ws_client.new_queue_accessor(IbkrWsKey.MARKET_HISTORY) - server_id = 87567 - full_channel = f'{queue.key.channel}+{self.conid}' - request = { - 'channel': f'{full_channel}', - 'data': {'period': '1min', 'bar': '1min', 'outsideRTH': True, 'source': 'trades', 'format': '%o/%c/%h/%l'}, - 'confirms_unsubscription': False, +def _subscribe(ws_client, wsa_mock, request: dict, response: Optional[dict]): + def override_on_message(wsa_mock: MagicMock, message: str): + if response is None: + return + raw_message = json.dumps(response) + wsa_mock.__on_message__(wsa_mock, raw_message) + + ws_client.start() + wsa_mock._on_message.side_effect = override_on_message + + rv = ws_client.subscribe( + **{ + 'channel': request.get('channel'), + 'data': request.get('data'), + 'needs_confirmation': request.get('needs_confirmation'), } - response = {'topic': f's{full_channel}', 'serverId': server_id, '_updated': self.update_time, 'conid': self.conid, 'foo': 'bar'} + ) + ws_client.unsubscribe( + **{ + 'channel': request.get('channel'), + 'data': request.get('data'), + 'needs_confirmation': request.get('confirms_unsubscription'), + } + ) + ws_client.shutdown() + return rv - self.assertTrue(queue.empty(), 'Queue should be empty') - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response) - self.assertTrue(success) - self.assertEqual(self._logs_subscriptions(full_channel, request['data']), [r.msg for r in cm.records]) +def _logs_subscriptions(full_channel, data=None, needs_confirmation_sub: bool = False, needs_confirmation_unsub: bool = True): + return [ + f'IbkrWsClient: Subscribed: s{full_channel}{"" if data is None else f"+{json.dumps(data)}"}{"" if not needs_confirmation_sub else " without confirmation."}', + f'IbkrWsClient: Unsubscribed: u{full_channel}+{json.dumps(data if data is not None else {})}{"" if not needs_confirmation_unsub else " without confirmation."}', + ] - self.assertEqual(response, queue.get()) - self.assertIn(server_id, self.ws_client.server_ids(IbkrWsKey.MARKET_HISTORY)) - def test_on_message_trade_channel_handling(self): - queue = self.ws_client.new_queue_accessor(IbkrWsKey.TRADES) - full_channel = f'{queue.key.channel}+{self.conid}' - request = {'channel': f'{full_channel}'} - response = {'topic': f's{full_channel}', '_updated': self.update_time, 'conid': self.conid, 'args': [{'foo': 'bar'}]} +# -------------------------------------------------------------------------------------- +# Message preprocessing +# -------------------------------------------------------------------------------------- - self.assertTrue(queue.empty(), 'Queue should be empty') - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response) - self.assertTrue(success) +def test_preprocess_with_well_formed_message(preprocess_ws_client): + """Preprocesses a well-formed raw message into (message, topic, data, subscribed, channel).""" + ## Arrange + raw_message = json.dumps({'topic': 'actABC', 'args': {'key': 'value'}}) + expected_result = ( + {'topic': 'actABC', 'args': {'key': 'value'}}, # message + 'actABC', # topic + {'key': 'value'}, # data + 'a', # subscribed + 'ctABC', # channel + ) - self.assertEqual(self._logs_subscriptions(full_channel), [r.msg for r in cm.records]) - self.assertEqual(response, queue.get()) + ## Act + rv = preprocess_ws_client._preprocess_raw_message(raw_message) - def test_on_message_orders_channel_handling(self): - queue = self.ws_client.new_queue_accessor(IbkrWsKey.ORDERS) + ## Assert + assert rv == expected_result - full_channel = f'{queue.key.channel}+{self.conid}' - request = {'channel': f'{full_channel}'} - response = {'topic': f's{full_channel}', '_updated': self.update_time, 'conid': self.conid, 'args': [{'foo': 'bar'}]} - self.assertTrue(queue.empty(), 'Queue should be empty') +def test_preprocess_with_unsubscribed_message(preprocess_ws_client): + """Returns empty preprocess result for unsubscribed messages.""" + ## Arrange + raw_message = json.dumps({'message': 'Unsubscribed'}) - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response) - self.assertTrue(success) + ## Act + rv = preprocess_ws_client._preprocess_raw_message(raw_message) - self.assertEqual(self._logs_subscriptions(full_channel, None, True, True), [r.msg for r in cm.records]) - self.assertEqual(response, queue.get()) + ## Assert + assert rv == ({'message': 'Unsubscribed'}, None, None, None, None) - def test_subscription_without_confirmation(self): - channel = 'fake' - full_channel = f'{channel}+{self.conid}' - request = {'channel': f'{full_channel}', 'needs_confirmation': False, 'confirms_unsubscription': False} - response = None - expected_errors = [f'IbkrWsClient: Channel subscription timeout: s{full_channel} after {self.subscription_retries} attempts.'] +# -------------------------------------------------------------------------------------- +# On-message handling +# -------------------------------------------------------------------------------------- - with patch.object(self.ws_client, 'has_subscription', return_value=True): - cm, success = self._subscribe(request, response, expected_errors=expected_errors) - self.assertTrue(success) - self.assertEqual( - [ - f'IbkrWsClient: Subscribed: s{full_channel} without confirmation.', - f'IbkrWsClient: Unsubscribed: u{full_channel}+{{}} without confirmation.', - ], - [r.msg for r in cm.records], - ) +@capture_logs(logger_level='DEBUG') +def test_on_message_system_heartbeat(ws_client, patched_constructors): + """Updates last heartbeat on system heartbeat message.""" + ## Arrange + hb = 12345678 - def test_check_health(self): - start_time = [100] - has_active_connection_counter = [0] - - # control time - def fake_time(): - start_time[0] += 100 - return start_time[0] - - # simulate that we don't have ws connection first - def has_active_connection(): - has_active_connection_counter[0] += 1 - if has_active_connection_counter[0] <= 2: - return False - return True - - # prepare a fake subscription - queue = self.ws_client.new_queue_accessor(IbkrWsKey.TRADES) - full_channel = f'{queue.key.channel}+{self.conid}' - request = {'channel': f'{full_channel}', 'data': {'foo': 'bar'}} - response = {'topic': f's{full_channel}', '_updated': self.update_time, 'conid': self.conid, 'args': [{'foo': 'bar'}]} - - def run(): - # ensures each time WebSocketApp's mock is created, we override its on_message method - def override_init_wsa_mock(wsa_mock: MagicMock, *args, **kwargs): - wsa_mock = init_wsa_mock(wsa_mock, *args, **kwargs) - wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) - return wsa_mock - - self.ws_client.start() - self.ws_client.check_health() - self.wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) - - # create the original subscription - self.ws_client.subscribe(**request) - - # we simulate that closing the WebSocket doesn't work since we have connectivity issues - # self.wsa_mock.on_close.side_effect = lambda x, y, z: None - - # override time.time, ignore check_ping and take control of has_active_connection - with patch('ibind.client.ibkr_ws_client.time') as time_mock, \ - patch.object(self.ws_client, 'check_ping', return_value=True), \ - patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: override_init_wsa_mock(self.wsa_mock, *args, **kwargs)), \ - patch.object(self.ws_client, '_has_active_connection', side_effect=has_active_connection) as has_active_connection_mock: # fmt: skip - time_mock.time.side_effect = fake_time - self.ws_client._last_heartbeat = self.max_ping_interval * 1000 - - # this should try to close the connection, fail to do so, abandon the WebSocketApp's mock, - # then recreate a new mock and recreate the connections - self.ws_client.check_health() - - self.assertTrue(self.ws_client.ready()) - self.assertEqual([call()] * 6, has_active_connection_mock.call_args_list) - self.ws_client.shutdown() - - expected_errors = [ - f'IbkrWsClient: Last IBKR heartbeat happened 162.00 seconds ago, exceeding the max ping interval of {self.max_ping_interval}. Restarting.', - # 'IbkrWsClient: Hard reset close timeout', - # f'IbkrWsClient: Abandoning current WebSocketApp that cannot be closed: {self.wsa_mock}' - ] + ## Act + _send_payload(ws_client, {'topic': 'system', 'hb': hb}) - cm, success = self.run_in_test_context(run, expected_errors=expected_errors) + ## Assert + assert ws_client._last_heartbeat == hb - channel_subscribed_log = f'IbkrWsClient: Subscribed: s{full_channel}+{json.dumps(request["data"])}' +@capture_logs(logger_level='DEBUG', expected_errors = ["IbkrWsClient: Account ID mismatch: expected=TEST_ACCOUNT_ID, received=['OTHER_ACCOUNT_ID']"]) +def test_on_message_act_account_mismatch(ws_client, patched_constructors): + """Logs a warning when account list in act message mismatches expected account.""" + ## Act + _send_payload(ws_client, {'topic': 'act', 'args': {'accounts': ['OTHER_ACCOUNT_ID']}}) - self.assertEqual( - [channel_subscribed_log] - + expected_errors - + [ - f'IbkrWsClient: Invalidated subscription: {full_channel}', - f"IbkrWsClient: Recreating 1/1 subscriptions: {{'{full_channel}': {{'status': False, 'data': {request['data']}, 'needs_confirmation': True, 'subscription_processor': None}}}}", - channel_subscribed_log, - f'IbkrWsClient: Invalidated subscription: {full_channel}', - ], - [r.msg for r in cm.records], - ) + +@capture_logs(logger_level='DEBUG') +def test_on_message_blt(ws_client, patched_constructors, mocker): + """Dispatches bulletin messages to _handle_bulletin.""" + ## Arrange + bulletin_message = {'topic': 'blt', 'args': {'bulletin_key': 'some_info'}} + mock_handle_bulletin = mocker.patch.object(ws_client, '_handle_bulletin', MagicMock()) + + ## Act + _send_payload(ws_client, bulletin_message) + + ## Assert + mock_handle_bulletin.assert_called_once_with(bulletin_message) + +@capture_logs(logger_level='DEBUG', expected_errors=[ + "IbkrWsClient: Status unauthenticated: {'authenticated': False}", + 'IbkrWsClient: Not authenticated, closing WebSocketApp', +]) +def test_on_message_sts_unauthenticated(ws_client, client_mock, patched_constructors, mocker): + """On unauthenticated status, refetches session and closes websocket.""" + ## Arrange + message_data = {'topic': 'sts', 'args': {'authenticated': False}} + session_id = 6545676 + + response_mock = MagicMock(spec=requests.Response) + response_mock.status_code = 200 + response_mock.json.return_value = {'session': session_id, 'data_to_be_ignored': '1234'} + + client_mock.tickle.return_value = Result(data=response_mock.json.return_value) + + requests_mock = mocker.patch('ibind.base.rest_client.requests') + requests_mock.request.return_value = response_mock + + ## Act + _send_payload(ws_client, message_data) + + ## Assert + assert ws_client._authenticated is False + +@capture_logs(logger_level='DEBUG') +def test_on_message_sts_authenticated(ws_client, patched_constructors): + """Accepts authenticated status without logging warnings.""" + ## Act + _send_payload(ws_client, {'topic': 'sts', 'args': {'authenticated': True}}) + + +@capture_logs(logger_level='DEBUG', expected_errors = [f'IbkrWsClient: Error message:'], partial_match=True) +def test_on_message_error(ws_client, patched_constructors): + """Logs error-topic messages as warnings.""" + ## Act + _send_payload(ws_client, {'topic': 'error', 'args': {'error_key': 'error_details'}}) + + + +@capture_logs(logger_level='DEBUG', expected_errors=['unrecognised. Message:'], partial_match=True) +def test_on_message_no_topic_handler(ws_client, patched_constructors): + """Logs a warning when no handler exists for a topic.""" + ## Arrange + message_data = {'topic': 'unrecognized_topic', 'args': {'some_key': 'some_value'}} + + ## Act + _send_payload(ws_client, message_data) + + +@capture_logs(logger_level='DEBUG', expected_errors = [ + 'message that is missing a subscription. Message:' +], partial_match=True) +def test_on_message_handled_without_subscription(ws_client, patched_constructors, mocker): + """Logs a warning if a subscribed message arrives without a known subscription.""" + ## Arrange + mocker.patch.object(ws_client, '_handle_subscribed_message', return_value=True) + + ## Act + _send_payload(ws_client, {'topic': 'some_topic', 'args': {'channel': 'XYZ', 'data': 'info'}}) + + + +# -------------------------------------------------------------------------------------- +# Subscription + channel-specific handling +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG') +def test_on_message_market_data_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes market data updates into the MARKET_DATA queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.MARKET_DATA) + full_channel = f'{queue.key.channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'data': {'fields': ['55', '71', '84', '86', '88', '85', '87', '7295', '7296', '70']}, + } + response = { + 'topic': f's{full_channel}', + 'conid': _CONID, + '_updated': _UPDATE_TIME, + 55: 'AAPL', + 70: '195.34', + 71: '193.67', + 87: '24.2M', + 7295: '194.10', + 84: '195.25', + 86: '195.26', + 88: '3,500', + 85: '500', + 6508: '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, request['data'])) + assert ( + { + _CONID: { + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'topic': f'smd+{_CONID}', + 'ask_price': '195.26', + 'ask_size': '500', + 'bid_price': '195.25', + 'bid_size': '3,500', + 'high': '195.34', + 'low': '193.67', + 'open': '194.10', + 'service_params': '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', + 'symbol': 'AAPL', + 'volume': '24.2M', + } + } + == queue.get() + ) + + +@capture_logs(logger_level='DEBUG') +def test_on_message_market_history_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes market history updates into the MARKET_HISTORY queue and tracks server IDs.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.MARKET_HISTORY) + server_id = 87567 + full_channel = f'{queue.key.channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'data': {'period': '1min', 'bar': '1min', 'outsideRTH': True, 'source': 'trades', 'format': '%o/%c/%h/%l'}, + 'confirms_unsubscription': False, + } + response = { + 'topic': f's{full_channel}', + 'serverId': server_id, + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'foo': 'bar', + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, request['data'])) + assert response == queue.get() + assert server_id in ws_client.server_ids(IbkrWsKey.MARKET_HISTORY) + + +@capture_logs(logger_level='DEBUG') +def test_on_message_trade_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes trade updates into the TRADES queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.TRADES) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}'} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel)) + assert response == queue.get() + + +@capture_logs(logger_level='DEBUG') +def test_on_message_orders_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Routes order updates into the ORDERS queue.""" + ## Arrange + cm = kwargs['_cm_ibind'] + queue = ws_client.new_queue_accessor(IbkrWsKey.ORDERS) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}'} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + assert queue.empty() is True + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log(_logs_subscriptions(full_channel, None, True, True)) + assert response == queue.get() + + +@capture_logs(logger_level='DEBUG') +def test_subscription_without_confirmation(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): + """Subscribes/unsubscribes without confirmation when requested.""" + ## Arrange + cm = kwargs['_cm_ibind'] + channel = 'fake' + full_channel = f'{channel}+{_CONID}' + request = { + 'channel': f'{full_channel}', + 'needs_confirmation': False, + 'confirms_unsubscription': False, + } + response = None + + mocker.patch.object(ws_client, 'has_subscription', return_value=True) + + ## Act + success = _subscribe(ws_client, wsa_mock, request, response) + + ## Assert + assert success is True + cm.partial_log([ + f'IbkrWsClient: Subscribed: s{full_channel} without confirmation.', + f'IbkrWsClient: Unsubscribed: u{full_channel}+{{}} without confirmation.', + ]) + + + +# -------------------------------------------------------------------------------------- +# Health checks +# -------------------------------------------------------------------------------------- + + +@capture_logs(logger_level='DEBUG', expected_errors=[ + f'IbkrWsClient: Last IBKR heartbeat happened 162.00 seconds ago, exceeding the max ping interval of {_MAX_PING_INTERVAL}. Restarting.', +]) +def test_check_health(ws_client, wsa_mock, ws_app_factory, patched_constructors, mocker, **kwargs): + """Restarts and recreates subscriptions when heartbeat exceeds max ping interval.""" + ## Arrange + cm = kwargs['_cm_ibind'] + start_time = [100] + has_active_connection_counter = [0] + + def fake_time(): + start_time[0] += 100 + return start_time[0] + + def has_active_connection(): + has_active_connection_counter[0] += 1 + if has_active_connection_counter[0] <= 2: + return False + return True + + queue = ws_client.new_queue_accessor(IbkrWsKey.TRADES) + full_channel = f'{queue.key.channel}+{_CONID}' + request = {'channel': f'{full_channel}', 'data': {'foo': 'bar'}} + response = { + 'topic': f's{full_channel}', + '_updated': _UPDATE_TIME, + 'conid': _CONID, + 'args': [{'foo': 'bar'}], + } + + ## Act + def override_init_wsa_mock(wsa_mock: MagicMock, *args, **kwargs): + wsa_mock = init_wsa_mock(wsa_mock, *args, **kwargs) + wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) + return wsa_mock + + ws_client.start() + ws_client.check_health() + wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) + + ws_client.subscribe(**request) + + # Override time, ignore ping check, and control active-connection health checks. + time_mock = mocker.patch('ibind.client.ibkr_ws_client.time') + time_mock.time.side_effect = fake_time + + mocker.patch.object(ws_client, 'check_ping', return_value=True) + mocker.patch.object(ws_client, '_has_active_connection', side_effect=has_active_connection) + + # Ensure each reconnect creates a WebSocketApp whose on_message pushes our fake response. + ws_app_factory['fn'] = lambda *args, **kwargs: override_init_wsa_mock(wsa_mock, *args, **kwargs) + + ws_client._last_heartbeat = _MAX_PING_INTERVAL * 1000 + ws_client.check_health() + + assert ws_client.ready() is True + assert [call()] * 6 == ws_client._has_active_connection.call_args_list + + ws_client.shutdown() + + + ## Assert + channel_subscribed_log = f'IbkrWsClient: Subscribed: s{full_channel}+{json.dumps(request["data"])}' + cm.partial_log( + [channel_subscribed_log] + + [ + f'IbkrWsClient: Invalidated subscription: {full_channel}', + f"IbkrWsClient: Recreating 1/1 subscriptions: {{'{full_channel}': {{'status': False, 'data': {request['data']}, 'needs_confirmation': True, 'subscription_processor': None}}}}", + channel_subscribed_log, + f'IbkrWsClient: Invalidated subscription: {full_channel}', + ] + ) \ No newline at end of file diff --git a/test/integration/client/test_ibkr_ws_client_i_new.py b/test/integration/client/test_ibkr_ws_client_i_new.py deleted file mode 100644 index 60da7ab7..00000000 --- a/test/integration/client/test_ibkr_ws_client_i_new.py +++ /dev/null @@ -1,539 +0,0 @@ -import json -from threading import Thread -from typing import Optional -from unittest.mock import MagicMock, call - -import pytest -import requests - -from ibind import Result -from ibind.client.ibkr_client import IbkrClient -from ibind.client.ibkr_ws_client import IbkrWsClient, IbkrSubscriptionProcessor, IbkrWsKey -from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils_new import capture_logs - -_URL_WS = 'wss://localhost:5000/v1/api/ws' -_URL_REST = 'https://localhost:5000' -_ACCOUNT_ID = 'TEST_ACCOUNT_ID' -_TIMEOUT_REST = 8 -_MAX_RETRIES_REST = 4 -_MAX_RECONNECT_ATTEMPTS = 4 -_MAX_PING_INTERVAL = 38 -_SUBSCRIPTION_RETRIES = 3 -_CONID = 265598 -_UPDATE_TIME = 5678765456 - - -# -------------------------------------------------------------------------------------- -# Test setup -# -------------------------------------------------------------------------------------- - - -@pytest.fixture -def preprocess_ws_client(): - return IbkrWsClient( - url=_URL_WS, - ibkr_client=None, - account_id=None, - subscription_processor_class=lambda: None, - ) - - -@pytest.fixture -def client_mock(): - client = MagicMock( - spec=IbkrClient( - url=_URL_REST, - account_id=_ACCOUNT_ID, - timeout=_TIMEOUT_REST, - max_retries=_MAX_RETRIES_REST, - ) - ) - client.tickle.return_value.data = {'session': 'TEST_COOKIE'} - return client - - -@pytest.fixture -def ws_client(client_mock): - return IbkrWsClient( - url=_URL_WS, - ibkr_client=client_mock, - account_id=_ACCOUNT_ID, - subscription_processor_class=IbkrSubscriptionProcessor, - subscription_retries=_SUBSCRIPTION_RETRIES, - subscription_timeout=0.01, - cacert=False, - timeout=0.01, - max_connection_attempts=_MAX_RECONNECT_ATTEMPTS, - max_ping_interval=_MAX_PING_INTERVAL, - ) - - - -@pytest.fixture -def wsa_mock(): - return create_wsa_mock() - - -@pytest.fixture -def thread_mock(ws_client, wsa_mock): - thread_mock = MagicMock(spec=Thread) - thread_mock.start.side_effect = lambda: ws_client._run_websocket(wsa_mock) - return thread_mock - - -@pytest.fixture -def ws_app_factory(wsa_mock): - # Use a mutable side-effect so individual tests can temporarily override WebSocketApp behavior. - return { - 'fn': lambda *args, **kwargs: init_wsa_mock(wsa_mock, *args, **kwargs), - } - - -@pytest.fixture -def patched_constructors(mocker, thread_mock, ws_app_factory): - mocker.patch('ibind.base.ws_client.WebSocketApp', side_effect=lambda *args, **kwargs: ws_app_factory['fn'](*args, **kwargs)) - mocker.patch('ibind.base.ws_client.Thread', return_value=thread_mock) - return None - - - -def _send_payload(ws_client, payload: dict): - success = ws_client.start() - ws_client.send(json.dumps(payload)) - ws_client.shutdown() - return success - - -def _subscribe(ws_client, wsa_mock, request: dict, response: Optional[dict]): - def override_on_message(wsa_mock: MagicMock, message: str): - if response is None: - return - raw_message = json.dumps(response) - wsa_mock.__on_message__(wsa_mock, raw_message) - - ws_client.start() - wsa_mock._on_message.side_effect = override_on_message - - rv = ws_client.subscribe( - **{ - 'channel': request.get('channel'), - 'data': request.get('data'), - 'needs_confirmation': request.get('needs_confirmation'), - } - ) - ws_client.unsubscribe( - **{ - 'channel': request.get('channel'), - 'data': request.get('data'), - 'needs_confirmation': request.get('confirms_unsubscription'), - } - ) - ws_client.shutdown() - return rv - - - -def _logs_subscriptions(full_channel, data=None, needs_confirmation_sub: bool = False, needs_confirmation_unsub: bool = True): - return [ - f'IbkrWsClient: Subscribed: s{full_channel}{"" if data is None else f"+{json.dumps(data)}"}{"" if not needs_confirmation_sub else " without confirmation."}', - f'IbkrWsClient: Unsubscribed: u{full_channel}+{json.dumps(data if data is not None else {})}{"" if not needs_confirmation_unsub else " without confirmation."}', - ] - - -# -------------------------------------------------------------------------------------- -# Message preprocessing -# -------------------------------------------------------------------------------------- - - -def test_preprocess_with_well_formed_message(preprocess_ws_client): - """Preprocesses a well-formed raw message into (message, topic, data, subscribed, channel).""" - ## Arrange - raw_message = json.dumps({'topic': 'actABC', 'args': {'key': 'value'}}) - expected_result = ( - {'topic': 'actABC', 'args': {'key': 'value'}}, # message - 'actABC', # topic - {'key': 'value'}, # data - 'a', # subscribed - 'ctABC', # channel - ) - - ## Act - rv = preprocess_ws_client._preprocess_raw_message(raw_message) - - ## Assert - assert rv == expected_result - - -def test_preprocess_with_unsubscribed_message(preprocess_ws_client): - """Returns empty preprocess result for unsubscribed messages.""" - ## Arrange - raw_message = json.dumps({'message': 'Unsubscribed'}) - - ## Act - rv = preprocess_ws_client._preprocess_raw_message(raw_message) - - ## Assert - assert rv == ({'message': 'Unsubscribed'}, None, None, None, None) - - -# -------------------------------------------------------------------------------------- -# On-message handling -# -------------------------------------------------------------------------------------- - - -@capture_logs(logger_level='DEBUG') -def test_on_message_system_heartbeat(ws_client, patched_constructors): - """Updates last heartbeat on system heartbeat message.""" - ## Arrange - hb = 12345678 - - ## Act - _send_payload(ws_client, {'topic': 'system', 'hb': hb}) - - ## Assert - assert ws_client._last_heartbeat == hb - -@capture_logs(logger_level='DEBUG', expected_errors = ["IbkrWsClient: Account ID mismatch: expected=TEST_ACCOUNT_ID, received=['OTHER_ACCOUNT_ID']"]) -def test_on_message_act_account_mismatch(ws_client, patched_constructors): - """Logs a warning when account list in act message mismatches expected account.""" - ## Act - _send_payload(ws_client, {'topic': 'act', 'args': {'accounts': ['OTHER_ACCOUNT_ID']}}) - - -@capture_logs(logger_level='DEBUG') -def test_on_message_blt(ws_client, patched_constructors, mocker): - """Dispatches bulletin messages to _handle_bulletin.""" - ## Arrange - bulletin_message = {'topic': 'blt', 'args': {'bulletin_key': 'some_info'}} - mock_handle_bulletin = mocker.patch.object(ws_client, '_handle_bulletin', MagicMock()) - - ## Act - _send_payload(ws_client, bulletin_message) - - ## Assert - mock_handle_bulletin.assert_called_once_with(bulletin_message) - -@capture_logs(logger_level='DEBUG', expected_errors=[ - "IbkrWsClient: Status unauthenticated: {'authenticated': False}", - 'IbkrWsClient: Not authenticated, closing WebSocketApp', -]) -def test_on_message_sts_unauthenticated(ws_client, client_mock, patched_constructors, mocker): - """On unauthenticated status, refetches session and closes websocket.""" - ## Arrange - message_data = {'topic': 'sts', 'args': {'authenticated': False}} - session_id = 6545676 - - response_mock = MagicMock(spec=requests.Response) - response_mock.status_code = 200 - response_mock.json.return_value = {'session': session_id, 'data_to_be_ignored': '1234'} - - client_mock.tickle.return_value = Result(data=response_mock.json.return_value) - - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response_mock - - ## Act - _send_payload(ws_client, message_data) - - ## Assert - assert ws_client._authenticated is False - -@capture_logs(logger_level='DEBUG') -def test_on_message_sts_authenticated(ws_client, patched_constructors): - """Accepts authenticated status without logging warnings.""" - ## Act - _send_payload(ws_client, {'topic': 'sts', 'args': {'authenticated': True}}) - - -@capture_logs(logger_level='DEBUG', expected_errors = [f'IbkrWsClient: Error message:'], partial_match=True) -def test_on_message_error(ws_client, patched_constructors): - """Logs error-topic messages as warnings.""" - ## Act - _send_payload(ws_client, {'topic': 'error', 'args': {'error_key': 'error_details'}}) - - - -@capture_logs(logger_level='DEBUG', expected_errors=['unrecognised. Message:'], partial_match=True) -def test_on_message_no_topic_handler(ws_client, patched_constructors): - """Logs a warning when no handler exists for a topic.""" - ## Arrange - message_data = {'topic': 'unrecognized_topic', 'args': {'some_key': 'some_value'}} - - ## Act - _send_payload(ws_client, message_data) - - -@capture_logs(logger_level='DEBUG', expected_errors = [ - 'message that is missing a subscription. Message:' -], partial_match=True) -def test_on_message_handled_without_subscription(ws_client, patched_constructors, mocker): - """Logs a warning if a subscribed message arrives without a known subscription.""" - ## Arrange - mocker.patch.object(ws_client, '_handle_subscribed_message', return_value=True) - - ## Act - _send_payload(ws_client, {'topic': 'some_topic', 'args': {'channel': 'XYZ', 'data': 'info'}}) - - - -# -------------------------------------------------------------------------------------- -# Subscription + channel-specific handling -# -------------------------------------------------------------------------------------- - - -@capture_logs(logger_level='DEBUG') -def test_on_message_market_data_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): - """Routes market data updates into the MARKET_DATA queue.""" - ## Arrange - cm = kwargs['_cm_ibind'] - queue = ws_client.new_queue_accessor(IbkrWsKey.MARKET_DATA) - full_channel = f'{queue.key.channel}+{_CONID}' - request = { - 'channel': f'{full_channel}', - 'data': {'fields': ['55', '71', '84', '86', '88', '85', '87', '7295', '7296', '70']}, - } - response = { - 'topic': f's{full_channel}', - 'conid': _CONID, - '_updated': _UPDATE_TIME, - 55: 'AAPL', - 70: '195.34', - 71: '193.67', - 87: '24.2M', - 7295: '194.10', - 84: '195.25', - 86: '195.26', - 88: '3,500', - 85: '500', - 6508: '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', - } - - assert queue.empty() is True - - mocker.patch.object(ws_client, 'has_subscription', return_value=True) - - ## Act - success = _subscribe(ws_client, wsa_mock, request, response) - - ## Assert - assert success is True - cm.partial_log(_logs_subscriptions(full_channel, request['data'])) - assert ( - { - _CONID: { - '_updated': _UPDATE_TIME, - 'conid': _CONID, - 'topic': f'smd+{_CONID}', - 'ask_price': '195.26', - 'ask_size': '500', - 'bid_price': '195.25', - 'bid_size': '3,500', - 'high': '195.34', - 'low': '193.67', - 'open': '194.10', - 'service_params': '&serviceID1=122&serviceID2=123&serviceID3=203&serviceID4=775&serviceID5=204&serviceID6=206&serviceID7=108&serviceID8=109', - 'symbol': 'AAPL', - 'volume': '24.2M', - } - } - == queue.get() - ) - - -@capture_logs(logger_level='DEBUG') -def test_on_message_market_history_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): - """Routes market history updates into the MARKET_HISTORY queue and tracks server IDs.""" - ## Arrange - cm = kwargs['_cm_ibind'] - queue = ws_client.new_queue_accessor(IbkrWsKey.MARKET_HISTORY) - server_id = 87567 - full_channel = f'{queue.key.channel}+{_CONID}' - request = { - 'channel': f'{full_channel}', - 'data': {'period': '1min', 'bar': '1min', 'outsideRTH': True, 'source': 'trades', 'format': '%o/%c/%h/%l'}, - 'confirms_unsubscription': False, - } - response = { - 'topic': f's{full_channel}', - 'serverId': server_id, - '_updated': _UPDATE_TIME, - 'conid': _CONID, - 'foo': 'bar', - } - - assert queue.empty() is True - - mocker.patch.object(ws_client, 'has_subscription', return_value=True) - - ## Act - success = _subscribe(ws_client, wsa_mock, request, response) - - ## Assert - assert success is True - cm.partial_log(_logs_subscriptions(full_channel, request['data'])) - assert response == queue.get() - assert server_id in ws_client.server_ids(IbkrWsKey.MARKET_HISTORY) - - -@capture_logs(logger_level='DEBUG') -def test_on_message_trade_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): - """Routes trade updates into the TRADES queue.""" - ## Arrange - cm = kwargs['_cm_ibind'] - queue = ws_client.new_queue_accessor(IbkrWsKey.TRADES) - full_channel = f'{queue.key.channel}+{_CONID}' - request = {'channel': f'{full_channel}'} - response = { - 'topic': f's{full_channel}', - '_updated': _UPDATE_TIME, - 'conid': _CONID, - 'args': [{'foo': 'bar'}], - } - - assert queue.empty() is True - - mocker.patch.object(ws_client, 'has_subscription', return_value=True) - - ## Act - success = _subscribe(ws_client, wsa_mock, request, response) - - ## Assert - assert success is True - cm.partial_log(_logs_subscriptions(full_channel)) - assert response == queue.get() - - -@capture_logs(logger_level='DEBUG') -def test_on_message_orders_channel_handling(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): - """Routes order updates into the ORDERS queue.""" - ## Arrange - cm = kwargs['_cm_ibind'] - queue = ws_client.new_queue_accessor(IbkrWsKey.ORDERS) - full_channel = f'{queue.key.channel}+{_CONID}' - request = {'channel': f'{full_channel}'} - response = { - 'topic': f's{full_channel}', - '_updated': _UPDATE_TIME, - 'conid': _CONID, - 'args': [{'foo': 'bar'}], - } - - assert queue.empty() is True - - mocker.patch.object(ws_client, 'has_subscription', return_value=True) - - ## Act - success = _subscribe(ws_client, wsa_mock, request, response) - - ## Assert - assert success is True - cm.partial_log(_logs_subscriptions(full_channel, None, True, True)) - assert response == queue.get() - - -@capture_logs(logger_level='DEBUG') -def test_subscription_without_confirmation(ws_client, wsa_mock, patched_constructors, mocker, **kwargs): - """Subscribes/unsubscribes without confirmation when requested.""" - ## Arrange - cm = kwargs['_cm_ibind'] - channel = 'fake' - full_channel = f'{channel}+{_CONID}' - request = { - 'channel': f'{full_channel}', - 'needs_confirmation': False, - 'confirms_unsubscription': False, - } - response = None - - mocker.patch.object(ws_client, 'has_subscription', return_value=True) - - ## Act - success = _subscribe(ws_client, wsa_mock, request, response) - - ## Assert - assert success is True - cm.partial_log([ - f'IbkrWsClient: Subscribed: s{full_channel} without confirmation.', - f'IbkrWsClient: Unsubscribed: u{full_channel}+{{}} without confirmation.', - ]) - - - -# -------------------------------------------------------------------------------------- -# Health checks -# -------------------------------------------------------------------------------------- - - -@capture_logs(logger_level='DEBUG', expected_errors=[ - f'IbkrWsClient: Last IBKR heartbeat happened 162.00 seconds ago, exceeding the max ping interval of {_MAX_PING_INTERVAL}. Restarting.', -]) -def test_check_health(ws_client, wsa_mock, ws_app_factory, patched_constructors, mocker, **kwargs): - """Restarts and recreates subscriptions when heartbeat exceeds max ping interval.""" - ## Arrange - cm = kwargs['_cm_ibind'] - start_time = [100] - has_active_connection_counter = [0] - - def fake_time(): - start_time[0] += 100 - return start_time[0] - - def has_active_connection(): - has_active_connection_counter[0] += 1 - if has_active_connection_counter[0] <= 2: - return False - return True - - queue = ws_client.new_queue_accessor(IbkrWsKey.TRADES) - full_channel = f'{queue.key.channel}+{_CONID}' - request = {'channel': f'{full_channel}', 'data': {'foo': 'bar'}} - response = { - 'topic': f's{full_channel}', - '_updated': _UPDATE_TIME, - 'conid': _CONID, - 'args': [{'foo': 'bar'}], - } - - ## Act - def override_init_wsa_mock(wsa_mock: MagicMock, *args, **kwargs): - wsa_mock = init_wsa_mock(wsa_mock, *args, **kwargs) - wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) - return wsa_mock - - ws_client.start() - ws_client.check_health() - wsa_mock._on_message.side_effect = lambda wsa_mock, message: wsa_mock.__on_message__(wsa_mock, json.dumps(response)) - - ws_client.subscribe(**request) - - # Override time, ignore ping check, and control active-connection health checks. - time_mock = mocker.patch('ibind.client.ibkr_ws_client.time') - time_mock.time.side_effect = fake_time - - mocker.patch.object(ws_client, 'check_ping', return_value=True) - mocker.patch.object(ws_client, '_has_active_connection', side_effect=has_active_connection) - - # Ensure each reconnect creates a WebSocketApp whose on_message pushes our fake response. - ws_app_factory['fn'] = lambda *args, **kwargs: override_init_wsa_mock(wsa_mock, *args, **kwargs) - - ws_client._last_heartbeat = _MAX_PING_INTERVAL * 1000 - ws_client.check_health() - - assert ws_client.ready() is True - assert [call()] * 6 == ws_client._has_active_connection.call_args_list - - ws_client.shutdown() - - - ## Assert - channel_subscribed_log = f'IbkrWsClient: Subscribed: s{full_channel}+{json.dumps(request["data"])}' - cm.partial_log( - [channel_subscribed_log] - + [ - f'IbkrWsClient: Invalidated subscription: {full_channel}', - f"IbkrWsClient: Recreating 1/1 subscriptions: {{'{full_channel}': {{'status': False, 'data': {request['data']}, 'needs_confirmation': True, 'subscription_processor': None}}}}", - channel_subscribed_log, - f'IbkrWsClient: Invalidated subscription: {full_channel}', - ] - ) \ No newline at end of file diff --git a/test/unit/support/test_py_utils_u.py b/test/unit/support/test_py_utils_u.py index bcd88198..5fb4a250 100644 --- a/test/unit/support/test_py_utils_u.py +++ b/test/unit/support/test_py_utils_u.py @@ -1,113 +1,228 @@ import time -import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock + +import pytest from ibind.support.py_utils import ensure_list_arg, execute_in_parallel, execute_with_key, wait_until -class TestEnsureListArgU(unittest.TestCase): - @ensure_list_arg('arg') - def sample_function(self, arg): - return arg +@ensure_list_arg('arg') +def sample_function(arg): + return arg + + +def test_ensure_list_arg_with_list(): + """Wraps list args without altering the list.""" + # Arrange + input_arg = [1, 2, 3] + + # Act + result = sample_function(input_arg) + + # Assert + assert result == input_arg + + +def test_ensure_list_arg_with_non_list(): + """Wraps a non-list arg into a single-item list.""" + # Arrange + input_arg = 1 + + # Act + result = sample_function(input_arg) + + # Assert + assert result == [input_arg] + + +def test_ensure_list_arg_with_keyword_arg_list(): + """Preserves list input when passed as a keyword arg.""" + # Arrange + input_arg = [1, 2, 3] + + # Act + result = sample_function(arg=input_arg) + + # Assert + assert result == input_arg + - def test_ensure_list_arg_with_list(self): - input_arg = [1, 2, 3] - self.assertEqual(self.sample_function(input_arg), input_arg) +def test_ensure_list_arg_with_keyword_arg_non_list(): + """Wraps a non-list keyword arg into a single-item list.""" + # Arrange + input_arg = 1 - def test_ensure_list_arg_with_non_list(self): - input_arg = 1 - self.assertEqual(self.sample_function(input_arg), [input_arg]) + # Act + result = sample_function(arg=input_arg) - def test_ensure_list_arg_with_keyword_arg_list(self): - input_arg = [1, 2, 3] - self.assertEqual(self.sample_function(arg=input_arg), input_arg) + # Assert + assert result == [input_arg] - def test_ensure_list_arg_with_keyword_arg_non_list(self): - input_arg = 1 - self.assertEqual(self.sample_function(arg=input_arg), [input_arg]) - def test_ensure_list_arg_with_missing_arg(self): - with self.assertRaises(TypeError): - self.sample_function() +def test_ensure_list_arg_with_missing_arg(): + """Raises TypeError when the decorated arg is missing.""" + # Arrange + + # Act / Assert + with pytest.raises(TypeError): + sample_function() -class TestExecuteInParallelU(unittest.TestCase): - def _func(self, v1, v2): +@pytest.fixture +def parallel_setup(): + state = {'delay': 0} + + def _func(v1, v2): if v1 == 1: - time.sleep(self.delay) + time.sleep(state['delay']) return 'result1' elif v2 == 2: return 'result2' else: return 'unknown' - def setUp(self): - self.delay = 0 - self.func = MagicMock(side_effect=self._func) - self.func.__name__ = 'TEST_FUNCTION' - self.requests_dict = {'req1': {'args': [1, 0], 'kwargs': {}}, 'req2': {'args': [0], 'kwargs': {'v2': 2}}} - self.requests_list = [{'args': [1, 0], 'kwargs': {}}, {'args': [0], 'kwargs': {'v2': 2}}] - - def test_execute_in_parallel_with_dict(self): - results = execute_in_parallel(self.func, self.requests_dict) - self.assertEqual(results, {'req1': 'result1', 'req2': 'result2'}) - self.assertEqual(self.func.call_count, 2) - - def test_execute_in_parallel_with_list(self): - self.delay = 0.1 - results = execute_in_parallel(self.func, self.requests_list) - self.assertEqual(results, ['result1', 'result2']) - self.assertEqual(self.func.call_count, 2) - - def test_execute_with_key_success(self): - result = execute_with_key('key', self.func, 1, v2=2) - self.func.assert_called_with(1, v2=2) - self.assertEqual(result, ('key', 'result1')) - - def test_execute_with_key_exception(self): - self.func.side_effect = Exception('error') - result = execute_with_key('key', self.func, 1, v2=2) - self.assertIsInstance(result[1], Exception) - - def test_execute_in_parallel_rate_limiting(self): - start_time = time.time() - - # Simulate a slow function to test rate limiting - def slow_func(): - time.sleep(0.05) - return 'slow_result' - - requests = {i: {'args': [], 'kwargs': {}} for i in range(20)} # 10 requests - max_per_second = 10 # Limit to 5 requests per second - results = execute_in_parallel(slow_func, requests, max_per_second=max_per_second) - - duration = time.time() - start_time - self.assertGreaterEqual(duration, 1.05) # Should take at least 1.1 seconds to complete all requests - self.assertEqual(len(results), 20) - - -class TestWaitUntilU(unittest.TestCase): - def test_wait_until_condition_met(self): - condition = MagicMock(return_value=True) - self.assertTrue(wait_until(condition)) - condition.assert_called() - - def test_wait_until_condition_not_met(self): - condition = MagicMock(return_value=False) - self.assertFalse(wait_until(condition, timeout=0.1)) - condition.assert_called() - - @patch('ibind.support.py_utils._LOGGER.error') - def test_wait_until_timeout_message(self, mock_logger_error): - condition = MagicMock(return_value=False) - timeout_message = 'Condition not met within timeout' - self.assertFalse(wait_until(condition, timeout_message=timeout_message, timeout=0.1)) - mock_logger_error.assert_called_with(timeout_message) - - def test_wait_until_timeout(self): - start_time = time.time() - condition = MagicMock(return_value=False) - timeout = 0.1 - self.assertFalse(wait_until(condition, timeout=timeout)) - duration = time.time() - start_time - self.assertAlmostEqual(duration, timeout, delta=0.02) + func = MagicMock(side_effect=_func) + func.__name__ = 'TEST_FUNCTION' + requests_dict = {'req1': {'args': [1, 0], 'kwargs': {}}, 'req2': {'args': [0], 'kwargs': {'v2': 2}}} + requests_list = [{'args': [1, 0], 'kwargs': {}}, {'args': [0], 'kwargs': {'v2': 2}}] + + return { + 'state': state, + 'func': func, + 'requests_dict': requests_dict, + 'requests_list': requests_list, + } + + +def test_execute_in_parallel_with_dict(parallel_setup): + """Executes requests in parallel when passed a dict of requests.""" + # Arrange + func = parallel_setup['func'] + requests = parallel_setup['requests_dict'] + + # Act + results = execute_in_parallel(func, requests) + + # Assert + assert results == {'req1': 'result1', 'req2': 'result2'} + assert func.call_count == 2 + + +def test_execute_in_parallel_with_list(parallel_setup): + """Executes requests in parallel when passed a list of requests.""" + # Arrange + func = parallel_setup['func'] + requests = parallel_setup['requests_list'] + parallel_setup['state']['delay'] = 0.1 + + # Act + results = execute_in_parallel(func, requests) + + # Assert + assert results == ['result1', 'result2'] + assert func.call_count == 2 + + +def test_execute_with_key_success(parallel_setup): + """Returns (key, result) when the wrapped function succeeds.""" + # Arrange + func = parallel_setup['func'] + + # Act + result = execute_with_key('key', func, 1, v2=2) + + # Assert + func.assert_called_with(1, v2=2) + assert result == ('key', 'result1') + + +def test_execute_with_key_exception(parallel_setup): + """Returns (key, exception) when the wrapped function raises.""" + # Arrange + func = parallel_setup['func'] + func.side_effect = Exception('error') + + # Act + result = execute_with_key('key', func, 1, v2=2) + + # Assert + assert isinstance(result[1], Exception) + + +def test_execute_in_parallel_rate_limiting(): + """Applies max_per_second rate limiting across parallel executions.""" + # Arrange + start_time = time.time() + + # Simulate a slow function to test rate limiting + def slow_func(): + time.sleep(0.05) + return 'slow_result' + + requests = {i: {'args': [], 'kwargs': {}} for i in range(20)} # 10 requests + max_per_second = 10 # Limit to 5 requests per second + + # Act + results = execute_in_parallel(slow_func, requests, max_per_second=max_per_second) + + # Assert + duration = time.time() - start_time + assert duration >= 1.05 # Should take at least 1.1 seconds to complete all requests + assert len(results) == 20 + + +def test_wait_until_condition_met(): + """Returns True immediately when the condition is already met.""" + # Arrange + condition = MagicMock(return_value=True) + + # Act + result = wait_until(condition) + + # Assert + assert result is True + condition.assert_called() + + +def test_wait_until_condition_not_met(): + """Returns False when the condition is not met before timeout.""" + # Arrange + condition = MagicMock(return_value=False) + + # Act + result = wait_until(condition, timeout=0.1) + + # Assert + assert result is False + condition.assert_called() + + +def test_wait_until_timeout_message(mocker): + """Logs the timeout_message when the deadline is reached.""" + # Arrange + mock_logger_error = mocker.patch('ibind.support.py_utils._LOGGER.error') + condition = MagicMock(return_value=False) + timeout_message = 'Condition not met within timeout' + + # Act + result = wait_until(condition, timeout_message=timeout_message, timeout=0.1) + + # Assert + assert result is False + mock_logger_error.assert_called_with(timeout_message) + + +def test_wait_until_timeout(): + """Waits roughly the specified timeout duration before returning False.""" + # Arrange + start_time = time.time() + condition = MagicMock(return_value=False) + timeout = 0.1 + + # Act + result = wait_until(condition, timeout=timeout) + + # Assert + assert result is False + duration = time.time() - start_time + assert duration == pytest.approx(timeout, abs=0.02) \ No newline at end of file diff --git a/test/unit/support/test_py_utils_u_new.py b/test/unit/support/test_py_utils_u_new.py deleted file mode 100644 index 5fb4a250..00000000 --- a/test/unit/support/test_py_utils_u_new.py +++ /dev/null @@ -1,228 +0,0 @@ -import time -from unittest.mock import MagicMock - -import pytest - -from ibind.support.py_utils import ensure_list_arg, execute_in_parallel, execute_with_key, wait_until - - -@ensure_list_arg('arg') -def sample_function(arg): - return arg - - -def test_ensure_list_arg_with_list(): - """Wraps list args without altering the list.""" - # Arrange - input_arg = [1, 2, 3] - - # Act - result = sample_function(input_arg) - - # Assert - assert result == input_arg - - -def test_ensure_list_arg_with_non_list(): - """Wraps a non-list arg into a single-item list.""" - # Arrange - input_arg = 1 - - # Act - result = sample_function(input_arg) - - # Assert - assert result == [input_arg] - - -def test_ensure_list_arg_with_keyword_arg_list(): - """Preserves list input when passed as a keyword arg.""" - # Arrange - input_arg = [1, 2, 3] - - # Act - result = sample_function(arg=input_arg) - - # Assert - assert result == input_arg - - -def test_ensure_list_arg_with_keyword_arg_non_list(): - """Wraps a non-list keyword arg into a single-item list.""" - # Arrange - input_arg = 1 - - # Act - result = sample_function(arg=input_arg) - - # Assert - assert result == [input_arg] - - -def test_ensure_list_arg_with_missing_arg(): - """Raises TypeError when the decorated arg is missing.""" - # Arrange - - # Act / Assert - with pytest.raises(TypeError): - sample_function() - - -@pytest.fixture -def parallel_setup(): - state = {'delay': 0} - - def _func(v1, v2): - if v1 == 1: - time.sleep(state['delay']) - return 'result1' - elif v2 == 2: - return 'result2' - else: - return 'unknown' - - func = MagicMock(side_effect=_func) - func.__name__ = 'TEST_FUNCTION' - requests_dict = {'req1': {'args': [1, 0], 'kwargs': {}}, 'req2': {'args': [0], 'kwargs': {'v2': 2}}} - requests_list = [{'args': [1, 0], 'kwargs': {}}, {'args': [0], 'kwargs': {'v2': 2}}] - - return { - 'state': state, - 'func': func, - 'requests_dict': requests_dict, - 'requests_list': requests_list, - } - - -def test_execute_in_parallel_with_dict(parallel_setup): - """Executes requests in parallel when passed a dict of requests.""" - # Arrange - func = parallel_setup['func'] - requests = parallel_setup['requests_dict'] - - # Act - results = execute_in_parallel(func, requests) - - # Assert - assert results == {'req1': 'result1', 'req2': 'result2'} - assert func.call_count == 2 - - -def test_execute_in_parallel_with_list(parallel_setup): - """Executes requests in parallel when passed a list of requests.""" - # Arrange - func = parallel_setup['func'] - requests = parallel_setup['requests_list'] - parallel_setup['state']['delay'] = 0.1 - - # Act - results = execute_in_parallel(func, requests) - - # Assert - assert results == ['result1', 'result2'] - assert func.call_count == 2 - - -def test_execute_with_key_success(parallel_setup): - """Returns (key, result) when the wrapped function succeeds.""" - # Arrange - func = parallel_setup['func'] - - # Act - result = execute_with_key('key', func, 1, v2=2) - - # Assert - func.assert_called_with(1, v2=2) - assert result == ('key', 'result1') - - -def test_execute_with_key_exception(parallel_setup): - """Returns (key, exception) when the wrapped function raises.""" - # Arrange - func = parallel_setup['func'] - func.side_effect = Exception('error') - - # Act - result = execute_with_key('key', func, 1, v2=2) - - # Assert - assert isinstance(result[1], Exception) - - -def test_execute_in_parallel_rate_limiting(): - """Applies max_per_second rate limiting across parallel executions.""" - # Arrange - start_time = time.time() - - # Simulate a slow function to test rate limiting - def slow_func(): - time.sleep(0.05) - return 'slow_result' - - requests = {i: {'args': [], 'kwargs': {}} for i in range(20)} # 10 requests - max_per_second = 10 # Limit to 5 requests per second - - # Act - results = execute_in_parallel(slow_func, requests, max_per_second=max_per_second) - - # Assert - duration = time.time() - start_time - assert duration >= 1.05 # Should take at least 1.1 seconds to complete all requests - assert len(results) == 20 - - -def test_wait_until_condition_met(): - """Returns True immediately when the condition is already met.""" - # Arrange - condition = MagicMock(return_value=True) - - # Act - result = wait_until(condition) - - # Assert - assert result is True - condition.assert_called() - - -def test_wait_until_condition_not_met(): - """Returns False when the condition is not met before timeout.""" - # Arrange - condition = MagicMock(return_value=False) - - # Act - result = wait_until(condition, timeout=0.1) - - # Assert - assert result is False - condition.assert_called() - - -def test_wait_until_timeout_message(mocker): - """Logs the timeout_message when the deadline is reached.""" - # Arrange - mock_logger_error = mocker.patch('ibind.support.py_utils._LOGGER.error') - condition = MagicMock(return_value=False) - timeout_message = 'Condition not met within timeout' - - # Act - result = wait_until(condition, timeout_message=timeout_message, timeout=0.1) - - # Assert - assert result is False - mock_logger_error.assert_called_with(timeout_message) - - -def test_wait_until_timeout(): - """Waits roughly the specified timeout duration before returning False.""" - # Arrange - start_time = time.time() - condition = MagicMock(return_value=False) - timeout = 0.1 - - # Act - result = wait_until(condition, timeout=timeout) - - # Assert - assert result is False - duration = time.time() - start_time - assert duration == pytest.approx(timeout, abs=0.02) \ No newline at end of file From 22b8ac56aff045858670e7c5c93d6263db7331b1 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:39:16 +0100 Subject: [PATCH 23/31] test: replaced test_utils with test_utils_new --- test/integration/base/test_rest_client_i.py | 2 +- .../base/test_websocket_client_i.py | 2 +- test/integration/client/test_ibkr_client_i.py | 2 +- test/integration/client/test_ibkr_utils_i.py | 2 +- .../client/test_ibkr_ws_client_i.py | 2 +- test/test_utils.py | 505 +++++++++--------- test/test_utils_new.py | 303 ----------- 7 files changed, 267 insertions(+), 551 deletions(-) delete mode 100644 test/test_utils_new.py diff --git a/test/integration/base/test_rest_client_i.py b/test/integration/base/test_rest_client_i.py index 1feda7b8..c7effe2d 100644 --- a/test/integration/base/test_rest_client_i.py +++ b/test/integration/base/test_rest_client_i.py @@ -11,7 +11,7 @@ from ibind.support.errors import ExternalBrokerError from ibind.base.rest_client import Result, RestClient from ibind.support.logs import ibind_logs_initialize -from test.test_utils_new import CaptureLogsContext +from test.test_utils import CaptureLogsContext _URL = 'https://localhost:5000' diff --git a/test/integration/base/test_websocket_client_i.py b/test/integration/base/test_websocket_client_i.py index aa3d1221..4237b648 100644 --- a/test/integration/base/test_websocket_client_i.py +++ b/test/integration/base/test_websocket_client_i.py @@ -7,7 +7,7 @@ from ibind.base.ws_client import WsClient from ibind.support.py_utils import tname from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils_new import capture_logs +from test_utils import capture_logs _URL = 'wss://localhost:5000/v1/api/ws' _MAX_RECONNECT_ATTEMPTS = 4 diff --git a/test/integration/client/test_ibkr_client_i.py b/test/integration/client/test_ibkr_client_i.py index b29c1a0f..986a70ed 100644 --- a/test/integration/client/test_ibkr_client_i.py +++ b/test/integration/client/test_ibkr_client_i.py @@ -11,7 +11,7 @@ from ibind.support.errors import ExternalBrokerError from ibind.support.logs import ibind_logs_initialize from integration.client import ibkr_responses -from test_utils_new import CaptureLogsContext +from test_utils import CaptureLogsContext _URL = 'https://localhost:5000' diff --git a/test/integration/client/test_ibkr_utils_i.py b/test/integration/client/test_ibkr_utils_i.py index 2eb92395..e03569b7 100644 --- a/test/integration/client/test_ibkr_utils_i.py +++ b/test/integration/client/test_ibkr_utils_i.py @@ -15,7 +15,7 @@ parse_order_request, ) from test.integration.client import ibkr_responses -from test.test_utils_new import CaptureLogsContext +from test.test_utils import CaptureLogsContext # -------------------------------------------------------------------------------------- diff --git a/test/integration/client/test_ibkr_ws_client_i.py b/test/integration/client/test_ibkr_ws_client_i.py index 60da7ab7..4932cb94 100644 --- a/test/integration/client/test_ibkr_ws_client_i.py +++ b/test/integration/client/test_ibkr_ws_client_i.py @@ -10,7 +10,7 @@ from ibind.client.ibkr_client import IbkrClient from ibind.client.ibkr_ws_client import IbkrWsClient, IbkrSubscriptionProcessor, IbkrWsKey from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils_new import capture_logs +from test_utils import capture_logs _URL_WS = 'wss://localhost:5000/v1/api/ws' _URL_REST = 'https://localhost:5000' diff --git a/test/test_utils.py b/test/test_utils.py index d8523004..8fb7a233 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,284 +1,303 @@ import functools +import inspect import logging -import sys +import os import traceback -import types -import unittest -from unittest import TestCase -from unittest._log import _CapturingHandler, _AssertLogsContext +from pathlib import Path +from typing import List, TypeVar -from ibind.support.py_utils import make_clean_stack +from ibind.support.logs import get_logger_children +from ibind.support.py_utils import make_clean_stack, OneOrMany, UNDEFINED +_NAME_TO_LEVEL = logging.getLevelNamesMapping() -def raise_from_context(cm, level='WARNING'): - for record in cm.records: - if record.levelno >= getattr(logging, level): - raise RuntimeError(record.message) +# --- New Functions and Types --- +def accepts_kwargs(func): + """Returns True if func accepts **kwargs, else False.""" + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return False -def verify_log(test_case: TestCase, cm, expected_messages, comparison: callable = lambda x, y: x == y): - messages = [record.msg for record in cm.records] - missing_expected = expected_messages.copy() - for i, expected_msg in enumerate(expected_messages): - for msg in messages: - if comparison(expected_msg, msg): - missing_expected.remove(expected_msg) - break - - if missing_expected: - test_case.fail('Expected log(s) not found:\n\t{}'.format('\n\t'.join(missing_expected))) - - -def verify_log_simple(test_self, cm, expected_messages): - for i, msg in enumerate(expected_messages): - test_self.assertEqual(msg, cm.records[i].msg) - - -def exact_log(test_case, cm, expected_messages): - test_case.assertEqual(expected_messages, [record.msg for record in cm.records]) - +# --- Logging Utilities --- -class SafeAssertLogs(_AssertLogsContext): - """ - The self.assertLogs context manager, that sets log level on the handler instead of logger. +class LoggingWatcher: + """Helper class for capturing and asserting logs during testing.""" - Original docstring: - A context manager used to implement TestCase.assertLogs(). - """ + def __init__(self, logger): + self.logger = logger + self.records = [] + self.output = [] - def __init__(self, *args, logger_level: str = None, **kwargs): - if sys.version_info < (3, 10, 0) and 'no_logs' in kwargs: - del kwargs['no_logs'] + def _process_logs(self, expected_messages: OneOrMany[str], comparison: callable = lambda x, y: x == y): + if not isinstance(expected_messages, list): + expected_messages = [expected_messages] - super().__init__(*args, **kwargs) - self.logger_level = logger_level + if not self.output: + return [], expected_messages - def __enter__(self, include_original_handlers: bool = False): - if isinstance(self.logger_name, logging.Logger): - logger = self.logger = self.logger_name - else: - logger = self.logger = logging.getLogger(self.logger_name) - formatter = logging.Formatter(self.LOGGING_FORMAT) - handler = _CapturingHandler() - handler.setFormatter(formatter) - self.watcher = handler.watcher - self.old_handlers = logger.handlers[:] - self.old_level = logger.level - self.old_propagate = logger.propagate - logger.handlers = [handler] - handler.setLevel(self.level) # this one line is different, originally was `logger.setLevel` - logger.propagate = False - if self.logger_level is not None: - logger.setLevel(getattr(logging, self.logger_level)) - - if include_original_handlers: - logger.handlers += self.old_handlers - logger.propagate = True - return handler.watcher - - -def get_logger_children(main_logger) -> list[logging.Logger]: - """ - Gets child loggers. Added as a support compat for Python version 3.11 and below. - Source: https://github.com/python/cpython/blob/3.12/Lib/logging/__init__.py#L1831 - """ - - def _hierlevel(logger): - if logger is logger.manager.root: - return 0 - return 1 + logger.name.count('.') - - d = main_logger.manager.loggerDict - # exclude PlaceHolders - the last check is to ensure that lower-level - # descendants aren't returned - if there are placeholders, a logger's - # parent field might point to a grandparent or ancestor thereof. - return [ - item - for item in d.values() - if isinstance(item, logging.Logger) and item.parent is main_logger and _hierlevel(item) == 1 + _hierlevel(item.parent) - ] - - -class RaiseLogsContext: - """ - Captures log messages at or above a specified level and raises unexpected ones as exceptions. - - This context manager monitors log messages from a specified logger. Any log messages - at or above the given logging level are recorded. If a message is not explicitly - expected, a `RuntimeError` is raised, including the stack trace of the log call. It ensures - loggers are restored to their original state after use. - - Note: - - When used in conjunction with `self.assertLogs` or `SafeAssertLogs`, ensure this context manager is defined last to properly assert log expectations. - - Args: - test_case (TestCase): The test case instance, typically from `unittest.TestCase`. - logger_name (str | None): The name of the logger to monitor. Defaults to the root logger. - level (str): The logging level threshold (e.g., 'ERROR', 'WARNING'). Logs at or above this level are captured. - expected_errors (list[str] | None): A list of log messages that are expected and should not trigger an exception. - comparison (Callable[[str, str], bool]): A function to compare expected errors with log messages. - Defaults to an exact string match (`lambda x, y: x == y`). - - Example Usage: - >>> with RaiseLogsContext(self, logger_name='my_logger', level='WARNING', expected_errors=['My expected warning']): - ... logging.getLogger('my_logger').warning('My expected warning') # No error - ... logging.getLogger('my_logger').error('Unexpected issue') # Raises RuntimeError - """ + messages = [msg for msg in self.output] + missing_expected = expected_messages.copy() + found = [] + for i, expected_msg in enumerate(expected_messages): + for msg in messages: + if comparison(expected_msg, msg): + found.append(msg) + missing_expected.remove(expected_msg) + break + return found, missing_expected + + def exact_log(self, expected_messages: OneOrMany[str]): + """Assert that all expected messages appear in the captured logs.""" + found, missing_expected = self._process_logs(expected_messages, lambda x, y: x == y) + if len(missing_expected) > 0: + raise AssertionError(f"Expected exact log(s) not found:\n\t{'\n\t'.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") + + def partial_log(self, expected_messages: OneOrMany[str]): + """Assert that each expected message is a substring of at least one captured log message.""" + found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) + if len(missing_expected) > 0: + raise AssertionError(f"Expected partial log(s) not found:\n\t{'\n\t'.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") + + def log_excludes(self, expected_messages: OneOrMany[str]): + """Assert that none of the expected messages appear in any captured log message.""" + found, _ = self._process_logs(expected_messages, lambda x, y: x in y) + if found: + raise AssertionError(f"Unexpected log(s) found:\n\t{'\n\t'.join(found)}\n\nCurrent logs:\n{self.format_logs()}\n") + + def format_logs(self): + """Return a formatted string of all captured log messages.""" + return f"\n{self} captured {len(self.output)} logs:\n[\n\t{'\n\t'.join(self.output)}\n]" + + def count_occurrences(self, msg: str): + """Count the number of occurrences of a message in the captured logs.""" + return sum(1 for log in self.output if msg in log) + + def print(self): + """Print the formatted logs.""" + print(self.format_logs()) + + def __str__(self): + return f'LoggingWatcher({self.logger.name})' + +class _CapturingHandler(logging.Handler): + """A logging handler capturing all (raw and formatted) logging output.""" + def __init__(self, logger): + logging.Handler.__init__(self) + self.watcher = LoggingWatcher(logger) + + def flush(self): + pass + + def emit(self, record): + self.watcher.records.append(record) + msg = self.format(record) + self.watcher.output.append(msg) + +class CaptureLogsContext: + LOGGING_FORMAT = "%(message)s" def __init__( self, - test_case: TestCase, - logger_name=None, - level='ERROR', - expected_errors: [str] = None, - comparison: callable = lambda x, y: x == y, + logger='ibind', + level='DEBUG', + logger_level: str = None, + error_level='WARNING', + no_logs=UNDEFINED, + expected_errors=None, + partial_match=False, + attach_stack=True, ): - self._test_case = test_case - self._logger_name = logger_name - self._level = level - self._level_no = getattr(logging, level) - if expected_errors is None: - expected_errors = [] - self._expected_errors = expected_errors - self._comparison = comparison - - def monkey_patch_log(self, original_method): - """Wraps a logger method to attach a manually captured stack trace to log records.""" - - def new_method(msg, *args, **kwargs): - # Store the manually captured stack trace in the log record - stack = make_clean_stack() - if 'extra' not in kwargs: - kwargs['extra'] = {} - kwargs['extra']['manual_trace'] = stack - - # Call the original logging method with the modified arguments - return original_method(msg, *args, **kwargs) - - return new_method - - def monkey_patch_loggers(self, loggers): - """Monkey-patches loggers to attach a stack trace to warning and error messages.""" + self._logger = logger + self.level = getattr(logging, level) if isinstance(level, str) else level + self.logger_level = getattr(logging, logger_level) if isinstance(logger_level, str) else logger_level + self.no_logs = no_logs + self.expected_errors = expected_errors or [] + self.partial_match = partial_match + self.comparison = (lambda x, y: x in y) if partial_match else (lambda x, y: x == y) + self.attach_stack = attach_stack + self.error_level = getattr(logging, error_level) if isinstance(error_level, str) else (error_level if error_level is not None else self.level) + if not isinstance(self.expected_errors, list): + self.expected_errors = [self.expected_errors] + + def _monkey_patch_log(self, logger): + original_log = logger._log + def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): + if extra is None: + extra = {} + extra['manual_trace'] = make_clean_stack()[:-2] + + return original_log(level, msg, args, exc_info, extra, stack_info, stacklevel) + + logger.__old_log_method__ = original_log + logger._log = new_log + + def _monkey_patch_loggers(self, loggers): for logger in loggers: - if self._level_no <= logging.ERROR: - logger.__old_error_method__ = logger.error - logger.error = self.monkey_patch_log(logger.error) + self._monkey_patch_log(logger) - if self._level_no <= logging.WARNING: - logger.__old_warning_method__ = logger.warning - logger.warning = self.monkey_patch_log(logger.warning) - - def restore_loggers(self, loggers): - """Restores the original error and warning logging methods after patching.""" + def _restore_loggers(self, loggers): for logger in loggers: - if self._level_no <= logging.ERROR: - logger.error = logger.__old_error_method__ # Restore the original error method + if hasattr(logger, '__old_log_method__'): + logger._log = logger.__old_log_method__ - if self._level_no <= logging.WARNING: - logger.warning = logger.__old_warning_method__ # Restore the original warning method + def logger_name(self): + return self._logger.name if isinstance(self._logger, logging.Logger) else self._logger - def __enter__(self): - """ - Initializes the logging context by patching loggers and setting up a log watcher. + def acquire(self) -> LoggingWatcher: + self.logger = logging.getLogger(self.logger_name()) + self.old_handlers = self.logger.handlers[:] + self.old_level = self.logger.level + self.old_propagate = self.logger.propagate - This method ensures that logs at the specified level are captured and asserts - that unexpected log messages are raised as errors. - """ + formatter = logging.Formatter(self.LOGGING_FORMAT, datefmt='%H:%M:%S') + handler = _CapturingHandler(self.logger) + handler.setFormatter(formatter) + self.watcher = handler.watcher + self.logger.handlers = [handler] + handler.setLevel(self.level) + self.logger.propagate = False + if self.logger_level is not None: + self.logger.setLevel(self.logger_level) - self._logger = logging.getLogger(self._logger_name) - loggers_to_be_patched = [self._logger] + get_logger_children(self._logger) - self.monkey_patch_loggers(loggers_to_be_patched) # Apply monkey-patching to attach stack traces to logged messages + if self.attach_stack: + loggers_to_patch = [self.logger] + get_logger_children(self.logger) + self._monkey_patch_loggers(loggers_to_patch) + self._loggers_to_patch = loggers_to_patch + else: + self._loggers_to_patch = [] - # Initialize SafeAssertLogs, a helper to capture and assert log records - self._context_manager = SafeAssertLogs(self._test_case, self._logger, level=self._level, no_logs=False) + return self.watcher - # Enter the SafeAssertLogs context, starting log capture and returning the watcher - self._watcher = self._context_manager.__enter__(include_original_handlers=True) - return self._watcher + def _raise_unexpected_log(self, record): + if hasattr(record, 'manual_trace'): + raise RuntimeError(f'\n{"".join(traceback.format_list(record.manual_trace))}Logger {self.logger} logged an unexpected message:\n{record.msg}') + raise RuntimeError(f'\n...\nFile "{record.pathname}", line {record.lineno} in {record.funcName}\n{record.msg}') - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Restores original logger methods and verifies captured log messages. - - This method is called when exiting the context manager. It ensures that: - - Monkey-patched loggers are restored to their original state. - - If an exception occurred inside the `with` block, it is propagated normally. - - If no exception occurred, all captured log messages are checked against expected errors. - - Unexpected log messages result in a `RuntimeError`. - """ - - # Restore original logging methods that were monkey-patched - loggers_to_be_patched = [self._logger] + get_logger_children(self._logger) - self.restore_loggers(loggers_to_be_patched) - - # If an exception occurred inside the 'with' block, return False to let Python re-raise it - if exc_type is not None: - return False - - # If no logs were captured return True to indicate that no errors were encountered and that the context exited cleanly - if len(self._watcher.records) == 0: + def _process_exit_logs(self): + records = self.watcher.records + if self.no_logs is not UNDEFINED and self.no_logs: + if records: + self._raise_unexpected_log(records[0]) return True - for record in self._watcher.records: - found = False - - # Check if the log message matches any of the expected error messages - for expected_error in self._expected_errors: - if self._comparison(expected_error, record.msg): - found = True - break + if self.no_logs is not UNDEFINED and not records: + raise AssertionError(f"no logs of level {logging.getLevelName(self.level)} or higher triggered on {self.logger.name}") - # If the message is expected, move on to the next record - if found: + for record in records: + if record.levelno < self.error_level: continue + if any(self.comparison(expected, record.msg) for expected in self.expected_errors): + continue + self._raise_unexpected_log(record) - # If the log record has a manually stored traceback, raise an error with that traceback - if hasattr(record, 'manual_trace'): - raise RuntimeError( - '\n' + ''.join(traceback.format_list(record.manual_trace)) + f'Logger {self._logger} logged an unexpected message:\n{record.msg}' - ) - - # Otherwise, raise an error using the log record's location - raise RuntimeError(f'\n...\nFile "{record.pathname}", line {record.lineno} in {record.funcName}\n{record.msg}') + if self.partial_match: + self.watcher.partial_log(self.expected_errors) + else: + self.watcher.exact_log(self.expected_errors) + def release(self, exc_type=None, exc_val=None, exc_tb=None): + self.logger.handlers = self.old_handlers + self.logger.propagate = self.old_propagate + self.logger.setLevel(self.old_level) + if self._loggers_to_patch: + self._restore_loggers(self._loggers_to_patch) + self._process_exit_logs() + return exc_type is None -def raise_logs(level='ERROR', logger_name=None): - def _wrapper(fn): - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - with RaiseLogsContext(self, level=level, logger_name=logger_name): - return fn(self, *args, **kwargs) + def __enter__(self) -> LoggingWatcher: + return self.acquire() + def __exit__(self, exc_type, exc_val, exc_tb): + return self.release(exc_type, exc_val, exc_tb) + +def capture_logs(**ctx_kwargs): + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + capture_log_context = CaptureLogsContext(**ctx_kwargs) + logger_name = f'_cm_{capture_log_context.logger_name()}' + fn_exc = None + log_exc = None + + cm = capture_log_context.acquire() + if accepts_kwargs(test_func): + kwargs[logger_name] = cm + + try: + rv = test_func(*args, **kwargs) + except Exception as e: + rv = None + fn_exc = e + + try: + capture_log_context.release() + except Exception as e2: + log_exc = e2 + + if fn_exc is not None: + if log_exc is not None: + print('Unexpected log found in test:') + traceback.print_exception(log_exc) + raise fn_exc + elif log_exc is not None: + raise log_exc + + return rv return wrapper + return decorator - return _wrapper - - -def decorate_methods(decorator, starts_with=''): - class DecorateMethods(type): - """Decorate all methods of the class with the decorator provided""" - - def __new__(cls, name, bases, attrs, **kwargs): - exclude = kwargs.get('exclude', []) - - for attr_name, attr_value in attrs.items(): - if ( - isinstance(attr_value, types.FunctionType) - and attr_name.startswith(starts_with) - and attr_name not in exclude - and not hasattr(attr_value, '__exclude_decorator__') - and not attr_name.startswith('__') - ): - attrs[attr_name] = decorator(attr_value) - - return super(DecorateMethods, cls).__new__(cls, name, bases, attrs) - - return DecorateMethods +# --- Time Mocking Utilities --- +class MockTimeController: + def __init__(self, target_module, time_sequence=None, start_time=0.0): + self.target_module = target_module + if time_sequence is not None: + self.time_sequence = list(time_sequence) + self.call_index = 0 + else: + self.time_sequence = None + self.current_time = start_time + self.original_time_module = None + + def advance_time(self, seconds): + if self.time_sequence is not None: + raise ValueError("Cannot advance time when using time_sequence.") + self.current_time += seconds + + def set_time(self, time_value): + if self.time_sequence is not None: + raise ValueError("Cannot set time when using time_sequence.") + self.current_time = time_value + + def mock_time(self): + if self.time_sequence is not None: + if self.call_index < len(self.time_sequence): + time_value = self.time_sequence[self.call_index] + self.call_index += 1 + return time_value + else: + return self.time_sequence[-1] + else: + return self.current_time -class TestCaseWithRaiseLogs(unittest.TestCase, metaclass=decorate_methods(raise_logs(logger_name='ibind'), starts_with='test')): ... + def __enter__(self): + target_module_obj = __import__(self.target_module, fromlist=['']) + self.original_time_module = target_module_obj.time + class MockTimeModule: + def __init__(self, original_module, mock_time_func): + self.original_module = original_module + self.time = mock_time_func + def __getattr__(self, name): + return getattr(self.original_module, name) + target_module_obj.time = MockTimeModule(self.original_time_module, self.mock_time) + self.target_module_obj = target_module_obj + return self + def __exit__(self, exc_type, exc_val, exc_tb): + self.target_module_obj.time = self.original_time_module -def exclude_decorator(fn): - fn.__exclude_decorator__ = True - return fn +def mock_module_time(target_module, time_sequence=None, start_time=0.0): + return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) \ No newline at end of file diff --git a/test/test_utils_new.py b/test/test_utils_new.py deleted file mode 100644 index 8fb7a233..00000000 --- a/test/test_utils_new.py +++ /dev/null @@ -1,303 +0,0 @@ -import functools -import inspect -import logging -import os -import traceback -from pathlib import Path -from typing import List, TypeVar - -from ibind.support.logs import get_logger_children -from ibind.support.py_utils import make_clean_stack, OneOrMany, UNDEFINED - -_NAME_TO_LEVEL = logging.getLevelNamesMapping() - -# --- New Functions and Types --- - -def accepts_kwargs(func): - """Returns True if func accepts **kwargs, else False.""" - sig = inspect.signature(func) - for param in sig.parameters.values(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - return True - return False - -# --- Logging Utilities --- - -class LoggingWatcher: - """Helper class for capturing and asserting logs during testing.""" - - def __init__(self, logger): - self.logger = logger - self.records = [] - self.output = [] - - def _process_logs(self, expected_messages: OneOrMany[str], comparison: callable = lambda x, y: x == y): - if not isinstance(expected_messages, list): - expected_messages = [expected_messages] - - if not self.output: - return [], expected_messages - - messages = [msg for msg in self.output] - missing_expected = expected_messages.copy() - found = [] - for i, expected_msg in enumerate(expected_messages): - for msg in messages: - if comparison(expected_msg, msg): - found.append(msg) - missing_expected.remove(expected_msg) - break - return found, missing_expected - - def exact_log(self, expected_messages: OneOrMany[str]): - """Assert that all expected messages appear in the captured logs.""" - found, missing_expected = self._process_logs(expected_messages, lambda x, y: x == y) - if len(missing_expected) > 0: - raise AssertionError(f"Expected exact log(s) not found:\n\t{'\n\t'.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") - - def partial_log(self, expected_messages: OneOrMany[str]): - """Assert that each expected message is a substring of at least one captured log message.""" - found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) - if len(missing_expected) > 0: - raise AssertionError(f"Expected partial log(s) not found:\n\t{'\n\t'.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") - - def log_excludes(self, expected_messages: OneOrMany[str]): - """Assert that none of the expected messages appear in any captured log message.""" - found, _ = self._process_logs(expected_messages, lambda x, y: x in y) - if found: - raise AssertionError(f"Unexpected log(s) found:\n\t{'\n\t'.join(found)}\n\nCurrent logs:\n{self.format_logs()}\n") - - def format_logs(self): - """Return a formatted string of all captured log messages.""" - return f"\n{self} captured {len(self.output)} logs:\n[\n\t{'\n\t'.join(self.output)}\n]" - - def count_occurrences(self, msg: str): - """Count the number of occurrences of a message in the captured logs.""" - return sum(1 for log in self.output if msg in log) - - def print(self): - """Print the formatted logs.""" - print(self.format_logs()) - - def __str__(self): - return f'LoggingWatcher({self.logger.name})' - -class _CapturingHandler(logging.Handler): - """A logging handler capturing all (raw and formatted) logging output.""" - def __init__(self, logger): - logging.Handler.__init__(self) - self.watcher = LoggingWatcher(logger) - - def flush(self): - pass - - def emit(self, record): - self.watcher.records.append(record) - msg = self.format(record) - self.watcher.output.append(msg) - -class CaptureLogsContext: - LOGGING_FORMAT = "%(message)s" - - def __init__( - self, - logger='ibind', - level='DEBUG', - logger_level: str = None, - error_level='WARNING', - no_logs=UNDEFINED, - expected_errors=None, - partial_match=False, - attach_stack=True, - ): - self._logger = logger - self.level = getattr(logging, level) if isinstance(level, str) else level - self.logger_level = getattr(logging, logger_level) if isinstance(logger_level, str) else logger_level - self.no_logs = no_logs - self.expected_errors = expected_errors or [] - self.partial_match = partial_match - self.comparison = (lambda x, y: x in y) if partial_match else (lambda x, y: x == y) - self.attach_stack = attach_stack - self.error_level = getattr(logging, error_level) if isinstance(error_level, str) else (error_level if error_level is not None else self.level) - if not isinstance(self.expected_errors, list): - self.expected_errors = [self.expected_errors] - - def _monkey_patch_log(self, logger): - original_log = logger._log - def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): - if extra is None: - extra = {} - extra['manual_trace'] = make_clean_stack()[:-2] - - return original_log(level, msg, args, exc_info, extra, stack_info, stacklevel) - - logger.__old_log_method__ = original_log - logger._log = new_log - - def _monkey_patch_loggers(self, loggers): - for logger in loggers: - self._monkey_patch_log(logger) - - def _restore_loggers(self, loggers): - for logger in loggers: - if hasattr(logger, '__old_log_method__'): - logger._log = logger.__old_log_method__ - - def logger_name(self): - return self._logger.name if isinstance(self._logger, logging.Logger) else self._logger - - def acquire(self) -> LoggingWatcher: - self.logger = logging.getLogger(self.logger_name()) - self.old_handlers = self.logger.handlers[:] - self.old_level = self.logger.level - self.old_propagate = self.logger.propagate - - formatter = logging.Formatter(self.LOGGING_FORMAT, datefmt='%H:%M:%S') - handler = _CapturingHandler(self.logger) - handler.setFormatter(formatter) - self.watcher = handler.watcher - self.logger.handlers = [handler] - handler.setLevel(self.level) - self.logger.propagate = False - if self.logger_level is not None: - self.logger.setLevel(self.logger_level) - - if self.attach_stack: - loggers_to_patch = [self.logger] + get_logger_children(self.logger) - self._monkey_patch_loggers(loggers_to_patch) - self._loggers_to_patch = loggers_to_patch - else: - self._loggers_to_patch = [] - - return self.watcher - - def _raise_unexpected_log(self, record): - if hasattr(record, 'manual_trace'): - raise RuntimeError(f'\n{"".join(traceback.format_list(record.manual_trace))}Logger {self.logger} logged an unexpected message:\n{record.msg}') - raise RuntimeError(f'\n...\nFile "{record.pathname}", line {record.lineno} in {record.funcName}\n{record.msg}') - - def _process_exit_logs(self): - records = self.watcher.records - if self.no_logs is not UNDEFINED and self.no_logs: - if records: - self._raise_unexpected_log(records[0]) - return True - - if self.no_logs is not UNDEFINED and not records: - raise AssertionError(f"no logs of level {logging.getLevelName(self.level)} or higher triggered on {self.logger.name}") - - for record in records: - if record.levelno < self.error_level: - continue - if any(self.comparison(expected, record.msg) for expected in self.expected_errors): - continue - self._raise_unexpected_log(record) - - if self.partial_match: - self.watcher.partial_log(self.expected_errors) - else: - self.watcher.exact_log(self.expected_errors) - - def release(self, exc_type=None, exc_val=None, exc_tb=None): - self.logger.handlers = self.old_handlers - self.logger.propagate = self.old_propagate - self.logger.setLevel(self.old_level) - if self._loggers_to_patch: - self._restore_loggers(self._loggers_to_patch) - self._process_exit_logs() - return exc_type is None - - def __enter__(self) -> LoggingWatcher: - return self.acquire() - - def __exit__(self, exc_type, exc_val, exc_tb): - return self.release(exc_type, exc_val, exc_tb) - -def capture_logs(**ctx_kwargs): - def decorator(test_func): - @functools.wraps(test_func) - def wrapper(*args, **kwargs): - capture_log_context = CaptureLogsContext(**ctx_kwargs) - logger_name = f'_cm_{capture_log_context.logger_name()}' - fn_exc = None - log_exc = None - - cm = capture_log_context.acquire() - if accepts_kwargs(test_func): - kwargs[logger_name] = cm - - try: - rv = test_func(*args, **kwargs) - except Exception as e: - rv = None - fn_exc = e - - try: - capture_log_context.release() - except Exception as e2: - log_exc = e2 - - if fn_exc is not None: - if log_exc is not None: - print('Unexpected log found in test:') - traceback.print_exception(log_exc) - raise fn_exc - elif log_exc is not None: - raise log_exc - - return rv - return wrapper - return decorator - -# --- Time Mocking Utilities --- - -class MockTimeController: - def __init__(self, target_module, time_sequence=None, start_time=0.0): - self.target_module = target_module - if time_sequence is not None: - self.time_sequence = list(time_sequence) - self.call_index = 0 - else: - self.time_sequence = None - self.current_time = start_time - self.original_time_module = None - - def advance_time(self, seconds): - if self.time_sequence is not None: - raise ValueError("Cannot advance time when using time_sequence.") - self.current_time += seconds - - def set_time(self, time_value): - if self.time_sequence is not None: - raise ValueError("Cannot set time when using time_sequence.") - self.current_time = time_value - - def mock_time(self): - if self.time_sequence is not None: - if self.call_index < len(self.time_sequence): - time_value = self.time_sequence[self.call_index] - self.call_index += 1 - return time_value - else: - return self.time_sequence[-1] - else: - return self.current_time - - def __enter__(self): - target_module_obj = __import__(self.target_module, fromlist=['']) - self.original_time_module = target_module_obj.time - class MockTimeModule: - def __init__(self, original_module, mock_time_func): - self.original_module = original_module - self.time = mock_time_func - def __getattr__(self, name): - return getattr(self.original_module, name) - target_module_obj.time = MockTimeModule(self.original_time_module, self.mock_time) - self.target_module_obj = target_module_obj - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.target_module_obj.time = self.original_time_module - -def mock_module_time(target_module, time_sequence=None, start_time=0.0): - return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) \ No newline at end of file From 1089a17dea1ea79b897a63ea0dcc7c5cb2595652 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:45:37 +0100 Subject: [PATCH 24/31] chore: removed coverage report files --- cov_new.txt | 32 -------------------------------- cov_old.txt | 32 -------------------------------- coverage_new.txt | 44 -------------------------------------------- coverage_old.txt | 44 -------------------------------------------- 4 files changed, 152 deletions(-) delete mode 100644 cov_new.txt delete mode 100644 cov_old.txt delete mode 100644 coverage_new.txt delete mode 100644 coverage_old.txt diff --git a/cov_new.txt b/cov_new.txt deleted file mode 100644 index b69eaaed..00000000 --- a/cov_new.txt +++ /dev/null @@ -1,32 +0,0 @@ ----------- coverage: platform win32, python 3.13.11-final-0 ---------- -Name Stmts Miss Cover Missing ------------------------------------------------------------------------------------ -ibind\__init__.py 13 0 100% -ibind\base\__init__.py 0 0 100% -ibind\base\queue_controller.py 18 0 100% -ibind\base\rest_client.py 152 35 77% 103, 110, 235, 240, 252, 265-275, 280-284, 296-297, 306-308, 311, 334-337, 340-346, 355 -ibind\base\subscription_controller.py 125 33 74% 65, 67-69, 74-75, 86, 89, 92, 100-101, 154, 173-191, 211, 283, 286, 289-292, 326, 334, 358 -ibind\base\ws_client.py 217 38 82% 64, 90, 95, 118-120, 124-129, 151, 155-156, 168-169, 190, 196-198, 202-204, 238, 247, 254-256, 317-322, 363, 439-440, 457, 470 -ibind\client\__init__.py 0 0 100% -ibind\client\ibkr_client.py 119 69 42% 87-91, 116-119, 132-135, 143-149, 163-164, 195-218, 234-237, 250-251, 254-256, 265-267, 270-272, 286, 289-333 -ibind\client\ibkr_client_mixins\__init__.py 0 0 100% -ibind\client\ibkr_client_mixins\accounts_mixin.py 4 0 100% -ibind\client\ibkr_client_mixins\contract_mixin.py 25 0 100% -ibind\client\ibkr_client_mixins\marketdata_mixin.py 61 28 54% 54-85, 226, 236-241 -ibind\client\ibkr_client_mixins\order_mixin.py 22 12 45% 102-110, 175-183 -ibind\client\ibkr_client_mixins\portfolio_mixin.py 5 0 100% -ibind\client\ibkr_client_mixins\scanner_mixin.py 5 0 100% -ibind\client\ibkr_client_mixins\session_mixin.py 39 13 67% 105, 130-145 -ibind\client\ibkr_client_mixins\watchlist_mixin.py 4 0 100% -ibind\client\ibkr_definitions.py 6 0 100% -ibind\client\ibkr_utils.py 226 42 81% 222, 310-313, 316, 440-441, 446, 603-606, 617-620, 639-642, 645-657, 666-672, 684-689 -ibind\client\ibkr_ws_client.py 238 65 73% 274, 277-281, 321, 326-328, 331, 351, 384, 390-394, 403-416, 422-423, 431-442, 447-460, 480, 484, 488, 491, 504, 517, 535, 538, 551, 702-714 -ibind\oauth\__init__.py 26 26 0% 1-58 -ibind\oauth\oauth1a.py 164 164 0% 1-466 -ibind\support\__init__.py 0 0 100% -ibind\support\errors.py 4 0 100% -ibind\support\logs.py 82 13 84% 23-29, 94, 96, 136, 143, 164, 172-173 -ibind\support\py_utils.py 87 25 71% 143, 153-156, 169, 308, 324-332, 336-354 -ibind\var.py 88 4 95% 24-25, 36-37 ------------------------------------------------------------------------------------ -TOTAL 1730 567 67% \ No newline at end of file diff --git a/cov_old.txt b/cov_old.txt deleted file mode 100644 index 40630161..00000000 --- a/cov_old.txt +++ /dev/null @@ -1,32 +0,0 @@ ----------- coverage: platform win32, python 3.13.11-final-0 ---------- -Name Stmts Miss Cover Missing ------------------------------------------------------------------------------------ -ibind\__init__.py 13 0 100% -ibind\base\__init__.py 0 0 100% -ibind\base\queue_controller.py 18 0 100% -ibind\base\rest_client.py 152 35 77% 103, 110, 235, 240, 252, 265-275, 280-284, 296-297, 306-308, 311, 334-337, 340-346, 355 -ibind\base\subscription_controller.py 125 33 74% 65, 67-69, 74-75, 86, 89, 92, 100-101, 154, 173-191, 211, 283, 286, 289-292, 326, 334, 358 -ibind\base\ws_client.py 217 38 82% 64, 90, 95, 118-120, 124-129, 151, 155-156, 168-169, 190, 196-198, 202-204, 238, 247, 254-256, 317-322, 363, 439-440, 457, 470 -ibind\client\__init__.py 0 0 100% -ibind\client\ibkr_client.py 119 69 42% 87-91, 116-119, 132-135, 143-149, 163-164, 195-218, 234-237, 250-251, 254-256, 265-267, 270-272, 286, 289-333 -ibind\client\ibkr_client_mixins\__init__.py 0 0 100% -ibind\client\ibkr_client_mixins\accounts_mixin.py 4 0 100% -ibind\client\ibkr_client_mixins\contract_mixin.py 25 0 100% -ibind\client\ibkr_client_mixins\marketdata_mixin.py 61 29 52% 54-85, 226, 236-241, 303 -ibind\client\ibkr_client_mixins\order_mixin.py 22 12 45% 102-110, 175-183 -ibind\client\ibkr_client_mixins\portfolio_mixin.py 5 0 100% -ibind\client\ibkr_client_mixins\scanner_mixin.py 5 0 100% -ibind\client\ibkr_client_mixins\session_mixin.py 39 13 67% 105, 130-145 -ibind\client\ibkr_client_mixins\watchlist_mixin.py 4 0 100% -ibind\client\ibkr_definitions.py 6 0 100% -ibind\client\ibkr_utils.py 226 42 81% 222, 310-313, 316, 440-441, 446, 603-606, 617-620, 639-642, 645-657, 666-672, 684-689 -ibind\client\ibkr_ws_client.py 238 65 73% 274, 277-281, 321, 326-328, 331, 351, 384, 390-394, 403-416, 422-423, 431-442, 447-460, 480, 484, 488, 491, 504, 517, 535, 538, 551, 702-714 -ibind\oauth\__init__.py 26 26 0% 1-58 -ibind\oauth\oauth1a.py 164 164 0% 1-466 -ibind\support\__init__.py 0 0 100% -ibind\support\errors.py 4 0 100% -ibind\support\logs.py 82 57 30% 20-29, 75-96, 121-141, 150-153, 156-157, 160, 163-168, 171-175 -ibind\support\py_utils.py 87 25 71% 143, 153-156, 169, 308, 324-332, 336-354 -ibind\var.py 88 4 95% 24-25, 36-37 ------------------------------------------------------------------------------------ -TOTAL 1730 612 65% \ No newline at end of file diff --git a/coverage_new.txt b/coverage_new.txt deleted file mode 100644 index 1272343d..00000000 --- a/coverage_new.txt +++ /dev/null @@ -1,44 +0,0 @@ -============================= test session starts ============================== -platform linux -- Python 3.12.12, pytest-8.4.2, pluggy-1.6.0 -rootdir: /app -configfile: pytest.ini -plugins: mock-3.15.1, cov-5.0.0 -collected 10 items - -test/integration/base/test_rest_client_i_new.py .......... [100%] - ----------- coverage: platform linux, python 3.12.12-final-0 ---------- -Name Stmts Miss Cover -------------------------------------------------------------------------- -ibind/__init__.py 13 0 100% -ibind/base/__init__.py 0 0 100% -ibind/base/queue_controller.py 18 7 61% -ibind/base/rest_client.py 152 36 76% -ibind/base/subscription_controller.py 125 104 17% -ibind/base/ws_client.py 217 184 15% -ibind/client/__init__.py 0 0 100% -ibind/client/ibkr_client.py 119 73 39% -ibind/client/ibkr_client_mixins/__init__.py 0 0 100% -ibind/client/ibkr_client_mixins/accounts_mixin.py 4 0 100% -ibind/client/ibkr_client_mixins/contract_mixin.py 25 14 44% -ibind/client/ibkr_client_mixins/marketdata_mixin.py 61 45 26% -ibind/client/ibkr_client_mixins/order_mixin.py 22 12 45% -ibind/client/ibkr_client_mixins/portfolio_mixin.py 5 0 100% -ibind/client/ibkr_client_mixins/scanner_mixin.py 5 0 100% -ibind/client/ibkr_client_mixins/session_mixin.py 39 29 26% -ibind/client/ibkr_client_mixins/watchlist_mixin.py 4 0 100% -ibind/client/ibkr_definitions.py 6 1 83% -ibind/client/ibkr_utils.py 226 131 42% -ibind/client/ibkr_ws_client.py 238 177 26% -ibind/oauth/__init__.py 26 26 0% -ibind/oauth/oauth1a.py 164 164 0% -ibind/support/__init__.py 0 0 100% -ibind/support/errors.py 4 0 100% -ibind/support/logs.py 82 13 84% -ibind/support/py_utils.py 87 63 28% -ibind/var.py 88 4 95% -------------------------------------------------------------------------- -TOTAL 1730 1083 37% - - -============================== 10 passed in 1.23s ============================== diff --git a/coverage_old.txt b/coverage_old.txt deleted file mode 100644 index 01bbe7f6..00000000 --- a/coverage_old.txt +++ /dev/null @@ -1,44 +0,0 @@ -============================= test session starts ============================== -platform linux -- Python 3.12.12, pytest-8.4.2, pluggy-1.6.0 -rootdir: /app -configfile: pytest.ini -plugins: mock-3.15.1, cov-5.0.0 -collected 8 items - -test/integration/base/test_rest_client_i.py ........ [100%] - ----------- coverage: platform linux, python 3.12.12-final-0 ---------- -Name Stmts Miss Cover -------------------------------------------------------------------------- -ibind/__init__.py 13 0 100% -ibind/base/__init__.py 0 0 100% -ibind/base/queue_controller.py 18 7 61% -ibind/base/rest_client.py 152 36 76% -ibind/base/subscription_controller.py 125 104 17% -ibind/base/ws_client.py 217 184 15% -ibind/client/__init__.py 0 0 100% -ibind/client/ibkr_client.py 119 73 39% -ibind/client/ibkr_client_mixins/__init__.py 0 0 100% -ibind/client/ibkr_client_mixins/accounts_mixin.py 4 0 100% -ibind/client/ibkr_client_mixins/contract_mixin.py 25 14 44% -ibind/client/ibkr_client_mixins/marketdata_mixin.py 61 45 26% -ibind/client/ibkr_client_mixins/order_mixin.py 22 12 45% -ibind/client/ibkr_client_mixins/portfolio_mixin.py 5 0 100% -ibind/client/ibkr_client_mixins/scanner_mixin.py 5 0 100% -ibind/client/ibkr_client_mixins/session_mixin.py 39 29 26% -ibind/client/ibkr_client_mixins/watchlist_mixin.py 4 0 100% -ibind/client/ibkr_definitions.py 6 1 83% -ibind/client/ibkr_utils.py 226 131 42% -ibind/client/ibkr_ws_client.py 238 177 26% -ibind/oauth/__init__.py 26 26 0% -ibind/oauth/oauth1a.py 164 164 0% -ibind/support/__init__.py 0 0 100% -ibind/support/errors.py 4 0 100% -ibind/support/logs.py 82 57 30% -ibind/support/py_utils.py 87 63 28% -ibind/var.py 88 4 95% -------------------------------------------------------------------------- -TOTAL 1730 1127 35% - - -============================== 8 passed in 1.26s =============================== From ccb1df1f749a794cc13dfff6c45e45287625277e Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:46:26 +0100 Subject: [PATCH 25/31] chore: removed migration_plan.md --- test/migration_plan.md | 145 ----------------------------------------- 1 file changed, 145 deletions(-) delete mode 100644 test/migration_plan.md diff --git a/test/migration_plan.md b/test/migration_plan.md deleted file mode 100644 index f5e0a7f1..00000000 --- a/test/migration_plan.md +++ /dev/null @@ -1,145 +0,0 @@ -# Unittest to Pytest Migration Plan - -This document outlines the roadmap for migrating our existing `unittest`-based tests to `pytest`. The goal is to modernize our testing suite, improve readability, and take advantage of `pytest`'s powerful features. - -## General Guidelines - -When migrating tests, please adhere to the following principles: - -- **New Test Files:** To compare test coverage before and after the migration, create a new test file for the migrated tests. For example, `test/integration/base/test_rest_client_i.py` should be migrated to `test/integration/base/test_rest_client_i_new.py`. -- **Post-migration check:** run the old and new test files *separately* with `--cov= --cov-report=term-missing` and confirm the covered/missing lines are identical (or document any differences). -- **Test Classes:** Convert `unittest.TestCase` subclasses into plain test functions. If a class structure is still beneficial for grouping related tests, you can use a class without inheriting from `unittest.TestCase`. -- **`setUp` and `tearDown`:** Replace `setUp` and `tearDown` methods with `pytest` fixtures. -- **Assertions:** Convert all `self.assert...` methods to plain `assert` statements. For example, `self.assertEqual(a, b)` becomes `assert a == b`. -- **Exception Handling:** Replace `with self.assertRaises(...)` with `with pytest.raises(...)`. -- **Logging:** Use the new `capture_logs` utility from `test_utils_new.py`. It can be used as a context manager (`with capture_logs(...) as cm:`) or as a decorator (`@capture_logs(...)`). This replaces all previous `unittest`-based logging helpers. The returned watcher object has methods like `exact_log`, `partial_log`, and `log_excludes` for assertions. -- **Arrange, Act, Assert:** Structure your tests using the ##Arrange, ##Act, ##Assert pattern. -- **Parametrization:** Use `@pytest.mark.parametrize` to run the same test with different inputs. - -## Additional Rules (learned from first few migrations) - -The following rules help avoid common migration pitfalls and reduce boilerplate. See: - -- `test/integration/base/test_rest_client_i_new.py` -- `test/integration/client/test_ibkr_client_i_new.py` - -### Fixtures and constants - -- **Prefer module constants for stable configuration** - - Put stable values such as `_URL`, `_TIMEOUT`, `_DEFAULT_PATH`, `_MAX_RETRIES` at module scope. - - Keep fixtures focused on objects with lifecycle/state (clients, mocks, results). - -- **Avoid “mega fixtures” that return tuples** - - If a `setUp` method created many objects, migrate it into multiple fixtures. - -### Patching (replacing class-level @patch) - -- **Use an autouse `requests_mock` fixture for common patching** - - When the original unittest test patched a whole `TestCase` class (e.g. `@patch('...requests')`), replicate it with a single `@pytest.fixture(autouse=True)`. - - Example pattern: - - ```python - @pytest.fixture(autouse=True) - def requests_mock(mocker, response): - requests_mock = mocker.patch('ibind.base.rest_client.requests') - requests_mock.request.return_value = response - return requests_mock - ``` - - Tests can still override behavior locally: - - - `requests_mock.request.side_effect = ReadTimeout()` - - `requests_mock.request.return_value = MagicMock(...)` - -### Preserve unittest semantics - -- **Float comparisons** - - `self.assertAlmostEqual(...)` should migrate to `pytest.approx(...)`. - -- **Logging expectations** - - Do not assert *more* than the unittest test asserted. - - If unittest checked a substring (e.g. `assertIn`), migrate to `partial_match=True` or explicit substring checks. - -- **Exceptions vs return values** - - Verify whether the production code *raises* or *returns* exceptions. - - A common pitfall is migrating a test to “return exception in results” when the implementation actually raises (or ignores) specific errors. - -- **Key types / coercions** - - Be careful with dict keys and parameter conversions. - - If production code casts IDs (e.g. `int(conid)`), results may be keyed by `int` even if the input looked like a string. - -- **Naming parity** - - Keep test names close to the original unittest names to make 1:1 mapping and review easier. - -## Migration Chunks - -The following files need to be migrated. Each file can be worked on independently. - ---- - -### 1. [✔] `test/integration/base/test_rest_client_i.py` - -- **Migration Steps:** - 1. Create a new file: `test/integration/base/test_rest_client_i_new.py`. - 2. In the new file, convert all `TestCase` subclasses into simple test functions. - 3. Replace the `setUp` method's logic with granular fixtures and module constants (avoid tuple-returning fixtures). - 4. Convert all `self.assert...` calls and `with self.assertRaises` to `assert` and `with pytest.raises(...)`. - 5. Replace `with self.assertLogs(...)` with the `capture_logs` context manager from `test_utils_new.py`. - 6. Refactor the class-level patch into an autouse fixture (e.g. `requests_mock`) so tests don't repeat patch boilerplate. - ---- - -### 2. [✔] `test/integration/base/test_websocket_client_i.py` - -- **Migration Steps:** - 1. Create a new file: `test/integration/base/test_websocket_client_i_new.py`. - 2. In the new file, convert the `TestWsClient` class into a series of test functions. - 3. Move the `setUp` logic into one or more `pytest` fixtures. - 4. Eliminate the complex `run_in_test_context` helper. Use the `mocker` fixture for patching and decorate tests with `@capture_logs(...)` from `test_utils_new.py` for logging. - 5. Convert all `self.assert...` calls to plain `assert` statements. - ---- - -### 3. [✔] `test/integration/client/test_ibkr_client_i.py` - -- **Migration Steps:** - 1. Create a new file: `test/integration/client/test_ibkr_client_i_new.py`. - 2. In the new file, convert the class into a series of test functions. - 3. Move the `setUp` logic into granular fixtures and module constants (avoid tuple-returning fixtures). - 4. Replace all `self.assert...` calls with plain `assert` statements and `pytest.raises`. - 5. Replace the `SafeAssertLogs` and `RaiseLogsContext` with the `capture_logs` utility from `test_utils_new.py`. - 6. Handle the class-level patch using an autouse fixture (e.g. `requests_mock`) so tests don't repeat patch boilerplate. - ---- - -### 4. [✔] `test/integration/client/test_ibkr_utils_i.py` - -- **Migration Steps:** - 1. Create a new file: `test/integration/client/test_ibkr_utils_i_new.py`. - 2. In the new file, convert all four classes into separate sets of test functions. - 3. Move `setUp` logic into fixtures where applicable. - 4. Convert all `self.assert...` calls to plain `assert` and `pytest.raises`. - 5. Replace `with self.assertLogs(...)` with the `capture_logs` context manager from `test_utils_new.py`. - ---- - -### 5. [✔] `test/integration/client/test_ibkr_ws_client_i.py` - -- **Migration Steps:** - 1. Create a new file: `test/integration/client/test_ibkr_ws_client_i_new.py`. - 2. In the new file, convert both `TestCase` subclasses into sets of test functions. - 3. Move the extensive `setUp` logic into `pytest` fixtures. - 4. Eliminate the `run_in_test_context` helper. Use the `mocker` fixture for patching and `@capture_logs(...)` from `test_utils_new.py` for logging. - 5. Convert all `self.assert...` calls to plain `assert` statements. - ---- - -### 6. [✔] `test/unit/support/test_py_utils_u.py` - -- **Migration Steps:** - 1. Create a new file: `test/unit/support/test_py_utils_u_new.py`. - 2. In the new file, convert all three classes into separate sets of test functions. - 3. Move the `setUp` method into a fixture. - 4. Convert all `self.assert...` methods and `with self.assertRaises` to plain `assert` statements and `with pytest.raises(...)`. - 5. Replace the `@patch` decorator with the `mocker` fixture. \ No newline at end of file From 38e079958965c56d14acc9692c2b085c5bb39c16 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 11:59:12 +0100 Subject: [PATCH 26/31] test: fixed Python<3.11 issues with \n\t by replacing \t with ` ` --- test/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 8fb7a233..8f5e889f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -53,23 +53,23 @@ def exact_log(self, expected_messages: OneOrMany[str]): """Assert that all expected messages appear in the captured logs.""" found, missing_expected = self._process_logs(expected_messages, lambda x, y: x == y) if len(missing_expected) > 0: - raise AssertionError(f"Expected exact log(s) not found:\n\t{'\n\t'.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") + raise AssertionError(f"Expected exact log(s) not found:\n {'\n '.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") def partial_log(self, expected_messages: OneOrMany[str]): """Assert that each expected message is a substring of at least one captured log message.""" found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) if len(missing_expected) > 0: - raise AssertionError(f"Expected partial log(s) not found:\n\t{'\n\t'.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") + raise AssertionError(f"Expected partial log(s) not found:\n {'\n '.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") def log_excludes(self, expected_messages: OneOrMany[str]): """Assert that none of the expected messages appear in any captured log message.""" found, _ = self._process_logs(expected_messages, lambda x, y: x in y) if found: - raise AssertionError(f"Unexpected log(s) found:\n\t{'\n\t'.join(found)}\n\nCurrent logs:\n{self.format_logs()}\n") + raise AssertionError(f"Unexpected log(s) found:\n {' '.join(found)}\n\nCurrent logs:\n{self.format_logs()}\n") def format_logs(self): """Return a formatted string of all captured log messages.""" - return f"\n{self} captured {len(self.output)} logs:\n[\n\t{'\n\t'.join(self.output)}\n]" + return f"\n{self} captured {len(self.output)} logs:\n[\n {'\n '.join(self.output)}\n]" def count_occurrences(self, msg: str): """Count the number of occurrences of a message in the captured logs.""" From 8a37f3edf8961e10cdce90958a455e4b41f0f561 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 12:16:30 +0100 Subject: [PATCH 27/31] test: fixed Python<3.11 issues with nested {...{\n\t ...}} by constructing the inner strings outside of f-string --- test/test_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 8f5e889f..ed0ca649 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -53,23 +53,27 @@ def exact_log(self, expected_messages: OneOrMany[str]): """Assert that all expected messages appear in the captured logs.""" found, missing_expected = self._process_logs(expected_messages, lambda x, y: x == y) if len(missing_expected) > 0: - raise AssertionError(f"Expected exact log(s) not found:\n {'\n '.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") + missing_expected_str = '\n\t'.join(missing_expected) + raise AssertionError(f"Expected exact log(s) not found:\n\t{missing_expected_str}\n\nActual logs:\n{self.format_logs()}\n") def partial_log(self, expected_messages: OneOrMany[str]): """Assert that each expected message is a substring of at least one captured log message.""" found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) if len(missing_expected) > 0: - raise AssertionError(f"Expected partial log(s) not found:\n {'\n '.join(missing_expected)}\n\nActual logs:\n{self.format_logs()}\n") + missing_expected_str = '\n\t'.join(missing_expected) + raise AssertionError(f"Expected partial log(s) not found:\n\t{missing_expected_str}\n\nActual logs:\n{self.format_logs()}\n") def log_excludes(self, expected_messages: OneOrMany[str]): """Assert that none of the expected messages appear in any captured log message.""" found, _ = self._process_logs(expected_messages, lambda x, y: x in y) if found: - raise AssertionError(f"Unexpected log(s) found:\n {' '.join(found)}\n\nCurrent logs:\n{self.format_logs()}\n") + found_str = '\n\t'.join(found) + raise AssertionError(f"Unexpected log(s) found:\n\t{found_str}\n\nCurrent logs:\n{self.format_logs()}\n") def format_logs(self): """Return a formatted string of all captured log messages.""" - return f"\n{self} captured {len(self.output)} logs:\n[\n {'\n '.join(self.output)}\n]" + output_str = '\n\t'.join(self.output) + return f"\n{self} captured {len(self.output)} logs:\n[\n\t{output_str}\n]" def count_occurrences(self, msg: str): """Count the number of occurrences of a message in the captured logs.""" From 3b12c228b595c2d6c96e53eb46f9dd226d17a4b7 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 24 Dec 2025 12:21:10 +0100 Subject: [PATCH 28/31] chore: removed leftover _NAME_TO_LEVEL = logging.getLevelNamesMapping() --- test/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index ed0ca649..329360aa 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,7 +9,6 @@ from ibind.support.logs import get_logger_children from ibind.support.py_utils import make_clean_stack, OneOrMany, UNDEFINED -_NAME_TO_LEVEL = logging.getLevelNamesMapping() # --- New Functions and Types --- From ec1e3d486e8d4f19bd3ab2f64b8537857522ebf6 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 7 Jan 2026 09:35:09 +0100 Subject: [PATCH 29/31] chore(test): made test imports uniform --- test/integration/base/test_websocket_client_i.py | 2 +- test/integration/client/test_ibkr_client_i.py | 4 ++-- test/integration/client/test_ibkr_ws_client_i.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/integration/base/test_websocket_client_i.py b/test/integration/base/test_websocket_client_i.py index 4237b648..b14d5e75 100644 --- a/test/integration/base/test_websocket_client_i.py +++ b/test/integration/base/test_websocket_client_i.py @@ -7,7 +7,7 @@ from ibind.base.ws_client import WsClient from ibind.support.py_utils import tname from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils import capture_logs +from test.test_utils import capture_logs _URL = 'wss://localhost:5000/v1/api/ws' _MAX_RECONNECT_ATTEMPTS = 4 diff --git a/test/integration/client/test_ibkr_client_i.py b/test/integration/client/test_ibkr_client_i.py index 986a70ed..b22ae18a 100644 --- a/test/integration/client/test_ibkr_client_i.py +++ b/test/integration/client/test_ibkr_client_i.py @@ -10,8 +10,8 @@ from ibind.client.ibkr_utils import StockQuery, filter_stocks from ibind.support.errors import ExternalBrokerError from ibind.support.logs import ibind_logs_initialize -from integration.client import ibkr_responses -from test_utils import CaptureLogsContext +from test.integration.client import ibkr_responses +from test.test_utils import CaptureLogsContext _URL = 'https://localhost:5000' diff --git a/test/integration/client/test_ibkr_ws_client_i.py b/test/integration/client/test_ibkr_ws_client_i.py index 4932cb94..b3cf9c72 100644 --- a/test/integration/client/test_ibkr_ws_client_i.py +++ b/test/integration/client/test_ibkr_ws_client_i.py @@ -9,8 +9,8 @@ from ibind import Result from ibind.client.ibkr_client import IbkrClient from ibind.client.ibkr_ws_client import IbkrWsClient, IbkrSubscriptionProcessor, IbkrWsKey -from integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock -from test_utils import capture_logs +from test.integration.base.websocketapp_mock import create_wsa_mock, init_wsa_mock +from test.test_utils import capture_logs _URL_WS = 'wss://localhost:5000/v1/api/ws' _URL_REST = 'https://localhost:5000' From f73d79c8734bf6eed51a7c793e49744246a5c74d Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 7 Jan 2026 09:46:25 +0100 Subject: [PATCH 30/31] docs(test): added docstring to test_utils.py --- test/test_utils.py | 214 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 192 insertions(+), 22 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 329360aa..e02a24af 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,29 +1,36 @@ import functools import inspect import logging -import os import traceback -from pathlib import Path -from typing import List, TypeVar +from typing import List, Union from ibind.support.logs import get_logger_children from ibind.support.py_utils import make_clean_stack, OneOrMany, UNDEFINED -# --- New Functions and Types --- +def _accepts_kwargs(func): + """ + Check if a function accepts **kwargs. -def accepts_kwargs(func): - """Returns True if func accepts **kwargs, else False.""" + Args: + func: A callable to inspect. + + Returns: + bool: True if the function accepts **kwargs, False otherwise. + """ sig = inspect.signature(func) for param in sig.parameters.values(): if param.kind == inspect.Parameter.VAR_KEYWORD: return True return False + # --- Logging Utilities --- class LoggingWatcher: - """Helper class for capturing and asserting logs during testing.""" + """ + Captures and asserts on log messages during testing. + """ def __init__(self, logger): self.logger = logger @@ -49,44 +56,87 @@ def _process_logs(self, expected_messages: OneOrMany[str], comparison: callable return found, missing_expected def exact_log(self, expected_messages: OneOrMany[str]): - """Assert that all expected messages appear in the captured logs.""" + """ + Assert that all expected messages appear exactly in the captured logs. + + Args: + expected_messages: A single message string or list of message strings to match. + + Raises: + AssertionError: If any expected message is not found in the captured logs. + """ found, missing_expected = self._process_logs(expected_messages, lambda x, y: x == y) if len(missing_expected) > 0: missing_expected_str = '\n\t'.join(missing_expected) raise AssertionError(f"Expected exact log(s) not found:\n\t{missing_expected_str}\n\nActual logs:\n{self.format_logs()}\n") def partial_log(self, expected_messages: OneOrMany[str]): - """Assert that each expected message is a substring of at least one captured log message.""" + """ + Assert that each expected message is a substring of at least one captured log. + + Args: + expected_messages: A single message string or list of message strings to match as substrings. + + Raises: + AssertionError: If any expected message is not found as a substring in the captured logs. + """ found, missing_expected = self._process_logs(expected_messages, lambda x, y: x in y) if len(missing_expected) > 0: missing_expected_str = '\n\t'.join(missing_expected) raise AssertionError(f"Expected partial log(s) not found:\n\t{missing_expected_str}\n\nActual logs:\n{self.format_logs()}\n") def log_excludes(self, expected_messages: OneOrMany[str]): - """Assert that none of the expected messages appear in any captured log message.""" + """ + Assert that none of the expected messages appear in any captured log. + + Args: + expected_messages: A single message string or list of message strings to exclude. + + Raises: + AssertionError: If any expected message is found in the captured logs. + """ found, _ = self._process_logs(expected_messages, lambda x, y: x in y) if found: found_str = '\n\t'.join(found) raise AssertionError(f"Unexpected log(s) found:\n\t{found_str}\n\nCurrent logs:\n{self.format_logs()}\n") def format_logs(self): - """Return a formatted string of all captured log messages.""" + """ + Return a formatted string of all captured log messages. + + Returns: + str: A formatted string containing all captured logs. + """ output_str = '\n\t'.join(self.output) return f"\n{self} captured {len(self.output)} logs:\n[\n\t{output_str}\n]" def count_occurrences(self, msg: str): - """Count the number of occurrences of a message in the captured logs.""" + """ + Count occurrences of a message in the captured logs. + + Args: + msg: The message substring to count. + + Returns: + int: The number of logs containing the message substring. + """ return sum(1 for log in self.output if msg in log) def print(self): - """Print the formatted logs.""" + """ + Print the formatted logs to stdout. + """ print(self.format_logs()) def __str__(self): return f'LoggingWatcher({self.logger.name})' + class _CapturingHandler(logging.Handler): - """A logging handler capturing all (raw and formatted) logging output.""" + """ + Internal logging handler that captures all logging output. + """ + def __init__(self, logger): logging.Handler.__init__(self) self.watcher = LoggingWatcher(logger) @@ -99,20 +149,38 @@ def emit(self, record): msg = self.format(record) self.watcher.output.append(msg) + class CaptureLogsContext: + """ + Context manager for capturing and validating log output during tests. + """ LOGGING_FORMAT = "%(message)s" def __init__( self, - logger='ibind', - level='DEBUG', + logger: str = 'ibind', + level: str = 'DEBUG', logger_level: str = None, - error_level='WARNING', - no_logs=UNDEFINED, - expected_errors=None, - partial_match=False, - attach_stack=True, + error_level: str = 'WARNING', + no_logs: Union[bool, UNDEFINED] = UNDEFINED, + expected_errors: List[str] = None, + partial_match: bool = False, + attach_stack: bool = True, ): + """ + Initialize a log capture context. + + Args: + logger (str): Logger name to capture. Defaults to 'ibind'. + level (str): Logging level to capture. Defaults to 'DEBUG'. + logger_level (str): Optional logger-specific level override. + error_level (str): Logging level threshold for unexpected logs. Defaults to 'WARNING'. + no_logs (bool): If True, assert no logs are produced. If False, assert logs are produced. + Defaults to UNDEFINED (no assertion). + expected_errors (list): List of expected error messages to match. + partial_match (bool): If True, match expected errors as substrings. Defaults to False. + attach_stack (bool): If True, attach stack traces to logs. Defaults to True. + """ self._logger = logger self.level = getattr(logging, level) if isinstance(level, str) else level self.logger_level = getattr(logging, logger_level) if isinstance(logger_level, str) else logger_level @@ -127,6 +195,7 @@ def __init__( def _monkey_patch_log(self, logger): original_log = logger._log + def new_log(level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): if extra is None: extra = {} @@ -147,9 +216,21 @@ def _restore_loggers(self, loggers): logger._log = logger.__old_log_method__ def logger_name(self): + """ + Get the logger name. + + Returns: + str: The name of the logger. + """ return self._logger.name if isinstance(self._logger, logging.Logger) else self._logger def acquire(self) -> LoggingWatcher: + """ + Acquire and configure the logger for capturing. + + Returns: + LoggingWatcher: A watcher object for asserting on captured logs. + """ self.logger = logging.getLogger(self.logger_name()) self.old_handlers = self.logger.handlers[:] self.old_level = self.logger.level @@ -202,6 +283,17 @@ def _process_exit_logs(self): self.watcher.exact_log(self.expected_errors) def release(self, exc_type=None, exc_val=None, exc_tb=None): + """ + Release and restore the logger to its original state. + + Args: + exc_type: Exception type if an exception occurred. + exc_val: Exception value if an exception occurred. + exc_tb: Exception traceback if an exception occurred. + + Returns: + bool: True if no exception occurred, False otherwise. + """ self.logger.handlers = self.old_handlers self.logger.propagate = self.old_propagate self.logger.setLevel(self.old_level) @@ -216,7 +308,25 @@ def __enter__(self) -> LoggingWatcher: def __exit__(self, exc_type, exc_val, exc_tb): return self.release(exc_type, exc_val, exc_tb) + def capture_logs(**ctx_kwargs): + """ + Decorator to capture and validate logs in a test function. + + Args: + **ctx_kwargs: Keyword arguments passed to CaptureLogsContext. + Common options: logger, level, error_level, expected_errors, partial_match. + + Returns: + callable: A decorator that wraps a test function to capture logs. + + Example: + @capture_logs(logger='myapp', expected_errors=['Error occurred']) + def test_something(): + # test code that logs + pass + """ + def decorator(test_func): @functools.wraps(test_func) def wrapper(*args, **kwargs): @@ -226,7 +336,7 @@ def wrapper(*args, **kwargs): log_exc = None cm = capture_log_context.acquire() - if accepts_kwargs(test_func): + if _accepts_kwargs(test_func): kwargs[logger_name] = cm try: @@ -249,13 +359,29 @@ def wrapper(*args, **kwargs): raise log_exc return rv + return wrapper + return decorator + # --- Time Mocking Utilities --- class MockTimeController: + """ + Mock time module for testing time-dependent code. + """ + def __init__(self, target_module, time_sequence=None, start_time=0.0): + """ + Initialize a mock time controller. + + Args: + target_module (str): Module name to inject the mock time into (eg. 'mymodule.submodule'). + time_sequence (list): Optional sequence of time values to return on successive calls. + If provided, time_sequence takes precedence over start_time. + start_time (float): Initial time value. Defaults to 0.0. Ignored if time_sequence is provided. + """ self.target_module = target_module if time_sequence is not None: self.time_sequence = list(time_sequence) @@ -266,16 +392,40 @@ def __init__(self, target_module, time_sequence=None, start_time=0.0): self.original_time_module = None def advance_time(self, seconds): + """ + Advance the mock time by the specified number of seconds. + + Args: + seconds (float): Number of seconds to advance. + + Raises: + ValueError: If using time_sequence mode. + """ if self.time_sequence is not None: raise ValueError("Cannot advance time when using time_sequence.") self.current_time += seconds def set_time(self, time_value): + """ + Set the mock time to a specific value. + + Args: + time_value (float): The time value to set. + + Raises: + ValueError: If using time_sequence mode. + """ if self.time_sequence is not None: raise ValueError("Cannot set time when using time_sequence.") self.current_time = time_value def mock_time(self): + """ + Get the current mock time value. + + Returns: + float: The current time value. If using time_sequence, returns the next value in the sequence. + """ if self.time_sequence is not None: if self.call_index < len(self.time_sequence): time_value = self.time_sequence[self.call_index] @@ -289,12 +439,15 @@ def mock_time(self): def __enter__(self): target_module_obj = __import__(self.target_module, fromlist=['']) self.original_time_module = target_module_obj.time + class MockTimeModule: def __init__(self, original_module, mock_time_func): self.original_module = original_module self.time = mock_time_func + def __getattr__(self, name): return getattr(self.original_module, name) + target_module_obj.time = MockTimeModule(self.original_time_module, self.mock_time) self.target_module_obj = target_module_obj return self @@ -302,5 +455,22 @@ def __getattr__(self, name): def __exit__(self, exc_type, exc_val, exc_tb): self.target_module_obj.time = self.original_time_module + def mock_module_time(target_module, time_sequence=None, start_time=0.0): + """ + Create a mock time controller for a target module. + + Args: + target_module (str): Module name to inject the mock time into. + time_sequence (list): Optional sequence of time values to return on successive calls. + start_time (float): Initial time value. Defaults to 0.0. + + Returns: + MockTimeController: A context manager for mocking time in the target module. + + Example: + with mock_module_time('mymodule', time_sequence=[1.0, 2.0, 3.0]): + # time.time() in mymodule will return 1.0, then 2.0, then 3.0 + pass + """ return MockTimeController(target_module, time_sequence=time_sequence, start_time=start_time) \ No newline at end of file From b577f69512d9f3c6b1c9abf3f3e6e065fe7df159 Mon Sep 17 00:00:00 2001 From: voyz Date: Wed, 7 Jan 2026 09:49:21 +0100 Subject: [PATCH 31/31] chore(test): fixed no_logs type hint in CaptureLogsContext --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index e02a24af..e33822c2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -162,7 +162,7 @@ def __init__( level: str = 'DEBUG', logger_level: str = None, error_level: str = 'WARNING', - no_logs: Union[bool, UNDEFINED] = UNDEFINED, + no_logs: Union[bool, object] = UNDEFINED, expected_errors: List[str] = None, partial_match: bool = False, attach_stack: bool = True,