From 31ba333f80c2a41e263438501b131b16b98f3b6f Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Thu, 14 May 2026 23:08:42 +0800 Subject: [PATCH 1/3] fix: default workflow input extensions by filetype Signed-off-by: nightcityblade --- .../stages/deduplication/exact/workflow.py | 5 +++- .../stages/deduplication/fuzzy/workflow.py | 6 ++-- .../text/deduplication/removal_workflow.py | 5 +++- .../stages/text/deduplication/semantic.py | 6 ++-- .../deduplication/exact/test_workflow.py | 23 +++++++++++++++ .../fuzzy/test_fuzzy_workflow.py | 25 ++++++++++++++++ .../deduplication/test_removal_workflow.py | 29 ++++++++++++++++--- 7 files changed, 88 insertions(+), 11 deletions(-) diff --git a/nemo_curator/stages/deduplication/exact/workflow.py b/nemo_curator/stages/deduplication/exact/workflow.py index 7f14ad02e2..123e681f78 100644 --- a/nemo_curator/stages/deduplication/exact/workflow.py +++ b/nemo_curator/stages/deduplication/exact/workflow.py @@ -30,6 +30,7 @@ ) from nemo_curator.stages.file_partitioning import FilePartitioningStage from nemo_curator.tasks import FileGroupTask +from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS ID_GENERATOR_OUTPUT_FILENAME = "exact_id_generator.json" @@ -151,7 +152,9 @@ def _create_input_filegroups(self) -> Pipeline: stages=[ FilePartitioningStage( file_paths=self.input_path, - file_extensions=self.input_file_extensions, + file_extensions=( + self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype] + ), blocksize=self.input_blocksize, storage_options=self.read_kwargs.get("storage_options") if self.read_kwargs is not None else None, ), diff --git a/nemo_curator/stages/deduplication/fuzzy/workflow.py b/nemo_curator/stages/deduplication/fuzzy/workflow.py index 4928389139..32bb3b9d8f 100644 --- a/nemo_curator/stages/deduplication/fuzzy/workflow.py +++ b/nemo_curator/stages/deduplication/fuzzy/workflow.py @@ -33,7 +33,7 @@ ) from nemo_curator.stages.file_partitioning import FilePartitioningStage from nemo_curator.tasks import FileGroupTask -from nemo_curator.utils.file_utils import get_fs +from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, get_fs ID_GENERATOR_OUTPUT_FILENAME = "fuzzy_id_generator.json" @@ -183,7 +183,9 @@ def _create_minhash_pipeline(self, generate_input_filegroups: bool) -> Pipeline: stages.append( FilePartitioningStage( file_paths=self.input_path, - file_extensions=self.input_file_extensions, + file_extensions=( + self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype] + ), blocksize=self.input_blocksize, storage_options=self.read_kwargs.get("storage_options") if self.read_kwargs is not None else None, ), diff --git a/nemo_curator/stages/text/deduplication/removal_workflow.py b/nemo_curator/stages/text/deduplication/removal_workflow.py index 078b2014d3..2b054bf003 100644 --- a/nemo_curator/stages/text/deduplication/removal_workflow.py +++ b/nemo_curator/stages/text/deduplication/removal_workflow.py @@ -23,6 +23,7 @@ from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR from nemo_curator.tasks import FileGroupTask +from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS from .removal import TextDuplicatesRemovalStage @@ -84,7 +85,9 @@ def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) -> file_paths=self.input_path, files_per_partition=self.input_files_per_partition, blocksize=self.input_blocksize, - file_extensions=self.input_file_extensions, + file_extensions=( + self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype] + ), storage_options=(self.input_kwargs or {}).get("storage_options"), limit=self.input_task_limit, ) diff --git a/nemo_curator/stages/text/deduplication/semantic.py b/nemo_curator/stages/text/deduplication/semantic.py index a50cfd160e..5bb8729031 100644 --- a/nemo_curator/stages/text/deduplication/semantic.py +++ b/nemo_curator/stages/text/deduplication/semantic.py @@ -45,7 +45,7 @@ from nemo_curator.stages.text.io.reader import JsonlReader, ParquetReader from nemo_curator.stages.text.io.writer import ParquetWriter from nemo_curator.tasks import Task -from nemo_curator.utils.file_utils import create_or_overwrite_dir +from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, create_or_overwrite_dir @dataclass @@ -245,7 +245,7 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: + [self.text_field] + (self.metadata_fields or []) ), - file_extensions=self.input_file_extensions, + file_extensions=self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype], _generate_ids=self.use_id_generator, read_kwargs=self.read_kwargs, ) @@ -259,7 +259,7 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: + [self.text_field] + (self.metadata_fields or []) ), - file_extensions=self.input_file_extensions, + file_extensions=self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype], read_kwargs=self.read_kwargs, _generate_ids=self.use_id_generator, ) diff --git a/tests/stages/deduplication/exact/test_workflow.py b/tests/stages/deduplication/exact/test_workflow.py index dc9139d240..bfdeaee820 100644 --- a/tests/stages/deduplication/exact/test_workflow.py +++ b/tests/stages/deduplication/exact/test_workflow.py @@ -182,6 +182,29 @@ def test_no_dedup(self, exact_no_dedup_data_jsonl: list[FileGroupTask], tmpdir: removal_ids_df = cudf.read_parquet(tmpdir / "ExactDuplicateIds") assert len(removal_ids_df) == 0 + def test_input_file_extensions_default_to_input_filetype(self, tmpdir: Path) -> None: + workflow = ExactDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmpdir), + input_filetype="jsonl", + ) + + stages = workflow._create_input_filegroups().stages + + assert stages[0].file_extensions == [".jsonl", ".json"] + + def test_input_file_extensions_override_default(self, tmpdir: Path) -> None: + workflow = ExactDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmpdir), + input_filetype="parquet", + input_file_extensions=[".pq"], + ) + + stages = workflow._create_input_filegroups().stages + + assert stages[0].file_extensions == [".pq"] + def test_bad_inputs(self, tmpdir: Path) -> None: with pytest.raises(NotImplementedError, match="Removal is not implemented"): # Removal is not implemented yet diff --git a/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py b/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py index 068fbe649a..b7786271c8 100644 --- a/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py +++ b/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py @@ -267,6 +267,31 @@ def test_fuzzy_dedup_no_duplicates( lsh_df = cudf.read_parquet(cache_path / "LSHStage") assert len(lsh_df) == 0 + def test_input_file_extensions_default_to_input_filetype(self, tmp_path: Path) -> None: + workflow = FuzzyDeduplicationWorkflow( + input_path="/dummy", + cache_path=str(tmp_path), + output_path=str(tmp_path), + input_filetype="jsonl", + ) + + stages = workflow._create_minhash_pipeline(generate_input_filegroups=True).stages + + assert stages[0].file_extensions == [".jsonl", ".json"] + + def test_input_file_extensions_override_default(self, tmp_path: Path) -> None: + workflow = FuzzyDeduplicationWorkflow( + input_path="/dummy", + cache_path=str(tmp_path), + output_path=str(tmp_path), + input_filetype="parquet", + input_file_extensions=[".pq"], + ) + + stages = workflow._create_minhash_pipeline(generate_input_filegroups=True).stages + + assert stages[0].file_extensions == [".pq"] + def test_bad_inputs(self, tmp_path: Path) -> None: with pytest.raises(ValueError, match="bands_per_iteration must be between"): # bands_per_iteration must be between 1 and num_bands diff --git a/tests/stages/text/deduplication/test_removal_workflow.py b/tests/stages/text/deduplication/test_removal_workflow.py index 5682931675..5002c8cf52 100644 --- a/tests/stages/text/deduplication/test_removal_workflow.py +++ b/tests/stages/text/deduplication/test_removal_workflow.py @@ -303,9 +303,17 @@ def test_invalid_filetypes(self): with pytest.raises(ValueError, match="Invalid output filetype: invalid"): write_invalid_file_type_workflow._generate_stages(initial_tasks=None) - @pytest.mark.parametrize("input_filetype", ["parquet", "jsonl"]) + @pytest.mark.parametrize( + ("input_filetype", "expected_file_extensions"), + [("parquet", [".parquet"]), ("jsonl", [".jsonl", ".json"])], + ) @pytest.mark.parametrize("id_generator_path", [None, "id_generator_path"]) - def test_reader_stage(self, input_filetype: str, id_generator_path: str | None): + def test_reader_stage( + self, + input_filetype: str, + expected_file_extensions: list[str], + id_generator_path: str | None, + ): workflow = TextDuplicatesRemovalWorkflow( input_path="input_path", ids_to_remove_path="ids_to_remove_path", @@ -322,8 +330,7 @@ def test_reader_stage(self, input_filetype: str, id_generator_path: str | None): assert stages[0].file_paths == "input_path" assert stages[0].files_per_partition is None assert stages[0].blocksize is None - # post init of FilePartitioningStage sets this - assert stages[0].file_extensions == [".jsonl", ".json", ".parquet"] + assert stages[0].file_extensions == expected_file_extensions assert stages[0].storage_options == {} # test for reader stage (stages[1]) @@ -344,6 +351,20 @@ def test_reader_stage(self, input_filetype: str, id_generator_path: str | None): # test for writer stage (stages[3]) - default output_filetype is parquet assert isinstance(stages[3], ParquetWriter) + def test_reader_stage_with_custom_input_file_extensions(self): + workflow = TextDuplicatesRemovalWorkflow( + input_path="input_path", + ids_to_remove_path="ids_to_remove_path", + output_path="output_path", + input_filetype="parquet", + input_file_extensions=[".pq"], + id_generator_path=None, + ) + + stages = workflow._generate_stages(initial_tasks=None) + + assert stages[0].file_extensions == [".pq"] + @pytest.mark.parametrize("output_filetype", ["parquet", "jsonl"]) def test_writer_stage(self, output_filetype: str): workflow = TextDuplicatesRemovalWorkflow( From aff052987b2d2416a93a746827e01e436bd1d21b Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Wed, 3 Jun 2026 23:12:52 +0800 Subject: [PATCH 2/3] fix: validate default input file extensions Signed-off-by: nightcityblade --- .../stages/deduplication/exact/workflow.py | 6 +- .../stages/deduplication/fuzzy/workflow.py | 6 +- .../stages/deduplication/semantic/kmeans.py | 4 +- .../text/deduplication/removal_workflow.py | 6 +- .../stages/text/deduplication/semantic.py | 6 +- nemo_curator/utils/file_utils.py | 9 ++ .../text/deduplication/test_semantic.py | 82 ++++++++++++++++++- 7 files changed, 99 insertions(+), 20 deletions(-) diff --git a/nemo_curator/stages/deduplication/exact/workflow.py b/nemo_curator/stages/deduplication/exact/workflow.py index 123e681f78..5e37e3612a 100644 --- a/nemo_curator/stages/deduplication/exact/workflow.py +++ b/nemo_curator/stages/deduplication/exact/workflow.py @@ -30,7 +30,7 @@ ) from nemo_curator.stages.file_partitioning import FilePartitioningStage from nemo_curator.tasks import FileGroupTask -from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS +from nemo_curator.utils.file_utils import get_default_file_extensions ID_GENERATOR_OUTPUT_FILENAME = "exact_id_generator.json" @@ -152,9 +152,7 @@ def _create_input_filegroups(self) -> Pipeline: stages=[ FilePartitioningStage( file_paths=self.input_path, - file_extensions=( - self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype] - ), + file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)), blocksize=self.input_blocksize, storage_options=self.read_kwargs.get("storage_options") if self.read_kwargs is not None else None, ), diff --git a/nemo_curator/stages/deduplication/fuzzy/workflow.py b/nemo_curator/stages/deduplication/fuzzy/workflow.py index 32bb3b9d8f..f2157c17d1 100644 --- a/nemo_curator/stages/deduplication/fuzzy/workflow.py +++ b/nemo_curator/stages/deduplication/fuzzy/workflow.py @@ -33,7 +33,7 @@ ) from nemo_curator.stages.file_partitioning import FilePartitioningStage from nemo_curator.tasks import FileGroupTask -from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, get_fs +from nemo_curator.utils.file_utils import get_default_file_extensions, get_fs ID_GENERATOR_OUTPUT_FILENAME = "fuzzy_id_generator.json" @@ -183,9 +183,7 @@ def _create_minhash_pipeline(self, generate_input_filegroups: bool) -> Pipeline: stages.append( FilePartitioningStage( file_paths=self.input_path, - file_extensions=( - self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype] - ), + file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)), blocksize=self.input_blocksize, storage_options=self.read_kwargs.get("storage_options") if self.read_kwargs is not None else None, ), diff --git a/nemo_curator/stages/deduplication/semantic/kmeans.py b/nemo_curator/stages/deduplication/semantic/kmeans.py index d530d18a61..5c39bf36f0 100644 --- a/nemo_curator/stages/deduplication/semantic/kmeans.py +++ b/nemo_curator/stages/deduplication/semantic/kmeans.py @@ -25,7 +25,7 @@ from nemo_curator.stages.resources import Resources from nemo_curator.stages.text.embedders.utils import create_list_series_from_1d_or_2d_ar from nemo_curator.tasks import FileGroupTask, _EmptyTask -from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, check_disallowed_kwargs +from nemo_curator.utils.file_utils import check_disallowed_kwargs, get_default_file_extensions from .utils import break_parquet_partition_into_groups, get_array_from_df @@ -344,7 +344,7 @@ def __post_init__(self): def decompose(self) -> list[ProcessingStage]: # Set default file extensions based on input_filetype if not provided - file_extensions = self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS.get(self.input_filetype, []) + file_extensions = self.input_file_extensions or get_default_file_extensions(self.input_filetype) if not file_extensions: msg = f"Unsupported filetype: {self.input_filetype}" raise ValueError(msg) diff --git a/nemo_curator/stages/text/deduplication/removal_workflow.py b/nemo_curator/stages/text/deduplication/removal_workflow.py index 2b054bf003..f32cccecde 100644 --- a/nemo_curator/stages/text/deduplication/removal_workflow.py +++ b/nemo_curator/stages/text/deduplication/removal_workflow.py @@ -23,7 +23,7 @@ from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR from nemo_curator.tasks import FileGroupTask -from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS +from nemo_curator.utils.file_utils import get_default_file_extensions from .removal import TextDuplicatesRemovalStage @@ -85,9 +85,7 @@ def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) -> file_paths=self.input_path, files_per_partition=self.input_files_per_partition, blocksize=self.input_blocksize, - file_extensions=( - self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype] - ), + file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)), storage_options=(self.input_kwargs or {}).get("storage_options"), limit=self.input_task_limit, ) diff --git a/nemo_curator/stages/text/deduplication/semantic.py b/nemo_curator/stages/text/deduplication/semantic.py index 5bb8729031..7f4d69c583 100644 --- a/nemo_curator/stages/text/deduplication/semantic.py +++ b/nemo_curator/stages/text/deduplication/semantic.py @@ -45,7 +45,7 @@ from nemo_curator.stages.text.io.reader import JsonlReader, ParquetReader from nemo_curator.stages.text.io.writer import ParquetWriter from nemo_curator.tasks import Task -from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, create_or_overwrite_dir +from nemo_curator.utils.file_utils import create_or_overwrite_dir, get_default_file_extensions @dataclass @@ -245,7 +245,7 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: + [self.text_field] + (self.metadata_fields or []) ), - file_extensions=self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype], + file_extensions=self.input_file_extensions or get_default_file_extensions(self.input_filetype), _generate_ids=self.use_id_generator, read_kwargs=self.read_kwargs, ) @@ -259,7 +259,7 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: + [self.text_field] + (self.metadata_fields or []) ), - file_extensions=self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS[self.input_filetype], + file_extensions=self.input_file_extensions or get_default_file_extensions(self.input_filetype), read_kwargs=self.read_kwargs, _generate_ids=self.use_id_generator, ) diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py index 1c3341c6f0..8611909f0a 100644 --- a/nemo_curator/utils/file_utils.py +++ b/nemo_curator/utils/file_utils.py @@ -39,6 +39,15 @@ } +def get_default_file_extensions(input_filetype: str) -> list[str]: + """Return default file extensions for an input file type.""" + file_extensions = FILETYPE_TO_DEFAULT_EXTENSIONS.get(input_filetype) + if file_extensions is None: + msg = f"Unsupported filetype: {input_filetype}" + raise ValueError(msg) + return file_extensions + + def get_fs(path: str, storage_options: dict[str, str] | None = None) -> fsspec.AbstractFileSystem: if not storage_options: storage_options = {} diff --git a/tests/stages/text/deduplication/test_semantic.py b/tests/stages/text/deduplication/test_semantic.py index d9907e8d90..b8a58512e6 100644 --- a/tests/stages/text/deduplication/test_semantic.py +++ b/tests/stages/text/deduplication/test_semantic.py @@ -15,7 +15,7 @@ import os from contextlib import suppress from pathlib import Path -from typing import Any +from typing import Any, Literal import pandas as pd import pytest @@ -27,10 +27,11 @@ # Suppress GPU-related import errors when running pytest -m "not gpu" with suppress(ImportError): + from nemo_curator.stages.text.deduplication import semantic from nemo_curator.stages.text.deduplication.semantic import TextSemanticDeduplicationWorkflow -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def ensure_semantic_model_downloaded() -> None: """Pre-download the model once per session to avoid rate limiting in CI.""" try: @@ -69,6 +70,78 @@ def create_data_with_duplicates(input_dir: Path) -> pd.DataFrame: return df +@pytest.mark.parametrize( + ("input_filetype", "expected_extensions"), + [ + ("jsonl", [".jsonl", ".json"]), + ("parquet", [".parquet"]), + ], +) +def test_embedding_reader_extensions_default_to_input_filetype( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + input_filetype: Literal["jsonl", "parquet"], + expected_extensions: list[str], +) -> None: + captured_stages = [] + + def capture_pipeline_run(self, executor) -> list[object]: # noqa: ANN001, ARG001 + captured_stages.extend(self.stages) + return [] + + monkeypatch.setattr(semantic.Pipeline, "run", capture_pipeline_run) + + workflow = TextSemanticDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmp_path / "output"), + cache_path=str(tmp_path / "cache"), + input_filetype=input_filetype, + ) + + workflow._run_embedding_generation(executor=object()) + + assert captured_stages[0].file_extensions == expected_extensions + + +def test_embedding_reader_extensions_override_default(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + captured_stages = [] + + def capture_pipeline_run(self, executor) -> list[object]: # noqa: ANN001, ARG001 + captured_stages.extend(self.stages) + return [] + + monkeypatch.setattr(semantic.Pipeline, "run", capture_pipeline_run) + + workflow = TextSemanticDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmp_path / "output"), + cache_path=str(tmp_path / "cache"), + input_filetype="parquet", + input_file_extensions=[".pq"], + ) + + workflow._run_embedding_generation(executor=object()) + + assert captured_stages[0].file_extensions == [".pq"] + + +def test_embedding_reader_unsupported_filetype_error(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + def fail_pipeline_run(self, executor) -> None: # noqa: ANN001, ARG001 + pytest.fail("Pipeline should not run when input_filetype is unsupported") + + monkeypatch.setattr(semantic.Pipeline, "run", fail_pipeline_run) + + workflow = TextSemanticDeduplicationWorkflow( + input_path="/dummy", + output_path=str(tmp_path / "output"), + cache_path=str(tmp_path / "cache"), + input_filetype="csv", # type: ignore[arg-type] + ) + + with pytest.raises(NotImplementedError, match="Input filetype csv not supported yet"): + workflow._run_embedding_generation(executor=object()) + + @pytest.mark.gpu @pytest.mark.parametrize( "test_config", @@ -95,7 +168,10 @@ class TestTextSemanticDeduplicationWorkflow: @pytest.fixture(scope="class", autouse=True) def test_config( - self, request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory + self, + request: pytest.FixtureRequest, + tmp_path_factory: pytest.TempPathFactory, + ensure_semantic_model_downloaded: None, ) -> "TestTextSemanticDeduplicationWorkflow": """Set up test environment and execute workflow.""" executor_cls, config, use_id_generator = request.param From 19a7290b1c7ca9fae01029673508432b830faedc Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Thu, 4 Jun 2026 23:07:55 +0800 Subject: [PATCH 3/3] test: cover kmeans input file extension defaults --- .../stages/deduplication/semantic/kmeans.py | 3 - .../deduplication/semantic/test_kmeans.py | 58 +++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/nemo_curator/stages/deduplication/semantic/kmeans.py b/nemo_curator/stages/deduplication/semantic/kmeans.py index 5c39bf36f0..d298a4d63f 100644 --- a/nemo_curator/stages/deduplication/semantic/kmeans.py +++ b/nemo_curator/stages/deduplication/semantic/kmeans.py @@ -345,9 +345,6 @@ def __post_init__(self): def decompose(self) -> list[ProcessingStage]: # Set default file extensions based on input_filetype if not provided file_extensions = self.input_file_extensions or get_default_file_extensions(self.input_filetype) - if not file_extensions: - msg = f"Unsupported filetype: {self.input_filetype}" - raise ValueError(msg) return [ FilePartitioningStage( diff --git a/tests/stages/deduplication/semantic/test_kmeans.py b/tests/stages/deduplication/semantic/test_kmeans.py index efadfce2fa..8f9ddd1790 100644 --- a/tests/stages/deduplication/semantic/test_kmeans.py +++ b/tests/stages/deduplication/semantic/test_kmeans.py @@ -143,6 +143,64 @@ def run_single_gpu_baseline( return df.sort_values("id", ignore_index=True)["centroid"].to_numpy() +class TestKMeansStage: + """Unit tests for KMeansStage decomposition.""" + + @pytest.mark.parametrize( + ("input_filetype", "expected_extensions"), + [ + ("parquet", [".parquet"]), + ("jsonl", [".jsonl", ".json"]), + ], + ) + def test_input_file_extensions_default_to_input_filetype( + self, + tmp_path: Path, + input_filetype: Literal["parquet", "jsonl"], + expected_extensions: list[str], + ) -> None: + stage = KMeansStage( + id_field="id", + embedding_field="embeddings", + n_clusters=2, + input_path=str(tmp_path / "input"), + output_path=str(tmp_path / "output"), + input_filetype=input_filetype, + ) + + stages = stage.decompose() + + assert stages[0].file_extensions == expected_extensions + + def test_input_file_extensions_override_default(self, tmp_path: Path) -> None: + stage = KMeansStage( + id_field="id", + embedding_field="embeddings", + n_clusters=2, + input_path=str(tmp_path / "input"), + output_path=str(tmp_path / "output"), + input_filetype="parquet", + input_file_extensions=[".pq"], + ) + + stages = stage.decompose() + + assert stages[0].file_extensions == [".pq"] + + def test_unsupported_input_filetype_raises(self, tmp_path: Path) -> None: + stage = KMeansStage( + id_field="id", + embedding_field="embeddings", + n_clusters=2, + input_path=str(tmp_path / "input"), + output_path=str(tmp_path / "output"), + input_filetype="csv", # type: ignore[arg-type] + ) + + with pytest.raises(ValueError, match="Unsupported filetype: csv"): + stage.decompose() + + @pytest.mark.gpu class TestKMeansStageIntegration: """Integration tests for KMeansStage comparing multi-GPU vs single-GPU results."""