Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions nemo_curator/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,19 +12,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import json
import os
import socket
import tempfile
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

from loguru import logger

from nemo_curator.core.utils import ignore_ray_head_node
from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask, NoneTask
from nemo_curator.utils.performance_utils import StageTimer

if TYPE_CHECKING:
from nemo_curator.stages.base import ProcessingStage


FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR"


def _safe_filename_token(value: object) -> str:
return "".join(ch if ch.isalnum() or ch in "._-" else "_" for ch in str(value))


def _write_failed_task_marker(marker_dir: Path, stage_name: str, task: FailedTask) -> None:
created_at = datetime.datetime.now(datetime.UTC)
timestamp = created_at.strftime("%Y%m%dT%H%M%S%fZ")
payload: dict[str, str | int] = {
"created_at": created_at.isoformat(),
"stage_name": stage_name,
"task_id": task.task_id,
"dataset_name": task.dataset_name,
"task_type": type(task).__name__,
"hostname": socket.gethostname(),
"pid": os.getpid(),
}

marker_dir.mkdir(parents=True, exist_ok=True)
filename = (
"failed_task_"
f"stage-{_safe_filename_token(stage_name)}_"
f"task-{_safe_filename_token(task.task_id)}_"
f"pid-{os.getpid()}_"
f"{timestamp}_{uuid.uuid4().hex}.json"
)
final_path = marker_dir / filename

tmp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=marker_dir,
prefix=f".{filename}.",
suffix=".tmp",
delete=False,
) as tmp:
tmp_path = Path(tmp.name)
json.dump(payload, tmp, indent=2, sort_keys=True)
tmp.write("\n")

os.replace(tmp_path, final_path)
except Exception:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
raise


@dataclass
class NodeInfo:
"""Generic node information for setup_on_node calls across backends.
Expand Down Expand Up @@ -85,9 +145,20 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
# Use the batch processing logic
results = self.stage.process_batch(tasks)

# A returned ``None`` ("filter this slot") becomes a NoneTask so every
# output is a real Task that gets a task_id. Sentinels (NoneTask /
# FailedTask) carry no identity and are stripped again before this
# method returns.
results = [NoneTask() if r is None else r for r in results]

# Guarantee every emitted task has a task_id (derived id, or uuid fallback).
results = self._post_process_task_ids(tasks, results)

self._record_failed_tasks([r for r in results if isinstance(r, FailedTask)])

@sarahyurick sarahyurick Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @abhinavg4 . For now the PR keeps track of FailedTask instances by looking for a user-set FAILED_TASKS_DIR_ENV_VAR = "NEMO_CURATOR_FAILED_TASKS_DIR" and writing a JSON file per failed task in the specified directory.

I did the environment variable and write approach because it seems more reliable than trying to handle a global Python variable, etc. And the reason it is an environment variable is so that BaseStageAdapter does not have to propagate an additional parameter for every single stage (which I think would involve having to update the executors as well?). Open to other suggestions.


# Sentinels never propagate to the next stage.
results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))]

# Log performance stats and add to result tasks
_, stage_perf_stats = self._timer.log_stats()
# Consume and attach any custom metrics recorded by the stage during this call
Expand All @@ -99,6 +170,18 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:

return results

def _record_failed_tasks(self, failed_tasks: list[FailedTask]) -> None:
marker_dir = os.environ.get(FAILED_TASKS_DIR_ENV_VAR)
if not marker_dir or not failed_tasks:
return

marker_path = Path(marker_dir)
for task in failed_tasks:
try:
_write_failed_task_marker(marker_path, self.stage.name, task)
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to write FailedTask marker to {marker_path}: {e}")

def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]:
"""Assign a deterministic ``task_id`` to every emitted task.

