diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 083e16d6..62e40137 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -202,7 +202,18 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): 10: {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, } metadata = test_data["metadata"] - metadata._custom_backend_meta.update(custom_backend_meta) + + for global_idx, field_metas in custom_backend_meta.items(): + template = SampleMeta( + partition_id="", + global_index=global_idx, + fields={ + fname: FieldMeta(name=fname, dtype=None, shape=None, _custom_backend_meta=meta) + for fname, meta in field_metas.items() + }, + custom_meta={}, + ) + next((s for s in metadata.samples if s.global_index == global_idx), None).union(template, validate=False) shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) @@ -230,7 +241,18 @@ def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_d 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only } metadata = test_data["metadata"] - metadata._custom_backend_meta.update(custom_backend_meta) + + for global_idx, field_metas in custom_backend_meta.items(): + template = SampleMeta( + partition_id="", + global_index=global_idx, + fields={ + fname: FieldMeta(name=fname, dtype=None, shape=None, _custom_backend_meta=meta) + for fname, meta in field_metas.items() + }, + custom_meta={}, + ) + next((s for s in metadata.samples if s.global_index == global_idx), None).union(template, validate=False) shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 2bbf40c6..adc48468 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -37,15 +37,36 @@ class TestFieldMeta: def test_field_meta_is_ready(self): """Test the is_ready property based on production status.""" field_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2, 3), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"test_field": {"dtype": torch.float32}}, ) assert field_ready.is_ready is True field_not_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.NOT_PRODUCED + name="test_field", + dtype=torch.float32, + shape=(2, 3), + production_status=ProductionStatus.NOT_PRODUCED, + _custom_backend_meta={"test_field": {"dtype": torch.float32}}, ) assert field_not_ready.is_ready is False + def test_filed_meta_complete_integrity(self): + """Test the complete_integrity property based on production status.""" + field_complete = FieldMeta( + name="test_field", + dtype=torch.float32, + shape=(2, 3), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"test_field": {"dtype": torch.float32}}, + ) + + assert field_complete.name == "test_field" + assert field_complete._custom_backend_meta["test_field"]["dtype"] == torch.float32 + class TestSampleMeta: """SampleMeta learning examples.""" @@ -57,14 +78,24 @@ def test_sample_meta_union(self): "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), } - sample1 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields1) + sample1 = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=fields1, + custom_meta={"fields": "fields1", "global_index": 0}, + ) # Create second sample with additional fields fields2 = { "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), } - sample2 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields2) + sample2 = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=fields2, + custom_meta={"fields": "fields2", "global_index": 0}, + ) # Union samples result = sample1.union(sample2) @@ -74,18 +105,34 @@ def test_sample_meta_union(self): assert "field2" in result.fields # From sample2 assert "field3" in result.fields + assert result.custom_meta["fields"] == "fields2" + assert result.custom_meta["global_index"] == 0 + def test_sample_meta_union_validation_error(self): """Example: Union validation catches mismatched global indexes.""" sample1 = SampleMeta( partition_id="partition_0", global_index=0, - fields={"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))}, + fields={ + "field1": FieldMeta( + name="field1", + dtype=torch.float32, + shape=(2,), + _custom_backend_meta={"backend_type": "float32_tensor"}, + ) + }, + custom_meta={"source": "dataset_A"}, ) sample2 = SampleMeta( partition_id="partition_0", global_index=1, # Different global index - fields={"field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,))}, + fields={ + "field2": FieldMeta( + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend_type": "int64_tensor"} + ) + }, + custom_meta={"source": "dataset_B"}, ) with pytest.raises(ValueError) as exc_info: @@ -96,30 +143,56 @@ def test_sample_meta_add_fields(self): """Example: Add new fields to a sample.""" initial_fields = { "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="field1", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"field1": {"dtype": torch.float32}}, ) } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=initial_fields) + sample = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=initial_fields, + custom_meta={"fields": "fields1", "global_index": 0}, + ) new_fields = { "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(3,), production_status=ProductionStatus.READY_FOR_CONSUME + name="field2", + dtype=torch.int64, + shape=(3,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"field2": {"dtype": torch.int64}}, ) } sample.add_fields(new_fields) assert "field1" in sample.fields assert "field2" in sample.fields + assert sample.fields["field2"]._custom_backend_meta["field2"]["dtype"] == torch.int64 + assert sample.custom_meta["fields"] == "fields1" assert sample.is_ready is True def test_sample_meta_select_fields(self): """Example: Select specific fields from a sample.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), + "field1": FieldMeta( + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend_type": "float32_tensor"} + ), + "field2": FieldMeta( + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend_type": "int64_tensor"} + ), + "field3": FieldMeta( + name="field3", dtype=torch.bool, shape=(4,), _custom_backend_meta={"backend_type": "bool_tensor"} + ), } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) + sample = SampleMeta( + partition_id="partition_0", + global_index=0, + fields=fields, + custom_meta={"source": "dataset_X", "priority": "high"}, + ) # Select only field1 and field3 selected_sample = sample.select_fields(["field1", "field3"]) @@ -127,6 +200,9 @@ def test_sample_meta_select_fields(self): assert "field1" in selected_sample.fields assert "field3" in selected_sample.fields assert "field2" not in selected_sample.fields + # Verify custom_backend_meta is preserved in selected fields + assert selected_sample.fields["field1"]._custom_backend_meta["backend_type"] == "float32_tensor" + assert selected_sample.fields["field3"]._custom_backend_meta["backend_type"] == "bool_tensor" # Original sample is unchanged assert len(sample.fields) == 3 # Selected sample has correct metadata @@ -134,14 +210,21 @@ def test_sample_meta_select_fields(self): assert selected_sample.fields["field1"].shape == (2,) assert selected_sample.global_index == 0 assert selected_sample.partition_id == "partition_0" + # Verify custom_meta is deep copied correctly + assert selected_sample.custom_meta == {"source": "dataset_X", "priority": "high"} + assert selected_sample.custom_meta is not sample.custom_meta # Ensure deep copy def test_sample_meta_select_fields_with_nonexistent_fields(self): """Example: Select fields ignores non-existent field names.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), + "field1": FieldMeta( + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend_type": "float32_tensor"} + ), + "field2": FieldMeta( + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend_type": "int64_tensor"} + ), } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) + sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"valid": True}) # Try to select a field that doesn't exist selected_sample = sample.select_fields(["field1", "nonexistent_field"]) @@ -150,14 +233,24 @@ def test_sample_meta_select_fields_with_nonexistent_fields(self): assert "field1" in selected_sample.fields assert "nonexistent_field" not in selected_sample.fields assert "field2" not in selected_sample.fields + # Verify custom_backend_meta is preserved for selected field + assert selected_sample.fields["field1"]._custom_backend_meta["backend_type"] == "float32_tensor" + # Verify custom_meta is preserved + assert selected_sample.custom_meta == {"valid": True} def test_sample_meta_select_fields_empty_list(self): """Example: Select with empty field list returns sample with no fields.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), + "field1": FieldMeta( + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend_type": "float32_tensor"} + ), + "field2": FieldMeta( + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend_type": "int64_tensor"} + ), } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) + sample = SampleMeta( + partition_id="partition_0", global_index=0, fields=fields, custom_meta={"metadata_version": 2} + ) # Select with empty list selected_sample = sample.select_fields([]) @@ -165,6 +258,8 @@ def test_sample_meta_select_fields_empty_list(self): assert len(selected_sample.fields) == 0 assert selected_sample.global_index == 0 assert selected_sample.partition_id == "partition_0" + # Verify custom_meta is preserved even with no fields + assert selected_sample.custom_meta == {"metadata_version": 2} class TestBatchMeta: @@ -172,17 +267,27 @@ class TestBatchMeta: def test_batch_meta_chunk(self): """Example: Split a batch into multiple chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + # Initialize samples with custom_meta at SampleMeta level and _custom_backend_meta at FieldMeta level + samples = [] + for i in range(10): + fields = { + "test_field": FieldMeta( + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, # Moved to FieldMeta + ) + } + sample = SampleMeta( + partition_id="partition_0", + global_index=i, + fields=fields, + custom_meta={"uid": i}, # Moved to SampleMeta ) - } - samples = [SampleMeta(partition_id="partition_0", global_index=i, fields=fields) for i in range(10)] - batch = BatchMeta( - samples=samples, - custom_meta={i: {"uid": i} for i in range(10)}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in range(10)}, - ) + samples.append(sample) + + batch = BatchMeta(samples=samples) # Removed custom_meta/_custom_backend_meta params # Chunk into 3 parts chunks = batch.chunk(3) @@ -192,35 +297,42 @@ def test_batch_meta_chunk(self): assert len(chunks[1]) == 3 assert len(chunks[2]) == 3 - # validate custom_meta is chunked - assert 0 in chunks[0].custom_meta - assert 1 in chunks[0].custom_meta - assert 2 in chunks[0].custom_meta - assert 3 in chunks[0].custom_meta - assert 4 not in chunks[0].custom_meta - assert 4 in chunks[1].custom_meta - - # validate _custom_backend_meta is chunked - assert 0 in chunks[0]._custom_backend_meta - assert 1 in chunks[0]._custom_backend_meta - assert 2 in chunks[0]._custom_backend_meta - assert 3 in chunks[0]._custom_backend_meta - assert 4 not in chunks[0]._custom_backend_meta - assert 4 in chunks[1]._custom_backend_meta + assert 0 in chunks[0].global_indexes + assert 1 in chunks[0].global_indexes + assert 2 in chunks[0].global_indexes + assert 3 in chunks[0].global_indexes + assert 4 not in chunks[0].global_indexes + assert 4 in chunks[1].global_indexes + + assert chunks[0].samples[0].custom_meta["uid"] == 0 + assert chunks[0].samples[1].custom_meta["uid"] == 1 + assert chunks[0].samples[2].custom_meta["uid"] == 2 + assert chunks[0].samples[3].custom_meta["uid"] == 3 + assert chunks[1].samples[0].custom_meta["uid"] == 4 + + # Validate _custom_backend_meta is preserved in fields (minimal change: check via fields) + assert chunks[0].samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 + assert chunks[1].samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 def test_batch_meta_chunk_by_partition(self): - """Example: Split a batch into multiple chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + """Example: Split a batch into multiple chunks by partition.""" + samples = [] + for i in range(10): + fields = { + "test_field": FieldMeta( + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, + ) + } + sample = SampleMeta( + partition_id=f"partition_{i % 4}", global_index=i + 10, fields=fields, custom_meta={"uid": i + 10} ) - } - samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i + 10, fields=fields) for i in range(10)] - batch = BatchMeta( - samples=samples, - custom_meta={i + 10: {"uid": i + 10} for i in range(10)}, - _custom_backend_meta={i + 10: {"test_field": {"dtype": torch.float32}} for i in range(10)}, - ) + samples.append(sample) + + batch = BatchMeta(samples=samples) # Removed custom_meta/_custom_backend_meta params # Chunk according to partition_id chunks = batch.chunk_by_partition() @@ -239,19 +351,15 @@ def test_batch_meta_chunk_by_partition(self): assert chunks[3].partition_ids == ["partition_3", "partition_3"] assert chunks[3].global_indexes == [13, 17] - # validate custom_meta is chunked - assert 10 in chunks[0].custom_meta - assert 14 in chunks[0].custom_meta - assert 18 in chunks[0].custom_meta - assert 11 not in chunks[0].custom_meta - assert 11 in chunks[1].custom_meta + # Validate custom_meta preserved in samples + assert chunks[0].samples[0].custom_meta == {"uid": 10} + assert chunks[0].samples[1].custom_meta == {"uid": 14} + assert chunks[0].samples[2].custom_meta == {"uid": 18} + assert chunks[1].samples[0].custom_meta == {"uid": 11} - # validate _custom_backend_meta is chunked - assert 10 in chunks[0]._custom_backend_meta - assert 14 in chunks[0]._custom_backend_meta - assert 18 in chunks[0]._custom_backend_meta - assert 11 not in chunks[0]._custom_backend_meta - assert 11 in chunks[1]._custom_backend_meta + # Validate _custom_backend_meta preserved in fields + assert chunks[0].samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 + assert chunks[1].samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 def test_batch_meta_init_validation_error_different_field_names(self): """Example: Init validation catches samples with different field names.""" @@ -272,27 +380,27 @@ def test_batch_meta_concat(self): """Example: Concatenate multiple batches.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, ) } - # Create two batches + # Create two batches with samples containing custom_meta batch1 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ], - custom_meta={i: {"uid": i} for i in [0, 1]}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [0, 1]}, + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"uid": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"uid": 1}), + ] ) batch2 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - SampleMeta(partition_id="partition_0", global_index=3, fields=fields), - ], - custom_meta={i: {"uid": i} for i in [2, 3]}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [2, 3]}, + SampleMeta(partition_id="partition_0", global_index=2, fields=fields, custom_meta={"uid": 2}), + SampleMeta(partition_id="partition_0", global_index=3, fields=fields, custom_meta={"uid": 3}), + ] ) # Concatenate batches @@ -300,8 +408,13 @@ def test_batch_meta_concat(self): assert len(result) == 4 assert result.global_indexes == [0, 1, 2, 3] - assert result.custom_meta == {i: {"uid": i} for i in [0, 1, 2, 3]} - assert result._custom_backend_meta == {i: {"test_field": {"dtype": torch.float32}} for i in [0, 1, 2, 3]} + # Validate custom_meta preserved via samples (minimal change) + assert result.samples[0].custom_meta == {"uid": 0} + assert result.samples[1].custom_meta == {"uid": 1} + assert result.samples[2].custom_meta == {"uid": 2} + assert result.samples[3].custom_meta == {"uid": 3} + # Validate _custom_backend_meta preserved via fields + assert result.samples[0].fields["test_field"]._custom_backend_meta["dtype"] == torch.float32 def test_batch_meta_concat_with_tensor_extra_info(self): """Example: Concat handles tensor extra_info by concatenating along dim=0.""" @@ -396,33 +509,36 @@ def test_batch_meta_concat_with_mixed_types(self): def test_batch_meta_union(self): """Example: Union two batches with matching global indexes.""" fields1 = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), + "field1": FieldMeta( + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend": "float32"} + ), + "field2": FieldMeta( + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend": "int64"} + ), } fields2 = { - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), + "field2": FieldMeta( + name="field2", + dtype=torch.int64, + shape=(3,), + _custom_backend_meta={"backend": "int64", "compression": "lz4"}, + ), + "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,), _custom_backend_meta={"backend": "bool"}), } batch1 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=8, fields=fields1), - SampleMeta(partition_id="partition_0", global_index=9, fields=fields1), - ], - _custom_backend_meta={ - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}} for i in [8, 9] - }, + SampleMeta(partition_id="partition_0", global_index=8, fields=fields1, custom_meta={"source": "A"}), + SampleMeta(partition_id="partition_0", global_index=9, fields=fields1, custom_meta={"source": "A"}), + ] ) batch1.extra_info["info1"] = "value1" batch2 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=8, fields=fields2), - SampleMeta(partition_id="partition_0", global_index=9, fields=fields2), - ], - _custom_backend_meta={ - i: {"field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} for i in [8, 9] - }, + SampleMeta(partition_id="partition_0", global_index=8, fields=fields2, custom_meta={"source": "B"}), + SampleMeta(partition_id="partition_0", global_index=9, fields=fields2, custom_meta={"source": "B"}), + ] ) batch2.extra_info["info2"] = "value2" @@ -434,15 +550,14 @@ def test_batch_meta_union(self): assert "field1" in sample.fields assert "field2" in sample.fields assert "field3" in sample.fields + # Verify _custom_backend_meta preserved correctly + assert sample.fields["field2"]._custom_backend_meta["backend"] == "int64" + assert sample.fields["field2"]._custom_backend_meta.get("compression") == "lz4" # Extra info is merged assert result.extra_info["info1"] == "value1" assert result.extra_info["info2"] == "value2" - - # _custom_backend_meta is merged - assert result._custom_backend_meta == { - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} - for i in [8, 9] - } + # Verify custom_meta merged correctly (last wins) + assert result.samples[0].custom_meta["source"] == "B" def test_batch_meta_union_validation(self): """Example: Union validation catches mismatched conditions.""" @@ -469,9 +584,9 @@ def test_batch_meta_reorder(self): ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields, custom_meta={"pos": 0}), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields, custom_meta={"pos": 1}), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields, custom_meta={"pos": 2}), ] batch = BatchMeta(samples=samples) @@ -483,6 +598,10 @@ def test_batch_meta_reorder(self): assert batch.samples[0].batch_index == 0 assert batch.samples[1].batch_index == 1 assert batch.samples[2].batch_index == 2 + # custom_meta preserved correctly + assert batch.samples[0].custom_meta == {"pos": 2} + assert batch.samples[1].custom_meta == {"pos": 0} + assert batch.samples[2].custom_meta == {"pos": 1} def test_batch_meta_add_fields(self): """Example: Add fields from TensorDict to all samples.""" @@ -507,30 +626,26 @@ def test_batch_meta_add_fields(self): assert "new_field1" in sample.fields assert "new_field2" in sample.fields assert sample.is_ready is True + # Verify new fields have default _custom_backend_meta (empty dict) + assert sample.fields["new_field1"]._custom_backend_meta == {} + assert sample.fields["new_field2"]._custom_backend_meta == {} def test_batch_meta_select_fields(self): """Example: Select specific fields from all samples in a batch.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), + "field1": FieldMeta( + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"precision": "fp32"} + ), + "field2": FieldMeta( + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"encoding": "varint"} + ), + "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,), _custom_backend_meta={"packing": "bit"}), } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"version": 1}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"version": 1}), ] - batch = BatchMeta( - samples=samples, - extra_info={"test_key": "test_value"}, - _custom_backend_meta={ - i: { - "field1": {"dtype": torch.float32}, - "field2": {"dtype": torch.int64}, - "field3": {"dtype": torch.bool}, - } - for i in [0, 1] - }, - ) + batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) # Select only field1 and field3 selected_batch = batch.select_fields(["field1", "field3"]) @@ -541,20 +656,16 @@ def test_batch_meta_select_fields(self): assert "field1" in sample.fields assert "field3" in sample.fields assert "field2" not in sample.fields + # Verify _custom_backend_meta preserved for selected fields + assert sample.fields["field1"]._custom_backend_meta["precision"] == "fp32" + assert sample.fields["field3"]._custom_backend_meta["packing"] == "bit" # Original batch is unchanged assert len(batch.samples[0].fields) == 3 # Extra info is preserved assert selected_batch.extra_info["test_key"] == "test_value" - # Global indexes are preserved + # Global indexes and custom_meta preserved assert selected_batch.global_indexes == [0, 1] - - # _custom_backend_meta is selected - assert "field1" in selected_batch._custom_backend_meta[0] - assert "field2" not in selected_batch._custom_backend_meta[0] - assert "field3" in selected_batch._custom_backend_meta[0] - assert "field1" in selected_batch._custom_backend_meta[1] - assert "field2" not in selected_batch._custom_backend_meta[1] - assert "field3" in selected_batch._custom_backend_meta[1] + assert selected_batch.samples[0].custom_meta == {"version": 1} def test_batch_meta_select_fields_with_nonexistent_fields(self): """Example: Select fields ignores non-existent field names in batch.""" @@ -618,10 +729,18 @@ def test_batch_meta_select_fields_preserves_field_metadata(self): """Example: Selected fields preserve their original metadata.""" fields = { "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME + name="field1", + dtype=torch.float32, + shape=(2, 3), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"source": "sensor_a"}, ), "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(5,), production_status=ProductionStatus.NOT_PRODUCED + name="field2", + dtype=torch.int64, + shape=(5,), + production_status=ProductionStatus.NOT_PRODUCED, + _custom_backend_meta={"source": "sensor_b"}, ), } samples = [ @@ -637,23 +756,28 @@ def test_batch_meta_select_fields_preserves_field_metadata(self): assert selected_field.shape == (2, 3) assert selected_field.production_status == ProductionStatus.READY_FOR_CONSUME assert selected_field.name == "field1" + assert selected_field._custom_backend_meta["source"] == "sensor_a" def test_batch_meta_select_samples(self): """Example: Select specific samples from a batch.""" fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), + "field1": FieldMeta( + name="field1", dtype=torch.float32, shape=(2,), _custom_backend_meta={"backend": "float32"} + ), + "field2": FieldMeta( + name="field2", dtype=torch.int64, shape=(3,), _custom_backend_meta={"backend": "int64"} + ), } samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), - SampleMeta(partition_id="partition_0", global_index=7, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields, custom_meta={"sample_id": 4}), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields, custom_meta={"sample_id": 5}), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields, custom_meta={"sample_id": 6}), + SampleMeta(partition_id="partition_0", global_index=7, fields=fields, custom_meta={"sample_id": 7}), ] batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) # Select samples at indices [0, 2] - selected_batch = batch.select_samples([0, 2]) # This will select the first two samples with global_index=4/5 + selected_batch = batch.select_samples([0, 2]) # Check number of samples assert len(selected_batch) == 2 @@ -663,6 +787,8 @@ def test_batch_meta_select_samples(self): for sample in selected_batch.samples: assert "field1" in sample.fields assert "field2" in sample.fields + # MINIMAL CHANGE: Verify custom_meta preserved via get_all_custom_meta() + assert sample.global_index in selected_batch.get_all_custom_meta() # Original batch is unchanged assert len(batch) == 4 # Extra info is preserved @@ -672,13 +798,17 @@ def test_batch_meta_select_samples_all_indices(self): """Example: Select all samples using complete index list.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields, custom_meta={"sample_id": 4}), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields, custom_meta={"sample_id": 5}), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields, custom_meta={"sample_id": 6}), ] batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) @@ -688,6 +818,10 @@ def test_batch_meta_select_samples_all_indices(self): # All samples are selected assert len(selected_batch) == 3 assert selected_batch.global_indexes == [4, 5, 6] + # MINIMAL CHANGE: Verify all custom_meta preserved + assert 4 in selected_batch.get_all_custom_meta() + assert 5 in selected_batch.get_all_custom_meta() + assert 6 in selected_batch.get_all_custom_meta() # Extra info is preserved assert selected_batch.extra_info["test_key"] == "test_value" @@ -695,13 +829,17 @@ def test_batch_meta_select_samples_single_sample(self): """Example: Select a single sample from batch.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), + SampleMeta(partition_id="partition_0", global_index=2, fields=fields, custom_meta={"sample_id": 2}), ] batch = BatchMeta(samples=samples) @@ -710,18 +848,24 @@ def test_batch_meta_select_samples_single_sample(self): assert len(selected_batch) == 1 assert selected_batch.global_indexes == [1] + # MINIMAL CHANGE: Verify custom_meta preserved + assert 1 in selected_batch.get_all_custom_meta() assert selected_batch.samples[0].batch_index == 0 # New batch index def test_batch_meta_select_samples_empty_list(self): """Example: Select with empty list returns empty batch.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), ] batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) @@ -732,18 +876,24 @@ def test_batch_meta_select_samples_empty_list(self): assert selected_batch.global_indexes == [] # Extra info is still preserved assert selected_batch.extra_info["test_key"] == "test_value" + # MINIMAL CHANGE: get_all_custom_meta returns empty dict for empty batch + assert selected_batch.get_all_custom_meta() == {} def test_batch_meta_select_samples_reverse_order(self): """Example: Select samples in reverse order.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), + SampleMeta(partition_id="partition_0", global_index=2, fields=fields, custom_meta={"sample_id": 2}), ] batch = BatchMeta(samples=samples) @@ -752,6 +902,10 @@ def test_batch_meta_select_samples_reverse_order(self): assert len(selected_batch) == 3 assert selected_batch.global_indexes == [2, 1, 0] + # MINIMAL CHANGE: Verify all custom_meta preserved in new order + assert 2 in selected_batch.get_all_custom_meta() + assert 1 in selected_batch.get_all_custom_meta() + assert 0 in selected_batch.get_all_custom_meta() # Batch indexes are re-assigned assert selected_batch.samples[0].global_index == 2 assert selected_batch.samples[1].global_index == 1 @@ -761,12 +915,16 @@ def test_batch_meta_select_samples_with_extra_info(self): """Example: Select samples preserves all extra info types.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), ] batch = BatchMeta(samples=samples) @@ -784,6 +942,8 @@ def test_batch_meta_select_samples_with_extra_info(self): assert selected_batch.extra_info["string"] == "test_string" assert selected_batch.extra_info["number"] == 42 assert selected_batch.extra_info["list"] == [1, 2, 3] + # MINIMAL CHANGE: Verify custom_meta preserved + assert 0 in selected_batch.get_all_custom_meta() # ===================================================== # Custom Meta Tests @@ -1091,12 +1251,16 @@ def test_batch_meta_chunk_with_more_chunks_than_samples(self): """Example: Chunking when chunks > samples produces empty chunks.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields, custom_meta={"sample_id": 1}), ] batch = BatchMeta(samples=samples) @@ -1111,23 +1275,38 @@ def test_batch_meta_chunk_with_more_chunks_than_samples(self): assert len(chunks[2]) == 0 assert len(chunks[3]) == 0 assert len(chunks[4]) == 0 + # MINIMAL CHANGE: Verify custom_meta preserved in non-empty chunks + if len(chunks[0]) > 0: + assert 0 in chunks[0].get_all_custom_meta() + if len(chunks[1]) > 0: + assert 1 in chunks[1].get_all_custom_meta() def test_batch_meta_concat_with_empty_batches(self): """Example: Concat handles empty batches gracefully.""" fields = { "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + name="test_field", + dtype=torch.float32, + shape=(2,), + production_status=ProductionStatus.READY_FOR_CONSUME, + _custom_backend_meta={"dtype": torch.float32}, ) } batch1 = BatchMeta(samples=[]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) + batch2 = BatchMeta( + samples=[ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields, custom_meta={"sample_id": 0}) + ] + ) batch3 = BatchMeta(samples=[]) # Empty batches are filtered out result = BatchMeta.concat([batch1, batch2, batch3]) assert len(result) == 1 assert result.global_indexes == [0] + # MINIMAL CHANGE: Verify custom_meta preserved + assert 0 in result.get_all_custom_meta() def test_batch_meta_concat_validation_error(self): """Example: Concat validation catches field name mismatches.""" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 763f3f49..ff2c4a9e 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1293,6 +1293,9 @@ def generate_batch_meta( if mode not in ["fetch", "insert", "force_fetch"]: raise ValueError(f"Invalid mode: {mode}") + custom_meta = partition.get_custom_meta(batch_global_indexes) + custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) + # Generate sample metadata samples = [] for global_index in batch_global_indexes: @@ -1322,11 +1325,16 @@ def generate_batch_meta( dtype = None shape = None + backend_meta = {} + if global_index in custom_backend_meta and field_name in custom_backend_meta[global_index]: + backend_meta = copy.deepcopy(custom_backend_meta[global_index][field_name]) + fields[field_name] = FieldMeta( name=field_name, dtype=dtype, shape=shape, production_status=production_status, + _custom_backend_meta=backend_meta, ) sample = SampleMeta( @@ -1336,12 +1344,8 @@ def generate_batch_meta( ) samples.append(sample) - custom_meta = partition.get_custom_meta(batch_global_indexes) - custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) - batch_meta = BatchMeta(samples=samples) batch_meta.update_custom_meta(custom_meta) - batch_meta._custom_backend_meta.update(custom_backend_meta) return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 5e254047..8f2f055f 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -48,6 +48,7 @@ class FieldMeta: dtype: Optional[Any] # Data type (e.g., torch.float32, numpy.float32) shape: Optional[Any] # Data shape (e.g., torch.Size([3, 224, 224]), (3, 224, 224)) production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED + _custom_backend_meta: dict[str, Any] = dataclasses.field(default_factory=dict) def __str__(self) -> str: return ( @@ -70,6 +71,7 @@ def from_dict(cls, data: dict) -> "FieldMeta": production_status=ProductionStatus(str(data["production_status"])) if isinstance(data["production_status"], int | str) else data["production_status"], + _custom_backend_meta=data.get("_custom_backend_meta", {}), ) @@ -80,6 +82,7 @@ class SampleMeta: partition_id: str # Partition id, used for data versioning global_index: int # Global row index, uniquely identifies a data sample fields: dict[str, FieldMeta] # Fields of interest for this sample + custom_meta: dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): """Initialize is_ready property based on field readiness""" @@ -137,11 +140,11 @@ def select_fields(self, field_names: list[str]) -> "SampleMeta": selected_fields = {name: self.fields[name] for name in field_names if name in self.fields} # construct new SampleMeta instance - # TODO(tianyi): (maybe) move _custom_backend_meta and custom_meta to FieldMeta level? selected_sample_meta = SampleMeta( fields=selected_fields, partition_id=self.partition_id, global_index=self.global_index, + custom_meta=copy.deepcopy(self.custom_meta), ) return selected_sample_meta @@ -168,6 +171,9 @@ def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta": # Merge fields self.fields = _union_fields(self.fields, other.fields) + # Merge custom meta + self.custom_meta = {**self.custom_meta, **other.custom_meta} + # Update is_ready property object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) return self @@ -193,6 +199,7 @@ def from_dict(cls, data: dict) -> "SampleMeta": partition_id=data["partition_id"], global_index=data["global_index"], fields=fields, + custom_meta=data.get("custom_meta", {}), ) @@ -205,12 +212,6 @@ class BatchMeta: # external meta for non-sample level information extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - # user-defined meta for each sample - custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) - - # internal meta for different storage backends in per-sample per-field level - _custom_backend_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) - def __post_init__(self): """Initialize all computed properties during initialization""" self.samples = copy.deepcopy(self.samples) @@ -235,16 +236,7 @@ def __post_init__(self): object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) - # filter custom_meta and _custom_backend_meta - self.custom_meta = copy.deepcopy( - {k: self.custom_meta[k] for k in self.global_indexes if k in self.custom_meta} - ) - self._custom_backend_meta = copy.deepcopy( - {k: self._custom_backend_meta[k] for k in self.global_indexes if k in self._custom_backend_meta} - ) else: - self.custom_meta = {} - self._custom_backend_meta = {} object.__setattr__(self, "_global_indexes", []) object.__setattr__(self, "_field_names", []) object.__setattr__(self, "_partition_ids", []) @@ -321,7 +313,10 @@ def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: if global_index not in self.global_indexes: raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.") - self.custom_meta[global_index] = copy.deepcopy(meta_dict) + for sample in self.samples: + if sample.global_index == global_index: + sample.custom_meta = copy.deepcopy(meta_dict) + break def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: """ @@ -330,7 +325,7 @@ def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: Returns: A deep copy of the custom_meta dictionary """ - return copy.deepcopy(self.custom_meta) + return {s.global_index: copy.deepcopy(s.custom_meta) for s in self.samples if s.custom_meta} def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): """ @@ -356,7 +351,11 @@ def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): f"do not exist in this batch." ) - self.custom_meta.update(new_meta) + for global_index, meta_dict in new_meta.items(): + for sample in self.samples: + if sample.global_index == global_index: + sample.custom_meta.update(copy.deepcopy(meta_dict)) + break def clear_custom_meta(self) -> None: """ @@ -364,7 +363,8 @@ def clear_custom_meta(self) -> None: This method removes all entries from the custom_meta dictionary. """ - self.custom_meta.clear() + for sample in self.samples: + sample.custom_meta.clear() def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": """ @@ -407,18 +407,10 @@ def select_samples(self, indexes: list[int]) -> "BatchMeta": selected_samples = [self.samples[i] for i in indexes] - global_indexes = [self.global_indexes[i] for i in indexes] - selected_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} - selected_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta - } - # construct new BatchMeta instance selected_batch_meta = BatchMeta( samples=selected_samples, extra_info=self.extra_info, - custom_meta=selected_custom_meta, - _custom_backend_meta=selected_custom_backend_meta, ) return selected_batch_meta @@ -437,22 +429,10 @@ def select_fields(self, field_names: list[str]) -> "BatchMeta": # select fields for each SampleMeta new_samples = [sample.select_fields(field_names=field_names) for sample in self.samples] - # select fields in _custom_backend_meta - selected_custom_backend_meta = {} - for idx in self.global_indexes: - if idx in self._custom_backend_meta: - custom_backend_meta_idx = self._custom_backend_meta[idx] - - selected_custom_backend_meta[idx] = { - field: custom_backend_meta_idx[field] for field in field_names if field in custom_backend_meta_idx - } - # construct new BatchMeta instance new_batch_meta = BatchMeta( samples=new_samples, extra_info=self.extra_info, - custom_meta=self.custom_meta, - _custom_backend_meta=selected_custom_backend_meta, ) return new_batch_meta @@ -464,23 +444,10 @@ def __len__(self) -> int: def __getitem__(self, item): if isinstance(item, int | np.integer): sample_meta = self.samples[item] if self.samples else [] - global_idx = self.global_indexes[item] - - if global_idx in self.custom_meta: - custom_meta = {global_idx: self.custom_meta[global_idx]} - else: - custom_meta = {} - - if global_idx in self._custom_backend_meta: - custom_backend_meta = {global_idx: self._custom_backend_meta[global_idx]} - else: - custom_backend_meta = {} return BatchMeta( samples=[sample_meta], extra_info=self.extra_info, - custom_meta=custom_meta, - _custom_backend_meta=custom_backend_meta, ) else: raise TypeError(f"Indexing with {type(item)} is not supported now!") @@ -514,16 +481,9 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: current_chunk_size = base_size + 1 if i < remainder else base_size end = start + current_chunk_size chunk_samples = self.samples[start:end] - global_indexes = self.global_indexes[start:end] - chunk_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} - chunk_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta - } chunk = BatchMeta( samples=chunk_samples, extra_info=self.extra_info, - custom_meta=chunk_custom_meta, - _custom_backend_meta=chunk_custom_backend_meta, ) chunk_list.append(chunk) start = end @@ -564,14 +524,14 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": """ if not data: logger.warning("Try to concat empty BatchMeta chunks. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) + return BatchMeta(samples=[], extra_info={}) # skip empty chunks data = [chunk for chunk in data if chunk and len(chunk.samples) > 0] if len(data) == 0: logger.warning("No valid BatchMeta chunks to concatenate. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) + return BatchMeta(samples=[], extra_info={}) if validate: base_fields = data[0].field_names @@ -585,15 +545,9 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": # Merge all extra_info dictionaries from the chunks merged_extra_info = dict() - merged_custom_meta = dict() - merged_custom_backend_meta = dict() values_by_key = defaultdict(list) for chunk in data: - # For the sample-level custom_meta and field-level _custom_backend_meta, we directly update the dict. - merged_custom_meta.update(chunk.custom_meta) - merged_custom_backend_meta.update(chunk._custom_backend_meta) - for key, value in chunk.extra_info.items(): values_by_key[key].append(value) @@ -622,8 +576,6 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": return BatchMeta( samples=all_samples, extra_info=merged_extra_info, - custom_meta=merged_custom_meta, - _custom_backend_meta=merged_custom_backend_meta, ) def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]: @@ -670,24 +622,9 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet # Merge extra info dictionaries merged_extra_info = {**self.extra_info, **other.extra_info} - # Merge custom_meta dictionaries - merged_custom_meta = {**self.custom_meta, **other.custom_meta} - - # Merge custom_backend_meta dictionaries - merged_custom_backend_meta = {} - for idx in self.global_indexes: - if idx in self._custom_backend_meta and idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx], **other._custom_backend_meta[idx]} - elif idx in self._custom_backend_meta: - merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx]} - elif idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = {**other._custom_backend_meta[idx]} - return BatchMeta( samples=merged_samples, extra_info=merged_extra_info, - custom_meta=merged_custom_meta, - _custom_backend_meta=merged_custom_backend_meta, ) def reorder(self, indexes: list[int]): @@ -784,7 +721,7 @@ def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": """ if extra_info is None: extra_info = {} - return cls(samples=[], extra_info=extra_info, custom_meta={}, _custom_backend_meta={}) + return cls(samples=[], extra_info=extra_info) def __str__(self): sample_strs = ", ".join(str(sample) for sample in self.samples) @@ -803,8 +740,6 @@ def from_dict(cls, data: dict) -> "BatchMeta": return cls( samples=samples, extra_info=data.get("extra_info", {}), - custom_meta=data.get("custom_meta", {}), - _custom_backend_meta=data.get("_custom_backend_meta", {}), ) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 352b3adc..ec6b296e 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -523,15 +523,14 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): shapes = [] dtypes = [] custom_backend_meta_list = [] - all_custom_backend_meta = copy.deepcopy(metadata._custom_backend_meta) for field_name in sorted(metadata.field_names): for index in range(len(metadata)): field = metadata.samples[index].get_field_by_name(field_name) assert field is not None, f"Field {field_name} not found in sample {index}" shapes.append(field.shape) dtypes.append(field.dtype) - global_index = metadata.global_indexes[index] - custom_backend_meta_list.append(all_custom_backend_meta.get(global_index, {}).get(field_name, None)) + backend_meta = field._custom_backend_meta + custom_backend_meta_list.append(copy.deepcopy(backend_meta) if backend_meta else None) return shapes, dtypes, custom_backend_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: