Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions nemo_curator/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -84,6 +85,9 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
# Use the batch processing logic
results = self.stage.process_batch(tasks)

# Guarantee every emitted task has a task_id (derived id, or uuid fallback).
results = self._post_process_task_ids(tasks, results)

# Log performance stats and add to result tasks
_, stage_perf_stats = self._timer.log_stats()
# Consume and attach any custom metrics recorded by the stage during this call
Expand All @@ -95,6 +99,75 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:

return results

def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]:
"""Assign a deterministic ``task_id`` to every emitted task.

This is the single place task ids are assigned — it runs for every
stage on every backend (all backend adapters subclass this), so it
makes no difference whether a stage defines ``process`` or overrides
``process_batch``. ``task_id`` is the task's id path (parents + own segment); ids are
re-derived at each stage boundary so the same object passing through
N stages gets N ids.

The input→output mapping decides each output's PARENT; whether the
stage is a source decides each output's SEGMENT (content id vs index)
— the two are independent. ``None`` outputs (Curator's "return None to
filter") are NOT removed before the length check — keeping them in
place preserves positional alignment for filter stages — and are then
dropped from the returned list.

- single input → every output is its child (fan-out): ``parent_<seg>``
- ``len(output) == len(input)`` → positional 1:1: each ``parent_i_<seg>``;
a ``None`` slot just means input ``i`` was filtered.
- any other (ambiguous) cardinality across a batch → a random ``uuid``
prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never
empty even when a derived id is not possible. The ``"r"`` prefix flags
the id as non-deterministic / ancestry-not-tracked (see
``Task.task_id`` docstring).

``seg`` is the output's content id (``Task.get_deterministic_id()``)
for a source stage when available, else the positional index — so a
source partition keeps a stable id across reorderings regardless of
whether the source is 1→N or N→N.

Note: a stage that BOTH filters and fans out within a single batch
(returning a flat list rather than a per-input slot) cannot be mapped
positionally; if its length happens to equal the input length the 1:1
assumption may misattribute parents. That combination is unsupported
until per-slot sentinels (NoneTask/FailedTask) land in a later PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are NoneTask's or can we have None?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have None, we repalce None with NoneTask anyways in the next PR> not sure what is question is though

"""
is_source = getattr(self.stage, "is_source_stage", False)

if len(input_tasks) == 1:
# Fan-out (incl. a source reading from EmptyTask): every non-None
# output is a child of the single input.
parent_id = input_tasks[0].task_id
out: list[Task] = [t for t in output_tasks if t is not None]
for i, task in enumerate(out):
suffix = (task.get_deterministic_id() or i) if is_source else i

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If something was not a source and wanted to still override get_deterministic_id do we not want to support that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. I think the current code assumption is that only the source can have a deterministic ID. This makes a few assumptions easy; all other IDs will be numbers except one. Very easy to say which one is the source ID (We do store the source ID as well). Finally, I think there is no specific purpose that deterministic_ids are serving. Later down the road, we might implement a uniqueness check for the source ID, but doing it for 2 ids would be much more difficult.

task._set_task_id(parent_id, suffix)
return out

if len(output_tasks) == len(input_tasks):
# Positional 1:1. None is kept above so a filtered slot still lines
# up with its own parent; drop the None slots from the result.
out = []
for parent, task in zip(input_tasks, output_tasks, strict=True):
if task is None:
continue
suffix = (task.get_deterministic_id() or 0) if is_source else 0
task._set_task_id(parent.task_id, suffix)
out.append(task)
return out

# Ambiguous cardinality across a batch: a derived id is not possible. Use a
# random "r"-prefixed uuid so task_id is non-empty but clearly flagged
# non-deterministic.
out = [t for t in output_tasks if t is not None]
for task in out:
task.task_id = "r" + uuid.uuid4().hex
return out

def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None:
"""Setup the stage on a node.

Expand Down
1 change: 1 addition & 0 deletions nemo_curator/models/transnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
_TRANSNETV2_MODEL_WEIGHTS: Final = "transnetv2-pytorch-weights.pth"
_TRANSNETV2_MODEL_REVISION: Final = "db6ceab"


class _TransNetV2(nn.Module):
def __init__( # noqa: PLR0913
self,
Expand Down
52 changes: 52 additions & 0 deletions nemo_curator/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,31 @@
from nemo_curator.backends.base import BaseExecutor
from nemo_curator.stages.base import CompositeStage, ProcessingStage
from nemo_curator.tasks import Task
from nemo_curator.tasks.tasks import _EmptyTask


def assign_root_task_ids(initial_tasks: list[Task]) -> list[Task]:
"""Assign root ``task_id``s to user-provided initial tasks.

Every task in a run descends from the implicit root ``"0"`` (the id of
:class:`_EmptyTask`). User-provided initial tasks are its direct
children, so they get ``"0_0"``, ``"0_1"``, … ``_EmptyTask`` instances
are skipped (already ``"0"``). All downstream ``task_id`` assignment
happens in ``BaseStageAdapter``.

