Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_async_simple_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async def test_async_storage_manager_mapping_functions():

# Mock handshake response
handshake_response = ZMQMessage.create(
request_type=ZMQRequestType.HANDSHAKE_ACK,
request_type=ZMQRequestType.HANDSHAKE_ACK, # type: ignore[arg-type]
sender_id="controller_0",
body={"message": "Handshake successful"},
)
Expand Down Expand Up @@ -295,7 +295,7 @@ async def test_async_storage_manager_error_handling():

# Mock handshake response
handshake_response = ZMQMessage.create(
request_type=ZMQRequestType.HANDSHAKE_ACK,
request_type=ZMQRequestType.HANDSHAKE_ACK, # type: ignore[arg-type]
sender_id="controller_0",
body={"message": "Handshake successful"},
)
Expand Down
5 changes: 3 additions & 2 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,9 +125,9 @@ async def wrapper(self, *args, **kwargs):
raise RuntimeError("No controller registered")

context = zmq.asyncio.Context()
address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(context, zmq.DEALER, identity=identity)
sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip)

try:
sock.connect(address)
Expand Down
22 changes: 14 additions & 8 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import torch
import zmq
from omegaconf import DictConfig
from ray.util import get_node_ip_address
from torch import Tensor

from transfer_queue.metadata import (
Expand All @@ -45,7 +44,9 @@
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
get_free_port,
get_node_ip_address_raw,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1574,31 +1575,36 @@ def kv_retrieve_keys(
def _init_zmq_socket(self):
"""Initialize ZMQ sockets for communication."""
self.zmq_context = zmq.Context()
self._node_ip = get_node_ip_address()
self._node_ip = get_node_ip_address_raw()

while True:
try:
self._handshake_socket_port = get_free_port()
self._request_handle_socket_port = get_free_port()
self._data_status_update_socket_port = get_free_port()
self._handshake_socket_port = get_free_port(ip=self._node_ip)
self._request_handle_socket_port = get_free_port(ip=self._node_ip)
self._data_status_update_socket_port = get_free_port(ip=self._node_ip)

self.handshake_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}")
self.handshake_socket.bind(format_zmq_address(self._node_ip, self._handshake_socket_port))

self.request_handle_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}")
self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port))

self.data_status_update_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.data_status_update_socket.bind(
format_zmq_address(self._node_ip, self._data_status_update_socket_port)
)
self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}")

break
except zmq.ZMQError:
Expand Down
13 changes: 7 additions & 6 deletions transfer_queue/storage/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ def _connect_to_controller(self) -> None:

# create zmq socket for handshake (sync, for initial connection)
self.controller_handshake_socket = create_zmq_socket(
sync_zmq_context,
zmq.DEALER,
ctx=sync_zmq_context,
socket_type=zmq.DEALER,
ip=self.controller_info.ip,
identity=f"{self.storage_manager_id}-controller_handshake_socket-{uuid4().hex[:8]}".encode(),
)

