From ae4e1555fbe0fef2fdfded11918a54361fddc38c Mon Sep 17 00:00:00 2001 From: Yixiang Zhang Date: Thu, 4 Jun 2026 17:27:58 +0800 Subject: [PATCH 1/5] [refactor] Merge data status update thread/socket into request handler to eliminate concurrency conflicts Previously, _update_data_status and _process_request ran on two separate threads listening on two different ROUTER sockets. When used in async frameworks, NOTIFY_DATA_UPDATE and metadata requests (GET_META, CLEAR_META, etc.) could interleave and cause data processing conflicts on shared partition state. This change consolidates everything onto request_handle_socket and the single _process_request thread, so all controller requests are serialized naturally: - Drop data_status_update_socket and its port; ZMQServerInfo.ports now exposes only handshake_socket and request_handle_socket. - Remove _start_process_update_data_status and _update_data_status; move the NOTIFY_DATA_UPDATE branch into _process_request. - Update storage manager to send NOTIFY_DATA_UPDATE to request_handle_socket. - Clean up data_status_update_socket port entries in test fixtures. Note: this is a breaking change to the controller-storage wire protocol (port name removed). All storage managers must be upgraded together. Signed-off-by: yxstev --- tests/test_async_simple_storage_manager.py | 4 +- tests/test_ray_p2p.py | 1 - transfer_queue/controller.py | 43 ++-------------------- transfer_queue/storage/managers/base.py | 2 +- 4 files changed, 6 insertions(+), 44 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 6c4da3d2..da6f48f8 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -53,7 +53,7 @@ async def mock_async_storage_manager(): role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", - ports={"handshake_socket": 12347, "data_status_update_socket": 12348}, + ports={"handshake_socket": 12347}, ) config = { @@ -158,7 +158,7 @@ async def test_async_storage_manager_error_handling(): role=Role.CONTROLLER, id="controller_0", ip="127.0.0.1", - ports={"handshake_socket": 12346, "data_status_update_socket": 12347}, + ports={"handshake_socket": 12346}, ) config = { diff --git a/tests/test_ray_p2p.py b/tests/test_ray_p2p.py index 353bb926..92f3f9c9 100644 --- a/tests/test_ray_p2p.py +++ b/tests/test_ray_p2p.py @@ -60,7 +60,6 @@ def create_mock_controller(): ip="127.0.0.1", ports={ "request_handle_socket": 9981, - "data_status_update_socket": 9982, "handshake_socket": 9983, }, ) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 9661ba0e..cc0814a6 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1021,7 +1021,6 @@ def __init__( # Start background processing threads self._start_process_handshake() - self._start_process_update_data_status() self._start_process_request() logger.info(f"TransferQueue Controller {self.controller_id} initialized") @@ -1685,7 +1684,6 @@ def _init_zmq_socket(self): try: self._handshake_socket_port = get_free_port(ip=self._node_ip) self._request_handle_socket_port = get_free_port(ip=self._node_ip) - self._data_status_update_socket_port = get_free_port(ip=self._node_ip) self.handshake_socket = create_zmq_socket( ctx=self.zmq_context, @@ -1701,15 +1699,6 @@ def _init_zmq_socket(self): ) self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port)) - self.data_status_update_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ip=self._node_ip, - ) - self.data_status_update_socket.bind( - format_zmq_address(self._node_ip, self._data_status_update_socket_port) - ) - break except zmq.ZMQError: logger.warning(f"[{self.controller_id}]: Try to bind ZMQ sockets failed, retrying...") @@ -1722,7 +1711,6 @@ def _init_zmq_socket(self): ports={ "handshake_socket": self._handshake_socket_port, "request_handle_socket": self._request_handle_socket_port, - "data_status_update_socket": self._data_status_update_socket_port, }, ) @@ -1781,15 +1769,6 @@ def _start_process_handshake(self): ) self.wait_connection_thread.start() - def _start_process_update_data_status(self): - """Start the data status update processing thread.""" - self.process_update_data_status_thread = Thread( - target=self._update_data_status, - name="TransferQueueControllerProcessUpdateDataStatusThread", - daemon=True, - ) - self.process_update_data_status_thread.start() - def _start_process_request(self): """Start the request processing thread.""" self.process_request_thread = Thread( @@ -2045,23 +2024,7 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) - self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) - - def _update_data_status(self): - """Process data status update messages from storage units - adapted for partitions.""" - logger.debug(f"[{self.controller_id}]: start receiving update_data_status requests...") - - perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id) - - while True: - monitor = self._metrics if self._metrics is not None else perf_monitor - - messages = self.data_status_update_socket.recv_multipart(copy=False) - identity = messages.pop(0) - serialized_msg = messages - request_msg = ZMQMessage.deserialize(serialized_msg) - - if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: + elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: with monitor.measure(op_type="NOTIFY_DATA_UPDATE"): message_data = request_msg.body partition_id = message_data.get("partition_id") @@ -2079,7 +2042,6 @@ def _update_data_status(self): self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes)) logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}") - # Send acknowledgment response_msg = ZMQMessage.create( request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, sender_id=self.controller_id, @@ -2089,7 +2051,8 @@ def _update_data_status(self): "success": success, }, ) - self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()]) + + self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def get_zmq_server_info(self) -> ZMQServerInfo: """Get ZMQ server connection information.""" diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 180d466e..7057cc1e 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -210,7 +210,7 @@ async def notify_data_update( sock = create_zmq_socket(self.zmq_context, zmq.DEALER, self.controller_info.ip, identity) try: - sock.connect(self.controller_info.to_addr("data_status_update_socket")) + sock.connect(self.controller_info.to_addr("request_handle_socket")) normalized_field_schema = {} for field_name, field in field_schema.items(): From 78bac46dca20ff086db1da69172517c6c31cc3a9 Mon Sep 17 00:00:00 2001 From: yxstev Date: Thu, 4 Jun 2026 18:49:38 +0800 Subject: [PATCH 2/5] remove data_status_lock logics Signed-off-by: yxstev --- transfer_queue/controller.py | 93 ++++++++++++++---------------------- 1 file changed, 35 insertions(+), 58 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index cc0814a6..aed48fc2 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -20,7 +20,7 @@ from dataclasses import dataclass, field from itertools import groupby from operator import itemgetter -from threading import Lock, Thread +from threading import Thread from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -361,10 +361,6 @@ class DataPartitionStatus: keys_mapping: dict[str, int] = field(default_factory=dict) # key -> global_idx revert_keys_mapping: dict[int, str] = field(default_factory=dict) # global_idx -> key - # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. - # No need to strictly lock for every read/write operation since freshness is not critical. - data_status_lock: Lock = field(default_factory=Lock) - # Dynamic configuration - these are computed from the current state @property def total_samples_num(self) -> int: @@ -409,8 +405,7 @@ def register_pre_allocated_indexes(self, allocated_indexes: list[int]): max_sample_idx = max(allocated_indexes) required_samples = max_sample_idx + 1 - with self.data_status_lock: - self.ensure_samples_capacity(required_samples) + self.ensure_samples_capacity(required_samples) logger.debug(f"Pre-allocated indexes in {self.partition_id}: {allocated_indexes}") @@ -526,9 +521,8 @@ def update_production_status( max_sample_idx = max(global_indices) if global_indices else -1 required_samples = max_sample_idx + 1 - with self.data_status_lock: - # Ensure we have enough rows - self.ensure_samples_capacity(required_samples) + # Ensure we have enough rows + self.ensure_samples_capacity(required_samples) # Register new fields if needed new_fields = [f for f in field_names if f not in self.field_name_mapping] @@ -538,14 +532,12 @@ def update_production_status( self.field_name_mapping[f] = len(self.field_name_mapping) required_fields = len(self.field_name_mapping) - with self.data_status_lock: - self.ensure_fields_capacity(required_fields) + self.ensure_fields_capacity(required_fields) - with self.data_status_lock: - # Update production status - if self.production_status is not None and global_indices and field_names: - field_indices = [self.field_name_mapping.get(f) for f in field_names] - self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 + # Update production status + if self.production_status is not None and global_indices and field_names: + field_indices = [self.field_name_mapping.get(f) for f in field_names] + self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 # Update field metadata self._update_field_metadata(global_indices, field_schema, custom_backend_meta) @@ -641,8 +633,7 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te if partition_global_index.numel() == 0: empty_status = self.consumption_status[task_name].new_zeros(0) return partition_global_index, empty_status - with self.data_status_lock: - self.ensure_samples_capacity(max(partition_global_index) + 1) + self.ensure_samples_capacity(max(partition_global_index) + 1) consumption_status = self.consumption_status[task_name][partition_global_index] else: consumption_status = self.consumption_status[task_name] @@ -730,23 +721,22 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: if field_name not in self.field_name_mapping: return [] - with self.data_status_lock: - row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool) + row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool) - # Apply consumption filter (exclude already consumed samples) - _, consumption_status = self.get_consumption_status(task_name, mask=False) - if consumption_status is not None: - unconsumed_mask = consumption_status == 0 - row_mask &= unconsumed_mask + # Apply consumption filter (exclude already consumed samples) + _, consumption_status = self.get_consumption_status(task_name, mask=False) + if consumption_status is not None: + unconsumed_mask = consumption_status == 0 + row_mask &= unconsumed_mask - # Create column mask for requested fields - col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool) - field_indices = [self.field_name_mapping[field] for field in field_names] - if field_indices: - col_mask[field_indices] = True + # Create column mask for requested fields + col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool) + field_indices = [self.field_name_mapping[field] for field in field_names] + if field_indices: + col_mask[field_indices] = True - # Filter production status by masks - relevant_status = self.production_status[row_mask][:, col_mask] + # Filter production status by masks + relevant_status = self.production_status[row_mask][:, col_mask] # Check if all required fields are ready for each sample all_fields_ready = torch.all(relevant_status, dim=1) @@ -878,32 +868,20 @@ def to_snapshot(self): Get a snapshot of partition status information. Returns: - DataPartitionStatus object without threading.Lock() + DataPartitionStatus object """ - def _perform_copy(): - cls = self.__class__ - snapshot = cls.__new__(cls) - - for name, value in self.__dict__.items(): - if name == "data_status_lock": - continue - - if isinstance(value, Tensor): - new_val = value.clone().detach() - else: - new_val = copy.deepcopy(value) + cls = self.__class__ + snapshot = cls.__new__(cls) - setattr(snapshot, name, new_val) - return snapshot - - lock_obj = getattr(self, "data_status_lock", None) + for name, value in self.__dict__.items(): + if isinstance(value, Tensor): + new_val = value.clone().detach() + else: + new_val = copy.deepcopy(value) - if lock_obj: - with lock_obj: - return _perform_copy() - else: - return _perform_copy() + setattr(snapshot, name, new_val) + return snapshot def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = True): """Clear all production and optionally consumption data for given global_indexes.""" @@ -1069,7 +1047,7 @@ def _get_partition(self, partition_id: str) -> DataPartitionStatus | None: def get_partition_snapshot(self, partition_id: str) -> DataPartitionStatus | None: """ - Get a copy of partition status information, without threading.Lock(). + Get a copy of partition status information. Args: partition_id: ID of the partition to retrieve @@ -1622,8 +1600,7 @@ def kv_retrieve_meta( partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i] partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]] - with partition.data_status_lock: - partition.ensure_samples_capacity(max(batch_global_indexes) + 1) + partition.ensure_samples_capacity(max(batch_global_indexes) + 1) verified_global_indexes = [idx for idx in global_indexes if idx is not None] assert len(verified_global_indexes) == len(keys) From 6330a54fc8f01b2aaa8374569cf6ee95f61579ff Mon Sep 17 00:00:00 2001 From: yxstev Date: Thu, 4 Jun 2026 18:56:32 +0800 Subject: [PATCH 3/5] clean up data_status_update_socket Signed-off-by: yxstev --- tests/test_async_simple_storage_manager.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index da6f48f8..a6875600 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -68,7 +68,6 @@ async def mock_async_storage_manager(): manager.config = config manager.controller_info = controller_info manager.storage_unit_infos = storage_unit_infos - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None @@ -257,7 +256,6 @@ async def test_get_data_routes_from_hash(): manager.storage_manager_id = "test_get" manager.storage_unit_infos = storage_unit_infos manager.controller_info = None - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None @@ -310,7 +308,6 @@ async def test_clear_data_routes_from_hash(): manager.storage_manager_id = "test_clear" manager.storage_unit_infos = storage_unit_infos manager.controller_info = None - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None @@ -361,7 +358,6 @@ async def test_hash_routing_stable_across_batch_sizes(): manager.storage_manager_id = "test_hash_batch" manager.storage_unit_infos = storage_unit_infos manager.controller_info = None - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None @@ -422,7 +418,6 @@ async def test_hash_routing_stable_reversed_order(): manager.storage_manager_id = "test_hash_order" manager.storage_unit_infos = storage_unit_infos manager.controller_info = None - manager.data_status_update_socket = None manager.controller_handshake_socket = None manager.zmq_context = None From dd57a642623ce052dfc0df1078fae72bd5e2a2ce Mon Sep 17 00:00:00 2001 From: yxstev Date: Thu, 4 Jun 2026 19:42:46 +0800 Subject: [PATCH 4/5] fix NOTIFY_DATA_UPDATE_ACK response message Signed-off-by: yxstev --- transfer_queue/controller.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index aed48fc2..c31810a3 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -2019,9 +2019,11 @@ def _process_request(self): self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes)) logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}") + # Send acknowledgment response_msg = ZMQMessage.create( request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, sender_id=self.controller_id, + receiver_id=request_msg.sender_id, body={ "controller_id": self.controller_id, "partition_id": partition_id, From 7d284d32ec7b779fca1b40dc2415fdadf66fe108 Mon Sep 17 00:00:00 2001 From: yxstev Date: Fri, 5 Jun 2026 14:28:10 +0800 Subject: [PATCH 5/5] fix request_msg if-else order Signed-off-by: yxstev --- transfer_queue/controller.py | 60 ++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index c31810a3..1182a44c 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1790,6 +1790,36 @@ def _process_request(self): body={"metadata": metadata}, ) + elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: + with monitor.measure(op_type="NOTIFY_DATA_UPDATE"): + message_data = request_msg.body + partition_id = message_data.get("partition_id") + global_indexes = message_data.get("global_indexes", []) + + # Update production status + success = self.update_production_status( + partition_id=partition_id, + global_indexes=global_indexes, + field_schema=message_data.get("field_schema", {}), + custom_backend_meta=message_data.get("custom_backend_meta", {}), + ) + if success: + if self._metrics is not None: + self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes)) + logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}") + + # Send acknowledgment + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={ + "controller_id": self.controller_id, + "partition_id": partition_id, + "success": success, + }, + ) + elif request_msg.request_type == ZMQRequestType.GET_PARTITION_META: with monitor.measure(op_type="GET_PARTITION_META"): params = request_msg.body @@ -2001,36 +2031,6 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) - elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: - with monitor.measure(op_type="NOTIFY_DATA_UPDATE"): - message_data = request_msg.body - partition_id = message_data.get("partition_id") - global_indexes = message_data.get("global_indexes", []) - - # Update production status - success = self.update_production_status( - partition_id=partition_id, - global_indexes=global_indexes, - field_schema=message_data.get("field_schema", {}), - custom_backend_meta=message_data.get("custom_backend_meta", {}), - ) - if success: - if self._metrics is not None: - self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes)) - logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}") - - # Send acknowledgment - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={ - "controller_id": self.controller_id, - "partition_id": partition_id, - "success": success, - }, - ) - self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def get_zmq_server_info(self) -> ZMQServerInfo: