Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions nemo_curator/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from loguru import logger

from nemo_curator.core.utils import ignore_ray_head_node
from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask, NoneTask
from nemo_curator.utils.performance_utils import StageTimer
from nemo_curator.utils.resumability_client import _flush_deltas, _is_active, _skip_completed_sources

if TYPE_CHECKING:
from nemo_curator.stages.base import ProcessingStage
Expand Down Expand Up @@ -85,9 +89,23 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
# Use the batch processing logic
results = self.stage.process_batch(tasks)

# A returned ``None`` ("filter this slot") becomes a NoneTask so every
# output is a real Task that gets a task_id. Sentinels (NoneTask /
# FailedTask) carry no identity and are stripped again before this
# method returns.
results = [NoneTask() if r is None else r for r in results]

# Guarantee every emitted task has a task_id (derived id, or uuid fallback).
results = self._post_process_task_ids(tasks, results)

# Opt-in resumability: fire per-source counter deltas. A no-op (the
# client helpers self-disable) when no resumability actor is registered.
if _is_active():
results = self._apply_resumability_counters(tasks, results)

# Sentinels never propagate to the next stage.
results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))]

# Log performance stats and add to result tasks
_, stage_perf_stats = self._timer.log_stats()
# Consume and attach any custom metrics recorded by the stage during this call
Expand Down Expand Up @@ -168,6 +186,82 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
task.task_id = "r" + uuid.uuid4().hex
return out

# ------------------------------------------------------------------ #
# Resumability (opt-in). Runs only when a resumability actor is
# registered. task_ids are already assigned by _post_process_task_ids;
# this layer only stamps _source_id, fires per-source counter deltas, and
# drops already-completed sources. Sentinels are stripped by the caller.
# ------------------------------------------------------------------ #
def _apply_resumability_counters(self, input_tasks: list[Task], output_tasks: list[Task]) -> list[Task]: # noqa: C901
stage = self.stage
if getattr(stage, "is_source_stage", False):
return self._source_counters(output_tasks)

# Pre-source stages: inputs carry no _source_id, so there's nothing to
# track yet. Leave outputs untouched.
if all(not t._source_id for t in input_tasks):
return output_tasks

is_sink = stage.is_sink_stage

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: why would the user ever want something other than source stage being the first stage and sink stage being the last stage of the pipeline? Like if the last stage failed but the second to last stage was the sink stage, they just don't want to rerun the last stage?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sync stage not being the last stage, I think PDF pipelines are a good example. I have this metadata stage called stagePerfLogging at the end, which is needed because I cannot do stuff similar to benchmarking, since the pipeline.run never returns.

As for the source stage, I don't have an example in mind, but we don't need to force this assumption hence, I would prefer for us to keep it relaxed. From a user's perspective, if they don't specify, the default is that the first stage is source and the last stage is sync.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for source we have a case where user might provide initial task and no source stage is defined.

per_task: list[tuple[str, str, int]] = []
real = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like since real is only used in the if len(input_tasks) == 1 and len(output_tasks) != 1: block, let's move it there for easier reading?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. WIll make this edit if we have other edits as well in bulk


if len(input_tasks) == 1 and len(output_tasks) != 1:
# Genuine fan-out (1 -> N, N != 1): every real output descends from
# the single input. (The 1 -> 1 case falls through to the positional
# branch so a lone FailedTask is handled as "no delta".)
parent = input_tasks[0]
delta = -1 if is_sink else (len(real) - 1)
per_task.append((parent.task_id, parent._source_id, delta))
for c in real:
if not c._source_id:
c._source_id = parent._source_id
Comment on lines +209 to +218

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fan-out branch silently completes sources that have FailedTask outputs

The fan-out delta len(real) - 1 excludes FailedTask items from real, but does not compensate for the missing "pending debt" those failed slots represent. Concretely: for 1 input → [real_A, failed_B] at a non-sink stage, len(real) = 1, delta = 0, counter stays at 1; when real_A reaches the sink it fires -1 and the counter hits 0 — the source is marked complete, but failed_B was supposed to keep it pending for retry. For sink fan-out the bug is even more immediate: is_sink forces delta to -1 regardless of real, so the counter zeros out the moment the batch emits even if every output is a FailedTask.

The 1:1 positional branch handles FailedTask correctly with an explicit continue (no delta). The fan-out branch has no analogous guard. The simplest conservative fix is to bail out (no delta fired) whenever any fan-out output is a FailedTask — the source stays pending and will reprocess on resume, consistent with FailedTask's documented semantics.

