Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo_curator/stages/deduplication/exact/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from nemo_curator.stages.file_partitioning import FilePartitioningStage
from nemo_curator.tasks import FileGroupTask
from nemo_curator.utils.file_utils import get_default_file_extensions

ID_GENERATOR_OUTPUT_FILENAME = "exact_id_generator.json"

Expand Down Expand Up @@ -151,7 +152,7 @@ def _create_input_filegroups(self) -> Pipeline:
stages=[
FilePartitioningStage(
file_paths=self.input_path,
file_extensions=self.input_file_extensions,
file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)),
blocksize=self.input_blocksize,
storage_options=self.read_kwargs.get("storage_options") if self.read_kwargs is not None else None,
),
Expand Down
4 changes: 2 additions & 2 deletions nemo_curator/stages/deduplication/fuzzy/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from nemo_curator.stages.file_partitioning import FilePartitioningStage
from nemo_curator.tasks import FileGroupTask
from nemo_curator.utils.file_utils import get_fs
from nemo_curator.utils.file_utils import get_default_file_extensions, get_fs

ID_GENERATOR_OUTPUT_FILENAME = "fuzzy_id_generator.json"

Expand Down Expand Up @@ -203,7 +203,7 @@ def _create_minhash_pipeline(self, generate_input_filegroups: bool) -> Pipeline:
stages.append(
FilePartitioningStage(
file_paths=self.input_path,
file_extensions=self.input_file_extensions,
file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)),
blocksize=self.input_blocksize,
storage_options=self.read_kwargs.get("storage_options") if self.read_kwargs is not None else None,
),
Expand Down
7 changes: 2 additions & 5 deletions nemo_curator/stages/deduplication/semantic/kmeans.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test for this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added focused KMeansStage decomposition tests in tests/stages/deduplication/semantic/test_kmeans.py covering default extensions for parquet/jsonl, custom extension overrides, and unsupported filetype errors.

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from nemo_curator.stages.resources import Resources
from nemo_curator.stages.text.embedders.utils import create_list_series_from_1d_or_2d_ar
from nemo_curator.tasks import EmptyTask, FileGroupTask
from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, check_disallowed_kwargs
from nemo_curator.utils.file_utils import check_disallowed_kwargs, get_default_file_extensions

from .utils import break_parquet_partition_into_groups, get_array_from_df

Expand Down Expand Up @@ -542,10 +542,7 @@ def __post_init__(self):

def decompose(self) -> list[ProcessingStage]:
# Set default file extensions based on input_filetype if not provided
file_extensions = self.input_file_extensions or FILETYPE_TO_DEFAULT_EXTENSIONS.get(self.input_filetype, [])
if not file_extensions:
msg = f"Unsupported filetype: {self.input_filetype}"
raise ValueError(msg)
file_extensions = self.input_file_extensions or get_default_file_extensions(self.input_filetype)

return [
FilePartitioningStage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR
from nemo_curator.tasks import FileGroupTask
from nemo_curator.utils.file_utils import get_default_file_extensions

from .removal import TextDuplicatesRemovalStage

Expand Down Expand Up @@ -84,7 +85,7 @@ def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) ->
file_paths=self.input_path,
files_per_partition=self.input_files_per_partition,
blocksize=self.input_blocksize,
file_extensions=self.input_file_extensions,
file_extensions=(self.input_file_extensions or get_default_file_extensions(self.input_filetype)),
storage_options=(self.input_kwargs or {}).get("storage_options"),
limit=self.input_task_limit,
)
Expand Down
6 changes: 3 additions & 3 deletions nemo_curator/stages/text/deduplication/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from nemo_curator.stages.text.io.reader import JsonlReader, ParquetReader
from nemo_curator.stages.text.io.writer import ParquetWriter
from nemo_curator.tasks import Task
from nemo_curator.utils.file_utils import create_or_overwrite_dir
from nemo_curator.utils.file_utils import create_or_overwrite_dir, get_default_file_extensions


