Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions nemo_curator/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from loguru import logger

from nemo_curator.core.utils import ignore_ray_head_node
from nemo_curator.tasks import Task
from nemo_curator.tasks.sentinels import FailedTask, NoneTask
from nemo_curator.utils.performance_utils import StageTimer
from nemo_curator.utils.resumability_client import _flush_deltas, _is_active, _skip_completed_sources

if TYPE_CHECKING:
from nemo_curator.stages.base import ProcessingStage
Expand Down Expand Up @@ -85,9 +89,23 @@ def process_batch(self, tasks: list[Task]) -> list[Task]:
# Use the batch processing logic
results = self.stage.process_batch(tasks)

# A returned ``None`` ("filter this slot") becomes a NoneTask so every
# output is a real Task that gets a task_id. Sentinels (NoneTask /
# FailedTask) carry no identity and are stripped again before this
# method returns.
results = [NoneTask() if r is None else r for r in results]

# Guarantee every emitted task has a task_id (derived id, or uuid fallback).
results = self._post_process_task_ids(tasks, results)

# Opt-in resumability: fire per-source counter deltas. A no-op (the
# client helpers self-disable) when no resumability actor is registered.
if _is_active():
results = self._apply_resumability_counters(tasks, results)

# Sentinels never propagate to the next stage.
results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))]

# Log performance stats and add to result tasks
_, stage_perf_stats = self._timer.log_stats()
# Consume and attach any custom metrics recorded by the stage during this call
Expand Down Expand Up @@ -168,6 +186,82 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas
task.task_id = "r" + uuid.uuid4().hex
return out

# ------------------------------------------------------------------ #
# Resumability (opt-in). Runs only when a resumability actor is
# registered. task_ids are already assigned by _post_process_task_ids;
# this layer only stamps _source_id, fires per-source counter deltas, and
# drops already-completed sources. Sentinels are stripped by the caller.
# ------------------------------------------------------------------ #
def _apply_resumability_counters(self, input_tasks: list[Task], output_tasks: list[Task]) -> list[Task]: # noqa: C901
stage = self.stage
if getattr(stage, "is_source_stage", False):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: why would the user ever want something other than source stage being the first stage and sink stage being the last stage of the pipeline? Like if the last stage failed but the second to last stage was the sink stage, they just don't want to rerun the last stage?

return self._source_counters(output_tasks)

# Pre-source stages: inputs carry no _source_id, so there's nothing to
# track yet. Leave outputs untouched.
if all(not t._source_id for t in input_tasks):
return output_tasks

Comment on lines +195 to +204

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Single-stage pipeline never marks sources complete

_apply_resumability_counters short-circuits to _source_counters whenever is_source_stage is True, returning before the is_sink_stage branch can fire a −1 delta. Pipeline.build() assigns both flags to the same stage when only one stage is present (stages[0] == stages[-1]). As a result, every source fires +1 but nothing ever fires −1, so _pending stays at 1 for every source indefinitely and the checkpoint is never written. Every resume reprocesses the full pipeline.

is_sink = stage.is_sink_stage
per_task: list[tuple[str, str, int]] = []
real = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like since real is only used in the if len(input_tasks) == 1 and len(output_tasks) != 1: block, let's move it there for easier reading?


if len(input_tasks) == 1 and len(output_tasks) != 1:
# Genuine fan-out (1 -> N, N != 1): every real output descends from
# the single input. (The 1 -> 1 case falls through to the positional
# branch so a lone FailedTask is handled as "no delta".)
parent = input_tasks[0]
delta = -1 if is_sink else (len(real) - 1)
per_task.append((parent.task_id, parent._source_id, delta))
for c in real:
if not c._source_id:
c._source_id = parent._source_id
elif len(output_tasks) == len(input_tasks):
# Positional 1:1, including filtered (NoneTask) / failed slots.
for parent, r in zip(input_tasks, output_tasks, strict=True):
sid = parent._source_id
if isinstance(r, FailedTask):
# No delta: the input stays pending so its source reruns.
continue
if isinstance(r, NoneTask):
per_task.append((parent.task_id, sid, -1))
continue
per_task.append((parent.task_id, sid, -1 if is_sink else 0))
if not r._source_id:
r._source_id = sid
else:
# M inputs -> K outputs (K != M): the parent of each output can't be
# determined, so the counter can't be updated correctly. Skip
# (the source counter stays pending -> reprocessed on resume).
logger.warning(
f"resumability: {type(stage).__name__} produced {len(output_tasks)} outputs "
f"for {len(input_tasks)} inputs; can't attribute sources, skipping counter "
f"update for this batch."
)
return output_tasks

_flush_deltas(per_task)
return output_tasks

