Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 433 additions & 0 deletions tests/e2e/test_checkpoint_e2e.py

Large diffs are not rendered by default.

57 changes: 57 additions & 0 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,3 +1039,60 @@ def test_controller_kv_retrieve_keys_multiple_partitions(self, ray_setup):
# Clean up
ray.get(tq_controller.clear_partition.remote(partition_1))
ray.get(tq_controller.clear_partition.remote(partition_2))


class TestTransferQueueControllerCheckpoint:
def test_save_creates_file(self, ray_setup, tmp_path):
tq_controller = TransferQueueController.remote()
ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True))

path = str(tmp_path / "ckpt.pkl")
ray.get(tq_controller.save_checkpoint.remote(path))

assert (tmp_path / "ckpt.pkl").exists()
assert (tmp_path / "ckpt.pkl").stat().st_size > 0

def test_save_raises_on_invalid_path(self, ray_setup):
tq_controller = TransferQueueController.remote()

with pytest.raises(ray.exceptions.RayTaskError, match="save checkpoint failed"):
ray.get(tq_controller.save_checkpoint.remote("/nonexistent_dir/ckpt.pkl"))

def test_load_restores_partition(self, ray_setup, tmp_path):
tq_controller = TransferQueueController.remote()
ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True))

path = str(tmp_path / "ckpt.pkl")
ray.get(tq_controller.save_checkpoint.remote(path))
ray.get(tq_controller.clear_partition.remote("p0"))
assert ray.get(tq_controller.list_partitions.remote()) == []

ray.get(tq_controller.load_checkpoint.remote(path))
assert "p0" in ray.get(tq_controller.list_partitions.remote())

def test_load_restores_key_mapping(self, ray_setup, tmp_path):
tq_controller = TransferQueueController.remote()
ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True))

path = str(tmp_path / "ckpt.pkl")
ray.get(tq_controller.save_checkpoint.remote(path))
ray.get(tq_controller.clear_partition.remote("p0"))

ray.get(tq_controller.load_checkpoint.remote(path))
snapshot = ray.get(tq_controller.get_partition_snapshot.remote("p0"))
assert "k0" in snapshot.keys_mapping
assert "k1" in snapshot.keys_mapping

def test_load_raises_on_missing_file(self, ray_setup):
tq_controller = TransferQueueController.remote()

with pytest.raises(ray.exceptions.RayTaskError, match="load checkpoint failed"):
ray.get(tq_controller.load_checkpoint.remote("/nonexistent/ckpt.pkl"))

def test_load_raises_on_corrupt_file(self, ray_setup, tmp_path):
tq_controller = TransferQueueController.remote()
path = tmp_path / "bad.pkl"
path.write_bytes(b"not a pickle")

with pytest.raises(ray.exceptions.RayTaskError, match="load checkpoint failed"):
ray.get(tq_controller.load_checkpoint.remote(str(path)))
58 changes: 58 additions & 0 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,33 @@ def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tu
assert hasattr(sampler, "_states")
assert sampler._states == {}

def test_save_load_checkpoint_roundtrip(self):
"""save_checkpoint / load_checkpoint restore _states."""

class TestSampler(BaseSampler):
def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tuple[list[int], list[int]]:
return ready_indexes[:batch_size], ready_indexes[:batch_size]

sampler = TestSampler()
sampler._states = {"p0": {"task": {0: {0: ([1, 2], [1, 2])}}}}

state = sampler.save_checkpoint()
sampler._states = {}
sampler.load_checkpoint(state)

assert sampler._states == {"p0": {"task": {0: {0: ([1, 2], [1, 2])}}}}

def test_save_checkpoint_returns_states_key(self):
"""save_checkpoint dict must contain '_states'."""

class TestSampler(BaseSampler):
def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tuple[list[int], list[int]]:
return [], []

sampler = TestSampler()
state = sampler.save_checkpoint()
assert "_states" in state


class TestSequentialSampler:
"""Test cases for SequentialSampler."""
Expand Down Expand Up @@ -1061,6 +1088,37 @@ def test_batch_size_not_divisible_by_n_samples_per_prompt(self):

assert "must be a multiple of n_samples_per_prompt" in str(exc_info.value)

def test_save_load_checkpoint_roundtrip(self):
"""save_checkpoint / load_checkpoint restore both _states and _balanced_cache."""
sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=2)
ready_indexes = [0, 1, 2, 3]

sampler.sample(ready_indexes, 2, task_name="task", partition_id="p0", dp_rank=0, batch_index=0)

