diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 39b0b91e..c2717860 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -78,6 +78,18 @@ }, }, }, + "Yuanrong": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "Yuanrong", + "Yuanrong": { + "host": "127.0.0.1", + "port": 31501, + }, + }, + }, } @@ -507,15 +519,19 @@ def test_cross_shard_complex_update(e2e_client): update_positions_in_full = [ i for i, global_index in enumerate(full_meta.global_indexes) if global_index in update_gis ] - update_meta_with_backend = full_meta.select_samples(update_positions_in_full) - # Populate empty schema for fields not yet in field_schema so select_fields can include them - for f in ["new_extra_tensor", "new_extra_non_tensor"]: - if f not in update_meta_with_backend.field_schema: - update_meta_with_backend.field_schema[f] = {} - update_meta_with_backend._field_names = sorted(update_meta_with_backend.field_schema.keys()) - extended_meta = update_meta_with_backend.select_fields( - base_fields + ["new_extra_tensor", "new_extra_non_tensor"] + extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"] + extended_meta = poll_for_meta( + client, + partition_id, + extended_fields, + 40, + task_name, + mode="force_fetch", + ) + assert extended_meta is not None and extended_meta.size > 0, ( + "Failed to fetch extended metadata for update region; poll_for_meta returned no or empty metadata." ) + extended_meta = extended_meta.select_samples(update_positions_in_full).select_fields(extended_fields) update_region_data = client.get_data(extended_meta) assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist" assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist" diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index aa51d055..42f17db8 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -566,13 +566,20 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: Store tensor data in the backend storage and notify the controller. """ num_samples = len(metadata.global_indexes) - if num_samples == 0: + if data.batch_size[0] != num_samples: + raise ValueError(f"Batch size of data ({data.batch_size[0]}) does not match expected ({num_samples})") + + if data.batch_size[0] == 0: + logger.warning("Attempted to put data with batch size 0. Operation will be skipped.") return - keys = self._generate_keys(data.keys(), metadata.global_indexes) + # Generate keys and values. + # metadata.field_names is legacy; generate keys/values from the actual data field names instead. + data_field_names = list(sorted(data.keys())) + keys = self._generate_keys(data_field_names, metadata.global_indexes) values = self._generate_values(data) - loop = asyncio.get_event_loop() + loop = asyncio.get_event_loop() custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) field_schema = extract_field_schema(data) @@ -588,15 +595,14 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: for global_idx in metadata.global_indexes: per_field_custom_backend_meta[global_idx] = {} - # FIXME(tianyi): the order of custom backend meta is coupled with keys/values - # FIXME: if put_data is called to partially update/add new fields, the current - # implementation will cause custom_backend_meta losses or mismatch! for (field_name, global_idx), meta_value in zip( - itertools.product(sorted(metadata.field_names), metadata.global_indexes), + itertools.product(data_field_names, metadata.global_indexes), custom_backend_meta, strict=True, ): per_field_custom_backend_meta[global_idx][field_name] = meta_value + # TODO: There should not visit private property of metadata, + # we should consider to add a public method in BatchMeta to set custom_backend_meta in the future. metadata._custom_backend_meta[global_index_to_position[global_idx]][field_name] = meta_value # Get current data partition id