def _source_counters(self, output_tasks: list[Task]) -> list[Task]:
"""Source stage: each output is a source partition. Its ``_source_id``
is its own (last) id segment — the content id or index assigned by
``_post_process_task_ids``. Already-completed sources are dropped; each
surviving source fires a ``+1``."""
sources = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))]
for t in sources:
t._source_id = t.task_id.rsplit("_", 1)[-1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused why the source ID here is the last part of the X_Y_Z chain? I guess maybe I am confused about how task_id versus _source_id are constructed/formatted.

completed = _skip_completed_sources([t._source_id for t in sources])
per_task: list[tuple[str, str, int]] = []
survivors: list[Task] = []
for t in sources:
if t._source_id in completed:
continue
per_task.append((t.task_id, t._source_id, +1))
survivors.append(t)
_flush_deltas(per_task)
return survivors

def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None:
"""Setup the stage on a node.

Expand Down
64 changes: 62 additions & 2 deletions nemo_curator/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Any

from loguru import logger
Expand Down Expand Up @@ -222,18 +223,35 @@ def describe(self) -> str:

return "\n".join(lines)

def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | None = None) -> list[Task] | None:
def run(
self,
executor: BaseExecutor | None = None,
initial_tasks: list[Task] | None = None,
checkpoint_path: str | Path | None = None,
) -> list[Task] | None:
"""Run the pipeline.

Args:
executor (BaseExecutor): Executor to use
initial_tasks (list[Task], optional): Initial tasks to start the pipeline with. Defaults to None.
checkpoint_path (str | Path, optional): Directory used for
resumability. When set, completed source partitions are tracked
across runs and skipped on rerun; the tracking state lives in a
``.nemo_curator_metadata`` subdirectory. Multiple independent
runs (e.g. the tasks of a SLURM array) may point at the same
directory — each writes its own LMDB file, so there is no
shared-file contention. The actor lifecycle is owned by this
method; executors are not modified.

Returns:
list[Task] | None: List of tasks
"""
self.build()

if checkpoint_path is not None:
checkpoint_path = Path(checkpoint_path).absolute()
checkpoint_path.mkdir(parents=True, exist_ok=True)

if executor is None:
from nemo_curator.backends.xenna import XennaExecutor

Expand Down Expand Up @@ -263,4 +281,46 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] |
if initial_tasks:
assign_root_task_ids(initial_tasks)

return executor.execute(self.stages, initial_tasks)
if checkpoint_path is None:
return executor.execute(self.stages, initial_tasks)
return self._run_with_resumability(executor, initial_tasks, checkpoint_path)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like this, the reason this is being done is probably coz we call ray init inside the executor and not here, so we need to differentiate that. Not ideal for sure. Maybe we can make a function for _start_resumability_actor and call it inside each executor. Design choice.


def _run_with_resumability(
self,
executor: BaseExecutor,
initial_tasks: list[Task] | None,
checkpoint_path: Path,
) -> list[Task] | None:
"""Owns the full resumability-actor lifecycle. Per-backend executors
are not modified — the actor is spawned ``lifetime="detached"`` so
it survives executor-local ``ray.shutdown()`` calls.

The actor never raises (see ``ResumabilityActor.apply_deltas``), so
there's no watchdog and no error propagation path here — just spawn,
run, close.
"""
import ray

from nemo_curator.utils.resumability_actor import ResumabilityActor
from nemo_curator.utils.resumability_client import ACTOR_NAME

ray.init(ignore_reinit_error=True)
ResumabilityActor.options( # type: ignore[attr-defined]
name=ACTOR_NAME,
lifetime="detached",
get_if_exists=True,
max_pending_calls=100,
).remote(str(checkpoint_path))
Comment on lines +308 to +313

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fire-and-forget delta loss when _max_pending_calls=100 is exceeded

_max_pending_calls=100 on the actor causes Ray to embed a RayActorError in the returned ObjectRef when the queue is full. Because _flush_deltas never calls ray.get() on that ref, the rejection is silently swallowed and the delta is permanently lost. For a pipeline processing more than 100 batches before the actor can drain its queue (trivially possible at high throughput), any source whose counter delta was lost will have _pending stuck above 0 and will never be written to LMDB. The PR description claims "_max_pending_calls provides backpressure", but the actual behavior is silent delta loss, not blocking — the two have opposite throughput implications. A limit of 100 is also very low relative to a typical pipeline's batch concurrency. Consider either raising the limit substantially, adding a ray.get() health check in the watchdog thread to surface errors, or verifying the actual Ray behavior for the fire-and-forget case.

Comment on lines +308 to +313

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Actor handle is discarded; silent resumability bypass on spawn failure

ResumabilityActor.options(...).remote(str(checkpoint_path)) returns an actor handle that is immediately discarded. Ray actor creation is asynchronous — if the actor fails to initialize (e.g., resource pressure, LMDB permission error), the .remote() call still returns a handle, the failure is deferred, and subsequent ray.get_actor(ACTOR_NAME) in _is_active() will raise ValueError returning None. The pipeline then runs without resumability and without any user-visible error; completed sources are never checkpointed.

Capturing the handle and making at least one synchronous ray.get call on a lightweight probe (e.g., are_completed([])) before the executor runs would surface initialization failures eagerly.


try:
return executor.execute(self.stages, initial_tasks)
finally:
# The executor's ray.shutdown() may have run in its own
# finally:; reconnect to clean up the detached actor.
try:
ray.init(ignore_reinit_error=True)
actor_handle = ray.get_actor(ACTOR_NAME)
ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined]
ray.kill(actor_handle)
except Exception as e: # noqa: BLE001
logger.warning(f"resumability actor cleanup failed: {e}")
Comment on lines +320 to +326

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ray.kill is unreachable when the close() RPC times out. ray.get(..., timeout=10) raises GetTimeoutError (a subclass of Exception), which is caught by the outer except Exception block — so ray.kill(actor_handle) on the next line is never executed. The detached actor keeps running and holds the LMDB file open indefinitely. On the next Pipeline.run(checkpoint_path=...) call, get_if_exists=True returns the leaked actor, which is pointing at the original path but may have stale in-memory state.

Separate the kill so it always runs regardless of whether close() succeeds.

Suggested change
try:
ray.init(ignore_reinit_error=True)
actor_handle = ray.get_actor(ACTOR_NAME)
ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined]
ray.kill(actor_handle)
except Exception as e: # noqa: BLE001
logger.warning(f"resumability actor cleanup failed: {e}")
try:
ray.init(ignore_reinit_error=True)
actor_handle = ray.get_actor(ACTOR_NAME)
try:
ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined]
except Exception as e: # noqa: BLE001
logger.warning(f"resumability actor close failed: {e}")
finally:
ray.kill(actor_handle)
except Exception as e: # noqa: BLE001
logger.warning(f"resumability actor cleanup failed: {e}")

4 changes: 3 additions & 1 deletion nemo_curator/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
from .file_group import FileGroupTask
from .image import ImageBatch, ImageObject
from .interleaved import InterleavedBatch
from .sentinels import EmptyTask, SentinelTask
from .sentinels import EmptyTask, FailedTask, NoneTask, SentinelTask
from .tasks import Task

__all__ = [
"AudioTask",
"DocumentBatch",
"EmptyTask",
"FailedTask",
"FileGroupTask",
"ImageBatch",
"ImageObject",
"InterleavedBatch",
"NoneTask",
"SentinelTask",
"Task",
]
29 changes: 26 additions & 3 deletions nemo_curator/tasks/sentinels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@
# limitations under the License.
"""Payload-less marker tasks.

``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). All markers
share the :class:`SentinelTask` base and carry no payload (``data is None``).
Construct one with ``EmptyTask()``.
``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). The resumability
layer adds two more markers on the same :class:`SentinelTask` base:

- ``NoneTask`` — this slot was intentionally filtered. The resumability counter
treats it as a consumed branch (decrements). The adapter auto-wraps a
returned ``None`` as a ``NoneTask``.
- ``FailedTask`` — this slot failed and should be retried on resume. The counter
is NOT decremented, so its source stays pending and reruns.

All carry no payload (``data is None``) and get their ``task_id`` assigned by
the executor adapter; sentinels are stripped before the next stage. Construct
with ``EmptyTask()`` / ``NoneTask()`` / ``FailedTask()``.
"""

from dataclasses import dataclass, field
Expand Down Expand Up @@ -52,3 +61,17 @@ class EmptyTask(SentinelTask):

dataset_name: str = "empty"
task_id: str = field(init=False, default="0")


@dataclass
class NoneTask(SentinelTask):
"""Marks a slot as intentionally filtered (resumability counter decrements)."""

dataset_name: str = "none"


@dataclass
class FailedTask(SentinelTask):
"""Marks a slot as failed → retried on resume (counter does NOT decrement)."""

dataset_name: str = "failed"
5 changes: 5 additions & 0 deletions nemo_curator/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ class Task(ABC, Generic[T]):
NON-deterministic (differ across runs).
dataset_name: Name of the dataset this task belongs to.
_stage_perf: List of stages perfs this task has passed through.
_source_id: Identifier of the source (input partition) this task
descends from. Stamped at the source stage and inherited
downstream; used only by the (opt-in) resumability layer to
track which sources have completed. Empty for pre-source tasks.
"""

dataset_name: str
data: T
_stage_perf: list[StagePerfStats] = field(default_factory=list)
_metadata: dict[str, Any] = field(default_factory=dict)
task_id: str = field(init=False, default="")
_source_id: str = field(init=False, default="")

def __post_init__(self) -> None:
"""Post-initialization hook."""
Expand Down
Loading
Loading