Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions tests/e2e/test_e2e_lifecycle_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@
},
},
},
"Yuanrong": {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this test to github actions

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In currently, Yuanrong-datasystem need to be started by hands. We plan to support automatic startup of Yuanrong for tq.init() in next PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#50

"controller": {
"polling_mode": True,
},
"backend": {
"storage_backend": "Yuanrong",
"Yuanrong": {
"host": "127.0.0.1",
"port": 31501,
},
},
},
Comment on lines +81 to +92

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore it temporarily

}


Expand Down Expand Up @@ -507,15 +519,19 @@ def test_cross_shard_complex_update(e2e_client):
update_positions_in_full = [
i for i, global_index in enumerate(full_meta.global_indexes) if global_index in update_gis
]
update_meta_with_backend = full_meta.select_samples(update_positions_in_full)
# Populate empty schema for fields not yet in field_schema so select_fields can include them
for f in ["new_extra_tensor", "new_extra_non_tensor"]:
if f not in update_meta_with_backend.field_schema:
update_meta_with_backend.field_schema[f] = {}
update_meta_with_backend._field_names = sorted(update_meta_with_backend.field_schema.keys())
extended_meta = update_meta_with_backend.select_fields(
base_fields + ["new_extra_tensor", "new_extra_non_tensor"]
extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"]
extended_meta = poll_for_meta(
client,
partition_id,
extended_fields,
40,
task_name,
mode="force_fetch",
)
assert extended_meta is not None and extended_meta.size > 0, (
"Failed to fetch extended metadata for update region; poll_for_meta returned no or empty metadata."
)
extended_meta = extended_meta.select_samples(update_positions_in_full).select_fields(extended_fields)
update_region_data = client.get_data(extended_meta)
assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist"
assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist"
Expand Down
20 changes: 13 additions & 7 deletions transfer_queue/storage/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,13 +566,20 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
Store tensor data in the backend storage and notify the controller.
"""
num_samples = len(metadata.global_indexes)
if num_samples == 0:
if data.batch_size[0] != num_samples:
raise ValueError(f"Batch size of data ({data.batch_size[0]}) does not match expected ({num_samples})")

if data.batch_size[0] == 0:
logger.warning("Attempted to put data with batch size 0. Operation will be skipped.")
return

keys = self._generate_keys(data.keys(), metadata.global_indexes)
# Generate keys and values.
# metadata.field_names is legacy; generate keys/values from the actual data field names instead.
data_field_names = list(sorted(data.keys()))
keys = self._generate_keys(data_field_names, metadata.global_indexes)
values = self._generate_values(data)
loop = asyncio.get_event_loop()

loop = asyncio.get_event_loop()
custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values)

field_schema = extract_field_schema(data)
Expand All @@ -588,15 +595,14 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
for global_idx in metadata.global_indexes:
per_field_custom_backend_meta[global_idx] = {}

# FIXME(tianyi): the order of custom backend meta is coupled with keys/values
# FIXME: if put_data is called to partially update/add new fields, the current
# implementation will cause custom_backend_meta losses or mismatch!
for (field_name, global_idx), meta_value in zip(
itertools.product(sorted(metadata.field_names), metadata.global_indexes),
itertools.product(data_field_names, metadata.global_indexes),
custom_backend_meta,
strict=True,
):
per_field_custom_backend_meta[global_idx][field_name] = meta_value
# TODO: There should not visit private property of metadata,
# we should consider to add a public method in BatchMeta to set custom_backend_meta in the future.
metadata._custom_backend_meta[global_index_to_position[global_idx]][field_name] = meta_value

# Get current data partition id
Expand Down
Loading