From 0115fae5d9f4362ef75e5d4215bddfdbfeeab977 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Mar 2026 21:01:54 +0800 Subject: [PATCH 1/5] try: support ipv6 Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 4 +- transfer_queue/client.py | 4 +- transfer_queue/controller.py | 9 ++-- .../managers/simple_backend_manager.py | 4 +- transfer_queue/storage/simple_backend.py | 30 +++++++------ transfer_queue/utils/zmq_utils.py | 43 ++++++++++++++++++- 6 files changed, 69 insertions(+), 25 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 5187e169..8254895a 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -221,7 +221,7 @@ async def test_async_storage_manager_mapping_functions(): # Mock handshake response handshake_response = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, + request_type=ZMQRequestType.HANDSHAKE_ACK, # type: ignore[arg-type] sender_id="controller_0", body={"message": "Handshake successful"}, ) @@ -295,7 +295,7 @@ async def test_async_storage_manager_error_handling(): # Mock handshake response handshake_response = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, + request_type=ZMQRequestType.HANDSHAKE_ACK, # type: ignore[arg-type] sender_id="controller_0", body={"message": "Handshake successful"}, ) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 05fac154..7988e261 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -124,9 +124,9 @@ async def wrapper(self, *args, **kwargs): raise RuntimeError("No controller registered") context = zmq.asyncio.Context() - address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" + address = f"tcp://[{server_info.ip}]:{server_info.ports.get(socket_name)}" identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity) + sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) try: sock.connect(address) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 15ff873c..8ef0cab6 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1585,20 +1585,23 @@ def _init_zmq_socket(self): self.handshake_socket = create_zmq_socket( ctx=self.zmq_context, socket_type=zmq.ROUTER, + ip=self._node_ip, ) - self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}") + self.handshake_socket.bind(f"tcp://[{self._node_ip}]:{self._handshake_socket_port}") self.request_handle_socket = create_zmq_socket( ctx=self.zmq_context, socket_type=zmq.ROUTER, + ip=self._node_ip, ) - self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}") + self.request_handle_socket.bind(f"tcp://[{self._node_ip}]:{self._request_handle_socket_port}") self.data_status_update_socket = create_zmq_socket( ctx=self.zmq_context, socket_type=zmq.ROUTER, + ip=self._node_ip, ) - self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}") + self.data_status_update_socket.bind(f"tcp://[{self._node_ip}]:{self._data_status_update_socket_port}") break except zmq.ZMQError: diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 745c04b8..7a057c91 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -150,9 +150,9 @@ async def wrapper(self, *args, **kwargs): raise RuntimeError(f"Server {server_key} not found in registered servers") context = zmq.asyncio.Context() - address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" + address = f"tcp://[{server_info.ip}]:{server_info.ports.get(socket_name)}" identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity) + sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) try: sock.connect(address) diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index ed12d547..41d3b73b 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -208,15 +208,15 @@ def _init_zmq_socket(self) -> None: - worker_socket (DEALER): Backend socket for worker communication. """ self.zmq_context = zmq.Context() + self._node_ip = get_node_ip_address() # Frontend: ROUTER for receiving client requests - self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER) - self._node_ip = get_node_ip_address() + self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, ip=self._node_ip) while True: try: self._put_get_socket_port = get_free_port() - self.put_get_socket.bind(f"tcp://{self._node_ip}:{self._put_get_socket_port}") + self.put_get_socket.bind(f"tcp://[{self._node_ip}]:{self._put_get_socket_port}") break except zmq.ZMQError: logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...") @@ -306,18 +306,18 @@ def _worker_routine(self) -> None: logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}") # Process request - if operation == ZMQRequestType.PUT_DATA: + if operation == ZMQRequestType.PUT_DATA: # type: ignore[arg-type] with perf_monitor.measure(op_type="PUT_DATA"): response_msg = self._handle_put(request_msg) - elif operation == ZMQRequestType.GET_DATA: + elif operation == ZMQRequestType.GET_DATA: # type: ignore[arg-type] with perf_monitor.measure(op_type="GET_DATA"): response_msg = self._handle_get(request_msg) - elif operation == ZMQRequestType.CLEAR_DATA: + elif operation == ZMQRequestType.CLEAR_DATA: # type: ignore[arg-type] with perf_monitor.measure(op_type="CLEAR_DATA"): response_msg = self._handle_clear(request_msg) else: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, + request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Storage unit id #{self.storage_unit_id} " @@ -326,7 +326,7 @@ def _worker_routine(self) -> None: ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_ERROR, + request_type=ZMQRequestType.PUT_GET_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"{self.storage_unit_id}, worker encountered error " @@ -361,13 +361,15 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: # After put operation finish, send a message to the client response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.storage_unit_id, body={} + request_type=ZMQRequestType.PUT_DATA_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={}, # type: ignore[arg-type] ) return response_msg except Exception as e: return ZMQMessage.create( - request_type=ZMQRequestType.PUT_ERROR, + request_type=ZMQRequestType.PUT_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to put data into storage unit id " @@ -395,7 +397,7 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: result_data = self.storage_data.get_data(fields, local_indexes) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA_RESPONSE, + request_type=ZMQRequestType.GET_DATA_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "data": result_data, @@ -403,7 +405,7 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_ERROR, + request_type=ZMQRequestType.GET_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to get data from storage unit id #{self.storage_unit_id}, " @@ -431,13 +433,13 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: self.storage_data.clear(local_indexes) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, + request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={"message": f"Clear data in storage unit id #{self.storage_unit_id} successfully."}, ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_ERROR, + request_type=ZMQRequestType.CLEAR_DATA_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to clear data in storage unit id #{self.storage_unit_id}, " diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 7d571a5e..fd2e0941 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -123,7 +123,7 @@ def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, s def to_addr(self, port_name: str) -> str: """Convert zmq port name to address string.""" - return f"tcp://{self.ip}:{self.ports[port_name]}" + return format_zmq_address(self.ip, self.ports[port_name]) def to_dict(self): """Convert ZMQServerInfo to dict.""" @@ -209,6 +209,33 @@ def deserialize(cls, frames: list) -> "ZMQMessage": return pickle.loads(frames[0]) +def is_ipv6_address(ip: str) -> bool: + """Check if the given IP address is an IPv6 address.""" + try: + socket.inet_pton(socket.AF_INET6, ip) + return True + except OSError: + return False + + +def format_zmq_address(ip: str, port: str | int) -> str: + """ + Format IP and port for ZMQ binding/connecting. + + For IPv6 addresses, ZMQ requires the address to be wrapped in brackets: + - IPv6: tcp://[::1]:port + This format also works for IPv4 addresses. + + Args: + ip: IP address (IPv4 or IPv6) + port: Port number + + Returns: + Formatted ZMQ address string + """ + return f"tcp://[{ip}]:{port}" + + def get_free_port() -> str: """Get free port of the host.""" with socket.socket() as sock: @@ -220,11 +247,23 @@ def create_zmq_socket( ctx: zmq.Context, socket_type: Any, identity: Optional[bytestr] = None, + ip: Optional[str] = None, ) -> zmq.Socket: - """Create ZMQ socket.""" + """Create ZMQ socket. + + Args: + ctx: ZMQ context + socket_type: ZMQ socket type + identity: Optional socket identity + ip: Optional IP address to detect IPv6 and enable IPV6 socket option + """ mem = psutil.virtual_memory() socket = ctx.socket(socket_type) + # Enable IPv6 if the IP address is IPv6 + if ip is not None and is_ipv6_address(ip): + socket.setsockopt(zmq.IPV6, 1) + # Calculate buffer size based on system memory total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 From d8da602305fb437bb9c5231da580c5e19f08df7b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Mar 2026 21:09:18 +0800 Subject: [PATCH 2/5] simplify Signed-off-by: 0oshowero0 --- transfer_queue/utils/zmq_utils.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index fd2e0941..484017d7 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -123,7 +123,7 @@ def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, s def to_addr(self, port_name: str) -> str: """Convert zmq port name to address string.""" - return format_zmq_address(self.ip, self.ports[port_name]) + return f"tcp://[{self.ip}]:{self.ports[port_name]}" def to_dict(self): """Convert ZMQServerInfo to dict.""" @@ -218,24 +218,6 @@ def is_ipv6_address(ip: str) -> bool: return False -def format_zmq_address(ip: str, port: str | int) -> str: - """ - Format IP and port for ZMQ binding/connecting. - - For IPv6 addresses, ZMQ requires the address to be wrapped in brackets: - - IPv6: tcp://[::1]:port - This format also works for IPv4 addresses. - - Args: - ip: IP address (IPv4 or IPv6) - port: Port number - - Returns: - Formatted ZMQ address string - """ - return f"tcp://[{ip}]:{port}" - - def get_free_port() -> str: """Get free port of the host.""" with socket.socket() as sock: From 39a6e0c2dbc1295418b5774ab0409670ab2f9212 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Mar 2026 21:23:10 +0800 Subject: [PATCH 3/5] fix get_free_port Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 3 +- transfer_queue/controller.py | 9 ++-- .../managers/simple_backend_manager.py | 10 +++- transfer_queue/storage/simple_backend.py | 11 +++- transfer_queue/utils/zmq_utils.py | 51 +++++++++++++++++-- 5 files changed, 72 insertions(+), 12 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 7988e261..2090ad1a 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -39,6 +39,7 @@ ZMQRequestType, ZMQServerInfo, create_zmq_socket, + format_zmq_address, ) logger = logging.getLogger(__name__) @@ -124,7 +125,7 @@ async def wrapper(self, *args, **kwargs): raise RuntimeError("No controller registered") context = zmq.asyncio.Context() - address = f"tcp://[{server_info.ip}]:{server_info.ports.get(socket_name)}" + address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 8ef0cab6..32185d3b 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -45,6 +45,7 @@ ZMQRequestType, ZMQServerInfo, create_zmq_socket, + format_zmq_address, get_free_port, ) @@ -1587,21 +1588,23 @@ def _init_zmq_socket(self): socket_type=zmq.ROUTER, ip=self._node_ip, ) - self.handshake_socket.bind(f"tcp://[{self._node_ip}]:{self._handshake_socket_port}") + self.handshake_socket.bind(format_zmq_address(self._node_ip, self._handshake_socket_port)) self.request_handle_socket = create_zmq_socket( ctx=self.zmq_context, socket_type=zmq.ROUTER, ip=self._node_ip, ) - self.request_handle_socket.bind(f"tcp://[{self._node_ip}]:{self._request_handle_socket_port}") + self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port)) self.data_status_update_socket = create_zmq_socket( ctx=self.zmq_context, socket_type=zmq.ROUTER, ip=self._node_ip, ) - self.data_status_update_socket.bind(f"tcp://[{self._node_ip}]:{self._data_status_update_socket_port}") + self.data_status_update_socket.bind( + format_zmq_address(self._node_ip, self._data_status_update_socket_port) + ) break except zmq.ZMQError: diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 7a057c91..715b5e14 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -34,7 +34,13 @@ from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory from transfer_queue.storage.simple_backend import StorageMetaGroup from transfer_queue.utils.common import get_env_bool -from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket +from transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, + format_zmq_address, +) logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -150,7 +156,7 @@ async def wrapper(self, *args, **kwargs): raise RuntimeError(f"Server {server_key} not found in registered servers") context = zmq.asyncio.Context() - address = f"tcp://[{server_info.ip}]:{server_info.ports.get(socket_name)}" + address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index 41d3b73b..697388f7 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -32,7 +32,14 @@ from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads from transfer_queue.utils.enum_utils import TransferQueueRole from transfer_queue.utils.perf_utils import IntervalPerfMonitor -from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket, get_free_port +from transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, + format_zmq_address, + get_free_port, +) logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -216,7 +223,7 @@ def _init_zmq_socket(self) -> None: while True: try: self._put_get_socket_port = get_free_port() - self.put_get_socket.bind(f"tcp://[{self._node_ip}]:{self._put_get_socket_port}") + self.put_get_socket.bind(format_zmq_address(self._node_ip, self._put_get_socket_port)) break except zmq.ZMQError: logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...") diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 484017d7..c77fb6be 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -123,7 +123,7 @@ def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, s def to_addr(self, port_name: str) -> str: """Convert zmq port name to address string.""" - return f"tcp://[{self.ip}]:{self.ports[port_name]}" + return format_zmq_address(self.ip, self.ports[port_name]) def to_dict(self): """Convert ZMQServerInfo to dict.""" @@ -218,11 +218,54 @@ def is_ipv6_address(ip: str) -> bool: return False +def format_zmq_address(ip: str, port: str | int) -> str: + """ + Format IP and port for ZMQ binding/connecting. + + For IPv6 addresses, ZMQ requires the address to be wrapped in brackets: + - IPv6: tcp://[::1]:port + - IPv4: tcp://1.2.3.4:port + + Args: + ip: IP address (IPv4 or IPv6) + port: Port number + + Returns: + Formatted ZMQ address string + """ + if is_ipv6_address(ip): + return f"tcp://[{ip}]:{port}" + else: + return f"tcp://{ip}:{port}" + + def get_free_port() -> str: """Get free port of the host.""" - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] + + # Prefer IPv6 if supported, fall back to IPv4 + families = [socket.AF_INET6, socket.AF_INET] + last_error = None + for family in families: + try: + with socket.socket(family, socket.SOCK_STREAM) as sock: + # For IPv6, allow both IPv4/IPv6 if the platform uses dual-stack by default + if family == socket.AF_INET6: + # Some OS default to v6-only; explicitly disable that to allow dual-stack. + # Ignore failures on platforms that don't support IPV6_V6ONLY. + try: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (OSError, AttributeError): + pass + + sock.bind(("", 0)) + return sock.getsockname()[1] + except OSError as e: + last_error = e + # Try next family + continue + + # Both IPv6 and IPv4 failed + raise RuntimeError(f"Failed to get free port: {last_error}") def create_zmq_socket( From eb74aed289bc16b586710d663f28c9faf57ed59f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 4 Mar 2026 21:36:28 +0800 Subject: [PATCH 4/5] fix type hint Signed-off-by: 0oshowero0 --- transfer_queue/utils/zmq_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index c77fb6be..7001c9e0 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -115,7 +115,7 @@ class ZMQServerInfo: TransferQueue server info class. """ - def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]): + def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, int]): self.role = role self.id = id self.ip = ip @@ -218,7 +218,7 @@ def is_ipv6_address(ip: str) -> bool: return False -def format_zmq_address(ip: str, port: str | int) -> str: +def format_zmq_address(ip: str, port: int) -> str: """ Format IP and port for ZMQ binding/connecting. @@ -239,7 +239,7 @@ def format_zmq_address(ip: str, port: str | int) -> str: return f"tcp://{ip}:{port}" -def get_free_port() -> str: +def get_free_port() -> int: """Get free port of the host.""" # Prefer IPv6 if supported, fall back to IPv4 From 0045d9a2c21ec9416c7a4e342817754fa123caf0 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 5 Mar 2026 11:03:31 +0800 Subject: [PATCH 5/5] fix error and enhance check logics Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 10 ++-- transfer_queue/storage/managers/base.py | 13 +++-- .../managers/simple_backend_manager.py | 2 +- transfer_queue/storage/simple_backend.py | 14 ++--- transfer_queue/utils/zmq_utils.py | 58 ++++++++++--------- 5 files changed, 51 insertions(+), 46 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 32185d3b..50ff10df 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -29,7 +29,6 @@ import torch import zmq from omegaconf import DictConfig -from ray.util import get_node_ip_address from torch import Tensor from transfer_queue.metadata import ( @@ -47,6 +46,7 @@ create_zmq_socket, format_zmq_address, get_free_port, + get_node_ip_address_raw, ) logger = logging.getLogger(__name__) @@ -1575,13 +1575,13 @@ def kv_retrieve_keys( def _init_zmq_socket(self): """Initialize ZMQ sockets for communication.""" self.zmq_context = zmq.Context() - self._node_ip = get_node_ip_address() + self._node_ip = get_node_ip_address_raw() while True: try: - self._handshake_socket_port = get_free_port() - self._request_handle_socket_port = get_free_port() - self._data_status_update_socket_port = get_free_port() + self._handshake_socket_port = get_free_port(ip=self._node_ip) + self._request_handle_socket_port = get_free_port(ip=self._node_ip) + self._data_status_update_socket_port = get_free_port(ip=self._node_ip) self.handshake_socket = create_zmq_socket( ctx=self.zmq_context, diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index cad91f97..cb714c32 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -83,8 +83,9 @@ def _connect_to_controller(self) -> None: # create zmq socket for handshake (sync, for initial connection) self.controller_handshake_socket = create_zmq_socket( - sync_zmq_context, - zmq.DEALER, + ctx=sync_zmq_context, + socket_type=zmq.DEALER, + ip=self.controller_info.ip, identity=f"{self.storage_manager_id}-controller_handshake_socket-{uuid4().hex[:8]}".encode(), ) @@ -219,13 +220,13 @@ async def notify_data_update( # create dynamic socket identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(self.zmq_context, zmq.DEALER, identity=identity) + sock = create_zmq_socket(self.zmq_context, zmq.DEALER, self.controller_info.ip, identity) try: sock.connect(self.controller_info.to_addr("data_status_update_socket")) request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={ "partition_id": partition_id, @@ -253,7 +254,7 @@ async def notify_data_update( messages = await asyncio.wait_for(sock.recv_multipart(), timeout=poll_interval) response_msg = ZMQMessage.deserialize(messages) - if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: + if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: # type: ignore[arg-type] response_received = True logger.debug( f"[{self.storage_manager_id}]: Get data status update ACK response " @@ -272,7 +273,7 @@ async def notify_data_update( logger.error(f"[{self.storage_manager_id}]: Error during notify_data_update: {e}") try: error_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={"message": f"Failed to notify: {str(e)}"}, ).serialize() diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 715b5e14..608a0827 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -158,7 +158,7 @@ async def wrapper(self, *args, **kwargs): context = zmq.asyncio.Context() address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) + sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity) try: sock.connect(address) diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index 697388f7..d84abe3b 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -26,7 +26,6 @@ import ray import zmq -from ray.util import get_node_ip_address from transfer_queue.metadata import SampleMeta from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads @@ -39,6 +38,7 @@ create_zmq_socket, format_zmq_address, get_free_port, + get_node_ip_address_raw, ) logger = logging.getLogger(__name__) @@ -215,14 +215,14 @@ def _init_zmq_socket(self) -> None: - worker_socket (DEALER): Backend socket for worker communication. """ self.zmq_context = zmq.Context() - self._node_ip = get_node_ip_address() + self._node_ip = get_node_ip_address_raw() # Frontend: ROUTER for receiving client requests - self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, ip=self._node_ip) + self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, self._node_ip) while True: try: - self._put_get_socket_port = get_free_port() + self._put_get_socket_port = get_free_port(ip=self._node_ip) self.put_get_socket.bind(format_zmq_address(self._node_ip, self._put_get_socket_port)) break except zmq.ZMQError: @@ -230,7 +230,7 @@ def _init_zmq_socket(self) -> None: continue # Backend: DEALER for worker communication (connected via zmq.proxy) - self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER) + self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip) self.worker_socket.bind(self._inproc_addr) self.zmq_server_info = ZMQServerInfo( @@ -276,8 +276,8 @@ def _proxy_routine(self) -> None: def _worker_routine(self) -> None: """Worker thread for processing requests.""" - # Each worker must have its own socket - worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER) + + worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER, self._node_ip) worker_socket.connect(self._inproc_addr) poller = zmq.Poller() diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 7001c9e0..7d0d8d18 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -25,6 +25,7 @@ import psutil import ray import zmq +from ray.util import get_node_ip_address from transfer_queue.utils.common import ( get_env_bool, @@ -239,54 +240,57 @@ def format_zmq_address(ip: str, port: int) -> str: return f"tcp://{ip}:{port}" -def get_free_port() -> int: - """Get free port of the host.""" +def get_node_ip_address_raw() -> str: + """A wrapper around Ray's get_node_ip_address(). - # Prefer IPv6 if supported, fall back to IPv4 - families = [socket.AF_INET6, socket.AF_INET] - last_error = None - for family in families: - try: - with socket.socket(family, socket.SOCK_STREAM) as sock: - # For IPv6, allow both IPv4/IPv6 if the platform uses dual-stack by default - if family == socket.AF_INET6: - # Some OS default to v6-only; explicitly disable that to allow dual-stack. - # Ignore failures on platforms that don't support IPV6_V6ONLY. - try: - sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - except (OSError, AttributeError): - pass + This function intentionally returns a raw IPv4/IPv6 address WITHOUT brackets. + """ + + return get_node_ip_address().strip("[]") + + +def get_free_port(ip: str) -> int: + """Get free port of the host. - sock.bind(("", 0)) - return sock.getsockname()[1] - except OSError as e: - last_error = e - # Try next family - continue + Args: + ip: IP address to detect IPv6 and enable IPV6 socket option + """ + is_ipv6 = is_ipv6_address(ip) + family = socket.AF_INET6 if is_ipv6 else socket.AF_INET - # Both IPv6 and IPv4 failed - raise RuntimeError(f"Failed to get free port: {last_error}") + with socket.socket(family, socket.SOCK_STREAM) as sock: + if is_ipv6: + # Try to allow dual-stack if the platform supports it. + try: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (OSError, AttributeError): + # Some platforms don't support IPV6_V6ONLY or this option; + # in that case just ignore and use the default behavior. + pass + + sock.bind(("", 0)) + return sock.getsockname()[1] def create_zmq_socket( ctx: zmq.Context, socket_type: Any, + ip: str, identity: Optional[bytestr] = None, - ip: Optional[str] = None, ) -> zmq.Socket: """Create ZMQ socket. Args: ctx: ZMQ context socket_type: ZMQ socket type + ip: IP address to detect IPv6 and enable IPV6 socket option identity: Optional socket identity - ip: Optional IP address to detect IPv6 and enable IPV6 socket option """ mem = psutil.virtual_memory() socket = ctx.socket(socket_type) # Enable IPv6 if the IP address is IPv6 - if ip is not None and is_ipv6_address(ip): + if is_ipv6_address(ip): socket.setsockopt(zmq.IPV6, 1) # Calculate buffer size based on system memory