NOTE: we deliberately use the positional index here, NOT
``get_deterministic_id()``, even for content-bearing tasks like
``FileGroupTask``. The source stage is the single place content-based
ids are assigned (to its outputs); hashing here too would put the
content hash at two levels of the id path (``"0_<hashA>_<hashB>"``).
Passing initial tasks directly is rare; if you need reorder-stable
source ids, let a source stage emit them.
"""
for i, task in enumerate(initial_tasks):
if isinstance(task, _EmptyTask):
Comment thread
abhinavg4 marked this conversation as resolved.
continue
task._set_task_id("0", i)
return initial_tasks


class Pipeline:
Expand Down Expand Up @@ -80,6 +105,30 @@ def build(self) -> None:
self.stages = execution_stages
self.decomposition_info = decomposition_info

# 3. Source / sink defaults: at most one stage may be explicitly
# marked; if none, the first stage is the source and the last is
# the sink. The source flag activates content-based ids in the
# default ``process_batch``; the sink flag is used by the
# resumability layer in a follow-up PR.
self._assign_source_sink_roles()

def _assign_source_sink_roles(self) -> None:
explicit_sources = [s for s in self.stages if s.is_source_stage]
if len(explicit_sources) > 1:
names = [s.name for s in explicit_sources]
msg = f"Pipeline has multiple source stages marked: {names}. At most one is supported."
raise ValueError(msg)
if not explicit_sources:
self.stages[0].is_source_stage = True

explicit_sinks = [s for s in self.stages if s.is_sink_stage]
if len(explicit_sinks) > 1:
names = [s.name for s in explicit_sinks]
msg = f"Pipeline has multiple sink stages marked: {names}. At most one is supported."
raise ValueError(msg)
if not explicit_sinks:
self.stages[-1].is_sink_stage = True

def _decompose_stages(
self, stages: list[ProcessingStage | CompositeStage]
) -> tuple[list[ProcessingStage], dict[str, list[str]]]:
Expand Down Expand Up @@ -212,4 +261,7 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] |
"The executor will schedule GPU stages on GPUs not held by Serve."
)

if initial_tasks:
assign_root_task_ids(initial_tasks)

return executor.execute(self.stages, initial_tasks)
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,15 @@ def _validate(cfg: dict[str, Any]) -> None: # noqa: C901

sigmos = cfg.get("sigmos", {})
if sigmos.get("enable", True):
for key in ("noise_threshold", "ovrl_threshold", "sig_threshold",
"col_threshold", "disc_threshold", "loud_threshold", "reverb_threshold"):
for key in (
"noise_threshold",
"ovrl_threshold",
"sig_threshold",
"col_threshold",
"disc_threshold",
"loud_threshold",
"reverb_threshold",
):
val = sigmos.get(key)
if val is not None and not 0.0 <= val <= _MOS_MAX:
msg = f"sigmos.{key} must be in [0, {_MOS_MAX}] (MOS scale), got {val}"
Expand Down
14 changes: 3 additions & 11 deletions nemo_curator/stages/audio/alm/pretrain/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,7 @@ def process(self, task: AudioTask) -> list[AudioTask]:
)
return outputs

def _extract_emit(
self, task: AudioTask, plan: list[dict], original_id: str
) -> list[AudioTask]:
def _extract_emit(self, task: AudioTask, plan: list[dict], original_id: str) -> list[AudioTask]:
source_path = task.data.get(self.audio_filepath_key)
if not source_path or not os.path.exists(source_path):
logger.error(
Expand All @@ -204,9 +202,7 @@ def _extract_emit(
outputs.append(self._make_stub_task(task))
return outputs

def _dry_run_emit(
self, task: AudioTask, plan: list[dict], original_id: str
) -> list[AudioTask]:
def _dry_run_emit(self, task: AudioTask, plan: list[dict], original_id: str) -> list[AudioTask]:
"""Emit snippet metadata only, without reading or writing audio.