Expand Down Expand Up @@ -219,13 +220,13 @@ async def notify_data_update(

# create dynamic socket
identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(self.zmq_context, zmq.DEALER, identity=identity)
sock = create_zmq_socket(self.zmq_context, zmq.DEALER, self.controller_info.ip, identity)

try:
sock.connect(self.controller_info.to_addr("data_status_update_socket"))

request_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE,
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"partition_id": partition_id,
Expand Down Expand Up @@ -253,7 +254,7 @@ async def notify_data_update(
messages = await asyncio.wait_for(sock.recv_multipart(), timeout=poll_interval)
response_msg = ZMQMessage.deserialize(messages)

if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK:
if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: # type: ignore[arg-type]
response_received = True
logger.debug(
f"[{self.storage_manager_id}]: Get data status update ACK response "
Expand All @@ -272,7 +273,7 @@ async def notify_data_update(
logger.error(f"[{self.storage_manager_id}]: Error during notify_data_update: {e}")
try:
error_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR,
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={"message": f"Failed to notify: {str(e)}"},
).serialize()
Expand Down
12 changes: 9 additions & 3 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
from transfer_queue.storage.simple_backend import StorageMetaGroup
from transfer_queue.utils.common import get_env_bool
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
)

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
Expand Down Expand Up @@ -150,9 +156,9 @@ async def wrapper(self, *args, **kwargs):
raise RuntimeError(f"Server {server_key} not found in registered servers")

context = zmq.asyncio.Context()
address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(context, zmq.DEALER, identity=identity)
sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity)

try:
sock.connect(address)
Expand Down
49 changes: 29 additions & 20 deletions transfer_queue/storage/simple_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@

import ray
import zmq
from ray.util import get_node_ip_address

from transfer_queue.metadata import SampleMeta
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
from transfer_queue.utils.enum_utils import TransferQueueRole
from transfer_queue.utils.perf_utils import IntervalPerfMonitor
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket, get_free_port
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
get_free_port,
get_node_ip_address_raw,
)

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
Expand Down Expand Up @@ -208,22 +215,22 @@ def _init_zmq_socket(self) -> None:
- worker_socket (DEALER): Backend socket for worker communication.
"""
self.zmq_context = zmq.Context()
self._node_ip = get_node_ip_address_raw()

# Frontend: ROUTER for receiving client requests
self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER)
self._node_ip = get_node_ip_address()
self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, self._node_ip)

while True:
try:
self._put_get_socket_port = get_free_port()
self.put_get_socket.bind(f"tcp://{self._node_ip}:{self._put_get_socket_port}")
self._put_get_socket_port = get_free_port(ip=self._node_ip)
self.put_get_socket.bind(format_zmq_address(self._node_ip, self._put_get_socket_port))
break
except zmq.ZMQError:
logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...")
continue

# Backend: DEALER for worker communication (connected via zmq.proxy)
self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER)
self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip)
self.worker_socket.bind(self._inproc_addr)

self.zmq_server_info = ZMQServerInfo(
Expand Down Expand Up @@ -269,8 +276,8 @@ def _proxy_routine(self) -> None:

def _worker_routine(self) -> None:
"""Worker thread for processing requests."""
# Each worker must have its own socket
worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER)

worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip)
worker_socket.connect(self._inproc_addr)

poller = zmq.Poller()
Expand Down Expand Up @@ -306,18 +313,18 @@ def _worker_routine(self) -> None:
logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}")

# Process request
if operation == ZMQRequestType.PUT_DATA:
if operation == ZMQRequestType.PUT_DATA: # type: ignore[arg-type]
with perf_monitor.measure(op_type="PUT_DATA"):
response_msg = self._handle_put(request_msg)
elif operation == ZMQRequestType.GET_DATA:
elif operation == ZMQRequestType.GET_DATA: # type: ignore[arg-type]
with perf_monitor.measure(op_type="GET_DATA"):
response_msg = self._handle_get(request_msg)
elif operation == ZMQRequestType.CLEAR_DATA:
elif operation == ZMQRequestType.CLEAR_DATA: # type: ignore[arg-type]
with perf_monitor.measure(op_type="CLEAR_DATA"):
response_msg = self._handle_clear(request_msg)
else:
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR,
request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type]
sender_id=self.storage_unit_id,
body={
"message": f"Storage unit id #{self.storage_unit_id} "
Expand All @@ -326,7 +333,7 @@ def _worker_routine(self) -> None:
)
except Exception as e:
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.PUT_GET_ERROR,
request_type=ZMQRequestType.PUT_GET_ERROR, # type: ignore[arg-type]
sender_id=self.storage_unit_id,
body={
"message": f"{self.storage_unit_id}, worker encountered error "
Expand Down Expand Up @@ -361,13 +368,15 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage:

# After put operation finish, send a message to the client
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.storage_unit_id, body={}
request_type=ZMQRequestType.PUT_DATA_RESPONSE, # type: ignore[arg-type]
sender_id=self.storage_unit_id,
body={}, # type: ignore[arg-type]
)

return response_msg
except Exception as e:
return ZMQMessage.create(
request_type=ZMQRequestType.PUT_ERROR,
request_type=ZMQRequestType.PUT_ERROR, # type: ignore[arg-type]
sender_id=self.storage_unit_id,
body={
"message": f"Failed to put data into storage unit id "
Expand Down Expand Up @@ -395,15 +404,15 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage:
result_data = self.storage_data.get_data(fields, local_indexes)

response_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_DATA_RESPONSE,
request_type=ZMQRequestType.GET_DATA_RESPONSE, # type: ignore[arg-type]
sender_id=self.storage_unit_id,
body={
"data": result_data,
},
)
except Exception as e:
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_ERROR,
request_type=ZMQRequestType.GET_ERROR, # type: ignore[arg-type]
sender_id=self.storage_unit_id,
body={
"message": f"Failed to get data from storage unit id #{self.storage_unit_id}, "
Expand Down Expand Up @@ -431,13 +440,13 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage:
self.storage_data.clear(local_indexes)

response_msg = ZMQMessage.create(
request_type=ZMQRequestType.CLEAR_DATA_RESPONSE,
request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, # type: ignore[arg-type]
sender_id=self.storage_unit_id,
body={"message": f"Clear data in storage unit id #{self.storage_unit_id} successfully."},
)
except Exception as e:
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.CLEAR_DATA_ERROR,
request_type=ZMQRequestType.CLEAR_DATA_ERROR, # type: ignore[arg-type]
sender_id=self.storage_unit_id,
body={
"message": f"Failed to clear data in storage unit id #{self.storage_unit_id}, "
Expand Down
Loading