Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions src/nwp500/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ def __init__(
# Connection state (simpler than checking _connection_manager)
self._connection: mqtt.Connection | None = None
self._connected = False
# Guards _active_reconnect / _deep_reconnect against re-entrancy.
# While True, _on_connection_interrupted_internal will not forward
# events to the reconnection handler, preventing the intentional
# teardown of the old connection from spawning a competing backoff loop.
self._actively_reconnecting = False

_logger.info(
f"Initialized MQTT client with ID: {self.config.client_id}"
Expand Down Expand Up @@ -276,8 +281,17 @@ def _on_connection_interrupted_internal(
)
)

# Delegate to reconnection handler if available
if self._reconnection_handler and self.config.auto_reconnect:
# Delegate to reconnection handler if available.
# Skip while _actively_reconnecting: the interruption was caused by
# _active_reconnect / _deep_reconnect intentionally closing the old
# connection. Forwarding it would queue a _start_reconnect_task
# coroutine that could fire after the new connection is up and the
# existing backoff task has been cancelled, spawning a competing loop.
if (
self._reconnection_handler
and self.config.auto_reconnect
and not self._actively_reconnecting
):
self._reconnection_handler.on_connection_interrupted(error)

# Record diagnostic event
Expand Down Expand Up @@ -380,8 +394,13 @@ async def _active_reconnect(self) -> None:
_logger.debug("Already connected, skipping reconnection")
return

if self._actively_reconnecting:
_logger.debug("Active reconnection already in progress, skipping")
return

_logger.info("Attempting active reconnection...")

self._actively_reconnecting = True
try:
# Ensure tokens are still valid
await self._auth_client.ensure_valid_token()
Expand All @@ -390,6 +409,9 @@ async def _active_reconnect(self) -> None:
if self._connection_manager:
# Close old connection to stop SDK auto-reconnect and
# prevent two connections with the same client ID.
# _actively_reconnecting suppresses the
# on_connection_interrupted callback that closing triggers,
# preventing a competing backoff loop from being spawned.
_logger.debug("Recreating MQTT connection...")
try:
await self._connection_manager.close()
Expand Down Expand Up @@ -432,6 +454,8 @@ async def _active_reconnect(self) -> None:
f"Error during active reconnection: {e}", exc_info=True
)
raise
finally:
self._actively_reconnecting = False

async def _deep_reconnect(self) -> None:
"""
Expand All @@ -451,13 +475,21 @@ async def _deep_reconnect(self) -> None:
_logger.debug("Already connected, skipping deep reconnection")
return

if self._actively_reconnecting:
_logger.debug("Active reconnection already in progress, skipping")
return

_logger.warning(
"Performing deep reconnection (full rebuild)... "
"This may take longer."
)

self._actively_reconnecting = True
try:
# Step 1: Clean up existing connection if any
# Step 1: Clean up existing connection if any.
# _actively_reconnecting suppresses the on_connection_interrupted
# callback that closing triggers, preventing a competing backoff
# loop from being spawned.
if self._connection_manager:
_logger.debug("Cleaning up old connection...")
try:
Expand Down Expand Up @@ -534,6 +566,8 @@ async def _deep_reconnect(self) -> None:
) as e:
_logger.error(f"Error during deep reconnection: {e}", exc_info=True)
raise
finally:
self._actively_reconnecting = False

async def connect(self) -> bool:
"""
Expand Down
19 changes: 17 additions & 2 deletions src/nwp500/mqtt/reconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,18 @@ def on_connection_interrupted(self, error: Exception) -> None:
"""
_logger.warning(f"Connection interrupted: {error}")

# Start automatic reconnection if enabled
# Start automatic reconnection if enabled.
# Also guard against stale interruption events that arrive after the
# connection has already been restored: these can be queued via
# run_coroutine_threadsafe and fire after on_connection_resumed has
# cancelled _reconnect_task (setting it to None), which would
# otherwise bypass the task-existence check and spawn a new backoff
# loop while the client is perfectly healthy.
if (
self.config.auto_reconnect
and self._enabled
and not self._manual_disconnect
and not self._is_connected_func()
and (not self._reconnect_task or self._reconnect_task.done())
):
_logger.info("Starting automatic reconnection...")
Expand Down Expand Up @@ -132,8 +139,16 @@ async def _start_reconnect_task(self) -> None:

This is a helper method to create the reconnect task from within
a coroutine that's scheduled via _schedule_coroutine.

The is_connected guard is re-checked here because this coroutine may
be queued via run_coroutine_threadsafe and run after the connection
has already been restored (e.g. by on_connection_resumed cancelling
_reconnect_task), in which case starting a new backoff loop would
incorrectly tear down a healthy connection.
"""
if not self._reconnect_task or self._reconnect_task.done():
if not self._is_connected_func() and (
not self._reconnect_task or self._reconnect_task.done()
):
self._reconnect_task = asyncio.create_task(
self._reconnect_with_backoff()
)
Expand Down
Loading
Loading