diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 5187e169..8254895a 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -221,7 +221,7 @@ async def test_async_storage_manager_mapping_functions(): # Mock handshake response handshake_response = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, + request_type=ZMQRequestType.HANDSHAKE_ACK, # type: ignore[arg-type] sender_id="controller_0", body={"message": "Handshake successful"}, ) @@ -295,7 +295,7 @@ async def test_async_storage_manager_error_handling(): # Mock handshake response handshake_response = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, + request_type=ZMQRequestType.HANDSHAKE_ACK, # type: ignore[arg-type] sender_id="controller_0", body={"message": "Handshake successful"}, ) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 05fac154..2090ad1a 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -39,6 +39,7 @@ ZMQRequestType, ZMQServerInfo, create_zmq_socket, + format_zmq_address, ) logger = logging.getLogger(__name__) @@ -124,9 +125,9 @@ async def wrapper(self, *args, **kwargs): raise RuntimeError("No controller registered") context = zmq.asyncio.Context() - address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" + address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity) + sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) try: sock.connect(address) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 15ff873c..50ff10df 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -29,7 +29,6 @@ import torch import zmq from omegaconf import DictConfig -from ray.util import get_node_ip_address from torch import Tensor from transfer_queue.metadata import ( @@ -45,7 +44,9 @@ ZMQRequestType, ZMQServerInfo, create_zmq_socket, + format_zmq_address, get_free_port, + get_node_ip_address_raw, ) logger = logging.getLogger(__name__) @@ -1574,31 +1575,36 @@ def kv_retrieve_keys( def _init_zmq_socket(self): """Initialize ZMQ sockets for communication.""" self.zmq_context = zmq.Context() - self._node_ip = get_node_ip_address() + self._node_ip = get_node_ip_address_raw() while True: try: - self._handshake_socket_port = get_free_port() - self._request_handle_socket_port = get_free_port() - self._data_status_update_socket_port = get_free_port() + self._handshake_socket_port = get_free_port(ip=self._node_ip) + self._request_handle_socket_port = get_free_port(ip=self._node_ip) + self._data_status_update_socket_port = get_free_port(ip=self._node_ip) self.handshake_socket = create_zmq_socket( ctx=self.zmq_context, socket_type=zmq.ROUTER, + ip=self._node_ip, ) - self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}") + self.handshake_socket.bind(format_zmq_address(self._node_ip, self._handshake_socket_port)) self.request_handle_socket = create_zmq_socket( ctx=self.zmq_context, socket_type=zmq.ROUTER, + ip=self._node_ip, ) - self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}") + self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port)) self.data_status_update_socket = create_zmq_socket( ctx=self.zmq_context, socket_type=zmq.ROUTER, + ip=self._node_ip, + ) + self.data_status_update_socket.bind( + format_zmq_address(self._node_ip, self._data_status_update_socket_port) ) - self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}") break except zmq.ZMQError: diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index cad91f97..cb714c32 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -83,8 +83,9 @@ def _connect_to_controller(self) -> None: # create zmq socket for handshake (sync, for initial connection) self.controller_handshake_socket = create_zmq_socket( - sync_zmq_context, - zmq.DEALER, + ctx=sync_zmq_context, + socket_type=zmq.DEALER, + ip=self.controller_info.ip, identity=f"{self.storage_manager_id}-controller_handshake_socket-{uuid4().hex[:8]}".encode(), ) @@ -219,13 +220,13 @@ async def notify_data_update( # create dynamic socket identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(self.zmq_context, zmq.DEALER, identity=identity) + sock = create_zmq_socket(self.zmq_context, zmq.DEALER, self.controller_info.ip, identity) try: sock.connect(self.controller_info.to_addr("data_status_update_socket")) request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={ "partition_id": partition_id, @@ -253,7 +254,7 @@ async def notify_data_update( messages = await asyncio.wait_for(sock.recv_multipart(), timeout=poll_interval) response_msg = ZMQMessage.deserialize(messages) - if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: + if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: # type: ignore[arg-type] response_received = True logger.debug( f"[{self.storage_manager_id}]: Get data status update ACK response " @@ -272,7 +273,7 @@ async def notify_data_update( logger.error(f"[{self.storage_manager_id}]: Error during notify_data_update: {e}") try: error_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={"message": f"Failed to notify: {str(e)}"}, ).serialize() diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 745c04b8..608a0827 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -34,7 +34,13 @@ from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory from transfer_queue.storage.simple_backend import StorageMetaGroup from transfer_queue.utils.common import get_env_bool -from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket +from transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, + format_zmq_address, +) logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -150,9 +156,9 @@ async def wrapper(self, *args, **kwargs): raise RuntimeError(f"Server {server_key} not found in registered servers") context = zmq.asyncio.Context() - address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" + address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity) + sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity) try: sock.connect(address) diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index ed12d547..d84abe3b 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -26,13 +26,20 @@ import ray import zmq -from ray.util import get_node_ip_address from transfer_queue.metadata import SampleMeta from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads from transfer_queue.utils.enum_utils import TransferQueueRole from transfer_queue.utils.perf_utils import IntervalPerfMonitor -from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket, get_free_port +from transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, + format_zmq_address, + get_free_port, + get_node_ip_address_raw, +) logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -208,22 +215,22 @@ def _init_zmq_socket(self) -> None: - worker_socket (DEALER): Backend socket for worker communication. """ self.zmq_context = zmq.Context() + self._node_ip = get_node_ip_address_raw() # Frontend: ROUTER for receiving client requests - self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER) - self._node_ip = get_node_ip_address() + self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, self._node_ip) while True: try: - self._put_get_socket_port = get_free_port() - self.put_get_socket.bind(f"tcp://{self._node_ip}:{self._put_get_socket_port}") + self._put_get_socket_port = get_free_port(ip=self._node_ip) + self.put_get_socket.bind(format_zmq_address(self._node_ip, self._put_get_socket_port)) break except zmq.ZMQError: logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...") continue # Backend: DEALER for worker communication (connected via zmq.proxy) - self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER) + self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip) self.worker_socket.bind(self._inproc_addr) self.zmq_server_info = ZMQServerInfo( @@ -269,8 +276,8 @@ def _proxy_routine(self) -> None: def _worker_routine(self) -> None: """Worker thread for processing requests.""" - # Each worker must have its own socket - worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER) + + worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip) worker_socket.connect(self._inproc_addr) poller = zmq.Poller() @@ -306,18 +313,18 @@ def _worker_routine(self) -> None: logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}") # Process request - if operation == ZMQRequestType.PUT_DATA: + if operation == ZMQRequestType.PUT_DATA: # type: ignore[arg-type] with perf_monitor.measure(op_type="PUT_DATA"): response_msg = self._handle_put(request_msg) - elif operation == ZMQRequestType.GET_DATA: + elif operation == ZMQRequestType.GET_DATA: # type: ignore[arg-type] with perf_monitor.measure(op_type="GET_DATA"): response_msg = self._handle_get(request_msg) - elif operation == ZMQRequestType.CLEAR_DATA: + elif operation == ZMQRequestType.CLEAR_DATA: # type: ignore[arg-type] with perf_monitor.measure(op_type="CLEAR_DATA"): response_msg = self._handle_clear(request_msg) else: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, + request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Storage unit id #{self.storage_unit_id} " @@ -326,7 +333,7 @@ def _worker_routine(self) -> None: ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_ERROR, + request_type=ZMQRequestType.PUT_GET_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"{self.storage_unit_id}, worker encountered error " @@ -361,13 +368,15 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: # After put operation finish, send a message to the client response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.storage_unit_id, body={} + request_type=ZMQRequestType.PUT_DATA_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={}, # type: ignore[arg-type] ) return response_msg except Exception as e: return ZMQMessage.create( - request_type=ZMQRequestType.PUT_ERROR, + request_type=ZMQRequestType.PUT_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to put data into storage unit id " @@ -395,7 +404,7 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: result_data = self.storage_data.get_data(fields, local_indexes) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA_RESPONSE, + request_type=ZMQRequestType.GET_DATA_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "data": result_data, @@ -403,7 +412,7 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_ERROR, + request_type=ZMQRequestType.GET_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to get data from storage unit id #{self.storage_unit_id}, " @@ -431,13 +440,13 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: self.storage_data.clear(local_indexes) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, + request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={"message": f"Clear data in storage unit id #{self.storage_unit_id} successfully."}, ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_ERROR, + request_type=ZMQRequestType.CLEAR_DATA_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to clear data in storage unit id #{self.storage_unit_id}, " diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 7d571a5e..7d0d8d18 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -25,6 +25,7 @@ import psutil import ray import zmq +from ray.util import get_node_ip_address from transfer_queue.utils.common import ( get_env_bool, @@ -115,7 +116,7 @@ class ZMQServerInfo: TransferQueue server info class. """ - def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]): + def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, int]): self.role = role self.id = id self.ip = ip @@ -123,7 +124,7 @@ def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, s def to_addr(self, port_name: str) -> str: """Convert zmq port name to address string.""" - return f"tcp://{self.ip}:{self.ports[port_name]}" + return format_zmq_address(self.ip, self.ports[port_name]) def to_dict(self): """Convert ZMQServerInfo to dict.""" @@ -209,9 +210,64 @@ def deserialize(cls, frames: list) -> "ZMQMessage": return pickle.loads(frames[0]) -def get_free_port() -> str: - """Get free port of the host.""" - with socket.socket() as sock: +def is_ipv6_address(ip: str) -> bool: + """Check if the given IP address is an IPv6 address.""" + try: + socket.inet_pton(socket.AF_INET6, ip) + return True + except OSError: + return False + + +def format_zmq_address(ip: str, port: int) -> str: + """ + Format IP and port for ZMQ binding/connecting. + + For IPv6 addresses, ZMQ requires the address to be wrapped in brackets: + - IPv6: tcp://[::1]:port + - IPv4: tcp://1.2.3.4:port + + Args: + ip: IP address (IPv4 or IPv6) + port: Port number + + Returns: + Formatted ZMQ address string + """ + if is_ipv6_address(ip): + return f"tcp://[{ip}]:{port}" + else: + return f"tcp://{ip}:{port}" + + +def get_node_ip_address_raw() -> str: + """A wrapper around Ray's get_node_ip_address(). + + This function intentionally returns a raw IPv4/IPv6 address WITHOUT brackets. + """ + + return get_node_ip_address().strip("[]") + + +def get_free_port(ip: str) -> int: + """Get free port of the host. + + Args: + ip: IP address to detect IPv6 and enable IPV6 socket option + """ + is_ipv6 = is_ipv6_address(ip) + family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + + with socket.socket(family, socket.SOCK_STREAM) as sock: + if is_ipv6: + # Try to allow dual-stack if the platform supports it. + try: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (OSError, AttributeError): + # Some platforms don't support IPV6_V6ONLY or this option; + # in that case just ignore and use the default behavior. + pass + sock.bind(("", 0)) return sock.getsockname()[1] @@ -219,12 +275,24 @@ def get_free_port() -> str: def create_zmq_socket( ctx: zmq.Context, socket_type: Any, + ip: str, identity: Optional[bytestr] = None, ) -> zmq.Socket: - """Create ZMQ socket.""" + """Create ZMQ socket. + + Args: + ctx: ZMQ context + socket_type: ZMQ socket type + ip: IP address to detect IPv6 and enable IPV6 socket option + identity: Optional socket identity + """ mem = psutil.virtual_memory() socket = ctx.socket(socket_type) + # Enable IPv6 if the IP address is IPv6 + if is_ipv6_address(ip): + socket.setsockopt(zmq.IPV6, 1) + # Calculate buffer size based on system memory total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3