From bde22175def5a5bdad3f457487d5df9781224417 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Tue, 9 Jun 2026 14:23:00 -0700 Subject: [PATCH 01/11] basic slurm array file partitioning Signed-off-by: Sarah Yurick --- nemo_curator/stages/file_partitioning.py | 69 +++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index 53d12ae88e..6e0a73c3d2 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import os from dataclasses import dataclass from typing import Any @@ -29,6 +31,24 @@ ) +def _get_int_or_env_var(input_value: int | str | None, default_name: str | None = None) -> int: + if type(input_value) is int: + return input_value + elif type(input_value) is str: + if os.environ.get(input_value) is None: + msg = f"Environment variable {input_value} is not set" + raise ValueError(msg) + return int(os.environ.get(input_value)) + elif default_name is not None: + if os.environ.get(default_name) is None: + msg = f"Environment variable {default_name} is not set" + raise ValueError(msg) + return int(os.environ.get(default_name)) + else: + msg = f"Invalid input value: {input_value}, must be an integer or a string" + raise ValueError(msg) + + @dataclass class FilePartitioningStage(ProcessingStage[_EmptyTask, FileGroupTask]): """Stage that partitions input file paths into FileGroupTasks. @@ -55,6 +75,18 @@ class FilePartitioningStage(ProcessingStage[_EmptyTask, FileGroupTask]): Storage options to pass to the file system. limit: int | None = None Maximum number of partitions to create. + enable_array_partitioning: bool = False + Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs). + Intended for use with Slurm job arrays via the `sbatch --array` option. + shard_index: int | str | None = None + The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable. + total_shards: int | str | None = None + The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. + minimum_shard_index: int = 0 + The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to 0. """ file_paths: str | list[str] @@ -63,6 +95,10 @@ class FilePartitioningStage(ProcessingStage[_EmptyTask, FileGroupTask]): file_extensions: list[str] | None = None storage_options: dict[str, Any] | None = None limit: int | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 name: str = "file_partitioning" def __post_init__(self): @@ -91,6 +127,12 @@ def __post_init__(self): self.resources = Resources(cpus=0.5) + if self.enable_array_partitioning: + self.shard_index = _get_int_or_env_var(self.shard_index, "SLURM_ARRAY_TASK_ID") + self.total_shards = _get_int_or_env_var(self.total_shards, "SLURM_ARRAY_TASK_COUNT") + self.minimum_shard_index = _get_int_or_env_var(self.minimum_shard_index) + self.name = "array_file_partitioning" + def inputs(self) -> tuple[list[str], list[str]]: return [], [] @@ -106,7 +148,7 @@ def ray_stage_spec(self) -> dict[str, Any]: def xenna_stage_spec(self) -> dict[str, Any]: return {"num_workers_per_node": 1} - def process(self, _: _EmptyTask) -> list[FileGroupTask]: + def _process(self, _: _EmptyTask) -> list[FileGroupTask]: """Process the initial task to create file group tasks. This stage expects a simple Task with file paths information @@ -189,6 +231,31 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]: logger.info(f"Created {len(tasks)} file groups from {len(files)} files") return tasks + def _process_array(self, task: _EmptyTask) -> list[FileGroupTask]: + all_tasks = self._process(task) + assigned_tasks = [] + + for ft in all_tasks: + source_files = list(ft._metadata.get("source_files") or ft.data) + # Hash the source files to get a unique identifier for the partition + digest = hashlib.sha256("|".join(sorted(source_files)).encode("utf-8")).hexdigest() + # Assign the partition to the shard + assigned = int(digest[:16], 16) % self.total_shards + # Add the minimum shard index to the assigned shard index + assigned += self.minimum_shard_index + # Add the partition to the assigned tasks + if assigned == self.shard_index: + assigned_tasks.append(ft) + + logger.info(f"Shard {self.shard_index}/{self.total_shards}: assigned {len(assigned_tasks)} of {len(all_tasks)} partitions") + return assigned_tasks + + def process(self, task: _EmptyTask) -> list[FileGroupTask]: + if self.enable_array_partitioning: + return self._process_array(task) + else: + return self._process(task) + def _get_file_list_with_sizes(self, sort_by_size: bool = True) -> list[tuple[str, int]]: """ Get the list of files to process. From a0595f6692b89a0d659d32be7a49d1429f3c3462 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Tue, 9 Jun 2026 14:54:34 -0700 Subject: [PATCH 02/11] add slurm array params to composite stages using filepartitioningstage Signed-off-by: Sarah Yurick --- nemo_curator/stages/audio/common.py | 15 ++++++++++++++ nemo_curator/stages/interleaved/io/reader.py | 16 +++++++++++++++ .../text/deduplication/removal_workflow.py | 2 +- .../stages/text/deduplication/semantic.py | 19 ++++++++++++++++++ nemo_curator/stages/text/io/reader/jsonl.py | 8 ++++++++ nemo_curator/stages/text/io/reader/parquet.py | 8 ++++++++ nemo_curator/stages/video/io/video_reader.py | 20 ++++++++++++++++++- 7 files changed, 86 insertions(+), 2 deletions(-) diff --git a/nemo_curator/stages/audio/common.py b/nemo_curator/stages/audio/common.py index 6c7c8eb837..3f0f0a7f78 100644 --- a/nemo_curator/stages/audio/common.py +++ b/nemo_curator/stages/audio/common.py @@ -192,6 +192,13 @@ class ManifestReader(CompositeStage[_EmptyTask, AudioTask]): blocksize: Target size per partition (e.g., "100MB"). Ignored if files_per_partition is set. file_extensions: File extensions to filter. Defaults to [".jsonl", ".json"]. storage_options: Storage options for cloud paths (S3, GCS credentials, endpoints). + enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs). + shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable. + total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. + minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to 0. """ manifest_path: str | list[str] @@ -200,6 +207,10 @@ class ManifestReader(CompositeStage[_EmptyTask, AudioTask]): blocksize: int | str | None = None file_extensions: list[str] = field(default_factory=lambda: [".jsonl", ".json"]) storage_options: dict[str, Any] | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 def __post_init__(self) -> None: super().__init__() @@ -215,6 +226,10 @@ def decompose(self) -> list[ProcessingStage]: blocksize=self.blocksize, file_extensions=self.file_extensions, storage_options=self.storage_options, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), ManifestReaderStage(), ] diff --git a/nemo_curator/stages/interleaved/io/reader.py b/nemo_curator/stages/interleaved/io/reader.py index 404c52fc16..9b19d4a513 100644 --- a/nemo_curator/stages/interleaved/io/reader.py +++ b/nemo_curator/stages/interleaved/io/reader.py @@ -43,6 +43,10 @@ class InterleavedWebdatasetReader(CompositeStage[_EmptyTask, InterleavedBatch]): blocksize: int | str | None = None max_batch_bytes: int | None = None read_kwargs: dict[str, Any] = field(default_factory=dict) + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 materialize_on_read: bool = False file_extensions: list[str] = field(default_factory=lambda: list(DEFAULT_WEBDATASET_EXTENSIONS)) json_extensions: list[str] = field(default_factory=lambda: list(DEFAULT_JSON_EXTENSIONS)) @@ -68,6 +72,10 @@ def decompose(self) -> list: blocksize=self.blocksize, file_extensions=self.file_extensions, storage_options=self.storage_options, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), InterleavedWebdatasetReaderStage( read_kwargs=self.read_kwargs, @@ -96,6 +104,10 @@ class InterleavedParquetReader(CompositeStage[_EmptyTask, InterleavedBatch]): fields: tuple[str, ...] | None = None max_batch_bytes: int | None = None read_kwargs: dict[str, Any] = field(default_factory=dict) + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 schema: pa.Schema | None = None schema_overrides: dict[str, pa.DataType] | None = None file_extensions: list[str] = field(default_factory=lambda: [".parquet"]) @@ -113,6 +125,10 @@ def decompose(self) -> list: blocksize=self.blocksize, file_extensions=self.file_extensions, storage_options=self.storage_options, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), InterleavedParquetReaderStage( read_kwargs=self.read_kwargs, diff --git a/nemo_curator/stages/text/deduplication/removal_workflow.py b/nemo_curator/stages/text/deduplication/removal_workflow.py index 078b2014d3..af6fd260bd 100644 --- a/nemo_curator/stages/text/deduplication/removal_workflow.py +++ b/nemo_curator/stages/text/deduplication/removal_workflow.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo_curator/stages/text/deduplication/semantic.py b/nemo_curator/stages/text/deduplication/semantic.py index 34eaa7f651..11b67b2a4d 100644 --- a/nemo_curator/stages/text/deduplication/semantic.py +++ b/nemo_curator/stages/text/deduplication/semantic.py @@ -76,6 +76,10 @@ class TextSemanticDeduplicationWorkflow: embedding_vllm_init_kwargs: dict[str, Any] | None = None hf_token: str | None = None model_cache_dir: str | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 # Semantic deduplication parameters n_clusters: int = 100 id_field: str = CURATOR_DEDUP_ID_STR @@ -132,6 +136,13 @@ class TextSemanticDeduplicationWorkflow: embedding_vllm_init_kwargs: Additional kwargs passed to vLLM's LLM initializer hf_token: HuggingFace token for private models model_cache_dir: Directory to cache model weights + enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs). + shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable. + total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. + minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to 0. # Semantic deduplication parameters n_clusters: Number of clusters for K-means @@ -252,6 +263,10 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: file_extensions=self.input_file_extensions, _generate_ids=self.use_id_generator, read_kwargs=self.read_kwargs, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ) elif self.input_filetype == "parquet": reader = ParquetReader( @@ -266,6 +281,10 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: file_extensions=self.input_file_extensions, read_kwargs=self.read_kwargs, _generate_ids=self.use_id_generator, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ) else: msg = f"Input filetype {self.input_filetype} not supported yet" diff --git a/nemo_curator/stages/text/io/reader/jsonl.py b/nemo_curator/stages/text/io/reader/jsonl.py index 24bb3dd0cf..de10064b04 100644 --- a/nemo_curator/stages/text/io/reader/jsonl.py +++ b/nemo_curator/stages/text/io/reader/jsonl.py @@ -94,6 +94,10 @@ class JsonlReader(CompositeStage[_EmptyTask, DocumentBatch]): blocksize: int | str | None = None fields: list[str] | None = None # If specified, only read these columns read_kwargs: dict[str, Any] | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 task_type: Literal["document", "image", "video", "audio"] = "document" file_extensions: list[str] = field(default_factory=lambda: FILETYPE_TO_DEFAULT_EXTENSIONS["jsonl"]) _generate_ids: bool = False @@ -121,6 +125,10 @@ def decompose(self) -> list[JsonlReaderStage]: storage_options=self.read_kwargs.get("storage_options", None) if self.read_kwargs is not None else None, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), JsonlReaderStage( fields=self.fields, diff --git a/nemo_curator/stages/text/io/reader/parquet.py b/nemo_curator/stages/text/io/reader/parquet.py index cc997956b2..c6ca45bcfb 100644 --- a/nemo_curator/stages/text/io/reader/parquet.py +++ b/nemo_curator/stages/text/io/reader/parquet.py @@ -79,6 +79,10 @@ class ParquetReader(CompositeStage[_EmptyTask, DocumentBatch]): blocksize: int | str | None = None fields: list[str] | None = None # If specified, only read these columns read_kwargs: dict[str, Any] | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 file_extensions: list[str] = field(default_factory=lambda: FILETYPE_TO_DEFAULT_EXTENSIONS["parquet"]) task_type: Literal["document", "image", "video", "audio"] = "document" _generate_ids: bool = False @@ -105,6 +109,10 @@ def decompose(self) -> list[ParquetReaderStage]: blocksize=self.blocksize, file_extensions=self.file_extensions, storage_options=self.read_kwargs.get("storage_options", {}) if self.read_kwargs is not None else None, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), # Second stage: process file groups into document batches ParquetReaderStage( diff --git a/nemo_curator/stages/video/io/video_reader.py b/nemo_curator/stages/video/io/video_reader.py index 9200e696f3..8c2c8a9893 100644 --- a/nemo_curator/stages/video/io/video_reader.py +++ b/nemo_curator/stages/video/io/video_reader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -244,11 +244,22 @@ class VideoReader(CompositeStage[_EmptyTask, VideoTask]): input_video_path: Path to the directory containing video files video_limit: Maximum number of videos to process (None for unlimited) verbose: Whether to enable verbose logging during download/processing + enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs). + shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable. + total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. + minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to 0. """ input_video_path: str video_limit: int | None = None verbose: bool = False + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 def __post_init__(self): """Initialize the parent CompositeStage after dataclass initialization.""" @@ -276,6 +287,9 @@ def decompose(self) -> list[ProcessingStage]: List of processing stages: [FilePartitioningStage, VideoReaderStage] """ if is_remote_url(self.input_video_path): + if self.enable_array_partitioning: + msg = "enable_array_partitioning is not supported for ClientPartitioningStage" + raise NotImplementedError(msg) reader_stage = ClientPartitioningStage( file_paths=self.input_video_path, files_per_partition=1, @@ -288,6 +302,10 @@ def decompose(self) -> list[ProcessingStage]: files_per_partition=1, file_extensions=[".mp4", ".mov", ".avi", ".mkv", ".webm"], limit=self.video_limit, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ) download_stage = VideoReaderStage(input_path=self.input_video_path, verbose=self.verbose) From 43ee179d206aa2966a67cf07dd46f3659ffb87ec Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 10:23:54 -0700 Subject: [PATCH 03/11] add tutorial and tests Signed-off-by: Sarah Yurick --- nemo_curator/stages/file_partitioning.py | 10 +- tests/stages/common/test_file_partitioning.py | 147 ++++++- tutorials/slurm/README.md | 125 ++++++ tutorials/slurm/array_pipeline.py | 406 ++++++++++++++++++ tutorials/slurm/submit_array.sh | 175 ++++++++ 5 files changed, 847 insertions(+), 16 deletions(-) create mode 100644 tutorials/slurm/array_pipeline.py create mode 100644 tutorials/slurm/submit_array.sh diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index 6e0a73c3d2..b7513f02e2 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -22,7 +22,7 @@ from nemo_curator.backends.utils import RayStageSpecKeys from nemo_curator.stages.base import ProcessingStage from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import FileGroupTask, _EmptyTask +from nemo_curator.tasks import EmptyTask, FileGroupTask from nemo_curator.utils.file_utils import ( _split_files_as_per_blocksize, get_all_file_paths_and_size_under, @@ -50,7 +50,7 @@ def _get_int_or_env_var(input_value: int | str | None, default_name: str | None @dataclass -class FilePartitioningStage(ProcessingStage[_EmptyTask, FileGroupTask]): +class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]): """Stage that partitions input file paths into FileGroupTasks. This stage runs as a dedicated processing stage (not on the driver) @@ -148,7 +148,7 @@ def ray_stage_spec(self) -> dict[str, Any]: def xenna_stage_spec(self) -> dict[str, Any]: return {"num_workers_per_node": 1} - def _process(self, _: _EmptyTask) -> list[FileGroupTask]: + def _process(self, _: EmptyTask) -> list[FileGroupTask]: """Process the initial task to create file group tasks. This stage expects a simple Task with file paths information @@ -231,7 +231,7 @@ def _process(self, _: _EmptyTask) -> list[FileGroupTask]: logger.info(f"Created {len(tasks)} file groups from {len(files)} files") return tasks - def _process_array(self, task: _EmptyTask) -> list[FileGroupTask]: + def _process_array(self, task: EmptyTask) -> list[FileGroupTask]: all_tasks = self._process(task) assigned_tasks = [] @@ -250,7 +250,7 @@ def _process_array(self, task: _EmptyTask) -> list[FileGroupTask]: logger.info(f"Shard {self.shard_index}/{self.total_shards}: assigned {len(assigned_tasks)} of {len(all_tasks)} partitions") return assigned_tasks - def process(self, task: _EmptyTask) -> list[FileGroupTask]: + def process(self, task: EmptyTask) -> list[FileGroupTask]: if self.enable_array_partitioning: return self._process_array(task) else: diff --git a/tests/stages/common/test_file_partitioning.py b/tests/stages/common/test_file_partitioning.py index d67de158c9..2ec39853de 100644 --- a/tests/stages/common/test_file_partitioning.py +++ b/tests/stages/common/test_file_partitioning.py @@ -18,7 +18,7 @@ import pytest from nemo_curator.stages.file_partitioning import FilePartitioningStage -from nemo_curator.tasks import FileGroupTask, _EmptyTask +from nemo_curator.tasks import EmptyTask, FileGroupTask def _create_test_jsonl_files(base_dir: Path | str, num_files: int, subdir: str | None = None) -> list[str]: @@ -48,9 +48,9 @@ def temp_files(self, tmp_path: Path) -> list[str]: return files @pytest.fixture - def empty_task(self) -> _EmptyTask: + def empty_task(self) -> EmptyTask: """Create an empty task for testing.""" - return _EmptyTask( + return EmptyTask( dataset_name="test_dataset", data=None, _metadata={"source": "test"}, @@ -122,7 +122,7 @@ def test_ray_stage_spec(self): spec = stage.ray_stage_spec() assert spec["is_fanout_stage"] is True - def test_process_with_file_list(self, empty_task: _EmptyTask, tmp_path: Path): + def test_process_with_file_list(self, empty_task: EmptyTask, tmp_path: Path): """Test processing with a list of files.""" # Create these files in the tmp_path: test_files = _create_test_jsonl_files(tmp_path, num_files=3, subdir="path") @@ -135,7 +135,7 @@ def test_process_with_file_list(self, empty_task: _EmptyTask, tmp_path: Path): assert result[0].data == [test_files[0]] assert result[0].dataset_name == "path" - def test_process_with_files_per_partition(self, empty_task: _EmptyTask, tmp_path: Path): + def test_process_with_files_per_partition(self, empty_task: EmptyTask, tmp_path: Path): """Test processing with files_per_partition setting.""" test_files = _create_test_jsonl_files(tmp_path, num_files=4, subdir="path") stage = FilePartitioningStage(file_paths=test_files, files_per_partition=2) @@ -146,7 +146,7 @@ def test_process_with_files_per_partition(self, empty_task: _EmptyTask, tmp_path assert result[0].data == test_files[:2] assert result[1].data == test_files[2:] - def test_process_with_limit(self, empty_task: _EmptyTask, tmp_path: Path): + def test_process_with_limit(self, empty_task: EmptyTask, tmp_path: Path): """Test processing with limit parameter - this is the main test for the limit functionality.""" test_files = _create_test_jsonl_files(tmp_path, num_files=10, subdir="path") stage = FilePartitioningStage( @@ -168,7 +168,7 @@ def test_process_with_limit(self, empty_task: _EmptyTask, tmp_path: Path): assert task._metadata["partition_index"] == i assert task._metadata["total_partitions"] == 5 # Total partitions before limit - def test_process_with_limit_single_partition(self, empty_task: _EmptyTask, tmp_path: Path): + def test_process_with_limit_single_partition(self, empty_task: EmptyTask, tmp_path: Path): """Test limit when all files would be in a single partition.""" test_files = _create_test_jsonl_files(tmp_path, num_files=5, subdir="path") stage = FilePartitioningStage( @@ -180,7 +180,7 @@ def test_process_with_limit_single_partition(self, empty_task: _EmptyTask, tmp_p assert len(result) == 1 assert result[0].data == [test_files[0]] - def test_process_with_limit_zero(self, empty_task: _EmptyTask, tmp_path: Path): + def test_process_with_limit_zero(self, empty_task: EmptyTask, tmp_path: Path): """Test processing with limit set to 0.""" test_files = _create_test_jsonl_files(tmp_path, num_files=5, subdir="path") stage = FilePartitioningStage( @@ -193,7 +193,7 @@ def test_process_with_limit_zero(self, empty_task: _EmptyTask, tmp_path: Path): assert len(result) == 0 - def test_process_with_blocksize(self, empty_task: _EmptyTask, tmp_path: Path): + def test_process_with_blocksize(self, empty_task: EmptyTask, tmp_path: Path): """Test processing with blocksize setting.""" test_files = _create_test_jsonl_files(tmp_path, num_files=6) # Test files are 3 bytes each, so blocksize of 3B should create 6 partitions @@ -223,7 +223,7 @@ def test_both_blocksize_and_files_per_partition_errors(self): blocksize="128MB", ) - def test_process_empty_file_list(self, empty_task: _EmptyTask): + def test_process_empty_file_list(self, empty_task: EmptyTask): """Test processing with empty file list.""" stage = FilePartitioningStage(file_paths=[]) @@ -256,7 +256,7 @@ def test_partition_by_count(self): assert partitions[1] == ["file3", "file4"] assert partitions[2] == ["file5"] - def test_task_metadata(self, empty_task: _EmptyTask, tmp_path: Path): + def test_task_metadata(self, empty_task: EmptyTask, tmp_path: Path): """Test that created tasks have proper metadata.""" test_files = _create_test_jsonl_files(tmp_path, num_files=2, subdir="path") storage_options = {"option1": "value1"} @@ -271,3 +271,128 @@ def test_task_metadata(self, empty_task: _EmptyTask, tmp_path: Path): assert task._metadata["total_partitions"] == 2 assert task._metadata["source_files"] == [test_files[0]] assert task.reader_config == {} + + def test_enable_array_partitioning_with_explicit_values(self, monkeypatch: pytest.MonkeyPatch): + """Test array partitioning initialization with explicit shard values.""" + monkeypatch.setenv("SLURM_ARRAY_TASK_ID", "7") + monkeypatch.setenv("SLURM_ARRAY_TASK_COUNT", "11") + + stage = FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index=2, + total_shards=10, + minimum_shard_index=1, + ) + + assert stage.name == "array_file_partitioning" + assert stage.shard_index == 2 + assert stage.total_shards == 10 + assert stage.minimum_shard_index == 1 + + def test_enable_array_partitioning_reads_slurm_env_vars( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that array partitioning defaults to Slurm array env vars.""" + monkeypatch.setenv("SLURM_ARRAY_TASK_ID", "7") + monkeypatch.setenv("SLURM_ARRAY_TASK_COUNT", "11") + + stage = FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + ) + + assert stage.shard_index == 7 + assert stage.total_shards == 11 + assert stage.minimum_shard_index == 0 + + def test_enable_array_partitioning_supports_custom_env_var_names( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that shard parameters can be read from custom env var names.""" + monkeypatch.setenv("CUSTOM_SHARD_INDEX", "3") + monkeypatch.setenv("CUSTOM_TOTAL_SHARDS", "8") + monkeypatch.setenv("CUSTOM_MINIMUM_SHARD_INDEX", "1") + + stage = FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index="CUSTOM_SHARD_INDEX", + total_shards="CUSTOM_TOTAL_SHARDS", + minimum_shard_index="CUSTOM_MINIMUM_SHARD_INDEX", + ) + + assert stage.shard_index == 3 + assert stage.total_shards == 8 + assert stage.minimum_shard_index == 1 + + def test_enable_array_partitioning_requires_slurm_env_vars_by_default( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that missing default Slurm env vars raise a clear error.""" + monkeypatch.delenv("SLURM_ARRAY_TASK_ID", raising=False) + monkeypatch.delenv("SLURM_ARRAY_TASK_COUNT", raising=False) + + with pytest.raises(ValueError, match="SLURM_ARRAY_TASK_ID"): + FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + ) + + def test_enable_array_partitioning_assigns_each_partition_to_one_shard( + self, + empty_task: EmptyTask, + tmp_path: Path, + ): + """Test that array partitioning covers all partitions exactly once.""" + test_files = _create_test_jsonl_files(tmp_path, num_files=8, subdir="path") + expected_partitions = { + tuple(test_files[i : i + 2]) + for i in range(0, len(test_files), 2) + } + assigned_partitions = [] + + for shard_index in range(3): + stage = FilePartitioningStage( + file_paths=test_files, + files_per_partition=2, + enable_array_partitioning=True, + shard_index=shard_index, + total_shards=3, + ) + + assigned_partitions.extend(tuple(task.data) for task in stage.process(empty_task)) + + assert set(assigned_partitions) == expected_partitions + assert len(assigned_partitions) == len(expected_partitions) + + def test_enable_array_partitioning_supports_minimum_shard_index( + self, + empty_task: EmptyTask, + tmp_path: Path, + ): + """Test non-zero Slurm arrays by offsetting hash-assigned shard IDs.""" + test_files = _create_test_jsonl_files(tmp_path, num_files=8, subdir="path") + zero_indexed_stage = FilePartitioningStage( + file_paths=test_files, + files_per_partition=2, + enable_array_partitioning=True, + shard_index=0, + total_shards=3, + ) + one_indexed_stage = FilePartitioningStage( + file_paths=test_files, + files_per_partition=2, + enable_array_partitioning=True, + shard_index=1, + total_shards=3, + minimum_shard_index=1, + ) + + zero_indexed_result = [task.data for task in zero_indexed_stage.process(empty_task)] + one_indexed_result = [task.data for task in one_indexed_stage.process(empty_task)] + + assert one_indexed_result == zero_indexed_result diff --git a/tutorials/slurm/README.md b/tutorials/slurm/README.md index 05565de2a0..038b2a4756 100644 --- a/tutorials/slurm/README.md +++ b/tutorials/slurm/README.md @@ -9,6 +9,8 @@ This tutorial shows how to scale a NeMo Curator pipeline from a single laptop to | `pipeline.py` | A simple CPU-only pipeline (word-count + node-tag) that runs locally or on SLURM | | `submit.sh` | `sbatch` script for bare-metal clusters with a shared virtualenv | | `submit_container.sh` | `sbatch` script using the official NGC container (Pyxis/enroot) | +| `array_pipeline.py` | Generic JSONL/Parquet pipeline that processes one Slurm array shard | +| `submit_array.sh` | `sbatch --array` script for splitting many input files across independent jobs | --- @@ -184,6 +186,129 @@ tail -f logs/slurm_demo_.log --- +## SLURM job arrays — JSONL or Parquet file sharding + +Use `submit_array.sh` when you already have a large directory of text data files and want to split the file set across many independent Slurm jobs. Each array task starts its own Curator pipeline, hashes the input file partitions deterministically, and processes only the partitions assigned to that task. + +This pattern is useful when the dataset is naturally represented as many JSONL or Parquet files and you want simple horizontal scaling without coordination between jobs. + +### 1. Build the virtualenv on a shared filesystem + +The array example uses the official NGC container for the base environment, then activates your local checkout inside the container so unreleased source changes are picked up: + +```bash +cd /path/to/Curator +python -m venv .venv +source .venv/bin/activate +pip install -e . +``` + +Make sure `CURATOR_DIR`, `INPUT_DIR`, and `OUTPUT_DIR` are visible from every compute node, either because they are on a shared filesystem or because you set `CONTAINER_MOUNTS` to expose the right host paths inside the container. + +### 2. Submit a JSONL array job + +By default, `submit_array.sh` reads JSONL files and writes JSONL output: + +```bash +export CURATOR_DIR=/path/to/Curator +export INPUT_DIR=/shared/data/my-jsonl-dataset +export OUTPUT_DIR=/shared/output/my-jsonl-dataset + +# 20 array tasks, task IDs 0-19 +sbatch --array=0-19 tutorials/slurm/submit_array.sh +``` + +For example, if the input directory contains 2000 files and `FILES_PER_PARTITION=1`, each of the 20 array tasks receives roughly 100 file partitions. Assignment is hash-based rather than contiguous, so work remains stable if Slurm retries a task. + +Single-node array tasks use `RayClient`. If you override the allocation to use more than one node per array task, `submit_array.sh` automatically passes `--slurm` to `array_pipeline.py`, which switches that task to `SlurmRayClient` so the nodes form one Ray cluster: + +```bash +sbatch --array=0-9 --nodes=2 --cpus-per-task=32 tutorials/slurm/submit_array.sh +``` + +### 3. Use Parquet instead + +Set the input and output file types to `parquet`: + +```bash +export INPUT_DIR=/shared/data/my-parquet-dataset +export OUTPUT_DIR=/shared/output/my-parquet-dataset +export INPUT_FILE_TYPE=parquet +export OUTPUT_FILE_TYPE=parquet + +sbatch --array=0-19 tutorials/slurm/submit_array.sh +``` + +### 4. Edit sharding logic + +If your array does not start at zero, set `MINIMUM_SHARD_INDEX` to the first task ID: + +```bash +MINIMUM_SHARD_INDEX=1 sbatch --array=1-20 tutorials/slurm/submit_array.sh +``` + +If your cluster limits the number of tasks in a single Slurm array, you can still use a larger logical shard count by overriding `TOTAL_SHARDS` and submitting the shard ID range in multiple windows. For example, if you want 10,000 logical shards but the cluster allows only 1,000 array tasks per submission: + +```bash +export TOTAL_SHARDS=10000 + +sbatch --array=0-999 tutorials/slurm/submit_array.sh +sbatch --array=1000-1999 tutorials/slurm/submit_array.sh +sbatch --array=2000-2999 tutorials/slurm/submit_array.sh +# ... +sbatch --array=9000-9999 tutorials/slurm/submit_array.sh +``` + +In this mode, keep `MINIMUM_SHARD_INDEX=0` because the Slurm array task IDs are already the global shard IDs. Each partition is assigned by `hash(partition) % TOTAL_SHARDS`, so the full set of windowed submissions covers shards `0` through `9999` exactly once. Some individual tasks may receive no files if `TOTAL_SHARDS` is larger than the number of file partitions. + +Some clusters enforce the maximum array index rather than just the number of tasks per submitted array. If `--array=1000-1999` is rejected, this windowing pattern needs an explicit shard-index offset in the submission script rather than higher Slurm task IDs. + +### 5. Retry failed array tasks only + +When `submit_array.sh` launches `array_pipeline.py`, it passes `--checkpoint-path`. At startup, the driver process for each array task creates a pending retry manifest under: + +```bash +${CHECKPOINT_PATH:-$OUTPUT_DIR}/.nemo_curator_metadata/.slurm_array_retry/ +``` + +In other words, retries are tracked at `checkpoint_path/.nemo_curator_metadata/.slurm_array_retry/`. + +If the shard completes successfully, that shard's matching retry manifests are removed. If the process fails, is preempted, or reaches the Slurm time limit before cleanup runs, the manifest remains in the retry directory. Caught Python exceptions update the manifest with `status="failed"` and the error message; hard termination may leave `status="pending"`, which should still be treated as retryable after the original Slurm array has finished. + +Retry manifests are uniquely named JSON files written with an atomic rename, so multiple array tasks can write to the same retry directory without coordinating through a shared database. + +Each manifest records the failed `shard_index`, plus the `total_shards` and `minimum_shard_index` values used for the original run. To retry only the failed shards, rebuild a Slurm array list from those manifests and preserve the original shard settings. For example, using `jq`: + +```bash +export CHECKPOINT_PATH="${CHECKPOINT_PATH:-$OUTPUT_DIR}" +RETRY_DIR="${CHECKPOINT_PATH}/.nemo_curator_metadata/.slurm_array_retry" + +FAILED_SHARDS=$(jq -r '.shard_index' "${RETRY_DIR}"/manifest_*.json | sort -n -u | paste -sd, -) +TOTAL_SHARDS_VALUES=$(jq -r '.total_shards' "${RETRY_DIR}"/manifest_*.json | sort -n -u) +MINIMUM_SHARD_INDEX_VALUES=$(jq -r '.minimum_shard_index' "${RETRY_DIR}"/manifest_*.json | sort -n -u) + +if [[ -z "${FAILED_SHARDS}" ]]; then + echo "No failed shards found in ${RETRY_DIR}" >&2 + exit 1 +fi + +if [[ "${TOTAL_SHARDS_VALUES}" == *$'\n'* || "${MINIMUM_SHARD_INDEX_VALUES}" == *$'\n'* ]]; then + echo "Retry manifests contain multiple shard configurations; split them by run." >&2 + exit 1 +fi + +export TOTAL_SHARDS="${TOTAL_SHARDS_VALUES}" +export MINIMUM_SHARD_INDEX="${MINIMUM_SHARD_INDEX_VALUES}" + +sbatch --array="${FAILED_SHARDS}" tutorials/slurm/submit_array.sh +``` + +The `TOTAL_SHARDS` override is important. On a retry array like `--array=3,17,42`, Slurm sets `SLURM_ARRAY_TASK_COUNT=3`, but the data was originally assigned using the full logical shard count. Reusing the original `TOTAL_SHARDS` keeps `hash(partition) % total_shards` identical to the first run. + +Run this retry collection after the original Slurm array has finished, otherwise still-running tasks will still have pending manifests. Use one `CHECKPOINT_PATH` per logical array run, or move old retry manifests aside after building `FAILED_SHARDS`, so later retries do not include failures that already succeeded. + +--- + ## Configuration reference ### SlurmRayClient parameters diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py new file mode 100644 index 0000000000..b3bc5a5dc5 --- /dev/null +++ b/tutorials/slurm/array_pipeline.py @@ -0,0 +1,406 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Slurm array tutorial: split a large file collection across array tasks. + +Each Slurm array task processes its own slice of the input files. +With 2000 JSONL or Parquet files and --array=0-19, each of the 20 jobs gets +~100 files. + +Array partitioning parameters +------------------------------ +shard_index Which shard this job processes. + Default: SLURM_ARRAY_TASK_ID env var. +total_shards Total number of shards (i.e. array width). + Default: SLURM_ARRAY_TASK_COUNT env var. +minimum_shard_index Offset added to the hash-assigned shard before + comparing with shard_index. Use when the array does + not start at 0. E.g. --array=1-20 requires + minimum_shard_index=1 so shard IDs 1-20 match task IDs 1-20. + Default: 0. No env var fallback — must be set explicitly. + +Usage (local smoke test against a small sample directory):: + + # Simulate task 0 of 4 locally (zero-indexed array) + SLURM_ARRAY_TASK_ID=0 SLURM_ARRAY_TASK_COUNT=4 \\ + python tutorials/slurm/array_pipeline.py \\ + --input-dir /path/to/input/directory \\ + --output-dir /path/to/output/directory + + # Non-zero-indexed array: tasks 1-4, minimum_shard_index=1 + python tutorials/slurm/array_pipeline.py \\ + --input-dir /path/to/input/directory \\ + --output-dir /path/to/output/directory \\ + --shard-index 1 --total-shards 4 --minimum-shard-index 1 + + # Use Parquet input/output instead of the default JSONL: + python tutorials/slurm/array_pipeline.py \\ + --input-dir /path/to/input/directory \\ + --input-file-type parquet \\ + --output-dir /path/to/output/directory \\ + --output-file-type parquet \\ + --shard-index 0 --total-shards 4 + + # Or let the sbatch script set the env vars: + sbatch --array=0-19 tutorials/slurm/submit_array.sh +""" + +from __future__ import annotations + +import argparse +import contextlib +import datetime +import json +import os +import socket +import tempfile +import uuid +from pathlib import Path + +from loguru import logger + +from nemo_curator.core.client import RayClient, SlurmRayClient +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.text.io.reader import JsonlReader, ParquetReader +from nemo_curator.stages.text.io.writer import JsonlWriter, ParquetWriter + + +METADATA_DIRNAME = ".nemo_curator_metadata" +SLURM_ARRAY_RETRY_DIRNAME = ".slurm_array_retry" + + +def _safe_token(value: object) -> str: + """Convert a value to a conservative filename token.""" + return "".join(char if char.isalnum() or char in "._-" else "_" for char in str(value)) + + +def _resolve_int_arg(value: int | None, env_var: str) -> int | None: + """Resolve an optional CLI integer from an environment variable.""" + if value is not None: + return value + env_value = os.environ.get(env_var) + return int(env_value) if env_value is not None else None + + +def _is_driver_process(use_slurm: bool) -> bool: + """Return True for the process that should run the pipeline and own retry metadata.""" + return not use_slurm or os.environ.get("SLURM_NODEID", "0") == "0" + + +def _retry_manifest_prefix( + shard_index: int | None, + total_shards: int | None, + minimum_shard_index: int, +) -> str: + return ( + f"manifest_shard-{_safe_token(shard_index)}_" + f"total-{_safe_token(total_shards)}_" + f"min-{_safe_token(minimum_shard_index)}_" + ) + + +def _retry_manifest_payload( + shard_index: int | None, + total_shards: int | None, + minimum_shard_index: int, + status: str, + error: BaseException | None = None, +) -> dict[str, object]: + payload = { + "shard_index": shard_index, + "total_shards": total_shards, + "minimum_shard_index": minimum_shard_index, + "status": status, + "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), + "slurm_job_id": os.environ.get("SLURM_JOB_ID"), + "slurm_array_job_id": os.environ.get("SLURM_ARRAY_JOB_ID"), + "slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"), + "slurm_node_id": os.environ.get("SLURM_NODEID"), + "hostname": socket.gethostname(), + "pid": os.getpid(), + } + + if error is not None: + payload["error_type"] = type(error).__name__ + payload["error"] = str(error) + + return payload + + +def write_retry_manifest( + checkpoint_path: str, + shard_index: int | None, + total_shards: int | None, + minimum_shard_index: int, + status: str, + error: BaseException | None = None, + manifest_file: Path | None = None, +) -> Path: + """Write a retry manifest using a unique name and atomic rename.""" + retry_dir = Path(checkpoint_path, METADATA_DIRNAME, SLURM_ARRAY_RETRY_DIRNAME).absolute() + retry_dir.mkdir(parents=True, exist_ok=True) + + created_at = datetime.datetime.now(datetime.timezone.utc) + manifest = _retry_manifest_payload( + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + status=status, + error=error, + ) + + if manifest_file is None: + timestamp = created_at.strftime("%Y%m%d%H%M%S%f") + manifest_file = retry_dir / ( + _retry_manifest_prefix(shard_index, total_shards, minimum_shard_index) + + f"job-{_safe_token(os.environ.get('SLURM_JOB_ID', 'local'))}_" + + f"task-{_safe_token(os.environ.get('SLURM_ARRAY_TASK_ID', 'local'))}_" + + f"node-{_safe_token(os.environ.get('SLURM_NODEID', '0'))}_" + + f"pid-{os.getpid()}_" + + f"{timestamp}_" + + f"{uuid.uuid4().hex}.json" + ) + + tmp_path = None + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=retry_dir, + prefix=f".{manifest_file.name}.", + suffix=".tmp", + delete=False, + ) as tmp_file: + tmp_path = Path(tmp_file.name) + json.dump(manifest, tmp_file, indent=2, sort_keys=True) + tmp_file.write("\n") + tmp_file.flush() + os.fsync(tmp_file.fileno()) + os.replace(tmp_path, manifest_file) + except Exception: + if tmp_path is not None: + with contextlib.suppress(FileNotFoundError): + tmp_path.unlink() + raise + + return manifest_file + + +def remove_retry_manifests( + checkpoint_path: str, + shard_index: int | None, + total_shards: int | None, + minimum_shard_index: int, +) -> None: + """Remove retry manifests for a shard after successful completion.""" + retry_dir = Path(checkpoint_path, METADATA_DIRNAME, SLURM_ARRAY_RETRY_DIRNAME).absolute() + if not retry_dir.exists(): + return + + pattern = _retry_manifest_prefix(shard_index, total_shards, minimum_shard_index) + "*.json" + for manifest_file in retry_dir.glob(pattern): + with contextlib.suppress(FileNotFoundError): + manifest_file.unlink() + + +def build_pipeline( + input_dir: str, + input_file_type: str, + output_dir: str, + output_file_type: str, + files_per_partition: int, + shard_index: int | None, + total_shards: int | None, + minimum_shard_index: int, +) -> Pipeline: + pipeline = Pipeline( + name="slurm_array_demo", + description=( + "Read files from input directory assigned to this Slurm array task " + "and write them out to output directory." + ), + ) + + # enable_array_partitioning=True reads SLURM_ARRAY_TASK_ID / SLURM_ARRAY_TASK_COUNT + # from the environment by default. Explicit shard_index / total_shards / minimum_shard_index + # override those env vars — useful for non-Slurm schedulers or local testing. + if input_file_type == "jsonl": + pipeline.add_stage( + JsonlReader( + file_paths=input_dir, + files_per_partition=files_per_partition, + enable_array_partitioning=True, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + ) + ) + elif input_file_type == "parquet": + pipeline.add_stage( + ParquetReader( + file_paths=input_dir, + files_per_partition=files_per_partition, + enable_array_partitioning=True, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + ) + ) + else: + raise ValueError(f"Unsupported input file type: {input_file_type}") + + if output_file_type == "jsonl": + pipeline.add_stage(JsonlWriter(output_dir)) + elif output_file_type == "parquet": + pipeline.add_stage(ParquetWriter(output_dir)) + else: + raise ValueError(f"Unsupported output file type: {output_file_type}") + + return pipeline + + +def main() -> None: + parser = argparse.ArgumentParser(description="Slurm array file-partitioning demo") + parser.add_argument("--input-dir", required=True, help="Directory containing input files") + parser.add_argument( + "--input-file-type", + choices=["jsonl", "parquet"], + default="jsonl", + help="Type of input files (default: jsonl)", + ) + parser.add_argument("--output-dir", required=True, help="Directory to write output files") + parser.add_argument( + "--output-file-type", + choices=["jsonl", "parquet"], + default="jsonl", + help="Type of output files (default: jsonl)", + ) + parser.add_argument( + "--files-per-partition", + type=int, + default=1, + help="Files grouped into each FileGroupTask (default: 1)", + ) + parser.add_argument( + "--shard-index", + type=int, + default=None, + help="Shard to process. Defaults to SLURM_ARRAY_TASK_ID.", + ) + parser.add_argument( + "--total-shards", + type=int, + default=None, + help="Total number of shards. Defaults to SLURM_ARRAY_TASK_COUNT.", + ) + parser.add_argument( + "--minimum-shard-index", + type=int, + default=0, + help=( + "Offset added to the hash-assigned shard before comparison. " + "Set to match the first task ID when the array does not start at 0 " + "(e.g. --array=1-20 requires --minimum-shard-index=1). Default: 0." + ), + ) + parser.add_argument( + "--checkpoint-path", + dest="checkpoint_path", + type=str, + default=None, + help=( + "Path for checkpoint metadata. Slurm array retry manifests are written under " + "/.nemo_curator_metadata/.slurm_array_retry/. Defaults to None." + ), + ) + parser.add_argument( + "--slurm", + action="store_true", + help="Use SlurmRayClient for multi-node srun jobs.", + ) + args = parser.parse_args() + + ray_client = SlurmRayClient() if args.slurm else RayClient() + shard_index = args.shard_index + total_shards = args.total_shards + minimum_shard_index = args.minimum_shard_index + retry_manifest_file = None + is_driver_process = _is_driver_process(args.slurm) + should_manage_retry_manifest = args.checkpoint_path is not None and is_driver_process + + try: + shard_index = _resolve_int_arg(args.shard_index, "SLURM_ARRAY_TASK_ID") + total_shards = _resolve_int_arg(args.total_shards, "SLURM_ARRAY_TASK_COUNT") + + if should_manage_retry_manifest: + retry_manifest_file = write_retry_manifest( + checkpoint_path=args.checkpoint_path, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + status="pending", + ) + logger.info(f"Wrote pending Slurm array retry manifest to {retry_manifest_file}") + + ray_client.start() + + pipeline = build_pipeline( + args.input_dir, + args.input_file_type, + args.output_dir, + args.output_file_type, + args.files_per_partition, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + ) + logger.info(f"\n{pipeline.describe()}") + + pipeline.run() + + if should_manage_retry_manifest: + try: + remove_retry_manifests( + checkpoint_path=args.checkpoint_path, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + ) + except Exception as cleanup_error: + logger.error(f"Pipeline succeeded but failed to remove retry manifest: {cleanup_error}") + except Exception as e: + if should_manage_retry_manifest: + try: + manifest_file = write_retry_manifest( + checkpoint_path=args.checkpoint_path, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + status="failed", + error=e, + manifest_file=retry_manifest_file, + ) + logger.error(f"Wrote Slurm array retry manifest to {manifest_file}") + except Exception as manifest_error: + logger.error(f"Failed to write Slurm array retry manifest: {manifest_error}") + + logger.error(f"Error running pipeline: {e}") + raise + finally: + if is_driver_process: + ray_client.stop() + + +if __name__ == "__main__": + main() diff --git a/tutorials/slurm/submit_array.sh b/tutorials/slurm/submit_array.sh new file mode 100644 index 0000000000..2192911465 --- /dev/null +++ b/tutorials/slurm/submit_array.sh @@ -0,0 +1,175 @@ +#!/bin/bash +# ============================================================================= +# NeMo Curator — Slurm array submit script +# +# Splits a large set of JSONL or Parquet files across multiple Slurm array +# tasks so that each job independently processes its assigned slice of the input. +# +# Example: 2000 input files, --array=0-19 -> 20 jobs x ~100 files each. +# +# How it works: +# - FilePartitioningStage groups all files into partitions of files_per_partition. +# - Each array task reads SLURM_ARRAY_TASK_ID / SLURM_ARRAY_TASK_COUNT and +# selects only the partitions assigned to it via deterministic SHA-256 hashing. +# - Jobs run in parallel with no coordination between them. +# +# Prerequisites: +# - NeMo Curator source checked out on a shared filesystem (Lustre, NFS, etc.) +# - A virtualenv built at ${CURATOR_DIR}/.venv with NeMo Curator installed +# - INPUT_DIR set to a directory of JSONL or Parquet files visible from all compute nodes +# - OUTPUT_DIR set to a writable shared directory +# - CHECKPOINT_PATH set to a writable shared path for retry manifests +# +# Usage: +# # 20 jobs (task IDs 0-19), ~100 files per job with a 2000-file dataset +# sbatch --array=0-19 tutorials/slurm/submit_array.sh +# +# # Override array size or resources at submission time: +# sbatch --array=0-9 --cpus-per-task=32 tutorials/slurm/submit_array.sh +# sbatch --array=0-39 --time=02:00:00 tutorials/slurm/submit_array.sh +# +# Array indexing: +# shard_index = SLURM_ARRAY_TASK_ID (set automatically by Slurm) +# total_shards = SLURM_ARRAY_TASK_COUNT (set automatically by Slurm) +# minimum_shard_index defaults to 0 — no env var fallback. +# +# If your array does not start at 0 (e.g. --array=1-20), set: +# MINIMUM_SHARD_INDEX=1 sbatch --array=1-20 tutorials/slurm/submit_array.sh +# ============================================================================= + +#SBATCH --job-name=curator-array +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --gpus-per-node=0 +#SBATCH --time=01:00:00 +#SBATCH --output=array_%A_%a.log +#SBATCH --error=array_%A_%a.log + +set -euo pipefail + +# --------------------------------------------------------------------------- +# Paths — adjust to your environment +# --------------------------------------------------------------------------- +CURATOR_DIR="${CURATOR_DIR:-$(cd "$(dirname "$0")/../.." && pwd)}" + +# Input and output directories +INPUT_DIR="${INPUT_DIR:-/path/to/your/input/directory}" +OUTPUT_DIR="${OUTPUT_DIR:-/path/to/your/output/directory}" + +# Retry manifests are written under: +# ${CHECKPOINT_PATH}/.nemo_curator_metadata/.slurm_array_retry/ +# Defaults to OUTPUT_DIR. If you override CONTAINER_MOUNTS, make sure it still +# includes CHECKPOINT_PATH. +CHECKPOINT_PATH="${CHECKPOINT_PATH:-${OUTPUT_DIR}}" + +# Official NeMo Curator container from NGC. +# Browse available tags: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo-curator +CONTAINER_IMAGE="${CONTAINER_IMAGE:-nvcr.io/nvidia/nemo-curator:26.02}" + +# Mount the shared filesystem paths that contain your code and data. +# Format: :[,:] +# Override this if your cluster expects filesystem roots, e.g. /lustre:/lustre,/home:/home. +DEFAULT_CONTAINER_MOUNTS="${CURATOR_DIR}:${CURATOR_DIR},${INPUT_DIR}:${INPUT_DIR},${OUTPUT_DIR}:${OUTPUT_DIR}" +if [[ "${CHECKPOINT_PATH}" != "${OUTPUT_DIR}" ]]; then + DEFAULT_CONTAINER_MOUNTS="${DEFAULT_CONTAINER_MOUNTS},${CHECKPOINT_PATH}:${CHECKPOINT_PATH}" +fi +CONTAINER_MOUNTS="${CONTAINER_MOUNTS:-${DEFAULT_CONTAINER_MOUNTS}}" + +# Input and output file types +INPUT_FILE_TYPE="${INPUT_FILE_TYPE:-jsonl}" +OUTPUT_FILE_TYPE="${OUTPUT_FILE_TYPE:-jsonl}" + +# Number of files to read into a single DocumentBatch +FILES_PER_PARTITION="${FILES_PER_PARTITION:-1}" + +# Shard index and total shards +SHARD_INDEX="${SHARD_INDEX:-${SLURM_ARRAY_TASK_ID}}" +TOTAL_SHARDS="${TOTAL_SHARDS:-${SLURM_ARRAY_TASK_COUNT}}" + +# Offset between 0-indexed hash assignments and SLURM_ARRAY_TASK_ID. +# Leave at 0 for --array=0-N. Set to the array start value for --array=K-N. +MINIMUM_SHARD_INDEX="${MINIMUM_SHARD_INDEX:-0}" + +# Use SlurmRayClient only when this array task spans multiple nodes. Single-node +# array tasks can use the regular RayClient. +NUM_NODES="${SLURM_JOB_NUM_NODES:-${SLURM_NNODES:-1}}" +USE_SLURM_RAY=0 +if (( NUM_NODES > 1 )); then + USE_SLURM_RAY=1 +fi + +mkdir -p "${CURATOR_DIR}/logs" "${OUTPUT_DIR}" "${CHECKPOINT_PATH}" + +export CURATOR_DIR +export INPUT_DIR +export OUTPUT_DIR +export CHECKPOINT_PATH +export INPUT_FILE_TYPE +export OUTPUT_FILE_TYPE +export FILES_PER_PARTITION +export SHARD_INDEX +export TOTAL_SHARDS +export MINIMUM_SHARD_INDEX +export USE_SLURM_RAY + +echo "==================================================" +echo " NeMo Curator — Slurm Array Demo" +echo "==================================================" +echo " Job array ID : ${SLURM_ARRAY_JOB_ID}" +echo " Array task ID : ${SLURM_ARRAY_TASK_ID}" +echo " Array task cnt : ${SLURM_ARRAY_TASK_COUNT}" +echo " Nodes : ${NUM_NODES}" +echo " Ray client : $([[ "${USE_SLURM_RAY}" == "1" ]] && echo SlurmRayClient || echo RayClient)" +echo " Node : $(hostname)" +echo " Container : ${CONTAINER_IMAGE}" +echo " Mounts : ${CONTAINER_MOUNTS}" +echo " Dir : ${CURATOR_DIR}" +echo " Checkpoint path: ${CHECKPOINT_PATH}" +echo "==================================================" + +# Each array task processes only the file partitions hashed to its +# SLURM_ARRAY_TASK_ID. With --nodes=1, the task uses a local RayClient. With +# --nodes>1, the same Python entrypoint runs on every node and --slurm enables +# SlurmRayClient so workers join the head-node Ray cluster. +srun \ + --ntasks-per-node=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + --container-workdir="${CURATOR_DIR}" \ + bash -c ' +set -euo pipefail + +export RAY_TMPDIR="/tmp/ray_${SLURM_JOB_ID}" +export RAY_PORT_BROADCAST_DIR="${CURATOR_DIR}/logs" + +# Activate the local virtualenv so the latest Curator code (from this +# checkout) is used instead of the version bundled in the container image. +source "${CURATOR_DIR}/.venv/bin/activate" + +echo "[$(hostname)] SLURM_NODEID=${SLURM_NODEID:-0} python=$(python --version 2>&1)" +nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader 2>/dev/null \ + | sed "s/^/ [$(hostname)] GPU /" || echo " [$(hostname)] no GPUs" + +pipeline_args=( + --input-dir "${INPUT_DIR}" + --input-file-type "${INPUT_FILE_TYPE}" + --output-dir "${OUTPUT_DIR}" + --output-file-type "${OUTPUT_FILE_TYPE}" + --files-per-partition "${FILES_PER_PARTITION}" + --shard-index "${SHARD_INDEX}" + --total-shards "${TOTAL_SHARDS}" + --minimum-shard-index "${MINIMUM_SHARD_INDEX}" + --checkpoint-path "${CHECKPOINT_PATH}" +) + +if [[ "${USE_SLURM_RAY}" == "1" ]]; then + pipeline_args+=(--slurm) +fi + +python "${CURATOR_DIR}/tutorials/slurm/array_pipeline.py" "${pipeline_args[@]}" +' + +echo "==================================================" +echo " Array task ${SLURM_ARRAY_TASK_ID} DONE" +echo "==================================================" From acfeceb57ddbba5098fd57a6dde5b7f5f86ef745 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 10:29:11 -0700 Subject: [PATCH 04/11] ruff Signed-off-by: Sarah Yurick --- tutorials/slurm/array_pipeline.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py index b3bc5a5dc5..02b0532096 100644 --- a/tutorials/slurm/array_pipeline.py +++ b/tutorials/slurm/array_pipeline.py @@ -75,7 +75,6 @@ from nemo_curator.stages.text.io.reader import JsonlReader, ParquetReader from nemo_curator.stages.text.io.writer import JsonlWriter, ParquetWriter - METADATA_DIRNAME = ".nemo_curator_metadata" SLURM_ARRAY_RETRY_DIRNAME = ".slurm_array_retry" @@ -122,7 +121,7 @@ def _retry_manifest_payload( "total_shards": total_shards, "minimum_shard_index": minimum_shard_index, "status": status, - "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), + "created_at": datetime.datetime.now(datetime.UTC).isoformat(), "slurm_job_id": os.environ.get("SLURM_JOB_ID"), "slurm_array_job_id": os.environ.get("SLURM_ARRAY_JOB_ID"), "slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"), @@ -138,7 +137,7 @@ def _retry_manifest_payload( return payload -def write_retry_manifest( +def write_retry_manifest( # noqa: PLR0913 checkpoint_path: str, shard_index: int | None, total_shards: int | None, @@ -151,7 +150,7 @@ def write_retry_manifest( retry_dir = Path(checkpoint_path, METADATA_DIRNAME, SLURM_ARRAY_RETRY_DIRNAME).absolute() retry_dir.mkdir(parents=True, exist_ok=True) - created_at = datetime.datetime.now(datetime.timezone.utc) + created_at = datetime.datetime.now(datetime.UTC) manifest = _retry_manifest_payload( shard_index=shard_index, total_shards=total_shards, @@ -214,7 +213,7 @@ def remove_retry_manifests( manifest_file.unlink() -def build_pipeline( +def build_pipeline( # noqa: PLR0913 input_dir: str, input_file_type: str, output_dir: str, @@ -227,8 +226,7 @@ def build_pipeline( pipeline = Pipeline( name="slurm_array_demo", description=( - "Read files from input directory assigned to this Slurm array task " - "and write them out to output directory." + "Read files from input directory assigned to this Slurm array task and write them out to output directory." ), ) @@ -258,14 +256,16 @@ def build_pipeline( ) ) else: - raise ValueError(f"Unsupported input file type: {input_file_type}") + msg = f"Unsupported input file type: {input_file_type}" + raise ValueError(msg) if output_file_type == "jsonl": pipeline.add_stage(JsonlWriter(output_dir)) elif output_file_type == "parquet": pipeline.add_stage(ParquetWriter(output_dir)) else: - raise ValueError(f"Unsupported output file type: {output_file_type}") + msg = f"Unsupported output file type: {output_file_type}" + raise ValueError(msg) return pipeline @@ -377,7 +377,7 @@ def main() -> None: total_shards=total_shards, minimum_shard_index=minimum_shard_index, ) - except Exception as cleanup_error: + except Exception as cleanup_error: # noqa: BLE001 logger.error(f"Pipeline succeeded but failed to remove retry manifest: {cleanup_error}") except Exception as e: if should_manage_retry_manifest: @@ -392,7 +392,7 @@ def main() -> None: manifest_file=retry_manifest_file, ) logger.error(f"Wrote Slurm array retry manifest to {manifest_file}") - except Exception as manifest_error: + except Exception as manifest_error: # noqa: BLE001 logger.error(f"Failed to write Slurm array retry manifest: {manifest_error}") logger.error(f"Error running pipeline: {e}") From 6eaf95e7ccf746a210b38961b77fb8e81c01b7a8 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 11:18:11 -0700 Subject: [PATCH 05/11] address greptile reviews Signed-off-by: Sarah Yurick --- nemo_curator/stages/file_partitioning.py | 39 +++++++--- tests/stages/common/test_file_partitioning.py | 50 ++++++++++++ tutorials/slurm/README.md | 28 +++++-- tutorials/slurm/array_pipeline.py | 78 +++++++++++-------- tutorials/slurm/submit_array.sh | 15 +++- 5 files changed, 157 insertions(+), 53 deletions(-) diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index b7513f02e2..041d669133 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -34,20 +34,18 @@ def _get_int_or_env_var(input_value: int | str | None, default_name: str | None = None) -> int: if type(input_value) is int: return input_value - elif type(input_value) is str: - if os.environ.get(input_value) is None: - msg = f"Environment variable {input_value} is not set" - raise ValueError(msg) - return int(os.environ.get(input_value)) - elif default_name is not None: - if os.environ.get(default_name) is None: - msg = f"Environment variable {default_name} is not set" - raise ValueError(msg) - return int(os.environ.get(default_name)) - else: + + env_var = input_value if type(input_value) is str else default_name + if env_var is None: msg = f"Invalid input value: {input_value}, must be an integer or a string" raise ValueError(msg) + env_value = os.environ.get(env_var) + if env_value is None: + msg = f"Environment variable {env_var} is not set" + raise ValueError(msg) + return int(env_value) + @dataclass class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]): @@ -131,6 +129,19 @@ def __post_init__(self): self.shard_index = _get_int_or_env_var(self.shard_index, "SLURM_ARRAY_TASK_ID") self.total_shards = _get_int_or_env_var(self.total_shards, "SLURM_ARRAY_TASK_COUNT") self.minimum_shard_index = _get_int_or_env_var(self.minimum_shard_index) + if self.total_shards <= 0: + msg = f"total_shards must be greater than 0, got {self.total_shards}" + raise ValueError(msg) + min_assignable_shard_index = self.minimum_shard_index + max_assignable_shard_index = self.minimum_shard_index + self.total_shards - 1 + if not min_assignable_shard_index <= self.shard_index <= max_assignable_shard_index: + logger.warning( + "shard_index={} is outside the assignable shard range [{}, {}]. " + "This task will not receive any partitions.", + self.shard_index, + min_assignable_shard_index, + max_assignable_shard_index, + ) self.name = "array_file_partitioning" def inputs(self) -> tuple[list[str], list[str]]: @@ -247,7 +258,11 @@ def _process_array(self, task: EmptyTask) -> list[FileGroupTask]: if assigned == self.shard_index: assigned_tasks.append(ft) - logger.info(f"Shard {self.shard_index}/{self.total_shards}: assigned {len(assigned_tasks)} of {len(all_tasks)} partitions") + msg = f"Shard {self.shard_index}/{self.total_shards}: assigned {len(assigned_tasks)} of {len(all_tasks)} partitions" + if len(assigned_tasks) == 0 and len(all_tasks) > 0: + logger.warning(msg) + else: + logger.info(msg) return assigned_tasks def process(self, task: EmptyTask) -> list[FileGroupTask]: diff --git a/tests/stages/common/test_file_partitioning.py b/tests/stages/common/test_file_partitioning.py index 2ec39853de..17845e963c 100644 --- a/tests/stages/common/test_file_partitioning.py +++ b/tests/stages/common/test_file_partitioning.py @@ -342,6 +342,56 @@ def test_enable_array_partitioning_requires_slurm_env_vars_by_default( enable_array_partitioning=True, ) + def test_enable_array_partitioning_requires_positive_total_shards(self): + """Test that array partitioning rejects non-positive shard counts.""" + with pytest.raises(ValueError, match="total_shards must be greater than 0"): + FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index=0, + total_shards=0, + ) + + def test_enable_array_partitioning_warns_for_out_of_range_shard( + self, + caplog: pytest.LogCaptureFixture, + ): + """Test that shard IDs outside the assignable range produce a clear warning.""" + with caplog.at_level("WARNING"): + stage = FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index=0, + total_shards=10, + minimum_shard_index=1, + ) + + assert stage.shard_index == 0 + assert stage.total_shards == 10 + assert stage.minimum_shard_index == 1 + assert "outside the assignable shard range [1, 10]" in caplog.text + + def test_enable_array_partitioning_warns_when_no_partitions_assigned( + self, + caplog: pytest.LogCaptureFixture, + empty_task: EmptyTask, + tmp_path: Path, + ): + """Test that non-empty input with zero assigned partitions is visible in logs.""" + test_files = _create_test_jsonl_files(tmp_path, num_files=2, subdir="path") + stage = FilePartitioningStage( + file_paths=test_files, + enable_array_partitioning=True, + shard_index=2, + total_shards=2, + ) + + with caplog.at_level("WARNING"): + result = stage.process(empty_task) + + assert result == [] + assert "assigned 0 of 2 partitions" in caplog.text + def test_enable_array_partitioning_assigns_each_partition_to_one_shard( self, empty_task: EmptyTask, diff --git a/tutorials/slurm/README.md b/tutorials/slurm/README.md index 038b2a4756..3a47a9bd95 100644 --- a/tutorials/slurm/README.md +++ b/tutorials/slurm/README.md @@ -261,7 +261,21 @@ sbatch --array=9000-9999 tutorials/slurm/submit_array.sh In this mode, keep `MINIMUM_SHARD_INDEX=0` because the Slurm array task IDs are already the global shard IDs. Each partition is assigned by `hash(partition) % TOTAL_SHARDS`, so the full set of windowed submissions covers shards `0` through `9999` exactly once. Some individual tasks may receive no files if `TOTAL_SHARDS` is larger than the number of file partitions. -Some clusters enforce the maximum array index rather than just the number of tasks per submitted array. If `--array=1000-1999` is rejected, this windowing pattern needs an explicit shard-index offset in the submission script rather than higher Slurm task IDs. +Some clusters enforce the maximum array index rather than just the number of tasks per submitted array. If `--array=1000-1999` is rejected, use `SHARD_INDEX_OFFSET` instead of higher Slurm task IDs. + +For those clusters, submit each window with Slurm task IDs `0-999` and set `SHARD_INDEX_OFFSET` so the script computes the global shard ID as `SLURM_ARRAY_TASK_ID + SHARD_INDEX_OFFSET`: + +```bash +export TOTAL_SHARDS=10000 + +SHARD_INDEX_OFFSET=0 sbatch --array=0-999 tutorials/slurm/submit_array.sh +SHARD_INDEX_OFFSET=1000 sbatch --array=0-999 tutorials/slurm/submit_array.sh +SHARD_INDEX_OFFSET=2000 sbatch --array=0-999 tutorials/slurm/submit_array.sh +# ... +SHARD_INDEX_OFFSET=9000 sbatch --array=0-999 tutorials/slurm/submit_array.sh +``` + +Keep `MINIMUM_SHARD_INDEX=0` for this offset mode too. `SHARD_INDEX_OFFSET` changes the logical shard ID passed to the pipeline; `MINIMUM_SHARD_INDEX` changes the assignable shard range used by the partitioning stage. ### 5. Retry failed array tasks only @@ -283,15 +297,19 @@ Each manifest records the failed `shard_index`, plus the `total_shards` and `min export CHECKPOINT_PATH="${CHECKPOINT_PATH:-$OUTPUT_DIR}" RETRY_DIR="${CHECKPOINT_PATH}/.nemo_curator_metadata/.slurm_array_retry" -FAILED_SHARDS=$(jq -r '.shard_index' "${RETRY_DIR}"/manifest_*.json | sort -n -u | paste -sd, -) -TOTAL_SHARDS_VALUES=$(jq -r '.total_shards' "${RETRY_DIR}"/manifest_*.json | sort -n -u) -MINIMUM_SHARD_INDEX_VALUES=$(jq -r '.minimum_shard_index' "${RETRY_DIR}"/manifest_*.json | sort -n -u) +shopt -s nullglob +MANIFEST_FILES=("${RETRY_DIR}"/manifest_*.json) +shopt -u nullglob -if [[ -z "${FAILED_SHARDS}" ]]; then +if (( ${#MANIFEST_FILES[@]} == 0 )); then echo "No failed shards found in ${RETRY_DIR}" >&2 exit 1 fi +FAILED_SHARDS=$(jq -r '.shard_index' "${MANIFEST_FILES[@]}" | sort -n -u | paste -sd, -) +TOTAL_SHARDS_VALUES=$(jq -r '.total_shards' "${MANIFEST_FILES[@]}" | sort -n -u) +MINIMUM_SHARD_INDEX_VALUES=$(jq -r '.minimum_shard_index' "${MANIFEST_FILES[@]}" | sort -n -u) + if [[ "${TOTAL_SHARDS_VALUES}" == *$'\n'* || "${MINIMUM_SHARD_INDEX_VALUES}" == *$'\n'* ]]; then echo "Retry manifests contain multiple shard configurations; split them by run." >&2 exit 1 diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py index 02b0532096..08b817acff 100644 --- a/tutorials/slurm/array_pipeline.py +++ b/tutorials/slurm/array_pipeline.py @@ -21,22 +21,20 @@ Array partitioning parameters ------------------------------ shard_index Which shard this job processes. - Default: SLURM_ARRAY_TASK_ID env var. total_shards Total number of shards (i.e. array width). - Default: SLURM_ARRAY_TASK_COUNT env var. minimum_shard_index Offset added to the hash-assigned shard before comparing with shard_index. Use when the array does not start at 0. E.g. --array=1-20 requires minimum_shard_index=1 so shard IDs 1-20 match task IDs 1-20. - Default: 0. No env var fallback — must be set explicitly. + Default: 0. Usage (local smoke test against a small sample directory):: # Simulate task 0 of 4 locally (zero-indexed array) - SLURM_ARRAY_TASK_ID=0 SLURM_ARRAY_TASK_COUNT=4 \\ - python tutorials/slurm/array_pipeline.py \\ - --input-dir /path/to/input/directory \\ - --output-dir /path/to/output/directory + python tutorials/slurm/array_pipeline.py \\ + --input-dir /path/to/input/directory \\ + --output-dir /path/to/output/directory \\ + --shard-index 0 --total-shards 4 # Non-zero-indexed array: tasks 1-4, minimum_shard_index=1 python tutorials/slurm/array_pipeline.py \\ @@ -52,7 +50,7 @@ --output-file-type parquet \\ --shard-index 0 --total-shards 4 - # Or let the sbatch script set the env vars: + # Or let the sbatch script read Slurm env vars and pass explicit args: sbatch --array=0-19 tutorials/slurm/submit_array.sh """ @@ -84,12 +82,24 @@ def _safe_token(value: object) -> str: return "".join(char if char.isalnum() or char in "._-" else "_" for char in str(value)) -def _resolve_int_arg(value: int | None, env_var: str) -> int | None: - """Resolve an optional CLI integer from an environment variable.""" - if value is not None: +def _parse_int_or_env_name(value: str) -> int | str: + """Parse an integer value or keep an environment variable name.""" + try: + return int(value) + except ValueError: return value - env_value = os.environ.get(env_var) - return int(env_value) if env_value is not None else None + + +def _resolve_int_or_env_name(value: int | str, label: str) -> int: + """Resolve an integer or an environment variable name containing an integer.""" + if isinstance(value, int): + return value + + env_value = os.environ.get(value) + if env_value is None: + msg = f"{label} references environment variable {value}, but it is not set" + raise ValueError(msg) + return int(env_value) def _is_driver_process(use_slurm: bool) -> bool: @@ -98,8 +108,8 @@ def _is_driver_process(use_slurm: bool) -> bool: def _retry_manifest_prefix( - shard_index: int | None, - total_shards: int | None, + shard_index: int, + total_shards: int, minimum_shard_index: int, ) -> str: return ( @@ -110,10 +120,11 @@ def _retry_manifest_prefix( def _retry_manifest_payload( - shard_index: int | None, - total_shards: int | None, + shard_index: int, + total_shards: int, minimum_shard_index: int, status: str, + created_at: datetime.datetime, error: BaseException | None = None, ) -> dict[str, object]: payload = { @@ -121,7 +132,7 @@ def _retry_manifest_payload( "total_shards": total_shards, "minimum_shard_index": minimum_shard_index, "status": status, - "created_at": datetime.datetime.now(datetime.UTC).isoformat(), + "created_at": created_at.isoformat(), "slurm_job_id": os.environ.get("SLURM_JOB_ID"), "slurm_array_job_id": os.environ.get("SLURM_ARRAY_JOB_ID"), "slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"), @@ -156,6 +167,7 @@ def write_retry_manifest( # noqa: PLR0913 total_shards=total_shards, minimum_shard_index=minimum_shard_index, status=status, + created_at=created_at, error=error, ) @@ -230,9 +242,9 @@ def build_pipeline( # noqa: PLR0913 ), ) - # enable_array_partitioning=True reads SLURM_ARRAY_TASK_ID / SLURM_ARRAY_TASK_COUNT - # from the environment by default. Explicit shard_index / total_shards / minimum_shard_index - # override those env vars — useful for non-Slurm schedulers or local testing. + # submit_array.sh maps Slurm array env vars into explicit shard arguments. + # Direct users may also pass env var names; those are resolved before the + # pipeline is built so retry manifests record concrete shard values. if input_file_type == "jsonl": pipeline.add_stage( JsonlReader( @@ -294,19 +306,19 @@ def main() -> None: ) parser.add_argument( "--shard-index", - type=int, - default=None, - help="Shard to process. Defaults to SLURM_ARRAY_TASK_ID.", + type=_parse_int_or_env_name, + required=True, + help="Shard to process, as an integer or environment variable name.", ) parser.add_argument( "--total-shards", - type=int, - default=None, - help="Total number of shards. Defaults to SLURM_ARRAY_TASK_COUNT.", + type=_parse_int_or_env_name, + required=True, + help="Total number of shards, as an integer or environment variable name.", ) parser.add_argument( "--minimum-shard-index", - type=int, + type=_parse_int_or_env_name, default=0, help=( "Offset added to the hash-assigned shard before comparison. " @@ -331,18 +343,16 @@ def main() -> None: ) args = parser.parse_args() + shard_index = _resolve_int_or_env_name(args.shard_index, "shard_index") + total_shards = _resolve_int_or_env_name(args.total_shards, "total_shards") + minimum_shard_index = _resolve_int_or_env_name(args.minimum_shard_index, "minimum_shard_index") + ray_client = SlurmRayClient() if args.slurm else RayClient() - shard_index = args.shard_index - total_shards = args.total_shards - minimum_shard_index = args.minimum_shard_index retry_manifest_file = None is_driver_process = _is_driver_process(args.slurm) should_manage_retry_manifest = args.checkpoint_path is not None and is_driver_process try: - shard_index = _resolve_int_arg(args.shard_index, "SLURM_ARRAY_TASK_ID") - total_shards = _resolve_int_arg(args.total_shards, "SLURM_ARRAY_TASK_COUNT") - if should_manage_retry_manifest: retry_manifest_file = write_retry_manifest( checkpoint_path=args.checkpoint_path, diff --git a/tutorials/slurm/submit_array.sh b/tutorials/slurm/submit_array.sh index 2192911465..df9dc75a30 100644 --- a/tutorials/slurm/submit_array.sh +++ b/tutorials/slurm/submit_array.sh @@ -32,6 +32,8 @@ # shard_index = SLURM_ARRAY_TASK_ID (set automatically by Slurm) # total_shards = SLURM_ARRAY_TASK_COUNT (set automatically by Slurm) # minimum_shard_index defaults to 0 — no env var fallback. +# shard_index_offset defaults to 0 and is added to SLURM_ARRAY_TASK_ID +# only when SHARD_INDEX is not explicitly set. # # If your array does not start at 0 (e.g. --array=1-20), set: # MINIMUM_SHARD_INDEX=1 sbatch --array=1-20 tutorials/slurm/submit_array.sh @@ -83,8 +85,13 @@ OUTPUT_FILE_TYPE="${OUTPUT_FILE_TYPE:-jsonl}" # Number of files to read into a single DocumentBatch FILES_PER_PARTITION="${FILES_PER_PARTITION:-1}" -# Shard index and total shards -SHARD_INDEX="${SHARD_INDEX:-${SLURM_ARRAY_TASK_ID}}" +# Shard index and total shards. +# +# SHARD_INDEX_OFFSET is useful on clusters that limit the maximum Slurm array +# index. For example, submit --array=0-999 with SHARD_INDEX_OFFSET=1000 to +# process logical shards 1000-1999. +SHARD_INDEX_OFFSET="${SHARD_INDEX_OFFSET:-0}" +SHARD_INDEX="${SHARD_INDEX:-$((SLURM_ARRAY_TASK_ID + SHARD_INDEX_OFFSET))}" TOTAL_SHARDS="${TOTAL_SHARDS:-${SLURM_ARRAY_TASK_COUNT}}" # Offset between 0-indexed hash assignments and SLURM_ARRAY_TASK_ID. @@ -108,6 +115,7 @@ export CHECKPOINT_PATH export INPUT_FILE_TYPE export OUTPUT_FILE_TYPE export FILES_PER_PARTITION +export SHARD_INDEX_OFFSET export SHARD_INDEX export TOTAL_SHARDS export MINIMUM_SHARD_INDEX @@ -119,6 +127,9 @@ echo "==================================================" echo " Job array ID : ${SLURM_ARRAY_JOB_ID}" echo " Array task ID : ${SLURM_ARRAY_TASK_ID}" echo " Array task cnt : ${SLURM_ARRAY_TASK_COUNT}" +echo " Shard index : ${SHARD_INDEX}" +echo " Shard offset : ${SHARD_INDEX_OFFSET}" +echo " Total shards : ${TOTAL_SHARDS}" echo " Nodes : ${NUM_NODES}" echo " Ray client : $([[ "${USE_SLURM_RAY}" == "1" ]] && echo SlurmRayClient || echo RayClient)" echo " Node : $(hostname)" From bb1e30a86637ff13bdbcc72ad6e848278bc015d7 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 11:20:52 -0700 Subject: [PATCH 06/11] ruff Signed-off-by: Sarah Yurick --- tutorials/slurm/array_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py index 08b817acff..333a803e4e 100644 --- a/tutorials/slurm/array_pipeline.py +++ b/tutorials/slurm/array_pipeline.py @@ -119,7 +119,7 @@ def _retry_manifest_prefix( ) -def _retry_manifest_payload( +def _retry_manifest_payload( # noqa: PLR0913 shard_index: int, total_shards: int, minimum_shard_index: int, From 2ccbd3f91878237d6790bb8171f631e63f789c34 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 12:53:42 -0700 Subject: [PATCH 07/11] more greptile comments Signed-off-by: Sarah Yurick --- nemo_curator/stages/file_partitioning.py | 6 +++++- tests/stages/common/test_file_partitioning.py | 15 +++++++++++++++ tutorials/slurm/array_pipeline.py | 6 +++++- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index 041d669133..39796b15a6 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -44,7 +44,11 @@ def _get_int_or_env_var(input_value: int | str | None, default_name: str | None if env_value is None: msg = f"Environment variable {env_var} is not set" raise ValueError(msg) - return int(env_value) + try: + return int(env_value) + except ValueError as e: + msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}" + raise ValueError(msg) from e @dataclass diff --git a/tests/stages/common/test_file_partitioning.py b/tests/stages/common/test_file_partitioning.py index 17845e963c..093d4f6a4f 100644 --- a/tests/stages/common/test_file_partitioning.py +++ b/tests/stages/common/test_file_partitioning.py @@ -342,6 +342,21 @@ def test_enable_array_partitioning_requires_slurm_env_vars_by_default( enable_array_partitioning=True, ) + def test_enable_array_partitioning_rejects_non_integer_env_var( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that non-integer shard env vars raise a contextual error.""" + monkeypatch.setenv("CUSTOM_SHARD_INDEX", "not-an-int") + + with pytest.raises(ValueError, match="CUSTOM_SHARD_INDEX.*not-an-int"): + FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index="CUSTOM_SHARD_INDEX", + total_shards=4, + ) + def test_enable_array_partitioning_requires_positive_total_shards(self): """Test that array partitioning rejects non-positive shard counts.""" with pytest.raises(ValueError, match="total_shards must be greater than 0"): diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py index 333a803e4e..151dc42322 100644 --- a/tutorials/slurm/array_pipeline.py +++ b/tutorials/slurm/array_pipeline.py @@ -99,7 +99,11 @@ def _resolve_int_or_env_name(value: int | str, label: str) -> int: if env_value is None: msg = f"{label} references environment variable {value}, but it is not set" raise ValueError(msg) - return int(env_value) + try: + return int(env_value) + except ValueError as e: + msg = f"{label} references environment variable {value}, which must contain an integer, got {env_value!r}" + raise ValueError(msg) from e def _is_driver_process(use_slurm: bool) -> bool: From 1b659ead5158e61b60f41ee83442f8e77e89605e Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 12:58:12 -0700 Subject: [PATCH 08/11] add nonetask and failedtask sentinels Signed-off-by: Sarah Yurick --- nemo_curator/backends/base.py | 12 +++++++- nemo_curator/tasks/sentinels.py | 29 +++++++++++++++++-- tests/stages/common/test_file_partitioning.py | 2 +- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index 94236f2cfc..6dc1ce6225 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ from nemo_curator.core.utils import ignore_ray_head_node from nemo_curator.tasks import Task +from nemo_curator.tasks.sentinels import FailedTask, NoneTask from nemo_curator.utils.performance_utils import StageTimer if TYPE_CHECKING: @@ -85,9 +86,18 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: # Use the batch processing logic results = self.stage.process_batch(tasks) + # A returned ``None`` ("filter this slot") becomes a NoneTask so every + # output is a real Task that gets a task_id. Sentinels (NoneTask / + # FailedTask) carry no identity and are stripped again before this + # method returns. + results = [NoneTask() if r is None else r for r in results] + # Guarantee every emitted task has a task_id (derived id, or uuid fallback). results = self._post_process_task_ids(tasks, results) + # Sentinels never propagate to the next stage. + results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))] + # Log performance stats and add to result tasks _, stage_perf_stats = self._timer.log_stats() # Consume and attach any custom metrics recorded by the stage during this call diff --git a/nemo_curator/tasks/sentinels.py b/nemo_curator/tasks/sentinels.py index 84896dd963..ed1ab8b572 100644 --- a/nemo_curator/tasks/sentinels.py +++ b/nemo_curator/tasks/sentinels.py @@ -13,9 +13,18 @@ # limitations under the License. """Payload-less marker tasks. -``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). All markers -share the :class:`SentinelTask` base and carry no payload (``data is None``). -Construct one with ``EmptyTask()``. +``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). The resumability +layer adds two more markers on the same :class:`SentinelTask` base: + +- ``NoneTask`` — this slot was intentionally filtered. The resumability counter + treats it as a consumed branch (decrements). The adapter auto-wraps a + returned ``None`` as a ``NoneTask``. +- ``FailedTask`` — this slot failed and should be retried on resume. The counter + is NOT decremented, so its source stays pending and reruns. + +All carry no payload (``data is None``) and get their ``task_id`` assigned by +the executor adapter; sentinels are stripped before the next stage. Construct +with ``EmptyTask()`` / ``NoneTask()`` / ``FailedTask()``. """ from dataclasses import dataclass, field @@ -52,3 +61,17 @@ class EmptyTask(SentinelTask): dataset_name: str = "empty" task_id: str = field(init=False, default="0") + + +@dataclass +class NoneTask(SentinelTask): + """Marks a slot as intentionally filtered (resumability counter decrements).""" + + dataset_name: str = "none" + + +@dataclass +class FailedTask(SentinelTask): + """Marks a slot as failed → retried on resume (counter does NOT decrement).""" + + dataset_name: str = "failed" diff --git a/tests/stages/common/test_file_partitioning.py b/tests/stages/common/test_file_partitioning.py index 093d4f6a4f..0aa58194a7 100644 --- a/tests/stages/common/test_file_partitioning.py +++ b/tests/stages/common/test_file_partitioning.py @@ -349,7 +349,7 @@ def test_enable_array_partitioning_rejects_non_integer_env_var( """Test that non-integer shard env vars raise a contextual error.""" monkeypatch.setenv("CUSTOM_SHARD_INDEX", "not-an-int") - with pytest.raises(ValueError, match="CUSTOM_SHARD_INDEX.*not-an-int"): + with pytest.raises(ValueError, match=r"CUSTOM_SHARD_INDEX.*not-an-int"): FilePartitioningStage( file_paths="/test/path", enable_array_partitioning=True, From 3522809351351b4c39ea6e1827a4d00e0f4b2103 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 13:44:25 -0700 Subject: [PATCH 09/11] add failedtask detection and repeat Signed-off-by: Sarah Yurick --- nemo_curator/backends/base.py | 75 ++++++++++++++++++++- tests/backends/test_base_stage_adapter.py | 81 +++++++++++++++++++++++ tutorials/slurm/README.md | 18 +++-- tutorials/slurm/array_pipeline.py | 57 ++++++++++++++++ tutorials/slurm/submit_array.sh | 14 +++- 5 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 tests/backends/test_base_stage_adapter.py diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index 6dc1ce6225..bff3667dba 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime +import json +import os +import socket +import tempfile import uuid from abc import ABC, abstractmethod from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any +from loguru import logger + from nemo_curator.core.utils import ignore_ray_head_node from nemo_curator.tasks import Task from nemo_curator.tasks.sentinels import FailedTask, NoneTask @@ -26,6 +34,57 @@ from nemo_curator.stages.base import ProcessingStage +FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR" + + +def _safe_filename_token(value: object) -> str: + return "".join(ch if ch.isalnum() or ch in "._-" else "_" for ch in str(value)) + + +def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTask) -> None: + created_at = datetime.datetime.now(datetime.UTC) + timestamp = created_at.strftime("%Y%m%dT%H%M%S%fZ") + payload: dict[str, str | int] = { + "created_at": created_at.isoformat(), + "stage_name": stage_name, + "task_id": task.task_id, + "dataset_name": task.dataset_name, + "task_type": type(task).__name__, + "hostname": socket.gethostname(), + "pid": os.getpid(), + } + + marker_dir.mkdir(parents=True, exist_ok=True) + filename = ( + "failed_task_" + f"stage-{_safe_filename_token(stage_name)}_" + f"task-{_safe_filename_token(task.task_id)}_" + f"pid-{os.getpid()}_" + f"{timestamp}_{uuid.uuid4().hex}.json" + ) + final_path = marker_dir / filename + + tmp_path: Path | None = None + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=marker_dir, + prefix=f".{filename}.", + suffix=".tmp", + delete=False, + ) as tmp: + tmp_path = Path(tmp.name) + json.dump(payload, tmp, indent=2, sort_keys=True) + tmp.write("\n") + + os.replace(tmp_path, final_path) + except Exception: # noqa: BLE001 + if tmp_path is not None: + tmp_path.unlink(missing_ok=True) + raise + + @dataclass class NodeInfo: """Generic node information for setup_on_node calls across backends. @@ -95,6 +154,8 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: # Guarantee every emitted task has a task_id (derived id, or uuid fallback). results = self._post_process_task_ids(tasks, results) + self._record_failed_tasks([r for r in results if isinstance(r, FailedTask)]) + # Sentinels never propagate to the next stage. results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))] @@ -109,6 +170,18 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: return results + def _record_failed_tasks(self, failed_tasks: list[FailedTask]) -> None: + marker_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR) + if not marker_dir or not failed_tasks: + return + + marker_path = Path(marker_dir) + for task in failed_tasks: + try: + _write_failed_task_marker(marker_path, self.stage.name, task) + except Exception as e: # noqa: BLE001 + logger.warning(f"Failed to write FailedTask marker to {marker_path}: {e}") + def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]: """Assign a deterministic ``task_id`` to every emitted task. @@ -144,7 +217,7 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas (returning a flat list rather than a per-input slot) cannot be mapped positionally; if its length happens to equal the input length the 1:1 assumption may misattribute parents. That combination is unsupported - until per-slot sentinels (NoneTask/FailedTask) land in a later PR. + unless the stage preserves an unambiguous input→output mapping. """ is_source = getattr(self.stage, "is_source_stage", False) diff --git a/tests/backends/test_base_stage_adapter.py b/tests/backends/test_base_stage_adapter.py new file mode 100644 index 0000000000..911ebc18e1 --- /dev/null +++ b/tests/backends/test_base_stage_adapter.py @@ -0,0 +1,81 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass + +from nemo_curator.backends.base import FAILED_TASKS_DIR_ENV_VAR, BaseStageAdapter +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import Task +from nemo_curator.tasks.sentinels import FailedTask + + +@dataclass +class _FailedStage(ProcessingStage[Task, Task]): + name: str = "failed" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: Task) -> Task: + return FailedTask() + + +@dataclass +class _SimpleTask(Task[list[int]]): + @property + def num_items(self) -> int: + return 0 + + def validate(self) -> bool: + return True + + +def _task(task_id: str = "") -> _SimpleTask: + task = _SimpleTask(dataset_name="d", data=[]) + task.task_id = task_id + return task + + +class TestBaseStageAdapter: + def test_process_batch_writes_failed_task_marker_when_enabled(self, tmp_path, monkeypatch) -> None: + marker_dir = tmp_path / "failed-tasks" + monkeypatch.setenv(FAILED_TASKS_DIR_ENV_VAR, str(marker_dir)) + + output = BaseStageAdapter(_FailedStage()).process_batch([_task("0_7")]) + + assert output == [] + marker_files = list(marker_dir.glob("failed_task_*.json")) + assert len(marker_files) == 1 + + payload = json.loads(marker_files[0].read_text()) + assert payload["stage_name"] == "failed" + assert payload["task_id"] == "0_7_0" + assert payload["dataset_name"] == "failed" + assert payload["task_type"] == "FailedTask" + assert isinstance(payload["hostname"], str) + assert isinstance(payload["pid"], int) + assert isinstance(payload["created_at"], str) + + def test_process_batch_does_not_write_failed_task_marker_by_default(self, tmp_path, monkeypatch) -> None: + marker_dir = tmp_path / "failed-tasks" + monkeypatch.delenv(FAILED_TASKS_DIR_ENV_VAR, raising=False) + + output = BaseStageAdapter(_FailedStage()).process_batch([_task("0_7")]) + + assert output == [] + assert not marker_dir.exists() diff --git a/tutorials/slurm/README.md b/tutorials/slurm/README.md index 3a47a9bd95..2ab43f548a 100644 --- a/tutorials/slurm/README.md +++ b/tutorials/slurm/README.md @@ -203,7 +203,7 @@ source .venv/bin/activate pip install -e . ``` -Make sure `CURATOR_DIR`, `INPUT_DIR`, and `OUTPUT_DIR` are visible from every compute node, either because they are on a shared filesystem or because you set `CONTAINER_MOUNTS` to expose the right host paths inside the container. +Make sure `CURATOR_DIR`, `INPUT_DIR`, `OUTPUT_DIR`, and `CHECKPOINT_PATH` are visible from every compute node, either because they are on a shared filesystem or because you set `CONTAINER_MOUNTS` to expose the right host paths inside the container. ### 2. Submit a JSONL array job @@ -287,9 +287,19 @@ ${CHECKPOINT_PATH:-$OUTPUT_DIR}/.nemo_curator_metadata/.slurm_array_retry/ In other words, retries are tracked at `checkpoint_path/.nemo_curator_metadata/.slurm_array_retry/`. -If the shard completes successfully, that shard's matching retry manifests are removed. If the process fails, is preempted, or reaches the Slurm time limit before cleanup runs, the manifest remains in the retry directory. Caught Python exceptions update the manifest with `status="failed"` and the error message; hard termination may leave `status="pending"`, which should still be treated as retryable after the original Slurm array has finished. +`submit_array.sh` also sets `NEMO_CURATOR_FAILED_TASKS_DIR` so the backend can record `FailedTask` sentinels produced by stages. By default, each Slurm job gets its own marker directory: -Retry manifests are uniquely named JSON files written with an atomic rename, so multiple array tasks can write to the same retry directory without coordinating through a shared database. +```bash +${CHECKPOINT_PATH:-$OUTPUT_DIR}/.nemo_curator_metadata/.failed_tasks/slurm_job_${SLURM_JOB_ID}/array_task_${SLURM_ARRAY_TASK_ID}/shard_${SHARD_INDEX}/ +``` + +The marker directory is per Slurm job/task. For `--nodes>1`, all workers in that one array task share the same directory, so any worker can record a `FailedTask` and the driver can inspect the directory after `pipeline.run()` returns. + +If the shard completes successfully and no `FailedTask` markers were written, that shard's matching retry manifests are removed. If the process fails, is preempted, or reaches the Slurm time limit before cleanup runs, the manifest remains in the retry directory. Caught Python exceptions update the manifest with `status="failed"` and the error message; hard termination may leave `status="pending"`, which should still be treated as retryable after the original Slurm array has finished. + +If the pipeline completes without raising but one or more `FailedTask` marker files exist, `array_pipeline.py` keeps the retry manifest and updates it with `status="failed_tasks"` plus marker metadata. The retry collection below only needs to read retry manifests; it does not need to inspect the `FailedTask` marker directories directly. + +Retry manifests and `FailedTask` markers are uniquely named JSON files written with an atomic rename, so multiple array tasks or workers can write without coordinating through a shared database. Each manifest records the failed `shard_index`, plus the `total_shards` and `minimum_shard_index` values used for the original run. To retry only the failed shards, rebuild a Slurm array list from those manifests and preserve the original shard settings. For example, using `jq`: @@ -323,7 +333,7 @@ sbatch --array="${FAILED_SHARDS}" tutorials/slurm/submit_array.sh The `TOTAL_SHARDS` override is important. On a retry array like `--array=3,17,42`, Slurm sets `SLURM_ARRAY_TASK_COUNT=3`, but the data was originally assigned using the full logical shard count. Reusing the original `TOTAL_SHARDS` keeps `hash(partition) % total_shards` identical to the first run. -Run this retry collection after the original Slurm array has finished, otherwise still-running tasks will still have pending manifests. Use one `CHECKPOINT_PATH` per logical array run, or move old retry manifests aside after building `FAILED_SHARDS`, so later retries do not include failures that already succeeded. +Run this retry collection after the original Slurm array has finished, otherwise still-running tasks will still have pending manifests. Use one `CHECKPOINT_PATH` per logical array run, or move old retry manifests aside after building `FAILED_SHARDS`, so later retries do not include failures that already succeeded. If you override `NEMO_CURATOR_FAILED_TASKS_DIR`, keep it unique per Slurm job or clean it before reuse; stale `FailedTask` markers make an otherwise successful shard look retryable. --- diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py index 151dc42322..3fcd752578 100644 --- a/tutorials/slurm/array_pipeline.py +++ b/tutorials/slurm/array_pipeline.py @@ -75,6 +75,9 @@ METADATA_DIRNAME = ".nemo_curator_metadata" SLURM_ARRAY_RETRY_DIRNAME = ".slurm_array_retry" +FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR" +FAILED_TASK_MARKER_PATTERN = "failed_task_*.json" +MAX_FAILED_TASK_MARKERS_IN_MANIFEST = 10 def _safe_token(value: object) -> str: @@ -130,6 +133,7 @@ def _retry_manifest_payload( # noqa: PLR0913 status: str, created_at: datetime.datetime, error: BaseException | None = None, + extra: dict[str, object] | None = None, ) -> dict[str, object]: payload = { "shard_index": shard_index, @@ -149,6 +153,9 @@ def _retry_manifest_payload( # noqa: PLR0913 payload["error_type"] = type(error).__name__ payload["error"] = str(error) + if extra is not None: + payload.update(extra) + return payload @@ -160,6 +167,7 @@ def write_retry_manifest( # noqa: PLR0913 status: str, error: BaseException | None = None, manifest_file: Path | None = None, + extra: dict[str, object] | None = None, ) -> Path: """Write a retry manifest using a unique name and atomic rename.""" retry_dir = Path(checkpoint_path, METADATA_DIRNAME, SLURM_ARRAY_RETRY_DIRNAME).absolute() @@ -173,6 +181,7 @@ def write_retry_manifest( # noqa: PLR0913 status=status, created_at=created_at, error=error, + extra=extra, ) if manifest_file is None: @@ -212,6 +221,30 @@ def write_retry_manifest( # noqa: PLR0913 return manifest_file +def failed_task_marker_files() -> list[Path]: + """Return FailedTask marker files written by BaseStageAdapter for this job.""" + failed_tasks_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR) + if not failed_tasks_dir: + return [] + + marker_dir = Path(failed_tasks_dir).absolute() + if not marker_dir.exists(): + return [] + + return sorted(path for path in marker_dir.glob(FAILED_TASK_MARKER_PATTERN) if path.is_file()) + + +def failed_task_manifest_metadata(marker_files: list[Path]) -> dict[str, object]: + marker_dir = marker_files[0].parent if marker_files else os.environ.get(FAILED_TASKS_DIR_ENV_VAR) + sample_marker_files = marker_files[:MAX_FAILED_TASK_MARKERS_IN_MANIFEST] + return { + "failed_task_marker_dir": str(marker_dir), + "failed_task_marker_count": len(marker_files), + "failed_task_marker_files": [str(path) for path in sample_marker_files], + "failed_task_marker_files_truncated": len(marker_files) > len(sample_marker_files), + } + + def remove_retry_manifests( checkpoint_path: str, shard_index: int | None, @@ -383,6 +416,30 @@ def main() -> None: pipeline.run() + failed_task_markers = failed_task_marker_files() + if failed_task_markers: + if should_manage_retry_manifest: + manifest_file = write_retry_manifest( + checkpoint_path=args.checkpoint_path, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + status="failed_tasks", + manifest_file=retry_manifest_file, + extra=failed_task_manifest_metadata(failed_task_markers), + ) + logger.warning( + "Pipeline completed without raising, but found " + f"{len(failed_task_markers)} FailedTask marker(s). " + f"Keeping retry manifest at {manifest_file}." + ) + else: + logger.warning( + "Pipeline completed without raising, but found " + f"{len(failed_task_markers)} FailedTask marker(s)." + ) + return + if should_manage_retry_manifest: try: remove_retry_manifests( diff --git a/tutorials/slurm/submit_array.sh b/tutorials/slurm/submit_array.sh index df9dc75a30..8060a38586 100644 --- a/tutorials/slurm/submit_array.sh +++ b/tutorials/slurm/submit_array.sh @@ -61,6 +61,8 @@ OUTPUT_DIR="${OUTPUT_DIR:-/path/to/your/output/directory}" # Retry manifests are written under: # ${CHECKPOINT_PATH}/.nemo_curator_metadata/.slurm_array_retry/ +# FailedTask marker files are written under: +# ${CHECKPOINT_PATH}/.nemo_curator_metadata/.failed_tasks//// # Defaults to OUTPUT_DIR. If you override CONTAINER_MOUNTS, make sure it still # includes CHECKPOINT_PATH. CHECKPOINT_PATH="${CHECKPOINT_PATH:-${OUTPUT_DIR}}" @@ -98,6 +100,13 @@ TOTAL_SHARDS="${TOTAL_SHARDS:-${SLURM_ARRAY_TASK_COUNT}}" # Leave at 0 for --array=0-N. Set to the array start value for --array=K-N. MINIMUM_SHARD_INDEX="${MINIMUM_SHARD_INDEX:-0}" +# BaseStageAdapter writes one marker JSON per FailedTask when this env var is +# set. Keep one directory per Slurm array job/task so retries can inspect just +# the FailedTasks from that job, while multi-node workers for that same job +# write into the same directory. +FAILED_TASKS_DIR="${FAILED_TASKS_DIR:-${CHECKPOINT_PATH}/.nemo_curator_metadata/.failed_tasks/slurm_job_${SLURM_JOB_ID:-local}/array_task_${SLURM_ARRAY_TASK_ID:-local}/shard_${SHARD_INDEX}}" +NEMO_CURATOR_FAILED_TASKS_DIR="${NEMO_CURATOR_FAILED_TASKS_DIR:-${FAILED_TASKS_DIR}}" + # Use SlurmRayClient only when this array task spans multiple nodes. Single-node # array tasks can use the regular RayClient. NUM_NODES="${SLURM_JOB_NUM_NODES:-${SLURM_NNODES:-1}}" @@ -106,12 +115,14 @@ if (( NUM_NODES > 1 )); then USE_SLURM_RAY=1 fi -mkdir -p "${CURATOR_DIR}/logs" "${OUTPUT_DIR}" "${CHECKPOINT_PATH}" +mkdir -p "${CURATOR_DIR}/logs" "${OUTPUT_DIR}" "${CHECKPOINT_PATH}" "${NEMO_CURATOR_FAILED_TASKS_DIR}" export CURATOR_DIR export INPUT_DIR export OUTPUT_DIR export CHECKPOINT_PATH +export FAILED_TASKS_DIR +export NEMO_CURATOR_FAILED_TASKS_DIR export INPUT_FILE_TYPE export OUTPUT_FILE_TYPE export FILES_PER_PARTITION @@ -137,6 +148,7 @@ echo " Container : ${CONTAINER_IMAGE}" echo " Mounts : ${CONTAINER_MOUNTS}" echo " Dir : ${CURATOR_DIR}" echo " Checkpoint path: ${CHECKPOINT_PATH}" +echo " FailedTask dir : ${NEMO_CURATOR_FAILED_TASKS_DIR}" echo "==================================================" # Each array task processes only the file partitions hashed to its From 717edacac54c76e13fa0687e6bc88afab394d439 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 13:49:09 -0700 Subject: [PATCH 10/11] ruff Signed-off-by: Sarah Yurick --- nemo_curator/backends/base.py | 2 +- tests/backends/test_base_stage_adapter.py | 11 +++++++++-- tutorials/slurm/array_pipeline.py | 5 ++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index bff3667dba..bb2e8431e5 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -79,7 +79,7 @@ def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTas tmp.write("\n") os.replace(tmp_path, final_path) - except Exception: # noqa: BLE001 + except Exception: if tmp_path is not None: tmp_path.unlink(missing_ok=True) raise diff --git a/tests/backends/test_base_stage_adapter.py b/tests/backends/test_base_stage_adapter.py index 911ebc18e1..d139cbe46a 100644 --- a/tests/backends/test_base_stage_adapter.py +++ b/tests/backends/test_base_stage_adapter.py @@ -14,6 +14,9 @@ import json from dataclasses import dataclass +from pathlib import Path + +from pytest import MonkeyPatch from nemo_curator.backends.base import FAILED_TASKS_DIR_ENV_VAR, BaseStageAdapter from nemo_curator.stages.base import ProcessingStage @@ -52,7 +55,9 @@ def _task(task_id: str = "") -> _SimpleTask: class TestBaseStageAdapter: - def test_process_batch_writes_failed_task_marker_when_enabled(self, tmp_path, monkeypatch) -> None: + def test_process_batch_writes_failed_task_marker_when_enabled( + self, tmp_path: Path, monkeypatch: MonkeyPatch + ) -> None: marker_dir = tmp_path / "failed-tasks" monkeypatch.setenv(FAILED_TASKS_DIR_ENV_VAR, str(marker_dir)) @@ -71,7 +76,9 @@ def test_process_batch_writes_failed_task_marker_when_enabled(self, tmp_path, mo assert isinstance(payload["pid"], int) assert isinstance(payload["created_at"], str) - def test_process_batch_does_not_write_failed_task_marker_by_default(self, tmp_path, monkeypatch) -> None: + def test_process_batch_does_not_write_failed_task_marker_by_default( + self, tmp_path: Path, monkeypatch: MonkeyPatch + ) -> None: marker_dir = tmp_path / "failed-tasks" monkeypatch.delenv(FAILED_TASKS_DIR_ENV_VAR, raising=False) diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py index 3fcd752578..44ad32069f 100644 --- a/tutorials/slurm/array_pipeline.py +++ b/tutorials/slurm/array_pipeline.py @@ -319,7 +319,7 @@ def build_pipeline( # noqa: PLR0913 return pipeline -def main() -> None: +def main() -> None: # noqa: PLR0915 parser = argparse.ArgumentParser(description="Slurm array file-partitioning demo") parser.add_argument("--input-dir", required=True, help="Directory containing input files") parser.add_argument( @@ -435,8 +435,7 @@ def main() -> None: ) else: logger.warning( - "Pipeline completed without raising, but found " - f"{len(failed_task_markers)} FailedTask marker(s)." + f"Pipeline completed without raising, but found {len(failed_task_markers)} FailedTask marker(s)." ) return From ebba73e331da8666fc7ab724f35bb4930fc7e112 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 11 Jun 2026 14:03:36 -0700 Subject: [PATCH 11/11] greptile comments Signed-off-by: Sarah Yurick --- nemo_curator/stages/file_partitioning.py | 2 +- tutorials/slurm/array_pipeline.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index 39796b15a6..b45d11a468 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -86,7 +86,7 @@ class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]): total_shards: int | str | None = None The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. - minimum_shard_index: int = 0 + minimum_shard_index: int | str = 0 The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. Only used if enable_array_partitioning is True. If not provided, it will be set to 0. """ diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py index 44ad32069f..49fc070658 100644 --- a/tutorials/slurm/array_pipeline.py +++ b/tutorials/slurm/array_pipeline.py @@ -68,6 +68,7 @@ from loguru import logger +from nemo_curator.backends.base import FAILED_TASKS_DIR_ENV_VAR from nemo_curator.core.client import RayClient, SlurmRayClient from nemo_curator.pipeline import Pipeline from nemo_curator.stages.text.io.reader import JsonlReader, ParquetReader @@ -75,7 +76,6 @@ METADATA_DIRNAME = ".nemo_curator_metadata" SLURM_ARRAY_RETRY_DIRNAME = ".slurm_array_retry" -FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR" FAILED_TASK_MARKER_PATTERN = "failed_task_*.json" MAX_FAILED_TASK_MARKERS_IN_MANIFEST = 10