Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions benchmarking/scripts/audio_tagging_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

from nemo_curator.pipeline import Pipeline
from nemo_curator.stages.audio.common import ManifestReader, ManifestWriterStage
from nemo_curator.stages.audio.inference.speaker_diarization.pyannote import PyAnnoteDiarizationStage
from nemo_curator.stages.audio.tagging.inference.nemo_asr_align import NeMoASRAlignerStage
from nemo_curator.stages.audio.inference.speaker_diarization import DiarizationStage
from nemo_curator.stages.audio.inference.alignment import ForcedAlignmentStage
from nemo_curator.stages.audio.tagging.merge_alignment_diarization import MergeAlignmentDiarizationStage
from nemo_curator.stages.audio.tagging.resample_audio import ResampleAudioStage
from nemo_curator.stages.audio.tagging.split import JoinSplitAudioMetadataStage, SplitLongAudioStage
Expand Down Expand Up @@ -85,12 +85,17 @@ def run_audio_tagging_benchmark( # noqa: PLR0913
).with_(resources=Resources(cpus=cpus))
)

# Speaker diarization and overlap detection (PyAnnote)
# Speaker diarization and overlap detection (DiarizationStage + PyAnnote adapter)
pipeline.add_stage(
PyAnnoteDiarizationStage(
DiarizationStage(
name="PyAnnoteDiarization",
hf_token=hf_token,
max_length=max_segment_length,
adapter_target="nemo_curator.adapters.diarization.PyAnnoteDiarizationAdapter",
model_id="pyannote/speaker-diarization-3.1",
non_speaker_max_length=max_segment_length,
adapter_kwargs={
"hf_token": hf_token,
"max_length": max_segment_length,
},
).with_(resources=Resources(cpus=cpus, gpus=0.5))
)

Expand All @@ -103,13 +108,17 @@ def run_audio_tagging_benchmark( # noqa: PLR0913
).with_(resources=Resources(cpus=cpus))
)