Expand Down Expand Up @@ -134,7 +217,7 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
(returning a flat list rather than a per-input slot) cannot be mapped
positionally; if its length happens to equal the input length the 1:1
assumption may misattribute parents. That combination is unsupported
until per-slot sentinels (NoneTask/FailedTask) land in a later PR.
unless the stage preserves an unambiguous input→output mapping.
"""
is_source = getattr(self.stage, "is_source_stage", False)

Expand Down
15 changes: 15 additions & 0 deletions nemo_curator/stages/audio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ class ManifestReader(CompositeStage[EmptyTask, AudioTask]):
blocksize: Target size per partition (e.g., "100MB"). Ignored if files_per_partition is set.
file_extensions: File extensions to filter. Defaults to [".jsonl", ".json"].
storage_options: Storage options for cloud paths (S3, GCS credentials, endpoints).
enable_array_partitioning: Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs).
shard_index: The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable.
total_shards: The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable.
minimum_shard_index: The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to 0.
"""

manifest_path: str | list[str]
Expand All @@ -200,6 +207,10 @@ class ManifestReader(CompositeStage[EmptyTask, AudioTask]):
blocksize: int | str | None = None
file_extensions: list[str] = field(default_factory=lambda: [".jsonl", ".json"])
storage_options: dict[str, Any] | None = None
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0

def __post_init__(self) -> None:
super().__init__()
Expand All @@ -215,6 +226,10 @@ def decompose(self) -> list[ProcessingStage]:
blocksize=self.blocksize,
file_extensions=self.file_extensions,
storage_options=self.storage_options,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
),
ManifestReaderStage(),
]
Expand Down
88 changes: 87 additions & 1 deletion nemo_curator/stages/file_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
from dataclasses import dataclass
from typing import Any

Expand All @@ -29,6 +31,26 @@
)


def _get_int_or_env_var(input_value: int | str | None, default_name: str | None = None) -> int:
if type(input_value) is int:
return input_value

env_var = input_value if type(input_value) is str else default_name
if env_var is None:
msg = f"Invalid input value: {input_value}, must be an integer or a string"
raise ValueError(msg)

env_value = os.environ.get(env_var)
if env_value is None:
msg = f"Environment variable {env_var} is not set"
raise ValueError(msg)
try:
return int(env_value)
except ValueError as e:
msg = f"Environment variable {env_var} must contain an integer, got {env_value!r}"
raise ValueError(msg) from e


@dataclass
class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]):
"""Stage that partitions input file paths into FileGroupTasks.
Expand All @@ -55,6 +77,18 @@ class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]):
Storage options to pass to the file system.
limit: int | None = None
Maximum number of partitions to create.
enable_array_partitioning: bool = False
Whether to enable array partitioning (e.g., partition files across multiple Slurm jobs).
Intended for use with Slurm job arrays via the `sbatch --array` option.
shard_index: int | str | None = None
The index of the shard to process. Can be an integer representing the shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_ID environment variable.
total_shards: int | str | None = None
The total number of shards. Can be an integer representing the total number of shards or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to the value of the SLURM_ARRAY_TASK_COUNT environment variable.
minimum_shard_index: int | str = 0
The minimum shard index to process. Can be an integer representing the minimum shard index or a string representing the environment variable name.
Only used if enable_array_partitioning is True. If not provided, it will be set to 0.
"""

file_paths: str | list[str]
Expand All @@ -63,6 +97,10 @@ class FilePartitioningStage(ProcessingStage[EmptyTask, FileGroupTask]):
file_extensions: list[str] | None = None
storage_options: dict[str, Any] | None = None
limit: int | None = None
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
name: str = "file_partitioning"

def __post_init__(self):
Expand Down Expand Up @@ -91,6 +129,25 @@ def __post_init__(self):

self.resources = Resources(cpus=0.5)

if self.enable_array_partitioning:
self.shard_index = _get_int_or_env_var(self.shard_index, "SLURM_ARRAY_TASK_ID")
self.total_shards = _get_int_or_env_var(self.total_shards, "SLURM_ARRAY_TASK_COUNT")
self.minimum_shard_index = _get_int_or_env_var(self.minimum_shard_index)
if self.total_shards <= 0:
msg = f"total_shards must be greater than 0, got {self.total_shards}"
raise ValueError(msg)
min_assignable_shard_index = self.minimum_shard_index
max_assignable_shard_index = self.minimum_shard_index + self.total_shards - 1
if not min_assignable_shard_index <= self.shard_index <= max_assignable_shard_index:
logger.warning(
"shard_index={} is outside the assignable shard range [{}, {}]. "
"This task will not receive any partitions.",
self.shard_index,
min_assignable_shard_index,
max_assignable_shard_index,
)
self.name = "array_file_partitioning"

