diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3c837ba..8039182 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,8 +2,70 @@ Changelog ========= -Unreleased (8.x) -================ +Unreleased +========== + +Version 8.1.0 (2026-05-16) +========================== + +Bug Fixes +--------- +- **Fix MQTT connection flapping after reconnect**: When ``_active_reconnect()`` + created a new ``MqttConnection``, the old connection was never closed. The old + SDK connection's built-in auto-reconnect would eventually succeed, creating two + active connections sharing the same client ID. Because AWS IoT allows only one + connection per client ID, the broker would kick one off, triggering + ``on_connection_interrupted`` and starting yet another reconnection — an + infinite connect/disconnect loop. Fixed by adding ``MqttConnection.close()`` + (unconditional teardown regardless of ``_connected`` state) and calling it + before creating the replacement connection in both ``_active_reconnect()`` and + ``_deep_reconnect()``. + +- **Thread-safety race in ``ensure_device_info_cached``**: The ``future.done()`` + check and ``future.set_result()`` were performed in the AWS SDK callback thread + without synchronisation, creating a race against the asyncio event loop thread. + Moved both operations inside a ``call_soon_threadsafe`` callback so they execute + atomically on the event loop thread. + +- **ZeroDivisionError when ``deep_reconnect_threshold`` is 0**: Config validation + now clamps ``deep_reconnect_threshold`` to a minimum of 1, preventing a + ``ZeroDivisionError`` in the exponential-backoff reconnection logic. + +- **Reconnect counter never incremented**: ``total_reconnect_attempts`` in + diagnostics was not incremented on connection drops, so it always reported 0 + despite active reconnections. Counter is now incremented on each + ``on_connection_interrupted`` event. + +- **``shortest_session_seconds`` not JSON-serialisable**: The diagnostics + ``to_dict()`` method used ``float('inf')`` as the initial value for + ``shortest_session_seconds``, which is not valid JSON. Changed to ``None`` + so serialisation succeeds when no session has completed yet. + +- **``wait_for()`` future not bound to running loop**: ``wait_for()`` created a + bare ``asyncio.Future()`` rather than + ``asyncio.get_running_loop().create_future()``, which could bind the future to + a different loop in multi-loop test setups. + +- **Reservation temperature validation was US-only**: ``build_reservation_entry`` + validated set-point temperatures against hardcoded Fahrenheit bounds (95–150 °F) + regardless of the active unit system. Validation now uses the current unit system + context: 35–65 °C in metric mode, 95–150 °F in US mode. Celsius users previously + received spurious ``ValueError`` rejections for valid temperatures. + +- **Malformed reservation data silently dropped**: ``build_reservation_entry`` now + logs a warning when reservation hex data contains unexpected trailing bytes + instead of silently dropping partial entries. + +- **Unknown ``PeriodicRequestType`` silently ignored**: The periodic-request handler + now logs an error and breaks when it encounters an unknown request type instead of + doing nothing. + +- **Memory leak in device info cache**: ``get_all_cached()`` only filtered expired + entries from its return value but left them in the cache dictionary. Expired + entries are now evicted during ``get_all_cached()`` to prevent unbounded growth. + +Version 8.0.0 (2026-05-13) +=========================== Bug Fixes --------- diff --git a/scripts/bump_version.py b/scripts/bump_version.py index 800ee0f..2e37224 100755 --- a/scripts/bump_version.py +++ b/scripts/bump_version.py @@ -22,6 +22,8 @@ import re import subprocess import sys +from datetime import date +from pathlib import Path def run_git_command(args: list) -> str: @@ -140,6 +142,75 @@ def check_working_directory_clean() -> None: sys.exit(1) +def update_changelog(version: str) -> None: + """Insert a version heading into CHANGELOG.rst below the Unreleased section. + + Transforms: + + Unreleased + ========== + + + + into: + + Unreleased + ========== + + Version X.Y.Z (YYYY-MM-DD) + =========================== + + + """ + changelog_path = Path("CHANGELOG.rst") + if not changelog_path.exists(): + print("Warning: CHANGELOG.rst not found, skipping changelog update.") + return + + content = changelog_path.read_text(encoding="utf-8") + + heading = f"Version {version} ({date.today().isoformat()})" + underline = "=" * len(heading) + version_block = f"{heading}\n{underline}\n" + + # Match "Unreleased\n==========\n" (any number of = signs) followed by + # one or more blank lines, then insert the version block after them. + pattern = re.compile( + r"(Unreleased\n=+\n)" # group 1: Unreleased heading + r"(\n+)", # group 2: blank line(s) separator + re.MULTILINE, + ) + + match = pattern.search(content) + if not match: + print( + "Warning: Could not find 'Unreleased' section in CHANGELOG.rst. " + "Skipping changelog update.", + file=sys.stderr, + ) + return + + # Insert the version block after the blank lines that follow "Unreleased" + new_content = ( + content[: match.end()] + + version_block + + "\n" + + content[match.end() :] + ) + + changelog_path.write_text(new_content, encoding="utf-8") + print(f"[OK] Updated CHANGELOG.rst with {heading}") + + +def commit_changelog(version: str) -> None: + """Stage and commit the CHANGELOG.rst update.""" + run_git_command(["add", "CHANGELOG.rst"]) + run_git_command( + ["commit", "-m", f"Update changelog for v{version}"] + ) + print("[OK] Committed changelog update") + + def create_tag(version: str, message: str = None) -> None: """Create a git tag for the version.""" tag_name = f"v{version}" @@ -223,6 +294,11 @@ def main() -> None: # Validate version progression validate_version_progression(current_version, new_version) + # Update CHANGELOG.rst and commit, then create the tag + print("\nUpdating CHANGELOG.rst...") + update_changelog(new_version) + commit_changelog(new_version) + # Create the tag print(f"\nCreating tag v{new_version}...") create_tag(new_version) @@ -230,6 +306,7 @@ def main() -> None: print("\n[OK] Version bump complete!") print("\nNext steps:") print(f" 1. Push the tag: git push origin v{new_version}") + print(" (also push the changelog commit: git push origin HEAD)") print(" 2. Build release: make build") print(" 3. Test on TestPyPI: make publish-test") print(" 4. Publish to PyPI: make publish") diff --git a/src/nwp500/device_info_cache.py b/src/nwp500/device_info_cache.py index f131f2b..a1f5865 100644 --- a/src/nwp500/device_info_cache.py +++ b/src/nwp500/device_info_cache.py @@ -137,12 +137,15 @@ async def get_all_cached(self) -> dict[str, DeviceFeature]: Dictionary mapping MAC addresses to DeviceFeature objects """ async with self._lock: - # Filter out expired entries - return { - mac: features - for mac, (features, timestamp) in self._cache.items() - if not self.is_expired(timestamp) - } + # Filter out expired entries and purge them from cache + expired_keys = [ + mac + for mac, (_, timestamp) in self._cache.items() + if self.is_expired(timestamp) + ] + for mac in expired_keys: + del self._cache[mac] + return {mac: features for mac, (features, _) in self._cache.items()} async def get_cache_info( self, diff --git a/src/nwp500/encoding.py b/src/nwp500/encoding.py index 5f783dd..2566c21 100644 --- a/src/nwp500/encoding.py +++ b/src/nwp500/encoding.py @@ -8,11 +8,14 @@ from __future__ import annotations +import logging from collections.abc import Iterable from numbers import Real from .exceptions import ParameterValidationError, RangeValidationError +_logger = logging.getLogger(__name__) + # MGPP Week Bitfield Encoding (from NaviLink APK KDEnum.MgppReservationWeek). # Uses a single byte where bits 1-7 represent days; bit 0 is unused. # @@ -342,14 +345,18 @@ def decode_reservation_hex(hex_string: str) -> list[dict[str, int]]: data = bytes.fromhex(hex_string) reservations = [] + if len(data) % 6 != 0: + _logger.warning( + "Reservation hex data length %d is not a multiple of 6; " + "trailing %d bytes will be ignored", + len(data), + len(data) % 6, + ) + # Process 6 bytes at a time - for i in range(0, len(data), 6): + for i in range(0, len(data) - (len(data) % 6), 6): chunk = data[i : i + 6] - # Ensure we have a full 6-byte entry - if len(chunk) != 6: - break - # Skip empty entries (all zeros) if all(b == 0 for b in chunk): continue @@ -425,11 +432,26 @@ def build_reservation_entry( """ # Import here to avoid circular import from .models import preferred_to_half_celsius + from .unit_system import get_unit_system + + # Read unit system once to keep min/max bounds consistent + unit_system = get_unit_system() # Use device-provided limits if available, otherwise use defaults - # Defaults are conservative: 95°F / 35°C minimum, 150°F / 65°C maximum - min_temp = temperature_min if temperature_min is not None else 95 - max_temp = temperature_max if temperature_max is not None else 150 + # in the user's preferred unit system. + if temperature_min is not None: + min_temp = temperature_min + elif unit_system == "metric": + min_temp = 35.0 # ~35°C + else: + min_temp = 95.0 # 95°F + + if temperature_max is not None: + max_temp = temperature_max + elif unit_system == "metric": + max_temp = 65.0 # ~65°C + else: + max_temp = 150.0 # 150°F if not 0 <= hour <= 23: raise RangeValidationError( diff --git a/src/nwp500/events.py b/src/nwp500/events.py index 5101669..b6dc46e 100644 --- a/src/nwp500/events.py +++ b/src/nwp500/events.py @@ -396,7 +396,7 @@ async def wait_for( current_temp = temperature_event.new_temperature """ future: asyncio.Future[tuple[tuple[Any, ...], dict[str, Any]]] = ( - asyncio.Future() + asyncio.get_running_loop().create_future() ) def handler(*args: Any, **kwargs: Any) -> None: diff --git a/src/nwp500/mqtt/client.py b/src/nwp500/mqtt/client.py index 3a3a6f1..6a331a3 100644 --- a/src/nwp500/mqtt/client.py +++ b/src/nwp500/mqtt/client.py @@ -371,7 +371,10 @@ async def _active_reconnect(self) -> None: reconnect instead of passively waiting for AWS IoT SDK. Note: This creates a new connection while preserving subscriptions - and configuration. + and configuration. The old connection is closed first to prevent + its SDK auto-reconnect from creating a competing connection with + the same client ID (which causes the broker to kick one off, + leading to an infinite connect/disconnect loop). """ if self._connected: _logger.debug("Already connected, skipping reconnection") @@ -385,12 +388,15 @@ async def _active_reconnect(self) -> None: # If we have a connection manager, try to reconnect using it if self._connection_manager: - # The connection might be in a bad state, so we need to - # recreate the underlying connection + # Close old connection to stop SDK auto-reconnect and + # prevent two connections with the same client ID. _logger.debug("Recreating MQTT connection...") + try: + await self._connection_manager.close() + except (AwsCrtError, RuntimeError) as e: + _logger.debug(f"Old connection cleanup (benign): {e}") # Create a new connection manager with same config - old_connection_manager = self._connection_manager self._connection_manager = MqttConnection( config=self.config, auth_client=self._auth_client, @@ -415,9 +421,6 @@ async def _active_reconnect(self) -> None: _logger.info("Active reconnection successful") else: - # Restore old connection manager and connection reference - self._connection_manager = old_connection_manager - self._connection = old_connection_manager.connection _logger.warning("Active reconnection failed") else: _logger.warning( @@ -458,8 +461,7 @@ async def _deep_reconnect(self) -> None: if self._connection_manager: _logger.debug("Cleaning up old connection...") try: - if self._connection_manager.is_connected: - await self._connection_manager.disconnect() + await self._connection_manager.close() except (AwsCrtError, RuntimeError) as e: # Expected: connection already dead or in bad state _logger.debug(f"Error during cleanup: {e} (expected)") @@ -1294,14 +1296,20 @@ async def ensure_device_info_cached( return True # Not cached, request and wait - future: asyncio.Future[DeviceFeature] = ( - asyncio.get_running_loop().create_future() - ) + loop = asyncio.get_running_loop() + future: asyncio.Future[DeviceFeature] = loop.create_future() def on_feature(feature: DeviceFeature) -> None: - if not future.done(): - _logger.info(f"Device feature received for {redacted_mac}") - future.set_result(feature) + # Called from AWS SDK thread — schedule onto the event loop + # thread-safely. The done() check is inside the scheduled + # callback so it runs on the event loop thread, eliminating + # the race between the check and set_result. + def _set_result() -> None: + if not future.done(): + _logger.info(f"Device feature received for {redacted_mac}") + future.set_result(feature) + + loop.call_soon_threadsafe(_set_result) _logger.info(f"Ensuring device info cached for {redacted_mac}") await self.subscribe_device_feature(device, on_feature) diff --git a/src/nwp500/mqtt/connection.py b/src/nwp500/mqtt/connection.py index 4493ddb..f48ecc8 100644 --- a/src/nwp500/mqtt/connection.py +++ b/src/nwp500/mqtt/connection.py @@ -245,6 +245,44 @@ async def disconnect(self) -> None: _logger.error(f"Error during disconnect: {e}") raise + async def close(self) -> None: + """Unconditionally close the underlying SDK connection. + + Unlike :meth:`disconnect`, this method closes the connection + regardless of the ``_connected`` flag. After a connection + interruption, ``_connected`` is ``False`` but the SDK connection + object is still alive and its built-in auto-reconnect can still + fire. Calling ``close()`` ensures the SDK connection is fully + torn down so its callbacks and auto-reconnect cannot interfere + with a replacement connection. + + This method is safe to call multiple times or on already-closed + connections. + """ + connection = self._connection + self._connection = None + self._connected = False + + if connection is None: + return + + _logger.debug("Closing underlying SDK connection...") + try: + disconnect_future = cast( + asyncio.Future[Any], connection.disconnect() + ) + await asyncio.shield(asyncio.wrap_future(disconnect_future)) + _logger.debug("SDK connection closed") + except (AwsCrtError, RuntimeError) as e: + # Expected when connection is already dead or in bad state + _logger.debug(f"SDK connection close (benign): {e}") + except asyncio.CancelledError: + _logger.debug( + "Close operation cancelled but SDK disconnect " + "will complete in background" + ) + raise + async def subscribe( self, topic: str, diff --git a/src/nwp500/mqtt/diagnostics.py b/src/nwp500/mqtt/diagnostics.py index d5ca647..4df6a32 100644 --- a/src/nwp500/mqtt/diagnostics.py +++ b/src/nwp500/mqtt/diagnostics.py @@ -95,7 +95,11 @@ class MqttMetrics: def to_dict(self) -> dict[str, Any]: """Convert to dictionary for JSON serialization.""" - return asdict(self) + d = asdict(self) + # Replace inf with None for JSON compatibility + if d.get("shortest_session_seconds") == float("inf"): + d["shortest_session_seconds"] = None + return d class MqttDiagnosticsCollector: @@ -213,6 +217,7 @@ async def record_connection_drop( # Update metrics self._metrics.total_connection_drops += 1 + self._metrics.total_reconnect_attempts += 1 if error_name: self._metrics.connection_drops_by_error[error_name] = ( self._metrics.connection_drops_by_error.get(error_name, 0) + 1 diff --git a/src/nwp500/mqtt/periodic.py b/src/nwp500/mqtt/periodic.py index 24b6f7d..013b3ea 100644 --- a/src/nwp500/mqtt/periodic.py +++ b/src/nwp500/mqtt/periodic.py @@ -173,6 +173,12 @@ async def periodic_request() -> None: await self._request_device_info(device) elif request_type == PeriodicRequestType.DEVICE_STATUS: await self._request_device_status(device) + else: + _logger.error( + "Unknown periodic request type: %s", + request_type, + ) + break _logger.debug( "Sent periodic %s request for %s", diff --git a/src/nwp500/mqtt/utils.py b/src/nwp500/mqtt/utils.py index 9d677bf..8e126e3 100644 --- a/src/nwp500/mqtt/utils.py +++ b/src/nwp500/mqtt/utils.py @@ -233,11 +233,13 @@ class MqttConnectionConfig: max_queued_commands: int = 100 def __post_init__(self) -> None: - """Generate client ID if not provided.""" + """Generate client ID if not provided and validate settings.""" if not self.client_id: object.__setattr__( self, "client_id", f"navien-client-{uuid.uuid4().hex[:8]}" ) + if self.deep_reconnect_threshold < 1: + object.__setattr__(self, "deep_reconnect_threshold", 1) @dataclass diff --git a/tests/test_bug_fixes.py b/tests/test_bug_fixes.py new file mode 100644 index 0000000..08765ea --- /dev/null +++ b/tests/test_bug_fixes.py @@ -0,0 +1,198 @@ +"""Tests for bug fixes: diagnostics, config validation, encoding, cache.""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import patch + +import pytest + +from nwp500.encoding import build_reservation_entry, decode_reservation_hex +from nwp500.events import EventEmitter +from nwp500.mqtt.diagnostics import MqttDiagnosticsCollector, MqttMetrics +from nwp500.mqtt.utils import MqttConnectionConfig + + +class TestMqttConnectionConfigValidation: + """Tests for MqttConnectionConfig validation.""" + + def test_deep_reconnect_threshold_zero_clamped(self): + """deep_reconnect_threshold=0 should be clamped to 1.""" + config = MqttConnectionConfig(deep_reconnect_threshold=0) + assert config.deep_reconnect_threshold == 1 + + def test_deep_reconnect_threshold_negative_clamped(self): + """Negative deep_reconnect_threshold should be clamped to 1.""" + config = MqttConnectionConfig(deep_reconnect_threshold=-5) + assert config.deep_reconnect_threshold == 1 + + def test_deep_reconnect_threshold_valid_preserved(self): + """Valid deep_reconnect_threshold should be preserved.""" + config = MqttConnectionConfig(deep_reconnect_threshold=5) + assert config.deep_reconnect_threshold == 5 + + def test_default_deep_reconnect_threshold(self): + """Default deep_reconnect_threshold should be 10.""" + config = MqttConnectionConfig() + assert config.deep_reconnect_threshold == 10 + + +class TestDiagnosticsReconnectCounter: + """Tests for total_reconnect_attempts counter.""" + + @pytest.mark.asyncio(loop_scope="function") + async def test_reconnect_attempts_incremented_on_drop(self): + """total_reconnect_attempts should increment on each drop.""" + collector = MqttDiagnosticsCollector() + assert collector._metrics.total_reconnect_attempts == 0 + + await collector.record_connection_drop(error=RuntimeError("test")) + assert collector._metrics.total_reconnect_attempts == 1 + + await collector.record_connection_drop(error=RuntimeError("test2")) + assert collector._metrics.total_reconnect_attempts == 2 + + @pytest.mark.asyncio(loop_scope="function") + async def test_drop_increments_both_counters(self): + """Both total_connection_drops and total_reconnect_attempts update.""" + collector = MqttDiagnosticsCollector() + await collector.record_connection_drop(error=RuntimeError("test")) + assert collector._metrics.total_connection_drops == 1 + assert collector._metrics.total_reconnect_attempts == 1 + + +class TestMqttMetricsSerialization: + """Tests for MqttMetrics JSON serialization.""" + + def test_to_dict_replaces_inf(self): + """to_dict should replace inf with None for JSON compatibility.""" + metrics = MqttMetrics() + d = metrics.to_dict() + assert d["shortest_session_seconds"] is None + + def test_to_dict_preserves_real_value(self): + """to_dict should preserve real shortest_session_seconds values.""" + metrics = MqttMetrics(shortest_session_seconds=42.5) + d = metrics.to_dict() + assert d["shortest_session_seconds"] == 42.5 + + def test_to_dict_json_serializable(self): + """Default MqttMetrics should be JSON-serializable.""" + metrics = MqttMetrics() + # Should not raise + result = json.dumps(metrics.to_dict()) + assert isinstance(result, str) + + +class TestEventEmitterFuture: + """Tests for EventEmitter.wait_for future creation.""" + + @pytest.mark.asyncio(loop_scope="function") + async def test_wait_for_uses_running_loop(self): + """wait_for should create future bound to running loop.""" + emitter = EventEmitter() + + async def emit_soon(): + await asyncio.sleep(0.01) + await emitter.emit("test_event", "data") + + asyncio.create_task(emit_soon()) + result = await emitter.wait_for("test_event", timeout=1.0) + assert result == ("data",) + + +class TestDecodeReservationHex: + """Tests for decode_reservation_hex partial data handling.""" + + def test_valid_hex_decoded(self): + """Valid 6-byte entries should be decoded.""" + result = decode_reservation_hex("013e061e0478") + assert len(result) == 1 + assert result[0]["enable"] == 1 + + def test_partial_entry_logged_and_skipped(self): + """Partial trailing bytes should be skipped with warning.""" + # 6 valid bytes + 3 trailing bytes + result = decode_reservation_hex("013e061e0478aabbcc") + assert len(result) == 1 # Only the complete entry + + def test_empty_hex(self): + """Empty hex string should return empty list.""" + assert decode_reservation_hex("") == [] + + +class TestBuildReservationEntryTempValidation: + """Tests for unit-aware temperature validation.""" + + @patch("nwp500.unit_system.get_unit_system", return_value="metric") + def test_celsius_defaults(self, _mock_unit): + """Metric mode should use Celsius defaults (35-65).""" + # 50°C is valid in metric mode + result = build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=50.0, + ) + assert "param" in result + + @patch("nwp500.unit_system.get_unit_system", return_value="metric") + def test_celsius_rejects_fahrenheit_values(self, _mock_unit): + """Values outside Celsius range should be rejected in metric mode.""" + from nwp500.exceptions import RangeValidationError + + with pytest.raises(RangeValidationError): + build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=140.0, # Fahrenheit value, too high for Celsius + ) + + @patch("nwp500.unit_system.get_unit_system", return_value="us_customary") + def test_fahrenheit_defaults(self, _mock_unit): + """US customary mode should use Fahrenheit defaults (95-150).""" + result = build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=140.0, + ) + assert "param" in result + + @patch("nwp500.unit_system.get_unit_system", return_value="us_customary") + def test_fahrenheit_rejects_low_celsius(self, _mock_unit): + """Values outside Fahrenheit range should be rejected in US mode.""" + from nwp500.exceptions import RangeValidationError + + with pytest.raises(RangeValidationError): + build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=50.0, # Celsius value, too low for Fahrenheit + ) + + @patch("nwp500.unit_system.get_unit_system", return_value="metric") + def test_explicit_limits_override_defaults(self, _mock_unit): + """Explicit temperature_min/max should override unit defaults.""" + result = build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=80.0, + temperature_min=70.0, + temperature_max=90.0, + ) + assert "param" in result diff --git a/tests/test_mqtt_reconnection.py b/tests/test_mqtt_reconnection.py new file mode 100644 index 0000000..d5f0382 --- /dev/null +++ b/tests/test_mqtt_reconnection.py @@ -0,0 +1,130 @@ +"""Tests for MQTT reconnection: old connection cleanup.""" + +from __future__ import annotations + +import concurrent.futures +from unittest.mock import MagicMock + +import pytest + +from nwp500.auth import AuthenticationResponse, AuthTokens, UserInfo +from nwp500.mqtt.connection import MqttConnection +from nwp500.mqtt.utils import MqttConnectionConfig + + +@pytest.fixture +def mock_auth_client(): + """Create a mock auth client with valid tokens.""" + from nwp500.auth import NavienAuthClient + + client = NavienAuthClient("test@example.com", "password") + valid_tokens = AuthTokens( + id_token="test_id", + access_token="test_access", + refresh_token="test_refresh", + authentication_expires_in=3600, + access_key_id="test_key_id", + secret_key="test_secret_key", + session_token="test_session", + authorization_expires_in=3600, + ) + client._auth_response = AuthenticationResponse( + user_info=UserInfo(user_first_name="Test", user_last_name="User"), + tokens=valid_tokens, + ) + return client + + +@pytest.fixture +def config(): + return MqttConnectionConfig(client_id="test-client") + + +def _make_cf_future(result=None, exception=None): + """Return a resolved concurrent.futures.Future matching SDK behaviour.""" + f = concurrent.futures.Future() + if exception is not None: + f.set_exception(exception) + else: + f.set_result(result) + return f + + +class TestMqttConnectionClose: + """Tests for MqttConnection.close() method.""" + + @pytest.mark.asyncio(loop_scope="function") + async def test_close_on_none_connection(self, config, mock_auth_client): + """close() should be safe to call when _connection is None.""" + conn = MqttConnection(config, mock_auth_client) + assert conn._connection is None + # Should not raise + await conn.close() + assert conn._connected is False + assert conn._connection is None + + @pytest.mark.asyncio(loop_scope="function") + async def test_close_clears_state(self, config, mock_auth_client): + """close() should clear _connection and _connected regardless. + + Uses concurrent.futures.Future to match what the AWS IoT SDK + returns from disconnect(), exercising the wrap_future() path. + """ + conn = MqttConnection(config, mock_auth_client) + mock_sdk_conn = MagicMock() + mock_sdk_conn.disconnect.return_value = _make_cf_future() + conn._connection = mock_sdk_conn + conn._connected = False # Interrupted state + + await conn.close() + + assert conn._connection is None + assert conn._connected is False + mock_sdk_conn.disconnect.assert_called_once() + + @pytest.mark.asyncio(loop_scope="function") + async def test_disconnect_skips_when_not_connected( + self, config, mock_auth_client + ): + """disconnect() should skip when _connected is False.""" + conn = MqttConnection(config, mock_auth_client) + mock_sdk_conn = MagicMock() + conn._connection = mock_sdk_conn + conn._connected = False # Interrupted state + + await conn.disconnect() + + # disconnect() should NOT call the SDK disconnect + mock_sdk_conn.disconnect.assert_not_called() + + @pytest.mark.asyncio(loop_scope="function") + async def test_close_handles_already_dead_connection( + self, config, mock_auth_client + ): + """close() should handle errors from SDK disconnect gracefully.""" + from awscrt.exceptions import AwsCrtError + + conn = MqttConnection(config, mock_auth_client) + mock_sdk_conn = MagicMock() + mock_sdk_conn.disconnect.return_value = _make_cf_future( + exception=AwsCrtError( + code=0, + name="AWS_ERROR_MQTT_CONNECTION_DESTROYED", + message="Connection destroyed", + ) + ) + conn._connection = mock_sdk_conn + conn._connected = False + + # Should not raise + await conn.close() + assert conn._connection is None + + @pytest.mark.asyncio(loop_scope="function") + async def test_close_idempotent(self, config, mock_auth_client): + """close() should be safe to call multiple times.""" + conn = MqttConnection(config, mock_auth_client) + # Call twice - should not raise + await conn.close() + await conn.close() + assert conn._connection is None