From 255daf88a4184a2dcac98b6b87ca739eaee95f37 Mon Sep 17 00:00:00 2001 From: yxstev Date: Fri, 12 Jun 2026 15:24:34 +0800 Subject: [PATCH 1/6] provide save/load checkpoint interfaces Signed-off-by: yxstev --- tests/e2e/test_checkpoint_e2e.py | 395 ++++++++++++++++++ transfer_queue/__init__.py | 7 + transfer_queue/client.py | 90 ++++ transfer_queue/controller.py | 84 ++++ transfer_queue/interface.py | 130 ++++++ transfer_queue/storage/managers/base.py | 19 + .../managers/simple_storage_manager.py | 119 ++++++ transfer_queue/storage/simple_storage.py | 94 ++++- transfer_queue/utils/zmq_utils.py | 10 + 9 files changed, 947 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/test_checkpoint_e2e.py diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py new file mode 100644 index 00000000..3425ff95 --- /dev/null +++ b/tests/e2e/test_checkpoint_e2e.py @@ -0,0 +1,395 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for save_checkpoint and load_checkpoint. + +Run with: + pytest tests/e2e/test_checkpoint_e2e.py -v +""" + +import json +import os + +import pytest +import ray +import torch +from omegaconf import OmegaConf +from tensordict import TensorDict + +import transfer_queue as tq + +os.environ["RAY_DEDUP_LOGS"] = "0" + +_TQ_CONFIG = OmegaConf.create( + { + "controller": {"polling_mode": True}, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, + }, + } +) + + +@pytest.fixture(scope="module") +def ray_init(): + if not ray.is_initialized(): + ray.init(namespace="TestCheckpointE2E") + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def tq_system(ray_init): + tq.init(_TQ_CONFIG) + yield + tq.close() + + +@pytest.fixture +def controller(tq_system): + return ray.get_actor("TransferQueueController", namespace="transfer_queue") + + +@pytest.fixture(autouse=True) +def cleanup_partitions(controller): + yield + try: + for pid in ray.get(controller.list_partitions.remote()): + ray.get(controller.clear_partition.remote(pid)) + except Exception: + pass + + +@pytest.fixture +def checkpoint_dir(tmp_path): + return tmp_path / "checkpoint" + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _put_batch(keys, partition_id, input_ids, attention_mask, tags=None): + fields = TensorDict( + {"input_ids": input_ids, "attention_mask": attention_mask}, + batch_size=len(keys), + ) + if tags is None: + tags = [{} for _ in keys] + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) + + +def _get_batch(keys, partition_id): + return tq.kv_batch_get(keys=keys, partition_id=partition_id) + + +def assert_tensor_equal(tensor_a, tensor_b, msg=""): + """Assert two tensors are equal, handling nested vs dense comparisons.""" + if (isinstance(tensor_a, torch.Tensor) and tensor_a.is_nested) or ( + isinstance(tensor_b, torch.Tensor) and tensor_b.is_nested + ): + seq_a = list(tensor_a) + seq_b = list(tensor_b) + assert len(seq_a) == len(seq_b), f"{msg} Length mismatch: {len(seq_a)} vs {len(seq_b)}" + for t1, t2 in zip(seq_a, seq_b, strict=True): + assert torch.equal(t1, t2), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + else: + assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + + +# --------------------------------------------------------------------------- +# basic save / load roundtrip +# --------------------------------------------------------------------------- + + +class TestCheckpointRoundtrip: + def test_save_creates_expected_files(self, tq_system, checkpoint_dir): + keys = ["k0", "k1"] + partition_id = "p0" + _put_batch(keys, partition_id, torch.tensor([[1, 2], [3, 4]]), torch.ones(2, 2)) + + tq.save_checkpoint(checkpoint_dir) + + assert (checkpoint_dir / "metadata.json").exists() + assert (checkpoint_dir / "controller_state.pkl").exists() + + with open(checkpoint_dir / "metadata.json") as f: + info = json.load(f) + + assert info["storage_saved"] is True + su_dir = checkpoint_dir / "storage_units" + assert su_dir.exists() + assert (su_dir / "su_info.json").exists() + + def test_metadata_json_content(self, tq_system, checkpoint_dir): + keys = ["m0"] + _put_batch(keys, "p_meta", torch.tensor([[10, 20]]), torch.ones(1, 2)) + + tq.save_checkpoint(checkpoint_dir, metadata={"iteration": 42, "loss": 0.5}) + + with open(checkpoint_dir / "metadata.json") as f: + meta = json.load(f) + + assert meta["user_metadata"]["iteration"] == 42 + assert meta["user_metadata"]["loss"] == pytest.approx(0.5) + assert "storage_saved" in meta + + def test_load_restores_controller_partitions(self, tq_system, checkpoint_dir, controller): + keys = ["a0", "a1", "a2"] + partition_id = "p_ctrl" + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + tags = [{"idx": i} for i in range(3)] + _put_batch(keys, partition_id, input_ids, torch.ones(3, 3), tags) + + tq.save_checkpoint(checkpoint_dir) + + # wipe controller state + ray.get(controller.clear_partition.remote(partition_id)) + assert ray.get(controller.list_partitions.remote()) == [] + + tq.load_checkpoint(checkpoint_dir) + + # partition must be back + partitions = ray.get(controller.list_partitions.remote()) + assert partition_id in partitions + + # key-to-global-index mapping must be intact + snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) + for key in keys: + assert key in snapshot.keys_mapping + + # tags must be intact + for i, key in enumerate(keys): + gidx = snapshot.keys_mapping[key] + assert snapshot.custom_meta[gidx]["idx"] == i + + def test_load_restores_storage_data(self, tq_system, checkpoint_dir, controller): + keys = ["s0", "s1"] + partition_id = "p_storage" + input_ids = torch.tensor([[10, 20], [30, 40]]) + attention_mask = torch.ones(2, 2) + _put_batch(keys, partition_id, input_ids, attention_mask) + + tq.save_checkpoint(checkpoint_dir) + + # clear both controller and storage state so load has to restore from scratch + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + retrieved = _get_batch(keys, partition_id) + assert_tensor_equal(retrieved["input_ids"], input_ids) + assert_tensor_equal(retrieved["attention_mask"], attention_mask) + + def test_load_restores_multiple_partitions(self, tq_system, checkpoint_dir, controller): + for i in range(3): + _put_batch( + [f"p{i}_k0", f"p{i}_k1"], + f"part_{i}", + torch.full((2, 4), i, dtype=torch.long), + torch.ones(2, 4), + ) + + tq.save_checkpoint(checkpoint_dir) + + for i in range(3): + ray.get(controller.clear_partition.remote(f"part_{i}")) + + tq.load_checkpoint(checkpoint_dir) + + for i in range(3): + retrieved = tq.kv_batch_get( + keys=[f"p{i}_k0", f"p{i}_k1"], + partition_id=f"part_{i}", + select_fields=["input_ids"], + ) + assert_tensor_equal(retrieved["input_ids"], torch.full((2, 4), i, dtype=torch.long)) + + +# --------------------------------------------------------------------------- +# include_storage=False +# --------------------------------------------------------------------------- + + +class TestCheckpointMetadataOnly: + def test_save_include_storage_false_simplestorage_override(self, tq_system, checkpoint_dir): + """For SimpleStorage, include_storage=False is overridden to True because in-memory + data would be lost on restart. storage_saved must be True and storage_units must exist.""" + _put_batch(["n0"], "p_nometa", torch.tensor([[1, 2]]), torch.ones(1, 2)) + + tq.save_checkpoint(checkpoint_dir, include_storage=False) + + with open(checkpoint_dir / "metadata.json") as f: + info = json.load(f) + + assert info["storage_saved"] is True + assert (checkpoint_dir / "storage_units").exists() + + def test_load_after_include_storage_false_restores_both(self, tq_system, checkpoint_dir, controller): + """With SimpleStorage, include_storage=False is forced True, so both controller and + storage are saved and restored.""" + keys = ["n0", "n1"] + partition_id = "p_nometa2" + input_ids = torch.tensor([[5, 6], [7, 8]]) + _put_batch(keys, partition_id, input_ids, torch.ones(2, 2)) + + tq.save_checkpoint(checkpoint_dir, include_storage=False) + + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + partitions = ray.get(controller.list_partitions.remote()) + assert partition_id in partitions + + snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) + for key in keys: + assert key in snapshot.keys_mapping + + retrieved = _get_batch(keys, partition_id) + assert_tensor_equal(retrieved["input_ids"], input_ids) + + +# --------------------------------------------------------------------------- +# error handling +# --------------------------------------------------------------------------- + + +class TestCheckpointErrors: + def test_save_raises_if_not_initialized(self, tmp_path): + # call save_checkpoint before tq.init() in a fresh module state + import transfer_queue.interface as iface + + original = iface._TQ_CONTROLLER + try: + iface._TQ_CONTROLLER = None + with pytest.raises(RuntimeError, match="not initialized"): + tq.save_checkpoint(tmp_path / "ck") + finally: + iface._TQ_CONTROLLER = original + + def test_load_raises_if_not_initialized(self, tmp_path): + import transfer_queue.interface as iface + + original = iface._TQ_CONTROLLER + try: + iface._TQ_CONTROLLER = None + with pytest.raises(RuntimeError, match="not initialized"): + tq.load_checkpoint(tmp_path / "ck") + finally: + iface._TQ_CONTROLLER = original + + def test_load_raises_if_dir_missing(self, tq_system, tmp_path): + with pytest.raises(FileNotFoundError): + tq.load_checkpoint(tmp_path / "nonexistent") + + def test_load_raises_if_metadata_missing(self, tq_system, tmp_path): + ck = tmp_path / "ck" + ck.mkdir() + with pytest.raises(FileNotFoundError, match="metadata.json"): + tq.load_checkpoint(ck) + + def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, checkpoint_dir): + _put_batch(["e0"], "p_err", torch.tensor([[1, 2]]), torch.ones(1, 2)) + tq.save_checkpoint(checkpoint_dir) + + # tamper: add a fake extra SU entry in su_info.json so count differs + su_info_path = checkpoint_dir / "storage_units" / "su_info.json" + with open(su_info_path) as f: + su_info = json.load(f) + su_info.append({"position": 99, "storage_unit_id": "fake"}) + with open(su_info_path, "w") as f: + json.dump(su_info, f) + + with pytest.raises(ValueError, match="count mismatch"): + tq.load_checkpoint(checkpoint_dir) + + def test_no_partial_state_on_failed_save(self, tq_system, tmp_path): + """A failed save must not leave a partial directory.""" + _put_batch(["f0"], "p_fail", torch.tensor([[1, 2]]), torch.ones(1, 2)) + + ck = tmp_path / "ck" + + import unittest.mock as mock + + with mock.patch( + "transfer_queue.client.TransferQueueClient.save_storage_checkpoint", + side_effect=RuntimeError("simulated dump failure"), + ): + with pytest.raises(RuntimeError, match="simulated dump failure"): + tq.save_checkpoint(ck) + + assert not ck.exists(), "Partial checkpoint directory should have been cleaned up" + assert not (tmp_path / "ck.tmp").exists(), "Temp directory should have been cleaned up" + + +# --------------------------------------------------------------------------- +# data variety +# --------------------------------------------------------------------------- + + +class TestCheckpointDataVariety: + def test_non_tensor_fields_roundtrip(self, tq_system, checkpoint_dir, controller): + """String fields should survive save/load.""" + from tensordict import NonTensorStack + + keys = ["t0", "t1"] + partition_id = "p_str" + fields = TensorDict( + { + "input_ids": torch.tensor([[1, 2], [3, 4]]), + "text": NonTensorStack("hello", "world"), + }, + batch_size=2, + ) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}]) + + tq.save_checkpoint(checkpoint_dir) + + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["input_ids"]) + assert_tensor_equal(retrieved["input_ids"], torch.tensor([[1, 2], [3, 4]])) + + def test_nested_tensor_fields_roundtrip(self, tq_system, checkpoint_dir, controller): + """Variable-length (jagged) tensor fields should survive save/load.""" + keys = ["j0", "j1", "j2"] + partition_id = "p_jagged" + for i, key in enumerate(keys): + seq = torch.arange(i + 1, dtype=torch.float).unsqueeze(0) + tq.kv_put( + key=key, + partition_id=partition_id, + fields=TensorDict({"seq": seq}, batch_size=1), + tag=None, + ) + + tq.save_checkpoint(checkpoint_dir) + + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["seq"]) + for i, component in enumerate(retrieved["seq"].unbind()): + assert_tensor_equal(component, torch.arange(i + 1, dtype=torch.float)) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 06754278..ad8e8ab9 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -34,6 +34,8 @@ kv_clear, kv_list, kv_put, + load_checkpoint, + save_checkpoint, ) from .metadata import BatchMeta, KVBatchMeta from .sampler import BaseSampler @@ -62,6 +64,11 @@ "async_kv_clear", "KVBatchMeta", ] + + [ + # Checkpoint Interface + "save_checkpoint", + "load_checkpoint", + ] + [ # High-Level StreamingDataLoader Interface "StreamingDataset", diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 45e6303a..5d919954 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -1061,6 +1061,75 @@ def close(self) -> None: except Exception as e: logger.warning(f"Error closing storage manager: {e}") + # ==================== Checkpoint API ==================== + @with_controller_socket + async def async_save_controller_checkpoint( + self, + path: str, + socket: zmq.asyncio.Socket | None = None, + ) -> None: + """Send SAVE_CHECKPOINT to controller and wait for response.""" + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint dump" + ) + if not response_msg.body.get("success"): + raise RuntimeError(f"[{self.client_id}]: Controller failed to dump checkpoint to {path}") + + @with_controller_socket + async def async_load_controller_checkpoint( + self, + path: str, + socket: zmq.asyncio.Socket | None = None, + ) -> None: + """Send LOAD_CHECKPOINT to controller and wait for response.""" + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint restore" + ) + if not response_msg.body.get("success"): + raise RuntimeError(f"[{self.client_id}]: Controller failed to restore checkpoint from {path}") + + async def async_save_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Save storage state into checkpoint_dir via StorageManager.""" + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + await self.storage_manager.save_checkpoint(checkpoint_dir) + + async def async_load_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Restore storage state from checkpoint_dir via StorageManager.""" + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + await self.storage_manager.load_checkpoint(checkpoint_dir) + class TransferQueueClient(AsyncTransferQueueClient): """Synchronous client wrapper for TransferQueue. @@ -1129,6 +1198,10 @@ def wrapper(*args, **kwargs): self._kv_retrieve_meta = _make_sync(self.async_kv_retrieve_meta) self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys) self._kv_list = _make_sync(self.async_kv_list) + self._save_controller_checkpoint = _make_sync(self.async_save_controller_checkpoint) + self._load_controller_checkpoint = _make_sync(self.async_load_controller_checkpoint) + self._save_storage_checkpoint = _make_sync(self.async_save_storage_checkpoint) + self._load_storage_checkpoint = _make_sync(self.async_load_storage_checkpoint) # ==================== Basic API ==================== def get_meta( @@ -1561,6 +1634,23 @@ def kv_list( return self._kv_list(partition_id=partition_id) + # ==================== Checkpoint API ==================== + def save_controller_checkpoint(self, path: str) -> None: + """Synchronously dump controller state to a file via ZMQ RPC.""" + return self._save_controller_checkpoint(path) + + def load_controller_checkpoint(self, path: str) -> None: + """Synchronously restore controller state from a file via ZMQ RPC.""" + return self._load_controller_checkpoint(path) + + def save_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Synchronously save storage state into checkpoint_dir.""" + return self._save_storage_checkpoint(checkpoint_dir) + + def load_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Synchronously restore storage state from checkpoint_dir.""" + return self._load_storage_checkpoint(checkpoint_dir) + def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 1182a44c..d7c1abf9 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -15,6 +15,7 @@ import copy import os +import pickle import time from collections import defaultdict from dataclasses import dataclass, field @@ -2031,6 +2032,26 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) + elif request_msg.request_type == ZMQRequestType.SAVE_CHECKPOINT: + path = request_msg.body["path"] + success = self.save_checkpoint(path) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_CHECKPOINT_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"success": success}, + ) + + elif request_msg.request_type == ZMQRequestType.LOAD_CHECKPOINT: + path = request_msg.body["path"] + success = self.load_checkpoint(path) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_CHECKPOINT_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"success": success}, + ) + self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def get_zmq_server_info(self) -> ZMQServerInfo: @@ -2045,6 +2066,69 @@ def get_config(self) -> DictConfig: """Retrieve the global config of TransferQueue.""" return self.tq_config + def save_checkpoint(self, path: str) -> bool: + """Serialize controller state directly to a file. + + Writes in-process to avoid transmitting the payload back over the + Ray object store — only a bool ACK is returned to the caller. + + Args: + path: Absolute path for the output .pkl file. + + Returns: + True on success, False on failure. + """ + try: + state = { + "controller_id": self.controller_id, + "partitions": {pid: p.to_snapshot() for pid, p in self.partitions.items()}, + "index_manager": { + "partition_to_indexes": dict(copy.deepcopy(self.index_manager.partition_to_indexes)), + "reusable_indexes": list(self.index_manager.reusable_indexes), + "global_index_counter": self.index_manager.global_index_counter, + "allocated_indexes": set(self.index_manager.allocated_indexes), + }, + "sampler": self.sampler.get_state() if hasattr(self.sampler, "get_state") else None, + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"[{self.controller_id}]: dumped to {path}") + return True + except Exception as e: + logger.error(f"[{self.controller_id}]: dump_to_file failed: {e}") + return False + + def load_checkpoint(self, path: str) -> bool: + """Restore controller state directly from a file. + + Args: + path: Absolute path to a .pkl file previously written by dump_to_file(). + + Returns: + True on success, False on failure. + """ + try: + with open(path, "rb") as f: + state = pickle.load(f) + + self.controller_id = state["controller_id"] + self.partitions = state["partitions"] + + im = state["index_manager"] + self.index_manager.partition_to_indexes = defaultdict(set, im["partition_to_indexes"]) + self.index_manager.reusable_indexes = im["reusable_indexes"] + self.index_manager.global_index_counter = im["global_index_counter"] + self.index_manager.allocated_indexes = im["allocated_indexes"] + + if state["sampler"] is not None and hasattr(self.sampler, "restore_state"): + self.sampler.restore_state(state["sampler"]) + + logger.info(f"[{self.controller_id}]: restored from {path}") + return True + except Exception as e: + logger.error(f"[{self.controller_id}]: restore_from_file failed: {e}") + return False + def register_sampler( self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 82ceacaf..a0e7e4ef 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import shutil import subprocess import time from importlib import resources +from pathlib import Path from typing import Any, Callable import ray @@ -31,6 +34,7 @@ from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler from transfer_queue.storage.bootstrap import StorageBootstrapProvider +from transfer_queue.storage.managers.simple_storage_manager import AsyncSimpleStorageManager from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.yuanrong_utils import cleanup_yuanrong_resources from transfer_queue.utils.zmq_utils import process_zmq_server_info @@ -1047,3 +1051,129 @@ def get_client(): """Get a TransferQueueClient for using low-level API""" assert _TQ_CLIENT is not None, "Please initialize the TransferQueue first by calling `tq.init()`!" return _TQ_CLIENT + + +# ==================== Checkpoint API ==================== + +_METADATA_FILE = "metadata.json" +_CONTROLLER_FILE = "controller_state.pkl" + + +def save_checkpoint( + checkpoint_dir: str | Path, + *, + include_storage: bool = True, + metadata: dict[str, Any] | None = None, +) -> None: + """Save a full checkpoint of the TransferQueue system state. + + .. note:: + **Multi-node limitation**: checkpoint_dir must reside on a shared network + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes running + storage units. Single-node deployments have no such requirement. + + Args: + checkpoint_dir: Directory to save the checkpoint (created if not exists). + include_storage: Whether to save storage backend state. Set to False to + save controller metadata only (e.g. for KV backends where + data persists externally and does not need to be re-saved). + metadata: User-defined key-value pairs written into metadata.json. + Example: {"time_stamp": 1234567.891234, "step": 10} + + Raises: + RuntimeError: TransferQueue is not initialized. + OSError: Failed to write checkpoint files. + """ + if _TQ_CONTROLLER is None: + raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") + + checkpoint_dir = Path(checkpoint_dir) + tmp_dir = checkpoint_dir.parent / (checkpoint_dir.name + ".tmp") + + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + tmp_dir.mkdir(parents=True) + + try: + client = _maybe_create_tq_client() + + controller_path = tmp_dir / _CONTROLLER_FILE + client.save_controller_checkpoint(str(controller_path)) + logger.info("Controller state saved.") + + if not include_storage and isinstance(client.storage_manager, AsyncSimpleStorageManager): + logger.warning( + "include_storage=False has no effect for SimpleStorage: " + "in-memory data would be lost on restart. Forcing include_storage=True." + ) + include_storage = True + + if include_storage: + try: + client.save_storage_checkpoint(str(tmp_dir)) + except NotImplementedError: + logger.warning("Storage backend does not support checkpoint; storage data will not be saved.") + include_storage = False + + meta_content = { + "storage_saved": include_storage, + "user_metadata": metadata or {}, + } + with open(tmp_dir / _METADATA_FILE, "w") as f: + json.dump(meta_content, f, indent=2) + + if checkpoint_dir.exists(): + shutil.rmtree(checkpoint_dir) + tmp_dir.rename(checkpoint_dir) + + logger.info(f"Checkpoint saved to {checkpoint_dir}") + + except Exception: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + raise + + +def load_checkpoint( + checkpoint_dir: str | Path, +) -> None: + """Restore TransferQueue system state from a checkpoint. + + .. note:: + **Multi-node limitation**: checkpoint_dir must reside on a shared network + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes running + storage units. Single-node deployments have no such requirement. + + Args: + checkpoint_dir: Path to the checkpoint directory. + + Raises: + FileNotFoundError: Checkpoint directory or required files do not exist. + RuntimeError: TransferQueue is not initialized, or restore fails. + """ + if _TQ_CONTROLLER is None: + raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") + + checkpoint_dir = Path(checkpoint_dir) + if not checkpoint_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") + + metadata_path = checkpoint_dir / _METADATA_FILE + if not metadata_path.exists(): + raise FileNotFoundError(f"{_METADATA_FILE} not found in {checkpoint_dir}") + + with open(metadata_path) as f: + meta = json.load(f) + + client = _maybe_create_tq_client() + + controller_path = checkpoint_dir / _CONTROLLER_FILE + if not controller_path.exists(): + raise FileNotFoundError(f"{_CONTROLLER_FILE} not found in {checkpoint_dir}") + + client.load_controller_checkpoint(str(controller_path)) + + if meta.get("storage_saved"): + client.load_storage_checkpoint(str(checkpoint_dir)) + + logger.info(f"Checkpoint loaded from {checkpoint_dir}") diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index f4d545da..e6b0faf4 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -352,6 +352,25 @@ async def clear_data(self, metadata: BatchMeta) -> None: """ raise NotImplementedError("Subclasses must implement clear_data") + async def save_checkpoint(self, checkpoint_dir: str) -> None: + """Save storage state into checkpoint_dir. + + The implementation is backend-specific: each backend decides what to + persist and how to lay out files within checkpoint_dir. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + + async def load_checkpoint(self, checkpoint_dir: str) -> None: + """Restore storage state from checkpoint_dir. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + def close(self) -> None: """Close all ZMQ sockets/contexts and stop the notify loop.""" diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index 80803522..5e2746a0 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -19,6 +19,7 @@ from collections import defaultdict from collections.abc import Mapping from operator import itemgetter +from pathlib import Path from typing import Any, Callable, NamedTuple import torch @@ -40,6 +41,9 @@ TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds +_SU_SUBDIR = "storage_units" +_SU_INFO_FILE = "su_info.json" + # Pre-bound decorator for storage-unit socket operations. with_storage_unit_socket = with_zmq_socket( "put_get_socket", @@ -524,6 +528,121 @@ def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]: """ return self.storage_unit_infos + @with_storage_unit_socket + async def _save_single_storage_unit( + self, + path: str, + target_storage_unit: str, + socket: zmq.Socket = None, + ): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + receiver_id=target_storage_unit, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize(), copy=False) + messages = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(messages) + if ( + response_msg.request_type != ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE + or not response_msg.body.get("success") + ): + raise RuntimeError( + f"Storage unit {target_storage_unit} failed to save checkpoint to {path}: " + f"{response_msg.body.get('message', 'unknown error')}" + ) + except Exception as e: + logger.error(f"[{self.storage_manager_id}]: Error dumping storage unit {target_storage_unit}: {str(e)}") + raise + + @with_storage_unit_socket + async def _load_single_storage_unit( + self, + path: str, + target_storage_unit: str, + socket: zmq.Socket = None, + ): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + receiver_id=target_storage_unit, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize(), copy=False) + messages = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(messages) + if ( + response_msg.request_type != ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE + or not response_msg.body.get("success") + ): + raise RuntimeError( + f"Storage unit {target_storage_unit} failed to load checkpoint from {path}: " + f"{response_msg.body.get('message', 'unknown error')}" + ) + except Exception as e: + logger.error( + f"[{self.storage_manager_id}]: Error restoring for storage unit {target_storage_unit}: {str(e)}" + ) + raise + + async def save_checkpoint(self, checkpoint_dir: str) -> None: + """Dump all storage units to the storage_units/ subdirectory of checkpoint_dir. + + Writes a storage_units/su_info.json manifest for load_checkpoint to use. + Validates current SU count matches the manifest on load. + """ + import json + + su_dir = Path(checkpoint_dir) / _SU_SUBDIR + su_dir.mkdir(parents=True, exist_ok=True) + + su_ids = list(self.storage_unit_infos.keys()) + paths = [str(su_dir / f"su_{pos}_{su_id}.pkl") for pos, su_id in enumerate(su_ids)] + tasks = [ + self._save_single_storage_unit(path, target_storage_unit=su_id) + for su_id, path in zip(su_ids, paths, strict=False) + ] + await asyncio.gather(*tasks) + + su_info_list = [{"position": pos, "storage_unit_id": su_id} for pos, su_id in enumerate(su_ids)] + with open(su_dir / _SU_INFO_FILE, "w") as f: + json.dump(su_info_list, f) + + logger.info(f"[{self.storage_manager_id}]: saved {len(su_ids)} storage units to {su_dir}") + + async def load_checkpoint(self, checkpoint_dir: str) -> None: + """Restore all storage units from the storage_units/ subdirectory of checkpoint_dir.""" + import json + + su_dir = Path(checkpoint_dir) / _SU_SUBDIR + su_info_path = su_dir / _SU_INFO_FILE + if not su_info_path.exists(): + raise FileNotFoundError(f"Storage unit manifest not found: {su_info_path}") + + with open(su_info_path) as f: + su_info_list = json.load(f) + + su_ids = list(self.storage_unit_infos.keys()) + if len(su_ids) != len(su_info_list): + raise ValueError( + f"Storage unit count mismatch: checkpoint has {len(su_info_list)}, current system has {len(su_ids)}." + ) + + entries = sorted(su_info_list, key=lambda e: e["position"]) + tasks = [ + self._load_single_storage_unit( + str(su_dir / f"su_{entry['position']}_{entry['storage_unit_id']}.pkl"), + target_storage_unit=su_ids[entry["position"]], + ) + for entry in entries + ] + await asyncio.gather(*tasks) + + logger.info(f"[{self.storage_manager_id}]: restored {len(su_ids)} storage units from {su_dir}") + def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" super().close() diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index e70648ea..90378f78 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import pickle import time import weakref from threading import Event, Thread @@ -308,6 +309,10 @@ def _worker_routine(self) -> None: response_msg = self._handle_clear(request_msg) elif operation == ZMQRequestType.GET_METRICS: # type: ignore[arg-type] response_msg = self._handle_get_metrics() + elif operation == ZMQRequestType.SAVE_STORAGE_CHECKPOINT: # type: ignore[arg-type] + response_msg = self._handle_save_checkpoint(request_msg) + elif operation == ZMQRequestType.LOAD_STORAGE_CHECKPOINT: # type: ignore[arg-type] + response_msg = self._handle_load_checkpoint(request_msg) else: response_msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] @@ -535,11 +540,98 @@ def _handle_get_metrics(self) -> ZMQMessage: metrics["op_stats"] = op_stats return ZMQMessage.create( - request_type=ZMQRequestType.METRICS_RESPONSE, + request_type=ZMQRequestType.METRICS_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body=metrics, ) + def _handle_save_checkpoint(self, data_parts) -> ZMQMessage: + """ + Serialize storage unit data directly to a file. + + Writes data in-process to avoid transmitting the payload back over the + Ray object store — only a bool ACK is returned to the caller. + + Args: + path: data_parts: ZMQMessage from client, including + absolute path for the output .pkl file. + The caller must ensure this path is reachable from the node + running this actor (shared filesystem required for multi-node setups). + + Returns: + Checkpoint dump success response ZMQMessage. + """ + path = data_parts.body["path"] + try: + state = { + "storage_unit_id": self.storage_unit_id, + "storage_unit_size": self.storage_unit_size, + "field_data": self.storage_data.field_data, + "active_keys": self.storage_data._active_keys, + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"[{self.storage_unit_id}]: saved checkpoint to {path}") + return ZMQMessage.create( + request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"success": True}, + ) + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: save checkpoint failed: {e}") + return ZMQMessage.create( + request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"success": False, "message": str(e)}, + ) + + def _handle_load_checkpoint(self, data_parts) -> ZMQMessage: + """ + Restore storage unit data directly from a file. + + Args: + path: data_parts: ZMQMessage from client, including + absolute path for the output .pkl file. + The caller must ensure this path is reachable from the node + running this actor (shared filesystem required for multi-node setups). + + Returns: + Checkpoint restore success response ZMQMessage. + """ + path = data_parts.body["path"] + try: + with open(path, "rb") as f: + data = pickle.load(f) + + if data["storage_unit_size"] != self.storage_unit_size: + logger.warning( + f"[{self.storage_unit_id}]: storage_unit_size mismatch — " + f"checkpoint={data['storage_unit_size']}, current={self.storage_unit_size}" + ) + + self.storage_data.field_data.clear() + self.storage_data._active_keys.clear() + self.storage_data.field_data = data["field_data"] + self.storage_data._active_keys = data["active_keys"] + + logger.info( + f"[{self.storage_unit_id}]: loaded checkpoint from {path} — " + f"{len(data['active_keys'])} keys, {len(data['field_data'])} fields" + ) + return ZMQMessage.create( + request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"success": True}, + ) + + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: load checkpoint failed: {e}") + return ZMQMessage.create( + request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"success": False, "message": str(e)}, + ) + @staticmethod def _cumulative_bucket_counts(hist) -> list[float]: """Build cumulative counts from a prometheus_client Histogram's non-cumulative buckets.""" diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 4fe32f0a..dd594dbb 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -101,6 +101,16 @@ class ZMQRequestType(ExplicitEnum): GET_METRICS = "GET_METRICS" METRICS_RESPONSE = "METRICS_RESPONSE" + # CHECKPOINT + SAVE_CONTROLLER_CHECKPOINT = "SAVE_CONTROLLER_CHECKPOINT" + SAVE_CONTROLLER_CHECKPOINT_RESPONSE = "SAVE_CONTROLLER_CHECKPOINT_RESPONSE" + SAVE_STORAGE_CHECKPOINT = "SAVE_STORAGE_CHECKPOINT" + SAVE_STORAGE_CHECKPOINT_RESPONSE = "SAVE_STORAGE_CHECKPOINT_RESPONSE" + LOAD_CONTROLLER_CHECKPOINT = "LOAD_CONTROLLER_CHECKPOINT" + LOAD_CONTROLLER_CHECKPOINT_RESPONSE = "LOAD_CONTROLLER_CHECKPOINT_RESPONSE" + LOAD_STORAGE_CHECKPOINT = "LOAD_STORAGE_CHECKPOINT" + LOAD_STORAGE_CHECKPOINT_RESPONSE = "LOAD_STORAGE_CHECKPOINT_RESPONSE" + class ZMQServerInfo: """ From 505e331d41e0f9120e1eb059eb0db5b3124e88d9 Mon Sep 17 00:00:00 2001 From: yxstev Date: Wed, 17 Jun 2026 10:41:57 +0800 Subject: [PATCH 2/6] fix bug Signed-off-by: yxstev --- transfer_queue/controller.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index d7c1abf9..d20424fb 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -2032,21 +2032,21 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) - elif request_msg.request_type == ZMQRequestType.SAVE_CHECKPOINT: + elif request_msg.request_type == ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT: path = request_msg.body["path"] success = self.save_checkpoint(path) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.SAVE_CHECKPOINT_RESPONSE, + request_type=ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT_RESPONSE, sender_id=self.controller_id, receiver_id=request_msg.sender_id, body={"success": success}, ) - elif request_msg.request_type == ZMQRequestType.LOAD_CHECKPOINT: + elif request_msg.request_type == ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT: path = request_msg.body["path"] success = self.load_checkpoint(path) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.LOAD_CHECKPOINT_RESPONSE, + request_type=ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT_RESPONSE, sender_id=self.controller_id, receiver_id=request_msg.sender_id, body={"success": success}, From 2338283ce7e4f22887ec4ae81d1b19d023d47462 Mon Sep 17 00:00:00 2001 From: yxstev Date: Wed, 17 Jun 2026 15:22:03 +0800 Subject: [PATCH 3/6] resolve comments Signed-off-by: yxstev --- tests/e2e/test_checkpoint_e2e.py | 363 ++++++++++-------- transfer_queue/client.py | 167 ++++++-- transfer_queue/controller.py | 37 +- transfer_queue/interface.py | 11 +- transfer_queue/sampler/base.py | 26 ++ .../sampler/seqlen_balanced_sampler.py | 11 + .../managers/simple_storage_manager.py | 14 +- transfer_queue/storage/simple_storage.py | 37 +- 8 files changed, 412 insertions(+), 254 deletions(-) diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py index 3425ff95..c544a0b4 100644 --- a/tests/e2e/test_checkpoint_e2e.py +++ b/tests/e2e/test_checkpoint_e2e.py @@ -26,7 +26,7 @@ import ray import torch from omegaconf import OmegaConf -from tensordict import TensorDict +from tensordict import NonTensorStack, TensorDict import transfer_queue as tq @@ -46,6 +46,11 @@ ) +# --------------------------------------------------------------------------- +# fixtures +# --------------------------------------------------------------------------- + + @pytest.fixture(scope="module") def ray_init(): if not ray.is_initialized(): @@ -87,32 +92,12 @@ def checkpoint_dir(tmp_path): # --------------------------------------------------------------------------- -def _put_batch(keys, partition_id, input_ids, attention_mask, tags=None): - fields = TensorDict( - {"input_ids": input_ids, "attention_mask": attention_mask}, - batch_size=len(keys), - ) - if tags is None: - tags = [{} for _ in keys] - tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) - - -def _get_batch(keys, partition_id): - return tq.kv_batch_get(keys=keys, partition_id=partition_id) - - -def assert_tensor_equal(tensor_a, tensor_b, msg=""): - """Assert two tensors are equal, handling nested vs dense comparisons.""" - if (isinstance(tensor_a, torch.Tensor) and tensor_a.is_nested) or ( - isinstance(tensor_b, torch.Tensor) and tensor_b.is_nested - ): - seq_a = list(tensor_a) - seq_b = list(tensor_b) - assert len(seq_a) == len(seq_b), f"{msg} Length mismatch: {len(seq_a)} vs {len(seq_b)}" - for t1, t2 in zip(seq_a, seq_b, strict=True): - assert torch.equal(t1, t2), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" +def _assert_tensor_equal(a, b, msg=""): + if (isinstance(a, torch.Tensor) and a.is_nested) or (isinstance(b, torch.Tensor) and b.is_nested): + for t1, t2 in zip(list(a), list(b), strict=True): + assert torch.equal(t1, t2), f"{msg} mismatch" else: - assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + assert torch.equal(a, b), f"{msg} mismatch" # --------------------------------------------------------------------------- @@ -121,151 +106,242 @@ def assert_tensor_equal(tensor_a, tensor_b, msg=""): class TestCheckpointRoundtrip: - def test_save_creates_expected_files(self, tq_system, checkpoint_dir): + """Standard data → save → verify files → wipe → load → verify data.""" + + def test_tensor_fields(self, tq_system, checkpoint_dir, controller): + # Define test data keys = ["k0", "k1"] - partition_id = "p0" - _put_batch(keys, partition_id, torch.tensor([[1, 2], [3, 4]]), torch.ones(2, 2)) + partition_id = "p_tensor" + input_ids = torch.tensor([[1, 2], [3, 4]]) + attention_mask = torch.ones(2, 2) + + # Put + tq.kv_batch_put( + keys=keys, + partition_id=partition_id, + fields=TensorDict({"input_ids": input_ids, "attention_mask": attention_mask}, batch_size=len(keys)), + tags=[{} for _ in keys], + ) + # Save tq.save_checkpoint(checkpoint_dir) + # Check saved state: expected files exist assert (checkpoint_dir / "metadata.json").exists() assert (checkpoint_dir / "controller_state.pkl").exists() - - with open(checkpoint_dir / "metadata.json") as f: - info = json.load(f) - - assert info["storage_saved"] is True su_dir = checkpoint_dir / "storage_units" assert su_dir.exists() assert (su_dir / "su_info.json").exists() - - def test_metadata_json_content(self, tq_system, checkpoint_dir): - keys = ["m0"] - _put_batch(keys, "p_meta", torch.tensor([[10, 20]]), torch.ones(1, 2)) - - tq.save_checkpoint(checkpoint_dir, metadata={"iteration": 42, "loss": 0.5}) - with open(checkpoint_dir / "metadata.json") as f: meta = json.load(f) + assert meta["storage_saved"] is True - assert meta["user_metadata"]["iteration"] == 42 - assert meta["user_metadata"]["loss"] == pytest.approx(0.5) - assert "storage_saved" in meta + # Wipe controller state so load has real work to do + ray.get(controller.clear_partition.remote(partition_id)) + assert ray.get(controller.list_partitions.remote()) == [] - def test_load_restores_controller_partitions(self, tq_system, checkpoint_dir, controller): + # Load + tq.load_checkpoint(checkpoint_dir) + + # Check loaded state: partition and data restored + assert partition_id in ray.get(controller.list_partitions.remote()) + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + _assert_tensor_equal(retrieved["input_ids"], input_ids) + _assert_tensor_equal(retrieved["attention_mask"], attention_mask) + + def test_controller_metadata(self, tq_system, checkpoint_dir, controller): + # Define test data keys = ["a0", "a1", "a2"] - partition_id = "p_ctrl" + partition_id = "p_meta" input_ids = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) tags = [{"idx": i} for i in range(3)] - _put_batch(keys, partition_id, input_ids, torch.ones(3, 3), tags) + # Put + tq.kv_batch_put( + keys=keys, + partition_id=partition_id, + fields=TensorDict({"input_ids": input_ids, "attention_mask": torch.ones(3, 3)}, batch_size=len(keys)), + tags=tags, + ) + + # Save tq.save_checkpoint(checkpoint_dir) - # wipe controller state + # Wipe ray.get(controller.clear_partition.remote(partition_id)) - assert ray.get(controller.list_partitions.remote()) == [] + # Load tq.load_checkpoint(checkpoint_dir) - # partition must be back - partitions = ray.get(controller.list_partitions.remote()) - assert partition_id in partitions - - # key-to-global-index mapping must be intact + # Check loaded state: key mapping and tags intact snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) - for key in keys: - assert key in snapshot.keys_mapping - - # tags must be intact for i, key in enumerate(keys): + assert key in snapshot.keys_mapping gidx = snapshot.keys_mapping[key] assert snapshot.custom_meta[gidx]["idx"] == i - def test_load_restores_storage_data(self, tq_system, checkpoint_dir, controller): - keys = ["s0", "s1"] - partition_id = "p_storage" - input_ids = torch.tensor([[10, 20], [30, 40]]) - attention_mask = torch.ones(2, 2) - _put_batch(keys, partition_id, input_ids, attention_mask) + def test_multiple_partitions(self, tq_system, checkpoint_dir, controller): + # Define test data + partitions_data = {f"part_{i}": (torch.full((2, 4), i, dtype=torch.long), torch.ones(2, 4)) for i in range(3)} + + # Put + for pid, (iids, mask) in partitions_data.items(): + tq.kv_batch_put( + keys=[f"{pid}_k0", f"{pid}_k1"], + partition_id=pid, + fields=TensorDict({"input_ids": iids, "attention_mask": mask}, batch_size=2), + tags=[{}, {}], + ) + + # Save + tq.save_checkpoint(checkpoint_dir) + + # Wipe + for pid in partitions_data: + ray.get(controller.clear_partition.remote(pid)) + + # Load + tq.load_checkpoint(checkpoint_dir) + + # Check loaded state + for pid, (iids, _) in partitions_data.items(): + retrieved = tq.kv_batch_get(keys=[f"{pid}_k0", f"{pid}_k1"], partition_id=pid, select_fields=["input_ids"]) + _assert_tensor_equal(retrieved["input_ids"], iids) + + def test_user_metadata_preserved(self, tq_system, checkpoint_dir): + # Define test data + keys = ["m0"] + + # Put + tq.kv_batch_put( + keys=keys, + partition_id="p_usermeta", + fields=TensorDict( + {"input_ids": torch.tensor([[10, 20]]), "attention_mask": torch.ones(1, 2)}, batch_size=1 + ), + tags=[{}], + ) + # Save with user metadata + tq.save_checkpoint(checkpoint_dir, metadata={"iteration": 42, "loss": 0.5}) + + # Check saved state: user metadata written correctly + with open(checkpoint_dir / "metadata.json") as f: + meta = json.load(f) + assert meta["user_metadata"]["iteration"] == 42 + assert meta["user_metadata"]["loss"] == pytest.approx(0.5) + + def test_non_tensor_fields(self, tq_system, checkpoint_dir, controller): + # Define test data + keys = ["t0", "t1"] + partition_id = "p_str" + input_ids = torch.tensor([[1, 2], [3, 4]]) + fields = TensorDict( + {"input_ids": input_ids, "text": NonTensorStack("hello", "world")}, + batch_size=2, + ) + + # Put + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}]) + + # Save tq.save_checkpoint(checkpoint_dir) - # clear both controller and storage state so load has to restore from scratch + # Wipe ray.get(controller.clear_partition.remote(partition_id)) + # Load tq.load_checkpoint(checkpoint_dir) - retrieved = _get_batch(keys, partition_id) - assert_tensor_equal(retrieved["input_ids"], input_ids) - assert_tensor_equal(retrieved["attention_mask"], attention_mask) - - def test_load_restores_multiple_partitions(self, tq_system, checkpoint_dir, controller): - for i in range(3): - _put_batch( - [f"p{i}_k0", f"p{i}_k1"], - f"part_{i}", - torch.full((2, 4), i, dtype=torch.long), - torch.ones(2, 4), + # Check loaded state + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["input_ids"]) + _assert_tensor_equal(retrieved["input_ids"], input_ids) + + def test_nested_tensor_fields(self, tq_system, checkpoint_dir, controller): + # Define test data + keys = ["j0", "j1", "j2"] + partition_id = "p_jagged" + + # Put (variable-length sequences) + for i, key in enumerate(keys): + tq.kv_put( + key=key, + partition_id=partition_id, + fields=TensorDict({"seq": torch.arange(i + 1, dtype=torch.float).unsqueeze(0)}, batch_size=1), + tag=None, ) + # Save tq.save_checkpoint(checkpoint_dir) - for i in range(3): - ray.get(controller.clear_partition.remote(f"part_{i}")) + # Wipe + ray.get(controller.clear_partition.remote(partition_id)) + # Load tq.load_checkpoint(checkpoint_dir) - for i in range(3): - retrieved = tq.kv_batch_get( - keys=[f"p{i}_k0", f"p{i}_k1"], - partition_id=f"part_{i}", - select_fields=["input_ids"], - ) - assert_tensor_equal(retrieved["input_ids"], torch.full((2, 4), i, dtype=torch.long)) + # Check loaded state + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["seq"]) + for i, component in enumerate(retrieved["seq"].unbind()): + _assert_tensor_equal(component, torch.arange(i + 1, dtype=torch.float)) # --------------------------------------------------------------------------- -# include_storage=False +# include_storage=False (SimpleStorage override) # --------------------------------------------------------------------------- -class TestCheckpointMetadataOnly: - def test_save_include_storage_false_simplestorage_override(self, tq_system, checkpoint_dir): - """For SimpleStorage, include_storage=False is overridden to True because in-memory - data would be lost on restart. storage_saved must be True and storage_units must exist.""" - _put_batch(["n0"], "p_nometa", torch.tensor([[1, 2]]), torch.ones(1, 2)) +class TestIncludeStorageFalse: + """For SimpleStorage, include_storage=False is silently forced to True.""" + + def test_storage_saved_is_true(self, tq_system, checkpoint_dir): + # Define test data + Put + tq.kv_batch_put( + keys=["n0"], + partition_id="p_nometa", + fields=TensorDict({"input_ids": torch.tensor([[1, 2]]), "attention_mask": torch.ones(1, 2)}, batch_size=1), + tags=[{}], + ) + # Save with include_storage=False tq.save_checkpoint(checkpoint_dir, include_storage=False) + # Check saved state: storage_saved must be True and directory must exist with open(checkpoint_dir / "metadata.json") as f: - info = json.load(f) - - assert info["storage_saved"] is True + meta = json.load(f) + assert meta["storage_saved"] is True assert (checkpoint_dir / "storage_units").exists() - def test_load_after_include_storage_false_restores_both(self, tq_system, checkpoint_dir, controller): - """With SimpleStorage, include_storage=False is forced True, so both controller and - storage are saved and restored.""" + def test_both_restored_after_load(self, tq_system, checkpoint_dir, controller): + # Define test data keys = ["n0", "n1"] partition_id = "p_nometa2" input_ids = torch.tensor([[5, 6], [7, 8]]) - _put_batch(keys, partition_id, input_ids, torch.ones(2, 2)) + # Put + tq.kv_batch_put( + keys=keys, + partition_id=partition_id, + fields=TensorDict({"input_ids": input_ids, "attention_mask": torch.ones(2, 2)}, batch_size=len(keys)), + tags=[{} for _ in keys], + ) + + # Save tq.save_checkpoint(checkpoint_dir, include_storage=False) + # Wipe ray.get(controller.clear_partition.remote(partition_id)) + # Load tq.load_checkpoint(checkpoint_dir) - partitions = ray.get(controller.list_partitions.remote()) - assert partition_id in partitions - + # Check loaded state: both controller and storage restored + assert partition_id in ray.get(controller.list_partitions.remote()) snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) for key in keys: assert key in snapshot.keys_mapping - - retrieved = _get_batch(keys, partition_id) - assert_tensor_equal(retrieved["input_ids"], input_ids) + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + _assert_tensor_equal(retrieved["input_ids"], input_ids) # --------------------------------------------------------------------------- @@ -275,7 +351,6 @@ def test_load_after_include_storage_false_restores_both(self, tq_system, checkpo class TestCheckpointErrors: def test_save_raises_if_not_initialized(self, tmp_path): - # call save_checkpoint before tq.init() in a fresh module state import transfer_queue.interface as iface original = iface._TQ_CONTROLLER @@ -308,10 +383,16 @@ def test_load_raises_if_metadata_missing(self, tq_system, tmp_path): tq.load_checkpoint(ck) def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, checkpoint_dir): - _put_batch(["e0"], "p_err", torch.tensor([[1, 2]]), torch.ones(1, 2)) + # Define test data + Put + Save + tq.kv_batch_put( + keys=["e0"], + partition_id="p_err", + fields=TensorDict({"input_ids": torch.tensor([[1, 2]]), "attention_mask": torch.ones(1, 2)}, batch_size=1), + tags=[{}], + ) tq.save_checkpoint(checkpoint_dir) - # tamper: add a fake extra SU entry in su_info.json so count differs + # Tamper: add a fake extra SU entry so count differs su_info_path = checkpoint_dir / "storage_units" / "su_info.json" with open(su_info_path) as f: su_info = json.load(f) @@ -323,8 +404,13 @@ def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, c tq.load_checkpoint(checkpoint_dir) def test_no_partial_state_on_failed_save(self, tq_system, tmp_path): - """A failed save must not leave a partial directory.""" - _put_batch(["f0"], "p_fail", torch.tensor([[1, 2]]), torch.ones(1, 2)) + # Define test data + Put + tq.kv_batch_put( + keys=["f0"], + partition_id="p_fail", + fields=TensorDict({"input_ids": torch.tensor([[1, 2]]), "attention_mask": torch.ones(1, 2)}, batch_size=1), + tags=[{}], + ) ck = tmp_path / "ck" @@ -337,59 +423,6 @@ def test_no_partial_state_on_failed_save(self, tq_system, tmp_path): with pytest.raises(RuntimeError, match="simulated dump failure"): tq.save_checkpoint(ck) - assert not ck.exists(), "Partial checkpoint directory should have been cleaned up" - assert not (tmp_path / "ck.tmp").exists(), "Temp directory should have been cleaned up" - - -# --------------------------------------------------------------------------- -# data variety -# --------------------------------------------------------------------------- - - -class TestCheckpointDataVariety: - def test_non_tensor_fields_roundtrip(self, tq_system, checkpoint_dir, controller): - """String fields should survive save/load.""" - from tensordict import NonTensorStack - - keys = ["t0", "t1"] - partition_id = "p_str" - fields = TensorDict( - { - "input_ids": torch.tensor([[1, 2], [3, 4]]), - "text": NonTensorStack("hello", "world"), - }, - batch_size=2, - ) - tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}]) - - tq.save_checkpoint(checkpoint_dir) - - ray.get(controller.clear_partition.remote(partition_id)) - - tq.load_checkpoint(checkpoint_dir) - - retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["input_ids"]) - assert_tensor_equal(retrieved["input_ids"], torch.tensor([[1, 2], [3, 4]])) - - def test_nested_tensor_fields_roundtrip(self, tq_system, checkpoint_dir, controller): - """Variable-length (jagged) tensor fields should survive save/load.""" - keys = ["j0", "j1", "j2"] - partition_id = "p_jagged" - for i, key in enumerate(keys): - seq = torch.arange(i + 1, dtype=torch.float).unsqueeze(0) - tq.kv_put( - key=key, - partition_id=partition_id, - fields=TensorDict({"seq": seq}, batch_size=1), - tag=None, - ) - - tq.save_checkpoint(checkpoint_dir) - - ray.get(controller.clear_partition.remote(partition_id)) - - tq.load_checkpoint(checkpoint_dir) - - retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["seq"]) - for i, component in enumerate(retrieved["seq"].unbind()): - assert_tensor_equal(component, torch.arange(i + 1, dtype=torch.float)) + # Check saved state: no partial directory left + assert not ck.exists() + assert not (tmp_path / "ck.tmp").exists() diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 5d919954..34f35757 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -1068,24 +1068,38 @@ async def async_save_controller_checkpoint( path: str, socket: zmq.asyncio.Socket | None = None, ) -> None: - """Send SAVE_CHECKPOINT to controller and wait for response.""" - assert socket is not None - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT, # type: ignore[arg-type] - sender_id=self.client_id, - receiver_id=self._controller.id, - body={"path": path}, - ) - await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart(copy=False) - response_msg = ZMQMessage.deserialize(response_serialized) - if response_msg.request_type != ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT_RESPONSE: - raise RuntimeError( - f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " - f"from controller during checkpoint dump" + """Asynchronously save controller state to a file via ZMQ RPC. + + Sends a SAVE_CONTROLLER_CHECKPOINT request to the controller and waits + for acknowledgement. The controller serializes its state directly to + ``path`` in-process. + + Args: + path: Absolute path for the output .pkl file. The caller must + ensure this path is writable from the node running the controller. + socket: ZMQ socket injected by @with_controller_socket. + + Raises: + RuntimeError: If the RPC fails or an unexpected response is received. + """ + try: + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, ) - if not response_msg.body.get("success"): - raise RuntimeError(f"[{self.client_id}]: Controller failed to dump checkpoint to {path}") + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint dump" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in save_controller_checkpoint: {str(e)}") from e @with_controller_socket async def async_load_controller_checkpoint( @@ -1093,27 +1107,54 @@ async def async_load_controller_checkpoint( path: str, socket: zmq.asyncio.Socket | None = None, ) -> None: - """Send LOAD_CHECKPOINT to controller and wait for response.""" - assert socket is not None - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT, # type: ignore[arg-type] - sender_id=self.client_id, - receiver_id=self._controller.id, - body={"path": path}, - ) - await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart(copy=False) - response_msg = ZMQMessage.deserialize(response_serialized) - if response_msg.request_type != ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT_RESPONSE: - raise RuntimeError( - f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " - f"from controller during checkpoint restore" + """Asynchronously restore controller state from a file via ZMQ RPC. + + Sends a LOAD_CONTROLLER_CHECKPOINT request to the controller and waits + for acknowledgement. The controller deserializes its state directly from + ``path`` in-process, overwriting the current in-memory state. + + Args: + path: Absolute path to a .pkl file previously written by + ``async_save_controller_checkpoint``. The caller must ensure + this path is readable from the node running the controller. + socket: ZMQ socket injected by @with_controller_socket. + + Raises: + RuntimeError: If the RPC fails or an unexpected response is received. + """ + try: + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, ) - if not response_msg.body.get("success"): - raise RuntimeError(f"[{self.client_id}]: Controller failed to restore checkpoint from {path}") + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint restore" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in load_controller_checkpoint: {str(e)}") from e async def async_save_storage_checkpoint(self, checkpoint_dir: str) -> None: - """Save storage state into checkpoint_dir via StorageManager.""" + """Asynchronously save storage state to a directory via StorageManager. + + Delegates to the underlying StorageManager, which fans out save requests + to all storage units concurrently. + + Args: + checkpoint_dir: Directory under which storage unit files are written. + The StorageManager creates a ``storage_units/`` subdirectory here. + + Raises: + RuntimeError: If the storage manager is not initialized. + NotImplementedError: If the storage backend does not support checkpoint. + """ if not hasattr(self, "storage_manager") or self.storage_manager is None: raise RuntimeError( f"[{self.client_id}]: Storage manager not initialized. " @@ -1122,7 +1163,21 @@ async def async_save_storage_checkpoint(self, checkpoint_dir: str) -> None: await self.storage_manager.save_checkpoint(checkpoint_dir) async def async_load_storage_checkpoint(self, checkpoint_dir: str) -> None: - """Restore storage state from checkpoint_dir via StorageManager.""" + """Asynchronously restore storage state from a directory via StorageManager. + + Delegates to the underlying StorageManager, which fans out load requests + to all storage units concurrently. Existing in-memory data is overwritten. + + Args: + checkpoint_dir: Directory previously written by + ``async_save_storage_checkpoint``, containing the + ``storage_units/`` subdirectory and manifest. + + Raises: + RuntimeError: If the storage manager is not initialized. + FileNotFoundError: If the storage unit manifest is missing. + ValueError: If the number of storage units does not match the checkpoint. + """ if not hasattr(self, "storage_manager") or self.storage_manager is None: raise RuntimeError( f"[{self.client_id}]: Storage manager not initialized. " @@ -1636,19 +1691,51 @@ def kv_list( # ==================== Checkpoint API ==================== def save_controller_checkpoint(self, path: str) -> None: - """Synchronously dump controller state to a file via ZMQ RPC.""" + """Synchronously save controller state to a file via ZMQ RPC. + + Args: + path: Absolute path for the output .pkl file. + + Raises: + RuntimeError: If the RPC fails or an unexpected response is received. + """ return self._save_controller_checkpoint(path) def load_controller_checkpoint(self, path: str) -> None: - """Synchronously restore controller state from a file via ZMQ RPC.""" + """Synchronously restore controller state from a file via ZMQ RPC. + + Args: + path: Absolute path to a .pkl file previously written by + ``save_controller_checkpoint``. + + Raises: + RuntimeError: If the RPC fails or an unexpected response is received. + """ return self._load_controller_checkpoint(path) def save_storage_checkpoint(self, checkpoint_dir: str) -> None: - """Synchronously save storage state into checkpoint_dir.""" + """Synchronously save storage state to a directory via StorageManager. + + Args: + checkpoint_dir: Directory under which storage unit files are written. + + Raises: + RuntimeError: If the storage manager is not initialized. + NotImplementedError: If the storage backend does not support checkpoint. + """ return self._save_storage_checkpoint(checkpoint_dir) def load_storage_checkpoint(self, checkpoint_dir: str) -> None: - """Synchronously restore storage state from checkpoint_dir.""" + """Synchronously restore storage state from a directory via StorageManager. + + Args: + checkpoint_dir: Directory previously written by ``save_storage_checkpoint``. + + Raises: + RuntimeError: If the storage manager is not initialized. + FileNotFoundError: If the storage unit manifest is missing. + ValueError: If the number of storage units does not match the checkpoint. + """ return self._load_storage_checkpoint(checkpoint_dir) def close(self) -> None: diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index d20424fb..3fbedd25 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -2034,22 +2034,22 @@ def _process_request(self): elif request_msg.request_type == ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT: path = request_msg.body["path"] - success = self.save_checkpoint(path) + self.save_checkpoint(path) response_msg = ZMQMessage.create( request_type=ZMQRequestType.SAVE_CONTROLLER_CHECKPOINT_RESPONSE, sender_id=self.controller_id, receiver_id=request_msg.sender_id, - body={"success": success}, + body={"success": True}, ) elif request_msg.request_type == ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT: path = request_msg.body["path"] - success = self.load_checkpoint(path) + self.load_checkpoint(path) response_msg = ZMQMessage.create( request_type=ZMQRequestType.LOAD_CONTROLLER_CHECKPOINT_RESPONSE, sender_id=self.controller_id, receiver_id=request_msg.sender_id, - body={"success": success}, + body={"success": True}, ) self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) @@ -2066,17 +2066,17 @@ def get_config(self) -> DictConfig: """Retrieve the global config of TransferQueue.""" return self.tq_config - def save_checkpoint(self, path: str) -> bool: + def save_checkpoint(self, path: str) -> None: """Serialize controller state directly to a file. Writes in-process to avoid transmitting the payload back over the - Ray object store — only a bool ACK is returned to the caller. + Ray object store. Args: path: Absolute path for the output .pkl file. - Returns: - True on success, False on failure. + Raises: + Exception: If serialization or file I/O fails. """ try: state = { @@ -2088,24 +2088,22 @@ def save_checkpoint(self, path: str) -> bool: "global_index_counter": self.index_manager.global_index_counter, "allocated_indexes": set(self.index_manager.allocated_indexes), }, - "sampler": self.sampler.get_state() if hasattr(self.sampler, "get_state") else None, + "sampler": self.sampler.save_checkpoint(), } with open(path, "wb") as f: pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info(f"[{self.controller_id}]: dumped to {path}") - return True except Exception as e: - logger.error(f"[{self.controller_id}]: dump_to_file failed: {e}") - return False + raise RuntimeError(f"[{self.controller_id}]: save checkpoint failed: {e}") from e - def load_checkpoint(self, path: str) -> bool: + def load_checkpoint(self, path: str) -> None: """Restore controller state directly from a file. Args: - path: Absolute path to a .pkl file previously written by dump_to_file(). + path: Absolute path to a .pkl file previously written by save_checkpoint(). - Returns: - True on success, False on failure. + Raises: + Exception: If deserialization or file I/O fails. """ try: with open(path, "rb") as f: @@ -2120,14 +2118,11 @@ def load_checkpoint(self, path: str) -> bool: self.index_manager.global_index_counter = im["global_index_counter"] self.index_manager.allocated_indexes = im["allocated_indexes"] - if state["sampler"] is not None and hasattr(self.sampler, "restore_state"): - self.sampler.restore_state(state["sampler"]) + self.sampler.load_checkpoint(state["sampler"]) logger.info(f"[{self.controller_id}]: restored from {path}") - return True except Exception as e: - logger.error(f"[{self.controller_id}]: restore_from_file failed: {e}") - return False + raise RuntimeError(f"[{self.controller_id}]: load checkpoint failed: {e}") from e def register_sampler( self, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index a0e7e4ef..ed48f0d3 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -1069,14 +1069,17 @@ def save_checkpoint( .. note:: **Multi-node limitation**: checkpoint_dir must reside on a shared network - filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes running - storage units. Single-node deployments have no such requirement. + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes. + Single-node deployments have no such requirement. Args: checkpoint_dir: Directory to save the checkpoint (created if not exists). include_storage: Whether to save storage backend state. Set to False to save controller metadata only (e.g. for KV backends where data persists externally and does not need to be re-saved). + For SimpleStorage (in-memory), this flag is ignored and + storage is always saved — skipping it would cause data loss + on restart. metadata: User-defined key-value pairs written into metadata.json. Example: {"time_stamp": 1234567.891234, "step": 10} @@ -1141,8 +1144,8 @@ def load_checkpoint( .. note:: **Multi-node limitation**: checkpoint_dir must reside on a shared network - filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes running - storage units. Single-node deployments have no such requirement. + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes. + Single-node deployments have no such requirement. Args: checkpoint_dir: Path to the checkpoint directory. diff --git a/transfer_queue/sampler/base.py b/transfer_queue/sampler/base.py index c831ecba..173b19da 100644 --- a/transfer_queue/sampler/base.py +++ b/transfer_queue/sampler/base.py @@ -124,3 +124,29 @@ def clear_cache(self, partition_id: str): """ if partition_id in self._states.keys(): self._states.pop(partition_id) + + def save_checkpoint(self) -> dict: + """Return the sampler's serializable state for checkpointing. + + By default this returns ``_states``, the shared cache used by all + samplers. Subclasses that maintain additional runtime state (e.g. + a separate balanced cache) should override this method and include + their extra fields. + + Returns: + A dict that can be passed to ``restore_state`` to recreate the + current sampling state. + """ + return {"_states": self._states} + + def load_checkpoint(self, state: dict) -> None: + """Restore the sampler's state from a checkpoint. + + By default this restores ``_states``. Subclasses that override + ``save_checkpoint`` must also override this method to restore their extra + fields. + + Args: + state: A dict previously returned by ``save_checkpoint``. + """ + self._states = state["_states"] diff --git a/transfer_queue/sampler/seqlen_balanced_sampler.py b/transfer_queue/sampler/seqlen_balanced_sampler.py index 1dd10740..20c904d7 100644 --- a/transfer_queue/sampler/seqlen_balanced_sampler.py +++ b/transfer_queue/sampler/seqlen_balanced_sampler.py @@ -196,6 +196,17 @@ def clear_cache(self, partition_id: str): for k in keys_to_remove: del self._balanced_cache[k] + def save_checkpoint(self) -> dict: + """Return sampler state including the balanced cache.""" + state = super().save_checkpoint() + state["_balanced_cache"] = self._balanced_cache + return state + + def load_checkpoint(self, state: dict) -> None: + """Restore sampler state including the balanced cache.""" + super().load_checkpoint(state) + self._balanced_cache = state["_balanced_cache"] + # Copied from https://github.com/volcengine/verl/blob/468adf22c43b744348051fccd7a5d830c6c3c36a/verl/utils/seqlen_balancing.py def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index 5e2746a0..0b6777fb 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -41,8 +41,8 @@ TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds -_SU_SUBDIR = "storage_units" -_SU_INFO_FILE = "su_info.json" +_SU_SUBDIR = "simple_storage" +_SU_INFO_FILE = "storage_unit_info.json" # Pre-bound decorator for storage-unit socket operations. with_storage_unit_socket = with_zmq_socket( @@ -554,8 +554,9 @@ async def _save_single_storage_unit( f"{response_msg.body.get('message', 'unknown error')}" ) except Exception as e: - logger.error(f"[{self.storage_manager_id}]: Error dumping storage unit {target_storage_unit}: {str(e)}") - raise + raise RuntimeError( + f"[{self.storage_manager_id}]: Error dumping storage unit {target_storage_unit}: {str(e)}" + ) from e @with_storage_unit_socket async def _load_single_storage_unit( @@ -583,10 +584,9 @@ async def _load_single_storage_unit( f"{response_msg.body.get('message', 'unknown error')}" ) except Exception as e: - logger.error( + raise RuntimeError( f"[{self.storage_manager_id}]: Error restoring for storage unit {target_storage_unit}: {str(e)}" - ) - raise + ) from e async def save_checkpoint(self, checkpoint_dir: str) -> None: """Dump all storage units to the storage_units/ subdirectory of checkpoint_dir. diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index 90378f78..e9328d07 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -546,20 +546,17 @@ def _handle_get_metrics(self) -> ZMQMessage: ) def _handle_save_checkpoint(self, data_parts) -> ZMQMessage: - """ - Serialize storage unit data directly to a file. - - Writes data in-process to avoid transmitting the payload back over the - Ray object store — only a bool ACK is returned to the caller. + """Serialize storage unit data directly to a file. Args: - path: data_parts: ZMQMessage from client, including - absolute path for the output .pkl file. - The caller must ensure this path is reachable from the node - running this actor (shared filesystem required for multi-node setups). + data_parts: ZMQMessage from client, containing ``path`` in body: + absolute path for the output .pkl file. The caller must ensure + this path is reachable from the node running this actor + (shared filesystem required for multi-node setups). Returns: - Checkpoint dump success response ZMQMessage. + ZMQMessage with ``success=True`` on success, or ``success=False`` + and ``message`` containing the error string on failure. """ path = data_parts.body["path"] try: @@ -586,17 +583,18 @@ def _handle_save_checkpoint(self, data_parts) -> ZMQMessage: ) def _handle_load_checkpoint(self, data_parts) -> ZMQMessage: - """ - Restore storage unit data directly from a file. + """Restore storage unit data directly from a file. Args: - path: data_parts: ZMQMessage from client, including - absolute path for the output .pkl file. - The caller must ensure this path is reachable from the node - running this actor (shared filesystem required for multi-node setups). + data_parts: ZMQMessage from client, containing ``path`` in body: + absolute path to a .pkl file previously written by + ``_handle_save_checkpoint``. The caller must ensure this path + is reachable from the node running this actor + (shared filesystem required for multi-node setups). Returns: - Checkpoint restore success response ZMQMessage. + ZMQMessage with ``success=True`` on success, or ``success=False`` + and ``message`` containing the error string on failure. """ path = data_parts.body["path"] try: @@ -609,6 +607,11 @@ def _handle_load_checkpoint(self, data_parts) -> ZMQMessage: f"checkpoint={data['storage_unit_size']}, current={self.storage_unit_size}" ) + if self.storage_data._active_keys: + logger.warning( + f"[{self.storage_unit_id}]: overwriting {len(self.storage_data._active_keys)} " + f"existing keys with checkpoint data from {path}" + ) self.storage_data.field_data.clear() self.storage_data._active_keys.clear() self.storage_data.field_data = data["field_data"] From 741aa5ae84365b0178b96f65a858ab396512c9fe Mon Sep 17 00:00:00 2001 From: yxstev Date: Thu, 18 Jun 2026 11:02:48 +0800 Subject: [PATCH 4/6] fix test Signed-off-by: yxstev --- tests/e2e/test_checkpoint_e2e.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py index c544a0b4..f2a8da92 100644 --- a/tests/e2e/test_checkpoint_e2e.py +++ b/tests/e2e/test_checkpoint_e2e.py @@ -129,9 +129,9 @@ def test_tensor_fields(self, tq_system, checkpoint_dir, controller): # Check saved state: expected files exist assert (checkpoint_dir / "metadata.json").exists() assert (checkpoint_dir / "controller_state.pkl").exists() - su_dir = checkpoint_dir / "storage_units" + su_dir = checkpoint_dir / "simple_storage" assert su_dir.exists() - assert (su_dir / "su_info.json").exists() + assert (su_dir / "storage_unit_info.json").exists() with open(checkpoint_dir / "metadata.json") as f: meta = json.load(f) assert meta["storage_saved"] is True @@ -310,7 +310,7 @@ def test_storage_saved_is_true(self, tq_system, checkpoint_dir): with open(checkpoint_dir / "metadata.json") as f: meta = json.load(f) assert meta["storage_saved"] is True - assert (checkpoint_dir / "storage_units").exists() + assert (checkpoint_dir / "simple_storage").exists() def test_both_restored_after_load(self, tq_system, checkpoint_dir, controller): # Define test data @@ -393,7 +393,7 @@ def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, c tq.save_checkpoint(checkpoint_dir) # Tamper: add a fake extra SU entry so count differs - su_info_path = checkpoint_dir / "storage_units" / "su_info.json" + su_info_path = checkpoint_dir / "simple_storage" / "storage_unit_info.json" with open(su_info_path) as f: su_info = json.load(f) su_info.append({"position": 99, "storage_unit_id": "fake"}) From 6968ca21cc295885b1fc14977875aaf56be4121b Mon Sep 17 00:00:00 2001 From: yxstev Date: Thu, 18 Jun 2026 11:45:46 +0800 Subject: [PATCH 5/6] add ut Signed-off-by: yxstev --- tests/test_controller.py | 57 +++++++++++++++++++ tests/test_samplers.py | 58 +++++++++++++++++++ tests/test_simple_storage_unit.py | 95 +++++++++++++++++++++++++++++++ 3 files changed, 210 insertions(+) diff --git a/tests/test_controller.py b/tests/test_controller.py index d45c54eb..c717894f 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -1039,3 +1039,60 @@ def test_controller_kv_retrieve_keys_multiple_partitions(self, ray_setup): # Clean up ray.get(tq_controller.clear_partition.remote(partition_1)) ray.get(tq_controller.clear_partition.remote(partition_2)) + + +class TestTransferQueueControllerCheckpoint: + def test_save_creates_file(self, ray_setup, tmp_path): + tq_controller = TransferQueueController.remote() + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True)) + + path = str(tmp_path / "ckpt.pkl") + ray.get(tq_controller.save_checkpoint.remote(path)) + + assert (tmp_path / "ckpt.pkl").exists() + assert (tmp_path / "ckpt.pkl").stat().st_size > 0 + + def test_save_raises_on_invalid_path(self, ray_setup): + tq_controller = TransferQueueController.remote() + + with pytest.raises(ray.exceptions.RayTaskError, match="save checkpoint failed"): + ray.get(tq_controller.save_checkpoint.remote("/nonexistent_dir/ckpt.pkl")) + + def test_load_restores_partition(self, ray_setup, tmp_path): + tq_controller = TransferQueueController.remote() + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True)) + + path = str(tmp_path / "ckpt.pkl") + ray.get(tq_controller.save_checkpoint.remote(path)) + ray.get(tq_controller.clear_partition.remote("p0")) + assert ray.get(tq_controller.list_partitions.remote()) == [] + + ray.get(tq_controller.load_checkpoint.remote(path)) + assert "p0" in ray.get(tq_controller.list_partitions.remote()) + + def test_load_restores_key_mapping(self, ray_setup, tmp_path): + tq_controller = TransferQueueController.remote() + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["k0", "k1"], partition_id="p0", create=True)) + + path = str(tmp_path / "ckpt.pkl") + ray.get(tq_controller.save_checkpoint.remote(path)) + ray.get(tq_controller.clear_partition.remote("p0")) + + ray.get(tq_controller.load_checkpoint.remote(path)) + snapshot = ray.get(tq_controller.get_partition_snapshot.remote("p0")) + assert "k0" in snapshot.keys_mapping + assert "k1" in snapshot.keys_mapping + + def test_load_raises_on_missing_file(self, ray_setup): + tq_controller = TransferQueueController.remote() + + with pytest.raises(ray.exceptions.RayTaskError, match="load checkpoint failed"): + ray.get(tq_controller.load_checkpoint.remote("/nonexistent/ckpt.pkl")) + + def test_load_raises_on_corrupt_file(self, ray_setup, tmp_path): + tq_controller = TransferQueueController.remote() + path = tmp_path / "bad.pkl" + path.write_bytes(b"not a pickle") + + with pytest.raises(ray.exceptions.RayTaskError, match="load checkpoint failed"): + ray.get(tq_controller.load_checkpoint.remote(str(path))) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index c8e8f843..9575c509 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -61,6 +61,33 @@ def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tu assert hasattr(sampler, "_states") assert sampler._states == {} + def test_save_load_checkpoint_roundtrip(self): + """save_checkpoint / load_checkpoint restore _states.""" + + class TestSampler(BaseSampler): + def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tuple[list[int], list[int]]: + return ready_indexes[:batch_size], ready_indexes[:batch_size] + + sampler = TestSampler() + sampler._states = {"p0": {"task": {0: {0: ([1, 2], [1, 2])}}}} + + state = sampler.save_checkpoint() + sampler._states = {} + sampler.load_checkpoint(state) + + assert sampler._states == {"p0": {"task": {0: {0: ([1, 2], [1, 2])}}}} + + def test_save_checkpoint_returns_states_key(self): + """save_checkpoint dict must contain '_states'.""" + + class TestSampler(BaseSampler): + def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tuple[list[int], list[int]]: + return [], [] + + sampler = TestSampler() + state = sampler.save_checkpoint() + assert "_states" in state + class TestSequentialSampler: """Test cases for SequentialSampler.""" @@ -1061,6 +1088,37 @@ def test_batch_size_not_divisible_by_n_samples_per_prompt(self): assert "must be a multiple of n_samples_per_prompt" in str(exc_info.value) + def test_save_load_checkpoint_roundtrip(self): + """save_checkpoint / load_checkpoint restore both _states and _balanced_cache.""" + sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=2) + ready_indexes = [0, 1, 2, 3] + + sampler.sample(ready_indexes, 2, task_name="task", partition_id="p0", dp_rank=0, batch_index=0) + + assert len(sampler._balanced_cache) > 0 + assert "p0" in sampler._states + + state = sampler.save_checkpoint() + + new_sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=2) + assert new_sampler._balanced_cache == {} + assert new_sampler._states == {} + + new_sampler.load_checkpoint(state) + + assert new_sampler._states == sampler._states + assert new_sampler._balanced_cache == sampler._balanced_cache + + def test_save_checkpoint_includes_balanced_cache(self): + """save_checkpoint state dict must contain both '_states' and '_balanced_cache'.""" + sampler = SeqlenBalancedSampler(n_samples_per_prompt=1, dp_size=1) + sampler.sample([0, 1], 2, task_name="task", partition_id="p0", dp_rank=0, batch_index=0) + + state = sampler.save_checkpoint() + + assert "_states" in state + assert "_balanced_cache" in state + class TestKarmarkarKarp: """Test cases for karmarkar_karp and get_seqlen_balanced_partitions utilities.""" diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index 319a46e7..f81bb56b 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -64,6 +64,24 @@ def send_clear(self, client_id, global_indexes): self.socket.send_multipart(msg.serialize()) return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) + def send_save_checkpoint(self, path): + msg = ZMQMessage.create( + request_type=ZMQRequestType.SAVE_STORAGE_CHECKPOINT, + sender_id="mock_client_ckpt", + body={"path": path}, + ) + self.socket.send_multipart(msg.serialize()) + return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) + + def send_load_checkpoint(self, path): + msg = ZMQMessage.create( + request_type=ZMQRequestType.LOAD_STORAGE_CHECKPOINT, + sender_id="mock_client_ckpt", + body={"path": path}, + ) + self.socket.send_multipart(msg.serialize()) + return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) + def close(self): self.socket.close() self.context.term() @@ -611,3 +629,80 @@ def wrong_len_parser(field_data): assert "data_parser changed the number of elements" in response.body["message"] client.close() + + +def test_save_load_checkpoint_roundtrip(storage_setup, tmp_path): + """Save and load restore data correctly.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + global_indexes = [0, 1] + field_data = { + "input_ids": [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])], + "labels": [torch.tensor([0]), torch.tensor([1])], + } + client.send_put(0, global_indexes, field_data) + + path = str(tmp_path / "su.pkl") + resp = client.send_save_checkpoint(path) + assert resp.request_type == ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE + assert resp.body["success"] is True + + resp = client.send_load_checkpoint(path) + assert resp.request_type == ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE + assert resp.body["success"] is True + + resp = client.send_get(0, global_indexes, ["input_ids"]) + assert resp.request_type == ZMQRequestType.GET_DATA_RESPONSE + torch.testing.assert_close(resp.body["data"]["input_ids"][0], torch.tensor([1, 2, 3])) + torch.testing.assert_close(resp.body["data"]["input_ids"][1], torch.tensor([4, 5, 6])) + + client.close() + + +def test_save_checkpoint_invalid_path(storage_setup): + """Save to an inaccessible path returns failure response.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + resp = client.send_save_checkpoint("/nonexistent_dir/su.pkl") + assert resp.request_type == ZMQRequestType.SAVE_STORAGE_CHECKPOINT_RESPONSE + assert resp.body["success"] is False + assert "message" in resp.body + + client.close() + + +def test_load_checkpoint_missing_file(storage_setup): + """Load from a non-existent file returns failure response.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + resp = client.send_load_checkpoint("/nonexistent/su.pkl") + assert resp.request_type == ZMQRequestType.LOAD_STORAGE_CHECKPOINT_RESPONSE + assert resp.body["success"] is False + assert "message" in resp.body + + client.close() + + +def test_load_checkpoint_clears_existing_data(storage_setup, tmp_path): + """Data added after save should not survive a load.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + client.send_put(0, [0], {"val": [torch.tensor([1])]}) + path = str(tmp_path / "su.pkl") + client.send_save_checkpoint(path) + + # Add index 1 after save + client.send_put(0, [1], {"val": [torch.tensor([99])]}) + + # Restore to checkpoint that only had index 0 + client.send_load_checkpoint(path) + + # Index 1 should be gone + resp = client.send_get(0, [1], ["val"]) + assert resp.request_type != ZMQRequestType.GET_DATA_RESPONSE + + client.close() From dfe8032b7e3cb13b0db270353f2dc7193a7137c9 Mon Sep 17 00:00:00 2001 From: yxstev Date: Fri, 26 Jun 2026 10:21:01 +0800 Subject: [PATCH 6/6] [fix] Validate storage checkpoint before restoring controller state Previously, load_checkpoint restored controller state first, then validated storage unit count. A mismatch would raise ValueError after the controller had already been overwritten, leaving the system in an inconsistent state. Fix: introduce validate_checkpoint on the storage manager that checks file existence and SU count without touching any state. interface.py now runs this validation before any restoration begins. Also adds an assertion to test_load_raises_on_storage_unit_count_mismatch to verify controller state is unchanged after a failed load. Signed-off-by: yxstev --- tests/e2e/test_checkpoint_e2e.py | 7 ++++++- transfer_queue/client.py | 14 ++++++++++++++ transfer_queue/interface.py | 3 +++ transfer_queue/storage/managers/base.py | 11 +++++++++++ .../storage/managers/simple_storage_manager.py | 15 +++++++++++++-- 5 files changed, 47 insertions(+), 3 deletions(-) diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py index f2a8da92..c1850daa 100644 --- a/tests/e2e/test_checkpoint_e2e.py +++ b/tests/e2e/test_checkpoint_e2e.py @@ -382,7 +382,7 @@ def test_load_raises_if_metadata_missing(self, tq_system, tmp_path): with pytest.raises(FileNotFoundError, match="metadata.json"): tq.load_checkpoint(ck) - def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, checkpoint_dir): + def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, checkpoint_dir, controller): # Define test data + Put + Save tq.kv_batch_put( keys=["e0"], @@ -400,9 +400,14 @@ def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, c with open(su_info_path, "w") as f: json.dump(su_info, f) + partitions_before = ray.get(controller.list_partitions.remote()) + with pytest.raises(ValueError, match="count mismatch"): tq.load_checkpoint(checkpoint_dir) + # Controller state must not have been modified + assert ray.get(controller.list_partitions.remote()) == partitions_before + def test_no_partial_state_on_failed_save(self, tq_system, tmp_path): # Define test data + Put tq.kv_batch_put( diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 34f35757..35845d74 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -1162,6 +1162,15 @@ async def async_save_storage_checkpoint(self, checkpoint_dir: str) -> None: ) await self.storage_manager.save_checkpoint(checkpoint_dir) + async def async_validate_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Validate storage checkpoint compatibility without modifying any state.""" + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + await self.storage_manager.validate_checkpoint(checkpoint_dir) + async def async_load_storage_checkpoint(self, checkpoint_dir: str) -> None: """Asynchronously restore storage state from a directory via StorageManager. @@ -1257,6 +1266,7 @@ def wrapper(*args, **kwargs): self._load_controller_checkpoint = _make_sync(self.async_load_controller_checkpoint) self._save_storage_checkpoint = _make_sync(self.async_save_storage_checkpoint) self._load_storage_checkpoint = _make_sync(self.async_load_storage_checkpoint) + self._validate_storage_checkpoint = _make_sync(self.async_validate_storage_checkpoint) # ==================== Basic API ==================== def get_meta( @@ -1738,6 +1748,10 @@ def load_storage_checkpoint(self, checkpoint_dir: str) -> None: """ return self._load_storage_checkpoint(checkpoint_dir) + def validate_storage_checkpoint(self, checkpoint_dir: str) -> None: + """Synchronously validate storage checkpoint compatibility without modifying any state.""" + return self._validate_storage_checkpoint(checkpoint_dir) + def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index ed48f0d3..fdb8e310 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -1174,6 +1174,9 @@ def load_checkpoint( if not controller_path.exists(): raise FileNotFoundError(f"{_CONTROLLER_FILE} not found in {checkpoint_dir}") + if meta.get("storage_saved"): + client.validate_storage_checkpoint(str(checkpoint_dir)) + client.load_controller_checkpoint(str(controller_path)) if meta.get("storage_saved"): diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index e6b0faf4..2b377990 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -363,6 +363,17 @@ async def save_checkpoint(self, checkpoint_dir: str) -> None: """ raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + async def validate_checkpoint(self, checkpoint_dir: str) -> None: + """Validate that checkpoint_dir is compatible with the current system configuration. + + Must be called before load_checkpoint to catch mismatches (e.g. storage unit + count) before any state is modified. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + async def load_checkpoint(self, checkpoint_dir: str) -> None: """Restore storage state from checkpoint_dir. diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index 0b6777fb..9c508a37 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -613,8 +613,8 @@ async def save_checkpoint(self, checkpoint_dir: str) -> None: logger.info(f"[{self.storage_manager_id}]: saved {len(su_ids)} storage units to {su_dir}") - async def load_checkpoint(self, checkpoint_dir: str) -> None: - """Restore all storage units from the storage_units/ subdirectory of checkpoint_dir.""" + async def validate_checkpoint(self, checkpoint_dir: str) -> None: + """Validate that the checkpoint is compatible before any state is modified.""" import json su_dir = Path(checkpoint_dir) / _SU_SUBDIR @@ -631,7 +631,18 @@ async def load_checkpoint(self, checkpoint_dir: str) -> None: f"Storage unit count mismatch: checkpoint has {len(su_info_list)}, current system has {len(su_ids)}." ) + async def load_checkpoint(self, checkpoint_dir: str) -> None: + """Restore all storage units from the storage_units/ subdirectory of checkpoint_dir.""" + import json + + su_dir = Path(checkpoint_dir) / _SU_SUBDIR + su_info_path = su_dir / _SU_INFO_FILE + + with open(su_info_path) as f: + su_info_list = json.load(f) + entries = sorted(su_info_list, key=lambda e: e["position"]) + su_ids = list(self.storage_unit_infos.keys()) tasks = [ self._load_single_storage_unit( str(su_dir / f"su_{entry['position']}_{entry['storage_unit_id']}.pkl"),