diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index 94236f2cfc..bb2e8431e5 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,19 +12,79 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime +import json +import os +import socket +import tempfile import uuid from abc import ABC, abstractmethod from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any +from loguru import logger + from nemo_curator.core.utils import ignore_ray_head_node from nemo_curator.tasks import Task +from nemo_curator.tasks.sentinels import FailedTask, NoneTask from nemo_curator.utils.performance_utils import StageTimer if TYPE_CHECKING: from nemo_curator.stages.base import ProcessingStage +FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR" + + +def _safe_filename_token(value: object) -> str: + return "".join(ch if ch.isalnum() or ch in "._-" else "_" for ch in str(value)) + + +def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTask) -> None: + created_at = datetime.datetime.now(datetime.UTC) + timestamp = created_at.strftime("%Y%m%dT%H%M%S%fZ") + payload: dict[str, str | int] = { + "created_at": created_at.isoformat(), + "stage_name": stage_name, + "task_id": task.task_id, + "dataset_name": task.dataset_name, + "task_type": type(task).__name__, + "hostname": socket.gethostname(), + "pid": os.getpid(), + } + + marker_dir.mkdir(parents=True, exist_ok=True) + filename = ( + "failed_task_" + f"stage-{_safe_filename_token(stage_name)}_" + f"task-{_safe_filename_token(task.task_id)}_" + f"pid-{os.getpid()}_" + f"{timestamp}_{uuid.uuid4().hex}.json" + ) + final_path = marker_dir / filename + + tmp_path: Path | None = None + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=marker_dir, + prefix=f".{filename}.", + suffix=".tmp", + delete=False, + ) as tmp: + tmp_path = Path(tmp.name) + json.dump(payload, tmp, indent=2, sort_keys=True) + tmp.write("\n") + + os.replace(tmp_path, final_path) + except Exception: + if tmp_path is not None: + tmp_path.unlink(missing_ok=True) + raise + + @dataclass class NodeInfo: """Generic node information for setup_on_node calls across backends. @@ -85,9 +145,20 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: # Use the batch processing logic results = self.stage.process_batch(tasks) + # A returned ``None`` ("filter this slot") becomes a NoneTask so every + # output is a real Task that gets a task_id. Sentinels (NoneTask / + # FailedTask) carry no identity and are stripped again before this + # method returns. + results = [NoneTask() if r is None else r for r in results] + # Guarantee every emitted task has a task_id (derived id, or uuid fallback). results = self._post_process_task_ids(tasks, results) + self._record_failed_tasks([r for r in results if isinstance(r, FailedTask)]) + + # Sentinels never propagate to the next stage. + results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))] + # Log performance stats and add to result tasks _, stage_perf_stats = self._timer.log_stats() # Consume and attach any custom metrics recorded by the stage during this call @@ -99,6 +170,18 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: return results + def _record_failed_tasks(self, failed_tasks: list[FailedTask]) -> None: + marker_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR) + if not marker_dir or not failed_tasks: + return + + marker_path = Path(marker_dir) + for task in failed_tasks: + try: + _write_failed_task_marker(marker_path, self.stage.name, task) + except Exception as e: # noqa: BLE001 + logger.warning(f"Failed to write FailedTask marker to {marker_path}: {e}") + def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]: """Assign a deterministic ``task_id`` to every emitted task. @@ -134,7 +217,7 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas (returning a flat list rather than a per-input slot) cannot be mapped positionally; if its length happens to equal the input length the 1:1 assumption may misattribute parents. That combination is unsupported - until per-slot sentinels (NoneTask/FailedTask) land in a later PR. + unless the stage preserves an unambiguous input→output mapping. """ is_source = getattr(self.stage, "is_source_stage", False) diff --git a/nemo_curator/stages/audio/common.py b/nemo_curator/stages/audio/common.py index bedebd066e..d2bc2db5da 100644 --- a/nemo_curator/stages/audio/common.py +++ b/nemo_curator/stages/audio/common.py @@ -192,6 +192,13 @@ class ManifestReader(CompositeStage[EmptyTask, AudioTask]): blocksize: Target size per partition (e.g., "100MB"). Ignored if files_per_partition is set. file_extensions: File extensions to filter. Defaults to [".jsonl", ".json"]. storage_options: Storage options for cloud paths (S3, GCS credentials, endpoints). + enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs). + shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable. + total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. + minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to 0. """ manifest_path: str | list[str] @@ -200,6 +207,10 @@ class ManifestReader(CompositeStage[EmptyTask, AudioTask]): blocksize: int | str | None = None file_extensions: list[str] = field(default_factory=lambda: [".jsonl", ".json"]) storage_options: dict[str, Any] | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 def __post_init__(self) -> None: super().__init__() @@ -215,6 +226,10 @@ def decompose(self) -> list[ProcessingStage]: blocksize=self.blocksize, file_extensions=self.file_extensions, storage_options=self.storage_options, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), ManifestReaderStage(), ] diff --git a/nemo_curator/stages/file_partitioning.py b/nemo_curator/stages/file_partitioning.py index 75d6906501..b45d11a468 100644 --- a/nemo_curator/stages/file_partitioning.py +++ b/nemo_curator/stages/file_partitioning.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import os from dataclasses import dataclass from typing import Any @@ -29,6 +31,26 @@ ) +def _get_int_or_env_var(input_value: int | str | None, default_name: str | None = None) -> int: + if type(input_value) is int: + return input_value + + env_var = input_value if type(input_value) is str else default_name + if env_var is None: + msg = f"Invalid input value: {input_value}, must be an integer or a string" + raise ValueError(msg) + + env_value = os.environ.get(env_var) + if env_value is None: + msg = f"Environment variable {env_var} is not set" + raise ValueError(msg) + try: + return int(env_value) + except ValueError as e: + msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}" + raise ValueError(msg) from e + + @dataclass class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]): """Stage that partitions input file paths into FileGroupTasks. @@ -55,6 +77,18 @@ class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]): Storage options to pass to the file system. limit: int | None = None Maximum number of partitions to create. + enable_array_partitioning: bool = False + Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs). + Intended for use with Slurm job arrays via the `sbatch --array` option. + shard_index: int | str | None = None + The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable. + total_shards: int | str | None = None + The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. + minimum_shard_index: int | str = 0 + The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to 0. """ file_paths: str | list[str] @@ -63,6 +97,10 @@ class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]): file_extensions: list[str] | None = None storage_options: dict[str, Any] | None = None limit: int | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 name: str = "file_partitioning" def __post_init__(self): @@ -91,6 +129,25 @@ def __post_init__(self): self.resources = Resources(cpus=0.5) + if self.enable_array_partitioning: + self.shard_index = _get_int_or_env_var(self.shard_index, "SLURM_ARRAY_TASK_ID") + self.total_shards = _get_int_or_env_var(self.total_shards, "SLURM_ARRAY_TASK_COUNT") + self.minimum_shard_index = _get_int_or_env_var(self.minimum_shard_index) + if self.total_shards <= 0: + msg = f"total_shards must be greater than 0, got {self.total_shards}" + raise ValueError(msg) + min_assignable_shard_index = self.minimum_shard_index + max_assignable_shard_index = self.minimum_shard_index + self.total_shards - 1 + if not min_assignable_shard_index <= self.shard_index <= max_assignable_shard_index: + logger.warning( + "shard_index={} is outside the assignable shard range [{}, {}]. " + "This task will not receive any partitions.", + self.shard_index, + min_assignable_shard_index, + max_assignable_shard_index, + ) + self.name = "array_file_partitioning" + def inputs(self) -> tuple[list[str], list[str]]: return [], [] @@ -106,7 +163,7 @@ def ray_stage_spec(self) -> dict[str, Any]: def xenna_stage_spec(self) -> dict[str, Any]: return {"num_workers_per_node": 1} - def process(self, _: EmptyTask) -> list[FileGroupTask]: + def _process(self, _: EmptyTask) -> list[FileGroupTask]: """Process the initial task to create file group tasks. This stage expects a simple Task with file paths information @@ -189,6 +246,35 @@ def process(self, _: EmptyTask) -> list[FileGroupTask]: logger.info(f"Created {len(tasks)} file groups from {len(files)} files") return tasks + def _process_array(self, task: EmptyTask) -> list[FileGroupTask]: + all_tasks = self._process(task) + assigned_tasks = [] + + for ft in all_tasks: + source_files = list(ft._metadata.get("source_files") or ft.data) + # Hash the source files to get a unique identifier for the partition + digest = hashlib.sha256("|".join(sorted(source_files)).encode("utf-8")).hexdigest() + # Assign the partition to the shard + assigned = int(digest[:16], 16) % self.total_shards + # Add the minimum shard index to the assigned shard index + assigned += self.minimum_shard_index + # Add the partition to the assigned tasks + if assigned == self.shard_index: + assigned_tasks.append(ft) + + msg = f"Shard {self.shard_index}/{self.total_shards}: assigned {len(assigned_tasks)} of {len(all_tasks)} partitions" + if len(assigned_tasks) == 0 and len(all_tasks) > 0: + logger.warning(msg) + else: + logger.info(msg) + return assigned_tasks + + def process(self, task: EmptyTask) -> list[FileGroupTask]: + if self.enable_array_partitioning: + return self._process_array(task) + else: + return self._process(task) + def _get_file_list_with_sizes(self, sort_by_size: bool = True) -> list[tuple[str, int]]: """ Get the list of files to process. diff --git a/nemo_curator/stages/interleaved/io/reader.py b/nemo_curator/stages/interleaved/io/reader.py index 6ba7790b6c..57a8b638fa 100644 --- a/nemo_curator/stages/interleaved/io/reader.py +++ b/nemo_curator/stages/interleaved/io/reader.py @@ -43,6 +43,10 @@ class InterleavedWebdatasetReader(CompositeStage[EmptyTask, InterleavedBatch]): blocksize: int | str | None = None max_batch_bytes: int | None = None read_kwargs: dict[str, Any] = field(default_factory=dict) + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 materialize_on_read: bool = False file_extensions: list[str] = field(default_factory=lambda: list(DEFAULT_WEBDATASET_EXTENSIONS)) json_extensions: list[str] = field(default_factory=lambda: list(DEFAULT_JSON_EXTENSIONS)) @@ -68,6 +72,10 @@ def decompose(self) -> list: blocksize=self.blocksize, file_extensions=self.file_extensions, storage_options=self.storage_options, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), InterleavedWebdatasetReaderStage( read_kwargs=self.read_kwargs, @@ -96,6 +104,10 @@ class InterleavedParquetReader(CompositeStage[EmptyTask, InterleavedBatch]): fields: tuple[str, ...] | None = None max_batch_bytes: int | None = None read_kwargs: dict[str, Any] = field(default_factory=dict) + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 schema: pa.Schema | None = None schema_overrides: dict[str, pa.DataType] | None = None file_extensions: list[str] = field(default_factory=lambda: [".parquet"]) @@ -113,6 +125,10 @@ def decompose(self) -> list: blocksize=self.blocksize, file_extensions=self.file_extensions, storage_options=self.storage_options, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), InterleavedParquetReaderStage( read_kwargs=self.read_kwargs, diff --git a/nemo_curator/stages/text/deduplication/removal_workflow.py b/nemo_curator/stages/text/deduplication/removal_workflow.py index 078b2014d3..af6fd260bd 100644 --- a/nemo_curator/stages/text/deduplication/removal_workflow.py +++ b/nemo_curator/stages/text/deduplication/removal_workflow.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo_curator/stages/text/deduplication/semantic.py b/nemo_curator/stages/text/deduplication/semantic.py index 34eaa7f651..11b67b2a4d 100644 --- a/nemo_curator/stages/text/deduplication/semantic.py +++ b/nemo_curator/stages/text/deduplication/semantic.py @@ -76,6 +76,10 @@ class TextSemanticDeduplicationWorkflow: embedding_vllm_init_kwargs: dict[str, Any] | None = None hf_token: str | None = None model_cache_dir: str | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 # Semantic deduplication parameters n_clusters: int = 100 id_field: str = CURATOR_DEDUP_ID_STR @@ -132,6 +136,13 @@ class TextSemanticDeduplicationWorkflow: embedding_vllm_init_kwargs: Additional kwargs passed to vLLM's LLM initializer hf_token: HuggingFace token for private models model_cache_dir: Directory to cache model weights + enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs). + shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable. + total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. + minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to 0. # Semantic deduplication parameters n_clusters: Number of clusters for K-means @@ -252,6 +263,10 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: file_extensions=self.input_file_extensions, _generate_ids=self.use_id_generator, read_kwargs=self.read_kwargs, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ) elif self.input_filetype == "parquet": reader = ParquetReader( @@ -266,6 +281,10 @@ def _run_embedding_generation(self, executor: BaseExecutor) -> list[Task]: file_extensions=self.input_file_extensions, read_kwargs=self.read_kwargs, _generate_ids=self.use_id_generator, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ) else: msg = f"Input filetype {self.input_filetype} not supported yet" diff --git a/nemo_curator/stages/text/io/reader/jsonl.py b/nemo_curator/stages/text/io/reader/jsonl.py index 7ae4c81003..9e107f2888 100644 --- a/nemo_curator/stages/text/io/reader/jsonl.py +++ b/nemo_curator/stages/text/io/reader/jsonl.py @@ -94,6 +94,10 @@ class JsonlReader(CompositeStage[EmptyTask, DocumentBatch]): blocksize: int | str | None = None fields: list[str] | None = None # If specified, only read these columns read_kwargs: dict[str, Any] | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 task_type: Literal["document", "image", "video", "audio"] = "document" file_extensions: list[str] = field(default_factory=lambda: FILETYPE_TO_DEFAULT_EXTENSIONS["jsonl"]) _generate_ids: bool = False @@ -121,6 +125,10 @@ def decompose(self) -> list[JsonlReaderStage]: storage_options=self.read_kwargs.get("storage_options", None) if self.read_kwargs is not None else None, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), JsonlReaderStage( fields=self.fields, diff --git a/nemo_curator/stages/text/io/reader/parquet.py b/nemo_curator/stages/text/io/reader/parquet.py index 0654b0f62c..47499e8610 100644 --- a/nemo_curator/stages/text/io/reader/parquet.py +++ b/nemo_curator/stages/text/io/reader/parquet.py @@ -79,6 +79,10 @@ class ParquetReader(CompositeStage[EmptyTask, DocumentBatch]): blocksize: int | str | None = None fields: list[str] | None = None # If specified, only read these columns read_kwargs: dict[str, Any] | None = None + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 file_extensions: list[str] = field(default_factory=lambda: FILETYPE_TO_DEFAULT_EXTENSIONS["parquet"]) task_type: Literal["document", "image", "video", "audio"] = "document" _generate_ids: bool = False @@ -105,6 +109,10 @@ def decompose(self) -> list[ParquetReaderStage]: blocksize=self.blocksize, file_extensions=self.file_extensions, storage_options=self.read_kwargs.get("storage_options", {}) if self.read_kwargs is not None else None, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ), # Second stage: process file groups into document batches ParquetReaderStage( diff --git a/nemo_curator/stages/video/io/video_reader.py b/nemo_curator/stages/video/io/video_reader.py index 9d5300fb5e..5a0c75501f 100644 --- a/nemo_curator/stages/video/io/video_reader.py +++ b/nemo_curator/stages/video/io/video_reader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -244,11 +244,22 @@ class VideoReader(CompositeStage[EmptyTask, VideoTask]): input_video_path: Path to the directory containing video files video_limit: Maximum number of videos to process (None for unlimited) verbose: Whether to enable verbose logging during download/processing + enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs). + shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable. + total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable. + minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name. + Only used if enable_array_partitioning is True. If not provided, it will be set to 0. """ input_video_path: str video_limit: int | None = None verbose: bool = False + enable_array_partitioning: bool = False + shard_index: int | str | None = None + total_shards: int | str | None = None + minimum_shard_index: int | str = 0 def __post_init__(self): """Initialize the parent CompositeStage after dataclass initialization.""" @@ -276,6 +287,9 @@ def decompose(self) -> list[ProcessingStage]: List of processing stages: [FilePartitioningStage, VideoReaderStage] """ if is_remote_url(self.input_video_path): + if self.enable_array_partitioning: + msg = "enable_array_partitioning is not supported for ClientPartitioningStage" + raise NotImplementedError(msg) reader_stage = ClientPartitioningStage( file_paths=self.input_video_path, files_per_partition=1, @@ -288,6 +302,10 @@ def decompose(self) -> list[ProcessingStage]: files_per_partition=1, file_extensions=[".mp4", ".mov", ".avi", ".mkv", ".webm"], limit=self.video_limit, + enable_array_partitioning=self.enable_array_partitioning, + shard_index=self.shard_index, + total_shards=self.total_shards, + minimum_shard_index=self.minimum_shard_index, ) download_stage = VideoReaderStage(input_path=self.input_video_path, verbose=self.verbose) diff --git a/nemo_curator/tasks/sentinels.py b/nemo_curator/tasks/sentinels.py index 84896dd963..ed1ab8b572 100644 --- a/nemo_curator/tasks/sentinels.py +++ b/nemo_curator/tasks/sentinels.py @@ -13,9 +13,18 @@ # limitations under the License. """Payload-less marker tasks. -``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). All markers -share the :class:`SentinelTask` base and carry no payload (``data is None``). -Construct one with ``EmptyTask()``. +``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). The resumability +layer adds two more markers on the same :class:`SentinelTask` base: + +- ``NoneTask`` — this slot was intentionally filtered. The resumability counter + treats it as a consumed branch (decrements). The adapter auto-wraps a + returned ``None`` as a ``NoneTask``. +- ``FailedTask`` — this slot failed and should be retried on resume. The counter + is NOT decremented, so its source stays pending and reruns. + +All carry no payload (``data is None``) and get their ``task_id`` assigned by +the executor adapter; sentinels are stripped before the next stage. Construct +with ``EmptyTask()`` / ``NoneTask()`` / ``FailedTask()``. """ from dataclasses import dataclass, field @@ -52,3 +61,17 @@ class EmptyTask(SentinelTask): dataset_name: str = "empty" task_id: str = field(init=False, default="0") + + +@dataclass +class NoneTask(SentinelTask): + """Marks a slot as intentionally filtered (resumability counter decrements).""" + + dataset_name: str = "none" + + +@dataclass +class FailedTask(SentinelTask): + """Marks a slot as failed → retried on resume (counter does NOT decrement).""" + + dataset_name: str = "failed" diff --git a/tests/backends/test_base_stage_adapter.py b/tests/backends/test_base_stage_adapter.py new file mode 100644 index 0000000000..d139cbe46a --- /dev/null +++ b/tests/backends/test_base_stage_adapter.py @@ -0,0 +1,88 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass +from pathlib import Path + +from pytest import MonkeyPatch + +from nemo_curator.backends.base import FAILED_TASKS_DIR_ENV_VAR, BaseStageAdapter +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import Task +from nemo_curator.tasks.sentinels import FailedTask + + +@dataclass +class _FailedStage(ProcessingStage[Task, Task]): + name: str = "failed" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: Task) -> Task: + return FailedTask() + + +@dataclass +class _SimpleTask(Task[list[int]]): + @property + def num_items(self) -> int: + return 0 + + def validate(self) -> bool: + return True + + +def _task(task_id: str = "") -> _SimpleTask: + task = _SimpleTask(dataset_name="d", data=[]) + task.task_id = task_id + return task + + +class TestBaseStageAdapter: + def test_process_batch_writes_failed_task_marker_when_enabled( + self, tmp_path: Path, monkeypatch: MonkeyPatch + ) -> None: + marker_dir = tmp_path / "failed-tasks" + monkeypatch.setenv(FAILED_TASKS_DIR_ENV_VAR, str(marker_dir)) + + output = BaseStageAdapter(_FailedStage()).process_batch([_task("0_7")]) + + assert output == [] + marker_files = list(marker_dir.glob("failed_task_*.json")) + assert len(marker_files) == 1 + + payload = json.loads(marker_files[0].read_text()) + assert payload["stage_name"] == "failed" + assert payload["task_id"] == "0_7_0" + assert payload["dataset_name"] == "failed" + assert payload["task_type"] == "FailedTask" + assert isinstance(payload["hostname"], str) + assert isinstance(payload["pid"], int) + assert isinstance(payload["created_at"], str) + + def test_process_batch_does_not_write_failed_task_marker_by_default( + self, tmp_path: Path, monkeypatch: MonkeyPatch + ) -> None: + marker_dir = tmp_path / "failed-tasks" + monkeypatch.delenv(FAILED_TASKS_DIR_ENV_VAR, raising=False) + + output = BaseStageAdapter(_FailedStage()).process_batch([_task("0_7")]) + + assert output == [] + assert not marker_dir.exists() diff --git a/tests/stages/common/test_file_partitioning.py b/tests/stages/common/test_file_partitioning.py index aea5d838f5..0aa58194a7 100644 --- a/tests/stages/common/test_file_partitioning.py +++ b/tests/stages/common/test_file_partitioning.py @@ -271,3 +271,193 @@ def test_task_metadata(self, empty_task: EmptyTask, tmp_path: Path): assert task._metadata["total_partitions"] == 2 assert task._metadata["source_files"] == [test_files[0]] assert task.reader_config == {} + + def test_enable_array_partitioning_with_explicit_values(self, monkeypatch: pytest.MonkeyPatch): + """Test array partitioning initialization with explicit shard values.""" + monkeypatch.setenv("SLURM_ARRAY_TASK_ID", "7") + monkeypatch.setenv("SLURM_ARRAY_TASK_COUNT", "11") + + stage = FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index=2, + total_shards=10, + minimum_shard_index=1, + ) + + assert stage.name == "array_file_partitioning" + assert stage.shard_index == 2 + assert stage.total_shards == 10 + assert stage.minimum_shard_index == 1 + + def test_enable_array_partitioning_reads_slurm_env_vars( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that array partitioning defaults to Slurm array env vars.""" + monkeypatch.setenv("SLURM_ARRAY_TASK_ID", "7") + monkeypatch.setenv("SLURM_ARRAY_TASK_COUNT", "11") + + stage = FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + ) + + assert stage.shard_index == 7 + assert stage.total_shards == 11 + assert stage.minimum_shard_index == 0 + + def test_enable_array_partitioning_supports_custom_env_var_names( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that shard parameters can be read from custom env var names.""" + monkeypatch.setenv("CUSTOM_SHARD_INDEX", "3") + monkeypatch.setenv("CUSTOM_TOTAL_SHARDS", "8") + monkeypatch.setenv("CUSTOM_MINIMUM_SHARD_INDEX", "1") + + stage = FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index="CUSTOM_SHARD_INDEX", + total_shards="CUSTOM_TOTAL_SHARDS", + minimum_shard_index="CUSTOM_MINIMUM_SHARD_INDEX", + ) + + assert stage.shard_index == 3 + assert stage.total_shards == 8 + assert stage.minimum_shard_index == 1 + + def test_enable_array_partitioning_requires_slurm_env_vars_by_default( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that missing default Slurm env vars raise a clear error.""" + monkeypatch.delenv("SLURM_ARRAY_TASK_ID", raising=False) + monkeypatch.delenv("SLURM_ARRAY_TASK_COUNT", raising=False) + + with pytest.raises(ValueError, match="SLURM_ARRAY_TASK_ID"): + FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + ) + + def test_enable_array_partitioning_rejects_non_integer_env_var( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Test that non-integer shard env vars raise a contextual error.""" + monkeypatch.setenv("CUSTOM_SHARD_INDEX", "not-an-int") + + with pytest.raises(ValueError, match=r"CUSTOM_SHARD_INDEX.*not-an-int"): + FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index="CUSTOM_SHARD_INDEX", + total_shards=4, + ) + + def test_enable_array_partitioning_requires_positive_total_shards(self): + """Test that array partitioning rejects non-positive shard counts.""" + with pytest.raises(ValueError, match="total_shards must be greater than 0"): + FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index=0, + total_shards=0, + ) + + def test_enable_array_partitioning_warns_for_out_of_range_shard( + self, + caplog: pytest.LogCaptureFixture, + ): + """Test that shard IDs outside the assignable range produce a clear warning.""" + with caplog.at_level("WARNING"): + stage = FilePartitioningStage( + file_paths="/test/path", + enable_array_partitioning=True, + shard_index=0, + total_shards=10, + minimum_shard_index=1, + ) + + assert stage.shard_index == 0 + assert stage.total_shards == 10 + assert stage.minimum_shard_index == 1 + assert "outside the assignable shard range [1, 10]" in caplog.text + + def test_enable_array_partitioning_warns_when_no_partitions_assigned( + self, + caplog: pytest.LogCaptureFixture, + empty_task: EmptyTask, + tmp_path: Path, + ): + """Test that non-empty input with zero assigned partitions is visible in logs.""" + test_files = _create_test_jsonl_files(tmp_path, num_files=2, subdir="path") + stage = FilePartitioningStage( + file_paths=test_files, + enable_array_partitioning=True, + shard_index=2, + total_shards=2, + ) + + with caplog.at_level("WARNING"): + result = stage.process(empty_task) + + assert result == [] + assert "assigned 0 of 2 partitions" in caplog.text + + def test_enable_array_partitioning_assigns_each_partition_to_one_shard( + self, + empty_task: EmptyTask, + tmp_path: Path, + ): + """Test that array partitioning covers all partitions exactly once.""" + test_files = _create_test_jsonl_files(tmp_path, num_files=8, subdir="path") + expected_partitions = { + tuple(test_files[i : i + 2]) + for i in range(0, len(test_files), 2) + } + assigned_partitions = [] + + for shard_index in range(3): + stage = FilePartitioningStage( + file_paths=test_files, + files_per_partition=2, + enable_array_partitioning=True, + shard_index=shard_index, + total_shards=3, + ) + + assigned_partitions.extend(tuple(task.data) for task in stage.process(empty_task)) + + assert set(assigned_partitions) == expected_partitions + assert len(assigned_partitions) == len(expected_partitions) + + def test_enable_array_partitioning_supports_minimum_shard_index( + self, + empty_task: EmptyTask, + tmp_path: Path, + ): + """Test non-zero Slurm arrays by offsetting hash-assigned shard IDs.""" + test_files = _create_test_jsonl_files(tmp_path, num_files=8, subdir="path") + zero_indexed_stage = FilePartitioningStage( + file_paths=test_files, + files_per_partition=2, + enable_array_partitioning=True, + shard_index=0, + total_shards=3, + ) + one_indexed_stage = FilePartitioningStage( + file_paths=test_files, + files_per_partition=2, + enable_array_partitioning=True, + shard_index=1, + total_shards=3, + minimum_shard_index=1, + ) + + zero_indexed_result = [task.data for task in zero_indexed_stage.process(empty_task)] + one_indexed_result = [task.data for task in one_indexed_stage.process(empty_task)] + + assert one_indexed_result == zero_indexed_result diff --git a/tutorials/slurm/README.md b/tutorials/slurm/README.md index 05565de2a0..2ab43f548a 100644 --- a/tutorials/slurm/README.md +++ b/tutorials/slurm/README.md @@ -9,6 +9,8 @@ This tutorial shows how to scale a NeMo Curator pipeline from a single laptop to | `pipeline.py` | A simple CPU-only pipeline (word-count + node-tag) that runs locally or on SLURM | | `submit.sh` | `sbatch` script for bare-metal clusters with a shared virtualenv | | `submit_container.sh` | `sbatch` script using the official NGC container (Pyxis/enroot) | +| `array_pipeline.py` | Generic JSONL/Parquet pipeline that processes one Slurm array shard | +| `submit_array.sh` | `sbatch --array` script for splitting many input files across independent jobs | --- @@ -184,6 +186,157 @@ tail -f logs/slurm_demo_.log --- +## SLURM job arrays — JSONL or Parquet file sharding + +Use `submit_array.sh` when you already have a large directory of text data files and want to split the file set across many independent Slurm jobs. Each array task starts its own Curator pipeline, hashes the input file partitions deterministically, and processes only the partitions assigned to that task. + +This pattern is useful when the dataset is naturally represented as many JSONL or Parquet files and you want simple horizontal scaling without coordination between jobs. + +### 1. Build the virtualenv on a shared filesystem + +The array example uses the official NGC container for the base environment, then activates your local checkout inside the container so unreleased source changes are picked up: + +```bash +cd /path/to/Curator +python -m venv .venv +source .venv/bin/activate +pip install -e . +``` + +Make sure `CURATOR_DIR`, `INPUT_DIR`, `OUTPUT_DIR`, and `CHECKPOINT_PATH` are visible from every compute node, either because they are on a shared filesystem or because you set `CONTAINER_MOUNTS` to expose the right host paths inside the container. + +### 2. Submit a JSONL array job + +By default, `submit_array.sh` reads JSONL files and writes JSONL output: + +```bash +export CURATOR_DIR=/path/to/Curator +export INPUT_DIR=/shared/data/my-jsonl-dataset +export OUTPUT_DIR=/shared/output/my-jsonl-dataset + +# 20 array tasks, task IDs 0-19 +sbatch --array=0-19 tutorials/slurm/submit_array.sh +``` + +For example, if the input directory contains 2000 files and `FILES_PER_PARTITION=1`, each of the 20 array tasks receives roughly 100 file partitions. Assignment is hash-based rather than contiguous, so work remains stable if Slurm retries a task. + +Single-node array tasks use `RayClient`. If you override the allocation to use more than one node per array task, `submit_array.sh` automatically passes `--slurm` to `array_pipeline.py`, which switches that task to `SlurmRayClient` so the nodes form one Ray cluster: + +```bash +sbatch --array=0-9 --nodes=2 --cpus-per-task=32 tutorials/slurm/submit_array.sh +``` + +### 3. Use Parquet instead + +Set the input and output file types to `parquet`: + +```bash +export INPUT_DIR=/shared/data/my-parquet-dataset +export OUTPUT_DIR=/shared/output/my-parquet-dataset +export INPUT_FILE_TYPE=parquet +export OUTPUT_FILE_TYPE=parquet + +sbatch --array=0-19 tutorials/slurm/submit_array.sh +``` + +### 4. Edit sharding logic + +If your array does not start at zero, set `MINIMUM_SHARD_INDEX` to the first task ID: + +```bash +MINIMUM_SHARD_INDEX=1 sbatch --array=1-20 tutorials/slurm/submit_array.sh +``` + +If your cluster limits the number of tasks in a single Slurm array, you can still use a larger logical shard count by overriding `TOTAL_SHARDS` and submitting the shard ID range in multiple windows. For example, if you want 10,000 logical shards but the cluster allows only 1,000 array tasks per submission: + +```bash +export TOTAL_SHARDS=10000 + +sbatch --array=0-999 tutorials/slurm/submit_array.sh +sbatch --array=1000-1999 tutorials/slurm/submit_array.sh +sbatch --array=2000-2999 tutorials/slurm/submit_array.sh +# ... +sbatch --array=9000-9999 tutorials/slurm/submit_array.sh +``` + +In this mode, keep `MINIMUM_SHARD_INDEX=0` because the Slurm array task IDs are already the global shard IDs. Each partition is assigned by `hash(partition) % TOTAL_SHARDS`, so the full set of windowed submissions covers shards `0` through `9999` exactly once. Some individual tasks may receive no files if `TOTAL_SHARDS` is larger than the number of file partitions. + +Some clusters enforce the maximum array index rather than just the number of tasks per submitted array. If `--array=1000-1999` is rejected, use `SHARD_INDEX_OFFSET` instead of higher Slurm task IDs. + +For those clusters, submit each window with Slurm task IDs `0-999` and set `SHARD_INDEX_OFFSET` so the script computes the global shard ID as `SLURM_ARRAY_TASK_ID + SHARD_INDEX_OFFSET`: + +```bash +export TOTAL_SHARDS=10000 + +SHARD_INDEX_OFFSET=0 sbatch --array=0-999 tutorials/slurm/submit_array.sh +SHARD_INDEX_OFFSET=1000 sbatch --array=0-999 tutorials/slurm/submit_array.sh +SHARD_INDEX_OFFSET=2000 sbatch --array=0-999 tutorials/slurm/submit_array.sh +# ... +SHARD_INDEX_OFFSET=9000 sbatch --array=0-999 tutorials/slurm/submit_array.sh +``` + +Keep `MINIMUM_SHARD_INDEX=0` for this offset mode too. `SHARD_INDEX_OFFSET` changes the logical shard ID passed to the pipeline; `MINIMUM_SHARD_INDEX` changes the assignable shard range used by the partitioning stage. + +### 5. Retry failed array tasks only + +When `submit_array.sh` launches `array_pipeline.py`, it passes `--checkpoint-path`. At startup, the driver process for each array task creates a pending retry manifest under: + +```bash +${CHECKPOINT_PATH:-$OUTPUT_DIR}/.nemo_curator_metadata/.slurm_array_retry/ +``` + +In other words, retries are tracked at `checkpoint_path/.nemo_curator_metadata/.slurm_array_retry/`. + +`submit_array.sh` also sets `NEMO_CURATOR_FAILED_TASKS_DIR` so the backend can record `FailedTask` sentinels produced by stages. By default, each Slurm job gets its own marker directory: + +```bash +${CHECKPOINT_PATH:-$OUTPUT_DIR}/.nemo_curator_metadata/.failed_tasks/slurm_job_${SLURM_JOB_ID}/array_task_${SLURM_ARRAY_TASK_ID}/shard_${SHARD_INDEX}/ +``` + +The marker directory is per Slurm job/task. For `--nodes>1`, all workers in that one array task share the same directory, so any worker can record a `FailedTask` and the driver can inspect the directory after `pipeline.run()` returns. + +If the shard completes successfully and no `FailedTask` markers were written, that shard's matching retry manifests are removed. If the process fails, is preempted, or reaches the Slurm time limit before cleanup runs, the manifest remains in the retry directory. Caught Python exceptions update the manifest with `status="failed"` and the error message; hard termination may leave `status="pending"`, which should still be treated as retryable after the original Slurm array has finished. + +If the pipeline completes without raising but one or more `FailedTask` marker files exist, `array_pipeline.py` keeps the retry manifest and updates it with `status="failed_tasks"` plus marker metadata. The retry collection below only needs to read retry manifests; it does not need to inspect the `FailedTask` marker directories directly. + +Retry manifests and `FailedTask` markers are uniquely named JSON files written with an atomic rename, so multiple array tasks or workers can write without coordinating through a shared database. + +Each manifest records the failed `shard_index`, plus the `total_shards` and `minimum_shard_index` values used for the original run. To retry only the failed shards, rebuild a Slurm array list from those manifests and preserve the original shard settings. For example, using `jq`: + +```bash +export CHECKPOINT_PATH="${CHECKPOINT_PATH:-$OUTPUT_DIR}" +RETRY_DIR="${CHECKPOINT_PATH}/.nemo_curator_metadata/.slurm_array_retry" + +shopt -s nullglob +MANIFEST_FILES=("${RETRY_DIR}"/manifest_*.json) +shopt -u nullglob + +if (( ${#MANIFEST_FILES[@]} == 0 )); then + echo "No failed shards found in ${RETRY_DIR}" >&2 + exit 1 +fi + +FAILED_SHARDS=$(jq -r '.shard_index' "${MANIFEST_FILES[@]}" | sort -n -u | paste -sd, -) +TOTAL_SHARDS_VALUES=$(jq -r '.total_shards' "${MANIFEST_FILES[@]}" | sort -n -u) +MINIMUM_SHARD_INDEX_VALUES=$(jq -r '.minimum_shard_index' "${MANIFEST_FILES[@]}" | sort -n -u) + +if [[ "${TOTAL_SHARDS_VALUES}" == *$'\n'* || "${MINIMUM_SHARD_INDEX_VALUES}" == *$'\n'* ]]; then + echo "Retry manifests contain multiple shard configurations; split them by run." >&2 + exit 1 +fi + +export TOTAL_SHARDS="${TOTAL_SHARDS_VALUES}" +export MINIMUM_SHARD_INDEX="${MINIMUM_SHARD_INDEX_VALUES}" + +sbatch --array="${FAILED_SHARDS}" tutorials/slurm/submit_array.sh +``` + +The `TOTAL_SHARDS` override is important. On a retry array like `--array=3,17,42`, Slurm sets `SLURM_ARRAY_TASK_COUNT=3`, but the data was originally assigned using the full logical shard count. Reusing the original `TOTAL_SHARDS` keeps `hash(partition) % total_shards` identical to the first run. + +Run this retry collection after the original Slurm array has finished, otherwise still-running tasks will still have pending manifests. Use one `CHECKPOINT_PATH` per logical array run, or move old retry manifests aside after building `FAILED_SHARDS`, so later retries do not include failures that already succeeded. If you override `NEMO_CURATOR_FAILED_TASKS_DIR`, keep it unique per Slurm job or clean it before reuse; stale `FailedTask` markers make an otherwise successful shard look retryable. + +--- + ## Configuration reference ### SlurmRayClient parameters diff --git a/tutorials/slurm/array_pipeline.py b/tutorials/slurm/array_pipeline.py new file mode 100644 index 0000000000..49fc070658 --- /dev/null +++ b/tutorials/slurm/array_pipeline.py @@ -0,0 +1,476 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Slurm array tutorial: split a large file collection across array tasks. + +Each Slurm array task processes its own slice of the input files. +With 2000 JSONL or Parquet files and --array=0-19, each of the 20 jobs gets +~100 files. + +Array partitioning parameters +------------------------------ +shard_index Which shard this job processes. +total_shards Total number of shards (i.e. array width). +minimum_shard_index Offset added to the hash-assigned shard before + comparing with shard_index. Use when the array does + not start at 0. E.g. --array=1-20 requires + minimum_shard_index=1 so shard IDs 1-20 match task IDs 1-20. + Default: 0. + +Usage (local smoke test against a small sample directory):: + + # Simulate task 0 of 4 locally (zero-indexed array) + python tutorials/slurm/array_pipeline.py \\ + --input-dir /path/to/input/directory \\ + --output-dir /path/to/output/directory \\ + --shard-index 0 --total-shards 4 + + # Non-zero-indexed array: tasks 1-4, minimum_shard_index=1 + python tutorials/slurm/array_pipeline.py \\ + --input-dir /path/to/input/directory \\ + --output-dir /path/to/output/directory \\ + --shard-index 1 --total-shards 4 --minimum-shard-index 1 + + # Use Parquet input/output instead of the default JSONL: + python tutorials/slurm/array_pipeline.py \\ + --input-dir /path/to/input/directory \\ + --input-file-type parquet \\ + --output-dir /path/to/output/directory \\ + --output-file-type parquet \\ + --shard-index 0 --total-shards 4 + + # Or let the sbatch script read Slurm env vars and pass explicit args: + sbatch --array=0-19 tutorials/slurm/submit_array.sh +""" + +from __future__ import annotations + +import argparse +import contextlib +import datetime +import json +import os +import socket +import tempfile +import uuid +from pathlib import Path + +from loguru import logger + +from nemo_curator.backends.base import FAILED_TASKS_DIR_ENV_VAR +from nemo_curator.core.client import RayClient, SlurmRayClient +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.text.io.reader import JsonlReader, ParquetReader +from nemo_curator.stages.text.io.writer import JsonlWriter, ParquetWriter + +METADATA_DIRNAME = ".nemo_curator_metadata" +SLURM_ARRAY_RETRY_DIRNAME = ".slurm_array_retry" +FAILED_TASK_MARKER_PATTERN = "failed_task_*.json" +MAX_FAILED_TASK_MARKERS_IN_MANIFEST = 10 + + +def _safe_token(value: object) -> str: + """Convert a value to a conservative filename token.""" + return "".join(char if char.isalnum() or char in "._-" else "_" for char in str(value)) + + +def _parse_int_or_env_name(value: str) -> int | str: + """Parse an integer value or keep an environment variable name.""" + try: + return int(value) + except ValueError: + return value + + +def _resolve_int_or_env_name(value: int | str, label: str) -> int: + """Resolve an integer or an environment variable name containing an integer.""" + if isinstance(value, int): + return value + + env_value = os.environ.get(value) + if env_value is None: + msg = f"{label} references environment variable {value}, but it is not set" + raise ValueError(msg) + try: + return int(env_value) + except ValueError as e: + msg = f"{label} references environment variable {value}, which must contain an integer, got {env_value!r}" + raise ValueError(msg) from e + + +def _is_driver_process(use_slurm: bool) -> bool: + """Return True for the process that should run the pipeline and own retry metadata.""" + return not use_slurm or os.environ.get("SLURM_NODEID", "0") == "0" + + +def _retry_manifest_prefix( + shard_index: int, + total_shards: int, + minimum_shard_index: int, +) -> str: + return ( + f"manifest_shard-{_safe_token(shard_index)}_" + f"total-{_safe_token(total_shards)}_" + f"min-{_safe_token(minimum_shard_index)}_" + ) + + +def _retry_manifest_payload( # noqa: PLR0913 + shard_index: int, + total_shards: int, + minimum_shard_index: int, + status: str, + created_at: datetime.datetime, + error: BaseException | None = None, + extra: dict[str, object] | None = None, +) -> dict[str, object]: + payload = { + "shard_index": shard_index, + "total_shards": total_shards, + "minimum_shard_index": minimum_shard_index, + "status": status, + "created_at": created_at.isoformat(), + "slurm_job_id": os.environ.get("SLURM_JOB_ID"), + "slurm_array_job_id": os.environ.get("SLURM_ARRAY_JOB_ID"), + "slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"), + "slurm_node_id": os.environ.get("SLURM_NODEID"), + "hostname": socket.gethostname(), + "pid": os.getpid(), + } + + if error is not None: + payload["error_type"] = type(error).__name__ + payload["error"] = str(error) + + if extra is not None: + payload.update(extra) + + return payload + + +def write_retry_manifest( # noqa: PLR0913 + checkpoint_path: str, + shard_index: int | None, + total_shards: int | None, + minimum_shard_index: int, + status: str, + error: BaseException | None = None, + manifest_file: Path | None = None, + extra: dict[str, object] | None = None, +) -> Path: + """Write a retry manifest using a unique name and atomic rename.""" + retry_dir = Path(checkpoint_path, METADATA_DIRNAME, SLURM_ARRAY_RETRY_DIRNAME).absolute() + retry_dir.mkdir(parents=True, exist_ok=True) + + created_at = datetime.datetime.now(datetime.UTC) + manifest = _retry_manifest_payload( + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + status=status, + created_at=created_at, + error=error, + extra=extra, + ) + + if manifest_file is None: + timestamp = created_at.strftime("%Y%m%d%H%M%S%f") + manifest_file = retry_dir / ( + _retry_manifest_prefix(shard_index, total_shards, minimum_shard_index) + + f"job-{_safe_token(os.environ.get('SLURM_JOB_ID', 'local'))}_" + + f"task-{_safe_token(os.environ.get('SLURM_ARRAY_TASK_ID', 'local'))}_" + + f"node-{_safe_token(os.environ.get('SLURM_NODEID', '0'))}_" + + f"pid-{os.getpid()}_" + + f"{timestamp}_" + + f"{uuid.uuid4().hex}.json" + ) + + tmp_path = None + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=retry_dir, + prefix=f".{manifest_file.name}.", + suffix=".tmp", + delete=False, + ) as tmp_file: + tmp_path = Path(tmp_file.name) + json.dump(manifest, tmp_file, indent=2, sort_keys=True) + tmp_file.write("\n") + tmp_file.flush() + os.fsync(tmp_file.fileno()) + os.replace(tmp_path, manifest_file) + except Exception: + if tmp_path is not None: + with contextlib.suppress(FileNotFoundError): + tmp_path.unlink() + raise + + return manifest_file + + +def failed_task_marker_files() -> list[Path]: + """Return FailedTask marker files written by BaseStageAdapter for this job.""" + failed_tasks_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR) + if not failed_tasks_dir: + return [] + + marker_dir = Path(failed_tasks_dir).absolute() + if not marker_dir.exists(): + return [] + + return sorted(path for path in marker_dir.glob(FAILED_TASK_MARKER_PATTERN) if path.is_file()) + + +def failed_task_manifest_metadata(marker_files: list[Path]) -> dict[str, object]: + marker_dir = marker_files[0].parent if marker_files else os.environ.get(FAILED_TASKS_DIR_ENV_VAR) + sample_marker_files = marker_files[:MAX_FAILED_TASK_MARKERS_IN_MANIFEST] + return { + "failed_task_marker_dir": str(marker_dir), + "failed_task_marker_count": len(marker_files), + "failed_task_marker_files": [str(path) for path in sample_marker_files], + "failed_task_marker_files_truncated": len(marker_files) > len(sample_marker_files), + } + + +def remove_retry_manifests( + checkpoint_path: str, + shard_index: int | None, + total_shards: int | None, + minimum_shard_index: int, +) -> None: + """Remove retry manifests for a shard after successful completion.""" + retry_dir = Path(checkpoint_path, METADATA_DIRNAME, SLURM_ARRAY_RETRY_DIRNAME).absolute() + if not retry_dir.exists(): + return + + pattern = _retry_manifest_prefix(shard_index, total_shards, minimum_shard_index) + "*.json" + for manifest_file in retry_dir.glob(pattern): + with contextlib.suppress(FileNotFoundError): + manifest_file.unlink() + + +def build_pipeline( # noqa: PLR0913 + input_dir: str, + input_file_type: str, + output_dir: str, + output_file_type: str, + files_per_partition: int, + shard_index: int | None, + total_shards: int | None, + minimum_shard_index: int, +) -> Pipeline: + pipeline = Pipeline( + name="slurm_array_demo", + description=( + "Read files from input directory assigned to this Slurm array task and write them out to output directory." + ), + ) + + # submit_array.sh maps Slurm array env vars into explicit shard arguments. + # Direct users may also pass env var names; those are resolved before the + # pipeline is built so retry manifests record concrete shard values. + if input_file_type == "jsonl": + pipeline.add_stage( + JsonlReader( + file_paths=input_dir, + files_per_partition=files_per_partition, + enable_array_partitioning=True, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + ) + ) + elif input_file_type == "parquet": + pipeline.add_stage( + ParquetReader( + file_paths=input_dir, + files_per_partition=files_per_partition, + enable_array_partitioning=True, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + ) + ) + else: + msg = f"Unsupported input file type: {input_file_type}" + raise ValueError(msg) + + if output_file_type == "jsonl": + pipeline.add_stage(JsonlWriter(output_dir)) + elif output_file_type == "parquet": + pipeline.add_stage(ParquetWriter(output_dir)) + else: + msg = f"Unsupported output file type: {output_file_type}" + raise ValueError(msg) + + return pipeline + + +def main() -> None: # noqa: PLR0915 + parser = argparse.ArgumentParser(description="Slurm array file-partitioning demo") + parser.add_argument("--input-dir", required=True, help="Directory containing input files") + parser.add_argument( + "--input-file-type", + choices=["jsonl", "parquet"], + default="jsonl", + help="Type of input files (default: jsonl)", + ) + parser.add_argument("--output-dir", required=True, help="Directory to write output files") + parser.add_argument( + "--output-file-type", + choices=["jsonl", "parquet"], + default="jsonl", + help="Type of output files (default: jsonl)", + ) + parser.add_argument( + "--files-per-partition", + type=int, + default=1, + help="Files grouped into each FileGroupTask (default: 1)", + ) + parser.add_argument( + "--shard-index", + type=_parse_int_or_env_name, + required=True, + help="Shard to process, as an integer or environment variable name.", + ) + parser.add_argument( + "--total-shards", + type=_parse_int_or_env_name, + required=True, + help="Total number of shards, as an integer or environment variable name.", + ) + parser.add_argument( + "--minimum-shard-index", + type=_parse_int_or_env_name, + default=0, + help=( + "Offset added to the hash-assigned shard before comparison. " + "Set to match the first task ID when the array does not start at 0 " + "(e.g. --array=1-20 requires --minimum-shard-index=1). Default: 0." + ), + ) + parser.add_argument( + "--checkpoint-path", + dest="checkpoint_path", + type=str, + default=None, + help=( + "Path for checkpoint metadata. Slurm array retry manifests are written under " + "/.nemo_curator_metadata/.slurm_array_retry/. Defaults to None." + ), + ) + parser.add_argument( + "--slurm", + action="store_true", + help="Use SlurmRayClient for multi-node srun jobs.", + ) + args = parser.parse_args() + + shard_index = _resolve_int_or_env_name(args.shard_index, "shard_index") + total_shards = _resolve_int_or_env_name(args.total_shards, "total_shards") + minimum_shard_index = _resolve_int_or_env_name(args.minimum_shard_index, "minimum_shard_index") + + ray_client = SlurmRayClient() if args.slurm else RayClient() + retry_manifest_file = None + is_driver_process = _is_driver_process(args.slurm) + should_manage_retry_manifest = args.checkpoint_path is not None and is_driver_process + + try: + if should_manage_retry_manifest: + retry_manifest_file = write_retry_manifest( + checkpoint_path=args.checkpoint_path, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + status="pending", + ) + logger.info(f"Wrote pending Slurm array retry manifest to {retry_manifest_file}") + + ray_client.start() + + pipeline = build_pipeline( + args.input_dir, + args.input_file_type, + args.output_dir, + args.output_file_type, + args.files_per_partition, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + ) + logger.info(f"\n{pipeline.describe()}") + + pipeline.run() + + failed_task_markers = failed_task_marker_files() + if failed_task_markers: + if should_manage_retry_manifest: + manifest_file = write_retry_manifest( + checkpoint_path=args.checkpoint_path, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + status="failed_tasks", + manifest_file=retry_manifest_file, + extra=failed_task_manifest_metadata(failed_task_markers), + ) + logger.warning( + "Pipeline completed without raising, but found " + f"{len(failed_task_markers)} FailedTask marker(s). " + f"Keeping retry manifest at {manifest_file}." + ) + else: + logger.warning( + f"Pipeline completed without raising, but found {len(failed_task_markers)} FailedTask marker(s)." + ) + return + + if should_manage_retry_manifest: + try: + remove_retry_manifests( + checkpoint_path=args.checkpoint_path, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + ) + except Exception as cleanup_error: # noqa: BLE001 + logger.error(f"Pipeline succeeded but failed to remove retry manifest: {cleanup_error}") + except Exception as e: + if should_manage_retry_manifest: + try: + manifest_file = write_retry_manifest( + checkpoint_path=args.checkpoint_path, + shard_index=shard_index, + total_shards=total_shards, + minimum_shard_index=minimum_shard_index, + status="failed", + error=e, + manifest_file=retry_manifest_file, + ) + logger.error(f"Wrote Slurm array retry manifest to {manifest_file}") + except Exception as manifest_error: # noqa: BLE001 + logger.error(f"Failed to write Slurm array retry manifest: {manifest_error}") + + logger.error(f"Error running pipeline: {e}") + raise + finally: + if is_driver_process: + ray_client.stop() + + +if __name__ == "__main__": + main() diff --git a/tutorials/slurm/submit_array.sh b/tutorials/slurm/submit_array.sh new file mode 100644 index 0000000000..8060a38586 --- /dev/null +++ b/tutorials/slurm/submit_array.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# ============================================================================= +# NeMo Curator — Slurm array submit script +# +# Splits a large set of JSONL or Parquet files across multiple Slurm array +# tasks so that each job independently processes its assigned slice of the input. +# +# Example: 2000 input files, --array=0-19 -> 20 jobs x ~100 files each. +# +# How it works: +# - FilePartitioningStage groups all files into partitions of files_per_partition. +# - Each array task reads SLURM_ARRAY_TASK_ID / SLURM_ARRAY_TASK_COUNT and +# selects only the partitions assigned to it via deterministic SHA-256 hashing. +# - Jobs run in parallel with no coordination between them. +# +# Prerequisites: +# - NeMo Curator source checked out on a shared filesystem (Lustre, NFS, etc.) +# - A virtualenv built at ${CURATOR_DIR}/.venv with NeMo Curator installed +# - INPUT_DIR set to a directory of JSONL or Parquet files visible from all compute nodes +# - OUTPUT_DIR set to a writable shared directory +# - CHECKPOINT_PATH set to a writable shared path for retry manifests +# +# Usage: +# # 20 jobs (task IDs 0-19), ~100 files per job with a 2000-file dataset +# sbatch --array=0-19 tutorials/slurm/submit_array.sh +# +# # Override array size or resources at submission time: +# sbatch --array=0-9 --cpus-per-task=32 tutorials/slurm/submit_array.sh +# sbatch --array=0-39 --time=02:00:00 tutorials/slurm/submit_array.sh +# +# Array indexing: +# shard_index = SLURM_ARRAY_TASK_ID (set automatically by Slurm) +# total_shards = SLURM_ARRAY_TASK_COUNT (set automatically by Slurm) +# minimum_shard_index defaults to 0 — no env var fallback. +# shard_index_offset defaults to 0 and is added to SLURM_ARRAY_TASK_ID +# only when SHARD_INDEX is not explicitly set. +# +# If your array does not start at 0 (e.g. --array=1-20), set: +# MINIMUM_SHARD_INDEX=1 sbatch --array=1-20 tutorials/slurm/submit_array.sh +# ============================================================================= + +#SBATCH --job-name=curator-array +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --gpus-per-node=0 +#SBATCH --time=01:00:00 +#SBATCH --output=array_%A_%a.log +#SBATCH --error=array_%A_%a.log + +set -euo pipefail + +# --------------------------------------------------------------------------- +# Paths — adjust to your environment +# --------------------------------------------------------------------------- +CURATOR_DIR="${CURATOR_DIR:-$(cd "$(dirname "$0")/../.." && pwd)}" + +# Input and output directories +INPUT_DIR="${INPUT_DIR:-/path/to/your/input/directory}" +OUTPUT_DIR="${OUTPUT_DIR:-/path/to/your/output/directory}" + +# Retry manifests are written under: +# ${CHECKPOINT_PATH}/.nemo_curator_metadata/.slurm_array_retry/ +# FailedTask marker files are written under: +# ${CHECKPOINT_PATH}/.nemo_curator_metadata/.failed_tasks//// +# Defaults to OUTPUT_DIR. If you override CONTAINER_MOUNTS, make sure it still +# includes CHECKPOINT_PATH. +CHECKPOINT_PATH="${CHECKPOINT_PATH:-${OUTPUT_DIR}}" + +# Official NeMo Curator container from NGC. +# Browse available tags: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo-curator +CONTAINER_IMAGE="${CONTAINER_IMAGE:-nvcr.io/nvidia/nemo-curator:26.02}" + +# Mount the shared filesystem paths that contain your code and data. +# Format: :[,:] +# Override this if your cluster expects filesystem roots, e.g. /lustre:/lustre,/home:/home. +DEFAULT_CONTAINER_MOUNTS="${CURATOR_DIR}:${CURATOR_DIR},${INPUT_DIR}:${INPUT_DIR},${OUTPUT_DIR}:${OUTPUT_DIR}" +if [[ "${CHECKPOINT_PATH}" != "${OUTPUT_DIR}" ]]; then + DEFAULT_CONTAINER_MOUNTS="${DEFAULT_CONTAINER_MOUNTS},${CHECKPOINT_PATH}:${CHECKPOINT_PATH}" +fi +CONTAINER_MOUNTS="${CONTAINER_MOUNTS:-${DEFAULT_CONTAINER_MOUNTS}}" + +# Input and output file types +INPUT_FILE_TYPE="${INPUT_FILE_TYPE:-jsonl}" +OUTPUT_FILE_TYPE="${OUTPUT_FILE_TYPE:-jsonl}" + +# Number of files to read into a single DocumentBatch +FILES_PER_PARTITION="${FILES_PER_PARTITION:-1}" + +# Shard index and total shards. +# +# SHARD_INDEX_OFFSET is useful on clusters that limit the maximum Slurm array +# index. For example, submit --array=0-999 with SHARD_INDEX_OFFSET=1000 to +# process logical shards 1000-1999. +SHARD_INDEX_OFFSET="${SHARD_INDEX_OFFSET:-0}" +SHARD_INDEX="${SHARD_INDEX:-$((SLURM_ARRAY_TASK_ID + SHARD_INDEX_OFFSET))}" +TOTAL_SHARDS="${TOTAL_SHARDS:-${SLURM_ARRAY_TASK_COUNT}}" + +# Offset between 0-indexed hash assignments and SLURM_ARRAY_TASK_ID. +# Leave at 0 for --array=0-N. Set to the array start value for --array=K-N. +MINIMUM_SHARD_INDEX="${MINIMUM_SHARD_INDEX:-0}" + +# BaseStageAdapter writes one marker JSON per FailedTask when this env var is +# set. Keep one directory per Slurm array job/task so retries can inspect just +# the FailedTasks from that job, while multi-node workers for that same job +# write into the same directory. +FAILED_TASKS_DIR="${FAILED_TASKS_DIR:-${CHECKPOINT_PATH}/.nemo_curator_metadata/.failed_tasks/slurm_job_${SLURM_JOB_ID:-local}/array_task_${SLURM_ARRAY_TASK_ID:-local}/shard_${SHARD_INDEX}}" +NEMO_CURATOR_FAILED_TASKS_DIR="${NEMO_CURATOR_FAILED_TASKS_DIR:-${FAILED_TASKS_DIR}}" + +# Use SlurmRayClient only when this array task spans multiple nodes. Single-node +# array tasks can use the regular RayClient. +NUM_NODES="${SLURM_JOB_NUM_NODES:-${SLURM_NNODES:-1}}" +USE_SLURM_RAY=0 +if (( NUM_NODES > 1 )); then + USE_SLURM_RAY=1 +fi + +mkdir -p "${CURATOR_DIR}/logs" "${OUTPUT_DIR}" "${CHECKPOINT_PATH}" "${NEMO_CURATOR_FAILED_TASKS_DIR}" + +export CURATOR_DIR +export INPUT_DIR +export OUTPUT_DIR +export CHECKPOINT_PATH +export FAILED_TASKS_DIR +export NEMO_CURATOR_FAILED_TASKS_DIR +export INPUT_FILE_TYPE +export OUTPUT_FILE_TYPE +export FILES_PER_PARTITION +export SHARD_INDEX_OFFSET +export SHARD_INDEX +export TOTAL_SHARDS +export MINIMUM_SHARD_INDEX +export USE_SLURM_RAY + +echo "==================================================" +echo " NeMo Curator — Slurm Array Demo" +echo "==================================================" +echo " Job array ID : ${SLURM_ARRAY_JOB_ID}" +echo " Array task ID : ${SLURM_ARRAY_TASK_ID}" +echo " Array task cnt : ${SLURM_ARRAY_TASK_COUNT}" +echo " Shard index : ${SHARD_INDEX}" +echo " Shard offset : ${SHARD_INDEX_OFFSET}" +echo " Total shards : ${TOTAL_SHARDS}" +echo " Nodes : ${NUM_NODES}" +echo " Ray client : $([[ "${USE_SLURM_RAY}" == "1" ]] && echo SlurmRayClient || echo RayClient)" +echo " Node : $(hostname)" +echo " Container : ${CONTAINER_IMAGE}" +echo " Mounts : ${CONTAINER_MOUNTS}" +echo " Dir : ${CURATOR_DIR}" +echo " Checkpoint path: ${CHECKPOINT_PATH}" +echo " FailedTask dir : ${NEMO_CURATOR_FAILED_TASKS_DIR}" +echo "==================================================" + +# Each array task processes only the file partitions hashed to its +# SLURM_ARRAY_TASK_ID. With --nodes=1, the task uses a local RayClient. With +# --nodes>1, the same Python entrypoint runs on every node and --slurm enables +# SlurmRayClient so workers join the head-node Ray cluster. +srun \ + --ntasks-per-node=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + --container-workdir="${CURATOR_DIR}" \ + bash -c ' +set -euo pipefail + +export RAY_TMPDIR="/tmp/ray_${SLURM_JOB_ID}" +export RAY_PORT_BROADCAST_DIR="${CURATOR_DIR}/logs" + +# Activate the local virtualenv so the latest Curator code (from this +# checkout) is used instead of the version bundled in the container image. +source "${CURATOR_DIR}/.venv/bin/activate" + +echo "[$(hostname)] SLURM_NODEID=${SLURM_NODEID:-0} python=$(python --version 2>&1)" +nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader 2>/dev/null \ + | sed "s/^/ [$(hostname)] GPU /" || echo " [$(hostname)] no GPUs" + +pipeline_args=( + --input-dir "${INPUT_DIR}" + --input-file-type "${INPUT_FILE_TYPE}" + --output-dir "${OUTPUT_DIR}" + --output-file-type "${OUTPUT_FILE_TYPE}" + --files-per-partition "${FILES_PER_PARTITION}" + --shard-index "${SHARD_INDEX}" + --total-shards "${TOTAL_SHARDS}" + --minimum-shard-index "${MINIMUM_SHARD_INDEX}" + --checkpoint-path "${CHECKPOINT_PATH}" +) + +if [[ "${USE_SLURM_RAY}" == "1" ]]; then + pipeline_args+=(--slurm) +fi + +python "${CURATOR_DIR}/tutorials/slurm/array_pipeline.py" "${pipeline_args[@]}" +' + +echo "==================================================" +echo " Array task ${SLURM_ARRAY_TASK_ID} DONE" +echo "=================================================="