Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions nemo_curator/stages/audio/alm/alm_manifest_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def process(self, task: FileGroupTask) -> list[AudioBatch]:
for entry in entries
]

def ray_stage_spec(self) -> dict[str, Any]:
return {"is_fanout_stage": True}


@dataclass
class ALMManifestReader(CompositeStage[_EmptyTask, AudioBatch]):
"""Composite stage for reading ALM JSONL manifests.
Expand Down
22 changes: 21 additions & 1 deletion nemo_curator/stages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
from abc import ABC, ABCMeta, abstractmethod
from inspect import isabstract
from typing import TYPE_CHECKING, Any, Generic, TypeVar, final
from typing import TYPE_CHECKING, Any, Generic, TypeVar, final, get_origin, get_type_hints

from loguru import logger

Expand Down Expand Up @@ -302,8 +302,28 @@ def ray_stage_spec(self) -> dict[str, Any]:
Returns (dict[str, Any]):
Dictionary containing Ray-specific configuration
"""
if self._process_returns_list():
return {"is_fanout_stage": True}
return {}

@classmethod
def _process_returns_list(cls) -> bool:
"""Return whether the stage's process annotation can return a list."""
try:
return_annotation = get_type_hints(cls.process).get("return")
except (NameError, TypeError):
return_annotation = cls.process.__annotations__.get("return")
Comment on lines +312 to +315

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Adding AttributeError to the except tuple ensures the fallback path is reached when get_type_hints encounters a dotted-attribute annotation referencing a missing symbol (e.g. some_module.MissingType). Without it, the exception propagates out of ray_stage_spec().

Suggested change
try:
return_annotation = get_type_hints(cls.process).get("return")
except (NameError, TypeError):
return_annotation = cls.process.__annotations__.get("return")
try:
return_annotation = get_type_hints(cls.process).get("return")
except (NameError, TypeError, AttributeError):
return_annotation = cls.process.__annotations__.get("return")

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


return cls._annotation_includes_list(return_annotation)

@staticmethod
def _annotation_includes_list(annotation: object) -> bool:
"""Return whether an annotation is or includes a list type."""
if isinstance(annotation, str):
return annotation.strip().startswith(("list[", "List[", "typing.List["))

return get_origin(annotation) is list

# --- Custom per-stage metrics helpers ---
def _log_metrics(self, metrics: dict[str, float]) -> None:
"""Record custom metrics for this stage (e.g., sub-stage timings)."""
Expand Down
7 changes: 0 additions & 7 deletions nemo_curator/stages/deduplication/semantic/pairwise_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from loguru import logger

from nemo_curator.backends.base import WorkerMetadata
from nemo_curator.backends.experimental.utils import RayStageSpecKeys
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.resources import Resources
from nemo_curator.tasks import FileGroupTask, _EmptyTask
Expand Down Expand Up @@ -65,12 +64,6 @@ def setup(self, _: WorkerMetadata | None = None) -> None:
self.fs = get_fs(self.input_path, storage_options=self.storage_options)
self.path_normalizer = self.fs.unstrip_protocol if is_remote_url(self.input_path) else (lambda x: x)

def ray_stage_spec(self) -> dict[str, Any]:
"""Ray stage specification for this stage."""
return {
RayStageSpecKeys.IS_FANOUT_STAGE: True,
}

