diff --git a/Dockerfile b/Dockerfile index 785adcb..21d66ed 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9 AS development +FROM python:3.13 AS development WORKDIR /app COPY requirements-server.txt . diff --git a/Dockerfile.manager b/Dockerfile.manager index 8566351..b9e95c2 100644 --- a/Dockerfile.manager +++ b/Dockerfile.manager @@ -1,4 +1,4 @@ -FROM python:3.9 +FROM python:3.13 WORKDIR /app COPY requirements-rabbit.txt . diff --git a/requirements-dev.txt b/requirements-dev.txt index b622902..087cff3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,3 +18,4 @@ types-requests # misc pre-commit codespell +django-silk diff --git a/requirements-rabbit.txt b/requirements-rabbit.txt index 6e6bcca..92f4497 100644 --- a/requirements-rabbit.txt +++ b/requirements-rabbit.txt @@ -1,8 +1,7 @@ -Django +Django==5.1 pika django-cors-headers djangorestframework djangorestframework-simplejwt djangorestframework-api-key==2.* -psycopg2 -django-silk +psycopg[binary, pool] diff --git a/requirements-server.txt b/requirements-server.txt index 41c38c4..f7b6ad4 100644 --- a/requirements-server.txt +++ b/requirements-server.txt @@ -1,12 +1,12 @@ asgiref -Django +Django==5.1 django-cors-headers djangorestframework djangorestframework-simplejwt PyJWT pytz sqlparse -psycopg2-binary +psycopg[binary, pool] djangorestframework-api-key==2.* gunicorn numpy @@ -20,4 +20,3 @@ redis PyJWT==1.7.1 boto3 pika -django-silk diff --git a/server/mq/core/connection_manager.py b/server/mq/core/connection_manager.py index 8404360..e0cacb7 100644 --- a/server/mq/core/connection_manager.py +++ b/server/mq/core/connection_manager.py @@ -1,6 +1,7 @@ import asyncio import logging import os +import threading import urllib import pika @@ -11,6 +12,7 @@ class ConnectionManager: instance = None + _instance_lock = threading.RLock() def __init__(self): if not ConnectionManager.instance._initialized: @@ -20,6 +22,7 @@ def __init__(self): self._connected = False self._loop = None self._url = self._build_amqp_url() + self._connection_thread_id = None self._initialized = True async def connect(self, loop=None): @@ -42,6 +45,9 @@ async def connect(self, loop=None): custom_ioloop=loop, ) + self._loop = loop or asyncio.get_event_loop() + self._connection_thread_id = threading.get_ident() + await self._ready.wait() self._connection = future_connection @@ -84,14 +90,41 @@ async def close(self): if self._connection and not ( self._connection.is_closing or self._connection.is_closed ): - self._connection.close() + current_thread_id = threading.get_ident() + if current_thread_id == self._connection_thread_id: + self._connection.close() + else: + # in a different thread, need to use add_callback_threadsafe + close_event = asyncio.Event() + + def threadsafe_close(): + try: + self._connection.close() + except Exception as e: + logger.error(f"Error closing connection: {str(e)}") + finally: + # signal completion + asyncio.run_coroutine_threadsafe( + self._set_event(close_event), self._loop + ) + + try: + self._connection.add_callback_threadsafe(threadsafe_close) + await close_event.wait() + except Exception as e: + logger.error(f"Failed to schedule connection close: {str(e)}") + self._connected = False + async def _set_event(self, event): + event.set() + def is_connected(self): return self._connected def __new__(cls): - if cls.instance is None: - cls.instance = super(ConnectionManager, cls).__new__(cls) - cls.instance._initialized = False - return cls.instance + with cls._instance_lock: + if cls.instance is None: + cls.instance = super(ConnectionManager, cls).__new__(cls) + cls.instance._initialized = False + return cls.instance diff --git a/server/mq/core/consumer.py b/server/mq/core/consumer.py index f3b0bec..76c2e56 100644 --- a/server/mq/core/consumer.py +++ b/server/mq/core/consumer.py @@ -1,6 +1,7 @@ import asyncio import functools import logging +import threading from typing import Any, Callable, Coroutine from pika.exchange_type import ExchangeType @@ -37,20 +38,24 @@ def __init__( self._channel = None self._closing = False self._consumer_tag = None + self._connection_thread_id = None + self._lock = threading.RLock() async def connect(self, loop=None): - LOGGER.info( - f"Connecting to {self._url} for exchange {self._exchange}, queue {self._queue}" - ) + async with self._lock: + LOGGER.info( + f"Connecting to {self._url} for exchange {self._exchange}, queue {self._queue}" + ) - try: - self._connection = await ConnectionManager().connect(loop=loop) - except Exception as e: - LOGGER.error(f"Failed to create connection for {self._queue}: {str(e)}") - self._connection = None - raise + try: + self._connection = await ConnectionManager().connect(loop=loop) + self._connection_thread_id = threading.get_ident() + except Exception as e: + LOGGER.error(f"Failed to create connection for {self._queue}: {str(e)}") + self._connection = None + raise - self.open_channel() + self.open_channel() def open_channel(self): LOGGER.info(f"Creating a new channel for {self._queue}") @@ -153,8 +158,22 @@ def on_message(self, _unused_channel, basic_deliver, properties, body): def stop_consuming(self): if self._channel: - LOGGER.info(f"Stopping consumption for {self._queue}") - self._channel.basic_cancel(self._consumer_tag, self.on_cancelok) + current_thread_id = threading.get_ident() + if current_thread_id == self._connection_thread_id: + LOGGER.info(f"Stopping consumption for {self._queue}") + self._channel.basic_cancel(self._consumer_tag, self.on_cancelok) + else: + # in a different thread, need to use add_callback_threadsafe + try: + self._connection.add_callback_threadsafe( + functools.partial( + self._channel.basic_cancel, + self._consumer_tag, + self.on_cancelok, + ) + ) + except Exception as e: + LOGGER.error(f"Failed to cancel consumer: {str(e)}") def on_cancelok(self, _unused_frame): """consumption is cancelled""" @@ -164,7 +183,18 @@ def on_cancelok(self, _unused_frame): def close_channel(self): LOGGER.info(f"Closing the channel for {self._queue}") if self._channel: - self._channel.close() + current_thread_id = threading.get_ident() + if current_thread_id == self._connection_thread_id: + self._channel.close() + else: + # in a different thread, need to use add_callback_threadsafe + try: + self._connection.add_callback_threadsafe( + functools.partial(self._channel.close) + ) + except Exception as e: + LOGGER.error(f"Failed to close channel: {str(e)}") + else: LOGGER.warning(f"Channel is already closed for {self._queue}") diff --git a/server/mq/core/producer.py b/server/mq/core/producer.py index 41e30a2..c7f2c85 100644 --- a/server/mq/core/producer.py +++ b/server/mq/core/producer.py @@ -1,5 +1,6 @@ import asyncio import logging +import threading from .connection_manager import ConnectionManager @@ -22,28 +23,37 @@ def __init__(self, amqp_url, exchange, exchange_type, routing_key=None): self._channel = None self._connected = False self._ready = asyncio.Event() + self._connection_thread_id = None + self._lock = threading.RLock() async def connect(self, loop=None): + async with self._lock: + if self._connected: + return self._connection - if self._connected: - return self._connection + self._ready.clear() - self._ready.clear() + self._connection = await ConnectionManager().connect( + loop=loop or asyncio.get_event_loop() + ) - self._connection = await ConnectionManager().connect( - loop=loop or asyncio.get_event_loop() - ) + LOGGER.info( + f"Producer connecting to {self._url} for exchange {self._exchange}" + ) + if not self._connection: + LOGGER.error( + f"Failed to create connection for producer {self._exchange}" + ) + return False - LOGGER.info(f"Producer connecting to {self._url} for exchange {self._exchange}") - if not self._connection: - LOGGER.error(f"Failed to create connection for producer {self._exchange}") - return False - self._connected = True - self.open_channel() + self._connected = True + self.open_channel() - self._ready.set() + self._connection_thread_id = threading.get_ident() - return self._connection + await self._ready.wait() + + return self._connection def open_channel(self): LOGGER.info(f"Creating a new channel for producer {self._exchange}") @@ -116,27 +126,70 @@ async def publish( # convert to bytes message = message.encode("utf-8") + current_thread_id = threading.get_ident() + + if current_thread_id == self._connection_thread_id: + return self._do_publish(message, actual_routing_key, properties, mandatory) + else: + # in a different thread, need to use add_callback_threadsafe + publish_event = asyncio.Event() + result = [False] + + def threadsafe_publish(): + try: + success = self._channel.basic_publish( + exchange=self._exchange, + routing_key=actual_routing_key, + body=message, + properties=properties, + mandatory=mandatory, + ) + result[0] = success + LOGGER.debug( + f"Published message to {self._exchange} with routing key {actual_routing_key}" + ) + except Exception as e: + LOGGER.error(f"Failed to publish message: {str(e)}") + result[0] = False + # mark as disconnected so the health monitor will reconnect it + self._connected = False + + # signal completion + asyncio.run_coroutine_threadsafe( + self._set_event(publish_event), self._connection._loop + ) + + try: + self._connection.add_callback_threadsafe(threadsafe_publish) + await publish_event.wait() + return result[0] + except Exception as e: + LOGGER.error(f"Failed to schedule publish: {str(e)}") + self._connected = False + return False + + def _do_publish(self, message, routing_key, properties, mandatory): try: self._channel.basic_publish( exchange=self._exchange, - routing_key=actual_routing_key, + routing_key=routing_key, body=message, properties=properties, mandatory=mandatory, ) LOGGER.debug( - f"Published message to {self._exchange} with routing key {actual_routing_key}" + f"Published message to {self._exchange} with routing key {routing_key}" ) return True except Exception as e: LOGGER.error(f"Failed to publish message: {str(e)}") # mark as disconnected so the health monitor will reconnect it - # debatable whether this is the right approach, will have - # to see how common failures are and whether they are naturally - # recovered self._connected = False return False + async def _set_event(self, event): + event.set() + async def close(self): LOGGER.info(f"Closing producer for {self._exchange}") if self._channel and self._channel.is_open: diff --git a/server/mq/core/synchronous_producer.py b/server/mq/core/synchronous_producer.py index 347a86d..944df25 100644 --- a/server/mq/core/synchronous_producer.py +++ b/server/mq/core/synchronous_producer.py @@ -1,11 +1,17 @@ +import functools +import logging import os +import threading import pika +logger = logging.getLogger(__name__) + class SynchronousRabbitProducer: _instance = None + _lock = threading.RLock() def __init__(self): if self._initialized: @@ -17,16 +23,58 @@ def __init__(self): pika.ConnectionParameters(host=self.host) ) self._channel = self._connection.channel() + self._connection_thread_id = threading.get_ident() def __new__(cls): - if not cls._instance: - cls._instance = super(SynchronousRabbitProducer, cls).__new__(cls) - cls._instance._initialized = False - return cls._instance + with cls._lock: + if not cls._instance: + cls._instance = super(SynchronousRabbitProducer, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance def publish(self, routing_key, body, exchange="swecc-server-exchange"): - self._channel.basic_publish( - exchange=exchange, - routing_key=routing_key, - body=body, - ) + current_thread_id = threading.get_ident() + + if current_thread_id == self._connection_thread_id: + self._channel.basic_publish( + exchange=exchange, + routing_key=routing_key, + body=body, + ) + else: + # in a different thread, need to use add_callback_threadsafe + try: + self._connection.add_callback_threadsafe( + functools.partial( + self._channel.basic_publish, + exchange=exchange, + routing_key=routing_key, + body=body, + ) + ) + except Exception as e: + logger.error(f"Error publishing message: {str(e)}") + if self._connection.is_closed: # attempt to reconnect + logger.info("Attempting to reconnect...") + self._reopen_connection() + # try again after reconnection + self._connection.add_callback_threadsafe( + functools.partial( + self._channel.basic_publish, + exchange=exchange, + routing_key=routing_key, + body=body, + ) + ) + + def _reopen_connection(self): + with self._lock: + if self._connection.is_closed: + try: + self._connection = pika.BlockingConnection( + pika.ConnectionParameters(host=self.host) + ) + self._channel = self._connection.channel() + self._connection_thread_id = threading.get_ident() + except Exception as e: + logger.error(f"Failed to reopen connection: {str(e)}") diff --git a/server/mq/producers.py b/server/mq/producers.py index 8cfe48c..8440bc4 100644 --- a/server/mq/producers.py +++ b/server/mq/producers.py @@ -1,6 +1,17 @@ +import logging + from mq.core.synchronous_producer import SynchronousRabbitProducer +logger = logging.getLogger(__name__) + def publish_verified_email(discord_id): producer_manager = SynchronousRabbitProducer() producer_manager.publish("server.verified-email", str(discord_id)) + + +def publish_health_check(): + producer_manager = SynchronousRabbitProducer() + producer_manager.publish("server.health-check", "health-check") + + logger.info("Health check message published to RabbitMQ") diff --git a/server/server/settings.py b/server/server/settings.py index d701533..ec3be08 100644 --- a/server/server/settings.py +++ b/server/server/settings.py @@ -81,16 +81,20 @@ "rest_framework_api_key", ] -if DJANGO_DEBUG: - print("DEBUG is enabled, adding debug apps") - INSTALLED_APPS += [ - "silk", - ] +# uncomment to enable silk in debug mode +# if DJANGO_DEBUG: +# print("DEBUG is enabled, adding debug apps") +# INSTALLED_APPS += [ +# "silk", +# ] - SILKY_PYTHON_PROFILER = True # Enable Python profiler - SILKY_ANALYZE_QUERIES = True # Analyze SQL queries - SILKY_MAX_RECORDED_REQUESTS = 10000 # Max number of recorded requests to store +# SILKY_PYTHON_PROFILER = True # Enable Python profiler +# SILKY_ANALYZE_QUERIES = True # Analyze SQL queries +# SILKY_MAX_RECORDED_REQUESTS = 10000 # Max number of recorded requests to store +# if DJANGO_DEBUG: +# print("DEBUG is enabled, adding debug middleware") +# MIDDLEWARE = ["silk.middleware.SilkyMiddleware"] + MIDDLEWARE FILE_UPLOAD_HANDLERS = [ "django.core.files.uploadhandler.MemoryFileUploadHandler", @@ -108,10 +112,6 @@ "django.middleware.clickjacking.XFrameOptionsMiddleware", ] -if DJANGO_DEBUG: - print("DEBUG is enabled, adding debug middleware") - MIDDLEWARE = ["silk.middleware.SilkyMiddleware"] + MIDDLEWARE - ROOT_URLCONF = "server.urls" TEMPLATES = [ @@ -144,6 +144,13 @@ "PASSWORD": DB_PASSWORD, "HOST": DB_HOST, "PORT": DB_PORT, + "OPTIONS": { + "pool": { + "min_size": 1, + "max_size": 5, + "timeout": 30, + } + }, } } diff --git a/server/server/urls.py b/server/server/urls.py index 81ffcad..38f6951 100644 --- a/server/server/urls.py +++ b/server/server/urls.py @@ -1,10 +1,10 @@ import logging from django.urls import include, path +from mq.producers import publish_health_check from rest_framework.decorators import api_view from rest_framework.response import Response -from .settings import DJANGO_DEBUG from .views import ManagementCommandView logger = logging.getLogger(__name__) @@ -26,14 +26,19 @@ path("resume/", include("resume_review.urls")), ] -if DJANGO_DEBUG: - logger.info("DEBUG is enabled, adding debug urls") - urlpatterns += [path("silk/", include("silk.urls", namespace="silk"))] + +# uncomment to enable silk in debug mode + +# from .settings import DJANGO_DEBUG +# if DJANGO_DEBUG: +# logger.info("DEBUG is enabled, adding debug urls") +# urlpatterns += [path("silk/", include("silk.urls", namespace="silk"))] @api_view(["GET"]) def health_check(request): logger.info("Health check") + publish_health_check() return Response({"status": "ok"})