assert len(sampler._balanced_cache) > 0
assert "p0" in sampler._states

state = sampler.save_checkpoint()

new_sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=2)
assert new_sampler._balanced_cache == {}
assert new_sampler._states == {}

new_sampler.load_checkpoint(state)

assert new_sampler._states == sampler._states
assert new_sampler._balanced_cache == sampler._balanced_cache

def test_save_checkpoint_includes_balanced_cache(self):
"""save_checkpoint state dict must contain both '_states' and '_balanced_cache'."""
sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=1)
sampler.sample([0, 1], 2, task_name="task", partition_id="p0", dp_rank=0, batch_index=0)

state = sampler.save_checkpoint()

assert "_states" in state
assert "_balanced_cache" in state


class TestKarmarkarKarp:
"""Test cases for karmarkar_karp and get_seqlen_balanced_partitions utilities."""
Expand Down
95 changes: 95 additions & 0 deletions tests/test_simple_storage_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def send_clear(self, client_id, global_indexes):
self.socket.send_multipart(msg.serialize())
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))

def send_save_checkpoint(self, path):
msg = ZMQMessage.create(
request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT,
sender_id="mock_client_ckpt",
body={"path": path},
)
self.socket.send_multipart(msg.serialize())
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))

def send_load_checkpoint(self, path):
msg = ZMQMessage.create(
request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT,
sender_id="mock_client_ckpt",
body={"path": path},
)
self.socket.send_multipart(msg.serialize())
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))

def close(self):
self.socket.close()
self.context.term()
Expand Down Expand Up @@ -611,3 +629,80 @@ def wrong_len_parser(field_data):
assert "data_parser changed the number of elements" in response.body["message"]

client.close()


def test_save_load_checkpoint_roundtrip(storage_setup, tmp_path):
"""Save and load restore data correctly."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)

global_indexes = [0, 1]
field_data = {
"input_ids": [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])],
"labels": [torch.tensor([0]), torch.tensor([1])],
}
client.send_put(0, global_indexes, field_data)

path = str(tmp_path / "su.pkl")
resp = client.send_save_checkpoint(path)
assert resp.request_type == ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE
assert resp.body["success"] is True

resp = client.send_load_checkpoint(path)
assert resp.request_type == ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE
assert resp.body["success"] is True

resp = client.send_get(0, global_indexes, ["input_ids"])
assert resp.request_type == ZMQRequestType.GET_DATA_RESPONSE
torch.testing.assert_close(resp.body["data"]["input_ids"][0], torch.tensor([1, 2, 3]))
torch.testing.assert_close(resp.body["data"]["input_ids"][1], torch.tensor([4, 5, 6]))

client.close()


def test_save_checkpoint_invalid_path(storage_setup):
"""Save to an inaccessible path returns failure response."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)

resp = client.send_save_checkpoint("/nonexistent_dir/su.pkl")
assert resp.request_type == ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE
assert resp.body["success"] is False
assert "message" in resp.body

client.close()


def test_load_checkpoint_missing_file(storage_setup):
"""Load from a non-existent file returns failure response."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)

resp = client.send_load_checkpoint("/nonexistent/su.pkl")
assert resp.request_type == ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE
assert resp.body["success"] is False
assert "message" in resp.body

client.close()


def test_load_checkpoint_clears_existing_data(storage_setup, tmp_path):
"""Data added after save should not survive a load."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)

client.send_put(0, [0], {"val": [torch.tensor([1])]})
path = str(tmp_path / "su.pkl")
client.send_save_checkpoint(path)

# Add index 1 after save
client.send_put(0, [1], {"val": [torch.tensor([99])]})

# Restore to checkpoint that only had index 0
client.send_load_checkpoint(path)

# Index 1 should be gone
resp = client.send_get(0, [1], ["val"])
assert resp.request_type != ZMQRequestType.GET_DATA_RESPONSE

client.close()
7 changes: 7 additions & 0 deletions transfer_queue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
kv_clear,
kv_list,
kv_put,
load_checkpoint,
save_checkpoint,
)
from .metadata import BatchMeta, KVBatchMeta
from .sampler import BaseSampler
Expand Down Expand Up @@ -62,6 +64,11 @@
"async_kv_clear",
"KVBatchMeta",
]
+ [
# Checkpoint Interface
"save_checkpoint",
"load_checkpoint",
]
+ [
# High-Level StreamingDataLoader Interface
"StreamingDataset",
Expand Down
Loading
Loading