Skip to content
This repository was archived by the owner on Jan 2, 2026. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.9 AS development
FROM python:3.13 AS development

WORKDIR /app
COPY requirements-server.txt .
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.manager
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.9
FROM python:3.13

WORKDIR /app
COPY requirements-rabbit.txt .
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ types-requests
# misc
pre-commit
codespell
django-silk
5 changes: 2 additions & 3 deletions requirements-rabbit.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
Django
Django==5.1
pika
django-cors-headers
djangorestframework
djangorestframework-simplejwt
djangorestframework-api-key==2.*
psycopg2
django-silk
psycopg[binary, pool]
5 changes: 2 additions & 3 deletions requirements-server.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
asgiref
Django
Django==5.1
django-cors-headers
djangorestframework
djangorestframework-simplejwt
PyJWT
pytz
sqlparse
psycopg2-binary
psycopg[binary, pool]
djangorestframework-api-key==2.*
gunicorn
numpy
Expand All @@ -20,4 +20,3 @@ redis
PyJWT==1.7.1
boto3
pika
django-silk
43 changes: 38 additions & 5 deletions server/mq/core/connection_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import os
import threading
import urllib

import pika
Expand All @@ -11,6 +12,7 @@

class ConnectionManager:
instance = None
_instance_lock = threading.RLock()

def __init__(self):
if not ConnectionManager.instance._initialized:
Expand All @@ -20,6 +22,7 @@ def __init__(self):
self._connected = False
self._loop = None
self._url = self._build_amqp_url()
self._connection_thread_id = None
self._initialized = True

async def connect(self, loop=None):
Expand All @@ -42,6 +45,9 @@ async def connect(self, loop=None):
custom_ioloop=loop,
)

self._loop = loop or asyncio.get_event_loop()
self._connection_thread_id = threading.get_ident()

await self._ready.wait()

self._connection = future_connection
Expand Down Expand Up @@ -84,14 +90,41 @@ async def close(self):
if self._connection and not (
self._connection.is_closing or self._connection.is_closed
):
self._connection.close()
current_thread_id = threading.get_ident()
if current_thread_id == self._connection_thread_id:
self._connection.close()
else:
# in a different thread, need to use add_callback_threadsafe
close_event = asyncio.Event()

def threadsafe_close():
try:
self._connection.close()
except Exception as e:
logger.error(f"Error closing connection: {str(e)}")
finally:
# signal completion
asyncio.run_coroutine_threadsafe(
self._set_event(close_event), self._loop
)

try:
self._connection.add_callback_threadsafe(threadsafe_close)
await close_event.wait()
except Exception as e:
logger.error(f"Failed to schedule connection close: {str(e)}")

self._connected = False

async def _set_event(self, event):
event.set()

def is_connected(self):
return self._connected

def __new__(cls):
if cls.instance is None:
cls.instance = super(ConnectionManager, cls).__new__(cls)
cls.instance._initialized = False
return cls.instance
with cls._instance_lock:
if cls.instance is None:
cls.instance = super(ConnectionManager, cls).__new__(cls)
cls.instance._initialized = False
return cls.instance
56 changes: 43 additions & 13 deletions server/mq/core/consumer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import logging
import threading
from typing import Any, Callable, Coroutine

