-
Notifications
You must be signed in to change notification settings - Fork 287
Pipeline resumability via source-level counter checkpointing #2033
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,9 +17,13 @@ | |
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from loguru import logger | ||
|
|
||
| from nemo_curator.core.utils import ignore_ray_head_node | ||
| from nemo_curator.tasks import Task | ||
| from nemo_curator.tasks.sentinels import FailedTask, NoneTask | ||
| from nemo_curator.utils.performance_utils import StageTimer | ||
| from nemo_curator.utils.resumability_client import _flush_deltas, _is_active, _skip_completed_sources | ||
|
|
||
| if TYPE_CHECKING: | ||
| from nemo_curator.stages.base import ProcessingStage | ||
|
|
@@ -85,9 +89,23 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: | |
| # Use the batch processing logic | ||
| results = self.stage.process_batch(tasks) | ||
|
|
||
| # A returned ``None`` ("filter this slot") becomes a NoneTask so every | ||
| # output is a real Task that gets a task_id. Sentinels (NoneTask / | ||
| # FailedTask) carry no identity and are stripped again before this | ||
| # method returns. | ||
| results = [NoneTask() if r is None else r for r in results] | ||
|
|
||
| # Guarantee every emitted task has a task_id (derived id, or uuid fallback). | ||
| results = self._post_process_task_ids(tasks, results) | ||
|
|
||
| # Opt-in resumability: fire per-source counter deltas. A no-op (the | ||
| # client helpers self-disable) when no resumability actor is registered. | ||
| if _is_active(): | ||
| results = self._apply_resumability_counters(tasks, results) | ||
|
|
||
| # Sentinels never propagate to the next stage. | ||
| results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))] | ||
|
|
||
| # Log performance stats and add to result tasks | ||
| _, stage_perf_stats = self._timer.log_stats() | ||
| # Consume and attach any custom metrics recorded by the stage during this call | ||
|
|
@@ -168,6 +186,82 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas | |
| task.task_id = "r" + uuid.uuid4().hex | ||
| return out | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Resumability (opt-in). Runs only when a resumability actor is | ||
| # registered. task_ids are already assigned by _post_process_task_ids; | ||
| # this layer only stamps _source_id, fires per-source counter deltas, and | ||
| # drops already-completed sources. Sentinels are stripped by the caller. | ||
| # ------------------------------------------------------------------ # | ||
| def _apply_resumability_counters(self, input_tasks: list[Task], output_tasks: list[Task]) -> list[Task]: # noqa: C901 | ||
| stage = self.stage | ||
| if getattr(stage, "is_source_stage", False): | ||
| return self._source_counters(output_tasks) | ||
|
|
||
| # Pre-source stages: inputs carry no _source_id, so there's nothing to | ||
| # track yet. Leave outputs untouched. | ||
| if all(not t._source_id for t in input_tasks): | ||
| return output_tasks | ||
|
|
||
|
Comment on lines
+195
to
+204
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| is_sink = stage.is_sink_stage | ||
| per_task: list[tuple[str, str, int]] = [] | ||
| real = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like since |
||
|
|
||
| if len(input_tasks) == 1 and len(output_tasks) != 1: | ||
| # Genuine fan-out (1 -> N, N != 1): every real output descends from | ||
| # the single input. (The 1 -> 1 case falls through to the positional | ||
| # branch so a lone FailedTask is handled as "no delta".) | ||
| parent = input_tasks[0] | ||
| delta = -1 if is_sink else (len(real) - 1) | ||
| per_task.append((parent.task_id, parent._source_id, delta)) | ||
| for c in real: | ||
| if not c._source_id: | ||
| c._source_id = parent._source_id | ||
| elif len(output_tasks) == len(input_tasks): | ||
| # Positional 1:1, including filtered (NoneTask) / failed slots. | ||
| for parent, r in zip(input_tasks, output_tasks, strict=True): | ||
| sid = parent._source_id | ||
| if isinstance(r, FailedTask): | ||
| # No delta: the input stays pending so its source reruns. | ||
| continue | ||
| if isinstance(r, NoneTask): | ||
| per_task.append((parent.task_id, sid, -1)) | ||
| continue | ||
| per_task.append((parent.task_id, sid, -1 if is_sink else 0)) | ||
| if not r._source_id: | ||
| r._source_id = sid | ||
| else: | ||
| # M inputs -> K outputs (K != M): the parent of each output can't be | ||
| # determined, so the counter can't be updated correctly. Skip | ||
| # (the source counter stays pending -> reprocessed on resume). | ||
| logger.warning( | ||
| f"resumability: {type(stage).__name__} produced {len(output_tasks)} outputs " | ||
| f"for {len(input_tasks)} inputs; can't attribute sources, skipping counter " | ||
| f"update for this batch." | ||
| ) | ||
| return output_tasks | ||
|
|
||
| _flush_deltas(per_task) | ||
| return output_tasks | ||
|
|
||
| def _source_counters(self, output_tasks: list[Task]) -> list[Task]: | ||
| """Source stage: each output is a source partition. Its ``_source_id`` | ||
| is its own (last) id segment — the content id or index assigned by | ||
| ``_post_process_task_ids``. Already-completed sources are dropped; each | ||
| surviving source fires a ``+1``.""" | ||
| sources = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))] | ||
| for t in sources: | ||
| t._source_id = t.task_id.rsplit("_", 1)[-1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a little confused why the source ID here is the last part of the |
||
| completed = _skip_completed_sources([t._source_id for t in sources]) | ||
| per_task: list[tuple[str, str, int]] = [] | ||
| survivors: list[Task] = [] | ||
| for t in sources: | ||
| if t._source_id in completed: | ||
| continue | ||
| per_task.append((t.task_id, t._source_id, +1)) | ||
| survivors.append(t) | ||
| _flush_deltas(per_task) | ||
| return survivors | ||
|
|
||
| def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: | ||
| """Setup the stage on a node. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from loguru import logger | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -222,18 +223,35 @@ def describe(self) -> str: | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| return "\n".join(lines) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | None = None) -> list[Task] | None: | ||||||||||||||||||||||||||||||||||||||
| def run( | ||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||
| executor: BaseExecutor | None = None, | ||||||||||||||||||||||||||||||||||||||
| initial_tasks: list[Task] | None = None, | ||||||||||||||||||||||||||||||||||||||
| checkpoint_path: str | Path | None = None, | ||||||||||||||||||||||||||||||||||||||
| ) -> list[Task] | None: | ||||||||||||||||||||||||||||||||||||||
| """Run the pipeline. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||
| executor (BaseExecutor): Executor to use | ||||||||||||||||||||||||||||||||||||||
| initial_tasks (list[Task], optional): Initial tasks to start the pipeline with. Defaults to None. | ||||||||||||||||||||||||||||||||||||||
| checkpoint_path (str | Path, optional): Directory used for | ||||||||||||||||||||||||||||||||||||||
| resumability. When set, completed source partitions are tracked | ||||||||||||||||||||||||||||||||||||||
| across runs and skipped on rerun; the tracking state lives in a | ||||||||||||||||||||||||||||||||||||||
| ``.nemo_curator_metadata`` subdirectory. Multiple independent | ||||||||||||||||||||||||||||||||||||||
| runs (e.g. the tasks of a SLURM array) may point at the same | ||||||||||||||||||||||||||||||||||||||
| directory — each writes its own LMDB file, so there is no | ||||||||||||||||||||||||||||||||||||||
| shared-file contention. The actor lifecycle is owned by this | ||||||||||||||||||||||||||||||||||||||
| method; executors are not modified. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||
| list[Task] | None: List of tasks | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| self.build() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if checkpoint_path is not None: | ||||||||||||||||||||||||||||||||||||||
| checkpoint_path = Path(checkpoint_path).absolute() | ||||||||||||||||||||||||||||||||||||||
| checkpoint_path.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if executor is None: | ||||||||||||||||||||||||||||||||||||||
| from nemo_curator.backends.xenna import XennaExecutor | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -263,4 +281,46 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | | |||||||||||||||||||||||||||||||||||||
| if initial_tasks: | ||||||||||||||||||||||||||||||||||||||
| assign_root_task_ids(initial_tasks) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| return executor.execute(self.stages, initial_tasks) | ||||||||||||||||||||||||||||||||||||||
| if checkpoint_path is None: | ||||||||||||||||||||||||||||||||||||||
| return executor.execute(self.stages, initial_tasks) | ||||||||||||||||||||||||||||||||||||||
| return self._run_with_resumability(executor, initial_tasks, checkpoint_path) | ||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not like this, the reason this is being done is probably coz we call ray init inside the executor and not here, so we need to differentiate that. Not ideal for sure. Maybe we can make a function for _start_resumability_actor and call it inside each executor. Design choice. |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _run_with_resumability( | ||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||
| executor: BaseExecutor, | ||||||||||||||||||||||||||||||||||||||
| initial_tasks: list[Task] | None, | ||||||||||||||||||||||||||||||||||||||
| checkpoint_path: Path, | ||||||||||||||||||||||||||||||||||||||
| ) -> list[Task] | None: | ||||||||||||||||||||||||||||||||||||||
| """Owns the full resumability-actor lifecycle. Per-backend executors | ||||||||||||||||||||||||||||||||||||||
| are not modified — the actor is spawned ``lifetime="detached"`` so | ||||||||||||||||||||||||||||||||||||||
| it survives executor-local ``ray.shutdown()`` calls. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| The actor never raises (see ``ResumabilityActor.apply_deltas``), so | ||||||||||||||||||||||||||||||||||||||
| there's no watchdog and no error propagation path here — just spawn, | ||||||||||||||||||||||||||||||||||||||
| run, close. | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| import ray | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from nemo_curator.utils.resumability_actor import ResumabilityActor | ||||||||||||||||||||||||||||||||||||||
| from nemo_curator.utils.resumability_client import ACTOR_NAME | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| ray.init(ignore_reinit_error=True) | ||||||||||||||||||||||||||||||||||||||
| ResumabilityActor.options( # type: ignore[attr-defined] | ||||||||||||||||||||||||||||||||||||||
| name=ACTOR_NAME, | ||||||||||||||||||||||||||||||||||||||
| lifetime="detached", | ||||||||||||||||||||||||||||||||||||||
| get_if_exists=True, | ||||||||||||||||||||||||||||||||||||||
| max_pending_calls=100, | ||||||||||||||||||||||||||||||||||||||
| ).remote(str(checkpoint_path)) | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+308
to
+313
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Comment on lines
+308
to
+313
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Capturing the handle and making at least one synchronous |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| return executor.execute(self.stages, initial_tasks) | ||||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||||
| # The executor's ray.shutdown() may have run in its own | ||||||||||||||||||||||||||||||||||||||
| # finally:; reconnect to clean up the detached actor. | ||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| ray.init(ignore_reinit_error=True) | ||||||||||||||||||||||||||||||||||||||
| actor_handle = ray.get_actor(ACTOR_NAME) | ||||||||||||||||||||||||||||||||||||||
| ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined] | ||||||||||||||||||||||||||||||||||||||
| ray.kill(actor_handle) | ||||||||||||||||||||||||||||||||||||||
| except Exception as e: # noqa: BLE001 | ||||||||||||||||||||||||||||||||||||||
| logger.warning(f"resumability actor cleanup failed: {e}") | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+320
to
+326
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Separate the kill so it always runs regardless of whether
Suggested change
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question: why would the user ever want something other than source stage being the first stage and sink stage being the last stage of the pipeline? Like if the last stage failed but the second to last stage was the sink stage, they just don't want to rerun the last stage?