# ASR forced alignment (NeMo FastConformer)
# ASR forced alignment (ForcedAlignmentStage + NeMoASRAlignAdapter)
pipeline.add_stage(
NeMoASRAlignerStage(
ForcedAlignmentStage(
name="ASRAlignment",
is_fastconformer=True,
decoder_type="rnnt",
adapter_target="nemo_curator.adapters.alignment.NeMoASRAlignAdapter",
model_id="nvidia/parakeet-tdt_ctc-1.1b",
batch_size=asr_batch_size,
adapter_kwargs={
"is_fastconformer": True,
"decoder_type": "rnnt",
},
).with_(resources=Resources(cpus=cpus, gpus=0.45))
)

Expand Down
27 changes: 27 additions & 0 deletions nemo_curator/adapters/__init__.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be an empty file.

Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Model adapters for the SDP-V2 stage-adapter split.

Each adapter family (``diarization``, ``vad``, ``alignment``, ...) lives
in its own subpackage and exposes:

* ``base.py`` - a ``Protocol`` plus a typed ``Result`` dataclass that
every adapter in the family must implement.
* one module per concrete model that implements the protocol.

Stages in ``nemo_curator/stages/audio/inference/`` import the protocol
and typed result only; the concrete adapter is resolved at runtime from
the YAML's ``adapter_target`` string via ``hydra.utils.get_class``.
"""
41 changes: 41 additions & 0 deletions nemo_curator/adapters/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Forced-alignment adapter family for the SDP-V2 stage-adapter split.

Public surface (the only symbols the stage imports):

* :class:`ForcedAlignmentAdapter` - structural protocol every alignment
adapter implements.
* :class:`AlignmentResult` - canonical per-utterance result dataclass.
* :class:`WordAlignment` - canonical per-word dataclass.

Concrete adapters live in their own modules (e.g. ``nemo_asr_align.py``,
``nemo_nfa.py``, ``whisperx_alignment.py``) and are resolved at runtime
by their fully-qualified class path in YAML's ``adapter_target`` field.
"""

from nemo_curator.adapters.alignment.base import (
AlignmentResult,
ForcedAlignmentAdapter,
WordAlignment,
)
from nemo_curator.adapters.alignment.nemo_asr_align import NeMoASRAlignAdapter

__all__ = [
"AlignmentResult",
"ForcedAlignmentAdapter",
"NeMoASRAlignAdapter",
"WordAlignment",
]
144 changes: 144 additions & 0 deletions nemo_curator/adapters/alignment/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Stage-adapter contract for forced alignment (SDP-V2 design doc §13).

Mirrors the ASR / diarization / VAD contract pattern:

* :class:`~nemo_curator.stages.audio.inference.alignment.ForcedAlignmentStage`
owns Curator-side glue (task.data reads, split-filepath fan-out + scatter,
segment cut, time-offset adjustment, metric logging).
* :class:`ForcedAlignmentAdapter` owns the model-side library call
(weight prefetch, model setup, decoder configuration, the actual
``transcribe(...)`` invocation, hypothesis-to-word-alignment
conversion) and packs results into the canonical
:class:`AlignmentResult` shape.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable


@dataclass
class WordAlignment:
"""Canonical per-word alignment dataclass.

Attributes:
word: The aligned word (or character, when the adapter uses
char-level timestamps).
start: Word start time in seconds (clip / segment coordinates,
see :class:`AlignmentResult`).
end: Word end time in seconds.
confidence: Optional adapter-supplied per-word confidence
score in ``[0, 1]``. ``None`` when the adapter doesn't
surface one.
"""

word: str
start: float
end: float
confidence: float | None = None


@dataclass
class AlignmentResult:
"""Canonical per-input alignment adapter output.

Attributes:
alignments: One :class:`WordAlignment` per emitted word
(or char). The stage applies any necessary time-offset
shift before writing this onto ``task.data``. Empty list
when the adapter could not process the input.
text: Concatenated transcription text. The stage writes this
onto ``task.data[text_key]``.
model_id: The actual model identifier the adapter ran (mirrors
the stage's ``model_id`` field; populated by the adapter).
extras: Adapter-specific scalar / structured diagnostics that
do not fit the canonical shape. Stage never reads inside
this dict.
"""

alignments: list[WordAlignment]
text: str = ""
model_id: str = ""
extras: dict[str, Any] = field(default_factory=dict)


@runtime_checkable
class ForcedAlignmentAdapter(Protocol):
"""Structural protocol every forced-alignment adapter must implement.

Constructor contract: adapters are constructed by the stage as
``cls(model_id=..., revision=..., **adapter_kwargs)``. Tier-2 knobs
are adapter-specific (decoder type, FastConformer toggle, batch
sizes, ...).

Per-batch contract: :meth:`align_batch` receives a list of dicts
(Tier-3 per-task knobs unpacked from ``task.data`` by the stage)
and returns one :class:`AlignmentResult` per input, in the same
order.

Expected per-item keys (the stage populates these; the adapter
reads whichever is present):

* ``audio_path`` (``str | None``): Path to a decodable audio file.
Used for full-audio / split-filepath inference.
* ``audio_segment`` (``numpy.ndarray | None``): In-memory mono
audio array, one segment cut. Used for segment-only inference.
* ``sample_rate`` (``int | None``): Sample rate of
``audio_segment`` (only meaningful in segment-mode).
* ``task_id`` (``str | None``): Carried through for diagnostics.

A batch must be homogeneous - either all items have ``audio_path``
OR all have ``audio_segment``; the stage guarantees this.

Attributes:
model_id: Identifier of the underlying model checkpoint.
last_metrics: Scalar metrics from the last :meth:`align_batch`
call. The stage merges these into ``_log_metrics`` output
under ``model_<key>`` aliases.
"""

model_id: str
last_metrics: dict[str, float]

@classmethod
def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None:
"""Download weights to local cache without allocating a GPU."""
...

def setup(self) -> None:
"""Load the model into the worker's process."""
...

def teardown(self) -> None:
"""Release GPU memory and worker-local state."""
...

def align_batch(self, items: list[dict[str, Any]]) -> list[AlignmentResult]:
"""Run forced alignment on a batch of per-task dicts.

Args:
items: One dict per task with the keys documented on the
class docstring. Length matches the batch size.

Returns:
One :class:`AlignmentResult` per input, in the same order.
Items the adapter could not process must still appear with
empty ``alignments`` and ``text=""`` so the stage can
scatter results 1:1.
"""
...
Loading