from pika.exchange_type import ExchangeType
Expand Down Expand Up @@ -37,20 +38,24 @@ def __init__(
self._channel = None
self._closing = False
self._consumer_tag = None
self._connection_thread_id = None
self._lock = threading.RLock()

async def connect(self, loop=None):
LOGGER.info(
f"Connecting to {self._url} for exchange {self._exchange}, queue {self._queue}"
)
async with self._lock:
LOGGER.info(
f"Connecting to {self._url} for exchange {self._exchange}, queue {self._queue}"
)

try:
self._connection = await ConnectionManager().connect(loop=loop)
except Exception as e:
LOGGER.error(f"Failed to create connection for {self._queue}: {str(e)}")
self._connection = None
raise
try:
self._connection = await ConnectionManager().connect(loop=loop)
self._connection_thread_id = threading.get_ident()
except Exception as e:
LOGGER.error(f"Failed to create connection for {self._queue}: {str(e)}")
self._connection = None
raise

self.open_channel()
self.open_channel()

def open_channel(self):
LOGGER.info(f"Creating a new channel for {self._queue}")
Expand Down Expand Up @@ -153,8 +158,22 @@ def on_message(self, _unused_channel, basic_deliver, properties, body):

def stop_consuming(self):
if self._channel:
LOGGER.info(f"Stopping consumption for {self._queue}")
self._channel.basic_cancel(self._consumer_tag, self.on_cancelok)
current_thread_id = threading.get_ident()
if current_thread_id == self._connection_thread_id:
LOGGER.info(f"Stopping consumption for {self._queue}")
self._channel.basic_cancel(self._consumer_tag, self.on_cancelok)
else:
# in a different thread, need to use add_callback_threadsafe
try:
self._connection.add_callback_threadsafe(
functools.partial(
self._channel.basic_cancel,
self._consumer_tag,
self.on_cancelok,
)
)
except Exception as e:
LOGGER.error(f"Failed to cancel consumer: {str(e)}")

def on_cancelok(self, _unused_frame):
"""consumption is cancelled"""
Expand All @@ -164,7 +183,18 @@ def on_cancelok(self, _unused_frame):
def close_channel(self):
LOGGER.info(f"Closing the channel for {self._queue}")
if self._channel:
self._channel.close()
current_thread_id = threading.get_ident()
if current_thread_id == self._connection_thread_id:
self._channel.close()
else:
# in a different thread, need to use add_callback_threadsafe
try:
self._connection.add_callback_threadsafe(
functools.partial(self._channel.close)
)
except Exception as e:
LOGGER.error(f"Failed to close channel: {str(e)}")

else:
LOGGER.warning(f"Channel is already closed for {self._queue}")

Expand Down
91 changes: 72 additions & 19 deletions server/mq/core/producer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import threading

from .connection_manager import ConnectionManager

Expand All @@ -22,28 +23,37 @@ def __init__(self, amqp_url, exchange, exchange_type, routing_key=None):
self._channel = None
self._connected = False
self._ready = asyncio.Event()
self._connection_thread_id = None
self._lock = threading.RLock()

async def connect(self, loop=None):
async with self._lock:
if self._connected:
return self._connection

if self._connected:
return self._connection
self._ready.clear()

self._ready.clear()
self._connection = await ConnectionManager().connect(
loop=loop or asyncio.get_event_loop()
)

self._connection = await ConnectionManager().connect(
loop=loop or asyncio.get_event_loop()
)
LOGGER.info(
f"Producer connecting to {self._url} for exchange {self._exchange}"
)
if not self._connection:
LOGGER.error(
f"Failed to create connection for producer {self._exchange}"
)
return False

LOGGER.info(f"Producer connecting to {self._url} for exchange {self._exchange}")
if not self._connection:
LOGGER.error(f"Failed to create connection for producer {self._exchange}")
return False
self._connected = True
self.open_channel()
self._connected = True
self.open_channel()

self._ready.set()
self._connection_thread_id = threading.get_ident()

return self._connection
await self._ready.wait()

return self._connection

def open_channel(self):
LOGGER.info(f"Creating a new channel for producer {self._exchange}")
Expand Down Expand Up @@ -116,27 +126,70 @@ async def publish(
# convert to bytes
message = message.encode("utf-8")

current_thread_id = threading.get_ident()

if current_thread_id == self._connection_thread_id:
return self._do_publish(message, actual_routing_key, properties, mandatory)
else:
# in a different thread, need to use add_callback_threadsafe
publish_event = asyncio.Event()
result = [False]

def threadsafe_publish():
try:
success = self._channel.basic_publish(
exchange=self._exchange,
routing_key=actual_routing_key,
body=message,
properties=properties,
mandatory=mandatory,
)
result[0] = success
LOGGER.debug(
f"Published message to {self._exchange} with routing key {actual_routing_key}"
)
except Exception as e:
LOGGER.error(f"Failed to publish message: {str(e)}")
result[0] = False
# mark as disconnected so the health monitor will reconnect it
self._connected = False

# signal completion
asyncio.run_coroutine_threadsafe(
self._set_event(publish_event), self._connection._loop
)

try:
self._connection.add_callback_threadsafe(threadsafe_publish)
await publish_event.wait()
return result[0]
except Exception as e:
LOGGER.error(f"Failed to schedule publish: {str(e)}")
self._connected = False
return False

def _do_publish(self, message, routing_key, properties, mandatory):
try:
self._channel.basic_publish(
exchange=self._exchange,
routing_key=actual_routing_key,
routing_key=routing_key,
body=message,
properties=properties,
mandatory=mandatory,
)
LOGGER.debug(
f"Published message to {self._exchange} with routing key {actual_routing_key}"
f"Published message to {self._exchange} with routing key {routing_key}"
)
return True
except Exception as e:
LOGGER.error(f"Failed to publish message: {str(e)}")
# mark as disconnected so the health monitor will reconnect it
# debatable whether this is the right approach, will have
# to see how common failures are and whether they are naturally
# recovered
self._connected = False
return False

async def _set_event(self, event):
event.set()

async def close(self):
LOGGER.info(f"Closing producer for {self._exchange}")
if self._channel and self._channel.is_open:
Expand Down
Loading