diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index 94236f2cfc..0c4b98670c 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -17,9 +17,13 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from loguru import logger + from nemo_curator.core.utils import ignore_ray_head_node from nemo_curator.tasks import Task +from nemo_curator.tasks.sentinels import FailedTask, NoneTask from nemo_curator.utils.performance_utils import StageTimer +from nemo_curator.utils.resumability_client import _flush_deltas, _is_active, _skip_completed_sources if TYPE_CHECKING: from nemo_curator.stages.base import ProcessingStage @@ -85,9 +89,23 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: # Use the batch processing logic results = self.stage.process_batch(tasks) + # A returned ``None`` ("filter this slot") becomes a NoneTask so every + # output is a real Task that gets a task_id. Sentinels (NoneTask / + # FailedTask) carry no identity and are stripped again before this + # method returns. + results = [NoneTask() if r is None else r for r in results] + # Guarantee every emitted task has a task_id (derived id, or uuid fallback). results = self._post_process_task_ids(tasks, results) + # Opt-in resumability: fire per-source counter deltas. A no-op (the + # client helpers self-disable) when no resumability actor is registered. + if _is_active(): + results = self._apply_resumability_counters(tasks, results) + + # Sentinels never propagate to the next stage. + results = [r for r in results if not isinstance(r, (NoneTask, FailedTask))] + # Log performance stats and add to result tasks _, stage_perf_stats = self._timer.log_stats() # Consume and attach any custom metrics recorded by the stage during this call @@ -168,6 +186,82 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas task.task_id = "r" + uuid.uuid4().hex return out + # ------------------------------------------------------------------ # + # Resumability (opt-in). Runs only when a resumability actor is + # registered. task_ids are already assigned by _post_process_task_ids; + # this layer only stamps _source_id, fires per-source counter deltas, and + # drops already-completed sources. Sentinels are stripped by the caller. + # ------------------------------------------------------------------ # + def _apply_resumability_counters(self, input_tasks: list[Task], output_tasks: list[Task]) -> list[Task]: # noqa: C901 + stage = self.stage + if getattr(stage, "is_source_stage", False): + return self._source_counters(output_tasks) + + # Pre-source stages: inputs carry no _source_id, so there's nothing to + # track yet. Leave outputs untouched. + if all(not t._source_id for t in input_tasks): + return output_tasks + + is_sink = stage.is_sink_stage + per_task: list[tuple[str, str, int]] = [] + real = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))] + + if len(input_tasks) == 1 and len(output_tasks) != 1: + # Genuine fan-out (1 -> N, N != 1): every real output descends from + # the single input. (The 1 -> 1 case falls through to the positional + # branch so a lone FailedTask is handled as "no delta".) + parent = input_tasks[0] + delta = -1 if is_sink else (len(real) - 1) + per_task.append((parent.task_id, parent._source_id, delta)) + for c in real: + if not c._source_id: + c._source_id = parent._source_id + elif len(output_tasks) == len(input_tasks): + # Positional 1:1, including filtered (NoneTask) / failed slots. + for parent, r in zip(input_tasks, output_tasks, strict=True): + sid = parent._source_id + if isinstance(r, FailedTask): + # No delta: the input stays pending so its source reruns. + continue + if isinstance(r, NoneTask): + per_task.append((parent.task_id, sid, -1)) + continue + per_task.append((parent.task_id, sid, -1 if is_sink else 0)) + if not r._source_id: + r._source_id = sid + else: + # M inputs -> K outputs (K != M): the parent of each output can't be + # determined, so the counter can't be updated correctly. Skip + # (the source counter stays pending -> reprocessed on resume). + logger.warning( + f"resumability: {type(stage).__name__} produced {len(output_tasks)} outputs " + f"for {len(input_tasks)} inputs; can't attribute sources, skipping counter " + f"update for this batch." + ) + return output_tasks + + _flush_deltas(per_task) + return output_tasks + + def _source_counters(self, output_tasks: list[Task]) -> list[Task]: + """Source stage: each output is a source partition. Its ``_source_id`` + is its own (last) id segment — the content id or index assigned by + ``_post_process_task_ids``. Already-completed sources are dropped; each + surviving source fires a ``+1``.""" + sources = [t for t in output_tasks if not isinstance(t, (NoneTask, FailedTask))] + for t in sources: + t._source_id = t.task_id.rsplit("_", 1)[-1] + completed = _skip_completed_sources([t._source_id for t in sources]) + per_task: list[tuple[str, str, int]] = [] + survivors: list[Task] = [] + for t in sources: + if t._source_id in completed: + continue + per_task.append((t.task_id, t._source_id, +1)) + survivors.append(t) + _flush_deltas(per_task) + return survivors + def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: """Setup the stage on a node. diff --git a/nemo_curator/pipeline/pipeline.py b/nemo_curator/pipeline/pipeline.py index 961ae33c6f..2b36033c37 100644 --- a/nemo_curator/pipeline/pipeline.py +++ b/nemo_curator/pipeline/pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Any from loguru import logger @@ -222,18 +223,35 @@ def describe(self) -> str: return "\n".join(lines) - def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | None = None) -> list[Task] | None: + def run( + self, + executor: BaseExecutor | None = None, + initial_tasks: list[Task] | None = None, + checkpoint_path: str | Path | None = None, + ) -> list[Task] | None: """Run the pipeline. Args: executor (BaseExecutor): Executor to use initial_tasks (list[Task], optional): Initial tasks to start the pipeline with. Defaults to None. + checkpoint_path (str | Path, optional): Directory used for + resumability. When set, completed source partitions are tracked + across runs and skipped on rerun; the tracking state lives in a + ``.nemo_curator_metadata`` subdirectory. Multiple independent + runs (e.g. the tasks of a SLURM array) may point at the same + directory — each writes its own LMDB file, so there is no + shared-file contention. The actor lifecycle is owned by this + method; executors are not modified. Returns: list[Task] | None: List of tasks """ self.build() + if checkpoint_path is not None: + checkpoint_path = Path(checkpoint_path).absolute() + checkpoint_path.mkdir(parents=True, exist_ok=True) + if executor is None: from nemo_curator.backends.xenna import XennaExecutor @@ -263,4 +281,46 @@ def run(self, executor: BaseExecutor | None = None, initial_tasks: list[Task] | if initial_tasks: assign_root_task_ids(initial_tasks) - return executor.execute(self.stages, initial_tasks) + if checkpoint_path is None: + return executor.execute(self.stages, initial_tasks) + return self._run_with_resumability(executor, initial_tasks, checkpoint_path) + + def _run_with_resumability( + self, + executor: BaseExecutor, + initial_tasks: list[Task] | None, + checkpoint_path: Path, + ) -> list[Task] | None: + """Owns the full resumability-actor lifecycle. Per-backend executors + are not modified — the actor is spawned ``lifetime="detached"`` so + it survives executor-local ``ray.shutdown()`` calls. + + The actor never raises (see ``ResumabilityActor.apply_deltas``), so + there's no watchdog and no error propagation path here — just spawn, + run, close. + """ + import ray + + from nemo_curator.utils.resumability_actor import ResumabilityActor + from nemo_curator.utils.resumability_client import ACTOR_NAME + + ray.init(ignore_reinit_error=True) + ResumabilityActor.options( # type: ignore[attr-defined] + name=ACTOR_NAME, + lifetime="detached", + get_if_exists=True, + max_pending_calls=100, + ).remote(str(checkpoint_path)) + + try: + return executor.execute(self.stages, initial_tasks) + finally: + # The executor's ray.shutdown() may have run in its own + # finally:; reconnect to clean up the detached actor. + try: + ray.init(ignore_reinit_error=True) + actor_handle = ray.get_actor(ACTOR_NAME) + ray.get(actor_handle.close.remote(), timeout=10) # type: ignore[attr-defined] + ray.kill(actor_handle) + except Exception as e: # noqa: BLE001 + logger.warning(f"resumability actor cleanup failed: {e}") diff --git a/nemo_curator/tasks/__init__.py b/nemo_curator/tasks/__init__.py index bf8d32a150..d45f3b778d 100644 --- a/nemo_curator/tasks/__init__.py +++ b/nemo_curator/tasks/__init__.py @@ -17,17 +17,19 @@ from .file_group import FileGroupTask from .image import ImageBatch, ImageObject from .interleaved import InterleavedBatch -from .sentinels import EmptyTask, SentinelTask +from .sentinels import EmptyTask, FailedTask, NoneTask, SentinelTask from .tasks import Task __all__ = [ "AudioTask", "DocumentBatch", "EmptyTask", + "FailedTask", "FileGroupTask", "ImageBatch", "ImageObject", "InterleavedBatch", + "NoneTask", "SentinelTask", "Task", ] diff --git a/nemo_curator/tasks/sentinels.py b/nemo_curator/tasks/sentinels.py index 84896dd963..ed1ab8b572 100644 --- a/nemo_curator/tasks/sentinels.py +++ b/nemo_curator/tasks/sentinels.py @@ -13,9 +13,18 @@ # limitations under the License. """Payload-less marker tasks. -``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). All markers -share the :class:`SentinelTask` base and carry no payload (``data is None``). -Construct one with ``EmptyTask()``. +``EmptyTask`` seeds a pipeline (the implicit root id ``"0"``). The resumability +layer adds two more markers on the same :class:`SentinelTask` base: + +- ``NoneTask`` — this slot was intentionally filtered. The resumability counter + treats it as a consumed branch (decrements). The adapter auto-wraps a + returned ``None`` as a ``NoneTask``. +- ``FailedTask`` — this slot failed and should be retried on resume. The counter + is NOT decremented, so its source stays pending and reruns. + +All carry no payload (``data is None``) and get their ``task_id`` assigned by +the executor adapter; sentinels are stripped before the next stage. Construct +with ``EmptyTask()`` / ``NoneTask()`` / ``FailedTask()``. """ from dataclasses import dataclass, field @@ -52,3 +61,17 @@ class EmptyTask(SentinelTask): dataset_name: str = "empty" task_id: str = field(init=False, default="0") + + +@dataclass +class NoneTask(SentinelTask): + """Marks a slot as intentionally filtered (resumability counter decrements).""" + + dataset_name: str = "none" + + +@dataclass +class FailedTask(SentinelTask): + """Marks a slot as failed → retried on resume (counter does NOT decrement).""" + + dataset_name: str = "failed" diff --git a/nemo_curator/tasks/tasks.py b/nemo_curator/tasks/tasks.py index 04bfb5caf0..fe2a2fbe70 100644 --- a/nemo_curator/tasks/tasks.py +++ b/nemo_curator/tasks/tasks.py @@ -46,6 +46,10 @@ class Task(ABC, Generic[T]): NON-deterministic (differ across runs). dataset_name: Name of the dataset this task belongs to. _stage_perf: List of stages perfs this task has passed through. + _source_id: Identifier of the source (input partition) this task + descends from. Stamped at the source stage and inherited + downstream; used only by the (opt-in) resumability layer to + track which sources have completed. Empty for pre-source tasks. """ dataset_name: str @@ -53,6 +57,7 @@ class Task(ABC, Generic[T]): _stage_perf: list[StagePerfStats] = field(default_factory=list) _metadata: dict[str, Any] = field(default_factory=dict) task_id: str = field(init=False, default="") + _source_id: str = field(init=False, default="") def __post_init__(self) -> None: """Post-initialization hook.""" diff --git a/nemo_curator/utils/resumability_actor.py b/nemo_curator/utils/resumability_actor.py new file mode 100644 index 0000000000..01d7523a20 --- /dev/null +++ b/nemo_curator/utils/resumability_actor.py @@ -0,0 +1,261 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Per-writer LMDB owner that tracks per-source pending counters for +resumability. + +Resumability state lives in a shared ``.nemo_curator_metadata`` directory so +that independent runs pointed at the same checkpoint location can record +completed sources without corrupting a shared file. The motivating case is a +SLURM array: each array task is its own job → its own Ray cluster → its own +actor, often on a different node. + +LMDB cannot be safely written by multiple processes on different hosts against +a single file — its inter-process lock lives in a memory-mapped lock file that +is not shared across nodes on a networked filesystem (e.g. Lustre). So instead +of one shared file: + +- each actor WRITES ONLY to its own file ``/-.mdb`` (it is the + sole writer of that file, so no cross-process locking is needed), and +- on startup it READS THE UNION of completed sources across every ``*.mdb`` + file in the directory. + +A rerun therefore sees everything every prior writer finished and skips it. +This assumes each writer owns a disjoint set of sources (the usual +shard-per-task model); two writers completing the *same* source id is harmless +(idempotent — the source is simply marked done). + +Workers fire ``apply_deltas`` fire-and-forget. The actor: + +- Maintains ``_pending: dict[source_id, int]`` (counter per in-flight source) +- Maintains ``_applied: dict[task_hash, delta]`` (Ray-retry dedup; on a retry + firing a *different* delta, rewrites the pending counter to reflect the + newest observation rather than raising) +- Writes a single LMDB row to its own file per source when its counter hits zero + +By design, ``apply_deltas`` **never raises**. The two anomaly cases we detect — +same task hash producing a different delta on retry, and a delta arriving for a +source that is already completed — are handled in-place: the first by +rewriting, the second by logging a warning and skipping. Resumability is never +the cause of a pipeline failure. +""" + +from __future__ import annotations + +import os +import socket +from pathlib import Path +from typing import TYPE_CHECKING + +import lmdb +import ray +from loguru import logger + +if TYPE_CHECKING: + from collections.abc import Iterable + + +_COMPLETED_DB = b"completed_sources" +_DEFAULT_MAP_SIZE = 1 << 30 # 1 GiB; sparse on Linux so effectively free +# Subdirectory (under the user-provided checkpoint dir) that holds the +# per-writer LMDB files. Hidden so it sits unobtrusively next to outputs. +METADATA_DIRNAME = ".nemo_curator_metadata" + + +@ray.remote(num_cpus=0, max_concurrency=1) +class ResumabilityActor: + """Per-writer counter + LMDB owner. + + Workers fire ``apply_deltas`` via ``.remote()`` and do NOT ``ray.get`` the + returned ref. The actor never raises; anomalies are logged and handled + inline (see ``apply_deltas`` docstring). + + Spawned by ``Pipeline.run`` with ``lifetime="detached"`` so it survives + executor-local ``ray.shutdown`` calls. ``Pipeline.run`` closes it + explicitly at end-of-run. + + Writes go only to this actor's own file in the shared metadata directory; + reads (``are_completed``) reflect the union of all writers' files as of + this actor's startup, plus this actor's own progress. + """ + + def __init__(self, base_dir: str, map_size: int = _DEFAULT_MAP_SIZE, writer_id: str | None = None): + # base_dir is the user-provided checkpoint directory; the per-writer + # LMDB files live in /.nemo_curator_metadata/. + self._dir = Path(base_dir).absolute() / METADATA_DIRNAME + self._dir.mkdir(parents=True, exist_ok=True) + # This actor's own file — the ONLY file it ever writes to. Keyed by + # this writer's identity (default host+pid, unique among concurrently + # running writers: distinct hosts, or distinct pids on one host). A + # rerun whose pid recycles simply reopens and appends to its old file, + # which is safe (sequential in time). + wid = writer_id or f"{socket.gethostname()}-{os.getpid()}" + self._path = str(self._dir / f"{wid}.mdb") + self._env = lmdb.open( + self._path, + subdir=False, + lock=False, # sole writer of this file → no inter-process lock needed + max_dbs=1, + map_size=map_size, + metasync=False, + sync=True, + readahead=False, + ) + self._db = self._env.open_db(_COMPLETED_DB) + self._pending: dict[str, int] = {} + # Union of completed sources across ALL writer files in the dir. + self._completed: set[str] = self._load_completed() + # task_hash -> delta we previously applied for this task. + # Dual-purpose: + # 1. Dedup: a Ray retry firing the same delta is a silent skip. + # 2. Rewrite-on-conflict: a retry firing a *different* delta + # replaces the old delta. The pending counter is adjusted by + # `(-old_delta + new_delta)` so the latest observation wins. + self._applied: dict[str, int] = {} + + def _read_completed_from(self, env: lmdb.Environment) -> set[str]: + """Read the completed-source ids from an already-open LMDB env. Returns + an empty set if this env has no completed-sources sub-db yet (a writer + that has only recorded in-flight, not-yet-finished sources).""" + try: + db = env.open_db(_COMPLETED_DB) + except lmdb.Error: + return set() + with env.begin() as txn, txn.cursor(db=db) as cur: + return {k.decode() for k, _ in cur} + + def _load_completed(self) -> set[str]: + """Union of completed source ids across every writer's LMDB file in the + metadata dir (other writers' files read read-only; our own via the open + write handle). A file that cannot be opened — e.g. it is mid-write by a + live writer, or already open in THIS process during tests — is skipped + with a warning rather than failing the run.""" + done = self._read_completed_from(self._env) # our own (possibly reused) file + for mdb in sorted(self._dir.glob("*.mdb")): + if str(mdb) == self._path: + continue + try: + env = lmdb.open(str(mdb), subdir=False, readonly=True, lock=False, max_dbs=1) + except lmdb.Error as e: + logger.warning(f"resumability: skipping unreadable checkpoint {mdb}: {e}") + continue + try: + done |= self._read_completed_from(env) + finally: + env.close() + return done + + # ------------------------------------------------------------ read + + def are_completed(self, source_ids: list[str]) -> list[bool]: + """Returns a parallel bool list indicating which source_ids are + already marked complete (and thus should be skipped on rerun). Reflects + the union of all writers' files as of startup, plus this actor's own + completions since.""" + return [sid in self._completed for sid in source_ids] + + # ------------------------------------------------------------ write + + def apply_deltas(self, per_task: list[tuple[str, str, int]]) -> None: + """Apply per-task counter deltas. Workers call this fire-and-forget + (no ``ray.get`` on the returned ref). + + Each tuple is ``(task_hash, source_id, delta)``. Behavior: + + - ``task_hash`` already in ``_applied`` with the same delta → + silent skip (Ray retry idempotency). + - ``task_hash`` already in ``_applied`` with a different delta and + ``source_id`` NOT in ``_completed`` → rewrite: adjust + ``_pending[source_id]`` by ``(-old + new)`` and update the + recorded delta. Reflects the latest observation. + - ``task_hash`` not in ``_applied`` and ``source_id`` already in + ``_completed`` → log a loud warning **and remove the source + from the completed set (in-memory + LMDB)** so it will be + reprocessed on the next run. Same treatment when a different + delta arrives for a task whose source is already completed. + These two cases indicate a bug — the cleanest recovery is to + un-complete the source rather than silently drop the update. + - Otherwise → normal apply. + + Never raises. + """ + newly_done: list[str] = [] + for task_hash, sid, d in per_task: + existing = self._applied.get(task_hash) + if existing is not None: + if existing == d: + continue # idempotent re-fire + if sid in self._completed: + # Source already finalized but we're getting a different + # delta for one of its tasks — the source wasn't actually + # done. Un-complete it so it reruns next launch. + logger.warning( + f"resumability: task {task_hash} delta changed from " + f"{existing} to {d} but source {sid!r} is already " + f"completed. Removing {sid!r} from the completed set " + f"so it will be reprocessed on the next run. Please " + f"file an issue at " + f"https://github.com/NVIDIA-NeMo/Curator if this is " + f"unexpected." + ) + self._remove_from_completed(sid) + continue + # Rewrite-on-conflict: the newest delta wins. + self._applied[task_hash] = d + self._pending[sid] = self._pending.get(sid, 0) + (-existing) + d + else: + # New task hash. + if sid in self._completed: + logger.warning( + f"resumability: source {sid!r} got update for new " # noqa: S608 + f"task {task_hash} (delta={d}) after being completed. " + f"Removing {sid!r} from the completed set so it will " + f"be reprocessed on the next run. Please file an " + f"issue at https://github.com/NVIDIA-NeMo/Curator." + ) + self._remove_from_completed(sid) + continue + self._applied[task_hash] = d + self._pending[sid] = self._pending.get(sid, 0) + d + if self._pending[sid] == 0: + newly_done.append(sid) + if newly_done: + self._persist_completed(newly_done) + for sid in newly_done: + self._completed.add(sid) + self._pending.pop(sid, None) + + def _persist_completed(self, sids: Iterable[str]) -> None: + with self._env.begin(write=True) as txn: + for sid in sids: + txn.put(sid.encode(), b"1", db=self._db, overwrite=True) + + def _remove_from_completed(self, sid: str) -> None: + """Remove ``sid`` from the in-memory completed set and from our own + LMDB file. Used when we detect that a source was prematurely marked + complete (a late delta arrives after completion); the safest recovery + is to un-complete so it reruns on next launch. Note: if ``sid`` was + completed by a *different* writer's file we can't delete it there (we + only ever write our own file), so it may reappear from the union on the + next startup — acceptable for this rare anomaly path.""" + self._completed.discard(sid) + with self._env.begin(write=True) as txn: + txn.delete(sid.encode(), db=self._db) + + def close(self) -> None: + if self._env is not None: + try: + self._env.close() + except Exception as e: # noqa: BLE001 + logger.warning(f"failed to close LMDB env: {e}") + self._env = None # type: ignore[assignment] diff --git a/nemo_curator/utils/resumability_client.py b/nemo_curator/utils/resumability_client.py new file mode 100644 index 0000000000..5b5940338e --- /dev/null +++ b/nemo_curator/utils/resumability_client.py @@ -0,0 +1,68 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module-level helpers that workers use to talk to the resumability +actor. All helpers are no-ops when the actor isn't registered, so +unchecked pipelines pay nothing. +""" + +from __future__ import annotations + +import ray + +# Name of the detached resumability actor. Defined here — NOT imported from +# resumability_actor — so the always-imported worker path +# (BaseStageAdapter -> this module) never pulls in resumability_actor, which +# imports lmdb. lmdb is only needed once resumability is actually used (the +# actor process and Pipeline._run_with_resumability). +ACTOR_NAME = "nemo_curator_resumability" + + +def _actor() -> ray.actor.ActorHandle | None: + """Return the resumability actor handle, or None if Ray is not + initialized or no such actor is registered.""" + if not ray.is_initialized(): + return None + try: + return ray.get_actor(ACTOR_NAME) + except ValueError: + return None + + +def _is_active() -> bool: + """True if a resumability actor is registered in this Ray cluster.""" + return _actor() is not None + + +def _flush_deltas(per_task: list[tuple[str, str, int]]) -> None: + """Fire-and-forget per-task counter deltas to the actor. + + Each entry is ``(task_hash, source_id, delta)``. Workers do NOT + ``ray.get`` the returned ref — errors surface to the executor via the + actor's watchdog poll, not synchronously on this call. Backpressure + is handled by Ray's ``_max_pending_calls`` cap on the actor. + """ + a = _actor() + if a is not None and per_task: + a.apply_deltas.remote(per_task) # type: ignore[attr-defined] + + +def _skip_completed_sources(source_ids: list[str]) -> set[str]: + """Synchronous lookup of which source_ids are already marked complete + in LMDB. Used by the source-stage adapter to drop already-done + sources from its output before downstream stages see them.""" + a = _actor() + if a is None or not source_ids: + return set() + flags = ray.get(a.are_completed.remote(source_ids)) # type: ignore[attr-defined] + return {sid for sid, done in zip(source_ids, flags, strict=True) if done} diff --git a/pyproject.toml b/pyproject.toml index bd10a5337b..5e823c8f43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ dependencies = [ "fsspec", "hydra-core", "jieba==0.42.1", + "lmdb>=1.4", "loguru", "mecab-python3", "omegaconf", diff --git a/tests/backends/test_resumability_adapter.py b/tests/backends/test_resumability_adapter.py new file mode 100644 index 0000000000..2c87987b13 --- /dev/null +++ b/tests/backends/test_resumability_adapter.py @@ -0,0 +1,204 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the resumability counter step in ``BaseStageAdapter``. + +task_id assignment is PR-A's job (``_post_process_task_ids``) and is covered +in ``tests/backends/test_task_id_postprocess.py``. Here we test only the +opt-in counter layer (``_apply_resumability_counters``) and the +``None``->``NoneTask`` normalization, with the actor RPCs mocked out. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import patch + +from nemo_curator.backends.base import BaseStageAdapter +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import FailedTask, NoneTask, Task + + +@dataclass +class _NoopStage(ProcessingStage[Task, Task]): + name: str = "noop" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: Task) -> Task: + return task + + +@dataclass +class _DropStage(ProcessingStage[Task, Task]): + """A non-source stage that filters every input (returns ``None``).""" + + name: str = "drop" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: Task) -> None: + return None + + +@dataclass +class _SimpleTask(Task[list[int]]): + @property + def num_items(self) -> int: + return 0 + + def validate(self) -> bool: + return True + + +def _task(task_id: str = "", source_id: str = "") -> _SimpleTask: + t = _SimpleTask(dataset_name="d", data=[]) + t.task_id = task_id # pretend _post_process_task_ids already ran + t._source_id = source_id + return t + + +def _counters( + stage: ProcessingStage, + input_tasks: list[Task], + output_tasks: list[Any], + *, + completed: set[str] | None = None, +) -> tuple[list[Task], list[tuple[str, str, int]]]: + """Run ``_apply_resumability_counters`` with the actor RPCs patched. + Returns ``(surviving_outputs, captured_deltas)``.""" + captured: list[tuple[str, str, int]] = [] + + with ( + patch("nemo_curator.backends.base._flush_deltas", side_effect=captured.extend), + patch("nemo_curator.backends.base._skip_completed_sources", return_value=completed or set()), + ): + out = BaseStageAdapter(stage)._apply_resumability_counters(input_tasks, output_tasks) + return out, captured + + +def _process( + stage: ProcessingStage, + tasks: list[Task], + *, + completed: set[str] | None = None, +) -> tuple[list[Task], list[tuple[str, str, int]]]: + """Run the full ``process_batch`` with the resumability actor patched + active. Returns ``(surviving_outputs, captured_deltas)``.""" + captured: list[tuple[str, str, int]] = [] + with ( + patch("nemo_curator.backends.base._is_active", return_value=True), + patch("nemo_curator.backends.base._flush_deltas", side_effect=captured.extend), + patch("nemo_curator.backends.base._skip_completed_sources", return_value=completed or set()), + ): + out = BaseStageAdapter(stage).process_batch(tasks) + return out, captured + + +class TestNoneNormalization: + """A returned ``None`` is normalized to a ``NoneTask`` inside + ``process_batch``: it decrements its slot's source counter and is then + stripped, so it never reaches the next stage.""" + + def test_returned_none_decrements_and_is_stripped(self) -> None: + parent = _task("s_0", source_id="s") + out, captured = _process(_DropStage(), [parent]) + assert out == [] # the NoneTask sentinel is stripped from the output + assert captured == [("s_0", "s", -1)] # the filtered slot is consumed + + +class TestSourceStage: + def _src_stage(self) -> _NoopStage: + s = _NoopStage() + s.is_source_stage = True + return s + + def test_stamps_source_id_and_fires_plus_one(self) -> None: + empty = _task("0") # EmptyTask-like root + a, b = _task("0_aaa"), _task("0_bbb") + out, captured = _counters(self._src_stage(), [empty], [a, b]) + assert out == [a, b] + # _source_id is the task_id's last segment (its content id / index). + assert a._source_id == "aaa" + assert b._source_id == "bbb" + assert sorted(captured) == [("0_aaa", "aaa", 1), ("0_bbb", "bbb", 1)] + + def test_drops_already_completed_sources(self) -> None: + empty = _task("0") + a, b, c = _task("0_a"), _task("0_b"), _task("0_c") + out, captured = _counters(self._src_stage(), [empty], [a, b, c], completed={"b"}) + assert out == [a, c] + assert {sid for _, sid, _ in captured} == {"a", "c"} + + +class TestNonSourceStage: + def test_pre_source_is_noop(self) -> None: + # Inputs carry no _source_id yet -> nothing tracked, outputs untouched. + a = _task("0_0") + out, captured = _counters(_NoopStage(), [a], [a]) + assert out == [a] + assert captured == [] + + def test_one_to_one_nonsink_zero_delta(self) -> None: + stage = _NoopStage() + stage.is_sink_stage = False + parent = _task("s_0", source_id="s") + child = _task("s_0_0") + _out, captured = _counters(stage, [parent], [child]) + assert captured == [("s_0", "s", 0)] + assert child._source_id == "s" # inherited + + def test_one_to_one_sink_minus_one(self) -> None: + stage = _NoopStage() + stage.is_sink_stage = True + parent = _task("s_0", source_id="s") + _out, captured = _counters(stage, [parent], [_task("s_0_0")]) + assert captured == [("s_0", "s", -1)] + + def test_nonetask_slot_decrements(self) -> None: + # The counter keys on the PARENT's identity, so the NoneTask itself is + # bare (no id / source_id) — it just marks "this slot was filtered". + parent = _task("s_0", source_id="s") + _out, captured = _counters(_NoopStage(), [parent], [NoneTask()]) + assert captured == [("s_0", "s", -1)] + + def test_failedtask_slot_no_delta(self) -> None: + parent = _task("s_0", source_id="s") + _out, captured = _counters(_NoopStage(), [parent], [FailedTask()]) + assert captured == [] + + def test_fanout_grows_counter(self) -> None: + stage = _NoopStage() + stage.is_sink_stage = False + parent = _task("s_0", source_id="s") + c0, c1, c2 = _task("s_0_0"), _task("s_0_1"), _task("s_0_2") + _out, captured = _counters(stage, [parent], [c0, c1, c2]) + # 1 input -> 3 children: net +2 for the source. + assert captured == [("s_0", "s", 2)] + assert all(c._source_id == "s" for c in (c0, c1, c2)) + + def test_ambiguous_batch_skips_counters(self) -> None: + # 2 inputs -> 3 outputs: can't attribute, so no deltas are fired. + p0, p1 = _task("s_0", source_id="s"), _task("s_1", source_id="s") + out, captured = _counters(_NoopStage(), [p0, p1], [_task(), _task(), _task()]) + assert captured == [] + assert len(out) == 3 diff --git a/tests/tasks/test_sentinels.py b/tests/tasks/test_sentinels.py new file mode 100644 index 0000000000..07b9bc785b --- /dev/null +++ b/tests/tasks/test_sentinels.py @@ -0,0 +1,77 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the payload-less sentinel tasks. + +``SentinelTask`` is the shared base; ``EmptyTask`` seeds a pipeline while +``NoneTask`` / ``FailedTask`` are the resumability markers. They carry no +payload and their ``task_id`` is framework-assigned (``EmptyTask`` is fixed +to the root ``"0"``; the others default empty until the adapter sets them). +""" + +from __future__ import annotations + +import pytest + +from nemo_curator.tasks import EmptyTask, FailedTask, NoneTask, SentinelTask, Task + + +class TestSentinelBase: + def test_subclasses_are_tasks(self) -> None: + for obj in (SentinelTask(dataset_name="s"), EmptyTask(), NoneTask(), FailedTask()): + assert isinstance(obj, Task) + assert isinstance(obj, SentinelTask) + + def test_carry_no_data(self) -> None: + for obj in (SentinelTask(dataset_name="s"), EmptyTask(), NoneTask(), FailedTask()): + assert obj.data is None + + def test_num_items_is_zero(self) -> None: + for obj in (SentinelTask(dataset_name="s"), EmptyTask(), NoneTask(), FailedTask()): + assert obj.num_items == 0 + + def test_validate_is_true(self) -> None: + for obj in (SentinelTask(dataset_name="s"), EmptyTask(), NoneTask(), FailedTask()): + assert obj.validate() is True + + def test_rejects_payload(self) -> None: + # The base asserts ``data is None`` so a sentinel can never carry data. + with pytest.raises(AssertionError): + SentinelTask(dataset_name="s", data="oops") + + +class TestEmptyTask: + def test_is_rooted_at_zero(self) -> None: + # EmptyTask is the implicit root every task descends from. + assert EmptyTask().task_id == "0" + assert EmptyTask().dataset_name == "empty" + + def test_task_id_is_not_user_settable(self) -> None: + # ``task_id`` is init=False, so it cannot be passed positionally/kw. + with pytest.raises(TypeError): + EmptyTask(task_id="5") # type: ignore[call-arg] + + +class TestResumabilityMarkers: + def test_dataset_names(self) -> None: + assert NoneTask().dataset_name == "none" + assert FailedTask().dataset_name == "failed" + + def test_task_id_unset_until_assigned(self) -> None: + # Unlike EmptyTask, these get their id from the adapter; default empty. + assert NoneTask().task_id == "" + assert FailedTask().task_id == "" + + def test_none_and_failed_are_distinct(self) -> None: + assert not isinstance(NoneTask(), FailedTask) + assert not isinstance(FailedTask(), NoneTask) diff --git a/tests/utils/test_resumability_actor.py b/tests/utils/test_resumability_actor.py new file mode 100644 index 0000000000..b73281ccb4 --- /dev/null +++ b/tests/utils/test_resumability_actor.py @@ -0,0 +1,290 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for :class:`ResumabilityActor` (counter math, dedup, +rewrite-on-conflict, LMDB persistence). + +These tests instantiate the actor class directly (without going through +``@ray.remote``) so they're fast and don't require a live Ray cluster. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import patch + +from nemo_curator.utils.resumability_actor import ResumabilityActor + +if TYPE_CHECKING: + from pathlib import Path + + +def _new_actor(tmp_path: Path, writer_id: str | None = None) -> ResumabilityActor: + """Bypass ``@ray.remote`` and instantiate the actor class directly. + + Ray's ``@ray.remote`` decorator stashes the original class on + ``__ray_metadata__.modified_class``. ``tmp_path`` is the checkpoint + directory; the actor keeps its LMDB file under + ``tmp_path/.nemo_curator_metadata/``. ``writer_id`` distinguishes writers + sharing that directory (defaults to host+pid in production); pass distinct + ids to simulate concurrent runs / SLURM-array tasks. + """ + cls = ResumabilityActor.__ray_metadata__.modified_class # type: ignore[attr-defined] + return cls(str(tmp_path), writer_id=writer_id) + + +class TestApplyDeltasCounterMath: + def test_source_emit_increments_pending(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1), ("h1", "1", +1)]) + assert actor._pending == {"0": 1, "1": 1} + assert actor._completed == set() + actor.close() + + def test_counter_reaches_zero_persists_to_lmdb(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_sink", "0", -1)]) + assert actor._completed == {"0"} + assert "0" not in actor._pending + actor.close() + + # Reopen the actor and confirm "0" survives in LMDB. + actor2 = _new_actor(tmp_path) + assert actor2._completed == {"0"} + actor2.close() + + def test_nonsink_real_task_is_zero_delta(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_passthrough", "0", 0)]) + assert actor._pending == {"0": 1} + assert actor._completed == set() + actor.close() + + def test_nonetask_decrements(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_filter", "0", -1)]) + assert actor._completed == {"0"} + actor.close() + + def test_fanout_grows_counter(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + # Fan-out 1→3 emits delta = (3-1) = +2 on the parent's source. + actor.apply_deltas([("h_fanout", "0", +2)]) + assert actor._pending == {"0": 3} + actor.close() + + +class TestDedupAndRewrite: + def test_same_task_same_delta_is_idempotent(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_t", "0", -1)]) + # Second identical fire — should be a no-op (Ray retry idempotency). + actor.apply_deltas([("h_t", "0", -1)]) + assert actor._completed == {"0"} + actor.close() + + def test_same_task_different_delta_rewrites(self, tmp_path: Path) -> None: + """When a Ray retry fires a different delta for the same task hash, + the actor adjusts pending by (-old + new) so the latest observation + wins. Never raises.""" + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + # First the worker says delta=0 (real Task passed through). + actor.apply_deltas([("h_t", "0", 0)]) + assert actor._pending == {"0": 1} + + # Retry says delta=-1 (NoneTask this time). Rewrite: pending += -0 + -1. + actor.apply_deltas([("h_t", "0", -1)]) + assert actor._completed == {"0"} + # And the recorded delta is updated to the new value. + assert actor._applied["h_t"] == -1 + actor.close() + + def test_rewrite_does_not_raise(self, tmp_path: Path) -> None: + """apply_deltas never raises; rewrite is silent.""" + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + # Multiple conflicting deltas for the same task: should not raise. + actor.apply_deltas([("h_t", "0", 0)]) + actor.apply_deltas([("h_t", "0", +5)]) + actor.apply_deltas([("h_t", "0", -1)]) + # Final state reflects the last delta. + assert actor._applied["h_t"] == -1 + actor.close() + + +class TestUncompleteOnAnomaly: + def test_new_task_after_source_completed_warns_and_uncompletes(self, tmp_path: Path) -> None: + """If a delta arrives for a never-seen task on an already-completed + source, the source wasn't actually done. Un-complete it (in-memory + and in LMDB) so it reruns next launch.""" + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1), ("h_t", "0", -1)]) + assert actor._completed == {"0"} + + with patch("nemo_curator.utils.resumability_actor.logger") as mock_logger: + actor.apply_deltas([("h_late", "0", -1)]) + mock_logger.warning.assert_called_once() + warn_msg = mock_logger.warning.call_args[0][0] + assert "Removing" in warn_msg + assert "completed set" in warn_msg + + # Source has been removed from the in-memory completed set. + assert "0" not in actor._completed + # And from LMDB — reopen and confirm. + actor.close() + actor2 = _new_actor(tmp_path) + assert "0" not in actor2._completed + actor2.close() + + def test_rewrite_attempt_after_source_completed_warns_and_uncompletes(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_t", "0", -1)]) + assert actor._completed == {"0"} + + # Same task tries to rewrite to a different delta after completion. + with patch("nemo_curator.utils.resumability_actor.logger") as mock_logger: + actor.apply_deltas([("h_t", "0", 0)]) + mock_logger.warning.assert_called_once() + warn_msg = mock_logger.warning.call_args[0][0] + assert "Removing" in warn_msg + + # Source has been uncompleted. + assert "0" not in actor._completed + actor.close() + actor2 = _new_actor(tmp_path) + assert "0" not in actor2._completed + actor2.close() + + def test_apply_deltas_never_raises(self, tmp_path: Path) -> None: + """The whole point of removing the error machinery — no path through + apply_deltas should raise.""" + actor = _new_actor(tmp_path) + # Throw lots of weird stuff at it. + actor.apply_deltas([("h0", "0", +1)]) + actor.apply_deltas([("h_t", "0", -1)]) # completes 0 + actor.apply_deltas([("h_t", "0", +5)]) # rewrite on completed source: warn + uncomplete + actor.apply_deltas([("h_new", "0", -1)]) # new hash, source no longer in completed + actor.apply_deltas([("h1", "1", +1), ("h_t1", "1", -5)]) # negative pending + # Reached this line without raising — pass. + actor.close() + + +class TestAreCompleted: + def test_returns_parallel_bool_list(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1), ("h1", "1", +1)]) + actor.apply_deltas([("h_t", "0", -1)]) + assert actor.are_completed(["0", "1", "unknown"]) == [True, False, False] + actor.close() + + def test_loads_from_lmdb_on_construction(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.apply_deltas([("h_a", "a", +1), ("h_t", "a", -1)]) + assert actor._completed == {"a"} + actor.close() + + actor2 = _new_actor(tmp_path) + assert actor2.are_completed(["a", "b"]) == [True, False] + actor2.close() + + +class TestLifecycle: + def test_close_is_idempotent(self, tmp_path: Path) -> None: + actor = _new_actor(tmp_path) + actor.close() + actor.close() # second close is a no-op + + def test_one_lmdb_write_per_completed_source(self, tmp_path: Path) -> None: + """Sanity-check the 'write only when a counter hits zero' contract: + a still-pending source is never persisted; once it completes it is. + + We verify via close/reopen rather than a concurrent second reader: + lmdb refuses to open the same env file twice in one process, and in + production a single detached actor owns each checkpoint file anyway. + """ + actor = _new_actor(tmp_path) + actor.apply_deltas([("h0", "0", +1)]) + # Source 0 is pending (counter != 0) — nothing recorded as completed. + assert actor._completed == set() + # Counter hits zero — now it's recorded. + actor.apply_deltas([("h_t", "0", -1)]) + assert actor._completed == {"0"} + actor.close() + + # A fresh actor loads exactly the one completed source from LMDB. + actor_c = _new_actor(tmp_path) + assert actor_c._completed == {"0"} + actor_c.close() + + +def test_no_lmdb_writes_for_pending_only_deltas(tmp_path: Path) -> None: + """Pending counters change in-memory only; LMDB is touched solely + when a counter hits zero.""" + actor = _new_actor(tmp_path) + # Lots of activity, but no source resolves. + actor.apply_deltas([("h0", "0", +1), ("h1", "1", +1), ("h_fanout_0", "0", +2)]) + actor.close() + + # Fresh actor: nothing persisted. + actor2 = _new_actor(tmp_path) + assert actor2._completed == set() + actor2.close() + + +class TestMultipleWriters: + """Shared metadata dir with one LMDB file per writer (the SLURM-array + model): each writer records ONLY its own completions; later writers read + the union across all writers' files on startup.""" + + def test_union_of_completed_across_writers(self, tmp_path: Path) -> None: + # Writer A finishes source "0". + a = _new_actor(tmp_path, writer_id="hostA-1") + a.apply_deltas([("hA", "0", +1), ("hA_sink", "0", -1)]) + assert a._completed == {"0"} + a.close() + + # Writer B starts later, sees A's completion in the union, finishes "1". + b = _new_actor(tmp_path, writer_id="hostB-2") + assert b.are_completed(["0", "1"]) == [True, False] + b.apply_deltas([("hB", "1", +1), ("hB_sink", "1", -1)]) + b.close() + + # A fresh writer sees the union of everything finished so far. + c = _new_actor(tmp_path, writer_id="hostC-3") + assert c.are_completed(["0", "1", "2"]) == [True, True, False] + c.close() + + # Each writer wrote its OWN file — nothing is shared. + files = sorted(p.name for p in (tmp_path / ".nemo_curator_metadata").glob("*.mdb")) + assert files == ["hostA-1.mdb", "hostB-2.mdb", "hostC-3.mdb"] + + def test_writer_does_not_write_other_writers_files(self, tmp_path: Path) -> None: + # A finishes "s"; B finishes nothing. B must not have touched A's file, + # and A's completion is still readable on its own. + a = _new_actor(tmp_path, writer_id="A") + a.apply_deltas([("h", "s", +1), ("h_sink", "s", -1)]) + a.close() + + b = _new_actor(tmp_path, writer_id="B") # finishes nothing + b.close() + + reader = _new_actor(tmp_path, writer_id="reader") + assert reader.are_completed(["s"]) == [True] + reader.close() diff --git a/tests/utils/test_resumability_client.py b/tests/utils/test_resumability_client.py new file mode 100644 index 0000000000..5b0810ede3 --- /dev/null +++ b/tests/utils/test_resumability_client.py @@ -0,0 +1,98 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the worker-side resumability client helpers. + +These talk to the detached resumability actor over Ray. ``ray`` is mocked so +the helpers' control flow (actor lookup, no-op when inactive, delta fire, +completed-source lookup) is exercised without a live cluster. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from nemo_curator.utils import resumability_client as rc + + +class TestActorLookup: + def test_none_when_ray_not_initialized(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = False + assert rc._actor() is None + assert rc._is_active() is False + ray.get_actor.assert_not_called() + + def test_none_when_no_actor_registered(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + ray.get_actor.side_effect = ValueError("no such actor") + assert rc._actor() is None + assert rc._is_active() is False + + def test_returns_handle_when_registered(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + handle = MagicMock() + ray.get_actor.return_value = handle + assert rc._actor() is handle + assert rc._is_active() is True + ray.get_actor.assert_called_with(rc.ACTOR_NAME) + + +class TestFlushDeltas: + def test_fires_when_active_and_nonempty(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + handle = MagicMock() + ray.get_actor.return_value = handle + deltas = [("t0", "s0", 1), ("t1", "s0", -1)] + rc._flush_deltas(deltas) + handle.apply_deltas.remote.assert_called_once_with(deltas) + + def test_noop_when_no_deltas(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + handle = MagicMock() + ray.get_actor.return_value = handle + rc._flush_deltas([]) + handle.apply_deltas.remote.assert_not_called() + + def test_noop_when_inactive(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = False + # Must not raise even though there are deltas to send. + rc._flush_deltas([("t0", "s0", 1)]) + + +class TestSkipCompletedSources: + def test_returns_completed_subset(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + handle = MagicMock() + ray.get_actor.return_value = handle + ray.get.return_value = [True, False, True] + assert rc._skip_completed_sources(["a", "b", "c"]) == {"a", "c"} + handle.are_completed.remote.assert_called_once_with(["a", "b", "c"]) + + def test_empty_when_inactive(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = False + assert rc._skip_completed_sources(["a"]) == set() + + def test_empty_when_no_sources(self) -> None: + with patch.object(rc, "ray") as ray: + ray.is_initialized.return_value = True + ray.get_actor.return_value = MagicMock() + assert rc._skip_completed_sources([]) == set() + ray.get.assert_not_called() diff --git a/uv.lock b/uv.lock index 7509d39c76..35200b30ed 100644 --- a/uv.lock +++ b/uv.lock @@ -4357,6 +4357,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/ef/11292bb0b85cf4c93447cab5a29f64576ed14d3ab4280e35ddd23486594a/lm_format_enforcer-0.11.3-py3-none-any.whl", hash = "sha256:cf586350875def1ae7a8fba84fcbbfc8371424b6c9d05c1fcba70aa233fbf06f", size = 45418, upload-time = "2025-08-24T19:37:46.325Z" }, ] +[[package]] +name = "lmdb" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/ddef3e433950e23844fd9d82fa045637cbe84140f482120bbdf6abe6be92/lmdb-2.2.1.tar.gz", hash = "sha256:b201b416f7d6cea9bd2f977277a5f51d6e52a434d6ec511a8b34990df2b1a9c5", size = 938665, upload-time = "2026-06-04T04:46:31.461Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/7f/0ed305faf932595d364af9a3046c044f9277273db9e1f033a66fbf2c5b77/lmdb-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:211cad947bc361cbe3c19ef6800d4e1dcb8f2f15e3e5b9bad34cc2818431d268", size = 115968, upload-time = "2026-06-04T04:45:50.068Z" }, + { url = "https://files.pythonhosted.org/packages/30/1e/712864753e331ecf2d93569a6a6d3d1f2a9dcb54feb11a2ace590e32f989/lmdb-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:090c498f57883d69420e4c6a6ec5726471e6ca35e183fe8f032165348c7d49b3", size = 114871, upload-time = "2026-06-04T04:45:51.35Z" }, + { url = "https://files.pythonhosted.org/packages/02/89/7570997080a4e778e6e066c829e722d73ebbc25c269982001b9ce8a26abf/lmdb-2.2.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aa4115c7fc86ca6ee654f931ceba9e410e83f3296e64cb73125020286be54eb2", size = 326436, upload-time = "2026-06-04T04:45:52.672Z" }, + { url = "https://files.pythonhosted.org/packages/af/97/dc5716d168d652cb2f04bef856a88d51652c42a09c20d23d2e08d4b7704a/lmdb-2.2.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c145f6a67cc10c0c055cf4b9ce16274fb850c4d9690fef5428cb588f0694be1", size = 329516, upload-time = "2026-06-04T04:45:54.233Z" }, + { url = "https://files.pythonhosted.org/packages/63/74/a8701f8e74ced8ec82de63fa0ac098c9fea41e4c57121ca9724790f7ef55/lmdb-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:7d39273c9cd561a7a084090ba33c008b668257c9202c15aa7d9f9c550f44d030", size = 113705, upload-time = "2026-06-04T04:45:55.482Z" }, + { url = "https://files.pythonhosted.org/packages/98/9a/a1304e1cdb991de6f250f5723a90558b17d4f34a0f1a7315cfa6cb301fee/lmdb-2.2.1-cp311-cp311-win_arm64.whl", hash = "sha256:2e5104ae83edf2e04e54ef9b85b07f080e982ea6c3d5c701b4bca2653ee160f1", size = 107498, upload-time = "2026-06-04T04:45:56.806Z" }, + { url = "https://files.pythonhosted.org/packages/1b/93/4796573d885dbc0dd94ed712d070c6919a019acd12754c4708ba8a47732d/lmdb-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e6957c9346ce9e9300ca2b75625e681b9868bbaf4d257626ec96d221e8200fc4", size = 116824, upload-time = "2026-06-04T04:45:58.058Z" }, + { url = "https://files.pythonhosted.org/packages/33/20/d3e48f1af18d67e56c2f42f82a598c2586d7d47dca7c8edda4f479e108b4/lmdb-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd3f3ab6feed2d4ca87d9d9063d2e371c8cc6d72879d54ae160a1c32758d26c0", size = 115341, upload-time = "2026-06-04T04:45:59.352Z" }, + { url = "https://files.pythonhosted.org/packages/5e/3e/6c3d2aa3b2250220d664a3ebb137519b6c33f94e27bf62e903130fac2cb4/lmdb-2.2.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9129a78af25dd1316784d689fefbd88bda6a756c82847a72b7f423bc1282dbd0", size = 333528, upload-time = "2026-06-04T04:46:00.748Z" }, + { url = "https://files.pythonhosted.org/packages/cf/72/64588fb1359b9a8d2fc6d3bfd98cd6a7f22adcd5fffa4252874529e72794/lmdb-2.2.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13438ad327f8bca47f1415671335eec500b653459d269556eb2cf2470cecec30", size = 338288, upload-time = "2026-06-04T04:46:02.097Z" }, + { url = "https://files.pythonhosted.org/packages/35/19/bf3466f65c7795d44b6119cd62fa505a1fd3ebb50d71bd20b823e2b1485c/lmdb-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:e54f8705489f8b6668b648333fbd90875c06878b3226a64f3f1af58af01c3d00", size = 113598, upload-time = "2026-06-04T04:46:03.593Z" }, + { url = "https://files.pythonhosted.org/packages/a9/7f/214172bc46f67ec58ee0ec0cda3cf6b27ceeaef614be25c863b7da35f9a8/lmdb-2.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:84468990d6b7f50243a1eb19e7f9fbaead93eb7de0eb854b7dacc7f893c699ea", size = 107614, upload-time = "2026-06-04T04:46:04.834Z" }, + { url = "https://files.pythonhosted.org/packages/55/ea/65df850c0f371856eb495c018b13b16da229cb072a06236021130ce6c2f7/lmdb-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d468fa89da30515979bf35c3e5b4db0ded560f9c39449c11459559c9f85bb820", size = 117352, upload-time = "2026-06-04T04:46:06.103Z" }, + { url = "https://files.pythonhosted.org/packages/1f/88/94a079be5dc482cb9971da32a82046bdcf2124646e4d84c5b4412ccb8d78/lmdb-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:881e8cdde83d9130b9cf75faf3202c16cbdeb54da7ec58a0856e8adfff5d5c25", size = 115703, upload-time = "2026-06-04T04:46:07.42Z" }, + { url = "https://files.pythonhosted.org/packages/a3/73/e360c13279ea523d0caf2d231dd581c9fd0e4c6b49f33acde8613f0b653c/lmdb-2.2.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d54bb7ef49241602599f6fee8547ba14765b896ec459dad9620940235c550ab6", size = 336991, upload-time = "2026-06-04T04:46:08.706Z" }, + { url = "https://files.pythonhosted.org/packages/9f/de/e36baf673fb218b17c0c7a8050d1aad7bd49eb7b8fcf8cf0268ddc06507e/lmdb-2.2.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:12b84c38d091bb283853d8af38951338bf3eb729d8e79f0381291b098c0616f6", size = 340692, upload-time = "2026-06-04T04:46:10.326Z" }, + { url = "https://files.pythonhosted.org/packages/c0/de/9e13991db388343ca59caf684e1572705d9d89bc5cc681cfa912cd3b9106/lmdb-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:f68a203f45d7442527c9cc8cd9a7e10666e38b64a71775870bf5b54c30a15661", size = 113526, upload-time = "2026-06-04T04:46:11.73Z" }, + { url = "https://files.pythonhosted.org/packages/4b/83/2c27f9544034387badbadf577a716cf5681afd79f5fb762c2038b62af70b/lmdb-2.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:6f783cd75835eb7d4676be5b0d38f68a31961f07d74126fd6424377005fb4d04", size = 107682, upload-time = "2026-06-04T04:46:12.981Z" }, +] + [[package]] name = "locket" version = "1.0.0" @@ -5106,6 +5132,7 @@ dependencies = [ { name = "fsspec" }, { name = "hydra-core" }, { name = "jieba" }, + { name = "lmdb" }, { name = "loguru" }, { name = "mecab-python3" }, { name = "omegaconf" }, @@ -5548,6 +5575,7 @@ requires-dist = [ { name = "jieba", specifier = "==0.42.1" }, { name = "justext", marker = "extra == 'text-cpu'" }, { name = "librosa", marker = "extra == 'audio-common'" }, + { name = "lmdb", specifier = ">=1.4" }, { name = "loguru" }, { name = "lxml", marker = "extra == 'text-cpu'" }, { name = "matplotlib", marker = "extra == 'interleaved-cpu'" },