From 72f68cec03e41aa292d75e3c3d3be55608e771fd Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Sat, 16 May 2026 14:11:29 -0700 Subject: [PATCH 1/8] fix: close old MQTT connection before reconnect to prevent flapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When _active_reconnect() creates a new MqttConnection, the old connection was never disconnected. The old SDK connection's built-in auto-reconnect would eventually succeed, creating two active connections with the same client ID. AWS IoT only allows one connection per client ID, so the broker kicks one off, triggering on_connection_interrupted and starting another reconnection — causing an infinite connect/disconnect loop. Changes: - Add MqttConnection.close() method that unconditionally tears down the underlying SDK connection regardless of _connected state (unlike disconnect() which skips when _connected is False after interruption) - _active_reconnect(): close old connection before creating replacement - _deep_reconnect(): use close() unconditionally instead of checking is_connected before calling disconnect() - Add tests for close() method behavior Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/nwp500/mqtt/client.py | 22 +++--- src/nwp500/mqtt/connection.py | 38 +++++++++++ tests/test_mqtt_reconnection.py | 115 ++++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 tests/test_mqtt_reconnection.py diff --git a/src/nwp500/mqtt/client.py b/src/nwp500/mqtt/client.py index 3a3a6f1..cd3dba2 100644 --- a/src/nwp500/mqtt/client.py +++ b/src/nwp500/mqtt/client.py @@ -371,7 +371,10 @@ async def _active_reconnect(self) -> None: reconnect instead of passively waiting for AWS IoT SDK. Note: This creates a new connection while preserving subscriptions - and configuration. + and configuration. The old connection is closed first to prevent + its SDK auto-reconnect from creating a competing connection with + the same client ID (which causes the broker to kick one off, + leading to an infinite connect/disconnect loop). """ if self._connected: _logger.debug("Already connected, skipping reconnection") @@ -385,12 +388,17 @@ async def _active_reconnect(self) -> None: # If we have a connection manager, try to reconnect using it if self._connection_manager: - # The connection might be in a bad state, so we need to - # recreate the underlying connection + # Close old connection to stop SDK auto-reconnect and + # prevent two connections with the same client ID. _logger.debug("Recreating MQTT connection...") + try: + await self._connection_manager.close() + except (AwsCrtError, RuntimeError) as e: + _logger.debug( + f"Old connection cleanup (benign): {e}" + ) # Create a new connection manager with same config - old_connection_manager = self._connection_manager self._connection_manager = MqttConnection( config=self.config, auth_client=self._auth_client, @@ -415,9 +423,6 @@ async def _active_reconnect(self) -> None: _logger.info("Active reconnection successful") else: - # Restore old connection manager and connection reference - self._connection_manager = old_connection_manager - self._connection = old_connection_manager.connection _logger.warning("Active reconnection failed") else: _logger.warning( @@ -458,8 +463,7 @@ async def _deep_reconnect(self) -> None: if self._connection_manager: _logger.debug("Cleaning up old connection...") try: - if self._connection_manager.is_connected: - await self._connection_manager.disconnect() + await self._connection_manager.close() except (AwsCrtError, RuntimeError) as e: # Expected: connection already dead or in bad state _logger.debug(f"Error during cleanup: {e} (expected)") diff --git a/src/nwp500/mqtt/connection.py b/src/nwp500/mqtt/connection.py index 4493ddb..f48ecc8 100644 --- a/src/nwp500/mqtt/connection.py +++ b/src/nwp500/mqtt/connection.py @@ -245,6 +245,44 @@ async def disconnect(self) -> None: _logger.error(f"Error during disconnect: {e}") raise + async def close(self) -> None: + """Unconditionally close the underlying SDK connection. + + Unlike :meth:`disconnect`, this method closes the connection + regardless of the ``_connected`` flag. After a connection + interruption, ``_connected`` is ``False`` but the SDK connection + object is still alive and its built-in auto-reconnect can still + fire. Calling ``close()`` ensures the SDK connection is fully + torn down so its callbacks and auto-reconnect cannot interfere + with a replacement connection. + + This method is safe to call multiple times or on already-closed + connections. + """ + connection = self._connection + self._connection = None + self._connected = False + + if connection is None: + return + + _logger.debug("Closing underlying SDK connection...") + try: + disconnect_future = cast( + asyncio.Future[Any], connection.disconnect() + ) + await asyncio.shield(asyncio.wrap_future(disconnect_future)) + _logger.debug("SDK connection closed") + except (AwsCrtError, RuntimeError) as e: + # Expected when connection is already dead or in bad state + _logger.debug(f"SDK connection close (benign): {e}") + except asyncio.CancelledError: + _logger.debug( + "Close operation cancelled but SDK disconnect " + "will complete in background" + ) + raise + async def subscribe( self, topic: str, diff --git a/tests/test_mqtt_reconnection.py b/tests/test_mqtt_reconnection.py new file mode 100644 index 0000000..50be9af --- /dev/null +++ b/tests/test_mqtt_reconnection.py @@ -0,0 +1,115 @@ +"""Tests for MQTT reconnection: old connection cleanup.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nwp500.auth import AuthTokens, AuthenticationResponse, UserInfo +from nwp500.mqtt.connection import MqttConnection +from nwp500.mqtt.utils import MqttConnectionConfig + + +@pytest.fixture +def mock_auth_client(): + """Create a mock auth client with valid tokens.""" + from nwp500.auth import NavienAuthClient + + client = NavienAuthClient("test@example.com", "password") + valid_tokens = AuthTokens( + id_token="test_id", + access_token="test_access", + refresh_token="test_refresh", + authentication_expires_in=3600, + access_key_id="test_key_id", + secret_key="test_secret_key", + session_token="test_session", + authorization_expires_in=3600, + ) + client._auth_response = AuthenticationResponse( + user_info=UserInfo(user_first_name="Test", user_last_name="User"), + tokens=valid_tokens, + ) + return client + + +@pytest.fixture +def config(): + return MqttConnectionConfig(client_id="test-client") + + +class TestMqttConnectionClose: + """Tests for MqttConnection.close() method.""" + + def test_close_on_none_connection(self, config, mock_auth_client): + """close() should be safe to call when _connection is None.""" + conn = MqttConnection(config, mock_auth_client) + assert conn._connection is None + # Should not raise + asyncio.get_event_loop().run_until_complete(conn.close()) + assert conn._connected is False + assert conn._connection is None + + def test_close_clears_state(self, config, mock_auth_client): + """close() should clear _connection and _connected regardless.""" + conn = MqttConnection(config, mock_auth_client) + # Simulate a connection that was interrupted (_connected=False + # but _connection still exists) + mock_sdk_conn = MagicMock() + future = asyncio.Future() + future.set_result(None) + mock_sdk_conn.disconnect.return_value = future + conn._connection = mock_sdk_conn + conn._connected = False # Interrupted state + + asyncio.get_event_loop().run_until_complete(conn.close()) + + assert conn._connection is None + assert conn._connected is False + mock_sdk_conn.disconnect.assert_called_once() + + def test_disconnect_skips_when_not_connected(self, config, mock_auth_client): + """disconnect() should skip when _connected is False (existing behavior).""" + conn = MqttConnection(config, mock_auth_client) + mock_sdk_conn = MagicMock() + conn._connection = mock_sdk_conn + conn._connected = False # Interrupted state + + asyncio.get_event_loop().run_until_complete(conn.disconnect()) + + # disconnect() should NOT call the SDK disconnect + mock_sdk_conn.disconnect.assert_not_called() + + def test_close_handles_already_dead_connection( + self, config, mock_auth_client + ): + """close() should handle errors from SDK disconnect gracefully.""" + from awscrt.exceptions import AwsCrtError + + conn = MqttConnection(config, mock_auth_client) + mock_sdk_conn = MagicMock() + future = asyncio.Future() + future.set_exception( + AwsCrtError( + code=0, + name="AWS_ERROR_MQTT_CONNECTION_DESTROYED", + message="Connection destroyed", + ) + ) + mock_sdk_conn.disconnect.return_value = future + conn._connection = mock_sdk_conn + conn._connected = False + + # Should not raise + asyncio.get_event_loop().run_until_complete(conn.close()) + assert conn._connection is None + + def test_close_idempotent(self, config, mock_auth_client): + """close() should be safe to call multiple times.""" + conn = MqttConnection(config, mock_auth_client) + # Call twice - should not raise + asyncio.get_event_loop().run_until_complete(conn.close()) + asyncio.get_event_loop().run_until_complete(conn.close()) + assert conn._connection is None From e8944bcdb70993ba692f96732a6937d4db917214 Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Sat, 16 May 2026 14:24:53 -0700 Subject: [PATCH 2/8] fix: multiple bug fixes from comprehensive code audit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thread safety: - ensure_device_info_cached: use loop.call_soon_threadsafe for Future resolution from AWS SDK callback thread (prevents race/crash) MQTT reconnection: - Clamp deep_reconnect_threshold to minimum of 1 in config validation to prevent ZeroDivisionError in reconnection backoff logic Diagnostics: - Increment total_reconnect_attempts counter on each connection drop (was always 0 despite reconnections occurring) - Replace float('inf') default for shortest_session_seconds with None in to_dict() to prevent JSON serialization errors Events: - Use asyncio.get_running_loop().create_future() instead of bare asyncio.Future() in wait_for() for proper loop binding Encoding: - Make build_reservation_entry temperature validation unit-aware: defaults are now 35-65°C in metric mode, 95-150°F in US mode (was hardcoded to Fahrenheit, breaking Celsius users) - Log warning on malformed reservation hex data with trailing bytes instead of silently dropping partial entries Periodic requests: - Log error and break on unknown PeriodicRequestType instead of silently doing nothing Cache: - Purge expired entries from device_info_cache during get_all_cached() instead of only filtering them from results (prevents memory leak) All 486 tests pass (23 new tests added). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/nwp500/device_info_cache.py | 12 +- src/nwp500/encoding.py | 35 ++++-- src/nwp500/events.py | 2 +- src/nwp500/mqtt/client.py | 8 +- src/nwp500/mqtt/diagnostics.py | 7 +- src/nwp500/mqtt/periodic.py | 6 + src/nwp500/mqtt/utils.py | 4 +- tests/test_bug_fixes.py | 199 ++++++++++++++++++++++++++++++++ 8 files changed, 255 insertions(+), 18 deletions(-) create mode 100644 tests/test_bug_fixes.py diff --git a/src/nwp500/device_info_cache.py b/src/nwp500/device_info_cache.py index f131f2b..da2fa69 100644 --- a/src/nwp500/device_info_cache.py +++ b/src/nwp500/device_info_cache.py @@ -137,11 +137,17 @@ async def get_all_cached(self) -> dict[str, DeviceFeature]: Dictionary mapping MAC addresses to DeviceFeature objects """ async with self._lock: - # Filter out expired entries + # Filter out expired entries and purge them from cache + expired_keys = [ + mac + for mac, (_, timestamp) in self._cache.items() + if self.is_expired(timestamp) + ] + for mac in expired_keys: + del self._cache[mac] return { mac: features - for mac, (features, timestamp) in self._cache.items() - if not self.is_expired(timestamp) + for mac, (features, _) in self._cache.items() } async def get_cache_info( diff --git a/src/nwp500/encoding.py b/src/nwp500/encoding.py index 5f783dd..f20abc8 100644 --- a/src/nwp500/encoding.py +++ b/src/nwp500/encoding.py @@ -8,11 +8,14 @@ from __future__ import annotations +import logging from collections.abc import Iterable from numbers import Real from .exceptions import ParameterValidationError, RangeValidationError +_logger = logging.getLogger(__name__) + # MGPP Week Bitfield Encoding (from NaviLink APK KDEnum.MgppReservationWeek). # Uses a single byte where bits 1-7 represent days; bit 0 is unused. # @@ -342,14 +345,18 @@ def decode_reservation_hex(hex_string: str) -> list[dict[str, int]]: data = bytes.fromhex(hex_string) reservations = [] + if len(data) % 6 != 0: + _logger.warning( + "Reservation hex data length %d is not a multiple of 6; " + "trailing %d bytes will be ignored", + len(data), + len(data) % 6, + ) + # Process 6 bytes at a time - for i in range(0, len(data), 6): + for i in range(0, len(data) - (len(data) % 6), 6): chunk = data[i : i + 6] - # Ensure we have a full 6-byte entry - if len(chunk) != 6: - break - # Skip empty entries (all zeros) if all(b == 0 for b in chunk): continue @@ -425,11 +432,23 @@ def build_reservation_entry( """ # Import here to avoid circular import from .models import preferred_to_half_celsius + from .unit_system import get_unit_system # Use device-provided limits if available, otherwise use defaults - # Defaults are conservative: 95°F / 35°C minimum, 150°F / 65°C maximum - min_temp = temperature_min if temperature_min is not None else 95 - max_temp = temperature_max if temperature_max is not None else 150 + # in the user's preferred unit system. + if temperature_min is not None: + min_temp = temperature_min + elif get_unit_system() == "metric": + min_temp = 35.0 # ~35°C + else: + min_temp = 95.0 # 95°F + + if temperature_max is not None: + max_temp = temperature_max + elif get_unit_system() == "metric": + max_temp = 65.0 # ~65°C + else: + max_temp = 150.0 # 150°F if not 0 <= hour <= 23: raise RangeValidationError( diff --git a/src/nwp500/events.py b/src/nwp500/events.py index 5101669..b6dc46e 100644 --- a/src/nwp500/events.py +++ b/src/nwp500/events.py @@ -396,7 +396,7 @@ async def wait_for( current_temp = temperature_event.new_temperature """ future: asyncio.Future[tuple[tuple[Any, ...], dict[str, Any]]] = ( - asyncio.Future() + asyncio.get_running_loop().create_future() ) def handler(*args: Any, **kwargs: Any) -> None: diff --git a/src/nwp500/mqtt/client.py b/src/nwp500/mqtt/client.py index cd3dba2..eeb3910 100644 --- a/src/nwp500/mqtt/client.py +++ b/src/nwp500/mqtt/client.py @@ -1298,14 +1298,14 @@ async def ensure_device_info_cached( return True # Not cached, request and wait - future: asyncio.Future[DeviceFeature] = ( - asyncio.get_running_loop().create_future() - ) + loop = asyncio.get_running_loop() + future: asyncio.Future[DeviceFeature] = loop.create_future() def on_feature(feature: DeviceFeature) -> None: + # Called from AWS SDK thread — must use thread-safe method if not future.done(): _logger.info(f"Device feature received for {redacted_mac}") - future.set_result(feature) + loop.call_soon_threadsafe(future.set_result, feature) _logger.info(f"Ensuring device info cached for {redacted_mac}") await self.subscribe_device_feature(device, on_feature) diff --git a/src/nwp500/mqtt/diagnostics.py b/src/nwp500/mqtt/diagnostics.py index d5ca647..4df6a32 100644 --- a/src/nwp500/mqtt/diagnostics.py +++ b/src/nwp500/mqtt/diagnostics.py @@ -95,7 +95,11 @@ class MqttMetrics: def to_dict(self) -> dict[str, Any]: """Convert to dictionary for JSON serialization.""" - return asdict(self) + d = asdict(self) + # Replace inf with None for JSON compatibility + if d.get("shortest_session_seconds") == float("inf"): + d["shortest_session_seconds"] = None + return d class MqttDiagnosticsCollector: @@ -213,6 +217,7 @@ async def record_connection_drop( # Update metrics self._metrics.total_connection_drops += 1 + self._metrics.total_reconnect_attempts += 1 if error_name: self._metrics.connection_drops_by_error[error_name] = ( self._metrics.connection_drops_by_error.get(error_name, 0) + 1 diff --git a/src/nwp500/mqtt/periodic.py b/src/nwp500/mqtt/periodic.py index 24b6f7d..013b3ea 100644 --- a/src/nwp500/mqtt/periodic.py +++ b/src/nwp500/mqtt/periodic.py @@ -173,6 +173,12 @@ async def periodic_request() -> None: await self._request_device_info(device) elif request_type == PeriodicRequestType.DEVICE_STATUS: await self._request_device_status(device) + else: + _logger.error( + "Unknown periodic request type: %s", + request_type, + ) + break _logger.debug( "Sent periodic %s request for %s", diff --git a/src/nwp500/mqtt/utils.py b/src/nwp500/mqtt/utils.py index 9d677bf..8e126e3 100644 --- a/src/nwp500/mqtt/utils.py +++ b/src/nwp500/mqtt/utils.py @@ -233,11 +233,13 @@ class MqttConnectionConfig: max_queued_commands: int = 100 def __post_init__(self) -> None: - """Generate client ID if not provided.""" + """Generate client ID if not provided and validate settings.""" if not self.client_id: object.__setattr__( self, "client_id", f"navien-client-{uuid.uuid4().hex[:8]}" ) + if self.deep_reconnect_threshold < 1: + object.__setattr__(self, "deep_reconnect_threshold", 1) @dataclass diff --git a/tests/test_bug_fixes.py b/tests/test_bug_fixes.py new file mode 100644 index 0000000..5dfc0d6 --- /dev/null +++ b/tests/test_bug_fixes.py @@ -0,0 +1,199 @@ +"""Tests for bug fixes: diagnostics, config validation, encoding, cache.""" + +from __future__ import annotations + +import asyncio +import json +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nwp500.encoding import build_reservation_entry, decode_reservation_hex +from nwp500.events import EventEmitter +from nwp500.mqtt.diagnostics import MqttDiagnosticsCollector, MqttMetrics +from nwp500.mqtt.utils import MqttConnectionConfig + + +class TestMqttConnectionConfigValidation: + """Tests for MqttConnectionConfig validation.""" + + def test_deep_reconnect_threshold_zero_clamped(self): + """deep_reconnect_threshold=0 should be clamped to 1.""" + config = MqttConnectionConfig(deep_reconnect_threshold=0) + assert config.deep_reconnect_threshold == 1 + + def test_deep_reconnect_threshold_negative_clamped(self): + """Negative deep_reconnect_threshold should be clamped to 1.""" + config = MqttConnectionConfig(deep_reconnect_threshold=-5) + assert config.deep_reconnect_threshold == 1 + + def test_deep_reconnect_threshold_valid_preserved(self): + """Valid deep_reconnect_threshold should be preserved.""" + config = MqttConnectionConfig(deep_reconnect_threshold=5) + assert config.deep_reconnect_threshold == 5 + + def test_default_deep_reconnect_threshold(self): + """Default deep_reconnect_threshold should be 10.""" + config = MqttConnectionConfig() + assert config.deep_reconnect_threshold == 10 + + +class TestDiagnosticsReconnectCounter: + """Tests for total_reconnect_attempts counter.""" + + @pytest.mark.asyncio(loop_scope="function") + async def test_reconnect_attempts_incremented_on_drop(self): + """total_reconnect_attempts should increment on each drop.""" + collector = MqttDiagnosticsCollector() + assert collector._metrics.total_reconnect_attempts == 0 + + await collector.record_connection_drop(error=RuntimeError("test")) + assert collector._metrics.total_reconnect_attempts == 1 + + await collector.record_connection_drop(error=RuntimeError("test2")) + assert collector._metrics.total_reconnect_attempts == 2 + + @pytest.mark.asyncio(loop_scope="function") + async def test_drop_increments_both_counters(self): + """Both total_connection_drops and total_reconnect_attempts update.""" + collector = MqttDiagnosticsCollector() + await collector.record_connection_drop(error=RuntimeError("test")) + assert collector._metrics.total_connection_drops == 1 + assert collector._metrics.total_reconnect_attempts == 1 + + +class TestMqttMetricsSerialization: + """Tests for MqttMetrics JSON serialization.""" + + def test_to_dict_replaces_inf(self): + """to_dict should replace inf with None for JSON compatibility.""" + metrics = MqttMetrics() + d = metrics.to_dict() + assert d["shortest_session_seconds"] is None + + def test_to_dict_preserves_real_value(self): + """to_dict should preserve real shortest_session_seconds values.""" + metrics = MqttMetrics(shortest_session_seconds=42.5) + d = metrics.to_dict() + assert d["shortest_session_seconds"] == 42.5 + + def test_to_dict_json_serializable(self): + """Default MqttMetrics should be JSON-serializable.""" + metrics = MqttMetrics() + # Should not raise + result = json.dumps(metrics.to_dict()) + assert isinstance(result, str) + + +class TestEventEmitterFuture: + """Tests for EventEmitter.wait_for future creation.""" + + @pytest.mark.asyncio(loop_scope="function") + async def test_wait_for_uses_running_loop(self): + """wait_for should create future bound to running loop.""" + emitter = EventEmitter() + + async def emit_soon(): + await asyncio.sleep(0.01) + await emitter.emit("test_event", "data") + + asyncio.create_task(emit_soon()) + result = await emitter.wait_for("test_event", timeout=1.0) + assert result == ("data",) + + +class TestDecodeReservationHex: + """Tests for decode_reservation_hex partial data handling.""" + + def test_valid_hex_decoded(self): + """Valid 6-byte entries should be decoded.""" + result = decode_reservation_hex("013e061e0478") + assert len(result) == 1 + assert result[0]["enable"] == 1 + + def test_partial_entry_logged_and_skipped(self): + """Partial trailing bytes should be skipped with warning.""" + # 6 valid bytes + 3 trailing bytes + result = decode_reservation_hex("013e061e0478aabbcc") + assert len(result) == 1 # Only the complete entry + + def test_empty_hex(self): + """Empty hex string should return empty list.""" + assert decode_reservation_hex("") == [] + + +class TestBuildReservationEntryTempValidation: + """Tests for unit-aware temperature validation.""" + + @patch("nwp500.unit_system.get_unit_system", return_value="metric") + def test_celsius_defaults(self, _mock_unit): + """Metric mode should use Celsius defaults (35-65).""" + # 50°C is valid in metric mode + result = build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=50.0, + ) + assert "param" in result + + @patch("nwp500.unit_system.get_unit_system", return_value="metric") + def test_celsius_rejects_fahrenheit_values(self, _mock_unit): + """Values outside Celsius range should be rejected in metric mode.""" + from nwp500.exceptions import RangeValidationError + + with pytest.raises(RangeValidationError): + build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=140.0, # Fahrenheit value, too high for Celsius + ) + + @patch("nwp500.unit_system.get_unit_system", return_value="us_customary") + def test_fahrenheit_defaults(self, _mock_unit): + """US customary mode should use Fahrenheit defaults (95-150).""" + result = build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=140.0, + ) + assert "param" in result + + @patch("nwp500.unit_system.get_unit_system", return_value="us_customary") + def test_fahrenheit_rejects_low_celsius(self, _mock_unit): + """Values outside Fahrenheit range should be rejected in US mode.""" + from nwp500.exceptions import RangeValidationError + + with pytest.raises(RangeValidationError): + build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=50.0, # Celsius value, too low for Fahrenheit + ) + + @patch("nwp500.unit_system.get_unit_system", return_value="metric") + def test_explicit_limits_override_defaults(self, _mock_unit): + """Explicit temperature_min/max should override unit defaults.""" + result = build_reservation_entry( + enabled=True, + days=["Monday"], + hour=6, + minute=30, + mode_id=3, + temperature=80.0, + temperature_min=70.0, + temperature_max=90.0, + ) + assert "param" in result From 3f461e07f4ad004bcd9b964cbb713184aefedfae Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Sat, 16 May 2026 14:28:58 -0700 Subject: [PATCH 3/8] fix: address CI failures (ruff lint + Python 3.14 compatibility) - Remove unused imports from test_bug_fixes.py (ruff F401) - Convert test_mqtt_reconnection.py from asyncio.get_event_loop() to @pytest.mark.asyncio (asyncio.get_event_loop() removed in 3.14) - Fix import sorting in test_mqtt_reconnection.py (ruff I001) - Shorten docstring line to fit 80-char limit (ruff E501) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/test_bug_fixes.py | 3 +-- tests/test_mqtt_reconnection.py | 41 ++++++++++++++++++++------------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/tests/test_bug_fixes.py b/tests/test_bug_fixes.py index 5dfc0d6..08765ea 100644 --- a/tests/test_bug_fixes.py +++ b/tests/test_bug_fixes.py @@ -4,8 +4,7 @@ import asyncio import json -from datetime import UTC, datetime, timedelta -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import patch import pytest diff --git a/tests/test_mqtt_reconnection.py b/tests/test_mqtt_reconnection.py index 50be9af..9c3d2d0 100644 --- a/tests/test_mqtt_reconnection.py +++ b/tests/test_mqtt_reconnection.py @@ -3,11 +3,11 @@ from __future__ import annotations import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock import pytest -from nwp500.auth import AuthTokens, AuthenticationResponse, UserInfo +from nwp500.auth import AuthenticationResponse, AuthTokens, UserInfo from nwp500.mqtt.connection import MqttConnection from nwp500.mqtt.utils import MqttConnectionConfig @@ -43,46 +43,53 @@ def config(): class TestMqttConnectionClose: """Tests for MqttConnection.close() method.""" - def test_close_on_none_connection(self, config, mock_auth_client): + @pytest.mark.asyncio(loop_scope="function") + async def test_close_on_none_connection(self, config, mock_auth_client): """close() should be safe to call when _connection is None.""" conn = MqttConnection(config, mock_auth_client) assert conn._connection is None # Should not raise - asyncio.get_event_loop().run_until_complete(conn.close()) + await conn.close() assert conn._connected is False assert conn._connection is None - def test_close_clears_state(self, config, mock_auth_client): + @pytest.mark.asyncio(loop_scope="function") + async def test_close_clears_state(self, config, mock_auth_client): """close() should clear _connection and _connected regardless.""" conn = MqttConnection(config, mock_auth_client) # Simulate a connection that was interrupted (_connected=False # but _connection still exists) mock_sdk_conn = MagicMock() - future = asyncio.Future() + loop = asyncio.get_running_loop() + future = loop.create_future() future.set_result(None) mock_sdk_conn.disconnect.return_value = future conn._connection = mock_sdk_conn conn._connected = False # Interrupted state - asyncio.get_event_loop().run_until_complete(conn.close()) + await conn.close() assert conn._connection is None assert conn._connected is False mock_sdk_conn.disconnect.assert_called_once() - def test_disconnect_skips_when_not_connected(self, config, mock_auth_client): - """disconnect() should skip when _connected is False (existing behavior).""" + @pytest.mark.asyncio(loop_scope="function") + async def test_disconnect_skips_when_not_connected( + self, config, mock_auth_client + ): + """disconnect() should skip when _connected is False.""" conn = MqttConnection(config, mock_auth_client) mock_sdk_conn = MagicMock() conn._connection = mock_sdk_conn conn._connected = False # Interrupted state - asyncio.get_event_loop().run_until_complete(conn.disconnect()) + await conn.disconnect() # disconnect() should NOT call the SDK disconnect mock_sdk_conn.disconnect.assert_not_called() - def test_close_handles_already_dead_connection( + @pytest.mark.asyncio(loop_scope="function") + async def test_close_handles_already_dead_connection( self, config, mock_auth_client ): """close() should handle errors from SDK disconnect gracefully.""" @@ -90,7 +97,8 @@ def test_close_handles_already_dead_connection( conn = MqttConnection(config, mock_auth_client) mock_sdk_conn = MagicMock() - future = asyncio.Future() + loop = asyncio.get_running_loop() + future = loop.create_future() future.set_exception( AwsCrtError( code=0, @@ -103,13 +111,14 @@ def test_close_handles_already_dead_connection( conn._connected = False # Should not raise - asyncio.get_event_loop().run_until_complete(conn.close()) + await conn.close() assert conn._connection is None - def test_close_idempotent(self, config, mock_auth_client): + @pytest.mark.asyncio(loop_scope="function") + async def test_close_idempotent(self, config, mock_auth_client): """close() should be safe to call multiple times.""" conn = MqttConnection(config, mock_auth_client) # Call twice - should not raise - asyncio.get_event_loop().run_until_complete(conn.close()) - asyncio.get_event_loop().run_until_complete(conn.close()) + await conn.close() + await conn.close() assert conn._connection is None From 19f51e5355ff55ef6dfdf4d663b6763b076f9375 Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Sat, 16 May 2026 14:31:01 -0700 Subject: [PATCH 4/8] style: apply ruff format to device_info_cache.py and mqtt/client.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/nwp500/device_info_cache.py | 5 +---- src/nwp500/mqtt/client.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/nwp500/device_info_cache.py b/src/nwp500/device_info_cache.py index da2fa69..a1f5865 100644 --- a/src/nwp500/device_info_cache.py +++ b/src/nwp500/device_info_cache.py @@ -145,10 +145,7 @@ async def get_all_cached(self) -> dict[str, DeviceFeature]: ] for mac in expired_keys: del self._cache[mac] - return { - mac: features - for mac, (features, _) in self._cache.items() - } + return {mac: features for mac, (features, _) in self._cache.items()} async def get_cache_info( self, diff --git a/src/nwp500/mqtt/client.py b/src/nwp500/mqtt/client.py index eeb3910..eaeae5d 100644 --- a/src/nwp500/mqtt/client.py +++ b/src/nwp500/mqtt/client.py @@ -394,9 +394,7 @@ async def _active_reconnect(self) -> None: try: await self._connection_manager.close() except (AwsCrtError, RuntimeError) as e: - _logger.debug( - f"Old connection cleanup (benign): {e}" - ) + _logger.debug(f"Old connection cleanup (benign): {e}") # Create a new connection manager with same config self._connection_manager = MqttConnection( From 29ad8a07abb3f789f48dfd20c82cd448fc866e81 Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Sat, 16 May 2026 14:39:28 -0700 Subject: [PATCH 5/8] Address PR review comments - client.py on_feature: move future.done() check inside call_soon_threadsafe callback to eliminate race between done-check and set_result across the SDK thread / event loop thread boundary - tests/test_mqtt_reconnection.py: replace asyncio.Future with concurrent.futures.Future to match what the AWS IoT SDK actually returns, exercising the wrap_future() conversion path in close() - encoding.py build_reservation_entry: read get_unit_system() once into a local variable instead of calling it twice Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/nwp500/encoding.py | 7 +++++-- src/nwp500/mqtt/client.py | 14 ++++++++++---- tests/test_mqtt_reconnection.py | 32 +++++++++++++++++++------------- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/nwp500/encoding.py b/src/nwp500/encoding.py index f20abc8..2566c21 100644 --- a/src/nwp500/encoding.py +++ b/src/nwp500/encoding.py @@ -434,18 +434,21 @@ def build_reservation_entry( from .models import preferred_to_half_celsius from .unit_system import get_unit_system + # Read unit system once to keep min/max bounds consistent + unit_system = get_unit_system() + # Use device-provided limits if available, otherwise use defaults # in the user's preferred unit system. if temperature_min is not None: min_temp = temperature_min - elif get_unit_system() == "metric": + elif unit_system == "metric": min_temp = 35.0 # ~35°C else: min_temp = 95.0 # 95°F if temperature_max is not None: max_temp = temperature_max - elif get_unit_system() == "metric": + elif unit_system == "metric": max_temp = 65.0 # ~65°C else: max_temp = 150.0 # 150°F diff --git a/src/nwp500/mqtt/client.py b/src/nwp500/mqtt/client.py index eaeae5d..6a331a3 100644 --- a/src/nwp500/mqtt/client.py +++ b/src/nwp500/mqtt/client.py @@ -1300,10 +1300,16 @@ async def ensure_device_info_cached( future: asyncio.Future[DeviceFeature] = loop.create_future() def on_feature(feature: DeviceFeature) -> None: - # Called from AWS SDK thread — must use thread-safe method - if not future.done(): - _logger.info(f"Device feature received for {redacted_mac}") - loop.call_soon_threadsafe(future.set_result, feature) + # Called from AWS SDK thread — schedule onto the event loop + # thread-safely. The done() check is inside the scheduled + # callback so it runs on the event loop thread, eliminating + # the race between the check and set_result. + def _set_result() -> None: + if not future.done(): + _logger.info(f"Device feature received for {redacted_mac}") + future.set_result(feature) + + loop.call_soon_threadsafe(_set_result) _logger.info(f"Ensuring device info cached for {redacted_mac}") await self.subscribe_device_feature(device, on_feature) diff --git a/tests/test_mqtt_reconnection.py b/tests/test_mqtt_reconnection.py index 9c3d2d0..d5f0382 100644 --- a/tests/test_mqtt_reconnection.py +++ b/tests/test_mqtt_reconnection.py @@ -2,7 +2,7 @@ from __future__ import annotations -import asyncio +import concurrent.futures from unittest.mock import MagicMock import pytest @@ -40,6 +40,16 @@ def config(): return MqttConnectionConfig(client_id="test-client") +def _make_cf_future(result=None, exception=None): + """Return a resolved concurrent.futures.Future matching SDK behaviour.""" + f = concurrent.futures.Future() + if exception is not None: + f.set_exception(exception) + else: + f.set_result(result) + return f + + class TestMqttConnectionClose: """Tests for MqttConnection.close() method.""" @@ -55,15 +65,14 @@ async def test_close_on_none_connection(self, config, mock_auth_client): @pytest.mark.asyncio(loop_scope="function") async def test_close_clears_state(self, config, mock_auth_client): - """close() should clear _connection and _connected regardless.""" + """close() should clear _connection and _connected regardless. + + Uses concurrent.futures.Future to match what the AWS IoT SDK + returns from disconnect(), exercising the wrap_future() path. + """ conn = MqttConnection(config, mock_auth_client) - # Simulate a connection that was interrupted (_connected=False - # but _connection still exists) mock_sdk_conn = MagicMock() - loop = asyncio.get_running_loop() - future = loop.create_future() - future.set_result(None) - mock_sdk_conn.disconnect.return_value = future + mock_sdk_conn.disconnect.return_value = _make_cf_future() conn._connection = mock_sdk_conn conn._connected = False # Interrupted state @@ -97,16 +106,13 @@ async def test_close_handles_already_dead_connection( conn = MqttConnection(config, mock_auth_client) mock_sdk_conn = MagicMock() - loop = asyncio.get_running_loop() - future = loop.create_future() - future.set_exception( - AwsCrtError( + mock_sdk_conn.disconnect.return_value = _make_cf_future( + exception=AwsCrtError( code=0, name="AWS_ERROR_MQTT_CONNECTION_DESTROYED", message="Connection destroyed", ) ) - mock_sdk_conn.disconnect.return_value = future conn._connection = mock_sdk_conn conn._connected = False From 3f94e3e2ef40b2dc47fb988fcfda7a688422d179 Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Sat, 16 May 2026 17:01:15 -0700 Subject: [PATCH 6/8] fix versioning --- CHANGELOG.rst | 63 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3c837ba..5d48961 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,8 +2,67 @@ Changelog ========= -Unreleased (8.x) -================ +Unreleased +========== + +Bug Fixes +--------- +- **Fix MQTT connection flapping after reconnect**: When ``_active_reconnect()`` + created a new ``MqttConnection``, the old connection was never closed. The old + SDK connection's built-in auto-reconnect would eventually succeed, creating two + active connections sharing the same client ID. Because AWS IoT allows only one + connection per client ID, the broker would kick one off, triggering + ``on_connection_interrupted`` and starting yet another reconnection — an + infinite connect/disconnect loop. Fixed by adding ``MqttConnection.close()`` + (unconditional teardown regardless of ``_connected`` state) and calling it + before creating the replacement connection in both ``_active_reconnect()`` and + ``_deep_reconnect()``. + +- **Thread-safety race in ``ensure_device_info_cached``**: The ``future.done()`` + check and ``future.set_result()`` were performed in the AWS SDK callback thread + without synchronisation, creating a race against the asyncio event loop thread. + Moved both operations inside a ``call_soon_threadsafe`` callback so they execute + atomically on the event loop thread. + +- **ZeroDivisionError when ``deep_reconnect_threshold`` is 0**: Config validation + now clamps ``deep_reconnect_threshold`` to a minimum of 1, preventing a + ``ZeroDivisionError`` in the exponential-backoff reconnection logic. + +- **Reconnect counter never incremented**: ``total_reconnect_attempts`` in + diagnostics was not incremented on connection drops, so it always reported 0 + despite active reconnections. Counter is now incremented on each + ``on_connection_interrupted`` event. + +- **``shortest_session_seconds`` not JSON-serialisable**: The diagnostics + ``to_dict()`` method used ``float('inf')`` as the initial value for + ``shortest_session_seconds``, which is not valid JSON. Changed to ``None`` + so serialisation succeeds when no session has completed yet. + +- **``wait_for()`` future not bound to running loop**: ``wait_for()`` created a + bare ``asyncio.Future()`` rather than + ``asyncio.get_running_loop().create_future()``, which could bind the future to + a different loop in multi-loop test setups. + +- **Reservation temperature validation was US-only**: ``build_reservation_entry`` + validated set-point temperatures against hardcoded Fahrenheit bounds (95–150 °F) + regardless of the active unit system. Validation now uses the current unit system + context: 35–65 °C in metric mode, 95–150 °F in US mode. Celsius users previously + received spurious ``ValueError`` rejections for valid temperatures. + +- **Malformed reservation data silently dropped**: ``build_reservation_entry`` now + logs a warning when reservation hex data contains unexpected trailing bytes + instead of silently dropping partial entries. + +- **Unknown ``PeriodicRequestType`` silently ignored**: The periodic-request handler + now logs an error and breaks when it encounters an unknown request type instead of + doing nothing. + +- **Memory leak in device info cache**: ``get_all_cached()`` only filtered expired + entries from its return value but left them in the cache dictionary. Expired + entries are now evicted during ``get_all_cached()`` to prevent unbounded growth. + +Version 8.0.0 (2026-05-13) +=========================== Bug Fixes --------- From 852d8a0fec8727e9e2b7a10387379479e19fe9c4 Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Sat, 16 May 2026 17:04:47 -0700 Subject: [PATCH 7/8] handle the changelog with bumpversion --- scripts/bump_version.py | 77 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/scripts/bump_version.py b/scripts/bump_version.py index 800ee0f..2e37224 100755 --- a/scripts/bump_version.py +++ b/scripts/bump_version.py @@ -22,6 +22,8 @@ import re import subprocess import sys +from datetime import date +from pathlib import Path def run_git_command(args: list) -> str: @@ -140,6 +142,75 @@ def check_working_directory_clean() -> None: sys.exit(1) +def update_changelog(version: str) -> None: + """Insert a version heading into CHANGELOG.rst below the Unreleased section. + + Transforms: + + Unreleased + ========== + + + + into: + + Unreleased + ========== + + Version X.Y.Z (YYYY-MM-DD) + =========================== + + + """ + changelog_path = Path("CHANGELOG.rst") + if not changelog_path.exists(): + print("Warning: CHANGELOG.rst not found, skipping changelog update.") + return + + content = changelog_path.read_text(encoding="utf-8") + + heading = f"Version {version} ({date.today().isoformat()})" + underline = "=" * len(heading) + version_block = f"{heading}\n{underline}\n" + + # Match "Unreleased\n==========\n" (any number of = signs) followed by + # one or more blank lines, then insert the version block after them. + pattern = re.compile( + r"(Unreleased\n=+\n)" # group 1: Unreleased heading + r"(\n+)", # group 2: blank line(s) separator + re.MULTILINE, + ) + + match = pattern.search(content) + if not match: + print( + "Warning: Could not find 'Unreleased' section in CHANGELOG.rst. " + "Skipping changelog update.", + file=sys.stderr, + ) + return + + # Insert the version block after the blank lines that follow "Unreleased" + new_content = ( + content[: match.end()] + + version_block + + "\n" + + content[match.end() :] + ) + + changelog_path.write_text(new_content, encoding="utf-8") + print(f"[OK] Updated CHANGELOG.rst with {heading}") + + +def commit_changelog(version: str) -> None: + """Stage and commit the CHANGELOG.rst update.""" + run_git_command(["add", "CHANGELOG.rst"]) + run_git_command( + ["commit", "-m", f"Update changelog for v{version}"] + ) + print("[OK] Committed changelog update") + + def create_tag(version: str, message: str = None) -> None: """Create a git tag for the version.""" tag_name = f"v{version}" @@ -223,6 +294,11 @@ def main() -> None: # Validate version progression validate_version_progression(current_version, new_version) + # Update CHANGELOG.rst and commit, then create the tag + print("\nUpdating CHANGELOG.rst...") + update_changelog(new_version) + commit_changelog(new_version) + # Create the tag print(f"\nCreating tag v{new_version}...") create_tag(new_version) @@ -230,6 +306,7 @@ def main() -> None: print("\n[OK] Version bump complete!") print("\nNext steps:") print(f" 1. Push the tag: git push origin v{new_version}") + print(" (also push the changelog commit: git push origin HEAD)") print(" 2. Build release: make build") print(" 3. Test on TestPyPI: make publish-test") print(" 4. Publish to PyPI: make publish") From 83141493874cee7c03ad268b41605e0ac14be19b Mon Sep 17 00:00:00 2001 From: Emmanuel Levijarvi Date: Sat, 16 May 2026 17:05:29 -0700 Subject: [PATCH 8/8] Update changelog for v8.1.0 --- CHANGELOG.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5d48961..8039182 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,9 @@ Changelog Unreleased ========== +Version 8.1.0 (2026-05-16) +========================== + Bug Fixes --------- - **Fix MQTT connection flapping after reconnect**: When ``_active_reconnect()``