elif len(output_tasks) == len(input_tasks):
# Positional 1:1, including filtered (NoneTask) / failed slots.
for parent, r in zip(input_tasks, output_tasks, strict=True):
sid = parent._source_id
if isinstance(r, FailedTask):
# No delta: the input stays pending so its source reruns.
continue
if isinstance(r, NoneTask):
per_task.append((parent.task_id, sid, -1))
continue
per_task.append((parent.task_id, sid, -1 if is_sink else 0))
if not r._source_id:
r._source_id = sid
else:
# M inputs -> K outputs (K != M): the parent of each output can't be
# determined, so the counter can't be updated correctly. Skip
# (the source counter stays pending -> reprocessed on resume).
logger.warning(
f"resumability: {type(stage).__name__} produced {len(output_tasks)} outputs "
f"for {len(input_tasks)} inputs; can't attribute sources, skipping counter "
f"update for this batch."
)
return output_tasks

_flush_deltas(per_task)
return output_tasks

def _source_counters(self, output_tasks: list[Task]) -> list[Task]:
"""Source stage: each output is a source partition. Its ``_source_id``
is its own (last) id segment — the content id or index assigned by
``_post_process_task_ids``. Already-completed sources are dropped; each
surviving source fires a ``+1``."""
sources = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))]
for t in sources:
t._source_id = t.task_id.rsplit("_", 1)[-1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 rsplit strips only the last segment, not the full suffix

t._source_id = t.task_id.rsplit("_", 1)[-1] takes only the last underscore-delimited token. For a source with task_id = "0_abc_def" (produced when get_deterministic_id() returns "abc_def"), this yields "def" instead of "abc_def". Two unrelated sources whose deterministic IDs differ only in a prefix (e.g. "shard1_42" and "shard2_42") would get the same _source_id = "42", and the first to complete would cause the second to be silently skipped on the next run. Since get_deterministic_id() is explicitly documented as overridable, this is a real footgun.

The correct extraction for sources is to strip just the parent prefix ("0_") using split("_", 1)[1] so that the entire content-based suffix is preserved as the source identity.

completed = _skip_completed_sources([t._source_id for t in sources])
per_task: list[tuple[str, str, int]] = []
survivors: list[Task] = []
for t in sources:
if t._source_id in completed:
continue
per_task.append((t.task_id, t._source_id, +1))
survivors.append(t)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused why the source ID here is the last part of the X_Y_Z chain? I guess maybe I am confused about how task_id versus _source_id are constructed/formatted.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is since source stage was the 3rd stage in your example. So each task_id is of one of the following format:

0_1_2_{sid}_2_1
0_1_2_0_1_1
r{uuid}_1_2

For the {sid}, the index purely depends on what index does source stage happens. It always starts with zero since an empty task. For most cases, the task_id will have this format:

0_{sid}_0_0_0_0: Single fanout at source and then filters.

_flush_deltas(per_task)
return survivors

def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None:
"""Setup the stage on a node.

Expand Down
64 changes: 62 additions & 2 deletions nemo_curator/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Any

from loguru import logger
Expand Down Expand Up @@ -222,18 +223,35 @@ def describe(self) -> str:

return "\n".join(lines)

def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | None = None) -> list[Task] | None:
def run(
self,
executor: BaseExecutor | None = None,
initial_tasks: list[Task] | None = None,
checkpoint_path: str | Path | None = None,
) -> list[Task] | None:
"""Run the pipeline.

Args:
executor (BaseExecutor): Executor to use
initial_tasks (list[Task], optional): Initial tasks to start the pipeline with. Defaults to None.
checkpoint_path (str | Path, optional): Directory used for
resumability. When set, completed source partitions are tracked
across runs and skipped on rerun; the tracking state lives in a
``.nemo_curator_metadata`` subdirectory. Multiple independent
runs (e.g. the tasks of a SLURM array) may point at the same
directory — each writes its own LMDB file, so there is no
shared-file contention. The actor lifecycle is owned by this
method; executors are not modified.

Returns:
list[Task] | None: List of tasks
"""
self.build()

if checkpoint_path is not None:
checkpoint_path = Path(checkpoint_path).absolute()
checkpoint_path.mkdir(parents=True, exist_ok=True)

if executor is None:
from nemo_curator.backends.xenna import XennaExecutor

Expand Down Expand Up @@ -263,4 +281,46 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] |
if initial_tasks:
assign_root_task_ids(initial_tasks)

return executor.execute(self.stages, initial_tasks)
if checkpoint_path is None:
return executor.execute(self.stages, initial_tasks)
return self._run_with_resumability(executor, initial_tasks, checkpoint_path)

def _run_with_resumability(
self,
executor: BaseExecutor,
initial_tasks: list[Task] | None,
checkpoint_path: Path,
) -> list[Task] | None:
"""Owns the full resumability-actor lifecycle. Per-backend executors
are not modified — the actor is spawned ``lifetime="detached"`` so
it survives executor-local ``ray.shutdown()`` calls.

The actor never raises (see ``ResumabilityActor.apply_deltas``), so
there's no watchdog and no error propagation path here — just spawn,
run, close.
"""
import ray

from nemo_curator.utils.resumability_actor import ResumabilityActor
from nemo_curator.utils.resumability_client import ACTOR_NAME

ray.init(ignore_reinit_error=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Shouldn't there have already been a RayClient().start() before pipeline.run() was called?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not enforce that right now, right? Like if the user forgets to include that, still our pipelines run. Ideally, I can add a check in the pipeline.run saying please either start a Ray client with RayClient.start() or SlurmClient with SLurmClient.start(). Would you prefer that? I would personally prefer that tbh.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point, it might be nice to enforce it in pipeline.run().

ResumabilityActor.options( # type: ignore[attr-defined]
name=ACTOR_NAME,
lifetime="detached",
get_if_exists=True,
max_pending_calls=100,
).remote(str(checkpoint_path))

try:
return executor.execute(self.stages, initial_tasks)
finally:
# The executor's ray.shutdown() may have run in its own
# finally:; reconnect to clean up the detached actor.
try:
ray.init(ignore_reinit_error=True)
actor_handle = ray.get_actor(ACTOR_NAME)
ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined]
ray.kill(actor_handle)
except Exception as e: # noqa: BLE001
logger.warning(f"resumability actor cleanup failed: {e}")
Comment on lines +317 to +326

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fire-and-forget deltas from the last batch can be discarded before close() drains them

executor.execute() returns when all task-processing workers are done, but the workers' fire-and-forget apply_deltas.remote() calls are async Ray messages that may still be in transit. Ray does not guarantee cross-actor ordering: the driver's close.remote() message can arrive at the actor before the workers' final apply_deltas messages, because message ordering is only guaranteed per sender-receiver pair.

With max_concurrency=1 the actor processes one call at a time, so once close() runs it sets self._env = None. Any apply_deltas messages that arrive after that and happen to call _persist_completed will hit an AttributeError on None.begin() inside the actor. ray.kill immediately afterward discards all remaining queued messages, so completions from the pipeline's final batch may not be persisted to LMDB.

The safest fix is to enqueue a no-op drain call after executor.execute() returns and before close(), ensuring all prior fire-and-forget messages are ahead of it in the mailbox.

4 changes: 3 additions & 1 deletion nemo_curator/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
from .file_group import FileGroupTask
from .image import ImageBatch, ImageObject
from .interleaved import InterleavedBatch
from .sentinels import EmptyTask, SentinelTask
from .sentinels import EmptyTask, FailedTask, NoneTask, SentinelTask
from .tasks import Task

__all__ = [
"AudioTask",
"DocumentBatch",
"EmptyTask",
"FailedTask",
"FileGroupTask",
"ImageBatch",
"ImageObject",
"InterleavedBatch",
"NoneTask",
"SentinelTask",
"Task",
]
29 changes: 26 additions & 3 deletions nemo_curator/tasks/sentinels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@
# limitations under the License.
"""Payload-less marker tasks.

``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). All markers
share the :class:`SentinelTask` base and carry no payload (``data is None``).
Construct one with ``EmptyTask()``.
``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). The resumability
layer adds two more markers on the same :class:`SentinelTask` base:

