diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py new file mode 100644 index 00000000..c1850daa --- /dev/null +++ b/tests/e2e/test_checkpoint_e2e.py @@ -0,0 +1,433 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for save_checkpoint and load_checkpoint. + +Run with: + pytest tests/e2e/test_checkpoint_e2e.py -v +""" + +import json +import os + +import pytest +import ray +import torch +from omegaconf import OmegaConf +from tensordict import NonTensorStack, TensorDict + +import transfer_queue as tq + +os.environ["RAY_DEDUP_LOGS"] = "0" + +_TQ_CONFIG = OmegaConf.create( + { + "controller": {"polling_mode": True}, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, + }, + } +) + + +# --------------------------------------------------------------------------- +# fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def ray_init(): + if not ray.is_initialized(): + ray.init(namespace="TestCheckpointE2E") + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def tq_system(ray_init): + tq.init(_TQ_CONFIG) + yield + tq.close() + + +@pytest.fixture +def controller(tq_system): + return ray.get_actor("TransferQueueController", namespace="transfer_queue") + + +@pytest.fixture(autouse=True) +def cleanup_partitions(controller): + yield + try: + for pid in ray.get(controller.list_partitions.remote()): + ray.get(controller.clear_partition.remote(pid)) + except Exception: + pass + + +@pytest.fixture +def checkpoint_dir(tmp_path): + return tmp_path / "checkpoint" + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _assert_tensor_equal(a, b, msg=""): + if (isinstance(a, torch.Tensor) and a.is_nested) or (isinstance(b, torch.Tensor) and b.is_nested): + for t1, t2 in zip(list(a), list(b), strict=True): + assert torch.equal(t1, t2), f"{msg} mismatch" + else: + assert torch.equal(a, b), f"{msg} mismatch" + + +# --------------------------------------------------------------------------- +# basic save / load roundtrip +# --------------------------------------------------------------------------- + + +class TestCheckpointRoundtrip: + """Standard data → save → verify files → wipe → load → verify data.""" + + def test_tensor_fields(self, tq_system, checkpoint_dir, controller): + # Define test data + keys = ["k0", "k1"] + partition_id = "p_tensor" + input_ids = torch.tensor([[1, 2], [3, 4]]) + attention_mask = torch.ones(2, 2) + + # Put + tq.kv_batch_put( + keys=keys, + partition_id=partition_id, + fields=TensorDict({"input_ids": input_ids, "attention_mask": attention_mask}, batch_size=len(keys)), + tags=[{} for _ in keys], + ) + + # Save + tq.save_checkpoint(checkpoint_dir) + + # Check saved state: expected files exist + assert (checkpoint_dir / "metadata.json").exists() + assert (checkpoint_dir / "controller_state.pkl").exists() + su_dir = checkpoint_dir / "simple_storage" + assert su_dir.exists() + assert (su_dir / "storage_unit_info.json").exists() + with open(checkpoint_dir / "metadata.json") as f: + meta = json.load(f) + assert meta["storage_saved"] is True + + # Wipe controller state so load has real work to do + ray.get(controller.clear_partition.remote(partition_id)) + assert ray.get(controller.list_partitions.remote()) == [] + + # Load + tq.load_checkpoint(checkpoint_dir) + + # Check loaded state: partition and data restored + assert partition_id in ray.get(controller.list_partitions.remote()) + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + _assert_tensor_equal(retrieved["input_ids"], input_ids) + _assert_tensor_equal(retrieved["attention_mask"], attention_mask) + + def test_controller_metadata(self, tq_system, checkpoint_dir, controller): + # Define test data + keys = ["a0", "a1", "a2"] + partition_id = "p_meta" + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + tags = [{"idx": i} for i in range(3)] + + # Put + tq.kv_batch_put( + keys=keys, + partition_id=partition_id, + fields=TensorDict({"input_ids": input_ids, "attention_mask": torch.ones(3, 3)}, batch_size=len(keys)), + tags=tags, + ) + + # Save + tq.save_checkpoint(checkpoint_dir) + + # Wipe + ray.get(controller.clear_partition.remote(partition_id)) + + # Load + tq.load_checkpoint(checkpoint_dir) + + # Check loaded state: key mapping and tags intact + snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) + for i, key in enumerate(keys): + assert key in snapshot.keys_mapping + gidx = snapshot.keys_mapping[key] + assert snapshot.custom_meta[gidx]["idx"] == i + + def test_multiple_partitions(self, tq_system, checkpoint_dir, controller): + # Define test data + partitions_data = {f"part_{i}": (torch.full((2, 4), i, dtype=torch.long), torch.ones(2, 4)) for i in range(3)} + + # Put + for pid, (iids, mask) in partitions_data.items(): + tq.kv_batch_put( + keys=[f"{pid}_k0", f"{pid}_k1"], + partition_id=pid, + fields=TensorDict({"input_ids": iids, "attention_mask": mask}, batch_size=2), + tags=[{}, {}], + ) + + # Save + tq.save_checkpoint(checkpoint_dir) + + # Wipe + for pid in partitions_data: + ray.get(controller.clear_partition.remote(pid)) + + # Load + tq.load_checkpoint(checkpoint_dir) + + # Check loaded state + for pid, (iids, _) in partitions_data.items(): + retrieved = tq.kv_batch_get(keys=[f"{pid}_k0", f"{pid}_k1"], partition_id=pid, select_fields=["input_ids"]) + _assert_tensor_equal(retrieved["input_ids"], iids) + + def test_user_metadata_preserved(self, tq_system, checkpoint_dir): + # Define test data + keys = ["m0"] + + # Put + tq.kv_batch_put( + keys=keys, + partition_id="p_usermeta", + fields=TensorDict( + {"input_ids": torch.tensor([[10, 20]]), "attention_mask": torch.ones(1, 2)}, batch_size=1 + ), + tags=[{}], + ) + + # Save with user metadata + tq.save_checkpoint(checkpoint_dir, metadata={"iteration": 42, "loss": 0.5}) + + # Check saved state: user metadata written correctly + with open(checkpoint_dir / "metadata.json") as f: + meta = json.load(f) + assert meta["user_metadata"]["iteration"] == 42 + assert meta["user_metadata"]["loss"] == pytest.approx(0.5) + + def test_non_tensor_fields(self, tq_system, checkpoint_dir, controller): + # Define test data + keys = ["t0", "t1"] + partition_id = "p_str" + input_ids = torch.tensor([[1, 2], [3, 4]]) + fields = TensorDict( + {"input_ids": input_ids, "text": NonTensorStack("hello", "world")}, + batch_size=2, + ) + + # Put + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}]) + + # Save + tq.save_checkpoint(checkpoint_dir) + + # Wipe + ray.get(controller.clear_partition.remote(partition_id)) + + # Load + tq.load_checkpoint(checkpoint_dir) + + # Check loaded state + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["input_ids"]) + _assert_tensor_equal(retrieved["input_ids"], input_ids) + + def test_nested_tensor_fields(self, tq_system, checkpoint_dir, controller): + # Define test data + keys = ["j0", "j1", "j2"] + partition_id = "p_jagged" + + # Put (variable-length sequences) + for i, key in enumerate(keys): + tq.kv_put( + key=key, + partition_id=partition_id, + fields=TensorDict({"seq": torch.arange(i + 1, dtype=torch.float).unsqueeze(0)}, batch_size=1), + tag=None, + ) + + # Save + tq.save_checkpoint(checkpoint_dir) + + # Wipe + ray.get(controller.clear_partition.remote(partition_id)) + + # Load + tq.load_checkpoint(checkpoint_dir) + + # Check loaded state + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["seq"]) + for i, component in enumerate(retrieved["seq"].unbind()): + _assert_tensor_equal(component, torch.arange(i + 1, dtype=torch.float)) + + +# --------------------------------------------------------------------------- +# include_storage=False (SimpleStorage override) +# --------------------------------------------------------------------------- + + +class TestIncludeStorageFalse: + """For SimpleStorage, include_storage=False is silently forced to True.""" + + def test_storage_saved_is_true(self, tq_system, checkpoint_dir): + # Define test data + Put + tq.kv_batch_put( + keys=["n0"], + partition_id="p_nometa", + fields=TensorDict({"input_ids": torch.tensor([[1, 2]]), "attention_mask": torch.ones(1, 2)}, batch_size=1), + tags=[{}], + ) + + # Save with include_storage=False + tq.save_checkpoint(checkpoint_dir, include_storage=False) + + # Check saved state: storage_saved must be True and directory must exist + with open(checkpoint_dir / "metadata.json") as f: + meta = json.load(f) + assert meta["storage_saved"] is True + assert (checkpoint_dir / "simple_storage").exists() + + def test_both_restored_after_load(self, tq_system, checkpoint_dir, controller): + # Define test data + keys = ["n0", "n1"] + partition_id = "p_nometa2" + input_ids = torch.tensor([[5, 6], [7, 8]]) + + # Put + tq.kv_batch_put( + keys=keys, + partition_id=partition_id, + fields=TensorDict({"input_ids": input_ids, "attention_mask": torch.ones(2, 2)}, batch_size=len(keys)), + tags=[{} for _ in keys], + ) + + # Save + tq.save_checkpoint(checkpoint_dir, include_storage=False) + + # Wipe + ray.get(controller.clear_partition.remote(partition_id)) + + # Load + tq.load_checkpoint(checkpoint_dir) + + # Check loaded state: both controller and storage restored + assert partition_id in ray.get(controller.list_partitions.remote()) + snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) + for key in keys: + assert key in snapshot.keys_mapping + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + _assert_tensor_equal(retrieved["input_ids"], input_ids) + + +# --------------------------------------------------------------------------- +# error handling +# --------------------------------------------------------------------------- + + +class TestCheckpointErrors: + def test_save_raises_if_not_initialized(self, tmp_path): + import transfer_queue.interface as iface + + original = iface._TQ_CONTROLLER + try: + iface._TQ_CONTROLLER = None + with pytest.raises(RuntimeError, match="not initialized"): + tq.save_checkpoint(tmp_path / "ck") + finally: + iface._TQ_CONTROLLER = original + + def test_load_raises_if_not_initialized(self, tmp_path): + import transfer_queue.interface as iface + + original = iface._TQ_CONTROLLER + try: + iface._TQ_CONTROLLER = None + with pytest.raises(RuntimeError, match="not initialized"): + tq.load_checkpoint(tmp_path / "ck") + finally: + iface._TQ_CONTROLLER = original + + def test_load_raises_if_dir_missing(self, tq_system, tmp_path): + with pytest.raises(FileNotFoundError): + tq.load_checkpoint(tmp_path / "nonexistent") + + def test_load_raises_if_metadata_missing(self, tq_system, tmp_path): + ck = tmp_path / "ck" + ck.mkdir() + with pytest.raises(FileNotFoundError, match="metadata.json"): + tq.load_checkpoint(ck) + + def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, checkpoint_dir, controller): + # Define test data + Put + Save + tq.kv_batch_put( + keys=["e0"], + partition_id="p_err", + fields=TensorDict({"input_ids": torch.tensor([[1, 2]]), "attention_mask": torch.ones(1, 2)}, batch_size=1), + tags=[{}], + ) + tq.save_checkpoint(checkpoint_dir) + + # Tamper: add a fake extra SU entry so count differs + su_info_path = checkpoint_dir / "simple_storage" / "storage_unit_info.json" + with open(su_info_path) as f: + su_info = json.load(f) + su_info.append({"position": 99, "storage_unit_id": "fake"}) + with open(su_info_path, "w") as f: + json.dump(su_info, f) + + partitions_before = ray.get(controller.list_partitions.remote()) + + with pytest.raises(ValueError, match="count mismatch"): + tq.load_checkpoint(checkpoint_dir) + + # Controller state must not have been modified + assert ray.get(controller.list_partitions.remote()) == partitions_before + + def test_no_partial_state_on_failed_save(self, tq_system, tmp_path): + # Define test data + Put + tq.kv_batch_put( + keys=["f0"], + partition_id="p_fail", + fields=TensorDict({"input_ids": torch.tensor([[1, 2]]), "attention_mask": torch.ones(1, 2)}, batch_size=1), + tags=[{}], + ) + + ck = tmp_path / "ck" + + import unittest.mock as mock + + with mock.patch( + "transfer_queue.client.TransferQueueClient.save_storage_checkpoint", + side_effect=RuntimeError("simulated dump failure"), + ): + with pytest.raises(RuntimeError, match="simulated dump failure"): + tq.save_checkpoint(ck) + + # Check saved state: no partial directory left + assert not ck.exists() + assert not (tmp_path / "ck.tmp").exists() diff --git a/tests/test_controller.py b/tests/test_controller.py index d45c54eb..c717894f 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -1039,3 +1039,60 @@ def test_controller_kv_retrieve_keys_multiple_partitions(self, ray_setup): # Clean up ray.get(tq_controller.clear_partition.remote(partition_1)) ray.get(tq_controller.clear_partition.remote(partition_2)) + + +class TestTransferQueueControllerCheckpoint: + def test_save_creates_file(self, ray_setup, tmp_path): + tq_controller = TransferQueueController.remote() + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True)) + + path = str(tmp_path / "ckpt.pkl") + ray.get(tq_controller.save_checkpoint.remote(path)) + + assert (tmp_path / "ckpt.pkl").exists() + assert (tmp_path / "ckpt.pkl").stat().st_size > 0 + + def test_save_raises_on_invalid_path(self, ray_setup): + tq_controller = TransferQueueController.remote() + + with pytest.raises(ray.exceptions.RayTaskError, match="save checkpoint failed"): + ray.get(tq_controller.save_checkpoint.remote("/nonexistent_dir/ckpt.pkl")) + + def test_load_restores_partition(self, ray_setup, tmp_path): + tq_controller = TransferQueueController.remote() + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True)) + + path = str(tmp_path / "ckpt.pkl") + ray.get(tq_controller.save_checkpoint.remote(path)) + ray.get(tq_controller.clear_partition.remote("p0")) + assert ray.get(tq_controller.list_partitions.remote()) == [] + + ray.get(tq_controller.load_checkpoint.remote(path)) + assert "p0" in ray.get(tq_controller.list_partitions.remote()) + + def test_load_restores_key_mapping(self, ray_setup, tmp_path): + tq_controller = TransferQueueController.remote() + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True)) + + path = str(tmp_path / "ckpt.pkl") + ray.get(tq_controller.save_checkpoint.remote(path)) + ray.get(tq_controller.clear_partition.remote("p0")) + + ray.get(tq_controller.load_checkpoint.remote(path)) + snapshot = ray.get(tq_controller.get_partition_snapshot.remote("p0")) + assert "k0" in snapshot.keys_mapping + assert "k1" in snapshot.keys_mapping + + def test_load_raises_on_missing_file(self, ray_setup): + tq_controller = TransferQueueController.remote() + + with pytest.raises(ray.exceptions.RayTaskError, match="load checkpoint failed"): + ray.get(tq_controller.load_checkpoint.remote("/nonexistent/ckpt.pkl")) + + def test_load_raises_on_corrupt_file(self, ray_setup, tmp_path): + tq_controller = TransferQueueController.remote() + path = tmp_path / "bad.pkl" + path.write_bytes(b"not a pickle") + + with pytest.raises(ray.exceptions.RayTaskError, match="load checkpoint failed"): + ray.get(tq_controller.load_checkpoint.remote(str(path))) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index c8e8f843..9575c509 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -61,6 +61,33 @@ def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tu assert hasattr(sampler, "_states") assert sampler._states == {} + def test_save_load_checkpoint_roundtrip(self): + """save_checkpoint / load_checkpoint restore _states.""" + + class TestSampler(BaseSampler): + def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tuple[list[int], list[int]]: + return ready_indexes[:batch_size], ready_indexes[:batch_size] + + sampler = TestSampler() + sampler._states = {"p0": {"task": {0: {0: ([1, 2], [1, 2])}}}} + + state = sampler.save_checkpoint() + sampler._states = {} + sampler.load_checkpoint(state) + + assert sampler._states == {"p0": {"task": {0: {0: ([1, 2], [1, 2])}}}} + + def test_save_checkpoint_returns_states_key(self): + """save_checkpoint dict must contain '_states'.""" + + class TestSampler(BaseSampler): + def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tuple[list[int], list[int]]: + return [], [] + + sampler = TestSampler() + state = sampler.save_checkpoint() + assert "_states" in state + class TestSequentialSampler: """Test cases for SequentialSampler.""" @@ -1061,6 +1088,37 @@ def test_batch_size_not_divisible_by_n_samples_per_prompt(self): assert "must be a multiple of n_samples_per_prompt" in str(exc_info.value) + def test_save_load_checkpoint_roundtrip(self): + """save_checkpoint / load_checkpoint restore both _states and _balanced_cache.""" + sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=2) + ready_indexes = [0, 1, 2, 3] + + sampler.sample(ready_indexes, 2, task_name="task", partition_id="p0", dp_rank=0, batch_index=0) + + assert len(sampler._balanced_cache) > 0 + assert "p0" in sampler._states + + state = sampler.save_checkpoint() + + new_sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=2) + assert new_sampler._balanced_cache == {} + assert new_sampler._states == {} + + new_sampler.load_checkpoint(state) + + assert new_sampler._states == sampler._states + assert new_sampler._balanced_cache == sampler._balanced_cache + + def test_save_checkpoint_includes_balanced_cache(self): + """save_checkpoint state dict must contain both '_states' and '_balanced_cache'.""" + sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=1) + sampler.sample([0, 1], 2, task_name="task", partition_id="p0", dp_rank=0, batch_index=0) + + state = sampler.save_checkpoint() + + assert "_states" in state + assert "_balanced_cache" in state + class TestKarmarkarKarp: """Test cases for karmarkar_karp and get_seqlen_balanced_partitions utilities.""" diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index 319a46e7..f81bb56b 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -64,6 +64,24 @@ def send_clear(self, client_id, global_indexes): self.socket.send_multipart(msg.serialize()) return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) + def send_save_checkpoint(self, path): + msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT, + sender_id="mock_client_ckpt", + body={"path": path}, + ) + self.socket.send_multipart(msg.serialize()) + return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) + + def send_load_checkpoint(self, path): + msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT, + sender_id="mock_client_ckpt", + body={"path": path}, + ) + self.socket.send_multipart(msg.serialize()) + return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) + def close(self): self.socket.close() self.context.term() @@ -611,3 +629,80 @@ def wrong_len_parser(field_data): assert "data_parser changed the number of elements" in response.body["message"] client.close() + + +def test_save_load_checkpoint_roundtrip(storage_setup, tmp_path): + """Save and load restore data correctly.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + global_indexes = [0, 1] + field_data = { + "input_ids": [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])], + "labels": [torch.tensor([0]), torch.tensor([1])], + } + client.send_put(0, global_indexes, field_data) + + path = str(tmp_path / "su.pkl") + resp = client.send_save_checkpoint(path) + assert resp.request_type == ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE + assert resp.body["success"] is True + + resp = client.send_load_checkpoint(path) + assert resp.request_type == ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE + assert resp.body["success"] is True + + resp = client.send_get(0, global_indexes, ["input_ids"]) + assert resp.request_type == ZMQRequestType.GET_DATA_RESPONSE + torch.testing.assert_close(resp.body["data"]["input_ids"][0], torch.tensor([1, 2, 3])) + torch.testing.assert_close(resp.body["data"]["input_ids"][1], torch.tensor([4, 5, 6])) + + client.close() + + +def test_save_checkpoint_invalid_path(storage_setup): + """Save to an inaccessible path returns failure response.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + resp = client.send_save_checkpoint("/nonexistent_dir/su.pkl") + assert resp.request_type == ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE + assert resp.body["success"] is False + assert "message" in resp.body + + client.close() + + +def test_load_checkpoint_missing_file(storage_setup): + """Load from a non-existent file returns failure response.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + resp = client.send_load_checkpoint("/nonexistent/su.pkl") + assert resp.request_type == ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE + assert resp.body["success"] is False + assert "message" in resp.body + + client.close() + + +def test_load_checkpoint_clears_existing_data(storage_setup, tmp_path): + """Data added after save should not survive a load.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + client.send_put(0, [0], {"val": [torch.tensor([1])]}) + path = str(tmp_path / "su.pkl") + client.send_save_checkpoint(path) + + # Add index 1 after save + client.send_put(0, [1], {"val": [torch.tensor([99])]}) + + # Restore to checkpoint that only had index 0 + client.send_load_checkpoint(path) + + # Index 1 should be gone + resp = client.send_get(0, [1], ["val"]) + assert resp.request_type != ZMQRequestType.GET_DATA_RESPONSE + + client.close() diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 06754278..ad8e8ab9 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -34,6 +34,8 @@ kv_clear, kv_list, kv_put, + load_checkpoint, + save_checkpoint, ) from .metadata import BatchMeta, KVBatchMeta from .sampler import BaseSampler @@ -62,6 +64,11 @@ "async_kv_clear", "KVBatchMeta", ] + + [ + # Checkpoint Interface + "save_checkpoint", + "load_checkpoint", + ] + [ # High-Level StreamingDataLoader Interface "StreamingDataset", diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 45e6303a..35845d74 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -1061,6 +1061,139 @@ def close(self) -> None: except Exception as e: logger.warning(f"Error closing storage manager: {e}") + # ==================== Checkpoint API ==================== + @with_controller_socket + async def async_save_controller_checkpoint( + self, + path: str, + socket: zmq.asyncio.Socket | None = None, + ) -> None: + """Asynchronously save controller state to a file via ZMQ RPC. + + Sends a SAVE_CONTROLLER_CHECKPOINT request to the controller and waits + for acknowledgement. The controller serializes its state directly to + ``path`` in-process. + + Args: + path: Absolute path for the output .pkl file. The caller must + ensure this path is writable from the node running the controller. + socket: ZMQ socket injected by @with_controller_socket. + + Raises: + RuntimeError: If the RPC fails or an unexpected response is received. + """ + try: + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint dump" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in save_controller_checkpoint: {str(e)}") from e + + @with_controller_socket + async def async_load_controller_checkpoint( + self, + path: str, + socket: zmq.asyncio.Socket | None = None, + ) -> None: + """Asynchronously restore controller state from a file via ZMQ RPC. + + Sends a LOAD_CONTROLLER_CHECKPOINT request to the controller and waits + for acknowledgement. The controller deserializes its state directly from + ``path`` in-process, overwriting the current in-memory state. + + Args: + path: Absolute path to a .pkl file previously written by + ``async_save_controller_checkpoint``. The caller must ensure + this path is readable from the node running the controller. + socket: ZMQ socket injected by @with_controller_socket. + + Raises: + RuntimeError: If the RPC fails or an unexpected response is received. + """ + try: + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint restore" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in load_controller_checkpoint: {str(e)}") from e + + async def async_save_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Asynchronously save storage state to a directory via StorageManager. + + Delegates to the underlying StorageManager, which fans out save requests + to all storage units concurrently. + + Args: + checkpoint_dir: Directory under which storage unit files are written. + The StorageManager creates a ``storage_units/`` subdirectory here. + + Raises: + RuntimeError: If the storage manager is not initialized. + NotImplementedError: If the storage backend does not support checkpoint. + """ + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + await self.storage_manager.save_checkpoint(checkpoint_dir) + + async def async_validate_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Validate storage checkpoint compatibility without modifying any state.""" + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + await self.storage_manager.validate_checkpoint(checkpoint_dir) + + async def async_load_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Asynchronously restore storage state from a directory via StorageManager. + + Delegates to the underlying StorageManager, which fans out load requests + to all storage units concurrently. Existing in-memory data is overwritten. + + Args: + checkpoint_dir: Directory previously written by + ``async_save_storage_checkpoint``, containing the + ``storage_units/`` subdirectory and manifest. + + Raises: + RuntimeError: If the storage manager is not initialized. + FileNotFoundError: If the storage unit manifest is missing. + ValueError: If the number of storage units does not match the checkpoint. + """ + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + await self.storage_manager.load_checkpoint(checkpoint_dir) + class TransferQueueClient(AsyncTransferQueueClient): """Synchronous client wrapper for TransferQueue. @@ -1129,6 +1262,11 @@ def wrapper(*args, **kwargs): self._kv_retrieve_meta = _make_sync(self.async_kv_retrieve_meta) self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys) self._kv_list = _make_sync(self.async_kv_list) + self._save_controller_checkpoint = _make_sync(self.async_save_controller_checkpoint) + self._load_controller_checkpoint = _make_sync(self.async_load_controller_checkpoint) + self._save_storage_checkpoint = _make_sync(self.async_save_storage_checkpoint) + self._load_storage_checkpoint = _make_sync(self.async_load_storage_checkpoint) + self._validate_storage_checkpoint = _make_sync(self.async_validate_storage_checkpoint) # ==================== Basic API ==================== def get_meta( @@ -1561,6 +1699,59 @@ def kv_list( return self._kv_list(partition_id=partition_id) + # ==================== Checkpoint API ==================== + def save_controller_checkpoint(self, path: str) -> None: + """Synchronously save controller state to a file via ZMQ RPC. + + Args: + path: Absolute path for the output .pkl file. + + Raises: + RuntimeError: If the RPC fails or an unexpected response is received. + """ + return self._save_controller_checkpoint(path) + + def load_controller_checkpoint(self, path: str) -> None: + """Synchronously restore controller state from a file via ZMQ RPC. + + Args: + path: Absolute path to a .pkl file previously written by + ``save_controller_checkpoint``. + + Raises: + RuntimeError: If the RPC fails or an unexpected response is received. + """ + return self._load_controller_checkpoint(path) + + def save_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Synchronously save storage state to a directory via StorageManager. + + Args: + checkpoint_dir: Directory under which storage unit files are written. + + Raises: + RuntimeError: If the storage manager is not initialized. + NotImplementedError: If the storage backend does not support checkpoint. + """ + return self._save_storage_checkpoint(checkpoint_dir) + + def load_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Synchronously restore storage state from a directory via StorageManager. + + Args: + checkpoint_dir: Directory previously written by ``save_storage_checkpoint``. + + Raises: + RuntimeError: If the storage manager is not initialized. + FileNotFoundError: If the storage unit manifest is missing. + ValueError: If the number of storage units does not match the checkpoint. + """ + return self._load_storage_checkpoint(checkpoint_dir) + + def validate_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Synchronously validate storage checkpoint compatibility without modifying any state.""" + return self._validate_storage_checkpoint(checkpoint_dir) + def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 1182a44c..3fbedd25 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -15,6 +15,7 @@ import copy import os +import pickle import time from collections import defaultdict from dataclasses import dataclass, field @@ -2031,6 +2032,26 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) + elif request_msg.request_type == ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT: + path = request_msg.body["path"] + self.save_checkpoint(path) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"success": True}, + ) + + elif request_msg.request_type == ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT: + path = request_msg.body["path"] + self.load_checkpoint(path) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"success": True}, + ) + self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def get_zmq_server_info(self) -> ZMQServerInfo: @@ -2045,6 +2066,64 @@ def get_config(self) -> DictConfig: """Retrieve the global config of TransferQueue.""" return self.tq_config + def save_checkpoint(self, path: str) -> None: + """Serialize controller state directly to a file. + + Writes in-process to avoid transmitting the payload back over the + Ray object store. + + Args: + path: Absolute path for the output .pkl file. + + Raises: + Exception: If serialization or file I/O fails. + """ + try: + state = { + "controller_id": self.controller_id, + "partitions": {pid: p.to_snapshot() for pid, p in self.partitions.items()}, + "index_manager": { + "partition_to_indexes": dict(copy.deepcopy(self.index_manager.partition_to_indexes)), + "reusable_indexes": list(self.index_manager.reusable_indexes), + "global_index_counter": self.index_manager.global_index_counter, + "allocated_indexes": set(self.index_manager.allocated_indexes), + }, + "sampler": self.sampler.save_checkpoint(), + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"[{self.controller_id}]: dumped to {path}") + except Exception as e: + raise RuntimeError(f"[{self.controller_id}]: save checkpoint failed: {e}") from e + + def load_checkpoint(self, path: str) -> None: + """Restore controller state directly from a file. + + Args: + path: Absolute path to a .pkl file previously written by save_checkpoint(). + + Raises: + Exception: If deserialization or file I/O fails. + """ + try: + with open(path, "rb") as f: + state = pickle.load(f) + + self.controller_id = state["controller_id"] + self.partitions = state["partitions"] + + im = state["index_manager"] + self.index_manager.partition_to_indexes = defaultdict(set, im["partition_to_indexes"]) + self.index_manager.reusable_indexes = im["reusable_indexes"] + self.index_manager.global_index_counter = im["global_index_counter"] + self.index_manager.allocated_indexes = im["allocated_indexes"] + + self.sampler.load_checkpoint(state["sampler"]) + + logger.info(f"[{self.controller_id}]: restored from {path}") + except Exception as e: + raise RuntimeError(f"[{self.controller_id}]: load checkpoint failed: {e}") from e + def register_sampler( self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 82ceacaf..fdb8e310 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import shutil import subprocess import time from importlib import resources +from pathlib import Path from typing import Any, Callable import ray @@ -31,6 +34,7 @@ from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler from transfer_queue.storage.bootstrap import StorageBootstrapProvider +from transfer_queue.storage.managers.simple_storage_manager import AsyncSimpleStorageManager from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.yuanrong_utils import cleanup_yuanrong_resources from transfer_queue.utils.zmq_utils import process_zmq_server_info @@ -1047,3 +1051,135 @@ def get_client(): """Get a TransferQueueClient for using low-level API""" assert _TQ_CLIENT is not None, "Please initialize the TransferQueue first by calling `tq.init()`!" return _TQ_CLIENT + + +# ==================== Checkpoint API ==================== + +_METADATA_FILE = "metadata.json" +_CONTROLLER_FILE = "controller_state.pkl" + + +def save_checkpoint( + checkpoint_dir: str | Path, + *, + include_storage: bool = True, + metadata: dict[str, Any] | None = None, +) -> None: + """Save a full checkpoint of the TransferQueue system state. + + .. note:: + **Multi-node limitation**: checkpoint_dir must reside on a shared network + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes. + Single-node deployments have no such requirement. + + Args: + checkpoint_dir: Directory to save the checkpoint (created if not exists). + include_storage: Whether to save storage backend state. Set to False to + save controller metadata only (e.g. for KV backends where + data persists externally and does not need to be re-saved). + For SimpleStorage (in-memory), this flag is ignored and + storage is always saved — skipping it would cause data loss + on restart. + metadata: User-defined key-value pairs written into metadata.json. + Example: {"time_stamp": 1234567.891234, "step": 10} + + Raises: + RuntimeError: TransferQueue is not initialized. + OSError: Failed to write checkpoint files. + """ + if _TQ_CONTROLLER is None: + raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") + + checkpoint_dir = Path(checkpoint_dir) + tmp_dir = checkpoint_dir.parent / (checkpoint_dir.name + ".tmp") + + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + tmp_dir.mkdir(parents=True) + + try: + client = _maybe_create_tq_client() + + controller_path = tmp_dir / _CONTROLLER_FILE + client.save_controller_checkpoint(str(controller_path)) + logger.info("Controller state saved.") + + if not include_storage and isinstance(client.storage_manager, AsyncSimpleStorageManager): + logger.warning( + "include_storage=False has no effect for SimpleStorage: " + "in-memory data would be lost on restart. Forcing include_storage=True." + ) + include_storage = True + + if include_storage: + try: + client.save_storage_checkpoint(str(tmp_dir)) + except NotImplementedError: + logger.warning("Storage backend does not support checkpoint; storage data will not be saved.") + include_storage = False + + meta_content = { + "storage_saved": include_storage, + "user_metadata": metadata or {}, + } + with open(tmp_dir / _METADATA_FILE, "w") as f: + json.dump(meta_content, f, indent=2) + + if checkpoint_dir.exists(): + shutil.rmtree(checkpoint_dir) + tmp_dir.rename(checkpoint_dir) + + logger.info(f"Checkpoint saved to {checkpoint_dir}") + + except Exception: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + raise + + +def load_checkpoint( + checkpoint_dir: str | Path, +) -> None: + """Restore TransferQueue system state from a checkpoint. + + .. note:: + **Multi-node limitation**: checkpoint_dir must reside on a shared network + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes. + Single-node deployments have no such requirement. + + Args: + checkpoint_dir: Path to the checkpoint directory. + + Raises: + FileNotFoundError: Checkpoint directory or required files do not exist. + RuntimeError: TransferQueue is not initialized, or restore fails. + """ + if _TQ_CONTROLLER is None: + raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") + + checkpoint_dir = Path(checkpoint_dir) + if not checkpoint_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") + + metadata_path = checkpoint_dir / _METADATA_FILE + if not metadata_path.exists(): + raise FileNotFoundError(f"{_METADATA_FILE} not found in {checkpoint_dir}") + + with open(metadata_path) as f: + meta = json.load(f) + + client = _maybe_create_tq_client() + + controller_path = checkpoint_dir / _CONTROLLER_FILE + if not controller_path.exists(): + raise FileNotFoundError(f"{_CONTROLLER_FILE} not found in {checkpoint_dir}") + + if meta.get("storage_saved"): + client.validate_storage_checkpoint(str(checkpoint_dir)) + + client.load_controller_checkpoint(str(controller_path)) + + if meta.get("storage_saved"): + client.load_storage_checkpoint(str(checkpoint_dir)) + + logger.info(f"Checkpoint loaded from {checkpoint_dir}") diff --git a/transfer_queue/sampler/base.py b/transfer_queue/sampler/base.py index c831ecba..173b19da 100644 --- a/transfer_queue/sampler/base.py +++ b/transfer_queue/sampler/base.py @@ -124,3 +124,29 @@ def clear_cache(self, partition_id: str): """ if partition_id in self._states.keys(): self._states.pop(partition_id) + + def save_checkpoint(self) -> dict: + """Return the sampler's serializable state for checkpointing. + + By default this returns ``_states``, the shared cache used by all + samplers. Subclasses that maintain additional runtime state (e.g. + a separate balanced cache) should override this method and include + their extra fields. + + Returns: + A dict that can be passed to ``restore_state`` to recreate the + current sampling state. + """ + return {"_states": self._states} + + def load_checkpoint(self, state: dict) -> None: + """Restore the sampler's state from a checkpoint. + + By default this restores ``_states``. Subclasses that override + ``save_checkpoint`` must also override this method to restore their extra + fields. + + Args: + state: A dict previously returned by ``save_checkpoint``. + """ + self._states = state["_states"] diff --git a/transfer_queue/sampler/seqlen_balanced_sampler.py b/transfer_queue/sampler/seqlen_balanced_sampler.py index 1dd10740..20c904d7 100644 --- a/transfer_queue/sampler/seqlen_balanced_sampler.py +++ b/transfer_queue/sampler/seqlen_balanced_sampler.py @@ -196,6 +196,17 @@ def clear_cache(self, partition_id: str): for k in keys_to_remove: del self._balanced_cache[k] + def save_checkpoint(self) -> dict: + """Return sampler state including the balanced cache.""" + state = super().save_checkpoint() + state["_balanced_cache"] = self._balanced_cache + return state + + def load_checkpoint(self, state: dict) -> None: + """Restore sampler state including the balanced cache.""" + super().load_checkpoint(state) + self._balanced_cache = state["_balanced_cache"] + # Copied from https://github.com/volcengine/verl/blob/468adf22c43b744348051fccd7a5d830c6c3c36a/verl/utils/seqlen_balancing.py def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index f4d545da..2b377990 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -352,6 +352,36 @@ async def clear_data(self, metadata: BatchMeta) -> None: """ raise NotImplementedError("Subclasses must implement clear_data") + async def save_checkpoint(self, checkpoint_dir: str) -> None: + """Save storage state into checkpoint_dir. + + The implementation is backend-specific: each backend decides what to + persist and how to lay out files within checkpoint_dir. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + + async def validate_checkpoint(self, checkpoint_dir: str) -> None: + """Validate that checkpoint_dir is compatible with the current system configuration. + + Must be called before load_checkpoint to catch mismatches (e.g. storage unit + count) before any state is modified. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + + async def load_checkpoint(self, checkpoint_dir: str) -> None: + """Restore storage state from checkpoint_dir. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + def close(self) -> None: """Close all ZMQ sockets/contexts and stop the notify loop.""" diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index 80803522..9c508a37 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -19,6 +19,7 @@ from collections import defaultdict from collections.abc import Mapping from operator import itemgetter +from pathlib import Path from typing import Any, Callable, NamedTuple import torch @@ -40,6 +41,9 @@ TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds +_SU_SUBDIR = "simple_storage" +_SU_INFO_FILE = "storage_unit_info.json" + # Pre-bound decorator for storage-unit socket operations. with_storage_unit_socket = with_zmq_socket( "put_get_socket", @@ -524,6 +528,132 @@ def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]: """ return self.storage_unit_infos + @with_storage_unit_socket + async def _save_single_storage_unit( + self, + path: str, + target_storage_unit: str, + socket: zmq.Socket = None, + ): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + receiver_id=target_storage_unit, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize(), copy=False) + messages = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(messages) + if ( + response_msg.request_type != ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE + or not response_msg.body.get("success") + ): + raise RuntimeError( + f"Storage unit {target_storage_unit} failed to save checkpoint to {path}: " + f"{response_msg.body.get('message', 'unknown error')}" + ) + except Exception as e: + raise RuntimeError( + f"[{self.storage_manager_id}]: Error dumping storage unit {target_storage_unit}: {str(e)}" + ) from e + + @with_storage_unit_socket + async def _load_single_storage_unit( + self, + path: str, + target_storage_unit: str, + socket: zmq.Socket = None, + ): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + receiver_id=target_storage_unit, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize(), copy=False) + messages = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(messages) + if ( + response_msg.request_type != ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE + or not response_msg.body.get("success") + ): + raise RuntimeError( + f"Storage unit {target_storage_unit} failed to load checkpoint from {path}: " + f"{response_msg.body.get('message', 'unknown error')}" + ) + except Exception as e: + raise RuntimeError( + f"[{self.storage_manager_id}]: Error restoring for storage unit {target_storage_unit}: {str(e)}" + ) from e + + async def save_checkpoint(self, checkpoint_dir: str) -> None: + """Dump all storage units to the storage_units/ subdirectory of checkpoint_dir. + + Writes a storage_units/su_info.json manifest for load_checkpoint to use. + Validates current SU count matches the manifest on load. + """ + import json + + su_dir = Path(checkpoint_dir) / _SU_SUBDIR + su_dir.mkdir(parents=True, exist_ok=True) + + su_ids = list(self.storage_unit_infos.keys()) + paths = [str(su_dir / f"su_{pos}_{su_id}.pkl") for pos, su_id in enumerate(su_ids)] + tasks = [ + self._save_single_storage_unit(path, target_storage_unit=su_id) + for su_id, path in zip(su_ids, paths, strict=False) + ] + await asyncio.gather(*tasks) + + su_info_list = [{"position": pos, "storage_unit_id": su_id} for pos, su_id in enumerate(su_ids)] + with open(su_dir / _SU_INFO_FILE, "w") as f: + json.dump(su_info_list, f) + + logger.info(f"[{self.storage_manager_id}]: saved {len(su_ids)} storage units to {su_dir}") + + async def validate_checkpoint(self, checkpoint_dir: str) -> None: + """Validate that the checkpoint is compatible before any state is modified.""" + import json + + su_dir = Path(checkpoint_dir) / _SU_SUBDIR + su_info_path = su_dir / _SU_INFO_FILE + if not su_info_path.exists(): + raise FileNotFoundError(f"Storage unit manifest not found: {su_info_path}") + + with open(su_info_path) as f: + su_info_list = json.load(f) + + su_ids = list(self.storage_unit_infos.keys()) + if len(su_ids) != len(su_info_list): + raise ValueError( + f"Storage unit count mismatch: checkpoint has {len(su_info_list)}, current system has {len(su_ids)}." + ) + + async def load_checkpoint(self, checkpoint_dir: str) -> None: + """Restore all storage units from the storage_units/ subdirectory of checkpoint_dir.""" + import json + + su_dir = Path(checkpoint_dir) / _SU_SUBDIR + su_info_path = su_dir / _SU_INFO_FILE + + with open(su_info_path) as f: + su_info_list = json.load(f) + + entries = sorted(su_info_list, key=lambda e: e["position"]) + su_ids = list(self.storage_unit_infos.keys()) + tasks = [ + self._load_single_storage_unit( + str(su_dir / f"su_{entry['position']}_{entry['storage_unit_id']}.pkl"), + target_storage_unit=su_ids[entry["position"]], + ) + for entry in entries + ] + await asyncio.gather(*tasks) + + logger.info(f"[{self.storage_manager_id}]: restored {len(su_ids)} storage units from {su_dir}") + def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" super().close() diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index e70648ea..e9328d07 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import pickle import time import weakref from threading import Event, Thread @@ -308,6 +309,10 @@ def _worker_routine(self) -> None: response_msg = self._handle_clear(request_msg) elif operation == ZMQRequestType.GET_METRICS: # type: ignore[arg-type] response_msg = self._handle_get_metrics() + elif operation == ZMQRequestType.SAVE_STORAGE_CHECKPOINT: # type: ignore[arg-type] + response_msg = self._handle_save_checkpoint(request_msg) + elif operation == ZMQRequestType.LOAD_STORAGE_CHECKPOINT: # type: ignore[arg-type] + response_msg = self._handle_load_checkpoint(request_msg) else: response_msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] @@ -535,11 +540,101 @@ def _handle_get_metrics(self) -> ZMQMessage: metrics["op_stats"] = op_stats return ZMQMessage.create( - request_type=ZMQRequestType.METRICS_RESPONSE, + request_type=ZMQRequestType.METRICS_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body=metrics, ) + def _handle_save_checkpoint(self, data_parts) -> ZMQMessage: + """Serialize storage unit data directly to a file. + + Args: + data_parts: ZMQMessage from client, containing ``path`` in body: + absolute path for the output .pkl file. The caller must ensure + this path is reachable from the node running this actor + (shared filesystem required for multi-node setups). + + Returns: + ZMQMessage with ``success=True`` on success, or ``success=False`` + and ``message`` containing the error string on failure. + """ + path = data_parts.body["path"] + try: + state = { + "storage_unit_id": self.storage_unit_id, + "storage_unit_size": self.storage_unit_size, + "field_data": self.storage_data.field_data, + "active_keys": self.storage_data._active_keys, + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"[{self.storage_unit_id}]: saved checkpoint to {path}") + return ZMQMessage.create( + request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"success": True}, + ) + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: save checkpoint failed: {e}") + return ZMQMessage.create( + request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"success": False, "message": str(e)}, + ) + + def _handle_load_checkpoint(self, data_parts) -> ZMQMessage: + """Restore storage unit data directly from a file. + + Args: + data_parts: ZMQMessage from client, containing ``path`` in body: + absolute path to a .pkl file previously written by + ``_handle_save_checkpoint``. The caller must ensure this path + is reachable from the node running this actor + (shared filesystem required for multi-node setups). + + Returns: + ZMQMessage with ``success=True`` on success, or ``success=False`` + and ``message`` containing the error string on failure. + """ + path = data_parts.body["path"] + try: + with open(path, "rb") as f: + data = pickle.load(f) + + if data["storage_unit_size"] != self.storage_unit_size: + logger.warning( + f"[{self.storage_unit_id}]: storage_unit_size mismatch — " + f"checkpoint={data['storage_unit_size']}, current={self.storage_unit_size}" + ) + + if self.storage_data._active_keys: + logger.warning( + f"[{self.storage_unit_id}]: overwriting {len(self.storage_data._active_keys)} " + f"existing keys with checkpoint data from {path}" + ) + self.storage_data.field_data.clear() + self.storage_data._active_keys.clear() + self.storage_data.field_data = data["field_data"] + self.storage_data._active_keys = data["active_keys"] + + logger.info( + f"[{self.storage_unit_id}]: loaded checkpoint from {path} — " + f"{len(data['active_keys'])} keys, {len(data['field_data'])} fields" + ) + return ZMQMessage.create( + request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"success": True}, + ) + + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: load checkpoint failed: {e}") + return ZMQMessage.create( + request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"success": False, "message": str(e)}, + ) + @staticmethod def _cumulative_bucket_counts(hist) -> list[float]: """Build cumulative counts from a prometheus_client Histogram's non-cumulative buckets.""" diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 4fe32f0a..dd594dbb 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -101,6 +101,16 @@ class ZMQRequestType(ExplicitEnum): GET_METRICS = "GET_METRICS" METRICS_RESPONSE = "METRICS_RESPONSE" + # CHECKPOINT + SAVE_CONTROLLER_CHECKPOINT = "SAVE_CONTROLLER_CHECKPOINT" + SAVE_CONTROLLER_CHECKPOINT_RESPONSE = "SAVE_CONTROLLER_CHECKPOINT_RESPONSE" + SAVE_STORAGE_CHECKPOINT = "SAVE_STORAGE_CHECKPOINT" + SAVE_STORAGE_CHECKPOINT_RESPONSE = "SAVE_STORAGE_CHECKPOINT_RESPONSE" + LOAD_CONTROLLER_CHECKPOINT = "LOAD_CONTROLLER_CHECKPOINT" + LOAD_CONTROLLER_CHECKPOINT_RESPONSE = "LOAD_CONTROLLER_CHECKPOINT_RESPONSE" + LOAD_STORAGE_CHECKPOINT = "LOAD_STORAGE_CHECKPOINT" + LOAD_STORAGE_CHECKPOINT_RESPONSE = "LOAD_STORAGE_CHECKPOINT_RESPONSE" + class ZMQServerInfo: """