From 349d262a832c8bf0fe64953f9717a739cc0e9437 Mon Sep 17 00:00:00 2001 From: MagellaX Date: Thu, 8 Jan 2026 17:41:50 +0530 Subject: [PATCH 1/7] Handle mid-epoch dataset changes on resume --- README.md | 4 ++ src/litdata/streaming/dataloader.py | 73 ++++++++++++++++++- tests/streaming/test_dataloader.py | 105 ++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2be60537b..9732a884a 100644 --- a/README.md +++ b/README.md @@ -495,6 +495,10 @@ if os.path.isfile("dataloader_state.pt"): state_dict = torch.load("dataloader_state.pt") dataloader.load_state_dict(state_dict) +# If you resume from a checkpoint with a different dataset, use +# StreamingDataLoader(..., dataset_change_policy="next_epoch") to skip the +# remainder of the old epoch and start fresh on the new data. + # Iterate over the data for batch_idx, batch in enumerate(dataloader): diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index d83808e73..21beb2375 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -18,7 +18,7 @@ from copy import deepcopy from importlib import reload from itertools import cycle -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch from torch.utils.data import Dataset, IterableDataset @@ -51,6 +51,21 @@ logger = logging.getLogger("litdata.streaming.dataloader") +DatasetChangePolicy = Literal["error", "next_epoch"] + + +def _streaming_dataset_signature(state_dict: Dict[str, Any]) -> Tuple[Any, ...]: + return ( + state_dict.get("input_dir_path"), + state_dict.get("input_dir_url"), + state_dict.get("item_loader"), + state_dict.get("seed"), + state_dict.get("shuffle"), + state_dict.get("drop_last"), + state_dict.get("subsampled_files"), + state_dict.get("region_of_interest"), + ) + def _equal_items(data_1: Any, data_2: Any) -> bool: data_1_flattened, _ = tree_flatten(data_1) @@ -568,6 +583,9 @@ class StreamingDataLoader(DataLoader): profile_skip_batches (int): How many batches to skip before recording profile_batches (int, bool, optional): Whether to record data loading profile and generate a result.json file. profile_dir (int, bool, optional): Where to store the recorded trace when profile_batches is enabled. + dataset_change_policy (str, optional): Behavior when resuming from a checkpoint with a different dataset + configuration. Use ``"next_epoch"`` to skip the remainder of the current epoch and start a new one, + or ``"error"`` (default) to keep the existing behavior. """ @@ -585,6 +603,7 @@ def __init__( prefetch_factor: Optional[int] = None, shuffle: Optional[bool] = None, drop_last: Optional[bool] = None, + dataset_change_policy: DatasetChangePolicy = "error", collate_fn: Optional[Callable] = None, **kwargs: Any, ) -> None: # pyright: ignore @@ -600,6 +619,9 @@ def __init__( if drop_last is not None: dataset.set_drop_last(drop_last) + if dataset_change_policy not in ("error", "next_epoch"): + raise ValueError(f"Invalid dataset_change_policy: {dataset_change_policy}") + dataset.set_batch_size(batch_size) dataset.set_num_workers(num_workers) @@ -624,6 +646,7 @@ def __init__( self._num_samples_yielded_wrapper: Dict[int, List[int]] = {} self._num_cycles: Dict[int, List[int]] = {} self.rng_state: Optional[Any] = None + self._dataset_change_policy = dataset_change_policy self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) self._worker_idx_iter: Optional[Any] = None self._latest_worker_idx = 0 @@ -704,6 +727,40 @@ def __len__(self) -> int: return length return len(self._index_sampler) + def _has_dataset_changed(self, state_dict: Dict[str, Any]) -> bool: + if not state_dict: + return False + + if isinstance(self.dataset, StreamingDataset): + current_signature = _streaming_dataset_signature( + self.dataset.state_dict(num_samples_yielded=0, num_workers=self.num_workers, batch_size=self.batch_size) + ) + saved_signature = _streaming_dataset_signature(state_dict) + return current_signature != saved_signature + + if isinstance(self.dataset, (CombinedStreamingDataset, ParallelStreamingDataset)): + if not all(isinstance(d, StreamingDataset) for d in self.dataset._datasets): + return False + + saved_signatures: List[Tuple[Any, ...]] = [] + for dataset_idx in range(len(self.dataset._datasets)): + key = str(dataset_idx) + if key not in state_dict: + return True + saved_signatures.append(_streaming_dataset_signature(state_dict[key])) + + current_signatures = [ + _streaming_dataset_signature( + dataset.state_dict( + num_samples_yielded=0, num_workers=self.num_workers, batch_size=self.batch_size + ) + ) + for dataset in self.dataset._datasets + ] + return current_signatures != saved_signatures + + return False + def state_dict(self) -> Dict[str, Any]: if isinstance(self.dataset, StreamingDataset): assert self.batch_size @@ -757,6 +814,20 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: """ self.current_epoch = obj["current_epoch"] + if self._dataset_change_policy == "next_epoch" and self._has_dataset_changed(obj["dataset"]): + logger.info( + "Detected a dataset change while resuming. Skipping the remainder of the current epoch and " + "continuing from the next one." + ) + self.restore = False + self.dataset.reset_state_dict() + self._latest_worker_idx = 0 + self._worker_idx_iter = iter(self._worker_idx) + self._num_samples_yielded_streaming = 0 + self._num_samples_yielded_wrapper = {} + self._num_cycles = {} + return + if isinstance(self.dataset, StreamingDataset): self._num_samples_yielded_streaming = obj["num_samples_yielded"] else: diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 68e8cbdc2..453a205ae 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -342,6 +342,111 @@ def test_resume_dataloader_with_new_dataset(tmpdir): assert dataloader.current_epoch == 2, "Current epoch should be 2" +@pytest.mark.timeout(120) +def test_resume_dataloader_mid_epoch_with_new_dataset(tmpdir): + dataset_1_path = tmpdir.join("dataset_1") + dataset_2_path = tmpdir.join("dataset_2") + for dataset, start in [(dataset_1_path, 0), (dataset_2_path, 100)]: + cache = Cache(input_dir=str(dataset), chunk_bytes="64MB") + for i in range(50): + cache[i] = i + start + cache.done() + cache.merge() + + dataset = StreamingDataset(str(dataset_1_path), shuffle=False) + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + for batch_idx, _ in enumerate(dataloader): + if batch_idx == 2: + break + + dataloader_state = dataloader.state_dict() + dataset = StreamingDataset(str(dataset_2_path), shuffle=False) + dataloader = StreamingDataLoader( + dataset, batch_size=4, num_workers=2, dataset_change_policy="next_epoch" + ) + dataloader.load_state_dict(dataloader_state) + assert not dataloader.restore + + first_batch = next(iter(dataloader)) + assert dataloader.current_epoch == 2, "Current epoch should be 2" + assert (first_batch >= 100).all().item() + + +@pytest.mark.timeout(300) +def test_resume_mid_epoch_with_new_dataset_next_epoch_e2e(tmp_path): + from pytorch_lightning import LightningModule, Trainer + from pytorch_lightning.callbacks import ModelCheckpoint + + def _write_dataset(path, value): + cache = Cache(input_dir=str(path), chunk_size=4) + for i in range(8): + cache[i] = value + cache.done() + cache.merge() + + data_dir_1 = tmp_path / "data_1" + data_dir_2 = tmp_path / "data_2" + _write_dataset(data_dir_1, 0) + _write_dataset(data_dir_2, 1) + + def _make_dataset(path): + dataset = StreamingDataset(str(path), shuffle=False) + + def transform(x, dataset=dataset): + return (x, dataset.current_epoch) + + dataset.transform = transform + return dataset + + def _make_dataloader(path, policy="error"): + return StreamingDataLoader( + _make_dataset(path), batch_size=2, num_workers=0, dataset_change_policy=policy + ) + + class _ValueCheckModel(LightningModule): + def __init__(self, expected_value, expected_epoch=None): + super().__init__() + self.expected_value = expected_value + self.expected_epoch = expected_epoch + self.layer = torch.nn.Linear(1, 1) + + def training_step(self, batch, batch_idx): + values, epochs = batch + assert (values == self.expected_value).all().item() + if self.expected_epoch is not None: + assert (epochs == self.expected_epoch).all().item() + loss = self.layer(values.float().unsqueeze(-1)).mean() + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.1) + + ckpt_callback = ModelCheckpoint(dirpath=str(tmp_path), save_last=True) + trainer = Trainer( + max_steps=2, + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + callbacks=[ckpt_callback], + ) + trainer.fit(_ValueCheckModel(expected_value=0), train_dataloaders=_make_dataloader(data_dir_1)) + ckpt_path = ckpt_callback.last_model_path + assert ckpt_path + + trainer = Trainer( + max_steps=1, + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + enable_checkpointing=False, + ) + trainer.fit( + _ValueCheckModel(expected_value=1, expected_epoch=2), + train_dataloaders=_make_dataloader(data_dir_2, policy="next_epoch"), + ckpt_path=ckpt_path, + ) + + def test_resume_dataloader_after_some_workers_are_done(tmpdir): # see https://github.com/Lightning-AI/litData/issues/563 dset_path = tmpdir.join("dataset") From c0507783a7f641b886ada1c647b42fabe45c531a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 12:46:42 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataloader.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 3f4d3f6c3..240f57724 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -362,9 +362,7 @@ def test_resume_dataloader_mid_epoch_with_new_dataset(tmpdir): dataloader_state = dataloader.state_dict() dataset = StreamingDataset(str(dataset_2_path), shuffle=False) - dataloader = StreamingDataLoader( - dataset, batch_size=4, num_workers=2, dataset_change_policy="next_epoch" - ) + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2, dataset_change_policy="next_epoch") dataloader.load_state_dict(dataloader_state) assert not dataloader.restore @@ -400,9 +398,7 @@ def transform(x, dataset=dataset): return dataset def _make_dataloader(path, policy="error"): - return StreamingDataLoader( - _make_dataset(path), batch_size=2, num_workers=0, dataset_change_policy=policy - ) + return StreamingDataLoader(_make_dataset(path), batch_size=2, num_workers=0, dataset_change_policy=policy) class _ValueCheckModel(LightningModule): def __init__(self, expected_value, expected_epoch=None): From 967eb85148abe9cec03a37d88c768be2cdadc613 Mon Sep 17 00:00:00 2001 From: MagellaX Date: Thu, 8 Jan 2026 18:26:59 +0530 Subject: [PATCH 3/7] Fix resume tests and dataloader state init --- src/litdata/streaming/dataloader.py | 24 ++++-------------------- tests/streaming/test_dataloader.py | 4 ++-- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 7fd4b9259..f41b279b9 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -604,20 +604,12 @@ def __init__( num_workers: int = 0, profile_batches: bool | int = False, profile_skip_batches: int = 0, - - profile_dir: Optional[str] = None, - prefetch_factor: Optional[int] = None, - shuffle: Optional[bool] = None, - drop_last: Optional[bool] = None, - dataset_change_policy: DatasetChangePolicy = "error", - collate_fn: Optional[Callable] = None, - profile_dir: str | None = None, prefetch_factor: int | None = None, shuffle: bool | None = None, drop_last: bool | None = None, + dataset_change_policy: DatasetChangePolicy = "error", collate_fn: Callable | None = None, - **kwargs: Any, ) -> None: # pyright: ignore if not isinstance(dataset, (StreamingDataset, _BaseStreamingDatasetWrapper)): @@ -656,17 +648,10 @@ def __init__( self._profile_skip_batches = profile_skip_batches self._profile_dir = profile_dir self._num_samples_yielded_streaming = 0 - - self._num_samples_yielded_wrapper: Dict[int, List[int]] = {} - self._num_cycles: Dict[int, List[int]] = {} - self.rng_state: Optional[Any] = None - self._dataset_change_policy = dataset_change_policy - self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) - self._worker_idx_iter: Optional[Any] = None - self._num_samples_yielded_wrapper: dict[int, list[int]] = {} self._num_cycles: dict[int, list[int]] = {} self.rng_state: Any | None = None + self._dataset_change_policy = dataset_change_policy self._worker_idx: Any | None = None # Lazily initialized in __iter__ self._worker_idx_iter: Any | None = None @@ -791,9 +776,6 @@ def _has_dataset_changed(self, state_dict: Dict[str, Any]) -> bool: return False def state_dict(self) -> Dict[str, Any]: - - def state_dict(self) -> dict[str, Any]: - if isinstance(self.dataset, StreamingDataset): assert self.batch_size return { @@ -854,6 +836,8 @@ def load_state_dict(self, obj: dict[str, Any]) -> None: self.restore = False self.dataset.reset_state_dict() self._latest_worker_idx = 0 + if self._worker_idx is None: + self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) self._worker_idx_iter = iter(self._worker_idx) self._num_samples_yielded_streaming = 0 self._num_samples_yielded_wrapper = {} diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 240f57724..7281747b3 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -373,8 +373,8 @@ def test_resume_dataloader_mid_epoch_with_new_dataset(tmpdir): @pytest.mark.timeout(300) def test_resume_mid_epoch_with_new_dataset_next_epoch_e2e(tmp_path): - from pytorch_lightning import LightningModule, Trainer - from pytorch_lightning.callbacks import ModelCheckpoint + from lightning.pytorch import LightningModule, Trainer + from lightning.pytorch.callbacks import ModelCheckpoint def _write_dataset(path, value): cache = Cache(input_dir=str(path), chunk_size=4) From c31dfb0f936a6f99d5447720d8243bbf5248be0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 12:57:43 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/dataloader.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index f41b279b9..7d50bb9a3 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -19,11 +19,7 @@ from copy import deepcopy from importlib import reload from itertools import cycle - -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union - -from typing import Any - +from typing import Any, Literal import torch from torch.utils.data import Dataset, IterableDataset @@ -59,7 +55,7 @@ DatasetChangePolicy = Literal["error", "next_epoch"] -def _streaming_dataset_signature(state_dict: Dict[str, Any]) -> Tuple[Any, ...]: +def _streaming_dataset_signature(state_dict: dict[str, Any]) -> tuple[Any, ...]: return ( state_dict.get("input_dir_path"), state_dict.get("input_dir_url"), @@ -740,8 +736,7 @@ def __len__(self) -> int: return length return len(self._index_sampler) - - def _has_dataset_changed(self, state_dict: Dict[str, Any]) -> bool: + def _has_dataset_changed(self, state_dict: dict[str, Any]) -> bool: if not state_dict: return False @@ -756,7 +751,7 @@ def _has_dataset_changed(self, state_dict: Dict[str, Any]) -> bool: if not all(isinstance(d, StreamingDataset) for d in self.dataset._datasets): return False - saved_signatures: List[Tuple[Any, ...]] = [] + saved_signatures: list[tuple[Any, ...]] = [] for dataset_idx in range(len(self.dataset._datasets)): key = str(dataset_idx) if key not in state_dict: @@ -765,9 +760,7 @@ def _has_dataset_changed(self, state_dict: Dict[str, Any]) -> bool: current_signatures = [ _streaming_dataset_signature( - dataset.state_dict( - num_samples_yielded=0, num_workers=self.num_workers, batch_size=self.batch_size - ) + dataset.state_dict(num_samples_yielded=0, num_workers=self.num_workers, batch_size=self.batch_size) ) for dataset in self.dataset._datasets ] @@ -775,7 +768,7 @@ def _has_dataset_changed(self, state_dict: Dict[str, Any]) -> bool: return False - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if isinstance(self.dataset, StreamingDataset): assert self.batch_size return { From 03a6e9a0073d2cbda855c2a2305b8456946ad708 Mon Sep 17 00:00:00 2001 From: MagellaX Date: Fri, 9 Jan 2026 14:20:53 +0530 Subject: [PATCH 5/7] Reduce workers for macOS resume test --- tests/streaming/test_dataloader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 7281747b3..65a551c91 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -355,14 +355,17 @@ def test_resume_dataloader_mid_epoch_with_new_dataset(tmpdir): cache.merge() dataset = StreamingDataset(str(dataset_1_path), shuffle=False) - dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + num_workers = 0 if sys.platform == "darwin" else 2 + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=num_workers) for batch_idx, _ in enumerate(dataloader): if batch_idx == 2: break dataloader_state = dataloader.state_dict() dataset = StreamingDataset(str(dataset_2_path), shuffle=False) - dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2, dataset_change_policy="next_epoch") + dataloader = StreamingDataLoader( + dataset, batch_size=4, num_workers=num_workers, dataset_change_policy="next_epoch" + ) dataloader.load_state_dict(dataloader_state) assert not dataloader.restore From 74e43510c1314e02f4ca737b68c48a42842a73c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 08:51:11 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataloader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 65a551c91..fe919a070 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -363,9 +363,7 @@ def test_resume_dataloader_mid_epoch_with_new_dataset(tmpdir): dataloader_state = dataloader.state_dict() dataset = StreamingDataset(str(dataset_2_path), shuffle=False) - dataloader = StreamingDataLoader( - dataset, batch_size=4, num_workers=num_workers, dataset_change_policy="next_epoch" - ) + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=num_workers, dataset_change_policy="next_epoch") dataloader.load_state_dict(dataloader_state) assert not dataloader.restore From 444acb6190d72b16ff7c7fb89485bb36e2f05b4e Mon Sep 17 00:00:00 2001 From: MagellaX Date: Sat, 10 Jan 2026 12:39:42 +0530 Subject: [PATCH 7/7] Force CPU for Lightning resume E2E test --- tests/streaming/test_dataloader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index fe919a070..d49ab433c 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -425,6 +425,8 @@ def configure_optimizers(self): logger=False, enable_model_summary=False, enable_progress_bar=False, + accelerator="cpu", + devices=1, callbacks=[ckpt_callback], ) trainer.fit(_ValueCheckModel(expected_value=0), train_dataloaders=_make_dataloader(data_dir_1)) @@ -437,6 +439,8 @@ def configure_optimizers(self): enable_model_summary=False, enable_progress_bar=False, enable_checkpointing=False, + accelerator="cpu", + devices=1, ) trainer.fit( _ValueCheckModel(expected_value=1, expected_epoch=2),