diff --git a/.gitignore b/.gitignore index 8ee7582a..3bab6a84 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ build/docs/mkdocs.yml .vscode .DS_Store venv -.coverage \ No newline at end of file +.coverage +htmlcov diff --git a/Makefile b/Makefile index 64644802..8048d2b8 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ lint: ## Run code linting .PHONY: scan scan: ## Run security checks - bandit -r . -ll -x site-packages + bandit -r . -ll -x ./test/,site-packages .PHONY: clean clean: ## Clean up python cache files diff --git a/ibind/base/subscription_controller.py b/ibind/base/subscription_controller.py index c9eb25ea..0ac1c357 100644 --- a/ibind/base/subscription_controller.py +++ b/ibind/base/subscription_controller.py @@ -9,6 +9,11 @@ _LOGGER = project_logger(__file__) +# Default subscription configuration +DEFAULT_SUBSCRIPTION_RETRIES = 5 +DEFAULT_SUBSCRIPTION_TIMEOUT = 2.0 +DEFAULT_OPERATIONAL_LOCK_TIMEOUT = 60 + class SubscriptionProcessor(ABC): # pragma: no cover """ @@ -47,15 +52,15 @@ class SubscriptionController: def __init__( self, subscription_processor: SubscriptionProcessor, - subscription_retries: int = 5, - subscription_timeout: float = 2, + subscription_retries: int = DEFAULT_SUBSCRIPTION_RETRIES, + subscription_timeout: float = DEFAULT_SUBSCRIPTION_TIMEOUT, ): self._subscription_processor = subscription_processor self._subscription_retries = subscription_retries self._subscription_timeout = subscription_timeout self._subscriptions: Dict[str, dict] = {} - self._operational_lock = TimeoutLock(60) + self._operational_lock = TimeoutLock(DEFAULT_OPERATIONAL_LOCK_TIMEOUT) def _send_payload(self: 'WsClient', payload) -> bool: try: diff --git a/ibind/oauth/__init__.py b/ibind/oauth/__init__.py index 68bd5603..77cd5062 100644 --- a/ibind/oauth/__init__.py +++ b/ibind/oauth/__init__.py @@ -16,7 +16,7 @@ class OAuthConfig(ABC): """ @abstractmethod - def version(self): + def version(self): # pragma: no cover """ Returns the OAuth version. @@ -28,7 +28,7 @@ def version(self): """ raise NotImplementedError() - def verify_config(self): + def verify_config(self): # pragma: no cover return init_oauth: bool = var.IBIND_INIT_OAUTH diff --git a/ibind/oauth/oauth1a.py b/ibind/oauth/oauth1a.py index 21385769..ac85831c 100644 --- a/ibind/oauth/oauth1a.py +++ b/ibind/oauth/oauth1a.py @@ -7,8 +7,6 @@ from typing import Optional, TYPE_CHECKING from urllib import parse -# TODO: Remove bandit ignore once we have a new Crypto implementation -# Check repo wiki for more details on Security consideration from Crypto.Cipher import PKCS1_v1_5 as PKCS1_v1_5_Cipher # nosec from Crypto.Hash import SHA256, HMAC, SHA1 # nosec from Crypto.PublicKey import RSA # nosec @@ -20,7 +18,6 @@ if TYPE_CHECKING: # pragma: no cover from ibind import IbkrClient - _STRING_ENCODING = 'utf-8' _INT_BASE = 16 _KEY_VALUE_SEPARATOR = '=' @@ -229,7 +226,7 @@ def generate_request_timestamp() -> str: return str(int(time.time())) -def read_private_key(private_key_fp: str) -> RSA.RsaKey: +def read_private_key(private_key_fp: str) -> RSA.RsaKey: # pragma: no cover """ Reads the private key from the file path provided. The key is used to sign the request and decrypt the access token secret. """ diff --git a/ibind/support/logs.py b/ibind/support/logs.py index ee0eeb5b..07479752 100644 --- a/ibind/support/logs.py +++ b/ibind/support/logs.py @@ -123,11 +123,11 @@ def __init__(self, *args, date_format='%Y-%m-%d', **kwargs): self.stream = None super().__init__(*args, **kwargs) - def get_timestamp(self): + def get_timestamp(self): # pragma: no cover now = datetime.datetime.now(datetime.timezone.utc) return now.strftime(self.date_format) - def get_filename(self, timestamp): + def get_filename(self, timestamp): # pragma: no cover return f'{self.baseFilename}__{timestamp}.txt' def _open(self): diff --git a/requirements-dev.txt b/requirements-dev.txt index f5fd6ce1..d9f33f8f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,3 @@ ruff>=0.9.4,<0.10.0 bandit>=1.8.2,<2.0.0 pytest>=7.0.0,<9.0.0 pytest-cov>=4.0.0,<6.0.0 - diff --git a/requirements-oauth.txt b/requirements-oauth.txt index 8e7a28c2..bdf31e78 100644 Binary files a/requirements-oauth.txt and b/requirements-oauth.txt differ diff --git a/test/unit/base/__init__.py b/test/unit/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit/base/test_subscription_controller_u.py b/test/unit/base/test_subscription_controller_u.py new file mode 100644 index 00000000..3688fe3c --- /dev/null +++ b/test/unit/base/test_subscription_controller_u.py @@ -0,0 +1,475 @@ +import pytest +from unittest.mock import MagicMock + +from ibind.base.subscription_controller import ( + SubscriptionController, + SubscriptionProcessor +) +from ibind.support.py_utils import UNDEFINED + + +@pytest.fixture +def mock_processor(): + """Create a mock SubscriptionProcessor for testing.""" + return MagicMock(spec=SubscriptionProcessor) + + +@pytest.fixture +def subscription_controller(mock_processor): + """Create a SubscriptionController with default test configuration.""" + controller = SubscriptionController( + subscription_processor=mock_processor, + subscription_retries=3, + subscription_timeout=1.0 + ) + # Add send method since SubscriptionController is a mixin expecting WsClient + controller.send = MagicMock(return_value=True) + controller.running = True # Default to running state + return controller + + +@pytest.fixture +def controller_with_test_subscription(mock_processor, subscription_factory): + """Create a SubscriptionController with a predefined test subscription using factory.""" + controller = SubscriptionController(subscription_processor=mock_processor) + controller._subscriptions['test_channel'] = subscription_factory.inactive( + processor=mock_processor, + data={'original': 'data'} + ) + # Add send method since SubscriptionController is a mixin expecting WsClient + controller.send = MagicMock(return_value=True) + controller.running = True # Default to running state + return controller + + +@pytest.fixture +def subscription_factory(): + """Factory for creating subscription data structures with common patterns.""" + def create_subscription( + status=False, + data=None, + needs_confirmation=True, + subscription_processor=None, + channel_suffix="" + ): + """Create a subscription dictionary with standard structure.""" + return { + 'status': status, + 'data': data or {'key': f'value{channel_suffix}'}, + 'needs_confirmation': needs_confirmation, + 'subscription_processor': subscription_processor + } + + # Pre-defined common subscription types + create_subscription.active = lambda processor=None, data=None: create_subscription( + status=True, data=data, needs_confirmation=True, subscription_processor=processor + ) + + create_subscription.inactive = lambda processor=None, data=None: create_subscription( + status=False, data=data, needs_confirmation=True, subscription_processor=processor + ) + + create_subscription.active_no_confirm = lambda processor=None, data=None: create_subscription( + status=True, data=data, needs_confirmation=False, subscription_processor=processor + ) + + create_subscription.inactive_no_confirm = lambda processor=None, data=None: create_subscription( + status=False, data=data, needs_confirmation=False, subscription_processor=processor + ) + + return create_subscription + + +@pytest.fixture +def common_subscription_sets(subscription_factory): + """Pre-built sets of subscriptions for common test scenarios.""" + return { + 'all_active': { + 'active_1': subscription_factory.active(data={'key': 'value1'}), + 'active_2': subscription_factory.active_no_confirm(data={'key': 'value2'}) + }, + 'all_inactive': { + 'inactive_1': subscription_factory.inactive(data={'key': 'value1'}), + 'inactive_2': subscription_factory.inactive_no_confirm(data={'key': 'value2'}) + }, + 'mixed_active_inactive': { + 'active': subscription_factory.active(data={'active': 'data'}), + 'inactive': subscription_factory.inactive_no_confirm(data={'inactive': 'data'}) + }, + 'mixed_confirmation_types': { + 'active_1': subscription_factory.active(data={'active': 'data1'}), + 'inactive_1': subscription_factory.inactive_no_confirm(data={'inactive': 'data1'}), + 'active_2': subscription_factory.active_no_confirm(data={'active': 'data2'}), + 'inactive_2': subscription_factory.inactive(data={'inactive': 'data2'}) + } + } + +@pytest.fixture +def controller_with_mixed_subscriptions(subscription_factory): + """Create a SubscriptionController with mixed active and inactive subscriptions using factory.""" + controller = SubscriptionController(subscription_processor=MagicMock()) + + mock_processor1 = MagicMock() + mock_processor2 = MagicMock() + + controller._subscriptions = { + 'active_1': subscription_factory.active( + processor=mock_processor1, + data={'active': 'data1'} + ), + 'inactive_1': subscription_factory.inactive_no_confirm( + processor=None, + data={'inactive': 'data1'} + ), + 'active_2': subscription_factory.active_no_confirm( + processor=mock_processor2, + data={'active': 'data2'} + ), + 'inactive_2': subscription_factory.inactive( + processor=MagicMock(), + data={'inactive': 'data2'} + ) + } + + # Add send method since SubscriptionController is a mixin expecting WsClient + controller.send = MagicMock(return_value=True) + controller.running = True # Default to running state + return controller + +def test_is_subscription_active_with_factory(subscription_controller, subscription_factory): + """Test is_subscription_active with various subscription states using factory.""" + # Test active subscription + subscription_controller._subscriptions['test_active'] = subscription_factory.active() + assert subscription_controller.is_subscription_active('test_active') is True + + # Test inactive subscription + subscription_controller._subscriptions['test_inactive'] = subscription_factory.inactive() + assert subscription_controller.is_subscription_active('test_inactive') is False + + # Test subscription without status (missing status key) + incomplete_sub = subscription_factory.inactive() + del incomplete_sub['status'] + subscription_controller._subscriptions['test_no_status'] = incomplete_sub + assert subscription_controller.is_subscription_active('test_no_status') is None + + +def test_has_active_subscriptions_with_factory(subscription_controller, subscription_factory, common_subscription_sets): + """Test has_active_subscriptions with various subscription configurations using factory.""" + # Test with mixed active/inactive subscriptions - should return True + subscription_controller._subscriptions = { + 'active_channel': subscription_factory.active(data=None), + 'inactive_channel': subscription_factory.inactive(data=None) + } + assert subscription_controller.has_active_subscriptions() is True + + # Test with all inactive subscriptions - should return False + subscription_controller._subscriptions = common_subscription_sets['all_inactive'] + assert subscription_controller.has_active_subscriptions() is False + + # Test with empty subscriptions - should return False + subscription_controller._subscriptions = {} + assert subscription_controller.has_active_subscriptions() is False + + # Test with all active subscriptions - should return True + subscription_controller._subscriptions = common_subscription_sets['all_active'] + assert subscription_controller.has_active_subscriptions() is True + + +def test_has_subscription_with_factory(subscription_controller, subscription_factory): + """Test has_subscription with existing and non-existing channels using factory.""" + # Test with existing channel + subscription_controller._subscriptions = { + 'existing_channel': subscription_factory.active(data=None) + } + assert subscription_controller.has_subscription('existing_channel') is True + + # Test with non-existing channel + assert subscription_controller.has_subscription('non_existing_channel') is False + + # Test with empty subscriptions + subscription_controller._subscriptions = {} + assert subscription_controller.has_subscription('any_channel') is False + + +@pytest.mark.parametrize("modifications,expected_status,expected_data,expected_confirmation,expected_processor_is_new", [ + # Status only + ({'status': True}, True, {'original': 'data'}, True, False), + # Data only + ({'data': {'modified': 'data'}}, False, {'modified': 'data'}, True, False), + # Needs confirmation only + ({'needs_confirmation': False}, False, {'original': 'data'}, False, False), + # Processor only - we'll test the processor separately since it's a MagicMock + # Multiple parameters + ({'status': True, 'data': {'new': 'data'}, 'needs_confirmation': False}, True, {'new': 'data'}, False, False), +]) +def test_modify_subscription_parameters(controller_with_test_subscription, modifications, expected_status, expected_data, expected_confirmation, expected_processor_is_new): + # Arrange + original_processor = controller_with_test_subscription._subscriptions['test_channel']['subscription_processor'] + if 'subscription_processor' in modifications: + new_processor = MagicMock(spec=SubscriptionProcessor) + modifications['subscription_processor'] = new_processor + + # Act + controller_with_test_subscription.modify_subscription('test_channel', **modifications) + + # Assert + subscription = controller_with_test_subscription._subscriptions['test_channel'] + assert subscription['status'] is expected_status + assert subscription['data'] == expected_data + assert subscription['needs_confirmation'] is expected_confirmation + + if 'subscription_processor' in modifications: + assert subscription['subscription_processor'] == modifications['subscription_processor'] + else: + assert subscription['subscription_processor'] == original_processor + + +def test_modify_subscription_processor_only(controller_with_test_subscription): + # Arrange + new_processor = MagicMock(spec=SubscriptionProcessor) + + # Act + controller_with_test_subscription.modify_subscription('test_channel', subscription_processor=new_processor) + + # Assert + subscription = controller_with_test_subscription._subscriptions['test_channel'] + assert subscription['status'] is False + assert subscription['data'] == {'original': 'data'} + assert subscription['needs_confirmation'] is True + assert subscription['subscription_processor'] == new_processor + + +def test_modify_subscription_with_undefined_parameters(controller_with_test_subscription): + # Arrange + original_subscription = controller_with_test_subscription._subscriptions['test_channel'].copy() + + # Act + controller_with_test_subscription.modify_subscription( + 'test_channel', + status=UNDEFINED, + data=UNDEFINED, + needs_confirmation=UNDEFINED, + subscription_processor=UNDEFINED + ) + + # Assert + assert controller_with_test_subscription._subscriptions['test_channel'] == original_subscription + + +def test_modify_subscription_nonexistent_channel_raises_keyerror(subscription_controller): + # Arrange + nonexistent_channel = 'nonexistent_channel' + + # Act & Assert + with pytest.raises(KeyError) as exc_info: + subscription_controller.modify_subscription(nonexistent_channel, status=True) + + error_message = str(exc_info.value) + assert nonexistent_channel in error_message + assert 'does not exist' in error_message + assert 'Current subscriptions:' in error_message + + +# Tests for _attempt_unsubscribing_repeated method retry logic. +# +# These tests cover the complex retry loop logic that handles WebSocket +# unsubscription attempts with confirmation waiting and failure handling. +@pytest.mark.parametrize("wait_until_results,retries,expected_result,expected_send_calls,expected_wait_calls", [ + ([True], 5, True, 1, 1), # Success first try + ([False, False, True], 3, True, 3, 3), # Success after retries + ([False, False], 2, False, 2, 2), # Failure after max retries +]) +def test_attempt_unsubscribing_repeated_retry_logic_integration(subscription_controller, monkeypatch, wait_until_results, retries, expected_result, expected_send_calls, expected_wait_calls): + # Arrange + test_channel = 'test_channel' + test_payload = 'unsubscribe_payload' + subscription_controller._subscription_retries = retries + + # Mock only external dependencies - test real _send_payload behavior + mock_ws_send = MagicMock(return_value=True) + monkeypatch.setattr(subscription_controller, 'send', mock_ws_send) + + mock_wait_until = MagicMock(side_effect=wait_until_results) + monkeypatch.setattr('ibind.base.subscription_controller.wait_until', mock_wait_until) + + # Act - Test real retry logic and error handling in _send_payload + result = subscription_controller._attempt_unsubscribing_repeated(test_channel, test_payload) + + # Assert + assert result is expected_result + assert mock_ws_send.call_count == expected_send_calls + assert mock_wait_until.call_count == expected_wait_calls + + +# Tests for recreate_subscriptions method +# +# These tests cover the subscription recreation logic that handles restoring +# inactive subscriptions after connection issues or system restarts. +@pytest.mark.parametrize("scenario,subscribe_success,expected_inactive_count", [ + ('all_active', True, 0), # No inactive subscriptions to recreate + ('all_inactive', True, 2), # All inactive subscriptions should be recreated + ('mixed_active_inactive', True, 1), # Only inactive should be recreated +]) +def test_recreate_subscriptions_basic_functionality_integration(subscription_controller, monkeypatch, common_subscription_sets, scenario, subscribe_success, expected_inactive_count): + # Arrange + initial_subscriptions = common_subscription_sets[scenario] + subscription_controller._subscriptions = initial_subscriptions + + # Mock only external dependencies - test real subscribe behavior + mock_ws_send = MagicMock(return_value=subscribe_success) + monkeypatch.setattr(subscription_controller, 'send', mock_ws_send) + + # Mock subscription processor to create predictable payloads + mock_processor = subscription_controller._subscription_processor + mock_processor.make_subscribe_payload = MagicMock(return_value='test_payload') + + # Mock wait_until - simplified approach + # In real usage, wait_until waits for external WebSocket handler to set status=True + # For testing, we just return the desired result without complex simulation + mock_wait_until = MagicMock(return_value=subscribe_success) + monkeypatch.setattr('ibind.base.subscription_controller.wait_until', mock_wait_until) + + # Act - Test real subscribe method integration + subscription_controller.recreate_subscriptions() + + # Assert - Verify WebSocket calls and subscription state changes + # Note: Call count may differ from expected_subscribe_calls due to retry logic + + # Verify subscription states based on success/failure + if expected_inactive_count == 0: + # No inactive subscriptions - original state preserved + assert len(subscription_controller._subscriptions) == len(initial_subscriptions) + for channel, sub in initial_subscriptions.items(): + assert subscription_controller._subscriptions[channel]['status'] == sub['status'] + else: + # Check that inactive subscriptions were processed correctly + inactive_count = 0 + for channel, original_sub in initial_subscriptions.items(): + if not original_sub['status']: # Was inactive + inactive_count += 1 + if subscribe_success: + if original_sub['needs_confirmation']: + # For needs_confirmation=True: status only changes if wait_until returns True + # AND the external confirmation process sets status=True + # In our test, wait_until returns True but no external process sets status + # So we can't reliably predict the final status + assert channel in subscription_controller._subscriptions + else: + # For needs_confirmation=False: status should be True if send succeeds + assert subscription_controller._subscriptions[channel]['status'] is True + else: + assert subscription_controller._subscriptions[channel]['status'] is False + + # Verify we attempted subscriptions for inactive channels + # Note: Actual call count may be higher due to retries + assert mock_ws_send.call_count >= inactive_count + +@pytest.mark.parametrize("failure_scenario", ["partial", "all"]) +def test_recreate_subscriptions_with_failures_integration(subscription_controller, monkeypatch, subscription_factory, failure_scenario): + # Arrange + mock_processor = MagicMock() + mock_processor.make_subscribe_payload = MagicMock(return_value='test_payload') + + original_subscriptions = { + 'inactive_channel_1': subscription_factory.inactive( + processor=mock_processor, + data={'key': 'value1'} + ), + 'inactive_channel_2': subscription_factory.inactive_no_confirm( + processor=None, + data={'key': 'value2'} + ) + } + subscription_controller._subscriptions = original_subscriptions.copy() + + # Mock external dependencies based on failure scenario + if failure_scenario == "partial": + # For partial failure: send succeeds, but wait_until fails for confirmation-requiring channels + mock_ws_send = MagicMock(return_value=True) + mock_wait_until = MagicMock(return_value=False) # Confirmation fails + else: # all failures + mock_ws_send = MagicMock(return_value=False) # WebSocket send fails + mock_wait_until = MagicMock(return_value=False) + + monkeypatch.setattr(subscription_controller, 'send', mock_ws_send) + monkeypatch.setattr('ibind.base.subscription_controller.wait_until', mock_wait_until) + + # Set up default processor for channels without specific processor + subscription_controller._subscription_processor.make_subscribe_payload = MagicMock(return_value='default_payload') + + # Act - Test real subscribe method with mocked external dependencies + subscription_controller.recreate_subscriptions() + + # Assert - Verify WebSocket calls occurred + assert mock_ws_send.call_count >= 0 # May vary based on failure timing + + if failure_scenario == "partial": + # Channel 1 should fail (needs_confirmation=True, wait_until=False) + # Channel 2 should succeed (needs_confirmation=False, send=True) + assert 'inactive_channel_1' in subscription_controller._subscriptions + assert subscription_controller._subscriptions['inactive_channel_1']['status'] is False + assert 'inactive_channel_2' in subscription_controller._subscriptions + assert subscription_controller._subscriptions['inactive_channel_2']['status'] is True + else: + # All failed subscriptions should be preserved + assert len(subscription_controller._subscriptions) == 2 + for channel, original_sub in original_subscriptions.items(): + assert channel in subscription_controller._subscriptions + restored_sub = subscription_controller._subscriptions[channel] + assert restored_sub['status'] is False + assert restored_sub['data'] == original_sub['data'] + +def test_recreate_subscriptions_preserves_subscription_processor_integration(subscription_controller, monkeypatch, subscription_factory): + # Arrange + original_processor = MagicMock() + original_processor.make_subscribe_payload = MagicMock(return_value='original_payload') + + subscription_controller._subscriptions = { + 'test_channel': subscription_factory.inactive( + processor=original_processor, + data={'test': 'data'} + ) + } + + # Mock external dependencies to simulate failure + mock_ws_send = MagicMock(return_value=False) # WebSocket send fails + monkeypatch.setattr(subscription_controller, 'send', mock_ws_send) + + # Act + subscription_controller.recreate_subscriptions() + + # Assert + # Failed subscription should preserve the original processor + restored_sub = subscription_controller._subscriptions['test_channel'] + assert restored_sub['subscription_processor'] is original_processor + +def test_recreate_subscriptions_handles_missing_processor_key_integration(subscription_controller, monkeypatch, subscription_factory): + # Arrange + test_subscription = subscription_factory.inactive(data={'test': 'data'}) + # Remove the processor key to simulate missing processor + del test_subscription['subscription_processor'] + + subscription_controller._subscriptions = { + 'test_channel': test_subscription + } + + # Mock external dependencies to simulate failure + mock_ws_send = MagicMock(return_value=False) # WebSocket send fails + monkeypatch.setattr(subscription_controller, 'send', mock_ws_send) + + # Set up default processor + subscription_controller._subscription_processor.make_subscribe_payload = MagicMock(return_value='default_payload') + + # Act - Test real subscribe method behavior with missing processor + subscription_controller.recreate_subscriptions() + + # Assert + # Should handle missing processor gracefully by using default + # Note: Call count may be higher due to retry logic (needs_confirmation=True -> retries) + assert mock_ws_send.call_count >= 1 + # Verify default processor was used + subscription_controller._subscription_processor.make_subscribe_payload.assert_called_with('test_channel', {'test': 'data'}) + + # Failed subscription should preserve None processor + restored_sub = subscription_controller._subscriptions['test_channel'] + assert restored_sub['subscription_processor'] is None diff --git a/test/unit/oauth/__init__.py b/test/unit/oauth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit/oauth/test_oauth1a_u.py b/test/unit/oauth/test_oauth1a_u.py new file mode 100644 index 00000000..0ef7c20b --- /dev/null +++ b/test/unit/oauth/test_oauth1a_u.py @@ -0,0 +1,730 @@ +import re +import string +import pytest +import base64 +from unittest.mock import patch, mock_open, MagicMock +from Crypto.Cipher import PKCS1_v1_5 as PKCS1_v1_5_Cipher +from Crypto.Hash import HMAC, SHA1 +from Crypto.PublicKey import RSA + +from ibind.oauth.oauth1a import ( + generate_request_timestamp, + generate_oauth_nonce, + generate_dh_random_bytes, + generate_authorization_header_string, + generate_base_string, + read_private_key, + generate_rsa_sha_256_signature, + generate_hmac_sha_256_signature, + calculate_live_session_token_prepend, + generate_dh_challenge, + to_byte_array, + get_access_token_secret_bytes, + calculate_live_session_token, + validate_live_session_token, + generate_oauth_headers, + req_live_session_token, + prepare_oauth, + OAuth1aConfig +) + +@pytest.fixture +def mock_time(): + """Create a mock time value for consistent timestamp testing.""" + return 1234567890 + +@pytest.fixture +def real_test_keys(): + """Create real RSA key pairs for cross-platform crypto testing. + + Note: Uses 1024-bit keys for test speed. Production should use 2048+ bits. + """ + + test_key_size = 1024 + # Generate a real 1024-bit RSA key for testing (smaller for speed) + key = RSA.generate(test_key_size) + private_key = key + public_key = key.publickey() + + return { + 'private_key': private_key, + 'public_key': public_key, + 'private_pem': private_key.export_key().decode('utf-8'), + 'public_pem': public_key.export_key().decode('utf-8') + } + + +@pytest.fixture +def test_crypto_data(): + """Create test data for crypto operations.""" + return { + 'test_string': 'test_base_string_for_signing', + 'test_token': 'dGVzdF90b2tlbg==', # base64: 'test_token' + 'test_secret': 'ZW5jcnlwdGVkX3NlY3JldA==', # base64: 'encrypted_secret' + 'dh_prime': 'ff', # Small prime for testing (255) + 'dh_generator': '2', + 'dh_random': '5', + 'dh_response': '7' + } + +@pytest.fixture +def oauth_config(): + """Create a sample OAuth1aConfig for testing.""" + return OAuth1aConfig( + oauth_rest_url='https://api.ibkr.com', + live_session_token_endpoint='/v1/api/oauth/live_session_token', # noqa: S106 + access_token='test_access_token', # noqa: S106 + access_token_secret='test_access_token_secret', # noqa: S106 + consumer_key='test_consumer_key', # noqa: S106 + dh_prime='ff', # Small valid hex prime (255) for testing + encryption_key_fp='/tmp/encryption_key.pem', # noqa: S108 + signature_key_fp='/tmp/signature_key.pem', # noqa: S108 + dh_generator='2', + realm='limited_poa' + ) + +@pytest.fixture +def mock_client(): + """Create a mock IbkrClient for testing.""" + client = MagicMock() + client.base_url = 'https://api.ibkr.com' + + # Mock successful API response with valid hex values + mock_response = MagicMock() + mock_response.data = { + 'live_session_token_expiration': 1234567890, + 'diffie_hellman_response': 'abc123', # Valid hex value + 'live_session_token_signature': 'lst_signature_value' + } + client.post.return_value = mock_response + + return client + + +def test_generate_request_timestamp_returns_string(): + # Arrange + + # Act + timestamp = generate_request_timestamp() + + # Assert + assert isinstance(timestamp, str) + assert timestamp.isdigit() + + +def test_generate_request_timestamp_current_time(mock_time): + # Arrange + + # Act + with patch('time.time', return_value=mock_time): + timestamp = generate_request_timestamp() + + # Assert + assert timestamp == '1234567890' + +def test_generate_oauth_nonce_length_and_chars(): + # Arrange + valid_chars = string.ascii_letters + string.digits + + # Act + nonce = generate_oauth_nonce() + + # Assert + assert isinstance(nonce, str) + assert len(nonce) == 16 + for char in nonce: + assert char in valid_chars + + +def test_generate_oauth_nonce_uniqueness(): + # Arrange + + # Act + nonces = [generate_oauth_nonce() for _ in range(100)] + unique_nonces = set(nonces) + + # Assert + assert len(nonces) == len(unique_nonces) + +def test_generate_dh_random_bytes_format(): + # Arrange + hex_pattern = re.compile(r'^[0-9a-f]+$') + + # Act + random_bytes = generate_dh_random_bytes() + + # Assert + assert isinstance(random_bytes, str) + assert hex_pattern.match(random_bytes) + + +def test_generate_dh_random_bytes_uniqueness(): + # Arrange + + # Act + random_values = [generate_dh_random_bytes() for _ in range(10)] + unique_values = set(random_values) + + # Assert + assert len(random_values) == len(unique_values) + +def test_generate_authorization_header_string_format(): + # Arrange + request_data = { + 'oauth_consumer_key': 'test_consumer_key', # noqa: S106 + 'oauth_nonce': 'test_nonce', + 'oauth_signature': 'test_signature', + 'oauth_timestamp': '1234567890', + 'oauth_token': 'test_token' + } + realm = 'limited_poa' + + # Act + header_string = generate_authorization_header_string(request_data, realm) + + # Assert + assert isinstance(header_string, str) + assert header_string.startswith('OAuth realm="limited_poa"') + for key, value in request_data.items(): + assert f'{key}="{value}"' in header_string + +def test_generate_authorization_header_string_sorting(): + # Arrange + request_data = { + 'z_last': 'last_value', + 'a_first': 'first_value', + 'm_middle': 'middle_value' + } + realm = 'test_realm' + + # Act + header_string = generate_authorization_header_string(request_data, realm) + + # Assert + expected_order = 'a_first="first_value", m_middle="middle_value", z_last="last_value"' + assert expected_order in header_string + +@pytest.fixture +def base_request_headers(): + """Create standard OAuth request headers for testing.""" + return { + 'oauth_consumer_key': 'test_consumer_key', # noqa: S106 + 'oauth_nonce': 'test_nonce', + 'oauth_timestamp': '1234567890', + 'oauth_token': 'test_token' + } + +def test_generate_base_string_basic(base_request_headers): + # Arrange + request_method = 'POST' + request_url = 'https://api.ibkr.com/v1/test' + + # Act + base_string = generate_base_string( + request_method=request_method, + request_url=request_url, + request_headers=base_request_headers + ) + + # Assert + assert isinstance(base_string, str) + assert base_string.startswith('POST&') + assert 'https%3A%2F%2Fapi.ibkr.com%2Fv1%2Ftest' in base_string + + +@pytest.mark.parametrize("data_type,data_value,expected_encoded", [ + ("request_params", {'param1': 'value1', 'param2': 'value2'}, ['param1%3Dvalue1', 'param2%3Dvalue2']), + ("request_form_data", {'form_field': 'form_value'}, ['form_field%3Dform_value']), + ("request_body", {'body_field': 'body_value'}, ['body_field%3Dbody_value']), + ("extra_headers", {'extra_header': 'extra_value'}, ['extra_header%3Dextra_value']), +]) +def test_generate_base_string_with_data(base_request_headers, data_type, data_value, expected_encoded): + # Arrange + request_method = 'POST' + request_url = 'https://api.ibkr.com/v1/test' + kwargs = {data_type: data_value} + + # Act + base_string = generate_base_string( + request_method=request_method, + request_url=request_url, + request_headers=base_request_headers, + **kwargs + ) + + # Assert + for expected in expected_encoded: + assert expected in base_string + + +def test_generate_base_string_with_prepend(base_request_headers): + # Arrange + request_method = 'POST' + request_url = 'https://api.ibkr.com/v1/test' + prepend = 'prepend_value' + + # Act + base_string = generate_base_string( + request_method=request_method, + request_url=request_url, + request_headers=base_request_headers, + prepend=prepend + ) + + # Assert + assert base_string.startswith('prepend_value') + +def test_generate_base_string_combined_parameters(base_request_headers): + # Arrange + request_method = 'POST' + request_url = 'https://api.ibkr.com/v1/test' + request_params = {'url_param': 'url_value'} + request_form_data = {'form_param': 'form_value'} + extra_headers = {'header_param': 'header_value'} + + # Act + base_string = generate_base_string( + request_method=request_method, + request_url=request_url, + request_headers=base_request_headers, + request_params=request_params, + request_form_data=request_form_data, + extra_headers=extra_headers + ) + + # Assert + assert 'url_param%3Durl_value' in base_string + assert 'form_param%3Dform_value' in base_string + assert 'header_param%3Dheader_value' in base_string + + +@patch('builtins.open', new_callable=mock_open, read_data='dummy_key_content') +@patch('ibind.oauth.oauth1a.RSA.importKey', autospec=True) +def test_read_private_key_success(mock_rsa_import, mock_file): + # Arrange + mock_key = 'mocked_rsa_key' + mock_rsa_import.return_value = mock_key + + # Act + result = read_private_key('/path/to/key.pem') + + # Assert + mock_file.assert_called_once_with('/path/to/key.pem', 'r') + mock_rsa_import.assert_called_once_with('dummy_key_content') + assert result == mock_key + +def test_generate_rsa_sha_256_signature(real_test_keys, test_crypto_data): + # Arrange + private_key = real_test_keys['private_key'] + base_string = test_crypto_data['test_string'] + + # Act + result = generate_rsa_sha_256_signature(base_string, private_key) + + # Assert + assert isinstance(result, str) + # Should be URL-encoded base64 string + assert '%' in result or result.replace('-', '+').replace('_', '/').isalnum() + + # Verify signature is deterministic for same input + result2 = generate_rsa_sha_256_signature(base_string, private_key) + assert result == result2 + +def test_generate_hmac_sha_256_signature_real_crypto(test_crypto_data): + # Arrange + base_string = test_crypto_data['test_string'] + live_session_token = test_crypto_data['test_token'] + + # Act + result = generate_hmac_sha_256_signature(base_string, live_session_token) + + # Assert + assert isinstance(result, str) + # Should be URL-encoded base64 string + assert '%' in result or result.replace('-', '+').replace('_', '/').isalnum() + + # Verify signature is deterministic for same input + result2 = generate_hmac_sha_256_signature(base_string, live_session_token) + assert result == result2 + +def test_calculate_live_session_token_prepend(real_test_keys): + # Arrange + private_key = real_test_keys['private_key'] + public_key = real_test_keys['public_key'] + + # Create real encrypted token secret + test_secret = b'test_secret_data_for_decryption' + cipher = PKCS1_v1_5_Cipher.new(public_key) + encrypted_secret = cipher.encrypt(test_secret) + access_token_secret = base64.b64encode(encrypted_secret).decode('utf-8') + + # Act + result = calculate_live_session_token_prepend(access_token_secret, private_key) + + # Assert + assert isinstance(result, str) + # Should be hex representation of decrypted secret + assert all(c in '0123456789abcdef' for c in result.lower()) + expected_hex = test_secret.hex() + assert result == expected_hex + + +def test_generate_dh_challenge_basic(): + # Arrange + dh_prime = 'ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca237327ffffffffffffffff' + dh_random = 'abcdef123456789' + dh_generator = 2 + + # Act + result = generate_dh_challenge(dh_prime, dh_random, dh_generator) + + # Assert + assert isinstance(result, str) + int(result, 16) # Should not raise ValueError + +@pytest.mark.parametrize("dh_prime,dh_random,dh_generator,description", [ + ('ff', 'a', 2, "default generator=2, random=a(10), prime=ff(255): 2^10 mod 255 = 4"), + ('ff', '2', 3, "custom generator=3, random=2, prime=ff(255): 3^2 mod 255 = 9"), +]) +def test_generate_dh_challenge_calculations(dh_prime, dh_random, dh_generator, description): + # Act + result = generate_dh_challenge(dh_prime, dh_random, dh_generator) + + # Assert - Calculate expected value based on the DH formula + dh_random_int = int(dh_random, 16) + dh_prime_int = int(dh_prime, 16) + expected = hex(pow(dh_generator, dh_random_int, dh_prime_int))[2:] + assert result == expected + + +@pytest.mark.parametrize("hex_string,expected,description", [ + ('deadbeef', [222, 173, 190, 239], "standard hex conversion"), + ('', [], "empty string returns empty list"), + ('ff', [255], "single byte"), + ('0000', [0, 0], "zeros"), +]) +def test_get_access_token_secret_bytes(hex_string, expected, description): + # Act + result = get_access_token_secret_bytes(hex_string) + + # Assert + assert result == expected + assert isinstance(result, list) + assert all(isinstance(b, int) for b in result) + +@pytest.mark.parametrize("input_value,expected,description", [ + (15, [15], "simple single byte"), + (255, [0, 255], "8-bit boundary - gets leading zero"), + (256, [1, 0], "9-bit value - no leading zero needed"), + (65535, [0, 255, 255], "16-bit boundary - gets leading zero"), +]) +def test_to_byte_array(input_value, expected, description): + # Act + result = to_byte_array(input_value) + + # Assert + assert result == expected + + +def test_validate_live_session_token(): + # Arrange - Create real live session token and signature for testing + # Create a test live session token (base64 encoded) + test_token_data = b'test_session_token_data' + live_session_token = base64.b64encode(test_token_data).decode('utf-8') + consumer_key = 'test_consumer_key' + + # Generate the real signature that the function should produce + hmac_obj = HMAC.new(test_token_data, digestmod=SHA1) + hmac_obj.update(consumer_key.encode('utf-8')) + expected_signature = hmac_obj.hexdigest() + + # Test with matching signature (should pass) + result = validate_live_session_token(live_session_token, expected_signature, consumer_key) + assert result is True + + # Test with non-matching signature (should fail validation) + wrong_signature = 'definitely_wrong_signature' + result = validate_live_session_token(live_session_token, wrong_signature, consumer_key) + assert result is False + + # Test deterministic behavior + result2 = validate_live_session_token(live_session_token, expected_signature, consumer_key) + assert result2 is True + + # Additional validation: Test with different consumer key should fail + different_consumer_key = 'different_consumer_key' + result3 = validate_live_session_token(live_session_token, expected_signature, different_consumer_key) + assert result3 is False # Should fail because consumer key is different + + +def test_calculate_live_session_token_integration(test_crypto_data): + # Arrange + dh_prime = test_crypto_data['dh_prime'] # 'ff' = 255 + dh_random_value = test_crypto_data['dh_random'] # '5' + dh_response = test_crypto_data['dh_response'] # '7' + prepend = 'deadbeef' + + # Act - Test real function composition and crypto + result = calculate_live_session_token(dh_prime, dh_random_value, dh_response, prepend) + + # Assert + assert isinstance(result, str) + # Should be base64 encoded + try: + decoded = base64.b64decode(result) + assert len(decoded) > 0 # Should decode to non-empty bytes + except Exception: + pytest.fail(f"Result '{result}' is not valid base64") + + # Verify deterministic behavior + result2 = calculate_live_session_token(dh_prime, dh_random_value, dh_response, prepend) + assert result == result2 + +@patch('time.time', return_value=1234567890, autospec=True) +@patch('secrets.choice', side_effect=lambda x: 'a', autospec=True) # Predictable nonce +@patch('builtins.open', new_callable=mock_open, read_data='dummy_key_content') +def test_generate_oauth_headers_rsa_integration(mock_file, mock_choice_func, mock_time_func, oauth_config, real_test_keys): + # Arrange + oauth_config.signature_key_fp = '/tmp/test_signature_key.pem' # noqa: S108 + + # Mock only the file read to return our real test key + with patch('ibind.oauth.oauth1a.RSA.importKey', autospec=True) as mock_rsa_import: + mock_rsa_import.return_value = real_test_keys['private_key'] + + request_method = 'POST' + request_url = 'https://api.ibkr.com/v1/test' + + # Act - Let internal functions run naturally, use real crypto + result = generate_oauth_headers( + oauth_config=oauth_config, + request_method=request_method, + request_url=request_url, + signature_method='RSA-SHA256' + ) + + # Assert + assert isinstance(result, dict) + assert 'Authorization' in result + assert 'User-Agent' in result + assert result['User-Agent'] == 'ibind' + assert result['Host'] == 'api.ibkr.com' + + # Verify authorization header contains expected elements + auth_header = result['Authorization'] + assert 'OAuth realm="limited_poa"' in auth_header + assert 'oauth_consumer_key="test_consumer_key"' in auth_header + assert 'oauth_token="test_access_token"' in auth_header + assert 'oauth_timestamp="1234567890"' in auth_header + assert 'oauth_nonce="aaaaaaaaaaaaaaaa"' in auth_header # 16 'a's from mocked choice + assert 'oauth_signature=' in auth_header + + +@patch('time.time', return_value=1234567890, autospec=True) +@patch('secrets.choice', side_effect=lambda x: 'a', autospec=True) # Predictable nonce +def test_generate_oauth_headers_hmac_integration(mock_choice_func, mock_time_func, oauth_config, test_crypto_data): + # Arrange + request_method = 'GET' + request_url = 'https://api.ibkr.com/v1/test' + live_session_token = test_crypto_data['test_token'] + + # Act - Let internal functions run naturally, use real crypto + result = generate_oauth_headers( + oauth_config=oauth_config, + request_method=request_method, + request_url=request_url, + live_session_token=live_session_token, + signature_method='HMAC-SHA256' + ) + + # Assert + assert isinstance(result, dict) + assert 'Authorization' in result + assert result['User-Agent'] == 'ibind' + + # Verify authorization header structure + auth_header = result['Authorization'] + assert 'OAuth realm="limited_poa"' in auth_header + assert 'oauth_signature=' in auth_header + + +@pytest.mark.parametrize("extra_data_type,extra_data_value", [ + ("extra_headers", {'custom_header': 'custom_value'}), + ("request_params", {'param1': 'value1', 'param2': 'value2'}), +]) +@patch('time.time', return_value=1234567890, autospec=True) +@patch('secrets.choice', side_effect=lambda x: 'a', autospec=True) # Predictable nonce +def test_generate_oauth_headers_with_extra_data_integration(mock_choice, mock_time, oauth_config, extra_data_type, extra_data_value, test_crypto_data): + # Arrange - Use real functions, only mock deterministic inputs + request_method = 'GET' + request_url = 'https://api.ibkr.com/v1/test' + kwargs = {extra_data_type: extra_data_value} + live_session_token = test_crypto_data['test_token'] + + # Act - Let all internal functions run naturally with real crypto + result = generate_oauth_headers( + oauth_config=oauth_config, + request_method=request_method, + request_url=request_url, + live_session_token=live_session_token, + signature_method='HMAC-SHA256', + **kwargs + ) + + # Assert - Test the actual behavior, not implementation details + assert isinstance(result, dict) + assert 'Authorization' in result + assert result['User-Agent'] == 'ibind' + + # Verify the authorization header is properly formed + auth_header = result['Authorization'] + assert 'OAuth realm="limited_poa"' in auth_header + assert 'oauth_signature=' in auth_header + assert 'oauth_timestamp="1234567890"' in auth_header + assert 'oauth_nonce="aaaaaaaaaaaaaaaa"' in auth_header # 16 'a's from mocked choice + + # Most importantly: verify extra data affects the signature (different signatures for different data) + # Generate another header without extra data + result_without_extra = generate_oauth_headers( + oauth_config=oauth_config, + request_method=request_method, + request_url=request_url, + live_session_token=live_session_token, + signature_method='HMAC-SHA256' + ) + + # The signatures should be different because the base string includes extra data + auth_header_without_extra = result_without_extra['Authorization'] + signature_with_extra = auth_header.split('oauth_signature="')[1].split('"')[0] + signature_without_extra = auth_header_without_extra.split('oauth_signature="')[1].split('"')[0] + assert signature_with_extra != signature_without_extra, "Extra data should affect OAuth signature" + + +@patch('secrets.randbits', return_value=0x123, autospec=True) # Deterministic randomness +@patch('builtins.open', new_callable=mock_open, read_data='dummy_key_content') +def test_prepare_oauth_integration(mock_file, mock_randbits, oauth_config, real_test_keys): + # Arrange + oauth_config.encryption_key_fp = '/tmp/encryption_key.pem' # noqa: S108 + + # Create real encrypted access token secret for testing + test_secret = b'test_decrypted_secret_for_prepend' + cipher = PKCS1_v1_5_Cipher.new(real_test_keys['public_key']) + encrypted_secret = cipher.encrypt(test_secret) + oauth_config.access_token_secret = base64.b64encode(encrypted_secret).decode('utf-8') + + # Mock RSA key import to return our real test key + with patch('ibind.oauth.oauth1a.RSA.importKey', autospec=True) as mock_rsa_import: + mock_rsa_import.return_value = real_test_keys['private_key'] + + # Act - Test real behavior with actual crypto operations + prepend, extra_headers, dh_random = prepare_oauth(oauth_config) + + # Assert + assert isinstance(prepend, str) + assert isinstance(dh_random, str) + assert isinstance(extra_headers, dict) + assert 'diffie_hellman_challenge' in extra_headers + + # Verify prepend is the hex representation of decrypted secret + assert prepend == test_secret.hex() + + # Verify dh_random is hex format + assert all(c in '0123456789abcdef' for c in dh_random.lower()) + + # Verify DH challenge is valid hex + dh_challenge = extra_headers['diffie_hellman_challenge'] + int(dh_challenge, 16) # Should not raise ValueError + + # Verify deterministic behavior + prepend2, extra_headers2, dh_random2 = prepare_oauth(oauth_config) + assert prepend == prepend2 # Same encrypted secret should give same prepend + assert dh_random == dh_random2 # Same mocked random should give same result + +@patch('secrets.randbits', return_value=0x123, autospec=True) +@patch('secrets.choice', side_effect=lambda x: 'a', autospec=True) +@patch('time.time', return_value=1234567890, autospec=True) +@patch('builtins.open', new_callable=mock_open, read_data='dummy_key_content') +def test_req_live_session_token_integration(mock_file, mock_time_func, mock_choice_func, mock_randbits_func, oauth_config, mock_client, real_test_keys): + # Arrange + oauth_config.encryption_key_fp = '/tmp/encryption_key.pem' # noqa: S108 + oauth_config.signature_key_fp = '/tmp/signature_key.pem' # noqa: S108 + + # Create real encrypted access token secret for testing + test_secret = b'test_decrypted_secret' + cipher = PKCS1_v1_5_Cipher.new(real_test_keys['public_key']) + encrypted_secret = cipher.encrypt(test_secret) + oauth_config.access_token_secret = base64.b64encode(encrypted_secret).decode('utf-8') + + with patch('ibind.oauth.oauth1a.RSA.importKey', autospec=True) as mock_rsa_import: + mock_rsa_import.return_value = real_test_keys['private_key'] + + # Act + live_session_token, lst_expires, lst_signature = req_live_session_token(mock_client, oauth_config) + + # Assert + assert isinstance(live_session_token, str) + assert lst_expires == 1234567890 + assert lst_signature == 'lst_signature_value' + + # Verify HTTP call was made + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + + # Verify endpoint was called correctly + assert call_args.args[0] == oauth_config.live_session_token_endpoint + + # Verify authorization header structure (real OAuth header generated) + auth_header = call_args.kwargs['extra_headers']['Authorization'] + assert isinstance(auth_header, str) + assert 'OAuth realm=' in auth_header + assert 'oauth_signature=' in auth_header + +@patch('secrets.randbits', return_value=0x123, autospec=True) +@patch('secrets.choice', side_effect=lambda x: 'a', autospec=True) +@patch('time.time', return_value=1234567890, autospec=True) +@patch('builtins.open', new_callable=mock_open, read_data='dummy_key_content') +def test_req_live_session_token_api_failure(mock_file, mock_time_func, mock_choice_func, mock_randbits_func, oauth_config, mock_client, real_test_keys): + # Arrange - Use real crypto behavior, only mock API failure + oauth_config.encryption_key_fp = '/tmp/encryption_key.pem' # noqa: S108 + oauth_config.signature_key_fp = '/tmp/signature_key.pem' # noqa: S108 + + # Create real encrypted access token secret for testing + test_secret = b'test_decrypted_secret' + cipher = PKCS1_v1_5_Cipher.new(real_test_keys['public_key']) + encrypted_secret = cipher.encrypt(test_secret) + oauth_config.access_token_secret = base64.b64encode(encrypted_secret).decode('utf-8') + + with patch('ibind.oauth.oauth1a.RSA.importKey', autospec=True) as mock_rsa_import: + mock_rsa_import.return_value = real_test_keys['private_key'] + + mock_client.post.side_effect = Exception('API request failed') + + # Act & Assert - Test real OAuth flow behavior with API failure + with pytest.raises(Exception, match='API request failed'): + req_live_session_token(mock_client, oauth_config) + + +@patch('secrets.randbits', return_value=0x123, autospec=True) +@patch('secrets.choice', side_effect=lambda x: 'a', autospec=True) +@patch('time.time', return_value=1234567890, autospec=True) +@patch('builtins.open', new_callable=mock_open, read_data='dummy_key_content') +def test_req_live_session_token_missing_response_data(mock_file, mock_time_func, mock_choice_func, mock_randbits_func, oauth_config, mock_client, real_test_keys): + # Arrange - Use real crypto behavior, only mock API response + oauth_config.encryption_key_fp = '/tmp/encryption_key.pem' # noqa: S108 + oauth_config.signature_key_fp = '/tmp/signature_key.pem' # noqa: S108 + + # Create real encrypted access token secret for testing + test_secret = b'test_decrypted_secret' + cipher = PKCS1_v1_5_Cipher.new(real_test_keys['public_key']) + encrypted_secret = cipher.encrypt(test_secret) + oauth_config.access_token_secret = base64.b64encode(encrypted_secret).decode('utf-8') + + with patch('ibind.oauth.oauth1a.RSA.importKey', autospec=True) as mock_rsa_import: + mock_rsa_import.return_value = real_test_keys['private_key'] + + mock_response = MagicMock() + mock_response.data = {} # Missing required fields + mock_client.post.return_value = mock_response + + # Act & Assert - Test real OAuth flow behavior with missing response data + with pytest.raises(KeyError): + req_live_session_token(mock_client, oauth_config) diff --git a/test/unit/oauth/test_oauth_base_config_u.py b/test/unit/oauth/test_oauth_base_config_u.py new file mode 100644 index 00000000..f1dd92dc --- /dev/null +++ b/test/unit/oauth/test_oauth_base_config_u.py @@ -0,0 +1,92 @@ +import pytest + +from ibind.oauth import OAuthConfig + + +class ConcreteOAuthConfig(OAuthConfig): + """Concrete implementation of OAuthConfig for testing purposes.""" + + def version(self): + return "test_version" + + +@pytest.fixture +def concrete_config(): + """Create a concrete OAuthConfig implementation for testing.""" + return ConcreteOAuthConfig( + init_oauth=True, + init_brokerage_session=False, + maintain_oauth=True, + shutdown_oauth=False + ) + + +def test_oauth_config_abstract_version_method(): + # Arrange + + # Act & Assert + with pytest.raises(TypeError, match="Can't instantiate abstract class OAuthConfig"): + OAuthConfig() + + + +def test_copy_method_creates_shallow_copy(concrete_config): + # Arrange + original_id = id(concrete_config) + + # Act + copied_config = concrete_config.copy() + + # Assert + assert id(copied_config) != original_id + assert copied_config.init_oauth == concrete_config.init_oauth + assert copied_config.init_brokerage_session == concrete_config.init_brokerage_session + assert copied_config.maintain_oauth == concrete_config.maintain_oauth + assert copied_config.shutdown_oauth == concrete_config.shutdown_oauth + + +def test_copy_method_with_modifications(concrete_config): + # Arrange + original_init_oauth = concrete_config.init_oauth + original_maintain_oauth = concrete_config.maintain_oauth + + # Act + copied_config = concrete_config.copy( + init_oauth=not original_init_oauth, + maintain_oauth=not original_maintain_oauth + ) + + # Assert + assert copied_config.init_oauth == (not original_init_oauth) + assert copied_config.maintain_oauth == (not original_maintain_oauth) + # Unchanged attributes should remain the same + assert copied_config.init_brokerage_session == concrete_config.init_brokerage_session + assert copied_config.shutdown_oauth == concrete_config.shutdown_oauth + + +def test_copy_method_with_invalid_attribute(concrete_config): + # Arrange + invalid_attribute = 'nonexistent_attribute' + + # Act & Assert + with pytest.raises(AttributeError, match=f'OAuthConfig does not have attribute "{invalid_attribute}"'): + concrete_config.copy(nonexistent_attribute='some_value') + + +def test_copy_method_with_multiple_modifications(concrete_config): + # Arrange + modifications = { + 'init_oauth': False, + 'init_brokerage_session': True, + 'maintain_oauth': False, + 'shutdown_oauth': True + } + + # Act + copied_config = concrete_config.copy(**modifications) + + # Assert + for attr, expected_value in modifications.items(): + assert getattr(copied_config, attr) == expected_value + + diff --git a/test/unit/oauth/test_oauth_config_u.py b/test/unit/oauth/test_oauth_config_u.py new file mode 100644 index 00000000..041dc0a8 --- /dev/null +++ b/test/unit/oauth/test_oauth_config_u.py @@ -0,0 +1,147 @@ +import tempfile +import pytest +from pathlib import Path + +from ibind.oauth.oauth1a import OAuth1aConfig + + +@pytest.fixture +def valid_config(): + """Create a valid OAuth1aConfig for testing.""" + return OAuth1aConfig( + oauth_rest_url='https://api.ibkr.com', + live_session_token_endpoint='/v1/api/oauth/live_session_token', # noqa: S106 + access_token='test_access_token', # noqa: S106 + access_token_secret='test_access_token_secret', # noqa: S106 + consumer_key='test_consumer_key', # noqa: S106 + dh_prime='test_dh_prime', # noqa: S106 + encryption_key_fp='/tmp/encryption_key.pem', # noqa: S108 + signature_key_fp='/tmp/signature_key.pem', # noqa: S108 + ) + +def test_version_returns_1_0a(): + # Arrange + config = OAuth1aConfig() + + # Act + result = config.version() + + # Assert + assert result == '1.0a' + +# TODO Check this test +def test_verify_config_success_with_valid_params(): + # Arrange + with tempfile.NamedTemporaryFile(mode='w', delete=False) as enc_file, tempfile.NamedTemporaryFile(mode='w', delete=False) as sig_file: + enc_file.write('dummy key content') + sig_file.write('dummy key content') + enc_file.flush() + sig_file.flush() + + config = OAuth1aConfig( + oauth_rest_url='https://api.ibkr.com', + live_session_token_endpoint='/v1/api/oauth/live_session_token', # noqa: S106 + access_token='test_access_token', # noqa: S106 + access_token_secret='test_access_token_secret', # noqa: S106 + consumer_key='test_consumer_key', # noqa: S106 + dh_prime='test_dh_prime', # noqa: S106 + encryption_key_fp=enc_file.name, + signature_key_fp=sig_file.name, + ) + + # Act + config.verify_config() + + # Assert + # No exception should be raised + + # Cleanup + Path(enc_file.name).unlink() + Path(sig_file.name).unlink() + +def test_verify_config_missing_required_params(): + # Arrange + config = OAuth1aConfig() + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + config.verify_config() + + error_message = str(exc_info.value) + assert 'OAuth1aConfig is missing required parameters:' in error_message + # Check that some expected None parameters are mentioned + expected_missing = ['access_token', 'access_token_secret', 'consumer_key', 'dh_prime'] + for param in expected_missing: + assert param in error_message + +def test_verify_config_partial_missing_params(): + # Arrange + config = OAuth1aConfig( + access_token='test_access_token', # noqa: S106 + consumer_key='test_consumer_key', # noqa: S106 + ) + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + config.verify_config() + + error_message = str(exc_info.value) + assert 'OAuth1aConfig is missing required parameters:' in error_message + # Should not contain the provided parameters (using word boundaries) + import re + assert re.search(r'\baccess_token\b', error_message) is None + assert 'consumer_key' not in error_message + # Should contain missing parameters + assert 'access_token_secret' in error_message + assert 'dh_prime' in error_message + +def test_verify_config_missing_filepaths(): + # Arrange + config = OAuth1aConfig( + oauth_rest_url='https://api.ibkr.com', + live_session_token_endpoint='/v1/api/oauth/live_session_token', # noqa: S106 + access_token='test_access_token', # noqa: S106 + access_token_secret='test_access_token_secret', # noqa: S106 + consumer_key='test_consumer_key', # noqa: S106 + dh_prime='test_dh_prime', # noqa: S106 + encryption_key_fp='/nonexistent/encryption_key.pem', + signature_key_fp='/nonexistent/signature_key.pem', + ) + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + config.verify_config() + + error_message = str(exc_info.value) + assert "OAuth1aConfig's filepaths don't exist:" in error_message + assert 'encryption_key_fp' in error_message + assert 'signature_key_fp' in error_message + +def test_verify_config_partial_missing_filepaths(): + # Arrange + with tempfile.NamedTemporaryFile(mode='w', delete=False) as enc_file: + enc_file.write('dummy key content') + enc_file.flush() + + config = OAuth1aConfig( + oauth_rest_url='https://api.ibkr.com', + live_session_token_endpoint='/v1/api/oauth/live_session_token', # noqa: S106 + access_token='test_access_token', # noqa: S106 + access_token_secret='test_access_token_secret', # noqa: S106 + consumer_key='test_consumer_key', # noqa: S106 + dh_prime='test_dh_prime', # noqa: S106 + encryption_key_fp=enc_file.name, + signature_key_fp='/nonexistent/signature_key.pem', + ) + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + config.verify_config() + + error_message = str(exc_info.value) + assert "OAuth1aConfig's filepaths don't exist:" in error_message + assert 'encryption_key_fp' not in error_message + assert 'signature_key_fp' in error_message + + # Cleanup + Path(enc_file.name).unlink() diff --git a/test/unit/support/test_logs_u.py b/test/unit/support/test_logs_u.py new file mode 100644 index 00000000..3c48104a --- /dev/null +++ b/test/unit/support/test_logs_u.py @@ -0,0 +1,316 @@ +import logging +import pytest +from unittest.mock import patch, MagicMock, mock_open + +from ibind.support.logs import ( + project_logger, + ibind_logs_initialize, + new_daily_rotating_file_handler, + DailyRotatingFileHandler, + DEFAULT_FORMAT +) + + +@pytest.fixture +def reset_logging_state(): + """Reset global logging state before and after each test.""" + # Reset global state before test + import ibind.support.logs + ibind.support.logs._initialized = False + ibind.support.logs._log_to_file = False + + # Clear any existing loggers + for logger_name in list(logging.Logger.manager.loggerDict.keys()): + if logger_name.startswith('ibind'): + logger = logging.getLogger(logger_name) + logger.handlers.clear() + logger.filters.clear() + + yield + + # Reset global state after test + ibind.support.logs._initialized = False + ibind.support.logs._log_to_file = False + + # Clear loggers again + for logger_name in list(logging.Logger.manager.loggerDict.keys()): + if logger_name.startswith('ibind'): + logger = logging.getLogger(logger_name) + logger.handlers.clear() + logger.filters.clear() + + +def test_project_logger_without_filepath(): + # Arrange + + # Act + logger = project_logger() + + # Assert + assert logger.name == 'ibind' + assert isinstance(logger, logging.Logger) + + +def test_project_logger_with_filepath(): + # Arrange + filepath = '/path/to/test_module.py' + + # Act + logger = project_logger(filepath) + + # Assert + assert logger.name == 'ibind.test_module' + assert isinstance(logger, logging.Logger) + + +@patch('ibind.support.logs.var.LOG_TO_CONSOLE', True) +@patch('ibind.support.logs.var.LOG_TO_FILE', False) +@patch('ibind.support.logs.var.LOG_LEVEL', 'DEBUG') +@patch('ibind.support.logs.var.LOG_FORMAT', DEFAULT_FORMAT) +@patch('ibind.support.logs.var.PRINT_FILE_LOGS', False) +def test_ibind_logs_initialize_console_only(reset_logging_state): + # Arrange + + # Act + ibind_logs_initialize() + + # Assert + logger = logging.getLogger('ibind') + assert logger.level == logging.DEBUG + assert len(logger.handlers) == 1 + assert isinstance(logger.handlers[0], logging.StreamHandler) + + +@patch('ibind.support.logs.var.LOG_TO_CONSOLE', False) +@patch('ibind.support.logs.var.LOG_TO_FILE', True) +@patch('ibind.support.logs.var.LOG_LEVEL', 'INFO') +@patch('ibind.support.logs.var.LOG_FORMAT', DEFAULT_FORMAT) +@patch('ibind.support.logs.var.PRINT_FILE_LOGS', False) +def test_ibind_logs_initialize_file_only(reset_logging_state): + # Arrange + + # Act + ibind_logs_initialize(log_to_console=False, log_to_file=True) + + # Assert + logger = logging.getLogger('ibind') + assert logger.level == logging.DEBUG + # Should have no console handlers when log_to_console=False + console_handlers = [h for h in logger.handlers if isinstance(h, logging.StreamHandler)] + assert len(console_handlers) == 0 + + +def test_ibind_logs_initialize_custom_parameters(reset_logging_state): + # Arrange + custom_format = '%(levelname)s - %(message)s' + + # Act + ibind_logs_initialize( + log_to_console=True, + log_to_file=False, + log_level='WARNING', + log_format=custom_format, + print_file_logs=False + ) + + # Assert + logger = logging.getLogger('ibind') + assert logger.level == logging.DEBUG + assert len(logger.handlers) == 1 + handler = logger.handlers[0] + assert handler.level == logging.WARNING + # Check formatter format string + assert handler.formatter._fmt == custom_format + + +def test_ibind_logs_initialize_idempotent(reset_logging_state): + # Arrange + + # Act + ibind_logs_initialize(log_to_console=True) + initial_handler_count = len(logging.getLogger('ibind').handlers) + + # Call again - should not add more handlers + ibind_logs_initialize(log_to_console=True) + + # Assert + final_handler_count = len(logging.getLogger('ibind').handlers) + assert initial_handler_count == final_handler_count + + +@patch('ibind.support.logs.var.LOG_TO_CONSOLE', True) +@patch('ibind.support.logs.var.LOG_TO_FILE', True) +@patch('ibind.support.logs.var.PRINT_FILE_LOGS', True) +def test_ibind_logs_initialize_with_file_and_console(reset_logging_state): + # Arrange + + # Act + ibind_logs_initialize(log_to_console=True, log_to_file=True, print_file_logs=True) + + # Assert + logger = logging.getLogger('ibind') + console_handlers = [h for h in logger.handlers if isinstance(h, logging.StreamHandler)] + assert len(console_handlers) == 1 + + # Check that file handler logger also gets console output when print_file_logs=True + fh_logger = logging.getLogger('ibind_fh') + fh_console_handlers = [h for h in fh_logger.handlers if isinstance(h, logging.StreamHandler)] + assert len(fh_console_handlers) == 1 + + +def test_ibind_logs_initialize_disables_file_logging(reset_logging_state): + # Arrange + + # Act + ibind_logs_initialize(log_to_file=False) + + # Assert + fh_logger = logging.getLogger('ibind_fh') + # Should have a filter that blocks all records + assert len(fh_logger.filters) > 0 + # Test the filter blocks records + test_record = logging.LogRecord('test', logging.INFO, 'path', 1, 'msg', (), None) + assert not fh_logger.filters[0](test_record) + + +def test_new_daily_rotating_file_handler_with_file_logging(reset_logging_state): + # Arrange + import ibind.support.logs + ibind.support.logs._log_to_file = True + logger_name = 'test_logger' + filepath = '/tmp/test.log' # noqa: S108 + + # Mock only file operations, not the handler itself + with patch('builtins.open', mock_open()), \ + patch('ibind.support.logs.Path') as mock_path: + mock_path.return_value.parent.mkdir = MagicMock() + + # Act - Test real DailyRotatingFileHandler behavior + logger = new_daily_rotating_file_handler(logger_name, filepath) + + # Assert + assert logger.name == 'ibind_fh.test_logger' + assert logger.level == logging.DEBUG + # Verify logger has real DailyRotatingFileHandler + assert len(logger.handlers) == 1 + assert isinstance(logger.handlers[0], DailyRotatingFileHandler) + assert logger.handlers[0].baseFilename == filepath + + +def test_new_daily_rotating_file_handler_without_file_logging(reset_logging_state): + # Arrange + import ibind.support.logs + ibind.support.logs._log_to_file = False + logger_name = 'test_logger' + filepath = '/tmp/test.log' # noqa: S108 + + # Act + logger = new_daily_rotating_file_handler(logger_name, filepath) + + # Assert + assert logger.name == 'ibind_fh.test_logger' + # Should have a NullHandler when file logging is disabled + null_handlers = [h for h in logger.handlers if isinstance(h, logging.NullHandler)] + assert len(null_handlers) == 1 + + +def test_new_daily_rotating_file_handler_existing_handlers(reset_logging_state): + # Arrange + import ibind.support.logs + ibind.support.logs._log_to_file = True + logger_name = 'test_logger' + filepath = '/tmp/test.log' # noqa: S108 + + # Pre-create logger with existing handler + logger = logging.getLogger('ibind_fh.test_logger') + existing_handler = logging.Handler() + logger.addHandler(existing_handler) + + # Act + result_logger = new_daily_rotating_file_handler(logger_name, filepath) + + # Assert + assert result_logger is logger + # Should not add new handlers if handlers already exist + assert len(logger.handlers) == 1 # Only the existing handler + + +def test_daily_rotating_file_handler_initialization(): + # Arrange + base_filename = '/tmp/test.log' # noqa: S108 + + # Act + with patch('builtins.open', mock_open()): + handler = DailyRotatingFileHandler(base_filename) + + # Assert + assert handler.baseFilename == base_filename + assert handler.timestamp is not None # Will be set during initialization + assert handler.date_format == '%Y-%m-%d' + + +def test_daily_rotating_file_handler_open(): + # Arrange + base_filename = '/tmp/test.log' # noqa: S108 + + # Mock only file operations and Path.mkdir, not the entire Path class + with patch('builtins.open', mock_open()) as mock_file_open, \ + patch('pathlib.Path.mkdir') as mock_mkdir: + + handler = DailyRotatingFileHandler(base_filename) + + # Test real get_timestamp behavior by using a fixed date + with patch.object(handler, 'get_timestamp', return_value='2024-01-15'): + # Act - Test real _open behavior + handler._open() + + # Assert + assert handler.timestamp == '2024-01-15' + expected_path = '/tmp/test.log__2024-01-15.txt' # noqa: S108 + + # Verify real get_filename behavior was used + assert handler.get_filename('2024-01-15') == expected_path + + # Verify directory creation and file opening + mock_mkdir.assert_called_with(parents=True, exist_ok=True) + mock_file_open.assert_called_with(expected_path, 'a', encoding='utf-8') + + +def test_daily_rotating_file_handler_emit_rotation(): + # Arrange + base_filename = '/tmp/test.log' # noqa: S108 + + with patch('builtins.open', mock_open()), \ + patch('pathlib.Path.mkdir'): + + handler = DailyRotatingFileHandler(base_filename) + handler.timestamp = '2024-01-15' # Set initial timestamp + + # Create a mock stream to simulate file being open + mock_stream = MagicMock() + handler.stream = mock_stream + + # Create test log record + record = logging.LogRecord('test', logging.INFO, 'path', 1, 'test message', (), None) + + # Test case 1: Same timestamp - no rotation + with patch.object(handler, 'get_timestamp', return_value='2024-01-15'): + handler.emit(record) + + # Should not have called close or _open + mock_stream.close.assert_not_called() + + # Test case 2: Different timestamp - should rotate + with patch.object(handler, 'get_timestamp', return_value='2024-01-16'), \ + patch.object(handler, '_open', return_value=MagicMock()) as mock_open_method: + + handler.emit(record) + + # Should have closed old file and opened new one + mock_stream.close.assert_called_once() + mock_open_method.assert_called_once() + + +def test_default_format_constant(): + # Arrange & Act & Assert + assert DEFAULT_FORMAT == '%(asctime)s|%(levelname)-.1s| %(message)s'