- ``NoneTask`` — this slot was intentionally filtered. The resumability counter
treats it as a consumed branch (decrements). The adapter auto-wraps a
returned ``None`` as a ``NoneTask``.
- ``FailedTask`` — this slot failed and should be retried on resume. The counter
is NOT decremented, so its source stays pending and reruns.

All carry no payload (``data is None``) and get their ``task_id`` assigned by
the executor adapter; sentinels are stripped before the next stage. Construct
with ``EmptyTask()`` / ``NoneTask()`` / ``FailedTask()``.
"""

from dataclasses import dataclass, field
Expand Down Expand Up @@ -52,3 +61,17 @@ class EmptyTask(SentinelTask):

dataset_name: str = "empty"
task_id: str = field(init=False, default="0")


@dataclass
class NoneTask(SentinelTask):
"""Marks a slot as intentionally filtered (resumability counter decrements)."""

dataset_name: str = "none"


@dataclass
class FailedTask(SentinelTask):
Comment thread
sarahyurick marked this conversation as resolved.
"""Marks a slot as failed → retried on resume (counter does NOT decrement)."""

dataset_name: str = "failed"
5 changes: 5 additions & 0 deletions nemo_curator/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ class Task(ABC, Generic[T]):
NON-deterministic (differ across runs).
dataset_name: Name of the dataset this task belongs to.
_stage_perf: List of stages perfs this task has passed through.
_source_id: Identifier of the source (input partition) this task

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit-ish but the docstring for Task is out of order and does not include data or _metadata.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix if there are other updates

descends from. Stamped at the source stage and inherited
downstream; used only by the (opt-in) resumability layer to
track which sources have completed. Empty for pre-source tasks.
"""

dataset_name: str
data: T
_stage_perf: list[StagePerfStats] = field(default_factory=list)
_metadata: dict[str, Any] = field(default_factory=dict)
task_id: str = field(init=False, default="")
_source_id: str = field(init=False, default="")

def __post_init__(self) -> None:
"""Post-initialization hook."""
Expand Down
Loading
Loading