From 6156ed80e7ed9b2a2d9f21cc898159cefeaf1d88 Mon Sep 17 00:00:00 2001 From: Evelynn-V Date: Tue, 10 Feb 2026 11:17:35 +0800 Subject: [PATCH 1/4] move custom_data to samplemeta and move custom_backend_meta to fieldmeta Signed-off-by: Evelynn-V --- tests/test_metadata.py | 530 +++++++++++++++++++++++++------------ transfer_queue/metadata.py | 108 ++------ 2 files changed, 385 insertions(+), 253 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 2bbf40c6..f40b0f8f 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -25,7 +25,7 @@ # Setup path parent_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(parent_dir)) +sys.path.insert(0, str(parent_dir)) from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @@ -37,15 +37,26 @@ class TestFieldMeta: def test_field_meta_is_ready(self): """Test the is_ready property based on production status.""" field_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"test_field": {"dtype": torch.float32}} ) assert field_ready.is_ready is True field_not_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.NOT_PRODUCED + name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.NOT_PRODUCED, + _custom_backend_meta={"test_field": {"dtype": torch.float32}} ) assert field_not_ready.is_ready is False + def test_filed_meta_complete_integrity(self): + """Test the complete_integrity property based on production status.""" + field_complete = FieldMeta( + name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"test_field": {"dtype": torch.float32}} + ) + + assert field_complete.name == "test_field" + assert field_complete._custom_backend_meta["test_field"]["dtype"] == torch.float32 class TestSampleMeta: """SampleMeta learning examples.""" @@ -57,14 +68,16 @@ def test_sample_meta_union(self): "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), } - sample1 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields1) + sample1 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields1, + custom_meta={"fields": "fields1", "global_index": 0}) # Create second sample with additional fields fields2 = { "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), } - sample2 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields2) + sample2 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields2, + custom_meta={"fields": "fields2", "global_index": 0}) # Union samples result = sample1.union(sample2) @@ -74,18 +87,25 @@ def test_sample_meta_union(self): assert "field2" in result.fields # From sample2 assert "field3" in result.fields + assert result.custom_meta["fields"] == "fields2" + assert result.custom_meta["global_index"] == 0 + def test_sample_meta_union_validation_error(self): """Example: Union validation catches mismatched global indexes.""" sample1 = SampleMeta( partition_id="partition_0", global_index=0, - fields={"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))}, + fields={"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,), + _custom_backend_meta={"backend_type": "float32_tensor"})}, + custom_meta={"source": "dataset_A"}, ) sample2 = SampleMeta( partition_id="partition_0", global_index=1, # Different global index - fields={"field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,))}, + fields={"field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,), + _custom_backend_meta={"backend_type": "int64_tensor"})}, + custom_meta={"source": "dataset_B"}, ) with pytest.raises(ValueError) as exc_info: @@ -96,30 +116,55 @@ def test_sample_meta_add_fields(self): """Example: Add new fields to a sample.""" initial_fields = { "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="field1", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"field1": {"dtype": torch.float32}} ) } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=initial_fields) + sample = SampleMeta(partition_id="partition_0", global_index=0, fields=initial_fields, + custom_meta={"fields": "fields1", "global_index": 0}) new_fields = { "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(3,), production_status=ProductionStatus.READY_FOR_CONSUME + name="field2", dtype=torch.int64, shape=(3,), production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"field2": {"dtype": torch.int64}} ) } sample.add_fields(new_fields) assert "field1" in sample.fields assert "field2" in sample.fields + assert sample.fields["field2"]._custom_backend_meta["field2"]["dtype"] == torch.int64 + assert sample.custom_meta["fields"] == "fields1" assert sample.is_ready is True def test_sample_meta_select_fields(self): """Example: Select specific fields from a sample.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), + "field1": FieldMeta( + name="field1", + dtype=torch.float32, + shape=(2,), + _custom_backend_meta={"backend_type": "float32_tensor"} + ), + "field2": FieldMeta( + name="field2", + dtype=torch.int64, + shape=(3,), + _custom_backend_meta={"backend_type": "int64_tensor"} + ), + "field3": FieldMeta( + name="field3", + dtype=torch.bool, + shape=(4,), + _custom_backend_meta={"backend_type": "bool_tensor"} + ), } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) + sample = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=fields, + custom_meta={"source": "dataset_X", "priority": "high"} + ) # Select only field1 and field3 selected_sample = sample.select_fields(["field1", "field3"]) @@ -127,6 +172,9 @@ def test_sample_meta_select_fields(self): assert "field1" in selected_sample.fields assert "field3" in selected_sample.fields assert "field2" not in selected_sample.fields + # Verify custom_backend_meta is preserved in selected fields + assert selected_sample.fields["field1"]._custom_backend_meta["backend_type"] == "float32_tensor" + assert selected_sample.fields["field3"]._custom_backend_meta["backend_type"] == "bool_tensor" # Original sample is unchanged assert len(sample.fields) == 3 # Selected sample has correct metadata @@ -134,14 +182,32 @@ def test_sample_meta_select_fields(self): assert selected_sample.fields["field1"].shape == (2,) assert selected_sample.global_index == 0 assert selected_sample.partition_id == "partition_0" + # Verify custom_meta is deep copied correctly + assert selected_sample.custom_meta == {"source": "dataset_X", "priority": "high"} + assert selected_sample.custom_meta is not sample.custom_meta # Ensure deep copy def test_sample_meta_select_fields_with_nonexistent_fields(self): """Example: Select fields ignores non-existent field names.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), + "field1": FieldMeta( + name="field1", + dtype=torch.float32, + shape=(2,), + _custom_backend_meta={"backend_type": "float32_tensor"} + ), + "field2": FieldMeta( + name="field2", + dtype=torch.int64, + shape=(3,), + _custom_backend_meta={"backend_type": "int64_tensor"} + ), } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) + sample = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=fields, + custom_meta={"valid": True} + ) # Try to select a field that doesn't exist selected_sample = sample.select_fields(["field1", "nonexistent_field"]) @@ -150,14 +216,33 @@ def test_sample_meta_select_fields_with_nonexistent_fields(self): assert "field1" in selected_sample.fields assert "nonexistent_field" not in selected_sample.fields assert "field2" not in selected_sample.fields + # Verify custom_backend_meta is preserved for selected field + assert selected_sample.fields["field1"]._custom_backend_meta["backend_type"] == "float32_tensor" + # Verify custom_meta is preserved + assert selected_sample.custom_meta == {"valid": True} def test_sample_meta_select_fields_empty_list(self): """Example: Select with empty field list returns sample with no fields.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), + "field1": FieldMeta( + name="field1", + dtype=torch.float32, + shape=(2,), + _custom_backend_meta={"backend_type": "float32_tensor"} + ), + "field2": FieldMeta( + name="field2", + dtype=torch.int64, + shape=(3,), + _custom_backend_meta={"backend_type": "int64_tensor"} + ), } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) + sample = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=fields, + custom_meta={"metadata_version": 2} + ) # Select with empty list selected_sample = sample.select_fields([]) @@ -165,24 +250,35 @@ def test_sample_meta_select_fields_empty_list(self): assert len(selected_sample.fields) == 0 assert selected_sample.global_index == 0 assert selected_sample.partition_id == "partition_0" - + # Verify custom_meta is preserved even with no fields + assert selected_sample.custom_meta == {"metadata_version": 2} class TestBatchMeta: """BatchMeta learning examples - Core Operations.""" def test_batch_meta_chunk(self): """Example: Split a batch into multiple chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + # Initialize samples with custom_meta at SampleMeta level and _custom_backend_meta at FieldMeta level + samples = [] + for i in range(10): + fields = { + "test_field": FieldMeta( + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} # Moved to FieldMeta + ) + } + sample = SampleMeta( + partition_id="partition_0", + global_index=i, + fields=fields, + custom_meta={"uid": i} # Moved to SampleMeta ) - } - samples = [SampleMeta(partition_id="partition_0", global_index=i, fields=fields) for i in range(10)] - batch = BatchMeta( - samples=samples, - custom_meta={i: {"uid": i} for i in range(10)}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in range(10)}, - ) + samples.append(sample) + + batch = BatchMeta(samples=samples) # Removed custom_meta/_custom_backend_meta params # Chunk into 3 parts chunks = batch.chunk(3) @@ -192,35 +288,45 @@ def test_batch_meta_chunk(self): assert len(chunks[1]) == 3 assert len(chunks[2]) == 3 - # validate custom_meta is chunked - assert 0 in chunks[0].custom_meta - assert 1 in chunks[0].custom_meta - assert 2 in chunks[0].custom_meta - assert 3 in chunks[0].custom_meta - assert 4 not in chunks[0].custom_meta - assert 4 in chunks[1].custom_meta - - # validate _custom_backend_meta is chunked - assert 0 in chunks[0]._custom_backend_meta - assert 1 in chunks[0]._custom_backend_meta - assert 2 in chunks[0]._custom_backend_meta - assert 3 in chunks[0]._custom_backend_meta - assert 4 not in chunks[0]._custom_backend_meta - assert 4 in chunks[1]._custom_backend_meta + assert 0 in chunks[0].global_indexes + assert 1 in chunks[0].global_indexes + assert 2 in chunks[0].global_indexes + assert 3 in chunks[0].global_indexes + assert 4 not in chunks[0].global_indexes + assert 4 in chunks[1].global_indexes + + assert chunks[0].samples[0].custom_meta["uid"] == 0 + assert chunks[0].samples[1].custom_meta["uid"] == 1 + assert chunks[0].samples[2].custom_meta["uid"] == 2 + assert chunks[0].samples[3].custom_meta["uid"] == 3 + assert chunks[1].samples[0].custom_meta["uid"] == 4 + + # Validate _custom_backend_meta is preserved in fields (minimal change: check via fields) + assert chunks[0].samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 + assert chunks[1].samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 def test_batch_meta_chunk_by_partition(self): - """Example: Split a batch into multiple chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + """Example: Split a batch into multiple chunks by partition.""" + samples = [] + for i in range(10): + fields = { + "test_field": FieldMeta( + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} + ) + } + sample = SampleMeta( + partition_id=f"partition_{i % 4}", + global_index=i + 10, + fields=fields, + custom_meta={"uid": i + 10} ) - } - samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i + 10, fields=fields) for i in range(10)] - batch = BatchMeta( - samples=samples, - custom_meta={i + 10: {"uid": i + 10} for i in range(10)}, - _custom_backend_meta={i + 10: {"test_field": {"dtype": torch.float32}} for i in range(10)}, - ) + samples.append(sample) + + batch = BatchMeta(samples=samples) # Removed custom_meta/_custom_backend_meta params # Chunk according to partition_id chunks = batch.chunk_by_partition() @@ -239,19 +345,15 @@ def test_batch_meta_chunk_by_partition(self): assert chunks[3].partition_ids == ["partition_3", "partition_3"] assert chunks[3].global_indexes == [13, 17] - # validate custom_meta is chunked - assert 10 in chunks[0].custom_meta - assert 14 in chunks[0].custom_meta - assert 18 in chunks[0].custom_meta - assert 11 not in chunks[0].custom_meta - assert 11 in chunks[1].custom_meta + # Validate custom_meta preserved in samples + assert chunks[0].samples[0].custom_meta == {"uid": 10} + assert chunks[0].samples[1].custom_meta == {"uid": 14} + assert chunks[0].samples[2].custom_meta == {"uid": 18} + assert chunks[1].samples[0].custom_meta == {"uid": 11} - # validate _custom_backend_meta is chunked - assert 10 in chunks[0]._custom_backend_meta - assert 14 in chunks[0]._custom_backend_meta - assert 18 in chunks[0]._custom_backend_meta - assert 11 not in chunks[0]._custom_backend_meta - assert 11 in chunks[1]._custom_backend_meta + # Validate _custom_backend_meta preserved in fields + assert chunks[0].samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 + assert chunks[1].samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 def test_batch_meta_init_validation_error_different_field_names(self): """Example: Init validation catches samples with different field names.""" @@ -272,27 +374,27 @@ def test_batch_meta_concat(self): """Example: Concatenate multiple batches.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} ) } - # Create two batches + # Create two batches with samples containing custom_meta batch1 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ], - custom_meta={i: {"uid": i} for i in [0, 1]}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [0, 1]}, + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"uid": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"uid": 1}), + ] ) batch2 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - SampleMeta(partition_id="partition_0", global_index=3, fields=fields), - ], - custom_meta={i: {"uid": i} for i in [2, 3]}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [2, 3]}, + SampleMeta(partition_id="partition_0", global_index=2, fields=fields, custom_meta={"uid": 2}), + SampleMeta(partition_id="partition_0", global_index=3, fields=fields, custom_meta={"uid": 3}), + ] ) # Concatenate batches @@ -300,8 +402,13 @@ def test_batch_meta_concat(self): assert len(result) == 4 assert result.global_indexes == [0, 1, 2, 3] - assert result.custom_meta == {i: {"uid": i} for i in [0, 1, 2, 3]} - assert result._custom_backend_meta == {i: {"test_field": {"dtype": torch.float32}} for i in [0, 1, 2, 3]} + # Validate custom_meta preserved via samples (minimal change) + assert result.samples[0].custom_meta == {"uid": 0} + assert result.samples[1].custom_meta == {"uid": 1} + assert result.samples[2].custom_meta == {"uid": 2} + assert result.samples[3].custom_meta == {"uid": 3} + # Validate _custom_backend_meta preserved via fields + assert result.samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 def test_batch_meta_concat_with_tensor_extra_info(self): """Example: Concat handles tensor extra_info by concatenating along dim=0.""" @@ -396,33 +503,47 @@ def test_batch_meta_concat_with_mixed_types(self): def test_batch_meta_union(self): """Example: Union two batches with matching global indexes.""" fields1 = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), + "field1": FieldMeta( + name="field1", + dtype=torch.float32, + shape=(2,), + _custom_backend_meta={"backend": "float32"} + ), + "field2": FieldMeta( + name="field2", + dtype=torch.int64, + shape=(3,), + _custom_backend_meta={"backend": "int64"} + ), } fields2 = { - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), + "field2": FieldMeta( + name="field2", + dtype=torch.int64, + shape=(3,), + _custom_backend_meta={"backend": "int64", "compression": "lz4"} + ), + "field3": FieldMeta( + name="field3", + dtype=torch.bool, + shape=(4,), + _custom_backend_meta={"backend": "bool"} + ), } batch1 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=8, fields=fields1), - SampleMeta(partition_id="partition_0", global_index=9, fields=fields1), - ], - _custom_backend_meta={ - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}} for i in [8, 9] - }, + SampleMeta(partition_id="partition_0", global_index=8, fields=fields1, custom_meta={"source": "A"}), + SampleMeta(partition_id="partition_0", global_index=9, fields=fields1, custom_meta={"source": "A"}), + ] ) batch1.extra_info["info1"] = "value1" batch2 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=8, fields=fields2), - SampleMeta(partition_id="partition_0", global_index=9, fields=fields2), - ], - _custom_backend_meta={ - i: {"field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} for i in [8, 9] - }, + SampleMeta(partition_id="partition_0", global_index=8, fields=fields2, custom_meta={"source": "B"}), + SampleMeta(partition_id="partition_0", global_index=9, fields=fields2, custom_meta={"source": "B"}), + ] ) batch2.extra_info["info2"] = "value2" @@ -434,15 +555,14 @@ def test_batch_meta_union(self): assert "field1" in sample.fields assert "field2" in sample.fields assert "field3" in sample.fields + # Verify _custom_backend_meta preserved correctly + assert sample.fields["field2"]._custom_backend_meta["backend"] == "int64" + assert sample.fields["field2"]._custom_backend_meta.get("compression") == "lz4" # Extra info is merged assert result.extra_info["info1"] == "value1" assert result.extra_info["info2"] == "value2" - - # _custom_backend_meta is merged - assert result._custom_backend_meta == { - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} - for i in [8, 9] - } + # Verify custom_meta merged correctly (last wins) + assert result.samples[0].custom_meta["source"] == "B" def test_batch_meta_union_validation(self): """Example: Union validation catches mismatched conditions.""" @@ -469,9 +589,9 @@ def test_batch_meta_reorder(self): ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields, custom_meta={"pos": 0}), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields, custom_meta={"pos": 1}), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields, custom_meta={"pos": 2}), ] batch = BatchMeta(samples=samples) @@ -483,6 +603,10 @@ def test_batch_meta_reorder(self): assert batch.samples[0].batch_index == 0 assert batch.samples[1].batch_index == 1 assert batch.samples[2].batch_index == 2 + # custom_meta preserved correctly + assert batch.samples[0].custom_meta == {"pos": 2} + assert batch.samples[1].custom_meta == {"pos": 0} + assert batch.samples[2].custom_meta == {"pos": 1} def test_batch_meta_add_fields(self): """Example: Add fields from TensorDict to all samples.""" @@ -507,30 +631,37 @@ def test_batch_meta_add_fields(self): assert "new_field1" in sample.fields assert "new_field2" in sample.fields assert sample.is_ready is True + # Verify new fields have default _custom_backend_meta (empty dict) + assert sample.fields["new_field1"]._custom_backend_meta == {} + assert sample.fields["new_field2"]._custom_backend_meta == {} def test_batch_meta_select_fields(self): """Example: Select specific fields from all samples in a batch.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), + "field1": FieldMeta( + name="field1", + dtype=torch.float32, + shape=(2,), + _custom_backend_meta={"precision": "fp32"} + ), + "field2": FieldMeta( + name="field2", + dtype=torch.int64, + shape=(3,), + _custom_backend_meta={"encoding": "varint"} + ), + "field3": FieldMeta( + name="field3", + dtype=torch.bool, + shape=(4,), + _custom_backend_meta={"packing": "bit"} + ), } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"version": 1}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"version": 1}), ] - batch = BatchMeta( - samples=samples, - extra_info={"test_key": "test_value"}, - _custom_backend_meta={ - i: { - "field1": {"dtype": torch.float32}, - "field2": {"dtype": torch.int64}, - "field3": {"dtype": torch.bool}, - } - for i in [0, 1] - }, - ) + batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) # Select only field1 and field3 selected_batch = batch.select_fields(["field1", "field3"]) @@ -541,21 +672,16 @@ def test_batch_meta_select_fields(self): assert "field1" in sample.fields assert "field3" in sample.fields assert "field2" not in sample.fields + # Verify _custom_backend_meta preserved for selected fields + assert sample.fields["field1"]._custom_backend_meta["precision"] == "fp32" + assert sample.fields["field3"]._custom_backend_meta["packing"] == "bit" # Original batch is unchanged assert len(batch.samples[0].fields) == 3 # Extra info is preserved assert selected_batch.extra_info["test_key"] == "test_value" - # Global indexes are preserved + # Global indexes and custom_meta preserved assert selected_batch.global_indexes == [0, 1] - - # _custom_backend_meta is selected - assert "field1" in selected_batch._custom_backend_meta[0] - assert "field2" not in selected_batch._custom_backend_meta[0] - assert "field3" in selected_batch._custom_backend_meta[0] - assert "field1" in selected_batch._custom_backend_meta[1] - assert "field2" not in selected_batch._custom_backend_meta[1] - assert "field3" in selected_batch._custom_backend_meta[1] - + assert selected_batch.samples[0].custom_meta == {"version": 1} def test_batch_meta_select_fields_with_nonexistent_fields(self): """Example: Select fields ignores non-existent field names in batch.""" fields = { @@ -618,10 +744,18 @@ def test_batch_meta_select_fields_preserves_field_metadata(self): """Example: Selected fields preserve their original metadata.""" fields = { "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME + name="field1", + dtype=torch.float32, + shape=(2, 3), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"source": "sensor_a"} ), "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(5,), production_status=ProductionStatus.NOT_PRODUCED + name="field2", + dtype=torch.int64, + shape=(5,), + production_status=ProductionStatus.NOT_PRODUCED, + _custom_backend_meta={"source": "sensor_b"} ), } samples = [ @@ -637,23 +771,34 @@ def test_batch_meta_select_fields_preserves_field_metadata(self): assert selected_field.shape == (2, 3) assert selected_field.production_status == ProductionStatus.READY_FOR_CONSUME assert selected_field.name == "field1" + assert selected_field._custom_backend_meta["source"] == "sensor_a" def test_batch_meta_select_samples(self): """Example: Select specific samples from a batch.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), + "field1": FieldMeta( + name="field1", + dtype=torch.float32, + shape=(2,), + _custom_backend_meta={"backend": "float32"} + ), + "field2": FieldMeta( + name="field2", + dtype=torch.int64, + shape=(3,), + _custom_backend_meta={"backend": "int64"} + ), } samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), - SampleMeta(partition_id="partition_0", global_index=7, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields, custom_meta={"sample_id": 4}), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields, custom_meta={"sample_id": 5}), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields, custom_meta={"sample_id": 6}), + SampleMeta(partition_id="partition_0", global_index=7, fields=fields, custom_meta={"sample_id": 7}), ] batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) # Select samples at indices [0, 2] - selected_batch = batch.select_samples([0, 2]) # This will select the first two samples with global_index=4/5 + selected_batch = batch.select_samples([0, 2]) # Check number of samples assert len(selected_batch) == 2 @@ -663,6 +808,8 @@ def test_batch_meta_select_samples(self): for sample in selected_batch.samples: assert "field1" in sample.fields assert "field2" in sample.fields + # MINIMAL CHANGE: Verify custom_meta preserved via get_all_custom_meta() + assert sample.global_index in selected_batch.get_all_custom_meta() # Original batch is unchanged assert len(batch) == 4 # Extra info is preserved @@ -672,13 +819,17 @@ def test_batch_meta_select_samples_all_indices(self): """Example: Select all samples using complete index list.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields, custom_meta={"sample_id": 4}), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields, custom_meta={"sample_id": 5}), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields, custom_meta={"sample_id": 6}), ] batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) @@ -688,6 +839,10 @@ def test_batch_meta_select_samples_all_indices(self): # All samples are selected assert len(selected_batch) == 3 assert selected_batch.global_indexes == [4, 5, 6] + # MINIMAL CHANGE: Verify all custom_meta preserved + assert 4 in selected_batch.get_all_custom_meta() + assert 5 in selected_batch.get_all_custom_meta() + assert 6 in selected_batch.get_all_custom_meta() # Extra info is preserved assert selected_batch.extra_info["test_key"] == "test_value" @@ -695,13 +850,17 @@ def test_batch_meta_select_samples_single_sample(self): """Example: Select a single sample from batch.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), + SampleMeta(partition_id="partition_0", global_index=2, fields=fields, custom_meta={"sample_id": 2}), ] batch = BatchMeta(samples=samples) @@ -710,18 +869,24 @@ def test_batch_meta_select_samples_single_sample(self): assert len(selected_batch) == 1 assert selected_batch.global_indexes == [1] + # MINIMAL CHANGE: Verify custom_meta preserved + assert 1 in selected_batch.get_all_custom_meta() assert selected_batch.samples[0].batch_index == 0 # New batch index def test_batch_meta_select_samples_empty_list(self): """Example: Select with empty list returns empty batch.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), ] batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) @@ -732,18 +897,24 @@ def test_batch_meta_select_samples_empty_list(self): assert selected_batch.global_indexes == [] # Extra info is still preserved assert selected_batch.extra_info["test_key"] == "test_value" + # MINIMAL CHANGE: get_all_custom_meta returns empty dict for empty batch + assert selected_batch.get_all_custom_meta() == {} def test_batch_meta_select_samples_reverse_order(self): """Example: Select samples in reverse order.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), + SampleMeta(partition_id="partition_0", global_index=2, fields=fields, custom_meta={"sample_id": 2}), ] batch = BatchMeta(samples=samples) @@ -752,6 +923,10 @@ def test_batch_meta_select_samples_reverse_order(self): assert len(selected_batch) == 3 assert selected_batch.global_indexes == [2, 1, 0] + # MINIMAL CHANGE: Verify all custom_meta preserved in new order + assert 2 in selected_batch.get_all_custom_meta() + assert 1 in selected_batch.get_all_custom_meta() + assert 0 in selected_batch.get_all_custom_meta() # Batch indexes are re-assigned assert selected_batch.samples[0].global_index == 2 assert selected_batch.samples[1].global_index == 1 @@ -761,12 +936,16 @@ def test_batch_meta_select_samples_with_extra_info(self): """Example: Select samples preserves all extra info types.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), ] batch = BatchMeta(samples=samples) @@ -784,6 +963,8 @@ def test_batch_meta_select_samples_with_extra_info(self): assert selected_batch.extra_info["string"] == "test_string" assert selected_batch.extra_info["number"] == 42 assert selected_batch.extra_info["list"] == [1, 2, 3] + # MINIMAL CHANGE: Verify custom_meta preserved + assert 0 in selected_batch.get_all_custom_meta() # ===================================================== # Custom Meta Tests @@ -1091,12 +1272,16 @@ def test_batch_meta_chunk_with_more_chunks_than_samples(self): """Example: Chunking when chunks > samples produces empty chunks.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), ] batch = BatchMeta(samples=samples) @@ -1111,23 +1296,34 @@ def test_batch_meta_chunk_with_more_chunks_than_samples(self): assert len(chunks[2]) == 0 assert len(chunks[3]) == 0 assert len(chunks[4]) == 0 + # MINIMAL CHANGE: Verify custom_meta preserved in non-empty chunks + if len(chunks[0]) > 0: + assert 0 in chunks[0].get_all_custom_meta() + if len(chunks[1]) > 0: + assert 1 in chunks[1].get_all_custom_meta() def test_batch_meta_concat_with_empty_batches(self): """Example: Concat handles empty batches gracefully.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32} ) } batch1 = BatchMeta(samples=[]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) + batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0})]) batch3 = BatchMeta(samples=[]) # Empty batches are filtered out result = BatchMeta.concat([batch1, batch2, batch3]) assert len(result) == 1 assert result.global_indexes == [0] + # MINIMAL CHANGE: Verify custom_meta preserved + assert 0 in result.get_all_custom_meta() def test_batch_meta_concat_validation_error(self): """Example: Concat validation catches field name mismatches.""" diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 5e254047..b989d521 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -48,6 +48,7 @@ class FieldMeta: dtype: Optional[Any] # Data type (e.g., torch.float32, numpy.float32) shape: Optional[Any] # Data shape (e.g., torch.Size([3, 224, 224]), (3, 224, 224)) production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED + _custom_backend_meta: dict[str, Any] = dataclasses.field(default_factory=dict) def __str__(self) -> str: return ( @@ -70,6 +71,7 @@ def from_dict(cls, data: dict) -> "FieldMeta": production_status=ProductionStatus(str(data["production_status"])) if isinstance(data["production_status"], int | str) else data["production_status"], + _custom_backend_meta=data.get("_custom_backend_meta", {}), ) @@ -80,6 +82,7 @@ class SampleMeta: partition_id: str # Partition id, used for data versioning global_index: int # Global row index, uniquely identifies a data sample fields: dict[str, FieldMeta] # Fields of interest for this sample + custom_meta: dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): """Initialize is_ready property based on field readiness""" @@ -137,11 +140,11 @@ def select_fields(self, field_names: list[str]) -> "SampleMeta": selected_fields = {name: self.fields[name] for name in field_names if name in self.fields} # construct new SampleMeta instance - # TODO(tianyi): (maybe) move _custom_backend_meta and custom_meta to FieldMeta level? selected_sample_meta = SampleMeta( fields=selected_fields, partition_id=self.partition_id, global_index=self.global_index, + custom_meta=copy.deepcopy(self.custom_meta) ) return selected_sample_meta @@ -168,6 +171,9 @@ def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta": # Merge fields self.fields = _union_fields(self.fields, other.fields) + # Merge custom meta + self.custom_meta = {**self.custom_meta, **other.custom_meta} + # Update is_ready property object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) return self @@ -193,6 +199,7 @@ def from_dict(cls, data: dict) -> "SampleMeta": partition_id=data["partition_id"], global_index=data["global_index"], fields=fields, + custom_meta=data.get("custom_meta", {}), ) @@ -205,12 +212,6 @@ class BatchMeta: # external meta for non-sample level information extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - # user-defined meta for each sample - custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) - - # internal meta for different storage backends in per-sample per-field level - _custom_backend_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) - def __post_init__(self): """Initialize all computed properties during initialization""" self.samples = copy.deepcopy(self.samples) @@ -235,16 +236,7 @@ def __post_init__(self): object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) - # filter custom_meta and _custom_backend_meta - self.custom_meta = copy.deepcopy( - {k: self.custom_meta[k] for k in self.global_indexes if k in self.custom_meta} - ) - self._custom_backend_meta = copy.deepcopy( - {k: self._custom_backend_meta[k] for k in self.global_indexes if k in self._custom_backend_meta} - ) else: - self.custom_meta = {} - self._custom_backend_meta = {} object.__setattr__(self, "_global_indexes", []) object.__setattr__(self, "_field_names", []) object.__setattr__(self, "_partition_ids", []) @@ -321,7 +313,10 @@ def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: if global_index not in self.global_indexes: raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.") - self.custom_meta[global_index] = copy.deepcopy(meta_dict) + for sample in self.samples: + if sample.global_index == global_index: + sample.custom_meta = copy.deepcopy(meta_dict) + break def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: """ @@ -330,7 +325,7 @@ def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: Returns: A deep copy of the custom_meta dictionary """ - return copy.deepcopy(self.custom_meta) + return {s.global_index: copy.deepcopy(s.custom_meta) for s in self.samples if s.custom_meta} def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): """ @@ -355,8 +350,12 @@ def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} " f"do not exist in this batch." ) - - self.custom_meta.update(new_meta) + + for global_index, meta_dict in new_meta.items(): + for sample in self.samples: + if sample.global_index == global_index: + sample.custom_meta.update(copy.deepcopy(meta_dict)) + break def clear_custom_meta(self) -> None: """ @@ -364,7 +363,8 @@ def clear_custom_meta(self) -> None: This method removes all entries from the custom_meta dictionary. """ - self.custom_meta.clear() + for sample in self.samples: + sample.custom_meta.clear() def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": """ @@ -407,18 +407,10 @@ def select_samples(self, indexes: list[int]) -> "BatchMeta": selected_samples = [self.samples[i] for i in indexes] - global_indexes = [self.global_indexes[i] for i in indexes] - selected_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} - selected_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta - } - # construct new BatchMeta instance selected_batch_meta = BatchMeta( samples=selected_samples, extra_info=self.extra_info, - custom_meta=selected_custom_meta, - _custom_backend_meta=selected_custom_backend_meta, ) return selected_batch_meta @@ -437,22 +429,10 @@ def select_fields(self, field_names: list[str]) -> "BatchMeta": # select fields for each SampleMeta new_samples = [sample.select_fields(field_names=field_names) for sample in self.samples] - # select fields in _custom_backend_meta - selected_custom_backend_meta = {} - for idx in self.global_indexes: - if idx in self._custom_backend_meta: - custom_backend_meta_idx = self._custom_backend_meta[idx] - - selected_custom_backend_meta[idx] = { - field: custom_backend_meta_idx[field] for field in field_names if field in custom_backend_meta_idx - } - # construct new BatchMeta instance new_batch_meta = BatchMeta( samples=new_samples, extra_info=self.extra_info, - custom_meta=self.custom_meta, - _custom_backend_meta=selected_custom_backend_meta, ) return new_batch_meta @@ -466,21 +446,9 @@ def __getitem__(self, item): sample_meta = self.samples[item] if self.samples else [] global_idx = self.global_indexes[item] - if global_idx in self.custom_meta: - custom_meta = {global_idx: self.custom_meta[global_idx]} - else: - custom_meta = {} - - if global_idx in self._custom_backend_meta: - custom_backend_meta = {global_idx: self._custom_backend_meta[global_idx]} - else: - custom_backend_meta = {} - return BatchMeta( samples=[sample_meta], extra_info=self.extra_info, - custom_meta=custom_meta, - _custom_backend_meta=custom_backend_meta, ) else: raise TypeError(f"Indexing with {type(item)} is not supported now!") @@ -514,16 +482,9 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: current_chunk_size = base_size + 1 if i < remainder else base_size end = start + current_chunk_size chunk_samples = self.samples[start:end] - global_indexes = self.global_indexes[start:end] - chunk_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} - chunk_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta - } chunk = BatchMeta( samples=chunk_samples, extra_info=self.extra_info, - custom_meta=chunk_custom_meta, - _custom_backend_meta=chunk_custom_backend_meta, ) chunk_list.append(chunk) start = end @@ -585,15 +546,9 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": # Merge all extra_info dictionaries from the chunks merged_extra_info = dict() - merged_custom_meta = dict() - merged_custom_backend_meta = dict() values_by_key = defaultdict(list) for chunk in data: - # For the sample-level custom_meta and field-level _custom_backend_meta, we directly update the dict. - merged_custom_meta.update(chunk.custom_meta) - merged_custom_backend_meta.update(chunk._custom_backend_meta) - for key, value in chunk.extra_info.items(): values_by_key[key].append(value) @@ -622,8 +577,6 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": return BatchMeta( samples=all_samples, extra_info=merged_extra_info, - custom_meta=merged_custom_meta, - _custom_backend_meta=merged_custom_backend_meta, ) def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]: @@ -670,24 +623,9 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet # Merge extra info dictionaries merged_extra_info = {**self.extra_info, **other.extra_info} - # Merge custom_meta dictionaries - merged_custom_meta = {**self.custom_meta, **other.custom_meta} - - # Merge custom_backend_meta dictionaries - merged_custom_backend_meta = {} - for idx in self.global_indexes: - if idx in self._custom_backend_meta and idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx], **other._custom_backend_meta[idx]} - elif idx in self._custom_backend_meta: - merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx]} - elif idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = {**other._custom_backend_meta[idx]} - return BatchMeta( samples=merged_samples, extra_info=merged_extra_info, - custom_meta=merged_custom_meta, - _custom_backend_meta=merged_custom_backend_meta, ) def reorder(self, indexes: list[int]): @@ -784,7 +722,7 @@ def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": """ if extra_info is None: extra_info = {} - return cls(samples=[], extra_info=extra_info, custom_meta={}, _custom_backend_meta={}) + return cls(samples=[], extra_info=extra_info) def __str__(self): sample_strs = ", ".join(str(sample) for sample in self.samples) @@ -803,8 +741,6 @@ def from_dict(cls, data: dict) -> "BatchMeta": return cls( samples=samples, extra_info=data.get("extra_info", {}), - custom_meta=data.get("custom_meta", {}), - _custom_backend_meta=data.get("_custom_backend_meta", {}), ) From 70edf78a0f0307d94c99174450ed45531971156a Mon Sep 17 00:00:00 2001 From: Evelynn-V Date: Tue, 10 Feb 2026 14:35:47 +0800 Subject: [PATCH 2/4] move custom_data to samplemeta and move custom_backend_meta to fieldmeta Signed-off-by: Evelynn-V --- tests/test_metadata.py | 299 +++++++++++------------- transfer_queue/controller.py | 12 +- transfer_queue/metadata.py | 9 +- transfer_queue/storage/managers/base.py | 5 +- 4 files changed, 155 insertions(+), 170 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index f40b0f8f..3a1b115e 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -37,27 +37,37 @@ class TestFieldMeta: def test_field_meta_is_ready(self): """Test the is_ready property based on production status.""" field_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"test_field": {"dtype": torch.float32}} + name="test_field", + dtype=torch.float32, + shape=(2, 3), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"test_field": {"dtype": torch.float32}}, ) assert field_ready.is_ready is True field_not_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.NOT_PRODUCED, - _custom_backend_meta={"test_field": {"dtype": torch.float32}} + name="test_field", + dtype=torch.float32, + shape=(2, 3), + production_status=ProductionStatus.NOT_PRODUCED, + _custom_backend_meta={"test_field": {"dtype": torch.float32}}, ) assert field_not_ready.is_ready is False def test_filed_meta_complete_integrity(self): """Test the complete_integrity property based on production status.""" field_complete = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"test_field": {"dtype": torch.float32}} + name="test_field", + dtype=torch.float32, + shape=(2, 3), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"test_field": {"dtype": torch.float32}}, ) assert field_complete.name == "test_field" assert field_complete._custom_backend_meta["test_field"]["dtype"] == torch.float32 + class TestSampleMeta: """SampleMeta learning examples.""" @@ -68,16 +78,24 @@ def test_sample_meta_union(self): "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), } - sample1 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields1, - custom_meta={"fields": "fields1", "global_index": 0}) + sample1 = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=fields1, + custom_meta={"fields": "fields1", "global_index": 0}, + ) # Create second sample with additional fields fields2 = { "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), } - sample2 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields2, - custom_meta={"fields": "fields2", "global_index": 0}) + sample2 = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=fields2, + custom_meta={"fields": "fields2", "global_index": 0}, + ) # Union samples result = sample1.union(sample2) @@ -95,16 +113,25 @@ def test_sample_meta_union_validation_error(self): sample1 = SampleMeta( partition_id="partition_0", global_index=0, - fields={"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,), - _custom_backend_meta={"backend_type": "float32_tensor"})}, + fields={ + "field1": FieldMeta( + name="field1", + dtype=torch.float32, + shape=(2,), + _custom_backend_meta={"backend_type": "float32_tensor"}, + ) + }, custom_meta={"source": "dataset_A"}, ) sample2 = SampleMeta( partition_id="partition_0", global_index=1, # Different global index - fields={"field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,), - _custom_backend_meta={"backend_type": "int64_tensor"})}, + fields={ + "field2": FieldMeta( + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend_type": "int64_tensor"} + ) + }, custom_meta={"source": "dataset_B"}, ) @@ -116,17 +143,27 @@ def test_sample_meta_add_fields(self): """Example: Add new fields to a sample.""" initial_fields = { "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"field1": {"dtype": torch.float32}} + name="field1", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"field1": {"dtype": torch.float32}}, ) } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=initial_fields, - custom_meta={"fields": "fields1", "global_index": 0}) + sample = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=initial_fields, + custom_meta={"fields": "fields1", "global_index": 0}, + ) new_fields = { "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(3,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"field2": {"dtype": torch.int64}} + name="field2", + dtype=torch.int64, + shape=(3,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"field2": {"dtype": torch.int64}}, ) } sample.add_fields(new_fields) @@ -141,29 +178,20 @@ def test_sample_meta_select_fields(self): """Example: Select specific fields from a sample.""" fields = { "field1": FieldMeta( - name="field1", - dtype=torch.float32, - shape=(2,), - _custom_backend_meta={"backend_type": "float32_tensor"} + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend_type": "float32_tensor"} ), "field2": FieldMeta( - name="field2", - dtype=torch.int64, - shape=(3,), - _custom_backend_meta={"backend_type": "int64_tensor"} + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend_type": "int64_tensor"} ), "field3": FieldMeta( - name="field3", - dtype=torch.bool, - shape=(4,), - _custom_backend_meta={"backend_type": "bool_tensor"} + name="field3", dtype=torch.bool, shape=(4,), _custom_backend_meta={"backend_type": "bool_tensor"} ), } sample = SampleMeta( - partition_id="partition_0", - global_index=0, + partition_id="partition_0", + global_index=0, fields=fields, - custom_meta={"source": "dataset_X", "priority": "high"} + custom_meta={"source": "dataset_X", "priority": "high"}, ) # Select only field1 and field3 @@ -190,24 +218,13 @@ def test_sample_meta_select_fields_with_nonexistent_fields(self): """Example: Select fields ignores non-existent field names.""" fields = { "field1": FieldMeta( - name="field1", - dtype=torch.float32, - shape=(2,), - _custom_backend_meta={"backend_type": "float32_tensor"} + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend_type": "float32_tensor"} ), "field2": FieldMeta( - name="field2", - dtype=torch.int64, - shape=(3,), - _custom_backend_meta={"backend_type": "int64_tensor"} + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend_type": "int64_tensor"} ), } - sample = SampleMeta( - partition_id="partition_0", - global_index=0, - fields=fields, - custom_meta={"valid": True} - ) + sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"valid": True}) # Try to select a field that doesn't exist selected_sample = sample.select_fields(["field1", "nonexistent_field"]) @@ -225,23 +242,14 @@ def test_sample_meta_select_fields_empty_list(self): """Example: Select with empty field list returns sample with no fields.""" fields = { "field1": FieldMeta( - name="field1", - dtype=torch.float32, - shape=(2,), - _custom_backend_meta={"backend_type": "float32_tensor"} + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend_type": "float32_tensor"} ), "field2": FieldMeta( - name="field2", - dtype=torch.int64, - shape=(3,), - _custom_backend_meta={"backend_type": "int64_tensor"} + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend_type": "int64_tensor"} ), } sample = SampleMeta( - partition_id="partition_0", - global_index=0, - fields=fields, - custom_meta={"metadata_version": 2} + partition_id="partition_0", global_index=0, fields=fields, custom_meta={"metadata_version": 2} ) # Select with empty list @@ -253,6 +261,7 @@ def test_sample_meta_select_fields_empty_list(self): # Verify custom_meta is preserved even with no fields assert selected_sample.custom_meta == {"metadata_version": 2} + class TestBatchMeta: """BatchMeta learning examples - Core Operations.""" @@ -263,21 +272,21 @@ def test_batch_meta_chunk(self): for i in range(10): fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} # Moved to FieldMeta + _custom_backend_meta={"dtype": torch.float32}, # Moved to FieldMeta ) } sample = SampleMeta( - partition_id="partition_0", - global_index=i, + partition_id="partition_0", + global_index=i, fields=fields, - custom_meta={"uid": i} # Moved to SampleMeta + custom_meta={"uid": i}, # Moved to SampleMeta ) samples.append(sample) - + batch = BatchMeta(samples=samples) # Removed custom_meta/_custom_backend_meta params # Chunk into 3 parts @@ -311,21 +320,18 @@ def test_batch_meta_chunk_by_partition(self): for i in range(10): fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } sample = SampleMeta( - partition_id=f"partition_{i % 4}", - global_index=i + 10, - fields=fields, - custom_meta={"uid": i + 10} + partition_id=f"partition_{i % 4}", global_index=i + 10, fields=fields, custom_meta={"uid": i + 10} ) samples.append(sample) - + batch = BatchMeta(samples=samples) # Removed custom_meta/_custom_backend_meta params # Chunk according to partition_id @@ -374,11 +380,11 @@ def test_batch_meta_concat(self): """Example: Concatenate multiple batches.""" fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } @@ -504,31 +510,20 @@ def test_batch_meta_union(self): """Example: Union two batches with matching global indexes.""" fields1 = { "field1": FieldMeta( - name="field1", - dtype=torch.float32, - shape=(2,), - _custom_backend_meta={"backend": "float32"} + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend": "float32"} ), "field2": FieldMeta( - name="field2", - dtype=torch.int64, - shape=(3,), - _custom_backend_meta={"backend": "int64"} + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend": "int64"} ), } fields2 = { "field2": FieldMeta( - name="field2", - dtype=torch.int64, + name="field2", + dtype=torch.int64, shape=(3,), - _custom_backend_meta={"backend": "int64", "compression": "lz4"} - ), - "field3": FieldMeta( - name="field3", - dtype=torch.bool, - shape=(4,), - _custom_backend_meta={"backend": "bool"} + _custom_backend_meta={"backend": "int64", "compression": "lz4"}, ), + "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,), _custom_backend_meta={"backend": "bool"}), } batch1 = BatchMeta( @@ -639,23 +634,12 @@ def test_batch_meta_select_fields(self): """Example: Select specific fields from all samples in a batch.""" fields = { "field1": FieldMeta( - name="field1", - dtype=torch.float32, - shape=(2,), - _custom_backend_meta={"precision": "fp32"} + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"precision": "fp32"} ), "field2": FieldMeta( - name="field2", - dtype=torch.int64, - shape=(3,), - _custom_backend_meta={"encoding": "varint"} - ), - "field3": FieldMeta( - name="field3", - dtype=torch.bool, - shape=(4,), - _custom_backend_meta={"packing": "bit"} + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"encoding": "varint"} ), + "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,), _custom_backend_meta={"packing": "bit"}), } samples = [ SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"version": 1}), @@ -682,6 +666,7 @@ def test_batch_meta_select_fields(self): # Global indexes and custom_meta preserved assert selected_batch.global_indexes == [0, 1] assert selected_batch.samples[0].custom_meta == {"version": 1} + def test_batch_meta_select_fields_with_nonexistent_fields(self): """Example: Select fields ignores non-existent field names in batch.""" fields = { @@ -744,18 +729,18 @@ def test_batch_meta_select_fields_preserves_field_metadata(self): """Example: Selected fields preserve their original metadata.""" fields = { "field1": FieldMeta( - name="field1", - dtype=torch.float32, - shape=(2, 3), + name="field1", + dtype=torch.float32, + shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"source": "sensor_a"} + _custom_backend_meta={"source": "sensor_a"}, ), "field2": FieldMeta( - name="field2", - dtype=torch.int64, - shape=(5,), + name="field2", + dtype=torch.int64, + shape=(5,), production_status=ProductionStatus.NOT_PRODUCED, - _custom_backend_meta={"source": "sensor_b"} + _custom_backend_meta={"source": "sensor_b"}, ), } samples = [ @@ -777,16 +762,10 @@ def test_batch_meta_select_samples(self): """Example: Select specific samples from a batch.""" fields = { "field1": FieldMeta( - name="field1", - dtype=torch.float32, - shape=(2,), - _custom_backend_meta={"backend": "float32"} + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend": "float32"} ), "field2": FieldMeta( - name="field2", - dtype=torch.int64, - shape=(3,), - _custom_backend_meta={"backend": "int64"} + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend": "int64"} ), } samples = [ @@ -819,11 +798,11 @@ def test_batch_meta_select_samples_all_indices(self): """Example: Select all samples using complete index list.""" fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ @@ -850,11 +829,11 @@ def test_batch_meta_select_samples_single_sample(self): """Example: Select a single sample from batch.""" fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ @@ -877,11 +856,11 @@ def test_batch_meta_select_samples_empty_list(self): """Example: Select with empty list returns empty batch.""" fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ @@ -904,11 +883,11 @@ def test_batch_meta_select_samples_reverse_order(self): """Example: Select samples in reverse order.""" fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ @@ -936,11 +915,11 @@ def test_batch_meta_select_samples_with_extra_info(self): """Example: Select samples preserves all extra info types.""" fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ @@ -1272,11 +1251,11 @@ def test_batch_meta_chunk_with_more_chunks_than_samples(self): """Example: Chunking when chunks > samples produces empty chunks.""" fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ @@ -1306,16 +1285,20 @@ def test_batch_meta_concat_with_empty_batches(self): """Example: Concat handles empty batches gracefully.""" fields = { "test_field": FieldMeta( - name="test_field", - dtype=torch.float32, - shape=(2,), + name="test_field", + dtype=torch.float32, + shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME, - _custom_backend_meta={"dtype": torch.float32} + _custom_backend_meta={"dtype": torch.float32}, ) } batch1 = BatchMeta(samples=[]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0})]) + batch2 = BatchMeta( + samples=[ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}) + ] + ) batch3 = BatchMeta(samples=[]) # Empty batches are filtered out diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 763f3f49..ff2c4a9e 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1293,6 +1293,9 @@ def generate_batch_meta( if mode not in ["fetch", "insert", "force_fetch"]: raise ValueError(f"Invalid mode: {mode}") + custom_meta = partition.get_custom_meta(batch_global_indexes) + custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) + # Generate sample metadata samples = [] for global_index in batch_global_indexes: @@ -1322,11 +1325,16 @@ def generate_batch_meta( dtype = None shape = None + backend_meta = {} + if global_index in custom_backend_meta and field_name in custom_backend_meta[global_index]: + backend_meta = copy.deepcopy(custom_backend_meta[global_index][field_name]) + fields[field_name] = FieldMeta( name=field_name, dtype=dtype, shape=shape, production_status=production_status, + _custom_backend_meta=backend_meta, ) sample = SampleMeta( @@ -1336,12 +1344,8 @@ def generate_batch_meta( ) samples.append(sample) - custom_meta = partition.get_custom_meta(batch_global_indexes) - custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) - batch_meta = BatchMeta(samples=samples) batch_meta.update_custom_meta(custom_meta) - batch_meta._custom_backend_meta.update(custom_backend_meta) return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index b989d521..8f2f055f 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -144,7 +144,7 @@ def select_fields(self, field_names: list[str]) -> "SampleMeta": fields=selected_fields, partition_id=self.partition_id, global_index=self.global_index, - custom_meta=copy.deepcopy(self.custom_meta) + custom_meta=copy.deepcopy(self.custom_meta), ) return selected_sample_meta @@ -350,7 +350,7 @@ def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} " f"do not exist in this batch." ) - + for global_index, meta_dict in new_meta.items(): for sample in self.samples: if sample.global_index == global_index: @@ -444,7 +444,6 @@ def __len__(self) -> int: def __getitem__(self, item): if isinstance(item, int | np.integer): sample_meta = self.samples[item] if self.samples else [] - global_idx = self.global_indexes[item] return BatchMeta( samples=[sample_meta], @@ -525,14 +524,14 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": """ if not data: logger.warning("Try to concat empty BatchMeta chunks. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) + return BatchMeta(samples=[], extra_info={}) # skip empty chunks data = [chunk for chunk in data if chunk and len(chunk.samples) > 0] if len(data) == 0: logger.warning("No valid BatchMeta chunks to concatenate. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) + return BatchMeta(samples=[], extra_info={}) if validate: base_fields = data[0].field_names diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 352b3adc..ec6b296e 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -523,15 +523,14 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): shapes = [] dtypes = [] custom_backend_meta_list = [] - all_custom_backend_meta = copy.deepcopy(metadata._custom_backend_meta) for field_name in sorted(metadata.field_names): for index in range(len(metadata)): field = metadata.samples[index].get_field_by_name(field_name) assert field is not None, f"Field {field_name} not found in sample {index}" shapes.append(field.shape) dtypes.append(field.dtype) - global_index = metadata.global_indexes[index] - custom_backend_meta_list.append(all_custom_backend_meta.get(global_index, {}).get(field_name, None)) + backend_meta = field._custom_backend_meta + custom_backend_meta_list.append(copy.deepcopy(backend_meta) if backend_meta else None) return shapes, dtypes, custom_backend_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: From f0bdd9aadb9f3ccd1ae0c745a669bb29bec48b87 Mon Sep 17 00:00:00 2001 From: Evelynn-V Date: Tue, 10 Feb 2026 16:01:02 +0800 Subject: [PATCH 3/4] fix test_kv_storage_manager Signed-off-by: Evelynn-V --- tests/test_kv_storage_manager.py | 26 ++++++++++++++++++++++++-- tests/test_metadata.py | 2 +- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 083e16d6..95b5c019 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -202,7 +202,18 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): 10: {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, } metadata = test_data["metadata"] - metadata._custom_backend_meta.update(custom_backend_meta) + + for global_idx, field_metas in custom_backend_meta.items(): + template = SampleMeta( + partition_id="", + global_index=global_idx, + fields={ + fname: FieldMeta(name=fname, dtype=None, shape=None, _custom_backend_meta=meta) + for fname, meta in field_metas.items() + }, + custom_meta={} + ) + next((s for s in metadata.samples if s.global_index == global_idx), None).union(template, validate=False) shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) @@ -230,7 +241,18 @@ def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_d 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only } metadata = test_data["metadata"] - metadata._custom_backend_meta.update(custom_backend_meta) + + for global_idx, field_metas in custom_backend_meta.items(): + template = SampleMeta( + partition_id="", + global_index=global_idx, + fields={ + fname: FieldMeta(name=fname, dtype=None, shape=None, _custom_backend_meta=meta) + for fname, meta in field_metas.items() + }, + custom_meta={} + ) + next((s for s in metadata.samples if s.global_index == global_idx), None).union(template, validate=False) shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 3a1b115e..adc48468 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -25,7 +25,7 @@ # Setup path parent_dir = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(parent_dir)) +sys.path.append(str(parent_dir)) from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 From f587529025db7acf750672d6725b94ac131d78f8 Mon Sep 17 00:00:00 2001 From: Evelynn-V Date: Tue, 10 Feb 2026 16:04:05 +0800 Subject: [PATCH 4/4] fix test_kv_storage_manager Signed-off-by: Evelynn-V --- tests/test_kv_storage_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 95b5c019..62e40137 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -211,7 +211,7 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): fname: FieldMeta(name=fname, dtype=None, shape=None, _custom_backend_meta=meta) for fname, meta in field_metas.items() }, - custom_meta={} + custom_meta={}, ) next((s for s in metadata.samples if s.global_index == global_idx), None).union(template, validate=False) @@ -250,7 +250,7 @@ def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_d fname: FieldMeta(name=fname, dtype=None, shape=None, _custom_backend_meta=meta) for fname, meta in field_metas.items() }, - custom_meta={} + custom_meta={}, ) next((s for s in metadata.samples if s.global_index == global_idx), None).union(template, validate=False)