Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 52 additions & 32 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import threading
from functools import wraps
from typing import Any, Callable, Optional
from uuid import uuid4

import torch
import zmq
Expand All @@ -34,11 +33,14 @@
TransferQueueStorageManagerFactory,
)
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
from transfer_queue.utils.socket_pool import (
SocketPoolManager,
invoke_with_pool,
)
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
)

Expand Down Expand Up @@ -78,6 +80,9 @@ def __init__(
raise TypeError(f"controller_info must be ZMQServerInfo, got {type(controller_info)}")
self.client_id = client_id
self._controller: ZMQServerInfo = controller_info
# Owns the long-lived DEALER pools used by ``dynamic_socket``;
# released by ``close()``.
self._socket_pool_manager = SocketPoolManager()
logger.info(f"[{self.client_id}]: Registered Controller server {controller_info.id} at {controller_info.ip}")

def initialize_storage_manager(
Expand All @@ -99,22 +104,29 @@ def initialize_storage_manager(
manager_type, controller_info=self._controller, config=config
)

# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
@staticmethod
def dynamic_socket(socket_name: str):
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers.
"""Decorator: route each call through a long-lived DEALER pool.

Handles socket lifecycle: create -> connect -> inject -> close.
The pool is keyed by ``(current_loop, controller_id, socket_name)``
and grows lazily up to ``TQ_POOL_SIZE``. Each call is wrapped in
``asyncio.wait_for(TQ_REQUEST_TIMEOUT_S)`` and retried up to
``TQ_REQUEST_MAX_ATTEMPTS`` times on failure, with the suspect
socket dropped between attempts. See ``transfer_queue.utils.socket_pool``
for the rationale (TIME_WAIT exhaustion under high-throughput
async RL training, and ROUTER reply mis-routing protection).

Args:
socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port")
socket_name: Port name from server config to use for ZMQ
connection (e.g., ``"request_handle_socket"``).

Decorated Function Requirements:
1. Must be an async class method (needs `self`)
2. `self` must have:
- `_controller`: Server registry
- `client_id`: Unique client ID for socket identity
3. Receives ZMQ socket via `socket` keyword argument (injected by decorator)
1. Must be an async class method (needs ``self``).
2. ``self`` must have:
- ``_controller``: ZMQServerInfo of the controller.
- ``client_id``: Unique client ID for socket identity.
3. Receives ZMQ socket via ``socket`` keyword argument
(injected by decorator).
"""

def decorator(func: Callable):
Expand All @@ -124,31 +136,34 @@ async def wrapper(self, *args, **kwargs):
if not server_info:
raise RuntimeError("No controller registered")

context = zmq.asyncio.Context()
# ``loop_id`` MUST be in the identity prefix. Some callers
# drive the same client instance from two asyncio loops
# (e.g. a bg loop for sync wrappers + a shared loop for
# async calls). Without loop_id, both pools' "first
# socket" would share the identity ``{client_id}_to_
# {server_id}-0`` and the ROUTER would route replies
# non-deterministically between them — one side's recv
# then hangs forever.
loop_id = id(asyncio.get_running_loop())
identity_prefix = f"{self.client_id}_to_{server_info.id}_loop{loop_id}"
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip)

try:
sock.connect(address)
logger.debug(
f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} "
f"with identity {identity.decode()}"
)
pool = self._socket_pool_manager.get_or_create(
pool_key=(server_info.id, socket_name),
address=address,
ip=server_info.ip,
identity_prefix=identity_prefix,
)

async def _call(sock):
kwargs["socket"] = sock
return await func(self, *args, **kwargs)
except Exception as e:
logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}")
raise
finally:
try:
if not sock.closed:
sock.close(linger=-1)
except Exception as e:
logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}")

context.term()

return await invoke_with_pool(
pool,
_call,
label=f"{self.client_id} {socket_name}.{func.__name__}",
)

return wrapper

Expand Down Expand Up @@ -1236,7 +1251,12 @@ async def async_kv_list(
raise RuntimeError(f"[{self.client_id}]: Error in kv_list: {str(e)}") from e

def close(self) -> None:
"""Close the client and cleanup resources including storage manager."""
"""Close the client and cleanup resources including socket pools and storage manager."""
try:
if hasattr(self, "_socket_pool_manager"):
self._socket_pool_manager.close()
except Exception as e:
logger.warning(f"[{self.client_id}]: Error closing socket pools: {e}")
try:
if hasattr(self, "storage_manager") and self.storage_manager:
if hasattr(self.storage_manager, "close"):
Expand Down
99 changes: 59 additions & 40 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from functools import wraps
from operator import itemgetter
from typing import Any, Callable, NamedTuple, Optional
from uuid import uuid4

import torch
import zmq
Expand All @@ -32,11 +31,14 @@
from transfer_queue.metadata import BatchMeta, extract_field_schema
from transfer_queue.storage.managers.base import TransferQueueStorageManager
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
from transfer_queue.utils.socket_pool import (
SocketPoolManager,
invoke_with_pool,
)
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
)

