diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index e302a37eb6..94236f2cfc 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import uuid from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -84,6 +85,9 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: # Use the batch processing logic results = self.stage.process_batch(tasks) + # Guarantee every emitted task has a task_id (derived id, or uuid fallback). + results = self._post_process_task_ids(tasks, results) + # Log performance stats and add to result tasks _, stage_perf_stats = self._timer.log_stats() # Consume and attach any custom metrics recorded by the stage during this call @@ -95,6 +99,75 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: return results + def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]: + """Assign a deterministic ``task_id`` to every emitted task. + + This is the single place task ids are assigned — it runs for every + stage on every backend (all backend adapters subclass this), so it + makes no difference whether a stage defines ``process`` or overrides + ``process_batch``. ``task_id`` is the task's id path (parents + own segment); ids are + re-derived at each stage boundary so the same object passing through + N stages gets N ids. + + The input→output mapping decides each output's PARENT; whether the + stage is a source decides each output's SEGMENT (content id vs index) + — the two are independent. ``None`` outputs (Curator's "return None to + filter") are NOT removed before the length check — keeping them in + place preserves positional alignment for filter stages — and are then + dropped from the returned list. + + - single input → every output is its child (fan-out): ``parent_`` + - ``len(output) == len(input)`` → positional 1:1: each ``parent_i_``; + a ``None`` slot just means input ``i`` was filtered. + - any other (ambiguous) cardinality across a batch → a random ``uuid`` + prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never + empty even when a derived id is not possible. The ``"r"`` prefix flags + the id as non-deterministic / ancestry-not-tracked (see + ``Task.task_id`` docstring). + + ``seg`` is the output's content id (``Task.get_deterministic_id()``) + for a source stage when available, else the positional index — so a + source partition keeps a stable id across reorderings regardless of + whether the source is 1→N or N→N. + + Note: a stage that BOTH filters and fans out within a single batch + (returning a flat list rather than a per-input slot) cannot be mapped + positionally; if its length happens to equal the input length the 1:1 + assumption may misattribute parents. That combination is unsupported + until per-slot sentinels (NoneTask/FailedTask) land in a later PR. + """ + is_source = getattr(self.stage, "is_source_stage", False) + + if len(input_tasks) == 1: + # Fan-out (incl. a source reading from EmptyTask): every non-None + # output is a child of the single input. + parent_id = input_tasks[0].task_id + out: list[Task] = [t for t in output_tasks if t is not None] + for i, task in enumerate(out): + suffix = (task.get_deterministic_id() or i) if is_source else i + task._set_task_id(parent_id, suffix) + return out + + if len(output_tasks) == len(input_tasks): + # Positional 1:1. None is kept above so a filtered slot still lines + # up with its own parent; drop the None slots from the result. + out = [] + for parent, task in zip(input_tasks, output_tasks, strict=True): + if task is None: + continue + suffix = (task.get_deterministic_id() or 0) if is_source else 0 + task._set_task_id(parent.task_id, suffix) + out.append(task) + return out + + # Ambiguous cardinality across a batch: a derived id is not possible. Use a + # random "r"-prefixed uuid so task_id is non-empty but clearly flagged + # non-deterministic. + out = [t for t in output_tasks if t is not None] + for task in out: + task.task_id = "r" + uuid.uuid4().hex + return out + def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: """Setup the stage on a node. diff --git a/nemo_curator/models/transnetv2.py b/nemo_curator/models/transnetv2.py index 0f334aa114..e1ae590503 100644 --- a/nemo_curator/models/transnetv2.py +++ b/nemo_curator/models/transnetv2.py @@ -39,6 +39,7 @@ _TRANSNETV2_MODEL_WEIGHTS: Final = "transnetv2-pytorch-weights.pth" _TRANSNETV2_MODEL_REVISION: Final = "db6ceab" + class _TransNetV2(nn.Module): def __init__( # noqa: PLR0913 self, diff --git a/nemo_curator/pipeline/pipeline.py b/nemo_curator/pipeline/pipeline.py index 246ffcffc1..eefc28dc70 100644 --- a/nemo_curator/pipeline/pipeline.py +++ b/nemo_curator/pipeline/pipeline.py @@ -19,6 +19,31 @@ from nemo_curator.backends.base import BaseExecutor from nemo_curator.stages.base import CompositeStage, ProcessingStage from nemo_curator.tasks import Task +from nemo_curator.tasks.tasks import _EmptyTask + + +def assign_root_task_ids(initial_tasks: list[Task]) -> list[Task]: + """Assign root ``task_id``s to user-provided initial tasks. + + Every task in a run descends from the implicit root ``"0"`` (the id of + :class:`_EmptyTask`). User-provided initial tasks are its direct + children, so they get ``"0_0"``, ``"0_1"``, … ``_EmptyTask`` instances + are skipped (already ``"0"``). All downstream ``task_id`` assignment + happens in ``BaseStageAdapter``. + + NOTE: we deliberately use the positional index here, NOT + ``get_deterministic_id()``, even for content-bearing tasks like + ``FileGroupTask``. The source stage is the single place content-based + ids are assigned (to its outputs); hashing here too would put the + content hash at two levels of the id path (``"0__"``). + Passing initial tasks directly is rare; if you need reorder-stable + source ids, let a source stage emit them. + """ + for i, task in enumerate(initial_tasks): + if isinstance(task, _EmptyTask): + continue + task._set_task_id("0", i) + return initial_tasks class Pipeline: @@ -80,6 +105,30 @@ def build(self) -> None: self.stages = execution_stages self.decomposition_info = decomposition_info + # 3. Source / sink defaults: at most one stage may be explicitly + # marked; if none, the first stage is the source and the last is + # the sink. The source flag activates content-based ids in the + # default ``process_batch``; the sink flag is used by the + # resumability layer in a follow-up PR. + self._assign_source_sink_roles() + + def _assign_source_sink_roles(self) -> None: + explicit_sources = [s for s in self.stages if s.is_source_stage] + if len(explicit_sources) > 1: + names = [s.name for s in explicit_sources] + msg = f"Pipeline has multiple source stages marked: {names}. At most one is supported." + raise ValueError(msg) + if not explicit_sources: + self.stages[0].is_source_stage = True + + explicit_sinks = [s for s in self.stages if s.is_sink_stage] + if len(explicit_sinks) > 1: + names = [s.name for s in explicit_sinks] + msg = f"Pipeline has multiple sink stages marked: {names}. At most one is supported." + raise ValueError(msg) + if not explicit_sinks: + self.stages[-1].is_sink_stage = True + def _decompose_stages( self, stages: list[ProcessingStage | CompositeStage] ) -> tuple[list[ProcessingStage], dict[str, list[str]]]: @@ -212,4 +261,7 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | "The executor will schedule GPU stages on GPUs not held by Serve." ) + if initial_tasks: + assign_root_task_ids(initial_tasks) + return executor.execute(self.stages, initial_tasks) diff --git a/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/config.py b/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/config.py index 17b5a6217c..c40f2ca59e 100644 --- a/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/config.py +++ b/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/config.py @@ -122,8 +122,15 @@ def _validate(cfg: dict[str, Any]) -> None: # noqa: C901 sigmos = cfg.get("sigmos", {}) if sigmos.get("enable", True): - for key in ("noise_threshold", "ovrl_threshold", "sig_threshold", - "col_threshold", "disc_threshold", "loud_threshold", "reverb_threshold"): + for key in ( + "noise_threshold", + "ovrl_threshold", + "sig_threshold", + "col_threshold", + "disc_threshold", + "loud_threshold", + "reverb_threshold", + ): val = sigmos.get(key) if val is not None and not 0.0 <= val <= _MOS_MAX: msg = f"sigmos.{key} must be in [0, {_MOS_MAX}] (MOS scale), got {val}" diff --git a/nemo_curator/stages/audio/alm/pretrain/extraction.py b/nemo_curator/stages/audio/alm/pretrain/extraction.py index 1fcd93bbd7..1aa70c1661 100644 --- a/nemo_curator/stages/audio/alm/pretrain/extraction.py +++ b/nemo_curator/stages/audio/alm/pretrain/extraction.py @@ -179,9 +179,7 @@ def process(self, task: AudioTask) -> list[AudioTask]: ) return outputs - def _extract_emit( - self, task: AudioTask, plan: list[dict], original_id: str - ) -> list[AudioTask]: + def _extract_emit(self, task: AudioTask, plan: list[dict], original_id: str) -> list[AudioTask]: source_path = task.data.get(self.audio_filepath_key) if not source_path or not os.path.exists(source_path): logger.error( @@ -204,9 +202,7 @@ def _extract_emit( outputs.append(self._make_stub_task(task)) return outputs - def _dry_run_emit( - self, task: AudioTask, plan: list[dict], original_id: str - ) -> list[AudioTask]: + def _dry_run_emit(self, task: AudioTask, plan: list[dict], original_id: str) -> list[AudioTask]: """Emit snippet metadata only, without reading or writing audio. ``audio_filepath`` is the tar-internal basename @@ -337,13 +333,10 @@ def _make_snippet_task( # Reset to "" — a downstream pipeline is expected to set this correctly. if "swift_audio_filepath" in new_data: new_data["swift_audio_filepath"] = "" - new_data["segments"] = relativize_segments( - snippet["segments"], snippet["start"], snippet["end"] - ) + new_data["segments"] = relativize_segments(snippet["segments"], snippet["start"], snippet["end"]) if "text" in new_data: new_data["text"] = " ".join(_segment_text(s) for s in snippet["segments"]).strip() return AudioTask( - task_id=f"{task.task_id}::{snippet_id}", dataset_name=task.dataset_name, data=new_data, filepath_key=self.audio_filepath_key, @@ -361,7 +354,6 @@ def _make_stub_task(self, task: AudioTask) -> AudioTask: "segments": [], } return AudioTask( - task_id=f"{task.task_id}::stub", dataset_name=task.dataset_name, data=stub_data, _metadata=copy.deepcopy(task._metadata), diff --git a/nemo_curator/stages/audio/alm/pretrain/io.py b/nemo_curator/stages/audio/alm/pretrain/io.py index ccd013bd91..2f335dcea4 100644 --- a/nemo_curator/stages/audio/alm/pretrain/io.py +++ b/nemo_curator/stages/audio/alm/pretrain/io.py @@ -57,9 +57,7 @@ # ---------------------------------------------------------------------- -def _read_manifest_row_id( - stage_name: str, lineno: int, entry: dict[str, Any], seen_ids: set[str] -) -> str | None: +def _read_manifest_row_id(stage_name: str, lineno: int, entry: dict[str, Any], seen_ids: set[str]) -> str | None: # `id` is required by the pipeline contract: downstream snippet ids # embed it, the metrics aggregator keys per-source records on it, and # tar members are named with it. A row without a usable id can't be @@ -202,16 +200,13 @@ def process(self, _: _EmptyTask) -> list[AudioTask]: logger.warning(f"[{self.name}] line {lineno}: missing {self.audio_filepath_key!r}; skipping") continue if self.audio_path_resolution == AUDIO_PATH_RESOLUTION_BASENAME: - _check_duplicate_audio_basename( - self.name, lineno, original_path, row_id, seen_basenames - ) + _check_duplicate_audio_basename(self.name, lineno, original_path, row_id, seen_basenames) entry[self.audio_filepath_key] = _resolve_audio_path( self.audio_dir, original_path, self.audio_path_resolution ) tasks.append( AudioTask( - task_id=row_id, dataset_name=self.dataset_name, data=entry, filepath_key=self.audio_filepath_key, diff --git a/nemo_curator/stages/audio/common.py b/nemo_curator/stages/audio/common.py index 3bcd464999..6c7c8eb837 100644 --- a/nemo_curator/stages/audio/common.py +++ b/nemo_curator/stages/audio/common.py @@ -154,7 +154,6 @@ def process(self, task: FileGroupTask) -> list[AudioTask]: if line.strip(): results.append( AudioTask( - task_id=f"{task.task_id}_{count}", dataset_name=task.dataset_name, data=json.loads(line.strip()), _metadata=task._metadata, @@ -282,7 +281,6 @@ def process(self, task: AudioTask) -> AudioTask: with self._fs.open(self._path, "a", encoding="utf-8") as f: f.write(json.dumps(task.data, ensure_ascii=False) + "\n") return AudioTask( - task_id=task.task_id, dataset_name=task.dataset_name, data=task.data, _metadata=task._metadata, diff --git a/nemo_curator/stages/audio/datasets/fleurs/create_initial_manifest.py b/nemo_curator/stages/audio/datasets/fleurs/create_initial_manifest.py index fe5099ea08..7ad84e0a7c 100644 --- a/nemo_curator/stages/audio/datasets/fleurs/create_initial_manifest.py +++ b/nemo_curator/stages/audio/datasets/fleurs/create_initial_manifest.py @@ -97,7 +97,6 @@ def process_transcript(self, file_path: str) -> list[AudioTask]: entries.append( AudioTask( data={self.filepath_key: abs_wav, self.text_key: transcript_text}, - task_id=f"task_id_{abs_wav}", dataset_name=f"Fleurs_{self.lang}_{self.split}_{self.raw_data_dir}", filepath_key=self.filepath_key, ) diff --git a/nemo_curator/stages/audio/datasets/readspeech/create_initial_manifest.py b/nemo_curator/stages/audio/datasets/readspeech/create_initial_manifest.py index 533ada385e..ab8029f5e5 100644 --- a/nemo_curator/stages/audio/datasets/readspeech/create_initial_manifest.py +++ b/nemo_curator/stages/audio/datasets/readspeech/create_initial_manifest.py @@ -340,11 +340,10 @@ def process(self, _: _EmptyTask) -> list[AudioTask]: logger.info(f"Creating manifest with {len(selected_entries)} total samples") audio_tasks = [] - for i, entry in enumerate(selected_entries): + for _i, entry in enumerate(selected_entries): audio_tasks.append( AudioTask( data=entry, - task_id=f"readspeech_{i}", dataset_name="DNS-ReadSpeech", filepath_key=self.filepath_key, ) diff --git a/nemo_curator/stages/audio/filtering/band.py b/nemo_curator/stages/audio/filtering/band.py index 0efd6d9a2e..6ffc0be087 100644 --- a/nemo_curator/stages/audio/filtering/band.py +++ b/nemo_curator/stages/audio/filtering/band.py @@ -158,7 +158,7 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]: if "segments" in task.data: survivors = [] for seg in task.data["segments"]: - temp = AudioTask(data=seg, task_id=task.task_id) + temp = AudioTask(data=seg) result = self._process_single(temp) if result is not None: survivors.append(temp.data) @@ -189,9 +189,7 @@ def _process_single(self, task: AudioTask) -> AudioTask | None: actual = task.data.get("band_prediction", "unknown") if actual != self.band_value: - logger.info( - f"[{task.task_id}] BAND FILTER FAILED: prediction '{actual}' != target '{self.band_value}'" - ) + logger.info(f"[{task.task_id}] BAND FILTER FAILED: prediction '{actual}' != target '{self.band_value}'") return None return task diff --git a/nemo_curator/stages/audio/filtering/sigmos.py b/nemo_curator/stages/audio/filtering/sigmos.py index 1667812d17..7cdd84f112 100755 --- a/nemo_curator/stages/audio/filtering/sigmos.py +++ b/nemo_curator/stages/audio/filtering/sigmos.py @@ -288,7 +288,7 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]: if "segments" in task.data: survivors = [] for seg in task.data["segments"]: - temp = AudioTask(data=seg, task_id=task.task_id) + temp = AudioTask(data=seg) result = self._process_single(temp) if result is not None: survivors.append(temp.data) diff --git a/nemo_curator/stages/audio/filtering/utmos.py b/nemo_curator/stages/audio/filtering/utmos.py index d38168f203..c2624f553c 100644 --- a/nemo_curator/stages/audio/filtering/utmos.py +++ b/nemo_curator/stages/audio/filtering/utmos.py @@ -202,7 +202,7 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]: if "segments" in task.data: survivors = [] for seg in task.data["segments"]: - temp = AudioTask(data=seg, task_id=task.task_id) + temp = AudioTask(data=seg) result = self._process_single(temp) if result is not None: survivors.append(temp.data) diff --git a/nemo_curator/stages/audio/inference/speaker_diarization/sortformer.py b/nemo_curator/stages/audio/inference/speaker_diarization/sortformer.py index 3f88b64e1b..5cf95996e6 100644 --- a/nemo_curator/stages/audio/inference/speaker_diarization/sortformer.py +++ b/nemo_curator/stages/audio/inference/speaker_diarization/sortformer.py @@ -230,7 +230,6 @@ def process(self, task: AudioTask) -> AudioTask: output_data[self.diar_segments_key] = segments return AudioTask( - task_id=f"{task.task_id}_sortformer", dataset_name=task.dataset_name, filepath_key=task.filepath_key or self.filepath_key, data=output_data, diff --git a/nemo_curator/stages/audio/io/convert.py b/nemo_curator/stages/audio/io/convert.py index 4edaa06a7d..2f13556d4a 100644 --- a/nemo_curator/stages/audio/io/convert.py +++ b/nemo_curator/stages/audio/io/convert.py @@ -81,7 +81,6 @@ def process_batch(self, tasks: list[AudioTask]) -> list[DocumentBatch]: return [ DocumentBatch( data=df, - task_id=",".join(t.task_id for t in tasks), dataset_name=",".join(dict.fromkeys(t.dataset_name for t in tasks)), _stage_perf=perf, ) diff --git a/nemo_curator/stages/audio/io/extract_segments.py b/nemo_curator/stages/audio/io/extract_segments.py index fb8383ac85..c93a7ac368 100644 --- a/nemo_curator/stages/audio/io/extract_segments.py +++ b/nemo_curator/stages/audio/io/extract_segments.py @@ -143,15 +143,17 @@ def _intervals_from_diar_segments(entry: dict) -> list[Interval]: speaker_id = entry.get("speaker_id", "unknown") logger.warning(f" {speaker_id}: no diar_segments, skipping") return [] - return [ - (int(s * 1000), int(e * 1000), e - s) - for s, e in sorted(diar_segments, key=lambda x: x[0]) - ] + return [(int(s * 1000), int(e * 1000), e - s) for s, e in sorted(diar_segments, key=lambda x: x[0])] def _base_metadata( # noqa: PLR0913 - filename: str, original_file: str, entry: dict, - seg_idx: int, start_ms: int, end_ms: int, dur: float, + filename: str, + original_file: str, + entry: dict, + seg_idx: int, + start_ms: int, + end_ms: int, + dur: float, ) -> dict: row: dict = { "filename": filename, @@ -346,10 +348,7 @@ def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: self._all_metadata_rows.extend(metadata_rows) _write_metadata_csv(self.output_dir, self._all_metadata_rows) - logger.info( - f"[{self.name}] Extracted {extracted} segments " - f"({total_dur:.1f}s) from {len(tasks)} entries" - ) + logger.info(f"[{self.name}] Extracted {extracted} segments ({total_dur:.1f}s) from {len(tasks)} entries") if speaker_counts: for speaker, count in sorted(speaker_counts.items()): logger.debug(f" {speaker}: {count} segments") @@ -361,7 +360,8 @@ def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: # ------------------------------------------------------------------ def _extract_by_timestamps( - self, entries: list[dict], + self, + entries: list[dict], ) -> tuple[int, float, dict[str, int], list[dict]]: """Combo 2: extract by original_start_ms / original_end_ms.""" @@ -378,7 +378,8 @@ def _make_filename(name: str, _entry: dict, _seg_idx: int) -> str: ) def _extract_speaker_diar( - self, entries: list[dict], + self, + entries: list[dict], ) -> tuple[int, float, dict[str, int], list[dict]]: """Combo 3: extract each diar_segment per speaker.""" @@ -396,7 +397,8 @@ def _make_filename(name: str, entry: dict, _seg_idx: int) -> str: ) def _extract_speaker_timestamps( - self, entries: list[dict], + self, + entries: list[dict], ) -> tuple[int, float, dict[str, int], list[dict]]: """Combo 4: extract speaker-segments by timestamps.""" @@ -546,7 +548,9 @@ def extract_from_manifest(self, input_path: str) -> None: def extract_segments_by_timestamps( - entries: list, output_dir: str, output_format: str, + entries: list, + output_dir: str, + output_format: str, ) -> tuple[int, float, dict[str, int], list[dict]]: """Extract segments by original_start_ms / original_end_ms, sorted by start time.""" stage = SegmentExtractionStage(output_dir=output_dir, output_format=output_format) @@ -554,7 +558,9 @@ def extract_segments_by_timestamps( def extract_speaker_diar_segments( - entries: list, output_dir: str, output_format: str, + entries: list, + output_dir: str, + output_format: str, ) -> tuple[int, float, dict[str, int], list[dict]]: """Extract individual speaking intervals from diar_segments per speaker.""" stage = SegmentExtractionStage(output_dir=output_dir, output_format=output_format) @@ -562,7 +568,9 @@ def extract_speaker_diar_segments( def extract_speaker_segments_by_timestamps( - entries: list, output_dir: str, output_format: str, + entries: list, + output_dir: str, + output_format: str, ) -> tuple[int, float, dict[str, int], list[dict]]: """Extract speaker-segments using original_start_ms / original_end_ms.""" stage = SegmentExtractionStage(output_dir=output_dir, output_format=output_format) diff --git a/nemo_curator/stages/audio/preprocessing/concatenation.py b/nemo_curator/stages/audio/preprocessing/concatenation.py index 36fcfaa53c..4afd134bef 100755 --- a/nemo_curator/stages/audio/preprocessing/concatenation.py +++ b/nemo_curator/stages/audio/preprocessing/concatenation.py @@ -111,7 +111,7 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]: segments_sorted = sorted(segments, key=self._seg_sort_key) original_file = segments_sorted[0].get("original_file", "unknown") - combined = self._concatenate(original_file, segments_sorted, task.task_id, task.dataset_name) + combined = self._concatenate(original_file, segments_sorted, task.dataset_name) if combined is None: return [] return combined @@ -147,7 +147,6 @@ def _concatenate( self, original_file: str, segments: list[dict[str, Any]], - task_id: str, dataset_name: str, ) -> AudioTask | None: """Concatenate a list of segment dicts from the same source file.""" @@ -166,8 +165,7 @@ def _concatenate( if parts and sr != sample_rate: logger.warning( - f"[SegmentConcat] Sample rate mismatch: " - f"expected {sample_rate}Hz, got {sr}Hz. Skipping segment." + f"[SegmentConcat] Sample rate mismatch: expected {sample_rate}Hz, got {sr}Hz. Skipping segment." ) continue sample_rate = sr @@ -227,7 +225,6 @@ def _concatenate( result_task = AudioTask( data=output_data, - task_id=task_id, dataset_name=dataset_name, ) result_task._metadata = {"segment_mappings": mappings} diff --git a/nemo_curator/stages/audio/segmentation/speaker_separation.py b/nemo_curator/stages/audio/segmentation/speaker_separation.py index 5ffe71d779..06e1434015 100755 --- a/nemo_curator/stages/audio/segmentation/speaker_separation.py +++ b/nemo_curator/stages/audio/segmentation/speaker_separation.py @@ -189,7 +189,6 @@ def _build_speaker_tasks( } spk_task = AudioTask( data=speaker_data, - task_id=f"{task.task_id}_{speaker_id}", dataset_name=task.dataset_name, ) if task._metadata: diff --git a/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py b/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py index aa0426abcd..1023d5e1c8 100755 --- a/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py +++ b/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py @@ -29,6 +29,7 @@ class SpeakerResult(NamedTuple): duration: float diar_segments: list[tuple[float, float]] + try: from nemo.collections.asr.models import SortformerEncLabelModel except ImportError: diff --git a/nemo_curator/stages/audio/segmentation/vad_segmentation.py b/nemo_curator/stages/audio/segmentation/vad_segmentation.py index 1dda9b1159..d259ead3b9 100755 --- a/nemo_curator/stages/audio/segmentation/vad_segmentation.py +++ b/nemo_curator/stages/audio/segmentation/vad_segmentation.py @@ -272,7 +272,6 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]: seg_data = self._build_segment_item(task.data, waveform, sample_rate, segment, i) seg_task = AudioTask( data=seg_data, - task_id=f"{task.task_id}_seg_{i}", dataset_name=task.dataset_name, ) if task._metadata: diff --git a/nemo_curator/stages/base.py b/nemo_curator/stages/base.py index 5761dfeb18..81dd00a71c 100644 --- a/nemo_curator/stages/base.py +++ b/nemo_curator/stages/base.py @@ -87,6 +87,16 @@ class ProcessingStage(ABC, Generic[X, Y], metaclass=StageMeta): batch_size = 1 runtime_env: ClassVar[dict[str, Any] | None] = None + # Source / sink role flags. User-overridable on the stage class or + # instance. If neither is set explicitly on any stage in the pipeline, + # ``Pipeline.build()`` defaults the first stage to source and the last + # to sink. The source flag selects content-based ids from + # ``Task.get_deterministic_id()`` (when the Task subclass implements + # one) for this task's id segment; the sink flag is reserved for the + # resumability layer to mark the counter-decrement boundary. + is_source_stage: bool = False + is_sink_stage: bool = False + @property @final def _name(self) -> str: @@ -182,6 +192,14 @@ def process_batch(self, tasks: list[X]) -> list[Y]: Note: The returned list should have the same length as the input list, with each element corresponding to the result of processing the task at the same index. + + ``task_id`` is framework-owned: stages must NOT set it. The executor + adapter (``BaseStageAdapter._post_process_task_ids``) assigns a + deterministic id to every emitted task — regardless of whether + a stage uses this default or overrides ``process_batch``. Where the + input→output mapping is ambiguous (e.g. a batch aggregation), the + adapter falls back to a random ``"r"``-prefixed id (see + ``Task.task_id``); there is no way for a stage to supply its own. """ # Default implementation: process tasks one by one # This is only used as a fallback if a stage doesn't override this method diff --git a/nemo_curator/stages/client_partitioning.py b/nemo_curator/stages/client_partitioning.py index 2d5c0a59a9..a71fda57b1 100644 --- a/nemo_curator/stages/client_partitioning.py +++ b/nemo_curator/stages/client_partitioning.py @@ -74,7 +74,6 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]: for i, group in enumerate(partitions): tasks.append( FileGroupTask( - task_id=f"file_group_{i}", dataset_name=dataset_name, data=group, _metadata={ diff --git a/nemo_curator/stages/deduplication/exact/identification.py b/nemo_curator/stages/deduplication/exact/identification.py index 23882a8c1f..564d79de88 100644 --- a/nemo_curator/stages/deduplication/exact/identification.py +++ b/nemo_curator/stages/deduplication/exact/identification.py @@ -256,7 +256,6 @@ def extract_and_write(self) -> list[FileGroupTask]: removal_ids.to_parquet(output_file, **write_kwargs) result_tasks.append( FileGroupTask( - task_id=partition_id, dataset_name=f"{self.dataset_name}_{self.name}", data=[output_file], _metadata={ diff --git a/nemo_curator/stages/deduplication/fuzzy/buckets_to_edges.py b/nemo_curator/stages/deduplication/fuzzy/buckets_to_edges.py index 7963724cb1..5aac56bfc6 100644 --- a/nemo_curator/stages/deduplication/fuzzy/buckets_to_edges.py +++ b/nemo_curator/stages/deduplication/fuzzy/buckets_to_edges.py @@ -80,10 +80,9 @@ def process(self, task: FileGroupTask) -> FileGroupTask: pd.DataFrame(edges, columns=[f"{self.document_id_field}_x", f"{self.document_id_field}_y"]) ) - output_path = self.output_fs.sep.join([self.output_path, f"{task._uuid}.parquet"]) + output_path = self.output_fs.sep.join([self.output_path, f"{task.task_id}.parquet"]) pq.write_table(edges, output_path, filesystem=self.output_fs) return FileGroupTask( - task_id=f"{task.task_id}", dataset_name=f"{task.dataset_name}_edges", data=[output_path], _metadata={**task._metadata, "storage_options": self.write_storage_options}, diff --git a/nemo_curator/stages/deduplication/fuzzy/connected_components.py b/nemo_curator/stages/deduplication/fuzzy/connected_components.py index 333fbad077..185a084d60 100644 --- a/nemo_curator/stages/deduplication/fuzzy/connected_components.py +++ b/nemo_curator/stages/deduplication/fuzzy/connected_components.py @@ -173,7 +173,7 @@ def process_batch(self, tasks: list[FileGroupTask]) -> list[FileGroupTask]: input_files = [] for task in tasks: input_files.extend(task.data) - output_file = self.output_fs.sep.join([self.output_path, f"{tasks[0]._uuid}.parquet"]) + output_file = self.output_fs.sep.join([self.output_path, f"{tasks[0].task_id}.parquet"]) edgelist_columns = [self.source_field, self.destination_field] dfs = [] for input_file in input_files: @@ -192,7 +192,6 @@ def process_batch(self, tasks: list[FileGroupTask]) -> list[FileGroupTask]: return [ FileGroupTask( dataset_name=tasks[0].dataset_name, - task_id=tasks[0].task_id, data=[output_file], _metadata={ "storage_options": self.write_kwargs.get("storage_options"), diff --git a/nemo_curator/stages/deduplication/fuzzy/identify_duplicates.py b/nemo_curator/stages/deduplication/fuzzy/identify_duplicates.py index 7468b3992a..351dfd6e87 100644 --- a/nemo_curator/stages/deduplication/fuzzy/identify_duplicates.py +++ b/nemo_curator/stages/deduplication/fuzzy/identify_duplicates.py @@ -134,7 +134,6 @@ def extract_and_write(self) -> list[FileGroupTask]: removal_ids.to_parquet(output_file, **write_kwargs) result_tasks.append( FileGroupTask( - task_id=partition_id, dataset_name=self.dataset_name + f"{self.name}", data=[output_file], _metadata={ diff --git a/nemo_curator/stages/deduplication/fuzzy/lsh/stage.py b/nemo_curator/stages/deduplication/fuzzy/lsh/stage.py index 0d94aa0669..426dc4006f 100644 --- a/nemo_curator/stages/deduplication/fuzzy/lsh/stage.py +++ b/nemo_curator/stages/deduplication/fuzzy/lsh/stage.py @@ -155,11 +155,10 @@ def insert_finished(self) -> None: def extract_and_write(self) -> list[FileGroupTask]: self._check_actor_obj() - current_band_min, current_band_max = self._current_band_range + _current_band_min, _current_band_max = self._current_band_range partition_dicts = self._actor_obj.extract_and_write() return [ FileGroupTask( - task_id=f"b{current_band_min}_b{current_band_max}_{partition_info['partition_id']}", dataset_name=self.dataset_name + f"{self.name}", data=[partition_info["path"]], _metadata={ diff --git a/nemo_curator/stages/deduplication/fuzzy/minhash.py b/nemo_curator/stages/deduplication/fuzzy/minhash.py index d9d6c65c41..93848372d7 100644 --- a/nemo_curator/stages/deduplication/fuzzy/minhash.py +++ b/nemo_curator/stages/deduplication/fuzzy/minhash.py @@ -306,7 +306,7 @@ def process(self, task: FileGroupTask) -> FileGroupTask: msg = "MinHash processor or ID generator not initialized. Call setup() first." raise RuntimeError(msg) - output_file = self.output_fs.sep.join([self.output_path, f"{task._uuid}.parquet"]) + output_file = self.output_fs.sep.join([self.output_path, f"{task.task_id}.parquet"]) read_kwargs = self.read_kwargs.copy() @@ -327,7 +327,6 @@ def process(self, task: FileGroupTask) -> FileGroupTask: # Return FileGroupTask with output file return FileGroupTask( - task_id=f"{task.task_id}", dataset_name=f"{task.dataset_name}_minhash", data=[output_file], _metadata={ diff --git a/nemo_curator/stages/deduplication/semantic/identify_duplicates.py b/nemo_curator/stages/deduplication/semantic/identify_duplicates.py index 21f51b7d52..cc6820935c 100644 --- a/nemo_curator/stages/deduplication/semantic/identify_duplicates.py +++ b/nemo_curator/stages/deduplication/semantic/identify_duplicates.py @@ -19,9 +19,9 @@ import pandas as pd from nemo_curator.stages.base import ProcessingStage -from nemo_curator.stages.text.io.writer.utils import get_deterministic_hash from nemo_curator.tasks import FileGroupTask from nemo_curator.utils.file_utils import check_disallowed_kwargs +from nemo_curator.utils.hash_utils import get_deterministic_hash @dataclass @@ -121,7 +121,6 @@ def process_batch(self, tasks: list[FileGroupTask]) -> list[FileGroupTask]: # Create output task return [ FileGroupTask( - task_id=f"identify_duplicates_{get_deterministic_hash(all_files, tasks[0].task_id)}", dataset_name=tasks[0].dataset_name, data=[output_file], _metadata={**tasks[0]._metadata, "num_removed": len(df)}, diff --git a/nemo_curator/stages/deduplication/semantic/kmeans.py b/nemo_curator/stages/deduplication/semantic/kmeans.py index cb6c7cec3c..9453853b8c 100644 --- a/nemo_curator/stages/deduplication/semantic/kmeans.py +++ b/nemo_curator/stages/deduplication/semantic/kmeans.py @@ -232,7 +232,7 @@ def _process_batch_single_pass(self, tasks: list[FileGroupTask], groups: list[li # Assign distances using the fitted cluster centers df = self._assign_distances(df, self.embedding_field, self.kmeans.cluster_centers_) # noqa: PLW2901 - output_filename = f"{tasks[0]._uuid}_{i}" + output_filename = f"{tasks[0].task_id}_{i}" # Write results for this subgroup self.write_parquet( df, @@ -247,7 +247,6 @@ def _process_batch_single_pass(self, tasks: list[FileGroupTask], groups: list[li # Create result task for this subgroup results.append( _EmptyTask( - task_id=output_filename, dataset_name=f"kmeans_group_{i}", _metadata=None, _stage_perf=[], @@ -278,10 +277,12 @@ def _process_batch_two_pass(self, tasks: list[FileGroupTask], groups: list[list[ """ pass1_read_time = self._fit_pass(groups) results, pass2_read_time, total_rows = self._predict_write_pass(tasks, groups) - self._log_metrics({ - "kmeans_read_time": pass1_read_time + pass2_read_time, - "num_rows": total_rows, - }) + self._log_metrics( + { + "kmeans_read_time": pass1_read_time + pass2_read_time, + "num_rows": total_rows, + } + ) return results def _fit_pass(self, groups: list[list[str]]) -> float: @@ -392,7 +393,7 @@ def _predict_write_pass( df["centroid"] = labels df = self._assign_distances(df, self.embedding_field, self.kmeans.cluster_centers_) - output_filename = f"{tasks[0]._uuid}_{i}" + output_filename = f"{tasks[0].task_id}_{i}" self.write_parquet( df, self.output_path, @@ -404,7 +405,6 @@ def _predict_write_pass( ) results.append( _EmptyTask( - task_id=output_filename, dataset_name=f"kmeans_group_{i}", _metadata=None, _stage_perf=[], diff --git a/nemo_curator/stages/deduplication/semantic/pairwise.py b/nemo_curator/stages/deduplication/semantic/pairwise.py index f214b59d33..a6d724ade4 100644 --- a/nemo_curator/stages/deduplication/semantic/pairwise.py +++ b/nemo_curator/stages/deduplication/semantic/pairwise.py @@ -158,7 +158,6 @@ def process(self, task: FileGroupTask) -> FileGroupTask: if not dfs: logger.warning(f"No data found for cluster {cluster_id}") return FileGroupTask( - task_id=task.task_id, dataset_name=task.dataset_name, _metadata=task._metadata, _stage_perf=task._stage_perf, @@ -180,7 +179,6 @@ def process(self, task: FileGroupTask) -> FileGroupTask: result_df, output_path, storage_options=self.output_storage_options, index=False, **self.write_kwargs ) return FileGroupTask( - task_id=task.task_id, dataset_name=task.dataset_name, _metadata={ **task._metadata, @@ -245,7 +243,6 @@ def process(self, task: FileGroupTask) -> FileGroupTask: ) return FileGroupTask( - task_id=task.task_id, dataset_name=task.dataset_name, _metadata={**task._metadata, "centroid_id": cluster_id}, _stage_perf=task._stage_perf, diff --git a/nemo_curator/stages/deduplication/semantic/pairwise_io.py b/nemo_curator/stages/deduplication/semantic/pairwise_io.py index a801d03e66..c019d009ba 100644 --- a/nemo_curator/stages/deduplication/semantic/pairwise_io.py +++ b/nemo_curator/stages/deduplication/semantic/pairwise_io.py @@ -112,7 +112,6 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]: fs=self.fs, ) pairwise_task = FileGroupTask( - task_id=f"pairwise_centroid_{centroid_id}", dataset_name=dataset_name, data=partition_files, _metadata={ diff --git a/nemo_curator/stages/deduplication/shuffle_utils/stage.py b/nemo_curator/stages/deduplication/shuffle_utils/stage.py index 9465b735ae..7a8c58bba2 100644 --- a/nemo_curator/stages/deduplication/shuffle_utils/stage.py +++ b/nemo_curator/stages/deduplication/shuffle_utils/stage.py @@ -130,7 +130,6 @@ def extract_and_write(self) -> list[FileGroupTask]: partition_paths = self._actor_obj.extract_and_write(column_names=self.output_columns) return [ FileGroupTask( - task_id=partition_id, dataset_name=self.dataset_name + f"{self.name}", data=[path], _metadata={ diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index 1e60e17480..53d12ae88e 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -175,7 +175,6 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]: logger.info(f"Reached limit of {self.limit} file groups") break file_task = FileGroupTask( - task_id=f"file_group_{i}", dataset_name=dataset_name, data=file_group, _metadata={ diff --git a/nemo_curator/stages/image/deduplication/removal.py b/nemo_curator/stages/image/deduplication/removal.py index d30f433b29..61537985d8 100644 --- a/nemo_curator/stages/image/deduplication/removal.py +++ b/nemo_curator/stages/image/deduplication/removal.py @@ -58,7 +58,11 @@ def outputs(self) -> tuple[list[str], list[str]]: return ["data"], [] def setup(self, _worker_metadata=None) -> None: # noqa: ANN001 - removal_parquets = [os.path.join(self.removal_parquets_dir, f) for f in os.listdir(self.removal_parquets_dir) if f.endswith(".parquet")] + removal_parquets = [ + os.path.join(self.removal_parquets_dir, f) + for f in os.listdir(self.removal_parquets_dir) + if f.endswith(".parquet") + ] if not removal_parquets: msg = f"No parquet files found in {self.removal_parquets_dir}" logger.error(msg) @@ -73,9 +77,7 @@ def setup(self, _worker_metadata=None) -> None: # noqa: ANN001 self._ids_to_remove.update(ids_array) if self.verbose: - logger.debug( - f"Loaded {len(self._ids_to_remove)} IDs to remove from '{self.removal_parquets_dir}'" - ) + logger.debug(f"Loaded {len(self._ids_to_remove)} IDs to remove from '{self.removal_parquets_dir}'") def process(self, task: ImageBatch) -> ImageBatch: original_count = len(task.data) @@ -85,14 +87,12 @@ def process(self, task: ImageBatch) -> ImageBatch: removed_count = original_count - len(filtered_images) if self.verbose: logger.debug( - f"Dedup filtering: kept {len(filtered_images)}/{original_count} images, " - f"removed {removed_count} by ID" + f"Dedup filtering: kept {len(filtered_images)}/{original_count} images, removed {removed_count} by ID" ) return ImageBatch( data=filtered_images, dataset_name=task.dataset_name, - task_id=f"{task.task_id}_{self.name}", _metadata=task._metadata, _stage_perf=task._stage_perf, ) diff --git a/nemo_curator/stages/image/embedders/clip_embedder.py b/nemo_curator/stages/image/embedders/clip_embedder.py index 5f8500a69a..6c0592c725 100644 --- a/nemo_curator/stages/image/embedders/clip_embedder.py +++ b/nemo_curator/stages/image/embedders/clip_embedder.py @@ -33,6 +33,7 @@ class ImageEmbeddingStage(ProcessingStage[ImageBatch, ImageBatch]): embeddings for each image. It assumes image data is already loaded in ImageObject.image_data and stores embeddings in ImageObject.embedding. """ + model_dir: str = None num_gpus_per_worker: float = 0.25 model_inference_batch_size: int = 32 # Number of images to process through model at once @@ -110,8 +111,6 @@ def process(self, task: ImageBatch) -> ImageBatch: image_obj.image_data = None if self.verbose: - logger.info( - f"Generated embeddings for {len(batch)} images." - ) + logger.info(f"Generated embeddings for {len(batch)} images.") return task diff --git a/nemo_curator/stages/image/filters/aesthetic_filter.py b/nemo_curator/stages/image/filters/aesthetic_filter.py index a6e1a4432d..ca992fe2f2 100644 --- a/nemo_curator/stages/image/filters/aesthetic_filter.py +++ b/nemo_curator/stages/image/filters/aesthetic_filter.py @@ -31,6 +31,7 @@ class ImageAestheticFilterStage(BaseFilterStage): This class processes image batches through an aesthetic scoring model to generate aesthetic scores for each image. Images with scores below the threshold will be filtered out. """ + model_dir: str = None num_gpus_per_worker: float = 0.25 model_inference_batch_size: int = 32 # Number of images to process through model at once @@ -38,7 +39,9 @@ class ImageAestheticFilterStage(BaseFilterStage): verbose: bool = False name: str = "image_aesthetic_filter" - def setup_on_node(self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None) -> None: + def setup_on_node( + self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None + ) -> None: """Download aesthetic model weights from HF""" AestheticScorer.download_weights_on_node(self.model_dir) @@ -99,7 +102,6 @@ def process(self, task: ImageBatch) -> ImageBatch: return ImageBatch( data=filtered_images, dataset_name=task.dataset_name, - task_id=f"{task.task_id}_{self.name}", _metadata=task._metadata, _stage_perf=task._stage_perf, ) diff --git a/nemo_curator/stages/image/filters/base.py b/nemo_curator/stages/image/filters/base.py index 282d603e8e..20b4464f96 100644 --- a/nemo_curator/stages/image/filters/base.py +++ b/nemo_curator/stages/image/filters/base.py @@ -29,6 +29,7 @@ class BaseFilterStage(ProcessingStage[ImageBatch, ImageBatch]): This class provides a base class for image filtering stages. """ + model_dir: str = None num_gpus_per_worker: float = 0.25 model_inference_batch_size: int = 32 # Number of images to process through model at once diff --git a/nemo_curator/stages/image/filters/nsfw_filter.py b/nemo_curator/stages/image/filters/nsfw_filter.py index de9bfbf0c0..4e39e543ed 100644 --- a/nemo_curator/stages/image/filters/nsfw_filter.py +++ b/nemo_curator/stages/image/filters/nsfw_filter.py @@ -32,10 +32,13 @@ class ImageNSFWFilterStage(BaseFilterStage): NSFW probability scores for each image. Images with scores above the threshold will be filtered out as NSFW content. """ + weights_path: str = None name: str = "image_nsfw_filter" - def setup_on_node(self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None) -> None: + def setup_on_node( + self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None + ) -> None: """Download NSFW model weights from LAION repository.""" NSFWScorer.download_weights_on_node(self.model_dir) @@ -73,8 +76,7 @@ def process(self, task: ImageBatch) -> ImageBatch: if self.verbose: logger.info( - f"Generated NSFW scores for {len(batch)} images " - f"in batch {i}-{i + self.model_inference_batch_size}" + f"Generated NSFW scores for {len(batch)} images in batch {i}-{i + self.model_inference_batch_size}" ) # Filter images based on NSFW score threshold @@ -94,15 +96,13 @@ def process(self, task: ImageBatch) -> ImageBatch: if self.verbose: logger.info( - f"NSFW filtering: {len(filtered_images)}/{len(task.data)} images passed, " - f"{filtered_count} filtered out" + f"NSFW filtering: {len(filtered_images)}/{len(task.data)} images passed, {filtered_count} filtered out" ) # Return new ImageBatch with filtered images return ImageBatch( data=filtered_images, dataset_name=task.dataset_name, - task_id=f"{task.task_id}_{self.name}", _metadata=task._metadata, _stage_perf=task._stage_perf, ) diff --git a/nemo_curator/stages/image/io/convert.py b/nemo_curator/stages/image/io/convert.py index e257a95680..7026cedc43 100644 --- a/nemo_curator/stages/image/io/convert.py +++ b/nemo_curator/stages/image/io/convert.py @@ -28,6 +28,7 @@ class ConvertImageBatchToDocumentBatchStage(ProcessingStage[ImageBatch, Document Args: fields: list of fields of ImageObject to convert to DocumentBatch """ + fields: list[str] = field(default_factory=list) name: str = "convert_image_batch_to_document_batch" @@ -45,7 +46,6 @@ def process(self, task: ImageBatch) -> DocumentBatch: df = pd.DataFrame(data) return DocumentBatch( - task_id=f"{task.task_id}_{self.name}", dataset_name=task.dataset_name, data=df, _metadata=task._metadata, diff --git a/nemo_curator/stages/image/io/image_reader.py b/nemo_curator/stages/image/io/image_reader.py index 9a08789e46..4968494344 100644 --- a/nemo_curator/stages/image/io/image_reader.py +++ b/nemo_curator/stages/image/io/image_reader.py @@ -137,9 +137,8 @@ def _read_tars_with_dali(self, tar_paths: list[pathlib.Path]) -> Generator[list[ def _stream_batches(self, tar_files: list[pathlib.Path]) -> Generator[ImageBatch, None, None]: """Emit one ImageBatch per DALI run across all provided tar files.""" - for batch_id, image_objects in enumerate(self._read_tars_with_dali(tar_files)): + for _batch_id, image_objects in enumerate(self._read_tars_with_dali(tar_files)): yield ImageBatch( - task_id=f"image_batch_{batch_id}", dataset_name="tar_files", data=image_objects, ) diff --git a/nemo_curator/stages/image/io/image_writer.py b/nemo_curator/stages/image/io/image_writer.py index 0de0f8f02c..92855beb11 100644 --- a/nemo_curator/stages/image/io/image_writer.py +++ b/nemo_curator/stages/image/io/image_writer.py @@ -227,7 +227,6 @@ def process(self, task: ImageBatch) -> FileGroupTask: # Return FileGroupTask with produced files return FileGroupTask( - task_id=task.task_id, dataset_name=task.dataset_name, data=[*tar_paths, *parquet_paths], _metadata={ diff --git a/nemo_curator/stages/interleaved/io/readers/parquet.py b/nemo_curator/stages/interleaved/io/readers/parquet.py index 0e8260ce3b..988774a010 100644 --- a/nemo_curator/stages/interleaved/io/readers/parquet.py +++ b/nemo_curator/stages/interleaved/io/readers/parquet.py @@ -116,7 +116,7 @@ def process(self, task: FileGroupTask) -> InterleavedBatch | list[InterleavedBat splits = split_table_by_group_max_bytes(combined, "sample_id", self.max_batch_bytes) batches: list[InterleavedBatch] = [] for idx, split in enumerate(splits): - task_id = f"{task.task_id}_processed" if len(splits) == 1 else f"{task.task_id}_processed_{idx:05d}" + f"{task.task_id}_processed" if len(splits) == 1 else f"{task.task_id}_processed_{idx:05d}" metadata: dict[str, Any] = dict(task._metadata) if len(splits) == 1: metadata["source_files"] = list(task.data) @@ -126,7 +126,6 @@ def process(self, task: FileGroupTask) -> InterleavedBatch | list[InterleavedBat metadata["source_storage_options"] = self._storage_options batches.append( InterleavedBatch( - task_id=task_id, dataset_name=task.dataset_name, data=split, _metadata=metadata, diff --git a/nemo_curator/stages/interleaved/io/readers/webdataset.py b/nemo_curator/stages/interleaved/io/readers/webdataset.py index 767d442c7b..95a5d3c7b8 100644 --- a/nemo_curator/stages/interleaved/io/readers/webdataset.py +++ b/nemo_curator/stages/interleaved/io/readers/webdataset.py @@ -456,7 +456,7 @@ def process(self, task: FileGroupTask) -> InterleavedBatch | list[InterleavedBat splits = split_table_by_group_max_bytes(table, "sample_id", self.max_batch_bytes) batches: list[InterleavedBatch] = [] for idx, split in enumerate(splits): - task_id = f"{task.task_id}_processed" if len(splits) == 1 else f"{task.task_id}_processed_{idx:05d}" + f"{task.task_id}_processed" if len(splits) == 1 else f"{task.task_id}_processed_{idx:05d}" metadata = dict(task._metadata) if len(splits) == 1: metadata["source_files"] = list(task.data) @@ -466,7 +466,6 @@ def process(self, task: FileGroupTask) -> InterleavedBatch | list[InterleavedBat metadata["source_storage_options"] = self._storage_options batches.append( InterleavedBatch( - task_id=task_id, dataset_name=task.dataset_name, data=split, _metadata=metadata, diff --git a/nemo_curator/stages/interleaved/io/writers/base.py b/nemo_curator/stages/interleaved/io/writers/base.py index 5a8c6bb1f3..ff7ecfafa5 100644 --- a/nemo_curator/stages/interleaved/io/writers/base.py +++ b/nemo_curator/stages/interleaved/io/writers/base.py @@ -24,13 +24,13 @@ from fsspec.core import url_to_fs from loguru import logger -import nemo_curator.stages.text.io.writer.utils as writer_utils from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.interleaved.utils import materialize_task_binary_content from nemo_curator.stages.interleaved.utils.schema import align_table, reconcile_schema, resolve_schema from nemo_curator.tasks import FileGroupTask, InterleavedBatch from nemo_curator.utils.client_utils import is_remote_url from nemo_curator.utils.file_utils import check_output_mode +from nemo_curator.utils.hash_utils import get_deterministic_hash @dataclass @@ -141,7 +141,7 @@ def write_data(self, task: InterleavedBatch, file_path: str) -> None: def process(self, task: InterleavedBatch) -> FileGroupTask: if source_files := task._metadata.get("source_files"): - filename = writer_utils.get_deterministic_hash(source_files, task.task_id) + filename = get_deterministic_hash(source_files, task.task_id) else: logger.warning("The task does not have source_files in metadata, using UUID for base filename") filename = uuid.uuid4().hex @@ -151,7 +151,6 @@ def process(self, task: InterleavedBatch) -> FileGroupTask: self.write_data(task, file_path_with_protocol) return FileGroupTask( - task_id=task.task_id, dataset_name=task.dataset_name, data=[file_path_with_protocol], _metadata={**task._metadata, "format": self.file_extension}, diff --git a/nemo_curator/stages/interleaved/pdf/nemotron_parse/inference.py b/nemo_curator/stages/interleaved/pdf/nemotron_parse/inference.py index 07f89c392c..2eb130530c 100644 --- a/nemo_curator/stages/interleaved/pdf/nemotron_parse/inference.py +++ b/nemo_curator/stages/interleaved/pdf/nemotron_parse/inference.py @@ -259,7 +259,6 @@ def process(self, task: InterleavedBatch) -> InterleavedBatch | None: metadata["model_path"] = self.model_path return InterleavedBatch( - task_id=f"{task.task_id}_inferred", dataset_name=task.dataset_name, data=pa.Table.from_pandas(task_df, preserve_index=False), _metadata=metadata, diff --git a/nemo_curator/stages/interleaved/pdf/nemotron_parse/partitioning.py b/nemo_curator/stages/interleaved/pdf/nemotron_parse/partitioning.py index b6c1aeb21d..6fd844b1c1 100644 --- a/nemo_curator/stages/interleaved/pdf/nemotron_parse/partitioning.py +++ b/nemo_curator/stages/interleaved/pdf/nemotron_parse/partitioning.py @@ -129,10 +129,8 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]: for i in range(0, len(entries), self.pdfs_per_task): batch = entries[i : i + self.pdfs_per_task] task_idx = i // self.pdfs_per_task - task_id = f"pdf_batch_{task_idx:06d}" tasks.append( FileGroupTask( - task_id=task_id, dataset_name=self.dataset_name, data=batch, _metadata={"source_files": batch, "partition_index": task_idx}, diff --git a/nemo_curator/stages/interleaved/pdf/nemotron_parse/postprocess.py b/nemo_curator/stages/interleaved/pdf/nemotron_parse/postprocess.py index 4e956a570d..b0d7776aa0 100644 --- a/nemo_curator/stages/interleaved/pdf/nemotron_parse/postprocess.py +++ b/nemo_curator/stages/interleaved/pdf/nemotron_parse/postprocess.py @@ -104,7 +104,6 @@ def process(self, task: InterleavedBatch) -> InterleavedBatch | None: final_df[col] = None return InterleavedBatch( - task_id=f"{task.task_id}_postprocessed", dataset_name=task.dataset_name, data=pa.Table.from_pandas(final_df, preserve_index=False), _metadata=task._metadata, diff --git a/nemo_curator/stages/interleaved/pdf/nemotron_parse/preprocess.py b/nemo_curator/stages/interleaved/pdf/nemotron_parse/preprocess.py index ba8f07d3e9..b76c98174b 100644 --- a/nemo_curator/stages/interleaved/pdf/nemotron_parse/preprocess.py +++ b/nemo_curator/stages/interleaved/pdf/nemotron_parse/preprocess.py @@ -238,7 +238,6 @@ def process(self, task: FileGroupTask) -> InterleavedBatch | None: pages_df = pd.DataFrame(rows) return InterleavedBatch( - task_id=f"{task.task_id}_preprocessed", dataset_name=task.dataset_name, data=pa.Table.from_pandas(pages_df, preserve_index=False), _metadata=task._metadata, diff --git a/nemo_curator/stages/interleaved/stages.py b/nemo_curator/stages/interleaved/stages.py index fecc7a9f05..436b9af8ba 100644 --- a/nemo_curator/stages/interleaved/stages.py +++ b/nemo_curator/stages/interleaved/stages.py @@ -56,7 +56,6 @@ def process(self, task: InterleavedBatch) -> InterleavedBatch: return task out_df = self.annotate(task, df) return InterleavedBatch( - task_id=f"{task.task_id}_{self.name}", dataset_name=task.dataset_name, data=out_df.reset_index(drop=True), _metadata=task._metadata, @@ -104,7 +103,6 @@ def iter_materialized_bytes( if not masked_indices: return temp_task = InterleavedBatch( - task_id=task.task_id, dataset_name=task.dataset_name, data=df.loc[masked_indices], _metadata=task._metadata, diff --git a/nemo_curator/stages/interleaved/utils/materialization.py b/nemo_curator/stages/interleaved/utils/materialization.py index efb279e11d..d4df54cecf 100644 --- a/nemo_curator/stages/interleaved/utils/materialization.py +++ b/nemo_curator/stages/interleaved/utils/materialization.py @@ -328,7 +328,6 @@ def _build_image_mask( def _task_with_dataframe(task: InterleavedBatch, df: pd.DataFrame) -> InterleavedBatch: return InterleavedBatch( - task_id=task.task_id, dataset_name=task.dataset_name, data=df, _metadata=task._metadata, diff --git a/nemo_curator/stages/math/classifiers/finemath.py b/nemo_curator/stages/math/classifiers/finemath.py index f1ae730a1a..96af374f96 100644 --- a/nemo_curator/stages/math/classifiers/finemath.py +++ b/nemo_curator/stages/math/classifiers/finemath.py @@ -67,7 +67,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/math/download/extract.py b/nemo_curator/stages/math/download/extract.py index 3bcddef930..7851567ac7 100644 --- a/nemo_curator/stages/math/download/extract.py +++ b/nemo_curator/stages/math/download/extract.py @@ -263,7 +263,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: output_cols = [*output_cols, self.filename_col] return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=pd.DataFrame(records) if records else pd.DataFrame(columns=output_cols), _metadata=batch._metadata, diff --git a/nemo_curator/stages/math/modifiers/chunking.py b/nemo_curator/stages/math/modifiers/chunking.py index 618e3679f1..f2f8167971 100644 --- a/nemo_curator/stages/math/modifiers/chunking.py +++ b/nemo_curator/stages/math/modifiers/chunking.py @@ -45,7 +45,9 @@ def __init__( # noqa: PLR0913 self._tokenizer = None self.name = format_name_with_suffix(self.model_name, suffix="_token_splitter") - def setup_on_node(self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None) -> None: + def setup_on_node( + self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None + ) -> None: """Download model weights to local cache once per physical node.""" from huggingface_hub import snapshot_download @@ -124,7 +126,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: output_df = pd.DataFrame(columns=output_cols) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=output_df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/math/modifiers/llm_cleanup.py b/nemo_curator/stages/math/modifiers/llm_cleanup.py index 09c41d6bab..abd1232390 100644 --- a/nemo_curator/stages/math/modifiers/llm_cleanup.py +++ b/nemo_curator/stages/math/modifiers/llm_cleanup.py @@ -143,7 +143,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df = df[df[self.n_tokens_field] < threshold].copy() if len(df) == 0: return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=pd.DataFrame(columns=df.columns), _metadata=batch._metadata, @@ -200,7 +199,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: output_df[self.output_field] = generated_texts return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=output_df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/math/modifiers/merge_chunks.py b/nemo_curator/stages/math/modifiers/merge_chunks.py index 908513d3bd..3d41cf1446 100644 --- a/nemo_curator/stages/math/modifiers/merge_chunks.py +++ b/nemo_curator/stages/math/modifiers/merge_chunks.py @@ -70,7 +70,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: if df.empty: return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -91,7 +90,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: if df.empty: logger.info(f"All {rows_before} rows filtered out during chunk merge") return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=pd.DataFrame(columns=df.columns), _metadata=batch._metadata, @@ -133,7 +131,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: logger.info(f"Chunk merge: {rows_before} rows -> {len(merged)} documents") return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=merged.reset_index(drop=True), _metadata=batch._metadata, diff --git a/nemo_curator/stages/synthetic/nemo_data_designer/data_designer.py b/nemo_curator/stages/synthetic/nemo_data_designer/data_designer.py index aff74fc177..a50ea477c5 100644 --- a/nemo_curator/stages/synthetic/nemo_data_designer/data_designer.py +++ b/nemo_curator/stages/synthetic/nemo_data_designer/data_designer.py @@ -132,12 +132,12 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, _stage_perf=batch._stage_perf, ) + # Explicitly export the class __all__ = ["DataDesignerStage"] diff --git a/nemo_curator/stages/synthetic/nemotron_cc/base.py b/nemo_curator/stages/synthetic/nemotron_cc/base.py index bc929d3254..c7b8b6d45e 100644 --- a/nemo_curator/stages/synthetic/nemotron_cc/base.py +++ b/nemo_curator/stages/synthetic/nemotron_cc/base.py @@ -68,7 +68,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: return DocumentBatch( data=df, dataset_name=batch.dataset_name, - task_id=f"{batch.task_id}_{self.name}", _metadata=batch._metadata, _stage_perf=batch._stage_perf, ) @@ -87,6 +86,7 @@ def _process_llm_response(self, response: list[str]) -> str: def _process_sync(self, df: pd.DataFrame) -> list[str]: """Process DataFrame using synchronous sequential processing.""" + def generate_response(row: pd.Series) -> str: prompt = self._process_llm_prompt(row) if self.system_prompt: @@ -95,9 +95,7 @@ def generate_response(row: pd.Series) -> str: {"role": "user", "content": prompt}, ] else: - messages = [ - {"role": "user", "content": prompt} - ] + messages = [{"role": "user", "content": prompt}] response = self.client.query_model( model=self.model_name, messages=messages, @@ -142,9 +140,7 @@ async def generate_response_async(row: pd.Series) -> str: {"role": "user", "content": prompt}, ] else: - messages = [ - {"role": "user", "content": prompt} - ] + messages = [{"role": "user", "content": prompt}] response = await self.client.query_model( model=self.model_name, messages=messages, diff --git a/nemo_curator/stages/synthetic/nemotron_cc/nemo_data_designer/base.py b/nemo_curator/stages/synthetic/nemotron_cc/nemo_data_designer/base.py index fd1689d2f9..cdb68a8bfe 100644 --- a/nemo_curator/stages/synthetic/nemotron_cc/nemo_data_designer/base.py +++ b/nemo_curator/stages/synthetic/nemotron_cc/nemo_data_designer/base.py @@ -163,12 +163,12 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) raise ValueError(msg) df[_FORMATTED_PROMPT_COL] = df.apply( - lambda row: self._process_llm_prompt(row.to_dict()), axis=1, + lambda row: self._process_llm_prompt(row.to_dict()), + axis=1, ) pre_batch = DocumentBatch( data=df, dataset_name=batch.dataset_name, - task_id=batch.task_id, _metadata=batch._metadata, _stage_perf=batch._stage_perf, ) @@ -193,7 +193,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: return DocumentBatch( data=result_df, dataset_name=batch.dataset_name, - task_id=f"{batch.task_id}_{self.name}", _metadata=batch._metadata, _stage_perf=batch._stage_perf, ) diff --git a/nemo_curator/stages/synthetic/nemotron_cc/nemo_data_designer/nemotron_cc.py b/nemo_curator/stages/synthetic/nemotron_cc/nemo_data_designer/nemotron_cc.py index 7f12201208..d92b5ed790 100644 --- a/nemo_curator/stages/synthetic/nemotron_cc/nemo_data_designer/nemotron_cc.py +++ b/nemo_curator/stages/synthetic/nemotron_cc/nemo_data_designer/nemotron_cc.py @@ -72,5 +72,3 @@ class KnowledgeListStage(NDDBaseSyntheticStage): prompt: str = KNOWLEDGE_LIST_PROMPT_TEMPLATE input_field: str = "text" output_field: str = "knowledge_list" - - diff --git a/nemo_curator/stages/synthetic/nemotron_cc/nemotron_cc.py b/nemo_curator/stages/synthetic/nemotron_cc/nemotron_cc.py index 619eaed577..94a871b376 100644 --- a/nemo_curator/stages/synthetic/nemotron_cc/nemotron_cc.py +++ b/nemo_curator/stages/synthetic/nemotron_cc/nemotron_cc.py @@ -41,6 +41,7 @@ class WikipediaParaphrasingStage(BaseSyntheticStage): output_field: str = "rephrased" name: str = "WikipediaParaphrasing" + @dataclass class DiverseQAStage(BaseSyntheticStage): system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT @@ -114,11 +115,11 @@ def _format_row(row: pd.Series) -> str: return DocumentBatch( data=df, dataset_name=batch.dataset_name, - task_id=f"{batch.task_id}_{self.name}", _metadata=batch._metadata, _stage_perf=batch._stage_perf, ) + @dataclass class DistillStage(BaseSyntheticStage): system_prompt: str = NEMOTRON_CC_DISTILL_SYSTEM_PROMPT @@ -127,6 +128,7 @@ class DistillStage(BaseSyntheticStage): output_field: str = "distill" name: str = "Distill" + @dataclass class ExtractKnowledgeStage(BaseSyntheticStage): system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT @@ -135,6 +137,7 @@ class ExtractKnowledgeStage(BaseSyntheticStage): output_field: str = "extract_knowledge" name: str = "ExtractKnowledge" + @dataclass class KnowledgeListStage(BaseSyntheticStage): system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT @@ -143,6 +146,7 @@ class KnowledgeListStage(BaseSyntheticStage): output_field: str = "knowledge_list" name: str = "KnowledgeList" + @dataclass class KnowledgeListPostProcessingStage(ProcessingStage[DocumentBatch, DocumentBatch]): """ @@ -173,7 +177,6 @@ def _format_text(generated_text: str) -> str: return DocumentBatch( data=df, dataset_name=batch.dataset_name, - task_id=f"{batch.task_id}_{self.name}", _metadata=batch._metadata, _stage_perf=batch._stage_perf, ) diff --git a/nemo_curator/stages/synthetic/qa_multilingual_synthetic.py b/nemo_curator/stages/synthetic/qa_multilingual_synthetic.py index b677b5dce1..c9989d1b19 100644 --- a/nemo_curator/stages/synthetic/qa_multilingual_synthetic.py +++ b/nemo_curator/stages/synthetic/qa_multilingual_synthetic.py @@ -58,7 +58,7 @@ def setup(self, _: WorkerMetadata | None = None) -> None: def process(self, _: _EmptyTask) -> DocumentBatch: responses = self._process_async() if self.is_async_client else self._process_sync() - return DocumentBatch(data=pd.DataFrame({"text": responses}), dataset_name="simple_synthetic_data", task_id=1) + return DocumentBatch(data=pd.DataFrame({"text": responses}), dataset_name="simple_synthetic_data") def _process_llm_response(self, response: list[str]) -> str: """Process a single response from the LLM.""" diff --git a/nemo_curator/stages/text/classifiers/aegis.py b/nemo_curator/stages/text/classifiers/aegis.py index 3f817c86de..6ca27d2c8a 100644 --- a/nemo_curator/stages/text/classifiers/aegis.py +++ b/nemo_curator/stages/text/classifiers/aegis.py @@ -251,7 +251,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df = self._wrap_in_prompt(df) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -360,7 +359,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df = self._postprocess_responses(df) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/classifiers/utils.py b/nemo_curator/stages/text/classifiers/utils.py index fd7c63d290..b8e366ffe0 100644 --- a/nemo_curator/stages/text/classifiers/utils.py +++ b/nemo_curator/stages/text/classifiers/utils.py @@ -49,7 +49,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=output, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/deduplication/removal.py b/nemo_curator/stages/text/deduplication/removal.py index f177d306ab..45d48d9a4a 100644 --- a/nemo_curator/stages/text/deduplication/removal.py +++ b/nemo_curator/stages/text/deduplication/removal.py @@ -95,7 +95,6 @@ def process(self, task: DocumentBatch) -> DocumentBatch: # Create output batch with filtered data return DocumentBatch( - task_id=f"removal_{task.task_id}", dataset_name=task.dataset_name, data=df, _metadata={**task._metadata, "num_removed": len(removal_ids)}, diff --git a/nemo_curator/stages/text/download/base/download.py b/nemo_curator/stages/text/download/base/download.py index 38f91e7434..afad4bbc7d 100644 --- a/nemo_curator/stages/text/download/base/download.py +++ b/nemo_curator/stages/text/download/base/download.py @@ -150,7 +150,6 @@ def process(self, task: FileGroupTask) -> FileGroupTask: local_files.append(downloaded_file) return FileGroupTask( - task_id=task.task_id, dataset_name=task.dataset_name, data=local_files, _metadata={ diff --git a/nemo_curator/stages/text/download/base/iterator.py b/nemo_curator/stages/text/download/base/iterator.py index b84ebb6922..6be952a77d 100644 --- a/nemo_curator/stages/text/download/base/iterator.py +++ b/nemo_curator/stages/text/download/base/iterator.py @@ -143,7 +143,6 @@ def process(self, task: FileGroupTask) -> DocumentBatch: df = pd.DataFrame(records) return DocumentBatch( - task_id=task.task_id, dataset_name=task.dataset_name, data=df, _metadata={ diff --git a/nemo_curator/stages/text/download/base/url_generation.py b/nemo_curator/stages/text/download/base/url_generation.py index 278bf88ad0..b91accbdff 100644 --- a/nemo_curator/stages/text/download/base/url_generation.py +++ b/nemo_curator/stages/text/download/base/url_generation.py @@ -69,7 +69,6 @@ def process(self, task: _EmptyTask) -> list[FileGroupTask]: return [ FileGroupTask( - task_id=f"{task.task_id}_{i}", dataset_name=task.dataset_name, data=[url], _metadata={"source_url": url}, diff --git a/nemo_curator/stages/text/download/common_crawl/download.py b/nemo_curator/stages/text/download/common_crawl/download.py index c2abdc984c..a323c599cf 100644 --- a/nemo_curator/stages/text/download/common_crawl/download.py +++ b/nemo_curator/stages/text/download/common_crawl/download.py @@ -420,7 +420,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: logger.info(f"Dropped {dropped_count}/{initial_count} rows due to failed WARC fetch.") return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/embedders/vllm.py b/nemo_curator/stages/text/embedders/vllm.py index 7ec1009deb..f08ae4d63c 100644 --- a/nemo_curator/stages/text/embedders/vllm.py +++ b/nemo_curator/stages/text/embedders/vllm.py @@ -170,7 +170,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: self._log_metrics(metrics) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/experimental/__init__.py b/nemo_curator/stages/text/experimental/__init__.py index 84dca51c88..8802878ba6 100644 --- a/nemo_curator/stages/text/experimental/__init__.py +++ b/nemo_curator/stages/text/experimental/__init__.py @@ -17,4 +17,3 @@ APIs in this package are subject to change without the same compatibility guarantees as stable text stages. """ - diff --git a/nemo_curator/stages/text/experimental/translation/evaluation/faith.py b/nemo_curator/stages/text/experimental/translation/evaluation/faith.py index f92a5c1456..fcab31e533 100644 --- a/nemo_curator/stages/text/experimental/translation/evaluation/faith.py +++ b/nemo_curator/stages/text/experimental/translation/evaluation/faith.py @@ -217,7 +217,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df[col] = pd.Series(dtype="float64") df["faith_parse_failed"] = pd.Series(dtype="bool") return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -233,7 +232,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df = self._filter_rows(df) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -519,7 +517,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=filtered_df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/experimental/translation/evaluation/text_quality.py b/nemo_curator/stages/text/experimental/translation/evaluation/text_quality.py index f6c1def51f..a76a6a2020 100644 --- a/nemo_curator/stages/text/experimental/translation/evaluation/text_quality.py +++ b/nemo_curator/stages/text/experimental/translation/evaluation/text_quality.py @@ -81,7 +81,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: if not self.metrics: df[self.pass_column] = True return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -116,7 +115,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df = df[df[self.pass_column]].reset_index(drop=True) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/experimental/translation/stages/format_translation_output.py b/nemo_curator/stages/text/experimental/translation/stages/format_translation_output.py index 9288ac3278..c668d4b96d 100644 --- a/nemo_curator/stages/text/experimental/translation/stages/format_translation_output.py +++ b/nemo_curator/stages/text/experimental/translation/stages/format_translation_output.py @@ -89,7 +89,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df = df.drop(columns=columns_to_drop) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/experimental/translation/stages/merge_faith_scores.py b/nemo_curator/stages/text/experimental/translation/stages/merge_faith_scores.py index 61c6ae0951..43f7f40204 100644 --- a/nemo_curator/stages/text/experimental/translation/stages/merge_faith_scores.py +++ b/nemo_curator/stages/text/experimental/translation/stages/merge_faith_scores.py @@ -76,7 +76,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/experimental/translation/stages/reassembly.py b/nemo_curator/stages/text/experimental/translation/stages/reassembly.py index e8e1b2a4aa..a462659222 100644 --- a/nemo_curator/stages/text/experimental/translation/stages/reassembly.py +++ b/nemo_curator/stages/text/experimental/translation/stages/reassembly.py @@ -105,7 +105,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: if col not in out_df.columns: out_df[col] = pd.Series(dtype=_OUTPUT_COLUMN_DTYPES.get(col, "object")) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=out_df, _metadata=batch._metadata, @@ -122,7 +121,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: out_df = out_df.reset_index(drop=True) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=out_df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/experimental/translation/stages/segmentation.py b/nemo_curator/stages/text/experimental/translation/stages/segmentation.py index 61e696fe64..1bc73922d5 100644 --- a/nemo_curator/stages/text/experimental/translation/stages/segmentation.py +++ b/nemo_curator/stages/text/experimental/translation/stages/segmentation.py @@ -322,7 +322,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: out_df["_seg_metadata"] = pd.Series(dtype="object") out_df["_seg_doc_id"] = pd.Series(dtype="int64") return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=out_df, _metadata=batch._metadata, @@ -361,7 +360,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: out_df = out_df.reset_index(drop=True) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=out_df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/experimental/translation/stages/skipped_rows.py b/nemo_curator/stages/text/experimental/translation/stages/skipped_rows.py index 3a58afba7e..2bf5f0ee33 100644 --- a/nemo_curator/stages/text/experimental/translation/stages/skipped_rows.py +++ b/nemo_curator/stages/text/experimental/translation/stages/skipped_rows.py @@ -55,7 +55,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: len(df), ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=metadata, @@ -89,7 +88,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=remaining_df, _metadata=metadata, @@ -136,7 +134,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: if order_col in df.columns: df = df.drop(columns=[order_col]) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=metadata, @@ -168,7 +165,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=merged, _metadata=metadata, diff --git a/nemo_curator/stages/text/experimental/translation/stages/translate.py b/nemo_curator/stages/text/experimental/translation/stages/translate.py index 48701ad0e2..08737973ac 100644 --- a/nemo_curator/stages/text/experimental/translation/stages/translate.py +++ b/nemo_curator/stages/text/experimental/translation/stages/translate.py @@ -185,7 +185,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df["_translation_time"] = [0.0] * len(segments) df["_translation_error"] = [""] * len(segments) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -202,7 +201,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df["_translation_error"] = errors return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/filters/score_filter.py b/nemo_curator/stages/text/filters/score_filter.py index 5ea1b76825..16af81f794 100644 --- a/nemo_curator/stages/text/filters/score_filter.py +++ b/nemo_curator/stages/text/filters/score_filter.py @@ -112,7 +112,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch | None: # Create output batch return DocumentBatch( - task_id=f"{batch.task_id}_{self.name}", dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -209,7 +208,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch | None: # Create output batch return DocumentBatch( - task_id=f"{batch.task_id}_{self.name}", dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -339,7 +337,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch | None: # Create output batch return DocumentBatch( - task_id=f"{batch.task_id}_{self.name}", dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/io/reader/base.py b/nemo_curator/stages/text/io/reader/base.py index 571c116eb5..551dc463be 100644 --- a/nemo_curator/stages/text/io/reader/base.py +++ b/nemo_curator/stages/text/io/reader/base.py @@ -98,7 +98,6 @@ def process(self, task: FileGroupTask) -> DocumentBatch: result = self._assign_ids_func(task.data, result) return DocumentBatch( - task_id=f"{task.task_id}_processed", dataset_name=task.dataset_name, data=result, _metadata=task._metadata, diff --git a/nemo_curator/stages/text/io/writer/base.py b/nemo_curator/stages/text/io/writer/base.py index f6008357d2..316dfd852c 100644 --- a/nemo_curator/stages/text/io/writer/base.py +++ b/nemo_curator/stages/text/io/writer/base.py @@ -20,11 +20,11 @@ from fsspec.core import url_to_fs from loguru import logger -import nemo_curator.stages.text.io.writer.utils as writer_utils from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import DocumentBatch, FileGroupTask from nemo_curator.utils.client_utils import is_remote_url from nemo_curator.utils.file_utils import check_output_mode +from nemo_curator.utils.hash_utils import get_deterministic_hash @dataclass @@ -74,7 +74,7 @@ def process(self, task: DocumentBatch) -> FileGroupTask: """ # Get source files from metadata for deterministic naming if source_files := task._metadata.get("source_files"): - filename = writer_utils.get_deterministic_hash(source_files, task.task_id) + filename = get_deterministic_hash(source_files, task.task_id) else: logger.warning("The task does not have source_files in metadata, using UUID for base filename") filename = uuid.uuid4().hex @@ -94,7 +94,6 @@ def process(self, task: DocumentBatch) -> FileGroupTask: # Create FileGroupTask with written files using the full protocol-prefixed path return FileGroupTask( - task_id=task.task_id, dataset_name=task.dataset_name, data=[file_path_with_protocol], _metadata={ diff --git a/nemo_curator/stages/text/io/writer/megatron_tokenizer.py b/nemo_curator/stages/text/io/writer/megatron_tokenizer.py index 982d3e4d58..d14911adb5 100644 --- a/nemo_curator/stages/text/io/writer/megatron_tokenizer.py +++ b/nemo_curator/stages/text/io/writer/megatron_tokenizer.py @@ -21,10 +21,10 @@ from loguru import logger from transformers import AutoTokenizer -import nemo_curator.stages.text.io.writer.utils as writer_utils from nemo_curator.backends.base import NodeInfo, WorkerMetadata from nemo_curator.tasks import DocumentBatch, FileGroupTask from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS +from nemo_curator.utils.hash_utils import get_deterministic_hash from .base import BaseWriter from .utils import batched @@ -71,10 +71,7 @@ def setup_on_node(self, _node_info: NodeInfo | None = None, _worker_metadata: Wo try: # download the relevant tokenizer files once _ = AutoTokenizer.from_pretrained( - self.model_identifier, - cache_dir=self.cache_dir, - token=self.hf_token, - **self.transformers_init_kwargs + self.model_identifier, cache_dir=self.cache_dir, token=self.hf_token, **self.transformers_init_kwargs ) except Exception as e: msg = f"Failed to download {self.model_identifier}" @@ -84,10 +81,7 @@ def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: # Load the tokenizer try: self.tokenizer = AutoTokenizer.from_pretrained( - self.model_identifier, - cache_dir=self.cache_dir, - local_files_only=True, - **self.transformers_init_kwargs + self.model_identifier, cache_dir=self.cache_dir, local_files_only=True, **self.transformers_init_kwargs ) except Exception as e: # noqa: BLE001 # Allow this fallback since loading a tokenizer is lightweight @@ -95,17 +89,14 @@ def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: logger.warning(msg) self.tokenizer = AutoTokenizer.from_pretrained( - self.model_identifier, - cache_dir=self.cache_dir, - token=self.hf_token, - **self.transformers_init_kwargs + self.model_identifier, cache_dir=self.cache_dir, token=self.hf_token, **self.transformers_init_kwargs ) def process(self, task: DocumentBatch) -> FileGroupTask: sequence_lengths: list[int] = [] # Get source files from metadata for deterministic naming if source_files := task._metadata.get("source_files"): - filename = writer_utils.get_deterministic_hash(source_files, task.task_id) + filename = get_deterministic_hash(source_files, task.task_id) else: logger.warning("The task does not have source_files in metadata, using UUID for base filename") filename = uuid.uuid4().hex @@ -161,7 +152,6 @@ def process(self, task: DocumentBatch) -> FileGroupTask: logger.debug(f"Written batch to {file_prefix} with {num_docs} documents ({sum(sequence_lengths)} tokens)") return FileGroupTask( - task_id=task.task_id, dataset_name=task.dataset_name, data=[file_prefix + file_extension for file_extension in self.file_extension], _metadata={ diff --git a/nemo_curator/stages/text/io/writer/utils.py b/nemo_curator/stages/text/io/writer/utils.py index 59c1b6c3b1..ce1e2f1310 100644 --- a/nemo_curator/stages/text/io/writer/utils.py +++ b/nemo_curator/stages/text/io/writer/utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib from collections.abc import Iterable, Iterator from itertools import islice from typing import Any @@ -35,9 +34,3 @@ def batched(iterable: Iterable[Any], n: int) -> Iterator[tuple[Any, ...]]: it = iter(iterable) while batch := tuple(islice(it, n)): yield batch - - -def get_deterministic_hash(inputs: list[str], seed: str = "") -> str: - """Create a deterministic hash from inputs.""" - combined = "|".join(sorted(inputs)) + "|" + seed - return hashlib.sha256(combined.encode()).hexdigest()[:12] diff --git a/nemo_curator/stages/text/models/model.py b/nemo_curator/stages/text/models/model.py index ba4d063ea9..6997fdb0bd 100644 --- a/nemo_curator/stages/text/models/model.py +++ b/nemo_curator/stages/text/models/model.py @@ -197,7 +197,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df_cpu = df_cpu.sort_values(by=SEQ_ORDER_FIELD, ignore_index=True).drop(columns=[SEQ_ORDER_FIELD]) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df_cpu, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/models/tokenizer.py b/nemo_curator/stages/text/models/tokenizer.py index 00d7c7ab9b..01819405ca 100644 --- a/nemo_curator/stages/text/models/tokenizer.py +++ b/nemo_curator/stages/text/models/tokenizer.py @@ -124,7 +124,10 @@ def setup_on_node(self, _node_info: NodeInfo | None = None, _worker_metadata: Wo @lru_cache(maxsize=1) # noqa: B019 def load_cfg(self, local_files_only: bool = True) -> AutoConfig: return AutoConfig.from_pretrained( - self.model_identifier, cache_dir=self.cache_dir, local_files_only=local_files_only, **self.transformers_init_kwargs + self.model_identifier, + cache_dir=self.cache_dir, + local_files_only=local_files_only, + **self.transformers_init_kwargs, ) # We use the _setup function to ensure that everything needed for the tokenizer is downloaded and loaded properly @@ -134,7 +137,7 @@ def _setup(self, local_files_only: bool = True) -> None: padding_side=self.padding_side, cache_dir=self.cache_dir, local_files_only=local_files_only, - **self.transformers_init_kwargs + **self.transformers_init_kwargs, ) if self.unk_token: self.tokenizer.pad_token = self.tokenizer.unk_token @@ -180,7 +183,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: ) return DocumentBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=output, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/modifiers/modifier.py b/nemo_curator/stages/text/modifiers/modifier.py index ee14710eec..71add6f428 100644 --- a/nemo_curator/stages/text/modifiers/modifier.py +++ b/nemo_curator/stages/text/modifiers/modifier.py @@ -90,7 +90,6 @@ def process(self, batch: DocumentBatch) -> DocumentBatch | None: df[output_field_i] = [inner_modify_fn(**rec) for rec in df[cols].to_dict("records")] return DocumentBatch( - task_id=f"{batch.task_id}_{self.name}", dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/modules/add_id.py b/nemo_curator/stages/text/modules/add_id.py index f04a824652..9dea418d0a 100644 --- a/nemo_curator/stages/text/modules/add_id.py +++ b/nemo_curator/stages/text/modules/add_id.py @@ -68,13 +68,12 @@ def process(self, batch: DocumentBatch) -> DocumentBatch | None: msg = f"Column '{self.id_field}' already exists. Set overwrite=True to replace it." raise ValueError(msg) - uuid_part = str(batch._uuid) + uuid_part = batch.task_id prefix = f"{self.id_prefix}_{uuid_part}" if self.id_prefix else uuid_part df[self.id_field] = [f"{prefix}_{i}" for i in range(len(df))] # Create output batch return DocumentBatch( - task_id=f"{batch.task_id}_{self.name}", dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, diff --git a/nemo_curator/stages/text/modules/joiner.py b/nemo_curator/stages/text/modules/joiner.py index d706a4b3b2..f6a3fcd923 100644 --- a/nemo_curator/stages/text/modules/joiner.py +++ b/nemo_curator/stages/text/modules/joiner.py @@ -184,10 +184,8 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: joined = joined.drop(columns=self.segment_id_field) return DocumentBatch( - task_id=f"{batch.task_id}_{self.name}", dataset_name=batch.dataset_name, data=joined, _metadata=batch._metadata, _stage_perf=batch._stage_perf, ) - diff --git a/nemo_curator/stages/text/modules/splitter.py b/nemo_curator/stages/text/modules/splitter.py index 174af34508..466da6d55d 100644 --- a/nemo_curator/stages/text/modules/splitter.py +++ b/nemo_curator/stages/text/modules/splitter.py @@ -85,10 +85,8 @@ def process(self, batch: DocumentBatch) -> DocumentBatch: df = df.drop(columns=["_split_text"]).reset_index(drop=True) return DocumentBatch( - task_id=f"{batch.task_id}_{self.name}", dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, _stage_perf=batch._stage_perf, ) - diff --git a/nemo_curator/stages/video/clipping/clip_extraction_stages.py b/nemo_curator/stages/video/clipping/clip_extraction_stages.py index 6814f384b7..28d20ea9ef 100644 --- a/nemo_curator/stages/video/clipping/clip_extraction_stages.py +++ b/nemo_curator/stages/video/clipping/clip_extraction_stages.py @@ -149,7 +149,6 @@ def process(self, task: VideoTask) -> VideoTask: for idx in range(len(clip_chunks)): # create subtask for each video task subtask = VideoTask( - task_id=f"{task.task_id}_chunk_{idx}", dataset_name=task.dataset_name, data=Video( input_video=video.input_video, diff --git a/nemo_curator/stages/video/io/video_reader.py b/nemo_curator/stages/video/io/video_reader.py index b78145474f..9200e696f3 100644 --- a/nemo_curator/stages/video/io/video_reader.py +++ b/nemo_curator/stages/video/io/video_reader.py @@ -89,7 +89,6 @@ def process(self, task: FileGroupTask) -> VideoTask: raise ValueError(msg) video = Video(input_video=task.data[0]) video_task = VideoTask( - task_id=f"{task.data[0]}_processed", dataset_name=task.dataset_name, data=video, _metadata=deepcopy(task._metadata), diff --git a/nemo_curator/tasks/file_group.py b/nemo_curator/tasks/file_group.py index e3d5beb226..3b3463b515 100644 --- a/nemo_curator/tasks/file_group.py +++ b/nemo_curator/tasks/file_group.py @@ -17,6 +17,8 @@ from loguru import logger +from nemo_curator.utils.hash_utils import get_deterministic_hash + from .tasks import Task @@ -46,3 +48,10 @@ def validate(self) -> bool: err = f"Invalid data type in task {self.task_id}" raise TypeError(err) return True + + def get_deterministic_id(self) -> str: + """Content-based id derived from the sorted file paths. Stable + across runs even if the source stage emits the file group at a + different position (e.g. because new files were added or removed + between runs).""" + return get_deterministic_hash(sorted(self.data)) diff --git a/nemo_curator/tasks/tasks.py b/nemo_curator/tasks/tasks.py index b2836415c1..f3efb20286 100644 --- a/nemo_curator/tasks/tasks.py +++ b/nemo_curator/tasks/tasks.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Generic, TypeVar @@ -25,21 +24,35 @@ @dataclass class Task(ABC, Generic[T]): """Abstract base class for tasks in the pipeline. + A task represents a batch of data to be processed. Different modalities (text, audio, video) can implement their own task types. + Attributes: - task_id: Unique identifier for this task - dataset_name: Name of the dataset this task belongs to - dataframe_attribute: Name of the attribute that contains the dataframe data. We use this for input/output validations. - _stage_perf: List of stages perfs this task has passed through + task_id: Deterministic identifier for this task. NOT user-settable — + the framework assigns it via ``_set_task_id`` at every stage + boundary. It is an underscore-joined id path through the pipeline + DAG — the parents' ids plus this task's own segment (e.g. + ``"abc123_0_5"`` = source ``abc123``, then child 0, then + grandchild 5). Using the readable path directly (rather than a + hash of it) keeps task ids easy to debug. Empty string until the + first stage runs; two runs of the same pipeline on the same + inputs produce byte-identical ``task_id``s across all tasks. + + A ``task_id`` that starts with ``"r"`` (followed by a uuid) is a + fallback assigned when the parent→child mapping could NOT be + derived — e.g. a stage that overrides ``process_batch`` with an + ambiguous batch fan-out (M inputs → K≠M outputs). Such ids are + NON-deterministic (differ across runs). + dataset_name: Name of the dataset this task belongs to. + _stage_perf: List of stages perfs this task has passed through. """ - task_id: str dataset_name: str data: T _stage_perf: list[StagePerfStats] = field(default_factory=list) _metadata: dict[str, Any] = field(default_factory=dict) - _uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())) + task_id: str = field(init=False, default="") def __post_init__(self) -> None: """Post-initialization hook.""" @@ -54,6 +67,52 @@ def add_stage_perf(self, perf_stats: StagePerfStats) -> None: """Add performance stats for a stage.""" self._stage_perf.append(perf_stats) + def _set_task_id(self, parent_task_id: str, current_task_id_suffix: str | int) -> None: + """Assign this task's deterministic ``task_id`` from its parent. + + The ``task_id`` is the parent id and this task's own segment joined + by ``"_"`` — e.g. parent ``"abc123"`` + suffix ``0`` → + ``"abc123_0"``. Always overwrites ``task_id``; there is no + idempotency check — each stage transition re-derives it, so the + same physical Python object passing through N stages gets N + distinct ``task_id``s (one per stage boundary). The dedup keys + used by resumability are captured BEFORE this method runs on a + given output, so the rewrite is safe. + + Only a single parent id is taken: the supported mappings (1→1, + 1→N fan-out, N→N positional) each give an output exactly one + parent. N→1 aggregations don't track ancestry — those outputs get + a random ``"r"``-prefixed id in the adapter instead of calling this. + + Args: + parent_task_id: ``task_id`` of the parent. An empty string + (an unassigned / EmptyTask parent) is dropped so it doesn't + contribute a leading ``"_"`` to the path. + current_task_id_suffix: This task's own segment of the id + path — appended after the parent id. Either a positional + index (``int`` → coerced to ``str``) for plain emissions, + or a string id (e.g. a content-based hash from + :py:meth:`get_deterministic_id`) for source-stage emissions + where stability across input reordering matters. + """ + if parent_task_id: + self.task_id = f"{parent_task_id}_{current_task_id_suffix}" + else: + self.task_id = str(current_task_id_suffix) + + def get_deterministic_id(self) -> str | None: + """Return a content-based identifier for this task as a source, + or ``None`` to fall back to the positional index. + + Override in subclasses that have stable content. The canonical + example is :class:`FileGroupTask`, which hashes its sorted file + paths so that adding or removing files between runs doesn't shift + the identifiers of unchanged source partitions. + + Only called by source-stage adapters; non-source stages ignore + this and always use positional indices.""" + return None + def __repr__(self) -> str: subclass_name = self.__class__.__name__ return f"{subclass_name}(task_id={self.task_id}, dataset_name={self.dataset_name})" @@ -65,7 +124,15 @@ def validate(self) -> bool: @dataclass class _EmptyTask(Task[None]): - """Dummy task for testing.""" + """Placeholder input that seeds a pipeline (e.g. for ``ls``/source stages). + + Its ``task_id`` is fixed to ``"0"`` — the implicit root that every task + in a run descends from, so all ``task_id``s + share the ``"0"`` prefix (source partitions become ``"0_"``, + user-provided initial tasks become ``"0_0"``, ``"0_1"``, …). + """ + + task_id: str = field(init=False, default="0") @property def num_items(self) -> int: @@ -77,4 +144,4 @@ def validate(self) -> bool: # Empty tasks are just used for `ls` stages -EmptyTask = _EmptyTask(task_id="empty", dataset_name="empty", data=None) +EmptyTask = _EmptyTask(dataset_name="empty", data=None) diff --git a/nemo_curator/utils/column_utils.py b/nemo_curator/utils/column_utils.py index bae97ef275..065e3b52b7 100644 --- a/nemo_curator/utils/column_utils.py +++ b/nemo_curator/utils/column_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + def resolve_filename_column(add_filename_column: bool | str) -> str | None: """Resolve the filename column name based on the input parameter. diff --git a/nemo_curator/utils/hash_utils.py b/nemo_curator/utils/hash_utils.py new file mode 100644 index 0000000000..d7875474ac --- /dev/null +++ b/nemo_curator/utils/hash_utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib + + +def get_deterministic_hash(inputs: list[str], seed: str = "") -> str: + """Create a deterministic hash from inputs. + + Lives in ``nemo_curator.utils`` (not under ``stages/text``) so that + non-text modalities can use it without pulling in text dependencies. + """ + combined = "|".join(sorted(inputs)) + "|" + seed + return hashlib.sha256(combined.encode()).hexdigest()[:12] diff --git a/nemo_curator/utils/split_large_files.py b/nemo_curator/utils/split_large_files.py index 41b7e35045..0e9fd73f40 100644 --- a/nemo_curator/utils/split_large_files.py +++ b/nemo_curator/utils/split_large_files.py @@ -105,7 +105,11 @@ def split_parquet_file_by_size( if row_group.nbytes > target_size_bytes: # Flush any pending small row groups first to preserve order. if row_groups_to_write: - sub_table = row_groups_to_write[0] if len(row_groups_to_write) == 1 else pa.concat_tables(row_groups_to_write) + sub_table = ( + row_groups_to_write[0] + if len(row_groups_to_write) == 1 + else pa.concat_tables(row_groups_to_write) + ) out_file = _join_out_path(output_path, f"{outfile_prefix}_{file_idx}{ext}", so) _write_table_to_file(sub_table, out_file, so) file_idx += 1 @@ -129,7 +133,9 @@ def split_parquet_file_by_size( row_group_idx += 1 if row_groups_to_write: - sub_table = row_groups_to_write[0] if len(row_groups_to_write) == 1 else pa.concat_tables(row_groups_to_write) + sub_table = ( + row_groups_to_write[0] if len(row_groups_to_write) == 1 else pa.concat_tables(row_groups_to_write) + ) out_file = _join_out_path(output_path, f"{outfile_prefix}_{file_idx}{ext}", so) _write_table_to_file(sub_table, out_file, so) file_idx += 1 @@ -207,7 +213,9 @@ def parse_args(args: argparse.ArgumentParser | None = None) -> argparse.Namespac parser.add_argument( "--input-path", type=str, required=True, help="Path to input file, or directory of files, to split" ) - parser.add_argument("--file-type", type=str, required=True, help="Type of file to split", choices=["parquet", "jsonl"]) + parser.add_argument( + "--file-type", type=str, required=True, help="Type of file to split", choices=["parquet", "jsonl"] + ) parser.add_argument("--output-path", type=str, required=True, help="Output directory to store split files") parser.add_argument("--target-size-mb", type=int, default=128, help="Target size (in MB) of split output files") parser.add_argument( @@ -242,7 +250,7 @@ def main(args: argparse.ArgumentParser | None = None) -> None: target_size_mb=args.target_size_mb, storage_options=storage_options, ) - for f in files + for f in files ] ) diff --git a/tests/backends/ray_data/test_max_calls_pid.py b/tests/backends/ray_data/test_max_calls_pid.py index e53530c6d8..694d72140d 100644 --- a/tests/backends/ray_data/test_max_calls_pid.py +++ b/tests/backends/ray_data/test_max_calls_pid.py @@ -96,7 +96,6 @@ def ray_stage_spec(self) -> dict: def process(self, task: DocumentBatch) -> DocumentBatch: return DocumentBatch( - task_id=task.task_id, dataset_name=task.dataset_name, data=pd.DataFrame({"worker_pid": [os.getpid()]}), ) diff --git a/tests/backends/test_integration.py b/tests/backends/test_integration.py index f2d046e8c4..b5d3d453f3 100644 --- a/tests/backends/test_integration.py +++ b/tests/backends/test_integration.py @@ -136,6 +136,47 @@ def test_output_tasks(self): "Mismatch in dataset names" ) + def test_task_ids(self): + """task_ids are deterministic id paths assigned as tasks flow + through the pipeline. We can't assert exact id strings — the source + partitions hash temp-dir file paths, so the hash varies per run — but + we can pin the structure: rooted at the EmptyTask root "0", a stable + content-hash source segment, then a clean underscore-joined path.""" + assert self.output_tasks is not None, "Expected output tasks" + + task_ids = [task.task_id for task in self.output_tasks] + + # Every task that made it through the pipeline has an id assigned. + assert all(task_ids), "Every output task should have a non-empty task_id" + + # Task ids are unique per task. + assert len(set(task_ids)) == len(task_ids), "task_ids should be unique" + + for tid in task_ids: + # Clean "_"-joined id path: no empty segments, so no leading/ + # trailing/double underscores (the source's EmptyTask parent, + # whose id is "", is filtered out by _set_task_id). + segments = tid.split("_") + assert all(segments), f"task_id {tid!r} has an empty id segment" + + # Every id descends from the implicit EmptyTask root "0". + assert segments[0] == "0", f"task_id {tid!r} is not rooted at '0'" + + # The source stage (FilePartitioningStage) stamps each partition's + # FileGroupTask content hash as the second segment; downstream + # stages append positional indices. The hash is a 12-char hex + # string (see tests/tasks/test_file_group_tasks.py); its value is + # not knowable a priori, so we assert its shape. + source_hash = segments[1] + assert len(source_hash) == 12, f"task_id {tid!r} source-hash segment {source_hash!r} is not 12 chars" + assert all(c in "0123456789abcdef" for c in source_hash), ( + f"task_id {tid!r} source-hash segment {source_hash!r} is not hex" + ) + + # All partitions share the same downstream id structure: every id has + # the same number of segments (one per stage boundary it crossed). + assert len({len(tid.split("_")) for tid in task_ids}) == 1, "task_ids should all have the same path depth" + def test_perf_stats(self): """Test that performance statistics are correctly recorded for all stages.""" # Check content of stage perf stats diff --git a/tests/backends/test_task_id_postprocess.py b/tests/backends/test_task_id_postprocess.py new file mode 100644 index 0000000000..2bfebab8ce --- /dev/null +++ b/tests/backends/test_task_id_postprocess.py @@ -0,0 +1,129 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for ``BaseStageAdapter._post_process_task_ids`` — the single +place every backend assigns a deterministic ``task_id`` to emitted tasks. + +The happy-path flow (fan-out, 1:1, source content ids) is exercised +end-to-end against real backends in tests/backends/test_integration.py +(``test_task_ids``). This file keeps only the cases that are awkward or +impossible to trigger through a real pipeline: filter-``None`` positional +alignment, the ambiguous-cardinality ``"r"``-uuid fallback, in-place +re-derivation, and source content-id vs. positional-index selection.""" + +from dataclasses import dataclass + +from nemo_curator.backends.base import BaseStageAdapter +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import FileGroupTask, Task, _EmptyTask + + +@dataclass +class _NoopStage(ProcessingStage[Task, Task]): + name: str = "noop" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: Task) -> Task: + return task + + +@dataclass +class _SimpleTask(Task[list[int]]): + @property + def num_items(self) -> int: + return 0 + + def validate(self) -> bool: + return True + + +def _task(task_id: str = "") -> _SimpleTask: + t = _SimpleTask(dataset_name="d", data=[]) + t.task_id = task_id + return t + + +def _assign(tasks: list[Task], results: list[Task | None], *, is_source: bool = False) -> list[Task]: + stage = _NoopStage() + stage.is_source_stage = is_source + return BaseStageAdapter(stage)._post_process_task_ids(tasks, results) + + +class TestPostProcessTaskIds: + def test_filter_stage_keeps_positional_alignment(self) -> None: + # A filter stage returns None in the filtered slot. None is NOT + # dropped before the length check, so the surviving outputs still map + # to their OWN parents (not shifted), then None slots are removed. + p0, p1, p2 = _task("0_0"), _task("0_1"), _task("0_2") + c0, c2 = _task(), _task() + out = _assign([p0, p1, p2], [c0, None, c2]) + assert out == [c0, c2] + assert c0.task_id == "0_0_0" # child of p0, not shifted + assert c2.task_id == "0_2_0" # child of p2, not p1 + + def test_in_place_return_is_reassigned(self) -> None: + # A 1:1 stage that returns its input unchanged still gets a fresh + # segment appended (ids are re-derived at each stage boundary). + t = _task("0_5") + out = _assign([t], [t]) + assert out == [t] + assert t.task_id == "0_5_0" + + def test_ambiguous_batch_fanout_falls_back_to_uuid(self) -> None: + # M inputs → K outputs (K != M, M > 1): mapping is ambiguous, so each + # output gets a random uuid rather than being left empty. + p0, p1 = _task("0_0"), _task("0_1") + c0, c1, c2 = _task(), _task(), _task() + out = _assign([p0, p1], [c0, c1, c2]) + assert len(out) == 3 + assert all(t.task_id for t in out), "no output should be left without an id" + assert len({t.task_id for t in out}) == 3, "uuid ids should be unique" + # Non-deterministic fallback ids are flagged with an "r" prefix. + assert all(t.task_id.startswith("r") for t in out) + assert all("_" not in t.task_id for t in out) + + +class TestSourceStage: + def test_uses_content_id_rooted_at_input(self) -> None: + # FileGroupTask.get_deterministic_id() hashes its files; the source + # output is rooted at the EmptyTask input id "0" → "0_". + empty = _EmptyTask(dataset_name="empty", data=None) + a = FileGroupTask(dataset_name="d", data=["a.parquet"]) + b = FileGroupTask(dataset_name="d", data=["b.parquet"]) + _assign([empty], [a, b], is_source=True) + assert a.task_id == f"0_{a.get_deterministic_id()}" + assert b.task_id == f"0_{b.get_deterministic_id()}" + + def test_n_to_n_source_parents_each_output_by_position(self) -> None: + # A source stage can also be N→N (each input → one partition). Each + # output must descend from ITS positional parent, not all from + # tasks[0]; the content id is the segment. + p0, p1 = _task("0_0"), _task("0_1") + a = FileGroupTask(dataset_name="d", data=["a.parquet"]) + b = FileGroupTask(dataset_name="d", data=["b.parquet"]) + _assign([p0, p1], [a, b], is_source=True) + assert a.task_id == f"0_0_{a.get_deterministic_id()}" + assert b.task_id == f"0_1_{b.get_deterministic_id()}" + + def test_non_source_stage_ignores_content_id(self) -> None: + # The same FileGroupTask outputs from a NON-source stage use the + # positional index, not the content id. + parent = _task("0_2") + a = FileGroupTask(dataset_name="d", data=["a.parquet"]) + _assign([parent], [a], is_source=False) + assert a.task_id == "0_2_0" diff --git a/tests/backends/utils.py b/tests/backends/utils.py index c6aef281e8..9b019b2340 100644 --- a/tests/backends/utils.py +++ b/tests/backends/utils.py @@ -120,7 +120,6 @@ def process_batch(self, tasks: list[DocumentBatch]) -> list[DocumentBatch]: df[self.column_name] = df["text"].apply(len) results.append( DocumentBatch( - task_id=input_data.task_id, dataset_name=input_data.dataset_name, data=df, _metadata=input_data._metadata, @@ -180,7 +179,6 @@ def process(self, input_data: DocumentBatch) -> list[DocumentBatch]: row_df = pd.DataFrame([row.to_dict()]) tasks.append( DocumentBatch( - task_id=f"{input_data.task_id}_row_{row['id']}", dataset_name=input_data.dataset_name, data=row_df, _metadata=input_metadata_without_source_files, @@ -232,7 +230,6 @@ def process(self, input_data: DocumentBatch) -> DocumentBatch: df["node_id"] = self.node_id df["random_string"] = self.random_str return DocumentBatch( - task_id=input_data.task_id, dataset_name=input_data.dataset_name, data=df, _metadata=input_data._metadata, diff --git a/tests/core/serve/dynamo/test_integration.py b/tests/core/serve/dynamo/test_integration.py index 1707fc7ee0..21e931da7d 100644 --- a/tests/core/serve/dynamo/test_integration.py +++ b/tests/core/serve/dynamo/test_integration.py @@ -194,7 +194,6 @@ def test_pipeline_gpu_stage_uses_different_gpu_than_inference( initial_tasks = [ DocumentBatch( - task_id=f"gpu-sep-{i}", dataset_name="dynamo-coexistence", data=pd.DataFrame({"text": [f"hello {i}"]}), ) diff --git a/tests/core/serve/ray_serve/test_integration.py b/tests/core/serve/ray_serve/test_integration.py index c5f974b051..328065e929 100644 --- a/tests/core/serve/ray_serve/test_integration.py +++ b/tests/core/serve/ray_serve/test_integration.py @@ -140,7 +140,6 @@ def test_pipeline_gpu_stage_uses_different_gpu_than_inference( initial_tasks = [ DocumentBatch( - task_id=f"gpu-sep-{i}", dataset_name="ray-serve-coexistence", data=pd.DataFrame({"text": [f"hello {i}"]}), ) diff --git a/tests/pipelines/test_per_stage_runtime_env.py b/tests/pipelines/test_per_stage_runtime_env.py index 8ad3dc2260..bd71612fa1 100644 --- a/tests/pipelines/test_per_stage_runtime_env.py +++ b/tests/pipelines/test_per_stage_runtime_env.py @@ -65,7 +65,6 @@ def process(self, task: DocumentBatch) -> DocumentBatch: batch[f"{self.name}_version"] = packaging.__version__ batch[f"{self.name}_loguru_available"] = loguru_available return DocumentBatch( - task_id=task.task_id, dataset_name=task.dataset_name, data=batch, _metadata=task._metadata, @@ -75,7 +74,6 @@ def process(self, task: DocumentBatch) -> DocumentBatch: def _make_initial_task() -> DocumentBatch: return DocumentBatch( - task_id="runtime_env_test", dataset_name="test", data=pd.DataFrame({"text": ["hello"]}), ) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 87cc23e324..38511a6607 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -12,13 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from unittest.mock import Mock, patch import pytest -from nemo_curator.pipeline.pipeline import Pipeline +from nemo_curator.pipeline.pipeline import Pipeline, assign_root_task_ids from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import EmptyTask, Task, _EmptyTask + + +@dataclass +class _NoopStage(ProcessingStage[Task, Task]): + name: str = "noop" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: Task) -> Task: + return task + + +@dataclass +class _SimpleTask(Task[list[int]]): + @property + def num_items(self) -> int: + return len(self.data) if self.data is not None else 0 + + def validate(self) -> bool: + return True def test_pipeline_uses_xenna_executor_by_default(): @@ -69,3 +95,65 @@ def test_raises_when_ray_serve_active_with_xenna_and_gpu_stages() -> None: with pytest.raises(RuntimeError, match="Cannot run XennaExecutor"): pipeline.run(executor=mock_executor) + + +class TestPipelineBuild: + """Source/sink role assignment performed by ``Pipeline.build``.""" + + def test_default_first_source_last_sink_stage(self) -> None: + """With no explicit marks, the first stage is the source and the + last is the sink; a lone stage is both.""" + s0, s1, s2 = _NoopStage(name="s0"), _NoopStage(name="s1"), _NoopStage(name="s2") + Pipeline(name="t", stages=[s0, s1, s2]).build() + assert [s.is_source_stage for s in (s0, s1, s2)] == [True, False, False] + assert [s.is_sink_stage for s in (s0, s1, s2)] == [False, False, True] + + lone = _NoopStage(name="lone") + Pipeline(name="t", stages=[lone]).build() + assert lone.is_source_stage is True + assert lone.is_sink_stage is True + + def test_explicit_marks_override_defaults(self) -> None: + s0, s1, s2 = _NoopStage(name="s0"), _NoopStage(name="s1"), _NoopStage(name="s2") + s1.is_source_stage = True + s1.is_sink_stage = True + Pipeline(name="t", stages=[s0, s1, s2]).build() + # Explicit source/sink win; defaults are not applied elsewhere. + assert [s.is_source_stage for s in (s0, s1, s2)] == [False, True, False] + assert [s.is_sink_stage for s in (s0, s1, s2)] == [False, True, False] + + def test_multiple_explicit_marks_raise(self) -> None: + s0, s1 = _NoopStage(name="s0"), _NoopStage(name="s1") + s0.is_source_stage = True + s1.is_source_stage = True + with pytest.raises(ValueError, match="multiple source stages marked"): + Pipeline(name="t", stages=[s0, s1]).build() + + t0, t1 = _NoopStage(name="t0"), _NoopStage(name="t1") + t0.is_sink_stage = True + t1.is_sink_stage = True + with pytest.raises(ValueError, match="multiple sink stages marked"): + Pipeline(name="t", stages=[t0, t1]).build() + + +class TestRootTaskIds: + """``assign_root_task_ids`` roots user-provided initial tasks under the + implicit ``_EmptyTask`` root id ``"0"``.""" + + def test_empty_task_id_is_zero(self) -> None: + assert EmptyTask.task_id == "0" + assert _EmptyTask(dataset_name="d", data=None).task_id == "0" + + def test_roots_user_tasks_at_zero(self) -> None: + tasks = [_SimpleTask(dataset_name="d", data=[1]) for _ in range(3)] + assign_root_task_ids(tasks) + # User-provided initial tasks are children of root "0", by position. + assert [t.task_id for t in tasks] == ["0_0", "0_1", "0_2"] + + def test_skips_empty_tasks(self) -> None: + et = _EmptyTask(dataset_name="d", data=None) + real = _SimpleTask(dataset_name="d", data=[1]) + assign_root_task_ids([et, real]) + # EmptyTask stays "0"; the real task is rooted by its position. + assert et.task_id == "0" + assert real.task_id == "0_1" diff --git a/tests/stages/audio/advanced_pipelines/test_audio_data_filter.py b/tests/stages/audio/advanced_pipelines/test_audio_data_filter.py index 2bab453a7e..6de8773e87 100644 --- a/tests/stages/audio/advanced_pipelines/test_audio_data_filter.py +++ b/tests/stages/audio/advanced_pipelines/test_audio_data_filter.py @@ -196,22 +196,26 @@ def test_decompose_all_enabled_stage_count(self) -> None: assert len(stages) == 12 def test_decompose_all_disabled_except_mono(self) -> None: - stage = AudioDataFilterStage(config={ - "vad": {"enable": False}, - "band_filter": {"enable": False}, - "utmos": {"enable": False}, - "sigmos": {"enable": False}, - "speaker_separation": {"enable": False}, - }) + stage = AudioDataFilterStage( + config={ + "vad": {"enable": False}, + "band_filter": {"enable": False}, + "utmos": {"enable": False}, + "sigmos": {"enable": False}, + "speaker_separation": {"enable": False}, + } + ) stages = stage.decompose() assert len(stages) == 2 assert isinstance(stages[0], MonoConversionStage) assert isinstance(stages[1], TimestampMapperStage) def test_decompose_no_speaker_no_second_pass(self) -> None: - stage = AudioDataFilterStage(config={ - "speaker_separation": {"enable": False}, - }) + stage = AudioDataFilterStage( + config={ + "speaker_separation": {"enable": False}, + } + ) stages = stage.decompose() assert len(stages) == 6 stage_types = [type(s) for s in stages] @@ -259,10 +263,12 @@ def test_decompose_custom_thresholds(self) -> None: class TestDecomposeEdgeCases: def test_decompose_speaker_without_vad_no_concat(self) -> None: - stage = AudioDataFilterStage(config={ - "vad": {"enable": False}, - "speaker_separation": {"enable": True}, - }) + stage = AudioDataFilterStage( + config={ + "vad": {"enable": False}, + "speaker_separation": {"enable": True}, + } + ) stages = stage.decompose() stage_types = [type(s) for s in stages] assert SegmentConcatenationStage not in stage_types @@ -299,8 +305,14 @@ class TestDefaultYAMLConsistency: def test_default_yaml_all_stages_have_resources(self) -> None: cfg = load_config(None) stages_with_cpus = [ - "mono_conversion", "vad", "band_filter", "utmos", - "sigmos", "concatenation", "speaker_separation", "timestamp_mapper", + "mono_conversion", + "vad", + "band_filter", + "utmos", + "sigmos", + "concatenation", + "speaker_separation", + "timestamp_mapper", ] for stage_name in stages_with_cpus: assert "cpus" in cfg[stage_name], f"{stage_name} missing 'cpus' in config" @@ -308,4 +320,3 @@ def test_default_yaml_all_stages_have_resources(self) -> None: stages_with_gpus = ["vad", "band_filter", "utmos", "sigmos", "speaker_separation"] for stage_name in stages_with_gpus: assert "gpus" in cfg[stage_name], f"{stage_name} missing 'gpus' in config" - diff --git a/tests/stages/audio/alm/pretrain/test_extraction.py b/tests/stages/audio/alm/pretrain/test_extraction.py index 4b7f4be219..65d3b585ed 100644 --- a/tests/stages/audio/alm/pretrain/test_extraction.py +++ b/tests/stages/audio/alm/pretrain/test_extraction.py @@ -66,7 +66,7 @@ def _task_with_plan(audio_path: Path, plan: list[dict], extras: dict | None = No data = {"id": "X", "audio_filepath": str(audio_path), _PLAN_DATA_KEY: plan} if extras: data.update(extras) - return AudioTask(task_id="t1", dataset_name="ds", data=data) + return AudioTask(dataset_name="ds", data=data) # ---------------------------------------------------------------------- @@ -160,11 +160,7 @@ def test_emitted_metadata_uses_tar_basename(self, tmp_path: Path) -> None: # Source row has audio_sample_rate / audio_num_channels populated, so # the extractor's conditional updates fire. - out = stage.process( - _task_with_plan( - src, plan, extras={"audio_sample_rate": 22050, "audio_num_channels": 2} - ) - ) + out = stage.process(_task_with_plan(src, plan, extras={"audio_sample_rate": 22050, "audio_num_channels": 2})) stage.teardown() d = out[0].data # Snippet ID + tar-internal basename (no slashes, no directory prefix) @@ -275,10 +271,6 @@ def test_zero_planned_emits_stub(self, tmp_path: Path) -> None: def test_invalid_output_format_rejected(self, tmp_path: Path) -> None: tar_path = str(tmp_path / "snips.tar") with pytest.raises(ValueError, match="output_format"): - SnippetExtractionStage( - output_dir=str(tmp_path), output_audio_tar_path=tar_path, output_format="m4a" - ) + SnippetExtractionStage(output_dir=str(tmp_path), output_audio_tar_path=tar_path, output_format="m4a") with pytest.raises(ValueError, match="target_sample_rate"): - SnippetExtractionStage( - output_dir=str(tmp_path), output_audio_tar_path=tar_path, target_sample_rate=0 - ) + SnippetExtractionStage(output_dir=str(tmp_path), output_audio_tar_path=tar_path, target_sample_rate=0) diff --git a/tests/stages/audio/alm/pretrain/test_io.py b/tests/stages/audio/alm/pretrain/test_io.py index dc59bf8634..667f3cf2b7 100644 --- a/tests/stages/audio/alm/pretrain/test_io.py +++ b/tests/stages/audio/alm/pretrain/test_io.py @@ -34,8 +34,8 @@ from nemo_curator.tasks import AudioTask, _EmptyTask -def _make_audio_task(data: dict | None = None, *, task_id: str = "t1") -> AudioTask: - return AudioTask(task_id=task_id, dataset_name="ds", data=data or {}) +def _make_audio_task(data: dict | None = None) -> AudioTask: + return AudioTask(dataset_name="ds", data=data or {}) def _ts(start: float, end: float, text: str = "x", text_itn: str | None = None) -> dict: @@ -71,19 +71,19 @@ def manifest_path(tmp_path: Path) -> Path: class TestReadLongFormManifestStage: def test_emits_one_task_per_valid_row(self, tmp_path: Path, manifest_path: Path) -> None: stage = ReadLongFormManifestStage(input_manifest=str(manifest_path), audio_dir=str(tmp_path)) - out = stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + out = stage.process(_EmptyTask(dataset_name="empty", data=None)) assert len(out) == 1 assert out[0].data["id"] == "A" def test_resolves_audio_path_against_audio_dir(self, tmp_path: Path, manifest_path: Path) -> None: stage = ReadLongFormManifestStage(input_manifest=str(manifest_path), audio_dir="/data") - out = stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + out = stage.process(_EmptyTask(dataset_name="empty", data=None)) assert out[0].data["audio_filepath"] == "/data/a.wav" def test_missing_manifest_raises(self, tmp_path: Path) -> None: stage = ReadLongFormManifestStage(input_manifest=str(tmp_path / "nope.jsonl"), audio_dir=str(tmp_path)) with pytest.raises(FileNotFoundError): - stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + stage.process(_EmptyTask(dataset_name="empty", data=None)) def test_skips_rows_missing_id(self, tmp_path: Path) -> None: # `id` is required: the metrics aggregator keys per-source records on @@ -100,7 +100,7 @@ def test_skips_rows_missing_id(self, tmp_path: Path) -> None: for r in rows: f.write(json.dumps(r) + "\n") stage = ReadLongFormManifestStage(input_manifest=str(p), audio_dir="/data") - out = stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + out = stage.process(_EmptyTask(dataset_name="empty", data=None)) assert [t.data["id"] for t in out] == ["OK"] def test_skips_duplicate_ids(self, tmp_path: Path) -> None: @@ -117,7 +117,7 @@ def test_skips_duplicate_ids(self, tmp_path: Path) -> None: for r in rows: f.write(json.dumps(r) + "\n") stage = ReadLongFormManifestStage(input_manifest=str(p), audio_dir="/data") - out = stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + out = stage.process(_EmptyTask(dataset_name="empty", data=None)) assert [t.data["id"] for t in out] == ["A", "B"] # The kept "A" row is the first one, not the duplicate. assert out[0].data["audio_filepath"] == "/data/a.wav" @@ -136,7 +136,7 @@ def test_basename_mode_rejects_duplicate_basenames(self, tmp_path: Path) -> None f.write(json.dumps(r) + "\n") stage = ReadLongFormManifestStage(input_manifest=str(p), audio_dir="/data") with pytest.raises(ValueError, match="duplicate audio basename"): - stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + stage.process(_EmptyTask(dataset_name="empty", data=None)) def test_relative_mode_preserves_subdirectories(self, tmp_path: Path) -> None: # 'relative' joins audio_dir / audio_filepath verbatim and so does @@ -149,10 +149,8 @@ def test_relative_mode_preserves_subdirectories(self, tmp_path: Path) -> None: with p.open("w", encoding="utf-8") as f: for r in rows: f.write(json.dumps(r) + "\n") - stage = ReadLongFormManifestStage( - input_manifest=str(p), audio_dir="/data", audio_path_resolution="relative" - ) - out = stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + stage = ReadLongFormManifestStage(input_manifest=str(p), audio_dir="/data", audio_path_resolution="relative") + out = stage.process(_EmptyTask(dataset_name="empty", data=None)) assert [t.data["audio_filepath"] for t in out] == [ "/data/shard1/foo.wav", "/data/shard2/foo.wav", @@ -167,10 +165,8 @@ def test_as_is_mode_returns_value_unchanged(self, tmp_path: Path) -> None: with p.open("w", encoding="utf-8") as f: for r in rows: f.write(json.dumps(r) + "\n") - stage = ReadLongFormManifestStage( - input_manifest=str(p), audio_dir="/ignored", audio_path_resolution="as_is" - ) - out = stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + stage = ReadLongFormManifestStage(input_manifest=str(p), audio_dir="/ignored", audio_path_resolution="as_is") + out = stage.process(_EmptyTask(dataset_name="empty", data=None)) assert [t.data["audio_filepath"] for t in out] == [ "/absolute/a.wav", "relative/b.wav", @@ -180,11 +176,9 @@ def test_unknown_resolution_mode_raises(self, tmp_path: Path) -> None: p = tmp_path / "in.jsonl" with p.open("w", encoding="utf-8") as f: f.write(json.dumps({"id": "A", "audio_filepath": "./a.wav", "segments": []}) + "\n") - stage = ReadLongFormManifestStage( - input_manifest=str(p), audio_dir="/data", audio_path_resolution="not_a_mode" - ) + stage = ReadLongFormManifestStage(input_manifest=str(p), audio_dir="/data", audio_path_resolution="not_a_mode") with pytest.raises(ValueError, match="unknown audio_path_resolution"): - stage.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + stage.process(_EmptyTask(dataset_name="empty", data=None)) # ---------------------------------------------------------------------- diff --git a/tests/stages/audio/alm/pretrain/test_pipeline.py b/tests/stages/audio/alm/pretrain/test_pipeline.py index f327763826..d140096d3a 100644 --- a/tests/stages/audio/alm/pretrain/test_pipeline.py +++ b/tests/stages/audio/alm/pretrain/test_pipeline.py @@ -163,7 +163,7 @@ def _run_pipeline_inline( # noqa: PLR0913 aggregator.setup_on_node() aggregator.setup() - for input_task in reader.process(_EmptyTask(task_id="e", dataset_name="e", data=None)): + for input_task in reader.process(_EmptyTask(dataset_name="e", data=None)): filtered = overlap.process(input_task) planned = planner.process(filtered) rep_filtered = rep_filter.process(planned) diff --git a/tests/stages/audio/alm/pretrain/test_planning.py b/tests/stages/audio/alm/pretrain/test_planning.py index 67b61cecbf..cc455aa7c7 100644 --- a/tests/stages/audio/alm/pretrain/test_planning.py +++ b/tests/stages/audio/alm/pretrain/test_planning.py @@ -41,8 +41,8 @@ from pathlib import Path -def _make_audio_task(data: dict | None = None, *, task_id: str = "t1") -> AudioTask: - return AudioTask(task_id=task_id, dataset_name="ds", data=data or {}) +def _make_audio_task(data: dict | None = None) -> AudioTask: + return AudioTask(dataset_name="ds", data=data or {}) def _ts(start: float, end: float, text: str = "x", text_itn: str | None = None) -> dict: @@ -153,7 +153,7 @@ def tokenizer_dir(self, tmp_path: Path) -> Path: ) def _make_task_with_plan(self, plan: list[dict]) -> AudioTask: - task = AudioTask(task_id="t1", dataset_name="ds", data={_PLAN_DATA_KEY: plan}) + task = AudioTask(dataset_name="ds", data={_PLAN_DATA_KEY: plan}) task._metadata = {} return task @@ -246,12 +246,8 @@ def test_process_is_idempotent_under_re_execution(self, tokenizer_dir: Path) -> repeat = "thank you for watching " * 10 # Build TWO tasks with the same plan; process() runs once per task. - task1 = self._make_task_with_plan( - [{"start": 0.0, "end": 30.0, "segments": [_ts(0.0, 30.0, repeat)]}] - ) - task2 = self._make_task_with_plan( - [{"start": 0.0, "end": 30.0, "segments": [_ts(0.0, 30.0, repeat)]}] - ) + task1 = self._make_task_with_plan([{"start": 0.0, "end": 30.0, "segments": [_ts(0.0, 30.0, repeat)]}]) + task2 = self._make_task_with_plan([{"start": 0.0, "end": 30.0, "segments": [_ts(0.0, 30.0, repeat)]}]) # First-pass result. out1 = stage.process(task1) meta1 = out1._metadata[_PRETRAIN_META_KEY] diff --git a/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py b/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py index 6e4ab28206..0066e3f7ea 100644 --- a/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py +++ b/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py @@ -109,7 +109,7 @@ def test_process_end_to_end(tmp_path: Path) -> None: patch("nemo_curator.stages.audio.datasets.fleurs.create_initial_manifest.download_file"), patch("nemo_curator.stages.audio.datasets.fleurs.create_initial_manifest.extract_archive"), ): - results = stage.process(_EmptyTask(task_id="empty", dataset_name="test", data=None)) + results = stage.process(_EmptyTask(dataset_name="test", data=None)) assert len(results) == 2 assert results[0].data["text"] == "hello" assert results[1].data["text"] == "world" diff --git a/tests/stages/audio/datasets/test_readspeech_create_initial_manifest.py b/tests/stages/audio/datasets/test_readspeech_create_initial_manifest.py index 69930993f3..ee6bc2930e 100644 --- a/tests/stages/audio/datasets/test_readspeech_create_initial_manifest.py +++ b/tests/stages/audio/datasets/test_readspeech_create_initial_manifest.py @@ -80,12 +80,13 @@ def test_process_end_to_end(tmp_path: Path) -> None: (wav_dir / "book_00001_chp_0002_reader_00200_0_seg_1_seg1.wav").write_bytes(b"\x00") stage = CreateInitialManifestReadSpeechStage( - raw_data_dir=str(tmp_path / "dns_data"), max_samples=-1, auto_download=False, + raw_data_dir=str(tmp_path / "dns_data"), + max_samples=-1, + auto_download=False, ) - results = stage.process(_EmptyTask(task_id="empty", dataset_name="test", data=None)) + results = stage.process(_EmptyTask(dataset_name="test", data=None)) assert len(results) == 2 assert all(isinstance(r, AudioTask) for r in results) - assert results[0].task_id == "readspeech_0" assert results[0].dataset_name == "DNS-ReadSpeech" @@ -93,7 +94,7 @@ def test_process_empty_dir(tmp_path: Path) -> None: empty_dir = tmp_path / "empty" empty_dir.mkdir() stage = CreateInitialManifestReadSpeechStage(raw_data_dir=str(empty_dir), auto_download=False) - results = stage.process(_EmptyTask(task_id="e", dataset_name="t", data=None)) + results = stage.process(_EmptyTask(dataset_name="t", data=None)) assert results == [] @@ -104,6 +105,6 @@ def test_auto_download_calls_download(tmp_path: Path) -> None: (wav_dir / "book_00000_chp_0001_reader_00100_0_seg_1_seg1.wav").write_bytes(b"\x00") with patch.object(stage, "download_and_extract", return_value=str(wav_dir)) as mock_dl: - results = stage.process(_EmptyTask(task_id="e", dataset_name="t", data=None)) + results = stage.process(_EmptyTask(dataset_name="t", data=None)) mock_dl.assert_called_once() assert len(results) == 1 diff --git a/tests/stages/audio/filtering/test_band.py b/tests/stages/audio/filtering/test_band.py index 906abade18..44fc1d0684 100644 --- a/tests/stages/audio/filtering/test_band.py +++ b/tests/stages/audio/filtering/test_band.py @@ -25,7 +25,6 @@ def _make_task(waveform: torch.Tensor | None = None, sample_rate: int = 48000) - waveform = torch.randn(1, sample_rate) return AudioTask( data={"waveform": waveform, "sample_rate": sample_rate}, - task_id="test", dataset_name="test", ) @@ -82,7 +81,7 @@ def test_no_waveform_no_filepath_skipped(self, mock_init: MagicMock) -> None: stage = BandFilterStage(band_value="full_band") stage._predictor = MagicMock() - task = AudioTask(data={"some_key": "value"}, task_id="test", dataset_name="test") + task = AudioTask(data={"some_key": "value"}, dataset_name="test") result = stage.process(task) assert result == [] @@ -105,7 +104,6 @@ def predict_side_effect(_waveform: object, _sample_rate: int) -> str: segments = [{"waveform": torch.randn(1, sr), "sample_rate": sr, "segment_num": i} for i in range(4)] task = AudioTask( data={"segments": segments, "original_file": "test.wav"}, - task_id="test", dataset_name="test", ) @@ -126,7 +124,6 @@ def test_process_nested_all_filtered_returns_empty(self, mock_init: MagicMock) - segments = [{"waveform": torch.randn(1, sr), "sample_rate": sr, "segment_num": i} for i in range(3)] task = AudioTask( data={"segments": segments}, - task_id="test", dataset_name="test", ) diff --git a/tests/stages/audio/filtering/test_sigmos.py b/tests/stages/audio/filtering/test_sigmos.py index 45da1651c1..7fde7efcb7 100644 --- a/tests/stages/audio/filtering/test_sigmos.py +++ b/tests/stages/audio/filtering/test_sigmos.py @@ -44,7 +44,6 @@ def _make_task(duration_s: float = 1.0, sample_rate: int = 48000) -> AudioTask: num_samples = int(duration_s * sample_rate) return AudioTask( data={"waveform": torch.randn(1, num_samples), "sample_rate": sample_rate}, - task_id="test", dataset_name="test", ) @@ -96,15 +95,17 @@ def test_none_thresholds_disable_checks(self, mock_init: MagicMock) -> None: @patch.object(SIGMOSFilterStage, "_initialize_model") def test_partial_threshold_fail(self, mock_init: MagicMock) -> None: stage = SIGMOSFilterStage(noise_threshold=4.0, ovrl_threshold=None) - stage._model = _make_mock_model({ - "MOS_NOISE": 3.0, - "MOS_OVRL": 5.0, - "MOS_SIG": 5.0, - "MOS_COL": 5.0, - "MOS_DISC": 5.0, - "MOS_LOUD": 5.0, - "MOS_REVERB": 5.0, - }) + stage._model = _make_mock_model( + { + "MOS_NOISE": 3.0, + "MOS_OVRL": 5.0, + "MOS_SIG": 5.0, + "MOS_COL": 5.0, + "MOS_DISC": 5.0, + "MOS_LOUD": 5.0, + "MOS_REVERB": 5.0, + } + ) result = stage.process(_make_task()) @@ -134,7 +135,7 @@ def test_no_audio_no_filepath_skipped(self, mock_init: MagicMock) -> None: stage = SIGMOSFilterStage() stage._model = _make_mock_model(_GOOD_SCORES) - task = AudioTask(data={"some_key": "value"}, task_id="test", dataset_name="test") + task = AudioTask(data={"some_key": "value"}, dataset_name="test") result = stage.process(task) assert result == [] @@ -168,7 +169,6 @@ def fake_run(audio: object, sr: int) -> dict: # noqa: ARG001 segments = [{"waveform": torch.randn(1, sr), "sample_rate": sr, "segment_num": i} for i in range(4)] task = AudioTask( data={"segments": segments, "original_file": "test.wav"}, - task_id="test", dataset_name="test", ) @@ -189,7 +189,6 @@ def test_process_nested_all_filtered_returns_empty(self, mock_init: MagicMock) - segments = [{"waveform": torch.randn(1, sr), "sample_rate": sr, "segment_num": i} for i in range(3)] task = AudioTask( data={"segments": segments}, - task_id="test", dataset_name="test", ) diff --git a/tests/stages/audio/filtering/test_utmos.py b/tests/stages/audio/filtering/test_utmos.py index bed7d5ba86..75c82ddd52 100644 --- a/tests/stages/audio/filtering/test_utmos.py +++ b/tests/stages/audio/filtering/test_utmos.py @@ -26,7 +26,6 @@ def _make_task(duration_s: float = 1.0, sample_rate: int = 16000) -> AudioTask: num_samples = int(duration_s * sample_rate) return AudioTask( data={"waveform": torch.randn(1, num_samples), "sample_rate": sample_rate}, - task_id="test", dataset_name="test", ) @@ -84,7 +83,7 @@ def test_no_waveform_no_filepath_skipped(self, mock_ensure: MagicMock) -> None: stage = UTMOSFilterStage(mos_threshold=3.0) stage._model = _mock_model(4.0) - task = AudioTask(data={"some_key": "value"}, task_id="test", dataset_name="test") + task = AudioTask(data={"some_key": "value"}, dataset_name="test") result = stage.process(task) assert result == [] @@ -122,7 +121,6 @@ def model_side_effect(_waveform: torch.Tensor, sr: int = 16000) -> torch.Tensor: segments = [{"waveform": torch.randn(1, sr), "sample_rate": sr, "segment_num": i} for i in range(4)] task = AudioTask( data={"segments": segments, "original_file": "test.wav"}, - task_id="test", dataset_name="test", ) @@ -143,7 +141,6 @@ def test_process_nested_all_filtered_returns_empty(self, mock_ensure: MagicMock) segments = [{"waveform": torch.randn(1, sr), "sample_rate": sr, "segment_num": i} for i in range(3)] task = AudioTask( data={"segments": segments}, - task_id="test", dataset_name="test", ) diff --git a/tests/stages/audio/inference/speaker_diarization/test_sortformer.py b/tests/stages/audio/inference/speaker_diarization/test_sortformer.py index d50bce6186..8e9b714a41 100644 --- a/tests/stages/audio/inference/speaker_diarization/test_sortformer.py +++ b/tests/stages/audio/inference/speaker_diarization/test_sortformer.py @@ -138,7 +138,6 @@ def test_process_audio_task(self) -> None: {"start": 0.0, "end": 2.7, "speaker": "speaker_0"}, {"start": 0.8, "end": 13.6, "speaker": "speaker_1"}, ] - assert result.task_id.endswith("_sortformer") mock_model.diarize.assert_called_once_with( audio=["/test/audio1.wav"], batch_size=1, diff --git a/tests/stages/audio/inference/test_asr_nemo.py b/tests/stages/audio/inference/test_asr_nemo.py index aed0e5fd6f..1137451dbe 100644 --- a/tests/stages/audio/inference/test_asr_nemo.py +++ b/tests/stages/audio/inference/test_asr_nemo.py @@ -73,8 +73,8 @@ def test_process_batch_success(self) -> None: stage.setup() tasks = [ - AudioTask(data={"audio_filepath": "/test/audio1.wav"}, task_id="t1"), - AudioTask(data={"audio_filepath": "/test/audio2.mp3"}, task_id="t2"), + AudioTask(data={"audio_filepath": "/test/audio1.wav"}), + AudioTask(data={"audio_filepath": "/test/audio2.mp3"}), ] results = stage.process_batch(tasks) diff --git a/tests/stages/audio/io/test_convert.py b/tests/stages/audio/io/test_convert.py index 7ab13ddce6..cadd80b70b 100644 --- a/tests/stages/audio/io/test_convert.py +++ b/tests/stages/audio/io/test_convert.py @@ -21,7 +21,6 @@ def test_audio_to_document_stage_process_raises() -> None: entry = AudioTask( - task_id="t1", dataset_name="ds", data={"audio_filepath": "/a.wav", "text": "hello"}, ) @@ -32,10 +31,7 @@ def test_audio_to_document_stage_process_raises() -> None: def test_process_batch_aggregates_into_single_dataframe() -> None: - tasks = [ - AudioTask(task_id=f"t{i}", dataset_name="ds", data={"audio_filepath": f"/{i}.wav", "text": f"text{i}"}) - for i in range(5) - ] + tasks = [AudioTask(dataset_name="ds", data={"audio_filepath": f"/{i}.wav", "text": f"text{i}"}) for i in range(5)] stage = AudioToDocumentStage() result = stage.process_batch(tasks) @@ -47,7 +43,6 @@ def test_process_batch_aggregates_into_single_dataframe() -> None: assert len(doc.data) == 5 assert list(doc.data["audio_filepath"]) == ["/0.wav", "/1.wav", "/2.wav", "/3.wav", "/4.wav"] assert list(doc.data["text"]) == ["text0", "text1", "text2", "text3", "text4"] - assert doc.task_id == "t0,t1,t2,t3,t4" assert doc.dataset_name == "ds" @@ -59,8 +54,8 @@ def test_process_batch_empty() -> None: def test_process_batch_preserves_stage_perf() -> None: tasks = [ - AudioTask(task_id="t1", dataset_name="ds", data={"audio_filepath": "/a.wav"}, _stage_perf=["perf1"]), - AudioTask(task_id="t2", dataset_name="ds", data={"audio_filepath": "/b.wav"}, _stage_perf=["perf2"]), + AudioTask(dataset_name="ds", data={"audio_filepath": "/a.wav"}, _stage_perf=["perf1"]), + AudioTask(dataset_name="ds", data={"audio_filepath": "/b.wav"}, _stage_perf=["perf2"]), ] stage = AudioToDocumentStage() result = stage.process_batch(tasks) @@ -69,9 +64,9 @@ def test_process_batch_preserves_stage_perf() -> None: def test_process_batch_deduplicates_dataset_names() -> None: tasks = [ - AudioTask(task_id="t1", dataset_name="ds_a", data={"audio_filepath": "/a.wav"}), - AudioTask(task_id="t2", dataset_name="ds_b", data={"audio_filepath": "/b.wav"}), - AudioTask(task_id="t3", dataset_name="ds_a", data={"audio_filepath": "/c.wav"}), + AudioTask(dataset_name="ds_a", data={"audio_filepath": "/a.wav"}), + AudioTask(dataset_name="ds_b", data={"audio_filepath": "/b.wav"}), + AudioTask(dataset_name="ds_a", data={"audio_filepath": "/c.wav"}), ] stage = AudioToDocumentStage() result = stage.process_batch(tasks) @@ -79,7 +74,7 @@ def test_process_batch_deduplicates_dataset_names() -> None: def test_process_batch_single_task() -> None: - task = AudioTask(task_id="only", dataset_name="ds", data={"audio_filepath": "/x.wav", "text": "hi"}) + task = AudioTask(dataset_name="ds", data={"audio_filepath": "/x.wav", "text": "hi"}) stage = AudioToDocumentStage() result = stage.process_batch([task]) assert len(result) == 1 diff --git a/tests/stages/audio/io/test_extract_segments.py b/tests/stages/audio/io/test_extract_segments.py index 333197cd1a..2229e4629b 100644 --- a/tests/stages/audio/io/test_extract_segments.py +++ b/tests/stages/audio/io/test_extract_segments.py @@ -120,7 +120,9 @@ def test_non_standard_id(self) -> None: class TestIntervalsFromTimestamps: def test_basic(self) -> None: - assert _intervals_from_timestamps({"original_start_ms": 1000, "original_end_ms": 3000, "duration": 2.0}) == [(1000, 3000, 2.0)] + assert _intervals_from_timestamps({"original_start_ms": 1000, "original_end_ms": 3000, "duration": 2.0}) == [ + (1000, 3000, 2.0) + ] def test_computed_duration(self) -> None: assert _intervals_from_timestamps({"original_start_ms": 500, "original_end_ms": 2500}) == [(500, 2500, 2.0)] @@ -146,7 +148,7 @@ class TestReadSegment: def test_reads_correct_slice(self, wav_dir: Path) -> None: filepath = _wav_path(wav_dir) original, sr = sf.read(filepath) - expected = original[int(1.0 * sr):int(2.0 * sr)] + expected = original[int(1.0 * sr) : int(2.0 * sr)] result = _read_segment(filepath, 1000, 2000, sr) np.testing.assert_array_almost_equal(result, expected, decimal=4) @@ -263,7 +265,7 @@ def test_invalid_format_raises(self, tmp_path: Path) -> None: def test_process_raises_not_implemented(self, tmp_path: Path) -> None: stage = SegmentExtractionStage(output_dir=str(tmp_path)) - task = AudioTask(data={"original_file": "/a.wav"}, task_id="t", dataset_name="d") + task = AudioTask(data={"original_file": "/a.wav"}, dataset_name="d") with pytest.raises(NotImplementedError): stage.process(task) @@ -287,8 +289,24 @@ def test_combo2_timestamps(self, wav_dir: Path, tmp_path: Path) -> None: out_dir = str(tmp_path / "extracted") stage = SegmentExtractionStage(output_dir=out_dir) tasks = [ - AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 0, "original_end_ms": 2000, "duration": 2.0}, task_id="t1", dataset_name="test"), - AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 2500, "original_end_ms": 4500, "duration": 2.0}, task_id="t2", dataset_name="test"), + AudioTask( + data={ + "original_file": _wav_path(wav_dir), + "original_start_ms": 0, + "original_end_ms": 2000, + "duration": 2.0, + }, + dataset_name="test", + ), + AudioTask( + data={ + "original_file": _wav_path(wav_dir), + "original_start_ms": 2500, + "original_end_ms": 4500, + "duration": 2.0, + }, + dataset_name="test", + ), ] result = stage.process_batch(tasks) assert len(result) == 2 @@ -300,8 +318,13 @@ def test_combo3_diar_segments(self, wav_dir: Path, tmp_path: Path) -> None: stage = SegmentExtractionStage(output_dir=out_dir) tasks = [ AudioTask( - data={"original_file": _wav_path(wav_dir), "speaker_id": "speaker_0", "num_speakers": 2, "diar_segments": [[0.5, 1.5], [2.0, 3.0]]}, - task_id="t1", dataset_name="test", + data={ + "original_file": _wav_path(wav_dir), + "speaker_id": "speaker_0", + "num_speakers": 2, + "diar_segments": [[0.5, 1.5], [2.0, 3.0]], + }, + dataset_name="test", ), ] result = stage.process_batch(tasks) @@ -313,8 +336,26 @@ def test_combo4_speaker_timestamps(self, wav_dir: Path, tmp_path: Path) -> None: out_dir = str(tmp_path / "extracted") stage = SegmentExtractionStage(output_dir=out_dir) tasks = [ - AudioTask(data={"original_file": _wav_path(wav_dir), "speaker_id": "speaker_0", "original_start_ms": 0, "original_end_ms": 1000, "duration": 1.0}, task_id="t1", dataset_name="test"), - AudioTask(data={"original_file": _wav_path(wav_dir), "speaker_id": "speaker_1", "original_start_ms": 1500, "original_end_ms": 2500, "duration": 1.0}, task_id="t2", dataset_name="test"), + AudioTask( + data={ + "original_file": _wav_path(wav_dir), + "speaker_id": "speaker_0", + "original_start_ms": 0, + "original_end_ms": 1000, + "duration": 1.0, + }, + dataset_name="test", + ), + AudioTask( + data={ + "original_file": _wav_path(wav_dir), + "speaker_id": "speaker_1", + "original_start_ms": 1500, + "original_end_ms": 2500, + "duration": 1.0, + }, + dataset_name="test", + ), ] result = stage.process_batch(tasks) assert len(result) == 2 @@ -325,7 +366,15 @@ def test_flac_format(self, wav_dir: Path, tmp_path: Path) -> None: out_dir = str(tmp_path / "extracted") stage = SegmentExtractionStage(output_dir=out_dir, output_format="flac") tasks = [ - AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 0, "original_end_ms": 1000, "duration": 1.0}, task_id="t1", dataset_name="test"), + AudioTask( + data={ + "original_file": _wav_path(wav_dir), + "original_start_ms": 0, + "original_end_ms": 1000, + "duration": 1.0, + }, + dataset_name="test", + ), ] stage.process_batch(tasks) assert os.path.exists(os.path.join(out_dir, "file_a_segment_000.flac")) @@ -334,7 +383,15 @@ def test_missing_original_file_skipped(self, tmp_path: Path) -> None: out_dir = str(tmp_path / "extracted") stage = SegmentExtractionStage(output_dir=out_dir) tasks = [ - AudioTask(data={"original_file": "/nonexistent/audio.wav", "original_start_ms": 0, "original_end_ms": 1000, "duration": 1.0}, task_id="t1", dataset_name="test"), + AudioTask( + data={ + "original_file": "/nonexistent/audio.wav", + "original_start_ms": 0, + "original_end_ms": 1000, + "duration": 1.0, + }, + dataset_name="test", + ), ] result = stage.process_batch(tasks) assert len(result) == 1 @@ -344,7 +401,16 @@ def test_metadata_csv_written(self, wav_dir: Path, tmp_path: Path) -> None: out_dir = str(tmp_path / "extracted") stage = SegmentExtractionStage(output_dir=out_dir) tasks = [ - AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 0, "original_end_ms": 1000, "duration": 1.0, "utmos_mos": 4.2}, task_id="t1", dataset_name="test"), + AudioTask( + data={ + "original_file": _wav_path(wav_dir), + "original_start_ms": 0, + "original_end_ms": 1000, + "duration": 1.0, + "utmos_mos": 4.2, + }, + dataset_name="test", + ), ] stage.process_batch(tasks) with open(os.path.join(out_dir, "metadata.csv")) as f: @@ -354,12 +420,20 @@ def test_metadata_csv_written(self, wav_dir: Path, tmp_path: Path) -> None: def test_audio_content_matches_source(self, wav_dir: Path, tmp_path: Path) -> None: original, sr = sf.read(_wav_path(wav_dir)) - expected_slice = original[int(1.0 * sr):int(2.0 * sr)] + expected_slice = original[int(1.0 * sr) : int(2.0 * sr)] out_dir = str(tmp_path / "extracted") stage = SegmentExtractionStage(output_dir=out_dir) tasks = [ - AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 1000, "original_end_ms": 2000, "duration": 1.0}, task_id="t1", dataset_name="test"), + AudioTask( + data={ + "original_file": _wav_path(wav_dir), + "original_start_ms": 1000, + "original_end_ms": 2000, + "duration": 1.0, + }, + dataset_name="test", + ), ] stage.process_batch(tasks) @@ -375,7 +449,13 @@ def test_audio_content_matches_source(self, wav_dir: Path, tmp_path: Path) -> No class TestExtractFromManifest: def test_end_to_end(self, wav_dir: Path, tmp_path: Path) -> None: entries = [ - {"original_file": _wav_path(wav_dir), "original_start_ms": 0, "original_end_ms": 2000, "duration": 2.0, "utmos_mos": 4.0}, + { + "original_file": _wav_path(wav_dir), + "original_start_ms": 0, + "original_end_ms": 2000, + "duration": 2.0, + "utmos_mos": 4.0, + }, {"original_file": _wav_path(wav_dir), "original_start_ms": 2500, "original_end_ms": 4500, "duration": 2.0}, ] manifest_path = _write_manifest(tmp_path, entries) @@ -406,7 +486,17 @@ def test_directory_input(self, wav_dir: Path, tmp_path: Path) -> None: manifest_dir.mkdir() for i, start in enumerate([0, 2000]): with open(str(manifest_dir / f"part_{i}.jsonl"), "w") as f: - f.write(json.dumps({"original_file": _wav_path(wav_dir), "original_start_ms": start, "original_end_ms": start + 1000, "duration": 1.0}) + "\n") + f.write( + json.dumps( + { + "original_file": _wav_path(wav_dir), + "original_start_ms": start, + "original_end_ms": start + 1000, + "duration": 1.0, + } + ) + + "\n" + ) out_dir = str(tmp_path / "output") stage = SegmentExtractionStage(output_dir=out_dir) diff --git a/tests/stages/audio/postprocessing/test_timestamp_mapper.py b/tests/stages/audio/postprocessing/test_timestamp_mapper.py index 74a077d8a9..7b12bddb01 100644 --- a/tests/stages/audio/postprocessing/test_timestamp_mapper.py +++ b/tests/stages/audio/postprocessing/test_timestamp_mapper.py @@ -24,8 +24,8 @@ from nemo_curator.tasks import AudioTask -def _make_task(data: dict, task_id: str = "test", metadata: dict | None = None) -> AudioTask: - t = AudioTask(data=data, task_id=task_id, dataset_name="test_ds") +def _make_task(data: dict, metadata: dict | None = None) -> AudioTask: + t = AudioTask(data=data, dataset_name="test_ds") if metadata: t._metadata = metadata return t @@ -35,9 +35,27 @@ class TestTranslateToOriginal: """Unit tests for the pure _translate_to_original() function.""" MAPPINGS: ClassVar[list[dict]] = [ - {"concat_start_ms": 0, "concat_end_ms": 2000, "original_file": "a.wav", "original_start_ms": 5000, "original_end_ms": 7000}, - {"concat_start_ms": 2000, "concat_end_ms": 5000, "original_file": "b.wav", "original_start_ms": 0, "original_end_ms": 3000}, - {"concat_start_ms": 5000, "concat_end_ms": 8000, "original_file": "c.wav", "original_start_ms": 10000, "original_end_ms": 13000}, + { + "concat_start_ms": 0, + "concat_end_ms": 2000, + "original_file": "a.wav", + "original_start_ms": 5000, + "original_end_ms": 7000, + }, + { + "concat_start_ms": 2000, + "concat_end_ms": 5000, + "original_file": "b.wav", + "original_start_ms": 0, + "original_end_ms": 3000, + }, + { + "concat_start_ms": 5000, + "concat_end_ms": 8000, + "original_file": "c.wav", + "original_start_ms": 10000, + "original_end_ms": 13000, + }, ] def test_single_mapping_exact_match(self) -> None: @@ -74,8 +92,20 @@ def test_cross_boundary_span(self) -> None: def test_silence_gap_no_overlap(self) -> None: """Segment falls entirely in a gap between mappings.""" mappings = [ - {"concat_start_ms": 0, "concat_end_ms": 1000, "original_file": "a.wav", "original_start_ms": 0, "original_end_ms": 1000}, - {"concat_start_ms": 3000, "concat_end_ms": 5000, "original_file": "b.wav", "original_start_ms": 0, "original_end_ms": 2000}, + { + "concat_start_ms": 0, + "concat_end_ms": 1000, + "original_file": "a.wav", + "original_start_ms": 0, + "original_end_ms": 1000, + }, + { + "concat_start_ms": 3000, + "concat_end_ms": 5000, + "original_file": "b.wav", + "original_start_ms": 0, + "original_end_ms": 2000, + }, ] results = _translate_to_original(mappings, 1000, 3000) assert len(results) == 0 @@ -84,7 +114,13 @@ def test_malformed_mapping_missing_key(self) -> None: """Malformed mapping (missing key) is skipped gracefully.""" mappings = [ {"concat_start_ms": 0, "concat_end_ms": 2000}, - {"concat_start_ms": 2000, "concat_end_ms": 4000, "original_file": "b.wav", "original_start_ms": 0, "original_end_ms": 2000}, + { + "concat_start_ms": 2000, + "concat_end_ms": 4000, + "original_file": "b.wav", + "original_start_ms": 0, + "original_end_ms": 2000, + }, ] results = _translate_to_original(mappings, 0, 4000) assert len(results) == 1 @@ -98,7 +134,13 @@ def test_empty_mappings(self) -> None: def test_no_overlap_before_all_mappings(self) -> None: """Segment ends before any mapping starts.""" mappings = [ - {"concat_start_ms": 5000, "concat_end_ms": 8000, "original_file": "a.wav", "original_start_ms": 0, "original_end_ms": 3000}, + { + "concat_start_ms": 5000, + "concat_end_ms": 8000, + "original_file": "a.wav", + "original_start_ms": 0, + "original_end_ms": 3000, + }, ] results = _translate_to_original(mappings, 0, 1000) assert results == [] diff --git a/tests/stages/audio/preprocessing/test_concatenation.py b/tests/stages/audio/preprocessing/test_concatenation.py index 827d285b35..5a76ddae32 100644 --- a/tests/stages/audio/preprocessing/test_concatenation.py +++ b/tests/stages/audio/preprocessing/test_concatenation.py @@ -34,7 +34,6 @@ def _make_segment_dict(duration_ms: int = 1000, sample_rate: int = 48000, segmen def _make_nested_task(segments: list[dict]) -> AudioTask: return AudioTask( data={"segments": segments, "original_file": "test.wav"}, - task_id="test_task", dataset_name="ds", ) @@ -92,7 +91,6 @@ def test_silence_duration_in_output(self) -> None: def test_no_waveform_in_tasks(self) -> None: task = AudioTask( data={"segments": [{"other_key": "value"}]}, - task_id="empty", dataset_name="ds", ) stage = SegmentConcatenationStage() @@ -100,7 +98,7 @@ def test_no_waveform_in_tasks(self) -> None: assert result == [] def test_missing_segments_key_raises(self) -> None: - task = AudioTask(data={"other_key": "value"}, task_id="empty", dataset_name="ds") + task = AudioTask(data={"other_key": "value"}, dataset_name="ds") stage = SegmentConcatenationStage() with pytest.raises(ValueError): # noqa: PT011 stage.process(task) diff --git a/tests/stages/audio/preprocessing/test_mono_conversion.py b/tests/stages/audio/preprocessing/test_mono_conversion.py index 61563a1f6f..9ba9e45389 100644 --- a/tests/stages/audio/preprocessing/test_mono_conversion.py +++ b/tests/stages/audio/preprocessing/test_mono_conversion.py @@ -33,7 +33,7 @@ def test_process_stereo_to_mono(self, tmp_path: Path) -> None: with patch(MOCK_TARGET, return_value=(stereo, 48000)), patch(MOCK_EXISTS, return_value=True): stage = MonoConversionStage(output_sample_rate=48000) - task = AudioTask(data={"audio_filepath": wav.as_posix()}, task_id="t1") + task = AudioTask(data={"audio_filepath": wav.as_posix()}) result = stage.process(task) assert isinstance(result, AudioTask) @@ -51,7 +51,7 @@ def test_process_mono_passthrough(self, tmp_path: Path) -> None: with patch(MOCK_TARGET, return_value=(mono, 48000)), patch(MOCK_EXISTS, return_value=True): stage = MonoConversionStage(output_sample_rate=48000) - task = AudioTask(data={"audio_filepath": wav.as_posix()}, task_id="t1") + task = AudioTask(data={"audio_filepath": wav.as_posix()}) result = stage.process(task) assert isinstance(result, AudioTask) @@ -66,7 +66,7 @@ def test_strict_sample_rate_rejects_mismatch(self, tmp_path: Path) -> None: with patch(MOCK_TARGET, return_value=(audio, 22050)), patch(MOCK_EXISTS, return_value=True): stage = MonoConversionStage(output_sample_rate=48000, strict_sample_rate=True) - task = AudioTask(data={"audio_filepath": wav.as_posix()}, task_id="t1") + task = AudioTask(data={"audio_filepath": wav.as_posix()}) result = stage.process(task) assert result == [] @@ -79,7 +79,7 @@ def test_non_strict_sample_rate_accepts_any(self, tmp_path: Path) -> None: with patch(MOCK_TARGET, return_value=(audio, 22050)), patch(MOCK_EXISTS, return_value=True): stage = MonoConversionStage(output_sample_rate=48000, strict_sample_rate=False) - task = AudioTask(data={"audio_filepath": wav.as_posix()}, task_id="t1") + task = AudioTask(data={"audio_filepath": wav.as_posix()}) result = stage.process(task) assert isinstance(result, AudioTask) @@ -87,13 +87,13 @@ def test_non_strict_sample_rate_accepts_any(self, tmp_path: Path) -> None: def test_missing_file_skipped(self) -> None: stage = MonoConversionStage() - task = AudioTask(data={"audio_filepath": "/nonexistent/path.wav"}, task_id="t1") + task = AudioTask(data={"audio_filepath": "/nonexistent/path.wav"}) result = stage.process(task) assert result == [] def test_missing_filepath_key_skipped(self) -> None: stage = MonoConversionStage() - task = AudioTask(data={"other_key": "value"}, task_id="t1") + task = AudioTask(data={"other_key": "value"}) result = stage.process(task) assert result == [] @@ -103,7 +103,7 @@ def test_read_exception_skipped(self, tmp_path: Path) -> None: with patch(MOCK_TARGET, side_effect=RuntimeError("bad file")), patch(MOCK_EXISTS, return_value=True): stage = MonoConversionStage() - task = AudioTask(data={"audio_filepath": wav.as_posix()}, task_id="t1") + task = AudioTask(data={"audio_filepath": wav.as_posix()}) result = stage.process(task) assert result == [] diff --git a/tests/stages/audio/segmentation/test_speaker_separation.py b/tests/stages/audio/segmentation/test_speaker_separation.py index 43067eb5e2..f2a3efb9e7 100644 --- a/tests/stages/audio/segmentation/test_speaker_separation.py +++ b/tests/stages/audio/segmentation/test_speaker_separation.py @@ -34,7 +34,6 @@ def _make_task(duration_sec: float = 10.0, sample_rate: int = 48000) -> AudioTas num_samples = int(duration_sec * sample_rate) return AudioTask( data={"waveform": torch.randn(1, num_samples), "sample_rate": sample_rate}, - task_id="test", dataset_name="test", ) @@ -119,7 +118,6 @@ def test_no_audio_no_filepath_skipped(self, mock_init: MagicMock) -> None: task = AudioTask( data={"some_key": "value"}, - task_id="test", dataset_name="test", ) result = stage.process(task) diff --git a/tests/stages/audio/segmentation/test_vad_segmentation.py b/tests/stages/audio/segmentation/test_vad_segmentation.py index b4f3d01c24..2b35544b7d 100644 --- a/tests/stages/audio/segmentation/test_vad_segmentation.py +++ b/tests/stages/audio/segmentation/test_vad_segmentation.py @@ -40,7 +40,6 @@ def test_process_returns_segments(self, mock_load_vad: MagicMock, mock_get_ts: M waveform = torch.randn(1, sr * 10) task = AudioTask( data={"waveform": waveform, "sample_rate": sr}, - task_id="test", dataset_name="test", ) @@ -69,7 +68,6 @@ def test_process_output_keys(self, mock_load_vad: MagicMock, mock_get_ts: MagicM waveform = torch.randn(1, sr * 10) task = AudioTask( data={"waveform": waveform, "sample_rate": sr}, - task_id="test", dataset_name="test", ) @@ -91,7 +89,6 @@ def test_empty_speech_returns_empty(self, mock_load_vad: MagicMock, mock_get_ts: waveform = torch.randn(1, 48000 * 5) task = AudioTask( data={"waveform": waveform, "sample_rate": 48000}, - task_id="test", dataset_name="test", ) @@ -117,7 +114,6 @@ def test_segment_numbering(self, mock_load_vad: MagicMock, mock_get_ts: MagicMoc waveform = torch.randn(1, sr * 10) task = AudioTask( data={"waveform": waveform, "sample_rate": sr}, - task_id="test", dataset_name="test", ) @@ -136,7 +132,6 @@ def test_missing_waveform_and_filepath_skipped(self, mock_load_vad: MagicMock, m task = AudioTask( data={"some_key": "value"}, - task_id="test", dataset_name="test", ) @@ -161,7 +156,6 @@ def test_nested_mode_returns_single_task(self, mock_load_vad: MagicMock, mock_ge waveform = torch.randn(1, sr * 10) task = AudioTask( data={"waveform": waveform, "sample_rate": sr}, - task_id="test", dataset_name="test", ) @@ -191,7 +185,6 @@ def test_nested_mode_no_speech_returns_task_with_empty_segments( waveform = torch.randn(1, 48000 * 5) task = AudioTask( data={"waveform": waveform, "sample_rate": 48000}, - task_id="test", dataset_name="test", ) diff --git a/tests/stages/audio/test_common.py b/tests/stages/audio/test_common.py index e8d701c57c..0e50ecc1cc 100644 --- a/tests/stages/audio/test_common.py +++ b/tests/stages/audio/test_common.py @@ -45,7 +45,7 @@ def _make_file_group_task(paths: list[str]) -> FileGroupTask: - return FileGroupTask(task_id="test", dataset_name="test", data=paths) + return FileGroupTask(dataset_name="test", data=paths) # --------------------------------------------------------------------------- @@ -371,7 +371,6 @@ def test_writes_entry_to_jsonl(self, tmp_path: Path) -> None: task = AudioTask( data={"audio_filepath": "a.wav", "duration": 1.0}, - task_id="t1", dataset_name="ds", ) writer.process(task) @@ -386,12 +385,11 @@ def test_returns_audio_task(self, tmp_path: Path) -> None: writer.setup_on_node() writer.setup() - task = AudioTask(data={"x": 1}, task_id="t1", dataset_name="ds") + task = AudioTask(data={"x": 1}, dataset_name="ds") result = writer.process(task) assert isinstance(result, AudioTask) assert result.data == {"x": 1} - assert result.task_id == "t1" assert result.dataset_name == "ds" def test_propagates_metadata_and_stage_perf(self, tmp_path: Path) -> None: @@ -404,7 +402,6 @@ def test_propagates_metadata_and_stage_perf(self, tmp_path: Path) -> None: stage_perf = [{"stage": "some_stage", "process_time": 0.5}] task = AudioTask( data={"x": 1}, - task_id="t1", dataset_name="ds", _metadata=metadata, _stage_perf=stage_perf, @@ -420,9 +417,9 @@ def test_appends_across_multiple_process_calls(self, tmp_path: Path) -> None: writer.setup_on_node() writer.setup() - writer.process(AudioTask(data={"entry": 1}, task_id="t1")) - writer.process(AudioTask(data={"entry": 2}, task_id="t2")) - writer.process(AudioTask(data={"entry": 3}, task_id="t3")) + writer.process(AudioTask(data={"entry": 1})) + writer.process(AudioTask(data={"entry": 2})) + writer.process(AudioTask(data={"entry": 3})) lines = out.read_text().strip().split("\n") assert len(lines) == 3 @@ -450,7 +447,7 @@ def test_handles_unicode_content(self, tmp_path: Path) -> None: writer.setup_on_node() writer.setup() - task = AudioTask(data={"text": "日本語テスト", "speaker": "Ñoño"}, task_id="t1") + task = AudioTask(data={"text": "日本語テスト", "speaker": "Ñoño"}) writer.process(task) loaded = json.loads(out.read_text().strip()) @@ -470,7 +467,7 @@ def test_preserves_nested_structures(self, tmp_path: Path) -> None: ], "stats": {"lost_bw": 3, "lost_sr": 0}, } - task = AudioTask(data=entry, task_id="t1") + task = AudioTask(data=entry) writer.process(task) loaded = json.loads(out.read_text().strip()) @@ -495,12 +492,12 @@ def test_reader_writer_round_trip(self, sample_entries: list[dict], tmp_path: Pa writer = ManifestWriterStage(output_path=str(out)) writer.setup_on_node() writer.setup() - for i, entry in enumerate(sample_entries): - task = AudioTask(data=entry, task_id=f"t{i}") + for _i, entry in enumerate(sample_entries): + task = AudioTask(data=entry) writer.process(task) reader = ManifestReaderStage() - result = reader.process(FileGroupTask(task_id="rt", dataset_name="rt", data=[str(out)])) + result = reader.process(FileGroupTask(dataset_name="rt", data=[str(out)])) assert len(result) == len(sample_entries) for orig, audio_entry in zip(sample_entries, result, strict=True): diff --git a/tests/stages/common/test_base.py b/tests/stages/common/test_base.py index b553eba6e1..dc7c3b0c42 100644 --- a/tests/stages/common/test_base.py +++ b/tests/stages/common/test_base.py @@ -24,7 +24,7 @@ class MockTask(Task[dict]): def __init__(self, data: dict | None = None): self.data = data or {} - super().__init__(task_id="", dataset_name="", data=self.data) + super().__init__(dataset_name="", data=self.data) @property def num_items(self) -> int: diff --git a/tests/stages/common/test_client_partitioning.py b/tests/stages/common/test_client_partitioning.py index 4b919f27ed..0b9ccef18c 100644 --- a/tests/stages/common/test_client_partitioning.py +++ b/tests/stages/common/test_client_partitioning.py @@ -28,7 +28,6 @@ class TestClientPartitioningStage: def empty_task(self) -> _EmptyTask: """Create an empty task for testing.""" return _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={"source": "test"}, @@ -109,7 +108,6 @@ def test_process_basic_functionality( assert len(result) == 3 assert isinstance(result[0], FileGroupTask) assert str(result[0].data[0]).endswith("file1.jsonl") - assert result[0].task_id == "file_group_0" assert result[0]._metadata["partition_index"] == 0 assert result[0]._metadata["total_partitions"] == 3 diff --git a/tests/stages/common/test_file_partitioning.py b/tests/stages/common/test_file_partitioning.py index 46ac96c34d..d67de158c9 100644 --- a/tests/stages/common/test_file_partitioning.py +++ b/tests/stages/common/test_file_partitioning.py @@ -51,7 +51,6 @@ def temp_files(self, tmp_path: Path) -> list[str]: def empty_task(self) -> _EmptyTask: """Create an empty task for testing.""" return _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={"source": "test"}, @@ -135,7 +134,6 @@ def test_process_with_file_list(self, empty_task: _EmptyTask, tmp_path: Path): assert isinstance(result[0], FileGroupTask) assert result[0].data == [test_files[0]] assert result[0].dataset_name == "path" - assert result[0].task_id == "file_group_0" def test_process_with_files_per_partition(self, empty_task: _EmptyTask, tmp_path: Path): """Test processing with files_per_partition setting.""" @@ -167,7 +165,6 @@ def test_process_with_limit(self, empty_task: _EmptyTask, tmp_path: Path): # Verify metadata for i, task in enumerate(result): - assert task.task_id == f"file_group_{i}" assert task._metadata["partition_index"] == i assert task._metadata["total_partitions"] == 5 # Total partitions before limit diff --git a/tests/stages/common/test_function_decorators.py b/tests/stages/common/test_function_decorators.py index b7bdb5fa49..5c227e2e4b 100644 --- a/tests/stages/common/test_function_decorators.py +++ b/tests/stages/common/test_function_decorators.py @@ -24,7 +24,7 @@ class MockTask(Task[int]): """Simple Task subclass for testing the decorator.""" def __init__(self, value: int = 0): - super().__init__(task_id="mock", dataset_name="test", data=value) + super().__init__(dataset_name="test", data=value) @property def num_items(self) -> int: diff --git a/tests/stages/deduplication/exact/test_identification.py b/tests/stages/deduplication/exact/test_identification.py index 5fa80faa58..ed7e3a8d75 100644 --- a/tests/stages/deduplication/exact/test_identification.py +++ b/tests/stages/deduplication/exact/test_identification.py @@ -72,7 +72,6 @@ def exact_dedup_data_parquet(tmp_path: Path) -> list[FileGroupTask]: return [ FileGroupTask( - task_id="exact_dedup_0", dataset_name="exact_dedup_dataset", data=[str(file1)], _metadata={ @@ -82,7 +81,6 @@ def exact_dedup_data_parquet(tmp_path: Path) -> list[FileGroupTask]: }, ), FileGroupTask( - task_id="exact_dedup_1", dataset_name="exact_dedup_dataset", data=[str(file2)], _metadata={ @@ -103,7 +101,6 @@ def exact_no_dedup_data_jsonl(tmp_path: Path) -> list[FileGroupTask]: return [ FileGroupTask( - task_id="no_dedup_0", dataset_name="no_dedup_dataset", data=[str(file1)], _metadata={ diff --git a/tests/stages/deduplication/exact/test_workflow.py b/tests/stages/deduplication/exact/test_workflow.py index 3beddbd841..d1ecfd7d2e 100644 --- a/tests/stages/deduplication/exact/test_workflow.py +++ b/tests/stages/deduplication/exact/test_workflow.py @@ -78,7 +78,6 @@ def exact_dedup_data_parquet(tmp_path: Path) -> list[FileGroupTask]: return [ FileGroupTask( - task_id="exact_dedup_0", dataset_name="exact_dedup_dataset", data=[str(file1)], _metadata={ @@ -88,7 +87,6 @@ def exact_dedup_data_parquet(tmp_path: Path) -> list[FileGroupTask]: }, ), FileGroupTask( - task_id="exact_dedup_1", dataset_name="exact_dedup_dataset", data=[str(file2)], _metadata={ @@ -109,7 +107,6 @@ def exact_no_dedup_data_jsonl(tmp_path: Path) -> list[FileGroupTask]: return [ FileGroupTask( - task_id="no_dedup_0", dataset_name="no_dedup_dataset", data=[str(file1)], _metadata={ diff --git a/tests/stages/deduplication/fuzzy/test_buckets_to_edges_stage.py b/tests/stages/deduplication/fuzzy/test_buckets_to_edges_stage.py index 1907d1ae56..74a3467f7b 100644 --- a/tests/stages/deduplication/fuzzy/test_buckets_to_edges_stage.py +++ b/tests/stages/deduplication/fuzzy/test_buckets_to_edges_stage.py @@ -81,7 +81,6 @@ def sample_files(tmp_path: Path, sample_bucket_data: tuple[pd.DataFrame, pd.Data def input_task(sample_files: list[str]) -> FileGroupTask: """Create a FileGroupTask from sample files.""" return FileGroupTask( - task_id="test_task", dataset_name="test_buckets", data=sample_files, _metadata={"batch_id": 0, "total_batches": 1}, @@ -163,7 +162,6 @@ def test_custom_column_name(self, tmp_path: Path) -> None: pq.write_table(table, file) input_task = FileGroupTask( - task_id="test_task_custom", dataset_name="test_buckets_custom", data=[str(file)], _metadata={"batch_id": 0, "total_batches": 1}, @@ -206,7 +204,6 @@ def test_empty_input_handling(self, tmp_path: Path) -> None: pq.write_table(table, input_file) input_task = FileGroupTask( - task_id="empty_test", dataset_name="empty_buckets", data=[str(input_file)], _metadata={}, @@ -242,7 +239,6 @@ def test_single_document_buckets(self, tmp_path: Path) -> None: pq.write_table(table, input_file) input_task = FileGroupTask( - task_id="single_doc_test", dataset_name="single_doc_buckets", data=[str(input_file)], _metadata={}, @@ -275,7 +271,6 @@ def test_large_buckets(self, tmp_path: Path) -> None: pq.write_table(table, input_file) input_task = FileGroupTask( - task_id="large_bucket_test", dataset_name="large_buckets", data=[str(input_file)], _metadata={}, diff --git a/tests/stages/deduplication/fuzzy/test_connected_components_stage.py b/tests/stages/deduplication/fuzzy/test_connected_components_stage.py index d7f48b4996..50623830a4 100644 --- a/tests/stages/deduplication/fuzzy/test_connected_components_stage.py +++ b/tests/stages/deduplication/fuzzy/test_connected_components_stage.py @@ -76,12 +76,10 @@ def input_tasks(sample_files: list[str]) -> list[FileGroupTask]: return [ FileGroupTask( dataset_name="test_edges", - task_id="edge_group_0", data=[sample_files[0]], ), FileGroupTask( dataset_name="test_edges", - task_id="edge_group_1", data=[sample_files[1]], ), ] diff --git a/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py b/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py index 57ab855402..ea52929ed6 100644 --- a/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py +++ b/tests/stages/deduplication/fuzzy/test_fuzzy_workflow.py @@ -106,11 +106,7 @@ def fuzzy_dedup_data_jsonl(tmp_path: Path) -> list[FileGroupTask]: df.iloc[3:].to_json(file2, orient="records", lines=True) files = [str(file1), str(file2)] - return [ - FileGroupTask( - task_id="file_group_0", dataset_name="test_dataset", data=files, _metadata={"source_files": files} - ) - ] + return [FileGroupTask(dataset_name="test_dataset", data=files, _metadata={"source_files": files})] @pytest.fixture @@ -125,11 +121,7 @@ def fuzzy_dedup_data_parquet(tmp_path: Path) -> list[FileGroupTask]: df.iloc[3:].to_parquet(file2) files = [str(file1), str(file2)] - return [ - FileGroupTask( - task_id="file_group_0", dataset_name="test_dataset", data=files, _metadata={"source_files": files} - ) - ] + return [FileGroupTask(dataset_name="test_dataset", data=files, _metadata={"source_files": files})] @pytest.fixture @@ -154,11 +146,7 @@ def no_duplicates_fuzzy_dedup_data(tmp_path: Path) -> list[FileGroupTask]: df.iloc[2:].to_parquet(file2) files = [str(file1), str(file2)] - return [ - FileGroupTask( - task_id="file_group_0", dataset_name="test_dataset", data=files, _metadata={"source_files": files} - ) - ] + return [FileGroupTask(dataset_name="test_dataset", data=files, _metadata={"source_files": files})] @pytest.mark.gpu diff --git a/tests/stages/deduplication/fuzzy/test_lsh_stage.py b/tests/stages/deduplication/fuzzy/test_lsh_stage.py index 33344ec8d7..a2745442c2 100644 --- a/tests/stages/deduplication/fuzzy/test_lsh_stage.py +++ b/tests/stages/deduplication/fuzzy/test_lsh_stage.py @@ -59,7 +59,6 @@ def minhash_data(self, tmp_path: Path) -> FileGroupTask: # Create FileGroupTask return FileGroupTask( - task_id="test_minhash_0", dataset_name="test_dataset", data=[minhash_file], _metadata={ @@ -193,7 +192,6 @@ def test_custom_column_names( # Create FileGroupTask minhash_data = FileGroupTask( - task_id="test_custom_cols_0", dataset_name="test_dataset", data=[minhash_file], _metadata={ @@ -262,7 +260,6 @@ def test_no_duplicates( # Create FileGroupTask minhash_data = FileGroupTask( - task_id="test_minhash_no_dup_0", dataset_name="test_dataset", data=[minhash_file], _metadata={ @@ -321,7 +318,6 @@ def test_partial_overlap( # Create FileGroupTask minhash_data = FileGroupTask( - task_id="test_minhash_partial_0", dataset_name="test_dataset", data=[minhash_file], _metadata={ diff --git a/tests/stages/deduplication/fuzzy/test_minhash_stage.py b/tests/stages/deduplication/fuzzy/test_minhash_stage.py index baedf5e48a..be67a32894 100644 --- a/tests/stages/deduplication/fuzzy/test_minhash_stage.py +++ b/tests/stages/deduplication/fuzzy/test_minhash_stage.py @@ -106,7 +106,6 @@ def input_task(sample_files: tuple[list[str], str]) -> FileGroupTask: """Create a FileGroupTask from sample files.""" files, format_type = sample_files return FileGroupTask( - task_id=f"test_task_{format_type}", dataset_name="test_dataset", data=files, _metadata={"batch_id": 0, "total_batches": 1, "format": format_type}, @@ -215,9 +214,7 @@ def test_error_handling_missing_column(self, tmp_path: Path) -> None: input_file = tmp_path / "bad_schema.jsonl" data.to_json(input_file, orient="records", lines=True) - input_task = FileGroupTask( - task_id="bad_test", dataset_name="bad_dataset", data=[str(input_file)], _metadata={} - ) + input_task = FileGroupTask(dataset_name="bad_dataset", data=[str(input_file)], _metadata={}) stage = MinHashStage( output_path=str(tmp_path / "output"), @@ -240,9 +237,7 @@ def test_empty_input_handling(self, tmp_path: Path) -> None: input_file = tmp_path / "empty.jsonl" data.to_json(input_file, orient="records", lines=True) - input_task = FileGroupTask( - task_id="empty_test", dataset_name="empty_dataset", data=[str(input_file)], _metadata={} - ) + input_task = FileGroupTask(dataset_name="empty_dataset", data=[str(input_file)], _metadata={}) stage = MinHashStage( output_path=str(tmp_path / "output"), @@ -261,9 +256,7 @@ def test_process_without_setup(self, tmp_path: Path) -> None: text_field="text", ) - input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=["dummy.jsonl"], _metadata={} - ) + input_task = FileGroupTask(dataset_name="test_dataset", data=["dummy.jsonl"], _metadata={}) # Should raise error because setup wasn't called with pytest.raises(RuntimeError, match="MinHash processor or ID generator not initialized"): @@ -287,9 +280,7 @@ def test_large_text_handling(self, tmp_path: Path) -> None: input_file = tmp_path / "large_texts.jsonl" data.to_json(input_file, orient="records", lines=True) - input_task = FileGroupTask( - task_id="large_test", dataset_name="large_dataset", data=[str(input_file)], _metadata={} - ) + input_task = FileGroupTask(dataset_name="large_dataset", data=[str(input_file)], _metadata={}) stage = MinHashStage( output_path=str(tmp_path / "output"), @@ -329,9 +320,7 @@ def test_special_characters_and_unicode(self, tmp_path: Path) -> None: input_file = tmp_path / "special_chars.jsonl" data.to_json(input_file, orient="records", lines=True) - input_task = FileGroupTask( - task_id="special_test", dataset_name="special_dataset", data=[str(input_file)], _metadata={} - ) + input_task = FileGroupTask(dataset_name="special_dataset", data=[str(input_file)], _metadata={}) stage = MinHashStage( output_path=str(tmp_path / "output"), @@ -376,12 +365,8 @@ def test_setup_idempotency(self, tmp_path: Path) -> None: data = pd.DataFrame({"text": ["Document 1", "Document 2", "Document 3"]}) data.to_json(input_file1, orient="records", lines=True) data.to_json(input_file2, orient="records", lines=True) - input_task1 = FileGroupTask( - task_id="setup_test_1", dataset_name="setup_dataset_1", data=[str(input_file1)], _metadata={} - ) - input_task2 = FileGroupTask( - task_id="setup_test_2", dataset_name="setup_dataset_2", data=[str(input_file2)], _metadata={} - ) + input_task1 = FileGroupTask(dataset_name="setup_dataset_1", data=[str(input_file1)], _metadata={}) + input_task2 = FileGroupTask(dataset_name="setup_dataset_2", data=[str(input_file2)], _metadata={}) # Setup and process first batch stage1.setup() diff --git a/tests/stages/deduplication/semantic/test_identify_duplicates.py b/tests/stages/deduplication/semantic/test_identify_duplicates.py index ab8f91e212..f01fba38a8 100644 --- a/tests/stages/deduplication/semantic/test_identify_duplicates.py +++ b/tests/stages/deduplication/semantic/test_identify_duplicates.py @@ -83,7 +83,7 @@ def test_identify_duplicates_various_cases(self, tmp_path: Path) -> None: } self.create_test_similarity_file(str(single_file), single_data) - task_single = FileGroupTask(task_id="single", dataset_name="test", data=[str(single_file)]) + task_single = FileGroupTask(dataset_name="test", data=[str(single_file)]) result_single = stage.process_batch([task_single]) assert len(result_single) == 1 result_df = pd.read_parquet(result_single[0].data[0]) @@ -98,7 +98,7 @@ def test_identify_duplicates_various_cases(self, tmp_path: Path) -> None: } self.create_test_similarity_file(str(no_similar_file), no_similar_data) - task_no_similar = FileGroupTask(task_id="no_similar", dataset_name="test", data=[str(no_similar_file)]) + task_no_similar = FileGroupTask(dataset_name="test", data=[str(no_similar_file)]) result_no_similar = stage.process_batch([task_no_similar]) assert len(result_no_similar) == 1 result_df = pd.read_parquet(result_no_similar[0].data[0]) @@ -115,7 +115,7 @@ def test_identify_duplicates_various_cases(self, tmp_path: Path) -> None: # Strict epsilon (0.01) - threshold = 0.99, should get 0 results stage_strict = IdentifyDuplicatesStage(output_path=str(output_dir), eps=0.01, verbose=True) - task_eps = FileGroupTask(task_id="eps_test", dataset_name="test", data=[str(eps_test_file)]) + task_eps = FileGroupTask(dataset_name="test", data=[str(eps_test_file)]) result_strict = stage_strict.process_batch([task_eps]) result_df = pd.read_parquet(result_strict[0].data[0]) assert len(result_df) == 0, "Strict epsilon should return 0 results" @@ -140,7 +140,7 @@ def test_identify_duplicates_various_cases(self, tmp_path: Path) -> None: } self.create_test_similarity_file(str(basic_file), basic_data) - task_basic = FileGroupTask(task_id="basic", dataset_name="test", data=[str(basic_file)]) + task_basic = FileGroupTask(dataset_name="test", data=[str(basic_file)]) result_basic = stage.process_batch([task_basic]) assert len(result_basic) == 1 output_file = result_basic[0].data[0] @@ -178,9 +178,8 @@ def test_identify_duplicates_stage_batch_processing(self, tmp_path: Path) -> Non # Create tasks for each cluster tasks = [] - for i, file_path in enumerate(cluster_files): + for _i, file_path in enumerate(cluster_files): task = FileGroupTask( - task_id=f"test_batch_{i}", dataset_name="test", data=[file_path], ) @@ -221,7 +220,6 @@ def test_identify_duplicates_stage_custom_row_groups(self, tmp_path: Path) -> No ) task = FileGroupTask( - task_id="test_custom_row_groups", dataset_name="test", data=[str(input_file)], ) diff --git a/tests/stages/deduplication/semantic/test_kmeans.py b/tests/stages/deduplication/semantic/test_kmeans.py index 8bc9573fb2..2af3667b64 100644 --- a/tests/stages/deduplication/semantic/test_kmeans.py +++ b/tests/stages/deduplication/semantic/test_kmeans.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from contextlib import suppress from pathlib import Path from typing import Literal @@ -241,39 +242,38 @@ def test_output_columns(self) -> None: assert cosine_dtype == np.float32, f"Cosine distance should be float, got {cosine_dtype}" def test_output_filenames_and_structure(self) -> None: - """Test that the output files are created with exact expected filenames and partitioning. - - Each actor (we should have two GPU actors) writes files with predictable names: {tasks[0]._uuid}_{subgroup_index}.parquet - Since our test data is small, each actor creates 1 subgroup, so files are named {uuid}_0.parquet + """Output files are written with deterministic, input-derived names and + partitioned by centroid. + + Each GPU actor writes ``{input_task_id}_{subgroup}.parquet`` where the + input task id is the FilePartitioning id (``0_``). + We assert the names match that deterministic pattern (never a random + ``r`` fallback) and that the centroid partitioning is correct. + + Note: the pipeline's result tasks are terminal ``_EmptyTask`` signals + whose ids are framework-assigned (and, for this aggregating stage, the + non-deterministic ``r`` fallback) — they are intentionally NOT + tied to the output filenames, which are derived from the input ids. """ - # Get the expected filenames from pipeline results - # The pipeline returns EmptyTasks with task_id = output_filename = f"{tasks[0]._uuid}_{i}" - expected_filenames = set() - for result_task in self.pipeline_results: - expected_filename = f"{result_task.task_id}.parquet" - expected_filenames.add(expected_filename) - - # Should have exactly 2 result tasks (one per actor) - assert len(expected_filenames) == 2, f"Expected 2 result tasks/filenames, got {len(expected_filenames)}" - - # Collect all actual filenames across all partitions - actual_filenames = set() - centroid_dirs = list(self.output_dir.glob("centroid=*")) + # One terminal result task per actor. + assert len(self.pipeline_results) == 2, f"Expected 2 result tasks, got {len(self.pipeline_results)}" - # Collect filenames from all centroid partitions - for centroid_dir in centroid_dirs: - partition_files = list(centroid_dir.glob("*.parquet")) - for file in partition_files: - actual_filenames.add(file.name) - - # Verify that all expected filenames are present - assert actual_filenames == expected_filenames, ( - f"Expected filenames {expected_filenames}, but found {actual_filenames}. " - f"Missing: {expected_filenames - actual_filenames}, " - f"Extra: {actual_filenames - expected_filenames}" - ) + # Collect all output filenames across centroid partitions. (The same + # file name appears under each centroid=* dir, so dedupe into a set.) + centroid_dirs = list(self.output_dir.glob("centroid=*")) + actual_filenames = {f.name for d in centroid_dirs for f in d.glob("*.parquet")} + + # Two distinct output files (one per actor), each deterministically + # named from its input partition's id: "0__". + assert len(actual_filenames) == 2, f"Expected 2 distinct output files, got {actual_filenames}" + deterministic_name = re.compile(r"^0_[0-9a-f]+_\d+\.parquet$") + for name in actual_filenames: + assert deterministic_name.match(name), ( + f"Output filename {name!r} is not deterministic/input-derived " + f"(an 'r' name would mean ancestry was lost)" + ) - # Verify we have the expected number of centroid partitions (should be exactly N_CLUSTERS) + # Exactly N_CLUSTERS centroid partitions. assert len(centroid_dirs) == N_CLUSTERS, ( f"Expected exactly {N_CLUSTERS} centroid partitions, got {len(centroid_dirs)}" ) @@ -382,7 +382,6 @@ def test_process_batch_read_paths( all_files = [str(input_dir / f"file_{i}.parquet") for i in range(4)] all_tasks = [ FileGroupTask( - task_id=f"test_task_{i}", dataset_name="test_dataset", data=[file], ) @@ -401,7 +400,6 @@ def test_process_batch_read_paths( all_files = [str(input_file)] all_tasks = [ FileGroupTask( - task_id="test_task_jsonl", dataset_name="test_dataset", data=[str(input_file)], ) @@ -538,7 +536,7 @@ def test_fit_data_fraction_validation(self, tmp_path: Path, bad_fraction: float) def test_process_batch_routes_by_fit_data_fraction(self, make_stage: "KMeansReadFitWriteStage") -> None: """fit_data_fraction=None -> single-pass; fraction set -> two-pass.""" # Use jsonl so process_batch skips break_parquet_partition_into_groups (which reads metadata). - task = FileGroupTask(task_id="t", dataset_name="d", data=["x.jsonl"]) + task = FileGroupTask(dataset_name="d", data=["x.jsonl"]) for fraction, expect_single in [(None, True), (0.5, False)]: stage = make_stage(fit_data_fraction=fraction, filetype="jsonl") @@ -628,7 +626,7 @@ def test_predict_write_pass_reads_every_group(self, make_stage: "KMeansReadFitWr ] df = cudf.DataFrame({"id": [0, 1], "embeddings": [[1.0, 0.0], [0.0, 1.0]]}) stage.kmeans.predict = Mock(return_value=cp.zeros(len(df), dtype=cp.int32)) - tasks = [FileGroupTask(task_id="t0", dataset_name="d", data=["any.parquet"])] + tasks = [FileGroupTask(dataset_name="d", data=["any.parquet"])] with ( patch.object(stage, "_read_group", return_value=df) as mock_read, patch.object(stage, "write_parquet"), @@ -687,7 +685,7 @@ def test_single_pass_saves_centroids(self, tmp_path: Path, make_stage: "KMeansRe stage = make_stage(fit_data_fraction=None, cache_path=str(cache_path)) df = cudf.DataFrame({"id": [0, 1], "embeddings": [[1.0, 0.0], [0.0, 1.0]]}) stage.kmeans.predict = Mock(return_value=cp.zeros(len(df), dtype=cp.int32)) - tasks = [FileGroupTask(task_id="t", dataset_name="d", data=["any.parquet"])] + tasks = [FileGroupTask(dataset_name="d", data=["any.parquet"])] with ( patch.object(stage, "_read_group", return_value=df), patch.object(stage, "write_parquet"), diff --git a/tests/stages/deduplication/semantic/test_pairwise.py b/tests/stages/deduplication/semantic/test_pairwise.py index e99a409749..96acdcafc2 100644 --- a/tests/stages/deduplication/semantic/test_pairwise.py +++ b/tests/stages/deduplication/semantic/test_pairwise.py @@ -120,7 +120,6 @@ def test_single_item_cluster(self, tmp_path: Path) -> None: # Create task task = FileGroupTask( - task_id="test_single", dataset_name="test", data=[str(input_file)], _metadata={"centroid_id": 0, "filetype": "parquet"}, @@ -188,7 +187,6 @@ def test_multi_item_cluster(self, mock_break_into_groups: patch, tmp_path: Path) # Create task task = FileGroupTask( - task_id="test_multi", dataset_name="test", data=[str(input_file)], _metadata={"centroid_id": 1, "filetype": "parquet"}, @@ -265,7 +263,6 @@ def test_pairwise_stage_with_custom_metadata_ranking(self, tmp_path: Path) -> No # Create task task = FileGroupTask( - task_id="test_custom_ranked", dataset_name="test", data=[str(input_file)], _metadata={"centroid_id": 1, "filetype": "parquet"}, @@ -327,7 +324,6 @@ def test_pairwise_stage_ranking_fails_on_missing_columns(self, tmp_path: Path) - # Create task task = FileGroupTask( - task_id="test_fail_missing_cols", dataset_name="test", data=[str(input_file)], _metadata={"centroid_id": 2, "filetype": "parquet"}, @@ -506,7 +502,6 @@ def _run_pairwise_stage_test(self, tmp_path: Path, ranking_kwargs: dict) -> tupl # Create task task = FileGroupTask( - task_id="test_workflow", dataset_name="test", data=[str(input_file) for input_file in input_files], _metadata={"centroid_id": 0, "filetype": "parquet"}, diff --git a/tests/stages/deduplication/semantic/test_pairwise_io.py b/tests/stages/deduplication/semantic/test_pairwise_io.py index e8d16b1d4d..ee1e86af4e 100644 --- a/tests/stages/deduplication/semantic/test_pairwise_io.py +++ b/tests/stages/deduplication/semantic/test_pairwise_io.py @@ -80,7 +80,7 @@ def test_process_finds_all_centroid_files(self, tmp_path: Path): mock_path_normalizer = Mock(side_effect=lambda x: x) stage.path_normalizer = mock_path_normalizer - empty_task = _EmptyTask(task_id="test", dataset_name="test", data=None) + empty_task = _EmptyTask(dataset_name="test", data=None) result = stage.process(empty_task) # Verify path_normalizer was called exactly 3 times (once per centroid directory) @@ -99,14 +99,11 @@ def test_process_finds_all_centroid_files(self, tmp_path: Path): result.sort(key=lambda x: x._metadata["centroid_id"]) # Check each task - assert result[0].task_id == "pairwise_centroid_0" assert result[0]._metadata == {"centroid_id": 0, "filetype": "parquet"} assert result[0].data == [str(centroid_0_dir / "file1.parquet"), str(centroid_0_dir / "file2.parquet")] - assert result[1].task_id == "pairwise_centroid_1" assert result[1]._metadata == {"centroid_id": 1, "filetype": "parquet"} assert result[1].data == [str(centroid_1_dir / "file3.parquet")] - assert result[2].task_id == "pairwise_centroid_2" assert result[2]._metadata == {"centroid_id": 2, "filetype": "parquet"} assert result[2].data == [str(centroid_2_dir / "file4.parquet"), str(centroid_2_dir / "file5.parquet")] diff --git a/tests/stages/deduplication/shuffle_utils/test_shuffle_stage.py b/tests/stages/deduplication/shuffle_utils/test_shuffle_stage.py index 74bae750b3..ce833c93d9 100644 --- a/tests/stages/deduplication/shuffle_utils/test_shuffle_stage.py +++ b/tests/stages/deduplication/shuffle_utils/test_shuffle_stage.py @@ -74,7 +74,6 @@ def test_data(self, tmp_path: Path) -> list[FileGroupTask]: tasks.append( FileGroupTask( - task_id=f"test_data_{i}", dataset_name="test_dataset", data=[test_file], _metadata={ diff --git a/tests/stages/image/dedup/test_dedup_filter.py b/tests/stages/image/dedup/test_dedup_filter.py index c451d98550..8e872bef97 100644 --- a/tests/stages/image/dedup/test_dedup_filter.py +++ b/tests/stages/image/dedup/test_dedup_filter.py @@ -31,7 +31,7 @@ def _write_parquet_ids(tmpdir: Path, filename: str, ids: list[str], id_column: s def _make_batch(ids: list[str]) -> ImageBatch: images = [ImageObject(image_id=i) for i in ids] - return ImageBatch(task_id="t0", dataset_name="ds", data=images) + return ImageBatch(dataset_name="ds", data=images) def test_setup_raises_when_no_parquet(tmp_path: Path) -> None: @@ -50,7 +50,6 @@ def test_filters_with_default_id_column(tmp_path: Path) -> None: kept_ids = [img.image_id for img in out.data] assert kept_ids == ["img1", "img3"] - assert out.task_id.endswith(stage._name) assert out.dataset_name == batch.dataset_name diff --git a/tests/stages/image/embedders/test_clip_embedder.py b/tests/stages/image/embedders/test_clip_embedder.py index 47b24de062..e6fdc98f6e 100644 --- a/tests/stages/image/embedders/test_clip_embedder.py +++ b/tests/stages/image/embedders/test_clip_embedder.py @@ -30,11 +30,7 @@ class TestImageEmbeddingStage: @pytest.fixture def stage(self) -> ImageEmbeddingStage: """Create a test stage instance.""" - return ImageEmbeddingStage( - model_dir="test_models/clip", - model_inference_batch_size=2, - verbose=True - ) + return ImageEmbeddingStage(model_dir="test_models/clip", model_inference_batch_size=2, verbose=True) @pytest.fixture def mock_model(self) -> Mock: @@ -53,34 +49,30 @@ def sample_image_objects(self) -> list[ImageObject]: ImageObject( image_id="img_001", image_path="/path/to/img1.jpg", - image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8) + image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8), ), ImageObject( image_id="img_002", image_path="/path/to/img2.jpg", - image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8) + image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8), ), ImageObject( image_id="img_003", image_path="/path/to/img3.jpg", - image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8) + image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8), ), ImageObject( image_id="img_004", image_path="/path/to/img4.jpg", - image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8) - ) + image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8), + ), ] @pytest.fixture def sample_image_batch(self, sample_image_objects: list[ImageObject]) -> ImageBatch: """Create a sample ImageBatch.""" return ImageBatch( - data=sample_image_objects, - dataset_name="test_dataset", - task_id="test_task_001", - _metadata={"test": "metadata"}, - _stage_perf={} + data=sample_image_objects, dataset_name="test_dataset", _metadata={"test": "metadata"}, _stage_perf={} ) def test_stage_properties(self, stage: ImageEmbeddingStage) -> None: @@ -110,9 +102,8 @@ def test_setup(self, mock_clip_embeddings: Mock, stage: ImageEmbeddingStage) -> mock_clip_embeddings.assert_called_once() call_args, call_kwargs = mock_clip_embeddings.call_args - assert ( - (len(call_args) >= 1 and call_args[0] == "test_models/clip") - or (call_kwargs.get("model_dir") == "test_models/clip") + assert (len(call_args) >= 1 and call_args[0] == "test_models/clip") or ( + call_kwargs.get("model_dir") == "test_models/clip" ) mock_model.setup.assert_called_once() assert stage.model == mock_model @@ -183,13 +174,15 @@ def test_batch_processing( rng = np.random.default_rng(42) images = [] for i in range(5): - images.append(ImageObject( - image_id=f"img_{i:03d}", - image_path=f"/path/to/img{i}.jpg", - image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8) - )) + images.append( + ImageObject( + image_id=f"img_{i:03d}", + image_path=f"/path/to/img{i}.jpg", + image_data=rng.integers(0, 255, (224, 224, 3), dtype=np.uint8), + ) + ) - batch = ImageBatch(data=images, task_id="test_batch", dataset_name="test_dataset") + batch = ImageBatch(data=images, dataset_name="test_dataset") # Mock processor to return appropriate tensor sizes # The processor returns the same structure regardless of input size @@ -211,7 +204,7 @@ def test_batch_processing( @patch("nemo_curator.stages.image.embedders.clip_embedder.CLIPImageEmbeddings") def test_empty_batch(self, mock_clip_embeddings: Mock, stage: ImageEmbeddingStage) -> None: """Test processing empty image batch.""" - empty_batch = ImageBatch(data=[], task_id="empty_test", dataset_name="test_dataset") + empty_batch = ImageBatch(data=[], dataset_name="test_dataset") mock_clip_embeddings.return_value = Mock() stage.setup() @@ -241,14 +234,17 @@ def test_verbose_logging( stage.process(sample_image_batch) # Should log embedding generation - embedding_calls = [call for call in mock_logger.info.call_args_list - if "Generated embeddings for" in str(call)] + embedding_calls = [call for call in mock_logger.info.call_args_list if "Generated embeddings for" in str(call)] assert len(embedding_calls) > 0 @patch("nemo_curator.stages.image.embedders.clip_embedder.CLIPImageEmbeddings") @patch("transformers.CLIPProcessor.from_pretrained") def test_preserves_other_image_attributes( - self, mock_processor: Mock, mock_clip_embeddings: Mock, stage: ImageEmbeddingStage, sample_image_batch: ImageBatch + self, + mock_processor: Mock, + mock_clip_embeddings: Mock, + stage: ImageEmbeddingStage, + sample_image_batch: ImageBatch, ) -> None: """Test that processing preserves other image attributes.""" mock_clip_embeddings.return_value = Mock() @@ -332,9 +328,8 @@ def test_processor_integration( # Verify the model was instantiated and setup was called mock_clip_embeddings.assert_called_once() call_args, call_kwargs = mock_clip_embeddings.call_args - assert ( - (len(call_args) >= 1 and call_args[0] == "test_models/clip") - or (call_kwargs.get("model_dir") == "test_models/clip") + assert (len(call_args) >= 1 and call_args[0] == "test_models/clip") or ( + call_kwargs.get("model_dir") == "test_models/clip" ) mock_model_instance.setup.assert_called_once() @@ -364,16 +359,16 @@ def test_embedding_shape_consistency(self, stage: ImageEmbeddingStage) -> None: ImageObject( image_id="small_img", image_path="/path/to/small.jpg", - image_data=rng.integers(0, 255, (100, 100, 3), dtype=np.uint8) + image_data=rng.integers(0, 255, (100, 100, 3), dtype=np.uint8), ), ImageObject( image_id="large_img", image_path="/path/to/large.jpg", - image_data=rng.integers(0, 255, (500, 500, 3), dtype=np.uint8) - ) + image_data=rng.integers(0, 255, (500, 500, 3), dtype=np.uint8), + ), ] - batch = ImageBatch(data=different_sized_images, task_id="shape_test", dataset_name="test_dataset") + batch = ImageBatch(data=different_sized_images, dataset_name="test_dataset") # Mock consistent outputs regardless of input size mock_processor_instance.return_value = {"pixel_values": torch.randn(2, 3, 224, 224)} @@ -421,7 +416,7 @@ def __call__(self, batch_numpy: np.ndarray | list[np.ndarray]) -> torch.Tensor: ) for i in range(4) ] - batch = ImageBatch(data=images, dataset_name="ds", task_id="t0") + batch = ImageBatch(data=images, dataset_name="ds") stage = ImageEmbeddingStage(model_dir="/does/not/matter", model_inference_batch_size=2, verbose=False) with patch( diff --git a/tests/stages/image/filters/test_aesthetic_filter.py b/tests/stages/image/filters/test_aesthetic_filter.py index 3c39d1cfc8..88d9c1cb0c 100644 --- a/tests/stages/image/filters/test_aesthetic_filter.py +++ b/tests/stages/image/filters/test_aesthetic_filter.py @@ -80,7 +80,6 @@ def sample_image_batch(self, sample_image_objects: list[ImageObject]) -> ImageBa return ImageBatch( data=sample_image_objects, dataset_name="test_dataset", - task_id="test_task_001", _metadata={"test": "metadata"}, _stage_perf={}, ) @@ -144,9 +143,6 @@ def test_process_filtering( elif img.image_id == "img_004": assert abs(img.aesthetic_score - 0.8) < 1e-5 - # Check that the task has updated ID - assert result.task_id == f"{sample_image_batch.task_id}_{stage.name}" - @patch("nemo_curator.stages.image.filters.aesthetic_filter.AestheticScorer") def test_threshold_variations( self, mock_aesthetic_scorer: Mock, sample_image_batch: ImageBatch, mock_model: Mock @@ -204,7 +200,7 @@ def test_verbose_logging( @patch("nemo_curator.stages.image.filters.aesthetic_filter.AestheticScorer") def test_empty_batch(self, mock_aesthetic_scorer: Mock, stage: ImageAestheticFilterStage) -> None: """Test processing empty image batch.""" - empty_batch = ImageBatch(data=[], task_id="empty_test", dataset_name="test_dataset") + empty_batch = ImageBatch(data=[], dataset_name="test_dataset") mock_aesthetic_scorer.return_value = Mock() stage.setup() @@ -229,7 +225,7 @@ def test_no_embeddings(self, mock_aesthetic_scorer: Mock) -> None: ) ] - batch = ImageBatch(data=images_no_embeddings, task_id="no_embed_test", dataset_name="test_dataset") + batch = ImageBatch(data=images_no_embeddings, dataset_name="test_dataset") mock_aesthetic_scorer.return_value = Mock() stage = ImageAestheticFilterStage( @@ -402,9 +398,7 @@ def test_large_batch_processing(self, sample_image_batch: ImageBatch) -> None: """Test processing with many images.""" # Create a larger batch by replicating existing images large_data = sample_image_batch.data * 25 # 100 images - large_batch = ImageBatch( - data=large_data, dataset_name=sample_image_batch.dataset_name, task_id=sample_image_batch.task_id - ) + large_batch = ImageBatch(data=large_data, dataset_name=sample_image_batch.dataset_name) stage = ImageAestheticFilterStage( model_dir="test_models/aesthetics", score_threshold=0.5, model_inference_batch_size=10 @@ -516,7 +510,7 @@ def __call__(self, embeddings_numpy: np.ndarray) -> torch.Tensor: ) for i in range(6) ] - batch = ImageBatch(data=images, dataset_name="ds", task_id="t0") + batch = ImageBatch(data=images, dataset_name="ds") stage = ImageAestheticFilterStage(model_dir="/unused", model_inference_batch_size=3, score_threshold=0.3) diff --git a/tests/stages/image/filters/test_nsfw_filter.py b/tests/stages/image/filters/test_nsfw_filter.py index a99c3da364..41d4a9960f 100644 --- a/tests/stages/image/filters/test_nsfw_filter.py +++ b/tests/stages/image/filters/test_nsfw_filter.py @@ -30,11 +30,7 @@ class TestImageNSFWFilterStage: @pytest.fixture def stage(self) -> ImageNSFWFilterStage: """Create a test stage instance.""" - return ImageNSFWFilterStage( - model_dir="test_models/nsfw", - score_threshold=0.5, - model_inference_batch_size=2 - ) + return ImageNSFWFilterStage(model_dir="test_models/nsfw", score_threshold=0.5, model_inference_batch_size=2) @pytest.fixture def mock_model(self) -> Mock: @@ -80,11 +76,7 @@ def sample_image_objects(self) -> list[ImageObject]: def sample_image_batch(self, sample_image_objects: list[ImageObject]) -> ImageBatch: """Create a sample ImageBatch.""" return ImageBatch( - data=sample_image_objects, - dataset_name="test_dataset", - task_id="test_task_001", - _metadata={"test": "metadata"}, - _stage_perf={} + data=sample_image_objects, dataset_name="test_dataset", _metadata={"test": "metadata"}, _stage_perf={} ) def test_stage_properties(self, stage: ImageNSFWFilterStage) -> None: @@ -126,7 +118,7 @@ def test_process_filtering( # So keep img1 (0.3), img3 (0.2), filter out img2 (0.7), img4 (0.8) mock_model.side_effect = [ torch.tensor([0.3, 0.7]), # First batch - torch.tensor([0.2, 0.8]) # Second batch + torch.tensor([0.2, 0.8]), # Second batch ] result = stage.process(sample_image_batch) @@ -148,9 +140,6 @@ def test_process_filtering( elif img.image_id == "img_003": assert abs(img.nsfw_score - 0.2) < 1e-5 - # Check that the task has updated ID - assert result.task_id == f"{sample_image_batch.task_id}_{stage.name}" - @patch("nemo_curator.stages.image.filters.nsfw_filter.NSFWScorer") def test_process_high_nsfw_filtering( self, @@ -167,7 +156,7 @@ def test_process_high_nsfw_filtering( # All images have high NSFW scores (above threshold) mock_model.side_effect = [ torch.tensor([0.8, 0.9]), # First batch - torch.tensor([0.7, 0.6]) # Second batch + torch.tensor([0.7, 0.6]), # Second batch ] result = stage.process(sample_image_batch) @@ -206,17 +195,13 @@ def test_threshold_boundary_cases( """Test boundary cases at threshold.""" mock_nsfw_scorer.return_value = mock_model - stage = ImageNSFWFilterStage( - model_dir="test_models/nsfw", - score_threshold=0.5, - model_inference_batch_size=2 - ) + stage = ImageNSFWFilterStage(model_dir="test_models/nsfw", score_threshold=0.5, model_inference_batch_size=2) stage.setup() # Test scores around threshold (0.5) mock_model.side_effect = [ - torch.tensor([0.5, 0.49]), # First batch: exactly at and just below - torch.tensor([0.51, 0.499]) # Second batch: just above and just below + torch.tensor([0.5, 0.49]), # First batch: exactly at and just below + torch.tensor([0.51, 0.499]), # Second batch: just above and just below ] result = stage.process(sample_image_batch) @@ -236,24 +221,19 @@ def test_all_images_filtered( """Test when all images are filtered out.""" mock_nsfw_scorer.return_value = mock_model - stage = ImageNSFWFilterStage( - model_dir="test_models/nsfw", - score_threshold=0.5, - model_inference_batch_size=2 - ) + stage = ImageNSFWFilterStage(model_dir="test_models/nsfw", score_threshold=0.5, model_inference_batch_size=2) stage.setup() # All high NSFW scores mock_model.side_effect = [ torch.tensor([0.9, 0.8]), # First batch - torch.tensor([0.7, 0.6]) # Second batch + torch.tensor([0.7, 0.6]), # Second batch ] result = stage.process(sample_image_batch) assert len(result.data) == 0 assert result.dataset_name == sample_image_batch.dataset_name - assert result.task_id == f"{sample_image_batch.task_id}_{stage.name}" @patch("nemo_curator.stages.image.filters.nsfw_filter.NSFWScorer") def test_no_images_filtered( @@ -265,17 +245,13 @@ def test_no_images_filtered( """Test when no images are filtered out.""" mock_nsfw_scorer.return_value = mock_model - stage = ImageNSFWFilterStage( - model_dir="test_models/nsfw", - score_threshold=0.5, - model_inference_batch_size=2 - ) + stage = ImageNSFWFilterStage(model_dir="test_models/nsfw", score_threshold=0.5, model_inference_batch_size=2) stage.setup() # All low NSFW scores mock_model.side_effect = [ torch.tensor([0.1, 0.2]), # First batch - torch.tensor([0.3, 0.4]) # Second batch + torch.tensor([0.3, 0.4]), # Second batch ] result = stage.process(sample_image_batch) @@ -326,27 +302,26 @@ def test_verbose_logging( model_dir="test_models/nsfw", score_threshold=0.5, model_inference_batch_size=2, # Match the mock data structure - verbose=True + verbose=True, ) verbose_stage.setup() verbose_stage.model = mock_model mock_model.side_effect = [ torch.tensor([0.3, 0.7]), # First batch: one pass, one fail - torch.tensor([0.2, 0.8]) # Second batch: one pass, one fail + torch.tensor([0.2, 0.8]), # Second batch: one pass, one fail ] verbose_stage.process(sample_image_batch) # Should log filtering results - filtering_calls = [call for call in mock_logger.info.call_args_list - if "NSFW" in str(call)] + filtering_calls = [call for call in mock_logger.info.call_args_list if "NSFW" in str(call)] assert len(filtering_calls) > 0 @patch("nemo_curator.stages.image.filters.nsfw_filter.NSFWScorer") def test_empty_batch(self, mock_nsfw_scorer: Mock, stage: ImageNSFWFilterStage) -> None: """Test processing empty image batch.""" - empty_batch = ImageBatch(data=[], task_id="empty_test", dataset_name="test_dataset") + empty_batch = ImageBatch(data=[], dataset_name="test_dataset") mock_nsfw_scorer.return_value = Mock() stage.setup() @@ -385,10 +360,12 @@ def __call__(self, embeddings_numpy: np.ndarray) -> torch.Tensor: tmp_dir = tempfile.gettempdir() images = [ - ImageObject(image_id=f"img_{i}", image_path=f"{tmp_dir}/{i}.jpg", embedding=rng.normal(size=(8,)).astype(np.float32)) + ImageObject( + image_id=f"img_{i}", image_path=f"{tmp_dir}/{i}.jpg", embedding=rng.normal(size=(8,)).astype(np.float32) + ) for i in range(6) ] - batch = ImageBatch(data=images, dataset_name="ds", task_id="t0") + batch = ImageBatch(data=images, dataset_name="ds") stage = ImageNSFWFilterStage(model_dir="/unused", model_inference_batch_size=3, score_threshold=0.5) diff --git a/tests/stages/image/io/test_convert.py b/tests/stages/image/io/test_convert.py index 8733564d5f..766a3f543f 100644 --- a/tests/stages/image/io/test_convert.py +++ b/tests/stages/image/io/test_convert.py @@ -47,7 +47,6 @@ def image_batch_with_embeddings(rng: np.random.Generator, tmp_path: Path) -> Ima return ImageBatch( data=images, dataset_name="ds_test", - task_id="task_123", _metadata={"foo": "bar"}, _stage_perf={"stage": 1.23}, ) @@ -64,7 +63,6 @@ def test_default_fields_outputs_image_id_only(self, image_batch_with_embeddings: assert df["image_id"].tolist() == [img.image_id for img in image_batch_with_embeddings.data] # Metadata and identifiers preserved - assert out.task_id == f"{image_batch_with_embeddings.task_id}_{stage.name}" assert out.dataset_name == image_batch_with_embeddings.dataset_name assert out._metadata == image_batch_with_embeddings._metadata assert out._stage_perf == image_batch_with_embeddings._stage_perf @@ -86,7 +84,7 @@ def test_custom_fields_include_embeddings(self, image_batch_with_embeddings: Ima def test_empty_input_default_fields(self) -> None: stage = ConvertImageBatchToDocumentBatchStage() - empty_batch = ImageBatch(data=[], dataset_name="ds", task_id="t0") + empty_batch = ImageBatch(data=[], dataset_name="ds") out = stage.process(empty_batch) df = out.to_pandas() assert isinstance(out, DocumentBatch) @@ -95,7 +93,7 @@ def test_empty_input_default_fields(self) -> None: def test_empty_input_with_custom_fields(self) -> None: stage = ConvertImageBatchToDocumentBatchStage(fields=["image_id", "embedding", "image_path"]) - empty_batch = ImageBatch(data=[], dataset_name="ds", task_id="t0") + empty_batch = ImageBatch(data=[], dataset_name="ds") out = stage.process(empty_batch) df = out.to_pandas() assert list(df.columns) == ["image_id", "embedding", "image_path"] @@ -115,7 +113,7 @@ def test_missing_attribute_yields_none(self, rng: np.random.Generator, tmp_path: image_data=rng.integers(0, 255, (8, 8, 3), dtype=np.uint8), ), ] - batch = ImageBatch(data=images, dataset_name="dsx", task_id="t1") + batch = ImageBatch(data=images, dataset_name="dsx") stage = ConvertImageBatchToDocumentBatchStage(fields=["image_id", "embedding"]) # 'embedding' may be missing out = stage.process(batch) df = out.to_pandas() diff --git a/tests/stages/image/io/test_image_reader.py b/tests/stages/image/io/test_image_reader.py index 424d34ef5f..e23ceea45e 100644 --- a/tests/stages/image/io/test_image_reader.py +++ b/tests/stages/image/io/test_image_reader.py @@ -35,9 +35,7 @@ class _FakeTensorList: """Minimal stand-in for a DALI TensorList returned by Pipeline.run().""" def __init__(self, batch_size: int, height: int = 8, width: int = 8) -> None: - self._arrays: list[np.ndarray] = [ - np.zeros((height, width, 3), dtype=np.uint8) for _ in range(batch_size) - ] + self._arrays: list[np.ndarray] = [np.zeros((height, width, 3), dtype=np.uint8) for _ in range(batch_size)] def as_cpu(self) -> _FakeTensorList: return self @@ -121,8 +119,10 @@ class _Types: sys.modules["nvidia.dali"] = dali sys.modules["nvidia.dali.pipeline"] = pipeline + def test_inputs_outputs_and_name() -> None: from nemo_curator.stages.image.io.image_reader import ImageReaderStage + with patch("torch.cuda.is_available", return_value=True): stage = ImageReaderStage(dali_batch_size=3, verbose=False) assert stage.inputs() == ([], []) @@ -132,6 +132,7 @@ def test_inputs_outputs_and_name() -> None: def test_init_allows_cpu_when_no_cuda() -> None: from nemo_curator.stages.image.io.image_reader import ImageReaderStage + # When CUDA is unavailable, the stage should initialize and use CPU DALI with patch("torch.cuda.is_available", return_value=False): stage = ImageReaderStage(dali_batch_size=2, verbose=False) @@ -140,9 +141,9 @@ def test_init_allows_cpu_when_no_cuda() -> None: def test_process_streams_batches_from_dali() -> None: from nemo_curator.stages.image.io.image_reader import ImageReaderStage + # Two tar files; each has 5 total samples, emitted in batches of 2 (2,2,1) task = FileGroupTask( - task_id="t1", dataset_name="ds", data=["/data/a.tar", "/data/b.tar"], ) @@ -168,7 +169,8 @@ def test_process_streams_batches_from_dali() -> None: def test_process_raises_on_empty_task() -> None: from nemo_curator.stages.image.io.image_reader import ImageReaderStage - empty = FileGroupTask(task_id="e1", dataset_name="ds", data=[]) + + empty = FileGroupTask(dataset_name="ds", data=[]) with patch("torch.cuda.is_available", return_value=True): stage = ImageReaderStage(dali_batch_size=2, verbose=False) @@ -177,9 +179,9 @@ def test_process_raises_on_empty_task() -> None: stage.process(empty) - def test_resources_with_cuda_available() -> None: from nemo_curator.stages.image.io.image_reader import ImageReaderStage + # Instantiate with CUDA available so __post_init__ passes with patch("torch.cuda.is_available", return_value=True): stage = ImageReaderStage(dali_batch_size=2, verbose=False) @@ -191,6 +193,7 @@ def test_resources_with_cuda_available() -> None: def test_resources_without_cuda() -> None: from nemo_curator.stages.image.io.image_reader import ImageReaderStage + # Create the stage without CUDA available with patch("torch.cuda.is_available", return_value=False): stage = ImageReaderStage(dali_batch_size=2, verbose=False) @@ -215,7 +218,7 @@ def test_dali_image_reader_on_gpu() -> None: from nemo_curator.tasks import FileGroupTask stage = ImageReaderStage(dali_batch_size=2, num_threads=2, verbose=False) - task = FileGroupTask(task_id="t0", dataset_name="ds", data=[str(tar_path)]) + task = FileGroupTask(dataset_name="ds", data=[str(tar_path)]) batches = stage.process(task) diff --git a/tests/stages/image/io/test_image_writer.py b/tests/stages/image/io/test_image_writer.py index a3b6f40390..7be24efc99 100644 --- a/tests/stages/image/io/test_image_writer.py +++ b/tests/stages/image/io/test_image_writer.py @@ -92,7 +92,7 @@ def _capture_parquet(_self: object, base_name: str, rows: list[dict]) -> str: for i in range(5) ] - batch = ImageBatch(task_id="t1", dataset_name="ds", data=images) + batch = ImageBatch(dataset_name="ds", data=images) out = stage.process(batch) # Validate output task @@ -137,7 +137,6 @@ def test_process_raises_on_missing_image_data(tmp_path: pathlib.Path) -> None: stage.setup() bad = ImageBatch( - task_id="bad", dataset_name="ds", data=[ImageObject(image_id="x", image_path="/p/x.jpg", image_data=None)], ) @@ -151,7 +150,7 @@ def test_process_handles_empty_batch(tmp_path: pathlib.Path) -> None: stage = image_writer_stage_cls(output_dir=str(tmp_path), images_per_tar=3) stage.setup() - empty = ImageBatch(task_id="e", dataset_name="ds", data=[]) + empty = ImageBatch(dataset_name="ds", data=[]) out = stage.process(empty) assert out.data == [] @@ -169,8 +168,8 @@ def test_construct_base_name_deterministic_and_random(monkeypatch: pytest.Monkey ImageObject(image_path="/p/a.jpg"), ] imgs2 = list(reversed(imgs1)) - b1 = stage_det.construct_base_name(ImageBatch(task_id="T", dataset_name="ds", data=imgs1)) - b2 = stage_det.construct_base_name(ImageBatch(task_id="T", dataset_name="ds", data=imgs2)) + b1 = stage_det.construct_base_name(ImageBatch(dataset_name="ds", data=imgs1)) + b2 = stage_det.construct_base_name(ImageBatch(dataset_name="ds", data=imgs2)) assert b1 == b2 assert b1.startswith("images-") @@ -181,7 +180,7 @@ def __init__(self, hex_image_id: str) -> None: monkeypatch.setattr(module.uuid, "uuid4", lambda: _FakeUUID("deadbeefcafebabe0123456789abcdef")) stage_rand = image_writer_stage_cls(output_dir=str(tmp_path), deterministic_name=False) - b3 = stage_rand.construct_base_name(ImageBatch(task_id="T2", dataset_name="ds", data=imgs1)) + b3 = stage_rand.construct_base_name(ImageBatch(dataset_name="ds", data=imgs1)) assert b3 == "images-deadbeefcafebabe" @@ -329,7 +328,7 @@ def test_process_respects_remove_image_data_flag( for i in range(4) ] - batch = ImageBatch(task_id="t1", dataset_name="ds", data=images) + batch = ImageBatch(dataset_name="ds", data=images) _out = stage.process(batch) # Image data should be removed only when the flag is True diff --git a/tests/stages/interleaved/conftest.py b/tests/stages/interleaved/conftest.py index a7f849ac86..8ac0d14d4a 100644 --- a/tests/stages/interleaved/conftest.py +++ b/tests/stages/interleaved/conftest.py @@ -91,10 +91,9 @@ def write_tar(tar_path: Path, members: dict[str, bytes]) -> str: return str(tar_path) -def task_for_tar(tar_path: str, task_id: str = "file_group_0", dataset_name: str = "mint_test") -> FileGroupTask: +def task_for_tar(tar_path: str, dataset_name: str = "mint_test") -> FileGroupTask: """Build a ``FileGroupTask`` wrapping a single tar path.""" return FileGroupTask( - task_id=task_id, dataset_name=dataset_name, data=[tar_path], _metadata={"source_files": [tar_path]}, @@ -224,12 +223,11 @@ def make_image_task(rows: list[dict], metadata: dict | None = None) -> Interleav Primarily used by materialization and classify-rows tests. """ table = pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA) - return InterleavedBatch(task_id="test", dataset_name="d", data=table, _metadata=metadata or {}) + return InterleavedBatch(dataset_name="d", data=table, _metadata=metadata or {}) def make_interleaved_batch( num_samples: int = 2, - task_id: str = "test_batch", include_images: bool = True, schema: pa.Schema = INTERLEAVED_SCHEMA, ) -> InterleavedBatch: @@ -257,7 +255,6 @@ def make_interleaved_batch( rows.append(make_row(sid, 1, "image", binary_content=b"fake-jpeg-bytes")) table = pa.Table.from_pylist(rows, schema=schema) return InterleavedBatch( - task_id=task_id, dataset_name="test", data=table, _metadata={"source_files": ["test.tar"]}, @@ -326,4 +323,4 @@ def single_row_table() -> pa.Table: @pytest.fixture def single_row_task(single_row_table: pa.Table) -> InterleavedBatch: """``InterleavedBatch`` wrapping the ``single_row_table`` fixture.""" - return InterleavedBatch(task_id="t1", dataset_name="d1", data=single_row_table) + return InterleavedBatch(dataset_name="d1", data=single_row_table) diff --git a/tests/stages/interleaved/filter/conftest.py b/tests/stages/interleaved/filter/conftest.py index a55d117f21..b285c8b4b6 100644 --- a/tests/stages/interleaved/filter/conftest.py +++ b/tests/stages/interleaved/filter/conftest.py @@ -24,7 +24,7 @@ def interleaved_task(rows: list[dict]) -> InterleavedBatch: table = pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA) - return InterleavedBatch(task_id="test", dataset_name="d", data=table) + return InterleavedBatch(dataset_name="d", data=table) def make_jpeg_bytes(width: int = 32, height: int = 32, sharp: bool = True) -> bytes: diff --git a/tests/stages/interleaved/pdf/nemotron_parse/test_stages.py b/tests/stages/interleaved/pdf/nemotron_parse/test_stages.py index a327b87b7b..c9bfef9c04 100644 --- a/tests/stages/interleaved/pdf/nemotron_parse/test_stages.py +++ b/tests/stages/interleaved/pdf/nemotron_parse/test_stages.py @@ -35,7 +35,7 @@ def _empty_task() -> _EmptyTask: - return _EmptyTask(task_id="empty", dataset_name="test", data=None) + return _EmptyTask(dataset_name="test", data=None) class TestPDFPartitioningStage: @@ -171,7 +171,6 @@ def test_pdf_dir_mode(self, tmp_path: Path): from nemo_curator.tasks import FileGroupTask task = FileGroupTask( - task_id="test_task", dataset_name="test", data=[entry], ) @@ -193,7 +192,6 @@ def test_missing_pdf_returns_none(self, tmp_path: Path): from nemo_curator.tasks import FileGroupTask task = FileGroupTask( - task_id="test_task", dataset_name="test", data=[entry], ) @@ -213,7 +211,7 @@ def test_zip_mode(self, tmp_path: Path): entry = json.dumps({"file_name": "0001234.pdf", "url": "http://test"}) from nemo_curator.tasks import FileGroupTask - task = FileGroupTask(task_id="test_task", dataset_name="test", data=[entry]) + task = FileGroupTask(dataset_name="test", data=[entry]) stage = PDFPreprocessStage(zip_base_dir=str(tmp_path)) result = stage.process(task) assert result is not None @@ -234,7 +232,7 @@ def test_jsonl_mode_with_byte_offset(self, tmp_path: Path): ) from nemo_curator.tasks import FileGroupTask - task = FileGroupTask(task_id="test_task", dataset_name="test", data=[entry]) + task = FileGroupTask(dataset_name="test", data=[entry]) stage = PDFPreprocessStage(jsonl_base_dir=str(jsonl_dir)) result = stage.process(task) assert result is not None @@ -256,7 +254,7 @@ def test_jsonl_mode_with_line_idx(self, tmp_path: Path): entry = json.dumps({"file_name": "test.pdf", "url": "http://test", "jsonl_file": "data.jsonl", "line_idx": 1}) from nemo_curator.tasks import FileGroupTask - task = FileGroupTask(task_id="test_task", dataset_name="test", data=[entry]) + task = FileGroupTask(dataset_name="test", data=[entry]) stage = PDFPreprocessStage(jsonl_base_dir=str(jsonl_dir)) result = stage.process(task) assert result is not None @@ -266,7 +264,7 @@ def test_no_mode_raises_value_error(self): entry = json.dumps({"file_name": "test.pdf"}) from nemo_curator.tasks import FileGroupTask - task = FileGroupTask(task_id="t", dataset_name="test", data=[entry]) + task = FileGroupTask(dataset_name="test", data=[entry]) stage = PDFPreprocessStage() with pytest.raises(ValueError, match="One of"): stage.process(task) @@ -300,7 +298,6 @@ def test_postprocess_basic(self): ) task = InterleavedBatch( - task_id="test", dataset_name="test", data=pa.Table.from_pandas(result_df), _metadata={"proc_size": [100, 100], "model_path": "v1.2"}, @@ -344,7 +341,6 @@ def test_no_valid_output_returns_none(self): ) task = InterleavedBatch( - task_id="test", dataset_name="test", data=pa.Table.from_pandas(result_df), _metadata={"proc_size": [100, 100], "model_path": "v1.2"}, diff --git a/tests/stages/interleaved/test_base_writer.py b/tests/stages/interleaved/test_base_writer.py index f089160792..0e2f96dd79 100644 --- a/tests/stages/interleaved/test_base_writer.py +++ b/tests/stages/interleaved/test_base_writer.py @@ -41,7 +41,6 @@ def _make_task(rows: list[dict[str, Any]]) -> InterleavedBatch: return InterleavedBatch( - task_id="base_writer_test", dataset_name="test", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]}, @@ -195,7 +194,6 @@ def test_process_no_source_files_uses_uuid(tmp_path: Path, metadata: dict) -> No """source_files absent or empty → UUID fallback (non-deterministic), matching text BaseWriter.""" rows = [{**_BASE_ROW}] task = InterleavedBatch( - task_id="no-source", dataset_name="test", data=pd.DataFrame(rows), _metadata=metadata, diff --git a/tests/stages/interleaved/test_interleaved_task.py b/tests/stages/interleaved/test_interleaved_task.py index 2069880e6a..90b5e37e58 100644 --- a/tests/stages/interleaved/test_interleaved_task.py +++ b/tests/stages/interleaved/test_interleaved_task.py @@ -34,7 +34,7 @@ def _make_batch(data: pa.Table | pd.DataFrame) -> InterleavedBatch: - return InterleavedBatch(task_id="t", dataset_name="d", data=data) + return InterleavedBatch(dataset_name="d", data=data) # --- to_pyarrow --- diff --git a/tests/stages/interleaved/test_materialization.py b/tests/stages/interleaved/test_materialization.py index f2944fb1e2..420172a1e8 100644 --- a/tests/stages/interleaved/test_materialization.py +++ b/tests/stages/interleaved/test_materialization.py @@ -409,7 +409,6 @@ def test_materialize_with_only_missing_binary_false(tmp_path: Path) -> None: } ] task = InterleavedBatch( - task_id="re_mat", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -445,7 +444,6 @@ def test_materialize_preserves_passthrough_columns_with_src_prefix(tmp_path: Pat pa.field("_src_metadata", pa.string()) ) task = InterleavedBatch( - task_id="passthrough_test", dataset_name="d", data=pa.Table.from_pylist(rows, schema=schema_with_passthrough), ) diff --git a/tests/stages/interleaved/test_multimodal_core.py b/tests/stages/interleaved/test_multimodal_core.py index b1d13adf58..2f9fe1d4a3 100644 --- a/tests/stages/interleaved/test_multimodal_core.py +++ b/tests/stages/interleaved/test_multimodal_core.py @@ -237,7 +237,6 @@ def test_materialize_mixed_strategies(tmp_path: Path) -> None: def test_materialize_empty_task() -> None: task = InterleavedBatch( - task_id="empty", dataset_name="d", data=pa.table( { @@ -272,7 +271,7 @@ def test_materialize_no_image_rows() -> None: ], schema=INTERLEAVED_SCHEMA, ) - task = InterleavedBatch(task_id="no_img", dataset_name="d", data=table) + task = InterleavedBatch(dataset_name="d", data=table) result = materialize_task_binary_content(task) assert result.num_items == 1 @@ -313,7 +312,7 @@ def test_aspect_ratio_filter_handles_non_default_dataframe_index() -> None: ] ) df.index = pd.Index([10, 42]) - task = InterleavedBatch(task_id="non_default_index", dataset_name="d1", data=df) + task = InterleavedBatch(dataset_name="d1", data=df) stage = InterleavedAspectRatioFilterStage(drop_invalid_rows=False) out = stage.process(task).to_pandas() assert len(out) == 1 @@ -366,7 +365,7 @@ def test_aspect_ratio_filter_works_on_png_images() -> None: }, ] ) - task = InterleavedBatch(task_id="png_test", dataset_name="d1", data=df) + task = InterleavedBatch(dataset_name="d1", data=df) stage = InterleavedAspectRatioFilterStage(min_aspect_ratio=0.2, max_aspect_ratio=5.0, drop_invalid_rows=False) out = stage.process(task).to_pandas() assert len(out) == 2 @@ -480,7 +479,6 @@ def content_keep_mask(self, task: InterleavedBatch, df: pd.DataFrame) -> pd.Seri }, ] task = InterleavedBatch( - task_id="pos_test", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -536,7 +534,6 @@ def _row(sample_id: str, position: int, modality: str, text: str | None = None) _row("s1", 4, "text", "end"), ] task = InterleavedBatch( - task_id="interleave_test", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -591,7 +588,6 @@ def _row(sample_id: str, position: int, modality: str, text: str | None = None) _row("s1", 3, "image"), ] task = InterleavedBatch( - task_id="noninterleaved_row_order", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -635,7 +631,6 @@ def _row(sample_id: str, position: int, modality: str, text: str | None = None) _row("s2", 1, "text", "dropped2"), ] task = InterleavedBatch( - task_id="orphan_test", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -688,7 +683,7 @@ def test_count_and_num_items() -> None: ], schema=INTERLEAVED_SCHEMA, ) - task = InterleavedBatch(task_id="cnt", dataset_name="d", data=table) + task = InterleavedBatch(dataset_name="d", data=table) assert task.num_items == 2 assert task.count() == 3 assert task.count(modality="text") == 2 @@ -722,7 +717,7 @@ def test_count_with_pandas_data() -> None: ], schema=INTERLEAVED_SCHEMA, ) - task = InterleavedBatch(task_id="pd_cnt", dataset_name="d", data=table.to_pandas()) + task = InterleavedBatch(dataset_name="d", data=table.to_pandas()) assert task.num_items == 1 assert task.count() == 2 assert task.count(modality="image") == 1 @@ -821,7 +816,6 @@ def test_iter_materialized_bytes_only_yields_masked_rows(tmp_path: Path) -> None }, ] task = InterleavedBatch( - task_id="iter_test", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -861,7 +855,7 @@ def test_iter_materialized_bytes_preserves_original_indices(tmp_path: Path) -> N ] df = pd.DataFrame(rows) df.index = pd.Index([99]) - task = InterleavedBatch(task_id="idx_test", dataset_name="d", data=df) + task = InterleavedBatch(dataset_name="d", data=df) stage = InterleavedAspectRatioFilterStage() mask = pd.Series([True], index=df.index) @@ -901,7 +895,6 @@ def test_materialize_extracts_individual_tiff_frames(tmp_path: Path) -> None: } ) task = InterleavedBatch( - task_id="tiff_mat", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -931,7 +924,7 @@ def annotate(self, task: InterleavedBatch, df: pd.DataFrame) -> pd.DataFrame: return df empty_table = pa.Table.from_pylist([], schema=INTERLEAVED_SCHEMA) - task = InterleavedBatch(task_id="empty", dataset_name="d", data=empty_table) + task = InterleavedBatch(dataset_name="d", data=empty_table) result = _Passthrough().process(task) assert result is task @@ -988,7 +981,6 @@ def content_keep_mask(self, task: InterleavedBatch, df: pd.DataFrame) -> pd.Seri }, ] task = InterleavedBatch( - task_id="drop_test", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -1013,7 +1005,6 @@ def test_iter_materialized_bytes_empty_mask() -> None: }, ] task = InterleavedBatch( - task_id="empty_mask", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -1055,7 +1046,6 @@ def content_keep_mask(self, task: InterleavedBatch, df: pd.DataFrame) -> pd.Seri }, ] task = InterleavedBatch( - task_id="meta_only", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -1089,7 +1079,6 @@ def test_aspect_ratio_filter_no_image_rows() -> None: }, ] task = InterleavedBatch( - task_id="no_img", dataset_name="d", data=pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA), ) @@ -1138,7 +1127,7 @@ def _materialized_bytes(binary_content: object) -> list[tuple[int, bytes | None] else: df_mat = df.copy() df_mat["binary_content"] = binary_content - fake = InterleavedBatch(task_id="t", dataset_name="d", data=df_mat, _metadata=task._metadata) + fake = InterleavedBatch(dataset_name="d", data=df_mat, _metadata=task._metadata) with patch("nemo_curator.stages.interleaved.stages.materialize_task_binary_content", return_value=fake): return list(_PassthroughFilter().iter_materialized_bytes(task, df, df["modality"] == "image")) diff --git a/tests/stages/interleaved/test_multimodal_reader.py b/tests/stages/interleaved/test_multimodal_reader.py index 416698bfe0..4484d9e7d9 100644 --- a/tests/stages/interleaved/test_multimodal_reader.py +++ b/tests/stages/interleaved/test_multimodal_reader.py @@ -144,7 +144,7 @@ def test_reader_supports_custom_field_mapping(tmp_path: Path) -> None: image_name="custom-image.jpg", image_bytes=image_bytes, ) - task = task_for_tar(str(tar_path), "file_group_custom") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( sample_id_field="doc_id", texts_field="captions", @@ -180,7 +180,7 @@ def test_reader_reads_all_fields_by_default(tmp_path: Path) -> None: "aux": {"page": 3}, } _write_tar_sample(tar_path, payload, json_name="sample.meta.json") - task = task_for_tar(str(tar_path), "all_fields") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( sample_id_field="doc_id", texts_field="captions", @@ -220,7 +220,7 @@ def test_reader_uses_resolved_content_key_for_content_type(tmp_path: Path) -> No jpg_info.size = 3 tf.addfile(jpg_info, BytesIO(b"jpg")) - task = task_for_tar(str(tar_path), "content_type_resolve") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( sample_id_field="doc_id", texts_field="captions", @@ -242,7 +242,7 @@ def test_reader_image_tokens_with_frame_index(tmp_path: Path) -> None: "images": [None, "page_0_image_15", "page_1_image_22"], } _write_tar_sample(tar_path, payload, json_name="sample.json", image_name="doc.pdf.tiff", image_bytes=b"TIFF_DATA") - task = task_for_tar(str(tar_path), "sub_image_test") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( sample_id_field="pdf_name", image_extensions=(".tiff",), @@ -277,7 +277,7 @@ def test_reader_interleaved_positions_do_not_overlap(tmp_path: Path) -> None: "images": [None, "page_img", None, "chart_img", None], } _write_tar_sample(tar_path, payload, image_name="interleaved.pdf.jpg", image_bytes=b"\xff\xd8\xff") - task = task_for_tar(str(tar_path), "interleaved_test") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage(sample_id_field="pdf_name") df = _as_df(reader.process(task)) @@ -304,7 +304,7 @@ def test_reader_empty_output_schema_includes_requested_passthrough_fields(tmp_pa img_info.size = 3 tf.addfile(img_info, BytesIO(b"abc")) - task = task_for_tar(str(tar_path), "empty_schema") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage(fields=("p_hash",)) df = _as_df(reader.process(task)) assert "p_hash" in df.columns @@ -314,7 +314,7 @@ def test_reader_fields_reserved_key_raises(tmp_path: Path) -> None: tar_path = tmp_path / "reserved_key.tar" payload = {"pdf_name": "doc.pdf", "texts": ["t"], "images": []} _write_tar_sample(tar_path, payload) - task = task_for_tar(str(tar_path), "reserved_key") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage(fields=("sample_id",)) with pytest.raises(ValueError, match="fields contains reserved keys"): _ = reader.process(task) @@ -324,7 +324,7 @@ def test_reader_fields_missing_key_warns_and_fills_none(tmp_path: Path, caplog: tar_path = tmp_path / "missing_key.tar" payload = {"pdf_name": "doc.pdf", "texts": ["t"], "images": []} _write_tar_sample(tar_path, payload) - task = task_for_tar(str(tar_path), "missing_key") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage(fields=("p_hash",)) with caplog.at_level("WARNING"): result = reader.process(task) @@ -345,7 +345,7 @@ def test_reader_per_image_fields_distributed_to_image_rows(tmp_path: Path) -> No "image_metadata": [{"height": 100, "width": 200}], } _write_tar_sample(tar_path, payload) - task = task_for_tar(str(tar_path), "per_image") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( per_image_fields=("image_metadata",), ) @@ -374,7 +374,7 @@ def test_reader_per_text_fields_distributed_to_text_rows(tmp_path: Path) -> None "text_scores": [0.95, 0.42], } _write_tar_sample(tar_path, payload) - task = task_for_tar(str(tar_path), "per_text") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( per_text_fields=("text_scores",), ) @@ -406,7 +406,7 @@ def test_reader_per_image_and_per_text_fields_together(tmp_path: Path) -> None: "url": "https://example.com", } _write_tar_sample(tar_path, payload) - task = task_for_tar(str(tar_path), "both_per_modality") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( per_image_fields=("image_metadata",), per_text_fields=("text_lang",), @@ -440,7 +440,7 @@ def test_reader_per_modality_fields_excluded_from_sample_passthrough(tmp_path: P "url": "https://example.com", } _write_tar_sample(tar_path, payload) - task = task_for_tar(str(tar_path), "exclude_pt") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( per_image_fields=("image_metadata",), per_text_fields=("text_scores",), @@ -462,7 +462,7 @@ def test_reader_per_modality_field_missing_warns(tmp_path: Path, caplog: pytest. "images": [], } _write_tar_sample(tar_path, payload) - task = task_for_tar(str(tar_path), "missing_per_field") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( per_image_fields=("image_metadata",), ) @@ -483,7 +483,7 @@ def test_reader_raises_on_non_list_per_modality_field(tmp_path: Path) -> None: "image_metadata": "not-a-list", } _write_tar_sample(tar_path, payload) - task = task_for_tar(str(tar_path), "non_list_field") + task = task_for_tar(str(tar_path)) reader = InterleavedWebdatasetReaderStage( per_image_fields=("image_metadata",), ) @@ -512,7 +512,7 @@ def test_reader_materialize_on_read_extracts_individual_tiff_frames(tmp_path: Pa tmp_path / "tiff-frames.tar", {"sample.json": json.dumps(payload).encode(), "doc.pdf.tiff": tiff_bytes}, ) - task = task_for_tar(tar_path, "tiff_frame_test") + task = task_for_tar(tar_path) reader = InterleavedWebdatasetReaderStage( sample_id_field="pdf_name", image_extensions=(".tiff",), @@ -559,7 +559,7 @@ def test_reader_materialize_on_read_records_error_for_missing_member(tmp_path: P tmp_path / "corrupt-member.tar", {"sample.json": json.dumps(payload).encode(), "doc.pdf.tiff": tiff_bytes}, ) - task = task_for_tar(tar_path, "corrupt_member_test") + task = task_for_tar(tar_path) reader = InterleavedWebdatasetReaderStage( sample_id_field="pdf_name", image_extensions=(".tiff",), @@ -589,7 +589,7 @@ def test_reader_frame_counter_resets_per_content_key(tmp_path: Path) -> None: tmp_path / "multi-tiff.tar", {"sample.json": json.dumps(payload).encode(), "a.tiff": tiff_a, "b.tiff": tiff_b}, ) - task = task_for_tar(tar_path, "multi_tiff_test") + task = task_for_tar(tar_path) reader = InterleavedWebdatasetReaderStage( sample_id_field="pdf_name", image_extensions=(".tiff",), @@ -622,7 +622,7 @@ def test_reader_materialize_preserves_raw_bytes_on_frame_extraction_failure(tmp_ tmp_path / "oob-frame.tar", {"sample.json": json.dumps(payload).encode(), "doc.pdf.tiff": tiff_bytes}, ) - task = task_for_tar(tar_path, "oob_frame_test") + task = task_for_tar(tar_path) reader = InterleavedWebdatasetReaderStage( sample_id_field="pdf_name", image_extensions=(".tiff",), @@ -688,7 +688,7 @@ def test_reader_materialize_on_read_jpeg_png_bytes_preserved( image_member: image_bytes, }, ) - task = task_for_tar(tar_path, f"{fmt.lower()}_test") + task = task_for_tar(tar_path) reader = InterleavedWebdatasetReaderStage(materialize_on_read=True) df = _as_df(reader.process(task)) @@ -728,7 +728,6 @@ def test_reader_empty_tar(tmp_path: Path) -> None: img_info.size = 3 tf.addfile(img_info, BytesIO(b"abc")) task = FileGroupTask( - task_id="empty", dataset_name="d", data=[str(tar_path)], _metadata={"source_files": [str(tar_path)]}, @@ -749,7 +748,6 @@ def test_reader_multi_tar(tmp_path: Path) -> None: {f"{sample_id}.json": json.dumps(payload).encode(), f"{sample_id}.jpg": b"img"}, ) task = FileGroupTask( - task_id="multi", dataset_name="d", data=[str(tmp_path / "shard1.tar"), str(tmp_path / "shard2.tar")], _metadata={"source_files": ["shard1.tar", "shard2.tar"]}, @@ -773,7 +771,6 @@ def test_reader_max_batch_bytes_splits(tmp_path: Path) -> None: {f"{sample_id}.json": json.dumps(payload).encode()}, ) task = FileGroupTask( - task_id="split", dataset_name="d", data=[str(tmp_path / "doc1.tar"), str(tmp_path / "doc2.tar")], _metadata={"source_files": ["doc1.tar", "doc2.tar"]}, @@ -782,8 +779,6 @@ def test_reader_max_batch_bytes_splits(tmp_path: Path) -> None: result = reader.process(task) assert isinstance(result, list) assert len(result) >= 2 - for batch in result: - assert "_processed_" in batch.task_id def test_reader_source_files_per_split_only_contributing_tars(tmp_path: Path) -> None: @@ -797,7 +792,6 @@ def test_reader_source_files_per_split_only_contributing_tars(tmp_path: Path) -> write_tar(Path(tar_path), {f"{sample_id}.json": json.dumps(payload).encode()}) task = FileGroupTask( - task_id="sf_split", dataset_name="d", data=[tar1, tar2], _metadata={"source_files": [tar1, tar2]}, @@ -908,7 +902,7 @@ def test_parquet_reader_roundtrip(tmp_path: Path) -> None: batch = make_interleaved_batch(num_samples=2, include_images=False) pq_path = _write_parquet_task(batch, tmp_path / "out") - task = FileGroupTask(task_id="pq_rt", dataset_name="d", data=[pq_path]) + task = FileGroupTask(dataset_name="d", data=[pq_path]) reader = InterleavedParquetReaderStage() result = reader.process(task) assert isinstance(result, InterleavedBatch) @@ -935,7 +929,7 @@ def test_parquet_reader_missing_columns_filled_with_null(tmp_path: Path) -> None pq_path = tmp_path / "minimal.parquet" pq.write_table(minimal, pq_path) - task = FileGroupTask(task_id="minimal", dataset_name="d", data=[str(pq_path)]) + task = FileGroupTask(dataset_name="d", data=[str(pq_path)]) result = InterleavedParquetReaderStage().process(task) assert isinstance(result, InterleavedBatch) df = result.to_pandas() @@ -949,7 +943,7 @@ def test_parquet_reader_fields_subset(tmp_path: Path) -> None: batch = make_interleaved_batch(num_samples=1, include_images=False) pq_path = _write_parquet_task(batch, tmp_path / "out") - task = FileGroupTask(task_id="fields_sub", dataset_name="d", data=[pq_path]) + task = FileGroupTask(dataset_name="d", data=[pq_path]) result = InterleavedParquetReaderStage(fields=("text_content",)).process(task) assert isinstance(result, InterleavedBatch) df = result.to_pandas() @@ -962,7 +956,7 @@ def test_parquet_reader_fields_null_fill_missing(tmp_path: Path) -> None: batch = make_interleaved_batch(num_samples=1, include_images=False) pq_path = _write_parquet_task(batch, tmp_path / "out") - task = FileGroupTask(task_id="null_fill", dataset_name="d", data=[pq_path]) + task = FileGroupTask(dataset_name="d", data=[pq_path]) result = InterleavedParquetReaderStage(fields=("nonexistent_field",)).process(task) assert isinstance(result, InterleavedBatch) df = result.to_pandas() @@ -977,9 +971,7 @@ def test_parquet_reader_extra_column_passthrough_by_default(tmp_path: Path) -> N pq_path = str(tmp_path / "extra.parquet") pq.write_table(pa.Table.from_pylist(_make_aligned_rows(fake_jpg)), pq_path) - result = InterleavedParquetReaderStage().process( - FileGroupTask(task_id="extra_col", dataset_name="d", data=[pq_path]) - ) + result = InterleavedParquetReaderStage().process(FileGroupTask(dataset_name="d", data=[pq_path])) assert isinstance(result, InterleavedBatch) _assert_field_alignment( result.to_pandas(), @@ -997,7 +989,7 @@ def test_parquet_reader_extra_column_excluded_when_fields_set(tmp_path: Path) -> # fields=("text_metadata",) → only text_metadata + reserved cols; others are NOT read result = InterleavedParquetReaderStage(fields=("text_metadata",)).process( - FileGroupTask(task_id="fields_excl", dataset_name="d", data=[pq_path]) + FileGroupTask(dataset_name="d", data=[pq_path]) ) assert isinstance(result, InterleavedBatch) df = result.to_pandas() @@ -1045,9 +1037,7 @@ def test_parquet_reader_wds_to_pq_to_wds_roundtrip(tmp_path: Path) -> None: path=str(tmp_path / "pq_out"), materialize_on_write=False, mode="overwrite" ).process(batch_wds) - batch_pq = InterleavedParquetReaderStage().process( - FileGroupTask(task_id="pq_rt", dataset_name="d", data=[pq_task.data[0]]) - ) + batch_pq = InterleavedParquetReaderStage().process(FileGroupTask(dataset_name="d", data=[pq_task.data[0]])) assert isinstance(batch_pq, InterleavedBatch) df_pq = batch_pq.to_pandas() for col in ("sample_metadata", "text_metadata", "image_metadata"): @@ -1059,7 +1049,7 @@ def test_parquet_reader_wds_to_pq_to_wds_roundtrip(tmp_path: Path) -> None: wds2_task = InterleavedWebdatasetWriterStage( path=str(tmp_path / "wds_out"), materialize_on_write=False, mode="overwrite" ).process(batch_pq) - batch_final = wds_reader.process(FileGroupTask(task_id="final", dataset_name="d", data=wds2_task.data)) + batch_final = wds_reader.process(FileGroupTask(dataset_name="d", data=wds2_task.data)) assert isinstance(batch_final, InterleavedBatch) _assert_field_alignment( batch_final.to_pandas(), @@ -1073,7 +1063,6 @@ def test_parquet_reader_pq_to_wds_to_pq_roundtrip(tmp_path: Path) -> None: """PQ→WDS→PQ: sample/text/image extra columns survive a full write-to-WDS and back.""" fake_jpg = b"\xff\xd8\xff\xe0fake" batch0 = InterleavedBatch( - task_id="pq0", dataset_name="d", data=pa.Table.from_pylist(_make_aligned_rows(fake_jpg)), _metadata={"source_files": ["source.parquet"]}, @@ -1085,7 +1074,7 @@ def test_parquet_reader_pq_to_wds_to_pq_roundtrip(tmp_path: Path) -> None: .data[0] ) - batch1 = InterleavedParquetReaderStage().process(FileGroupTask(task_id="pq1", dataset_name="d", data=[pq_path0])) + batch1 = InterleavedParquetReaderStage().process(FileGroupTask(dataset_name="d", data=[pq_path0])) assert isinstance(batch1, InterleavedBatch) _assert_field_alignment( batch1.to_pandas(), @@ -1103,7 +1092,7 @@ def test_parquet_reader_pq_to_wds_to_pq_roundtrip(tmp_path: Path) -> None: per_image_fields=("image_metadata",), per_text_fields=("text_metadata",), ) - batch2 = wds_reader.process(FileGroupTask(task_id="wds1", dataset_name="d", data=wds_task.data)) + batch2 = wds_reader.process(FileGroupTask(dataset_name="d", data=wds_task.data)) assert isinstance(batch2, InterleavedBatch) _assert_field_alignment( batch2.to_pandas(), @@ -1117,9 +1106,7 @@ def test_parquet_reader_pq_to_wds_to_pq_roundtrip(tmp_path: Path) -> None: .process(batch2) .data[0] ) - batch_final = InterleavedParquetReaderStage().process( - FileGroupTask(task_id="pq_final", dataset_name="d", data=[pq_path1]) - ) + batch_final = InterleavedParquetReaderStage().process(FileGroupTask(dataset_name="d", data=[pq_path1])) assert isinstance(batch_final, InterleavedBatch) _assert_field_alignment( batch_final.to_pandas(), @@ -1132,8 +1119,8 @@ def test_parquet_reader_pq_to_wds_to_pq_roundtrip(tmp_path: Path) -> None: def test_parquet_reader_max_batch_bytes_splits(tmp_path: Path) -> None: """Two parquet files, one sample each; max_batch_bytes=1 → 2 splits, each split's source_files lists only its contributing file.""" - batch_a = make_interleaved_batch(num_samples=1, task_id="a", include_images=False) - batch_b = make_interleaved_batch(num_samples=1, task_id="b", include_images=False) + batch_a = make_interleaved_batch(num_samples=1, include_images=False) + batch_b = make_interleaved_batch(num_samples=1, include_images=False) # Give distinct sample_ids rows_a = batch_a.to_pandas().copy() rows_a["sample_id"] = "doc_a" @@ -1143,11 +1130,11 @@ def test_parquet_reader_max_batch_bytes_splits(tmp_path: Path) -> None: out_a = tmp_path / "a" out_b = tmp_path / "b" writer = InterleavedParquetWriterStage(path=str(out_a), materialize_on_write=False, mode="overwrite") - pq_a = writer.process(InterleavedBatch(task_id="a", dataset_name="d", data=rows_a)).data[0] + pq_a = writer.process(InterleavedBatch(dataset_name="d", data=rows_a)).data[0] writer2 = InterleavedParquetWriterStage(path=str(out_b), materialize_on_write=False, mode="overwrite") - pq_b = writer2.process(InterleavedBatch(task_id="b", dataset_name="d", data=rows_b)).data[0] + pq_b = writer2.process(InterleavedBatch(dataset_name="d", data=rows_b)).data[0] - task = FileGroupTask(task_id="split_test", dataset_name="d", data=[pq_a, pq_b]) + task = FileGroupTask(dataset_name="d", data=[pq_a, pq_b]) result = InterleavedParquetReaderStage(max_batch_bytes=1).process(task) assert isinstance(result, list) @@ -1168,7 +1155,7 @@ def test_parquet_reader_empty_file(tmp_path: Path) -> None: pq_path = tmp_path / "empty.parquet" pq.write_table(empty, pq_path) - task = FileGroupTask(task_id="empty", dataset_name="d", data=[str(pq_path)]) + task = FileGroupTask(dataset_name="d", data=[str(pq_path)]) result = InterleavedParquetReaderStage().process(task) assert isinstance(result, InterleavedBatch) assert len(result.to_pandas()) == 0 @@ -1184,7 +1171,7 @@ def test_parquet_reader_composite_decompose(tmp_path: Path) -> None: def test_parquet_reader_empty_file_list_returns_empty_batch() -> None: - result = InterleavedParquetReaderStage().process(FileGroupTask(task_id="t", dataset_name="d", data=[])) + result = InterleavedParquetReaderStage().process(FileGroupTask(dataset_name="d", data=[])) assert isinstance(result, InterleavedBatch) df = result.to_pandas() assert len(df) == 0 diff --git a/tests/stages/interleaved/test_multimodal_writer.py b/tests/stages/interleaved/test_multimodal_writer.py index 03e1596353..9c549fd987 100644 --- a/tests/stages/interleaved/test_multimodal_writer.py +++ b/tests/stages/interleaved/test_multimodal_writer.py @@ -62,7 +62,6 @@ def test_writer_marks_materialize_error_on_bad_source_path(tmp_path: Path, input first_image_idx = df[image_mask].index[0] df.loc[first_image_idx, "source_ref"] = _source_ref("/definitely/missing/path.tar", "abc123.tiff") bad_batch = InterleavedBatch( - task_id=batch.task_id, dataset_name=batch.dataset_name, data=df, _metadata=batch._metadata, @@ -104,7 +103,6 @@ def test_writer_materializes_direct_content_path_without_key(tmp_path: Path) -> schema=INTERLEAVED_SCHEMA, ) task = InterleavedBatch( - task_id="direct_content_path", dataset_name="mint_test", data=table, _metadata={"source_files": [str(raw_path)]}, @@ -135,7 +133,7 @@ def test_writer_does_not_persist_dataframe_index(tmp_path: Path) -> None: ] ) df.index = pd.Index([99]) - task = InterleavedBatch(task_id="idx_task", dataset_name="mint_test", data=df) + task = InterleavedBatch(dataset_name="mint_test", data=df) writer = InterleavedParquetWriterStage( path=str(tmp_path / "out_idx"), materialize_on_write=False, mode="overwrite" ) @@ -187,7 +185,7 @@ def _row(sample_id: str, position: int, modality: str, text: str | None = None) _row("s1", 4, "text", "end"), ] table = pa.Table.from_pylist(rows, schema=INTERLEAVED_SCHEMA) - task = InterleavedBatch(task_id="e2e_order", dataset_name="d", data=table) + task = InterleavedBatch(dataset_name="d", data=table) filter_stage = _DropSecondImage(drop_invalid_rows=False) filtered_task = filter_stage.process(task) @@ -225,7 +223,7 @@ def test_writer_write_kwargs_cannot_override_index_false(tmp_path: Path) -> None ] ) df.index = pd.Index([42]) - task = InterleavedBatch(task_id="kwargs_override", dataset_name="test", data=df) + task = InterleavedBatch(dataset_name="test", data=df) writer = InterleavedParquetWriterStage( path=str(tmp_path / "override_out"), materialize_on_write=False, @@ -277,8 +275,8 @@ def test_heterogeneous_passthrough_fields_combine_as_nullable(tmp_path: Path) -> ) reader = InterleavedWebdatasetReaderStage() - batch_a = reader.process(FileGroupTask(task_id="a", dataset_name="d", data=[shard_a])) - batch_b = reader.process(FileGroupTask(task_id="b", dataset_name="d", data=[shard_b])) + batch_a = reader.process(FileGroupTask(dataset_name="d", data=[shard_a])) + batch_b = reader.process(FileGroupTask(dataset_name="d", data=[shard_b])) assert isinstance(batch_a, InterleavedBatch) assert isinstance(batch_b, InterleavedBatch) @@ -338,7 +336,7 @@ def test_writer_uses_uuid_when_no_source_files(tmp_path: Path) -> None: } ] ) - task = InterleavedBatch(task_id="no_source", dataset_name="test", data=df, _metadata={}) + task = InterleavedBatch(dataset_name="test", data=df, _metadata={}) out_dir = tmp_path / "uuid_out" writer = InterleavedParquetWriterStage( path=str(out_dir), @@ -368,7 +366,6 @@ def test_writer_no_materialize_preserves_null_binary(tmp_path: Path) -> None: schema=INTERLEAVED_SCHEMA, ) task = InterleavedBatch( - task_id="no_mat", dataset_name="test", data=table, _metadata={"source_files": ["/fake/img.jpg"]}, @@ -404,7 +401,6 @@ def test_writer_custom_compression(tmp_path: Path, compression: str) -> None: ] ) task = InterleavedBatch( - task_id="comp", dataset_name="test", data=df, _metadata={"source_files": ["test.tar"]}, @@ -473,7 +469,6 @@ def _make_wds_batch( } ) return InterleavedBatch( - task_id=f"wds_{sample_id}", dataset_name="test", data=pd.DataFrame(rows), _metadata={"source_files": source_files or ["test.tar"]}, @@ -536,7 +531,7 @@ def test_wds_writer_roundtrip(tmp_path: Path) -> None: tar_path = write_task.data[0] assert tar_path.endswith(".tar") - read_task = FileGroupTask(task_id="rt", dataset_name="test", data=[tar_path]) + read_task = FileGroupTask(dataset_name="test", data=[tar_path]) result = InterleavedWebdatasetReaderStage(sample_id_field="sample_id").process(read_task) assert isinstance(result, InterleavedBatch) df = result.to_pandas() @@ -579,7 +574,7 @@ def test_wds_writer_unsupported_modality_raises(tmp_path: Path) -> None: } ] ) - task = InterleavedBatch(task_id="vid", dataset_name="t", data=df, _metadata={"source_files": ["x.tar"]}) + task = InterleavedBatch(dataset_name="t", data=df, _metadata={"source_files": ["x.tar"]}) writer = InterleavedWebdatasetWriterStage( path=str(tmp_path / "vid_out"), materialize_on_write=False, mode="overwrite" ) @@ -601,7 +596,7 @@ def test_wds_writer_key_escaping(tmp_path: Path) -> None: assert "/" not in name.split(".json")[0], f"unescaped slash in member name: {name}" assert ":" not in name.split(".json")[0], f"unescaped colon in member name: {name}" - read_task = FileGroupTask(task_id="esc_rt", dataset_name="test", data=[write_task.data[0]]) + read_task = FileGroupTask(dataset_name="test", data=[write_task.data[0]]) result = InterleavedWebdatasetReaderStage(sample_id_field="sample_id").process(read_task) assert isinstance(result, InterleavedBatch) assert sample_id in result.to_pandas()["sample_id"].tolist() @@ -633,9 +628,7 @@ def test_wds_writer_passthrough_columns_in_json(tmp_path: Path) -> None: "url": None, }, ] - task = InterleavedBatch( - task_id="pt", dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]} - ) + task = InterleavedBatch(dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]}) writer = InterleavedWebdatasetWriterStage( path=str(tmp_path / "pt_out"), materialize_on_write=False, mode="overwrite" ) @@ -671,9 +664,7 @@ def test_wds_writer_null_binary_skips_member(tmp_path: Path) -> None: ], schema=INTERLEAVED_SCHEMA, ) - task = InterleavedBatch( - task_id="null_bin", dataset_name="t", data=table, _metadata={"source_files": ["/fake/img.png"]} - ) + task = InterleavedBatch(dataset_name="t", data=table, _metadata={"source_files": ["/fake/img.png"]}) writer = InterleavedWebdatasetWriterStage( path=str(tmp_path / "null_bin_out"), materialize_on_write=False, mode="overwrite" ) @@ -688,9 +679,8 @@ def test_wds_writer_null_binary_skips_member(tmp_path: Path) -> None: def test_wds_writer_deterministic_filename(tmp_path: Path) -> None: """Same source_files + task_id → same output filename across two writer instances.""" source = ["shard-00000.tar"] - task_id = "fixed_task" + batch = InterleavedBatch( - task_id=task_id, dataset_name="test", data=_make_wds_batch(sample_id="s1", source_files=source).to_pandas(), _metadata={"source_files": source}, @@ -730,9 +720,7 @@ def test_wds_writer_per_image_fields_roundtrip(tmp_path: Path) -> None: image_alt_text="a dog in a park", ), ] - task = InterleavedBatch( - task_id="per_img", dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]} - ) + task = InterleavedBatch(dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]}) writer = InterleavedWebdatasetWriterStage( path=str(tmp_path / "per_img_out"), materialize_on_write=False, mode="overwrite" ) @@ -748,7 +736,7 @@ def test_wds_writer_per_image_fields_roundtrip(tmp_path: Path) -> None: sample_id_field="sample_id", per_image_fields=("image_alt_text",), ) - result = reader.process(FileGroupTask(task_id="rt", dataset_name="t", data=write_task.data)) + result = reader.process(FileGroupTask(dataset_name="t", data=write_task.data)) assert isinstance(result, InterleavedBatch) df = result.to_pandas() @@ -765,9 +753,7 @@ def test_wds_writer_per_text_fields_roundtrip(tmp_path: Path) -> None: make_row("s1", 1, "text", text_content="second paragraph", text_confidence=None), # gap make_row("s1", 2, "text", text_content="third paragraph", text_confidence=0.72), ] - task = InterleavedBatch( - task_id="per_txt", dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]} - ) + task = InterleavedBatch(dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]}) writer = InterleavedWebdatasetWriterStage( path=str(tmp_path / "per_txt_out"), materialize_on_write=False, mode="overwrite" ) @@ -781,7 +767,7 @@ def test_wds_writer_per_text_fields_roundtrip(tmp_path: Path) -> None: sample_id_field="sample_id", per_text_fields=("text_confidence",), ) - result = reader.process(FileGroupTask(task_id="rt2", dataset_name="t", data=write_task.data)) + result = reader.process(FileGroupTask(dataset_name="t", data=write_task.data)) assert isinstance(result, InterleavedBatch) df = result.to_pandas() @@ -800,9 +786,7 @@ def test_wds_writer_non_ascii_per_image_field_roundtrip(tmp_path: Path) -> None: make_row("s1", 1, "image", content_type="image/png", binary_content=fake_png, image_alt_text=alt_texts[1]), make_row("s1", 2, "image", content_type="image/png", binary_content=fake_png, image_alt_text=alt_texts[2]), ] - task = InterleavedBatch( - task_id="non_ascii", dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]} - ) + task = InterleavedBatch(dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]}) writer = InterleavedWebdatasetWriterStage( path=str(tmp_path / "non_ascii_out"), materialize_on_write=False, mode="overwrite" ) @@ -818,7 +802,7 @@ def test_wds_writer_non_ascii_per_image_field_roundtrip(tmp_path: Path) -> None: sample_id_field="sample_id", per_image_fields=("image_alt_text",), ) - result = reader.process(FileGroupTask(task_id="rt3", dataset_name="t", data=write_task.data)) + result = reader.process(FileGroupTask(dataset_name="t", data=write_task.data)) assert isinstance(result, InterleavedBatch) df = result.to_pandas() image_rows = df[df["modality"] == "image"].sort_values("position") @@ -838,9 +822,7 @@ def test_wds_writer_mixed_modality_field_written_as_position_aligned_list(tmp_pa make_row("s1", 1, "image", content_type="image/png", binary_content=fake_png, confidence="img_conf_B"), make_row("s1", 2, "text", text_content="outro", confidence=None), ] - task = InterleavedBatch( - task_id="mixed", dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]} - ) + task = InterleavedBatch(dataset_name="t", data=pd.DataFrame(rows), _metadata={"source_files": ["x.tar"]}) write_task = InterleavedWebdatasetWriterStage( path=str(tmp_path / "mixed_out"), materialize_on_write=False, mode="overwrite" ).process(task) @@ -850,7 +832,7 @@ def test_wds_writer_mixed_modality_field_written_as_position_aligned_list(tmp_pa # Without per_image/per_text declaration, reader treats it as a sample-level passthrough reader = InterleavedWebdatasetReaderStage(sample_id_field="sample_id") - result = reader.process(FileGroupTask(task_id="rt_mixed", dataset_name="t", data=write_task.data)) + result = reader.process(FileGroupTask(dataset_name="t", data=write_task.data)) assert isinstance(result, InterleavedBatch) df = result.to_pandas() meta_row = df[df["modality"] == "metadata"].iloc[0] diff --git a/tests/stages/interleaved/test_validation_utils.py b/tests/stages/interleaved/test_validation_utils.py index cedf19ee0c..a5d0d15014 100644 --- a/tests/stages/interleaved/test_validation_utils.py +++ b/tests/stages/interleaved/test_validation_utils.py @@ -27,7 +27,7 @@ def _make_task(metadata: dict | None = None) -> InterleavedBatch: table = pa.Table.from_pylist([], schema=INTERLEAVED_SCHEMA) - return InterleavedBatch(task_id="t", dataset_name="d", data=table, _metadata=metadata or {}) + return InterleavedBatch(dataset_name="d", data=table, _metadata=metadata or {}) # --- resolve_storage_options --- diff --git a/tests/stages/math_stages/classifiers/test_finemath_classifier.py b/tests/stages/math_stages/classifiers/test_finemath_classifier.py index 4af265a3d2..08efa71eea 100644 --- a/tests/stages/math_stages/classifiers/test_finemath_classifier.py +++ b/tests/stages/math_stages/classifiers/test_finemath_classifier.py @@ -60,7 +60,7 @@ def test_process_with_cropping(self) -> None: # Create test data with long text long_text = "0123456789ABCDEFGHIJ" # 20 characters, mid=10 df = pd.DataFrame({"text": [long_text, "short"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -79,7 +79,7 @@ def test_process_no_cropping_needed(self) -> None: stage = CenterCropTextStage(center_crop_chars=100) df = pd.DataFrame({"text": ["Short text", "Another short text"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -94,7 +94,7 @@ def test_process_zero_crop_chars(self) -> None: stage = CenterCropTextStage(center_crop_chars=0) df = pd.DataFrame({"text": ["Any text here"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -106,7 +106,7 @@ def test_process_missing_text_field(self) -> None: stage = CenterCropTextStage(text_field="missing_field") df = pd.DataFrame({"other_field": ["Some text"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -282,7 +282,6 @@ def math_dataset(self) -> DocumentBatch: df = pd.DataFrame({"text": text}) return DocumentBatch( data=df, - task_id="math_batch_1", dataset_name="math_test_1", ) @@ -309,7 +308,7 @@ def test_classifier_with_different_text_field(self) -> None: # Create dataset with different text field df = pd.DataFrame({"content": ["Mathematical equation: E = mc²"]}) - dataset = DocumentBatch(data=df, task_id="test", dataset_name="test") + dataset = DocumentBatch(data=df, dataset_name="test") # Check that input columns match input_columns = classifier.inputs()[1] @@ -322,7 +321,7 @@ def test_edge_case_empty_dataset(self) -> None: # Create empty dataset df = pd.DataFrame({"text": []}) - empty_dataset = DocumentBatch(data=df, task_id="empty", dataset_name="empty") + empty_dataset = DocumentBatch(data=df, dataset_name="empty") # Should still have correct input/output structure input_columns = classifier.inputs()[1] diff --git a/tests/stages/math_stages/download/test_extract_stage.py b/tests/stages/math_stages/download/test_extract_stage.py index 2a6d1ace98..ec689dea86 100644 --- a/tests/stages/math_stages/download/test_extract_stage.py +++ b/tests/stages/math_stages/download/test_extract_stage.py @@ -76,9 +76,7 @@ def test_process_content_types(self, url: str, expected_type: str, expected_text # Create input DataFrame with single record input_data = pd.DataFrame([{"binary_content": b"test content", "url": url, "mime_type": "test/type"}]) - input_task = DocumentBatch( - task_id="test_content_type", dataset_name="test_dataset", data=input_data, _metadata={} - ) + input_task = DocumentBatch(dataset_name="test_dataset", data=input_data, _metadata={}) result = stage.process(input_task) @@ -110,9 +108,7 @@ def test_process_with_extraction_failures(self) -> None: ] ) - input_task = DocumentBatch( - task_id="test_extraction_failures", dataset_name="test_dataset", data=input_data, _metadata={} - ) + input_task = DocumentBatch(dataset_name="test_dataset", data=input_data, _metadata={}) result = stage.process(input_task) @@ -164,9 +160,7 @@ def magic_side_effect(binary_content: bytes) -> str: ] ) - input_task = DocumentBatch( - task_id="test_magic_failures", dataset_name="test_dataset", data=input_data, _metadata={} - ) + input_task = DocumentBatch(dataset_name="test_dataset", data=input_data, _metadata={}) result = stage.process(input_task) @@ -190,14 +184,11 @@ def test_process_empty_batch(self) -> None: stage = MathExtractStage(extractor=extractor, add_filename_column=False) input_data = pd.DataFrame() - input_task = DocumentBatch( - task_id="empty_task", dataset_name="test_dataset", data=input_data, _metadata={"source": "test"} - ) + input_task = DocumentBatch(dataset_name="test_dataset", data=input_data, _metadata={"source": "test"}) result = stage.process(input_task) assert isinstance(result, DocumentBatch) - assert result.task_id == "empty_task" assert result.dataset_name == "test_dataset" assert len(result.data) == 0 assert result._metadata == {"source": "test"} @@ -217,7 +208,7 @@ def test_stage_with_mock_extractor_smoke(self) -> None: [{"binary_content": b"test", "url": "http://example.com/test.html", "mime_type": "text/html"}] ) - input_task = DocumentBatch(task_id="smoke_test", dataset_name="test", data=input_data, _metadata={}) + input_task = DocumentBatch(dataset_name="test", data=input_data, _metadata={}) result = stage.process(input_task) @@ -245,9 +236,7 @@ def test_process_with_filename_column(self) -> None: ] ) - input_task = DocumentBatch( - task_id="test_with_filename", dataset_name="test_dataset", data=input_data, _metadata={} - ) + input_task = DocumentBatch(dataset_name="test_dataset", data=input_data, _metadata={}) # Mock only external system boundaries with mock.patch("magic.Magic") as mock_magic_class: @@ -281,9 +270,7 @@ def test_process_real_notebook_content(self, complex_notebook_json: str) -> None ] ) - input_task = DocumentBatch( - task_id="test_real_notebook", dataset_name="test_dataset", data=input_data, _metadata={} - ) + input_task = DocumentBatch(dataset_name="test_dataset", data=input_data, _metadata={}) # Only mock external system boundaries, not internal methods with mock.patch("magic.Magic") as mock_magic_class: @@ -325,9 +312,7 @@ def test_process_real_html_with_math(self, math_html: str) -> None: ] ) - input_task = DocumentBatch( - task_id="test_html_math", dataset_name="test_dataset", data=input_data, _metadata={} - ) + input_task = DocumentBatch(dataset_name="test_dataset", data=input_data, _metadata={}) # Mock external systems only - lynx and magic mock_lynx = mock.Mock() diff --git a/tests/stages/math_stages/download/test_lynx_extractor.py b/tests/stages/math_stages/download/test_lynx_extractor.py index 8366d511ee..920b0931b1 100644 --- a/tests/stages/math_stages/download/test_lynx_extractor.py +++ b/tests/stages/math_stages/download/test_lynx_extractor.py @@ -23,7 +23,9 @@ class TestLynxExtractor: @mock.patch("shutil.which", return_value="/usr/bin/lynx") @mock.patch("subprocess.run") - def test_lynx_extractor_extract_text_success(self, mock_run: mock.Mock, mock_which: mock.Mock, html_with_content: str) -> None: + def test_lynx_extractor_extract_text_success( + self, mock_run: mock.Mock, mock_which: mock.Mock, html_with_content: str + ) -> None: """Test successful lynx text extraction.""" # Mock successful subprocess call mock_process = mock.Mock() @@ -57,7 +59,9 @@ def test_lynx_extractor_extract_text_success(self, mock_run: mock.Mock, mock_whi @mock.patch("shutil.which", return_value="/usr/bin/lynx") @mock.patch("subprocess.run") - def test_lynx_extractor_extract_text_timeout(self, mock_run: mock.Mock, mock_which: mock.Mock, simple_html: str) -> None: + def test_lynx_extractor_extract_text_timeout( + self, mock_run: mock.Mock, mock_which: mock.Mock, simple_html: str + ) -> None: """Test LynxExtractor timeout handling.""" mock_run.side_effect = subprocess.TimeoutExpired(["lynx"], timeout=20) @@ -70,7 +74,9 @@ def test_lynx_extractor_extract_text_timeout(self, mock_run: mock.Mock, mock_whi @mock.patch("shutil.which", return_value="/usr/bin/lynx") @mock.patch("subprocess.run") - def test_lynx_extractor_extract_text_failure(self, mock_run: mock.Mock, mock_which: mock.Mock, simple_html: str) -> None: + def test_lynx_extractor_extract_text_failure( + self, mock_run: mock.Mock, mock_which: mock.Mock, simple_html: str + ) -> None: """Test LynxExtractor when lynx returns non-zero exit code.""" mock_process = mock.Mock() mock_process.returncode = 1 @@ -96,7 +102,9 @@ def test_lynx_extractor_extract_text_empty_input(self, mock_run: mock.Mock, mock @mock.patch("shutil.which", return_value="/usr/bin/lynx") @mock.patch("subprocess.run") - def test_lynx_extractor_extract_text_decode_error(self, mock_run: mock.Mock, mock_which: mock.Mock, simple_html: str) -> None: + def test_lynx_extractor_extract_text_decode_error( + self, mock_run: mock.Mock, mock_which: mock.Mock, simple_html: str + ) -> None: """Test LynxExtractor with decode error handling.""" mock_process = mock.Mock() mock_process.returncode = 0 @@ -114,7 +122,9 @@ def test_lynx_extractor_extract_text_decode_error(self, mock_run: mock.Mock, moc @mock.patch("shutil.which", return_value="/usr/bin/lynx") @mock.patch("subprocess.run") - def test_lynx_extractor_extract_text_with_math_content(self, mock_run: mock.Mock, mock_which: mock.Mock, math_html: str) -> None: + def test_lynx_extractor_extract_text_with_math_content( + self, mock_run: mock.Mock, mock_which: mock.Mock, math_html: str + ) -> None: """Test LynxExtractor with mathematical content.""" # Simulate lynx extracting LaTeX/math content mock_process = mock.Mock() @@ -132,7 +142,9 @@ def test_lynx_extractor_extract_text_with_math_content(self, mock_run: mock.Mock @mock.patch("shutil.which", return_value="/usr/bin/lynx") @mock.patch("subprocess.run") - def test_lynx_extractor_subprocess_error(self, mock_run: mock.Mock, mock_which: mock.Mock, simple_html: str) -> None: + def test_lynx_extractor_subprocess_error( + self, mock_run: mock.Mock, mock_which: mock.Mock, simple_html: str + ) -> None: """Test LynxExtractor with subprocess error handling.""" mock_run.side_effect = subprocess.SubprocessError("Subprocess failed") diff --git a/tests/stages/math_stages/modifiers/test_chunking.py b/tests/stages/math_stages/modifiers/test_chunking.py index e3e3eeb839..a2848b58f4 100644 --- a/tests/stages/math_stages/modifiers/test_chunking.py +++ b/tests/stages/math_stages/modifiers/test_chunking.py @@ -54,7 +54,7 @@ def test_process_single_short_text(self): stage.setup() df = pd.DataFrame({"text": ["Short text here"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -72,7 +72,7 @@ def test_process_text_with_chunking(self): # Create text with multiple paragraphs that will exceed token limit text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.\n\nParagraph four." df = pd.DataFrame({"text": [text]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -97,7 +97,7 @@ def test_process_preserves_metadata(self): "metadata": ["extra info"], } ) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -119,7 +119,7 @@ def test_process_multiple_documents(self): "doc_id": [1, 2, 3], } ) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -134,7 +134,7 @@ def test_process_empty_text(self): stage.setup() df = pd.DataFrame({"text": [""]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -147,7 +147,7 @@ def test_process_text_with_only_whitespace(self): stage.setup() df = pd.DataFrame({"text": [" \n\n \n\n "]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -161,7 +161,7 @@ def test_process_custom_separator(self): text = "Line one\nLine two\nLine three\nLine four" df = pd.DataFrame({"text": [text]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -180,7 +180,7 @@ def test_process_chunk_id_sequential(self): text = "Para one.\n\nPara two.\n\nPara three.\n\nPara four.\n\nPara five." df = pd.DataFrame({"text": [text]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -194,7 +194,7 @@ def test_process_n_tokens_calculated(self): stage.setup() df = pd.DataFrame({"text": ["Some text here"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -212,7 +212,7 @@ def test_process_last_paragraph_no_separator(self): text = "First paragraph.\n\nSecond paragraph." df = pd.DataFrame({"text": [text]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -226,7 +226,7 @@ def test_process_missing_text_field(self): stage.setup() df = pd.DataFrame({"other_field": ["Some text"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) diff --git a/tests/stages/math_stages/modifiers/test_llm_cleanup.py b/tests/stages/math_stages/modifiers/test_llm_cleanup.py index 54f82a7d5f..5f82ed21e9 100644 --- a/tests/stages/math_stages/modifiers/test_llm_cleanup.py +++ b/tests/stages/math_stages/modifiers/test_llm_cleanup.py @@ -174,7 +174,7 @@ def test_process_basic_cleanup(self): stage.setup() df = pd.DataFrame({"text": ["Original text here"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -189,7 +189,7 @@ def test_process_classification_mode(self): stage.setup() df = pd.DataFrame({"text": ["Some text to classify"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -203,7 +203,7 @@ def test_process_multiple_texts(self): stage.setup() df = pd.DataFrame({"text": ["First text", "Second text", "Third text"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -224,7 +224,7 @@ def test_process_preserves_metadata(self): "metadata": ["extra info"], } ) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -239,7 +239,7 @@ def test_process_null_text(self): stage.setup() df = pd.DataFrame({"text": [None, pd.NA, "Valid text"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -266,7 +266,7 @@ def test_process_filter_by_n_tokens(self): "n_tokens": [100, 900], # Second exceeds 80% of 1000 = 800 } ) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -289,7 +289,7 @@ def test_process_filter_by_n_tokens_all_filtered(self): "n_tokens": [900, 950], } ) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -313,7 +313,7 @@ def test_process_sort_by_n_tokens(self): "n_tokens": [300, 100, 200], } ) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) @@ -328,7 +328,7 @@ def test_process_prompt_formatting(self): stage.setup() df = pd.DataFrame({"text": ["Input text"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") # Mock the generate method to capture prompts original_generate = stage._model.generate @@ -361,7 +361,7 @@ def failing_generate(prompts: list[str]) -> None: # noqa: ARG001 stage._model.generate = failing_generate df = pd.DataFrame({"text": ["Some text"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") with pytest.raises(RuntimeError, match="LLM generation failed"): stage.process(batch) @@ -378,7 +378,7 @@ def empty_generate(prompts: list[str]) -> list[str]: # noqa: ARG001 stage._model.generate = empty_generate df = pd.DataFrame({"text": ["Some text"]}) - batch = DocumentBatch(data=df, task_id="test", dataset_name="test") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) diff --git a/tests/stages/math_stages/modifiers/test_merge_chunks.py b/tests/stages/math_stages/modifiers/test_merge_chunks.py index fe3fd1a349..d3ddf5e2ab 100644 --- a/tests/stages/math_stages/modifiers/test_merge_chunks.py +++ b/tests/stages/math_stages/modifiers/test_merge_chunks.py @@ -19,7 +19,7 @@ def _make_batch(df: pd.DataFrame) -> DocumentBatch: - return DocumentBatch(data=df, task_id="test", dataset_name="test") + return DocumentBatch(data=df, dataset_name="test") class TestChunkMergeStage: diff --git a/tests/stages/synthetic/nemo_data_designer/test_data_designer.py b/tests/stages/synthetic/nemo_data_designer/test_data_designer.py index a4dab1aebb..d41523032e 100644 --- a/tests/stages/synthetic/nemo_data_designer/test_data_designer.py +++ b/tests/stages/synthetic/nemo_data_designer/test_data_designer.py @@ -159,7 +159,6 @@ def test_process(self) -> None: batch = DocumentBatch( data=input_df, dataset_name="ds1", - task_id="task-1", _metadata=original_metadata, _stage_perf=original_stage_perf, ) @@ -174,7 +173,6 @@ def test_process(self) -> None: stage.data_designer.preview.assert_called_once_with(real_builder, num_records=1) assert isinstance(out_batch, DocumentBatch) - assert out_batch.task_id == "task-1" assert out_batch.dataset_name == "ds1" assert out_batch.data is output_df # Preserve metadata and stage_perf (same assertion style as video reader, URL generation, image convert) @@ -200,7 +198,6 @@ def test_process_preserves_metadata(self) -> None: batch = DocumentBatch( data=pd.DataFrame([{"text": "hello"}]), dataset_name="ds1", - task_id="task-1", _metadata=original_metadata, _stage_perf=original_stage_perf, ) @@ -220,12 +217,11 @@ def test_process_empty_batch(self) -> None: return_value=PreviewResults(config_builder=real_builder, dataset=output_df) ) - batch = DocumentBatch(data=pd.DataFrame(), dataset_name="ds", task_id="t1") + batch = DocumentBatch(data=pd.DataFrame(), dataset_name="ds") out_batch = stage.process(batch) stage.data_designer.preview.assert_called_once_with(real_builder, num_records=0) assert len(out_batch.data) == 0 - assert out_batch.task_id == "t1" def test_process_logs_metrics(self) -> None: """process logs ndd_running_time, num_input_records, num_output_records.""" @@ -239,7 +235,7 @@ def test_process_logs_metrics(self) -> None: return_value=PreviewResults(config_builder=real_builder, dataset=output_df) ) - batch = DocumentBatch(data=input_df, dataset_name="ds", task_id="t1") + batch = DocumentBatch(data=input_df, dataset_name="ds") stage.process(batch) assert hasattr(stage, "_custom_metrics") @@ -285,12 +281,10 @@ def test_process_with_mock_llm_endpoint(self, httpserver: pytest_httpserver.HTTP batch = DocumentBatch( data=pd.DataFrame([{"x": 1}]), dataset_name="ds", - task_id="t1", ) out_batch = stage.process(batch) assert isinstance(out_batch, DocumentBatch) - assert out_batch.task_id == "t1" assert out_batch.data is not None assert hasattr(stage, "_custom_metrics") assert "ndd_running_time" in stage._custom_metrics @@ -340,7 +334,6 @@ def test_pipeline_run_end_to_end(self, httpserver: pytest_httpserver.HTTPServer) DocumentBatch( data=pd.DataFrame([{"x": 1}]), dataset_name="integration", - task_id="e2e-1", ) ] executor = XennaExecutor(config={"execution_mode": "streaming"}) @@ -350,7 +343,6 @@ def test_pipeline_run_end_to_end(self, httpserver: pytest_httpserver.HTTPServer) assert len(result_tasks) == 1 out = result_tasks[0] assert isinstance(out, DocumentBatch) - assert out.task_id == "e2e-1" assert out.dataset_name == "integration" assert out.data is not None expected_rows = len(initial_tasks[0].data) diff --git a/tests/stages/synthetic/nemotron_cc/nemo_data_designer/test_base.py b/tests/stages/synthetic/nemotron_cc/nemo_data_designer/test_base.py index 16b6400661..685dc196ea 100644 --- a/tests/stages/synthetic/nemotron_cc/nemo_data_designer/test_base.py +++ b/tests/stages/synthetic/nemotron_cc/nemo_data_designer/test_base.py @@ -143,10 +143,12 @@ def test_process(self) -> None: stage = _make_stage() stage.setup() - output_df = pd.DataFrame([ - {"text": "a", _FORMATTED_PROMPT_COL: "Rephrase: a", "result": "out_a"}, - {"text": "b", _FORMATTED_PROMPT_COL: "Rephrase: b", "result": "out_b"}, - ]) + output_df = pd.DataFrame( + [ + {"text": "a", _FORMATTED_PROMPT_COL: "Rephrase: a", "result": "out_a"}, + {"text": "b", _FORMATTED_PROMPT_COL: "Rephrase: b", "result": "out_b"}, + ] + ) stage.data_designer.preview = MagicMock( return_value=PreviewResults(config_builder=stage.config_builder, dataset=output_df) ) @@ -156,14 +158,12 @@ def test_process(self) -> None: batch = DocumentBatch( data=pd.DataFrame([{"text": "a"}, {"text": "b"}]), dataset_name="ds", - task_id="t1", _metadata=original_metadata, _stage_perf=original_stage_perf, ) out = stage.process(batch) assert isinstance(out, DocumentBatch) - assert out.task_id == "t1_DataDesignerStage" assert out.dataset_name == "ds" assert out.data["result"].tolist() == ["out_a", "out_b"] assert _FORMATTED_PROMPT_COL not in out.data.columns @@ -179,7 +179,7 @@ def test_process_no_output_field_in_result(self) -> None: stage.data_designer.preview = MagicMock( return_value=PreviewResults(config_builder=stage.config_builder, dataset=output_df) ) - batch = DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="ds", task_id="t1") + batch = DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="ds") out = stage.process(batch) assert isinstance(out, DocumentBatch) @@ -215,7 +215,7 @@ def test_process_with_mock_llm_endpoint(self, httpserver: pytest_httpserver.HTTP ) stage.setup() - batch = DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="ds", task_id="t1") + batch = DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="ds") out = stage.process(batch) assert isinstance(out, DocumentBatch) @@ -267,9 +267,7 @@ def test_pipeline_run_end_to_end(self, httpserver: pytest_httpserver.HTTPServer) description="NDDBaseSyntheticStage via pipeline.run()", stages=[stage], ) - initial_tasks = [ - DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="integration", task_id="e2e-1") - ] + initial_tasks = [DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="integration")] executor = XennaExecutor(config={"execution_mode": "streaming"}) result_tasks = pipeline.run(executor, initial_tasks=initial_tasks) diff --git a/tests/stages/synthetic/nemotron_cc/nemo_data_designer/test_nemotron_cc.py b/tests/stages/synthetic/nemotron_cc/nemo_data_designer/test_nemotron_cc.py index d143da0d90..53b940cb35 100644 --- a/tests/stages/synthetic/nemotron_cc/nemo_data_designer/test_nemotron_cc.py +++ b/tests/stages/synthetic/nemotron_cc/nemo_data_designer/test_nemotron_cc.py @@ -94,16 +94,20 @@ def test_process_smoke(self, stage_cls: type, output_field: str) -> None: stage = _make_stage(stage_cls) stage.setup() - output_df = pd.DataFrame([{ - "text": "doc", - _FORMATTED_PROMPT_COL: "prompt", - output_field: "generated", - }]) + output_df = pd.DataFrame( + [ + { + "text": "doc", + _FORMATTED_PROMPT_COL: "prompt", + output_field: "generated", + } + ] + ) stage.data_designer.preview = MagicMock( return_value=PreviewResults(config_builder=stage.config_builder, dataset=output_df) ) - batch = DocumentBatch(data=pd.DataFrame([{"text": "doc"}]), dataset_name="ds", task_id="t1") + batch = DocumentBatch(data=pd.DataFrame([{"text": "doc"}]), dataset_name="ds") out = stage.process(batch) assert isinstance(out, DocumentBatch) @@ -137,7 +141,7 @@ def test_process_with_mock_llm_endpoint(self, httpserver: pytest_httpserver.HTTP ) stage.setup() - batch = DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="ds", task_id="t1") + batch = DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="ds") out = stage.process(batch) assert isinstance(out, DocumentBatch) @@ -192,9 +196,7 @@ def test_pipeline_run_end_to_end( description=f"{stage_cls.__name__} via pipeline.run()", stages=[stage], ) - initial_tasks = [ - DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="integration", task_id="e2e-1") - ] + initial_tasks = [DocumentBatch(data=pd.DataFrame([{"text": "hello"}]), dataset_name="integration")] result_tasks = pipeline.run(XennaExecutor(config={"execution_mode": "streaming"}), initial_tasks=initial_tasks) assert len(result_tasks) == 1 @@ -205,9 +207,7 @@ def test_pipeline_run_end_to_end( assert _FORMATTED_PROMPT_COL not in out.data.columns assert len(out.data) == 1 - def test_pipeline_e2e_reader_ndd_writer( - self, httpserver: pytest_httpserver.HTTPServer, tmp_path: Path - ) -> None: + def test_pipeline_e2e_reader_ndd_writer(self, httpserver: pytest_httpserver.HTTPServer, tmp_path: Path) -> None: """JsonlReader -> WikipediaParaphrasingStage -> JsonlWriter. Verifies files, _metadata, _stage_perf.""" from nemo_curator.backends.xenna import XennaExecutor diff --git a/tests/stages/synthetic/nemotron_cc/test_base.py b/tests/stages/synthetic/nemotron_cc/test_base.py index 9be3f4ea74..55bff231f8 100644 --- a/tests/stages/synthetic/nemotron_cc/test_base.py +++ b/tests/stages/synthetic/nemotron_cc/test_base.py @@ -117,12 +117,11 @@ def test_process_sync_single_row_no_system_prompt() -> None: model_name="test-model", ) df = pd.DataFrame([{"text": "hello"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t1") + batch = DocumentBatch(data=df, dataset_name="ds") out_batch = stage.process(batch) assert isinstance(out_batch, DocumentBatch) assert out_batch.dataset_name == "ds" - assert out_batch.task_id.endswith(stage.name) assert "out" in out_batch.data.columns assert out_batch.data["out"].iloc[0] == "resp-1" # Ensure user-only message when no system prompt @@ -144,7 +143,7 @@ def test_process_sync_with_system_prompt() -> None: model_name="test-model", ) df = pd.DataFrame([{"text": "abc"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t2") + batch = DocumentBatch(data=df, dataset_name="ds") out_batch = stage.process(batch) assert out_batch.data["out"].iloc[0] == "ok" @@ -169,7 +168,7 @@ def test_process_async_multiple_rows() -> None: model_name="test-model", ) df = pd.DataFrame([{"text": "x"}, {"text": "y"}, {"text": "z"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t3") + batch = DocumentBatch(data=df, dataset_name="ds") out_batch = stage.process(batch) assert len(out_batch.data) == 3 diff --git a/tests/stages/synthetic/nemotron_cc/test_nemotron_cc.py b/tests/stages/synthetic/nemotron_cc/test_nemotron_cc.py index 1537437d8a..42e063daa4 100644 --- a/tests/stages/synthetic/nemotron_cc/test_nemotron_cc.py +++ b/tests/stages/synthetic/nemotron_cc/test_nemotron_cc.py @@ -110,10 +110,11 @@ def test_diverseqa_post_processing_basic() -> None: pp = DiverseQAPostProcessingStage() generated_text = _build_diverseqa_response(pp.prefix) df = pd.DataFrame([{"text": "DOC", "diverse_qa": generated_text}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t0") + batch = DocumentBatch(data=df, dataset_name="ds") # Deterministic behavior: no shuffle and pick 2 pairs - with patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.shuffle", lambda _: None), patch( - "nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.randint", return_value=2 + with ( + patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.shuffle", lambda _: None), + patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.randint", return_value=2), ): out_batch = pp.process(batch) out = out_batch.data["diverse_qa"].iloc[0] @@ -131,9 +132,10 @@ def test_diverseqa_sync_end_to_end() -> None: ) pp = DiverseQAPostProcessingStage() df = pd.DataFrame([{"text": "DOC"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t1") - with patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.shuffle", lambda _: None), patch( - "nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.randint", return_value=1 + batch = DocumentBatch(data=df, dataset_name="ds") + with ( + patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.shuffle", lambda _: None), + patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.randint", return_value=1), ): raw_batch = stage.process(batch) out_batch = pp.process(raw_batch) @@ -153,9 +155,10 @@ def test_diverseqa_async_multiple_rows() -> None: ) pp = DiverseQAPostProcessingStage() df = pd.DataFrame([{"text": "D1"}, {"text": "D2"}, {"text": "D3"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t2") - with patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.shuffle", lambda _: None), patch( - "nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.randint", return_value=1 + batch = DocumentBatch(data=df, dataset_name="ds") + with ( + patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.shuffle", lambda _: None), + patch("nemo_curator.stages.synthetic.nemotron_cc.nemotron_cc.random.randint", return_value=1), ): raw_batch = stage.process(batch) out_batch = pp.process(raw_batch) @@ -168,14 +171,9 @@ def test_diverseqa_async_multiple_rows() -> None: def test_knowledge_list_process_llm_response() -> None: pp = KnowledgeListPostProcessingStage() # First line not starting with "-" should be skipped - generated = ( - "Header line\n" - "- item one\n" - " continuation\n" - "- item two" - ) + generated = "Header line\n- item one\n continuation\n- item two" df = pd.DataFrame([{"knowledge_list": generated}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="tkl") + batch = DocumentBatch(data=df, dataset_name="ds") out_batch = pp.process(batch) assert out_batch.data["knowledge_list"].iloc[0] == "item one\ncontinuation\nitem two" @@ -184,7 +182,7 @@ def test_wikipedia_paraphrasing_smoke() -> None: client = MockSyncLLMClient(responses=[["rephrased"]]) stage = WikipediaParaphrasingStage(client=client, model_name="m") df = pd.DataFrame([{"text": "original"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t3") + batch = DocumentBatch(data=df, dataset_name="ds") out_batch = stage.process(batch) assert out_batch.data["rephrased"].iloc[0] == "rephrased" @@ -193,7 +191,7 @@ def test_distill_stage_smoke() -> None: client = MockSyncLLMClient(responses=[["distilled"]]) stage = DistillStage(client=client, model_name="m") df = pd.DataFrame([{"text": "doc"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t4") + batch = DocumentBatch(data=df, dataset_name="ds") out_batch = stage.process(batch) assert out_batch.data["distill"].iloc[0] == "distilled" # Ensure system prompt is present in messages @@ -205,7 +203,7 @@ def test_extract_knowledge_stage_smoke() -> None: client = MockSyncLLMClient(responses=[["facts"]]) stage = ExtractKnowledgeStage(client=client, model_name="m") df = pd.DataFrame([{"text": "doc"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t5") + batch = DocumentBatch(data=df, dataset_name="ds") out_batch = stage.process(batch) assert out_batch.data["extract_knowledge"].iloc[0] == "facts" @@ -216,6 +214,6 @@ def test_knowledge_list_stage_smoke() -> None: client = MockSyncLLMClient(responses=[[generated]]) stage = KnowledgeListStage(client=client, model_name="m") df = pd.DataFrame([{"text": "doc"}]) - batch = DocumentBatch(data=df, dataset_name="ds", task_id="t6") + batch = DocumentBatch(data=df, dataset_name="ds") out_batch = stage.process(batch) assert out_batch.data["knowledge_list"].iloc[0] == generated diff --git a/tests/stages/synthetic/test_qa_multilingual_synthetic.py b/tests/stages/synthetic/test_qa_multilingual_synthetic.py index 3f9e212c66..2e218dc086 100644 --- a/tests/stages/synthetic/test_qa_multilingual_synthetic.py +++ b/tests/stages/synthetic/test_qa_multilingual_synthetic.py @@ -126,12 +126,11 @@ def test_process_sync_single_sample(self) -> None: num_samples=1, ) - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) result = stage.process(task) assert isinstance(result, DocumentBatch) assert result.dataset_name == "simple_synthetic_data" - assert result.task_id == 1 assert isinstance(result.data, pd.DataFrame) assert len(result.data) == 1 assert "text" in result.data.columns @@ -150,7 +149,7 @@ def test_process_sync_multiple_samples(self) -> None: num_samples=3, ) - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) with patch("nemo_curator.models.client.llm_client.logger"): # Suppress log statements result = stage.process(task) @@ -170,7 +169,7 @@ def test_process_async_single_sample(self) -> None: num_samples=1, ) - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) result = stage.process(task) assert isinstance(result, DocumentBatch) @@ -190,7 +189,7 @@ def test_process_async_multiple_samples(self) -> None: num_samples=5, ) - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) result = stage.process(task) assert isinstance(result, DocumentBatch) @@ -212,7 +211,7 @@ def test_process_sync_with_generation_config(self) -> None: generation_config=config, ) - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) with patch("nemo_curator.models.client.llm_client.logger"): result = stage.process(task) @@ -241,7 +240,7 @@ def query_model(self, *, messages: Iterable, model: str, **kwargs: object) -> li ) with patch("nemo_curator.models.client.llm_client.logger"), patch("secrets.choice", return_value="Japanese"): - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) stage.process(task) assert len(captured_prompts) == 2 @@ -254,7 +253,7 @@ def test_process_sync_with_response_asterisks(self) -> None: prompt="Test {language}", languages=["English"], client=client, model_name="test-model", num_samples=1 ) - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) with patch("nemo_curator.models.client.llm_client.logger"): result = stage.process(task) @@ -267,7 +266,7 @@ def test_process_async_with_response_asterisks(self) -> None: prompt="Test {language}", languages=["English"], client=client, model_name="test-model", num_samples=1 ) - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) result = stage.process(task) assert result.data["text"].iloc[0] == "Another styled response" @@ -298,7 +297,7 @@ def query_model(self, *, messages: Iterable, model: str, **kwargs: object) -> li ) with patch("nemo_curator.models.client.llm_client.logger"): - task = _EmptyTask(task_id="test", dataset_name="test", data=None) + task = _EmptyTask(dataset_name="test", data=None) stage.process(task) # Should have captured 10 languages diff --git a/tests/stages/text/classifiers/test_classifiers.py b/tests/stages/text/classifiers/test_classifiers.py index fb8ebe8ffa..d67135312c 100644 --- a/tests/stages/text/classifiers/test_classifiers.py +++ b/tests/stages/text/classifiers/test_classifiers.py @@ -47,7 +47,6 @@ def domain_dataset() -> DocumentBatch: df = pd.DataFrame({"text": text}) return DocumentBatch( data=df, - task_id="batch_1", dataset_name="test_1", ) @@ -132,7 +131,6 @@ def test_quality_classifier() -> None: df = pd.DataFrame({"text": text}) input_dataset = DocumentBatch( data=df, - task_id="batch_1", dataset_name="test_1", ) @@ -173,7 +171,6 @@ def test_aegis_classifier(aegis_variant: str, filter_by: list[str] | None) -> No df = pd.DataFrame({"text": text}) input_dataset = DocumentBatch( data=df, - task_id="batch_1", dataset_name="test_1", ) @@ -324,7 +321,6 @@ def test_instruction_data_guard_classifier(filter_by: list[str] | None) -> None: df = pd.DataFrame({"text": text}) input_dataset = DocumentBatch( data=df, - task_id="batch_1", dataset_name="test_1", ) @@ -366,7 +362,6 @@ def test_multilingual_domain_classifier() -> None: df = pd.DataFrame({"text": text}) input_dataset = DocumentBatch( data=df, - task_id="batch_1", dataset_name="test_1", ) @@ -389,7 +384,6 @@ def test_content_type_classifier() -> None: df = pd.DataFrame({"text": text}) input_dataset = DocumentBatch( data=df, - task_id="batch_1", dataset_name="test_1", ) @@ -413,7 +407,6 @@ def test_prompt_task_complexity_classifier(filter_by: list[str] | None) -> None: df = pd.DataFrame({"text": text}) input_dataset = DocumentBatch( data=df, - task_id="batch_1", dataset_name="test_1", ) diff --git a/tests/stages/text/classifiers/test_utils.py b/tests/stages/text/classifiers/test_utils.py index a1663eb290..a5af2ecede 100644 --- a/tests/stages/text/classifiers/test_utils.py +++ b/tests/stages/text/classifiers/test_utils.py @@ -22,7 +22,6 @@ class TestSortByLengthStage: def test_process(self): batch = DocumentBatch( - task_id="test", dataset_name="test", data=pd.DataFrame({"attention_mask": [[1, 1, 1, 1, 0], [1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]}), ) @@ -33,7 +32,6 @@ def test_process(self): def test_process_no_op(self): batch = DocumentBatch( - task_id="test", dataset_name="test", data=pd.DataFrame( # Set the SEQ_ORDER_FIELD to a random order diff --git a/tests/stages/text/deduplication/test_removal.py b/tests/stages/text/deduplication/test_removal.py index df029d16d6..1d1fd24a32 100644 --- a/tests/stages/text/deduplication/test_removal.py +++ b/tests/stages/text/deduplication/test_removal.py @@ -40,7 +40,6 @@ def sample_document_batch(self) -> DocumentBatch: } ) return DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, _metadata={"source": "test"}, @@ -82,7 +81,6 @@ def test_process_removes_duplicates( # Verify result assert isinstance(result, DocumentBatch) - assert result.task_id == "removal_test_batch" assert result.dataset_name == "test_dataset" result_df = result.to_pandas() @@ -162,7 +160,6 @@ def test_process_with_custom_id_fields(self, mock_read_parquet: MagicMock, remov } ) doc_batch = DocumentBatch( - task_id="custom_test", dataset_name="test", data=df, ) diff --git a/tests/stages/text/deduplication/test_removal_workflow.py b/tests/stages/text/deduplication/test_removal_workflow.py index 09f8b0163e..2f87ce2e5b 100644 --- a/tests/stages/text/deduplication/test_removal_workflow.py +++ b/tests/stages/text/deduplication/test_removal_workflow.py @@ -228,7 +228,7 @@ def test_initial_tasks_partitioning(self, test_config: "TestTextDuplicateRemoval initial_tasks = [] for i in range(0, len(test_config.input_file_paths), 5): task_files = test_config.input_file_paths[i : i + 5] - initial_tasks.append(FileGroupTask(task_id=f"file_group_{i // 5}", dataset_name="input", data=task_files)) + initial_tasks.append(FileGroupTask(dataset_name="input", data=task_files)) assert len(initial_tasks) == 20 # 100 files / 5 per group = 20 tasks diff --git a/tests/stages/text/download/arxiv/test_download.py b/tests/stages/text/download/arxiv/test_download.py index 1e1ee69ee9..15a8993d53 100644 --- a/tests/stages/text/download/arxiv/test_download.py +++ b/tests/stages/text/download/arxiv/test_download.py @@ -38,7 +38,9 @@ class TestArxivDownloader: @mock.patch("nemo_curator.stages.text.download.arxiv.download.check_s5cmd_installed", return_value=True) @mock.patch("subprocess.run", return_value=mock.Mock(returncode=0)) @pytest.mark.parametrize("verbose", [True, False]) - def test_download_to_path(self, mock_run: mock.Mock, mock_s5cmd_check: mock.Mock, tmp_path: Path, verbose: bool) -> None: + def test_download_to_path( + self, mock_run: mock.Mock, mock_s5cmd_check: mock.Mock, tmp_path: Path, verbose: bool + ) -> None: """Test _download_to_path with s5cmd.""" downloader = ArxivDownloader(str(tmp_path), verbose=verbose) diff --git a/tests/stages/text/download/base/test_download.py b/tests/stages/text/download/base/test_download.py index eb02618d62..5197e38bde 100644 --- a/tests/stages/text/download/base/test_download.py +++ b/tests/stages/text/download/base/test_download.py @@ -209,7 +209,6 @@ def test_process_successful_downloads(self, tmp_path: Path) -> None: # Create input task with multiple URLs urls = ["http://example.com/file1.txt", "http://example.com/file2.txt", "http://example.com/file3.txt"] input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=urls, _metadata={"source": "test", "count": 3}, @@ -219,7 +218,6 @@ def test_process_successful_downloads(self, tmp_path: Path) -> None: # Verify result structure assert isinstance(result, FileGroupTask) - assert result.task_id == "test_task" assert result.dataset_name == "test_dataset" assert result._metadata == { "source": "test", @@ -255,7 +253,6 @@ def side_effect(url: str) -> str | None: "http://example.com/file3.txt", ] input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=urls, _metadata={"source": "test"}, @@ -277,7 +274,6 @@ def test_process_empty_file_group(self, tmp_path: Path) -> None: stage = DocumentDownloadStage(downloader=downloader) input_task = FileGroupTask( - task_id="empty_task", dataset_name="test_dataset", data=[], _metadata={"source": "test"}, @@ -286,7 +282,6 @@ def test_process_empty_file_group(self, tmp_path: Path) -> None: result = stage.process(input_task) assert isinstance(result, FileGroupTask) - assert result.task_id == "empty_task" assert result.dataset_name == "test_dataset" assert result.data == [] assert result._metadata == {"source": "test", "source_files": []} @@ -299,7 +294,6 @@ def test_process_all_downloads_fail(self, mock_download: mock.Mock, tmp_path: Pa urls = ["http://example.com/file1.txt", "http://example.com/file2.txt"] input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=urls, _metadata={"source": "test"}, @@ -309,7 +303,6 @@ def test_process_all_downloads_fail(self, mock_download: mock.Mock, tmp_path: Pa # Should return empty data list when all downloads fail assert len(result.data) == 0 - assert result.task_id == "test_task" assert result.dataset_name == "test_dataset" assert result._metadata == {"source": "test", "source_files": []} diff --git a/tests/stages/text/download/base/test_iterator.py b/tests/stages/text/download/base/test_iterator.py index cfc8ca50c2..518a765cb0 100644 --- a/tests/stages/text/download/base/test_iterator.py +++ b/tests/stages/text/download/base/test_iterator.py @@ -164,7 +164,6 @@ def test_process_successful_iteration(self, tmp_path: Path) -> None: # Create input task input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(file1), str(file2)], _metadata={"source": "test"}, @@ -174,7 +173,6 @@ def test_process_successful_iteration(self, tmp_path: Path) -> None: # Verify result structure assert isinstance(result, DocumentBatch) - assert result.task_id == "test_task" assert result.dataset_name == "test_dataset" assert result._metadata == {"source": "test"} @@ -210,7 +208,6 @@ def test_process_successful_extraction(self, tmp_path: Path) -> None: # Create input task input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(file1), str(file2), str(file3)], _metadata={"source": "test"}, @@ -220,7 +217,6 @@ def test_process_successful_extraction(self, tmp_path: Path) -> None: # Verify result structure assert isinstance(result, DocumentBatch) - assert result.task_id == "test_task" assert result.dataset_name == "test_dataset" assert result._metadata == {"source": "test"} @@ -254,7 +250,6 @@ def test_process_with_record_limit(self, tmp_path: Path) -> None: test_file.write_text("content") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(test_file)], _metadata={}, @@ -283,7 +278,6 @@ def test_process_with_filtered_records(self, tmp_path: Path) -> None: # Create input task input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(file1), str(file2), str(file3)], _metadata={"source": "test"}, @@ -308,7 +302,6 @@ def test_process_iterate_without_filename_column(self, tmp_path: Path) -> None: test_file.write_text("content") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(test_file)], _metadata={}, @@ -332,7 +325,6 @@ def test_process_extract_without_filename_column(self, tmp_path: Path) -> None: test_file.write_text("content") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(test_file)], _metadata={}, @@ -356,7 +348,6 @@ def test_process_iterate_with_custom_filename_column(self, tmp_path: Path) -> No test_file.write_text("content") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(test_file)], _metadata={}, @@ -380,7 +371,6 @@ def test_process_extract_with_custom_filename_column(self, tmp_path: Path) -> No test_file.write_text("content") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(test_file)], _metadata={}, @@ -399,7 +389,6 @@ def test_process_empty_file_group(self) -> None: stage = DocumentIterateExtractStage(iterator=iterator) input_task = FileGroupTask( - task_id="empty_task", dataset_name="test_dataset", data=[], _metadata={"source": "test"}, @@ -408,7 +397,6 @@ def test_process_empty_file_group(self) -> None: result = stage.process(input_task) assert isinstance(result, DocumentBatch) - assert result.task_id == "empty_task" assert result.dataset_name == "test_dataset" assert len(result.data) == 0 assert result._metadata == {"source": "test"} @@ -424,7 +412,6 @@ def test_process_iterator_returns_none(self, mock_iterate: mock.Mock, tmp_path: test_file.write_text("content") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(test_file)], _metadata={}, @@ -450,7 +437,6 @@ def test_process_all_records_filtered(self, tmp_path: Path) -> None: file2.write_text("world_skip") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(file1), str(file2)], _metadata={}, @@ -460,7 +446,6 @@ def test_process_all_records_filtered(self, tmp_path: Path) -> None: # Should return empty DataFrame when all records are filtered assert len(result.data) == 0 - assert result.task_id == "test_task" assert result.dataset_name == "test_dataset" def test_process_with_file_errors(self, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: @@ -477,7 +462,6 @@ def test_process_with_file_errors(self, tmp_path: Path, caplog: pytest.LogCaptur error_file.write_text("content") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(good_file), str(error_file)], _metadata={}, @@ -506,7 +490,6 @@ def test_process_all_files_fail(self, tmp_path: Path, caplog: pytest.LogCaptureF error_file.write_text("content") input_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[str(error_file)], _metadata={}, @@ -516,7 +499,6 @@ def test_process_all_files_fail(self, tmp_path: Path, caplog: pytest.LogCaptureF # Should return empty DataFrame when all files fail assert len(result.data) == 0 - assert result.task_id == "test_task" assert result.dataset_name == "test_dataset" # Check that error was logged diff --git a/tests/stages/text/download/base/test_url_generation.py b/tests/stages/text/download/base/test_url_generation.py index 20bf1e1354..f6fb5c8b0e 100644 --- a/tests/stages/text/download/base/test_url_generation.py +++ b/tests/stages/text/download/base/test_url_generation.py @@ -117,7 +117,6 @@ def test_process_successful_generation(self) -> None: # Create input task input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={"source": "test"}, @@ -132,7 +131,6 @@ def test_process_successful_generation(self) -> None: # Check each generated task for i, task in enumerate(result): assert isinstance(task, FileGroupTask) - assert task.task_id == f"test_task_{i}" assert task.dataset_name == "test_dataset" assert task.data == [urls[i]] assert task._metadata == {"source_url": urls[i]} @@ -150,7 +148,6 @@ def test_process_with_limit(self) -> None: stage = URLGenerationStage(url_generator=generator, limit=3) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={}, @@ -171,7 +168,6 @@ def test_process_empty_url_list(self) -> None: stage = URLGenerationStage(url_generator=generator) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={}, @@ -189,7 +185,6 @@ def test_process_limit_larger_than_urls(self) -> None: stage = URLGenerationStage(url_generator=generator, limit=10) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={}, @@ -206,7 +201,6 @@ def test_process_limit_zero(self) -> None: stage = URLGenerationStage(url_generator=generator, limit=0) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={}, @@ -226,7 +220,6 @@ def test_process_generation_failure(self, mock_generate: mock.Mock) -> None: stage = URLGenerationStage(url_generator=generator) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={}, @@ -243,7 +236,6 @@ def test_process_task_metadata_propagation(self) -> None: stage = URLGenerationStage(url_generator=generator) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={"original": "metadata"}, @@ -263,7 +255,6 @@ def test_process_single_url_per_task(self) -> None: stage = URLGenerationStage(url_generator=generator) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={}, @@ -287,17 +278,14 @@ def test_process_task_id_generation(self) -> None: stage = URLGenerationStage(url_generator=generator) input_task = _EmptyTask( - task_id="parent_task", dataset_name="test_dataset", data=None, _metadata={}, ) - result = stage.process(input_task) + stage.process(input_task) # Check task ID generation - assert result[0].task_id == "parent_task_0" - assert result[1].task_id == "parent_task_1" def test_process_metadata_per_task(self) -> None: """Test that each task gets correct source URL metadata.""" @@ -306,7 +294,6 @@ def test_process_metadata_per_task(self) -> None: stage = URLGenerationStage(url_generator=generator) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={}, @@ -325,7 +312,6 @@ def test_process_with_no_urls_generated(self, mock_generate: mock.Mock) -> None: stage = URLGenerationStage(url_generator=generator) input_task = _EmptyTask( - task_id="test_task", dataset_name="test_dataset", data=None, _metadata={}, diff --git a/tests/stages/text/embedders/test_base.py b/tests/stages/text/embedders/test_base.py index dad61f3e84..d387fc8e6d 100644 --- a/tests/stages/text/embedders/test_base.py +++ b/tests/stages/text/embedders/test_base.py @@ -134,7 +134,6 @@ def test_process_end_to_end(self, mock_auto_model: Mock, pooling_strategy: str) # Create sample data with tokenized inputs sample_data = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=pd.DataFrame( { @@ -212,7 +211,7 @@ def sample_data(self) -> DocumentBatch: """Create sample text data for testing.""" texts = ["Hello world", "Test text"] data = pd.DataFrame({"text": texts}) - return DocumentBatch(task_id="test_batch", dataset_name="test_dataset", data=data) + return DocumentBatch(dataset_name="test_dataset", data=data) def test_embedding_creator_stage_initialization_and_decomposition(self) -> None: """Test initialization, decomposition, and parameter passing to decomposed stages.""" diff --git a/tests/stages/text/embedders/test_vllm.py b/tests/stages/text/embedders/test_vllm.py index f4f8792bf0..e63eefe98d 100644 --- a/tests/stages/text/embedders/test_vllm.py +++ b/tests/stages/text/embedders/test_vllm.py @@ -38,7 +38,7 @@ def sample_data() -> DocumentBatch: """Create sample text data for testing.""" texts = ["Hello world", "This is a test", "Machine learning is great"] data = pd.DataFrame({"text": texts}) - return DocumentBatch(task_id="test_batch", dataset_name="test_dataset", data=data) + return DocumentBatch(dataset_name="test_dataset", data=data) @pytest.fixture(scope="module") diff --git a/tests/stages/text/experimental/translation/conftest.py b/tests/stages/text/experimental/translation/conftest.py index ae18a58496..74a38053f4 100644 --- a/tests/stages/text/experimental/translation/conftest.py +++ b/tests/stages/text/experimental/translation/conftest.py @@ -95,7 +95,7 @@ def sample_batch() -> DocumentBatch: "id": [1, 2], } ) - return DocumentBatch(data=df, dataset_name="test", task_id="1") + return DocumentBatch(data=df, dataset_name="test") @pytest.fixture @@ -127,7 +127,7 @@ def messages_batch() -> DocumentBatch: "id": [10, 20], } ) - return DocumentBatch(data=df, dataset_name="messages-test", task_id="1") + return DocumentBatch(data=df, dataset_name="messages-test") @pytest.fixture @@ -155,4 +155,4 @@ def batch_with_existing_translations() -> DocumentBatch: "id": [100, 200, 300], } ) - return DocumentBatch(data=df, dataset_name="resume-test", task_id="1") + return DocumentBatch(data=df, dataset_name="resume-test") diff --git a/tests/stages/text/experimental/translation/test_pipeline.py b/tests/stages/text/experimental/translation/test_pipeline.py index 346641d6d4..dd60e71dd1 100644 --- a/tests/stages/text/experimental/translation/test_pipeline.py +++ b/tests/stages/text/experimental/translation/test_pipeline.py @@ -267,7 +267,7 @@ def test_composite_stage_process_raises(self, mock_client: MockAsyncLLMClient) - model_name="m", ) df = pd.DataFrame({"text": ["hello"]}) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") with pytest.raises(RuntimeError, match="should not be executed directly"): pipeline.process(batch) @@ -408,7 +408,7 @@ def test_filter_process_drops_low_scores(self, mock_client: MockAsyncLLMClient) "translated_text": ["Hallo Welt."], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -432,7 +432,7 @@ def test_filter_process_keeps_high_scores(self, mock_client: MockAsyncLLMClient) "translated_text": ["Hallo Welt.", "Zweites Dok."], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -453,7 +453,7 @@ def test_filter_process_empty_batch(self, mock_client: MockAsyncLLMClient) -> No target_lang="de", ) df = pd.DataFrame({"text": pd.Series(dtype="str"), "translated_text": pd.Series(dtype="str")}) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) assert result.to_pandas().empty @@ -523,7 +523,7 @@ def test_full_e2e_with_faith_eval(self, mock_client: MockAsyncLLMClient) -> None "id": [1, 2], } ) - batch = DocumentBatch(data=df, dataset_name="e2e-test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="e2e-test") pipeline = TranslationStage( source_lang="en", @@ -566,7 +566,7 @@ def test_full_e2e_empty_segment_not_sent_to_llm(self, mock_client: MockAsyncLLMC "id": [10, 20], } ) - batch = DocumentBatch(data=df, dataset_name="e2e-empty-test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="e2e-empty-test") pipeline = TranslationStage( source_lang="en", @@ -598,7 +598,7 @@ def test_full_e2e_with_non_contiguous_index(self, mock_client: MockAsyncLLMClien }, index=[5, 10, 15], # Non-contiguous index ) - batch = DocumentBatch(data=df, dataset_name="e2e-index-test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="e2e-index-test") pipeline = TranslationStage( source_lang="en", @@ -647,7 +647,7 @@ def test_faith_eval_score_without_filtering(self, mock_client: MockAsyncLLMClien "translated_text": ["Hallo Welt.", "Zweites Dok."], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -677,7 +677,7 @@ def test_faith_eval_filter_enabled_true_drops_rows(self, mock_client: MockAsyncL "translated_text": ["Hallo."], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -737,7 +737,7 @@ def test_dry_run_returns_empty_translations(self, mock_client: MockAsyncLLMClien "id": [1, 2, 3], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -761,7 +761,7 @@ def test_dry_run_produces_timing_columns(self, mock_client: MockAsyncLLMClient) stage._initialized = True df = pd.DataFrame({"_seg_segments": ["Hello"], "id": [1]}) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -839,7 +839,7 @@ def test_skip_translated_all_rows_already_translated(self, mock_client: MockAsyn "translated_text": ["Bereits uebersetzt", "Auch bereits uebersetzt"], } ) - batch = DocumentBatch(data=df, dataset_name="resume-test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="resume-test") pipeline = TranslationStage( source_lang="en", target_lang="de", @@ -868,7 +868,7 @@ def test_merge_skipped_reads_batch_metadata(self) -> None: "translated_text": ["Bereits uebersetzt", "", ""], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") skip_stage = SkipExistingTranslationsStage() skipped_batch = skip_stage.process(batch) @@ -878,7 +878,6 @@ def test_merge_skipped_reads_batch_metadata(self) -> None: translated_batch = DocumentBatch( data=remaining_df, dataset_name=skipped_batch.dataset_name, - task_id=skipped_batch.task_id, _metadata=skipped_batch._metadata, ) @@ -907,7 +906,7 @@ def test_output_mode_both(self, mock_client: MockAsyncLLMClient) -> None: output_mode="both", ) df = pd.DataFrame({"text": ["Hello world."], "id": [1]}) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") stages = pipeline.decompose() result = batch @@ -950,7 +949,7 @@ def test_partial_failure_does_not_crash(self, mock_client: MockAsyncLLMClient) - "id": [1, 2, 3], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -976,7 +975,7 @@ def test_filter_stage_drops_low_scores(self) -> None: "faith_parse_failed": [False, False], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -1010,7 +1009,7 @@ def test_reassembly_aggregates_segment_scores(self) -> None: "faith_parse_failed": [False, True], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") stage = ReassemblyStage(aggregate_faith_scores=True) result = stage.process(batch).to_pandas() @@ -1036,7 +1035,7 @@ def test_filter_stage_keeps_parse_failed_rows(self) -> None: "faith_parse_failed": [True, False], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -1054,7 +1053,7 @@ def test_filter_stage_keeps_not_scored_rows(self) -> None: "faith_segment_scores": ["[]", '[{"Fluency": 3.0}]'], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -1085,7 +1084,7 @@ def test_raw_mode_creates_metadata_drops_translated(self) -> None: "id": [1], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -1111,7 +1110,7 @@ def test_both_mode_keeps_both_columns(self) -> None: "id": [1], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -1133,7 +1132,7 @@ def test_raw_mode_uses_reassembly_helper_maps(self) -> None: "_segmented_translation_map": [json.dumps({"question": [{"src": "Hello", "tgt": "Hallo"}]})], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -1158,7 +1157,7 @@ def test_replaced_mode_no_metadata(self) -> None: "id": [1], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -1190,7 +1189,7 @@ def test_merge_scores_into_metadata(self) -> None: "faith_avg": [4.2], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -1205,7 +1204,7 @@ def test_merge_scores_no_faith_columns(self) -> None: metadata = json.dumps({"target_lang": "de"}) df = pd.DataFrame({"translation_metadata": [metadata], "text": ["Hello"]}) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) # Should return the batch unmodified assert result.to_pandas()["translation_metadata"].iloc[0] == metadata diff --git a/tests/stages/text/experimental/translation/test_reassembly.py b/tests/stages/text/experimental/translation/test_reassembly.py index d711e10c9d..4f435a035f 100644 --- a/tests/stages/text/experimental/translation/test_reassembly.py +++ b/tests/stages/text/experimental/translation/test_reassembly.py @@ -42,7 +42,7 @@ def _make_batch(texts: list[str], **extra_columns: list) -> DocumentBatch: data = {"text": texts} data.update(extra_columns) df = pd.DataFrame(data) - return DocumentBatch(data=df, dataset_name="test", task_id="1") + return DocumentBatch(data=df, dataset_name="test") def _segment_and_add_translations( @@ -59,7 +59,7 @@ def _segment_and_add_translations( if extra_columns: data.update(extra_columns) df = pd.DataFrame(data) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") seg_stage = SegmentationStage(source_lang="en", mode=mode) segmented = seg_stage.process(batch) @@ -74,7 +74,6 @@ def _segment_and_add_translations( return DocumentBatch( data=seg_df, dataset_name=segmented.dataset_name, - task_id=segmented.task_id, _metadata=segmented._metadata, _stage_perf=segmented._stage_perf, ) @@ -266,7 +265,7 @@ def test_empty_batch(self) -> None: "text": pd.Series(dtype="str"), } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") stage = ReassemblyStage() result = stage.process(batch) @@ -317,7 +316,7 @@ def test_wildcard_list_field_replaced_in_place(self) -> None: ] } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") segmented = SegmentationStage( source_lang="en", @@ -329,7 +328,6 @@ def test_wildcard_list_field_replaced_in_place(self) -> None: translated_batch = DocumentBatch( data=seg_df, dataset_name=segmented.dataset_name, - task_id=segmented.task_id, _metadata=segmented._metadata, _stage_perf=segmented._stage_perf, ) diff --git a/tests/stages/text/experimental/translation/test_segmentation.py b/tests/stages/text/experimental/translation/test_segmentation.py index be7feca23d..b397c5da92 100644 --- a/tests/stages/text/experimental/translation/test_segmentation.py +++ b/tests/stages/text/experimental/translation/test_segmentation.py @@ -42,7 +42,7 @@ def _make_batch(texts: list[str], **extra_columns: list) -> DocumentBatch: data = {"text": texts} data.update(extra_columns) df = pd.DataFrame(data) - return DocumentBatch(data=df, dataset_name="test", task_id="1") + return DocumentBatch(data=df, dataset_name="test") def _seg_metadata(batch: DocumentBatch, row: int = 0) -> dict: @@ -386,7 +386,7 @@ async def _query_model_impl(self, **kwargs: object) -> list[str]: # type: ignor tag_only = "
" json_blob = '{"tool":"lookup","payload":{"model":"DeepSeek V3"}}' df = pd.DataFrame({"_seg_segments": [code_block, tag_only, json_blob]}) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -421,7 +421,7 @@ async def _query_model_impl(self, **kwargs: object) -> list[str]: # type: ignor stage._initialized = True df = pd.DataFrame({"_seg_segments": ["Hello world"]}) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() diff --git a/tests/stages/text/experimental/translation/test_translate.py b/tests/stages/text/experimental/translation/test_translate.py index 46affaab25..de91843128 100644 --- a/tests/stages/text/experimental/translation/test_translate.py +++ b/tests/stages/text/experimental/translation/test_translate.py @@ -146,7 +146,7 @@ def test_process_llm_backend(self) -> None: "id": [1, 2], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -198,7 +198,7 @@ async def _fake_async(texts: list[str], src: str, tgt: str) -> list[str]: "id": [1, 2], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() @@ -241,7 +241,7 @@ async def _fake_async(texts: list[str], src: str, tgt: str) -> list[str]: "id": [1, 2, 3], } ) - batch = DocumentBatch(data=df, dataset_name="test", task_id="1") + batch = DocumentBatch(data=df, dataset_name="test") result = stage.process(batch) result_df = result.to_pandas() diff --git a/tests/stages/text/io/reader/test_jsonl.py b/tests/stages/text/io/reader/test_jsonl.py index d51eae9ce0..a90b5c9b05 100644 --- a/tests/stages/text/io/reader/test_jsonl.py +++ b/tests/stages/text/io/reader/test_jsonl.py @@ -40,7 +40,7 @@ def sample_jsonl_files(tmp_path: Path) -> list[str]: def file_group_tasks(sample_jsonl_files: list[str]) -> list[FileGroupTask]: """Create multiple FileGroupTasks.""" return [ - FileGroupTask(task_id=f"task_{i}", dataset_name="test_dataset", data=[file_path], _metadata={}) + FileGroupTask(dataset_name="test_dataset", data=[file_path], _metadata={}) for i, file_path in enumerate(sample_jsonl_files) ] @@ -73,7 +73,7 @@ def test_storage_options_via_read_kwargs(self, tmp_path: Path, monkeypatch: pyte pd.DataFrame({"a": [1]}).to_json(file_path, orient="records", lines=True) # Reader uses read_kwargs storage options - task = FileGroupTask(task_id="t1", dataset_name="ds", data=[str(file_path)], _metadata={}) + task = FileGroupTask(dataset_name="ds", data=[str(file_path)], _metadata={}) stage = JsonlReaderStage(read_kwargs={"storage_options": {"auto_mkdir": True}}) seen: dict[str, object] = {} @@ -114,7 +114,7 @@ def fake_read_json(_path: object, *_args: object, **kwargs: object) -> pd.DataFr return pd.DataFrame({"x": [1, 2]}) monkeypatch.setattr(pd, "read_json", fake_read_json) - task = FileGroupTask(task_id="t2", dataset_name="ds", data=[str(f)], _metadata={}) + task = FileGroupTask(dataset_name="ds", data=[str(f)], _metadata={}) stage = JsonlReaderStage(read_kwargs={"storage_options": {"auto_mkdir": True}}) out = stage.process(task) assert seen["storage_options"] == {"auto_mkdir": True} diff --git a/tests/stages/text/io/reader/test_parquet.py b/tests/stages/text/io/reader/test_parquet.py index 27c6169007..1d5b39fd04 100644 --- a/tests/stages/text/io/reader/test_parquet.py +++ b/tests/stages/text/io/reader/test_parquet.py @@ -41,7 +41,7 @@ def sample_parquet_files(tmp_path: Path) -> list[str]: def parquet_file_group_tasks(sample_parquet_files: list[str]) -> list[FileGroupTask]: """Create multiple FileGroupTasks for parquet files.""" return [ - FileGroupTask(task_id=f"task_{i}", dataset_name="test_dataset", data=[file_path], _metadata={}) + FileGroupTask(dataset_name="test_dataset", data=[file_path], _metadata={}) for i, file_path in enumerate(sample_parquet_files) ] @@ -65,7 +65,6 @@ def _sample_records(start: int = 0, n: int = 2) -> list[dict]: def _make_file_group_task(files: list[str]) -> FileGroupTask: return FileGroupTask( - task_id="fg1", dataset_name="ds", data=files, reader_config={}, diff --git a/tests/stages/text/io/writer/conftest.py b/tests/stages/text/io/writer/conftest.py index 7d5c276fde..c27f168a87 100644 --- a/tests/stages/text/io/writer/conftest.py +++ b/tests/stages/text/io/writer/conftest.py @@ -31,7 +31,6 @@ def pandas_document_batch() -> DocumentBatch: } ) return DocumentBatch( - task_id="test_pandas_batch", dataset_name="test_dataset", data=df, _metadata={"dummy_key": "dummy_value"}, @@ -51,7 +50,6 @@ def pandas_document_batch() -> DocumentBatch: def pyarrow_document_batch(pandas_document_batch: DocumentBatch) -> DocumentBatch: """Fixture providing a pyarrow Table for testing.""" return DocumentBatch( - task_id="test_pyarrow_batch", dataset_name="test_dataset", data=pandas_document_batch.to_pyarrow(), _metadata={"dummy_key": "dummy_value"}, diff --git a/tests/stages/text/io/writer/test_jsonl.py b/tests/stages/text/io/writer/test_jsonl.py index 66f51c8a2a..32fcac8d0b 100644 --- a/tests/stages/text/io/writer/test_jsonl.py +++ b/tests/stages/text/io/writer/test_jsonl.py @@ -20,7 +20,6 @@ import pandas as pd import pytest -import nemo_curator.stages.text.io.writer.utils as writer_utils from nemo_curator.stages.text.io.writer import JsonlWriter from nemo_curator.stages.text.io.writer import base as writer_base from nemo_curator.tasks import DocumentBatch @@ -49,7 +48,7 @@ def test_jsonl_writer( # Process with ( mock.patch.object( - writer_utils, "get_deterministic_hash", return_value="_TEST_FILE_HASH" + writer_base, "get_deterministic_hash", return_value="_TEST_FILE_HASH" ) as mock_get_deterministic_hash, mock.patch.object(uuid, "uuid4", return_value=mock.Mock(hex="_TEST_FILE_HASH")) as mock_uuid4, ): @@ -62,12 +61,12 @@ def test_jsonl_writer( assert mock_get_deterministic_hash.call_count == 1 # Verify get_deterministic_hash was called with correct arguments mock_get_deterministic_hash.assert_called_once_with(source_files, document_batch.task_id) - # because we call it once for task, and that should be the only one - assert mock_uuid4.call_count <= 1 + # consistent path uses the content hash for the filename; uuid is unused + assert mock_uuid4.call_count == 0 else: assert mock_get_deterministic_hash.call_count == 0 - # because we call it once for task, and once for the filename - assert mock_uuid4.call_count == 2 + # non-consistent path uses a single uuid for the filename + assert mock_uuid4.call_count == 1 # Verify file was created assert result.task_id == document_batch.task_id # Task ID should match input @@ -107,12 +106,11 @@ def test_jsonl_writer_with_columns_subset(self, pandas_document_batch: DocumentB expected = pandas_document_batch.to_pandas()[["text", "score"]] pd.testing.assert_frame_equal(df, expected) - def test_jsonl_writer_defaults_to_utf8_output(self, tmpdir: str): """JsonlWriter should preserve non-ASCII characters by default.""" output_dir = os.path.join(tmpdir, "jsonl_utf8") writer = JsonlWriter(path=output_dir) - batch = DocumentBatch(data=pd.DataFrame({"text": ["你好, 世界"], "id": [1]}), task_id="utf8", dataset_name="test") + batch = DocumentBatch(data=pd.DataFrame({"text": ["你好, 世界"], "id": [1]}), dataset_name="test") writer.setup() result = writer.process(batch) @@ -127,7 +125,7 @@ def test_jsonl_writer_allows_force_ascii_override(self, tmpdir: str): """JsonlWriter should still honor user-provided force_ascii=True.""" output_dir = os.path.join(tmpdir, "jsonl_ascii") writer = JsonlWriter(path=output_dir, write_kwargs={"force_ascii": True}) - batch = DocumentBatch(data=pd.DataFrame({"text": ["你好, 世界"], "id": [1]}), task_id="ascii", dataset_name="test") + batch = DocumentBatch(data=pd.DataFrame({"text": ["你好, 世界"], "id": [1]}), dataset_name="test") writer.setup() result = writer.process(batch) diff --git a/tests/stages/text/io/writer/test_megatron_tokenizer.py b/tests/stages/text/io/writer/test_megatron_tokenizer.py index 31e5bd2d6b..ef68b841cf 100644 --- a/tests/stages/text/io/writer/test_megatron_tokenizer.py +++ b/tests/stages/text/io/writer/test_megatron_tokenizer.py @@ -23,7 +23,6 @@ import numpy as np import pytest -import nemo_curator.stages.text.io.writer.utils as writer_utils from nemo_curator.stages.text.io.writer.megatron_tokenizer import _INDEX_HEADER, MegatronTokenizerWriter from nemo_curator.tasks import DocumentBatch @@ -113,8 +112,9 @@ def test_megatron_tokenizer_writer( # Process with ( - mock.patch.object( - writer_utils, "get_deterministic_hash", return_value="_TEST_FILE_HASH" + mock.patch( + "nemo_curator.stages.text.io.writer.megatron_tokenizer.get_deterministic_hash", + return_value="_TEST_FILE_HASH", ) as mock_get_deterministic_hash, mock.patch.object(uuid, "uuid4", return_value=mock.Mock(hex="_TEST_FILE_HASH")) as mock_uuid4, ): @@ -127,12 +127,12 @@ def test_megatron_tokenizer_writer( assert mock_get_deterministic_hash.call_count == 1 # Verify get_deterministic_hash was called with correct arguments mock_get_deterministic_hash.assert_called_once_with(source_files, document_batch.task_id) - # because we call it once for task, and that should be the only one - assert mock_uuid4.call_count <= 1 + # consistent path uses the content hash for the filename; uuid is unused + assert mock_uuid4.call_count == 0 else: assert mock_get_deterministic_hash.call_count == 0 - # because we call it once for task, and once for the filename - assert mock_uuid4.call_count == 2 + # non-consistent path uses a single uuid for the filename + assert mock_uuid4.call_count == 1 # Verify file was created assert result.task_id == document_batch.task_id # Task ID should match input diff --git a/tests/stages/text/io/writer/test_parquet.py b/tests/stages/text/io/writer/test_parquet.py index e47703b56b..37a2dd1cbb 100644 --- a/tests/stages/text/io/writer/test_parquet.py +++ b/tests/stages/text/io/writer/test_parquet.py @@ -22,7 +22,6 @@ from nemo_curator.stages.text.io.writer import ParquetWriter from nemo_curator.stages.text.io.writer import base as writer_base -from nemo_curator.stages.text.io.writer import utils as writer_utils from nemo_curator.tasks import DocumentBatch @@ -49,7 +48,7 @@ def test_parquet_writer( # Process with ( mock.patch.object( - writer_utils, "get_deterministic_hash", return_value="_TEST_FILE_HASH" + writer_base, "get_deterministic_hash", return_value="_TEST_FILE_HASH" ) as mock_get_deterministic_hash, mock.patch.object(uuid, "uuid4", return_value=mock.Mock(hex="_TEST_FILE_HASH")) as mock_uuid4, ): @@ -62,12 +61,12 @@ def test_parquet_writer( assert mock_get_deterministic_hash.call_count == 1 # Verify get_deterministic_hash was called with correct arguments mock_get_deterministic_hash.assert_called_once_with(source_files, document_batch.task_id) - # because we call it once for task, and that should be the only one - assert mock_uuid4.call_count <= 1 + # consistent path uses the content hash for the filename; uuid is unused + assert mock_uuid4.call_count == 0 else: assert mock_get_deterministic_hash.call_count == 0 - # because we call it once for task, and once for the filename - assert mock_uuid4.call_count == 2 + # non-consistent path uses a single uuid for the filename + assert mock_uuid4.call_count == 1 # Verify file was created assert result.task_id == document_batch.task_id # Task ID should match input diff --git a/tests/stages/text/models/test_model.py b/tests/stages/text/models/test_model.py index b4959988d6..5774435b4d 100644 --- a/tests/stages/text/models/test_model.py +++ b/tests/stages/text/models/test_model.py @@ -166,7 +166,7 @@ def test_process_with_seq_order(self): stage.setup() df = self.create_sample_dataframe(4, include_seq_order=True) - batch = DocumentBatch(task_id="test_task", dataset_name="test_dataset", data=df) + batch = DocumentBatch(dataset_name="test_dataset", data=df) result = stage.process(batch).to_pandas() @@ -185,7 +185,7 @@ def test_process_with_max_seq_length_right_padding(self): stage.setup() df = self.create_sample_dataframe(4, include_seq_order=False) - batch = DocumentBatch(task_id="test_task", dataset_name="test_dataset", data=df) + batch = DocumentBatch(dataset_name="test_dataset", data=df) result = stage.process(batch).to_pandas() @@ -200,7 +200,7 @@ def test_process_with_max_seq_length_left_padding(self): stage.setup() df = self.create_sample_dataframe(4, include_seq_order=False) - batch = DocumentBatch(task_id="test_task", dataset_name="test_dataset", data=df) + batch = DocumentBatch(dataset_name="test_dataset", data=df) result = stage.process(batch).to_pandas() @@ -213,7 +213,7 @@ def test_process_without_seq_order(self): stage.setup() df = self.create_sample_dataframe(4, include_seq_order=False) - batch = DocumentBatch(task_id="test_task", dataset_name="test_dataset", data=df) + batch = DocumentBatch(dataset_name="test_dataset", data=df) result = stage.process(batch).to_pandas() @@ -239,7 +239,7 @@ def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: stage.setup() df = self.create_sample_dataframe(2) - batch = DocumentBatch(task_id="test_task", dataset_name="test_dataset", data=df) + batch = DocumentBatch(dataset_name="test_dataset", data=df) _ = stage.process(batch) diff --git a/tests/stages/text/models/test_tokenizer.py b/tests/stages/text/models/test_tokenizer.py index b6bbcbb557..3cbde91c33 100644 --- a/tests/stages/text/models/test_tokenizer.py +++ b/tests/stages/text/models/test_tokenizer.py @@ -82,7 +82,7 @@ def sample_document_batch() -> DocumentBatch: } ) - return DocumentBatch(task_id="test_task", dataset_name="test_dataset", data=data) + return DocumentBatch(dataset_name="test_dataset", data=data) @pytest.fixture(autouse=True) @@ -173,7 +173,7 @@ def test_tokenizer_stage_max_chars_truncation(): data = pd.DataFrame( {"text": ["This is a very long text that should be truncated when max_chars is set to a small value"]} ) - batch = DocumentBatch(task_id="test_task", dataset_name="test_dataset", data=data) + batch = DocumentBatch(dataset_name="test_dataset", data=data) stage = TokenizerStage(model_identifier="test/model", max_chars=20, sort_by_length=False, text_field="text") diff --git a/tests/stages/text/modules/test_add_id.py b/tests/stages/text/modules/test_add_id.py index 5b2ca1f9ed..1ce8b3d181 100644 --- a/tests/stages/text/modules/test_add_id.py +++ b/tests/stages/text/modules/test_add_id.py @@ -22,7 +22,7 @@ def _sample_batch() -> DocumentBatch: """Create a simple three-row batch for tests.""" df = pd.DataFrame({"text": ["first", "second", "third"]}) - return DocumentBatch(data=df, task_id="batch_1", dataset_name="test_ds") + return DocumentBatch(data=df, dataset_name="test_ds") class TestAddIdStage: @@ -35,7 +35,7 @@ def test_add_id_basic(self) -> None: result = stage.process(batch) assert result is not None, "Stage returned None" - prefix = str(batch._uuid) + prefix = batch.task_id expected_ids = [f"{prefix}_{i}" for i in range(len(batch.to_pandas()))] # Check column creation and values @@ -44,9 +44,6 @@ def test_add_id_basic(self) -> None: # Original data should remain unchanged pd.testing.assert_series_equal(batch.data["text"], result.data["text"]) - # Task id should include the stage name - assert result.task_id == f"{batch.task_id}_{stage.name}" - def test_io_spec(self) -> None: """The declared inputs/outputs match the contract.""" stage = AddId(id_field="custom_id") @@ -54,9 +51,16 @@ def test_io_spec(self) -> None: assert stage.outputs() == (["data"], ["custom_id"]) def test_unique_ids_across_batches(self) -> None: - """Ensure IDs are unique across different batches.""" + """Ensure IDs are unique across different batches. + + AddId derives doc ids from ``batch.task_id``. In a pipeline the + executor adapter gives each batch a unique task_id (its id + path); simulate that here so the two batches don't collide. + """ batch1 = _sample_batch() batch2 = _sample_batch() + batch1._set_task_id("", 0) + batch2._set_task_id("", 1) stage = AddId(id_field="id") @@ -79,7 +83,7 @@ def test_id_prefix_is_applied(self) -> None: result = stage.process(batch) - prefix = f"custom_{batch._uuid}" + prefix = f"custom_{batch.task_id}" expected_ids = [f"{prefix}_{i}" for i in range(len(batch.to_pandas()))] assert list(result.data["uid"]) == expected_ids @@ -101,7 +105,7 @@ def test_overwrite_true_replaces_column(self) -> None: stage = AddId(id_field="my_id", overwrite=True) result = stage.process(batch) - prefix = str(batch._uuid) + prefix = batch.task_id expected_ids = [f"{prefix}_{i}" for i in range(len(batch.to_pandas()))] assert list(result.data["my_id"]) == expected_ids # Ensure the old values are gone diff --git a/tests/stages/text/modules/test_filters.py b/tests/stages/text/modules/test_filters.py index 1ca1359fd2..486b43706f 100644 --- a/tests/stages/text/modules/test_filters.py +++ b/tests/stages/text/modules/test_filters.py @@ -170,7 +170,6 @@ def list_to_dataset(documents: list[str], col_name: str = "text") -> DocumentBat return DocumentBatch( data=pdf, - task_id="batch_1", dataset_name="test_1", ) @@ -179,7 +178,6 @@ def list_to_dataset(documents: list[str], col_name: str = "text") -> DocumentBat def letter_count_data() -> DocumentBatch: return DocumentBatch( data=pd.DataFrame({"documents": ["Two aa", "a a Three a", "Five aaa aa", "aaaSeven aaaa"]}), - task_id="batch_1", dataset_name="test_1", ) @@ -193,7 +191,6 @@ def test_score_filter(self, letter_count_data: DocumentBatch) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"documents": ["Five aaa aa", "aaaSeven aaaa"]}), - task_id="batch_1_letter_count", dataset_name="test_1", ) @@ -238,7 +235,6 @@ def test_retain_score_filter(self, letter_count_data: DocumentBatch) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"documents": ["Five aaa aa", "aaaSeven aaaa"]}), - task_id="batch_1_letter_count", dataset_name="test_1", ) expected_data.data[score_field] = pd.Series([5, 7]) @@ -261,7 +257,6 @@ def test_filter_document(self, letter_count_data: DocumentBatch) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"documents": ["Five aaa aa", "aaaSeven aaaa"]}), - task_id="batch_1_score_fn_filter_fn", dataset_name="test_1", ) expected_data.data[score_field] = pd.Series([5, 7]) @@ -284,7 +279,6 @@ def test_filter(self, letter_count_data: DocumentBatch) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"documents": ["Five aaa aa", "aaaSeven aaaa"]}), - task_id="batch_1_letter_count_letter_count", dataset_name="test_1", ) expected_data.data[score_field] = pd.Series([5, 7]) @@ -298,7 +292,6 @@ def test_invert(self, letter_count_data: DocumentBatch) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"documents": ["Two aa", "a a Three a"]}), - task_id="batch_1_letter_count", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -329,7 +322,6 @@ def test_score_filter_chain(self, letter_count_data: DocumentBatch, score_field: expected_data = DocumentBatch( data=expected_df, - task_id="batch_1_score_filter_chain_of_letter_count_letter_count", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -363,7 +355,6 @@ def test_score_chain(self, letter_count_data: DocumentBatch, score_field: list[s expected_data = DocumentBatch( data=expected_df, - task_id="batch_1_score_chain_of_letter_count_letter_count", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -378,7 +369,6 @@ def test_filter_chain(self, filter_field: list[str] | None) -> None: "e_count": [0, 2, 1, 2], } ), - task_id="batch_1", dataset_name="test_1", ) @@ -407,7 +397,6 @@ def test_filter_chain(self, filter_field: list[str] | None) -> None: expected_data = DocumentBatch( data=expected_df, - task_id="batch_1_filter_chain_of_letter_count_letter_count", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -432,7 +421,6 @@ def test_score_filter_all_rows(self, letter_count_data: DocumentBatch) -> None: expected_data = DocumentBatch( data=expected_df, - task_id="batch_1_letter_count", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -451,7 +439,6 @@ def test_score_filter_all_rows_chain(self, letter_count_data: DocumentBatch) -> expected_data = DocumentBatch( data=expected_df, - task_id="batch_1_score_filter_chain_of_letter_count_letter_count", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -490,7 +477,6 @@ def test_nonalpha(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["This is a test case.", "$aaa"]}), - task_id="batch_1_alpha_numeric", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -510,7 +496,6 @@ def test_symbolswords(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["full of words", "barely ok 3 4 5 6 7 8 9 #"]}), - task_id="batch_1_symbol_to_word", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -523,7 +508,6 @@ def test_numbers(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["purely letters", "$!@$@!$!@", "abcdefghi1"]}), - task_id="batch_1_numbers_ratio", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -552,7 +536,6 @@ def test_urls(self) -> None: ] } ), - task_id="batch_1_urls_ratio", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -596,7 +579,6 @@ def test_urls_filter_accepts_custom_regex(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["no urls here!", "https://www.nvidia.com/en-us/"]}), - task_id="batch_1_urls_ratio", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -633,7 +615,6 @@ def test_bullets(self) -> None: ] } ), - task_id="batch_1_bullet_ratio", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -646,7 +627,6 @@ def test_whitespace(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["good", "123\b"]}), - task_id="batch_1_white_space", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -659,7 +639,6 @@ def test_parentheses(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["this is completely absolutely fine", "123456789("]}), - task_id="batch_1_parentheses_ratio", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -672,7 +651,6 @@ def test_longword(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["tiny"]}), - task_id="batch_1_max_word_length", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -685,7 +663,6 @@ def test_wordcount(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["two words", "$#@$ %$@$#@ !#@!"]}), - task_id="batch_1_word_count", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -698,7 +675,6 @@ def test_wordcount_zh(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["你好。", "我喜欢学习中文。"]}), - task_id="batch_1_word_count", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -712,7 +688,6 @@ def test_wordcount_ja(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["猫が寝ます。", "私は日本語のテキストを分割します。"]}), - task_id="batch_1_word_count_ja", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -738,7 +713,6 @@ def test_boilerplate(self) -> None: ] } ), - task_id="batch_1_boilerplate_string_ratio", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -759,7 +733,6 @@ def test_meanwordlength(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["superlongword short", "evenly balanced"]}), - task_id="batch_1_mean_word_length", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -772,7 +745,6 @@ def test_repeatedlines(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["totally unique"]}), - task_id="batch_1_repeated_lines", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -785,7 +757,6 @@ def test_repeatedparagraphs(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["totally unique"]}), - task_id="batch_1_repeated_paragraphs", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -805,7 +776,6 @@ def test_repeatedlineschar(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["totally unique", "a.\na.\nvery very very short duplicate."]}), - task_id="batch_1_repeated_lines_char", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -825,7 +795,6 @@ def test_repeatedparagraphschar(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["totally unique", "a.\n\n a.\n\n very very very short duplicate."]}), - task_id="batch_1_repeated_paragraphs_char", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -852,7 +821,6 @@ def test_repeatingtopngrams(self) -> None: ] } ), - task_id="batch_1_repeating_top_2grams", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -865,7 +833,6 @@ def test_repeatingduplicatengrams(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["totally fine", "a a a a this should be fine as well"]}), - task_id="batch_1_repeating_dup_2gram", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -878,7 +845,6 @@ def test_punctuation(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["good.", "just\n barely\n fine\n ok\n yep."]}), - task_id="batch_1_punctuation", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -891,7 +857,6 @@ def test_ellipsis(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["good.", "just...\n barely...\n fine...\n ok...\n yep."]}), - task_id="batch_1_ellipsis", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -904,7 +869,6 @@ def test_commonenglishwords(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["the and", "the and and of to"]}), - task_id="batch_1_common_english_words", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -917,7 +881,6 @@ def test_wordswithoutalphabets(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["totally fine", "good good good good !"]}), - task_id="batch_1_words_without_alphabets", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -936,7 +899,6 @@ def test_pornographicurls(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["no url", "fine url https://www.nvidia.com/en-us/"]}), - task_id="batch_1_PornographicUrlsFilter", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -968,7 +930,6 @@ def test_histogram(self) -> None: ] } ), - task_id="batch_1_histogram", dataset_name="test_1", ) expected_data2 = DocumentBatch( @@ -981,7 +942,6 @@ def test_histogram(self) -> None: ] } ), - task_id="batch_1_histogram", dataset_name="test_1", ) assert all_equal(expected_data1, filtered_data1), f"Expected {expected_data1} but got {filtered_data1}" @@ -1025,7 +985,6 @@ def test_filter_dataset(self) -> None: # We expect to keep only the documents with exactly 2 or 3 tokens. expected_dataset = DocumentBatch( data=pd.DataFrame({"text": ["hello world", "another test case"]}), - task_id="batch_1_token_count", dataset_name="test_1", ) assert all_equal(expected_dataset, filtered_dataset) @@ -1050,7 +1009,6 @@ def test_filter_dataset_default(self) -> None: # We expect to keep all documents. expected_dataset = DocumentBatch( data=pd.DataFrame({"text": docs}), - task_id="batch_1_token_count", dataset_name="test_1", ) assert all_equal(expected_dataset, filtered_dataset) @@ -1112,7 +1070,6 @@ def test_filter_dataset_prefix(self) -> None: # Expect only those records where the text starts with "Hello". expected_dataset = DocumentBatch( data=pd.DataFrame({"text": ["Hello world", "Hello everyone"]}), - task_id="batch_1_SubstringFilter", dataset_name="test_1", ) @@ -1135,7 +1092,6 @@ def test_filter_dataset_suffix(self) -> None: # Expect only those records that end with "end". expected_dataset = DocumentBatch( data=pd.DataFrame({"text": ["This is the end", "Not matching end", "The end"]}), - task_id="batch_1_SubstringFilter", dataset_name="test_1", ) @@ -1152,7 +1108,6 @@ def test_filter_dataset_any(self) -> None: # Expect documents that contain "test" anywhere. expected_dataset = DocumentBatch( data=pd.DataFrame({"text": ["test case", "This is a testcase", "another test"]}), - task_id="batch_1_SubstringFilter", dataset_name="test_1", ) @@ -1171,7 +1126,6 @@ def test_python_comment_to_code(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": [doc_1, doc_4]}), - task_id="batch_1_python_comment_ratio", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -1188,7 +1142,6 @@ def test_general_commment_to_code(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": [doc_1, doc_4]}), - task_id="batch_1_comment_ratio", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -1208,7 +1161,6 @@ def test_number_lines_code(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": [doc_2]}), - task_id="batch_1_num_lines", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -1221,7 +1173,6 @@ def test_xml_header(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["no header"]}), - task_id="batch_1_xml_header", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -1234,7 +1185,6 @@ def test_alpha(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["full of alphabet", "mixed <>"]}), - task_id="batch_1_alpha_filter", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -1282,7 +1232,6 @@ def test_html_boilerplate(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": [good_doc]}), - task_id="batch_1_html_boilerplate", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -1320,7 +1269,6 @@ def test_per_extension_filter(self, per_extension_filter: PerExtensionFilter) -> expected_data = DocumentBatch( data=pd.DataFrame({"text": [good_cpp]}), - task_id="batch_1_per_extension_filter", dataset_name="test_1", ) @@ -1351,7 +1299,6 @@ def test_fake_quality_filter(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["b", "c", "d"]}), - task_id="batch_1_FakeQualityFilter", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" @@ -1364,7 +1311,6 @@ def test_fake_langid_filter(self) -> None: expected_data = DocumentBatch( data=pd.DataFrame({"text": ["a", "b", "d"]}), - task_id="batch_1_FakeLangId", dataset_name="test_1", ) assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" diff --git a/tests/stages/text/modules/test_joiner.py b/tests/stages/text/modules/test_joiner.py index 5e0e5b0cdf..0b1c29e445 100644 --- a/tests/stages/text/modules/test_joiner.py +++ b/tests/stages/text/modules/test_joiner.py @@ -31,7 +31,6 @@ def test_basic_join(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -63,7 +62,6 @@ def test_custom_separator(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -85,7 +83,6 @@ def test_custom_fields(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -112,7 +109,6 @@ def test_keep_segment_id_field(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -136,7 +132,6 @@ def test_max_length_single_segment_exceeds(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -165,7 +160,6 @@ def test_multiple_documents(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -191,7 +185,6 @@ def test_preserve_additional_columns(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -218,7 +211,6 @@ def test_out_of_order_segments(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -234,7 +226,6 @@ def test_empty_batch(self): """Test handling of empty batch.""" df = pd.DataFrame(columns=["id", "segment_id", "text"]) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -254,7 +245,6 @@ def test_single_segment_document(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -320,7 +310,6 @@ def test_metadata_preservation(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, _metadata={"source": "test_source", "version": "1.0"}, @@ -344,7 +333,6 @@ def test_roundtrip_with_splitter(self): } ) original_batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=original_df, ) diff --git a/tests/stages/text/modules/test_modifiers.py b/tests/stages/text/modules/test_modifiers.py index a0c685b56a..17e097d4ef 100644 --- a/tests/stages/text/modules/test_modifiers.py +++ b/tests/stages/text/modules/test_modifiers.py @@ -31,7 +31,7 @@ def list_to_doc_batch(documents: list[str], col_name: str = "text") -> DocumentBatch: df = pd.DataFrame({col_name: documents}) - return DocumentBatch(data=df, task_id="test_id", dataset_name="test_ds") + return DocumentBatch(data=df, dataset_name="test_ds") def run_modify(modifier: DocumentModifier, doc_batch: DocumentBatch) -> DocumentBatch: @@ -564,7 +564,7 @@ def join(a: str, b: str) -> str: return f"{a}-{b}" df = pd.DataFrame({"a": ["x", "hello"], "b": ["y", "world"]}) - batch = DocumentBatch(task_id="t", dataset_name="ds", data=df) + batch = DocumentBatch(dataset_name="ds", data=df) m = Modify(join, input_fields=[["a", "b"]], output_fields="joined") m.setup() @@ -591,7 +591,7 @@ def concat(a: str, b: str) -> str: return a + b df = pd.DataFrame({"a": [" a ", "b "], "b": ["x", "y"]}) - batch = DocumentBatch(task_id="t", dataset_name="ds", data=df) + batch = DocumentBatch(dataset_name="ds", data=df) m = Modify( [strip_a, concat], diff --git a/tests/stages/text/modules/test_splitter.py b/tests/stages/text/modules/test_splitter.py index e71cd77d99..3c91fa37b4 100644 --- a/tests/stages/text/modules/test_splitter.py +++ b/tests/stages/text/modules/test_splitter.py @@ -29,7 +29,6 @@ def test_basic_split(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -69,7 +68,6 @@ def test_custom_separator(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -92,7 +90,6 @@ def test_custom_text_field(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -115,7 +112,6 @@ def test_custom_segment_id_field(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -138,7 +134,6 @@ def test_no_split_needed(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -158,7 +153,6 @@ def test_empty_segments(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -183,7 +177,6 @@ def test_metadata_preservation(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, _metadata={"source": "test_source"}, @@ -221,7 +214,6 @@ def test_validate_input(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) @@ -236,7 +228,6 @@ def test_validate_input(self): } ) batch_missing = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df_missing, ) @@ -251,7 +242,6 @@ def test_reconstruction_with_unique_ids(self): } ) batch = DocumentBatch( - task_id="test_batch", dataset_name="test_dataset", data=df, ) diff --git a/tests/stages/video/caption/test_caption_enhancement.py b/tests/stages/video/caption/test_caption_enhancement.py index e31d8acd0c..2b16227f62 100644 --- a/tests/stages/video/caption/test_caption_enhancement.py +++ b/tests/stages/video/caption/test_caption_enhancement.py @@ -162,7 +162,7 @@ def _create_test_video_task_with_captions(self) -> VideoTask: clip2.windows = [window3] video.clips = [clip1, clip2] - return VideoTask(task_id="test", dataset_name="test", data=video) + return VideoTask(dataset_name="test", data=video) def _create_test_video_task_empty_clips(self) -> VideoTask: """Create a test VideoTask with empty clips.""" @@ -174,7 +174,7 @@ def _create_test_video_task_empty_clips(self) -> VideoTask: clip = Clip(uuid=uuid4(), source_video="test.mp4", span=(0.0, 10.0), buffer=b"test_buffer") clip.windows = [] video.clips = [clip] - return VideoTask(task_id="test", dataset_name="test", data=video) + return VideoTask(dataset_name="test", data=video) def _create_test_video_task_no_captions(self) -> VideoTask: """Create a test VideoTask with windows but no captions.""" @@ -191,7 +191,7 @@ def _create_test_video_task_no_captions(self) -> VideoTask: clip.windows = [window] video.clips = [clip] - return VideoTask(task_id="test", dataset_name="test", data=video) + return VideoTask(dataset_name="test", data=video) @patch("nemo_curator.stages.video.caption.caption_enhancement.logger") def test_process_with_valid_captions(self, _mock_logger: Mock): # noqa: PT019 @@ -305,7 +305,7 @@ def test_process_batch_processing(self): clip.windows.append(window) video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) self.stage.process(task) @@ -351,7 +351,7 @@ def test_process_with_verbose_logging(self, mock_logger: Mock): window.enhanced_caption = {} clip.windows = [window] video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) self.stage.process(task) @@ -424,7 +424,7 @@ def _make_task_with_captions(captions: list[str], task_id: str = "enhancement-te clip.windows = windows video = Video(input_video=Path(task_id + ".mp4")) video.clips = [clip] - return VideoTask(task_id=task_id, dataset_name="integration", data=video) + return VideoTask(dataset_name="integration", data=video) @pytest.mark.gpu diff --git a/tests/stages/video/caption/test_caption_generation.py b/tests/stages/video/caption/test_caption_generation.py index 84d861bb52..f977853464 100644 --- a/tests/stages/video/caption/test_caption_generation.py +++ b/tests/stages/video/caption/test_caption_generation.py @@ -176,7 +176,7 @@ def _create_test_video_task(self) -> VideoTask: video.clips = [clip1, clip2] - return VideoTask(task_id="test", dataset_name="test", data=video) + return VideoTask(dataset_name="test", data=video) def test_process_successful_generation(self): """Test successful caption generation process.""" @@ -228,7 +228,7 @@ def test_process_with_verbose_logging(self): ) clip.windows = [window] video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) self.stage.process(task) @@ -245,7 +245,7 @@ def test_process_empty_windows(self, mock_logger: Mock): clip = Clip(uuid=uuid4(), source_video="test.mp4", span=(0.0, 5.0), buffer=b"test_buffer") # No windows added video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) self.stage.process(task) @@ -271,7 +271,7 @@ def test_process_window_without_input(self, mock_logger: Mock): window = _Window(start_frame=0, end_frame=5) clip.windows = [window] video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) self.stage.process(task) @@ -347,7 +347,7 @@ def test_process_with_stage2_caption(self): ) clip.windows = [window] video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) self.stage.process(task) @@ -437,7 +437,7 @@ def test_process_nemotron_with_generic_llm_inputs(self): clip.windows = [window1, window2] video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) result = stage.process(task) @@ -465,7 +465,7 @@ def _make_task(video_bytes: bytes, task_id: str = "integration-test") -> VideoTa ) video = Video(input_video=Path(task_id + ".mp4")) video.clips = [clip] - return VideoTask(task_id=task_id, dataset_name="integration", data=video) + return VideoTask(dataset_name="integration", data=video) @pytest.fixture(scope="module") @@ -538,7 +538,7 @@ def run_pipeline( ) -> None: """Run prep→generation once and store state on the class for all tests.""" video_bytes = video_fixture_path.read_bytes() - task = _make_task(video_bytes, task_id="pipeline-test") + task = _make_task(video_bytes) # --- preparation stage --- task = preparation_stage.process(task) diff --git a/tests/stages/video/caption/test_caption_preparation.py b/tests/stages/video/caption/test_caption_preparation.py index 2f12356ba1..e6ee90af40 100644 --- a/tests/stages/video/caption/test_caption_preparation.py +++ b/tests/stages/video/caption/test_caption_preparation.py @@ -146,7 +146,7 @@ def _create_test_video_task(self) -> VideoTask: video.clips = [clip1, clip2] - return VideoTask(task_id="test", dataset_name="test", data=video) + return VideoTask(dataset_name="test", data=video) @patch("nemo_curator.stages.video.caption.caption_preparation.windowing_utils.split_video_into_windows") @patch("nemo_curator.stages.video.caption.caption_preparation._get_prompt") @@ -237,7 +237,7 @@ def test_process_clip_without_buffer(self, mock_logger: Mock): # Mock the id attribute since original code uses clip.id but Clip only has uuid clip.id = clip.uuid video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) # Setup formatter self.stage.prompt_formatter = Mock() @@ -346,7 +346,7 @@ def test_process_multiple_windows_different_frame_ranges(self, mock_get_prompt: # Mock attributes for original code bugs/quirks clip.id = clip.uuid video.clips = [clip] - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) result = self.stage.process(task) diff --git a/tests/stages/video/clipping/test_clip_frame_extraction.py b/tests/stages/video/clipping/test_clip_frame_extraction.py index 67ed65ae62..a5f11f177a 100644 --- a/tests/stages/video/clipping/test_clip_frame_extraction.py +++ b/tests/stages/video/clipping/test_clip_frame_extraction.py @@ -68,7 +68,7 @@ def setup_method(self) -> None: clips=self.mock_clips, ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_name_property(self) -> None: """Test that the name property returns the correct value.""" @@ -178,7 +178,6 @@ def test_process_successful_extraction(self, mock_extract_frames: Mock) -> None: # With target_fps=[2, 4], LCM=4, so extract_frames is called once per clip with buffer assert mock_extract_frames.call_count == 2 # 2 clips with buffers assert isinstance(result, VideoTask) - assert result.task_id == "test_task" # Check that extracted frames are stored with correct signatures processed_clips = result.data.clips @@ -249,7 +248,6 @@ def test_process_with_lcm_optimization(self, mock_extract_frames: Mock) -> None: # Create task with clips that have buffers task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=Video( input_video="test_video.mp4", @@ -295,7 +293,6 @@ def test_process_without_lcm_optimization(self, mock_extract_frames: Mock) -> No # Create task with clips that have buffers task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=Video( input_video="test_video.mp4", @@ -329,7 +326,6 @@ def test_process_multiple_extraction_policies(self, mock_extract_frames: Mock) - # Create task with clips that have buffers task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=Video( input_video="test_video.mp4", @@ -372,7 +368,6 @@ def test_process_verbose_logging(self, mock_extract_frames: Mock) -> None: # Create task with clips that have buffers task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=Video( input_video="test_video.mp4", @@ -392,9 +387,7 @@ def test_process_verbose_logging(self, mock_extract_frames: Mock) -> None: def test_process_no_clips(self) -> None: """Test processing when video has no clips.""" - task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=Video(input_video="test_video.mp4", clips=[]) - ) + task = VideoTask(dataset_name="test_dataset", data=Video(input_video="test_video.mp4", clips=[])) self.stage.setup() result = self.stage.process(task) @@ -419,7 +412,6 @@ def test_process_different_error_types(self) -> None: # Create task with one clip task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=Video( input_video="test_video.mp4", @@ -503,9 +495,7 @@ def test_clip_uuid_preservation(self) -> None: original_uuid = uuid.uuid4() clip = Clip(uuid=original_uuid, source_video="test_video.mp4", span=(0.0, 5.0), buffer=b"fake_video_data") - task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=Video(input_video="test_video.mp4", clips=[clip]) - ) + task = VideoTask(dataset_name="test_dataset", data=Video(input_video="test_video.mp4", clips=[clip])) stage = ClipFrameExtractionStage(target_fps=[2]) stage.setup() diff --git a/tests/stages/video/clipping/test_clip_transcoding_stage.py b/tests/stages/video/clipping/test_clip_transcoding_stage.py index 35a054adf5..d6b8c23fac 100644 --- a/tests/stages/video/clipping/test_clip_transcoding_stage.py +++ b/tests/stages/video/clipping/test_clip_transcoding_stage.py @@ -80,7 +80,7 @@ def setup_method(self) -> None: clips=copy.deepcopy(self.mock_clips), ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_name_property(self) -> None: """Test that the name property returns the correct value.""" @@ -180,7 +180,6 @@ def test_process_multiple_chunks(self, mock_split: MagicMock, mock_temp_dir: Mag # Verify task properties for i, task in enumerate(result): assert isinstance(task, VideoTask) - assert task.task_id == f"test_task_chunk_{i}" assert task.data.num_total_clips == len(self.mock_clips) assert task.data.num_clip_chunks == 2 assert task.data.clip_chunk_index == i @@ -563,7 +562,7 @@ def test_edge_case_empty_clips(self) -> None: clips=[], ) - empty_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=empty_video) + empty_task = VideoTask(dataset_name="test_dataset", data=empty_video) with patch("nemo_curator.stages.video.clipping.clip_extraction_stages.logger") as mock_logger: result = stage.process(empty_task) diff --git a/tests/stages/video/clipping/test_fixed_stride_extractor_stage.py b/tests/stages/video/clipping/test_fixed_stride_extractor_stage.py index d43fc3ecb5..b40ab893fa 100644 --- a/tests/stages/video/clipping/test_fixed_stride_extractor_stage.py +++ b/tests/stages/video/clipping/test_fixed_stride_extractor_stage.py @@ -49,7 +49,7 @@ def setup_method(self): clips=[], ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_name_property(self): """Test that the name property returns the correct value.""" @@ -68,7 +68,6 @@ def test_process_successful_extraction(self): result = self.stage.process(self.mock_task) assert isinstance(result, VideoTask) - assert result.task_id == "test_task" assert result.dataset_name == "test_dataset" assert len(result.data.clips) > 0 @@ -161,7 +160,7 @@ def test_clip_generation_logic(self): clips=[], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = self.stage.process(task) @@ -217,7 +216,7 @@ def test_min_clip_length_filtering(self): clips=[], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -252,7 +251,7 @@ def test_limit_clips_enforcement(self): clip = Clip(uuid=uuid.uuid4(), source_video="test_video.mp4", span=(i * 1.0, (i + 1) * 1.0)) video.clips.append(clip) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -291,7 +290,7 @@ def test_edge_case_very_short_video(self): clips=[], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = self.stage.process(task) @@ -315,7 +314,7 @@ def test_edge_case_exact_clip_length(self): clips=[], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = self.stage.process(task) @@ -391,7 +390,7 @@ def test_different_parameter_combinations(self): clips=[], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -427,7 +426,7 @@ def test_limit_clips_generation(self): clips=[], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -452,7 +451,7 @@ def test_metadata_validation_edge_cases(self): clips=[], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = self.stage.process(task) @@ -478,7 +477,7 @@ def test_negative_duration_calculation(self): clips=[], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = self.stage.process(task) diff --git a/tests/stages/video/clipping/test_transnetv2_extraction.py b/tests/stages/video/clipping/test_transnetv2_extraction.py index 51b37d21e7..98d571152d 100644 --- a/tests/stages/video/clipping/test_transnetv2_extraction.py +++ b/tests/stages/video/clipping/test_transnetv2_extraction.py @@ -72,7 +72,7 @@ def setup_method(self): frame_array=rng.integers(0, 255, (900, 27, 48, 3), dtype=np.uint8), ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_name_property(self): """Test the name property.""" @@ -187,7 +187,7 @@ def test_process_no_metadata(self): frame_array=rng.integers(0, 255, (100, 27, 48, 3), dtype=np.uint8), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video_without_metadata) + task = VideoTask(dataset_name="test_dataset", data=video_without_metadata) with patch("nemo_curator.models.transnetv2.TransNetV2"): self.stage.setup() @@ -220,7 +220,7 @@ def test_process_no_framerate(self): frame_array=rng.integers(0, 255, (100, 27, 48, 3), dtype=np.uint8), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video_no_framerate) + task = VideoTask(dataset_name="test_dataset", data=video_no_framerate) with patch("nemo_curator.models.transnetv2.TransNetV2"): self.stage.setup() @@ -250,7 +250,7 @@ def test_process_no_frame_array(self): frame_array=None, ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video_no_frames) + task = VideoTask(dataset_name="test_dataset", data=video_no_frames) with patch("nemo_curator.models.transnetv2.TransNetV2"): self.stage.setup() @@ -279,7 +279,7 @@ def test_process_wrong_frame_shape(self): frame_array=rng.integers(0, 255, (100, 28, 48, 3), dtype=np.uint8), # Wrong height ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video_wrong_shape) + task = VideoTask(dataset_name="test_dataset", data=video_wrong_shape) with patch("nemo_curator.models.transnetv2.TransNetV2"): self.stage.setup() @@ -685,7 +685,7 @@ def test_complete_pipeline_integration(self, mock_transnetv2_class: Mock): frame_array=frames, ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) # Mock model mock_model = Mock() @@ -694,7 +694,9 @@ def test_complete_pipeline_integration(self, mock_transnetv2_class: Mock): # Setup and process stage.setup() - with patch("nemo_curator.stages.video.clipping.transnetv2_extraction._get_predictions") as mock_get_predictions: + with patch( + "nemo_curator.stages.video.clipping.transnetv2_extraction._get_predictions" + ) as mock_get_predictions: # Mock predictions to create some transitions mock_get_predictions.return_value = np.array([[0], [1], [0], [0], [1], [0]] * 25, dtype=np.uint8) diff --git a/tests/stages/video/clipping/test_video_frame_extraction.py b/tests/stages/video/clipping/test_video_frame_extraction.py index 9f68a02223..2a1600edb4 100644 --- a/tests/stages/video/clipping/test_video_frame_extraction.py +++ b/tests/stages/video/clipping/test_video_frame_extraction.py @@ -330,7 +330,7 @@ def test_process_no_pynvc_extractor(self, mock_logger: Any, mock_get_frames: Any framerate=30, width=640, height=480, duration=10.0, video_codec="h264", num_frames=300 ), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -355,7 +355,7 @@ def test_process_no_source_bytes(self) -> None: framerate=30, width=640, height=480, duration=10.0, video_codec="h264", num_frames=300 ), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) with pytest.raises(ValueError, match="Video source bytes are not available"): stage.process(task) @@ -371,7 +371,7 @@ def test_process_incomplete_metadata(self, mock_logger: Any) -> None: source_bytes=b"fake_video_data", metadata=VideoMetadata(framerate=30, width=640, height=480), # Missing required fields ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -408,7 +408,7 @@ def test_process_pynvc_success(self, mock_temp_file: Any) -> None: framerate=30, width=640, height=480, duration=10.0, video_codec="h264", num_frames=300 ), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -449,7 +449,7 @@ def test_process_pynvc_exception_fallback( framerate=30, width=640, height=480, duration=10.0, video_codec="h264", num_frames=300 ), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -497,7 +497,7 @@ def test_process_ffmpeg_mode(self, mock_logger: Any, mock_get_frames: Any, mock_ framerate=30, width=640, height=480, duration=10.0, video_codec="h264", num_frames=300 ), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -544,7 +544,7 @@ def test_process_ffmpeg_gpu_mode(self, mock_logger: Any, mock_get_frames: Any, m framerate=30, width=640, height=480, duration=10.0, video_codec="h264", num_frames=300 ), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -592,7 +592,7 @@ def test_process_frame_extraction_failure( framerate=30, width=640, height=480, duration=10.0, video_codec="h264", num_frames=300 ), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) @@ -629,7 +629,7 @@ def test_process_verbose_mode(self, mock_logger: Any, mock_get_frames: Any, mock framerate=30, width=640, height=480, duration=10.0, video_codec="h264", num_frames=300 ), ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) result = stage.process(task) diff --git a/tests/stages/video/embedding/test_cosmos_embed1.py b/tests/stages/video/embedding/test_cosmos_embed1.py index 10b7143b07..fb0a1e8f93 100644 --- a/tests/stages/video/embedding/test_cosmos_embed1.py +++ b/tests/stages/video/embedding/test_cosmos_embed1.py @@ -72,7 +72,7 @@ def setup_method(self): ], ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_name_property(self): """Test the name property.""" @@ -302,7 +302,7 @@ def setup_method(self): ], ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_name_property(self): """Test the name property.""" @@ -499,7 +499,7 @@ def test_process_without_text_verification(self, mock_cosmos_embed1: "MagicMock" ], ) - fresh_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=fresh_video) + fresh_task = VideoTask(dataset_name="test_dataset", data=fresh_video) result = stage.process(fresh_task) diff --git a/tests/stages/video/filtering/test_clip_aesthetic_filter.py b/tests/stages/video/filtering/test_clip_aesthetic_filter.py index fbfaa9f5ea..5e6430b9ae 100644 --- a/tests/stages/video/filtering/test_clip_aesthetic_filter.py +++ b/tests/stages/video/filtering/test_clip_aesthetic_filter.py @@ -96,7 +96,7 @@ def setup_method(self) -> None: clip_chunk_index=0, ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_stage_initialization_defaults(self) -> None: """Test stage initialization with default parameters.""" @@ -317,7 +317,7 @@ def test_process_clip_without_buffer(self) -> None: clip_stats=ClipStats(), clip_chunk_index=0, ) - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) result = self.stage.process(task) @@ -353,7 +353,7 @@ def test_process_clip_without_extracted_frames(self) -> None: clip_stats=ClipStats(), clip_chunk_index=0, ) - task = VideoTask(task_id="test", dataset_name="test", data=video) + task = VideoTask(dataset_name="test", data=video) result = self.stage.process(task) @@ -443,7 +443,7 @@ def test_process_empty_clips_list(self) -> None: clip_stats=ClipStats(), clip_chunk_index=0, ) - empty_task = VideoTask(task_id="test", dataset_name="test", data=empty_video) + empty_task = VideoTask(dataset_name="test", data=empty_video) self.stage.model = Mock() self.stage.reduction_fn = np.min diff --git a/tests/stages/video/filtering/test_motion_filter.py b/tests/stages/video/filtering/test_motion_filter.py index f08d30b42f..02a9d02f04 100644 --- a/tests/stages/video/filtering/test_motion_filter.py +++ b/tests/stages/video/filtering/test_motion_filter.py @@ -70,7 +70,7 @@ def setup_method(self): ], ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_name_property(self): """Test the name property.""" @@ -250,7 +250,7 @@ def setup_method(self): clip_stats=ClipStats(), ) - self.mock_task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=self.mock_video) + self.mock_task = VideoTask(dataset_name="test_dataset", data=self.mock_video) def test_name_property(self): """Test the name property.""" @@ -493,7 +493,6 @@ def test_process_integration_with_video_task(self): # Check that the task structure is maintained assert isinstance(result, VideoTask) - assert result.task_id == "test_task" assert result.dataset_name == "test_dataset" assert isinstance(result.data, Video) assert result.data.input_video == pathlib.Path("test_video.mp4") diff --git a/tests/stages/video/io/test_clip_writer.py b/tests/stages/video/io/test_clip_writer.py index ad2829c96c..0b94e52f09 100644 --- a/tests/stages/video/io/test_clip_writer.py +++ b/tests/stages/video/io/test_clip_writer.py @@ -127,7 +127,6 @@ def setup_method(self): # Create mock task self.mock_task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=self.mock_video, ) @@ -760,7 +759,6 @@ def test_edge_cases_empty_clips(self): clip_chunk_index=0, ) empty_task = VideoTask( - task_id="empty_task", dataset_name="test_dataset", data=empty_video, ) diff --git a/tests/stages/video/io/test_video_reader.py b/tests/stages/video/io/test_video_reader.py index 6ab9a05ace..cbafeeaba4 100644 --- a/tests/stages/video/io/test_video_reader.py +++ b/tests/stages/video/io/test_video_reader.py @@ -153,7 +153,7 @@ def test_log_video_info(self) -> None: def test_process_success(self) -> None: """Test process method with successful execution.""" file_path = "/test/video.mp4" - file_group_task = FileGroupTask(task_id="test_task", dataset_name="test_dataset", data=[file_path]) + file_group_task = FileGroupTask(dataset_name="test_dataset", data=[file_path]) with ( patch.object(VideoReaderStage, "_download_video_bytes", return_value=True), @@ -163,7 +163,6 @@ def test_process_success(self) -> None: result = stage.process(file_group_task) assert isinstance(result, VideoTask) - assert result.task_id == f"{file_path}_processed" assert result.dataset_name == "test_dataset" assert isinstance(result.data, Video) assert result.data.input_video == file_path @@ -171,19 +170,18 @@ def test_process_success(self) -> None: def test_process_download_failure(self) -> None: """Test process method when download fails.""" file_path = "/test/video.mp4" - file_group_task = FileGroupTask(task_id="test_task", dataset_name="test_dataset", data=[file_path]) + file_group_task = FileGroupTask(dataset_name="test_dataset", data=[file_path]) with patch.object(VideoReaderStage, "_download_video_bytes", return_value=False): stage = VideoReaderStage() result = stage.process(file_group_task) assert isinstance(result, VideoTask) - assert result.task_id == f"{file_path}_processed" def test_process_metadata_failure(self) -> None: """Test process method when metadata extraction fails.""" file_path = "/test/video.mp4" - file_group_task = FileGroupTask(task_id="test_task", dataset_name="test_dataset", data=[file_path]) + file_group_task = FileGroupTask(dataset_name="test_dataset", data=[file_path]) with ( patch.object(VideoReaderStage, "_download_video_bytes", return_value=True), @@ -193,7 +191,6 @@ def test_process_metadata_failure(self) -> None: result = stage.process(file_group_task) assert isinstance(result, VideoTask) - assert result.task_id == f"{file_path}_processed" def test_process_preserves_metadata(self) -> None: """Test process method preserves task metadata and stage performance.""" @@ -202,7 +199,6 @@ def test_process_preserves_metadata(self) -> None: original_stage_perf = [{"stage": "prev_stage", "time": 1.0}] file_group_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[file_path], _metadata=original_metadata, @@ -222,7 +218,7 @@ def test_process_preserves_metadata(self) -> None: def test_process_with_verbose_logging(self) -> None: """Test process method enables verbose logging when configured.""" file_path = "/test/video.mp4" - file_group_task = FileGroupTask(task_id="test_task", dataset_name="test_dataset", data=[file_path]) + file_group_task = FileGroupTask(dataset_name="test_dataset", data=[file_path]) with ( patch.object(VideoReaderStage, "_download_video_bytes", return_value=True), @@ -407,21 +403,19 @@ def test_process_creates_correct_task_id(self) -> None: ] for file_path in test_cases: - file_group_task = FileGroupTask(task_id="original_task", dataset_name="test_dataset", data=[file_path]) + file_group_task = FileGroupTask(dataset_name="test_dataset", data=[file_path]) with ( patch.object(VideoReaderStage, "_download_video_bytes", return_value=True), patch.object(VideoReaderStage, "_extract_and_validate_metadata", return_value=True), ): stage = VideoReaderStage() - result = stage.process(file_group_task) - - assert result.task_id == f"{file_path}_processed" + stage.process(file_group_task) def test_process_without_verbose_no_logging(self) -> None: """Test process method doesn't call _log_video_info when verbose is False.""" file_path = "/test/video.mp4" - file_group_task = FileGroupTask(task_id="test_task", dataset_name="test_dataset", data=[file_path]) + file_group_task = FileGroupTask(dataset_name="test_dataset", data=[file_path]) with ( patch.object(VideoReaderStage, "_download_video_bytes", return_value=True), @@ -447,7 +441,7 @@ def test_stage_default_verbose_setting(self) -> None: def test_video_task_data_structure(self) -> None: """Test that created VideoTask has correct data structure.""" file_path = "/test/video.mp4" - file_group_task = FileGroupTask(task_id="test_task", dataset_name="test_dataset", data=[file_path]) + file_group_task = FileGroupTask(dataset_name="test_dataset", data=[file_path]) with ( patch.object(VideoReaderStage, "_download_video_bytes", return_value=True), @@ -488,7 +482,7 @@ def test_metadata_extraction_failure_logging(self) -> None: def test_process_with_various_file_extensions(self, file_extension: str) -> None: """Test process method works with various video file extensions.""" file_path = f"/test/video{file_extension}" - file_group_task = FileGroupTask(task_id="test_task", dataset_name="test_dataset", data=[file_path]) + file_group_task = FileGroupTask(dataset_name="test_dataset", data=[file_path]) with ( patch.object(VideoReaderStage, "_download_video_bytes", return_value=True), @@ -507,7 +501,6 @@ def test_deepcopy_preservation(self) -> None: nested_stage_perf = [{"stage": "prev", "nested": {"time": 1.0}}] file_group_task = FileGroupTask( - task_id="test_task", dataset_name="test_dataset", data=[file_path], _metadata=nested_metadata, @@ -544,8 +537,9 @@ def test_nonexistent_video_dir_raises(self) -> None: """Test that VideoReader raises FileNotFoundError for non-existent local path.""" mock_instance = mock.Mock() mock_instance.exists.return_value = False - with patch("nemo_curator.stages.video.io.video_reader.Path", return_value=mock_instance), pytest.raises( - FileNotFoundError, match="Video directory does not exist" + with ( + patch("nemo_curator.stages.video.io.video_reader.Path", return_value=mock_instance), + pytest.raises(FileNotFoundError, match="Video directory does not exist"), ): VideoReader(input_video_path="/nonexistent/path") @@ -555,8 +549,9 @@ def test_empty_video_dir_raises(self) -> None: mock_instance.exists.return_value = True mock_instance.is_file.return_value = False mock_instance.rglob.side_effect = lambda *_: iter([]) - with patch("nemo_curator.stages.video.io.video_reader.Path", return_value=mock_instance), pytest.raises( - FileNotFoundError, match="No video files found" + with ( + patch("nemo_curator.stages.video.io.video_reader.Path", return_value=mock_instance), + pytest.raises(FileNotFoundError, match="No video files found"), ): VideoReader(input_video_path="/empty/dir") @@ -576,8 +571,9 @@ def test_single_non_video_file_raises(self) -> None: mock_instance.exists.return_value = True mock_instance.is_file.return_value = True mock_instance.suffix = ".txt" - with patch("nemo_curator.stages.video.io.video_reader.Path", return_value=mock_instance), pytest.raises( - FileNotFoundError, match=r"Not a supported video file.*Supported formats" + with ( + patch("nemo_curator.stages.video.io.video_reader.Path", return_value=mock_instance), + pytest.raises(FileNotFoundError, match=r"Not a supported video file.*Supported formats"), ): VideoReader(input_video_path="/data/document.txt") @@ -613,7 +609,7 @@ def test_stage_properties(self) -> None: # Test that it's a composite stage (should raise error when trying to process) from nemo_curator.tasks import _EmptyTask - empty_task = _EmptyTask(task_id="test", dataset_name="test", data=None) + empty_task = _EmptyTask(dataset_name="test", data=None) with pytest.raises(RuntimeError, match="Composite stage 'video_reader' should not be executed directly"): stage.process(empty_task) diff --git a/tests/stages/video/preview/test_preview.py b/tests/stages/video/preview/test_preview.py index 49c50d048f..7eae3ab31e 100644 --- a/tests/stages/video/preview/test_preview.py +++ b/tests/stages/video/preview/test_preview.py @@ -86,7 +86,7 @@ def test_process_with_adequate_metadata(self): video = Video(input_video=pathlib.Path("test.mp4"), metadata=video_metadata, clips=[clip]) # Create video task - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) stage = PreviewStage() @@ -110,7 +110,7 @@ def test_process_with_low_framerate_warning(self): video = Video(input_video=pathlib.Path("test.mp4"), metadata=video_metadata, clips=[]) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) stage = PreviewStage() @@ -134,7 +134,7 @@ def test_process_with_low_height_warning(self): video = Video(input_video=pathlib.Path("test.mp4"), metadata=video_metadata, clips=[]) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) stage = PreviewStage() @@ -165,7 +165,7 @@ def test_process_multiple_clips_and_windows(self): clips=[clip1, clip2], ) - task = VideoTask(task_id="test_task", dataset_name="test_dataset", data=video) + task = VideoTask(dataset_name="test_dataset", data=video) stage = PreviewStage() diff --git a/tests/tasks/test_file_group_tasks.py b/tests/tasks/test_file_group_tasks.py new file mode 100644 index 0000000000..c317a0378d --- /dev/null +++ b/tests/tasks/test_file_group_tasks.py @@ -0,0 +1,36 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.tasks import FileGroupTask + + +class TestFileGroupTask: + def test_deterministic_ids(self) -> None: + """``get_deterministic_id`` hashes the sorted file paths, so it is: + order-independent, distinct for distinct file sets, and a 12-char + hex string.""" + # Order-independent: same files in different orders → same id. + a = FileGroupTask(dataset_name="d", data=["b.parquet", "a.parquet"]) + b = FileGroupTask(dataset_name="d", data=["a.parquet", "b.parquet"]) + assert a.get_deterministic_id() == b.get_deterministic_id() + + # Distinct file sets → distinct ids. + c = FileGroupTask(dataset_name="d", data=["c.parquet"]) + assert a.get_deterministic_id() != c.get_deterministic_id() + + # 12-char hex string. + result = c.get_deterministic_id() + assert isinstance(result, str) + assert len(result) == 12 + assert all(ch in "0123456789abcdef" for ch in result) diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index e76959c8e6..526a0aabc6 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -44,28 +44,68 @@ def outputs(self) -> tuple[list[str], list[str]]: return [], [] def process(self, task: SimpleTask) -> list[SimpleTask]: - # Important: construct fresh Task objects so each gets a fresh _uuid + # Construct fresh Task objects. task_id is assigned later by the + # executor adapter, not by process()/process_batch. return [ SimpleTask( - task_id=f"{task.task_id}_{i}", dataset_name=task.dataset_name, data=task.data, _metadata=task._metadata.copy(), _stage_perf=task._stage_perf.copy(), ) - for i in range(self.times) + for _ in range(self.times) ] def _sample_task() -> SimpleTask: - return SimpleTask(task_id="t0", dataset_name="test", data=[1, 2, 3]) + return SimpleTask(dataset_name="test", data=[1, 2, 3]) -def test_fanout_tasks_have_unique_uuid(): +def test_default_process_batch_does_not_assign_task_id(): + """``process_batch`` (and ``process``) do not touch ``task_id`` — that's + the executor adapter's job (``BaseStageAdapter._post_process_task_ids``). + So fanned-out children come back here with empty ids; task_id assignment + is covered in tests/backends/test_task_id_postprocess.py.""" task = _sample_task() - stage = Repeat(times=3) - output = stage.process(task) + output = Repeat(times=3).process_batch([task]) assert len(output) == 3 - uuids = [t._uuid for t in output] - assert len(set(uuids)) == 3, f"Expected unique _uuid per task, got {uuids}" + assert all(t.task_id == "" for t in output) + + +class TestSetTaskId: + """``Task._set_task_id``: the id is the parent id and this task's own + segment joined by ``"_"`` (no hashing).""" + + def test_no_parent_uses_suffix_only(self) -> None: + t = _sample_task() + t._set_task_id("", 3) + # An empty parent id is dropped, so no leading "_". + assert t.task_id == "3" + + def test_joins_parent_and_suffix(self) -> None: + t = _sample_task() + t._set_task_id("0", 7) + assert t.task_id == "0_7" + + def test_always_overwrites(self) -> None: + """No idempotency — each stage boundary re-derives the id, so one + object passing through N stages gets N distinct task_ids.""" + t = _sample_task() + t._set_task_id("", 0) + t._set_task_id("0", 7) + assert t.task_id == "0_7" + + def test_string_suffix(self) -> None: + """Source stages pass a content-based hash (str) as the suffix + instead of a positional index.""" + t = _sample_task() + t._set_task_id("root", "abc123") + assert t.task_id == "root_abc123" + + +def test_get_deterministic_id_defaults_to_none(): + """Base ``Task`` has no content identity, so source stages fall back to + the positional index. ``FileGroupTask`` overrides this — see + tests/tasks/test_file_group_tasks.py.""" + assert _sample_task().get_deterministic_id() is None diff --git a/tests/tasks/test_utils.py b/tests/tasks/test_utils.py index 62f3120abc..b475d4c343 100644 --- a/tests/tasks/test_utils.py +++ b/tests/tasks/test_utils.py @@ -22,7 +22,7 @@ def make_dummy_task(stage_name: str, process_time: float, custom: float = 0.0) -> _EmptyTask: perf = StagePerfStats(stage_name=stage_name, process_time=process_time, custom_metrics={"io": custom}) - return _EmptyTask(task_id=f"{stage_name}_{process_time}", dataset_name="test", data=None, _stage_perf=[perf]) + return _EmptyTask(dataset_name="test", data=None, _stage_perf=[perf]) class TestTaskPerfUtils: diff --git a/tests/tasks/test_video.py b/tests/tasks/test_video.py index a1442e162f..b1ce7119ae 100644 --- a/tests/tasks/test_video.py +++ b/tests/tasks/test_video.py @@ -537,12 +537,10 @@ def test_video_task_initialization(self) -> None: """Test VideoTask initialization.""" video_data = Video(input_video=pathlib.Path("test.mp4")) task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=video_data, ) - assert task.task_id == "test_task" assert task.dataset_name == "test_dataset" assert isinstance(task.data, Video) @@ -550,7 +548,6 @@ def test_video_task_initialization_with_data(self) -> None: """Test VideoTask initialization with video data.""" video_data = Video(input_video=pathlib.Path("test.mp4")) task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=video_data, ) @@ -565,7 +562,6 @@ def test_validate_existing_file(self) -> None: try: video_data = Video(input_video=tmp_path) task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=video_data, ) @@ -578,7 +574,6 @@ def test_validate_non_existing_file(self) -> None: """Test validate method with non-existing file.""" video_data = Video(input_video=pathlib.Path("non_existing_file.mp4")) task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=video_data, ) @@ -590,7 +585,6 @@ def test_num_items_property(self) -> None: """Test num_items property.""" video_data = Video(input_video=pathlib.Path("test.mp4")) task = VideoTask( - task_id="test_task", dataset_name="test_dataset", data=video_data, ) diff --git a/tests/stages/text/io/writer/test_utils.py b/tests/utils/test_hash_utils.py similarity index 97% rename from tests/stages/text/io/writer/test_utils.py rename to tests/utils/test_hash_utils.py index a074aa3965..932bb53b46 100644 --- a/tests/stages/text/io/writer/test_utils.py +++ b/tests/utils/test_hash_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo_curator.stages.text.io.writer.utils import get_deterministic_hash +from nemo_curator.utils.hash_utils import get_deterministic_hash class TestGetDeterministicHash: diff --git a/tests/utils/test_merge_file_prefixes.py b/tests/utils/test_merge_file_prefixes.py index 131c67893a..85c3a8ceb7 100644 --- a/tests/utils/test_merge_file_prefixes.py +++ b/tests/utils/test_merge_file_prefixes.py @@ -64,13 +64,12 @@ def setup_mocks(mock_tokenizer: Mock): yield {"auto_tokenizer": mock_auto_tokenizer} -def _make_batch(task_id: str, texts: list[str]) -> DocumentBatch: - df = pd.DataFrame({"text": texts, "id": [f"{task_id}_{i}" for i in range(len(texts))]}) +def _make_batch(prefix: str, texts: list[str]) -> DocumentBatch: + df = pd.DataFrame({"text": texts, "id": [f"{prefix}_{i}" for i in range(len(texts))]}) return DocumentBatch( - task_id=task_id, dataset_name="test_dataset", data=df, - _metadata={"source_files": [f"{task_id}.jsonl"]}, + _metadata={"source_files": [f"{prefix}.jsonl"]}, ) @@ -112,7 +111,7 @@ def test_merge_file_prefixes_produces_expected_sizes( # Build batches with varying per-document token counts so sequence_lengths differ. batches = [ _make_batch( - task_id=f"batch_{i}", + prefix=f"batch_{i}", texts=[f"batch {i} doc {j} " + ("word " * (j + i + 1)) for j in range(3)], ) for i in range(num_batches) @@ -188,7 +187,7 @@ def test_merge_file_prefixes_single_prefix(self, tmpdir: str): os.makedirs(input_dir, exist_ok=True) batch = _make_batch( - task_id="only_batch", + prefix="single", texts=["hello world", "this is a longer test document", "tiny"], ) @@ -223,7 +222,7 @@ def test_merge_file_prefixes_missing_pair_raises(self, tmpdir: str): os.makedirs(input_dir, exist_ok=True) # A valid pair to prove the guard is checking each prefix, not just the directory as a whole. - batch = _make_batch(task_id="valid", texts=["some text here", "another doc"]) + batch = _make_batch(prefix="pair", texts=["some text here", "another doc"]) writer = MegatronTokenizerWriter(path=input_dir, model_identifier="test/model") writer.setup() writer.process(batch) diff --git a/tests/utils/test_split_large_files.py b/tests/utils/test_split_large_files.py index 8fb87c5c94..455f0dccd6 100644 --- a/tests/utils/test_split_large_files.py +++ b/tests/utils/test_split_large_files.py @@ -57,9 +57,7 @@ def _(num_row_groups: int = 1) -> pathlib.Path: def test_default_target_size(parquet_file_factory: Callable, tmp_path: pathlib.Path): parquet_file = parquet_file_factory() - args = parse_args( - ["--input-path", str(parquet_file), "--output-path", str(tmp_path), "--file-type", "parquet"] - ) + args = parse_args(["--input-path", str(parquet_file), "--output-path", str(tmp_path), "--file-type", "parquet"]) assert args.target_size_mb == 128 @@ -70,7 +68,9 @@ def test_split_parquet_file_by_size(parquet_file_factory: Callable, tmp_path: pa target_size_mb = size_original_mb / 3 output_path = tmp_path / "out" output_path.mkdir(exist_ok=True) - split_parquet_file_by_size._function(input_file=str(parquet_file), output_path=str(output_path), target_size_mb=target_size_mb) + split_parquet_file_by_size._function( + input_file=str(parquet_file), output_path=str(output_path), target_size_mb=target_size_mb + ) expected = pd.read_parquet(parquet_file) result = pd.read_parquet(output_path) @@ -96,7 +96,9 @@ def test_split_jsonl_file_by_size(tmp_path: pathlib.Path): target_size_mb = max(size_original_mb / 4, 1e-6) output_path = tmp_path / "out" output_path.mkdir(exist_ok=True) - split_jsonl_file_by_size._function(input_file=str(jsonl_file), output_path=str(output_path), target_size_mb=target_size_mb) + split_jsonl_file_by_size._function( + input_file=str(jsonl_file), output_path=str(output_path), target_size_mb=target_size_mb + ) files = sorted(output_path.glob("data_*.jsonl")) assert len(files) >= 2