From 5d99f27f26ed1d0f7ccabbfad86f79c455ab60de Mon Sep 17 00:00:00 2001 From: "aaftaabv@gmail.com" Date: Fri, 29 May 2026 15:13:09 +0530 Subject: [PATCH 1/6] audio/diarization: stage-adapter split per SDP-V2 design Bring the SDP-V2 design doc (sec 3 speaker diarization) stage-adapter split into the audio tagging pipeline. What moves ---------- * New package nemo_curator/adapters/ with: - adapters/diarization/base.py - DiarSegment + DiarizationResult dataclasses + DiarizationAdapter Protocol (model_id / revision / setup / teardown / prefetch_weights classmethod / diarize_batch). - adapters/diarization/pyannote.py - PyAnnoteDiarizationAdapter that re-houses every PyAnnote-specific code path from the deleted PyAnnoteDiarizationStage (HF auth, in-pipeline batching knobs, overlap detection, has_overlap walk, WhisperX-VAD-driven micro-split of long turns, RTTM sidecar write). * New generic nemo_curator/stages/audio/inference/speaker_diarization/ stage.py with DiarizationStage that owns Curator-side glue only: task.data key reads, item-dict construction, adapter dispatch, DiarSegment to on-disk dict conversion, add_non_speaker_segments gap fill, metric logging. YAML shape (Tier-1 / Tier-2 split) ---------------------------------- - _target_: nemo_curator.stages.audio.inference.speaker_diarization.DiarizationStage name: PyAnnoteDiarization # keeps perf_summary key adapter_target: nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter model_id: pyannote/speaker-diarization-3.1 non_speaker_max_length: ${max_segment_length} adapter_kwargs: hf_token: ${hf_token} max_length: ${max_segment_length} Class resolution uses hydra.utils.get_class(adapter_target). Behaviour preservation ---------------------- * RTTM sidecar write, overlap detection, has_overlap semantics, WhisperX-VAD micro-split of >max_length turns - all preserved byte-for-byte from the pre-split PyAnnoteDiarizationStage. Numeric output is identical for the same inputs and same random seed. * On-disk segments_key + overlap_segments_key dict shape unchanged. * add_non_speaker_segments still runs stage-side after adapter results. What is deleted --------------- * nemo_curator/stages/audio/inference/speaker_diarization/pyannote.py * tests/stages/audio/inference/speaker_diarization/test_pyannote.py Call-site migrations -------------------- * tutorials/audio/tagging/tts_pipeline.yaml * tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml * benchmarking/scripts/audio_tagging_benchmark.py * nemo_curator/stages/audio/tagging/__init__.py (lazy-import map) * tutorials/audio/tagging/README.md (table + override doc) New tests --------- * tests/stages/audio/inference/speaker_diarization/test_diarization_stage.py - 20 CPU tests against a fake adapter (construction, lifecycle, process(), metric logging, no GPU / PyAnnote needed). * tests/adapters/diarization/test_pyannote_adapter.py - has_overlap unit tests (carried over from the deleted file). - PyAnnoteDiarizationAdapter construction, prefetch_weights, setup/teardown lifecycle (PyAnnote mocked), diarize_batch empty/missing-filepath paths, _add_vad_segments micro-split. Follow-ups (out of scope for this commit; same posture as PR1967) ----------------------------------------------------------------- * Stage-side pre-slicing and BatchPolicy - tagging pipeline currently uses SplitLongAudioStage for that role; revisit when promoting to in-memory dataflow. * SortformerAdapter / WhisperXDiarizationAdapter - the stage is now shaped for them; they ship in a separate commit. Signed-off-by: aaftaabv@gmail.com --- .../scripts/audio_tagging_benchmark.py | 15 +- nemo_curator/adapters/__init__.py | 27 ++ nemo_curator/adapters/diarization/__init__.py | 41 ++ nemo_curator/adapters/diarization/base.py | 170 ++++++++ nemo_curator/adapters/diarization/pyannote.py | 371 ++++++++++++++++++ .../inference/speaker_diarization/__init__.py | 19 + .../inference/speaker_diarization/pyannote.py | 304 -------------- .../inference/speaker_diarization/stage.py | 290 ++++++++++++++ nemo_curator/stages/audio/tagging/__init__.py | 4 +- tests/adapters/__init__.py | 0 tests/adapters/diarization/__init__.py | 0 .../diarization/test_pyannote_adapter.py | 221 +++++++++++ .../test_diarization_stage.py | 324 +++++++++++++++ .../speaker_diarization/test_pyannote.py | 112 ------ .../tagging/e2e/configs/tts_pipeline.yaml | 16 +- tutorials/audio/tagging/README.md | 9 +- tutorials/audio/tagging/tts_pipeline.yaml | 12 +- 17 files changed, 1499 insertions(+), 436 deletions(-) create mode 100644 nemo_curator/adapters/__init__.py create mode 100644 nemo_curator/adapters/diarization/__init__.py create mode 100644 nemo_curator/adapters/diarization/base.py create mode 100644 nemo_curator/adapters/diarization/pyannote.py delete mode 100644 nemo_curator/stages/audio/inference/speaker_diarization/pyannote.py create mode 100644 nemo_curator/stages/audio/inference/speaker_diarization/stage.py create mode 100644 tests/adapters/__init__.py create mode 100644 tests/adapters/diarization/__init__.py create mode 100644 tests/adapters/diarization/test_pyannote_adapter.py create mode 100644 tests/stages/audio/inference/speaker_diarization/test_diarization_stage.py delete mode 100644 tests/stages/audio/inference/speaker_diarization/test_pyannote.py diff --git a/benchmarking/scripts/audio_tagging_benchmark.py b/benchmarking/scripts/audio_tagging_benchmark.py index 8854bb54b0..d282bcfc83 100644 --- a/benchmarking/scripts/audio_tagging_benchmark.py +++ b/benchmarking/scripts/audio_tagging_benchmark.py @@ -31,7 +31,7 @@ from nemo_curator.pipeline import Pipeline from nemo_curator.stages.audio.common import ManifestReader, ManifestWriterStage -from nemo_curator.stages.audio.inference.speaker_diarization.pyannote import PyAnnoteDiarizationStage +from nemo_curator.stages.audio.inference.speaker_diarization import DiarizationStage from nemo_curator.stages.audio.tagging.inference.nemo_asr_align import NeMoASRAlignerStage from nemo_curator.stages.audio.tagging.merge_alignment_diarization import MergeAlignmentDiarizationStage from nemo_curator.stages.audio.tagging.resample_audio import ResampleAudioStage @@ -85,12 +85,17 @@ def run_audio_tagging_benchmark( # noqa: PLR0913 ).with_(resources=Resources(cpus=cpus)) ) - # Speaker diarization and overlap detection (PyAnnote) + # Speaker diarization and overlap detection (DiarizationStage + PyAnnote adapter) pipeline.add_stage( - PyAnnoteDiarizationStage( + DiarizationStage( name="PyAnnoteDiarization", - hf_token=hf_token, - max_length=max_segment_length, + adapter_target="nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter", + model_id="pyannote/speaker-diarization-3.1", + non_speaker_max_length=max_segment_length, + adapter_kwargs={ + "hf_token": hf_token, + "max_length": max_segment_length, + }, ).with_(resources=Resources(cpus=cpus, gpus=0.5)) ) diff --git a/nemo_curator/adapters/__init__.py b/nemo_curator/adapters/__init__.py new file mode 100644 index 0000000000..62c88a98bc --- /dev/null +++ b/nemo_curator/adapters/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model adapters for the SDP-V2 stage-adapter split. + +Each adapter family (``diarization``, ``vad``, ``alignment``, ...) lives +in its own subpackage and exposes: + +* ``base.py`` - a ``Protocol`` plus a typed ``Result`` dataclass that + every adapter in the family must implement. +* one module per concrete model that implements the protocol. + +Stages in ``nemo_curator/stages/audio/inference/`` import the protocol +and typed result only; the concrete adapter is resolved at runtime from +the YAML's ``adapter_target`` string via ``hydra.utils.get_class``. +""" diff --git a/nemo_curator/adapters/diarization/__init__.py b/nemo_curator/adapters/diarization/__init__.py new file mode 100644 index 0000000000..054dfd12de --- /dev/null +++ b/nemo_curator/adapters/diarization/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Diarization adapter family for the SDP-V2 stage-adapter split. + +Public surface (the only symbols the stage imports): + +* :class:`DiarizationAdapter` - structural protocol every diarization + adapter implements. +* :class:`DiarizationResult` - canonical per-utterance result dataclass. +* :class:`DiarSegment` - canonical per-speaker-turn dataclass. + +Concrete adapters live in their own modules (e.g. ``pyannote.py``, +``sortformer.py``) and are resolved at runtime by their fully-qualified +class path in YAML's ``adapter_target`` field. +""" + +from nemo_curator.adapters.diarization.base import ( + DiarizationAdapter, + DiarizationResult, + DiarSegment, +) +from nemo_curator.adapters.diarization.pyannote import PyAnnoteDiarizationAdapter + +__all__ = [ + "DiarSegment", + "DiarizationAdapter", + "DiarizationResult", + "PyAnnoteDiarizationAdapter", +] diff --git a/nemo_curator/adapters/diarization/base.py b/nemo_curator/adapters/diarization/base.py new file mode 100644 index 0000000000..a29bc47b7f --- /dev/null +++ b/nemo_curator/adapters/diarization/base.py @@ -0,0 +1,170 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage-adapter contract for speaker diarization (SDP-V2 design doc §3). + +Mirrors the ASR contract pattern (``nemo_curator.adapters.asr.base``): + +* :class:`~nemo_curator.stages.audio.inference.speaker_diarization.DiarizationStage` + owns Curator-side glue - ``task.data`` reads, batching, ``min_length`` + / ``max_length`` filtering, non-speaker gap fill, metric logging. +* :class:`DiarizationAdapter` owns the model-side library call - weight + prefetch, model setup, the actual diarizer invocation + (``pyannote.audio.Pipeline(...)``, ``SortformerEncLabelModel.diarize(...)``, + ...) and packing results into the canonical :class:`DiarizationResult` + shape. + +This split lets the stage swap diarizers with a single ``adapter_target:`` +line in YAML without rewriting the stage. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + + +@dataclass +class DiarSegment: + """Canonical per-speaker-turn dataclass. + + Attributes: + start: Turn start time in seconds (clip coordinates). + end: Turn end time in seconds. + speaker: Speaker identifier as emitted by the adapter + (e.g. ``"speaker_0"`` or ``"_speaker_0"``). + The stage does NOT remap speaker ids - the adapter is + responsible for any namespacing. + confidence: Optional adapter-supplied per-turn confidence. + ``None`` when the adapter doesn't surface one (matches + PyAnnote / Sortformer defaults). + """ + + start: float + end: float + speaker: str + confidence: float | None = None + + +@dataclass +class DiarizationResult: + """Canonical per-task diarization adapter output. + + Identical across every diarization adapter so the stage's schema + mutation code path stays constant when the adapter is swapped. + + Attributes: + diar_segments: One :class:`DiarSegment` per speaker turn the + adapter emitted. The stage writes these onto + ``task.data[segments_key]``. + overlap_segments: Optional list of cross-speaker overlap turns. + Populated by adapters that surface overlap detection (e.g. + PyAnnote); empty list otherwise. The stage writes these + onto ``task.data[overlap_segments_key]`` when the key is + configured. + model_id: The actual model identifier the adapter ran (mirrors + the stage's ``model_id`` field; populated by the adapter so + downstream consumers see the live value). + extras: Adapter-specific scalar / structured diagnostics that + do not fit the canonical shape (e.g. raw RTTM path, per-turn + embedding tensors). Stage never reads inside this dict. + """ + + diar_segments: list[DiarSegment] + overlap_segments: list[DiarSegment] = field(default_factory=list) + model_id: str = "" + extras: dict[str, Any] = field(default_factory=dict) + + +@runtime_checkable +class DiarizationAdapter(Protocol): + """Structural protocol every diarization adapter must implement. + + Constructor contract: adapters are constructed by the stage as + ``cls(model_id=..., revision=..., **adapter_kwargs)`` - so every + adapter must accept ``model_id`` and ``revision`` keyword arguments, + plus whatever Tier-2 knobs that adapter exposes. + + Per-batch contract: :meth:`diarize_batch` receives a list of dicts + (Tier-3 per-task knobs unpacked from ``task.data`` by the stage) + and returns one :class:`DiarizationResult` per input, in the same + order. + + Expected per-item keys (the stage populates these; the adapter + reads whichever it needs): + + * ``audio_filepath`` (``str``): Path to a decodable audio file + (typically the §1.2 resampled 16 kHz mono WAV). Always present. + * ``waveform`` (``numpy.ndarray | None``): Optional in-memory + waveform; adapters MAY use it instead of re-decoding from disk + when present. + * ``sample_rate`` (``int | None``): Sample rate of ``waveform``; + ignored when ``waveform`` is ``None``. + * ``audio_item_id`` (``str | None``): Carried through for diagnostic + / speaker-id namespacing. + * ``duration`` (``float | None``): Optional clip duration in + seconds. The stage uses it for non-speaker gap fill if the + adapter doesn't echo it back via ``extras``. + * ``task_id`` (``str | None``): Carried through for diagnostics. + + Attributes: + model_id: Identifier of the underlying model checkpoint. + last_metrics: Scalar metrics from the last + :meth:`diarize_batch` call (per-clip timings, speaker + counts, ...). The stage merges these into its + ``_log_metrics`` output under ``model_`` aliases. + """ + + model_id: str + last_metrics: dict[str, float] + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + """Download weights to local cache without allocating a GPU. + + Called once per node from + :meth:`DiarizationStage.setup_on_node` before any worker + starts. Must be a classmethod so the stage can call it without + instantiating the adapter (which may import heavy GPU + libraries at construction time). + """ + ... + + def setup(self) -> None: + """Load the model into the worker's process. + + Called once per worker from :meth:`DiarizationStage.setup`. May + allocate GPU memory, build pipelines, instantiate processors. + """ + ... + + def teardown(self) -> None: + """Release GPU memory and worker-local state.""" + ... + + def diarize_batch(self, items: list[dict[str, Any]]) -> list[DiarizationResult]: + """Run diarization on a batch of per-task dicts. + + Args: + items: One dict per task with the keys documented on the + class docstring. Length matches the batch size. + + Returns: + One :class:`DiarizationResult` per input, in the same + order. Items the adapter could not process (e.g. corrupt + audio) must still appear in the output list with an empty + ``diar_segments`` list so the stage can preserve task + ordering. + """ + ... diff --git a/nemo_curator/adapters/diarization/pyannote.py b/nemo_curator/adapters/diarization/pyannote.py new file mode 100644 index 0000000000..af8d141851 --- /dev/null +++ b/nemo_curator/adapters/diarization/pyannote.py @@ -0,0 +1,371 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyAnnote diarization adapter. + +Implements :class:`~nemo_curator.adapters.diarization.DiarizationAdapter` +on top of PyAnnote 3.x / 4.x's +``pyannote.audio.pipelines.SpeakerDiarization``. All PyAnnote-specific +behaviour (HF auth, in-pipeline batching knobs, overlap detection, +PyAnnote-Segment ``has_overlap`` walk, WhisperX-VAD-driven micro-split +of long turns, RTTM sidecar write) lives here so the generic +:class:`~nemo_curator.stages.audio.inference.speaker_diarization.DiarizationStage` +stays model-agnostic. + +Logic moved verbatim from the pre-split +``nemo_curator.stages.audio.inference.speaker_diarization.pyannote.PyAnnoteDiarizationStage``; +numeric output is identical for the same inputs and same random seed. +""" + +from __future__ import annotations + +import os +import random +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import soundfile as sf +import torch +from fsspec.core import url_to_fs +from loguru import logger + +# PyAnnote imports - kept module-level (matches pre-split stage) because +# the adapter is imported only inside DiarizationStage.setup() on a GPU +# worker that already has PyAnnote available. +from pyannote.audio import Pipeline as PyAnnotePipeline +from pyannote.audio.pipelines.utils.hook import ProgressHook +from pyannote.core import Segment + +from nemo_curator.adapters.diarization.base import DiarizationResult, DiarSegment +from nemo_curator.stages.audio.inference.vad.whisperx_vad import WhisperXVADModel + +if TYPE_CHECKING: + from pyannote.core import Annotation + + +def has_overlap(turn: Segment, overlaps: list) -> bool: + """Check if a given turn overlaps with any segment in the overlaps list. + + Args: + turn: A segment representing a speech turn. + overlaps: List of overlap segments, sorted by start time. + + Returns: + True if the turn overlaps with any segment, False otherwise. + """ + turn_overlaps = False + for overlap in overlaps: + if overlap.start > turn.end: + # Overlap happens after turn, no need to keep looping since overlaps is sorted + break + if overlap.start >= turn.start and overlap.start < turn.end: + # Overlap starts during turn + turn_overlaps = True + break + if (overlap.end < turn.end) and (overlap.end > turn.start): + # Overlap ends during turn + turn_overlaps = True + break + if overlap.start < turn.start and overlap.end > turn.end: + # Overlap completely contains the turn + turn_overlaps = True + break + return turn_overlaps + + +@dataclass +class PyAnnoteDiarizationAdapter: + """PyAnnote-backed implementation of :class:`DiarizationAdapter`. + + Tier-2 knobs (set via ``adapter_kwargs`` in YAML; the stage forwards + them verbatim): + + Attributes: + model_id: HuggingFace model id (e.g. ``pyannote/speaker-diarization-3.1``). + Mirrors the stage's ``model_id`` Tier-1 field; passed through + by the stage. + revision: HuggingFace revision pin or ``None``. Currently unused + by PyAnnote's ``from_pretrained`` but accepted for protocol + uniformity. + hf_token: HuggingFace authentication token (required - PyAnnote + models are gated). + device: ``"cuda"`` or ``"cpu"``. The stage passes the worker's + actual device; default ``"cuda"`` matches the pre-split + behaviour. + segmentation_batch_size: Forwarded to PyAnnote pipeline. + embedding_batch_size: Forwarded to PyAnnote pipeline. + min_length: Minimum speech-turn duration kept by the adapter + (shorter turns are dropped before being returned to the + stage). + max_length: Speech-turn length above which the adapter runs a + WhisperX-VAD-driven micro-split to break long turns into + sub-segments bounded between ``min_length`` and + ``max_length``. + write_rttm: When True, write an RTTM sidecar next to the input + audio (same path with ``.rttm`` extension). Pre-split + behaviour was unconditional write; we keep True as the + default for compatibility. + random_seed: Optional integer seed for the WhisperX-VAD-driven + micro-split's uniform sampling. ``None`` (default) matches + the pre-split non-deterministic behaviour. + """ + + # ---- Required protocol fields ---- + model_id: str = "pyannote/speaker-diarization-3.1" + revision: str | None = None + + # ---- PyAnnote-specific knobs ---- + hf_token: str = "" + device: str = "cuda" + segmentation_batch_size: int = 128 + embedding_batch_size: int = 128 + min_length: float = 0.5 + max_length: float = 40.0 + write_rttm: bool = True + random_seed: int | None = None + + # ---- Internal state (not serialised, populated in setup()) ---- + _pipeline: Any = field(default=None, repr=False) + _vad_model: Any = field(default=None, repr=False) + _rng: random.Random | None = field(default=None, repr=False) + last_metrics: dict[str, float] = field(default_factory=dict) + + # ------------------------------------------------------------------ + # Adapter contract: prefetch + setup + teardown + diarize_batch + # ------------------------------------------------------------------ + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + """Download PyAnnote pipeline weights to local cache. + + PyAnnote's ``Pipeline.from_pretrained`` is the only public entry + point that triggers the HF download; calling it once with a + valid token caches the segmentation + embedding sub-models on + disk. We tolerate the absence of an ``HF_TOKEN`` env var during + prefetch (the stage's ``prefetch_fail_on_error=False`` path will + defer the actual setup until the worker is up). + """ + del revision # PyAnnote uses ``from_pretrained`` without revision arg. + hf_token = os.environ.get("HF_TOKEN", "") + if not hf_token: + msg = ( + "PyAnnoteDiarizationAdapter.prefetch_weights: HF_TOKEN env var " + "is not set; weights will be downloaded lazily on the worker." + ) + raise RuntimeError(msg) + PyAnnotePipeline.from_pretrained(model_id, token=hf_token) + + def setup(self) -> None: + if self._pipeline is None: + self._pipeline = PyAnnotePipeline.from_pretrained( + self.model_id, token=self.hf_token or None + ) + self._pipeline.segmentation_batch_size = int(self.segmentation_batch_size) + self._pipeline.embedding_batch_size = int(self.embedding_batch_size) + + if self._vad_model is None: + self._vad_model = WhisperXVADModel( + device=self.device, + vad_onset=0.5, + vad_offset=0.363, + ) + + self._pipeline.to(torch.device(self.device)) + self._vad_model.to(self.device) + + self._rng = random.Random(self.random_seed) if self.random_seed is not None else random.Random() # noqa: S311 + logger.info("PyAnnoteDiarizationAdapter ready on {} (model={})", self.device, self.model_id) + + def teardown(self) -> None: + self._pipeline = None + self._vad_model = None + self._rng = None + + def diarize_batch(self, items: list[dict[str, Any]]) -> list[DiarizationResult]: + if not items: + return [] + if self._pipeline is None: + msg = "PyAnnoteDiarizationAdapter.setup() must be called before diarize_batch()" + raise RuntimeError(msg) + + results: list[DiarizationResult] = [] + per_item_times: list[float] = [] + per_item_speakers: list[int] = [] + per_item_segments: list[int] = [] + per_item_overlaps: list[int] = [] + + for item in items: + t0 = time.perf_counter() + audio_filepath = item.get("audio_filepath") + if not audio_filepath: + # Empty result; the stage's add_non_speaker_segments still runs. + results.append(DiarizationResult(diar_segments=[], model_id=self.model_id)) + per_item_times.append(time.perf_counter() - t0) + per_item_speakers.append(0) + per_item_segments.append(0) + per_item_overlaps.append(0) + continue + result = self._diarize_one(item) + results.append(result) + per_item_times.append(time.perf_counter() - t0) + per_item_speakers.append( + len({seg.speaker for seg in result.diar_segments if seg.speaker != "no-speaker"}) + ) + per_item_segments.append(len(result.diar_segments)) + per_item_overlaps.append(len(result.overlap_segments)) + + self.last_metrics = { + "batch_size": float(len(items)), + "diarize_time_s_total": float(sum(per_item_times)), + "diarize_time_s_max": float(max(per_item_times)) if per_item_times else 0.0, + "speakers_detected_max": float(max(per_item_speakers)) if per_item_speakers else 0.0, + "segments_detected_total": float(sum(per_item_segments)), + "overlap_segments_detected_total": float(sum(per_item_overlaps)), + } + return results + + # ------------------------------------------------------------------ + # Internal helpers (moved verbatim from PyAnnoteDiarizationStage) + # ------------------------------------------------------------------ + + def _diarize_one(self, item: dict[str, Any]) -> DiarizationResult: + audio_filepath: str = item["audio_filepath"] + audio_item_id: str | None = item.get("audio_item_id") + + # Pre-split behaviour: read with soundfile (avoids torchcodec/FFmpeg). + data, fs = sf.read(audio_filepath, dtype="float32") + s = torch.from_numpy(data).unsqueeze(0) if data.ndim == 1 else torch.from_numpy(data.T) + logger.info("Processing {}", audio_filepath) + + with ProgressHook() as hook: + result = self._pipeline({"waveform": s, "sample_rate": fs}, hook=hook) + + # pyannote-audio 4.x returns DiarizeOutput; older returns Annotation. + diarization: Annotation = ( + result.speaker_diarization if hasattr(result, "speaker_diarization") else result + ) + + overlaps = diarization.get_overlap().segments_list_ + + # Crop to audio length (fix for the historical PyAnnote bug + # where annotations could extend a few ms past the waveform). + diarization = diarization.crop(Segment(0, len(s[0]) / fs)) + + if self.write_rttm: + self._write_rttm(diarization, audio_filepath) + + diar_segments: list[DiarSegment] = [] + overlap_segments: list[DiarSegment] = [] + + for speech_turn, _track, speaker in diarization.itertracks(yield_label=True): + if audio_item_id: + speaker_id = f"{audio_item_id}_{speaker}" + elif item.get("speaker_id"): + speaker_id = f"{item['speaker_id']}_{speaker}" + else: + speaker_id = f"{Path(audio_filepath).stem}_{speaker}" + + if has_overlap(speech_turn, overlaps): + overlap_segments.append( + DiarSegment( + start=float(speech_turn.start), + end=float(speech_turn.end), + speaker=speaker_id, + ) + ) + continue + + speech_duration = speech_turn.end - speech_turn.start + if speech_duration > self.min_length: + self._add_vad_segments( + audio=s, + fs=fs, + start=float(speech_turn.start), + end=float(speech_turn.end), + segments=diar_segments, + speaker_id=speaker_id, + ) + + return DiarizationResult( + diar_segments=diar_segments, + overlap_segments=overlap_segments, + model_id=self.model_id, + ) + + def _write_rttm(self, diarization: Annotation, audio_filepath: str) -> None: + logger.info("Writing {} turns to RTTM file", len(diarization._tracks)) # noqa: SLF001 + rttm_filepath = os.path.splitext(audio_filepath)[0] + ".rttm" + rttm_fs, rttm_path = url_to_fs(rttm_filepath) + with rttm_fs.open(rttm_path, "w") as rttm_file: + diarization.write_rttm(rttm_file) + + def _add_vad_segments( # noqa: PLR0913 + self, + audio: torch.Tensor, + fs: int, + start: float, + end: float, + segments: list[DiarSegment], + speaker_id: str, + ) -> None: + """Sub-split a long speech turn using WhisperX VAD timings. + + For turns longer than ``max_length`` we run WhisperX VAD on the + slice and pack contiguous VAD segments into sub-turns of + random length sampled uniformly from + [``min_length``, ``max_length``]. Mirrors the pre-split stage + implementation byte-for-byte (same uniform sample, same + index walk). + """ + segment_duration = end - start + + if segment_duration > self.max_length: + audio_seg = audio[:, int(start * fs) : int(end * fs)] + vad_segments = self._vad_model.get_vad_segments( + audio_seg.numpy(), self.max_length, sample_rate=fs + ) + i = 0 + n = len(vad_segments) + + while i < n: + random_duration = self._rng.uniform(self.min_length, self.max_length) + start_seg = vad_segments[i]["start"] + end_seg = vad_segments[i]["end"] + + if end_seg - start_seg >= random_duration: + segments.append( + DiarSegment( + start=start + start_seg, + end=start + end_seg, + speaker=speaker_id, + ) + ) + i += 1 + continue + + while i < n and (vad_segments[i]["end"] - start_seg) < random_duration: + end_seg = vad_segments[i]["end"] + i += 1 + + segments.append( + DiarSegment( + start=start + start_seg, + end=start + end_seg, + speaker=speaker_id, + ) + ) + else: + segments.append(DiarSegment(start=start, end=end, speaker=speaker_id)) diff --git a/nemo_curator/stages/audio/inference/speaker_diarization/__init__.py b/nemo_curator/stages/audio/inference/speaker_diarization/__init__.py index e69de29bb2..e906f80718 100644 --- a/nemo_curator/stages/audio/inference/speaker_diarization/__init__.py +++ b/nemo_curator/stages/audio/inference/speaker_diarization/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic speaker-diarization stage (SDP-V2 stage-adapter split, §3).""" + +from nemo_curator.stages.audio.inference.speaker_diarization.stage import DiarizationStage + +__all__ = ["DiarizationStage"] diff --git a/nemo_curator/stages/audio/inference/speaker_diarization/pyannote.py b/nemo_curator/stages/audio/inference/speaker_diarization/pyannote.py deleted file mode 100644 index bd6ea2ecc2..0000000000 --- a/nemo_curator/stages/audio/inference/speaker_diarization/pyannote.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -PyAnnote Diarization and Overlap Detection Stage. -""" - -import os -import random -import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -import soundfile as sf -import torch -from fsspec.core import url_to_fs -from loguru import logger - -# Import pyannote components -from pyannote.audio import Pipeline as PyAnnotePipeline -from pyannote.audio.pipelines.utils.hook import ProgressHook -from pyannote.core import Segment - -from nemo_curator.backends.base import NodeInfo, WorkerMetadata -from nemo_curator.stages.audio.common import get_audio_duration -from nemo_curator.stages.audio.inference.vad.whisperx_vad import WhisperXVADModel -from nemo_curator.stages.audio.tagging.utils import add_non_speaker_segments -from nemo_curator.stages.base import ProcessingStage -from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import AudioTask - - -def has_overlap(turn: Segment, overlaps: list) -> bool: - """Check if a given turn overlaps with any segment in the overlaps list. - - Args: - turn: A segment representing a speech turn - overlaps: List of overlap segments, sorted by start time - - Returns: - True if the turn overlaps with any segment, False otherwise - """ - turn_overlaps = False - for overlap in overlaps: - if overlap.start > turn.end: - # Overlap happens after turn, no need to keep looping since overlaps is sorted - break - elif overlap.start >= turn.start and overlap.start < turn.end: - # overlap starts during turn - turn_overlaps = True - break - elif (overlap.end < turn.end) and (overlap.end > turn.start): - # overlap ends during turn - turn_overlaps = True - break - elif overlap.start < turn.start and overlap.end > turn.end: - # Overlap completely contains the turn - turn_overlaps = True - break - return turn_overlaps - - -@dataclass -class PyAnnoteDiarizationStage(ProcessingStage[AudioTask, AudioTask]): - """ - Stage that performs speaker diarization and overlap detection using PyAnnote. - - Identifies different speakers and detects overlapping speech segments. - - Args: - hf_token: HuggingFace authentication token - segmentation_batch_size: Batch size for segmentation - embedding_batch_size: Batch size for speaker embeddings - min_length: Minimum segment length in seconds - max_length: Maximum segment length in seconds - xenna_num_workers: If set, passes ``num_workers`` to Xenna (cluster-wide cap). Unset uses Xenna autoscaling. - """ - - hf_token: str - - # Diarization pipeline model ID on HuggingFace - model_name: str = "pyannote/speaker-diarization-3.1" - - # Model parameters - segmentation_batch_size: int = 128 - embedding_batch_size: int = 128 - - # Segment length constraints - min_length: float = 0.5 - max_length: float = 40.0 - - audio_filepath_key: str = "resampled_audio_filepath" - segments_key: str = "segments" - overlap_segments_key: str = "overlap_segments" - - # Stage metadata - name: str = "PyAnnoteDiarization" - resources: Resources = field(default_factory=lambda: Resources(gpus=1)) - - # Xenna executor (optional; unset = default autoscaling) - xenna_num_workers: int | None = None - - # Internal state (not serialized, initialized in setup() to allow deepcopy) - _pipeline: Any = field(default=None, repr=False) - _vad_model: Any = field(default=None, repr=False) # WhisperXVADModel - _rng: random.Random | None = field(default=None, repr=False) - - def inputs(self) -> tuple[list[str], list[str]]: - return [], [self.audio_filepath_key] - - def outputs(self) -> tuple[list[str], list[str]]: - return [], [self.audio_filepath_key, self.segments_key, self.overlap_segments_key] - - def xenna_stage_spec(self) -> dict[str, Any]: - spec: dict[str, Any] = {} - if self.xenna_num_workers is not None: - spec["num_workers"] = self.xenna_num_workers - return spec - - @property - def _device(self) -> str: - """Derive device from resources configuration.""" - return "cuda" if self.resources.requires_gpu else "cpu" - - def setup_on_node( - self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None - ) -> None: - """Download model weights (called once per node).""" - if self._pipeline is None: - self._pipeline = PyAnnotePipeline.from_pretrained(self.model_name, token=self.hf_token) - if self._vad_model is None: - self._vad_model = WhisperXVADModel( - device="cpu", - vad_onset=0.5, - vad_offset=0.363, - ) - - def setup(self, _: WorkerMetadata | None = None) -> None: - """Load models to device (called per replica before processing).""" - if self._pipeline is None: - self._pipeline = PyAnnotePipeline.from_pretrained(self.model_name, token=self.hf_token) - self._pipeline.segmentation_batch_size = self.segmentation_batch_size - self._pipeline.embedding_batch_size = self.embedding_batch_size - - if self._vad_model is None: - self._vad_model = WhisperXVADModel( - device=self._device, - vad_onset=0.5, - vad_offset=0.363, - ) - - self._pipeline.to(torch.device(self._device)) - self._vad_model.to(self._device) - - self._rng = random.Random() # noqa: S311 - logger.info(f"[{self.name}] Initialized PyAnnote diarization on {self._device}") - - def add_vad_segments( # noqa: PLR0913 - self, - audio: torch.Tensor, - fs: int, - start: float, - end: float, - segments: list[dict], - speaker_id: str, - ) -> None: - """Add VAD segments for a given audio region to the segments list.""" - segment_duration = end - start - - if segment_duration > self.max_length: - audio_seg = audio[:, int(start * fs) : int(end * fs)] - vad_segments = self._vad_model.get_vad_segments(audio_seg.numpy(), self.max_length, sample_rate=fs) - i = 0 - n = len(vad_segments) - - while i < n: - random_duration = self._rng.uniform(self.min_length, self.max_length) - start_seg = vad_segments[i]["start"] - end_seg = vad_segments[i]["end"] - - if end_seg - start_seg >= random_duration: - segments.append( - { - "speaker": speaker_id, - "start": start + start_seg, - "end": start + end_seg, - } - ) - i += 1 - continue - - while i < n and (vad_segments[i]["end"] - start_seg) < random_duration: - end_seg = vad_segments[i]["end"] - i += 1 - - segments.append( - { - "speaker": speaker_id, - "start": start + start_seg, - "end": start + end_seg, - } - ) - else: - segments.append({"speaker": speaker_id, "start": start, "end": end}) - - def process(self, task: AudioTask) -> AudioTask: - """Process a single entry for diarization and overlap detection.""" - t0 = time.perf_counter() - data_entry = task.data - file_path = data_entry.get(self.audio_filepath_key) - if not file_path: - msg = f"[{self.name}] Missing key '{self.audio_filepath_key}' in entry: {data_entry.get('audio_item_id', 'unknown')}" - raise ValueError(msg) - - # Load audio using soundfile (avoids torchcodec/FFmpeg dependency) - data, fs = sf.read(file_path, dtype="float32") - s = torch.from_numpy(data).unsqueeze(0) if data.ndim == 1 else torch.from_numpy(data.T) - logger.info(f"Processing {file_path}") - - # Run diarization - with ProgressHook() as hook: - result = self._pipeline({"waveform": s, "sample_rate": fs}, hook=hook) - - # pyannote-audio 4.x returns DiarizeOutput; extract the Annotation - diarization = result.speaker_diarization if hasattr(result, "speaker_diarization") else result - - overlaps = diarization.get_overlap().segments_list_ - - # Crop to audio length (fix for PyAnnote bug) - diarization = diarization.crop(Segment(0, len(s[0]) / fs)) - - # Write RTTM file (cloud-aware via fsspec) - logger.info(f"Writing {len(diarization._tracks)} turns to RTTM file") - rttm_filepath = os.path.splitext(file_path)[0] + ".rttm" - rttm_fs, rttm_path = url_to_fs(rttm_filepath) - with rttm_fs.open(rttm_path, "w") as rttm_file: - diarization.write_rttm(rttm_file) - - segments = [] - overlap_segments = [] - - # Process speaker turns - for speech_turn, _track, speaker in diarization.itertracks(yield_label=True): - if "audio_item_id" in data_entry: - speaker_id = data_entry["audio_item_id"] + "_" + speaker - elif "speaker_id" in data_entry: - speaker_id = data_entry["speaker_id"] + "_" + speaker - elif self.audio_filepath_key in data_entry: - speaker_id = Path(data_entry[self.audio_filepath_key]).stem + "_" + speaker - else: - msg = f"No speaker identifier in {file_path}" - raise ValueError(msg) - - if has_overlap(speech_turn, overlaps): - overlap_segments.append( - { - "speaker": speaker_id, - "start": speech_turn.start, - "end": speech_turn.end, - } - ) - else: - speech_duration = speech_turn.end - speech_turn.start - if speech_duration > self.min_length: - self.add_vad_segments( - s, - fs, - speech_turn.start, - speech_turn.end, - segments, - speaker_id, - ) - - # Add non-speaker segments - audio_duration = data_entry.get("duration", get_audio_duration(file_path)) - add_non_speaker_segments(segments, audio_duration, self.max_length) - - # Update entry - data_entry[self.segments_key] = segments - data_entry[self.overlap_segments_key] = overlap_segments - - speakers = {seg["speaker"] for seg in segments if seg.get("speaker") != "no-speaker"} - self._log_metrics( - { - "process_time": time.perf_counter() - t0, - "segments_detected": len(segments), - "overlap_segments_detected": len(overlap_segments), - "speakers_detected": len(speakers), - "audio_duration": audio_duration, - } - ) - return task diff --git a/nemo_curator/stages/audio/inference/speaker_diarization/stage.py b/nemo_curator/stages/audio/inference/speaker_diarization/stage.py new file mode 100644 index 0000000000..a129e1d3b2 --- /dev/null +++ b/nemo_curator/stages/audio/inference/speaker_diarization/stage.py @@ -0,0 +1,290 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic speaker-diarization Curator stage (SDP-V2 design doc §3). + +Implements the stage half of the SDP-V2 stage-adapter split for the +diarization family. The stage owns Curator-side glue: + +* validates ``task.data`` against ``inputs()`` / ``outputs()``; +* unpacks per-task knobs (audio filepath, optional in-memory waveform, + ``audio_item_id``, ``duration``) into a single item dict; +* dispatches the adapter's ``diarize_batch`` once per task; +* converts the adapter's typed :class:`DiarSegment` results into the + on-disk dict shape downstream consumers expect + (``{"speaker": ..., "start": ..., "end": ...}``); +* fills the inter-turn gaps with ``no-speaker`` segments via + :func:`add_non_speaker_segments`; +* writes ``task.data[segments_key]`` and optionally + ``task.data[overlap_segments_key]``; +* emits performance metrics in the shape ``perf_summary_merged.json`` + consumers already expect. + +The stage knows nothing about which diarizer is running. The concrete +adapter class is resolved at runtime from the YAML's ``adapter_target`` +string via ``hydra.utils.get_class``. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import hydra.utils +from loguru import logger + +from nemo_curator.adapters.diarization.base import DiarizationAdapter, DiarizationResult +from nemo_curator.stages.audio.common import get_audio_duration +from nemo_curator.stages.audio.tagging.utils import add_non_speaker_segments +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + +if TYPE_CHECKING: + from nemo_curator.backends.base import NodeInfo, WorkerMetadata + + +@dataclass +class DiarizationStage(ProcessingStage[AudioTask, AudioTask]): + """Speaker-diarization Curator stage with pluggable adapter. + + Args: + adapter_target: Tier-1 swap surface. Fully-qualified class path + of the concrete + :class:`~nemo_curator.adapters.diarization.DiarizationAdapter` + implementation (e.g. + ``"nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter"``). + Resolved at ``setup()`` time via ``hydra.utils.get_class``. + model_id: Tier-1. Model checkpoint identifier, forwarded both to + :meth:`DiarizationAdapter.prefetch_weights` (in + ``setup_on_node``) and to the adapter constructor. + revision: Tier-1. Optional model revision to pin. + audio_filepath_key: Key into ``task.data`` carrying the decoded + audio path. Defaults to ``"resampled_audio_filepath"`` for + symmetry with the §1.2 ResampleAudioStage output. + waveform_key: Optional key into ``task.data`` carrying an + in-memory waveform. When present alongside + ``sample_rate_key`` and ``filepath_fallback_key`` is + enabled, adapters MAY use the in-memory buffer to avoid a + re-decode. + sample_rate_key: Key into ``task.data`` carrying the + ``waveform_key`` array's sample rate. + segments_key: Key under which the canonical per-speaker turn + list is written. Each entry is a dict + ``{"speaker": str, "start": float, "end": float}``; + includes ``no-speaker`` gap-fill segments emitted by + :func:`add_non_speaker_segments`. + overlap_segments_key: When set, the stage also writes a list + of overlap turns under this key. Set ``None`` for adapters + that don't surface overlap detection. + non_speaker_max_length: Optional ceiling for the + ``no-speaker`` gap-fill segments. When set, long gaps are + split into ``<= non_speaker_max_length`` second chunks + (preserves the pre-split behaviour where this was tied to + the PyAnnote adapter's ``max_length`` knob). + prefetch_fail_on_error: When False, ``setup_on_node`` warns and + defers weight prefetch to ``setup()`` instead of raising. + adapter_kwargs: Tier-2. Opaque dict forwarded to the adapter + constructor as ``**adapter_kwargs``. The stage NEVER reads + inside this dict - it is the adapter's private knob bag. + resources / batch_size: Standard Curator stage knobs. + xenna_num_workers: Optional cluster-wide cap forwarded to the + Xenna scheduler. ``None`` (default) lets Xenna autoscale. + """ + + name: str = "Diarization" + + # ---- Tier 1: swap surface ---- + adapter_target: str = "" + model_id: str = "" + revision: str | None = None + + # ---- Tier 1: universal stage knobs ---- + audio_filepath_key: str = "resampled_audio_filepath" + waveform_key: str | None = None + sample_rate_key: str | None = None + segments_key: str = "segments" + overlap_segments_key: str | None = "overlap_segments" + non_speaker_max_length: float | None = 40.0 + + prefetch_fail_on_error: bool = True + + # ---- Tier 2: opaque adapter knob bag ---- + adapter_kwargs: dict[str, Any] = field(default_factory=dict) + + # ---- Standard Curator stage knobs ---- + resources: Resources = field(default_factory=lambda: Resources(gpus=1)) + xenna_num_workers: int | None = None + + # ---- Internal state ---- + _adapter: DiarizationAdapter | None = field(default=None, init=False, repr=False) + + def __post_init__(self) -> None: + if not self.adapter_target: + msg = ( + "DiarizationStage.adapter_target is required - set it in YAML to a fully-qualified " + "adapter class path (e.g. 'nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter')." + ) + raise ValueError(msg) + + # ------------------------------------------------------------------ + # I/O contract + # ------------------------------------------------------------------ + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key] + + def outputs(self) -> tuple[list[str], list[str]]: + keys = [self.audio_filepath_key, self.segments_key] + if self.overlap_segments_key: + keys.append(self.overlap_segments_key) + return [], keys + + def xenna_stage_spec(self) -> dict[str, Any]: + spec: dict[str, Any] = {} + if self.xenna_num_workers is not None: + spec["num_workers"] = self.xenna_num_workers + return spec + + @property + def _device(self) -> str: + return "cuda" if self.resources.requires_gpu else "cpu" + + # ------------------------------------------------------------------ + # Adapter lifecycle + # ------------------------------------------------------------------ + + def _adapter_class(self) -> type: + return hydra.utils.get_class(self.adapter_target) + + def setup_on_node( + self, + _node_info: NodeInfo | None = None, + _worker_metadata: WorkerMetadata | None = None, + ) -> None: + """Cache diarizer weights once per node (no GPU allocation).""" + try: + prefetch_t0 = time.perf_counter() + self._adapter_class().prefetch_weights(self.model_id, self.revision) + logger.info( + "Diarization weights cached on node for {} ({}) in {:.3f}s", + self.model_id, + self.adapter_target, + time.perf_counter() - prefetch_t0, + ) + except Exception as exc: # noqa: BLE001 + msg = f"DiarizationStage: prefetch_weights failed for {self.model_id}" + if self.prefetch_fail_on_error: + raise RuntimeError(msg) from exc + logger.warning("{}; setup() will retry: {}", msg, exc) + + def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: + if self._adapter is None: + cls = self._adapter_class() + kwargs = dict(self.adapter_kwargs) + # The stage owns model_id/revision (Tier-1); pass through. + kwargs.setdefault("model_id", self.model_id) if self.model_id else None + kwargs.setdefault("revision", self.revision) + # Inject device hint when the adapter accepts one (PyAnnote does). + kwargs.setdefault("device", self._device) + self._adapter = cls(**kwargs) + self._adapter.setup() + logger.info( + "[{}] Diarization adapter ready ({})", + self.name, + self.adapter_target, + ) + + def teardown(self) -> None: + if self._adapter is not None: + self._adapter.teardown() + self._adapter = None + + # ------------------------------------------------------------------ + # Processing + # ------------------------------------------------------------------ + + def _build_item(self, task: AudioTask) -> dict[str, Any]: + data = task.data + item: dict[str, Any] = { + "audio_filepath": data.get(self.audio_filepath_key), + "audio_item_id": data.get("audio_item_id"), + "speaker_id": data.get("speaker_id"), + "duration": data.get("duration"), + "task_id": getattr(task, "task_id", None), + } + if self.waveform_key: + item["waveform"] = data.get(self.waveform_key) + if self.sample_rate_key: + item["sample_rate"] = data.get(self.sample_rate_key) + return item + + @staticmethod + def _segment_to_dict(seg: Any) -> dict[str, Any]: + out: dict[str, Any] = { + "speaker": seg.speaker, + "start": float(seg.start), + "end": float(seg.end), + } + if getattr(seg, "confidence", None) is not None: + out["confidence"] = float(seg.confidence) + return out + + def process(self, task: AudioTask) -> AudioTask: + t0 = time.perf_counter() + data_entry = task.data + + if self._adapter is None: + msg = "Adapter not initialized - setup() was not called" + raise RuntimeError(msg) + + file_path = data_entry.get(self.audio_filepath_key) + if not file_path: + msg = ( + f"[{self.name}] Missing key '{self.audio_filepath_key}' in entry: " + f"{data_entry.get('audio_item_id', 'unknown')}" + ) + raise ValueError(msg) + + item = self._build_item(task) + results: list[DiarizationResult] = self._adapter.diarize_batch([item]) + result = results[0] if results else DiarizationResult(diar_segments=[], model_id=self.model_id) + + # Convert typed segments -> on-disk dict shape. + segments: list[dict[str, Any]] = [self._segment_to_dict(seg) for seg in result.diar_segments] + overlap_segments: list[dict[str, Any]] = [ + self._segment_to_dict(seg) for seg in result.overlap_segments + ] + + # Non-speaker gap fill (works on dicts; mirrors pre-split behaviour). + audio_duration = data_entry.get("duration", get_audio_duration(file_path)) + add_non_speaker_segments(segments, audio_duration, self.non_speaker_max_length) + + data_entry[self.segments_key] = segments + if self.overlap_segments_key: + data_entry[self.overlap_segments_key] = overlap_segments + + speakers = {seg["speaker"] for seg in segments if seg.get("speaker") != "no-speaker"} + metrics: dict[str, float] = { + "process_time": time.perf_counter() - t0, + "segments_detected": float(len(segments)), + "overlap_segments_detected": float(len(overlap_segments)), + "speakers_detected": float(len(speakers)), + "audio_duration": float(audio_duration), + } + for key, value in (self._adapter.last_metrics or {}).items(): + metrics[f"model_{key}"] = float(value) + self._log_metrics(metrics) + return task diff --git a/nemo_curator/stages/audio/tagging/__init__.py b/nemo_curator/stages/audio/tagging/__init__.py index aa80c03f0d..b2c914faca 100644 --- a/nemo_curator/stages/audio/tagging/__init__.py +++ b/nemo_curator/stages/audio/tagging/__init__.py @@ -38,7 +38,9 @@ # --- Inference (tagging/inference/) --- "BaseASRProcessorStage": "nemo_curator.stages.audio.tagging.inference.nemo_asr_align", "NeMoASRAlignerStage": "nemo_curator.stages.audio.tagging.inference.nemo_asr_align", - "PyAnnoteDiarizationStage": "nemo_curator.stages.audio.inference.speaker_diarization.pyannote", + # --- Inference (stage-adapter split per SDP-V2 design) --- + "DiarizationStage": "nemo_curator.stages.audio.inference.speaker_diarization", + "PyAnnoteDiarizationAdapter": "nemo_curator.adapters.diarization", "WhisperXVADStage": "nemo_curator.stages.audio.inference.vad.whisperx_vad", } diff --git a/tests/adapters/__init__.py b/tests/adapters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/adapters/diarization/__init__.py b/tests/adapters/diarization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/adapters/diarization/test_pyannote_adapter.py b/tests/adapters/diarization/test_pyannote_adapter.py new file mode 100644 index 0000000000..3ccf65d429 --- /dev/null +++ b/tests/adapters/diarization/test_pyannote_adapter.py @@ -0,0 +1,221 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for PyAnnoteDiarizationAdapter. + +Heavy PyAnnote internals are mocked. End-to-end inference is covered by +the e2e GPU test (``tests/stages/audio/tagging/e2e/test_tts_e2e.py``). +""" + +from __future__ import annotations + +import os +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip("pyannote.audio") + +from nemo_curator.adapters.diarization.pyannote import ( + PyAnnoteDiarizationAdapter, + has_overlap, +) + + +# --------------------------------------------------------------------------- +# has_overlap helper (moved verbatim from the deleted pyannote.py stage) +# --------------------------------------------------------------------------- + + +class TestHasOverlap: + def test_turn_overlaps_with_segment(self) -> None: + turn = SimpleNamespace(start=0.0, end=2.0) + overlaps = [SimpleNamespace(start=1.0, end=1.5)] + assert has_overlap(turn, overlaps) is True + + def test_turn_after_overlap_returns_false(self) -> None: + turn = SimpleNamespace(start=3.0, end=4.0) + overlaps = [SimpleNamespace(start=1.0, end=2.0)] + assert has_overlap(turn, overlaps) is False + + def test_turn_before_overlap_returns_false(self) -> None: + turn = SimpleNamespace(start=0.0, end=0.5) + overlaps = [SimpleNamespace(start=1.0, end=2.0)] + assert has_overlap(turn, overlaps) is False + + def test_empty_overlaps_returns_false(self) -> None: + turn = SimpleNamespace(start=0.0, end=1.0) + assert has_overlap(turn, []) is False + + def test_overlap_fully_contains_turn(self) -> None: + turn = SimpleNamespace(start=1.0, end=2.0) + overlaps = [SimpleNamespace(start=0.5, end=2.5)] + assert has_overlap(turn, overlaps) is True + + +# --------------------------------------------------------------------------- +# Adapter construction and protocol conformance +# --------------------------------------------------------------------------- + + +class TestPyAnnoteDiarizationAdapterConstruction: + def test_defaults(self) -> None: + a = PyAnnoteDiarizationAdapter() + assert a.model_id == "pyannote/speaker-diarization-3.1" + assert a.revision is None + assert a.hf_token == "" + assert a.device == "cuda" + assert a.min_length == 0.5 + assert a.max_length == 40.0 + assert a.write_rttm is True + assert a.random_seed is None + assert a.last_metrics == {} + + def test_conforms_to_protocol(self) -> None: + from nemo_curator.adapters.diarization.base import DiarizationAdapter + + a = PyAnnoteDiarizationAdapter() + assert isinstance(a, DiarizationAdapter) + + +# --------------------------------------------------------------------------- +# prefetch_weights +# --------------------------------------------------------------------------- + + +class TestPrefetchWeights: + def test_prefetch_requires_hf_token(self) -> None: + with patch.dict(os.environ, {}, clear=True), pytest.raises(RuntimeError, match="HF_TOKEN"): + PyAnnoteDiarizationAdapter.prefetch_weights("pyannote/speaker-diarization-3.1") + + def test_prefetch_calls_from_pretrained(self) -> None: + with patch.dict(os.environ, {"HF_TOKEN": "abc"}, clear=True), patch( + "nemo_curator.adapters.diarization.pyannote.PyAnnotePipeline.from_pretrained" + ) as mock_from: + PyAnnoteDiarizationAdapter.prefetch_weights("pyannote/speaker-diarization-3.1") + mock_from.assert_called_once_with("pyannote/speaker-diarization-3.1", token="abc") + + +# --------------------------------------------------------------------------- +# setup / teardown +# --------------------------------------------------------------------------- + + +class TestSetupTeardown: + @patch("nemo_curator.adapters.diarization.pyannote.WhisperXVADModel") + @patch("nemo_curator.adapters.diarization.pyannote.PyAnnotePipeline") + def test_setup_wires_pipeline_and_vad(self, mock_pipeline_cls: MagicMock, mock_vad_cls: MagicMock) -> None: + mock_pipe = MagicMock() + mock_pipeline_cls.from_pretrained.return_value = mock_pipe + + a = PyAnnoteDiarizationAdapter( + hf_token="tok", + device="cpu", + segmentation_batch_size=64, + embedding_batch_size=32, + random_seed=42, + ) + a.setup() + + mock_pipeline_cls.from_pretrained.assert_called_once_with( + "pyannote/speaker-diarization-3.1", token="tok" + ) + assert mock_pipe.segmentation_batch_size == 64 + assert mock_pipe.embedding_batch_size == 32 + mock_pipe.to.assert_called_once() + mock_vad_cls.assert_called_once() + + @patch("nemo_curator.adapters.diarization.pyannote.WhisperXVADModel") + @patch("nemo_curator.adapters.diarization.pyannote.PyAnnotePipeline") + def test_teardown_clears_state(self, _mp: MagicMock, _mv: MagicMock) -> None: + a = PyAnnoteDiarizationAdapter(hf_token="tok", device="cpu") + a.setup() + a.teardown() + assert a._pipeline is None + assert a._vad_model is None + assert a._rng is None + + +# --------------------------------------------------------------------------- +# diarize_batch +# --------------------------------------------------------------------------- + + +class TestDiarizeBatch: + def test_empty_batch_returns_empty_list(self) -> None: + a = PyAnnoteDiarizationAdapter() + assert a.diarize_batch([]) == [] + + def test_diarize_batch_requires_setup(self) -> None: + a = PyAnnoteDiarizationAdapter() + with pytest.raises(RuntimeError, match="setup\\(\\) must be called"): + a.diarize_batch([{"audio_filepath": "/tmp/x.wav"}]) + + def test_missing_audio_filepath_yields_empty_result(self) -> None: + a = PyAnnoteDiarizationAdapter() + a._pipeline = MagicMock() # bypass setup() guard + a._vad_model = MagicMock() + results = a.diarize_batch([{"audio_filepath": None}]) + assert len(results) == 1 + assert results[0].diar_segments == [] + assert results[0].overlap_segments == [] + assert a.last_metrics["batch_size"] == 1.0 + + +# --------------------------------------------------------------------------- +# _add_vad_segments micro-split +# --------------------------------------------------------------------------- + + +class TestAddVadSegments: + def _make_adapter(self) -> PyAnnoteDiarizationAdapter: + a = PyAnnoteDiarizationAdapter(min_length=0.5, max_length=10.0, random_seed=0) + a._vad_model = MagicMock() + a._rng = __import__("random").Random(0) + return a + + def test_short_turn_emits_single_segment(self) -> None: + a = self._make_adapter() + out: list = [] + # 5-sec turn (<= max_length=10) -> no VAD micro-split. + import torch + + audio = torch.zeros(1, 16000 * 12, dtype=torch.float32) + a._add_vad_segments(audio=audio, fs=16000, start=0.0, end=5.0, segments=out, speaker_id="A") + assert len(out) == 1 + assert out[0].speaker == "A" + assert out[0].start == 0.0 + assert out[0].end == 5.0 + + def test_long_turn_triggers_vad_micro_split(self) -> None: + a = self._make_adapter() + out: list = [] + import torch + + audio = torch.zeros(1, 16000 * 30, dtype=torch.float32) + # Adapter VAD returns two windows that each exceed any rand-sample length; + # _add_vad_segments must emit one DiarSegment per window. + a._vad_model.get_vad_segments.return_value = [ + {"start": 0.0, "end": 6.0}, + {"start": 6.5, "end": 12.0}, + ] + a._add_vad_segments(audio=audio, fs=16000, start=10.0, end=25.0, segments=out, speaker_id="B") + assert len(out) == 2 + # Offsets land in absolute clip-coordinate space (start + sub-window start). + assert out[0].start == 10.0 + assert out[0].end == 16.0 + assert out[1].start == 16.5 + assert out[1].end == 22.0 + assert all(seg.speaker == "B" for seg in out) diff --git a/tests/stages/audio/inference/speaker_diarization/test_diarization_stage.py b/tests/stages/audio/inference/speaker_diarization/test_diarization_stage.py new file mode 100644 index 0000000000..4ff5b64b6e --- /dev/null +++ b/tests/stages/audio/inference/speaker_diarization/test_diarization_stage.py @@ -0,0 +1,324 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the generic DiarizationStage. + +The stage is tested with a fake adapter so these tests never touch +PyAnnote and never need a GPU. The PyAnnote-specific code is covered by +``tests/adapters/diarization/test_pyannote_adapter.py``. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import patch + +import pytest + +from nemo_curator.adapters.diarization import DiarizationResult, DiarSegment +from nemo_curator.stages.audio.inference.speaker_diarization import DiarizationStage +from nemo_curator.tasks import AudioTask + + +# --------------------------------------------------------------------------- +# Fake adapter used as the swap target +# --------------------------------------------------------------------------- + + +@dataclass +class _FakeDiarAdapter: + """In-process diarization adapter used as DiarizationStage.adapter_target.""" + + model_id: str = "fake/diar" + revision: str | None = None + device: str = "cpu" + fixed_result: DiarizationResult | None = None + setup_called: int = 0 + teardown_called: int = 0 + last_batch: list[dict[str, Any]] | None = None + last_metrics: dict[str, float] = field(default_factory=dict) + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + del model_id, revision # no-op for the fake + + def setup(self) -> None: + self.setup_called += 1 + + def teardown(self) -> None: + self.teardown_called += 1 + + def diarize_batch(self, items: list[dict[str, Any]]) -> list[DiarizationResult]: + self.last_batch = list(items) + self.last_metrics = {"batch_size": float(len(items))} + if self.fixed_result is not None: + return [self.fixed_result for _ in items] + return [ + DiarizationResult( + diar_segments=[DiarSegment(start=0.0, end=1.0, speaker="spk_0")], + overlap_segments=[], + model_id=self.model_id, + ) + for _ in items + ] + + +_ADAPTER_TARGET = f"{__name__}._FakeDiarAdapter" + + +def _audio_task(**data: Any) -> AudioTask: # noqa: ANN401 + base = {"resampled_audio_filepath": "/tmp/fake.wav", "audio_item_id": "id_1", "duration": 10.0} + base.update(data) + return AudioTask(data=base) + + +# --------------------------------------------------------------------------- +# Construction / validation +# --------------------------------------------------------------------------- + + +class TestDiarizationStageConstruction: + def test_requires_adapter_target(self) -> None: + with pytest.raises(ValueError, match="adapter_target is required"): + DiarizationStage() + + def test_default_io_keys(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + assert stage.audio_filepath_key == "resampled_audio_filepath" + assert stage.segments_key == "segments" + assert stage.overlap_segments_key == "overlap_segments" + assert stage.non_speaker_max_length == 40.0 + + def test_inputs_outputs(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + assert stage.inputs() == ([], ["resampled_audio_filepath"]) + assert stage.outputs() == ([], ["resampled_audio_filepath", "segments", "overlap_segments"]) + + def test_outputs_skips_overlap_key_when_disabled(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET, overlap_segments_key=None) + assert stage.outputs() == ([], ["resampled_audio_filepath", "segments"]) + + +# --------------------------------------------------------------------------- +# Adapter lifecycle (setup_on_node + setup + teardown) +# --------------------------------------------------------------------------- + + +class TestDiarizationStageLifecycle: + def test_setup_instantiates_adapter_with_tier1_and_tier2_kwargs(self) -> None: + stage = DiarizationStage( + adapter_target=_ADAPTER_TARGET, + model_id="pyannote/x", + revision="rev-1", + adapter_kwargs={"fixed_result": None}, + ) + stage.setup() + assert isinstance(stage._adapter, _FakeDiarAdapter) + assert stage._adapter.model_id == "pyannote/x" + assert stage._adapter.revision == "rev-1" + assert stage._adapter.setup_called == 1 + + def test_setup_forwards_device(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + stage.setup() + # Default Resources(gpus=1) -> requires_gpu -> "cuda" + assert stage._adapter.device == "cuda" + + def test_teardown_releases_adapter(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + stage.setup() + adapter = stage._adapter + stage.teardown() + assert stage._adapter is None + assert adapter.teardown_called == 1 + + def test_setup_on_node_calls_prefetch_weights(self) -> None: + with patch.object(_FakeDiarAdapter, "prefetch_weights") as mock_prefetch: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET, model_id="m", revision="r") + stage.setup_on_node() + mock_prefetch.assert_called_once_with("m", "r") + + def test_setup_on_node_propagates_when_fail_on_error(self) -> None: + with patch.object(_FakeDiarAdapter, "prefetch_weights", side_effect=RuntimeError("boom")): + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET, prefetch_fail_on_error=True) + with pytest.raises(RuntimeError, match="prefetch_weights failed"): + stage.setup_on_node() + + def test_setup_on_node_swallows_when_fail_disabled(self) -> None: + with patch.object(_FakeDiarAdapter, "prefetch_weights", side_effect=RuntimeError("boom")): + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET, prefetch_fail_on_error=False) + stage.setup_on_node() # must not raise + + def test_xenna_stage_spec_default_empty(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + assert stage.xenna_stage_spec() == {} + + def test_xenna_stage_spec_with_num_workers(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET, xenna_num_workers=4) + assert stage.xenna_stage_spec() == {"num_workers": 4} + + +# --------------------------------------------------------------------------- +# Processing +# --------------------------------------------------------------------------- + + +class TestDiarizationStageProcess: + def test_process_requires_setup_first(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + with pytest.raises(RuntimeError, match="setup\\(\\) was not called"): + stage.process(_audio_task()) + + def test_process_requires_audio_filepath_key(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + stage.setup() + task = AudioTask(data={"audio_item_id": "id_1"}) + with pytest.raises(ValueError, match="Missing key 'resampled_audio_filepath'"): + stage.process(task) + + def test_process_writes_segments_and_overlap_keys(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + stage.setup() + stage._adapter.fixed_result = DiarizationResult( + diar_segments=[ + DiarSegment(start=0.0, end=2.0, speaker="id_1_spk_0"), + DiarSegment(start=4.0, end=6.0, speaker="id_1_spk_1"), + ], + overlap_segments=[DiarSegment(start=3.0, end=4.0, speaker="id_1_spk_0")], + model_id="pyannote/x", + ) + task = stage.process(_audio_task(duration=10.0)) + segs = task.data["segments"] + # 2 real turns + 3 no-speaker gap fills (0-0, 2-4, 6-10) + speaker_turns = [s for s in segs if s["speaker"] != "no-speaker"] + assert len(speaker_turns) == 2 + no_speakers = [s for s in segs if s["speaker"] == "no-speaker"] + assert len(no_speakers) >= 2 + assert task.data["overlap_segments"] == [ + {"speaker": "id_1_spk_0", "start": 3.0, "end": 4.0} + ] + + def test_process_skips_overlap_key_when_disabled(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET, overlap_segments_key=None) + stage.setup() + stage._adapter.fixed_result = DiarizationResult( + diar_segments=[DiarSegment(start=0.0, end=2.0, speaker="spk_0")], + overlap_segments=[DiarSegment(start=5.0, end=6.0, speaker="spk_0")], + model_id="x", + ) + task = stage.process(_audio_task(duration=10.0)) + assert "overlap_segments" not in task.data + assert "segments" in task.data + + def test_process_non_speaker_max_length_chunks_long_gaps(self) -> None: + stage = DiarizationStage( + adapter_target=_ADAPTER_TARGET, + non_speaker_max_length=5.0, + ) + stage.setup() + stage._adapter.fixed_result = DiarizationResult( + diar_segments=[DiarSegment(start=0.0, end=1.0, speaker="spk_0")], + overlap_segments=[], + model_id="x", + ) + task = stage.process(_audio_task(duration=30.0)) + no_speakers = [s for s in task.data["segments"] if s["speaker"] == "no-speaker"] + # 1s..30s = 29s gap, chunked at 5s -> ceil(29/5)=6 chunks. + assert len(no_speakers) == 6 + for ns in no_speakers: + assert ns["end"] - ns["start"] <= 5.0 + 1e-9 + + def test_process_forwards_item_dict_to_adapter(self) -> None: + stage = DiarizationStage( + adapter_target=_ADAPTER_TARGET, + waveform_key="waveform", + sample_rate_key="sample_rate", + ) + stage.setup() + task = _audio_task( + duration=8.0, + waveform=[0.0, 1.0], + sample_rate=16000, + ) + stage.process(task) + batch = stage._adapter.last_batch + assert batch is not None and len(batch) == 1 + item = batch[0] + assert item["audio_filepath"] == "/tmp/fake.wav" + assert item["audio_item_id"] == "id_1" + assert item["duration"] == 8.0 + assert item["waveform"] == [0.0, 1.0] + assert item["sample_rate"] == 16000 + + def test_process_uses_get_audio_duration_when_missing(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + stage.setup() + with patch( + "nemo_curator.stages.audio.inference.speaker_diarization.stage.get_audio_duration", + return_value=12.0, + ) as mock_dur: + task = AudioTask( + data={"resampled_audio_filepath": "/tmp/fake.wav", "audio_item_id": "id_1"} + ) + stage.process(task) + mock_dur.assert_called_once_with("/tmp/fake.wav") + + def test_process_uses_default_diar_segment_when_adapter_returns_empty(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + stage.setup() + stage._adapter.fixed_result = DiarizationResult( + diar_segments=[], overlap_segments=[], model_id="x" + ) + task = stage.process(_audio_task(duration=4.0)) + # Whole clip is one big no-speaker block, chunked at default 40.0s. + no_speakers = [s for s in task.data["segments"] if s["speaker"] == "no-speaker"] + assert len(no_speakers) == 1 + assert no_speakers[0]["start"] == 0.0 + assert no_speakers[0]["end"] == 4.0 + + +# --------------------------------------------------------------------------- +# Metric logging +# --------------------------------------------------------------------------- + + +class TestDiarizationStageMetrics: + def test_log_metrics_includes_adapter_aliases(self) -> None: + stage = DiarizationStage(adapter_target=_ADAPTER_TARGET) + stage.setup() + observed: list[dict[str, float]] = [] + + def capture(metrics: dict[str, float]) -> None: + observed.append(metrics) + + stage._log_metrics = capture # type: ignore[assignment] + stage._adapter.fixed_result = DiarizationResult( + diar_segments=[ + DiarSegment(start=0.0, end=1.0, speaker="A"), + DiarSegment(start=2.0, end=3.0, speaker="B"), + ], + overlap_segments=[DiarSegment(start=1.5, end=2.0, speaker="A")], + model_id="x", + ) + stage.process(_audio_task(duration=5.0)) + + assert observed, "log_metrics must be called" + m = observed[-1] + assert m["speakers_detected"] == 2.0 + assert m["overlap_segments_detected"] == 1.0 + assert m["audio_duration"] == 5.0 + # Adapter-side last_metrics keys must be prefixed with "model_". + assert m["model_batch_size"] == 1.0 + assert "process_time" in m diff --git a/tests/stages/audio/inference/speaker_diarization/test_pyannote.py b/tests/stages/audio/inference/speaker_diarization/test_pyannote.py deleted file mode 100644 index 8415c728ff..0000000000 --- a/tests/stages/audio/inference/speaker_diarization/test_pyannote.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from pathlib import Path - -import pytest - -from nemo_curator.stages.audio.inference.speaker_diarization.pyannote import PyAnnoteDiarizationStage, has_overlap -from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import AudioTask - -hf_token = os.getenv("HF_TOKEN") - - -class TestPyannoteHasOverlap: - """Tests for has_overlap helper.""" - - def test_turn_overlaps_with_segment(self) -> None: - """Turn that overlaps an overlap segment returns True.""" - - class Turn: - start = 0.0 - end = 2.0 - - class Overlap: - start = 1.0 - end = 1.5 - - turn = Turn() - overlaps = [Overlap()] - assert has_overlap(turn, overlaps) is True - - def test_turn_after_overlap_returns_false(self) -> None: - """Turn entirely after overlap returns False.""" - - class Turn: - start = 3.0 - end = 4.0 - - class Overlap: - start = 1.0 - end = 2.0 - - turn = Turn() - overlaps = [Overlap()] - assert has_overlap(turn, overlaps) is False - - def test_turn_before_overlap_returns_false(self) -> None: - """Turn entirely before overlap returns False.""" - - class Turn: - start = 0.0 - end = 0.5 - - class Overlap: - start = 1.0 - end = 2.0 - - turn = Turn() - overlaps = [Overlap()] - assert has_overlap(turn, overlaps) is False - - def test_empty_overlaps_returns_false(self) -> None: - """Empty overlaps list returns False.""" - - class Turn: - start = 0.0 - end = 1.0 - - turn = Turn() - assert has_overlap(turn, []) is False - - -class TestPyAnnoteDiarizationStage: - """Tests for PyAnnoteDiarizationStage.""" - - @pytest.mark.gpu - @pytest.mark.skipif(not hf_token, reason="HF_TOKEN not set") - def test_process(self, wav_filepath: Path) -> None: - """Process a single entry for diarization.""" - stage = PyAnnoteDiarizationStage(hf_token=hf_token, resources=Resources(gpus=1)) - stage.setup_on_node() - stage.setup() - data_entry = { - "resampled_audio_filepath": str(wav_filepath), - "audio_item_id": "id_1", - "duration": 60.0, - } - task = AudioTask(data=data_entry) - result = stage.process(task) - assert result.data["resampled_audio_filepath"] == str(wav_filepath) - segments = result.data["segments"] - assert len(segments) == 33 - # assert len(segments) < 100, "Sanity check: too many segments suggests an issue" - for segment in segments: - assert "start" in segment, "Segment should have start time" - assert "end" in segment, "Segment should have end time" - assert segment["start"] < segment["end"], "Start should be before end" - assert 0 <= segment["start"] <= 60.0, "Start within audio duration" - assert 0 <= segment["end"] <= 60.0, "End within audio duration" diff --git a/tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml b/tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml index fce61ffbab..3c1c506b1e 100644 --- a/tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml +++ b/tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml @@ -30,16 +30,20 @@ stages: target_nchannels: 1 resources: ${resources} - # 2. Speaker diarization (PyAnnote) - - _target_: nemo_curator.stages.audio.inference.speaker_diarization.pyannote.PyAnnoteDiarizationStage + # 2. Speaker diarization (DiarizationStage + PyAnnoteDiarizationAdapter) + - _target_: nemo_curator.stages.audio.inference.speaker_diarization.DiarizationStage name: "PyAnnoteDiarization" - hf_token: ${hf_token} - max_length: ${max_segment_length} - segmentation_batch_size: 2 - embedding_batch_size: 2 + adapter_target: nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter + model_id: "pyannote/speaker-diarization-3.1" + non_speaker_max_length: ${max_segment_length} xenna_num_workers: 1 resources: gpus: 1 + adapter_kwargs: + hf_token: ${hf_token} + max_length: ${max_segment_length} + segmentation_batch_size: 2 + embedding_batch_size: 2 # 3. Split long audio segments - _target_: nemo_curator.stages.audio.tagging.SplitLongAudioStage diff --git a/tutorials/audio/tagging/README.md b/tutorials/audio/tagging/README.md index 10f5f2e60c..dff42b3a5a 100644 --- a/tutorials/audio/tagging/README.md +++ b/tutorials/audio/tagging/README.md @@ -26,7 +26,7 @@ The audio tagging pipeline is a processing framework that takes raw audio files |---|-------|-------------|-----| | 0 | **ManifestReader** | Reads input JSONL manifest | No | | 1 | **ResampleAudioStage** | Resample to 16 kHz mono WAV | No | -| 2 | **PyAnnoteDiarizationStage** | Speaker diarization and overlap detection | Yes | +| 2 | **DiarizationStage** (+ `PyAnnoteDiarizationAdapter`) | Speaker diarization and overlap detection | Yes | | 3 | **SplitLongAudioStage** | Split segments exceeding max length | No | | 4 | **NeMoASRAlignerStage** | Forced alignment via NeMo FastConformer | Yes | | 5 | **JoinSplitAudioMetadataStage** | Rejoin split audio metadata | No | @@ -161,8 +161,9 @@ python tutorials/audio/tagging/main.py \ Override individual stage parameters using their index in the `stages` list: ```bash -# Change diarization model -stages.2.diarization_model=pyannote/speaker-diarization-3.1 +# Change diarization model (Tier-1 swap line) +stages.2.adapter_target=nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter +stages.2.model_id=pyannote/speaker-diarization-3.1 # Adjust ASR batch size stages.4.batch_size=16 @@ -315,7 +316,7 @@ See the test file for detailed comments on the pipeline steps and configuration ### GPU Out of Memory - Reduce `stages.4.batch_size` (ASR alignment) -- Reduce `stages.2.segmentation_batch_size` (diarization) +- Reduce `stages.2.adapter_kwargs.segmentation_batch_size` (diarization) - Process fewer files per manifest - See [GPU Memory Requirements](#gpu-memory-requirements) for per-model VRAM usage diff --git a/tutorials/audio/tagging/tts_pipeline.yaml b/tutorials/audio/tagging/tts_pipeline.yaml index 1096c24547..7fe61c14c2 100644 --- a/tutorials/audio/tagging/tts_pipeline.yaml +++ b/tutorials/audio/tagging/tts_pipeline.yaml @@ -21,7 +21,7 @@ # Pipeline stages: # 0. ManifestReader - Read input JSONL manifest # 1. ResampleAudio - Resample audio to 16 kHz mono WAV -# 2. PyAnnoteDiarization - Speaker diarization + overlap detection +# 2. DiarizationStage - Speaker diarization + overlap detection (PyAnnote adapter) # 3. SplitLongAudio - Split segments exceeding max length # 4. NeMoASRAligner - Forced alignment via NeMo FastConformer # 5. JoinSplitAudioMetadata - Rejoin split metadata @@ -101,10 +101,14 @@ stages: target_nchannels: 1 resources: ${resources} - - _target_: nemo_curator.stages.audio.inference.speaker_diarization.pyannote.PyAnnoteDiarizationStage + - _target_: nemo_curator.stages.audio.inference.speaker_diarization.DiarizationStage name: "PyAnnoteDiarization" - hf_token: ${hf_token} - max_length: ${max_segment_length} + adapter_target: nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter + model_id: "pyannote/speaker-diarization-3.1" + non_speaker_max_length: ${max_segment_length} + adapter_kwargs: + hf_token: ${hf_token} + max_length: ${max_segment_length} - _target_: nemo_curator.stages.audio.tagging.SplitLongAudioStage name: "SplitLongAudio" From 3be716e0b50f096ed7ac6a337092e4112ce40ed1 Mon Sep 17 00:00:00 2001 From: "aaftaabv@gmail.com" Date: Fri, 29 May 2026 15:18:39 +0530 Subject: [PATCH 2/6] audio/vad: stage-adapter split per SDP-V2 design Bring the SDP-V2 design doc (sec 4 voice activity detection) stage-adapter split into the audio tagging pipeline. What moves ---------- * nemo_curator/adapters/vad/ - base.py - VADInterval + VADResult dataclasses + VADAdapter Protocol (model_id / revision / setup / teardown / prefetch_weights classmethod / detect_batch). - whisperx.py - WhisperXVADAdapter that re-houses the WhisperX VAD code path from the deleted WhisperXVADStage process() body. Uses the same WhisperXVADModel helper as before (kept at its existing nemo_curator/stages/audio/inference/vad/whisperx_vad.py home - PyAnnoteDiarizationAdapter still consumes it for sub-segment VAD, so the helper stays as shared infra). * nemo_curator/stages/audio/inference/vad/stage.py with VADStage that owns Curator-side glue only: task.data key reads, item-dict construction, adapter dispatch, VADInterval to on-disk dict conversion, metric logging. YAML shape (Tier-1 / Tier-2 split) ---------------------------------- - _target_: nemo_curator.stages.audio.inference.vad.VADStage name: WhisperXVAD adapter_target: nemo_curator.adapters.vad.WhisperXVADAdapter model_id: whisperx/vad adapter_kwargs: vad_onset: 0.5 vad_offset: 0.363 min_length: 0.5 max_length: 40.0 What is deleted --------------- * WhisperXVADStage class (removed from nemo_curator/stages/audio/inference/vad/whisperx_vad.py; the file is now just the WhisperXVADModel helper). * tests/stages/audio/inference/vad/test_whisperx_vad.py Call-site migrations -------------------- * nemo_curator/stages/audio/inference/vad/__init__.py now re-exports VADStage (new) and WhisperXVADModel (helper still used by both adapters). * nemo_curator/stages/audio/tagging/__init__.py lazy-import map drops WhisperXVADStage; adds VADStage and WhisperXVADAdapter. New tests --------- * tests/stages/audio/inference/vad/test_vad_stage.py - 14 CPU tests against a fake adapter (construction, lifecycle, process(), metric logging, no GPU / WhisperX needed). * tests/adapters/vad/test_whisperx_adapter.py - WhisperXVADAdapter construction, prefetch_weights, setup/teardown lifecycle (WhisperXVADModel mocked), detect_batch empty / missing-filepath / short-clip-skip / happy-path with mocked sf.read and WhisperX VAD model. Behaviour preservation ---------------------- * Skip-short-clip rule, vad_onset/vad_offset thresholds, max_length chunk-merge - all preserved from pre-split WhisperXVADStage. * PyAnnoteDiarizationAdapter (commit 1) still constructs a WhisperXVADModel for its long-turn sub-VAD path; behaviour unchanged. * No tutorials/audio/tagging/tts_pipeline.yaml migration is needed because the tagging tutorial does not wire WhisperXVADStage standalone - it only enters via the PyAnnote diarization adapter. Follow-ups (out of scope; same posture as PR1967) ------------------------------------------------- * SileroVADAdapter / PyAnnoteVADAdapter - the stage is now shaped for them; they ship in a separate commit. Signed-off-by: aaftaabv@gmail.com --- nemo_curator/adapters/vad/__init__.py | 37 +++ nemo_curator/adapters/vad/base.py | 121 +++++++++ nemo_curator/adapters/vad/whisperx.py | 171 +++++++++++++ .../stages/audio/inference/vad/__init__.py | 7 +- .../stages/audio/inference/vad/stage.py | 237 ++++++++++++++++++ .../audio/inference/vad/whisperx_vad.py | 123 ++------- nemo_curator/stages/audio/tagging/__init__.py | 3 +- tests/adapters/vad/__init__.py | 0 tests/adapters/vad/test_whisperx_adapter.py | 119 +++++++++ .../audio/inference/vad/test_vad_stage.py | 201 +++++++++++++++ .../audio/inference/vad/test_whisperx_vad.py | 44 ---- 11 files changed, 911 insertions(+), 152 deletions(-) create mode 100644 nemo_curator/adapters/vad/__init__.py create mode 100644 nemo_curator/adapters/vad/base.py create mode 100644 nemo_curator/adapters/vad/whisperx.py create mode 100644 nemo_curator/stages/audio/inference/vad/stage.py create mode 100644 tests/adapters/vad/__init__.py create mode 100644 tests/adapters/vad/test_whisperx_adapter.py create mode 100644 tests/stages/audio/inference/vad/test_vad_stage.py delete mode 100644 tests/stages/audio/inference/vad/test_whisperx_vad.py diff --git a/nemo_curator/adapters/vad/__init__.py b/nemo_curator/adapters/vad/__init__.py new file mode 100644 index 0000000000..07503f913e --- /dev/null +++ b/nemo_curator/adapters/vad/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""VAD adapter family for the SDP-V2 stage-adapter split. + +Public surface (the only symbols the stage imports): + +* :class:`VADAdapter` - structural protocol every VAD adapter + implements. +* :class:`VADResult` - canonical per-utterance result dataclass. +* :class:`VADInterval` - canonical per-segment dataclass. + +Concrete adapters live in their own modules (e.g. ``whisperx.py``, +``silero.py``) and are resolved at runtime by their fully-qualified +class path in YAML's ``adapter_target`` field. +""" + +from nemo_curator.adapters.vad.base import VADAdapter, VADInterval, VADResult +from nemo_curator.adapters.vad.whisperx import WhisperXVADAdapter + +__all__ = [ + "VADAdapter", + "VADInterval", + "VADResult", + "WhisperXVADAdapter", +] diff --git a/nemo_curator/adapters/vad/base.py b/nemo_curator/adapters/vad/base.py new file mode 100644 index 0000000000..22eb26e238 --- /dev/null +++ b/nemo_curator/adapters/vad/base.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage-adapter contract for voice activity detection (SDP-V2 design doc §4). + +Mirrors the diarization / ASR contract pattern: + +* :class:`~nemo_curator.stages.audio.inference.vad.VADStage` owns the + Curator-side glue (``task.data`` reads, duration-skip rule, metric + logging). +* :class:`VADAdapter` owns the model-side library call (weight prefetch, + model setup, the actual VAD invocation - WhisperX / Silero / PyAnnote + VAD - and packing results into the canonical :class:`VADResult` shape). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + + +@dataclass +class VADInterval: + """Canonical per-VAD-region dataclass. + + Attributes: + start: Interval start time in seconds (clip coordinates). + end: Interval end time in seconds. + """ + + start: float + end: float + + +@dataclass +class VADResult: + """Canonical per-task VAD adapter output. + + Attributes: + intervals: One :class:`VADInterval` per emitted speech region. + Empty list when the adapter could not process the input. + model_id: The actual model identifier the adapter ran (mirrors + the stage's ``model_id`` field; populated by the adapter). + extras: Adapter-specific scalar / structured diagnostics that + do not fit the canonical shape. Stage never reads inside + this dict. + """ + + intervals: list[VADInterval] + model_id: str = "" + extras: dict[str, Any] = field(default_factory=dict) + + +@runtime_checkable +class VADAdapter(Protocol): + """Structural protocol every VAD adapter must implement. + + Constructor contract: adapters are constructed by the stage as + ``cls(model_id=..., revision=..., **adapter_kwargs)``. Tier-2 knobs + are adapter-specific. + + Per-batch contract: :meth:`detect_batch` receives a list of dicts + (Tier-3 per-task knobs unpacked from ``task.data`` by the stage) + and returns one :class:`VADResult` per input, in the same order. + + Expected per-item keys (the stage populates these; the adapter + reads whichever it needs): + + * ``audio_filepath`` (``str``): Path to a decodable audio file. + * ``waveform`` (``numpy.ndarray | None``): Optional in-memory + waveform. + * ``sample_rate`` (``int | None``): Sample rate of ``waveform``. + * ``duration`` (``float | None``): Optional clip duration in + seconds. + * ``task_id`` (``str | None``): Diagnostic only. + + Attributes: + model_id: Identifier of the underlying model checkpoint. + last_metrics: Scalar metrics from the last :meth:`detect_batch` + call. The stage merges these into ``_log_metrics`` output + under ``model_`` aliases. + """ + + model_id: str + last_metrics: dict[str, float] + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + """Download weights to local cache without allocating a GPU.""" + ... + + def setup(self) -> None: + """Load the model into the worker's process.""" + ... + + def teardown(self) -> None: + """Release GPU memory and worker-local state.""" + ... + + def detect_batch(self, items: list[dict[str, Any]]) -> list[VADResult]: + """Run VAD on a batch of per-task dicts. + + Args: + items: One dict per task with the keys documented on the + class docstring. Length matches the batch size. + + Returns: + One :class:`VADResult` per input, in the same order. + """ + ... diff --git a/nemo_curator/adapters/vad/whisperx.py b/nemo_curator/adapters/vad/whisperx.py new file mode 100644 index 0000000000..5e366131ae --- /dev/null +++ b/nemo_curator/adapters/vad/whisperx.py @@ -0,0 +1,171 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WhisperX VAD adapter. + +Implements :class:`~nemo_curator.adapters.vad.VADAdapter` on top of +WhisperX's ``Pyannote.merge_chunks``-based VAD helper. The underlying +:class:`~nemo_curator.stages.audio.inference.vad.whisperx_vad.WhisperXVADModel` +class is kept where it is - PyAnnote diarization also depends on it +for its long-turn micro-split path - and this adapter just wraps it +behind the canonical Protocol. + +Logic moved verbatim from the pre-split +``nemo_curator.stages.audio.inference.vad.whisperx_vad.WhisperXVADStage`` +process() body. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import soundfile as sf +from loguru import logger + +from nemo_curator.adapters.vad.base import VADInterval, VADResult +from nemo_curator.stages.audio.common import get_audio_duration +from nemo_curator.stages.audio.inference.vad.whisperx_vad import WhisperXVADModel + + +@dataclass +class WhisperXVADAdapter: + """WhisperX-backed implementation of :class:`VADAdapter`. + + Attributes: + model_id: Identifier for diagnostics (WhisperX VAD doesn't use + a HF id; default ``"whisperx/vad"`` is a label only). + revision: Accepted for protocol uniformity; not used by + WhisperX VAD. + device: ``"cuda"`` or ``"cpu"``. The stage passes the worker's + actual device. + vad_onset: Onset probability threshold forwarded to WhisperX. + vad_offset: Offset probability threshold forwarded to WhisperX. + max_length: Maximum chunk length passed to + :meth:`WhisperXVADModel.get_vad_segments`. + min_length: Minimum clip duration; clips shorter than this are + skipped entirely (empty :class:`VADResult`). Matches the + pre-split ``WhisperXVADStage`` behaviour. + """ + + # ---- Required protocol fields ---- + model_id: str = "whisperx/vad" + revision: str | None = None + + # ---- Adapter-specific knobs ---- + device: str = "cuda" + vad_onset: float = 0.5 + vad_offset: float = 0.363 + max_length: float = 40.0 + min_length: float = 0.5 + + # ---- Internal state ---- + _vad_model: Any = field(default=None, repr=False) + last_metrics: dict[str, float] = field(default_factory=dict) + + # ------------------------------------------------------------------ + # Adapter contract + # ------------------------------------------------------------------ + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + """Force a CPU-side download of the WhisperX VAD weights. + + WhisperX VAD downloads on first instantiation; we trigger that + once on the node by constructing the model on CPU and then + dropping it. The actual GPU placement happens in + :meth:`setup`. + """ + del model_id, revision # WhisperX VAD has no public model_id knob. + _ = WhisperXVADModel(device="cpu") + + def setup(self) -> None: + if self._vad_model is None: + self._vad_model = WhisperXVADModel( + device=self.device, + vad_onset=self.vad_onset, + vad_offset=self.vad_offset, + ) + self._vad_model.to(self.device) + logger.info("WhisperXVADAdapter ready on {} (model_id={})", self.device, self.model_id) + + def teardown(self) -> None: + self._vad_model = None + + def detect_batch(self, items: list[dict[str, Any]]) -> list[VADResult]: + if not items: + return [] + if self._vad_model is None: + msg = "WhisperXVADAdapter.setup() must be called before detect_batch()" + raise RuntimeError(msg) + + results: list[VADResult] = [] + per_item_times: list[float] = [] + per_item_skip: list[float] = [] + per_item_count: list[int] = [] + + for item in items: + t0 = time.perf_counter() + result, skipped = self._detect_one(item) + results.append(result) + per_item_times.append(time.perf_counter() - t0) + per_item_skip.append(1.0 if skipped else 0.0) + per_item_count.append(len(result.intervals)) + + self.last_metrics = { + "batch_size": float(len(items)), + "vad_time_s_total": float(sum(per_item_times)), + "vad_time_s_max": float(max(per_item_times)) if per_item_times else 0.0, + "skipped_short_total": float(sum(per_item_skip)), + "vad_intervals_detected_total": float(sum(per_item_count)), + } + return results + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _detect_one(self, item: dict[str, Any]) -> tuple[VADResult, bool]: + audio_filepath = item.get("audio_filepath") + if not audio_filepath: + return VADResult(intervals=[], model_id=self.model_id), True + + duration = item.get("duration") + if duration is None: + duration = get_audio_duration(audio_filepath) + if duration < self.min_length: + logger.warning( + "Skipping {} because it is less than {} seconds", audio_filepath, self.min_length + ) + return ( + VADResult(intervals=[], model_id=self.model_id, extras={"duration_s": float(duration)}), + True, + ) + + data, sr = sf.read(audio_filepath, dtype="float32") + audio = np.expand_dims(data, axis=0) if data.ndim == 1 else data.T + raw_segments = self._vad_model.get_vad_segments(audio, self.max_length, sample_rate=sr) + intervals = [ + VADInterval(start=float(seg["start"]), end=float(seg["end"])) for seg in raw_segments + ] + return ( + VADResult( + intervals=intervals, + model_id=self.model_id, + extras={"duration_s": float(duration)}, + ), + False, + ) diff --git a/nemo_curator/stages/audio/inference/vad/__init__.py b/nemo_curator/stages/audio/inference/vad/__init__.py index bedb242969..410d2b47d7 100644 --- a/nemo_curator/stages/audio/inference/vad/__init__.py +++ b/nemo_curator/stages/audio/inference/vad/__init__.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Voice Activity Detection inference stages.""" +"""Generic VAD stage (SDP-V2 stage-adapter split, §4) and WhisperX helper model.""" -from nemo_curator.stages.audio.inference.vad.whisperx_vad import WhisperXVADModel, WhisperXVADStage +from nemo_curator.stages.audio.inference.vad.stage import VADStage +from nemo_curator.stages.audio.inference.vad.whisperx_vad import WhisperXVADModel -__all__ = ["WhisperXVADModel", "WhisperXVADStage"] +__all__ = ["VADStage", "WhisperXVADModel"] diff --git a/nemo_curator/stages/audio/inference/vad/stage.py b/nemo_curator/stages/audio/inference/vad/stage.py new file mode 100644 index 0000000000..591407cd89 --- /dev/null +++ b/nemo_curator/stages/audio/inference/vad/stage.py @@ -0,0 +1,237 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic VAD Curator stage (SDP-V2 design doc §4). + +Implements the stage half of the SDP-V2 stage-adapter split for the +VAD family. The stage owns Curator-side glue: + +* validates ``task.data`` against ``inputs()`` / ``outputs()``; +* unpacks per-task knobs (audio filepath, optional in-memory waveform, + ``duration``) into a single item dict; +* dispatches the adapter's ``detect_batch`` once per task; +* writes the adapter's :class:`VADInterval` list onto + ``task.data[segments_key]`` as canonical ``{"start": float, + "end": float}`` dicts; +* emits performance metrics in the shape ``perf_summary_merged.json`` + consumers already expect. + +The stage knows nothing about which VAD model is running. The concrete +adapter class is resolved at runtime from the YAML's ``adapter_target`` +string via ``hydra.utils.get_class``. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import hydra.utils +from loguru import logger + +from nemo_curator.adapters.vad.base import VADAdapter, VADResult +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + +if TYPE_CHECKING: + from nemo_curator.backends.base import NodeInfo, WorkerMetadata + + +@dataclass +class VADStage(ProcessingStage[AudioTask, AudioTask]): + """Voice-activity-detection Curator stage with pluggable adapter. + + Args: + adapter_target: Tier-1 swap surface. Fully-qualified class path + of the concrete :class:`~nemo_curator.adapters.vad.VADAdapter` + implementation (e.g. + ``"nemo_curator.adapters.vad.WhisperXVADAdapter"``). + Resolved at ``setup()`` time via ``hydra.utils.get_class``. + model_id: Tier-1. Model checkpoint identifier, forwarded to + :meth:`VADAdapter.prefetch_weights` and to the adapter + constructor. + revision: Tier-1. Optional model revision to pin. + audio_filepath_key: Key into ``task.data`` carrying the audio + path. Defaults to ``"resampled_audio_filepath"``. + waveform_key / sample_rate_key: Optional keys for in-memory + waveform reuse. + segments_key: Key under which the canonical VAD interval list + is written. + prefetch_fail_on_error: When False, ``setup_on_node`` warns and + defers weight prefetch to ``setup()`` instead of raising. + adapter_kwargs: Tier-2. Opaque dict forwarded to the adapter + constructor as ``**adapter_kwargs``. The stage NEVER reads + inside this dict. + resources / xenna_num_workers: Standard Curator stage knobs. + """ + + name: str = "VAD" + + # ---- Tier 1: swap surface ---- + adapter_target: str = "" + model_id: str = "" + revision: str | None = None + + # ---- Tier 1: universal stage knobs ---- + audio_filepath_key: str = "resampled_audio_filepath" + waveform_key: str | None = None + sample_rate_key: str | None = None + segments_key: str = "vad_segments" + + prefetch_fail_on_error: bool = True + + # ---- Tier 2: opaque adapter knob bag ---- + adapter_kwargs: dict[str, Any] = field(default_factory=dict) + + # ---- Standard Curator stage knobs ---- + resources: Resources = field(default_factory=lambda: Resources(gpus=1)) + xenna_num_workers: int | None = None + + # ---- Internal state ---- + _adapter: VADAdapter | None = field(default=None, init=False, repr=False) + + def __post_init__(self) -> None: + if not self.adapter_target: + msg = ( + "VADStage.adapter_target is required - set it in YAML to a fully-qualified " + "adapter class path (e.g. 'nemo_curator.adapters.vad.WhisperXVADAdapter')." + ) + raise ValueError(msg) + + # ------------------------------------------------------------------ + # I/O contract + # ------------------------------------------------------------------ + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key, self.segments_key] + + def xenna_stage_spec(self) -> dict[str, Any]: + spec: dict[str, Any] = {} + if self.xenna_num_workers is not None: + spec["num_workers"] = self.xenna_num_workers + return spec + + @property + def _device(self) -> str: + return "cuda" if self.resources.requires_gpu else "cpu" + + # ------------------------------------------------------------------ + # Adapter lifecycle + # ------------------------------------------------------------------ + + def _adapter_class(self) -> type: + return hydra.utils.get_class(self.adapter_target) + + def setup_on_node( + self, + _node_info: NodeInfo | None = None, + _worker_metadata: WorkerMetadata | None = None, + ) -> None: + try: + prefetch_t0 = time.perf_counter() + self._adapter_class().prefetch_weights(self.model_id, self.revision) + logger.info( + "VAD weights cached on node for {} ({}) in {:.3f}s", + self.model_id, + self.adapter_target, + time.perf_counter() - prefetch_t0, + ) + except Exception as exc: # noqa: BLE001 + msg = f"VADStage: prefetch_weights failed for {self.model_id}" + if self.prefetch_fail_on_error: + raise RuntimeError(msg) from exc + logger.warning("{}; setup() will retry: {}", msg, exc) + + def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: + if self._adapter is None: + cls = self._adapter_class() + kwargs = dict(self.adapter_kwargs) + if self.model_id: + kwargs.setdefault("model_id", self.model_id) + kwargs.setdefault("revision", self.revision) + kwargs.setdefault("device", self._device) + self._adapter = cls(**kwargs) + self._adapter.setup() + logger.info("[{}] VAD adapter ready ({})", self.name, self.adapter_target) + + def teardown(self) -> None: + if self._adapter is not None: + self._adapter.teardown() + self._adapter = None + + # ------------------------------------------------------------------ + # Processing + # ------------------------------------------------------------------ + + def _build_item(self, task: AudioTask) -> dict[str, Any]: + data = task.data + item: dict[str, Any] = { + "audio_filepath": data.get(self.audio_filepath_key), + "duration": data.get("duration"), + "task_id": getattr(task, "task_id", None), + } + if self.waveform_key: + item["waveform"] = data.get(self.waveform_key) + if self.sample_rate_key: + item["sample_rate"] = data.get(self.sample_rate_key) + return item + + @staticmethod + def _interval_to_dict(interval: Any) -> dict[str, float]: + return {"start": float(interval.start), "end": float(interval.end)} + + def process(self, task: AudioTask) -> AudioTask: + t0 = time.perf_counter() + data_entry = task.data + + if self._adapter is None: + msg = "Adapter not initialized - setup() was not called" + raise RuntimeError(msg) + + file_path = data_entry.get(self.audio_filepath_key) + if not file_path: + msg = ( + f"[{self.name}] Missing key '{self.audio_filepath_key}' in entry: " + f"{data_entry.get('audio_item_id', 'unknown')}" + ) + raise ValueError(msg) + + item = self._build_item(task) + results: list[VADResult] = self._adapter.detect_batch([item]) + result = results[0] if results else VADResult(intervals=[], model_id=self.model_id) + + intervals = [self._interval_to_dict(iv) for iv in result.intervals] + data_entry[self.segments_key] = intervals + + duration = float(item.get("duration") or 0.0) + if duration == 0.0: + duration = float(result.extras.get("duration_s", 0.0) or 0.0) + + metrics: dict[str, float] = { + "process_time": time.perf_counter() - t0, + "audio_duration": duration, + "vad_segments_detected": float(len(intervals)), + "skipped_short": float(result.extras.get("skipped_short", 0.0)) + if "skipped_short" in result.extras + else (1.0 if not intervals and result.extras.get("duration_s", 0.0) < 0 else 0.0), + } + for key, value in (self._adapter.last_metrics or {}).items(): + metrics[f"model_{key}"] = float(value) + self._log_metrics(metrics) + return task diff --git a/nemo_curator/stages/audio/inference/vad/whisperx_vad.py b/nemo_curator/stages/audio/inference/vad/whisperx_vad.py index 1cbe42d0a1..442736fe25 100644 --- a/nemo_curator/stages/audio/inference/vad/whisperx_vad.py +++ b/nemo_curator/stages/audio/inference/vad/whisperx_vad.py @@ -12,37 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -WhisperX VAD for NeMo Curator. +"""WhisperX VAD model helper. + +Provides :class:`WhisperXVADModel`, the shared inference wrapper around +WhisperX's ``Pyannote.merge_chunks``-driven VAD. Two callers consume it: + +* :class:`~nemo_curator.adapters.vad.WhisperXVADAdapter` - the VAD + adapter used by :class:`~nemo_curator.stages.audio.inference.vad.VADStage`. +* :class:`~nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter` - + uses it to micro-split long PyAnnote speaker turns. -Provides WhisperXVADModel (shared VAD logic for pyannote and standalone VAD) -and WhisperXVADStage (ProcessingStage for VAD-only pipeline use). +The pre-split ``WhisperXVADStage`` lived here too; it was removed in +favour of the SDP-V2 stage-adapter split. """ import os -import time -from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING -import numpy as np -import soundfile as sf import torch -from loguru import logger from whisperx.audio import SAMPLE_RATE from whisperx.vads.pyannote import Pyannote, load_vad_model -from nemo_curator.backends.base import NodeInfo, WorkerMetadata -from nemo_curator.stages.audio.common import get_audio_duration -from nemo_curator.stages.base import ProcessingStage -from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import AudioTask +if TYPE_CHECKING: + import numpy as np class WhisperXVADModel: - """Shared VAD model and get_vad_segments logic for PyAnnote and standalone VAD. + """Shared VAD model and ``get_vad_segments`` logic. - Used by PyAnnoteDiarizationStage for sub-segment VAD and by WhisperXVADStage - for VAD-only processing. + Used by :class:`~nemo_curator.adapters.vad.WhisperXVADAdapter` for + standalone VAD and by + :class:`~nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter` + for sub-segment VAD of long speaker turns. """ def __init__( @@ -100,89 +101,3 @@ def get_vad_segments( } ) return Pyannote.merge_chunks(vad_segments, merge_max_length, onset=self._vad_onset) - - -@dataclass -class WhisperXVADStage(ProcessingStage[AudioTask, AudioTask]): - """ - Stage that performs Voice Activity Detection (VAD) using WhisperX's VAD model. - - Adds VAD segments to each entry under segments_key (e.g. "vad_segments"). - Entries shorter than min_length are skipped (not emitted). - """ - - min_length: float = 0.5 - max_length: float = 40.0 - vad_onset: float = 0.5 - vad_offset: float = 0.363 - segments_key: str = "vad_segments" - audio_filepath_key: str = "resampled_audio_filepath" - - name: str = "WhisperXVAD" - resources: Resources = field(default_factory=lambda: Resources(gpus=1)) - - _vad_model: Any = field(default=None, repr=False) - - def inputs(self) -> tuple[list[str], list[str]]: - return [], [self.audio_filepath_key] - - def outputs(self) -> tuple[list[str], list[str]]: - return [], [self.audio_filepath_key, self.segments_key] - - @property - def _device(self) -> str: - """Derive device from resources configuration.""" - return "cuda" if self.resources.requires_gpu else "cpu" - - def setup_on_node( - self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None - ) -> None: - """Setup stage on node.""" - if self._vad_model is None: - self._vad_model = WhisperXVADModel( - device="cpu", - vad_onset=self.vad_onset, - vad_offset=self.vad_offset, - ) - - def setup(self, _: WorkerMetadata | None = None) -> None: - if self._vad_model is None: - self._vad_model = WhisperXVADModel( - device=self._device, - vad_onset=self.vad_onset, - vad_offset=self.vad_offset, - ) - self._vad_model.to(self._device) - logger.info(f"[{self.name}] Initialized WhisperX VAD on {self._device}") - - def process(self, task: AudioTask) -> AudioTask: - t0 = time.perf_counter() - data_entry = task.data - file_path = data_entry[self.audio_filepath_key] - duration = data_entry.get("duration", get_audio_duration(file_path)) - if duration < self.min_length: - logger.warning(f"Skipping {file_path} because it is less than {self.min_length} seconds") - data_entry[self.segments_key] = [] - self._log_metrics( - { - "process_time": time.perf_counter() - t0, - "audio_duration": duration, - "vad_segments_detected": 0, - "skipped_short": 1.0, - } - ) - return task - - data, sr = sf.read(file_path, dtype="float32") - audio = np.expand_dims(data, axis=0) if data.ndim == 1 else data.T - vad_segments = self._vad_model.get_vad_segments(audio, self.max_length, sample_rate=sr) - data_entry[self.segments_key] = vad_segments - self._log_metrics( - { - "process_time": time.perf_counter() - t0, - "audio_duration": duration, - "vad_segments_detected": len(vad_segments), - "skipped_short": 0.0, - } - ) - return task diff --git a/nemo_curator/stages/audio/tagging/__init__.py b/nemo_curator/stages/audio/tagging/__init__.py index b2c914faca..44457c2414 100644 --- a/nemo_curator/stages/audio/tagging/__init__.py +++ b/nemo_curator/stages/audio/tagging/__init__.py @@ -41,7 +41,8 @@ # --- Inference (stage-adapter split per SDP-V2 design) --- "DiarizationStage": "nemo_curator.stages.audio.inference.speaker_diarization", "PyAnnoteDiarizationAdapter": "nemo_curator.adapters.diarization", - "WhisperXVADStage": "nemo_curator.stages.audio.inference.vad.whisperx_vad", + "VADStage": "nemo_curator.stages.audio.inference.vad", + "WhisperXVADAdapter": "nemo_curator.adapters.vad", } _cache: dict[str, Any] = {} diff --git a/tests/adapters/vad/__init__.py b/tests/adapters/vad/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/adapters/vad/test_whisperx_adapter.py b/tests/adapters/vad/test_whisperx_adapter.py new file mode 100644 index 0000000000..6aced81978 --- /dev/null +++ b/tests/adapters/vad/test_whisperx_adapter.py @@ -0,0 +1,119 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for WhisperXVADAdapter (WhisperX VAD model mocked).""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +pytest.importorskip("whisperx") + +from nemo_curator.adapters.vad import VADResult, WhisperXVADAdapter + + +class TestWhisperXVADAdapterConstruction: + def test_defaults(self) -> None: + a = WhisperXVADAdapter() + assert a.model_id == "whisperx/vad" + assert a.device == "cuda" + assert a.vad_onset == 0.5 + assert a.vad_offset == 0.363 + assert a.max_length == 40.0 + assert a.min_length == 0.5 + assert a.last_metrics == {} + + def test_conforms_to_protocol(self) -> None: + from nemo_curator.adapters.vad import VADAdapter + + assert isinstance(WhisperXVADAdapter(), VADAdapter) + + +class TestWhisperXVADAdapterLifecycle: + @patch("nemo_curator.adapters.vad.whisperx.WhisperXVADModel") + def test_prefetch_constructs_cpu_model(self, mock_model_cls: MagicMock) -> None: + WhisperXVADAdapter.prefetch_weights("whisperx/vad") + mock_model_cls.assert_called_once_with(device="cpu") + + @patch("nemo_curator.adapters.vad.whisperx.WhisperXVADModel") + def test_setup_wires_model(self, mock_model_cls: MagicMock) -> None: + mock_model = MagicMock() + mock_model_cls.return_value = mock_model + a = WhisperXVADAdapter(device="cpu", vad_onset=0.7, vad_offset=0.2) + a.setup() + mock_model_cls.assert_called_once_with(device="cpu", vad_onset=0.7, vad_offset=0.2) + mock_model.to.assert_called_once_with("cpu") + + @patch("nemo_curator.adapters.vad.whisperx.WhisperXVADModel") + def test_teardown_clears_state(self, _mc: MagicMock) -> None: + a = WhisperXVADAdapter(device="cpu") + a.setup() + a.teardown() + assert a._vad_model is None + + +class TestWhisperXVADAdapterDetectBatch: + def test_empty_batch_returns_empty(self) -> None: + a = WhisperXVADAdapter() + assert a.detect_batch([]) == [] + + def test_requires_setup(self) -> None: + a = WhisperXVADAdapter() + with pytest.raises(RuntimeError, match="setup\\(\\) must be called"): + a.detect_batch([{"audio_filepath": "/tmp/x.wav"}]) + + def test_missing_audio_filepath_skipped(self) -> None: + a = WhisperXVADAdapter() + a._vad_model = MagicMock() + results = a.detect_batch([{"audio_filepath": None}]) + assert len(results) == 1 + assert results[0].intervals == [] + assert a.last_metrics["skipped_short_total"] == 1.0 + + @patch("nemo_curator.adapters.vad.whisperx.get_audio_duration", return_value=0.2) + def test_short_clip_skipped(self, _mock_dur: MagicMock) -> None: + a = WhisperXVADAdapter(min_length=0.5) + a._vad_model = MagicMock() + results = a.detect_batch([{"audio_filepath": "/tmp/x.wav"}]) + assert results[0].intervals == [] + # Skipped clips MUST NOT call the VAD model. + a._vad_model.get_vad_segments.assert_not_called() + + @patch("nemo_curator.adapters.vad.whisperx.sf.read") + @patch("nemo_curator.adapters.vad.whisperx.get_audio_duration", return_value=5.0) + def test_happy_path_emits_intervals( + self, _mock_dur: MagicMock, mock_read: MagicMock + ) -> None: + # 5-sec mono audio at 16 kHz + mock_read.return_value = (np.zeros(16000 * 5, dtype=np.float32), 16000) + a = WhisperXVADAdapter(device="cpu", min_length=0.5, max_length=40.0) + a._vad_model = MagicMock() + a._vad_model.get_vad_segments.return_value = [ + {"start": 0.5, "end": 2.5}, + {"start": 3.0, "end": 4.7}, + ] + results = a.detect_batch([{"audio_filepath": "/tmp/x.wav"}]) + assert len(results) == 1 + result = results[0] + assert isinstance(result, VADResult) + assert [(iv.start, iv.end) for iv in result.intervals] == [(0.5, 2.5), (3.0, 4.7)] + # mono audio -> expand_dims along axis 0 -> (1, N) + passed_audio = a._vad_model.get_vad_segments.call_args.args[0] + assert passed_audio.shape == (1, 80000) + # last_metrics is populated. + assert a.last_metrics["batch_size"] == 1.0 + assert a.last_metrics["vad_intervals_detected_total"] == 2.0 diff --git a/tests/stages/audio/inference/vad/test_vad_stage.py b/tests/stages/audio/inference/vad/test_vad_stage.py new file mode 100644 index 0000000000..6a64a9392b --- /dev/null +++ b/tests/stages/audio/inference/vad/test_vad_stage.py @@ -0,0 +1,201 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the generic VADStage. + +Stage is tested with a fake adapter -- no WhisperX import, no GPU. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import patch + +import pytest + +from nemo_curator.adapters.vad import VADInterval, VADResult +from nemo_curator.stages.audio.inference.vad import VADStage +from nemo_curator.tasks import AudioTask + + +@dataclass +class _FakeVADAdapter: + model_id: str = "fake/vad" + revision: str | None = None + device: str = "cpu" + fixed_result: VADResult | None = None + setup_called: int = 0 + teardown_called: int = 0 + last_batch: list[dict[str, Any]] | None = None + last_metrics: dict[str, float] = field(default_factory=dict) + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + del model_id, revision + + def setup(self) -> None: + self.setup_called += 1 + + def teardown(self) -> None: + self.teardown_called += 1 + + def detect_batch(self, items: list[dict[str, Any]]) -> list[VADResult]: + self.last_batch = list(items) + self.last_metrics = {"batch_size": float(len(items))} + if self.fixed_result is not None: + return [self.fixed_result for _ in items] + return [ + VADResult( + intervals=[VADInterval(start=0.0, end=1.0)], + model_id=self.model_id, + extras={"duration_s": 5.0}, + ) + for _ in items + ] + + +_ADAPTER_TARGET = f"{__name__}._FakeVADAdapter" + + +def _task(**data: Any) -> AudioTask: # noqa: ANN401 + base = {"resampled_audio_filepath": "/tmp/x.wav", "duration": 5.0} + base.update(data) + return AudioTask(data=base) + + +class TestVADStageConstruction: + def test_requires_adapter_target(self) -> None: + with pytest.raises(ValueError, match="adapter_target is required"): + VADStage() + + def test_defaults(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + assert s.audio_filepath_key == "resampled_audio_filepath" + assert s.segments_key == "vad_segments" + + def test_inputs_outputs(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + assert s.inputs() == ([], ["resampled_audio_filepath"]) + assert s.outputs() == ([], ["resampled_audio_filepath", "vad_segments"]) + + +class TestVADStageLifecycle: + def test_setup_instantiates_adapter(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET, model_id="m", revision="r") + s.setup() + assert isinstance(s._adapter, _FakeVADAdapter) + assert s._adapter.model_id == "m" + assert s._adapter.revision == "r" + assert s._adapter.setup_called == 1 + + def test_setup_forwards_device(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + s.setup() + assert s._adapter.device == "cuda" + + def test_teardown_clears_adapter(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + s.setup() + adapter = s._adapter + s.teardown() + assert s._adapter is None + assert adapter.teardown_called == 1 + + def test_setup_on_node_calls_prefetch(self) -> None: + with patch.object(_FakeVADAdapter, "prefetch_weights") as mock_pf: + s = VADStage(adapter_target=_ADAPTER_TARGET, model_id="m") + s.setup_on_node() + mock_pf.assert_called_once_with("m", None) + + def test_setup_on_node_swallows_when_disabled(self) -> None: + with patch.object(_FakeVADAdapter, "prefetch_weights", side_effect=RuntimeError("boom")): + s = VADStage(adapter_target=_ADAPTER_TARGET, prefetch_fail_on_error=False) + s.setup_on_node() # must not raise + + def test_xenna_stage_spec(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET, xenna_num_workers=3) + assert s.xenna_stage_spec() == {"num_workers": 3} + + +class TestVADStageProcess: + def test_process_requires_setup(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + with pytest.raises(RuntimeError, match="setup\\(\\) was not called"): + s.process(_task()) + + def test_process_requires_audio_filepath(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + s.setup() + with pytest.raises(ValueError, match="Missing key 'resampled_audio_filepath'"): + s.process(AudioTask(data={"duration": 5.0})) + + def test_process_writes_canonical_intervals(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + s.setup() + s._adapter.fixed_result = VADResult( + intervals=[ + VADInterval(start=0.0, end=1.5), + VADInterval(start=2.0, end=3.5), + ], + model_id="x", + extras={"duration_s": 5.0}, + ) + out = s.process(_task()) + assert out.data["vad_segments"] == [ + {"start": 0.0, "end": 1.5}, + {"start": 2.0, "end": 3.5}, + ] + + def test_process_empty_intervals_skipped_short(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + s.setup() + s._adapter.fixed_result = VADResult(intervals=[], model_id="x", extras={"duration_s": 0.2}) + out = s.process(_task(duration=0.2)) + assert out.data["vad_segments"] == [] + + def test_process_forwards_item_dict_to_adapter(self) -> None: + s = VADStage( + adapter_target=_ADAPTER_TARGET, + waveform_key="waveform", + sample_rate_key="sample_rate", + ) + s.setup() + s.process(_task(waveform=[1.0], sample_rate=16000)) + assert s._adapter.last_batch is not None + item = s._adapter.last_batch[0] + assert item["audio_filepath"] == "/tmp/x.wav" + assert item["duration"] == 5.0 + assert item["waveform"] == [1.0] + assert item["sample_rate"] == 16000 + + +class TestVADStageMetrics: + def test_log_metrics_includes_adapter_aliases(self) -> None: + s = VADStage(adapter_target=_ADAPTER_TARGET) + s.setup() + observed: list[dict[str, float]] = [] + s._log_metrics = observed.append # type: ignore[assignment] + s._adapter.fixed_result = VADResult( + intervals=[VADInterval(start=0.0, end=1.0)], + model_id="x", + extras={"duration_s": 5.0}, + ) + s.process(_task()) + assert observed + m = observed[-1] + assert m["vad_segments_detected"] == 1.0 + assert m["audio_duration"] == 5.0 + assert m["model_batch_size"] == 1.0 + assert "process_time" in m diff --git a/tests/stages/audio/inference/vad/test_whisperx_vad.py b/tests/stages/audio/inference/vad/test_whisperx_vad.py deleted file mode 100644 index 348db8bfb3..0000000000 --- a/tests/stages/audio/inference/vad/test_whisperx_vad.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path - -import pytest - -from nemo_curator.stages.audio.inference.vad.whisperx_vad import WhisperXVADStage -from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import AudioTask - - -class TestWhisperXVADStage: - @pytest.mark.gpu - def test_process(self, wav_filepath: Path) -> None: - stage = WhisperXVADStage( - min_length=0.5, - max_length=40.0, - segments_key="vad_segments", - resources=Resources(gpus=1), - ) - stage.setup() - - entry = { - "resampled_audio_filepath": str(wav_filepath), - "duration": 60.0, - } - task = AudioTask(data=entry) - result = stage.process(task) - out = result.data - assert "vad_segments" in out - assert isinstance(out["vad_segments"], list) - assert len(out["vad_segments"]) == 2 From b9f3c8b4b207c26e90f3f522b269e34102eb3025 Mon Sep 17 00:00:00 2001 From: "aaftaabv@gmail.com" Date: Fri, 29 May 2026 15:26:02 +0530 Subject: [PATCH 3/6] audio/alignment: stage-adapter split per SDP-V2 design Bring the SDP-V2 design doc (sec 13 forced alignment) stage-adapter split into the audio tagging pipeline. What moves ---------- * New nemo_curator/adapters/alignment/ package with: - base.py - WordAlignment + AlignmentResult dataclasses + ForcedAlignmentAdapter Protocol (model_id / revision / setup / teardown / prefetch_weights classmethod / align_batch). - nemo_asr_align.py - NeMoASRAlignAdapter that re-houses the NeMo-specific code path from the deleted NeMoASRAlignerStage body: ASRModel.from_pretrained / restore_from, FastConformer config (change_attention_model / change_subsampling_conv_chunking_factor), decoder config (CTC vs RNNT, preserve_alignments, preserve_word_confidence, compute_timestamps), _override_cfg setup, transcribe() with one-by-one retry fallback, get_alignments_text time-stride math (FastConformer 8x window_stride vs default 4x, RNNT -0.08s offset), and U+2047 \"?\" glyph strip on the joined text. * New nemo_curator/stages/audio/inference/alignment/stage.py with ForcedAlignmentStage that owns Curator-side glue only: task.data plumbing, split_filepaths fan-out + scatter, segment-mode in-memory audio cut (the _prepare_segment_batch_with_metadata helper, moved from BaseASRProcessorStage), per-segment time-offset adjustment, batch homogeneity guarantee, metric logging. YAML shape (Tier-1 / Tier-2 split) ---------------------------------- - _target_: nemo_curator.stages.audio.inference.alignment.ForcedAlignmentStage name: ASRAlignment adapter_target: nemo_curator.adapters.alignment.NeMoASRAlignAdapter model_id: nvidia/parakeet-tdt_ctc-1.1b batch_size: 32 adapter_kwargs: is_fastconformer: true decoder_type: rnnt What is deleted --------------- * nemo_curator/stages/audio/tagging/inference/nemo_asr_align.py (NeMoASRAlignerStage + BaseASRProcessorStage) * nemo_curator/stages/audio/tagging/inference/ (empty after removal) * tests/stages/audio/tagging/inference/test_base_asr_processor.py * tests/stages/audio/tagging/inference/test_nemo_asr_align.py * tests/stages/audio/tagging/inference/ (empty after removal) Call-site migrations -------------------- * tutorials/audio/tagging/tts_pipeline.yaml * tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml * tests/stages/audio/tagging/e2e/test_tts_e2e.py (comment) * benchmarking/scripts/audio_tagging_benchmark.py * nemo_curator/stages/audio/tagging/__init__.py (lazy-import map - drops NeMoASRAlignerStage + BaseASRProcessorStage, adds ForcedAlignmentStage + NeMoASRAlignAdapter) * nemo_curator/stages/audio/tagging/split.py (SplitASRAlignJoinStage.decompose() now wires ForcedAlignmentStage with NeMoASRAlignAdapter; public API of SplitASRAlignJoinStage is unchanged.) * tutorials/audio/tagging/README.md (table row) New tests --------- * tests/stages/audio/inference/alignment/test_forced_alignment_stage.py - 15 CPU tests against a fake adapter (construction, lifecycle, full-audio mode scatter into split_metadata, top-level fallback when no split_metadata, sentinel non-list split_filepaths, segment-only mode with mocked torchaudio.load, time-offset adjustment, eligible-segment min_len filter, metrics). * tests/adapters/alignment/test_nemo_asr_align_adapter.py - NeMoASRAlignAdapter construction + validation (rejects unknown decoder_type / timestamp_type), prefetch_weights happy + failure paths, align_batch path-mode and segment-mode dispatch, transcribe-tuple unwrap, batch-failure one-by-one retry in path-mode (succeed) vs segment-mode (raise), get_alignments_text CTC vs RNNT time-stride math, U+2047 strip, compute_timestamps False short-circuit. Behaviour preservation ---------------------- * FastConformer time_stride = 8 * window_stride (default 4x) - byte-for-byte preserved. * RNNT start/end = max(0, offset * stride - 0.08) - preserved. * Confidence rounding to 4 decimals + start/end rounding to 3 - preserved. * One-by-one retry fallback when batch transcribe fails in path-mode - preserved. * SplitASRAlignJoinStage external dataclass fields and public API unchanged; only its internal decompose() changes. Follow-ups (out of scope; same posture as PR1967) ------------------------------------------------- * NeMoNFAAdapter (the SDP-V2 doc \"canonical\" alignment adapter - NFA, not transcribe(timestamps=True)) - the stage is now shaped for it; it ships in a separate commit. * WhisperXAlignmentAdapter - same. Signed-off-by: aaftaabv@gmail.com --- .../scripts/audio_tagging_benchmark.py | 14 +- nemo_curator/adapters/alignment/__init__.py | 41 ++ nemo_curator/adapters/alignment/base.py | 144 ++++++ .../adapters/alignment/nemo_asr_align.py | 292 ++++++++++++ .../audio/inference/alignment/__init__.py | 19 + .../stages/audio/inference/alignment/stage.py | 357 ++++++++++++++ nemo_curator/stages/audio/tagging/__init__.py | 5 +- .../audio/tagging/inference/nemo_asr_align.py | 449 ------------------ nemo_curator/stages/audio/tagging/split.py | 36 +- .../adapters/alignment}/__init__.py | 0 .../alignment/test_nemo_asr_align_adapter.py | 241 ++++++++++ .../alignment}/__init__.py | 0 .../alignment/test_forced_alignment_stage.py | 280 +++++++++++ .../tagging/e2e/configs/tts_pipeline.yaml | 15 +- .../stages/audio/tagging/e2e/test_tts_e2e.py | 2 +- .../inference/test_base_asr_processor.py | 96 ---- .../tagging/inference/test_nemo_asr_align.py | 58 --- tutorials/audio/tagging/README.md | 2 +- tutorials/audio/tagging/tts_pipeline.yaml | 11 +- 19 files changed, 1422 insertions(+), 640 deletions(-) create mode 100644 nemo_curator/adapters/alignment/__init__.py create mode 100644 nemo_curator/adapters/alignment/base.py create mode 100644 nemo_curator/adapters/alignment/nemo_asr_align.py create mode 100644 nemo_curator/stages/audio/inference/alignment/__init__.py create mode 100644 nemo_curator/stages/audio/inference/alignment/stage.py delete mode 100644 nemo_curator/stages/audio/tagging/inference/nemo_asr_align.py rename {nemo_curator/stages/audio/tagging/inference => tests/adapters/alignment}/__init__.py (100%) create mode 100644 tests/adapters/alignment/test_nemo_asr_align_adapter.py rename tests/stages/audio/{tagging/inference => inference/alignment}/__init__.py (100%) create mode 100644 tests/stages/audio/inference/alignment/test_forced_alignment_stage.py delete mode 100644 tests/stages/audio/tagging/inference/test_base_asr_processor.py delete mode 100644 tests/stages/audio/tagging/inference/test_nemo_asr_align.py diff --git a/benchmarking/scripts/audio_tagging_benchmark.py b/benchmarking/scripts/audio_tagging_benchmark.py index d282bcfc83..e0c2268e0d 100644 --- a/benchmarking/scripts/audio_tagging_benchmark.py +++ b/benchmarking/scripts/audio_tagging_benchmark.py @@ -32,7 +32,7 @@ from nemo_curator.pipeline import Pipeline from nemo_curator.stages.audio.common import ManifestReader, ManifestWriterStage from nemo_curator.stages.audio.inference.speaker_diarization import DiarizationStage -from nemo_curator.stages.audio.tagging.inference.nemo_asr_align import NeMoASRAlignerStage +from nemo_curator.stages.audio.inference.alignment import ForcedAlignmentStage from nemo_curator.stages.audio.tagging.merge_alignment_diarization import MergeAlignmentDiarizationStage from nemo_curator.stages.audio.tagging.resample_audio import ResampleAudioStage from nemo_curator.stages.audio.tagging.split import JoinSplitAudioMetadataStage, SplitLongAudioStage @@ -108,13 +108,17 @@ def run_audio_tagging_benchmark( # noqa: PLR0913 ).with_(resources=Resources(cpus=cpus)) ) - # ASR forced alignment (NeMo FastConformer) + # ASR forced alignment (ForcedAlignmentStage + NeMoASRAlignAdapter) pipeline.add_stage( - NeMoASRAlignerStage( + ForcedAlignmentStage( name="ASRAlignment", - is_fastconformer=True, - decoder_type="rnnt", + adapter_target="nemo_curator.adapters.alignment.NeMoASRAlignAdapter", + model_id="nvidia/parakeet-tdt_ctc-1.1b", batch_size=asr_batch_size, + adapter_kwargs={ + "is_fastconformer": True, + "decoder_type": "rnnt", + }, ).with_(resources=Resources(cpus=cpus, gpus=0.45)) ) diff --git a/nemo_curator/adapters/alignment/__init__.py b/nemo_curator/adapters/alignment/__init__.py new file mode 100644 index 0000000000..09f672f58c --- /dev/null +++ b/nemo_curator/adapters/alignment/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Forced-alignment adapter family for the SDP-V2 stage-adapter split. + +Public surface (the only symbols the stage imports): + +* :class:`ForcedAlignmentAdapter` - structural protocol every alignment + adapter implements. +* :class:`AlignmentResult` - canonical per-utterance result dataclass. +* :class:`WordAlignment` - canonical per-word dataclass. + +Concrete adapters live in their own modules (e.g. ``nemo_asr_align.py``, +``nemo_nfa.py``, ``whisperx_alignment.py``) and are resolved at runtime +by their fully-qualified class path in YAML's ``adapter_target`` field. +""" + +from nemo_curator.adapters.alignment.base import ( + AlignmentResult, + ForcedAlignmentAdapter, + WordAlignment, +) +from nemo_curator.adapters.alignment.nemo_asr_align import NeMoASRAlignAdapter + +__all__ = [ + "AlignmentResult", + "ForcedAlignmentAdapter", + "NeMoASRAlignAdapter", + "WordAlignment", +] diff --git a/nemo_curator/adapters/alignment/base.py b/nemo_curator/adapters/alignment/base.py new file mode 100644 index 0000000000..73b27a9d82 --- /dev/null +++ b/nemo_curator/adapters/alignment/base.py @@ -0,0 +1,144 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage-adapter contract for forced alignment (SDP-V2 design doc §13). + +Mirrors the ASR / diarization / VAD contract pattern: + +* :class:`~nemo_curator.stages.audio.inference.alignment.ForcedAlignmentStage` + owns Curator-side glue (task.data reads, split-filepath fan-out + scatter, + segment cut, time-offset adjustment, metric logging). +* :class:`ForcedAlignmentAdapter` owns the model-side library call + (weight prefetch, model setup, decoder configuration, the actual + ``transcribe(...)`` invocation, hypothesis-to-word-alignment + conversion) and packs results into the canonical + :class:`AlignmentResult` shape. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + + +@dataclass +class WordAlignment: + """Canonical per-word alignment dataclass. + + Attributes: + word: The aligned word (or character, when the adapter uses + char-level timestamps). + start: Word start time in seconds (clip / segment coordinates, + see :class:`AlignmentResult`). + end: Word end time in seconds. + confidence: Optional adapter-supplied per-word confidence + score in ``[0, 1]``. ``None`` when the adapter doesn't + surface one. + """ + + word: str + start: float + end: float + confidence: float | None = None + + +@dataclass +class AlignmentResult: + """Canonical per-input alignment adapter output. + + Attributes: + alignments: One :class:`WordAlignment` per emitted word + (or char). The stage applies any necessary time-offset + shift before writing this onto ``task.data``. Empty list + when the adapter could not process the input. + text: Concatenated transcription text. The stage writes this + onto ``task.data[text_key]``. + model_id: The actual model identifier the adapter ran (mirrors + the stage's ``model_id`` field; populated by the adapter). + extras: Adapter-specific scalar / structured diagnostics that + do not fit the canonical shape. Stage never reads inside + this dict. + """ + + alignments: list[WordAlignment] + text: str = "" + model_id: str = "" + extras: dict[str, Any] = field(default_factory=dict) + + +@runtime_checkable +class ForcedAlignmentAdapter(Protocol): + """Structural protocol every forced-alignment adapter must implement. + + Constructor contract: adapters are constructed by the stage as + ``cls(model_id=..., revision=..., **adapter_kwargs)``. Tier-2 knobs + are adapter-specific (decoder type, FastConformer toggle, batch + sizes, ...). + + Per-batch contract: :meth:`align_batch` receives a list of dicts + (Tier-3 per-task knobs unpacked from ``task.data`` by the stage) + and returns one :class:`AlignmentResult` per input, in the same + order. + + Expected per-item keys (the stage populates these; the adapter + reads whichever is present): + + * ``audio_path`` (``str | None``): Path to a decodable audio file. + Used for full-audio / split-filepath inference. + * ``audio_segment`` (``numpy.ndarray | None``): In-memory mono + audio array, one segment cut. Used for segment-only inference. + * ``sample_rate`` (``int | None``): Sample rate of + ``audio_segment`` (only meaningful in segment-mode). + * ``task_id`` (``str | None``): Carried through for diagnostics. + + A batch must be homogeneous - either all items have ``audio_path`` + OR all have ``audio_segment``; the stage guarantees this. + + Attributes: + model_id: Identifier of the underlying model checkpoint. + last_metrics: Scalar metrics from the last :meth:`align_batch` + call. The stage merges these into ``_log_metrics`` output + under ``model_`` aliases. + """ + + model_id: str + last_metrics: dict[str, float] + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + """Download weights to local cache without allocating a GPU.""" + ... + + def setup(self) -> None: + """Load the model into the worker's process.""" + ... + + def teardown(self) -> None: + """Release GPU memory and worker-local state.""" + ... + + def align_batch(self, items: list[dict[str, Any]]) -> list[AlignmentResult]: + """Run forced alignment on a batch of per-task dicts. + + Args: + items: One dict per task with the keys documented on the + class docstring. Length matches the batch size. + + Returns: + One :class:`AlignmentResult` per input, in the same order. + Items the adapter could not process must still appear with + empty ``alignments`` and ``text=""`` so the stage can + scatter results 1:1. + """ + ... diff --git a/nemo_curator/adapters/alignment/nemo_asr_align.py b/nemo_curator/adapters/alignment/nemo_asr_align.py new file mode 100644 index 0000000000..0b5d044e07 --- /dev/null +++ b/nemo_curator/adapters/alignment/nemo_asr_align.py @@ -0,0 +1,292 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NeMo ASR forced-alignment adapter. + +Implements :class:`~nemo_curator.adapters.alignment.ForcedAlignmentAdapter` +on top of NeMo's ``ASRModel.transcribe(timestamps=True)`` path +(FastConformer + CTC / RNNT decoders). + +Logic moved verbatim from the pre-split +``nemo_curator.stages.audio.tagging.inference.nemo_asr_align.NeMoASRAlignerStage`` +body so per-word timestamps, confidences, RNNT 0.08 s offset, ``⁇`` +strip and one-by-one retry fallback are byte-for-byte preserved. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any + +import nemo.collections.asr as nemo_asr +import torch +from loguru import logger +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig + +from nemo_curator.adapters.alignment.base import AlignmentResult, WordAlignment + + +@dataclass +class NeMoASRAlignAdapter: + """NeMo-backed implementation of :class:`ForcedAlignmentAdapter`. + + Tier-2 knobs (set via ``adapter_kwargs`` in YAML): + + Attributes: + model_id: Pretrained model identifier passed to + ``ASRModel.from_pretrained`` (e.g. + ``"nvidia/parakeet-tdt_ctc-1.1b"``). Ignored when + ``model_path`` is set. + revision: Accepted for protocol uniformity. NeMo's + ``from_pretrained`` does not currently accept a revision + argument; passed through to ``extras`` for diagnostics. + model_path: Optional local ``.nemo`` checkpoint path. When set + overrides ``model_id``. + device: ``"cuda"`` or ``"cpu"``; passed by the stage. + is_fastconformer: Whether the model encoder is FastConformer + (triggers ``change_attention_model`` / + ``change_subsampling_conv_chunking_factor`` calls and + adjusts the per-token time stride). + decoder_type: ``"ctc"`` or ``"rnnt"``. + timestamp_type: ``"word"`` or ``"char"``. + transcribe_batch_size: Batch size for the NeMo + ``transcribe`` call. + num_workers: Number of data-loading workers. + compute_timestamps: When False, returns alignments=[] (text + only). Pre-split behaviour was True. + disable_word_confidence: When True, the adapter does NOT + populate ``WordAlignment.confidence``. + """ + + # ---- Required protocol fields ---- + model_id: str = "nvidia/parakeet-tdt_ctc-1.1b" + revision: str | None = None + + # ---- Adapter-specific knobs ---- + model_path: str | None = None + device: str = "cuda" + is_fastconformer: bool = True + decoder_type: str = "rnnt" + timestamp_type: str = "word" + transcribe_batch_size: int = 32 + num_workers: int = 10 + compute_timestamps: bool = True + disable_word_confidence: bool = False + + # ---- Internal state ---- + _asr_model: Any = field(default=None, repr=False) + _override_cfg: Any = field(default=None, repr=False) + last_metrics: dict[str, float] = field(default_factory=dict) + + # ------------------------------------------------------------------ + # Adapter contract + # ------------------------------------------------------------------ + + def __post_init__(self) -> None: + if self.decoder_type not in ("ctc", "rnnt"): + msg = f"decoder_type must be 'ctc' or 'rnnt', got {self.decoder_type}" + raise ValueError(msg) + if self.timestamp_type not in ("word", "char"): + msg = f"timestamp_type must be 'word' or 'char', got {self.timestamp_type}" + raise ValueError(msg) + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + """Download model weights without instantiating the GPU runtime. + + ``ASRModel.from_pretrained(return_model_file=True)`` is the + public entry point that triggers the HF / NGC download. + """ + del revision # NeMo's from_pretrained doesn't take a revision arg today. + if not model_id: + return + try: + nemo_asr.models.ASRModel.from_pretrained(model_name=model_id, return_model_file=True) + except Exception as exc: # noqa: BLE001 + msg = f"NeMoASRAlignAdapter: failed to download model {model_id}" + raise RuntimeError(msg) from exc + + def setup(self) -> None: + if self._asr_model is None: + if self.model_path: + self._asr_model = nemo_asr.models.ASRModel.restore_from(restore_path=self.model_path) + else: + self._asr_model = nemo_asr.models.ASRModel.from_pretrained( + model_name=self.model_id, + map_location=torch.device(self.device), + ) + + self._asr_model.to(self.device) + self._asr_model.eval() + + if self.is_fastconformer: + self._asr_model.change_attention_model( + self_attention_model="rel_pos_local_attn", att_context_size=[128, 128] + ) + self._asr_model.change_subsampling_conv_chunking_factor(1) + + decoding_cfg = CTCDecodingConfig() if self.decoder_type == "ctc" else RNNTDecodingConfig() + if self.decoder_type == "ctc": + decoding_cfg.strategy = "greedy_batch" + else: + decoding_cfg.rnnt_timestamp_type = self.timestamp_type + + decoding_cfg.preserve_alignments = self.compute_timestamps + decoding_cfg.confidence_cfg.preserve_word_confidence = not self.disable_word_confidence + decoding_cfg.compute_timestamps = self.compute_timestamps + decoding_cfg.greedy.compute_timestamps = self.compute_timestamps + + self._asr_model.change_decoding_strategy(decoding_cfg=decoding_cfg) + + self._override_cfg = self._asr_model.get_transcribe_config() + self._override_cfg.batch_size = self.transcribe_batch_size + self._override_cfg.num_workers = self.num_workers + self._override_cfg.return_hypotheses = True + self._override_cfg.timestamps = self.compute_timestamps + + logger.info("NeMoASRAlignAdapter ready on {} (model={})", self.device, self.model_id) + + def teardown(self) -> None: + self._asr_model = None + self._override_cfg = None + + def align_batch(self, items: list[dict[str, Any]]) -> list[AlignmentResult]: + if not items: + return [] + if self._asr_model is None: + msg = "NeMoASRAlignAdapter.setup() must be called before align_batch()" + raise RuntimeError(msg) + + t0 = time.perf_counter() + # Classify the batch as path-mode or segment-mode. The stage + # guarantees homogeneity per call; we double-check defensively. + first = items[0] + if first.get("audio_segment") is not None: + transcribe_inputs = [it["audio_segment"] for it in items] + mode = "segment" + else: + transcribe_inputs = [it.get("audio_path") for it in items] + mode = "path" + + hypotheses_list = self._transcribe(transcribe_inputs, mode=mode) + + results: list[AlignmentResult] = [] + for hyp in hypotheses_list: + if hyp is None: + results.append(AlignmentResult(alignments=[], text="", model_id=self.model_id)) + continue + alignments, text = self._get_alignments_text(hyp) + results.append( + AlignmentResult( + alignments=[ + WordAlignment( + word=w["word"], + start=w["start"], + end=w["end"], + confidence=w.get("confidence"), + ) + for w in alignments + ], + text=text, + model_id=self.model_id, + ) + ) + + self.last_metrics = { + "batch_size": float(len(items)), + "align_time_s_total": float(time.perf_counter() - t0), + "mode_is_segment": 1.0 if mode == "segment" else 0.0, + } + return results + + # ------------------------------------------------------------------ + # Internal helpers (moved verbatim from NeMoASRAlignerStage) + # ------------------------------------------------------------------ + + def _transcribe(self, inputs: list[Any], mode: str) -> list[Any]: + """Run ``ASRModel.transcribe`` with a one-by-one retry fallback.""" + try: + with torch.no_grad(): + hypotheses_list = self._asr_model.transcribe( + inputs, override_config=self._override_cfg + ) + if isinstance(hypotheses_list, tuple) and len(hypotheses_list) == 2: # noqa: PLR2004 + hypotheses_list = hypotheses_list[0] + return list(hypotheses_list) + except Exception as exc: # noqa: BLE001 + if mode == "segment": + msg = f"NeMoASRAlignAdapter: batch transcribe failed in segment mode: {exc}" + raise ValueError(msg) from exc + logger.error( + "NeMoASRAlignAdapter: batch transcribe failed ({}), retrying one-by-one", exc + ) + out: list[Any] = [] + for single_input in inputs: + try: + with torch.no_grad(): + hyp = self._asr_model.transcribe( + [single_input], override_config=self._override_cfg + ) + if isinstance(hyp, tuple) and len(hyp) == 2: # noqa: PLR2004 + hyp = hyp[0] + out.append(hyp[0] if hyp else None) + except Exception as exc2: # noqa: BLE001, PERF203 + logger.error("NeMoASRAlignAdapter: per-item transcribe failed for {}: {}", single_input, exc2) + out.append(None) + return out + + def _get_alignments_text(self, hypothesis: Any) -> tuple[list[dict[str, Any]], str]: + """Extract word alignments + text from a single NeMo Hypothesis.""" + if not self.compute_timestamps: + return [], hypothesis.text + + timestamp_dict = hypothesis.timestamp + + if self.is_fastconformer: + time_stride = 8 * self._asr_model.cfg.preprocessor.window_stride + else: + time_stride = 4 * self._asr_model.cfg.preprocessor.window_stride + + word_timestamps = timestamp_dict[self.timestamp_type] + + alignments: list[dict[str, Any]] = [] + for i, stamp in enumerate(word_timestamps): + conf: float | None = None + if hypothesis.word_confidence is not None and i < len(hypothesis.word_confidence): + raw = hypothesis.word_confidence[i] + if isinstance(raw, torch.Tensor): + raw = raw.item() + conf = round(float(raw), 4) + + if self.decoder_type == "ctc": + start = stamp["start_offset"] * time_stride + end = stamp["end_offset"] * time_stride + else: + start = max(0, stamp["start_offset"] * time_stride - 0.08) + end = max(0, stamp["end_offset"] * time_stride - 0.08) + + word = stamp.get("word", stamp.get("char", "")) + alignments.append( + { + "word": word, + "start": round(start, 3), + "end": round(end, 3), + "confidence": conf, + } + ) + + text = " ".join(w["word"] for w in alignments).replace("⁇", "") + return alignments, text diff --git a/nemo_curator/stages/audio/inference/alignment/__init__.py b/nemo_curator/stages/audio/inference/alignment/__init__.py new file mode 100644 index 0000000000..8508a60243 --- /dev/null +++ b/nemo_curator/stages/audio/inference/alignment/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic forced-alignment stage (SDP-V2 stage-adapter split, §13).""" + +from nemo_curator.stages.audio.inference.alignment.stage import ForcedAlignmentStage + +__all__ = ["ForcedAlignmentStage"] diff --git a/nemo_curator/stages/audio/inference/alignment/stage.py b/nemo_curator/stages/audio/inference/alignment/stage.py new file mode 100644 index 0000000000..f04e9269b5 --- /dev/null +++ b/nemo_curator/stages/audio/inference/alignment/stage.py @@ -0,0 +1,357 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic forced-alignment Curator stage (SDP-V2 design doc §13). + +Implements the stage half of the SDP-V2 stage-adapter split for the +forced-alignment family. The stage owns Curator-side glue: + +* validates ``task.data`` against ``inputs()`` / ``outputs()``; +* in **full-audio** mode: collects each task's ``split_filepaths``, + flattens them into one homogeneous path-batch, dispatches a single + adapter ``align_batch`` call, then scatters the per-path + :class:`AlignmentResult` back onto the originating + ``split_metadata`` entry (or the task itself when no splits exist); +* in **segment-only** mode: cuts in-memory mono audio for each + segment that exceeds ``min_len`` via the + :meth:`_prepare_segment_batch_with_metadata` helper, dispatches a + single adapter call with the homogeneous segment-batch, then + scatters the per-segment results onto + ``task.data[segments_key][segment_idx]`` with the word timestamps + shifted into clip-coordinate space by ``segment["start"]``; +* writes ``text_key`` / ``words_key`` / ``alignment`` per the + pre-split tagging-pipeline convention; +* emits performance metrics in the shape ``perf_summary_merged.json`` + consumers already expect. + +The stage knows nothing about the underlying ASR model, decoder type, +or FastConformer specifics. The concrete adapter is resolved at +runtime from the YAML's ``adapter_target`` string via +``hydra.utils.get_class``. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import hydra.utils +import torchaudio +from loguru import logger + +from nemo_curator.adapters.alignment.base import AlignmentResult, ForcedAlignmentAdapter +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + +if TYPE_CHECKING: + from nemo_curator.backends.base import NodeInfo, WorkerMetadata + + +@dataclass +class ForcedAlignmentStage(ProcessingStage[AudioTask, AudioTask]): + """Forced-alignment Curator stage with pluggable adapter. + + Args: + adapter_target: Tier-1 swap surface. Fully-qualified class path + of the concrete + :class:`~nemo_curator.adapters.alignment.ForcedAlignmentAdapter` + implementation (e.g. + ``"nemo_curator.adapters.alignment.NeMoASRAlignAdapter"``). + Resolved at ``setup()`` time via ``hydra.utils.get_class``. + model_id: Tier-1. Model checkpoint identifier, forwarded both to + :meth:`ForcedAlignmentAdapter.prefetch_weights` and to the + adapter constructor. + revision: Tier-1. Optional model revision to pin. + text_key: Output key for transcription text (per-split or + per-task). + words_key: Output key for the segment-mode word alignment list. + alignment_key: Output key for the full-audio-mode alignment + list (matches the pre-split convention - the segment-mode + uses ``words_key`` to be consistent with the SDP convention). + segments_key: Input key for the per-task segments list used by + segment-only mode. + infer_segment_only: When True, the stage operates on the + ``segments_key`` segments list rather than on + ``split_filepaths`` / ``split_metadata``. + min_len: Minimum segment duration (seconds) that segment-mode + considers for inference. + max_len: Maximum segment duration (seconds) - currently + informational; matches pre-split semantics. + prefetch_fail_on_error: When False, ``setup_on_node`` warns + and defers weight prefetch to ``setup()``. + adapter_kwargs: Tier-2. Opaque dict forwarded to the adapter + constructor as ``**adapter_kwargs``. + resources / batch_size: Standard Curator stage knobs. + """ + + name: str = "ForcedAlignment" + + # ---- Tier 1: swap surface ---- + adapter_target: str = "" + model_id: str = "" + revision: str | None = None + + # ---- Tier 1: universal stage knobs ---- + text_key: str = "text" + words_key: str = "words" + alignment_key: str = "alignment" + segments_key: str = "segments" + infer_segment_only: bool = False + min_len: float = 1.0 + max_len: float = 40.0 + + prefetch_fail_on_error: bool = True + + # ---- Tier 2: opaque adapter knob bag ---- + adapter_kwargs: dict[str, Any] = field(default_factory=dict) + + # ---- Standard Curator stage knobs ---- + resources: Resources = field(default_factory=lambda: Resources(gpus=1)) + batch_size: int = 100 + + # ---- Internal state ---- + _adapter: ForcedAlignmentAdapter | None = field(default=None, init=False, repr=False) + + def __post_init__(self) -> None: + if not self.adapter_target: + msg = ( + "ForcedAlignmentStage.adapter_target is required - set it in YAML to a " + "fully-qualified adapter class path (e.g. " + "'nemo_curator.adapters.alignment.NeMoASRAlignAdapter')." + ) + raise ValueError(msg) + + # ------------------------------------------------------------------ + # I/O contract + # ------------------------------------------------------------------ + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], ["duration", self.segments_key, "split_filepaths", "split_metadata"] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], ["duration", self.segments_key, "split_filepaths", "split_metadata"] + + @property + def _device(self) -> str: + return "cuda" if self.resources.requires_gpu else "cpu" + + # ------------------------------------------------------------------ + # Adapter lifecycle + # ------------------------------------------------------------------ + + def _adapter_class(self) -> type: + return hydra.utils.get_class(self.adapter_target) + + def setup_on_node( + self, + _node_info: NodeInfo | None = None, + _worker_metadata: WorkerMetadata | None = None, + ) -> None: + try: + prefetch_t0 = time.perf_counter() + self._adapter_class().prefetch_weights(self.model_id, self.revision) + logger.info( + "Forced-alignment weights cached on node for {} ({}) in {:.3f}s", + self.model_id, + self.adapter_target, + time.perf_counter() - prefetch_t0, + ) + except Exception as exc: # noqa: BLE001 + msg = f"ForcedAlignmentStage: prefetch_weights failed for {self.model_id}" + if self.prefetch_fail_on_error: + raise RuntimeError(msg) from exc + logger.warning("{}; setup() will retry: {}", msg, exc) + + def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: + if self._adapter is None: + cls = self._adapter_class() + kwargs = dict(self.adapter_kwargs) + if self.model_id: + kwargs.setdefault("model_id", self.model_id) + kwargs.setdefault("revision", self.revision) + kwargs.setdefault("device", self._device) + self._adapter = cls(**kwargs) + self._adapter.setup() + logger.info("[{}] Forced-alignment adapter ready ({})", self.name, self.adapter_target) + + def teardown(self) -> None: + if self._adapter is not None: + self._adapter.teardown() + self._adapter = None + + # ------------------------------------------------------------------ + # Processing + # ------------------------------------------------------------------ + + def process(self, task: AudioTask) -> AudioTask: + results = self.process_batch([task]) + return results[0] if results else task + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + if not tasks: + return [] + if self._adapter is None: + msg = "Adapter not initialized - setup() was not called" + raise RuntimeError(msg) + + t0 = time.perf_counter() + if self.infer_segment_only: + self._process_segments(tasks) + else: + self._process_full_audio(tasks) + self._log_metrics( + { + "process_time": time.perf_counter() - t0, + "entries_processed": float(len(tasks)), + **{ + f"model_{k}": float(v) + for k, v in (self._adapter.last_metrics or {}).items() + }, + } + ) + return tasks + + # ------------------------------------------------------------------ + # Full-audio path: fan-out split_filepaths, scatter back per split + # ------------------------------------------------------------------ + + def _process_full_audio(self, tasks: list[AudioTask]) -> None: + entries = [task.data for task in tasks] + all_paths: list[str] = [] + path_to_entry_and_split: list[tuple[int, int]] = [] + + for entry_idx, data in enumerate(entries): + split_filepaths = data.get("split_filepaths") + has_splits = isinstance(split_filepaths, list) and len(split_filepaths) > 0 + if not (has_splits or split_filepaths is None): + # Sentinel / skip case from pre-split semantics. + data[self.text_key] = "" + data[self.alignment_key] = [] + continue + if not split_filepaths: + logger.warning( + "[{}] Entry at index {} has no split_filepaths, skipping.", + self.name, + entry_idx, + ) + continue + for split_idx, path in enumerate(split_filepaths): + all_paths.append(path) + path_to_entry_and_split.append((entry_idx, split_idx)) + + if not all_paths: + return + + items = [{"audio_path": p} for p in all_paths] + results: list[AlignmentResult] = self._adapter.align_batch(items) + + for path_idx, result in enumerate(results): + if path_idx >= len(path_to_entry_and_split): + break + entry_idx, split_idx = path_to_entry_and_split[path_idx] + meta_entry = entries[entry_idx] + alignments = [ + { + "word": w.word, + "start": w.start, + "end": w.end, + "confidence": w.confidence, + } + for w in result.alignments + ] + split_metadata = meta_entry.get("split_metadata") + if split_metadata and split_idx < len(split_metadata): + split_metadata[split_idx][self.text_key] = result.text + split_metadata[split_idx][self.alignment_key] = alignments + else: + meta_entry[self.text_key] = result.text + meta_entry[self.alignment_key] = alignments + + # ------------------------------------------------------------------ + # Segment-only path + # ------------------------------------------------------------------ + + def _prepare_segment_batch_with_metadata( + self, + metadata_batch: list[dict[str, Any]], + *, + segments_key: str, + ) -> list[dict[str, Any]]: + """Cut per-segment in-memory mono audio + remember scatter coords. + + Mirrors the pre-split + ``BaseASRProcessorStage._prepare_segment_batch_with_metadata`` + with ``cut_audio_segments=True``. + """ + segment_metadata_list: list[dict[str, Any]] = [] + for metadata_idx, metadata in enumerate(metadata_batch): + audio_path = metadata.get("resampled_audio_filepath", metadata.get("audio_filepath")) + if not audio_path: + continue + audio, sr = torchaudio.load(audio_path) + for segment_idx, segment in enumerate(metadata.get(segments_key, [])): + duration = segment.get("end", 0) - segment.get("start", 0) + if duration >= self.min_len: + start = int(segment["start"] * sr) + end = int(segment["end"] * sr) + audio_segment = audio[:, start:end].squeeze(0) + if len(audio_segment) > 0: + segment_metadata_list.append( + { + "audio_segment": audio_segment.numpy(), + "sample_rate": int(sr), + "metadata_idx": metadata_idx, + "segment_idx": segment_idx, + } + ) + return segment_metadata_list + + def _process_segments(self, tasks: list[AudioTask]) -> None: + entries = [task.data for task in tasks] + if not entries: + return + + scatter_list = self._prepare_segment_batch_with_metadata( + entries, segments_key=self.segments_key + ) + if not scatter_list: + return + + # Adapter consumes only audio_segment + sample_rate; strip our + # bookkeeping fields before dispatch. + items = [ + {"audio_segment": s["audio_segment"], "sample_rate": s["sample_rate"]} + for s in scatter_list + ] + results: list[AlignmentResult] = self._adapter.align_batch(items) + + for scatter, result in zip(scatter_list, results, strict=True): + metadata_idx = scatter["metadata_idx"] + segment_idx = scatter["segment_idx"] + segment = entries[metadata_idx][self.segments_key][segment_idx] + segment[self.text_key] = result.text + if result.alignments: + seg_start = float(segment.get("start", 0.0)) + alignments = [ + { + "word": w.word, + "start": round(w.start + seg_start, 3), + "end": round(w.end + seg_start, 3), + "confidence": w.confidence, + } + for w in result.alignments + ] + segment[self.words_key] = alignments diff --git a/nemo_curator/stages/audio/tagging/__init__.py b/nemo_curator/stages/audio/tagging/__init__.py index 44457c2414..90e87fd7f3 100644 --- a/nemo_curator/stages/audio/tagging/__init__.py +++ b/nemo_curator/stages/audio/tagging/__init__.py @@ -35,14 +35,13 @@ "JoinSplitAudioMetadataStage": "nemo_curator.stages.audio.tagging.split", "SplitASRAlignJoinStage": "nemo_curator.stages.audio.tagging.split", "MergeAlignmentDiarizationStage": "nemo_curator.stages.audio.tagging.merge_alignment_diarization", - # --- Inference (tagging/inference/) --- - "BaseASRProcessorStage": "nemo_curator.stages.audio.tagging.inference.nemo_asr_align", - "NeMoASRAlignerStage": "nemo_curator.stages.audio.tagging.inference.nemo_asr_align", # --- Inference (stage-adapter split per SDP-V2 design) --- "DiarizationStage": "nemo_curator.stages.audio.inference.speaker_diarization", "PyAnnoteDiarizationAdapter": "nemo_curator.adapters.diarization", "VADStage": "nemo_curator.stages.audio.inference.vad", "WhisperXVADAdapter": "nemo_curator.adapters.vad", + "ForcedAlignmentStage": "nemo_curator.stages.audio.inference.alignment", + "NeMoASRAlignAdapter": "nemo_curator.adapters.alignment", } _cache: dict[str, Any] = {} diff --git a/nemo_curator/stages/audio/tagging/inference/nemo_asr_align.py b/nemo_curator/stages/audio/tagging/inference/nemo_asr_align.py deleted file mode 100644 index ebe68b86e5..0000000000 --- a/nemo_curator/stages/audio/tagging/inference/nemo_asr_align.py +++ /dev/null @@ -1,449 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -NeMo ASR Aligner Stage. - -Contains BaseASRProcessorStage (shared config and segment preparation) -and NeMoASRAlignerStage (forced alignment via NeMo FastConformer). - -These stages are tagging-pipeline-specific because they operate on -tagging manifest keys like ``split_filepaths``, ``split_metadata``, -and ``segments``. -""" - -import time -from dataclasses import dataclass, field -from typing import Any - -import nemo.collections.asr as nemo_asr -import torch -import torchaudio -from loguru import logger -from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig -from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig - -from nemo_curator.backends.base import NodeInfo, WorkerMetadata -from nemo_curator.stages.base import ProcessingStage -from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import AudioTask - - -@dataclass -class BaseASRProcessorStage(ProcessingStage[AudioTask, AudioTask]): - """Base class for ASR stages with shared config and segment preparation. - - Provides common fields and _prepare_segment_batch_with_metadata for - segment-only inference. Subclasses must implement setup() and process(). - - Args: - min_len: Minimum length of audio segments to process (seconds). - max_len: Maximum length of audio segments to process (seconds). - num_workers: Number of workers for data loading. - split_batch_size: Max entries/paths per batch when chunking. - infer_segment_only: If True, process segments only; else full audio / meta-entries. - text_key: Key for predicted text in manifest. - words_key: Key for word alignments in manifest (same as SDP alignment_key). - compute_timestamps: Whether to compute word-level timestamps. - segments_key: Key for segments list in manifest. - """ - - # Length constraints - min_len: float = 1.0 - max_len: float = 40.0 - - # Processing parameters - batch_size: int = 32 - num_workers: int = 10 - split_batch_size: int = 5000 - infer_segment_only: bool = False - - # Output keys - text_key: str = "text" - words_key: str = "words" - - compute_timestamps: bool = True - segments_key: str = "segments" - - # Stage metadata (subclasses can override) - name: str = "BaseASRProcessor" - resources: Resources = field(default_factory=lambda: Resources(gpus=1)) - - @property - def _device(self) -> str: - """Derive device from resources configuration.""" - return "cuda" if self.resources.requires_gpu else "cpu" - - def _prepare_segment_batch_with_metadata( - self, - metadata_batch: list[dict], - cut_audio_segments: bool = False, - segments_key: str = "segments", - ) -> list[dict]: - """Prepare segment metadata for a batch. - - Collects segment metadata with indices for later processing. Mirrors - generic_sdp BaseASRProcessor._prepare_segment_batch_with_metadata. - - Args: - metadata_batch: List of metadata dicts, each with a segments list. - cut_audio_segments: If True, load audio and cut segments (numpy); - if False, only collect resampled_audio_filepath from segments. - segments_key: Key for the segments list in each metadata dict. - - Returns: - List of segment metadata dicts with metadata_idx, segment_idx, and - either "audio_segment" (numpy) or "resampled_audio_filepath". - """ - segment_metadata_list: list[dict] = [] - - if cut_audio_segments: - for metadata_idx, metadata in enumerate(metadata_batch): - audio_path = metadata.get("resampled_audio_filepath", metadata.get("audio_filepath")) - if not audio_path: - continue - audio, sr = torchaudio.load(audio_path) - for segment_idx, segment in enumerate(metadata.get(segments_key, [])): - duration = segment.get("end", 0) - segment.get("start", 0) - if duration >= self.min_len: - start = int(segment["start"] * sr) - end = int(segment["end"] * sr) - audio_segment = audio[:, start:end].squeeze(0) - if len(audio_segment) > 0: - segment_metadata_list.append( - { - "audio_segment": audio_segment.numpy(), - "metadata_idx": metadata_idx, - "segment_idx": segment_idx, - } - ) - else: - for metadata_idx, metadata in enumerate(metadata_batch): - for segment_idx, segment in enumerate(metadata.get(segments_key, [])): - if "resampled_audio_filepath" in segment: - segment_metadata_list.append( - { - "resampled_audio_filepath": segment["resampled_audio_filepath"], - "metadata_idx": metadata_idx, - "segment_idx": segment_idx, - } - ) - - return segment_metadata_list - - -@dataclass -class NeMoASRAlignerStage(BaseASRProcessorStage): - """ - Stage that aligns text and audio using NeMo ASR models. - - Uses a pre-trained ASR model to transcribe audio files and generate - word-level alignments with timestamps. Supports both CTC and RNNT decoders and - can process either full audio files or just specific segments. - - Args: - model_name (str): Name of pretrained model to use. Defaults to "nvidia/parakeet-tdt_ctc-1.1b" - model_path (str, Optional): Path to local model file. If provided, overrides model_name - is_fastconformer (bool): Whether model's encoder is FastConformer - decoder_type (str): Type of decoder ('ctc' or 'rnnt'). Defaults to "rnnt" - transcribe_batch_size (int): Batch size for transcribing. Defaults to 32 - timestamp_type (str): Type of timestamp ('word' or 'char') - disable_word_confidence (bool): Whether to disable word confidence score computation - """ - - # Model configuration - model_name: str = "nvidia/parakeet-tdt_ctc-1.1b" - model_path: str | None = None - - # Length constraints - min_len: float = 1.0 - max_len: float = 40.0 - - # Model settings - is_fastconformer: bool = True - decoder_type: str = "rnnt" - - # Processing parameters - transcribe_batch_size: int = 32 - num_workers: int = 10 - batch_size: int = 100 - - # Timestamp settings - compute_timestamps: bool = True - timestamp_type: str = "word" - - # Processing mode - infer_segment_only: bool = False - - # input keys - segments_key: str = "segments" - - # Output keys - text_key: str = "text" - words_key: str = "words" - disable_word_confidence: bool = False - - # Stage metadata - name: str = "NeMoASRAligner" - _asr_model: Any = field(default=None, repr=False) - _override_cfg: Any = field(default=None, repr=False) - - def __post_init__(self) -> None: - """Validate config.""" - if self.decoder_type not in ["ctc", "rnnt"]: - msg = f"decoder_type must be 'ctc' or 'rnnt', got {self.decoder_type}" - raise ValueError(msg) - - def load_model(self) -> None: - if self.model_path: - self._asr_model = nemo_asr.models.ASRModel.restore_from(restore_path=self.model_path) - else: - self._asr_model = nemo_asr.models.ASRModel.from_pretrained( - model_name=self.model_name, map_location=torch.device(self._device) - ) - - def setup_on_node( - self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None - ) -> None: - """Download model weights without loading into memory (called once per node).""" - if self._asr_model is None: - if self.model_path: - return - try: - nemo_asr.models.ASRModel.from_pretrained(model_name=self.model_name, return_model_file=True) - except Exception as e: - msg = f"[{self.name}] Failed to download model {self.model_name}" - raise RuntimeError(msg) from e - - def setup(self, _: WorkerMetadata | None = None) -> None: - """Load model to device and configure decoding (called per replica).""" - if self._asr_model is None: - self.load_model() - - self._asr_model.to(self._device) - self._asr_model.eval() - - if self.is_fastconformer: - self._asr_model.change_attention_model( - self_attention_model="rel_pos_local_attn", att_context_size=[128, 128] - ) - self._asr_model.change_subsampling_conv_chunking_factor(1) - - decoding_cfg = CTCDecodingConfig() if self.decoder_type == "ctc" else RNNTDecodingConfig() - - if self.decoder_type == "ctc": - decoding_cfg.strategy = "greedy_batch" - else: - decoding_cfg.rnnt_timestamp_type = self.timestamp_type - - decoding_cfg.preserve_alignments = self.compute_timestamps - decoding_cfg.confidence_cfg.preserve_word_confidence = not self.disable_word_confidence - decoding_cfg.compute_timestamps = self.compute_timestamps - decoding_cfg.greedy.compute_timestamps = self.compute_timestamps - - self._asr_model.change_decoding_strategy(decoding_cfg=decoding_cfg) - - self._override_cfg = self._asr_model.get_transcribe_config() - self._override_cfg.batch_size = self.transcribe_batch_size - self._override_cfg.num_workers = self.num_workers - self._override_cfg.return_hypotheses = True - self._override_cfg.timestamps = self.compute_timestamps - - logger.info(f"[{self.name}] Initialized ASR model on {self._device}") - - def inputs(self) -> tuple[list[str], list[str]]: - return ["data"], ["duration", self.segments_key, "split_filepaths", "split_metadata"] - - def outputs(self) -> tuple[list[str], list[str]]: - return ["data"], ["duration", self.segments_key, "split_filepaths", "split_metadata"] - - def get_alignments_text(self, hypotheses: Any) -> tuple[list, str]: # noqa: ANN401 - """Extract word alignments and text from model hypotheses.""" - if not self.compute_timestamps: - return [], hypotheses.text - - timestamp_dict = hypotheses.timestamp - - if self.is_fastconformer: - time_stride = 8 * self._asr_model.cfg.preprocessor.window_stride - else: - time_stride = 4 * self._asr_model.cfg.preprocessor.window_stride - - word_timestamps = timestamp_dict[self.timestamp_type] - - alignments = [] - for i, stamp in enumerate(word_timestamps): - conf = None - if hypotheses.word_confidence is not None and i < len(hypotheses.word_confidence): - conf = hypotheses.word_confidence[i] - if isinstance(conf, torch.Tensor): - conf = conf.item() - conf = round(conf, 4) - - if self.decoder_type == "ctc": - start = stamp["start_offset"] * time_stride - end = stamp["end_offset"] * time_stride - else: - start = max(0, stamp["start_offset"] * time_stride - 0.08) - end = max(0, stamp["end_offset"] * time_stride - 0.08) - - word = stamp.get("word", stamp.get("char", "")) - alignments.append( - { - "word": word, - "start": round(start, 3), - "end": round(end, 3), - "confidence": conf, - } - ) - - text = " ".join(w["word"] for w in alignments) - text = text.replace("⁇", "") - - return alignments, text - - def process(self, task: AudioTask) -> AudioTask: - results = self.process_batch([task]) - return results[0] if results else task - - def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: - """Process a batch of AudioTasks for ASR alignment.""" - if len(tasks) == 0: - return [] - t0 = time.perf_counter() - results = self.process_segments(tasks) if self.infer_segment_only else self.process_full_audio(tasks) - - self._log_metrics( - { - "process_time": time.perf_counter() - t0, - "entries_processed": len(tasks), - } - ) - return results - - def process_full_audio(self, tasks: list[AudioTask]) -> list[AudioTask]: # noqa: C901, PLR0912, PLR0915 - """Process entries as full audio (or meta-entries with split_filepaths).""" - entries = [task.data for task in tasks] - skip_indices = [] - meta_indices = [] - for i, data in enumerate(entries): - split_filepaths = data.get("split_filepaths") - has_splits = isinstance(split_filepaths, list) and len(split_filepaths) > 0 - if has_splits or split_filepaths is None: - meta_indices.append(i) - else: - skip_indices.append(i) - - for i in skip_indices: - entries[i][self.text_key] = "" - entries[i]["alignment"] = [] - - # collect all split paths of all entries in the batch - all_paths = [] - path_to_entry_and_split = [] - for entry_idx in meta_indices: - meta_entry = entries[entry_idx] - split_filepaths = meta_entry.get("split_filepaths") - if not split_filepaths: - logger.warning(f"[{self.name}] Entry at index {entry_idx} has no split_filepaths, skipping.") - continue - for split_idx, path in enumerate(split_filepaths): - all_paths.append(path) - path_to_entry_and_split.append((entry_idx, split_idx)) - - if not all_paths: - return tasks - - try: - with torch.no_grad(): - hypotheses_list = self._asr_model.transcribe(all_paths, override_config=self._override_cfg) - if isinstance(hypotheses_list, tuple) and len(hypotheses_list) == 2: # noqa: PLR2004 - hypotheses_list = hypotheses_list[0] - except Exception as e: # noqa: BLE001 - logger.error( - f"[{self.name}] Exception for meta-entries batch: {e!s} for paths: {all_paths}, transcribing one by one" - ) - hypotheses_list = [] - for path in all_paths: - try: - with torch.no_grad(): - hyp = self._asr_model.transcribe([path], override_config=self._override_cfg) - if isinstance(hyp, tuple) and len(hyp) == 2: # noqa: PLR2004 - hyp = hyp[0] - hypotheses_list.append(hyp[0] if hyp else None) - except Exception as e2: # noqa: BLE001, PERF203 - logger.error(f"[{self.name}] Exception for {path}: {e2}") - hypotheses_list.append(None) - - for path_idx, hyp in enumerate(hypotheses_list): - if path_idx >= len(path_to_entry_and_split): - break - entry_idx, split_idx = path_to_entry_and_split[path_idx] - meta_entry = entries[entry_idx] - if hyp is not None: - alignments, text = self.get_alignments_text(hyp) - else: - alignments, text = [], "" - - split_metadata = meta_entry.get("split_metadata") - if split_metadata and split_idx < len(split_metadata): - split_metadata[split_idx][self.text_key] = text - split_metadata[split_idx]["alignment"] = alignments - else: - meta_entry[self.text_key] = text - meta_entry["alignment"] = alignments - - return tasks - - def process_segments(self, tasks: list[AudioTask]) -> list[AudioTask]: - """Process entries in segment-only mode (infer per segment).""" - entries = [task.data for task in tasks] - if not entries: - return [] - - segment_metadata_list = self._prepare_segment_batch_with_metadata( - entries, - cut_audio_segments=True, - segments_key=self.segments_key, - ) - all_segments = [seg["audio_segment"] for seg in segment_metadata_list] - - if len(all_segments) == 0: - return tasks - - try: - with torch.no_grad(): - hypotheses_list = self._asr_model.transcribe(all_segments, override_config=self._override_cfg) - except Exception as e: - files_list = [x.get("resampled_audio_filepath", x.get("audio_filepath")) for x in entries] - msg = f"[{self.name}] Exception for audio list: {files_list}, error: {e}" - raise ValueError(msg) from e - - if isinstance(hypotheses_list, tuple) and len(hypotheses_list) == 2: # noqa: PLR2004 - hypotheses_list = hypotheses_list[0] - - for segment_metadata, hypotheses in zip(segment_metadata_list, hypotheses_list, strict=True): - alignments, text = self.get_alignments_text(hypotheses) - metadata_idx = segment_metadata["metadata_idx"] - segment_idx = segment_metadata["segment_idx"] - segment = entries[metadata_idx][self.segments_key][segment_idx] - segment[self.text_key] = text - if self.compute_timestamps: - seg_start = segment.get("start", 0) - for word in alignments: - word["start"] = round(word["start"] + seg_start, 3) - word["end"] = round(word["end"] + seg_start, 3) - segment[self.words_key] = alignments - - return tasks diff --git a/nemo_curator/stages/audio/tagging/split.py b/nemo_curator/stages/audio/tagging/split.py index b522ca0db7..44a19afa4c 100644 --- a/nemo_curator/stages/audio/tagging/split.py +++ b/nemo_curator/stages/audio/tagging/split.py @@ -25,7 +25,7 @@ from fsspec.core import url_to_fs from loguru import logger -from nemo_curator.stages.audio.tagging.inference.nemo_asr_align import NeMoASRAlignerStage +from nemo_curator.stages.audio.inference.alignment import ForcedAlignmentStage from nemo_curator.stages.base import CompositeStage, ProcessingStage from nemo_curator.tasks import AudioTask @@ -289,7 +289,7 @@ class SplitASRAlignJoinStage(CompositeStage[AudioTask, AudioTask]): Decomposes into three sequential stages that always run together: 1. SplitLongAudioStage — splits audio exceeding ``suggested_max_len`` - 2. NeMoASRAlignerStage — transcribes and aligns each chunk + 2. ForcedAlignmentStage (with NeMoASRAlignAdapter) — transcribes and aligns each chunk 3. JoinSplitAudioMetadataStage — merges transcripts back into original entries Args: @@ -355,24 +355,26 @@ def decompose(self) -> list[ProcessingStage]: suggested_max_len=self.suggested_max_len, min_len=self.min_len, ), - NeMoASRAlignerStage( - model_name=self.model_name, - model_path=self.model_path, - is_fastconformer=self.is_fastconformer, - decoder_type=self.decoder_type, - min_len=self.min_len, - max_len=self.max_len, - batch_size=self.batch_size, - transcribe_batch_size=self.transcribe_batch_size, - split_batch_size=self.split_batch_size, - num_workers=self.num_workers, - infer_segment_only=self.infer_segment_only, - compute_timestamps=self.compute_timestamps, - timestamp_type=self.timestamp_type, + ForcedAlignmentStage( + adapter_target="nemo_curator.adapters.alignment.NeMoASRAlignAdapter", + model_id=self.model_name, text_key=self.text_key, words_key=self.words_key, - disable_word_confidence=self.disable_word_confidence, segments_key=self.segments_key, + infer_segment_only=self.infer_segment_only, + min_len=self.min_len, + max_len=self.max_len, + batch_size=self.batch_size, + adapter_kwargs={ + "model_path": self.model_path, + "is_fastconformer": self.is_fastconformer, + "decoder_type": self.decoder_type, + "transcribe_batch_size": self.transcribe_batch_size, + "num_workers": self.num_workers, + "compute_timestamps": self.compute_timestamps, + "timestamp_type": self.timestamp_type, + "disable_word_confidence": self.disable_word_confidence, + }, ), JoinSplitAudioMetadataStage(), ] diff --git a/nemo_curator/stages/audio/tagging/inference/__init__.py b/tests/adapters/alignment/__init__.py similarity index 100% rename from nemo_curator/stages/audio/tagging/inference/__init__.py rename to tests/adapters/alignment/__init__.py diff --git a/tests/adapters/alignment/test_nemo_asr_align_adapter.py b/tests/adapters/alignment/test_nemo_asr_align_adapter.py new file mode 100644 index 0000000000..5d892ef68f --- /dev/null +++ b/tests/adapters/alignment/test_nemo_asr_align_adapter.py @@ -0,0 +1,241 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for NeMoASRAlignAdapter (NeMo ASR model mocked).""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +pytest.importorskip("nemo.collections.asr") + +from nemo_curator.adapters.alignment import ( + AlignmentResult, + ForcedAlignmentAdapter, + NeMoASRAlignAdapter, +) + + +class TestConstruction: + def test_defaults(self) -> None: + a = NeMoASRAlignAdapter() + assert a.model_id == "nvidia/parakeet-tdt_ctc-1.1b" + assert a.is_fastconformer is True + assert a.decoder_type == "rnnt" + assert a.timestamp_type == "word" + assert a.compute_timestamps is True + assert a.disable_word_confidence is False + assert a.last_metrics == {} + + def test_rejects_unknown_decoder(self) -> None: + with pytest.raises(ValueError, match="decoder_type"): + NeMoASRAlignAdapter(decoder_type="beam") + + def test_rejects_unknown_timestamp_type(self) -> None: + with pytest.raises(ValueError, match="timestamp_type"): + NeMoASRAlignAdapter(timestamp_type="phoneme") + + def test_conforms_to_protocol(self) -> None: + assert isinstance(NeMoASRAlignAdapter(), ForcedAlignmentAdapter) + + +class TestPrefetchWeights: + @patch("nemo_curator.adapters.alignment.nemo_asr_align.nemo_asr.models.ASRModel.from_pretrained") + def test_calls_from_pretrained(self, mock_from: MagicMock) -> None: + NeMoASRAlignAdapter.prefetch_weights("nvidia/parakeet-tdt_ctc-1.1b") + mock_from.assert_called_once_with( + model_name="nvidia/parakeet-tdt_ctc-1.1b", return_model_file=True + ) + + @patch( + "nemo_curator.adapters.alignment.nemo_asr_align.nemo_asr.models.ASRModel.from_pretrained", + side_effect=Exception("download failed"), + ) + def test_failure_wrapped_as_runtime_error(self, _mock_from: MagicMock) -> None: + with pytest.raises(RuntimeError, match="failed to download"): + NeMoASRAlignAdapter.prefetch_weights("bad/model") + + @patch("nemo_curator.adapters.alignment.nemo_asr_align.nemo_asr.models.ASRModel.from_pretrained") + def test_empty_model_id_noop(self, mock_from: MagicMock) -> None: + NeMoASRAlignAdapter.prefetch_weights("") + mock_from.assert_not_called() + + +class TestAlignBatch: + def test_empty_batch_returns_empty(self) -> None: + a = NeMoASRAlignAdapter() + assert a.align_batch([]) == [] + + def test_requires_setup(self) -> None: + a = NeMoASRAlignAdapter() + with pytest.raises(RuntimeError, match="setup\\(\\) must be called"): + a.align_batch([{"audio_path": "/p/x.wav"}]) + + def test_path_mode_dispatch(self) -> None: + a = NeMoASRAlignAdapter(compute_timestamps=False) + a._asr_model = MagicMock() + a._override_cfg = MagicMock() + a._asr_model.transcribe.return_value = [ + SimpleNamespace(text="hello world", timestamp={}, word_confidence=None), + SimpleNamespace(text="bye", timestamp={}, word_confidence=None), + ] + results = a.align_batch([{"audio_path": "/p/1.wav"}, {"audio_path": "/p/2.wav"}]) + assert len(results) == 2 + assert all(isinstance(r, AlignmentResult) for r in results) + assert results[0].text == "hello world" + assert results[1].text == "bye" + # Transcribe got the path list, not numpy arrays. + passed = a._asr_model.transcribe.call_args.args[0] + assert passed == ["/p/1.wav", "/p/2.wav"] + assert a.last_metrics["mode_is_segment"] == 0.0 + assert a.last_metrics["batch_size"] == 2.0 + + def test_segment_mode_dispatch(self) -> None: + a = NeMoASRAlignAdapter(compute_timestamps=False) + a._asr_model = MagicMock() + a._override_cfg = MagicMock() + a._asr_model.transcribe.return_value = [ + SimpleNamespace(text="seg-text", timestamp={}, word_confidence=None), + ] + seg = np.zeros(16000, dtype=np.float32) + results = a.align_batch([{"audio_segment": seg, "sample_rate": 16000}]) + assert len(results) == 1 + assert results[0].text == "seg-text" + # Transcribe got the numpy list. + passed = a._asr_model.transcribe.call_args.args[0] + assert len(passed) == 1 + assert passed[0] is seg + assert a.last_metrics["mode_is_segment"] == 1.0 + + def test_transcribe_returns_tuple_unwrapped(self) -> None: + a = NeMoASRAlignAdapter(compute_timestamps=False) + a._asr_model = MagicMock() + a._override_cfg = MagicMock() + hyp = [SimpleNamespace(text="x", timestamp={}, word_confidence=None)] + a._asr_model.transcribe.return_value = (hyp, None) + results = a.align_batch([{"audio_path": "/p/x.wav"}]) + assert results[0].text == "x" + + def test_batch_failure_falls_back_one_by_one_in_path_mode(self) -> None: + a = NeMoASRAlignAdapter(compute_timestamps=False) + a._asr_model = MagicMock() + a._override_cfg = MagicMock() + call_count = {"n": 0} + + def transcribe(args, override_config: object) -> object: + del override_config + call_count["n"] += 1 + # First call (batch) blows up; per-item retries succeed. + if len(args) > 1: + raise RuntimeError("boom") + return [SimpleNamespace(text=f"ok-{args[0]}", timestamp={}, word_confidence=None)] + + a._asr_model.transcribe.side_effect = transcribe + results = a.align_batch([{"audio_path": "/a.wav"}, {"audio_path": "/b.wav"}]) + assert [r.text for r in results] == ["ok-/a.wav", "ok-/b.wav"] + # 1 batch call + 2 per-item retries. + assert call_count["n"] == 3 + + def test_batch_failure_in_segment_mode_raises(self) -> None: + a = NeMoASRAlignAdapter(compute_timestamps=False) + a._asr_model = MagicMock() + a._override_cfg = MagicMock() + a._asr_model.transcribe.side_effect = RuntimeError("boom") + with pytest.raises(ValueError, match="segment mode"): + a.align_batch([{"audio_segment": np.zeros(16000), "sample_rate": 16000}]) + + +class TestGetAlignmentsText: + def _adapter(self, *, decoder_type: str, is_fastconformer: bool) -> NeMoASRAlignAdapter: + a = NeMoASRAlignAdapter( + decoder_type=decoder_type, + is_fastconformer=is_fastconformer, + timestamp_type="word", + compute_timestamps=True, + ) + # Mock cfg.preprocessor.window_stride for time_stride math. + a._asr_model = MagicMock() + a._asr_model.cfg.preprocessor.window_stride = 0.01 + a._override_cfg = MagicMock() + return a + + def test_ctc_path_mode_emits_word_alignment(self) -> None: + a = self._adapter(decoder_type="ctc", is_fastconformer=True) + # FastConformer time_stride = 8 * 0.01 = 0.08 + hyp = SimpleNamespace( + text="ignored", + timestamp={ + "word": [ + {"word": "hello", "start_offset": 10, "end_offset": 20}, + {"word": "world", "start_offset": 25, "end_offset": 40}, + ] + }, + word_confidence=[0.9, 0.8], + ) + a._asr_model.transcribe.return_value = [hyp] + results = a.align_batch([{"audio_path": "/x.wav"}]) + assert len(results) == 1 + result = results[0] + # 10 * 0.08 = 0.8, 20 * 0.08 = 1.6 (rounded to 3 dp) + assert result.alignments[0].word == "hello" + assert result.alignments[0].start == 0.8 + assert result.alignments[0].end == 1.6 + assert result.alignments[0].confidence == 0.9 + assert result.text == "hello world" + + def test_rnnt_path_subtracts_0_08(self) -> None: + a = self._adapter(decoder_type="rnnt", is_fastconformer=True) + # RNNT: start/end = max(0, offset * stride - 0.08) + hyp = SimpleNamespace( + text="ignored", + timestamp={"word": [{"word": "hi", "start_offset": 10, "end_offset": 20}]}, + word_confidence=None, + ) + a._asr_model.transcribe.return_value = [hyp] + results = a.align_batch([{"audio_path": "/x.wav"}]) + # 10 * 0.08 = 0.8, minus 0.08 = 0.72; 20 * 0.08 = 1.6, minus 0.08 = 1.52 + assert results[0].alignments[0].start == 0.72 + assert results[0].alignments[0].end == 1.52 + assert results[0].alignments[0].confidence is None + + def test_question_mark_glyph_stripped_from_text(self) -> None: + a = self._adapter(decoder_type="ctc", is_fastconformer=True) + hyp = SimpleNamespace( + text="ignored", + timestamp={ + "word": [ + {"word": "hello\u2047", "start_offset": 0, "end_offset": 10}, + {"word": "world", "start_offset": 15, "end_offset": 20}, + ] + }, + word_confidence=None, + ) + a._asr_model.transcribe.return_value = [hyp] + results = a.align_batch([{"audio_path": "/x.wav"}]) + # ⁇ (U+2047) stripped from the JOINED text only; individual words keep the raw form. + assert results[0].text == "hello world" + + def test_compute_timestamps_false_returns_text_only(self) -> None: + a = NeMoASRAlignAdapter(compute_timestamps=False) + a._asr_model = MagicMock() + a._override_cfg = MagicMock() + hyp = SimpleNamespace(text="plain", timestamp={}, word_confidence=None) + a._asr_model.transcribe.return_value = [hyp] + results = a.align_batch([{"audio_path": "/x.wav"}]) + assert results[0].text == "plain" + assert results[0].alignments == [] diff --git a/tests/stages/audio/tagging/inference/__init__.py b/tests/stages/audio/inference/alignment/__init__.py similarity index 100% rename from tests/stages/audio/tagging/inference/__init__.py rename to tests/stages/audio/inference/alignment/__init__.py diff --git a/tests/stages/audio/inference/alignment/test_forced_alignment_stage.py b/tests/stages/audio/inference/alignment/test_forced_alignment_stage.py new file mode 100644 index 0000000000..4ab9e31e07 --- /dev/null +++ b/tests/stages/audio/inference/alignment/test_forced_alignment_stage.py @@ -0,0 +1,280 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the generic ForcedAlignmentStage. + +Stage is tested with a fake adapter -- no NeMo / torch / GPU needed. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from nemo_curator.adapters.alignment import AlignmentResult, WordAlignment +from nemo_curator.stages.audio.inference.alignment import ForcedAlignmentStage +from nemo_curator.tasks import AudioTask + + +@dataclass +class _FakeAlignAdapter: + model_id: str = "fake/align" + revision: str | None = None + device: str = "cpu" + setup_called: int = 0 + teardown_called: int = 0 + last_batch: list[dict[str, Any]] | None = None + last_metrics: dict[str, float] = field(default_factory=dict) + fixed_results: list[AlignmentResult] | None = None + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + del model_id, revision + + def setup(self) -> None: + self.setup_called += 1 + + def teardown(self) -> None: + self.teardown_called += 1 + + def align_batch(self, items: list[dict[str, Any]]) -> list[AlignmentResult]: + self.last_batch = list(items) + self.last_metrics = {"batch_size": float(len(items))} + if self.fixed_results is not None: + return list(self.fixed_results) + return [ + AlignmentResult( + alignments=[WordAlignment(word="hello", start=0.0, end=0.5, confidence=0.9)], + text="hello", + model_id=self.model_id, + ) + for _ in items + ] + + +_ADAPTER_TARGET = f"{__name__}._FakeAlignAdapter" + + +class TestConstruction: + def test_requires_adapter_target(self) -> None: + with pytest.raises(ValueError, match="adapter_target is required"): + ForcedAlignmentStage() + + def test_default_io(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + ins, outs = s.inputs() + assert ins == ["data"] + assert "split_filepaths" in outs + assert "split_metadata" in outs + + +class TestLifecycle: + def test_setup_instantiates_with_tier1_tier2(self) -> None: + s = ForcedAlignmentStage( + adapter_target=_ADAPTER_TARGET, + model_id="m", + revision="r", + adapter_kwargs={"setup_called": 0}, + ) + s.setup() + assert isinstance(s._adapter, _FakeAlignAdapter) + assert s._adapter.model_id == "m" + assert s._adapter.revision == "r" + assert s._adapter.setup_called == 1 + + def test_teardown_clears(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + s.setup() + adapter = s._adapter + s.teardown() + assert s._adapter is None + assert adapter.teardown_called == 1 + + def test_prefetch_called_setup_on_node(self) -> None: + with patch.object(_FakeAlignAdapter, "prefetch_weights") as mock_pf: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET, model_id="m", revision="r") + s.setup_on_node() + mock_pf.assert_called_once_with("m", "r") + + def test_prefetch_failure_swallowed(self) -> None: + with patch.object(_FakeAlignAdapter, "prefetch_weights", side_effect=RuntimeError("x")): + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET, prefetch_fail_on_error=False) + s.setup_on_node() + + +class TestProcessBatchFullAudio: + def test_requires_setup(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + with pytest.raises(RuntimeError, match="setup\\(\\) was not called"): + s.process_batch([AudioTask(data={})]) + + def test_empty_batch_returns_empty(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + s.setup() + assert s.process_batch([]) == [] + + def test_scatters_results_into_split_metadata(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + s.setup() + s._adapter.fixed_results = [ + AlignmentResult( + alignments=[WordAlignment(word="a", start=0.0, end=0.3, confidence=0.8)], + text="a", + model_id="m", + ), + AlignmentResult( + alignments=[WordAlignment(word="b", start=0.0, end=0.4, confidence=0.7)], + text="b", + model_id="m", + ), + ] + task = AudioTask( + data={ + "split_filepaths": ["/p/1.wav", "/p/2.wav"], + "split_metadata": [ + {"resampled_audio_filepath": "/p/1.wav"}, + {"resampled_audio_filepath": "/p/2.wav"}, + ], + } + ) + out = s.process_batch([task]) + assert len(out) == 1 + md = out[0].data["split_metadata"] + assert md[0]["text"] == "a" + assert md[0]["alignment"][0]["word"] == "a" + assert md[1]["text"] == "b" + assert md[1]["alignment"][0]["word"] == "b" + # Adapter received homogeneous path-mode items. + batch = s._adapter.last_batch + assert batch is not None and len(batch) == 2 + assert all("audio_path" in item for item in batch) + assert all("audio_segment" not in item for item in batch) + + def test_writes_top_level_when_no_split_metadata(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + s.setup() + s._adapter.fixed_results = [ + AlignmentResult(alignments=[], text="hello", model_id="m"), + ] + task = AudioTask(data={"split_filepaths": ["/p/x.wav"], "split_metadata": []}) + out = s.process_batch([task]) + assert out[0].data["text"] == "hello" + assert out[0].data["alignment"] == [] + + def test_sentinel_split_filepaths_empty_string(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + s.setup() + task = AudioTask(data={"split_filepaths": "skip-me"}) + out = s.process_batch([task]) + assert out[0].data["text"] == "" + assert out[0].data["alignment"] == [] + + def test_missing_split_filepaths_passes_through(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + s.setup() + # split_filepaths key absent -> meta entry mode; adapter still called with 0 items. + # Stage handles by emitting empty results -- nothing to scatter. + task = AudioTask(data={"split_metadata": []}) + out = s.process_batch([task]) + # No exception, no text written (no splits -> nothing to do). + assert "text" not in out[0].data + + +class TestProcessBatchSegmentMode: + @patch("nemo_curator.stages.audio.inference.alignment.stage.torchaudio.load") + def test_segments_cut_and_scattered_with_time_offset(self, mock_load: MagicMock) -> None: + import torch + + mock_load.return_value = (torch.zeros(1, 16000 * 10), 16000) + s = ForcedAlignmentStage( + adapter_target=_ADAPTER_TARGET, + infer_segment_only=True, + min_len=0.5, + ) + s.setup() + s._adapter.fixed_results = [ + AlignmentResult( + alignments=[WordAlignment(word="hi", start=0.0, end=0.2, confidence=0.9)], + text="hi", + model_id="m", + ), + AlignmentResult( + alignments=[WordAlignment(word="bye", start=0.0, end=0.3, confidence=0.95)], + text="bye", + model_id="m", + ), + ] + task = AudioTask( + data={ + "resampled_audio_filepath": "/p/x.wav", + "segments": [ + {"start": 1.0, "end": 2.5}, # 1.5s -> included + {"start": 3.0, "end": 4.6}, # 1.6s -> included + {"start": 5.0, "end": 5.2}, # 0.2s -> skipped (< min_len) + ], + } + ) + out = s.process_batch([task]) + segs = out[0].data["segments"] + # First segment cut + transcribed + words[start] offset by seg start (1.0). + assert segs[0]["text"] == "hi" + assert segs[0]["words"][0]["start"] == 1.0 + assert segs[0]["words"][0]["end"] == 1.2 + # Second segment offset by 3.0. + assert segs[1]["text"] == "bye" + assert segs[1]["words"][0]["start"] == 3.0 + assert segs[1]["words"][0]["end"] == 3.3 + # Third segment skipped (too short) -> no text added. + assert "text" not in segs[2] + # Adapter received homogeneous segment-mode items. + batch = s._adapter.last_batch + assert batch is not None and len(batch) == 2 + assert all("audio_segment" in item for item in batch) + assert all("audio_path" not in item for item in batch) + + @patch("nemo_curator.stages.audio.inference.alignment.stage.torchaudio.load") + def test_no_eligible_segments_does_not_call_adapter(self, mock_load: MagicMock) -> None: + import torch + + mock_load.return_value = (torch.zeros(1, 16000 * 10), 16000) + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET, infer_segment_only=True, min_len=2.0) + s.setup() + task = AudioTask( + data={ + "resampled_audio_filepath": "/p/x.wav", + "segments": [{"start": 0.0, "end": 0.5}], + } + ) + s.process_batch([task]) + assert s._adapter.last_batch is None + + +class TestMetrics: + def test_logs_entries_processed_and_adapter_aliases(self) -> None: + s = ForcedAlignmentStage(adapter_target=_ADAPTER_TARGET) + s.setup() + observed: list[dict[str, float]] = [] + s._log_metrics = observed.append # type: ignore[assignment] + task = AudioTask( + data={"split_filepaths": ["/p/1.wav"], "split_metadata": [{"resampled_audio_filepath": "/p/1.wav"}]} + ) + s.process_batch([task]) + assert observed + m = observed[-1] + assert m["entries_processed"] == 1.0 + assert m["model_batch_size"] == 1.0 + assert "process_time" in m diff --git a/tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml b/tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml index 3c1c506b1e..c9b0b5a2cf 100644 --- a/tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml +++ b/tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml @@ -52,16 +52,19 @@ stages: min_len: 1.0 resources: ${resources} - # 4. ASR forced alignment (NeMo) - - _target_: nemo_curator.stages.audio.tagging.inference.nemo_asr_align.NeMoASRAlignerStage + # 4. ASR forced alignment (ForcedAlignmentStage + NeMoASRAlignAdapter) + - _target_: nemo_curator.stages.audio.inference.alignment.ForcedAlignmentStage name: "ASRAlignment" - is_fastconformer: true - decoder_type: ctc + adapter_target: nemo_curator.adapters.alignment.NeMoASRAlignAdapter + model_id: "nvidia/parakeet-tdt_ctc-1.1b" batch_size: 1 - transcribe_batch_size: 4 - num_workers: 1 resources: gpus: 1 + adapter_kwargs: + is_fastconformer: true + decoder_type: ctc + transcribe_batch_size: 4 + num_workers: 1 # 5. Join split audio metadata - _target_: nemo_curator.stages.audio.tagging.JoinSplitAudioMetadataStage diff --git a/tests/stages/audio/tagging/e2e/test_tts_e2e.py b/tests/stages/audio/tagging/e2e/test_tts_e2e.py index f914725ebe..2010ae4a68 100644 --- a/tests/stages/audio/tagging/e2e/test_tts_e2e.py +++ b/tests/stages/audio/tagging/e2e/test_tts_e2e.py @@ -51,7 +51,7 @@ def test_tts_e2e(tmp_path: Path, get_input_manifest: str) -> None: cfg.hf_token = os.getenv("HF_TOKEN", "") cfg.language_short = "en" - # Override NeMoASRAlignerStage (index 4) to use CTC model for CPU testing + # Override ForcedAlignmentStage (index 4) to use CTC model for CPU testing cfg.stages[4].model_name = "nvidia/stt_en_fastconformer_ctc_large" cfg.stages[4].is_fastconformer = True cfg.stages[4].decoder_type = "ctc" diff --git a/tests/stages/audio/tagging/inference/test_base_asr_processor.py b/tests/stages/audio/tagging/inference/test_base_asr_processor.py deleted file mode 100644 index 40de7f1b39..0000000000 --- a/tests/stages/audio/tagging/inference/test_base_asr_processor.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from nemo_curator.backends.base import WorkerMetadata -from nemo_curator.stages.audio.tagging.inference.nemo_asr_align import BaseASRProcessorStage -from nemo_curator.tasks import AudioTask - - -class ConcreteASRProcessor(BaseASRProcessorStage): - """Concrete subclass for testing base behavior.""" - - def setup(self, _: WorkerMetadata | None = None) -> None: - pass - - def process(self, task: AudioTask) -> AudioTask: - return task - - -class TestBaseASRProcessorStagePrepareSegmentBatch: - """Tests for BaseASRProcessorStage._prepare_segment_batch_with_metadata.""" - - def test_collects_segment_paths_without_cutting_audio(self) -> None: - """When cut_audio_segments=False, collects resampled_audio_filepath from segments.""" - stage = ConcreteASRProcessor() - metadata_batch = [ - { - "segments": [ - { - "start": 0.0, - "end": 1.0, - "resampled_audio_filepath": "/path/1.wav", - }, - { - "start": 1.0, - "end": 2.0, - "resampled_audio_filepath": "/path/2.wav", - }, - ], - }, - { - "segments": [ - { - "start": 0.0, - "end": 1.5, - "resampled_audio_filepath": "/path/3.wav", - }, - ], - }, - ] - result = stage._prepare_segment_batch_with_metadata( - metadata_batch, cut_audio_segments=False, segments_key="segments" - ) - assert len(result) == 3 - assert result[0]["resampled_audio_filepath"] == "/path/1.wav" - assert result[0]["metadata_idx"] == 0 - assert result[0]["segment_idx"] == 0 - assert result[1]["resampled_audio_filepath"] == "/path/2.wav" - assert result[1]["metadata_idx"] == 0 - assert result[1]["segment_idx"] == 1 - assert result[2]["resampled_audio_filepath"] == "/path/3.wav" - assert result[2]["metadata_idx"] == 1 - assert result[2]["segment_idx"] == 0 - - def test_skips_segments_without_resampled_audio_filepath(self) -> None: - """Segments missing resampled_audio_filepath are not included.""" - stage = ConcreteASRProcessor() - metadata_batch = [ - { - "segments": [ - {"start": 0.0, "end": 1.0}, - {"start": 1.0, "end": 2.0, "resampled_audio_filepath": "/only.wav"}, - ], - }, - ] - result = stage._prepare_segment_batch_with_metadata(metadata_batch, cut_audio_segments=False) - assert len(result) == 1 - assert result[0]["resampled_audio_filepath"] == "/only.wav" - - def test_empty_segments_returns_empty_list(self) -> None: - """Metadata with no segments or empty segments returns empty list.""" - stage = ConcreteASRProcessor() - result = stage._prepare_segment_batch_with_metadata([{"segments": []}], cut_audio_segments=False) - assert result == [] - result = stage._prepare_segment_batch_with_metadata([{}], cut_audio_segments=False) - assert result == [] diff --git a/tests/stages/audio/tagging/inference/test_nemo_asr_align.py b/tests/stages/audio/tagging/inference/test_nemo_asr_align.py deleted file mode 100644 index 6f73fbd94e..0000000000 --- a/tests/stages/audio/tagging/inference/test_nemo_asr_align.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path -from typing import Any - -from nemo_curator.stages.audio.tagging.inference.nemo_asr_align import NeMoASRAlignerStage -from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import AudioTask - - -class TestNeMoASRAlignerStage: - def test_process_full_audio(self, tmpdir: Any, wav_filepath: Path) -> None: # noqa: ANN401 - stage = NeMoASRAlignerStage( - model_name="nvidia/stt_en_fastconformer_ctc_large", - is_fastconformer=True, - decoder_type="ctc", - resources=Resources(cpus=1.0), - ) - stage.setup() - - tasks = [ - AudioTask( - data={ - "audio_filepath": str(wav_filepath), - "split_filepaths": [str(wav_filepath)], - "split_metadata": [ - { - "start": 0, - "end": 10, - "resampled_audio_filepath": str(wav_filepath), - } - ], - } - ) - ] - results = stage.process_batch(tasks) - - assert len(results) == 1 - entry = results[0].data - split = entry["split_metadata"][0] - assert "text" in split - assert "alignment" in split - assert isinstance(split["text"], str) - assert isinstance(split["alignment"], list) - assert split["text"] != "" - assert len(split["alignment"]) > 10 diff --git a/tutorials/audio/tagging/README.md b/tutorials/audio/tagging/README.md index dff42b3a5a..13739019dc 100644 --- a/tutorials/audio/tagging/README.md +++ b/tutorials/audio/tagging/README.md @@ -28,7 +28,7 @@ The audio tagging pipeline is a processing framework that takes raw audio files | 1 | **ResampleAudioStage** | Resample to 16 kHz mono WAV | No | | 2 | **DiarizationStage** (+ `PyAnnoteDiarizationAdapter`) | Speaker diarization and overlap detection | Yes | | 3 | **SplitLongAudioStage** | Split segments exceeding max length | No | -| 4 | **NeMoASRAlignerStage** | Forced alignment via NeMo FastConformer | Yes | +| 4 | **ForcedAlignmentStage** (+ `NeMoASRAlignAdapter`) | Forced alignment via NeMo FastConformer | Yes | | 5 | **JoinSplitAudioMetadataStage** | Rejoin split audio metadata | No | | 6 | **MergeAlignmentDiarizationStage** | Merge alignment with diarization segments | No | | 7 | **ManifestWriterStage** | Write output JSONL manifest | No | diff --git a/tutorials/audio/tagging/tts_pipeline.yaml b/tutorials/audio/tagging/tts_pipeline.yaml index 7fe61c14c2..563deba919 100644 --- a/tutorials/audio/tagging/tts_pipeline.yaml +++ b/tutorials/audio/tagging/tts_pipeline.yaml @@ -23,7 +23,7 @@ # 1. ResampleAudio - Resample audio to 16 kHz mono WAV # 2. DiarizationStage - Speaker diarization + overlap detection (PyAnnote adapter) # 3. SplitLongAudio - Split segments exceeding max length -# 4. NeMoASRAligner - Forced alignment via NeMo FastConformer +# 4. ForcedAlignmentStage - Forced alignment via NeMo FastConformer (NeMoASRAlignAdapter) # 5. JoinSplitAudioMetadata - Rejoin split metadata # 6. MergeAlignmentDiar - Merge ASR alignment with diarization # 7. ManifestWriter - Write output JSONL manifest @@ -116,11 +116,14 @@ stages: min_len: 1.0 resources: ${resources} - - _target_: nemo_curator.stages.audio.tagging.inference.nemo_asr_align.NeMoASRAlignerStage + - _target_: nemo_curator.stages.audio.inference.alignment.ForcedAlignmentStage name: "ASRAlignment" - is_fastconformer: true - decoder_type: rnnt + adapter_target: nemo_curator.adapters.alignment.NeMoASRAlignAdapter + model_id: "nvidia/parakeet-tdt_ctc-1.1b" batch_size: 32 + adapter_kwargs: + is_fastconformer: true + decoder_type: rnnt - _target_: nemo_curator.stages.audio.tagging.JoinSplitAudioMetadataStage name: "JoinSplitMetadata" From fae6b483d7beb133da7d03375de818e22e222e70 Mon Sep 17 00:00:00 2001 From: "aaftaabv@gmail.com" Date: Fri, 29 May 2026 16:11:09 +0530 Subject: [PATCH 4/6] audio/tagging: align test overrides + Hydra schema with stage-adapter split Two pre-Kratos fixes that fell out of the b9f3c8b stage-adapter split: - tests/stages/audio/tagging/e2e/test_tts_e2e.py: update stages[4] overrides to the new ForcedAlignmentStage shape. Tier-1 stage field is model_id (was model_name); is_fastconformer / decoder_type / transcribe_batch_size now live under adapter_kwargs.* per the SDP-V2 Tier-1/Tier-2 split. The pre-split shape would break OmegaConf attribute assignment on the new dataclass. - tutorials/audio/tagging/tts_pipeline.yaml: declare top-level data_config + output_dir as empty strings so Hydra struct-mode accepts the CLI overrides that NvLLMOps' non-native Curator runner (nvllmops/stages/harvest/curator/run_curator.py lines 234-238) injects unconditionally for every pipeline that isn't in _CURATOR_NATIVE_PIPELINES. Tagging is one such pipeline. Neither field is consumed by any stage; they exist purely to satisfy the Hydra struct-mode contract during pod invocation. Both are zero-behavior changes for the audio pipeline itself; they only un-wedge downstream callers (pytest e2e and NvLLMOps Kratos wrapper). Mirrors PR1967's pattern of keeping post-refactor branches runnable end-to-end without modifying NvLLMOps-side code. Signed-off-by: aaftaabv@gmail.com --- tests/stages/audio/tagging/e2e/test_tts_e2e.py | 11 ++++++----- tutorials/audio/tagging/tts_pipeline.yaml | 8 ++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/stages/audio/tagging/e2e/test_tts_e2e.py b/tests/stages/audio/tagging/e2e/test_tts_e2e.py index 2010ae4a68..4da3e86143 100644 --- a/tests/stages/audio/tagging/e2e/test_tts_e2e.py +++ b/tests/stages/audio/tagging/e2e/test_tts_e2e.py @@ -51,11 +51,12 @@ def test_tts_e2e(tmp_path: Path, get_input_manifest: str) -> None: cfg.hf_token = os.getenv("HF_TOKEN", "") cfg.language_short = "en" - # Override ForcedAlignmentStage (index 4) to use CTC model for CPU testing - cfg.stages[4].model_name = "nvidia/stt_en_fastconformer_ctc_large" - cfg.stages[4].is_fastconformer = True - cfg.stages[4].decoder_type = "ctc" - cfg.stages[4].transcribe_batch_size = 1 + # Override ForcedAlignmentStage (index 4) to use CTC model for CPU testing. + # Tier-1 stage field: model_id. Tier-2 adapter knobs: under adapter_kwargs. + cfg.stages[4].model_id = "nvidia/stt_en_fastconformer_ctc_large" + cfg.stages[4].adapter_kwargs.is_fastconformer = True + cfg.stages[4].adapter_kwargs.decoder_type = "ctc" + cfg.stages[4].adapter_kwargs.transcribe_batch_size = 1 pipeline = create_pipeline_from_yaml(cfg) executor = XennaExecutor( diff --git a/tutorials/audio/tagging/tts_pipeline.yaml b/tutorials/audio/tagging/tts_pipeline.yaml index 563deba919..84e0ae1224 100644 --- a/tutorials/audio/tagging/tts_pipeline.yaml +++ b/tutorials/audio/tagging/tts_pipeline.yaml @@ -72,6 +72,14 @@ documentation: | input_manifest: ??? final_manifest: ??? +# NvLLMOps' non-native Curator runner (nvllmops/stages/harvest/curator/ +# run_curator.py) unconditionally injects `data_config=` and +# `output_dir=` as Hydra CLI overrides for every pipeline that +# isn't in _CURATOR_NATIVE_PIPELINES (which includes tagging). Declare +# them here so Hydra struct-mode accepts the assignment. Not consumed +# by any stage in this pipeline. +data_config: "" +output_dir: "" backend: xenna workspace_dir: /tmp/tagging_workspace resampled_audio_dir: ${workspace_dir}/audio_resampled From 9564b74fb1d4016d4f1829b85e573bf26c298d1d Mon Sep 17 00:00:00 2001 From: "aaftaabv@gmail.com" Date: Fri, 29 May 2026 16:21:32 +0530 Subject: [PATCH 5/6] audio/tagging: add tts_pipeline_correctness.yaml for Kratos reference-diff run Bakes the exact post-override pipeline shape used to generate the in-tree reference manifest at tests/fixtures/audio/tagging/reference/tts/test_data_reference.jsonl into a tutorial-layout Hydra YAML, so the NvLLMOps non-native Curator runner can load it via --local-config-path and reproduce the reference output byte-for-byte (within check_output()'s tolerance: text exact + segment boundaries pytest.approx(rel=1e-3) + word timings abs=0.01). The reference was produced by tests/stages/audio/tagging/e2e/ test_tts_e2e.py, which loads tests/.../e2e/configs/tts_pipeline.yaml and then applies four runtime overrides on Stage 4 (ForcedAlignmentStage): model_id = nvidia/stt_en_fastconformer_ctc_large adapter_kwargs.is_fastconformer = true adapter_kwargs.decoder_type = ctc adapter_kwargs.transcribe_batch_size = 1 This new YAML bakes those four into the YAML itself; everything else (Stage 2 PyAnnote small batches, Stage 4 batch_size=1 num_workers=1, stage list + ordering) mirrors the e2e test config verbatim. Adds top-level data_config: \"\" + output_dir: \"\" shim per the NvLLMOps non-native Hydra override surface (same reason as fae6b48). Kept separate from tts_pipeline.yaml because that file targets production-shape tagging (parakeet-tdt_ctc-1.1b RNN-T at batch=32, default PyAnnote batches), which would NOT match the reference text or timings even with identical input audio. One config per intent: - tts_pipeline.yaml: production tagging, no in-tree reference - tts_pipeline_correctness.yaml: Kratos validation against reference Three configs total for the tagging pipeline now: - tutorials/audio/tagging/tts_pipeline.yaml (production tagging) - tutorials/audio/tagging/tts_pipeline_correctness.yaml (Kratos diff vs ref) - tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml (pytest e2e in-process; runtime-overrides Stage 4 to match the same shape) Signed-off-by: aaftaabv@gmail.com --- .../tagging/tts_pipeline_correctness.yaml | 168 ++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 tutorials/audio/tagging/tts_pipeline_correctness.yaml diff --git a/tutorials/audio/tagging/tts_pipeline_correctness.yaml b/tutorials/audio/tagging/tts_pipeline_correctness.yaml new file mode 100644 index 0000000000..5f2fd4d83d --- /dev/null +++ b/tutorials/audio/tagging/tts_pipeline_correctness.yaml @@ -0,0 +1,168 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +############################################################################### +# TTS Audio Tagging Pipeline - Correctness Validation Configuration +############################################################################### +# +# Bakes in the exact post-override pipeline shape used to generate the +# in-tree reference manifest at: +# +# tests/fixtures/audio/tagging/reference/tts/test_data_reference.jsonl +# +# Reference was produced by running tests/stages/audio/tagging/e2e/ +# test_tts_e2e.py, which loads +# tests/stages/audio/tagging/e2e/configs/tts_pipeline.yaml and then +# applies four runtime overrides on Stage 4 (ForcedAlignmentStage): +# model_id = "nvidia/stt_en_fastconformer_ctc_large" +# adapter_kwargs.is_fastconformer = true +# adapter_kwargs.decoder_type = "ctc" +# adapter_kwargs.transcribe_batch_size = 1 +# This file bakes those four values directly into the YAML so the +# pipeline-as-loaded matches the pipeline-as-validated. +# +# Designed for NvLLMOps' non-native Curator runner +# (nvllmops/stages/harvest/curator/run_curator.py), which: +# - invokes tutorials/audio/tagging/main.py as the pod entry point, +# - downloads this YAML from swift via --local-config-path, +# - injects the following Hydra CLI overrides unconditionally: +# workspace_dir= output_dir= hf_token= +# input_manifest= data_config= language_short= +# max_segment_length= final_manifest= +# +# Post-run, diff the merged output against test_data_reference.jsonl +# using tests/stages/audio/tagging/e2e/utils.py::check_output (text exact +# + segment boundaries pytest.approx(rel=1e-3) + word timings abs=0.01). + +defaults: + - _self_ + - override hydra/job_logging: none + - override hydra/hydra_logging: none + +hydra: + run: + dir: . + output_subdir: null + +documentation: | + TTS Audio Tagging Pipeline - Correctness Validation + ################################################### + + Loaded by NvLLMOps' non-native Curator runner. Settings frozen to + reproduce the in-tree reference manifest at + ``tests/fixtures/audio/tagging/reference/tts/test_data_reference.jsonl``. + + **Required arguments** (NvLLMOps injects these via Hydra CLI overrides): + + * **input_manifest**: JSONL with audio_filepath + audio_item_id entries. + * **final_manifest**: Output JSONL path. + * **hf_token**: HuggingFace token for PyAnnote model gate. + +# ============================================================================= +# CORE CONFIGURATION +# ============================================================================= + +input_manifest: ??? +final_manifest: ??? +# NvLLMOps run_curator.py unconditionally injects these two as Hydra +# overrides for non-native pipelines (see lines 234-238 in that file). +# Declared empty here so Hydra struct-mode accepts the assignment; +# neither is consumed by any stage. +data_config: "" +output_dir: "" +backend: xenna +workspace_dir: /tmp/tagging_workspace +resampled_audio_dir: ${workspace_dir}/audio_resampled +language_short: en +hf_token: "" +sample_rate: 16000 +max_segment_length: 40 +resources: + cpus: 1 + +# ============================================================================= +# PIPELINE STAGES +# ============================================================================= +# +# Stage list + ordering identical to tts_pipeline.yaml; only Stage 2 + +# Stage 4 model / batch knobs differ (matched against the reference). + +stages: + # 0. Read input manifest + - _target_: nemo_curator.stages.audio.common.ManifestReader + manifest_path: ${input_manifest} + files_per_partition: 1 + + # 1. Resample audio to 16 kHz mono WAV + - _target_: nemo_curator.stages.audio.tagging.ResampleAudioStage + name: "ResampleAudio" + resampled_audio_dir: ${resampled_audio_dir} + input_format: "wav" + target_sample_rate: ${sample_rate} + target_format: "wav" + target_nchannels: 1 + resources: ${resources} + + # 2. Speaker diarization — small PyAnnote batches matching reference + - _target_: nemo_curator.stages.audio.inference.speaker_diarization.DiarizationStage + name: "PyAnnoteDiarization" + adapter_target: nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter + model_id: "pyannote/speaker-diarization-3.1" + non_speaker_max_length: ${max_segment_length} + xenna_num_workers: 1 + resources: + gpus: 1 + adapter_kwargs: + hf_token: ${hf_token} + max_length: ${max_segment_length} + segmentation_batch_size: 2 + embedding_batch_size: 2 + + # 3. Split long audio segments + - _target_: nemo_curator.stages.audio.tagging.SplitLongAudioStage + name: "SplitLongAudio" + suggested_max_len: ${max_segment_length} + min_len: 1.0 + resources: ${resources} + + # 4. ASR forced alignment — CTC fastconformer large, batch=1 + # (matches the test_tts_e2e.py runtime overrides exactly) + - _target_: nemo_curator.stages.audio.inference.alignment.ForcedAlignmentStage + name: "ASRAlignment" + adapter_target: nemo_curator.adapters.alignment.NeMoASRAlignAdapter + model_id: "nvidia/stt_en_fastconformer_ctc_large" + batch_size: 1 + resources: + gpus: 1 + adapter_kwargs: + is_fastconformer: true + decoder_type: ctc + transcribe_batch_size: 1 + num_workers: 1 + + # 5. Join split audio metadata + - _target_: nemo_curator.stages.audio.tagging.JoinSplitAudioMetadataStage + name: "JoinSplitMetadata" + resources: ${resources} + + # 6. Merge alignment with diarization segments + - _target_: nemo_curator.stages.audio.tagging.MergeAlignmentDiarizationStage + name: "MergeAlignmentDiar" + text_key: "text" + words_key: "words" + resources: ${resources} + + # 7. Write output manifest + - _target_: nemo_curator.stages.audio.common.ManifestWriterStage + output_path: ${final_manifest} From 6a13fde3bdfc20f3a39c259774597b99d420ec10 Mon Sep 17 00:00:00 2001 From: "aaftaabv@gmail.com" Date: Fri, 29 May 2026 17:22:46 +0530 Subject: [PATCH 6/6] audio/tagging: add sample_input_kratos.jsonl with actual_duration for Kratos pre-pipeline splitter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sibling of tests/fixtures/audio/tagging/sample_input.jsonl. The Kratos non-native runner path calls NvLLMOps split_manifest_by_duration_field which requires an actual_duration key on every manifest entry before the Curator pipeline starts (KeyError otherwise). Pytest path is unaffected — it ingests sample_input.jsonl directly and NeMo Curator ignores unknown fields, so this file is Kratos-only fixture data. Durations measured with soundfile.info: audio_1.opus: 60.000000 s (stereo, 48kHz, 2880000 frames) audio_2.opus: 67.073500 s (stereo, 48kHz, 3219528 frames) Signed-off-by: aaftaabv@gmail.com --- tests/fixtures/audio/tagging/sample_input_kratos.jsonl | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 tests/fixtures/audio/tagging/sample_input_kratos.jsonl diff --git a/tests/fixtures/audio/tagging/sample_input_kratos.jsonl b/tests/fixtures/audio/tagging/sample_input_kratos.jsonl new file mode 100644 index 0000000000..dd03e6a514 --- /dev/null +++ b/tests/fixtures/audio/tagging/sample_input_kratos.jsonl @@ -0,0 +1,2 @@ +{"audio_filepath": "tests/fixtures/audio/tagging/audios/audio_1.opus", "audio_item_id": "audio_1", "actual_duration": 60.0} +{"audio_filepath": "tests/fixtures/audio/tagging/audios/audio_2.opus", "audio_item_id": "audio_2", "actual_duration": 67.0735}