From 370b4d319be74ceb1b9b005ec5c58132622981fa Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Sun, 24 May 2026 21:26:42 +0800 Subject: [PATCH 1/5] fix: auto-detect Ray fanout stages Signed-off-by: nightcityblade --- nemo_curator/stages/base.py | 25 ++++++++++++- .../text/download/base/url_generation.py | 5 --- tests/stages/common/test_base.py | 37 +++++++++++++++++++ 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/nemo_curator/stages/base.py b/nemo_curator/stages/base.py index cbf7652ac0..20675171e1 100644 --- a/nemo_curator/stages/base.py +++ b/nemo_curator/stages/base.py @@ -19,7 +19,8 @@ import time from abc import ABC, ABCMeta, abstractmethod from inspect import isabstract -from typing import TYPE_CHECKING, Any, Generic, TypeVar, final +from types import UnionType +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, final, get_args, get_origin, get_type_hints from loguru import logger @@ -302,8 +303,30 @@ def ray_stage_spec(self) -> dict[str, Any]: Returns (dict[str, Any]): Dictionary containing Ray-specific configuration """ + if self._process_returns_list(): + return {"is_fanout_stage": True} return {} + @classmethod + def _process_returns_list(cls) -> bool: + """Return whether the stage's process annotation can return a list.""" + try: + return_annotation = get_type_hints(cls.process).get("return") + except (NameError, TypeError): + return_annotation = cls.process.__annotations__.get("return") + + return cls._annotation_includes_list(return_annotation) + + @staticmethod + def _annotation_includes_list(annotation: object) -> bool: + """Return whether an annotation is or includes a list type.""" + origin = get_origin(annotation) + if origin is list: + return True + if origin in (UnionType, Union) or isinstance(annotation, UnionType): + return any(ProcessingStage._annotation_includes_list(arg) for arg in get_args(annotation)) + return False + # --- Custom per-stage metrics helpers --- def _log_metrics(self, metrics: dict[str, float]) -> None: """Record custom metrics for this stage (e.g., sub-stage timings).""" diff --git a/nemo_curator/stages/text/download/base/url_generation.py b/nemo_curator/stages/text/download/base/url_generation.py index 278bf88ad0..3046c8ba12 100644 --- a/nemo_curator/stages/text/download/base/url_generation.py +++ b/nemo_curator/stages/text/download/base/url_generation.py @@ -77,11 +77,6 @@ def process(self, task: _EmptyTask) -> list[FileGroupTask]: for i, url in enumerate(urls) ] - def ray_stage_spec(self) -> dict[str, Any]: - return { - "is_fanout_stage": True, - } - def xenna_stage_spec(self) -> dict[str, Any]: return { "num_workers_per_node": 1, diff --git a/tests/stages/common/test_base.py b/tests/stages/common/test_base.py index 937bc3d2f5..2a8c9c5d5a 100644 --- a/tests/stages/common/test_base.py +++ b/tests/stages/common/test_base.py @@ -51,6 +51,24 @@ def outputs(self) -> tuple[list[str], list[str]]: return [], [] +class FanoutProcessingStage(ProcessingStage[MockTask, MockTask]): + """ProcessingStage that returns multiple tasks.""" + + name = "FanoutProcessingStage" + + def process(self, task: MockTask) -> list[MockTask]: + return [task] + + +class MaybeFanoutProcessingStage(ProcessingStage[MockTask, MockTask]): + """ProcessingStage that may return multiple tasks.""" + + name = "MaybeFanoutProcessingStage" + + def process(self, task: MockTask) -> MockTask | list[MockTask]: + return task + + class TestProcessingStageWith: """Test the with_ method for ProcessingStage.""" @@ -265,6 +283,25 @@ def process(self, task: MockTask) -> MockTask: assert stage_with_custom2.resources == Resources(cpus=7.0) +class TestProcessingStageRaySpec: + """Test Ray stage spec defaults.""" + + def test_default_ray_stage_spec_empty_for_single_task_stage(self): + stage = ConcreteProcessingStage() + + assert stage.ray_stage_spec() == {} + + def test_ray_stage_spec_detects_fanout_stage(self): + stage = FanoutProcessingStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + def test_ray_stage_spec_detects_optional_fanout_stage(self): + stage = MaybeFanoutProcessingStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + class TestProcessingStageOverriddenProperties: """Test that ProcessingStage raises an error if a derived class overrides the _name, _resources, or _batch_size property.""" From 1b72885d138f50b483448b66be3804a7bd70042a Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Thu, 4 Jun 2026 23:17:39 +0800 Subject: [PATCH 2/5] fix: refine fanout stage detection Signed-off-by: nightcityblade --- nemo_curator/stages/base.py | 13 +++----- .../audio/alm/test_alm_manifest_reader.py | 5 +++ .../test_fleurs_create_initial_manifest.py | 7 ++++ tests/stages/audio/io/test_convert.py | 6 ++++ tests/stages/common/test_base.py | 33 ++++++++++++++++++- tests/stages/image/io/test_image_reader.py | 8 +++++ 6 files changed, 63 insertions(+), 9 deletions(-) diff --git a/nemo_curator/stages/base.py b/nemo_curator/stages/base.py index 20675171e1..6df28f1ec9 100644 --- a/nemo_curator/stages/base.py +++ b/nemo_curator/stages/base.py @@ -19,8 +19,7 @@ import time from abc import ABC, ABCMeta, abstractmethod from inspect import isabstract -from types import UnionType -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, final, get_args, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, Generic, TypeVar, final, get_origin, get_type_hints from loguru import logger @@ -320,12 +319,10 @@ def _process_returns_list(cls) -> bool: @staticmethod def _annotation_includes_list(annotation: object) -> bool: """Return whether an annotation is or includes a list type.""" - origin = get_origin(annotation) - if origin is list: - return True - if origin in (UnionType, Union) or isinstance(annotation, UnionType): - return any(ProcessingStage._annotation_includes_list(arg) for arg in get_args(annotation)) - return False + if isinstance(annotation, str): + return annotation.strip().startswith(("list[", "List[", "typing.List[")) + + return get_origin(annotation) is list # --- Custom per-stage metrics helpers --- def _log_metrics(self, metrics: dict[str, float]) -> None: diff --git a/tests/stages/audio/alm/test_alm_manifest_reader.py b/tests/stages/audio/alm/test_alm_manifest_reader.py index cf829d65f9..77d22e1f87 100644 --- a/tests/stages/audio/alm/test_alm_manifest_reader.py +++ b/tests/stages/audio/alm/test_alm_manifest_reader.py @@ -28,6 +28,11 @@ def _make_file_group_task(paths: list[str]) -> FileGroupTask: class TestALMManifestReaderStage: """Unit tests for ALMManifestReaderStage (low-level stage).""" + def test_ray_stage_spec_is_fanout_stage(self) -> None: + stage = ALMManifestReaderStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + def test_reads_single_manifest(self, tmp_path: Path) -> None: entries = [ {"audio_filepath": "a.wav", "audio_sample_rate": 16000, "segments": []}, diff --git a/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py b/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py index d26e2f02ab..ade09ed628 100644 --- a/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py +++ b/tests/stages/audio/datasets/test_fleurs_create_initial_manifest.py @@ -38,6 +38,13 @@ def test_get_fleurs_url_list_builds_urls() -> None: assert urls[1].endswith("/hy_am/audio/dev.tar.gz") +def test_create_initial_manifest_stage_is_ray_fanout_stage(tmp_path: Path) -> None: + stage_cls, _ = _import_stage_module() + stage = stage_cls(lang="hy_am", split="dev", raw_data_dir=tmp_path.as_posix()) + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + def test_process_transcript_parses_tsv(tmp_path: Path) -> None: stage_cls, _ = _import_stage_module() # Arrange: create fake dev.tsv and expected wav layout diff --git a/tests/stages/audio/io/test_convert.py b/tests/stages/audio/io/test_convert.py index f28a8ca0bb..07d8b8930b 100644 --- a/tests/stages/audio/io/test_convert.py +++ b/tests/stages/audio/io/test_convert.py @@ -18,6 +18,12 @@ from nemo_curator.tasks import AudioBatch +def test_audio_to_document_stage_is_ray_fanout_stage() -> None: + stage = AudioToDocumentStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + def test_audio_to_document_stage_converts_batch() -> None: audio = AudioBatch( task_id="t1", diff --git a/tests/stages/common/test_base.py b/tests/stages/common/test_base.py index 2a8c9c5d5a..372f59752a 100644 --- a/tests/stages/common/test_base.py +++ b/tests/stages/common/test_base.py @@ -69,6 +69,27 @@ def process(self, task: MockTask) -> MockTask | list[MockTask]: return task +class ExplicitMaybeFanoutProcessingStage(MaybeFanoutProcessingStage): + """Maybe-fanout stage that opts into Ray fanout explicitly.""" + + name = "ExplicitMaybeFanoutProcessingStage" + + def ray_stage_spec(self) -> dict[str, bool]: + return {"is_fanout_stage": True} + + +class StringAnnotatedFanoutProcessingStage(ProcessingStage[MockTask, MockTask]): + """ProcessingStage with a string return annotation.""" + + name = "StringAnnotatedFanoutProcessingStage" + + def process(self, task: MockTask) -> list[MockTask]: + return [task] + + +StringAnnotatedFanoutProcessingStage.process.__annotations__["return"] = "list[MissingTask]" + + class TestProcessingStageWith: """Test the with_ method for ProcessingStage.""" @@ -296,9 +317,19 @@ def test_ray_stage_spec_detects_fanout_stage(self): assert stage.ray_stage_spec() == {"is_fanout_stage": True} - def test_ray_stage_spec_detects_optional_fanout_stage(self): + def test_ray_stage_spec_does_not_infer_optional_fanout_stage(self): stage = MaybeFanoutProcessingStage() + assert stage.ray_stage_spec() == {} + + def test_ray_stage_spec_allows_optional_fanout_stage_to_opt_in(self): + stage = ExplicitMaybeFanoutProcessingStage() + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + + def test_ray_stage_spec_detects_string_annotated_fanout_stage(self): + stage = StringAnnotatedFanoutProcessingStage() + assert stage.ray_stage_spec() == {"is_fanout_stage": True} diff --git a/tests/stages/image/io/test_image_reader.py b/tests/stages/image/io/test_image_reader.py index 251f24ccf9..9499e0a07d 100644 --- a/tests/stages/image/io/test_image_reader.py +++ b/tests/stages/image/io/test_image_reader.py @@ -177,6 +177,14 @@ def test_process_raises_on_empty_task() -> None: stage.process(empty) +def test_ray_stage_spec_is_fanout_stage() -> None: + from nemo_curator.stages.image.io.image_reader import ImageReaderStage + + with patch("torch.cuda.is_available", return_value=False): + stage = ImageReaderStage(dali_batch_size=2, verbose=False) + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + def test_resources_with_cuda_available() -> None: from nemo_curator.stages.image.io.image_reader import ImageReaderStage From ec6a75f7cf9b9cf2c937b3b6a394c587681040c9 Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Fri, 5 Jun 2026 11:08:35 +0800 Subject: [PATCH 3/5] fix: remove redundant fanout overrides Signed-off-by: nightcityblade --- nemo_curator/stages/audio/alm/alm_manifest_reader.py | 4 ---- nemo_curator/stages/deduplication/semantic/pairwise_io.py | 7 ------- nemo_curator/stages/file_partitioning.py | 7 ------- tests/stages/deduplication/semantic/test_pairwise_io.py | 5 +++++ 4 files changed, 5 insertions(+), 18 deletions(-) diff --git a/nemo_curator/stages/audio/alm/alm_manifest_reader.py b/nemo_curator/stages/audio/alm/alm_manifest_reader.py index 16eac99dbd..4be2fdee36 100644 --- a/nemo_curator/stages/audio/alm/alm_manifest_reader.py +++ b/nemo_curator/stages/audio/alm/alm_manifest_reader.py @@ -63,10 +63,6 @@ def process(self, task: FileGroupTask) -> list[AudioBatch]: for entry in entries ] - def ray_stage_spec(self) -> dict[str, Any]: - return {"is_fanout_stage": True} - - @dataclass class ALMManifestReader(CompositeStage[_EmptyTask, AudioBatch]): """Composite stage for reading ALM JSONL manifests. diff --git a/nemo_curator/stages/deduplication/semantic/pairwise_io.py b/nemo_curator/stages/deduplication/semantic/pairwise_io.py index 5dd9dc34a3..22d0ac7ead 100644 --- a/nemo_curator/stages/deduplication/semantic/pairwise_io.py +++ b/nemo_curator/stages/deduplication/semantic/pairwise_io.py @@ -17,7 +17,6 @@ from loguru import logger from nemo_curator.backends.base import WorkerMetadata -from nemo_curator.backends.experimental.utils import RayStageSpecKeys from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.resources import Resources from nemo_curator.tasks import FileGroupTask, _EmptyTask @@ -65,12 +64,6 @@ def setup(self, _: WorkerMetadata | None = None) -> None: self.fs = get_fs(self.input_path, storage_options=self.storage_options) self.path_normalizer = self.fs.unstrip_protocol if is_remote_url(self.input_path) else (lambda x: x) - def ray_stage_spec(self) -> dict[str, Any]: - """Ray stage specification for this stage.""" - return { - RayStageSpecKeys.IS_FANOUT_STAGE: True, - } - def xenna_stage_spec(self) -> dict[str, Any]: return { "num_workers_per_node": 1, diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index 2829927481..eeb5f3b7b4 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -17,7 +17,6 @@ from loguru import logger -from nemo_curator.backends.experimental.utils import RayStageSpecKeys from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.resources import Resources from nemo_curator.tasks import FileGroupTask, _EmptyTask @@ -85,12 +84,6 @@ def inputs(self) -> tuple[list[str], list[str]]: def outputs(self) -> tuple[list[str], list[str]]: return [], [] - def ray_stage_spec(self) -> dict[str, Any]: - """Ray stage specification for this stage.""" - return { - RayStageSpecKeys.IS_FANOUT_STAGE: True, - } - def xenna_stage_spec(self) -> dict[str, Any]: return { "num_workers_per_node": 1, diff --git a/tests/stages/deduplication/semantic/test_pairwise_io.py b/tests/stages/deduplication/semantic/test_pairwise_io.py index e8d16b1d4d..e6cbe6c747 100644 --- a/tests/stages/deduplication/semantic/test_pairwise_io.py +++ b/tests/stages/deduplication/semantic/test_pairwise_io.py @@ -45,6 +45,11 @@ def test_setup(self): assert stage.path_normalizer is not None assert stage.path_normalizer("/test/path") == "/test/path" + def test_ray_stage_spec_is_fanout_stage(self): + stage = ClusterWiseFilePartitioningStage("/test/path") + + assert stage.ray_stage_spec() == {"is_fanout_stage": True} + def test_process_finds_all_centroid_files(self, tmp_path: Path): """Test that process method finds all files in centroid directories.""" From ee7e80a338c28a80f0d2b2b8cbd6f80c36952e0d Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Sat, 6 Jun 2026 11:24:19 +0800 Subject: [PATCH 4/5] fix: infer clip fanout stage from return type Signed-off-by: nightcityblade --- .../stages/video/clipping/clip_extraction_stages.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/nemo_curator/stages/video/clipping/clip_extraction_stages.py b/nemo_curator/stages/video/clipping/clip_extraction_stages.py index 1bed14b9bf..222e7dc211 100644 --- a/nemo_curator/stages/video/clipping/clip_extraction_stages.py +++ b/nemo_curator/stages/video/clipping/clip_extraction_stages.py @@ -17,13 +17,11 @@ import subprocess import uuid from dataclasses import dataclass -from typing import Any from cosmos_xenna.ray_utils.resources import _get_local_gpu_info, _make_gpu_resources_from_gpu_name from loguru import logger from nemo_curator.backends.base import WorkerMetadata -from nemo_curator.backends.experimental.utils import RayStageSpecKeys from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.resources import Resources, _get_gpu_memory_gb from nemo_curator.tasks.video import Clip, Video, VideoTask @@ -96,13 +94,7 @@ def inputs(self) -> tuple[list[str], list[str]]: def outputs(self) -> tuple[list[str], list[str]]: return ["data"], [] - def ray_stage_spec(self) -> dict[str, Any]: - """Ray stage specification for this stage.""" - return { - RayStageSpecKeys.IS_FANOUT_STAGE: True, - } - - def process(self, task: VideoTask) -> VideoTask: + def process(self, task: VideoTask) -> list[VideoTask]: video = task.data if not video.clips: From b92b95ffe9ccb3dd051c17630434af1b2071345d Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Mon, 8 Jun 2026 11:13:51 +0800 Subject: [PATCH 5/5] fix: return clip fanout tasks consistently Signed-off-by: nightcityblade --- nemo_curator/stages/video/clipping/clip_extraction_stages.py | 2 +- tests/stages/video/clipping/test_clip_transcoding_stage.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo_curator/stages/video/clipping/clip_extraction_stages.py b/nemo_curator/stages/video/clipping/clip_extraction_stages.py index 222e7dc211..e6f4a2a79b 100644 --- a/nemo_curator/stages/video/clipping/clip_extraction_stages.py +++ b/nemo_curator/stages/video/clipping/clip_extraction_stages.py @@ -100,7 +100,7 @@ def process(self, task: VideoTask) -> list[VideoTask]: if not video.clips: logger.warning(f"No clips to transcode for {video.input_video}. Skipping...") video.source_bytes = None - return task + return [task] with make_pipeline_temporary_dir(sub_dir="transcode") as tmp_dir: # write video to file diff --git a/tests/stages/video/clipping/test_clip_transcoding_stage.py b/tests/stages/video/clipping/test_clip_transcoding_stage.py index bae37a6170..b9c46df588 100644 --- a/tests/stages/video/clipping/test_clip_transcoding_stage.py +++ b/tests/stages/video/clipping/test_clip_transcoding_stage.py @@ -138,7 +138,10 @@ def test_process_no_clips(self) -> None: # Should return early and log warning mock_logger.warning.assert_called_once() assert "No clips to transcode" in mock_logger.warning.call_args[0][0] - assert result.data.source_bytes is None + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] is self.mock_task + assert result[0].data.source_bytes is None @patch("nemo_curator.stages.video.clipping.clip_extraction_stages.make_pipeline_temporary_dir") @patch("nemo_curator.stages.video.clipping.clip_extraction_stages.grouping.split_by_chunk_size")