From 6a77302334e54c81b162ca9e03b22ce24dff56cc Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Date: Fri, 24 Oct 2025 15:57:03 -0400 Subject: [PATCH 1/3] Adding multisample feature along with testcases --- src/litdata/streaming/dataset.py | 45 +++++++++++-- tests/streaming/test_dataset.py | 111 +++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 6 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 1dca6c5b..e7925b2c 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -62,6 +62,7 @@ def __init__( index_path: Optional[str] = None, force_override_state_dict: bool = False, transform: Optional[Union[Callable, list[Callable]]] = None, + is_multisample: bool = False, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -89,6 +90,7 @@ def __init__( If `index_path` is a full file path, it will use that directly. force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict. transform: Optional transformation function or list of functions to apply to each item in the dataset. + is_multisample: If True, each index access returns multiple samples transformed by the list of functions. """ _check_version_and_prompt_upgrade(__version__) @@ -209,6 +211,9 @@ def __init__( raise ValueError(f"Transform should be a callable. Found {t}") self.transform = transform self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache + self.is_multisample = is_multisample + if self.is_multisample and not transform: + raise ValueError("When using `is_multisample=True`, `transform` must be a list of callables.") @property def on_demand_bytes(self) -> bool: @@ -282,7 +287,8 @@ def _create_shuffler(self, cache: Cache) -> Shuffle: return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last) def __len__(self) -> int: - return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) + original_len = self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) + return original_len if not self.is_multisample else original_len * len(self.transform) def set_batch_size(self, batch_size: int) -> None: self.batch_size = batch_size @@ -323,8 +329,13 @@ def __iter__(self) -> "StreamingDataset": self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] + # multiply the interval by the multisample factor if multisampling is enabled + self.multisample_factor = len(self.transform) if self.is_multisample else 1 + # The max number of samples to return from `__next__` (in worker) - self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) + self.stop_length = ( + sum(interval[2] - interval[1] for interval in self.worker_intervals) * self.multisample_factor + ) # Handle restart if self._state_dict: @@ -407,7 +418,8 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any]) # replay the indexes for the current chunks interval = self.worker_intervals[self.worker_next_chunk_index] - current_indexes = np.arange(interval[1], interval[2]) + # multiply the interval by the multisample factor if multisampling is enabled + current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor) # re-shuffle the indexes current_indexes = self.shuffler( @@ -424,6 +436,21 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any]) self.worker_next_chunk_index += 1 def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: + # Deflate index for multisample case + if self.is_multisample: + if not self.transform: + raise ValueError("When using `is_multisample=True`, `transform` must be a list of callables.") + if not all(callable(fn) for fn in self.transform): + raise ValueError("All elements in `transform` must be callable when using `is_multisample=True`.") + if isinstance(index, int): + sample_idx = index % len(self.transform) + index = index // len(self.transform) + elif isinstance(index, ChunkedIndex): + sample_idx = index.index % len(self.transform) + index.index = index.index // len(self.transform) + else: + raise ValueError("Slices are not supported when using `is_multisample=True`.") + if self.cache is None: self.worker_env = _WorkerEnv.detect() self.cache = self._create_cache(worker_env=self.worker_env) @@ -437,16 +464,21 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: _my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices] return [self.cache[chnk_idx] for chnk_idx in _my_cache_indices] item = self.cache[index] + if hasattr(self, "transform"): if isinstance(self.transform, list): - for transform_fn in self.transform: - item = transform_fn(item) + if not self.is_multisample: + for transform_fn in self.transform: + item = transform_fn(item) + else: + item = self.transform[sample_idx](item) # apply the specific transform for multisample else: item = self.transform(item) return item def __next__(self) -> Any: + # print(self.worker_next_chunk_index, self.num_chunks) # check if we have reached the end of the dataset (i.e., all the chunks have been processed) if self.global_index >= self.stop_length: # global_index: total number of samples processed by the current worker across all chunks @@ -476,7 +508,8 @@ def __next__(self) -> Any: # `next_worker_chunks_index` is the index of the chunk that we will be working on now interval = self.worker_intervals[self.worker_next_chunk_index] - current_indexes = np.arange(interval[1], interval[2]) + + current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor) assert self.shuffler is not None assert self.num_chunks is not None diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c3cd4637..1103adac 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1813,3 +1813,114 @@ def transform(self, x, *args, **kwargs): # Verify that the transform is applied correctly for i, item in enumerate(complete_data): assert item == i * 2, f"Expected {i * 2}, got {item}" + + +def test_dataset_multisample(tmpdir): + """Test if the dataset transform is applied correctly.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # Define simple transform functions + def transform_fn_sq(x, *args, **kwargs): + """A simple transform function that doubles the input.""" + return x * 2 + + def transform_fn_add(x): + """A simple transform function that adds 3 to the input.""" + return x + 3 + + def transform_fn_identity(x): + """A simple transform function that returns the input as is.""" + return x + + dataset = StreamingDataset( + data_dir, + cache_dir=str(cache_dir), + shuffle=False, + transform=[transform_fn_sq, transform_fn_add, transform_fn_identity], + is_multisample=True, + ) + dataset_length = len(dataset) + assert dataset_length == 300 + + # ASSERT + # Verify that the transform functions are applied correctly + for i, item in enumerate(dataset): + assert item is not None + if i % 3 == 0: + assert item == (i // len(dataset.transform)) * 2, ( + f"Expected {(i // len(dataset.transform)) * 2}, got {item}" + ) + elif i % 3 == 1: + assert item == (i // len(dataset.transform)) + 3, ( + f"Expected {(i // len(dataset.transform)) + 3}, got {item}" + ) + else: + assert item == (i // len(dataset.transform)), f"Expected {(i // len(dataset.transform))}, got {item}" + + +def test_dataset_multisample_single_transform(tmpdir): + """Test if the dataset transform is applied correctly.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # Define simple transform functions + def transform_fn_sq(x, *args, **kwargs): + """A simple transform function that doubles the input.""" + return x * 2 + + dataset = StreamingDataset( + data_dir, cache_dir=str(cache_dir), shuffle=False, transform=transform_fn_sq, is_multisample=True + ) + dataset_length = len(dataset) + assert dataset_length == 100 + + # ASSERT + # Verify that the transform function is applied correctly + for i, item in enumerate(dataset): + assert item is not None + assert item == (i * 2), f"Expected {(i * 2)}, got {item}" + + +def test_dataset_multisample_nonlist_transform_error(tmpdir): + """Test if the dataset raises an error when is_multisample is True but transform is not a list.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # ASSERT + # Verify that ValueError is raised when transform is not given + with pytest.raises(ValueError, match="When using `is_multisample=True`, `transform` must be a list of callables."): + StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, is_multisample=True) From 494e1481ce9844079e39673259f68c34490ca631 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Date: Mon, 3 Nov 2025 15:44:59 -0500 Subject: [PATCH 2/3] Modifying the feature for only single transform case --- README.md | 38 +++++++++ src/litdata/streaming/dataset.py | 65 ++++++++-------- tests/streaming/test_dataloader.py | 121 +++++++++++++++++++++++++++++ tests/streaming/test_dataset.py | 108 ++++++++++++------------- 4 files changed, 242 insertions(+), 90 deletions(-) diff --git a/README.md b/README.md index adcf8ded..76d42c8d 100644 --- a/README.md +++ b/README.md @@ -1082,6 +1082,44 @@ dataset = StreamingDatasetWithTransform(data_dir, cache_dir=str(cache_dir), shuf + +
+ ✅ Multi-Sample Transform datasets while Streaming 🔗 +  + +Sometimes you need to return a sub-sample batch for a given batch while adding subtle variations to the samples. The multi-sample feature allows you to apply multi-sample transformation while streaming, without the need to store intermediate results. + +```python +def transform_fn(x, sample_idx): + """ + Apply different rotation for each sample based on sample_idx. + """ + + angles = [0, 15, -15, 30] + angle = angles[sample_idx % len(angles)] + + torch_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.Lambda(lambda x: transforms.functional.rotate(x, angle)), # apply rotation + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ]) + return torch_transform(x) + +dataset = StreamingDataset( +data_dir, +cache_dir=str(cache_dir), +shuffle=False, +transform=[transform_fn], +sample_count=4 # Generate 4 transformed samples per input +) +``` + +
+
✅ Split datasets for train, val, test diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index e7925b2c..9a049363 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -13,6 +13,7 @@ import logging import os +from inspect import signature from time import time from typing import Any, Callable, Optional, Union @@ -62,7 +63,7 @@ def __init__( index_path: Optional[str] = None, force_override_state_dict: bool = False, transform: Optional[Union[Callable, list[Callable]]] = None, - is_multisample: bool = False, + sample_count: int = 1, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -90,7 +91,7 @@ def __init__( If `index_path` is a full file path, it will use that directly. force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict. transform: Optional transformation function or list of functions to apply to each item in the dataset. - is_multisample: If True, each index access returns multiple samples transformed by the list of functions. + sample_count: Number of samples to return for each index access. """ _check_version_and_prompt_upgrade(__version__) @@ -204,16 +205,28 @@ def __init__( self.storage_options = storage_options self.session_options = session_options self.max_pre_download = max_pre_download + self.sample_count = sample_count if transform is not None: transform = transform if isinstance(transform, list) else [transform] for t in transform: if not callable(t): raise ValueError(f"Transform should be a callable. Found {t}") self.transform = transform + + # define invalid transform conditions for multisample case + invalid_transform = self.sample_count > 1 and ( + not hasattr(self, "transform") + or len(self.transform) > 1 + or "sample_idx" not in signature(self.transform[0]).parameters + ) + if invalid_transform: + logger.warning( + "Invalid transform configuration detected. " + "Either no transform, multiple transforms, or missing `sample_idx` parameter. " + "Reverting `sample_count` to 1 and returning data as-is." + ) + self.sample_count = 1 self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache - self.is_multisample = is_multisample - if self.is_multisample and not transform: - raise ValueError("When using `is_multisample=True`, `transform` must be a list of callables.") @property def on_demand_bytes(self) -> bool: @@ -287,8 +300,7 @@ def _create_shuffler(self, cache: Cache) -> Shuffle: return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last) def __len__(self) -> int: - original_len = self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) - return original_len if not self.is_multisample else original_len * len(self.transform) + return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) * self.sample_count def set_batch_size(self, batch_size: int) -> None: self.batch_size = batch_size @@ -329,13 +341,8 @@ def __iter__(self) -> "StreamingDataset": self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] - # multiply the interval by the multisample factor if multisampling is enabled - self.multisample_factor = len(self.transform) if self.is_multisample else 1 - # The max number of samples to return from `__next__` (in worker) - self.stop_length = ( - sum(interval[2] - interval[1] for interval in self.worker_intervals) * self.multisample_factor - ) + self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) * self.sample_count # Handle restart if self._state_dict: @@ -418,8 +425,9 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any]) # replay the indexes for the current chunks interval = self.worker_intervals[self.worker_next_chunk_index] - # multiply the interval by the multisample factor if multisampling is enabled - current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor) + + # multiply the interval by the sample_count for multisample case + current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count) # re-shuffle the indexes current_indexes = self.shuffler( @@ -437,19 +445,15 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any]) def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: # Deflate index for multisample case - if self.is_multisample: - if not self.transform: - raise ValueError("When using `is_multisample=True`, `transform` must be a list of callables.") - if not all(callable(fn) for fn in self.transform): - raise ValueError("All elements in `transform` must be callable when using `is_multisample=True`.") + if self.sample_count > 1: if isinstance(index, int): - sample_idx = index % len(self.transform) - index = index // len(self.transform) + sample_idx = index % self.sample_count + index = index // self.sample_count elif isinstance(index, ChunkedIndex): - sample_idx = index.index % len(self.transform) - index.index = index.index // len(self.transform) + sample_idx = index.index % self.sample_count + index.index = index.index // self.sample_count else: - raise ValueError("Slices are not supported when using `is_multisample=True`.") + raise ValueError("Slices are not supported when using `sample_count > 1`.") if self.cache is None: self.worker_env = _WorkerEnv.detect() @@ -467,18 +471,14 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: if hasattr(self, "transform"): if isinstance(self.transform, list): - if not self.is_multisample: - for transform_fn in self.transform: - item = transform_fn(item) - else: - item = self.transform[sample_idx](item) # apply the specific transform for multisample + for transform_fn in self.transform: + item = transform_fn(item) if self.sample_count == 1 else transform_fn(item, sample_idx) else: item = self.transform(item) return item def __next__(self) -> Any: - # print(self.worker_next_chunk_index, self.num_chunks) # check if we have reached the end of the dataset (i.e., all the chunks have been processed) if self.global_index >= self.stop_length: # global_index: total number of samples processed by the current worker across all chunks @@ -509,7 +509,8 @@ def __next__(self) -> Any: # `next_worker_chunks_index` is the index of the chunk that we will be working on now interval = self.worker_intervals[self.worker_next_chunk_index] - current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor) + # current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor) + current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count) assert self.shuffler is not None assert self.num_chunks is not None diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index ddb517e4..7d073706 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -1,3 +1,4 @@ +import logging import os import sys @@ -496,3 +497,123 @@ def test_dataloader_dataset_transform_inheritance(tmpdir, shuffle): # Verify that the transform is applied correctly for i, item in enumerate(complete_data): assert item == i * 2, f"Expected {i * 2}, got {item}" + + +# Define a simple transform function +def multisample_transform_fn(x, sample_idx, *args, **kwargs): + """A simple transform function that doubles the input.""" + return x * sample_idx + + +def test_dataloader_dataset_transform_multisample(tmpdir): + """Test if the dataset's transform is applied correctly with dataloader.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + dataset = StreamingDataset( + data_dir, cache_dir=str(cache_dir), shuffle=False, transform=multisample_transform_fn, sample_count=3 + ) + dataset_length = len(dataset) + assert dataset_length == 300 + + # ACT + dl = StreamingDataLoader(dataset, batch_size=10, num_workers=1, shuffle=False) + + complete_data = [] + for batch in dl: + complete_data.extend(batch) + + # ASSERT + # Verify that the multisample transform is applied correctly + for i, item in enumerate(complete_data): + if i % 3 == 0: + assert item == (i // 3) * 0, f"Expected {i * 0}, got {item}" + elif i % 3 == 1: + assert item == (i // 3) * 1, f"Expected {i * 1}, got {item}" + else: + assert item == (i // 3) * 2, f"Expected {i * 2}, got {item}" + + +def test_dataloader_dataset_transform_invalid_config(tmpdir, caplog): + """Test if the dataset's transform is applied correctly with dataloader.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Define simple transform functions + def transform_fn_sq(x, sample_idx): + """A simple transform function that doubles the input.""" + return x * sample_idx + + def transform_fn_add(x, sample_idx): + """A simple transform function that adds the sample_idx to the input.""" + return x + sample_idx + + def transform_fn_no_sample_idx(x): + """A simple transform function that doubles the input.""" + return x + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # Verify that logger warning happens when transform is not given + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 + + # Verify that logger warning happens when multiple transforms are given + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset( + data_dir, + cache_dir=str(cache_dir), + shuffle=False, + sample_count=4, + transform=[transform_fn_sq, transform_fn_add], + ) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 + + # Verify that logger warning happens when sample_idx parameter is missing + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset( + data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4, transform=transform_fn_no_sample_idx + ) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 + + # ACT + dl = StreamingDataLoader(dataset, batch_size=10, num_workers=1, shuffle=False) + + complete_data = [] + for batch in dl: + complete_data.extend(batch) + + # ASSERT + # Verify that the multisample transform is applied correctly + for i, item in enumerate(complete_data): + assert item == i, f"Expected {i}, got {item}" diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 1103adac..f5aea6fa 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1815,7 +1815,7 @@ def transform(self, x, *args, **kwargs): assert item == i * 2, f"Expected {i * 2}, got {item}" -def test_dataset_multisample(tmpdir): +def test_dataset_transform_multisample(tmpdir): """Test if the dataset transform is applied correctly.""" # Create a simple dataset # Create directories for cache and data @@ -1832,24 +1832,13 @@ def test_dataset_multisample(tmpdir): cache.merge() # Define simple transform functions - def transform_fn_sq(x, *args, **kwargs): + def transform_fn_sq(x, sample_idx): """A simple transform function that doubles the input.""" - return x * 2 - - def transform_fn_add(x): - """A simple transform function that adds 3 to the input.""" - return x + 3 - - def transform_fn_identity(x): - """A simple transform function that returns the input as is.""" - return x + return x * sample_idx + sample_count = 3 dataset = StreamingDataset( - data_dir, - cache_dir=str(cache_dir), - shuffle=False, - transform=[transform_fn_sq, transform_fn_add, transform_fn_identity], - is_multisample=True, + data_dir, cache_dir=str(cache_dir), shuffle=False, transform=transform_fn_sq, sample_count=sample_count ) dataset_length = len(dataset) assert dataset_length == 300 @@ -1858,20 +1847,16 @@ def transform_fn_identity(x): # Verify that the transform functions are applied correctly for i, item in enumerate(dataset): assert item is not None - if i % 3 == 0: - assert item == (i // len(dataset.transform)) * 2, ( - f"Expected {(i // len(dataset.transform)) * 2}, got {item}" - ) - elif i % 3 == 1: - assert item == (i // len(dataset.transform)) + 3, ( - f"Expected {(i // len(dataset.transform)) + 3}, got {item}" - ) + if i % sample_count == 0: + assert item == (i // sample_count) * 0, f"Expected {(i // sample_count) * 0}, got {item}" + elif i % sample_count == 1: + assert item == (i // sample_count) * 1, f"Expected {(i // sample_count) * 1}, got {item}" else: - assert item == (i // len(dataset.transform)), f"Expected {(i // len(dataset.transform))}, got {item}" + assert item == (i // sample_count) * 2, f"Expected {(i // sample_count) * 2}, got {item}" -def test_dataset_multisample_single_transform(tmpdir): - """Test if the dataset transform is applied correctly.""" +def test_dataset_transform_multisample_invalid_config(tmpdir, caplog): + """Test if the dataset raises an error when is_multisample is True but transform is not a list.""" # Create a simple dataset # Create directories for cache and data cache_dir = os.path.join(tmpdir, "cache_dir") @@ -1879,6 +1864,19 @@ def test_dataset_multisample_single_transform(tmpdir): os.makedirs(cache_dir) os.makedirs(data_dir) + # Define simple transform functions + def transform_fn_sq(x, sample_idx): + """A simple transform function that doubles the input.""" + return x * sample_idx + + def transform_fn_add(x, sample_idx): + """A simple transform function that adds the sample_idx to the input.""" + return x + sample_idx + + def transform_fn_no_sample_idx(x): + """A simple transform function that misses the sample_idx parameter.""" + return x + # Create a dataset with 100 items, 20 items per chunk cache = Cache(str(data_dir), chunk_size=20) for i in range(100): @@ -1886,41 +1884,35 @@ def test_dataset_multisample_single_transform(tmpdir): cache.done() cache.merge() - # Define simple transform functions - def transform_fn_sq(x, *args, **kwargs): - """A simple transform function that doubles the input.""" - return x * 2 + # ASSERT + # Verify that logger warning happens when transform is not given + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4) - dataset = StreamingDataset( - data_dir, cache_dir=str(cache_dir), shuffle=False, transform=transform_fn_sq, is_multisample=True - ) + assert "Invalid transform configuration detected." in caplog.text dataset_length = len(dataset) assert dataset_length == 100 - # ASSERT - # Verify that the transform function is applied correctly - for i, item in enumerate(dataset): - assert item is not None - assert item == (i * 2), f"Expected {(i * 2)}, got {item}" - + # Verify that logger warning happens when multiple transforms are given + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset( + data_dir, + cache_dir=str(cache_dir), + shuffle=False, + sample_count=4, + transform=[transform_fn_sq, transform_fn_add], + ) -def test_dataset_multisample_nonlist_transform_error(tmpdir): - """Test if the dataset raises an error when is_multisample is True but transform is not a list.""" - # Create a simple dataset - # Create directories for cache and data - cache_dir = os.path.join(tmpdir, "cache_dir") - data_dir = os.path.join(tmpdir, "data_dir") - os.makedirs(cache_dir) - os.makedirs(data_dir) + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 - # Create a dataset with 100 items, 20 items per chunk - cache = Cache(str(data_dir), chunk_size=20) - for i in range(100): - cache[i] = i - cache.done() - cache.merge() + # Verify that logger warning happens when sample_idx parameter is missing + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset( + data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4, transform=transform_fn_no_sample_idx + ) - # ASSERT - # Verify that ValueError is raised when transform is not given - with pytest.raises(ValueError, match="When using `is_multisample=True`, `transform` must be a list of callables."): - StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, is_multisample=True) + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 From 87f69b8dc63d967f496d78f6bf7c832939e2272c Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Date: Mon, 3 Nov 2025 15:53:33 -0500 Subject: [PATCH 3/3] Moving transform fns outside test functions --- tests/streaming/test_dataloader.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 7d073706..2c2e9c92 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -545,6 +545,22 @@ def test_dataloader_dataset_transform_multisample(tmpdir): assert item == (i // 3) * 2, f"Expected {i * 2}, got {item}" +# Define simple transform functions +def transform_fn_sq(x, sample_idx): + """A simple transform function that doubles the input.""" + return x * sample_idx + + +def transform_fn_add(x, sample_idx): + """A simple transform function that adds the sample_idx to the input.""" + return x + sample_idx + + +def transform_fn_no_sample_idx(x): + """A simple transform function that doubles the input.""" + return x + + def test_dataloader_dataset_transform_invalid_config(tmpdir, caplog): """Test if the dataset's transform is applied correctly with dataloader.""" # Create a simple dataset @@ -554,19 +570,6 @@ def test_dataloader_dataset_transform_invalid_config(tmpdir, caplog): os.makedirs(cache_dir) os.makedirs(data_dir) - # Define simple transform functions - def transform_fn_sq(x, sample_idx): - """A simple transform function that doubles the input.""" - return x * sample_idx - - def transform_fn_add(x, sample_idx): - """A simple transform function that adds the sample_idx to the input.""" - return x + sample_idx - - def transform_fn_no_sample_idx(x): - """A simple transform function that doubles the input.""" - return x - # Create a dataset with 100 items, 20 items per chunk cache = Cache(str(data_dir), chunk_size=20) for i in range(100):