@dataclass
Expand Down Expand Up @@ -249,7 +249,7 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]:
+ [self.text_field]
+ (self.metadata_fields or [])
),
file_extensions=self.input_file_extensions,
file_extensions=self.input_file_extensions or get_default_file_extensions(self.input_filetype),
_generate_ids=self.use_id_generator,
read_kwargs=self.read_kwargs,
)
Expand All @@ -263,7 +263,7 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]:
+ [self.text_field]
+ (self.metadata_fields or [])
),
file_extensions=self.input_file_extensions,
file_extensions=self.input_file_extensions or get_default_file_extensions(self.input_filetype),
read_kwargs=self.read_kwargs,
_generate_ids=self.use_id_generator,
)
Expand Down
9 changes: 9 additions & 0 deletions nemo_curator/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
}


def get_default_file_extensions(input_filetype: str) -> list[str]:
"""Return default file extensions for an input file type."""
file_extensions = FILETYPE_TO_DEFAULT_EXTENSIONS.get(input_filetype)
if file_extensions is None:
msg = f"Unsupported filetype: {input_filetype}"
raise ValueError(msg)
return file_extensions


def get_fs(path: str, storage_options: dict[str, str] | None = None) -> fsspec.AbstractFileSystem:
if not storage_options:
storage_options = {}
Expand Down
23 changes: 23 additions & 0 deletions tests/stages/deduplication/exact/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,29 @@ def test_no_dedup(self, exact_no_dedup_data_jsonl: list[FileGroupTask], tmpdir:
removal_ids_df = cudf.read_parquet(tmpdir / "ExactDuplicateIds")
assert len(removal_ids_df) == 0

def test_input_file_extensions_default_to_input_filetype(self, tmpdir: Path) -> None:
workflow = ExactDeduplicationWorkflow(
input_path="/dummy",
output_path=str(tmpdir),
input_filetype="jsonl",
)

stages = workflow._create_input_filegroups().stages

assert stages[0].file_extensions == [".jsonl", ".json"]

def test_input_file_extensions_override_default(self, tmpdir: Path) -> None:
workflow = ExactDeduplicationWorkflow(
input_path="/dummy",
output_path=str(tmpdir),
input_filetype="parquet",
input_file_extensions=[".pq"],
)

stages = workflow._create_input_filegroups().stages

assert stages[0].file_extensions == [".pq"]

def test_bad_inputs(self, tmpdir: Path) -> None:
with pytest.raises(NotImplementedError, match="Removal is not implemented"):
# Removal is not implemented yet
Expand Down
25 changes: 25 additions & 0 deletions tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,31 @@ def test_fuzzy_dedup_no_duplicates(
lsh_df = cudf.read_parquet(cache_path / "LSHStage")
assert len(lsh_df) == 0

def test_input_file_extensions_default_to_input_filetype(self, tmp_path: Path) -> None:
workflow = FuzzyDeduplicationWorkflow(
input_path="/dummy",
cache_path=str(tmp_path),
output_path=str(tmp_path),
input_filetype="jsonl",
)

stages = workflow._create_minhash_pipeline(generate_input_filegroups=True).stages

assert stages[0].file_extensions == [".jsonl", ".json"]

def test_input_file_extensions_override_default(self, tmp_path: Path) -> None:
workflow = FuzzyDeduplicationWorkflow(
input_path="/dummy",
cache_path=str(tmp_path),
output_path=str(tmp_path),
input_filetype="parquet",
input_file_extensions=[".pq"],
)

stages = workflow._create_minhash_pipeline(generate_input_filegroups=True).stages

assert stages[0].file_extensions == [".pq"]

def test_bad_inputs(self, tmp_path: Path) -> None:
with pytest.raises(ValueError, match="bands_per_iteration must be between"):
# bands_per_iteration must be between 1 and num_bands
Expand Down
58 changes: 58 additions & 0 deletions tests/stages/deduplication/semantic/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,64 @@ def run_single_gpu_baseline(
return df.sort_values("id", ignore_index=True)["centroid"].to_numpy()


class TestKMeansStage:
"""Unit tests for KMeansStage decomposition."""

@pytest.mark.parametrize(
("input_filetype", "expected_extensions"),
[
("parquet", [".parquet"]),
("jsonl", [".jsonl", ".json"]),
],
)
def test_input_file_extensions_default_to_input_filetype(
self,
tmp_path: Path,
input_filetype: Literal["parquet", "jsonl"],
expected_extensions: list[str],
) -> None:
stage = KMeansStage(
id_field="id",
embedding_field="embeddings",
n_clusters=2,
input_path=str(tmp_path / "input"),
output_path=str(tmp_path / "output"),
input_filetype=input_filetype,
)

stages = stage.decompose()

assert stages[0].file_extensions == expected_extensions

def test_input_file_extensions_override_default(self, tmp_path: Path) -> None:
stage = KMeansStage(
id_field="id",
embedding_field="embeddings",
n_clusters=2,
input_path=str(tmp_path / "input"),
output_path=str(tmp_path / "output"),
input_filetype="parquet",
input_file_extensions=[".pq"],
)

stages = stage.decompose()

assert stages[0].file_extensions == [".pq"]

def test_unsupported_input_filetype_raises(self, tmp_path: Path) -> None:
stage = KMeansStage(
id_field="id",
embedding_field="embeddings",
n_clusters=2,
input_path=str(tmp_path / "input"),
output_path=str(tmp_path / "output"),
input_filetype="csv", # type: ignore[arg-type]
)

with pytest.raises(ValueError, match="Unsupported filetype: csv"):
stage.decompose()


@pytest.mark.gpu
class TestKMeansStageIntegration:
"""Integration tests for KMeansStage comparing multi-GPU vs single-GPU results."""
Expand Down
29 changes: 25 additions & 4 deletions tests/stages/text/deduplication/test_removal_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,17 @@ def test_invalid_filetypes(self):
with pytest.raises(ValueError, match="Invalid output filetype: invalid"):
write_invalid_file_type_workflow._generate_stages(initial_tasks=None)

@pytest.mark.parametrize("input_filetype", ["parquet", "jsonl"])
@pytest.mark.parametrize(
("input_filetype", "expected_file_extensions"),
[("parquet", [".parquet"]), ("jsonl", [".jsonl", ".json"])],
)
@pytest.mark.parametrize("id_generator_path", [None, "id_generator_path"])
def test_reader_stage(self, input_filetype: str, id_generator_path: str | None):
def test_reader_stage(
self,
input_filetype: str,
expected_file_extensions: list[str],
id_generator_path: str | None,
):
workflow = TextDuplicatesRemovalWorkflow(
input_path="input_path",
ids_to_remove_path="ids_to_remove_path",
Expand All @@ -322,8 +330,7 @@ def test_reader_stage(self, input_filetype: str, id_generator_path: str | None):
assert stages[0].file_paths == "input_path"
assert stages[0].files_per_partition is None
assert stages[0].blocksize is None
# post init of FilePartitioningStage sets this
assert stages[0].file_extensions == [".jsonl", ".json", ".parquet"]
assert stages[0].file_extensions == expected_file_extensions
assert stages[0].storage_options == {}

# test for reader stage (stages[1])
Expand All @@ -344,6 +351,20 @@ def test_reader_stage(self, input_filetype: str, id_generator_path: str | None):
# test for writer stage (stages[3]) - default output_filetype is parquet
assert isinstance(stages[3], ParquetWriter)

def test_reader_stage_with_custom_input_file_extensions(self):
workflow = TextDuplicatesRemovalWorkflow(
input_path="input_path",
ids_to_remove_path="ids_to_remove_path",
output_path="output_path",
input_filetype="parquet",
input_file_extensions=[".pq"],
id_generator_path=None,
)

stages = workflow._generate_stages(initial_tasks=None)

assert stages[0].file_extensions == [".pq"]

@pytest.mark.parametrize("output_filetype", ["parquet", "jsonl"])
def test_writer_stage(self, output_filetype: str):
workflow = TextDuplicatesRemovalWorkflow(
Expand Down
Loading
Loading