-
Notifications
You must be signed in to change notification settings - Fork 287
Pipeline resumability via source-level counter checkpointing #2063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,9 +17,13 @@ | |
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from loguru import logger | ||
|
|
||
| from nemo_curator.core.utils import ignore_ray_head_node | ||
| from nemo_curator.tasks import Task | ||
| from nemo_curator.tasks.sentinels import FailedTask, NoneTask | ||
| from nemo_curator.utils.performance_utils import StageTimer | ||
| from nemo_curator.utils.resumability_client import _flush_deltas, _is_active, _skip_completed_sources | ||
|
|
||
| if TYPE_CHECKING: | ||
| from nemo_curator.stages.base import ProcessingStage | ||
|
|
@@ -85,9 +89,23 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: | |
| # Use the batch processing logic | ||
| results = self.stage.process_batch(tasks) | ||
|
|
||
| # A returned ``None`` ("filter this slot") becomes a NoneTask so every | ||
| # output is a real Task that gets a task_id. Sentinels (NoneTask / | ||
| # FailedTask) carry no identity and are stripped again before this | ||
| # method returns. | ||
| results = [NoneTask() if r is None else r for r in results] | ||
|
|
||
| # Guarantee every emitted task has a task_id (derived id, or uuid fallback). | ||
| results = self._post_process_task_ids(tasks, results) | ||
|
|
||
| # Opt-in resumability: fire per-source counter deltas. A no-op (the | ||
| # client helpers self-disable) when no resumability actor is registered. | ||
| if _is_active(): | ||
| results = self._apply_resumability_counters(tasks, results) | ||
|
|
||
| # Sentinels never propagate to the next stage. | ||
| results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))] | ||
|
|
||
| # Log performance stats and add to result tasks | ||
| _, stage_perf_stats = self._timer.log_stats() | ||
| # Consume and attach any custom metrics recorded by the stage during this call | ||
|
|
@@ -168,6 +186,82 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas | |
| task.task_id = "r" + uuid.uuid4().hex | ||
| return out | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Resumability (opt-in). Runs only when a resumability actor is | ||
| # registered. task_ids are already assigned by _post_process_task_ids; | ||
| # this layer only stamps _source_id, fires per-source counter deltas, and | ||
| # drops already-completed sources. Sentinels are stripped by the caller. | ||
| # ------------------------------------------------------------------ # | ||
| def _apply_resumability_counters(self, input_tasks: list[Task], output_tasks: list[Task]) -> list[Task]: # noqa: C901 | ||
| stage = self.stage | ||
| if getattr(stage, "is_source_stage", False): | ||
| return self._source_counters(output_tasks) | ||
|
|
||
| # Pre-source stages: inputs carry no _source_id, so there's nothing to | ||
| # track yet. Leave outputs untouched. | ||
| if all(not t._source_id for t in input_tasks): | ||
| return output_tasks | ||
|
|
||
| is_sink = stage.is_sink_stage | ||
| per_task: list[tuple[str, str, int]] = [] | ||
| real = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. WIll make this edit if we have other edits as well in bulk |
||
|
|
||
| if len(input_tasks) == 1 and len(output_tasks) != 1: | ||
| # Genuine fan-out (1 -> N, N != 1): every real output descends from | ||
| # the single input. (The 1 -> 1 case falls through to the positional | ||
| # branch so a lone FailedTask is handled as "no delta".) | ||
| parent = input_tasks[0] | ||
| delta = -1 if is_sink else (len(real) - 1) | ||
| per_task.append((parent.task_id, parent._source_id, delta)) | ||
| for c in real: | ||
| if not c._source_id: | ||
| c._source_id = parent._source_id | ||
|
Comment on lines
+209
to
+218
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The fan-out delta The 1:1 positional branch handles |
||
| elif len(output_tasks) == len(input_tasks): | ||
| # Positional 1:1, including filtered (NoneTask) / failed slots. | ||
| for parent, r in zip(input_tasks, output_tasks, strict=True): | ||
| sid = parent._source_id | ||
| if isinstance(r, FailedTask): | ||
| # No delta: the input stays pending so its source reruns. | ||
| continue | ||
| if isinstance(r, NoneTask): | ||
| per_task.append((parent.task_id, sid, -1)) | ||
| continue | ||
| per_task.append((parent.task_id, sid, -1 if is_sink else 0)) | ||
| if not r._source_id: | ||
| r._source_id = sid | ||
| else: | ||
| # M inputs -> K outputs (K != M): the parent of each output can't be | ||
| # determined, so the counter can't be updated correctly. Skip | ||
| # (the source counter stays pending -> reprocessed on resume). | ||
| logger.warning( | ||
| f"resumability: {type(stage).__name__} produced {len(output_tasks)} outputs " | ||
| f"for {len(input_tasks)} inputs; can't attribute sources, skipping counter " | ||
| f"update for this batch." | ||
| ) | ||
| return output_tasks | ||
|
|
||
| _flush_deltas(per_task) | ||
| return output_tasks | ||
|
|
||
| def _source_counters(self, output_tasks: list[Task]) -> list[Task]: | ||
| """Source stage: each output is a source partition. Its ``_source_id`` | ||
| is its own (last) id segment — the content id or index assigned by | ||
| ``_post_process_task_ids``. Already-completed sources are dropped; each | ||
| surviving source fires a ``+1``.""" | ||
| sources = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))] | ||
| for t in sources: | ||
| t._source_id = t.task_id.rsplit("_", 1)[-1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The correct extraction for sources is to strip just the parent prefix ( |
||
| completed = _skip_completed_sources([t._source_id for t in sources]) | ||
| per_task: list[tuple[str, str, int]] = [] | ||
| survivors: list[Task] = [] | ||
| for t in sources: | ||
| if t._source_id in completed: | ||
| continue | ||
| per_task.append((t.task_id, t._source_id, +1)) | ||
| survivors.append(t) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a little confused why the source ID here is the last part of the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is since source stage was the 3rd stage in your example. So each task_id is of one of the following format: 0_1_2_{sid}_2_1 For the {sid}, the index purely depends on what index does source stage happens. It always starts with zero since an empty task. For most cases, the task_id will have this format: 0_{sid}_0_0_0_0: Single fanout at source and then filters. |
||
| _flush_deltas(per_task) | ||
| return survivors | ||
|
|
||
| def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: | ||
| """Setup the stage on a node. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from loguru import logger | ||
|
|
@@ -222,18 +223,35 @@ def describe(self) -> str: | |
|
|
||
| return "\n".join(lines) | ||
|
|
||
| def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | None = None) -> list[Task] | None: | ||
| def run( | ||
| self, | ||
| executor: BaseExecutor | None = None, | ||
| initial_tasks: list[Task] | None = None, | ||
| checkpoint_path: str | Path | None = None, | ||
| ) -> list[Task] | None: | ||
| """Run the pipeline. | ||
|
|
||
| Args: | ||
| executor (BaseExecutor): Executor to use | ||
| initial_tasks (list[Task], optional): Initial tasks to start the pipeline with. Defaults to None. | ||
| checkpoint_path (str | Path, optional): Directory used for | ||
| resumability. When set, completed source partitions are tracked | ||
| across runs and skipped on rerun; the tracking state lives in a | ||
| ``.nemo_curator_metadata`` subdirectory. Multiple independent | ||
| runs (e.g. the tasks of a SLURM array) may point at the same | ||
| directory — each writes its own LMDB file, so there is no | ||
| shared-file contention. The actor lifecycle is owned by this | ||
| method; executors are not modified. | ||
|
|
||
| Returns: | ||
| list[Task] | None: List of tasks | ||
| """ | ||
| self.build() | ||
|
|
||
| if checkpoint_path is not None: | ||
| checkpoint_path = Path(checkpoint_path).absolute() | ||
| checkpoint_path.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| if executor is None: | ||
| from nemo_curator.backends.xenna import XennaExecutor | ||
|
|
||
|
|
@@ -263,4 +281,46 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | | |
| if initial_tasks: | ||
| assign_root_task_ids(initial_tasks) | ||
|
|
||
| return executor.execute(self.stages, initial_tasks) | ||
| if checkpoint_path is None: | ||
| return executor.execute(self.stages, initial_tasks) | ||
| return self._run_with_resumability(executor, initial_tasks, checkpoint_path) | ||
|
|
||
| def _run_with_resumability( | ||
| self, | ||
| executor: BaseExecutor, | ||
| initial_tasks: list[Task] | None, | ||
| checkpoint_path: Path, | ||
| ) -> list[Task] | None: | ||
| """Owns the full resumability-actor lifecycle. Per-backend executors | ||
| are not modified — the actor is spawned ``lifetime="detached"`` so | ||
| it survives executor-local ``ray.shutdown()`` calls. | ||
|
|
||
| The actor never raises (see ``ResumabilityActor.apply_deltas``), so | ||
| there's no watchdog and no error propagation path here — just spawn, | ||
| run, close. | ||
| """ | ||
| import ray | ||
|
|
||
| from nemo_curator.utils.resumability_actor import ResumabilityActor | ||
| from nemo_curator.utils.resumability_client import ACTOR_NAME | ||
|
|
||
| ray.init(ignore_reinit_error=True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? Shouldn't there have already been a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do not enforce that right now, right? Like if the user forgets to include that, still our pipelines run. Ideally, I can add a check in the pipeline.run saying please either start a Ray client with RayClient.start() or SlurmClient with SLurmClient.start(). Would you prefer that? I would personally prefer that tbh.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah good point, it might be nice to enforce it in |
||
| ResumabilityActor.options( # type: ignore[attr-defined] | ||
| name=ACTOR_NAME, | ||
| lifetime="detached", | ||
| get_if_exists=True, | ||
| max_pending_calls=100, | ||
| ).remote(str(checkpoint_path)) | ||
|
|
||
| try: | ||
| return executor.execute(self.stages, initial_tasks) | ||
| finally: | ||
| # The executor's ray.shutdown() may have run in its own | ||
| # finally:; reconnect to clean up the detached actor. | ||
| try: | ||
| ray.init(ignore_reinit_error=True) | ||
| actor_handle = ray.get_actor(ACTOR_NAME) | ||
| ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined] | ||
| ray.kill(actor_handle) | ||
| except Exception as e: # noqa: BLE001 | ||
| logger.warning(f"resumability actor cleanup failed: {e}") | ||
|
Comment on lines
+317
to
+326
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
With The safest fix is to enqueue a no-op drain call after |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,13 +46,18 @@ class Task(ABC, Generic[T]): | |
| NON-deterministic (differ across runs). | ||
| dataset_name: Name of the dataset this task belongs to. | ||
| _stage_perf: List of stages perfs this task has passed through. | ||
| _source_id: Identifier of the source (input partition) this task | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit-ish but the docstring for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will fix if there are other updates |
||
| descends from. Stamped at the source stage and inherited | ||
| downstream; used only by the (opt-in) resumability layer to | ||
| track which sources have completed. Empty for pre-source tasks. | ||
| """ | ||
|
|
||
| dataset_name: str | ||
| data: T | ||
| _stage_perf: list[StagePerfStats] = field(default_factory=list) | ||
| _metadata: dict[str, Any] = field(default_factory=dict) | ||
| task_id: str = field(init=False, default="") | ||
| _source_id: str = field(init=False, default="") | ||
|
|
||
| def __post_init__(self) -> None: | ||
| """Post-initialization hook.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question: why would the user ever want something other than source stage being the first stage and sink stage being the last stage of the pipeline? Like if the last stage failed but the second to last stage was the sink stage, they just don't want to rerun the last stage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sync stage not being the last stage, I think PDF pipelines are a good example. I have this metadata stage called stagePerfLogging at the end, which is needed because I cannot do stuff similar to benchmarking, since the pipeline.run never returns.
As for the source stage, I don't have an example in mind, but we don't need to force this assumption hence, I would prefer for us to keep it relaxed. From a user's perspective, if they don't specify, the default is that the first stage is source and the last stage is sync.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also for source we have a case where user might provide initial task and no source stage is defined.