def xenna_stage_spec(self) -> dict[str, Any]:
return {
"num_workers_per_node": 1,
Expand Down
7 changes: 0 additions & 7 deletions nemo_curator/stages/file_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from loguru import logger

from nemo_curator.backends.experimental.utils import RayStageSpecKeys
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.resources import Resources
from nemo_curator.tasks import FileGroupTask, _EmptyTask
Expand Down Expand Up @@ -85,12 +84,6 @@ def inputs(self) -> tuple[list[str], list[str]]:
def outputs(self) -> tuple[list[str], list[str]]:
return [], []

def ray_stage_spec(self) -> dict[str, Any]:
"""Ray stage specification for this stage."""
return {
RayStageSpecKeys.IS_FANOUT_STAGE: True,
}

def xenna_stage_spec(self) -> dict[str, Any]:
return {
"num_workers_per_node": 1,
Expand Down
5 changes: 0 additions & 5 deletions nemo_curator/stages/text/download/base/url_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ def process(self, task: _EmptyTask) -> list[FileGroupTask]:
for i, url in enumerate(urls)
]

def ray_stage_spec(self) -> dict[str, Any]:
return {
"is_fanout_stage": True,
}

def xenna_stage_spec(self) -> dict[str, Any]:
return {
"num_workers_per_node": 1,
Expand Down
12 changes: 2 additions & 10 deletions nemo_curator/stages/video/clipping/clip_extraction_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
import subprocess
import uuid
from dataclasses import dataclass
from typing import Any

from cosmos_xenna.ray_utils.resources import _get_local_gpu_info, _make_gpu_resources_from_gpu_name
from loguru import logger

from nemo_curator.backends.base import WorkerMetadata
from nemo_curator.backends.experimental.utils import RayStageSpecKeys
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.resources import Resources, _get_gpu_memory_gb
from nemo_curator.tasks.video import Clip, Video, VideoTask
Expand Down Expand Up @@ -96,19 +94,13 @@ def inputs(self) -> tuple[list[str], list[str]]:
def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], []

def ray_stage_spec(self) -> dict[str, Any]:
"""Ray stage specification for this stage."""
return {
RayStageSpecKeys.IS_FANOUT_STAGE: True,
}

def process(self, task: VideoTask) -> VideoTask:
def process(self, task: VideoTask) -> list[VideoTask]:
video = task.data

if not video.clips:
logger.warning(f"No clips to transcode for {video.input_video}. Skipping...")
video.source_bytes = None
return task
return [task]

with make_pipeline_temporary_dir(sub_dir="transcode") as tmp_dir:
# write video to file
Expand Down
5 changes: 5 additions & 0 deletions tests/stages/audio/alm/test_alm_manifest_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def _make_file_group_task(paths: list[str]) -> FileGroupTask:
class TestALMManifestReaderStage:
"""Unit tests for ALMManifestReaderStage (low-level stage)."""

def test_ray_stage_spec_is_fanout_stage(self) -> None:
stage = ALMManifestReaderStage()

assert stage.ray_stage_spec() == {"is_fanout_stage": True}

def test_reads_single_manifest(self, tmp_path: Path) -> None:
entries = [
{"audio_filepath": "a.wav", "audio_sample_rate": 16000, "segments": []},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def test_get_fleurs_url_list_builds_urls() -> None:
assert urls[1].endswith("/hy_am/audio/dev.tar.gz")


def test_create_initial_manifest_stage_is_ray_fanout_stage(tmp_path: Path) -> None:
stage_cls, _ = _import_stage_module()
stage = stage_cls(lang="hy_am", split="dev", raw_data_dir=tmp_path.as_posix())

assert stage.ray_stage_spec() == {"is_fanout_stage": True}


def test_process_transcript_parses_tsv(tmp_path: Path) -> None:
stage_cls, _ = _import_stage_module()
# Arrange: create fake dev.tsv and expected wav layout
Expand Down
6 changes: 6 additions & 0 deletions tests/stages/audio/io/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from nemo_curator.tasks import AudioBatch


def test_audio_to_document_stage_is_ray_fanout_stage() -> None:
stage = AudioToDocumentStage()

assert stage.ray_stage_spec() == {"is_fanout_stage": True}


def test_audio_to_document_stage_converts_batch() -> None:
audio = AudioBatch(
task_id="t1",
Expand Down
68 changes: 68 additions & 0 deletions tests/stages/common/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,45 @@ def outputs(self) -> tuple[list[str], list[str]]:
return [], []


class FanoutProcessingStage(ProcessingStage[MockTask, MockTask]):
"""ProcessingStage that returns multiple tasks."""

name = "FanoutProcessingStage"

def process(self, task: MockTask) -> list[MockTask]:
return [task]


class MaybeFanoutProcessingStage(ProcessingStage[MockTask, MockTask]):
"""ProcessingStage that may return multiple tasks."""

name = "MaybeFanoutProcessingStage"

def process(self, task: MockTask) -> MockTask | list[MockTask]:
return task


class ExplicitMaybeFanoutProcessingStage(MaybeFanoutProcessingStage):
"""Maybe-fanout stage that opts into Ray fanout explicitly."""

name = "ExplicitMaybeFanoutProcessingStage"

def ray_stage_spec(self) -> dict[str, bool]:
return {"is_fanout_stage": True}


class StringAnnotatedFanoutProcessingStage(ProcessingStage[MockTask, MockTask]):
"""ProcessingStage with a string return annotation."""

name = "StringAnnotatedFanoutProcessingStage"

def process(self, task: MockTask) -> list[MockTask]:
return [task]


StringAnnotatedFanoutProcessingStage.process.__annotations__["return"] = "list[MissingTask]"


class TestProcessingStageWith:
"""Test the with_ method for ProcessingStage."""

Expand Down Expand Up @@ -265,6 +304,35 @@ def process(self, task: MockTask) -> MockTask:
assert stage_with_custom2.resources == Resources(cpus=7.0)


class TestProcessingStageRaySpec:
"""Test Ray stage spec defaults."""

def test_default_ray_stage_spec_empty_for_single_task_stage(self):
stage = ConcreteProcessingStage()

assert stage.ray_stage_spec() == {}

def test_ray_stage_spec_detects_fanout_stage(self):
stage = FanoutProcessingStage()

assert stage.ray_stage_spec() == {"is_fanout_stage": True}

def test_ray_stage_spec_does_not_infer_optional_fanout_stage(self):
stage = MaybeFanoutProcessingStage()

assert stage.ray_stage_spec() == {}

def test_ray_stage_spec_allows_optional_fanout_stage_to_opt_in(self):
stage = ExplicitMaybeFanoutProcessingStage()

assert stage.ray_stage_spec() == {"is_fanout_stage": True}

def test_ray_stage_spec_detects_string_annotated_fanout_stage(self):
stage = StringAnnotatedFanoutProcessingStage()

assert stage.ray_stage_spec() == {"is_fanout_stage": True}


class TestProcessingStageOverriddenProperties:
"""Test that ProcessingStage raises an error if a derived class overrides the _name, _resources, or _batch_size property."""

Expand Down
5 changes: 5 additions & 0 deletions tests/stages/deduplication/semantic/test_pairwise_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def test_setup(self):
assert stage.path_normalizer is not None
assert stage.path_normalizer("/test/path") == "/test/path"

def test_ray_stage_spec_is_fanout_stage(self):
stage = ClusterWiseFilePartitioningStage("/test/path")

assert stage.ray_stage_spec() == {"is_fanout_stage": True}

def test_process_finds_all_centroid_files(self, tmp_path: Path):
"""Test that process method finds all files in centroid directories."""

Expand Down
8 changes: 8 additions & 0 deletions tests/stages/image/io/test_image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ def test_process_raises_on_empty_task() -> None:
stage.process(empty)


def test_ray_stage_spec_is_fanout_stage() -> None:
from nemo_curator.stages.image.io.image_reader import ImageReaderStage

with patch("torch.cuda.is_available", return_value=False):
stage = ImageReaderStage(dali_batch_size=2, verbose=False)

assert stage.ray_stage_spec() == {"is_fanout_stage": True}


def test_resources_with_cuda_available() -> None:
from nemo_curator.stages.image.io.image_reader import ImageReaderStage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ def test_process_no_clips(self) -> None:
# Should return early and log warning
mock_logger.warning.assert_called_once()
assert "No clips to transcode" in mock_logger.warning.call_args[0][0]
assert result.data.source_bytes is None
assert isinstance(result, list)
assert len(result) == 1
assert result[0] is self.mock_task
assert result[0].data.source_bytes is None

@patch("nemo_curator.stages.video.clipping.clip_extraction_stages.make_pipeline_temporary_dir")
@patch("nemo_curator.stages.video.clipping.clip_extraction_stages.grouping.split_by_chunk_size")
Expand Down