Expand Down Expand Up @@ -86,6 +88,9 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig):
raise ValueError("AsyncSimpleStorageManager requires non-empty 'zmq_info' in config.")

self.storage_unit_infos = self._register_servers(server_infos)
# Owns the long-lived DEALER pools used by
# ``dynamic_storage_manager_socket``; released by ``close()``.
self._socket_pool_manager = SocketPoolManager()

def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerInfo]"):
"""Register and validate server information.
Expand Down Expand Up @@ -114,22 +119,36 @@ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerIn

return server_infos_transform

# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
@staticmethod
def dynamic_storage_manager_socket(socket_name: str, timeout: int):
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close).
"""Decorator: route each call through a long-lived DEALER pool.

The pool is keyed by ``(current_loop, storage_unit_key, socket_name)``
and grows lazily up to ``TQ_POOL_SIZE``. Each call is wrapped in
``asyncio.wait_for(timeout)`` and retried up to
``TQ_REQUEST_MAX_ATTEMPTS`` times on failure, with the suspect
socket dropped between attempts. See ``transfer_queue.utils.socket_pool``
for the rationale (TIME_WAIT exhaustion under high-throughput
async RL training, and ROUTER reply mis-routing protection).

The ``timeout`` argument is applied at both layers:
* ``asyncio.wait_for`` (asyncio-level, around the whole call)
* ZMQ ``RCVTIMEO``/``SNDTIMEO`` on each pooled socket (libzmq-
level, so a runaway recv cannot block libzmq IO either)

Args:
socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port").
timeout (float): Timeout in seconds for ZMQ connection (in seconds).
socket_name (str): Port name (from server config) to use for
ZMQ connection (e.g., ``"put_get_socket"``).
timeout (int): Per-call timeout in seconds.

Decorated Function Rules:
1. Must be an async class method (needs `self`).
2. `self` requires:
- `storage_unit_infos: storage unit infos (ZMQServerInfo | dict[Any, ZMQServerInfo]).
3. Specify target server via:
- `target_storage_unit` arg.
4. Receives ZMQ socket via `socket` keyword arg (injected by decorator).
1. Must be an async class method (needs ``self``).
2. ``self`` requires:
- ``storage_unit_infos``: storage unit infos
(ZMQServerInfo | dict[Any, ZMQServerInfo]).
3. Specify target server via the ``target_storage_unit`` arg.
4. Receives ZMQ socket via ``socket`` keyword arg (injected
by decorator).
"""

def decorator(func: Callable):
Expand All @@ -143,44 +162,39 @@ async def wrapper(self, *args, **kwargs):
break

server_info = self.storage_unit_infos.get(server_key)

if not server_info:
raise RuntimeError(f"Server {server_key} not found in registered servers")

context = zmq.asyncio.Context()
# See ``AsyncTransferQueueClient.dynamic_socket`` for why
# ``loop_id`` must be in the identity prefix: cross-loop
# DEALER identity collision otherwise misroutes replies.
loop_id = id(asyncio.get_running_loop())
identity_prefix = f"{self.storage_manager_id}_to_{server_info.id}_loop{loop_id}"
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity)

try:
sock.connect(address)
# Timeouts to avoid indefinite await on recv/send
def _on_create(sock):
# libzmq-level fallback in addition to asyncio.wait_for.
sock.setsockopt(zmq.RCVTIMEO, timeout * 1000)
sock.setsockopt(zmq.SNDTIMEO, timeout * 1000)
logger.debug(
f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} "
f"with identity {identity.decode()}"
)

pool = self._socket_pool_manager.get_or_create(
pool_key=(server_key, socket_name),
address=address,
ip=server_info.ip,
identity_prefix=identity_prefix,
on_create=_on_create,
)

async def _call(sock):
kwargs["socket"] = sock
return await func(self, *args, **kwargs)
except Exception as e:
logger.error(
f"[{self.storage_manager_id}]: Error in socket operation with "
f"StorageUnit {server_info.id} at {address}: "
f"{type(e).__name__}: {e}"
)
raise
finally:
try:
if not sock.closed:
sock.close(linger=-1)
except Exception as e:
logger.warning(
f"[{self.storage_manager_id}]: Error closing socket to StorageUnit {server_info.id}: {e}"
)

context.term()

return await invoke_with_pool(
pool,
_call,
timeout=float(timeout),
label=f"{self.storage_manager_id} {socket_name}.{func.__name__}",
)

return wrapper

Expand Down Expand Up @@ -590,4 +604,9 @@ def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]:

def close(self) -> None:
"""Close all ZMQ sockets and context to prevent resource leaks."""
try:
if hasattr(self, "_socket_pool_manager"):
self._socket_pool_manager.close()
except Exception as e:
logger.warning(f"[{self.storage_manager_id}]: Error closing socket pools: {e}")
super().close()
Loading
Loading