diff --git a/README.md b/README.md index 49e1dda..10f5104 100644 --- a/README.md +++ b/README.md @@ -105,10 +105,20 @@ The fixture-driven `tests/test_device_decode.py` now asserts that the raw payloa `WorxCloud` accepts a `command_timeout` argument (seconds) that controls how long MQTT command calls wait for a matching mower response before raising `TimeoutException`. +Initial MQTT connection attempts use a separate `mqtt_connect_timeout` argument. This keeps API-backed startup responsive when the cloud MQTT service is unavailable, while preserving a longer timeout for mower command responses. + +The read-only `mqtt_connected` property reports whether the MQTT client is currently active. It is available on both the `WorxCloud` instance and the mapped device objects. It is `False` when startup falls back to API-only mode, after disconnects, and whenever the MQTT client loses its connection. + ```python from pyworxcloud import WorxCloud -cloud = WorxCloud("user@example.com", "secret", "worx", command_timeout=15.0) +cloud = WorxCloud( + "user@example.com", + "secret", + "worx", + command_timeout=15.0, + mqtt_connect_timeout=8.0, +) ``` ## Schedule CRUD diff --git a/pyworxcloud/__init__.py b/pyworxcloud/__init__.py index 1fbfed0..44289f0 100644 --- a/pyworxcloud/__init__.py +++ b/pyworxcloud/__init__.py @@ -18,6 +18,14 @@ from .api import LandroidCloudAPI from .clouds import CloudType +from .const import ( + API_REFRESH_TIME_MAX, + API_REFRESH_TIME_MIN, + DEFAULT_COMMAND_TIMEOUT, + DEFAULT_MQTT_CONNECT_TIMEOUT, + MQTT_RECONNECT_RETRY_SECONDS, + VISION_BORDER_DISTANCE_MM_VALUES, +) from .events import EventHandler, LandroidEvent from .exceptions import ( AuthorizationError, @@ -60,11 +68,6 @@ _LOGGER = logging.getLogger(__name__) -API_REFRESH_TIME_MIN = 5 -API_REFRESH_TIME_MAX = 10 -DEFAULT_COMMAND_TIMEOUT = 30.0 -VISION_BORDER_DISTANCE_MM_VALUES = (50, 100, 150, 200) - class WorxCloud(dict): """ @@ -88,6 +91,7 @@ def __init__( verify_ssl: bool = True, tz: str | None = None, # pylint: disable=invalid-name command_timeout: float = DEFAULT_COMMAND_TIMEOUT, + mqtt_connect_timeout: float = DEFAULT_MQTT_CONNECT_TIMEOUT, deduplicate_inflight_commands: bool = False, ) -> None: """ @@ -146,8 +150,6 @@ def __init__( _LOGGER.debug("Initializing connector...") super().__init__() - self._worx_mqtt_client_id = None - if not isinstance( cloud, ( @@ -175,11 +177,13 @@ def __init__( self._raw = None self._tz = tz - self._save_zones = None self._verify_ssl = verify_ssl if command_timeout <= 0: raise ValueError("command_timeout must be greater than 0") self._command_timeout = float(command_timeout) + if mqtt_connect_timeout <= 0: + raise ValueError("mqtt_connect_timeout must be greater than 0") + self._mqtt_connect_timeout = float(mqtt_connect_timeout) self._deduplicate_inflight_commands = bool(deduplicate_inflight_commands) _LOGGER.debug("Initializing EventHandler ...") self._events = EventHandler() @@ -191,9 +195,8 @@ def __init__( self._mowers_by_uuid: dict[str, dict[str, Any]] = {} self._mowers_by_mac: dict[str, dict[str, Any]] = {} - self._decoding: bool = False - self._api_refresh_task: asyncio.Task | None = None + self._mqtt_retry_task: asyncio.Task | None = None self._disconnecting = asyncio.Event() self._loop: asyncio.AbstractEventLoop | None = None self._sync_loop: asyncio.AbstractEventLoop | None = None @@ -297,39 +300,53 @@ async def disconnect(self) -> None: if self._api_refresh_task is not None: self._api_refresh_task.cancel() self._api_refresh_task = None + if self._mqtt_retry_task is not None: + self._mqtt_retry_task.cancel() + try: + await self._mqtt_retry_task + except asyncio.CancelledError: + pass + self._mqtt_retry_task = None # Disconnect MQTT connection try: - if self.mqtt is not None: - disconnect_failed = False - try: - started = time.perf_counter() - await self.mqtt.adisconnect() - logger.debug( - "MQTT adisconnect completed in %.3fs", - time.perf_counter() - started, - ) - except Exception as err: - disconnect_failed = True - logger.debug("Could not disconnect MQTT cleanly: %s", err) - - try: - started = time.perf_counter() - await self.mqtt.ashutdown() - logger.debug( - "MQTT ashutdown completed in %.3fs", - time.perf_counter() - started, - ) - except Exception as err: - logger.debug("Could not shutdown MQTT cleanly: %s", err) - if not disconnect_failed: - raise + await self._disconnect_mqtt(logger) finally: - self.mqtt = None started = time.perf_counter() await self._api.close() logger.debug("API close completed in %.3fs", time.perf_counter() - started) + async def _disconnect_mqtt(self, logger: logging.Logger) -> None: + """Disconnect and release the MQTT client without closing the API session.""" + if self.mqtt is None: + return + + disconnect_failed = False + try: + started = time.perf_counter() + await self.mqtt.adisconnect() + logger.debug( + "MQTT adisconnect completed in %.3fs", + time.perf_counter() - started, + ) + except Exception as err: + disconnect_failed = True + logger.debug("Could not disconnect MQTT cleanly: %s", err) + + try: + started = time.perf_counter() + await self.mqtt.ashutdown() + logger.debug( + "MQTT ashutdown completed in %.3fs", + time.perf_counter() - started, + ) + except Exception as err: + logger.debug("Could not shutdown MQTT cleanly: %s", err) + if not disconnect_failed: + raise + finally: + self.mqtt = None + async def connect( self, ) -> bool: @@ -354,25 +371,14 @@ async def connect( self._endpoint = self._mowers[0]["mqtt_endpoint"] self._user_id = self._mowers[0]["user_id"] - self._log.debug("Setting up MQTT handler") - # setup MQTT handler - self.mqtt = await asyncio.to_thread( - MQTT, - self._api, - self._cloud.BRAND_PREFIX, - self._endpoint, - self._user_id, - self._log, - self._on_update, - identifier_resolver=self._resolve_mower_identifiers, - deduplicate_inflight_commands=self._deduplicate_inflight_commands, - response_timeout=self._command_timeout, - ) - - await self.mqtt.aconnect() - - for mower in self._mowers: - await self.mqtt.asubscribe(mower["mqtt_topics"]["command_out"], True) + try: + await self._connect_mqtt_once(log_connect_errors=False) + except Exception: + logger.debug( + "MQTT connect failed; continuing with API refresh fallback" + ) + await self._disconnect_mqtt(logger) + self._schedule_mqtt_retry() # Convert time strings to objects. for name, device in self.devices.items(): @@ -397,16 +403,105 @@ async def connect( ) raise + async def _connect_mqtt_once(self, log_connect_errors: bool = True) -> None: + """Create, connect, and subscribe the MQTT client once.""" + self._log.debug("Setting up MQTT handler") + self.mqtt = await asyncio.to_thread( + MQTT, + self._api, + self._cloud.BRAND_PREFIX, + self._endpoint, + self._user_id, + self._log, + self._on_update, + identifier_resolver=self._resolve_mower_identifiers, + deduplicate_inflight_commands=self._deduplicate_inflight_commands, + response_timeout=self._command_timeout, + connect_timeout=self._mqtt_connect_timeout, + ) + self.mqtt._log_connect_errors = log_connect_errors + + await self.mqtt.aconnect() + + for mower in self._mowers: + await self.mqtt.asubscribe(mower["mqtt_topics"]["command_out"], True) + + def _schedule_mqtt_retry(self) -> None: + """Schedule background MQTT reconnect attempts without touching API polling.""" + if self._disconnecting.is_set(): + return + if self._mqtt_retry_task is not None and not self._mqtt_retry_task.done(): + return + self._mqtt_retry_task = asyncio.create_task(self._mqtt_retry_loop()) + + async def _mqtt_retry_loop(self) -> None: + """Retry MQTT in the background until it reconnects or the cloud disconnects.""" + logger = self._log.getChild("MQTT_Retry") + try: + while not self._disconnecting.is_set(): + await asyncio.sleep(MQTT_RECONNECT_RETRY_SECONDS) + if self._disconnecting.is_set(): + return + if self.mqtt is not None and self.mqtt.connected: + return + + try: + await self._disconnect_mqtt(logger) + await self._connect_mqtt_once(log_connect_errors=False) + except Exception: + logger.debug("Background MQTT reconnect failed; will retry") + await self._disconnect_mqtt(logger) + continue + + logger.debug("Background MQTT reconnect completed") + return + except asyncio.CancelledError: + raise + async def _token_updated(self) -> None: """Called when token is updated.""" - if self.mqtt is not None: + if self.mqtt_connected: await self.mqtt.aupdate_token() + elif self.mqtt is not None: + self._schedule_mqtt_retry() + + def _bind_device_mqtt_state(self, device: DeviceHandler) -> DeviceHandler: + """Bind a device object to the current cloud-level MQTT state.""" + device.set_mqtt_connected_resolver(lambda: self.mqtt_connected) + return device @property def auth_result(self) -> bool: """Return current authentication result.""" return self._auth_result + @property + def mqtt_connected(self) -> bool: + """Return whether the MQTT client is currently connected.""" + mqtt_client = self.mqtt + if mqtt_client is None: + return False + return bool(getattr(mqtt_client, "connected", False)) + + def _require_mqtt_connected(self) -> Any: + """Return the MQTT client or raise a connection error.""" + mqtt_client = self.mqtt + if mqtt_client is None or not self.mqtt_connected: + raise NoConnectionError("MQTT connection is not ready") + return mqtt_client + + async def _mqtt_apublish(self, *args: Any, **kwargs: Any) -> None: + """Publish over MQTT when the client is currently connected.""" + await self._require_mqtt_connected().apublish(*args, **kwargs) + + async def _mqtt_aping(self, *args: Any, **kwargs: Any) -> None: + """Ping over MQTT when the client is currently connected.""" + await self._require_mqtt_connected().aping(*args, **kwargs) + + async def _mqtt_acommand(self, *args: Any, **kwargs: Any) -> None: + """Send a command over MQTT when the client is currently connected.""" + await self._require_mqtt_connected().acommand(*args, **kwargs) + def _on_update(self, payload): # , topic, payload, dup, qos, retain, **kwargs): """Triggered when a MQTT message was received.""" logger = self._log.getChild("MQTT_data_in") @@ -577,6 +672,7 @@ async def _fetch(self, forced: bool = False) -> None: try: previous_device = self.devices.get(mower["name"]) device = DeviceHandler(self._api, mower, self._tz, False) + self._bind_device_mqtt_state(device) if not isinstance(mower["last_status"], type(None)): device.raw_data = mower["last_status"]["payload"] @@ -989,7 +1085,7 @@ async def _publish_schedule_payload( raise OfflineError("The device is currently offline, no action was sent.") identifier = mower["serial_number"] if mower["protocol"] == 0 else mower["uuid"] - await self.mqtt.apublish( + await self._mqtt_apublish( identifier, mower["mqtt_topics"]["command_in"], {"sc": sc_payload}, @@ -1052,14 +1148,14 @@ async def update(self, serial_number: str, timeout: float | None = None) -> None _LOGGER.debug("Trying to refresh '%s'", serial_number) try: - await self.mqtt.aping( + await self._mqtt_aping( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], mower["protocol"], timeout=timeout, ) - except NoConnectionError: - raise NoConnectionError from None + except NoConnectionError as err: + raise NoConnectionError(str(err)) from None async def start(self, serial_number: str) -> None: """Start mowing task @@ -1073,7 +1169,7 @@ async def start(self, serial_number: str) -> None: mower = self.get_mower(serial_number) if mower["online"]: _LOGGER.debug("Sending start command to '%s'", serial_number) - await self.mqtt.acommand( + await self._mqtt_acommand( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], Command.START, @@ -1096,7 +1192,7 @@ async def home(self, serial_number: str) -> None: mower = self.get_mower(serial_number) if mower["online"]: - await self.mqtt.acommand( + await self._mqtt_acommand( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], Command.HOME, @@ -1116,7 +1212,7 @@ async def safehome(self, serial_number: str) -> None: """ mower = self.get_mower(serial_number) if mower["online"]: - await self.mqtt.acommand( + await self._mqtt_acommand( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], Command.SAFEHOME, @@ -1136,7 +1232,7 @@ async def pause(self, serial_number: str) -> None: """ mower = self.get_mower(serial_number) if mower["online"]: - await self.mqtt.acommand( + await self._mqtt_acommand( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], Command.PAUSE, @@ -1161,7 +1257,7 @@ async def raindelay(self, serial_number: str, rain_delay: str) -> None: rain_delay, "rain_delay", minimum=0, maximum=1440 ) if mower["protocol"] == 0: - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number, mower["mqtt_topics"]["command_in"], {"rd": rain_delay}, @@ -1169,7 +1265,7 @@ async def raindelay(self, serial_number: str, rain_delay: str) -> None: ) else: # Protocol 1 requires rd to be wrapped in cfg - await self.mqtt.apublish( + await self._mqtt_apublish( mower["uuid"], mower["mqtt_topics"]["command_in"], {"cfg": {"rd": rain_delay}}, @@ -1191,7 +1287,7 @@ async def set_lock(self, serial_number: str, state: bool) -> None: state = self._require_bool(state, "state") mower = self.get_mower(serial_number) if mower["online"]: - await self.mqtt.acommand( + await self._mqtt_acommand( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], Command.LOCK if state else Command.UNLOCK, @@ -1218,7 +1314,7 @@ async def set_party_mode(self, serial_number: str, state: bool) -> None: device = DeviceHandler(self._api, mower, self._tz) if device.capabilities.check(DeviceCapability.PARTY_MODE): if mower["protocol"] == 0: - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], ( @@ -1229,7 +1325,7 @@ async def set_party_mode(self, serial_number: str, state: bool) -> None: mower["protocol"], ) else: - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], ( @@ -1273,7 +1369,7 @@ async def set_offlimits(self, serial_number: str, state: bool) -> None: _LOGGER.debug("Setting offlimits") device = DeviceHandler(self._api, mower, self._tz) if device.capabilities.check(DeviceCapability.OFF_LIMITS): - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number if device.protocol == 0 else device.uuid, mower["mqtt_topics"]["command_in"], ( @@ -1320,7 +1416,7 @@ async def set_offlimits_shortcut(self, serial_number: str, state: bool) -> None: _LOGGER.debug("Setting offlimits") device = DeviceHandler(self._api, mower, self._tz) if device.capabilities.check(DeviceCapability.OFF_LIMITS): - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number if device.protocol == 0 else device.uuid, mower["mqtt_topics"]["command_in"], ( @@ -1388,7 +1484,7 @@ async def setzone(self, serial_number: str, zone: str | int) -> None: new_zones.append(current_zones[(offset + i) % no_indices]) device = DeviceHandler(self._api, mower, self._tz) - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], {"mzv": new_zones}, @@ -1409,7 +1505,7 @@ async def zonetraining(self, serial_number: str) -> None: mower = self.get_mower(serial_number) if mower["online"]: _LOGGER.debug("Sending ZONETRAINING command to %s", mower["name"]) - await self.mqtt.acommand( + await self._mqtt_acommand( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], Command.ZONETRAINING, @@ -1430,7 +1526,7 @@ async def restart(self, serial_number: str): mower = self.get_mower(serial_number) if mower["online"]: _LOGGER.debug("Sending RESTART command to %s", mower["name"]) - await self.mqtt.acommand( + await self._mqtt_acommand( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], Command.RESTART, @@ -1845,7 +1941,7 @@ async def set_torque(self, serial_number: str, torque: int) -> None: mower = self.get_mower(serial_number) if mower["online"]: if mower["protocol"] == 0: - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number, mower["mqtt_topics"]["command_in"], {"tq": torque}, @@ -1853,7 +1949,7 @@ async def set_torque(self, serial_number: str, torque: int) -> None: ) else: # Protocol 1 requires tq to be wrapped in cfg - await self.mqtt.apublish( + await self._mqtt_apublish( mower["uuid"], mower["mqtt_topics"]["command_in"], {"cfg": {"tq": torque}}, @@ -1873,14 +1969,14 @@ async def edgecut(self, serial_number: str) -> None: device = DeviceHandler(self._api, mower, self._tz) if device.capabilities.check(DeviceCapability.EDGE_CUT): if mower["protocol"] == 0: - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number, mower["mqtt_topics"]["command_in"], {"sc": {"ots": {"bc": 1, "wtm": 0}}}, mower["protocol"], ) else: - await self.mqtt.apublish( + await self._mqtt_apublish( mower["uuid"], mower["mqtt_topics"]["command_in"], {"cmd": 101}, @@ -1913,14 +2009,14 @@ async def ots( device = DeviceHandler(self._api, mower, self._tz) if mower["protocol"] == 0: - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number, mower["mqtt_topics"]["command_in"], {"sc": {"ots": {"bc": int(boundary), "wtm": runtime}}}, mower["protocol"], ) else: - await self.mqtt.apublish( + await self._mqtt_apublish( mower["uuid"], mower["mqtt_topics"]["command_in"], { @@ -1980,7 +2076,7 @@ async def _set_border_cut_settings( cut_over_border=cut_over_border, border_distance=border_distance, ) - await self.mqtt.apublish( + await self._mqtt_apublish( mower["uuid"], mower["mqtt_topics"]["command_in"], {"cut": cut_payload}, @@ -2052,7 +2148,7 @@ async def send(self, serial_number: str, data: str) -> None: mower = self.get_mower(serial_number) if mower["online"]: _LOGGER.debug("Sending %s to %s", data, mower["name"]) - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], json.loads(data), @@ -2139,7 +2235,7 @@ async def set_cutting_height(self, serial_number: str, height: int) -> None: if mower["online"]: device = DeviceHandler(self._api, mower, self._tz) if device.capabilities.check(DeviceCapability.CUTTING_HEIGHT): - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], {"cmd": 0, "modules": {"EA": {"h": height}}}, @@ -2168,7 +2264,7 @@ async def set_acs(self, serial_number: str, state: bool) -> None: if mower["online"]: device = DeviceHandler(self._api, mower, self._tz) if device.capabilities.check(DeviceCapability.ACS): - await self.mqtt.apublish( + await self._mqtt_apublish( serial_number if mower["protocol"] == 0 else mower["uuid"], mower["mqtt_topics"]["command_in"], {"cmd": 0, "modules": {"US": {"enabled": 1 if state else 0}}}, diff --git a/pyworxcloud/const.py b/pyworxcloud/const.py index 0916743..072fe8d 100644 --- a/pyworxcloud/const.py +++ b/pyworxcloud/const.py @@ -3,6 +3,14 @@ from __future__ import annotations API_BASE = "https://{}/api/v2" +API_REFRESH_TIME_MIN = 5 +API_REFRESH_TIME_MAX = 10 +DEFAULT_COMMAND_TIMEOUT = 30.0 +DEFAULT_MQTT_CONNECT_TIMEOUT = 8.0 +MQTT_RECONNECT_RETRY_SECONDS = 15 * 60 +PAHO_MQTT_RECONNECT_MIN_DELAY_SECONDS = 60 +PAHO_MQTT_RECONNECT_MAX_DELAY_SECONDS = 30 * 60 +VISION_BORDER_DISTANCE_MM_VALUES = (50, 100, 150, 200) UNWANTED_ATTRIBS = [ "distance_covered", diff --git a/pyworxcloud/utils/devices.py b/pyworxcloud/utils/devices.py index 5d6cfc0..96ab6e0 100644 --- a/pyworxcloud/utils/devices.py +++ b/pyworxcloud/utils/devices.py @@ -5,7 +5,7 @@ import json import logging from datetime import datetime, timedelta, timezone -from typing import Any +from typing import Any, Callable from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from ..clouds import CloudType @@ -53,6 +53,7 @@ def __init__( self.mower = mower self._tz = tz self._decode = decode + self._mqtt_connected_resolver: Callable[[], bool] | None = None self.battery = Battery() self.blades = Blades() @@ -115,6 +116,17 @@ def is_decoded(self, value: bool) -> None: """Set decoded flag when dataset was decoded and handled.""" self.__is_decoded = value + @property + def mqtt_connected(self) -> bool: + """Return whether MQTT is currently connected for this device.""" + if self._mqtt_connected_resolver is None: + return False + return bool(self._mqtt_connected_resolver()) + + def set_mqtt_connected_resolver(self, resolver: Callable[[], bool] | None) -> None: + """Set a callback that resolves the current MQTT connection state.""" + self._mqtt_connected_resolver = resolver + def __mapinfo(self, api: Any, data: Any) -> None: """Map information from API.""" diff --git a/pyworxcloud/utils/mqtt.py b/pyworxcloud/utils/mqtt.py index 0440202..1ce57c9 100644 --- a/pyworxcloud/utils/mqtt.py +++ b/pyworxcloud/utils/mqtt.py @@ -18,11 +18,15 @@ from concurrent.futures import TimeoutError as FutureTimeoutError from datetime import datetime, timezone from logging import Logger -from typing import Any, Callable, Optional +from typing import Any, Callable from uuid import uuid4 import paho.mqtt.client as paho_mqtt +from ..const import ( + PAHO_MQTT_RECONNECT_MAX_DELAY_SECONDS, + PAHO_MQTT_RECONNECT_MIN_DELAY_SECONDS, +) from ..events import EventHandler, LandroidEvent from ..exceptions import NoConnectionError, TimeoutException from .landroid_class import LDict @@ -160,6 +164,7 @@ def __init__( identifier_resolver: Callable[[str], set[str]] | None = None, deduplicate_inflight_commands: bool = False, response_timeout: float = DEFAULT_RESPONSE_TIMEOUT, + connect_timeout: float | None = None, ) -> dict: """Initialize the paho-mqtt handler.""" @@ -170,14 +175,12 @@ def __init__( self._deduplicate_inflight_commands = deduplicate_inflight_commands self._endpoint = endpoint self._log = logger.getChild("MQTT") - self._reconnected: bool = False self._topic: list = [] self._api = api self._uuid = uuid4() self._is_connected: bool = False self._brandprefix = brandprefix self._user_id = user_id - self._connection_future: Optional[Any] = None self._command_lock = threading.Lock() self._response_lock = threading.Lock() self._response_event = threading.Event() @@ -195,10 +198,16 @@ def __init__( self._message_id_seq = itertools.count(random.randint(1024, 65535)) self._client_generation = 0 self._active_generation = 0 - self._awaiting_post_resume_message = False if response_timeout <= 0: raise ValueError("response_timeout must be greater than 0") self._response_timeout = float(response_timeout) + if connect_timeout is not None and connect_timeout <= 0: + raise ValueError("connect_timeout must be greater than 0") + self._connect_timeout = ( + self._response_timeout + if connect_timeout is None + else float(connect_timeout) + ) self._client_id = ( f"{self._brandprefix}/USER/{self._user_id}/homeassistant/{self._uuid}" ) @@ -246,7 +255,10 @@ def _create_mqtt_connection(self) -> paho_mqtt.Client: client.username_pw_set(username=username, password=None) client.tls_set(cert_reqs=ssl.CERT_REQUIRED) client.ws_set_options(path=MQTT_WEBSOCKET_PATH) - client.reconnect_delay_set(min_delay=1, max_delay=32) + client.reconnect_delay_set( + min_delay=PAHO_MQTT_RECONNECT_MIN_DELAY_SECONDS, + max_delay=PAHO_MQTT_RECONNECT_MAX_DELAY_SECONDS, + ) client.on_connect = lambda client, userdata, flags, reason_code, *args: ( self._on_paho_connect( client, userdata, flags, reason_code, generation=generation @@ -294,25 +306,6 @@ def _resubscribe_topic(self, topic: str, generation: int | None = None) -> None: raise self.subscribe(topic, False) - def _schedule_reconnect_after_resume(self) -> None: - """Kick off a full client rebuild after a broker-level reconnect.""" - worker = threading.Thread( - target=self._reconnect_after_resume, - name="pyworxcloud-mqtt-resume-reconnect", - daemon=True, - ) - worker.start() - - def _reconnect_after_resume(self) -> None: - """Rebuild the MQTT client after a reconnect callback we do not trust.""" - try: - self.update_token() - except Exception: # pragma: no cover - defensive logging path - self._log.debug( - "Forced reconnect after MQTT resume failed", - exc_info=True, - ) - def _on_paho_connect( self, connection: Any, @@ -332,8 +325,15 @@ def _on_paho_connect( if _reason_code_value(reason_code) == MQTT_CONNECT_ACCEPTED: self._is_connected = True self._connect_error = None + self._get_ready_event().set() + for topic in list(self._topic): + self._log.debug( + "Resubscribing to '%s' after MQTT connect callback", topic + ) + self._resubscribe_topic(topic, generation) else: self._is_connected = False + self._get_ready_event().clear() self._connect_error = NoConnectionError( f"MQTT connection rejected with result {reason_code}" ) @@ -364,6 +364,12 @@ def _on_paho_disconnect( reason_code = args[-2] if len(args) >= 2 else args[-1] if args else 0 if _reason_code_value(reason_code) != MQTT_CONNECT_ACCEPTED: + connect_event = self._get_connect_event() + if not self._is_connected and not connect_event.is_set(): + self._connect_error = NoConnectionError( + f"MQTT connection interrupted before ready: {reason_code}" + ) + connect_event.set() self._on_connection_interrupted( connection, reason_code, generation=generation ) @@ -387,49 +393,9 @@ def _on_connection_interrupted( self._is_connected = False self._get_ready_event().clear() - self._awaiting_post_resume_message = False logger.debug("Connection interrupted. error: %s", error) self._events.call(LandroidEvent.MQTT_CONNECTION, state=False) - def _on_connection_resumed( - self, connection: Any, return_code: Any, session_present: bool, **kwargs: Any - ) -> None: - """Callback when an interrupted connection is re-established.""" - del connection - logger = self._log.getChild("Conn_State") - generation = kwargs.get("generation") - if self._is_stale_generation(generation): - logger.debug( - "Ignoring stale connection resumed callback for generation %s", - generation, - ) - return - - logger.debug( - "Connection resumed. return_code: %s, session_present: %s", - return_code, - session_present, - ) - - self._is_connected = False - self._get_ready_event().clear() - self._awaiting_post_resume_message = False - - if _reason_code_value(return_code) == MQTT_CONNECT_ACCEPTED: - if session_present: - logger.debug( - "Session resumed, but forcing a full MQTT reconnect before trusting it" - ) - else: - logger.debug("Session did not persist; forcing a full MQTT reconnect") - else: - logger.debug( - "Resume returned non-accepted code %s; forcing a full MQTT reconnect", - return_code, - ) - - self._schedule_reconnect_after_resume() - @property def connected(self) -> bool: """Returns the MQTT connection state.""" @@ -447,7 +413,6 @@ def _on_message_received( return msg = payload.decode("utf-8") - self._awaiting_post_resume_message = False self._log.debug("Received MQTT message on topic '%s':\n%s", topic, msg) identifiers, message_ids = self._extract_response_markers(msg) expanded_identifiers = self._expand_identifiers(identifiers) @@ -550,7 +515,6 @@ def connect(self) -> None: connect_event.clear() self._connect_error = None try: - self._connection_future = None result = client.connect( self._endpoint, port=MQTT_PORT, @@ -559,7 +523,7 @@ def connect(self) -> None: _wait_for_operation(result) client.loop_start() - if not connect_event.wait(self._response_timeout): + if not connect_event.wait(self._connect_timeout): raise TimeoutException("Timed out waiting for MQTT connection") if self._connect_error is not None: raise self._connect_error @@ -573,22 +537,19 @@ def connect(self) -> None: return self._is_connected = True - self._reconnected = False - self._awaiting_post_resume_message = False - - for topic in self._topic: - self._log.debug("Subscribing to '%s'", topic) - self._resubscribe_topic(topic, generation) - self._get_ready_event().set() self._events.call(LandroidEvent.MQTT_CONNECTION, state=True) except Exception as exc: self._is_connected = False - self._connection_future = None self._get_ready_event().clear() self._safe_loop_stop(client) - self._log.error("Failed to connect to MQTT: %s", exc) + log_method = ( + self._log.error + if getattr(self, "_log_connect_errors", True) + else self._log.debug + ) + log_method("Failed to connect to MQTT: %s", exc) raise NoConnectionError() from exc async def aconnect(self) -> None: @@ -669,7 +630,6 @@ def disconnect(self, keep_topic: bool = False) -> None: with self._lifecycle_lock: if self._shutdown_event: self._is_connected = False - self._connection_future = None self._get_ready_event().clear() return @@ -707,7 +667,6 @@ def disconnect(self, keep_topic: bool = False) -> None: finally: self._safe_loop_stop(client) self._is_connected = False - self._connection_future = None self._get_ready_event().clear() async def adisconnect(self, keep_topic: bool = False) -> None: @@ -727,7 +686,6 @@ def shutdown(self) -> None: self.client = None self._is_connected = False - self._connection_future = None self._get_ready_event().clear() if client is not None and was_connected: @@ -828,9 +786,7 @@ def publish( raise ValueError("timeout must be greater than 0") should_retry_after_reconnect = ( - not self.connected - or self._awaiting_post_resume_message - or self._token_update_lock.locked() + not self.connected or self._token_update_lock.locked() ) command_signature = ( @@ -960,14 +916,6 @@ def _wait_until_ready(self, timeout: float) -> bool: def _ensure_connection_ready(self, timeout: float) -> None: """Ensure a usable MQTT connection exists before publishing.""" if self.connected: - if self._awaiting_post_resume_message: - self._log.debug( - "No MQTT traffic received after connection resume; rebuilding client before publish" - ) - self.update_token() - if self.connected: - return - raise NoConnectionError("MQTT connection did not recover after resume") return if self._token_update_lock.locked(): diff --git a/test.py b/test.py index 97c1958..91815cc 100644 --- a/test.py +++ b/test.py @@ -39,7 +39,6 @@ async def main() -> None: for _, device in cloud.devices.items(): # await cloud.update(device.serial_number) print(f"{device.name} online: {device.online}") - # await cloud.set_offlimits(device.serial_number, False) # await cloud.set_offlimits_shortcut(device.serial_number, True) # await cloud.set_cutting_height(device.serial_number, 45) diff --git a/tests/test_api_lifecycle.py b/tests/test_api_lifecycle.py index b33fd1d..6103837 100644 --- a/tests/test_api_lifecycle.py +++ b/tests/test_api_lifecycle.py @@ -16,6 +16,7 @@ from pyworxcloud.events import LandroidEvent from pyworxcloud.exceptions import ( AuthorizationError, + NoConnectionError, NoFirmwareAvailableError, NoFirmwareOtaError, NotFoundError, @@ -42,6 +43,7 @@ class DummyMQTT: """Simple MQTT stub.""" def __init__(self) -> None: + self.connected = True self.disconnect_called = False self.shutdown_called = False @@ -99,6 +101,7 @@ class CapturingMQTT: """MQTT constructor stub capturing provided timeout.""" last_response_timeout: float | None = None + last_connect_timeout: float | None = None constructor_thread_id: int | None = None def __init__( @@ -110,12 +113,14 @@ def __init__( _logger: Any, _callback: Any, response_timeout: float, + connect_timeout: float | None = None, identifier_resolver: Any = None, deduplicate_inflight_commands: bool = False, ) -> None: self.identifier_resolver = identifier_resolver self.deduplicate_inflight_commands = deduplicate_inflight_commands self.__class__.last_response_timeout = response_timeout + self.__class__.last_connect_timeout = connect_timeout self.__class__.constructor_thread_id = threading.get_ident() async def aconnect(self) -> None: @@ -145,12 +150,14 @@ def __init__( _logger: Any, _callback: Any, response_timeout: float, + connect_timeout: float | None = None, identifier_resolver: Any = None, deduplicate_inflight_commands: bool = False, ) -> None: self.identifier_resolver = identifier_resolver self.deduplicate_inflight_commands = deduplicate_inflight_commands self.response_timeout = response_timeout + self.connect_timeout = connect_timeout self.disconnect_calls = 0 self.shutdown_calls = 0 self.subscriptions: list[str] = [] @@ -351,6 +358,39 @@ def test_token_updated_is_noop_without_mqtt() -> None: asyncio.run(cloud._token_updated()) +def test_token_updated_refreshes_mqtt_only_when_connected() -> None: + """Token refresh should not force MQTT reconnects while MQTT is down.""" + cloud = WorxCloud("user@example.com", "secret", "worx") + retry_calls = 0 + + class TokenMQTT: + def __init__(self, connected: bool) -> None: + self.connected = connected + self.update_calls = 0 + + async def aupdate_token(self) -> None: + self.update_calls += 1 + + def _schedule_retry() -> None: + nonlocal retry_calls + retry_calls += 1 + + cloud._schedule_mqtt_retry = _schedule_retry # type: ignore[method-assign] + mqtt = TokenMQTT(False) + cloud.mqtt = mqtt + + asyncio.run(cloud._token_updated()) + + assert mqtt.update_calls == 0 + assert retry_calls == 1 + + mqtt.connected = True + asyncio.run(cloud._token_updated()) + + assert mqtt.update_calls == 1 + assert retry_calls == 1 + + def test_get_logger_does_not_accumulate_handlers() -> None: """Repeated logger setup should not attach output handlers or force levels.""" get_logger("pyworxcloud.test_handlers") @@ -439,19 +479,69 @@ def test_on_api_update_dispatches_api_event_callback() -> None: assert received == [{"key": "value"}] +def test_mqtt_connected_reports_current_client_state() -> None: + """MQTT connection property should mirror the current client state.""" + cloud = WorxCloud("user@example.com", "secret", "worx") + + class MQTTState: + def __init__(self, connected: bool) -> None: + self.connected = connected + + assert cloud.mqtt_connected is False + + cloud.mqtt = MQTTState(True) + assert cloud.mqtt_connected is True + + cloud.mqtt.connected = False + assert cloud.mqtt_connected is False + + cloud.mqtt = None + assert cloud.mqtt_connected is False + + +def test_mqtt_commands_raise_connection_error_when_mqtt_is_unavailable() -> None: + """MQTT command paths should fail cleanly in API-only fallback mode.""" + cloud = WorxCloud("user@example.com", "secret", "worx") + cloud._mowers_by_serial = { + "SN-1": { + "serial_number": "SN-1", + "uuid": "UUID-1", + "protocol": 0, + "mqtt_topics": {"command_in": "topic/in"}, + } + } + + with pytest.raises(NoConnectionError, match="MQTT connection is not ready"): + asyncio.run(cloud.update("SN-1")) + + class DisconnectedMQTT: + connected = False + + cloud.mqtt = DisconnectedMQTT() + with pytest.raises(NoConnectionError, match="MQTT connection is not ready"): + asyncio.run(cloud.update("SN-1")) + + def test_constructor_rejects_non_positive_command_timeout() -> None: """WorxCloud should validate command timeout.""" with pytest.raises(ValueError): WorxCloud("user@example.com", "secret", "worx", command_timeout=0) -def test_connect_passes_configured_command_timeout_to_mqtt(monkeypatch) -> None: - """Configured command timeout should be forwarded to MQTT layer.""" +def test_constructor_rejects_non_positive_mqtt_connect_timeout() -> None: + """WorxCloud should validate MQTT connect timeout.""" + with pytest.raises(ValueError): + WorxCloud("user@example.com", "secret", "worx", mqtt_connect_timeout=0) + + +def test_connect_passes_configured_timeouts_to_mqtt(monkeypatch) -> None: + """Configured timeouts should be forwarded to MQTT layer.""" cloud = WorxCloud( "user@example.com", "secret", "worx", command_timeout=12.5, + mqtt_connect_timeout=4.5, ) async def _fake_fetch() -> None: @@ -471,6 +561,7 @@ async def _fake_fetch() -> None: assert asyncio.run(cloud.connect()) is True assert CapturingMQTT.last_response_timeout == 12.5 + assert CapturingMQTT.last_connect_timeout == 4.5 def test_connect_constructs_mqtt_off_event_loop_thread(monkeypatch) -> None: @@ -516,6 +607,8 @@ def test_update_passes_optional_timeout_to_mqtt_ping() -> None: calls: list[tuple[str, str, int, float | None]] = [] class MQTTStub: + connected = True + async def aping( self, serial_number: str, diff --git a/tests/test_connect_cleanup.py b/tests/test_connect_cleanup.py index 51f726f..8a2d2f5 100644 --- a/tests/test_connect_cleanup.py +++ b/tests/test_connect_cleanup.py @@ -5,8 +5,6 @@ import asyncio from typing import Any -import pytest - from pyworxcloud import WorxCloud @@ -42,10 +40,12 @@ def __init__( _logger: Any, _callback: Any, response_timeout: float, + connect_timeout: float | None = None, identifier_resolver: Any = None, deduplicate_inflight_commands: bool = False, ) -> None: self.response_timeout = response_timeout + self.connect_timeout = connect_timeout self.identifier_resolver = identifier_resolver self.deduplicate_inflight_commands = deduplicate_inflight_commands self.disconnect_calls = 0 @@ -65,8 +65,22 @@ async def ashutdown(self) -> None: self.shutdown_calls += 1 -def test_connect_failure_cleans_up_mqtt_and_api_session(monkeypatch) -> None: - """A failed connect should not leave partial MQTT or API resources behind.""" +class TransientMQTT(FailingMQTT): + """MQTT stub that fails once and then reconnects.""" + + async def aconnect(self) -> None: + if len(self.__class__.instances) == 1: + raise RuntimeError("connect failed") + self.connected = True + + async def asubscribe(self, topic: str, _append: bool = True) -> None: + subscriptions = getattr(self, "subscriptions", []) + subscriptions.append(topic) + self.subscriptions = subscriptions + + +def test_mqtt_connect_failure_keeps_api_fallback_running(monkeypatch) -> None: + """MQTT failure should not prevent API-backed connect from succeeding.""" cloud = WorxCloud("user@example.com", "secret", "worx") session = DummySession() FailingMQTT.instances = [] @@ -87,8 +101,16 @@ async def _fake_fetch() -> None: monkeypatch.setattr("pyworxcloud.MQTT", FailingMQTT) monkeypatch.setattr("pyworxcloud.convert_to_time", lambda *_args, **_kwargs: None) - with pytest.raises(RuntimeError, match="connect failed"): - asyncio.run(cloud.connect()) + async def _exercise() -> None: + assert await cloud.connect() is True + assert cloud.mqtt is None + assert cloud._mqtt_retry_task is not None + assert session.close_calls == 0 + assert session.closed is False + assert cloud._disconnecting.is_set() is False + await cloud.disconnect() + + asyncio.run(_exercise()) assert len(FailingMQTT.instances) == 1 assert FailingMQTT.instances[0].disconnect_calls == 1 @@ -97,3 +119,43 @@ async def _fake_fetch() -> None: assert session.close_calls == 1 assert session.closed is True assert cloud._disconnecting.is_set() is True + + +def test_mqtt_background_retry_reconnects_without_api_refetch(monkeypatch) -> None: + """Background MQTT retry should use existing API data and restore MQTT.""" + cloud = WorxCloud("user@example.com", "secret", "worx") + TransientMQTT.instances = [] + fetch_calls = 0 + + async def _fake_fetch() -> None: + nonlocal fetch_calls + fetch_calls += 1 + cloud._mowers = [ + { + "name": "My Mower", + "mqtt_endpoint": "mqtt.example.invalid", + "user_id": 99, + "mqtt_topics": {"command_out": "topic/out"}, + } + ] + cloud.devices = {"My Mower": DummyDevice()} + + monkeypatch.setattr(cloud, "_fetch", _fake_fetch) + monkeypatch.setattr("pyworxcloud.MQTT", TransientMQTT) + monkeypatch.setattr("pyworxcloud.MQTT_RECONNECT_RETRY_SECONDS", 0) + monkeypatch.setattr("pyworxcloud.convert_to_time", lambda *_args, **_kwargs: None) + + async def _exercise() -> None: + assert await cloud.connect() is True + assert cloud._mqtt_retry_task is not None + await asyncio.wait_for(cloud._mqtt_retry_task, timeout=1) + assert cloud.mqtt is TransientMQTT.instances[1] + assert cloud.mqtt.subscriptions == ["topic/out"] + await cloud.disconnect() + + asyncio.run(_exercise()) + + assert fetch_calls == 1 + assert len(TransientMQTT.instances) == 2 + assert TransientMQTT.instances[0].disconnect_calls == 1 + assert TransientMQTT.instances[0].shutdown_calls == 1 diff --git a/tests/test_device_decode.py b/tests/test_device_decode.py index 88e2bae..0d8aa2f 100644 --- a/tests/test_device_decode.py +++ b/tests/test_device_decode.py @@ -98,6 +98,22 @@ def test_devicehandler_maps_module_capabilities() -> None: assert device.capabilities.check(DeviceCapability.ACS) is True +def test_devicehandler_reports_bound_mqtt_connection_state() -> None: + """Device MQTT state should resolve from the bound cloud callback.""" + _, payload = HTTP_FIXTURES[0] + mower = _build_mower(payload, _protocol_from_payload(payload), "Fixture Mower") + device = DeviceHandler(api=object(), mower=mower, tz="UTC") + connected = False + + assert device.mqtt_connected is False + + device.set_mqtt_connected_resolver(lambda: connected) + assert device.mqtt_connected is False + + connected = True + assert device.mqtt_connected is True + + def test_devicehandler_maps_border_cut_api_capability() -> None: """Vision border_cut API capability should expose edge-cut related features.""" payload = { diff --git a/tests/test_mqtt_commands.py b/tests/test_mqtt_commands.py index 4d400f9..7874c0b 100644 --- a/tests/test_mqtt_commands.py +++ b/tests/test_mqtt_commands.py @@ -122,71 +122,6 @@ def test_publish_uses_default_response_timeout(monkeypatch: pytest.MonkeyPatch) ) -def test_publish_rebuilds_connection_before_first_command_after_resume( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """First command after a resumed session should reconnect and retry once.""" - mqtt, dummy = _build_mqtt(monkeypatch, response_timeout=0.05) - reconnects: list[str] = [] - - mqtt._awaiting_post_resume_message = True - - def _update_token() -> None: - reconnects.append("called") - mqtt._awaiting_post_resume_message = False - mqtt._is_connected = True - - mqtt.update_token = _update_token # type: ignore[method-assign] - - with pytest.raises(TimeoutException): - mqtt.publish( - serial_number="SN-1", - topic="topic/in", - message={"cmd": 1}, - protocol=0, - ) - - assert reconnects == ["called", "called"] - assert len(dummy.published) == 2 - - -def test_publish_retries_once_after_connection_recovery_timeout( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """A command started during recovery should get one silent retry.""" - mqtt, dummy = _build_mqtt(monkeypatch, response_timeout=0.01) - reconnects: list[str] = [] - attempts = {"count": 0} - - mqtt._awaiting_post_resume_message = True - - def _ensure_connection_ready(timeout: float) -> None: - attempts["count"] += 1 - if attempts["count"] == 1: - mqtt._awaiting_post_resume_message = False - return None - - def _update_token() -> None: - reconnects.append("called") - mqtt._awaiting_post_resume_message = False - mqtt._is_connected = True - - mqtt._ensure_connection_ready = _ensure_connection_ready # type: ignore[method-assign] - mqtt.update_token = _update_token # type: ignore[method-assign] - - with pytest.raises(TimeoutException): - mqtt.publish( - serial_number="SN-1", - topic="topic/in", - message={"cmd": 1}, - protocol=0, - ) - - assert reconnects == ["called"] - assert attempts["count"] == 2 - assert len(dummy.published) == 2 - - def test_mqtt_rejects_non_positive_default_timeout( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tests/test_mqtt_lifecycle.py b/tests/test_mqtt_lifecycle.py index c9cd69d..bebb674 100644 --- a/tests/test_mqtt_lifecycle.py +++ b/tests/test_mqtt_lifecycle.py @@ -8,6 +8,11 @@ from concurrent.futures import TimeoutError as FutureTimeoutError from typing import Any +from pyworxcloud.const import ( + PAHO_MQTT_RECONNECT_MAX_DELAY_SECONDS, + PAHO_MQTT_RECONNECT_MIN_DELAY_SECONDS, +) +from pyworxcloud.events import EventHandler from pyworxcloud.utils.mqtt import MQTT, MQTT_CONNECT_ACCEPTED @@ -30,6 +35,7 @@ class _ClientStub: def __init__(self, should_raise: bool = False) -> None: self.disconnect_calls = 0 + self.subscriptions: list[tuple[str, int]] = [] self.should_raise = should_raise self.future: Any = _ImmediateFuture() @@ -39,6 +45,30 @@ def disconnect(self) -> _ImmediateFuture: raise RuntimeError("disconnect failed") return self.future + def subscribe(self, topic: str, qos: int) -> _ImmediateFuture: + self.subscriptions.append((topic, qos)) + return self.future + + +class _PahoConfigStub: + """Client stub that records setup calls.""" + + def __init__(self) -> None: + self.reconnect_delay: tuple[int, int] | None = None + + def username_pw_set(self, username: str, password: str | None = None) -> None: + self.username = username + self.password = password + + def tls_set(self, **_kwargs: Any) -> None: + return None + + def ws_set_options(self, **_kwargs: Any) -> None: + return None + + def reconnect_delay_set(self, min_delay: int, max_delay: int) -> None: + self.reconnect_delay = (min_delay, max_delay) + def _build_mqtt_lifecycle_fixture( *, connected: bool = True, client: Any | None = None @@ -48,7 +78,6 @@ def _build_mqtt_lifecycle_fixture( mqtt._lifecycle_lock = threading.RLock() mqtt._shutdown_event = False mqtt._is_connected = connected - mqtt._connection_future = object() mqtt._shutdown_timeout = 5.0 mqtt._disconnect_timeout = 5.0 mqtt._topic = ["topic/out"] @@ -62,6 +91,24 @@ def _build_mqtt_lifecycle_fixture( return mqtt +def test_create_mqtt_connection_uses_conservative_paho_reconnect_backoff() -> None: + """Paho reconnect backoff should not hammer the broker.""" + mqtt = MQTT.__new__(MQTT) + mqtt._api = type("API", (), {"access_token": "aaa.bbb.ccc"})() + mqtt._client_generation = 0 + mqtt._active_generation = 0 + mqtt._client_id = "client-id" + mqtt._log = logging.getLogger("test") + client = _PahoConfigStub() + mqtt._create_paho_client = lambda: client + + assert mqtt._create_mqtt_connection() is client + assert client.reconnect_delay == ( + PAHO_MQTT_RECONNECT_MIN_DELAY_SECONDS, + PAHO_MQTT_RECONNECT_MAX_DELAY_SECONDS, + ) + + def test_disconnect_is_idempotent_and_safe_with_missing_client() -> None: """Disconnect should be safe to call repeatedly and with no client.""" client = _ClientStub() @@ -72,7 +119,6 @@ def test_disconnect_is_idempotent_and_safe_with_missing_client() -> None: mqtt.disconnect() assert client.disconnect_calls == 1 - assert mqtt._connection_future is None assert mqtt._is_connected is False @@ -94,7 +140,6 @@ def test_disconnect_swallows_disconnect_future_timeout() -> None: mqtt.disconnect() assert client.disconnect_calls == 1 - assert mqtt._connection_future is None assert mqtt._is_connected is False @@ -135,7 +180,6 @@ def loop_stop(self) -> None: assert loop_stop_started.is_set() is True assert time.perf_counter() - started < 1.0 - assert mqtt._connection_future is None assert mqtt._is_connected is False @@ -149,7 +193,6 @@ def test_shutdown_is_idempotent_and_detaches_resources() -> None: assert client.disconnect_calls == 1 assert mqtt.client is None - assert mqtt._connection_future is None assert mqtt._shutdown_event is True assert mqtt._is_connected is False @@ -177,40 +220,67 @@ def test_shutdown_swallows_disconnect_future_timeout() -> None: assert mqtt._shutdown_event is True -def test_connection_resumed_resubscribes_even_when_session_persists() -> None: - """Resume should trigger a full reconnect even when session_present is true.""" +def test_disconnect_before_initial_connect_unblocks_connect_wait() -> None: + """A broker hangup during initial connect should fail the attempt immediately.""" mqtt = _build_mqtt_lifecycle_fixture(connected=False, client=_ClientStub()) + mqtt._events = EventHandler() + mqtt._connect_error = None + connect_event = mqtt._get_connect_event() + connect_event.clear() + + mqtt._on_paho_disconnect(mqtt.client, None, 1) + + assert connect_event.is_set() is True + assert mqtt._connect_error is not None + assert "before ready" in str(mqtt._connect_error) + + +def test_connect_callback_resubscribes_existing_topics() -> None: + """Paho reconnect callbacks should restore subscriptions without extra reconnects.""" + client = _ClientStub() + mqtt = _build_mqtt_lifecycle_fixture(connected=False, client=client) + mqtt._events = EventHandler() + mqtt._ready_event = threading.Event() + mqtt._connect_event = threading.Event() + mqtt._connect_error = None mqtt._topic = ["topic/a", "topic/b"] - mqtt._awaiting_post_resume_message = False - reconnect_calls: list[str] = [] - mqtt._schedule_reconnect_after_resume = lambda: reconnect_calls.append("called") + mqtt._active_generation = 7 - mqtt._on_connection_resumed( + mqtt._on_paho_connect( + client, None, + {"session present": False}, MQTT_CONNECT_ACCEPTED, - True, + generation=7, ) - assert mqtt._is_connected is False - assert mqtt._get_ready_event().is_set() is False - assert mqtt._awaiting_post_resume_message is False - assert reconnect_calls == ["called"] + assert mqtt.connected is True + assert mqtt._get_ready_event().is_set() is True + assert mqtt._get_connect_event().is_set() is True + assert client.subscriptions == [("topic/a", 1), ("topic/b", 1)] -def test_connection_resumed_resubscribes_when_session_is_not_present() -> None: - """Resume should trigger a full reconnect when the session is lost.""" - mqtt = _build_mqtt_lifecycle_fixture(connected=False, client=_ClientStub()) - mqtt._topic = ["topic/out"] - mqtt._awaiting_post_resume_message = False - reconnect_calls: list[str] = [] - mqtt._schedule_reconnect_after_resume = lambda: reconnect_calls.append("called") - - mqtt._on_connection_resumed( +def test_rejected_connect_callback_clears_ready_without_subscribing() -> None: + """Rejected connect callbacks should not leave the MQTT client ready.""" + client = _ClientStub() + mqtt = _build_mqtt_lifecycle_fixture(connected=True, client=client) + mqtt._ready_event = threading.Event() + mqtt._ready_event.set() + mqtt._connect_event = threading.Event() + mqtt._connect_error = None + mqtt._topic = ["topic/a"] + mqtt._active_generation = 7 + + mqtt._on_paho_connect( + client, None, - MQTT_CONNECT_ACCEPTED, - False, + {"session present": False}, + 1, + generation=7, ) - assert mqtt._is_connected is False - assert mqtt._awaiting_post_resume_message is False - assert reconnect_calls == ["called"] + assert mqtt.connected is False + assert mqtt._get_ready_event().is_set() is False + assert mqtt._get_connect_event().is_set() is True + assert mqtt._connect_error is not None + assert client.subscriptions == [] diff --git a/tests/test_mqtt_runtime.py b/tests/test_mqtt_runtime.py index aefe114..f7847f0 100644 --- a/tests/test_mqtt_runtime.py +++ b/tests/test_mqtt_runtime.py @@ -10,7 +10,7 @@ from pyworxcloud.events import EventHandler from pyworxcloud.exceptions import TimeoutException -from pyworxcloud.utils.mqtt import MQTT, MQTT_CONNECT_ACCEPTED +from pyworxcloud.utils.mqtt import MQTT def _build_mqtt() -> MQTT: @@ -27,51 +27,10 @@ def _build_mqtt() -> MQTT: mqtt._client_generation = 2 mqtt._response_timeout = 0.2 mqtt._shutdown_event = False - mqtt._connection_future = None - mqtt._reconnected = False mqtt.client = object() return mqtt -def test_connection_resumed_ignores_stale_generation() -> None: - """Stale AWS callbacks must not flip connection state back to ready.""" - mqtt = _build_mqtt() - subscribe_calls: list[tuple[str, bool, int | None]] = [] - - mqtt.subscribe = lambda topic, append, generation=None: subscribe_calls.append( - (topic, append, generation) - ) - - mqtt._on_connection_resumed( - object(), - MQTT_CONNECT_ACCEPTED, - True, - generation=1, - ) - - assert mqtt.connected is False - assert mqtt._ready_event.is_set() is False - assert subscribe_calls == [] - - -def test_connection_resumed_marks_active_generation_ready() -> None: - """Active resume callbacks should trigger a forced reconnect instead.""" - mqtt = _build_mqtt() - reconnect_calls: list[str] = [] - mqtt._schedule_reconnect_after_resume = lambda: reconnect_calls.append("called") - - mqtt._on_connection_resumed( - object(), - MQTT_CONNECT_ACCEPTED, - True, - generation=2, - ) - - assert mqtt.connected is False - assert mqtt._ready_event.is_set() is False - assert reconnect_calls == ["called"] - - def test_ensure_connection_ready_waits_for_parallel_refresh() -> None: """Publish callers should wait for an in-flight token refresh to finish.""" mqtt = _build_mqtt()