From 64f7ca283e1486fac32c6d435ab3263063ea6af7 Mon Sep 17 00:00:00 2001 From: marcosf2 Date: Thu, 11 Dec 2025 13:56:58 -0600 Subject: [PATCH 1/3] Implement WebSocket coordination for follow requests and responses & refactor rotary_encoder_dep bugfix --- configs/config_app_example.toml | 2 + scripts/coordination_websocket_client | 217 ++++++++++++++++++++ src/pqnstack/app/api/deps.py | 24 +++ src/pqnstack/app/api/main.py | 4 + src/pqnstack/app/api/routes/coordination.py | 199 ++++++++++++++++++ src/pqnstack/app/api/routes/debug.py | 18 ++ src/pqnstack/app/api/routes/qkd.py | 13 +- src/pqnstack/app/api/routes/serial.py | 31 +-- src/pqnstack/app/core/config.py | 35 +++- 9 files changed, 510 insertions(+), 33 deletions(-) create mode 100755 scripts/coordination_websocket_client create mode 100644 src/pqnstack/app/api/routes/coordination.py create mode 100644 src/pqnstack/app/api/routes/debug.py diff --git a/configs/config_app_example.toml b/configs/config_app_example.toml index 81c4a476..84c8a813 100644 --- a/configs/config_app_example.toml +++ b/configs/config_app_example.toml @@ -1,5 +1,7 @@ # MAKE SURE TO RENAME THIS FILE TO config.toml AND PLACE IT IN THE ROOT OF THE PROJECT +node_name = "example_node" + # Router configuration router_name = "router1" router_address = "xx.xx.xx.xx" # Replace with actual IP address diff --git a/scripts/coordination_websocket_client b/scripts/coordination_websocket_client new file mode 100755 index 00000000..e31ecc64 --- /dev/null +++ b/scripts/coordination_websocket_client @@ -0,0 +1,217 @@ +#!/usr/bin/env python + +""" +WebSocket client for testing the coordination follow_requested_alerts endpoint. +Acts as a proxy for the frontend, automatically responding to follow requests. +""" + +import argparse +import asyncio +import logging +import signal +import sys + +import websockets +from websockets.exceptions import ConnectionClosedError +from websockets.exceptions import ConnectionClosedOK + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +# Default configuration +DEFAULT_HOST = "localhost" +DEFAULT_PORT = 8000 +DEFAULT_RESPONSE = "yes" # Default response to follow requests + + +class CoordinationWebSocketClient: + """WebSocket client for coordination follow requests.""" + + def __init__( + self, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + auto_response: str = DEFAULT_RESPONSE, + interactive: bool = False, + ) -> None: + """ + Initialize the WebSocket client. + + Parameters + ---------- + host : str + The host address of the API server. + port : int + The port number of the API server. + auto_response : str + Automatic response to follow requests ('yes', 'no', 'true', 'false'). + interactive : bool + If True, prompt user for each response. If False, use auto_response. + """ + self.host = host + self.port = port + self.auto_response = auto_response.lower() + self.interactive = interactive + self.uri = f"ws://{host}:{port}/coordination/follow_requested_alerts" + self.websocket = None + self.running = False + + async def connect(self) -> None: + """Establish connection to the WebSocket endpoint.""" + try: + logger.info("Connecting to %s", self.uri) + self.websocket = await websockets.connect(self.uri) + logger.info("Successfully connected to coordination WebSocket") + self.running = True + except Exception as e: + logger.exception("Failed to connect to WebSocket: %s", e) + raise + + async def handle_messages(self) -> None: + """Listen for messages from the server and respond accordingly.""" + if not self.websocket: + logger.error("WebSocket not connected") + return + + try: + async for message in self.websocket: + logger.info("Received message: %s", message) + + # Determine response + if self.interactive: + response = await self._get_user_input(message) + else: + response = self.auto_response + logger.info("Auto-responding with: %s", response) + + # Send response + await self.websocket.send(response) + logger.info("Sent response: %s", response) + + except ConnectionClosedOK: + logger.info("WebSocket connection closed normally") + except ConnectionClosedError as e: + logger.error("WebSocket connection closed with error: %s", e) + except Exception: + logger.exception("Error while handling messages") + finally: + self.running = False + + async def _get_user_input(self, prompt: str) -> str: + """ + Get user input interactively. + + Parameters + ---------- + prompt : str + The prompt message to display to the user. + + Returns + ------- + str + User's response. + """ + loop = asyncio.get_event_loop() + print(f"\n{prompt}") + print("Enter your response (yes/no): ", end="", flush=True) + + # Run blocking input() in executor to not block event loop + response = await loop.run_in_executor(None, sys.stdin.readline) + return response.strip().lower() + + async def close(self) -> None: + """Close the WebSocket connection.""" + if self.websocket: + await self.websocket.close() + logger.info("WebSocket connection closed") + self.running = False + + async def run(self) -> None: + """Main run loop for the client.""" + try: + await self.connect() + await self.handle_messages() + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + except Exception: + logger.exception("Unexpected error in run loop") + finally: + await self.close() + + +async def main() -> None: + """Main entry point for the script.""" + parser = argparse.ArgumentParser( + description="WebSocket client for coordination follow_requested_alerts endpoint" + ) + parser.add_argument( + "--host", + type=str, + default=DEFAULT_HOST, + help=f"API server host (default: {DEFAULT_HOST})", + ) + parser.add_argument( + "--port", + type=int, + default=DEFAULT_PORT, + help=f"API server port (default: {DEFAULT_PORT})", + ) + parser.add_argument( + "--response", + type=str, + default=DEFAULT_RESPONSE, + choices=["yes", "no", "true", "false", "y", "n"], + help=f"Automatic response to follow requests (default: {DEFAULT_RESPONSE})", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Prompt for user input for each follow request instead of auto-responding", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug logging", + ) + + args = parser.parse_args() + + # Set logging level + if args.debug: + logging.getLogger().setLevel(logging.DEBUG) + + # Create and run client + client = CoordinationWebSocketClient( + host=args.host, + port=args.port, + auto_response=args.response, + interactive=args.interactive, + ) + + # Handle graceful shutdown + loop = asyncio.get_event_loop() + + def signal_handler() -> None: + logger.info("Received shutdown signal") + if client.running: + asyncio.create_task(client.close()) + + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, signal_handler) + + # Run the client + await client.run() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Script terminated by user") + except Exception: + logger.exception("Fatal error") + sys.exit(1) \ No newline at end of file diff --git a/src/pqnstack/app/api/deps.py b/src/pqnstack/app/api/deps.py index 463dac67..95d33341 100644 --- a/src/pqnstack/app/api/deps.py +++ b/src/pqnstack/app/api/deps.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from functools import lru_cache from typing import Annotated import httpx @@ -6,8 +7,12 @@ from pqnstack.app.core.config import NodeState from pqnstack.app.core.config import get_state +from pqnstack.app.core.config import logger from pqnstack.app.core.config import settings from pqnstack.network.client import Client +from pqnstack.pqn.drivers.rotaryencoder import MockRotaryEncoder +from pqnstack.pqn.drivers.rotaryencoder import RotaryEncoderInstrument +from pqnstack.pqn.drivers.rotaryencoder import SerialRotaryEncoder async def get_http_client() -> AsyncGenerator[httpx.AsyncClient, None]: @@ -25,4 +30,23 @@ async def get_instrument_client() -> AsyncGenerator[Client, None]: InstrumentClientDep = Annotated[httpx.AsyncClient, Depends(get_instrument_client)] + StateDep = Annotated[NodeState, Depends(get_state)] + + +@lru_cache +def get_rotary_encoder() -> RotaryEncoderInstrument: + if settings.virtual_rotator: + # Virtual rotator mode enabled, use mock with terminal input + logger.info("Virtual rotator mode enabled") + rotary_encoder: RotaryEncoderInstrument = MockRotaryEncoder() + else: + # Use the real serial rotary encoder + rotary_encoder = SerialRotaryEncoder( + label="rotary_encoder", address=settings.rotary_encoder_address, offset_degrees=0.0 + ) + + return rotary_encoder + + +SERDep = Annotated[RotaryEncoderInstrument, Depends(get_rotary_encoder)] diff --git a/src/pqnstack/app/api/main.py b/src/pqnstack/app/api/main.py index d50d7c12..175e0899 100644 --- a/src/pqnstack/app/api/main.py +++ b/src/pqnstack/app/api/main.py @@ -1,6 +1,8 @@ from fastapi import APIRouter from pqnstack.app.api.routes import chsh +from pqnstack.app.api.routes import coordination +from pqnstack.app.api.routes import debug from pqnstack.app.api.routes import qkd from pqnstack.app.api.routes import rng from pqnstack.app.api.routes import serial @@ -12,3 +14,5 @@ api_router.include_router(timetagger.router) api_router.include_router(rng.router) api_router.include_router(serial.router) +api_router.include_router(coordination.router) +api_router.include_router(debug.router) diff --git a/src/pqnstack/app/api/routes/coordination.py b/src/pqnstack/app/api/routes/coordination.py new file mode 100644 index 00000000..30891845 --- /dev/null +++ b/src/pqnstack/app/api/routes/coordination.py @@ -0,0 +1,199 @@ +import asyncio +import logging + +from fastapi import APIRouter +from fastapi import HTTPException +from fastapi import Request +from fastapi import WebSocket +from fastapi import WebSocketDisconnect +from fastapi import status +from pydantic import BaseModel + +from pqnstack.app.api.deps import ClientDep +from pqnstack.app.api.deps import StateDep +from pqnstack.app.core.config import ask_user_for_follow_event +from pqnstack.app.core.config import settings +from pqnstack.app.core.config import user_replied_event + +logger = logging.getLogger(__name__) + + +class FollowRequestResponse(BaseModel): + accepted: bool + + +class CollectFollowerResponse(BaseModel): + accepted: bool + + +class ResetCoordinationStateResponse(BaseModel): + message: str = "Coordination state reset successfully" + + +router = APIRouter(prefix="/coordination", tags=["coordination"]) + + +# TODO: Send a disconnection message if I was following/leading someone. +# FIXME: This is technically resetting more than just coordination state. including qkd. +@router.post("/reset_coordination_state") +async def reset_coordination_state(state: StateDep) -> ResetCoordinationStateResponse: + """Reset the coordination state of the node.""" + state.leading = False + state.followers_address = "" + state.following = False + state.following_requested = False + state.following_requested_user_response = None + state.leaders_address = "" + state.leaders_name = "" + state.qkd_emoji_pick = "" + state.qkd_bit_list = [] + state.qkd_question_order = [] + state.qkd_leader_basis_list = [] + state.qkd_follower_basis_list = [] + state.qkd_single_bit_current_index = 0 + state.qkd_resulting_bit_list = [] + state.qkd_request_basis_list = [] + state.qkd_request_bit_list = [] + state.qkd_n_matching_bits = -1 + return ResetCoordinationStateResponse() + + +@router.post("/collect_follower") +async def collect_follower( + request: Request, address: str, state: StateDep, http_client: ClientDep +) -> CollectFollowerResponse: + """ + Endpoint called by a leader node (this one) to request a follower node (other node) to follow it. + + Returns + ------- + CollectFollowerResponse indicating if the follower accepted the request. + """ + logger.info("Requesting client at %s to follow", address) + + # Get the port this server is listening on + server_port = request.scope["server"][1] + + ret = await http_client.post( + f"http://{address}/coordination/follow_requested?leaders_name={settings.node_name}&leaders_port={server_port}" + ) + if ret.status_code != status.HTTP_200_OK: + raise HTTPException(status_code=ret.status_code, detail=ret.text) + + response_data = ret.json() + if response_data.get("accepted") is True: + state.leading = True + state.followers_address = address + logger.info("Successfully collected follower") + return CollectFollowerResponse(accepted=True) + if response_data.get("accepted") is False: + logger.info("Follower rejected follow request") + return CollectFollowerResponse(accepted=False) + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not collect follower for unknown reasons" + ) + + +@router.post("/follow_requested") +async def follow_requested( + request: Request, leaders_name: str, leaders_port: int, state: StateDep +) -> FollowRequestResponse: + """ + Endpoint is called by a leader node (other node) to request this node to follow it. + + Returns + ------- + FollowRequestResponse indicating if the follow request is accepted. + """ + logger.debug("Requesting client at %s to follow", leaders_name) + + if request.client is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Request lacks the clients host") + leaders_address = f"{request.client.host}:{leaders_port}" + + # Check if the client is ready to accept a follower request and that node is not already following someone. + if not state.client_listening_for_follower_requests or state.following: + logger.info( + "Request rejected because %s", + ( + "client is not listening for requests" + if not state.client_listening_for_follower_requests + else "this node is already following someone" + ), + ) + return FollowRequestResponse(accepted=False) + + state.following_requested = True + state.leaders_name = leaders_name + state.leaders_address = leaders_address + # Trigger the state change to get the websocket to send question to user + ask_user_for_follow_event.set() + + logger.debug("Asking user to accept follow request from %s (%s)", leaders_name, leaders_address) + await user_replied_event.wait() # Wait for a state change event to see if user accepted + user_replied_event.clear() # Reset the event for the next change + if state.following_requested_user_response: + logger.debug("Follow request from %s accepted.", leaders_address) + state.following = True + state.leaders_name = leaders_name + state.leaders_address = leaders_address + return FollowRequestResponse(accepted=True) + + logger.debug("Follow request from %s rejected.", leaders_address) + # Clean up the state if user rejected + state.leaders_address = "" + state.leaders_name = "" + state.following_requested = False + state.following_requested_user_response = None + return FollowRequestResponse(accepted=False) + + +@router.websocket("/follow_requested_alerts") +async def follow_requested_alert(websocket: WebSocket, state: StateDep) -> None: + """Websocket endpoint is used to alert the client when a follow request is received. It also handles the response from the client.""" + await websocket.accept() + logger.info("Client connected to websocket for multiplayer coordination.") + state.client_listening_for_follower_requests = True + + async def ask_user_for_follow_handler() -> None: + """Task that waits for the ask_user_for_follow_event event and sends a message to the client if a follow request is detected.""" + while True: + try: + await ask_user_for_follow_event.wait() # Wait for a state change event + if state.following_requested: + logger.debug("Websocket detected a follow request, asking user for response.") + if websocket.client_state.name == "CONNECTED": + await websocket.send_text(f"Do you want to accept a connection from {state.leaders_name}?") + else: + logger.debug("WebSocket not connected, cannot send message") + break + ask_user_for_follow_event.clear() # Reset the event for the next change + except WebSocketDisconnect: + logger.info("WebSocket disconnected in ask_user_for_follow_handler") + break + except Exception: + logger.exception("Error in ask_user_for_follow_handler, continuing to listen") + ask_user_for_follow_event.clear() # Reset the event to continue + + async def client_message_handler() -> None: + """Task that waits for a message from the client and handles the response. It also handles the case where the client disconnects.""" + try: + while True: + response = await websocket.receive_text() + state.following_requested_user_response = response.lower() in ["true", "yes", "y"] + state.following_requested = False + logger.debug("Websocket received a response from user: %s", state.following_requested_user_response) + user_replied_event.set() + except WebSocketDisconnect: + logger.info("Client disconnected from websocket for multiplayer coordination.") + state.client_listening_for_follower_requests = False + + state_change_task = asyncio.create_task(ask_user_for_follow_handler()) + client_message_task = asyncio.create_task(client_message_handler()) + + try: + await asyncio.gather(state_change_task, client_message_task) + finally: + state_change_task.cancel() + client_message_task.cancel() diff --git a/src/pqnstack/app/api/routes/debug.py b/src/pqnstack/app/api/routes/debug.py new file mode 100644 index 00000000..3e43d74b --- /dev/null +++ b/src/pqnstack/app/api/routes/debug.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter + +from pqnstack.app.api.deps import StateDep +from pqnstack.app.core.config import NodeState +from pqnstack.app.core.config import Settings +from pqnstack.app.core.config import settings + +router = APIRouter(prefix="/debug", tags=["debug"]) + + +@router.get("/state") +async def get_state(state: StateDep) -> NodeState: + return state + + +@router.get("/settings") +async def get_settings() -> Settings: + return settings diff --git a/src/pqnstack/app/api/routes/qkd.py b/src/pqnstack/app/api/routes/qkd.py index ab32c1e5..81d06573 100644 --- a/src/pqnstack/app/api/routes/qkd.py +++ b/src/pqnstack/app/api/routes/qkd.py @@ -39,7 +39,7 @@ async def _qkd( ) counts = [] - for basis in state.qkd_basis_list: + for basis in state.qkd_leader_basis_list: r = await http_client.post(f"http://{follower_node_address}/qkd/single_bit") if r.status_code != status.HTTP_200_OK: @@ -76,16 +76,19 @@ def get_outcome(state: int, basis: int, choice: int, counts: int) -> int: outcome = [] logger.debug( - "Going for qkd_basis_list: %s, qkd_bit_list: %s, counts: %s", state.qkd_basis_list, state.qkd_bit_list, counts + "Going for qkd_leader_basis_list: %s, qkd_bit_list: %s, counts: %s", + state.qkd_leader_basis_list, + state.qkd_bit_list, + counts, ) - for basis, choice, count in zip(state.qkd_basis_list, state.qkd_bit_list, counts, strict=False): + for basis, choice, count in zip(state.qkd_leader_basis_list, state.qkd_bit_list, counts, strict=False): out = get_outcome(settings.bell_state.value, BasisBool[basis.name].value, choice, count) logger.debug( "Calculating outcome for basis: %s, choice: %s, count: %s, outcome: %s", basis.name, choice, count, out ) outcome.append(out) - basis_list = [basis.name for basis in state.qkd_basis_list] + basis_list = [basis.name for basis in state.qkd_leader_basis_list] # FIXME: Send already binary basis instead of HV/AD. r = await http_client.post(f"http://{follower_node_address}/qkd/request_basis_list", json=basis_list) @@ -112,7 +115,7 @@ async def qkd( timetagger_address: str | None = None, ) -> list[int]: """Perform a QKD protocol with the given follower node.""" - if not state.qkd_basis_list: + if not state.qkd_leader_basis_list: logger.error("QKD basis list is empty") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/src/pqnstack/app/api/routes/serial.py b/src/pqnstack/app/api/routes/serial.py index c86cb9a7..f5811840 100644 --- a/src/pqnstack/app/api/routes/serial.py +++ b/src/pqnstack/app/api/routes/serial.py @@ -1,40 +1,19 @@ import logging -from typing import Annotated +from typing import TYPE_CHECKING from typing import cast from fastapi import APIRouter -from fastapi import Depends from pydantic import BaseModel -from pqnstack.app.core.config import settings -from pqnstack.pqn.drivers.rotaryencoder import MockRotaryEncoder -from pqnstack.pqn.drivers.rotaryencoder import RotaryEncoderInstrument -from pqnstack.pqn.drivers.rotaryencoder import SerialRotaryEncoder +from pqnstack.app.api.deps import SERDep + +if TYPE_CHECKING: + from pqnstack.pqn.drivers.rotaryencoder import MockRotaryEncoder logger = logging.getLogger(__name__) router = APIRouter(prefix="/serial", tags=["measure"]) -def get_rotary_encoder() -> RotaryEncoderInstrument: - if settings.rotary_encoder is None: - if settings.virtual_rotator: - # Virtual rotator mode enabled, use mock with terminal input - logger.info("Virtual rotator mode enabled") - mock_encoder = MockRotaryEncoder() - settings.rotary_encoder = mock_encoder - else: - # Use the real serial rotary encoder - rotary_encoder = SerialRotaryEncoder( - label="rotary_encoder", address=settings.rotary_encoder_address, offset_degrees=0.0 - ) - settings.rotary_encoder = rotary_encoder - - return settings.rotary_encoder - - -SERDep = Annotated[RotaryEncoderInstrument, Depends(get_rotary_encoder)] - - class AngleResponse(BaseModel): theta: float diff --git a/src/pqnstack/app/core/config.py b/src/pqnstack/app/core/config.py index c62f9769..6a1b5368 100644 --- a/src/pqnstack/app/core/config.py +++ b/src/pqnstack/app/core/config.py @@ -1,3 +1,4 @@ +import asyncio import logging from functools import lru_cache @@ -33,6 +34,7 @@ class QKDSettings(BaseModel): class Settings(BaseSettings): + node_name: str = "node1" router_name: str = "router1" router_address: str = "localhost" router_port: int = 5555 @@ -74,9 +76,32 @@ def get_settings() -> Settings: class NodeState(BaseModel): + # Coordination state + # FIXME: Make sure we are checking for the client_listening_for_follower_requests state everywhere. + client_listening_for_follower_requests: bool = False + + # Leader's state + leading: bool = False + followers_address: str = "" + + # Follower's state + following: bool = False + # Other node requested this node to follow it. + following_requested: bool = False + # User's response to the follow request. None if no response yet, True if accepted, False if rejected. + following_requested_user_response: bool | None = None + # The address of the leader this node is following. None if not following anyone. + leaders_address: str = "" + leaders_name: str = "" + + # CHSH state chsh_request_basis: list[float] = [22.5, 67.5] - # FIXME: Use enums for this - qkd_basis_list: list[QKDEncodingBasis] = [ + + # QKD state + # FIXME: At the moment the reset_coordination_state resets this, probably want to refactor that function out. + qkd_question_order: list[int] = [] # Order of questions for QKD + qkd_emoji_pick: str = "" # Emoji chosen for QKD + qkd_leader_basis_list: list[QKDEncodingBasis] = [ QKDEncodingBasis.DA, QKDEncodingBasis.DA, QKDEncodingBasis.DA, @@ -89,13 +114,19 @@ class NodeState(BaseModel): QKDEncodingBasis.HV, QKDEncodingBasis.HV, ] + qkd_follower_basis_list: list[QKDEncodingBasis] = [] + qkd_single_bit_current_index: int = 0 # Current index in follower basis list for single_bit endpoint qkd_bit_list: list[int] = [] qkd_resulting_bit_list: list[int] = [] # Resulting bits after QKD qkd_request_basis_list: list[QKDEncodingBasis] = [] # Basis angles for QKD qkd_request_bit_list: list[int] = [] + qkd_n_matching_bits: int = -1 # Leaders populate this value after qkd is done. Same with the emoji state = NodeState() +ask_user_for_follow_event = asyncio.Event() +user_replied_event = asyncio.Event() +qkd_result_received_event = asyncio.Event() def get_state() -> NodeState: From fa7d45cf922ac86ddde690137446dad00fdf1e4f Mon Sep 17 00:00:00 2001 From: marcosf2 Date: Thu, 22 Jan 2026 22:08:58 -0600 Subject: [PATCH 2/3] Makes following and leading into an enum field. --- src/pqnstack/app/api/routes/coordination.py | 13 ++++++------- src/pqnstack/app/core/config.py | 16 +++++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/pqnstack/app/api/routes/coordination.py b/src/pqnstack/app/api/routes/coordination.py index 30891845..57257fea 100644 --- a/src/pqnstack/app/api/routes/coordination.py +++ b/src/pqnstack/app/api/routes/coordination.py @@ -11,7 +11,7 @@ from pqnstack.app.api.deps import ClientDep from pqnstack.app.api.deps import StateDep -from pqnstack.app.core.config import ask_user_for_follow_event +from pqnstack.app.core.config import ask_user_for_follow_event, NodeRole from pqnstack.app.core.config import settings from pqnstack.app.core.config import user_replied_event @@ -38,9 +38,8 @@ class ResetCoordinationStateResponse(BaseModel): @router.post("/reset_coordination_state") async def reset_coordination_state(state: StateDep) -> ResetCoordinationStateResponse: """Reset the coordination state of the node.""" - state.leading = False + state.role = NodeRole.INDEPENDENT state.followers_address = "" - state.following = False state.following_requested = False state.following_requested_user_response = None state.leaders_address = "" @@ -82,7 +81,7 @@ async def collect_follower( response_data = ret.json() if response_data.get("accepted") is True: - state.leading = True + state.role = NodeRole.LEADER state.followers_address = address logger.info("Successfully collected follower") return CollectFollowerResponse(accepted=True) @@ -113,13 +112,13 @@ async def follow_requested( leaders_address = f"{request.client.host}:{leaders_port}" # Check if the client is ready to accept a follower request and that node is not already following someone. - if not state.client_listening_for_follower_requests or state.following: + if not state.client_listening_for_follower_requests or state.role != NodeRole.INDEPENDENT: logger.info( "Request rejected because %s", ( "client is not listening for requests" if not state.client_listening_for_follower_requests - else "this node is already following someone" + else f"this node is already a {state.role}" ), ) return FollowRequestResponse(accepted=False) @@ -135,7 +134,7 @@ async def follow_requested( user_replied_event.clear() # Reset the event for the next change if state.following_requested_user_response: logger.debug("Follow request from %s accepted.", leaders_address) - state.following = True + state.role = NodeRole.FOLLOWER state.leaders_name = leaders_name state.leaders_address = leaders_address return FollowRequestResponse(accepted=True) diff --git a/src/pqnstack/app/core/config.py b/src/pqnstack/app/core/config.py index 6a1b5368..cea80d06 100644 --- a/src/pqnstack/app/core/config.py +++ b/src/pqnstack/app/core/config.py @@ -1,5 +1,6 @@ import asyncio import logging +from enum import auto, Enum from functools import lru_cache from pydantic import BaseModel @@ -45,8 +46,6 @@ class Settings(BaseSettings): rotary_encoder_address: str = "/dev/ttyACM0" virtual_rotator: bool = False # If True, use terminal input instead of hardware rotary encoder - rotary_encoder: RotaryEncoderInstrument | None = None - model_config = SettingsConfigDict(toml_file="./config.toml", env_file=".env", env_file_encoding="utf-8") @classmethod @@ -74,18 +73,21 @@ def get_settings() -> Settings: settings = get_settings() +class NodeRole(Enum): + """Enum indicating the role of this Node. Enum values are strings to see the role explicitly in logging instead of seeing numeric values.""" + INDEPENDENT = "independent" + LEADER = "leader" + FOLLOWER = "follower" class NodeState(BaseModel): # Coordination state # FIXME: Make sure we are checking for the client_listening_for_follower_requests state everywhere. client_listening_for_follower_requests: bool = False - # Leader's state - leading: bool = False + # Current role of this node. + role: NodeRole = NodeRole.INDEPENDENT + # Address of the Node following this node followers_address: str = "" - - # Follower's state - following: bool = False # Other node requested this node to follow it. following_requested: bool = False # User's response to the follow request. None if no response yet, True if accepted, False if rejected. From 5887d754526b1979f3cff91e1a803c52b8ca2a89 Mon Sep 17 00:00:00 2001 From: marcosf2 Date: Thu, 22 Jan 2026 22:35:02 -0600 Subject: [PATCH 3/3] Fixes collect_follower flow --- src/pqnstack/app/api/routes/coordination.py | 15 ++++++--------- src/pqnstack/app/core/config.py | 6 ++++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/pqnstack/app/api/routes/coordination.py b/src/pqnstack/app/api/routes/coordination.py index 57257fea..4c368807 100644 --- a/src/pqnstack/app/api/routes/coordination.py +++ b/src/pqnstack/app/api/routes/coordination.py @@ -11,7 +11,8 @@ from pqnstack.app.api.deps import ClientDep from pqnstack.app.api.deps import StateDep -from pqnstack.app.core.config import ask_user_for_follow_event, NodeRole +from pqnstack.app.core.config import NodeRole +from pqnstack.app.core.config import ask_user_for_follow_event from pqnstack.app.core.config import settings from pqnstack.app.core.config import user_replied_event @@ -79,19 +80,15 @@ async def collect_follower( if ret.status_code != status.HTTP_200_OK: raise HTTPException(status_code=ret.status_code, detail=ret.text) - response_data = ret.json() - if response_data.get("accepted") is True: + response_data = FollowRequestResponse(**ret.json()) + if response_data.accepted: state.role = NodeRole.LEADER state.followers_address = address logger.info("Successfully collected follower") return CollectFollowerResponse(accepted=True) - if response_data.get("accepted") is False: - logger.info("Follower rejected follow request") - return CollectFollowerResponse(accepted=False) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not collect follower for unknown reasons" - ) + logger.info("Follower rejected follow request") + return CollectFollowerResponse(accepted=False) @router.post("/follow_requested") diff --git a/src/pqnstack/app/core/config.py b/src/pqnstack/app/core/config.py index cea80d06..88c04497 100644 --- a/src/pqnstack/app/core/config.py +++ b/src/pqnstack/app/core/config.py @@ -1,6 +1,6 @@ import asyncio import logging -from enum import auto, Enum +from enum import Enum from functools import lru_cache from pydantic import BaseModel @@ -12,7 +12,6 @@ from pqnstack.constants import BellState from pqnstack.constants import QKDEncodingBasis -from pqnstack.pqn.drivers.rotaryencoder import RotaryEncoderInstrument from pqnstack.pqn.protocols.measurement import MeasurementConfig logger = logging.getLogger(__name__) @@ -73,12 +72,15 @@ def get_settings() -> Settings: settings = get_settings() + class NodeRole(Enum): """Enum indicating the role of this Node. Enum values are strings to see the role explicitly in logging instead of seeing numeric values.""" + INDEPENDENT = "independent" LEADER = "leader" FOLLOWER = "follower" + class NodeState(BaseModel): # Coordination state # FIXME: Make sure we are checking for the client_listening_for_follower_requests state everywhere.