def inputs(self) -> tuple[list[str], list[str]]:
return [], []

Expand All @@ -106,7 +163,7 @@ def ray_stage_spec(self) -> dict[str, Any]:
def xenna_stage_spec(self) -> dict[str, Any]:
return {"num_workers_per_node": 1}

def process(self, _: EmptyTask) -> list[FileGroupTask]:
def _process(self, _: EmptyTask) -> list[FileGroupTask]:
"""Process the initial task to create file group tasks.

This stage expects a simple Task with file paths information
Expand Down Expand Up @@ -189,6 +246,35 @@ def process(self, _: EmptyTask) -> list[FileGroupTask]:
logger.info(f"Created {len(tasks)} file groups from {len(files)} files")
return tasks

def _process_array(self, task: EmptyTask) -> list[FileGroupTask]:
all_tasks = self._process(task)
assigned_tasks = []

for ft in all_tasks:
source_files = list(ft._metadata.get("source_files") or ft.data)
# Hash the source files to get a unique identifier for the partition
digest = hashlib.sha256("|".join(sorted(source_files)).encode("utf-8")).hexdigest()
# Assign the partition to the shard
assigned = int(digest[:16], 16) % self.total_shards
# Add the minimum shard index to the assigned shard index
assigned += self.minimum_shard_index
# Add the partition to the assigned tasks
if assigned == self.shard_index:
assigned_tasks.append(ft)

msg = f"Shard {self.shard_index}/{self.total_shards}: assigned {len(assigned_tasks)} of {len(all_tasks)} partitions"
if len(assigned_tasks) == 0 and len(all_tasks) > 0:
logger.warning(msg)
else:
logger.info(msg)
return assigned_tasks
Comment thread
sarahyurick marked this conversation as resolved.

def process(self, task: EmptyTask) -> list[FileGroupTask]:
if self.enable_array_partitioning:
return self._process_array(task)
else:
return self._process(task)

def _get_file_list_with_sizes(self, sort_by_size: bool = True) -> list[tuple[str, int]]:
"""
Get the list of files to process.
Expand Down
16 changes: 16 additions & 0 deletions nemo_curator/stages/interleaved/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class InterleavedWebdatasetReader(CompositeStage[EmptyTask, InterleavedBatch]):
blocksize: int | str | None = None
max_batch_bytes: int | None = None
read_kwargs: dict[str, Any] = field(default_factory=dict)
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
materialize_on_read: bool = False
file_extensions: list[str] = field(default_factory=lambda: list(DEFAULT_WEBDATASET_EXTENSIONS))
json_extensions: list[str] = field(default_factory=lambda: list(DEFAULT_JSON_EXTENSIONS))
Expand All @@ -68,6 +72,10 @@ def decompose(self) -> list:
blocksize=self.blocksize,
file_extensions=self.file_extensions,
storage_options=self.storage_options,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
),
InterleavedWebdatasetReaderStage(
read_kwargs=self.read_kwargs,
Expand Down Expand Up @@ -96,6 +104,10 @@ class InterleavedParquetReader(CompositeStage[EmptyTask, InterleavedBatch]):
fields: tuple[str, ...] | None = None
max_batch_bytes: int | None = None
read_kwargs: dict[str, Any] = field(default_factory=dict)
enable_array_partitioning: bool = False
shard_index: int | str | None = None
total_shards: int | str | None = None
minimum_shard_index: int | str = 0
schema: pa.Schema | None = None
schema_overrides: dict[str, pa.DataType] | None = None
file_extensions: list[str] = field(default_factory=lambda: [".parquet"])
Expand All @@ -113,6 +125,10 @@ def decompose(self) -> list:
blocksize=self.blocksize,
file_extensions=self.file_extensions,
storage_options=self.storage_options,
enable_array_partitioning=self.enable_array_partitioning,
shard_index=self.shard_index,
total_shards=self.total_shards,
minimum_shard_index=self.minimum_shard_index,
),
InterleavedParquetReaderStage(
read_kwargs=self.read_kwargs,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading
Loading