diff --git a/bergson/__init__.py b/bergson/__init__.py index e3357854..85558d32 100644 --- a/bergson/__init__.py +++ b/bergson/__init__.py @@ -26,6 +26,7 @@ from .query.attributor import Attributor from .query.faiss_index import FaissConfig from .score.scorer import Scorer +from .sharding import ShardedMemmap, shard_status from .utils.gradcheck import FiniteDiff from .utils.load_from_optimizer import load_from_optimizer @@ -54,5 +55,7 @@ "Scorer", "ScoreConfig", "QueryConfig", + "ShardedMemmap", + "shard_status", "mix_autocorrelation_matrices", ] diff --git a/bergson/__main__.py b/bergson/__main__.py index 7edcb460..529d956b 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -17,6 +17,7 @@ Query, Reduce, Score, + Status, Test_Model_Configuration, Trackstar, Train, @@ -38,6 +39,7 @@ class Main: Query, Reduce, Score, + Status, Trackstar, Train, Test_Model_Configuration, diff --git a/bergson/build.py b/bergson/build.py index f7f9c087..0778065f 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -24,6 +24,7 @@ ) from bergson.utils.worker_utils import ( create_processor, + publish_shard, setup_data_pipeline, setup_model_and_peft, ) @@ -156,6 +157,12 @@ def build( if index_cfg.debug: setup_reproducibility() + if index_cfg.sharded and preprocess_cfg.aggregation != "none": + raise ValueError( + "Sharded runs do not support gradient aggregation; per-shard " + "aggregates would be concatenated instead of summed." + ) + index_cfg.partial_run_path.mkdir(parents=True, exist_ok=True) ds, _ = setup_data_pipeline(index_cfg) @@ -175,7 +182,10 @@ def build( ) if dist_cfg.rank == 0: - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + if index_cfg.sharded: + publish_shard(index_cfg, num_items=len(ds)) + else: + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) if dist_cfg.world_size < index_cfg.distributed.world_size: parent_barrier(index_cfg.distributed) diff --git a/bergson/cli/commands.py b/bergson/cli/commands.py index 35ff2f7d..ada85aa7 100644 --- a/bergson/cli/commands.py +++ b/bergson/cli/commands.py @@ -8,7 +8,7 @@ from dataclasses import dataclass -from simple_parsing import Serializable +from simple_parsing import Serializable, field from ..build import build from ..config.config import ( @@ -31,7 +31,7 @@ from ..process_grads import mix_autocorrelation_matrices from ..query.query_index import query from ..score.score import score_dataset -from ..utils.worker_utils import validate_run_path +from ..utils.worker_utils import prepare_shard, validate_run_path @dataclass @@ -52,6 +52,12 @@ class ApproxUnrolling(Serializable): def execute(self): from ..approx_unrolling.pipeline import approx_unrolling_pipeline + if self.index_cfg.sharded: + raise ValueError( + "approx_unrolling does not support sharded runs yet; " + "shard the build and score steps individually." + ) + save_run_config(self, self.index_cfg.run_path) approx_unrolling_pipeline( self.index_cfg, @@ -92,7 +98,17 @@ def execute(self): f"{self.hessian_cfg.method}." ) - validate_run_path(self.index_cfg) + if self.index_cfg.sharded: + if self.hessian_cfg is not None: + raise ValueError( + "Sharded builds do not support simultaneous Hessian " + "estimation; Hessian factors cannot be merged across " + "independent shards yet. Run `bergson hessian` separately." + ) + if prepare_shard(self, self.index_cfg): + return + else: + validate_run_path(self.index_cfg) save_run_config(self, self.index_cfg.partial_run_path) build(self.index_cfg, self.preprocess_cfg, self.hessian_cfg) @@ -115,6 +131,12 @@ class Ekfac(Serializable): def execute(self): from ..hessians.pipeline import hessian_pipeline + if self.index_cfg.sharded: + raise ValueError( + "ekfac does not support sharded runs yet; " + "shard the build and score steps individually." + ) + save_run_config(self, self.index_cfg.run_path) hessian_pipeline( self.index_cfg, @@ -134,6 +156,11 @@ class Hessian(Serializable): def execute(self): """Compute Hessian approximation.""" + if self.index_cfg.sharded: + raise ValueError( + "hessian does not support sharded runs; Hessian factors " + "cannot be merged across independent shards yet." + ) validate_run_path(self.index_cfg) @@ -197,6 +224,12 @@ class Reduce(Serializable): def execute(self): """Reduce a gradient index.""" + if self.index_cfg.sharded: + raise ValueError( + "reduce does not support sharded runs; per-shard aggregates " + "would be concatenated instead of summed." + ) + if self.index_cfg.projection_dim != 0: print(f"Using a projection dimension of {self.index_cfg.projection_dim}. ") @@ -222,7 +255,12 @@ def execute(self): if self.index_cfg.projection_dim != 0: print(f"Using a projection dimension of {self.index_cfg.projection_dim}. ") - validate_run_path(self.index_cfg) + if self.index_cfg.sharded: + if prepare_shard(self, self.index_cfg): + return + else: + validate_run_path(self.index_cfg) + save_run_config(self, self.index_cfg.partial_run_path) score_dataset(self.index_cfg, self.score_cfg, self.preprocess_cfg) @@ -238,6 +276,12 @@ class Trackstar(Serializable): def execute(self): from .trackstar import trackstar + if self.index_cfg.sharded: + raise ValueError( + "trackstar does not support sharded runs yet; " + "shard the build and score steps individually." + ) + save_run_config(self, self.index_cfg.run_path) trackstar(self.index_cfg, self.trackstar_cfg) @@ -251,6 +295,32 @@ def execute(self): run_magic(self) +@dataclass +class Status(Serializable): + """Report the progress of a sharded run: which shards are published, + in progress, or missing.""" + + run_path: str = field(positional=True) + + def execute(self): + """Print the shard inventory of a run path.""" + from ..sharding import shard_status + + published, partial, num_shards = shard_status(self.run_path) + if num_shards is None: + print(f"{self.run_path} is not a sharded run.") + return + + missing = sorted(set(range(num_shards)) - published.keys() - partial.keys()) + print(f"{self.run_path}: {len(published)}/{num_shards} shards published") + if partial: + print(f" in progress or crashed: {sorted(partial)}") + if missing: + print(f" not started: {missing}") + if not partial and not missing: + print(" run complete") + + @dataclass class Test_Model_Configuration: """Test gradient consistency across padding and batch composition. diff --git a/bergson/config/config.py b/bergson/config/config.py index fceea112..a02e6f15 100644 --- a/bergson/config/config.py +++ b/bergson/config/config.py @@ -8,6 +8,8 @@ import torch from simple_parsing import Serializable, field +from bergson.sharding import SHARDS_DIRNAME, shard_dir_name + @dataclass class DataConfig(Serializable): @@ -451,9 +453,77 @@ class IndexConfig(AttributionConfig, Serializable): modules: list[str] = field(default_factory=list) """Modules to use for the query. If empty, all modules will be used.""" + num_shards: int = 1 + """Split the dataset into this many contiguous shards, processing only + `shard_id`'s slice. Each shard is an independent single-node run that + publishes into `run_path/shards/`; readers present the published shards + as one index. Incompatible with `nnode` > 1.""" + + shard_id: int | None = None + """Which shard to process when `num_shards` > 1. If unset, inferred from + SLURM_ARRAY_TASK_ID or SLURM_PROCID.""" + + def __post_init__(self): + super().__post_init__() + + if self.num_shards < 1: + raise ValueError(f"num_shards must be >= 1, got {self.num_shards}") + if self.num_shards > 1 and self.distributed.nnode > 1: + raise ValueError( + "num_shards launches independent single-node runs and cannot " + "be combined with nnode > 1. Use nnode for coordinated " + "multi-node runs, or num_shards for embarrassingly parallel " + "ones, not both." + ) + if self.shard_id is not None and not self.sharded: + raise ValueError("shard_id requires num_shards > 1") + if self.shard_id is not None and not 0 <= self.shard_id < self.num_shards: + raise ValueError( + f"shard_id must be in [0, {self.num_shards}), got {self.shard_id}" + ) + + @property + def sharded(self) -> bool: + """Whether this run builds one shard of a sharded index.""" + return self.num_shards > 1 + + @property + def resolved_shard_id(self) -> int: + """The shard to process, from config or SLURM environment variables.""" + if self.shard_id is not None: + return self.shard_id + + for var in ("SLURM_ARRAY_TASK_ID", "SLURM_PROCID"): + if var in os.environ: + shard_id = int(os.environ[var]) + if not 0 <= shard_id < self.num_shards: + raise ValueError( + f"{var}={shard_id} is out of range for " + f"num_shards={self.num_shards}" + ) + return shard_id + + raise ValueError( + "num_shards > 1 but no shard id found. Set it with --shard_id, " + "or via SLURM_ARRAY_TASK_ID/SLURM_PROCID." + ) + + @property + def final_run_path(self) -> Path: + """Where this run's finished artifacts are published.""" + if self.sharded: + name = shard_dir_name(self.resolved_shard_id, self.num_shards) + return Path(self.run_path) / SHARDS_DIRNAME / name + + return Path(self.run_path) + @property def partial_run_path(self) -> Path: """Temporary path to use while writing build artifacts.""" + if self.sharded: + final = self.final_run_path + return final.with_name(final.name + ".part") + return Path(self.run_path + ".part") diff --git a/bergson/config/config_io.py b/bergson/config/config_io.py index 15fadb97..1921f3e4 100644 --- a/bergson/config/config_io.py +++ b/bergson/config/config_io.py @@ -1,3 +1,4 @@ +import os import subprocess from datetime import datetime, timezone from importlib.metadata import PackageNotFoundError @@ -12,6 +13,11 @@ CONFIG_FILENAME = "config.yaml" +# Per-invocation identity, not run configuration: these index_cfg fields may +# differ between the shards of one sharded run and are stripped from the +# canonical config.yaml shared by all shards. +EPHEMERAL_INDEX_FIELDS = ("shard_id", "overwrite") + def _resolve(path: str | Path) -> Path: """Return the path to a ``config.yaml``, accepting either a dir or a file.""" @@ -78,6 +84,64 @@ def save_run_config(command: Any, run_path: str | Path): _write([step], Path(run_path) / CONFIG_FILENAME) +def canonical_steps(command: Any) -> list[dict[str, Any]]: + """One-step ``steps`` list with per-invocation identity stripped. + + All shards of a sharded run must produce the same canonical steps, so + fields that legitimately differ between shards (``shard_id``, + ``overwrite``, ``distributed.node_rank``) are removed. + """ + # Round-trip through YAML so the comparison in publish_canonical_config + # sees the same plain types a reread of the file would produce. + steps = yaml.safe_load( + yaml.safe_dump([{(type(command).__name__).lower(): command.to_dict()}]) + ) + + for parsed_step in steps: + for cmd_dict in parsed_step.values(): + index_cfg = (cmd_dict or {}).get("index_cfg") + if not index_cfg: + continue + for field in EPHEMERAL_INDEX_FIELDS: + index_cfg.pop(field, None) + (index_cfg.get("distributed") or {}).pop("node_rank", None) + + return steps + + +def publish_canonical_config(command: Any, run_path: str | Path): + """Write the ``config.yaml`` shared by all shards of a sharded run. + + The first shard to arrive writes it atomically; every later shard + verifies its own canonical config matches and errors out otherwise, + so shards built from different configurations can never mix in one + run_path. + """ + path = Path(run_path) / CONFIG_FILENAME + steps = canonical_steps(command) + + if path.exists(): + existing = read_config(path) + if existing["steps"] != steps: + raise ValueError( + f"{path} was written by a run with a different configuration. " + f"Refusing to add shards to it; use a fresh run_path or " + f"rerun with the original configuration." + ) + return + + doc: dict[str, Any] = {"steps": steps, "metadata": make_metadata()} + path.parent.mkdir(parents=True, exist_ok=True) + + # Concurrent shards may race to create the file; writing a temp file and + # renaming it into place is atomic, and every racer writes identical + # steps, so last-writer-wins is safe. + tmp_path = path.with_name(f".{CONFIG_FILENAME}.{os.getpid()}.tmp") + with tmp_path.open("w") as f: + yaml.safe_dump(doc, f, sort_keys=False) + os.rename(tmp_path, path) + + def save_pipeline_config(steps: list[tuple[str, Any]], run_path: str | Path | None): """Write a multi-step ``config.yaml`` to ``run_path``. diff --git a/bergson/data.py b/bergson/data.py index 594661df..19465be9 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -26,6 +26,7 @@ from transformers import PreTrainedTokenizerFast, logging from .config import DataConfig +from .sharding import ShardedMemmap, is_sharded_run, published_shard_dirs from .utils.utils import ( assert_type, simple_parse_kwargs_string, @@ -129,7 +130,7 @@ def create_token_index( def load_token_gradients( root_dir: Path | str, -) -> tuple[np.memmap, np.ndarray, np.ndarray]: +) -> tuple[np.memmap | ShardedMemmap, np.ndarray, np.ndarray]: """Load per-token gradients stored by :func:`create_token_index`. Returns @@ -140,6 +141,19 @@ def load_token_gradients( shape ``(num_token_grads[i], total_grad_dim)``. """ root_dir = Path(root_dir) + + # Sharded runs: concatenate the per-shard token indexes. + if not (root_dir / "token_gradients.bin").exists() and is_sharded_run(root_dir): + parts = [ + load_token_gradients(shard_dir) + for shard_dir in published_shard_dirs(root_dir) + ] + mmap = ShardedMemmap([mmap for mmap, _, _ in parts]) + num_token_grads = np.concatenate([counts for _, counts, _ in parts]) + offsets = np.zeros(len(num_token_grads) + 1, dtype=np.int64) + np.cumsum(num_token_grads, out=offsets[1:]) + return mmap, num_token_grads, offsets + with (root_dir / "info.json").open("r") as f: info = json.load(f) @@ -487,9 +501,27 @@ def load_data_string( return ds -def load_gradients(root_dir: Path | str, structured: bool = True) -> np.memmap: - """Map the structured gradients stored in `root_dir` into memory.""" +def load_gradients( + root_dir: Path | str, + structured: bool = True, + allow_partial: bool = False, +) -> np.memmap | ShardedMemmap: + """Map the structured gradients stored in `root_dir` into memory. + + For sharded runs (a ``shards/`` directory of per-shard indexes), the + published shards are presented as one logically concatenated index. + ``allow_partial`` permits reading a sharded run whose shards have not + all been published yet. + """ root_dir = Path(root_dir) + if not (root_dir / "gradients.bin").exists() and is_sharded_run(root_dir): + return ShardedMemmap( + [ + load_gradients(shard_dir, structured=structured) + for shard_dir in published_shard_dirs(root_dir, allow_partial) + ] + ) + with (root_dir / "info.json").open("r") as f: info = json.load(f) @@ -542,14 +574,19 @@ def _to_arrow(arr: np.ndarray) -> pa.Array: if (root_dir / "data.hf").exists(): return load_shard(root_dir) + if is_sharded_run(root_dir): + shard_paths = published_shard_dirs(root_dir) + else: + shard_paths = [path for path in sorted(root_dir.iterdir()) if path.is_dir()] + # Flatten indices to avoid CPU OOM return concatenate_datasets( - [load_shard(path) for path in sorted(root_dir.iterdir()) if path.is_dir()] + [load_shard(path) for path in shard_paths] ).flatten_indices() class Scores: - def __init__(self, mmap: np.memmap, info: dict[str, Any]): + def __init__(self, mmap: np.memmap | ShardedMemmap, info: dict[str, Any]): self.mmap = mmap self.info = info self.num_scores = info["num_scores"] @@ -573,10 +610,24 @@ def is_written(self) -> bool: return all(np.all(self.mmap[f"written_{i}"]) for i in range(self.num_scores)) -def load_scores(path: Path) -> Scores: +def load_scores(path: Path, allow_partial: bool = False) -> Scores: bin_path = path / "scores.bin" info_path = path / "info.json" + # Sharded runs: present the per-shard scores as one concatenated array. + if not bin_path.exists() and is_sharded_run(path): + shards = [ + load_scores(shard_dir) + for shard_dir in published_shard_dirs(path, allow_partial) + ] + if len({shard.num_scores for shard in shards}) != 1: + raise ValueError(f"Shards of {path} disagree on num_scores") + + mmap = ShardedMemmap([shard.mmap for shard in shards]) + info = dict(shards[0].info) + info["num_items"] = len(mmap) + return Scores(mmap, info) + with open(info_path, "r") as f: info = json.load(f) diff --git a/bergson/gradients.py b/bergson/gradients.py index ab6a2d51..ee12c1d1 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -9,6 +9,8 @@ from torch import Tensor from transformers.pytorch_utils import Conv1D as HFConv1D +from bergson.sharding import is_sharded_run, published_shard_dirs + NORMALIZER_TYPES: dict[str, type["Normalizer"]] = {} @@ -142,6 +144,15 @@ def load( Load the normalizers and hessians from a file. """ path = Path(path) + + # Sharded runs store identical processor artifacts in every shard; + # read them from the first published one. + if not (path / "processor_config.yaml").exists() and is_sharded_run(path): + shard_dirs = published_shard_dirs(path, allow_partial=True) + if not shard_dirs: + raise FileNotFoundError(f"No published shards in {path}") + path = shard_dirs[0] + cfg_path = path / "processor_config.yaml" norm_path = path / "normalizers.pth" diff --git a/bergson/query/faiss_index.py b/bergson/query/faiss_index.py index faf5b399..544f719e 100644 --- a/bergson/query/faiss_index.py +++ b/bergson/query/faiss_index.py @@ -13,6 +13,7 @@ from bergson.config.config import FaissConfig from bergson.process_grads import precondition_flat_grads +from bergson.sharding import is_sharded_run, published_shard_dirs if TYPE_CHECKING: import faiss # noqa: F401 # pyright: ignore[reportMissingImports] @@ -103,6 +104,9 @@ def load_shard(shard_dir: Path) -> np.memmap: if (root_dir / "info.json").exists(): yield load_shard(root_dir) + elif is_sharded_run(root_dir): + for path in published_shard_dirs(root_dir): + yield load_shard(path) else: for path in sorted(root_dir.iterdir()): if "shard" in path.name: @@ -218,6 +222,11 @@ def create_index( # Write the gradients into an on-disk FAISS index if (gradients_path / "info.json").exists(): info_paths = [gradients_path / "info.json"] + elif is_sharded_run(gradients_path): + info_paths = [ + shard_path / "info.json" + for shard_path in published_shard_dirs(gradients_path) + ] else: info_paths = [ shard_path / "info.json" diff --git a/bergson/score/score.py b/bergson/score/score.py index 364eca1b..456db16b 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -40,6 +40,7 @@ ) from bergson.utils.worker_utils import ( create_processor, + publish_shard, setup_data_pipeline, setup_model_and_peft, ) @@ -388,7 +389,11 @@ def score_dataset( ) if dist_cfg.rank == 0: - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + if index_cfg.sharded: + assert isinstance(ds, Dataset) + publish_shard(index_cfg, num_items=len(ds)) + else: + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) if dist_cfg.world_size < index_cfg.distributed.world_size: parent_barrier(index_cfg.distributed) diff --git a/bergson/sharding.py b/bergson/sharding.py new file mode 100644 index 00000000..9820c3fa --- /dev/null +++ b/bergson/sharding.py @@ -0,0 +1,247 @@ +"""Embarrassingly-parallel sharded runs that share one run_path. + +Multiple independent ``bergson build``/``bergson score`` invocations +(typically one per SLURM array task) each process one contiguous slice of +the dataset and publish into the same ``run_path``: + + run_path/ + ├── config.yaml # canonical config, identical across shards + └── shards/ + ├── 00000-of-00064/ # published shard + │ ├── gradients.bin / scores.bin + │ ├── info.json + │ ├── shard.json # provenance: dataset slice, host, timestamp + │ └── ... + └── 00017-of-00064.part/ # in progress, or left over from a crash + +A shard is written under a ``.part`` suffix and atomically renamed into +place on success, so a published shard is always complete and a crashed +shard is rebuilt by simply re-running the same command (a published shard +is skipped, making restarts idempotent). Readers present the published +shards as one logically concatenated index, so no manual stitching is +needed at any point. +""" + +import json +import os +import re +import socket +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Sequence + +import numpy as np + +SHARDS_DIRNAME = "shards" +PART_SUFFIX = ".part" +SHARD_RECORD_FILENAME = "shard.json" + +_SHARD_DIR_RE = re.compile(r"^(\d{5})-of-(\d{5})$") + + +def shard_dir_name(shard_id: int, num_shards: int) -> str: + """Directory name for one shard, e.g. ``00017-of-00064``.""" + return f"{shard_id:05d}-of-{num_shards:05d}" + + +def shard_row_range(total_rows: int, num_shards: int, shard_id: int) -> tuple[int, int]: + """[start, end) row range of a contiguous shard. + + Matches ``datasets.Dataset.shard(num_shards, shard_id, contiguous=True)``. + """ + div, mod = divmod(total_rows, num_shards) + start = shard_id * div + min(shard_id, mod) + end = start + div + (1 if shard_id < mod else 0) + return start, end + + +def is_sharded_run(run_path: Path | str) -> bool: + """Whether ``run_path`` holds per-shard subdirectories.""" + return (Path(run_path) / SHARDS_DIRNAME).is_dir() + + +def make_shard_record( + shard_id: int, + num_shards: int, + split: str, + row_range: tuple[int, int] | None, + num_items: int, +) -> dict[str, Any]: + """Provenance record written to ``shard.json`` inside a shard directory. + + ``row_range`` is the [start, end) slice of the resolved parent split + *before* tokenization; ``num_items`` is the number of index items the + shard produced, which can differ when chunking is enabled. The shard's + own ``info.json`` holds the authoritative gradient counts. + """ + record: dict[str, Any] = { + "shard_id": shard_id, + "num_shards": num_shards, + "split": split, + "num_items": num_items, + "hostname": socket.gethostname(), + "completed_at": datetime.now(timezone.utc).isoformat(timespec="seconds"), + } + if row_range is not None: + record["row_start"], record["row_end"] = row_range + record["num_rows"] = row_range[1] - row_range[0] + for var in ("SLURM_JOB_ID", "SLURM_ARRAY_JOB_ID", "SLURM_ARRAY_TASK_ID"): + if var in os.environ: + record[var.lower()] = os.environ[var] + return record + + +def write_shard_record(shard_dir: Path, record: dict[str, Any]) -> None: + with (Path(shard_dir) / SHARD_RECORD_FILENAME).open("w") as f: + json.dump(record, f, indent=2) + + +def shard_status( + run_path: Path | str, +) -> tuple[dict[int, Path], dict[int, Path], int | None]: + """Inventory of a sharded run. + + Returns ``(published, partial, num_shards)`` where the dicts map + shard_id to its directory. ``num_shards`` is ``None`` if no shard + directories exist yet. + """ + shards_dir = Path(run_path) / SHARDS_DIRNAME + published: dict[int, Path] = {} + partial: dict[int, Path] = {} + num_shards: int | None = None + + if not shards_dir.is_dir(): + return published, partial, num_shards + + for child in sorted(shards_dir.iterdir()): + if not child.is_dir(): + continue + name = child.name.removesuffix(PART_SUFFIX) + match = _SHARD_DIR_RE.match(name) + if match is None: + continue + shard_id, total = int(match.group(1)), int(match.group(2)) + if num_shards is None: + num_shards = total + elif total != num_shards: + raise ValueError( + f"Inconsistent shard counts in {shards_dir}: found shards " + f"of {num_shards} and of {total}. The run_path mixes " + f"runs with different --num_shards." + ) + if child.name.endswith(PART_SUFFIX): + partial[shard_id] = child + else: + published[shard_id] = child + + return published, partial, num_shards + + +def published_shard_dirs( + run_path: Path | str, allow_partial: bool = False +) -> list[Path]: + """Published shard directories in shard order. + + Raises if any shard is missing or unpublished, unless ``allow_partial``. + """ + published, partial, num_shards = shard_status(run_path) + if num_shards is None: + raise FileNotFoundError(f"No shard directories found in {run_path}") + + missing = sorted(set(range(num_shards)) - published.keys()) + if missing and not allow_partial: + in_progress = sorted(partial.keys()) + raise RuntimeError( + f"Sharded run {run_path} is incomplete: missing shards {missing}" + + (f" (in progress or crashed: {in_progress})" if in_progress else "") + + ". Re-run the missing shards, or pass allow_partial=True to " + "read the published subset." + ) + + return [published[i] for i in sorted(published)] + + +class ShardedMemmap: + """Read-only view over per-shard arrays, concatenated along axis 0. + + Quacks enough like an ``np.memmap`` for the index readers: ``len``, + ``shape``, ``dtype``, field access on structured arrays, and + int/slice/fancy indexing. Indexing materializes the requested rows as + a regular ndarray while the underlying per-shard memmaps stay lazy, so + avoid full-index reads (``mmap[:]``) on very large runs — iterate + ``shards`` instead. + """ + + def __init__(self, arrays: "Sequence[np.ndarray | ShardedMemmap]"): + if not arrays: + raise ValueError("ShardedMemmap needs at least one array") + head = arrays[0] + for arr in arrays[1:]: + if arr.dtype != head.dtype or arr.shape[1:] != head.shape[1:]: + raise ValueError( + f"Shards disagree on dtype/shape: {head.dtype}{head.shape[1:]} " + f"vs {arr.dtype}{arr.shape[1:]}" + ) + self.shards = list(arrays) + # offsets[i] is the global index of shard i's first row + self._offsets = np.cumsum([0] + [len(a) for a in self.shards]) + + @property + def dtype(self) -> np.dtype: + return self.shards[0].dtype + + @property + def shape(self) -> tuple[int, ...]: + return (int(self._offsets[-1]), *self.shards[0].shape[1:]) + + def __len__(self) -> int: + return int(self._offsets[-1]) + + def copy(self) -> np.ndarray: + """Materialize the full concatenated array.""" + return self[:] + + def reshape(self, *shape: Any) -> np.ndarray: + """Materialize the full concatenated array and reshape it.""" + return self[:].reshape(*shape) + + def __array__(self, dtype: Any = None, copy: Any = None) -> np.ndarray: + out = self[:] + return out if dtype is None else out.astype(dtype) + + def __getitem__(self, key: Any) -> np.ndarray: + if isinstance(key, str): + return np.concatenate([np.asarray(a[key]) for a in self.shards], axis=0) + + if isinstance(key, (int, np.integer)): + idx = int(key) + if idx < 0: + idx += len(self) + if not 0 <= idx < len(self): + raise IndexError(f"index {key} out of range for length {len(self)}") + shard = int(np.searchsorted(self._offsets, idx, side="right")) - 1 + return self.shards[shard][idx - self._offsets[shard]] + + if isinstance(key, slice): + start, stop, step = key.indices(len(self)) + if step != 1: + return self[np.arange(start, stop, step)] + pieces = [ + np.asarray(arr[max(start - off, 0) : max(stop - off, 0)]) + for arr, off in zip(self.shards, self._offsets) + ] + return np.concatenate(pieces, axis=0) + + indices = np.asarray(key) + if indices.dtype == bool: + indices = np.nonzero(indices)[0] + indices = np.where(indices < 0, indices + len(self), indices) + if indices.size and (indices.min() < 0 or indices.max() >= len(self)): + raise IndexError(f"index out of range for length {len(self)}") + + out = np.empty((len(indices), *self.shape[1:]), dtype=self.dtype) + shard_ids = np.searchsorted(self._offsets, indices, side="right") - 1 + for shard in np.unique(shard_ids): + mask = shard_ids == shard + out[mask] = self.shards[shard][indices[mask] - self._offsets[shard]] + return out diff --git a/bergson/utils/worker_utils.py b/bergson/utils/worker_utils.py index 9a10de46..13e7701e 100644 --- a/bergson/utils/worker_utils.py +++ b/bergson/utils/worker_utils.py @@ -1,6 +1,7 @@ import shutil import warnings from pathlib import Path +from typing import Any import numpy as np import pandas as pd @@ -32,6 +33,7 @@ IndexConfig, ModelConfig, ) +from bergson.config.config_io import publish_canonical_config from bergson.data import ( load_data_string, tokenize, @@ -39,6 +41,7 @@ ) from bergson.format import apply_format from bergson.gradients import GradientProcessor, Normalizer +from bergson.sharding import make_shard_record, shard_row_range, write_shard_record from bergson.utils import assert_type, get_layer_list, weighted_causal_lm_ce from bergson.utils.utils import get_device, simple_parse_kwargs_string @@ -62,6 +65,78 @@ def validate_run_path(index_cfg: IndexConfig): ) +def prepare_shard(command: Any, index_cfg: IndexConfig) -> bool: + """Prepare the run path for one shard of a sharded run. + + Publishes (or verifies) the canonical ``config.yaml`` shared by all + shards and cleans up this shard's directories. Returns True when the + shard is already published and there is nothing to do, which makes + re-running a sharded command (e.g. a requeued SLURM array task) + idempotent. + """ + # Pin the env-derived shard id so the per-shard config.yaml records it. + index_cfg.shard_id = index_cfg.resolved_shard_id + + # A shard is its own single-node world. Without this, SLURM's node + # numbering from a surrounding multi-node launch (SLURM_NODEID > 0) + # would leak into DistributedConfig and stop this process from acting + # as rank 0 of its shard. + index_cfg.distributed.node_rank = 0 + + publish_canonical_config(command, index_cfg.run_path) + + final_path = index_cfg.final_run_path + if final_path.exists(): + if not index_cfg.overwrite: + print( + f"Shard {index_cfg.shard_id}/{index_cfg.num_shards} is " + f"already published at {final_path}; nothing to do. " + f"Use --overwrite to rebuild it." + ) + return True + shutil.rmtree(final_path) + + # A .part directory is incomplete by definition: a crashed or + # concurrent run. Its contents are deterministic, so rebuilding is safe. + if index_cfg.partial_run_path.exists(): + shutil.rmtree(index_cfg.partial_run_path) + + return False + + +def publish_shard(index_cfg: IndexConfig, num_items: int) -> None: + """Atomically publish this shard's ``.part`` directory. + + Writes the shard's provenance record, then renames the directory into + ``run_path/shards/``. The rename is atomic, so a published shard is + always complete. + """ + record = make_shard_record( + index_cfg.resolved_shard_id, + index_cfg.num_shards, + index_cfg.data.split, + getattr(index_cfg, "shard_row_range", None), + num_items, + ) + write_shard_record(index_cfg.partial_run_path, record) + + final_path = index_cfg.final_run_path + try: + index_cfg.partial_run_path.rename(final_path) + except OSError: + # A concurrent run of the same shard (e.g. a requeued SLURM task + # racing its predecessor) published first. Both runs build the same + # data, so the published shard is complete either way. + if final_path.exists(): + print( + f"Shard already published by a concurrent run at " + f"{final_path}; discarding this build." + ) + shutil.rmtree(index_cfg.partial_run_path) + else: + raise + + def create_processor( model: PreTrainedModel | PeftModel, cfg: IndexConfig, @@ -335,6 +410,24 @@ def setup_data_pipeline( ds = load_data_string( data_cfg.dataset, data_cfg.split, data_cfg.subset, data_cfg.data_kwargs ) + + # Sharded runs process one contiguous slice of the resolved split. + num_shards = getattr(cfg, "num_shards", 1) + if num_shards > 1 and data_cfg is cfg.data: + assert isinstance(cfg, IndexConfig) + if not isinstance(ds, Dataset): + raise ValueError("Sharded runs require a non-streaming Dataset.") + shard_id = cfg.resolved_shard_id + row_range = shard_row_range(len(ds), num_shards, shard_id) + ds = ds.shard(num_shards=num_shards, index=shard_id, contiguous=True) + print( + f"Shard {shard_id}/{num_shards}: processing rows " + f"[{row_range[0]}, {row_range[1]}) of {data_cfg.split}" + ) + # Stashed on the config (not a dataclass field) so publish_shard can + # record the slice in this shard's shard.json. + setattr(cfg, "shard_row_range", row_range) + tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer or cfg.model) max_model_length = max_tokens_for_model(tokenizer, cfg.model, cfg.revision) diff --git a/docs/cli.rst b/docs/cli.rst index ad3fc615..478fae52 100644 --- a/docs/cli.rst +++ b/docs/cli.rst @@ -112,3 +112,39 @@ runs hessian fitting, build, and score as a single pipeline (see :doc:`trackstar --query.dataset NeelNanda/pile-10k \ --query.truncation \ --projection_dim 16 + +Sharded (data-parallel) runs +---------------------------- + +``build`` and ``score`` are embarrassingly parallel across the dataset, and +both accept ``--num_shards``/``--shard_id`` to split a run into independent +single-node jobs that share one ``run_path``: + +.. code-block:: bash + + # Typically launched as a SLURM job array (sbatch --array=0-63 --requeue); + # --shard_id is inferred from SLURM_ARRAY_TASK_ID when unset. + bergson build runs/my-index \ + --model EleutherAI/pythia-14m \ + --dataset NeelNanda/pile-10k \ + --truncation \ + --num_shards 64 --shard_id $SLURM_ARRAY_TASK_ID + +Each shard processes one contiguous slice of the dataset, writes into +``run_path/shards/-of-.part``, and atomically renames the directory +into place when it finishes. A crashed shard is rebuilt by re-running the +same command; a finished shard is skipped, so requeued jobs are idempotent. +The first shard to arrive writes a canonical ``run_path/config.yaml`` and +every other shard verifies its configuration against it. + +No stitching is needed afterwards: ``load_gradients``, ``load_scores``, +``Attributor``, and friends read the published shards as one concatenated +index. Inspect progress with: + +.. code-block:: bash + + bergson status runs/my-index + +See ``examples/slurm/data_parallel_score.sh`` for a complete job-array +script. Sharded runs do not support simultaneous Hessian estimation or +gradient aggregation; compute those separately. diff --git a/examples/slurm/data_parallel_score.sh b/examples/slurm/data_parallel_score.sh index 9b8a47f8..bc8be5e4 100644 --- a/examples/slurm/data_parallel_score.sh +++ b/examples/slurm/data_parallel_score.sh @@ -1,29 +1,36 @@ #!/usr/bin/bash #SBATCH --job-name=bergson_score_data_parallel -#SBATCH --nodes=64 -#SBATCH --ntasks=64 -#SBATCH --ntasks-per-node=1 +#SBATCH --array=0-63 +#SBATCH --requeue +#SBATCH --nodes=1 +#SBATCH --ntasks=1 #SBATCH --gpus-per-node=4 #SBATCH --time=24:00:00 -#SBATCH --output=logs/bergson_score_data_parallel_%A_%N.out -#SBATCH --error=logs/bergson_score_data_parallel_%A_%N.err +#SBATCH --output=logs/bergson_score_%A_%a.out +#SBATCH --error=logs/bergson_score_%A_%a.err + +# Embarrassingly parallel scoring: every array task processes one contiguous +# shard of the dataset and publishes it into the same run_path. A task that +# dies is requeued by SLURM and rebuilds only its own shard; shards that +# already finished are skipped, so requeues are idempotent. No stitching is +# needed afterwards — bergson reads run_path as one index. +# +# Check progress at any time with: bergson status runs/$RUN_NAME mkdir -p logs hf auth login --token -# Set number of nodes -NUM_NODES=64 +NUM_SHARDS=64 # keep equal to the array size above RUN_NAME="bergson_score" -TOTAL_EXAMPLES=100_000_000 -EXAMPLES_PER_NODE=$((TOTAL_EXAMPLES / NUM_NODES)) - -# Export variables for the worker script -export TOTAL_EXAMPLES -export EXAMPLES_PER_NODE -export NUM_NODES -export RUN_NAME - -# Run worker script on each node -srun --kill-on-bad-exit=1 --output=logs/bergson_score_%A_%t.out --error=logs/bergson_score_%A_%t.err bash score_worker.sh +# --shard_id is inferred from SLURM_ARRAY_TASK_ID +python -m bergson score \ + runs/$RUN_NAME \ + --num_shards $NUM_SHARDS \ + --query_path runs/$QUERY_RUN_NAME \ + --score individual \ + --dataset NeelNanda/pile-10k \ + --model EleutherAI/pythia-14m \ + --token_batch_size 15500 \ + --truncation diff --git a/examples/slurm/score_worker.sh b/examples/slurm/score_worker.sh deleted file mode 100644 index 5dcc93bb..00000000 --- a/examples/slurm/score_worker.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -SHARD_ID=$(printf "%05d" $SLURM_PROCID) - -# Set the dataset starting index for this shard -SHARD_START=$((SLURM_PROCID * EXAMPLES_PER_NODE)) - -# Set the dataset ending index for this shard. -# Include the remainder if it's the final shard. -LAST_SLURM_PROCID=$((NUM_NODES - 1)) -if [ $SLURM_PROCID -eq $LAST_SLURM_PROCID ]; then - SHARD_END=$TOTAL_EXAMPLES -else - SHARD_END=$(((SLURM_PROCID + 1) * EXAMPLES_PER_NODE)) -fi - -NODE_EXAMPLES=$((SHARD_END - SHARD_START)) - -echo "Node $SLURM_PROCID (shard $SLURM_PROCID) processing examples $SHARD_START to $SHARD_END (total: $NODE_EXAMPLES)" -echo "[$(date)] Shard $SLURM_PROCID: examples $SHARD_START to $SHARD_END - STARTING" - -# Start GPU monitoring in background -(while true; do - GPU_UTIL=$(nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits | paste -sd, -) - echo "[$(date)] Node $SLURM_PROCID Shard $SLURM_PROCID GPU util: $GPU_UTIL%" - sleep 60 -done) & -MONITOR_PID=$! - -python -m bergson score \ - runs/$RUN_NAME/shard-$SHARD_ID \ - --split "train[$SHARD_START:$SHARD_END]" \ - --query_path runs/$QUERY_RUN_NAME \ - --score individual \ - --dataset NeelNanda/pile-10k \ - --model EleutherAI/pythia-14m \ - --token_batch_size 15500 \ - --truncation - -EXIT_CODE=$? - -# Kill the monitoring process -kill $MONITOR_PID 2>/dev/null -wait $MONITOR_PID 2>/dev/null - -if [ $EXIT_CODE -eq 0 ]; then - echo "[$(date)] Process $SLURM_PROCID complete!" -else - echo "[$(date)] Process $SLURM_PROCID failed with exit code $EXIT_CODE." -fi diff --git a/tests/test_sharded_runs.py b/tests/test_sharded_runs.py new file mode 100644 index 00000000..6c69017e --- /dev/null +++ b/tests/test_sharded_runs.py @@ -0,0 +1,321 @@ +"""Sharded (embarrassingly parallel) runs that share one run_path. + +Unit tests cover the sharding helpers, the concatenated reader view, and +the canonical config protocol. The GPU integration test simulates the +SLURM job-array workflow on one machine: several independent ``bergson +build`` invocations with ``--num_shards``/``--shard_id`` publish into one +run_path, one of them is killed mid-build and restarted, a finished shard +is re-run to check idempotency, and a config mismatch is rejected. The +resulting index must match a non-sharded build of the same dataset. +""" + +import os +import signal +import subprocess +import time +from pathlib import Path + +import numpy as np +import pytest +import torch +import yaml +from datasets import Dataset + +from bergson.config.config import IndexConfig, PreprocessConfig +from bergson.config.config_io import publish_canonical_config +from bergson.data import load_gradient_dataset, load_gradients +from bergson.gradients import GradientProcessor +from bergson.sharding import ( + ShardedMemmap, + published_shard_dirs, + shard_dir_name, + shard_row_range, + shard_status, +) + +# --------------------------------------------------------------------------- +# Unit tests +# --------------------------------------------------------------------------- + + +def test_shard_row_range_matches_hf_contiguous_shard(): + ds = Dataset.from_dict({"x": list(range(103))}) + for num_shards in (1, 4, 7, 103): + previous_end = 0 + for shard_id in range(num_shards): + start, end = shard_row_range(len(ds), num_shards, shard_id) + piece = ds.shard(num_shards=num_shards, index=shard_id, contiguous=True) + assert start == previous_end + assert end - start == len(piece) + assert piece["x"] == list(range(start, end)) + previous_end = end + assert previous_end == len(ds) + + +@pytest.mark.parametrize("structured", [False, True]) +def test_sharded_memmap_matches_concatenate(structured: bool): + rng = np.random.default_rng(0) + if structured: + dtype = np.dtype({"names": ["a", "b"], "formats": ["(3,)f4", "f4"]}) + parts = [] + for n in (3, 1, 4): + arr = np.zeros(n, dtype=dtype) + arr["a"] = rng.normal(size=(n, 3)) + arr["b"] = rng.normal(size=n) + parts.append(arr) + else: + parts = [rng.normal(size=(n, 5)).astype(np.float32) for n in (3, 1, 4)] + + ref = np.concatenate(parts) + view = ShardedMemmap(parts) + + assert len(view) == len(ref) + assert view.shape == ref.shape + assert view.dtype == ref.dtype + assert np.array_equal(view[:], ref) + assert np.array_equal(view[2:7], ref[2:7]) + assert np.array_equal(view[::3], ref[::3]) + assert np.array_equal(view[-1], ref[-1]) + assert np.array_equal(view[4], ref[4]) + fancy = [7, 0, 3, 3, -2] + assert np.array_equal(view[fancy], ref[fancy]) + mask = np.arange(len(ref)) % 2 == 0 + assert np.array_equal(view[mask], ref[mask]) + assert np.array_equal(view.copy(), ref) + assert np.array_equal(np.asarray(view), ref) + + if structured: + assert view.dtype.names == ("a", "b") + assert np.array_equal(view["a"], ref["a"]) + + with pytest.raises(IndexError): + view[len(ref)] + + +def test_sharded_memmap_rejects_mismatched_shards(): + with pytest.raises(ValueError, match="disagree"): + ShardedMemmap([np.zeros((2, 3)), np.zeros((2, 4))]) + + +def _make_shard_dirs(run_path: Path, names: list[str]) -> None: + for name in names: + (run_path / "shards" / name).mkdir(parents=True) + + +def test_shard_status_and_published_dirs(tmp_path: Path): + _make_shard_dirs( + tmp_path, ["00000-of-00003", "00001-of-00003.part", "junk", "notes.txt"] + ) + (tmp_path / "shards" / "notes.txt").rmdir() + (tmp_path / "shards" / "notes.txt").touch() + + published, partial, num_shards = shard_status(tmp_path) + assert num_shards == 3 + assert sorted(published) == [0] + assert sorted(partial) == [1] + + # Incomplete runs raise by default and list the missing shards + with pytest.raises(RuntimeError, match=r"missing shards \[1, 2\]"): + published_shard_dirs(tmp_path) + + assert [p.name for p in published_shard_dirs(tmp_path, allow_partial=True)] == [ + "00000-of-00003" + ] + + # Mixing different --num_shards in one run_path is rejected + _make_shard_dirs(tmp_path, ["00000-of-00002"]) + with pytest.raises(ValueError, match="Inconsistent shard counts"): + shard_status(tmp_path) + + +class FakeBuild: + """Minimal stand-in for the Build command dataclass.""" + + def __init__(self, index_cfg: IndexConfig, preprocess_cfg: PreprocessConfig): + self.index_cfg = index_cfg + self.preprocess_cfg = preprocess_cfg + + def to_dict(self): + return { + "index_cfg": self.index_cfg.to_dict(), + "preprocess_cfg": self.preprocess_cfg.to_dict(), + } + + +def test_publish_canonical_config(tmp_path: Path): + run_path = tmp_path / "run" + + def make_command(shard_id: int, **kwargs) -> FakeBuild: + index_cfg = IndexConfig( + run_path=str(run_path), num_shards=4, shard_id=shard_id, **kwargs + ) + return FakeBuild(index_cfg, PreprocessConfig()) + + publish_canonical_config(make_command(0), run_path) + config_path = run_path / "config.yaml" + assert config_path.exists() + + with config_path.open() as f: + doc = yaml.safe_load(f) + index_dict = doc["steps"][0]["fakebuild"]["index_cfg"] + assert "shard_id" not in index_dict + assert "overwrite" not in index_dict + assert "node_rank" not in index_dict["distributed"] + assert index_dict["num_shards"] == 4 + + # Another shard with per-invocation differences only: accepted + other = make_command(2) + other.index_cfg.overwrite = True + other.index_cfg.distributed.node_rank = 5 + publish_canonical_config(other, run_path) + + # A shard with a different run configuration: rejected + with pytest.raises(ValueError, match="different configuration"): + publish_canonical_config(make_command(1, projection_dim=99), run_path) + + +def test_sharded_config_validation(): + with pytest.raises(ValueError, match="shard_id requires"): + IndexConfig(run_path="x", shard_id=0) + + with pytest.raises(ValueError, match="shard_id must be in"): + IndexConfig(run_path="x", num_shards=2, shard_id=2) + + cfg = IndexConfig(run_path="x", num_shards=2) + cfg.distributed.nnode = 2 + with pytest.raises(ValueError, match="cannot be combined with nnode"): + cfg.__post_init__() + + +# --------------------------------------------------------------------------- +# GPU integration: the SLURM job-array workflow on one machine +# --------------------------------------------------------------------------- + +NUM_EXAMPLES = 40 +NUM_SHARDS = 3 + + +def build_command(run_path: Path, dataset_path: Path, **overrides) -> list[str]: + args = { + "model": "gpt2", + "dataset": str(dataset_path), + "prompt_column": "text", + "projection_dim": "8", + "token_batch_size": "512", + "nproc_per_node": "1", + "force_math_sdp": None, # batch-composition-invariant gradients + **overrides, + } + cmd = ["bergson", "build", str(run_path)] + for key, value in args.items(): + cmd.append(f"--{key}") + if value is not None: + cmd.append(str(value)) + return cmd + + +def run_checked(cmd: list[str]) -> str: + result = subprocess.run(cmd, capture_output=True, text=True) + assert ( + result.returncode == 0 + ), f"{' '.join(cmd)} failed:\n{result.stdout}\n{result.stderr}" + return result.stdout + result.stderr + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sharded_build_lifecycle(tmp_path: Path): + texts = [ + f"The number {i} comes before the number {i + 1}." for i in range(NUM_EXAMPLES) + ] + dataset_path = tmp_path / "data" + Dataset.from_dict({"text": texts}).save_to_disk(str(dataset_path)) + + single_path = tmp_path / "single" + sharded_path = tmp_path / "sharded" + + # Reference: ordinary non-sharded build + run_checked(build_command(single_path, dataset_path)) + + shard_args = {"num_shards": str(NUM_SHARDS)} + + # ── Crash: kill shard 0 mid-build, before it can publish ──────────────── + crash_cmd = build_command(sharded_path, dataset_path, **shard_args, shard_id="0") + proc = subprocess.Popen( + crash_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, # so we can kill the whole process group + ) + part_dir = sharded_path / "shards" / (shard_dir_name(0, NUM_SHARDS) + ".part") + deadline = time.monotonic() + 120 + while not part_dir.exists(): + assert proc.poll() is None, "shard finished before it could be killed" + assert time.monotonic() < deadline, ".part dir never appeared" + time.sleep(0.05) + os.killpg(proc.pid, signal.SIGKILL) + proc.wait() + + published, partial, _ = shard_status(sharded_path) + assert 0 not in published, "killed shard must not be published" + assert 0 in partial, "killed shard should leave a .part dir behind" + + # ── Restart: re-running the same command rebuilds the crashed shard ───── + for shard_id in range(NUM_SHARDS): + run_checked( + build_command( + sharded_path, dataset_path, **shard_args, shard_id=str(shard_id) + ) + ) + + published, partial, num_shards = shard_status(sharded_path) + assert num_shards == NUM_SHARDS and len(published) == NUM_SHARDS and not partial + + # ── Idempotency: re-running a published shard is a no-op ──────────────── + output = run_checked( + build_command(sharded_path, dataset_path, **shard_args, shard_id="1") + ) + assert "already published" in output + + # ── Mismatch: a different config may not add shards to this run_path ──── + bad = subprocess.run( + build_command( + sharded_path, dataset_path, **shard_args, shard_id="2", projection_dim="16" + ), + capture_output=True, + text=True, + ) + assert bad.returncode != 0 + assert "different configuration" in bad.stderr + + # ── The sharded index reads back identical to the non-sharded one ─────── + single = load_gradients(single_path, structured=False) + sharded = load_gradients(sharded_path, structured=False) + assert isinstance(sharded, ShardedMemmap) + assert sharded.shape == single.shape + torch.testing.assert_close( + torch.from_numpy(sharded[:]).float(), + torch.from_numpy(single.copy()).float(), + ) + + # Structured view, dataset view, and processor artifacts all resolve + structured = load_gradients(sharded_path) + assert structured.dtype.names == load_gradients(single_path).dtype.names + ds = load_gradient_dataset(sharded_path, structured=False) + assert len(ds) == NUM_EXAMPLES + GradientProcessor.load(sharded_path) + + # Canonical + per-shard configs and provenance records are in place + assert (sharded_path / "config.yaml").exists() + shard_dirs = published_shard_dirs(sharded_path) + ranges = [] + for shard_dir in shard_dirs: + assert (shard_dir / "config.yaml").exists() + with (shard_dir / "shard.json").open() as f: + record = yaml.safe_load(f) + ranges.append((record["row_start"], record["row_end"])) + assert ranges[0][0] == 0 and ranges[-1][1] == NUM_EXAMPLES + assert all(end == nxt for (_, end), (nxt, _) in zip(ranges, ranges[1:])) + + # `bergson status` agrees + status_output = run_checked(["bergson", "status", str(sharded_path)]) + assert f"{NUM_SHARDS}/{NUM_SHARDS} shards published" in status_output