From ddef1cf650900876e81f6148714cb6077de19ae1 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 17:00:36 +0800 Subject: [PATCH 1/9] remove unnecessary copy Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 4 +--- tests/test_client.py | 4 ++-- tests/test_simple_storage_unit.py | 6 +++--- transfer_queue/controller.py | 6 +++--- transfer_queue/storage/managers/base.py | 2 +- .../storage/managers/simple_backend_manager.py | 12 +++--------- 6 files changed, 13 insertions(+), 21 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 7d697849..0108bdce 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -128,10 +128,8 @@ async def test_async_storage_manager_mock_operations(mock_async_storage_manager) manager._put_to_single_storage_unit = AsyncMock() manager._get_from_single_storage_unit = AsyncMock( return_value=( - [0, 1], ["test_field"], {"test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]}, - b"this is the serialized message object.", ) ) manager._clear_single_storage_unit = AsyncMock() @@ -286,7 +284,7 @@ async def fake_get(global_indexes, fields, target_storage_unit=None, **kwargs): su = target_storage_unit called_with[su] = list(global_indexes) tensors = [torch.zeros(2) for _ in global_indexes] - return global_indexes, fields, {"f": tensors}, b"" + return fields, {"f": tensors} manager._get_from_single_storage_unit = fake_get diff --git a/tests/test_client.py b/tests/test_client.py index ccf8e8b9..b9d10397 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -91,7 +91,7 @@ def _handle_requests(self): try: socks = dict(poller.poll(100)) # 100ms timeout if self.request_socket in socks: - messages = self.request_socket.recv_multipart() + messages = self.request_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) @@ -332,7 +332,7 @@ def _handle_data_requests(self): try: socks = dict(poller.poll(100)) # 100ms timeout if self.data_socket in socks: - messages = self.data_socket.recv_multipart() + messages = self.data_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages msg = ZMQMessage.deserialize(serialized_msg) diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index b18f8dd1..01ef2027 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -47,7 +47,7 @@ def send_put(self, client_id, global_indexes, field_data): body={"global_indexes": global_indexes, "data": field_data}, ) self.socket.send_multipart(msg.serialize()) - return ZMQMessage.deserialize(self.socket.recv_multipart()) + return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) def send_get(self, client_id, global_indexes, fields): msg = ZMQMessage.create( @@ -56,7 +56,7 @@ def send_get(self, client_id, global_indexes, fields): body={"global_indexes": global_indexes, "fields": fields}, ) self.socket.send_multipart(msg.serialize()) - return ZMQMessage.deserialize(self.socket.recv_multipart()) + return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) def send_clear(self, client_id, global_indexes): msg = ZMQMessage.create( @@ -65,7 +65,7 @@ def send_clear(self, client_id, global_indexes): body={"global_indexes": global_indexes}, ) self.socket.send_multipart(msg.serialize()) - return ZMQMessage.deserialize(self.socket.recv_multipart()) + return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) def close(self): self.socket.close() diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index b0ef7572..f1410480 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1715,7 +1715,7 @@ def _wait_connection(self): if self.handshake_socket in socks: try: - messages = self.handshake_socket.recv_multipart() + messages = self.handshake_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) @@ -1784,7 +1784,7 @@ def _process_request(self): perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id) while True: - messages = self.request_handle_socket.recv_multipart() + messages = self.request_handle_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) @@ -2027,7 +2027,7 @@ def _update_data_status(self): perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id) while True: - messages = self.data_status_update_socket.recv_multipart() + messages = self.data_status_update_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index d33c591f..bf6cdb9d 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -158,7 +158,7 @@ def _do_handshake_with_controller(self) -> None: if (socks.get(self.controller_handshake_socket, 0) & zmq.POLLIN) and pending_connection: try: - response_msg = ZMQMessage.deserialize(self.controller_handshake_socket.recv_multipart()) + response_msg = ZMQMessage.deserialize(self.controller_handshake_socket.recv_multipart(copy=False)) if response_msg.request_type == ZMQRequestType.HANDSHAKE_ACK: is_connected = True diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 3d66dcb8..2312950d 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -24,7 +24,6 @@ from typing import Any, Callable, NamedTuple from uuid import uuid4 -import numpy as np import torch import zmq from omegaconf import DictConfig @@ -345,11 +344,6 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack: if all(v.shape == values[0].shape for v in values): return torch.stack(values) return torch.nested.as_nested_tensor(values, layout=torch.jagged) - if all(isinstance(v, np.ndarray) for v in values): - # Detach numpy arrays from ZMQ frame memory (copy=False path). - # Use per-element .copy() instead of np.stack because string-dtype - # arrays may have heterogeneous shapes. - return NonTensorStack(*[v.copy() for v in values]) return NonTensorStack(*values) async def get_data(self, metadata: BatchMeta) -> TensorDict: @@ -392,7 +386,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: n = len(metadata.global_indexes) ordered_data: dict[str, list] = {field: [None] * n for field in metadata.field_names} - for (su_id, group), (_, fields, su_data, _) in zip(routing.items(), results, strict=True): + for (su_id, group), (fields, su_data) in zip(routing.items(), results, strict=True): for field in fields: for i, pos in enumerate(group.batch_positions): ordered_data[field][pos] = su_data[field][i] @@ -423,7 +417,7 @@ async def _get_from_single_storage_unit( if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE: storage_unit_data = response_msg.body["data"] - return global_indexes, fields, storage_unit_data, messages + return fields, storage_unit_data else: raise RuntimeError( f"Failed to get data from storage unit {target_storage_unit}: " @@ -484,7 +478,7 @@ async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=N ) await socket.send_multipart(request_msg.serialize()) - messages = await socket.recv_multipart() + messages = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(messages) if response_msg.request_type != ZMQRequestType.CLEAR_DATA_RESPONSE: From 121339be1d5fcc2f7f625307b1e5f0d83bba3e95 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 17:25:30 +0800 Subject: [PATCH 2/9] fix comments Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/simple_backend_manager.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 2312950d..1fc82bf1 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -331,10 +331,11 @@ async def _put_to_single_storage_unit( @staticmethod def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack: - """Pack a list of per-sample values into a batched container. + """ + Pack a list of per-sample values into a batched container. - A memory copy is intentional here: it detaches received tensors from - zero-copy buffers, gives them their own lifetime, and ensures writability. + For tensor values, this performs a memory copy via stacking or nested tensor creation. + Non-tensor values are grouped into a ``NonTensorStack`` without copying. """ if not values: raise ValueError("_pack_field_values received empty values list; caller should filter empty batches") From 7297090def01e8cde8ba793f479dcf6e94466f00 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 17 Mar 2026 14:51:39 +0800 Subject: [PATCH 3/9] try: fix metadata error Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 14 ----------- transfer_queue/metadata.py | 49 ++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 37 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index c6f7828b..195eaf3b 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -136,20 +136,6 @@ def test_size_property(self): assert batch.size == 5 assert len(batch) == 5 - def test_add_fields_empty_batch_is_non_tensor_unknown(self): - """add_fields with empty field value leaves is_non_tensor as None (unknown). - - When a field has zero samples, we cannot determine the field type from data. - is_non_tensor must not default to False (which would incorrectly imply Tensor). - """ - from tensordict import TensorDict - - batch = BatchMeta.empty() - # TensorDict with an empty tensor of batch_size=0 - empty_td = TensorDict({"empty_field": torch.empty(0, 2)}, batch_size=0) - batch.add_fields(empty_td) - assert batch.field_schema["empty_field"]["is_non_tensor"] is None - def test_pickle_roundtrip_preserves_batchmeta(self): """BatchMeta must survive pickle round-trip with all fields intact.""" import pickle diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 24248e97..9c04e23d 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -146,47 +146,46 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: field_schema: dict[str, dict[str, Any]] = {} batch_size = data.batch_size[0] + if batch_size == 0: + logger.warning("Trying to extract field schema for empty batch. No action is taken.") + for field_name, value in data.items(): is_tensor = isinstance(value, torch.Tensor) is_nested = is_tensor and value.is_nested - first_item = None if is_nested: unbound = value.unbind() - first_item = unbound[0] if unbound else None + if len(unbound) != batch_size: + raise ValueError( + f"Inconsistent batch dimension for field '{field_name}': " + f"expected batch_size[0]={batch_size}, got nested tensor composed of {len(unbound)} tensors" + ) + first_item = unbound[0] elif is_tensor: - first_item = value[0] if value.shape[0] > 0 else None - else: - first_item = value[0] if len(value) > 0 else None - - # Determine is_non_tensor: when first_item is None (empty field), cannot determine type - if first_item is None: - is_non_tensor = None - else: - is_non_tensor = not is_tensor - - # Determine the shape of each sample (excluding batch dimension) - # When TensorDict converts a Python list to tensor, the first dimension equals batch_size - # We need to strip this batch dimension to get per-sample shape - if isinstance(value, torch.Tensor) and not is_nested and value.shape[0] > 0: if value.shape[0] != batch_size: raise ValueError( f"Inconsistent batch dimension for field '{field_name}': " f"expected batch_size[0]={batch_size}, got value.shape[0]={value.shape[0]}" ) - if len(value.shape) > 1: - sample_shape = value.shape[1:] - else: - # When input is 1D tensor, manually set to torch.Size([1]). - sample_shape = torch.Size([1]) + first_item = value[0] + else: + if len(value) != batch_size: + raise ValueError( + f"Inconsistent batch dimension for field '{field_name}': " + f"expected batch_size[0]={batch_size}, got len(value)={len(value)}" + ) + first_item = value[0] + + if is_tensor or isinstance(first_item, np.ndarray): + sample_shape = first_item.shape else: - sample_shape = getattr(first_item, "shape", None) if first_item is not None else None + sample_shape = None field_meta = { "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), "shape": sample_shape, "is_nested": is_nested, - "is_non_tensor": is_non_tensor, + "is_non_tensor": not is_tensor, } # For nested tensors, record per-sample shapes @@ -441,6 +440,10 @@ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "Ba set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True. """ batch_size = tensor_dict.batch_size[0] + + if batch_size == 0: + logger.warning(f"Input TensorDict is empty with batch_size={batch_size}. No action is taken.") + if batch_size != self.size: raise ValueError(f"add_fields batch size mismatch: self.size={self.size} vs tensor_dict={batch_size}") From fc0ca4a97f76d426df9c6d737519823b438a2ed7 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 17 Mar 2026 19:44:05 +0800 Subject: [PATCH 4/9] fix more metadata error Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 93 ++++++++++++++----------- transfer_queue/metadata.py | 2 +- transfer_queue/storage/managers/base.py | 7 ++ 3 files changed, 61 insertions(+), 41 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index f1410480..537f7fc9 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -202,6 +202,7 @@ class FieldMeta: indexed by global_idx, O(B_nested) storage. """ + global_indexes: set[int] = field(default_factory=set) dtype: Any = None shape: Optional[tuple] = None # None when is_nested=True is_nested: bool = False @@ -209,15 +210,16 @@ class FieldMeta: per_sample_shapes: dict[int, tuple] = field(default_factory=dict) # {global_idx: shape} - def update(self, incoming: dict[str, Any]) -> None: + def update(self, incoming: dict[str, Any], incoming_global_indexes: list[int]) -> None: """Update this field's metadata from an incoming schema dict. Encapsulates dtype consistency check, shape conflict detection, and automatic is_nested inference. Args: - incoming: Schema dict with optional keys: dtype, shape, is_nested, is_non_tensor. - + incoming: Schema dict with optional keys: + global_indexes, dtype, shape, is_nested, is_non_tensor, per_sample_shape + incoming_global_indexes: global indexes of the inupt meta Raises: ValueError: If incoming dtype conflicts with existing dtype. """ @@ -232,24 +234,55 @@ def update(self, incoming: dict[str, Any]) -> None: f"All batches for the same field must have the same dtype." ) - # shape consistency check → is_nested inference - new_shape = incoming.get("shape") - if new_shape is not None: - if self.shape is None and not self.is_nested: - self.shape = new_shape - elif self.shape is not None and self.shape != new_shape: + new_is_nested = incoming.get("is_nested") + new_is_non_tensor = incoming.get("is_non_tensor") + + if new_is_nested: + new_per_sample_shapes = incoming.get("per_sample_shapes", None) + if new_per_sample_shapes is None: + raise ValueError("Receiving a nested field without 'per_sample_shapes'!") + if not self.is_nested: + # new input is nested, but original is regular tensor. + # We need to write old shape into per_sample_shampes + assert self.shape is not None + for gi in self.global_indexes: + self.per_sample_shapes[gi] = self.shape self.is_nested = True self.shape = None - # explicit is_nested flag overrides inference - if incoming.get("is_nested"): - self.is_nested = True - self.shape = None + # Update newly provided per_sample_shapes + self.per_sample_shapes.update(new_per_sample_shapes) + + else: + if not new_is_non_tensor: + # newly input is regular tensor + new_shape = incoming.get("shape", None) + if new_shape is None: + raise ValueError("Receiving a regular tensor without 'shape'!") + if self.is_nested: + # we need to update incoming shape into per_sample_shapes + for gi in incoming_global_indexes: + self.per_sample_shapes[gi] = new_shape + else: + if not self.is_non_tensor: + # original data is also regular tensor + assert self.shape is not None + if self.shape != new_shape: + for gi in self.global_indexes: + self.per_sample_shapes[gi] = self.shape + for gi in incoming_global_indexes: + self.per_sample_shapes[gi] = new_shape + + self.shape = None + self.is_nested = True + + self.global_indexes.update(incoming_global_indexes) def remove_samples(self, indexes: list[int]): """Remove sample-level data for the given indexes.""" for idx in indexes: self.per_sample_shapes.pop(idx, None) + self.global_indexes.discard(idx) # After removing samples, check if we can update is_nested and shape # If per_sample_shapes is empty or all remaining shapes are the same, @@ -260,7 +293,9 @@ def remove_samples(self, indexes: list[int]): self.shape = None else: # Check if all remaining shapes are the same - remaining_shapes = set(self.per_sample_shapes.values()) + remaining_shapes = set( + tuple(shape) if isinstance(shape, list) else shape for shape in self.per_sample_shapes.values() + ) if len(remaining_shapes) == 1: # All remaining samples have the same shape - update to non-nested self.is_nested = False @@ -278,6 +313,7 @@ def to_batch_schema(self, batch_global_indexes: list[int]) -> dict[str, Any]: } if self.is_nested and self.per_sample_shapes: schema["per_sample_shapes"] = [self.per_sample_shapes.get(gi) for gi in batch_global_indexes] + return schema @@ -540,38 +576,15 @@ def _update_field_metadata( for field_name, meta in field_schema.items(): if field_name not in self.field_metadata: self.field_metadata[field_name] = FieldMeta( + global_indexes=set(global_indexes), dtype=meta.get("dtype"), shape=meta.get("shape"), is_nested=meta.get("is_nested", False), is_non_tensor=meta.get("is_non_tensor", False), + per_sample_shapes=meta.get("per_sample_shapes", {}), ) else: - # Track if is_nested changed from False to True during update - was_not_nested = not self.field_metadata[field_name].is_nested - # Save old shape before update (for filling per_sample_shapes of existing samples) - old_shape = self.field_metadata[field_name].shape - self.field_metadata[field_name].update(meta) - # If is_nested became True due to shape mismatch, capture shapes for all samples - if was_not_nested and self.field_metadata[field_name].is_nested: - col_meta = self.field_metadata[field_name] - new_shape = meta.get("shape") - # Fill new samples with new shape - if new_shape is not None: - for gi in global_indexes: - col_meta.per_sample_shapes[gi] = new_shape - # Fill existing samples with old shape - if old_shape is not None: - for gi in self.global_indexes: - if gi not in col_meta.per_sample_shapes: - col_meta.per_sample_shapes[gi] = old_shape - - # nested per-sample shapes - per_sample_shapes = meta.get("per_sample_shapes") - if per_sample_shapes: - col_meta = self.field_metadata[field_name] - for gi, shape in zip(global_indexes, per_sample_shapes, strict=False): - if shape is not None: - col_meta.per_sample_shapes[gi] = shape + self.field_metadata[field_name].update(meta, global_indexes) # custom_backend_meta remains row-oriented storage if custom_backend_meta: diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 9c04e23d..dbd00755 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -176,7 +176,7 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: ) first_item = value[0] - if is_tensor or isinstance(first_item, np.ndarray): + if is_tensor: sample_shape = first_item.shape else: sample_shape = None diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index bf6cdb9d..b60d1e8d 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -219,6 +219,13 @@ async def notify_data_update( try: sock.connect(self.controller_info.to_addr("data_status_update_socket")) + # FIXME: convert per_sample_shapes into dict + for field in field_schema.values(): + per_sample_shapes = field.get("per_sample_shapes", None) + if per_sample_shapes: + per_sample_shapes = {global_indexes[i]: per_sample_shapes[i] for i in range(len(global_indexes))} + field["per_sample_shapes"] = per_sample_shapes + request_msg = ZMQMessage.create( request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] sender_id=self.storage_manager_id, From 7c551146a7659975bcd35a79e43887cdee784389 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 17 Mar 2026 20:21:28 +0800 Subject: [PATCH 5/9] fix more metadata error Signed-off-by: 0oshowero0 --- tests/test_controller_data_partitions.py | 144 ++++++++++++++++++++++- 1 file changed, 139 insertions(+), 5 deletions(-) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 904f4320..72b0f319 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -521,7 +521,14 @@ def test_shape_conflict_promotes_to_nested(self): def test_nested_per_sample_shapes(self): partition = self._make_partition() - schema = {"f3": {"dtype": "torch.float32", "shape": None, "is_nested": True, "per_sample_shapes": [(3,), (5,)]}} + schema = { + "f3": { + "dtype": "torch.float32", + "shape": None, + "is_nested": True, + "per_sample_shapes": {10: (3,), 11: (5,)}, + } + } partition._update_field_metadata([10, 11], schema) assert partition.field_metadata["f3"].is_nested is True assert partition.field_metadata["f3"].per_sample_shapes == {10: (3,), 11: (5,)} @@ -1056,10 +1063,12 @@ def test_remove_samples(self): fm = FieldMeta(is_nested=True) fm.per_sample_shapes = {0: (3,), 1: (5,), 2: (7,)} + fm.global_indexes = {0, 1, 2} fm.remove_samples([0, 2]) assert fm.per_sample_shapes == {} assert fm.shape == (5,) assert not fm.is_nested + assert fm.global_indexes == {1} # Removing non-existent index should not raise fm.remove_samples([99]) @@ -1099,13 +1108,138 @@ def test_update_dtype_conflict(self): from transfer_queue.controller import FieldMeta fm = FieldMeta(dtype="torch.int32", shape=(16,)) + fm.global_indexes = {0} with pytest.raises(ValueError, match="dtype mismatch"): - fm.update({"dtype": "torch.float64"}) + fm.update({"dtype": "torch.float64"}, [1]) + + def test_update_regular_to_regular_different_shape_becomes_nested(self): + """Test that two regular tensor updates with different shapes promotes to nested.""" + from transfer_queue.controller import FieldMeta + + # Start with a regular tensor + fm = FieldMeta(dtype="torch.float32", shape=(256,), is_nested=False) + fm.global_indexes = {0} + + # Update with a different shape regular tensor + fm.update({"dtype": "torch.float32", "shape": (128,)}, [1]) + + # Should now be nested with both shapes saved + assert fm.is_nested is True + assert fm.shape is None + assert fm.per_sample_shapes[0] == (256,) + assert fm.per_sample_shapes[1] == (128,) + assert fm.global_indexes == {0, 1} + + def test_update_regular_to_nested_promotes_nested(self): + """Test that updating from regular tensor to nested tensor correctly promotes.""" + from transfer_queue.controller import FieldMeta + + # Start with a regular tensor + fm = FieldMeta(dtype="torch.float32", shape=(256,), is_nested=False) + fm.global_indexes = {0, 1, 2} + + # Update with a nested tensor (different per_sample_shapes) + incoming = {"dtype": "torch.float32", "is_nested": True, "per_sample_shapes": {3: (128,), 4: (512,)}} + fm.update(incoming, [3, 4]) + + # Should now be nested + assert fm.is_nested is True + # Original shape should be saved in per_sample_shapes + assert fm.per_sample_shapes[0] == (256,) + assert fm.per_sample_shapes[1] == (256,) + assert fm.per_sample_shapes[2] == (256,) + # New shapes should be added + assert fm.per_sample_shapes[3] == (128,) + assert fm.per_sample_shapes[4] == (512,) + # shape should be None for nested + assert fm.shape is None + # global_indexes should be updated + assert fm.global_indexes == {0, 1, 2, 3, 4} - def test_update_shape_conflict_promotes_nested(self): + def test_update_nested_to_regular_merges_shapes(self): + """Test that updating from nested to regular tensor adds new shapes to per_sample_shapes.""" from transfer_queue.controller import FieldMeta - fm = FieldMeta(dtype="torch.float32", shape=(256,)) - fm.update({"dtype": "torch.float32", "shape": (128,)}) + # Start with a nested tensor + fm = FieldMeta(dtype="torch.float32", shape=None, is_nested=True) + fm.per_sample_shapes = {0: (128,), 1: (512,)} + fm.global_indexes = {0, 1} + + # Update with a regular tensor (all same shape) + incoming = {"dtype": "torch.float32", "is_nested": False, "shape": (256,)} + fm.update(incoming, [2, 3]) + + # Once nested, stays nested (historical data is nested) assert fm.is_nested is True + # Old shapes should remain unchanged + assert fm.per_sample_shapes[0] == (128,) + assert fm.per_sample_shapes[1] == (512,) + # New shapes should be added to per_sample_shapes + assert fm.per_sample_shapes[2] == (256,) + assert fm.per_sample_shapes[3] == (256,) + # global_indexes should be updated + assert fm.global_indexes == {0, 1, 2, 3} + + def test_remove_samples_different_removed_becomes_regular(self): + """Test that removing samples with different shapes converts back to + regular tensor when remaining shapes are same. + """ + from transfer_queue.controller import FieldMeta + + # Start with nested field having different shapes: {0: (256,), 1: (128,), 2: (256,)} + fm = FieldMeta(dtype="torch.float32", shape=None, is_nested=True) + fm.per_sample_shapes = {0: (256,), 1: (128,), 2: (256,)} + fm.global_indexes = {0, 1, 2} + + # Remove the sample with different shape (index 1) + fm.remove_samples([1]) + + # After removing index 1, remaining shapes are all (256,) + # Should convert back to non-nested + assert fm.is_nested is False + assert fm.shape == (256,) + # per_sample_shapes should be cleared + assert fm.per_sample_shapes == {} + assert fm.global_indexes == {0, 2} + + def test_remove_samples_all_removed_resets_state(self): + """Test that removing all samples resets the field meta.""" + from transfer_queue.controller import FieldMeta + + fm = FieldMeta(dtype="torch.float32", shape=None, is_nested=True) + fm.per_sample_shapes = {0: (256,), 1: (128,)} + fm.global_indexes = {0, 1} + + fm.remove_samples([0, 1]) + + # All samples removed - should reset + assert fm.is_nested is False assert fm.shape is None + assert fm.per_sample_shapes == {} + assert fm.global_indexes == set() + + def test_update_nested_with_partial_overlap(self): + """Test update with nested tensor where some indexes already exist.""" + from transfer_queue.controller import FieldMeta + + # Start with a regular tensor + fm = FieldMeta(dtype="torch.float32", shape=(256,), is_nested=False) + fm.global_indexes = {0, 1} + + # Update with nested tensor that includes overlapping indexes + incoming = { + "dtype": "torch.float32", + "is_nested": True, + "per_sample_shapes": {1: (128,), 2: (512,)}, # 1 overlaps + } + fm.update(incoming, [1, 2]) + + # Should now be nested + assert fm.is_nested is True + # Original shape for index 0 should be saved + assert fm.per_sample_shapes[0] == (256,) + # Index 1 should be updated with new shape + assert fm.per_sample_shapes[1] == (128,) + # Index 2 is new + assert fm.per_sample_shapes[2] == (512,) + assert fm.global_indexes == {0, 1, 2} From 7afd72692ea406f93d780e21d3553207539ed742 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 17 Mar 2026 20:34:27 +0800 Subject: [PATCH 6/9] fix mooncake backend Signed-off-by: 0oshowero0 --- transfer_queue/metadata.py | 3 +++ transfer_queue/storage/clients/mooncake_client.py | 1 + 2 files changed, 4 insertions(+) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index dbd00755..ac23cb73 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -167,6 +167,9 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: f"Inconsistent batch dimension for field '{field_name}': " f"expected batch_size[0]={batch_size}, got value.shape[0]={value.shape[0]}" ) + if len(value.shape) == 1: + logger.warning(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.") + value = value.unsqueeze(-1) first_item = value[0] else: if len(value) != batch_size: diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index a6273210..6ab610ee 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -43,6 +43,7 @@ class MooncakeStoreClient(TransferQueueStorageKVClient): """ def __init__(self, config: dict[str, Any]): + super().__init__(config) if not MOONCAKE_STORE_IMPORTED: raise ImportError("Mooncake Store not installed. Please install via: pip install mooncake-transfer-engine") From e8fca14d5f0fc2a49af40c14a4406e33fa1fca38 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 17 Mar 2026 21:40:23 +0800 Subject: [PATCH 7/9] fix Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 537f7fc9..33db3615 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -203,10 +203,10 @@ class FieldMeta: """ global_indexes: set[int] = field(default_factory=set) - dtype: Any = None + dtype: Optional[Any] = None shape: Optional[tuple] = None # None when is_nested=True - is_nested: bool = False - is_non_tensor: bool = False + is_nested: Optional[bool] = None + is_non_tensor: Optional[bool] = None per_sample_shapes: dict[int, tuple] = field(default_factory=dict) # {global_idx: shape} @@ -241,7 +241,7 @@ def update(self, incoming: dict[str, Any], incoming_global_indexes: list[int]) - new_per_sample_shapes = incoming.get("per_sample_shapes", None) if new_per_sample_shapes is None: raise ValueError("Receiving a nested field without 'per_sample_shapes'!") - if not self.is_nested: + if self.is_nested is not None and not self.is_nested: # new input is nested, but original is regular tensor. # We need to write old shape into per_sample_shampes assert self.shape is not None @@ -264,7 +264,7 @@ def update(self, incoming: dict[str, Any], incoming_global_indexes: list[int]) - for gi in incoming_global_indexes: self.per_sample_shapes[gi] = new_shape else: - if not self.is_non_tensor: + if self.is_non_tensor is not None and not self.is_non_tensor: # original data is also regular tensor assert self.shape is not None if self.shape != new_shape: @@ -287,21 +287,24 @@ def remove_samples(self, indexes: list[int]): # After removing samples, check if we can update is_nested and shape # If per_sample_shapes is empty or all remaining shapes are the same, # we should reset is_nested to False and update shape accordingly - if not self.per_sample_shapes: + if len(self.global_indexes) == 0: # All samples removed - reset to non-nested state self.is_nested = False self.shape = None + self.is_non_tensor = None + self.dtype = None else: - # Check if all remaining shapes are the same - remaining_shapes = set( - tuple(shape) if isinstance(shape, list) else shape for shape in self.per_sample_shapes.values() - ) - if len(remaining_shapes) == 1: - # All remaining samples have the same shape - update to non-nested - self.is_nested = False - self.shape = next(iter(remaining_shapes)) - # Clear per-sample shapes since we are no longer nested - self.per_sample_shapes.clear() + if self.is_nested: + # Check if all remaining shapes are the same + remaining_shapes = set( + tuple(shape) if isinstance(shape, list) else shape for shape in self.per_sample_shapes.values() + ) + if len(remaining_shapes) == 1: + # All remaining samples have the same shape - update to non-nested + self.is_nested = False + self.shape = next(iter(remaining_shapes)) + # Clear per-sample shapes since we are no longer nested + self.per_sample_shapes.clear() def to_batch_schema(self, batch_global_indexes: list[int]) -> dict[str, Any]: """Export as a BatchMeta.field_schema-compatible dict for generate_batch_meta.""" From bd3891266a43e566a41cebe7a6a67fd76c32f214 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 18 Mar 2026 09:20:25 +0800 Subject: [PATCH 8/9] fully remove FieldMeta if empty Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 --- tests/test_controller_data_partitions.py | 491 +++++++++++++++-------- transfer_queue/controller.py | 14 +- transfer_queue/metadata.py | 2 + 3 files changed, 333 insertions(+), 174 deletions(-) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 72b0f319..02bacbb5 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -1055,191 +1055,346 @@ def test_kv_retrieve_keys_partial_match(self): assert keys == ["key_1", "key_3"] -class TestFieldMeta: - """Unit tests for FieldMeta dataclass.""" - - def test_remove_samples(self): - from transfer_queue.controller import FieldMeta - - fm = FieldMeta(is_nested=True) - fm.per_sample_shapes = {0: (3,), 1: (5,), 2: (7,)} - fm.global_indexes = {0, 1, 2} - fm.remove_samples([0, 2]) - assert fm.per_sample_shapes == {} - assert fm.shape == (5,) - assert not fm.is_nested - assert fm.global_indexes == {1} - # Removing non-existent index should not raise - fm.remove_samples([99]) - - def test_to_batch_schema_regular(self): - from transfer_queue.controller import FieldMeta - - fm = FieldMeta(dtype="torch.float32", shape=(512,), is_nested=False, is_non_tensor=False) - schema = fm.to_batch_schema([0, 1, 2]) - assert schema == { - "dtype": "torch.float32", - "shape": (512,), - "is_nested": False, - "is_non_tensor": False, - } - assert "per_sample_shapes" not in schema +class TestFieldMetaIntegration: + """Unit tests for DataPartitionStatus integration with FieldMeta. + + Tests that _update_field_metadata correctly updates underlying FieldMeta state, + and clear_data properly handles FieldMeta when partition becomes empty or partially empty. + """ - def test_to_batch_schema_nested(self): - from transfer_queue.controller import FieldMeta + def _make_partition(self): + from transfer_queue.controller import DataPartitionStatus - fm = FieldMeta(dtype="torch.float32", shape=None, is_nested=True) - fm.per_sample_shapes = {0: (3,), 1: (5,), 2: (7,)} - schema = fm.to_batch_schema([0, 2, 1]) - assert schema["is_nested"] is True - assert schema["per_sample_shapes"] == [(3,), (7,), (5,)] + return DataPartitionStatus(partition_id="fieldmeta_integration_test") - def test_to_batch_schema_nested_missing_sample(self): - from transfer_queue.controller import FieldMeta + def test_update_field_metadata_creates_fieldmeta(self): + """Test that _update_field_metadata creates FieldMeta for new fields.""" + partition = self._make_partition() - fm = FieldMeta(dtype="torch.float32", shape=None, is_nested=True) - fm.per_sample_shapes = {0: (3,)} - schema = fm.to_batch_schema([0, 1]) - assert schema["per_sample_shapes"] == [(3,), None] + # Update with some field metadata + partition._update_field_metadata( + global_indexes=[0, 1, 2], + field_schema={ + "input_ids": {"dtype": "torch.int32", "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": "torch.bool", "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, + ) - def test_update_dtype_conflict(self): + # Verify FieldMeta was created for both fields + assert "input_ids" in partition.field_metadata + assert "attention_mask" in partition.field_metadata + + # Verify FieldMeta properties + input_ids_meta = partition.field_metadata["input_ids"] + assert input_ids_meta.dtype == "torch.int32" + assert input_ids_meta.shape == (512,) + assert input_ids_meta.is_nested is False + assert input_ids_meta.is_non_tensor is False + assert input_ids_meta.global_indexes == {0, 1, 2} + + attention_mask_meta = partition.field_metadata["attention_mask"] + assert attention_mask_meta.dtype == "torch.bool" + assert attention_mask_meta.shape == (512,) + + def test_update_field_metadata_incremental_add(self): + """Test that _update_field_metadata correctly handles incremental field additions.""" + partition = self._make_partition() + + # First update + partition._update_field_metadata( + global_indexes=[0, 1], + field_schema={"field_a": {"dtype": "torch.int32", "shape": (16,)}}, + ) + + field_meta = partition.field_metadata["field_a"] + assert field_meta.dtype == "torch.int32" + assert field_meta.shape == (16,) + assert field_meta.global_indexes == {0, 1} + + # Second update with new indexes + partition._update_field_metadata( + global_indexes=[2, 3], + field_schema={"field_a": {"dtype": "torch.int32", "shape": (16,)}}, + ) + + # Verify indexes were added + assert field_meta.global_indexes == {0, 1, 2, 3} + assert field_meta.dtype == "torch.int32" + assert field_meta.shape == (16,) + + def test_update_field_metadata_dtype_conflict_raises(self): + """Test that _update_field_metadata raises error on dtype conflict.""" import pytest - from transfer_queue.controller import FieldMeta + partition = self._make_partition() + + # First update + partition._update_field_metadata( + global_indexes=[0], + field_schema={"field_x": {"dtype": "torch.int32", "shape": (16,)}}, + ) - fm = FieldMeta(dtype="torch.int32", shape=(16,)) - fm.global_indexes = {0} + # Second update with conflicting dtype should raise with pytest.raises(ValueError, match="dtype mismatch"): - fm.update({"dtype": "torch.float64"}, [1]) + partition._update_field_metadata( + global_indexes=[1], + field_schema={"field_x": {"dtype": "torch.float64", "shape": (16,)}}, + ) + + def test_update_field_metadata_shape_conflict_promotes_nested(self): + """Test that shape conflict promotes field to nested.""" + partition = self._make_partition() - def test_update_regular_to_regular_different_shape_becomes_nested(self): - """Test that two regular tensor updates with different shapes promotes to nested.""" - from transfer_queue.controller import FieldMeta + # First update with shape (256,) + partition._update_field_metadata( + global_indexes=[0], + field_schema={"field_nested": {"dtype": "torch.float32", "shape": (256,)}}, + ) - # Start with a regular tensor - fm = FieldMeta(dtype="torch.float32", shape=(256,), is_nested=False) - fm.global_indexes = {0} + # Second update with different shape (128,) + partition._update_field_metadata( + global_indexes=[1], + field_schema={"field_nested": {"dtype": "torch.float32", "shape": (128,)}}, + ) - # Update with a different shape regular tensor - fm.update({"dtype": "torch.float32", "shape": (128,)}, [1]) + field_meta = partition.field_metadata["field_nested"] + # Should now be nested + assert field_meta.is_nested is True + assert field_meta.shape is None + # Both shapes should be tracked + assert 0 in field_meta.per_sample_shapes + assert 1 in field_meta.per_sample_shapes + assert field_meta.per_sample_shapes[0] == (256,) + assert field_meta.per_sample_shapes[1] == (128,) + + def test_update_field_metadata_with_custom_backend_meta(self): + """Test that _update_field_metadata correctly stores custom_backend_meta.""" + partition = self._make_partition() - # Should now be nested with both shapes saved - assert fm.is_nested is True - assert fm.shape is None - assert fm.per_sample_shapes[0] == (256,) - assert fm.per_sample_shapes[1] == (128,) - assert fm.global_indexes == {0, 1} + partition._update_field_metadata( + global_indexes=[0, 1, 2], + field_schema={"field_a": {"dtype": "torch.int32"}}, + custom_backend_meta={ + 0: {"field_a": {"token_count": 100}}, + 1: {"field_a": {"token_count": 200}}, + 2: {"field_a": {"token_count": 300}}, + }, + ) - def test_update_regular_to_nested_promotes_nested(self): - """Test that updating from regular tensor to nested tensor correctly promotes.""" - from transfer_queue.controller import FieldMeta + # Verify custom_backend_meta was stored + assert 0 in partition.field_custom_backend_meta + assert partition.field_custom_backend_meta[0]["field_a"]["token_count"] == 100 + assert partition.field_custom_backend_meta[1]["field_a"]["token_count"] == 200 + assert partition.field_custom_backend_meta[2]["field_a"]["token_count"] == 300 - # Start with a regular tensor - fm = FieldMeta(dtype="torch.float32", shape=(256,), is_nested=False) - fm.global_indexes = {0, 1, 2} + def test_update_field_metadata_empty_indexes_is_noop(self): + """Test that _update_field_metadata with empty indexes does nothing.""" + partition = self._make_partition() - # Update with a nested tensor (different per_sample_shapes) - incoming = {"dtype": "torch.float32", "is_nested": True, "per_sample_shapes": {3: (128,), 4: (512,)}} - fm.update(incoming, [3, 4]) + # Empty update should not raise and should not create any field_metadata + partition._update_field_metadata( + global_indexes=[], + field_schema={}, + ) - # Should now be nested - assert fm.is_nested is True - # Original shape should be saved in per_sample_shapes - assert fm.per_sample_shapes[0] == (256,) - assert fm.per_sample_shapes[1] == (256,) - assert fm.per_sample_shapes[2] == (256,) - # New shapes should be added - assert fm.per_sample_shapes[3] == (128,) - assert fm.per_sample_shapes[4] == (512,) - # shape should be None for nested - assert fm.shape is None - # global_indexes should be updated - assert fm.global_indexes == {0, 1, 2, 3, 4} - - def test_update_nested_to_regular_merges_shapes(self): - """Test that updating from nested to regular tensor adds new shapes to per_sample_shapes.""" - from transfer_queue.controller import FieldMeta - - # Start with a nested tensor - fm = FieldMeta(dtype="torch.float32", shape=None, is_nested=True) - fm.per_sample_shapes = {0: (128,), 1: (512,)} - fm.global_indexes = {0, 1} - - # Update with a regular tensor (all same shape) - incoming = {"dtype": "torch.float32", "is_nested": False, "shape": (256,)} - fm.update(incoming, [2, 3]) - - # Once nested, stays nested (historical data is nested) - assert fm.is_nested is True - # Old shapes should remain unchanged - assert fm.per_sample_shapes[0] == (128,) - assert fm.per_sample_shapes[1] == (512,) - # New shapes should be added to per_sample_shapes - assert fm.per_sample_shapes[2] == (256,) - assert fm.per_sample_shapes[3] == (256,) - # global_indexes should be updated - assert fm.global_indexes == {0, 1, 2, 3} - - def test_remove_samples_different_removed_becomes_regular(self): - """Test that removing samples with different shapes converts back to - regular tensor when remaining shapes are same. - """ - from transfer_queue.controller import FieldMeta - - # Start with nested field having different shapes: {0: (256,), 1: (128,), 2: (256,)} - fm = FieldMeta(dtype="torch.float32", shape=None, is_nested=True) - fm.per_sample_shapes = {0: (256,), 1: (128,), 2: (256,)} - fm.global_indexes = {0, 1, 2} - - # Remove the sample with different shape (index 1) - fm.remove_samples([1]) - - # After removing index 1, remaining shapes are all (256,) - # Should convert back to non-nested - assert fm.is_nested is False - assert fm.shape == (256,) - # per_sample_shapes should be cleared - assert fm.per_sample_shapes == {} - assert fm.global_indexes == {0, 2} - - def test_remove_samples_all_removed_resets_state(self): - """Test that removing all samples resets the field meta.""" - from transfer_queue.controller import FieldMeta - - fm = FieldMeta(dtype="torch.float32", shape=None, is_nested=True) - fm.per_sample_shapes = {0: (256,), 1: (128,)} - fm.global_indexes = {0, 1} - - fm.remove_samples([0, 1]) - - # All samples removed - should reset - assert fm.is_nested is False - assert fm.shape is None - assert fm.per_sample_shapes == {} - assert fm.global_indexes == set() - - def test_update_nested_with_partial_overlap(self): - """Test update with nested tensor where some indexes already exist.""" - from transfer_queue.controller import FieldMeta - - # Start with a regular tensor - fm = FieldMeta(dtype="torch.float32", shape=(256,), is_nested=False) - fm.global_indexes = {0, 1} - - # Update with nested tensor that includes overlapping indexes - incoming = { + assert partition.field_metadata == {} + + def test_clear_data_removes_samples_from_fieldmeta(self): + """Test that clear_data correctly removes samples from FieldMeta.""" + partition = self._make_partition() + + # Set up some data using update_production_status (which properly initializes global_indexes) + partition.update_production_status( + global_indices=[0, 1, 2, 3, 4], + field_names=["field_a", "field_b"], + field_schema={ + "field_a": {"dtype": "torch.int32", "shape": (16,)}, + "field_b": {"dtype": "torch.float32", "shape": (32,)}, + }, + ) + + # Verify initial state + assert partition.field_metadata["field_a"].global_indexes == {0, 1, 2, 3, 4} + assert partition.field_metadata["field_b"].global_indexes == {0, 1, 2, 3, 4} + assert partition.global_indexes == {0, 1, 2, 3, 4} + + # Clear some samples (0, 2, 4) + partition.clear_data([0, 2, 4], clear_consumption=False) + + # Verify samples were removed from FieldMeta + assert partition.field_metadata["field_a"].global_indexes == {1, 3} + assert partition.field_metadata["field_b"].global_indexes == {1, 3} + + def test_clear_data_all_samples_clears_fieldmeta_when_empty_partition(self): + """Test that clear_data clears all FieldMeta when partition becomes empty.""" + partition = self._make_partition() + + # Set up some data using update_production_status + partition.update_production_status( + global_indices=[0, 1, 2], + field_names=["field_a", "field_b"], + field_schema={ + "field_a": {"dtype": "torch.int32", "shape": (16,)}, + "field_b": {"dtype": "torch.float32", "shape": (32,)}, + }, + ) + + # Verify initial state + assert len(partition.field_metadata) == 2 + assert partition.global_indexes == {0, 1, 2} + + # Clear all samples - should clear field_metadata when partition is empty + partition.clear_data([0, 1, 2], clear_consumption=False) + + # After clearing all samples, field_metadata should be cleared + assert partition.field_metadata == {} + + def test_clear_data_nested_field_becomes_regular(self): + """Test that nested FieldMeta becomes regular when remaining samples have same shape.""" + partition = self._make_partition() + + # Create nested field with different shapes using update_production_status + partition.update_production_status( + global_indices=[0, 1], + field_names=["nested_field"], + field_schema={"nested_field": {"dtype": "torch.float32", "shape": (256,)}}, + ) + partition.update_production_status( + global_indices=[2], + field_names=["nested_field"], + field_schema={"nested_field": {"dtype": "torch.float32", "shape": (128,)}}, + ) + + # Verify it's nested + assert partition.field_metadata["nested_field"].is_nested is True + assert partition.field_metadata["nested_field"].global_indexes == {0, 1, 2} + assert partition.field_metadata["nested_field"].per_sample_shapes == {0: (256,), 1: (256,), 2: (128,)} + + # Clear the sample with different shape (index 2) + partition.clear_data([2], clear_consumption=False) + + # Now only samples 0, 1 remain with same shape (256,) + # FieldMeta should become non-nested + field_meta = partition.field_metadata["nested_field"] + assert field_meta.is_nested is False + assert field_meta.shape == (256,) + assert field_meta.global_indexes == {0, 1} + assert field_meta.per_sample_shapes == {} + + def test_update_production_status_updates_field_metadata(self): + """Test that update_production_status correctly updates field_metadata via _update_field_metadata.""" + partition = self._make_partition() + + # Use update_production_status (which internally calls _update_field_metadata) + partition.update_production_status( + global_indices=[0, 1, 2], + field_names=["input_ids", "attention_mask"], + field_schema={ + "input_ids": {"dtype": "torch.int32", "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": "torch.bool", "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, + ) + + # Verify field_metadata was updated + assert "input_ids" in partition.field_metadata + assert "attention_mask" in partition.field_metadata + assert partition.field_metadata["input_ids"].dtype == "torch.int32" + assert partition.field_metadata["attention_mask"].dtype == "torch.bool" + + def test_fieldmeta_global_indexes_in_sync_with_partition(self): + """Test that FieldMeta global_indexes stays in sync with partition's global_indexes.""" + partition = self._make_partition() + + # Add data + partition.update_production_status( + global_indices=[0, 1, 2, 3, 4], + field_names=["field_a"], + field_schema={"field_a": {"dtype": "torch.int32", "shape": (16,)}}, + ) + + # Verify sync + assert partition.global_indexes == {0, 1, 2, 3, 4} + assert partition.field_metadata["field_a"].global_indexes == {0, 1, 2, 3, 4} + + def test_fieldmeta_to_batch_schema_regular(self): + """Test that FieldMeta.to_batch_schema works correctly for regular tensors.""" + partition = self._make_partition() + + # Create a regular field + partition.update_production_status( + global_indices=[0, 1, 2], + field_names=["regular_field"], + field_schema={"regular_field": {"dtype": "torch.float32", "shape": (512,), "is_nested": False}}, + ) + + field_meta = partition.field_metadata["regular_field"] + schema = field_meta.to_batch_schema([0, 1, 2]) + + assert schema == { "dtype": "torch.float32", - "is_nested": True, - "per_sample_shapes": {1: (128,), 2: (512,)}, # 1 overlaps + "shape": (512,), + "is_nested": False, + "is_non_tensor": False, } - fm.update(incoming, [1, 2]) + assert "per_sample_shapes" not in schema - # Should now be nested - assert fm.is_nested is True - # Original shape for index 0 should be saved - assert fm.per_sample_shapes[0] == (256,) - # Index 1 should be updated with new shape - assert fm.per_sample_shapes[1] == (128,) - # Index 2 is new - assert fm.per_sample_shapes[2] == (512,) - assert fm.global_indexes == {0, 1, 2} + def test_fieldmeta_to_batch_schema_nested(self): + """Test that FieldMeta.to_batch_schema works correctly for nested tensors.""" + partition = self._make_partition() + + # Create nested field + partition.update_production_status( + global_indices=[0, 1], + field_names=["nested_field"], + field_schema={ + "nested_field": {"dtype": "torch.float32", "is_nested": True, "per_sample_shapes": {0: (3,), 1: (5,)}} + }, + ) + + field_meta = partition.field_metadata["nested_field"] + schema = field_meta.to_batch_schema([0, 1]) + + assert schema["is_nested"] is True + assert schema["per_sample_shapes"] == [(3,), (5,)] + + def test_fieldmeta_to_batch_schema_nested_different_order(self): + """Test that FieldMeta.to_batch_schema returns shapes in requested order.""" + partition = self._make_partition() + + # Create nested field + partition.update_production_status( + global_indices=[0, 1, 2], + field_names=["nested_field"], + field_schema={ + "nested_field": { + "dtype": "torch.float32", + "is_nested": True, + "per_sample_shapes": {0: (3,), 1: (5,), 2: (7,)}, + } + }, + ) + + field_meta = partition.field_metadata["nested_field"] + # Request in different order + schema = field_meta.to_batch_schema([2, 0, 1]) + + assert schema["per_sample_shapes"] == [(7,), (3,), (5,)] + + def test_fieldmeta_to_batch_schema_nested_missing_sample(self): + """Test that FieldMeta.to_batch_schema returns None for missing samples.""" + partition = self._make_partition() + + # Create nested field with only one sample + partition.update_production_status( + global_indices=[0], + field_names=["nested_field"], + field_schema={ + "nested_field": {"dtype": "torch.float32", "is_nested": True, "per_sample_shapes": {0: (3,)}} + }, + ) + + field_meta = partition.field_metadata["nested_field"] + # Request samples where one doesn't exist + schema = field_meta.to_batch_schema([0, 1]) + + assert schema["per_sample_shapes"] == [(3,), None] diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 33db3615..38320777 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -285,14 +285,13 @@ def remove_samples(self, indexes: list[int]): self.global_indexes.discard(idx) # After removing samples, check if we can update is_nested and shape - # If per_sample_shapes is empty or all remaining shapes are the same, - # we should reset is_nested to False and update shape accordingly if len(self.global_indexes) == 0: - # All samples removed - reset to non-nested state - self.is_nested = False - self.shape = None + # If no samples remain, fully reset field-level metadata. + self.is_nested = None self.is_non_tensor = None + self.shape = None self.dtype = None + self.per_sample_shapes.clear() else: if self.is_nested: # Check if all remaining shapes are the same @@ -928,8 +927,11 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr consumption_tensor[indexes_to_release] = 0 self.global_indexes.difference_update(indexes_to_release) - for field_meta in self.field_metadata.values(): + + for field_name, field_meta in self.field_metadata.items(): field_meta.remove_samples(indexes_to_release) + if len(self.global_indexes) == 0: + self.field_metadata.clear() for idx in indexes_to_release: self.field_custom_backend_meta.pop(idx, None) self.custom_meta.pop(idx, None) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index ac23cb73..62d4a4f5 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -148,6 +148,7 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: if batch_size == 0: logger.warning("Trying to extract field schema for empty batch. No action is taken.") + return field_schema for field_name, value in data.items(): is_tensor = isinstance(value, torch.Tensor) @@ -446,6 +447,7 @@ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "Ba if batch_size == 0: logger.warning(f"Input TensorDict is empty with batch_size={batch_size}. No action is taken.") + return self if batch_size != self.size: raise ValueError(f"add_fields batch size mismatch: self.size={self.size} vs tensor_dict={batch_size}") From 0f92c1e4e3d1609ea867a43cb7548290796f074a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 18 Mar 2026 12:40:07 +0800 Subject: [PATCH 9/9] solve comments Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 15 ++++++++++++--- transfer_queue/metadata.py | 4 +++- transfer_queue/storage/managers/base.py | 23 +++++++++++++++++++++-- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 38320777..7b8ad1bb 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -210,6 +210,7 @@ class FieldMeta: per_sample_shapes: dict[int, tuple] = field(default_factory=dict) # {global_idx: shape} + # TODO: FieldMeta needs to be refactored to prevent these complicated and fragile logics def update(self, incoming: dict[str, Any], incoming_global_indexes: list[int]) -> None: """Update this field's metadata from an incoming schema dict. @@ -218,8 +219,8 @@ def update(self, incoming: dict[str, Any], incoming_global_indexes: list[int]) - Args: incoming: Schema dict with optional keys: - global_indexes, dtype, shape, is_nested, is_non_tensor, per_sample_shape - incoming_global_indexes: global indexes of the inupt meta + global_indexes, dtype, shape, is_nested, is_non_tensor, per_sample_shapes + incoming_global_indexes: global indexes of the input meta Raises: ValueError: If incoming dtype conflicts with existing dtype. """ @@ -243,7 +244,7 @@ def update(self, incoming: dict[str, Any], incoming_global_indexes: list[int]) - raise ValueError("Receiving a nested field without 'per_sample_shapes'!") if self.is_nested is not None and not self.is_nested: # new input is nested, but original is regular tensor. - # We need to write old shape into per_sample_shampes + # We need to write old shape into per_sample_shapes assert self.shape is not None for gi in self.global_indexes: self.per_sample_shapes[gi] = self.shape @@ -928,10 +929,18 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr self.global_indexes.difference_update(indexes_to_release) + empty_fields = [] for field_name, field_meta in self.field_metadata.items(): field_meta.remove_samples(indexes_to_release) + if len(field_meta.global_indexes) == 0: + empty_fields.append(field_name) if len(self.global_indexes) == 0: + # clear the whole field_meta if the whole partition is empty self.field_metadata.clear() + else: + # only clear empty fields + for field_name in empty_fields: + self.field_metadata.pop(field_name) for idx in indexes_to_release: self.field_custom_backend_meta.pop(idx, None) self.custom_meta.pop(idx, None) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 62d4a4f5..9c91a6bf 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -182,11 +182,13 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: if is_tensor: sample_shape = first_item.shape + dtype = getattr(first_item, "dtype", None) else: sample_shape = None + dtype = None field_meta = { - "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), + "dtype": dtype, "shape": sample_shape, "is_nested": is_nested, "is_non_tensor": not is_tensor, diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index b60d1e8d..aa51d055 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -219,7 +219,26 @@ async def notify_data_update( try: sock.connect(self.controller_info.to_addr("data_status_update_socket")) - # FIXME: convert per_sample_shapes into dict + normalized_field_schema = {} + for field_name, field in field_schema.items(): + # Work on a shallow copy to avoid mutating caller-provided schema + field_copy = field.copy() + per_sample_shapes = field_copy.get("per_sample_shapes", None) + if isinstance(per_sample_shapes, list | tuple): + if len(per_sample_shapes) != len(global_indexes): + raise ValueError( + f"per_sample_shapes length ({len(per_sample_shapes)}) does not match " + f"number of global_indexes ({len(global_indexes)}) for field '{field_name}'; " + f"skipping per_sample_shapes normalization." + ) + else: + field_copy["per_sample_shapes"] = { + global_indexes[i]: per_sample_shapes[i] for i in range(len(global_indexes)) + } + + normalized_field_schema[field_name] = field_copy + + # convert per_sample_shapes into dict for field in field_schema.values(): per_sample_shapes = field.get("per_sample_shapes", None) if per_sample_shapes: @@ -232,7 +251,7 @@ async def notify_data_update( body={ "partition_id": partition_id, "global_indexes": global_indexes, - "field_schema": field_schema, + "field_schema": normalized_field_schema, "custom_backend_meta": custom_backend_meta, }, ).serialize()