``audio_filepath`` is the tar-internal basename
Expand Down Expand Up @@ -337,13 +333,10 @@ def _make_snippet_task(
# Reset to "" — a downstream pipeline is expected to set this correctly.
if "swift_audio_filepath" in new_data:
new_data["swift_audio_filepath"] = ""
new_data["segments"] = relativize_segments(
snippet["segments"], snippet["start"], snippet["end"]
)
new_data["segments"] = relativize_segments(snippet["segments"], snippet["start"], snippet["end"])
if "text" in new_data:
new_data["text"] = " ".join(_segment_text(s) for s in snippet["segments"]).strip()
return AudioTask(
task_id=f"{task.task_id}::{snippet_id}",
dataset_name=task.dataset_name,
data=new_data,
filepath_key=self.audio_filepath_key,
Expand All @@ -361,7 +354,6 @@ def _make_stub_task(self, task: AudioTask) -> AudioTask:
"segments": [],
}
return AudioTask(
task_id=f"{task.task_id}::stub",
dataset_name=task.dataset_name,
data=stub_data,
_metadata=copy.deepcopy(task._metadata),
Expand Down
9 changes: 2 additions & 7 deletions nemo_curator/stages/audio/alm/pretrain/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@
# ----------------------------------------------------------------------


def _read_manifest_row_id(
stage_name: str, lineno: int, entry: dict[str, Any], seen_ids: set[str]
) -> str | None:
def _read_manifest_row_id(stage_name: str, lineno: int, entry: dict[str, Any], seen_ids: set[str]) -> str | None:
# `id` is required by the pipeline contract: downstream snippet ids
# embed it, the metrics aggregator keys per-source records on it, and
# tar members are named with it. A row without a usable id can't be
Expand Down Expand Up @@ -202,16 +200,13 @@ def process(self, _: _EmptyTask) -> list[AudioTask]:
logger.warning(f"[{self.name}] line {lineno}: missing {self.audio_filepath_key!r}; skipping")
continue
if self.audio_path_resolution == AUDIO_PATH_RESOLUTION_BASENAME:
_check_duplicate_audio_basename(
self.name, lineno, original_path, row_id, seen_basenames
)
_check_duplicate_audio_basename(self.name, lineno, original_path, row_id, seen_basenames)
entry[self.audio_filepath_key] = _resolve_audio_path(
self.audio_dir, original_path, self.audio_path_resolution
)

tasks.append(
AudioTask(
task_id=row_id,
dataset_name=self.dataset_name,
data=entry,
filepath_key=self.audio_filepath_key,
Expand Down
2 changes: 0 additions & 2 deletions nemo_curator/stages/audio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def process(self, task: FileGroupTask) -> list[AudioTask]:
if line.strip():
results.append(
AudioTask(
task_id=f"{task.task_id}_{count}",
dataset_name=task.dataset_name,
data=json.loads(line.strip()),
_metadata=task._metadata,
Expand Down Expand Up @@ -282,7 +281,6 @@ def process(self, task: AudioTask) -> AudioTask:
with self._fs.open(self._path, "a", encoding="utf-8") as f:
f.write(json.dumps(task.data, ensure_ascii=False) + "\n")
return AudioTask(
task_id=task.task_id,
dataset_name=task.dataset_name,
data=task.data,
_metadata=task._metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def process_transcript(self, file_path: str) -> list[AudioTask]:
entries.append(
AudioTask(
data={self.filepath_key: abs_wav, self.text_key: transcript_text},
task_id=f"task_id_{abs_wav}",
dataset_name=f"Fleurs_{self.lang}_{self.split}_{self.raw_data_dir}",
filepath_key=self.filepath_key,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,10 @@ def process(self, _: _EmptyTask) -> list[AudioTask]:
logger.info(f"Creating manifest with {len(selected_entries)} total samples")

audio_tasks = []
for i, entry in enumerate(selected_entries):
for _i, entry in enumerate(selected_entries):
audio_tasks.append(
AudioTask(
data=entry,
task_id=f"readspeech_{i}",
dataset_name="DNS-ReadSpeech",
filepath_key=self.filepath_key,
)
Expand Down
6 changes: 2 additions & 4 deletions nemo_curator/stages/audio/filtering/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]:
if "segments" in task.data:
survivors = []
for seg in task.data["segments"]:
temp = AudioTask(data=seg, task_id=task.task_id)
temp = AudioTask(data=seg)
result = self._process_single(temp)
if result is not None:
survivors.append(temp.data)
Expand Down Expand Up @@ -189,9 +189,7 @@ def _process_single(self, task: AudioTask) -> AudioTask | None:

actual = task.data.get("band_prediction", "unknown")
if actual != self.band_value:
logger.info(
f"[{task.task_id}] BAND FILTER FAILED: prediction '{actual}' != target '{self.band_value}'"
)
logger.info(f"[{task.task_id}] BAND FILTER FAILED: prediction '{actual}' != target '{self.band_value}'")
return None

return task
2 changes: 1 addition & 1 deletion nemo_curator/stages/audio/filtering/sigmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]:
if "segments" in task.data:
survivors = []
for seg in task.data["segments"]:
temp = AudioTask(data=seg, task_id=task.task_id)
temp = AudioTask(data=seg)
result = self._process_single(temp)
if result is not None:
survivors.append(temp.data)
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/stages/audio/filtering/utmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]:
if "segments" in task.data:
survivors = []
for seg in task.data["segments"]:
temp = AudioTask(data=seg, task_id=task.task_id)
temp = AudioTask(data=seg)
result = self._process_single(temp)
if result is not None:
survivors.append(temp.data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def process(self, task: AudioTask) -> AudioTask:
output_data[self.diar_segments_key] = segments

return AudioTask(
task_id=f"{task.task_id}_sortformer",
dataset_name=task.dataset_name,
filepath_key=task.filepath_key or self.filepath_key,
data=output_data,
Expand Down
1 change: 0 additions & 1 deletion nemo_curator/stages/audio/io/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def process_batch(self, tasks: list[AudioTask]) -> list[DocumentBatch]:
return [
DocumentBatch(
data=df,
task_id=",".join(t.task_id for t in tasks),
dataset_name=",".join(dict.fromkeys(t.dataset_name for t in tasks)),
_stage_perf=perf,
)
Expand Down
Loading
Loading