Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/config_app_example.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# MAKE SURE TO RENAME THIS FILE TO config.toml AND PLACE IT IN THE ROOT OF THE PROJECT

node_name = "example_node"

# Router configuration
router_name = "router1"
router_address = "xx.xx.xx.xx" # Replace with actual IP address
Expand Down
217 changes: 217 additions & 0 deletions scripts/coordination_websocket_client
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
#!/usr/bin/env python

"""
WebSocket client for testing the coordination follow_requested_alerts endpoint.
Acts as a proxy for the frontend, automatically responding to follow requests.
"""

import argparse
import asyncio
import logging
import signal
import sys

import websockets
from websockets.exceptions import ConnectionClosedError
from websockets.exceptions import ConnectionClosedOK

# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# Default configuration
DEFAULT_HOST = "localhost"
DEFAULT_PORT = 8000
DEFAULT_RESPONSE = "yes" # Default response to follow requests


class CoordinationWebSocketClient:
"""WebSocket client for coordination follow requests."""

def __init__(
self,
host: str = DEFAULT_HOST,
port: int = DEFAULT_PORT,
auto_response: str = DEFAULT_RESPONSE,
interactive: bool = False,
) -> None:
"""
Initialize the WebSocket client.

Parameters
----------
host : str
The host address of the API server.
port : int
The port number of the API server.
auto_response : str
Automatic response to follow requests ('yes', 'no', 'true', 'false').
interactive : bool
If True, prompt user for each response. If False, use auto_response.
"""
self.host = host
self.port = port
self.auto_response = auto_response.lower()
self.interactive = interactive
self.uri = f"ws://{host}:{port}/coordination/follow_requested_alerts"
self.websocket = None
self.running = False

async def connect(self) -> None:
"""Establish connection to the WebSocket endpoint."""
try:
logger.info("Connecting to %s", self.uri)
self.websocket = await websockets.connect(self.uri)
logger.info("Successfully connected to coordination WebSocket")
self.running = True
except Exception as e:
logger.exception("Failed to connect to WebSocket: %s", e)
raise

async def handle_messages(self) -> None:
"""Listen for messages from the server and respond accordingly."""
if not self.websocket:
logger.error("WebSocket not connected")
return

try:
async for message in self.websocket:
logger.info("Received message: %s", message)

# Determine response
if self.interactive:
response = await self._get_user_input(message)
else:
response = self.auto_response
logger.info("Auto-responding with: %s", response)

# Send response
await self.websocket.send(response)
logger.info("Sent response: %s", response)

except ConnectionClosedOK:
logger.info("WebSocket connection closed normally")
except ConnectionClosedError as e:
logger.error("WebSocket connection closed with error: %s", e)
except Exception:
logger.exception("Error while handling messages")
finally:
self.running = False

async def _get_user_input(self, prompt: str) -> str:
"""
Get user input interactively.

Parameters
----------
prompt : str
The prompt message to display to the user.

Returns
-------
str
User's response.
"""
loop = asyncio.get_event_loop()
print(f"\n{prompt}")
print("Enter your response (yes/no): ", end="", flush=True)

# Run blocking input() in executor to not block event loop
response = await loop.run_in_executor(None, sys.stdin.readline)
return response.strip().lower()

async def close(self) -> None:
"""Close the WebSocket connection."""
if self.websocket:
await self.websocket.close()
logger.info("WebSocket connection closed")
self.running = False

async def run(self) -> None:
"""Main run loop for the client."""
try:
await self.connect()
await self.handle_messages()
except KeyboardInterrupt:
logger.info("Received keyboard interrupt, shutting down...")
except Exception:
logger.exception("Unexpected error in run loop")
finally:
await self.close()


async def main() -> None:
"""Main entry point for the script."""
parser = argparse.ArgumentParser(
description="WebSocket client for coordination follow_requested_alerts endpoint"
)
parser.add_argument(
"--host",
type=str,
default=DEFAULT_HOST,
help=f"API server host (default: {DEFAULT_HOST})",
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help=f"API server port (default: {DEFAULT_PORT})",
)
parser.add_argument(
"--response",
type=str,
default=DEFAULT_RESPONSE,
choices=["yes", "no", "true", "false", "y", "n"],
help=f"Automatic response to follow requests (default: {DEFAULT_RESPONSE})",
)
parser.add_argument(
"--interactive",
action="store_true",
help="Prompt for user input for each follow request instead of auto-responding",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug logging",
)

args = parser.parse_args()

# Set logging level
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)

# Create and run client
client = CoordinationWebSocketClient(
host=args.host,
port=args.port,
auto_response=args.response,
interactive=args.interactive,
)

# Handle graceful shutdown
loop = asyncio.get_event_loop()

def signal_handler() -> None:
logger.info("Received shutdown signal")
if client.running:
asyncio.create_task(client.close())

for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, signal_handler)

# Run the client
await client.run()


if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Script terminated by user")
except Exception:
logger.exception("Fatal error")
sys.exit(1)
24 changes: 24 additions & 0 deletions src/pqnstack/app/api/deps.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from collections.abc import AsyncGenerator
from functools import lru_cache
from typing import Annotated

import httpx
from fastapi import Depends

from pqnstack.app.core.config import NodeState
from pqnstack.app.core.config import get_state
from pqnstack.app.core.config import logger
from pqnstack.app.core.config import settings
from pqnstack.network.client import Client
from pqnstack.pqn.drivers.rotaryencoder import MockRotaryEncoder
from pqnstack.pqn.drivers.rotaryencoder import RotaryEncoderInstrument
from pqnstack.pqn.drivers.rotaryencoder import SerialRotaryEncoder


async def get_http_client() -> AsyncGenerator[httpx.AsyncClient, None]:
Expand All @@ -25,4 +30,23 @@ async def get_instrument_client() -> AsyncGenerator[Client, None]:

InstrumentClientDep = Annotated[httpx.AsyncClient, Depends(get_instrument_client)]


StateDep = Annotated[NodeState, Depends(get_state)]


@lru_cache
def get_rotary_encoder() -> RotaryEncoderInstrument:
if settings.virtual_rotator:
# Virtual rotator mode enabled, use mock with terminal input
logger.info("Virtual rotator mode enabled")
rotary_encoder: RotaryEncoderInstrument = MockRotaryEncoder()
else:
# Use the real serial rotary encoder
rotary_encoder = SerialRotaryEncoder(
label="rotary_encoder", address=settings.rotary_encoder_address, offset_degrees=0.0
)

return rotary_encoder


SERDep = Annotated[RotaryEncoderInstrument, Depends(get_rotary_encoder)]
4 changes: 4 additions & 0 deletions src/pqnstack/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from fastapi import APIRouter

from pqnstack.app.api.routes import chsh
from pqnstack.app.api.routes import coordination
from pqnstack.app.api.routes import debug
from pqnstack.app.api.routes import qkd
from pqnstack.app.api.routes import rng
from pqnstack.app.api.routes import serial
Expand All @@ -12,3 +14,5 @@
api_router.include_router(timetagger.router)
api_router.include_router(rng.router)
api_router.include_router(serial.router)
api_router.include_router(coordination.router)
api_router.include_router(debug.router)
Loading