diff --git a/nemo_curator/stages/deduplication/exact/workflow.py b/nemo_curator/stages/deduplication/exact/workflow.py index 6955913b27..fee61ae8fd 100644 --- a/nemo_curator/stages/deduplication/exact/workflow.py +++ b/nemo_curator/stages/deduplication/exact/workflow.py @@ -30,6 +30,7 @@ ) from nemo_curator.stages.file_partitioning import FilePartitioningStage from nemo_curator.tasks import FileGroupTask +from nemo_curator.utils.file_utils import get_default_file_extensions ID_GENERATOR_OUTPUT_FILENAME = "exact_id_generator.json" @@ -151,7 +152,7 @@ def _create_input_filegroups(self) -> Pipeline: stages=[ FilePartitioningStage( file_paths=self.input_path, - file_extensions=self.input_file_extensions, + file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)), blocksize=self.input_blocksize, storage_options=self.read_kwargs.get("storage_options") if self.read_kwargs is not None else None, ), diff --git a/nemo_curator/stages/deduplication/fuzzy/workflow.py b/nemo_curator/stages/deduplication/fuzzy/workflow.py index 9acc86b272..12b47e7c68 100644 --- a/nemo_curator/stages/deduplication/fuzzy/workflow.py +++ b/nemo_curator/stages/deduplication/fuzzy/workflow.py @@ -33,7 +33,7 @@ ) from nemo_curator.stages.file_partitioning import FilePartitioningStage from nemo_curator.tasks import FileGroupTask -from nemo_curator.utils.file_utils import get_fs +from nemo_curator.utils.file_utils import get_default_file_extensions, get_fs ID_GENERATOR_OUTPUT_FILENAME = "fuzzy_id_generator.json" @@ -203,7 +203,7 @@ def _create_minhash_pipeline(self, generate_input_filegroups: bool) -> Pipeline: stages.append( FilePartitioningStage( file_paths=self.input_path, - file_extensions=self.input_file_extensions, + file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)), blocksize=self.input_blocksize, storage_options=self.read_kwargs.get("storage_options") if self.read_kwargs is not None else None, ), diff --git a/nemo_curator/stages/deduplication/semantic/kmeans.py b/nemo_curator/stages/deduplication/semantic/kmeans.py index 1a1854a9f8..c27b09721a 100644 --- a/nemo_curator/stages/deduplication/semantic/kmeans.py +++ b/nemo_curator/stages/deduplication/semantic/kmeans.py @@ -25,7 +25,7 @@ from nemo_curator.stages.resources import Resources from nemo_curator.stages.text.embedders.utils import create_list_series_from_1d_or_2d_ar from nemo_curator.tasks import EmptyTask, FileGroupTask -from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, check_disallowed_kwargs +from nemo_curator.utils.file_utils import check_disallowed_kwargs, get_default_file_extensions from .utils import break_parquet_partition_into_groups, get_array_from_df @@ -542,10 +542,7 @@ def __post_init__(self): def decompose(self) -> list[ProcessingStage]: # Set default file extensions based on input_filetype if not provided - file_extensions = self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS.get(self.input_filetype, []) - if not file_extensions: - msg = f"Unsupported filetype: {self.input_filetype}" - raise ValueError(msg) + file_extensions = self.input_file_extensions or get_default_file_extensions(self.input_filetype) return [ FilePartitioningStage( diff --git a/nemo_curator/stages/text/deduplication/removal_workflow.py b/nemo_curator/stages/text/deduplication/removal_workflow.py index 078b2014d3..f32cccecde 100644 --- a/nemo_curator/stages/text/deduplication/removal_workflow.py +++ b/nemo_curator/stages/text/deduplication/removal_workflow.py @@ -23,6 +23,7 @@ from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR from nemo_curator.tasks import FileGroupTask +from nemo_curator.utils.file_utils import get_default_file_extensions from .removal import TextDuplicatesRemovalStage @@ -84,7 +85,7 @@ def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) -> file_paths=self.input_path, files_per_partition=self.input_files_per_partition, blocksize=self.input_blocksize, - file_extensions=self.input_file_extensions, + file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)), storage_options=(self.input_kwargs or {}).get("storage_options"), limit=self.input_task_limit, ) diff --git a/nemo_curator/stages/text/deduplication/semantic.py b/nemo_curator/stages/text/deduplication/semantic.py index 34eaa7f651..acb333da2f 100644 --- a/nemo_curator/stages/text/deduplication/semantic.py +++ b/nemo_curator/stages/text/deduplication/semantic.py @@ -45,7 +45,7 @@ from nemo_curator.stages.text.io.reader import JsonlReader, ParquetReader from nemo_curator.stages.text.io.writer import ParquetWriter from nemo_curator.tasks import Task -from nemo_curator.utils.file_utils import create_or_overwrite_dir +from nemo_curator.utils.file_utils import create_or_overwrite_dir, get_default_file_extensions @dataclass @@ -249,7 +249,7 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: + [self.text_field] + (self.metadata_fields or []) ), - file_extensions=self.input_file_extensions, + file_extensions=self.input_file_extensions or get_default_file_extensions(self.input_filetype), _generate_ids=self.use_id_generator, read_kwargs=self.read_kwargs, ) @@ -263,7 +263,7 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: + [self.text_field] + (self.metadata_fields or []) ), - file_extensions=self.input_file_extensions, + file_extensions=self.input_file_extensions or get_default_file_extensions(self.input_filetype), read_kwargs=self.read_kwargs, _generate_ids=self.use_id_generator, ) diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py index 2359a07bb7..2a5e1bae26 100644 --- a/nemo_curator/utils/file_utils.py +++ b/nemo_curator/utils/file_utils.py @@ -39,6 +39,15 @@ } +def get_default_file_extensions(input_filetype: str) -> list[str]: + """Return default file extensions for an input file type.""" + file_extensions = FILETYPE_TO_DEFAULT_EXTENSIONS.get(input_filetype) + if file_extensions is None: + msg = f"Unsupported filetype: {input_filetype}" + raise ValueError(msg) + return file_extensions + + def get_fs(path: str, storage_options: dict[str, str] | None = None) -> fsspec.AbstractFileSystem: if not storage_options: storage_options = {} diff --git a/tests/stages/deduplication/exact/test_workflow.py b/tests/stages/deduplication/exact/test_workflow.py index d1ecfd7d2e..d7fcf9042a 100644 --- a/tests/stages/deduplication/exact/test_workflow.py +++ b/tests/stages/deduplication/exact/test_workflow.py @@ -186,6 +186,29 @@ def test_no_dedup(self, exact_no_dedup_data_jsonl: list[FileGroupTask], tmpdir: removal_ids_df = cudf.read_parquet(tmpdir / "ExactDuplicateIds") assert len(removal_ids_df) == 0 + def test_input_file_extensions_default_to_input_filetype(self, tmpdir: Path) -> None: + workflow = ExactDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmpdir), + input_filetype="jsonl", + ) + + stages = workflow._create_input_filegroups().stages + + assert stages[0].file_extensions == [".jsonl", ".json"] + + def test_input_file_extensions_override_default(self, tmpdir: Path) -> None: + workflow = ExactDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmpdir), + input_filetype="parquet", + input_file_extensions=[".pq"], + ) + + stages = workflow._create_input_filegroups().stages + + assert stages[0].file_extensions == [".pq"] + def test_bad_inputs(self, tmpdir: Path) -> None: with pytest.raises(NotImplementedError, match="Removal is not implemented"): # Removal is not implemented yet diff --git a/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py b/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py index ea52929ed6..4cf71e472a 100644 --- a/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py +++ b/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py @@ -266,6 +266,31 @@ def test_fuzzy_dedup_no_duplicates( lsh_df = cudf.read_parquet(cache_path / "LSHStage") assert len(lsh_df) == 0 + def test_input_file_extensions_default_to_input_filetype(self, tmp_path: Path) -> None: + workflow = FuzzyDeduplicationWorkflow( + input_path="/dummy", + cache_path=str(tmp_path), + output_path=str(tmp_path), + input_filetype="jsonl", + ) + + stages = workflow._create_minhash_pipeline(generate_input_filegroups=True).stages + + assert stages[0].file_extensions == [".jsonl", ".json"] + + def test_input_file_extensions_override_default(self, tmp_path: Path) -> None: + workflow = FuzzyDeduplicationWorkflow( + input_path="/dummy", + cache_path=str(tmp_path), + output_path=str(tmp_path), + input_filetype="parquet", + input_file_extensions=[".pq"], + ) + + stages = workflow._create_minhash_pipeline(generate_input_filegroups=True).stages + + assert stages[0].file_extensions == [".pq"] + def test_bad_inputs(self, tmp_path: Path) -> None: with pytest.raises(ValueError, match="bands_per_iteration must be between"): # bands_per_iteration must be between 1 and num_bands diff --git a/tests/stages/deduplication/semantic/test_kmeans.py b/tests/stages/deduplication/semantic/test_kmeans.py index 968a6b5217..5061cdfe66 100644 --- a/tests/stages/deduplication/semantic/test_kmeans.py +++ b/tests/stages/deduplication/semantic/test_kmeans.py @@ -144,6 +144,64 @@ def run_single_gpu_baseline( return df.sort_values("id", ignore_index=True)["centroid"].to_numpy() +class TestKMeansStage: + """Unit tests for KMeansStage decomposition.""" + + @pytest.mark.parametrize( + ("input_filetype", "expected_extensions"), + [ + ("parquet", [".parquet"]), + ("jsonl", [".jsonl", ".json"]), + ], + ) + def test_input_file_extensions_default_to_input_filetype( + self, + tmp_path: Path, + input_filetype: Literal["parquet", "jsonl"], + expected_extensions: list[str], + ) -> None: + stage = KMeansStage( + id_field="id", + embedding_field="embeddings", + n_clusters=2, + input_path=str(tmp_path / "input"), + output_path=str(tmp_path / "output"), + input_filetype=input_filetype, + ) + + stages = stage.decompose() + + assert stages[0].file_extensions == expected_extensions + + def test_input_file_extensions_override_default(self, tmp_path: Path) -> None: + stage = KMeansStage( + id_field="id", + embedding_field="embeddings", + n_clusters=2, + input_path=str(tmp_path / "input"), + output_path=str(tmp_path / "output"), + input_filetype="parquet", + input_file_extensions=[".pq"], + ) + + stages = stage.decompose() + + assert stages[0].file_extensions == [".pq"] + + def test_unsupported_input_filetype_raises(self, tmp_path: Path) -> None: + stage = KMeansStage( + id_field="id", + embedding_field="embeddings", + n_clusters=2, + input_path=str(tmp_path / "input"), + output_path=str(tmp_path / "output"), + input_filetype="csv", # type: ignore[arg-type] + ) + + with pytest.raises(ValueError, match="Unsupported filetype: csv"): + stage.decompose() + + @pytest.mark.gpu class TestKMeansStageIntegration: """Integration tests for KMeansStage comparing multi-GPU vs single-GPU results.""" diff --git a/tests/stages/text/deduplication/test_removal_workflow.py b/tests/stages/text/deduplication/test_removal_workflow.py index 2f87ce2e5b..c45abcee61 100644 --- a/tests/stages/text/deduplication/test_removal_workflow.py +++ b/tests/stages/text/deduplication/test_removal_workflow.py @@ -303,9 +303,17 @@ def test_invalid_filetypes(self): with pytest.raises(ValueError, match="Invalid output filetype: invalid"): write_invalid_file_type_workflow._generate_stages(initial_tasks=None) - @pytest.mark.parametrize("input_filetype", ["parquet", "jsonl"]) + @pytest.mark.parametrize( + ("input_filetype", "expected_file_extensions"), + [("parquet", [".parquet"]), ("jsonl", [".jsonl", ".json"])], + ) @pytest.mark.parametrize("id_generator_path", [None, "id_generator_path"]) - def test_reader_stage(self, input_filetype: str, id_generator_path: str | None): + def test_reader_stage( + self, + input_filetype: str, + expected_file_extensions: list[str], + id_generator_path: str | None, + ): workflow = TextDuplicatesRemovalWorkflow( input_path="input_path", ids_to_remove_path="ids_to_remove_path", @@ -322,8 +330,7 @@ def test_reader_stage(self, input_filetype: str, id_generator_path: str | None): assert stages[0].file_paths == "input_path" assert stages[0].files_per_partition is None assert stages[0].blocksize is None - # post init of FilePartitioningStage sets this - assert stages[0].file_extensions == [".jsonl", ".json", ".parquet"] + assert stages[0].file_extensions == expected_file_extensions assert stages[0].storage_options == {} # test for reader stage (stages[1]) @@ -344,6 +351,20 @@ def test_reader_stage(self, input_filetype: str, id_generator_path: str | None): # test for writer stage (stages[3]) - default output_filetype is parquet assert isinstance(stages[3], ParquetWriter) + def test_reader_stage_with_custom_input_file_extensions(self): + workflow = TextDuplicatesRemovalWorkflow( + input_path="input_path", + ids_to_remove_path="ids_to_remove_path", + output_path="output_path", + input_filetype="parquet", + input_file_extensions=[".pq"], + id_generator_path=None, + ) + + stages = workflow._generate_stages(initial_tasks=None) + + assert stages[0].file_extensions == [".pq"] + @pytest.mark.parametrize("output_filetype", ["parquet", "jsonl"]) def test_writer_stage(self, output_filetype: str): workflow = TextDuplicatesRemovalWorkflow( diff --git a/tests/stages/text/deduplication/test_semantic.py b/tests/stages/text/deduplication/test_semantic.py index 8f49aeda98..d0ff593e08 100644 --- a/tests/stages/text/deduplication/test_semantic.py +++ b/tests/stages/text/deduplication/test_semantic.py @@ -15,7 +15,7 @@ import os from contextlib import suppress from pathlib import Path -from typing import Any +from typing import Any, Literal import pandas as pd import pytest @@ -27,12 +27,13 @@ # Suppress GPU-related import errors when running pytest -m "not gpu" with suppress(ImportError): + from nemo_curator.stages.text.deduplication import semantic from nemo_curator.stages.text.deduplication.semantic import TextSemanticDeduplicationWorkflow MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def ensure_semantic_model_downloaded() -> None: """Pre-download the model once per session to avoid rate limiting in CI.""" try: @@ -71,6 +72,78 @@ def create_data_with_duplicates(input_dir: Path) -> pd.DataFrame: return df +@pytest.mark.parametrize( + ("input_filetype", "expected_extensions"), + [ + ("jsonl", [".jsonl", ".json"]), + ("parquet", [".parquet"]), + ], +) +def test_embedding_reader_extensions_default_to_input_filetype( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + input_filetype: Literal["jsonl", "parquet"], + expected_extensions: list[str], +) -> None: + captured_stages = [] + + def capture_pipeline_run(self, executor) -> list[object]: # noqa: ANN001, ARG001 + captured_stages.extend(self.stages) + return [] + + monkeypatch.setattr(semantic.Pipeline, "run", capture_pipeline_run) + + workflow = TextSemanticDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmp_path / "output"), + cache_path=str(tmp_path / "cache"), + input_filetype=input_filetype, + ) + + workflow._run_embedding_generation(executor=object()) + + assert captured_stages[0].file_extensions == expected_extensions + + +def test_embedding_reader_extensions_override_default(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + captured_stages = [] + + def capture_pipeline_run(self, executor) -> list[object]: # noqa: ANN001, ARG001 + captured_stages.extend(self.stages) + return [] + + monkeypatch.setattr(semantic.Pipeline, "run", capture_pipeline_run) + + workflow = TextSemanticDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmp_path / "output"), + cache_path=str(tmp_path / "cache"), + input_filetype="parquet", + input_file_extensions=[".pq"], + ) + + workflow._run_embedding_generation(executor=object()) + + assert captured_stages[0].file_extensions == [".pq"] + + +def test_embedding_reader_unsupported_filetype_error(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + def fail_pipeline_run(self, executor) -> None: # noqa: ANN001, ARG001 + pytest.fail("Pipeline should not run when input_filetype is unsupported") + + monkeypatch.setattr(semantic.Pipeline, "run", fail_pipeline_run) + + workflow = TextSemanticDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmp_path / "output"), + cache_path=str(tmp_path / "cache"), + input_filetype="csv", # type: ignore[arg-type] + ) + + with pytest.raises(NotImplementedError, match="Input filetype csv not supported yet"): + workflow._run_embedding_generation(executor=object()) + + @pytest.mark.gpu @pytest.mark.parametrize( "test_config", @@ -97,7 +170,10 @@ class TestTextSemanticDeduplicationWorkflow: @pytest.fixture(scope="class", autouse=True) def test_config( - self, request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory + self, + request: pytest.FixtureRequest, + tmp_path_factory: pytest.TempPathFactory, + ensure_semantic_model_downloaded: None, ) -> "TestTextSemanticDeduplicationWorkflow": """Set up test environment and execute workflow.""" executor_cls, config, use_id_generator = request.param