diff --git a/charge3net_ft/data.py b/charge3net_ft/data.py index 34ffa8d..544097f 100644 --- a/charge3net_ft/data.py +++ b/charge3net_ft/data.py @@ -10,6 +10,7 @@ opened tables per chunk file so each file is read from disk only once per worker. """ +import collections import json import sys from functools import partial @@ -55,10 +56,20 @@ # --------------------------------------------------------------------------- _SYMBOL_TO_Z = {s: z for z, s in enumerate(ase.data.chemical_symbols)} -# Process-local table cache: keyed by file index, populated on first access. -# Each DataLoader worker process has its own cache, so each chunk file is read -# from disk at most once per worker instead of once per __getitem__ call. -_TABLE_CACHE: dict = {} +# Process-local LRU table cache: keyed by file index, populated on first access. +# Each DataLoader worker has its own cache (workers fork the parent), so each +# chunk file is read from disk at most once per worker per cache cycle. +# +# Bounded LRU because the previous unbounded version OOM-killed jobs 4971293 +# and 4971343 at MaxRSS=35 GB/rank. Per-chunk decompressed pyarrow tables +# weigh ~2 GB (the compressed_charge_density JSON strings inflate 6x from +# disk). With 8 workers x 4 DDP ranks = 32 workers, an unbounded cache grew +# to ~140 GB total in 6 h. +# +# Cap of 5 chunks per worker keeps each worker's cache around 10 GB worst +# case, well under any per-rank memory budget. OrderedDict gives O(1) LRU. +_TABLE_CACHE_MAX_CHUNKS = 5 +_TABLE_CACHE: "collections.OrderedDict[int, object]" = collections.OrderedDict() def _parse_grid_json(json_str: str) -> np.ndarray: @@ -131,7 +142,9 @@ def _build_parquet_index(parquet_dir: Path) -> tuple: index.append((fi, ri)) n_valid = len(index) - print(f"LeMatRhoDataset: {n_valid}/{n_total} valid rows indexed from {len(file_paths)} files") + print( + f"LeMatRhoDataset: {n_valid}/{n_total} valid rows indexed from {len(file_paths)} files" + ) return file_paths, index @@ -186,11 +199,21 @@ def _read_row(self, idx: int) -> dict: """ Read a single row from disk via its index entry. - Uses a process-local cache (_TABLE_CACHE) so each chunk file is - loaded from disk only once per worker, not on every __getitem__ call. + Uses a process-local LRU cache (_TABLE_CACHE) so each chunk file is + loaded from disk at most once per worker per cache cycle. Cache is + capped at _TABLE_CACHE_MAX_CHUNKS entries; on a miss past capacity + the least-recently-used chunk is evicted. Re-access of a present + entry promotes it to most-recent so the running shuffled-access + pattern from RandomSampler doesn't constantly thrash. """ fi, ri = self._index[idx] - if fi not in _TABLE_CACHE: + if fi in _TABLE_CACHE: + # Hit: bump to most-recent and return. + _TABLE_CACHE.move_to_end(fi) + else: + # Miss: evict LRU if at capacity, then read. + if len(_TABLE_CACHE) >= _TABLE_CACHE_MAX_CHUNKS: + _TABLE_CACHE.popitem(last=False) _TABLE_CACHE[fi] = pq.read_table(self._file_paths[fi], columns=_COLUMNS) table = _TABLE_CACHE[fi] row = {} @@ -230,6 +253,7 @@ def build_dataloaders( num_workers: int = 4, seed: int = 42, pin_memory: bool = False, + distributed: bool = False, ) -> tuple: """ Build train, validation, and test DataLoaders. @@ -298,10 +322,27 @@ def build_dataloaders( collate_fn = partial(collate_list_of_dicts, pin_memory=pin_memory) + # DDP path: shard the training set across ranks via DistributedSampler. + # Val/test stay non-distributed (each rank evaluates the whole set; only + # rank 0 reports). This wastes V+T compute but keeps eval simple and + # rank-agnostic. The data is tiny (5%+5% of 65k) so it's fine. + train_sampler = None + if distributed: + from torch.utils.data.distributed import DistributedSampler + + train_sampler = DistributedSampler( + train_subset, + shuffle=True, + seed=seed, + drop_last=True, + ) + train_loader = DataLoader( train_subset, batch_size=batch_size, - shuffle=True, + # shuffle and sampler are mutually exclusive in DataLoader. + shuffle=(train_sampler is None), + sampler=train_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, diff --git a/charge3net_ft/train.py b/charge3net_ft/train.py index c30b277..5004f3f 100644 --- a/charge3net_ft/train.py +++ b/charge3net_ft/train.py @@ -47,9 +47,49 @@ from .model import ChargE3NetWrapper # noqa: E402 +# --------------------------------------------------------------------------- +# Distributed training helpers +# --------------------------------------------------------------------------- +def _is_ddp() -> bool: + """True if SLURM/torchrun has set up multi-process training.""" + return int(os.environ.get("WORLD_SIZE", "1")) > 1 + + +def _setup_ddp() -> tuple[int, int, int]: + """Initialize the process group and return (rank, local_rank, world_size). + + No-op (returns 0, 0, 1) if we're not in a distributed environment. + + The submit script is expected to export the standard torch env vars from + SLURM: + WORLD_SIZE = $SLURM_NTASKS + RANK = $SLURM_PROCID + LOCAL_RANK = $SLURM_LOCALID + MASTER_ADDR = $(scontrol show hostname $SLURM_NODELIST | head -1) + MASTER_PORT = some unused port (e.g. 29500) + """ + if not _is_ddp(): + return 0, 0, 1 + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + # nccl works on AMD ROCm because PyTorch routes it through RCCL. + torch.distributed.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + return rank, local_rank, world_size + + +def _is_main(rank: int) -> bool: + """True on rank 0; used to gate prints, wandb, and checkpoint saves.""" + return rank == 0 + + def _probe_mask(targets: torch.Tensor, num_probes: torch.Tensor) -> torch.Tensor: """Boolean mask [B, max_probes], True for real probe points (not padding).""" - return torch.arange(targets.shape[1], device=targets.device)[None] < num_probes[:, None] + return ( + torch.arange(targets.shape[1], device=targets.device)[None] + < num_probes[:, None] + ) def compute_nmape( @@ -108,8 +148,16 @@ def compute_nrmse( return (rmse / (mean_abs + 1e-10) * 100.0).mean() -def train_one_epoch(model, train_loader, optimizer, scheduler, device, global_step, - log_every=50, use_wandb=False): +def train_one_epoch( + model, + train_loader, + optimizer, + scheduler, + device, + global_step, + log_every=50, + use_wandb=False, +): """Run one training epoch, return (average loss, updated global_step).""" model.train() total_loss = 0.0 @@ -135,7 +183,7 @@ def train_one_epoch(model, train_loader, optimizer, scheduler, device, global_st if (i + 1) % log_every == 0: lr = optimizer.param_groups[0]["lr"] - print(f" step {i+1}: loss={loss.item():.6f} lr={lr:.2e}") + print(f" step {i + 1}: loss={loss.item():.6f} lr={lr:.2e}") if use_wandb: wandb.log({"train/loss_step": loss.item(), "lr": lr}, step=global_step) @@ -177,12 +225,22 @@ def validate(model, loader, device): } +def _unwrap(model): + """Return the underlying ChargE3NetWrapper regardless of DDP wrapping. + + DistributedDataParallel wraps the user model in a ``.module`` attribute; + state_dict() and load_state_dict() should always target the inner model + so checkpoints are interchangeable between single-GPU and DDP runs. + """ + return model.module if hasattr(model, "module") else model + + def save_checkpoint(model, optimizer, scheduler, epoch, best_nmape, global_step, path): - """Save training checkpoint.""" + """Save training checkpoint (rank 0 should be the only caller in DDP).""" torch.save( { "epoch": epoch, - "model": model.model.state_dict(), + "model": _unwrap(model).model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "best_nmape": best_nmape, @@ -195,7 +253,7 @@ def save_checkpoint(model, optimizer, scheduler, epoch, best_nmape, global_step, def load_checkpoint(path, model, optimizer, scheduler, device): """Load training checkpoint, return (start_epoch, best_nmape, global_step).""" ckpt = torch.load(path, map_location=device, weights_only=False) - model.model.load_state_dict(ckpt["model"]) + _unwrap(model).model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) scheduler.load_state_dict(ckpt["scheduler"]) start_epoch = ckpt["epoch"] + 1 @@ -222,19 +280,35 @@ def main(): "Defaults to $LEMATRHO_DATA_DIR env var." ), ) - parser.add_argument("--ckpt-path", type=str, default=None, help="Pre-trained checkpoint (.pt)") - parser.add_argument("--save-dir", type=str, default="./checkpoints", help="Save directory") + parser.add_argument( + "--ckpt-path", type=str, default=None, help="Pre-trained checkpoint (.pt)" + ) + parser.add_argument( + "--save-dir", type=str, default="./checkpoints", help="Save directory" + ) parser.add_argument("--cutoff", type=float, default=4.0, help="Neighbor cutoff (A)") - parser.add_argument("--train-probes", type=int, default=200, help="Probes per sample (train)") - parser.add_argument("--val-probes", type=int, default=1000, help="Probes per sample (val/test)") + parser.add_argument( + "--train-probes", type=int, default=200, help="Probes per sample (train)" + ) + parser.add_argument( + "--val-probes", type=int, default=1000, help="Probes per sample (val/test)" + ) parser.add_argument("--batch-size", type=int, default=4, help="Batch size") parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate") parser.add_argument("--epochs", type=int, default=50, help="Number of epochs") - parser.add_argument("--val-frac", type=float, default=0.05, - help="Validation fraction. Do not change after first run.") - parser.add_argument("--test-frac", type=float, default=0.05, - help="Test fraction (held out, evaluated once at end). " - "Do not change after first run.") + parser.add_argument( + "--val-frac", + type=float, + default=0.05, + help="Validation fraction. Do not change after first run.", + ) + parser.add_argument( + "--test-frac", + type=float, + default=0.05, + help="Test fraction (held out, evaluated once at end). " + "Do not change after first run.", + ) parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--log-every", type=int, default=50, help="Log every N steps") @@ -252,14 +326,22 @@ def main(): default=None, help="Force device (cpu, cuda, mps). Auto-detect if not set.", ) - parser.add_argument("--resume-from", type=str, default=None, - help="Path to training checkpoint (latest.pt) to resume from") + parser.add_argument( + "--resume-from", + type=str, + default=None, + help="Path to training checkpoint (latest.pt) to resume from", + ) parser.add_argument("--wandb-project", type=str, default="lemat-rho-charge3net") parser.add_argument("--wandb-entity", type=str, default="dtts") parser.add_argument("--no-wandb", action="store_true", help="Disable W&B logging") - parser.add_argument("--wandb-mode", type=str, default="online", - choices=["online", "offline", "disabled"], - help="W&B mode (use 'offline' on air-gapped clusters)") + parser.add_argument( + "--wandb-mode", + type=str, + default="online", + choices=["online", "offline", "disabled"], + help="W&B mode (use 'offline' on air-gapped clusters)", + ) args = parser.parse_args() if args.parquet_dir is None: @@ -274,28 +356,48 @@ def main(): if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) + # DDP setup (no-op when WORLD_SIZE=1). Must happen before device + # selection because each rank pins itself to its own GPU via local_rank. + rank, local_rank, world_size = _setup_ddp() + is_main = _is_main(rank) + # Device if args.device: device = torch.device(args.device) + elif _is_ddp(): + device = torch.device(f"cuda:{local_rank}") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") - print(f"Using device: {device}") - - # W&B - use_wandb = not args.no_wandb and not args.smoke_test + if is_main: + print(f"Using device: {device}; world_size={world_size}") + + # W&B (rank 0 only). Soft-fail: if init times out (e.g. compute node + # can't reach api.wandb.ai through the cluster proxy), degrade to + # disabled mode and keep training. Used to be fatal — caused the + # 1h47m job 4969727 timeout-then-crash on Adastra. + use_wandb = (not args.no_wandb and not args.smoke_test) and is_main if use_wandb: - wandb.init( - project=args.wandb_project, - entity=args.wandb_entity, - config=vars(args), - settings=wandb.Settings(init_timeout=300), - mode=args.wandb_mode, - ) + try: + wandb.init( + project=args.wandb_project, + entity=args.wandb_entity, + config=vars(args), + settings=wandb.Settings(init_timeout=300), + mode=args.wandb_mode, + ) + except Exception as e: # noqa: BLE001 — really do want broad here + print( + f"WARNING: wandb.init failed ({type(e).__name__}: {e}); " + "continuing with wandb disabled. Training output is still " + "saved to checkpoints + stdout." + ) + use_wandb = False # Data - print("Building dataloaders...") + if is_main: + print("Building dataloaders...") train_loader, val_loader, test_loader = build_dataloaders( parquet_dir=args.parquet_dir, cutoff=args.cutoff, @@ -306,26 +408,40 @@ def main(): test_frac=args.test_frac, num_workers=args.num_workers, seed=args.seed, + distributed=_is_ddp(), ) - print( - f"Train: {len(train_loader.dataset)} samples, " - f"Val: {len(val_loader.dataset)} samples, " - f"Test: {len(test_loader.dataset)} samples" - ) + if is_main: + print( + f"Train: {len(train_loader.dataset)} samples, " + f"Val: {len(val_loader.dataset)} samples, " + f"Test: {len(test_loader.dataset)} samples" + ) - # Model - print("Initializing ChargE3Net...") + # Model. Loaded on every rank (each gets its own copy of the weights); + # DDP will sync gradients across ranks at backward. + if is_main: + print("Initializing ChargE3Net...") model = ChargE3NetWrapper(ckpt_path=args.ckpt_path, cutoff=args.cutoff) model = model.to(device) - n_params = sum(p.numel() for p in model.parameters()) - print(f"Model parameters: {n_params:,}") + if _is_ddp(): + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank], output_device=local_rank + ) + n_params = sum( + p.numel() for p in (model.module if _is_ddp() else model).parameters() + ) + if is_main: + print(f"Model parameters: {n_params:,}") # Smoke test: just run one forward pass if args.smoke_test: print("\n--- Smoke test ---") model.eval() batch = next(iter(train_loader)) - batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } print(f"Batch keys: {list(batch.keys())}") for k, v in batch.items(): if isinstance(v, torch.Tensor): @@ -376,13 +492,15 @@ def main(): f"NMAPE={nmape.item():.2f}% RMSE={rmse.item():.4f} NRMSE={nrmse.item():.2f}%" ) if use_wandb: - wandb.log({ - "overfit/L1": loss.item(), - "overfit/NMAPE": nmape.item(), - "overfit/RMSE": rmse.item(), - "overfit/NRMSE": nrmse.item(), - "epoch": epoch, - }) + wandb.log( + { + "overfit/L1": loss.item(), + "overfit/NMAPE": nmape.item(), + "overfit/RMSE": rmse.item(), + "overfit/NRMSE": nrmse.item(), + "epoch": epoch, + } + ) print("\nOverfit test complete.") if use_wandb: @@ -405,53 +523,86 @@ def main(): if args.resume_from: start_epoch, best_nmape, global_step = load_checkpoint( - args.resume_from, model, optimizer, scheduler, device, + args.resume_from, + model, + optimizer, + scheduler, + device, ) - print(f"\nStarting training from epoch {start_epoch + 1} to {args.epochs}...") + if is_main: + print(f"\nStarting training from epoch {start_epoch + 1} to {args.epochs}...") for epoch in range(start_epoch, args.epochs): + # DDP requires set_epoch on the sampler each epoch for proper shuffling. + if _is_ddp() and hasattr(train_loader.sampler, "set_epoch"): + train_loader.sampler.set_epoch(epoch) t0 = time.time() train_loss, global_step = train_one_epoch( - model, train_loader, optimizer, scheduler, device, global_step, - log_every=args.log_every, use_wandb=use_wandb, + model, + train_loader, + optimizer, + scheduler, + device, + global_step, + log_every=args.log_every, + use_wandb=use_wandb, ) val = validate(model, val_loader, device) elapsed = time.time() - t0 - print( - f"Epoch {epoch+1}/{args.epochs} " - f"train_L1={train_loss:.6f} " - f"val_L1={val['L1']:.6f} " - f"val_NMAPE={val['NMAPE']:.2f}% " - f"val_RMSE={val['RMSE']:.4f} " - f"val_NRMSE={val['NRMSE']:.2f}% " - f"time={elapsed:.0f}s" - ) + if is_main: + print( + f"Epoch {epoch + 1}/{args.epochs} " + f"train_L1={train_loss:.6f} " + f"val_L1={val['L1']:.6f} " + f"val_NMAPE={val['NMAPE']:.2f}% " + f"val_RMSE={val['RMSE']:.4f} " + f"val_NRMSE={val['NRMSE']:.2f}% " + f"time={elapsed:.0f}s" + ) if use_wandb: - wandb.log({ - "train/L1": train_loss, - "val/L1": val["L1"], - "val/NMAPE": val["NMAPE"], - "val/RMSE": val["RMSE"], - "val/NRMSE": val["NRMSE"], - "epoch": epoch + 1, - }, step=global_step) - - # Save best checkpoint (selected on val NMAPE) - if val["NMAPE"] < best_nmape: + wandb.log( + { + "train/L1": train_loss, + "val/L1": val["L1"], + "val/NMAPE": val["NMAPE"], + "val/RMSE": val["RMSE"], + "val/NRMSE": val["NRMSE"], + "epoch": epoch + 1, + }, + step=global_step, + ) + + # Save best checkpoint (selected on val NMAPE). Only rank 0 writes. + if is_main and val["NMAPE"] < best_nmape: best_nmape = val["NMAPE"] save_checkpoint( - model, optimizer, scheduler, epoch, best_nmape, global_step, + model, + optimizer, + scheduler, + epoch, + best_nmape, + global_step, save_dir / "best.pt", ) print(f" -> New best val NMAPE: {best_nmape:.2f}%") - # Save latest checkpoint every epoch (for SLURM resumption) - save_checkpoint( - model, optimizer, scheduler, epoch, best_nmape, global_step, - save_dir / "latest.pt", - ) + # Save latest checkpoint every epoch (for SLURM resumption). + if is_main: + save_checkpoint( + model, + optimizer, + scheduler, + epoch, + best_nmape, + global_step, + save_dir / "latest.pt", + ) + + # Keep ranks in lockstep so a slow saver doesn't get lapped. + if _is_ddp(): + torch.distributed.barrier() # ----------------------------------------------------------------------- # Test set evaluation — run once at the end using the best checkpoint. @@ -471,17 +622,22 @@ def main(): f"RMSE={test['RMSE']:.4f} NRMSE={test['NRMSE']:.2f}%" ) if use_wandb: - wandb.log({ - "test/L1": test["L1"], - "test/NMAPE": test["NMAPE"], - "test/RMSE": test["RMSE"], - "test/NRMSE": test["NRMSE"], - }) - - print(f"\nTraining complete. Best val NMAPE: {best_nmape:.2f}%") - print(f"Checkpoints saved to {save_dir}") + wandb.log( + { + "test/L1": test["L1"], + "test/NMAPE": test["NMAPE"], + "test/RMSE": test["RMSE"], + "test/NRMSE": test["NRMSE"], + } + ) + + if is_main: + print(f"\nTraining complete. Best val NMAPE: {best_nmape:.2f}%") + print(f"Checkpoints saved to {save_dir}") if use_wandb: wandb.finish() + if _is_ddp(): + torch.distributed.destroy_process_group() if __name__ == "__main__": diff --git a/deepdft_ft/__init__.py b/deepdft_ft/__init__.py new file mode 100644 index 0000000..da6fdd4 --- /dev/null +++ b/deepdft_ft/__init__.py @@ -0,0 +1,5 @@ +"""DeepDFT (peterbjorgensen/DeepDFT) fine-tuning glue for LeMat-Rho. + +Mirrors ``charge3net_ft/`` in structure: the data loader reuses ``charge3net_ft``'s +parquet helpers and adapts the per-sample shape to DeepDFT's dict contract. +""" diff --git a/deepdft_ft/data.py b/deepdft_ft/data.py new file mode 100644 index 0000000..03acbb2 --- /dev/null +++ b/deepdft_ft/data.py @@ -0,0 +1,139 @@ +"""LeMat-Rho → DeepDFT data adapter. + +DeepDFT's ``runner.py`` expects a ``torch.utils.data.Dataset`` that yields +per-sample dicts of the form:: + + { + "density": np.ndarray (Nx, Ny, Nz), + "atoms": ase.Atoms, + "origin": np.ndarray (3,), + "grid_position": np.ndarray (Nx, Ny, Nz, 3), + "metadata": {"filename": str, ...}, + } + +That dict is fed into DeepDFT's ``CollateFuncRandomSample`` which samples +random probe points, builds the atom/probe graph via asap3, and pads the +batch. The only thing we provide is a path from a directory of LeMat-Rho +parquet chunks to that dict shape. + +The parquet schema, the index building, and the row → (atoms, density, origin) +conversion live in ``charge3net_ft.data`` and are reused verbatim. Keeping a +single source of truth for the input pipeline means a future Bader/extra-column +addition only needs one regression test. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional + +import numpy as np +import pyarrow.parquet as pq +from torch.utils.data import Dataset + +from charge3net_ft.data import ( + _COLUMNS, + _build_parquet_index, + _row_to_atoms_and_density, +) + + +# Per-worker cache, separate from charge3net_ft's so the two pipelines don't +# step on each other when running side by side in the same process. +_DEEPDFT_TABLE_CACHE: dict = {} + + +def _calculate_grid_pos(density: np.ndarray, origin: np.ndarray, cell) -> np.ndarray: + """Cartesian probe positions for an (Nx, Ny, Nz) density grid. + + Same formula DeepDFT uses internally (see DeepDFT/dataset.py:_calculate_grid_pos). + Kept here so we don't need DeepDFT importable at test time. + + Parameters + ---------- + density : np.ndarray of shape (Nx, Ny, Nz) + Used only for its shape. + origin : np.ndarray of shape (3,) + Cell-frame origin in Cartesian coordinates. + cell : ASE Cell or 3x3 array + Lattice vectors as rows. + + Returns + ------- + grid_pos : np.ndarray of shape (Nx, Ny, Nz, 3) + Cartesian coordinates of every grid point. + """ + ngridpts = np.array(density.shape) + grid_pos = np.meshgrid( + np.arange(ngridpts[0]) / density.shape[0], + np.arange(ngridpts[1]) / density.shape[1], + np.arange(ngridpts[2]) / density.shape[2], + indexing="ij", + ) + grid_pos = np.stack(grid_pos, 3) + grid_pos = np.dot(grid_pos, np.asarray(cell)) + grid_pos = grid_pos + origin + return grid_pos + + +class LeMatRhoDeepDFTDataset(Dataset): + """Iterate LeMat-Rho parquet chunks as DeepDFT-shaped sample dicts. + + Parameters + ---------- + parquet_dir : str or Path + Directory containing ``chunk_*.parquet`` files. + _shared_index : tuple, optional + Internal: pre-built (file_paths, index) tuple shared between + train/val splits to avoid scanning files twice. + """ + + def __init__( + self, + parquet_dir: str | Path | None = None, + _shared_index: Optional[tuple] = None, + ): + if _shared_index is not None: + self._file_paths, self._index = _shared_index + else: + if parquet_dir is None: + raise ValueError("Must provide parquet_dir or _shared_index") + self._file_paths, self._index = _build_parquet_index(Path(parquet_dir)) + + def __len__(self) -> int: + return len(self._index) + + def _read_row(self, idx: int) -> dict: + """Lazy per-worker chunk caching, mirrors charge3net_ft.data. + + Cache is keyed by the absolute parquet path (not the integer ``fi``) + so multiple ``LeMatRhoDeepDFTDataset`` instances pointing at different + directories don't collide on ``fi=0``. + """ + fi, ri = self._index[idx] + key = str(self._file_paths[fi].resolve()) + if key not in _DEEPDFT_TABLE_CACHE: + _DEEPDFT_TABLE_CACHE[key] = pq.read_table( + self._file_paths[fi], columns=_COLUMNS + ) + table = _DEEPDFT_TABLE_CACHE[key] + return {col: table.column(col)[ri].as_py() for col in _COLUMNS} + + def __getitem__(self, idx: int) -> dict: + row = self._read_row(idx) + atoms, density, origin = _row_to_atoms_and_density(row) + grid_pos = _calculate_grid_pos(density, origin, atoms.get_cell()) + + # Index-derived filename so DeepDFT logs stay distinguishable across + # samples. Format mirrors the tar member names DeepDFT normally sees. + fi, ri = self._index[idx] + chunk_stem = Path(self._file_paths[fi]).stem # e.g. "chunk_000017" + filename = f"{chunk_stem}_row{ri:06d}.parquet" + + return { + "density": density, + "atoms": atoms, + "origin": origin, + "grid_position": grid_pos, + "metadata": {"filename": filename}, + } diff --git a/deepdft_ft/runner.py b/deepdft_ft/runner.py new file mode 100644 index 0000000..6140edf --- /dev/null +++ b/deepdft_ft/runner.py @@ -0,0 +1,557 @@ +"""DeepDFT training runner — vendored from peterbjorgensen/DeepDFT@main. + +Vendored rather than monkey-patched because the DDP integration touches +many points throughout `main()` (dataset construction, model wrap, +sampler, checkpoint save, logging gates). Keeping the patched copy here +makes the delta auditable and the code testable. + +Diff vs upstream: +- Adds DDP setup via `_setup_ddp`/`_is_main` helpers (mirrors the pattern + used in `charge3net_ft/train.py`). DDP activates iff `WORLD_SIZE>1`. +- Detects parquet directories and uses `LeMatRhoDeepDFTDataset` instead + of `dataset.DensityData`. Other arg formats are passed through to + upstream unchanged so the runner still works on the original tar/dir + datasets. +- `RandomSampler` swapped for `DistributedSampler` when DDP active. +- Model wrapped in `DistributedDataParallel`; checkpoint save/load unwraps + via `_unwrap`. +- Logging + checkpoint writes gated on rank 0. +""" + +from __future__ import annotations + +import os +import sys +import json +import argparse +import math +import logging +import itertools +import timeit +from pathlib import Path + +import numpy as np +import torch +import torch.utils.data +from torch.utils.data.distributed import DistributedSampler + +torch.set_num_threads(1) # Try to avoid thread overload on cluster + +# --------------------------------------------------------------------------- +# Make the DeepDFT sibling repo importable. Expected layout (mirrors +# how charge3net is set up): +# / <-- LeMat-Rho +# /../DeepDFT/ <-- AIforGreatGood/DeepDFT clone +# --------------------------------------------------------------------------- +_DEEPDFT_ROOT = Path(__file__).resolve().parent.parent.parent / "DeepDFT" +if not _DEEPDFT_ROOT.exists(): + raise RuntimeError( + f"DeepDFT repo not found at {_DEEPDFT_ROOT}.\n" + "Clone it with: git clone https://github.com/peterbjorgensen/DeepDFT " + f"{_DEEPDFT_ROOT}" + ) +if str(_DEEPDFT_ROOT) not in sys.path: + sys.path.insert(0, str(_DEEPDFT_ROOT)) + +# --------------------------------------------------------------------------- +# Stub `asap3` if it isn't available. Building asap3 from source requires +# Python.h which isn't installed on Adastra (and getting it would need +# admin). Upstream DeepDFT supports an ASE-based fallback via +# `AseNeigborListWrapper`; we expose the same interface from `asap3.FullNeighborList` +# so the upstream `import asap3 ; asap3.FullNeighborList(...)` calls work. +# --------------------------------------------------------------------------- +try: + import asap3 # noqa: F401 +except ImportError: + import types + + import ase.neighborlist + import numpy as np + + _asap3_stub = types.ModuleType("asap3") + + class _AseFullNeighborList: + """Drop-in `asap3.FullNeighborList` replacement using ASE primitives. + + Behaviourally equivalent for DeepDFT's use case: ``get_neighbors(i, cutoff)`` + returns ``(indices, rel_positions, dist2)`` arrays. Much slower than real + asap3 but works without C++ headers. + """ + + def __init__(self, cutoff, atoms): + self._cutoff = cutoff + self._positions = atoms.get_positions() + self._cell = np.asarray(atoms.get_cell()) + nl = ase.neighborlist.NewPrimitiveNeighborList( + cutoff, skin=0.0, self_interaction=False, bothways=True + ) + nl.build(atoms.get_pbc(), atoms.get_cell(), atoms.get_positions()) + self._nl = nl + + def get_neighbors(self, i, cutoff): + assert cutoff == self._cutoff, ( + "cutoff must match the one used at FullNeighborList init" + ) + indices, offsets = self._nl.get_neighbors(i) + rel_positions = ( + self._positions[indices] + offsets @ self._cell - self._positions[i] + ) + dist2 = (rel_positions**2).sum(axis=1) + return indices, rel_positions, dist2 + + _asap3_stub.FullNeighborList = _AseFullNeighborList + sys.modules["asap3"] = _asap3_stub + +import densitymodel # noqa: E402 (upstream module) +import dataset # noqa: E402 (upstream module) + +from deepdft_ft.data import LeMatRhoDeepDFTDataset # noqa: E402 + + +# --------------------------------------------------------------------------- +# Distributed-training helpers (same pattern as charge3net_ft/train.py). +# --------------------------------------------------------------------------- +def _is_ddp() -> bool: + return int(os.environ.get("WORLD_SIZE", "1")) > 1 + + +def _setup_ddp() -> tuple[int, int, int]: + """Returns (rank, local_rank, world_size). No-op when WORLD_SIZE=1.""" + if not _is_ddp(): + return 0, 0, 1 + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + # nccl routes through RCCL on AMD ROCm builds. + torch.distributed.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + return rank, local_rank, world_size + + +def _is_main(rank: int) -> bool: + return rank == 0 + + +def _unwrap(model: torch.nn.Module) -> torch.nn.Module: + """Strip DistributedDataParallel for state_dict access.""" + return model.module if hasattr(model, "module") else model + + +def _is_parquet_dir(path: str | Path) -> bool: + """LeMat-Rho parquet dirs contain ``chunk_*.parquet``; tar/cube paths don't.""" + p = Path(path) + return p.is_dir() and any(p.glob("chunk_*.parquet")) + + +def get_arguments(arg_list=None): + parser = argparse.ArgumentParser( + description="Train graph convolution network", fromfile_prefix_chars="+" + ) + parser.add_argument( + "--load_model", + type=str, + default=None, + help="Load model parameters from previous run", + ) + parser.add_argument( + "--cutoff", + type=float, + default=5.0, + help="Atomic interaction cutoff distance [Å]", + ) + parser.add_argument( + "--split_file", + type=str, + default=None, + help="Train/test/validation split file json", + ) + parser.add_argument( + "--num_interactions", + type=int, + default=3, + help="Number of interaction layers used", + ) + parser.add_argument( + "--node_size", type=int, default=64, help="Size of hidden node states" + ) + parser.add_argument( + "--output_dir", + type=str, + default="runs/model_output", + help="Path to output directory", + ) + parser.add_argument( + "--dataset", + type=str, + default="data/qm9.db", + help="Path to ASE database", + ) + parser.add_argument( + "--max_steps", + type=int, + default=int(1e6), + help="Maximum number of optimisation steps", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Set which device to use for training e.g. 'cuda' or 'cpu'", + ) + + parser.add_argument( + "--use_painn_model", + action="store_true", + help="Enable equivariant message passing model (PaiNN)", + ) + + parser.add_argument( + "--ignore_pbc", + action="store_true", + help="If flag is given, disable periodic boundary conditions (force to False) in atoms data", + ) + + parser.add_argument( + "--force_pbc", + action="store_true", + help="If flag is given, force periodic boundary conditions to True in atoms data", + ) + + return parser.parse_args(arg_list) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +def split_data(dataset, args): + # Load or generate splits + if args.split_file: + with open(args.split_file, "r") as fp: + splits = json.load(fp) + else: + datalen = len(dataset) + num_validation = int(math.ceil(datalen * 0.05)) + indices = np.random.permutation(len(dataset)) + splits = { + "train": indices[num_validation:].tolist(), + "validation": indices[:num_validation].tolist(), + } + + # Save split file + with open(os.path.join(args.output_dir, "datasplits.json"), "w") as f: + json.dump(splits, f) + + # Split the dataset + datasplits = {} + for key, indices in splits.items(): + datasplits[key] = torch.utils.data.Subset(dataset, indices) + return datasplits + + +def eval_model(model, dataloader, device): + with torch.no_grad(): + running_ae = torch.tensor(0.0, device=device) + running_se = torch.tensor(0.0, device=device) + running_count = torch.tensor(0.0, device=device) + for batch in dataloader: + device_batch = { + k: v.to(device=device, non_blocking=True) for k, v in batch.items() + } + outputs = model(device_batch) + targets = device_batch["probe_target"] + + running_ae += torch.sum(torch.abs(targets - outputs)) + running_se += torch.sum(torch.square(targets - outputs)) + running_count += torch.sum(device_batch["num_probes"]) + + mae = (running_ae / running_count).item() + rmse = (torch.sqrt(running_se / running_count)).item() + + return mae, rmse + + +def get_normalization(dataset, per_atom=True): + try: + num_targets = len(dataset.transformer.targets) + except AttributeError: + num_targets = 1 + x_sum = torch.zeros(num_targets) + x_2 = torch.zeros(num_targets) + num_objects = 0 + for sample in dataset: + x = sample["targets"] + if per_atom: + x = x / sample["num_nodes"] + x_sum += x + x_2 += x**2.0 + num_objects += 1 + # Var(X) = E[X^2] - E[X]^2 + x_mean = x_sum / num_objects + x_var = x_2 / num_objects - x_mean**2.0 + + return x_mean, torch.sqrt(x_var) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def main(): + args = get_arguments() + + # DDP setup (no-op when WORLD_SIZE=1). Must precede device + dataset + # construction; each rank pins itself to its own GCD via local_rank. + rank, local_rank, world_size = _setup_ddp() + is_main = _is_main(rank) + + # Override device for DDP runs. + if _is_ddp(): + args.device = f"cuda:{local_rank}" + + # Setup logging + os.makedirs(args.output_dir, exist_ok=True) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)-5.5s] %(message)s", + handlers=[ + logging.FileHandler( + os.path.join(args.output_dir, "printlog.txt"), mode="w" + ), + logging.StreamHandler(), + ], + ) + + # Save command line args + with open(os.path.join(args.output_dir, "commandline_args.txt"), "w") as f: + f.write("\n".join(sys.argv[1:])) + # Save parsed command line arguments + with open(os.path.join(args.output_dir, "arguments.json"), "w") as f: + json.dump(vars(args), f) + + # Setup dataset and loader. If args.dataset points at a directory of + # LeMat-Rho chunk_*.parquet files, use our adapter; otherwise fall + # through to upstream's tar/cube/dir loader unchanged. + if _is_parquet_dir(args.dataset): + if is_main: + logging.info("loading LeMat-Rho parquet dir %s", args.dataset) + densitydata = LeMatRhoDeepDFTDataset(parquet_dir=args.dataset) + else: + if args.dataset.endswith(".txt"): + # Text file contains list of datafiles + with open(args.dataset, "r") as datasetfiles: + filelist = [ + os.path.join(os.path.dirname(args.dataset), line.strip("\n")) + for line in datasetfiles + ] + else: + filelist = [args.dataset] + if is_main: + logging.info("loading data %s", args.dataset) + densitydata = torch.utils.data.ConcatDataset( + [dataset.DensityData(path) for path in filelist] + ) + + # Split data into train and validation sets + datasplits = split_data(densitydata, args) + datasplits["train"] = dataset.RotatingPoolData(datasplits["train"], 20) + + if args.ignore_pbc and args.force_pbc: + raise ValueError( + "ignore_pbc and force_pbc are mutually exclusive and can't both be set at the same time" + ) + elif args.ignore_pbc: + set_pbc = False + elif args.force_pbc: + set_pbc = True + else: + set_pbc = None + + # Setup loaders. With DDP, the train sampler shards data across ranks + # so each rank sees a disjoint subset per epoch. Val stays + # non-distributed and only rank 0 actually uses it. + if _is_ddp(): + train_sampler = DistributedSampler( + datasplits["train"], shuffle=True, drop_last=True + ) + else: + train_sampler = torch.utils.data.RandomSampler(datasplits["train"]) + train_loader = torch.utils.data.DataLoader( + datasplits["train"], + 2, + num_workers=4, + sampler=train_sampler, + collate_fn=dataset.CollateFuncRandomSample( + args.cutoff, 1000, pin_memory=False, set_pbc_to=set_pbc + ), + ) + val_loader = torch.utils.data.DataLoader( + datasplits["validation"], + 2, + collate_fn=dataset.CollateFuncRandomSample( + args.cutoff, 5000, pin_memory=False, set_pbc_to=set_pbc + ), + num_workers=0, + ) + # Upstream materialised the full val_loader into a list at startup for + # speed ("Preloading validation batch"). Their NMC/QM9/ethyleneCarbonate + # val sets are ~100 materials so that's cheap. Ours is ~3.3 k materials + # x 5 000 probes/material -> ~150 GB if eagerly preloaded, which OOM-killed + # job 4971720. Leave val_loader as a streaming DataLoader instead; the + # data-loading overhead per val pass is negligible compared to DDP + # gradient sync (when DDP is enabled). Hyperparameters are unchanged. + + # Initialise model + device = torch.device(args.device) + if args.use_painn_model: + net = densitymodel.PainnDensityModel( + args.num_interactions, args.node_size, args.cutoff + ) + else: + net = densitymodel.DensityModel( + args.num_interactions, args.node_size, args.cutoff + ) + if is_main: + logging.debug("model has %d parameters", count_parameters(net)) + net = net.to(device) + if _is_ddp(): + net = torch.nn.parallel.DistributedDataParallel( + net, device_ids=[local_rank], output_device=local_rank + ) + + # Setup optimizer + optimizer = torch.optim.Adam(net.parameters(), lr=0.0001) + criterion = torch.nn.MSELoss() + scheduler_fn = lambda step: 0.96 ** (step / 100000) # noqa: E731 (vendored) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_fn) + + log_interval = 5000 + running_loss = torch.tensor(0.0, device=device) + running_loss_count = torch.tensor(0, device=device) + best_val_mae = np.inf + step = 0 + # Restore checkpoint + if args.load_model: + state_dict = torch.load(args.load_model, map_location=device) + _unwrap(net).load_state_dict(state_dict["model"]) + step = state_dict["step"] + best_val_mae = state_dict["best_val_mae"] + optimizer.load_state_dict(state_dict["optimizer"]) + scheduler.load_state_dict(state_dict["scheduler"]) + + if is_main: + logging.info("start training") + + data_timer = AverageMeter("data_timer") + transfer_timer = AverageMeter("transfer_timer") + train_timer = AverageMeter("train_timer") + eval_timer = AverageMeter("eval_time") + + endtime = timeit.default_timer() + for _ in itertools.count(): + for batch_host in train_loader: + data_timer.update(timeit.default_timer() - endtime) + tstart = timeit.default_timer() + # Transfer to 'device' + batch = { + k: v.to(device=device, non_blocking=True) + for (k, v) in batch_host.items() + } + transfer_timer.update(timeit.default_timer() - tstart) + + tstart = timeit.default_timer() + # Reset gradient + optimizer.zero_grad() + + # Forward, backward and optimize + outputs = net(batch) + loss = criterion(outputs, batch["probe_target"]) + loss.backward() + optimizer.step() + + with torch.no_grad(): + running_loss += ( + loss + * batch["probe_target"].shape[0] + * batch["probe_target"].shape[1] + ) + running_loss_count += torch.sum(batch["num_probes"]) + + train_timer.update(timeit.default_timer() - tstart) + + # print(step, loss_value) + # Validate and save model + if (step % log_interval == 0) or ((step + 1) == args.max_steps): + tstart = timeit.default_timer() + with torch.no_grad(): + train_loss = (running_loss / running_loss_count).item() + running_loss = running_loss_count = 0 + + val_mae, val_rmse = eval_model(net, val_loader, device) + + if is_main: + logging.info( + "step=%d, val_mae=%g, val_rmse=%g, sqrt(train_loss)=%g", + step, + val_mae, + val_rmse, + math.sqrt(train_loss), + ) + + # Save checkpoint (rank 0 only). _unwrap so the state_dict + # is interchangeable between single-GPU and DDP runs. + if is_main and val_mae < best_val_mae: + best_val_mae = val_mae + torch.save( + { + "model": _unwrap(net).state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "best_val_mae": best_val_mae, + }, + os.path.join(args.output_dir, "best_model.pth"), + ) + + eval_timer.update(timeit.default_timer() - tstart) + logging.debug( + "%s %s %s %s" + % (data_timer, transfer_timer, train_timer, eval_timer) + ) + step += 1 + + scheduler.step() + + if step >= args.max_steps: + if is_main: + logging.info("Max steps reached, exiting") + if _is_ddp(): + torch.distributed.destroy_process_group() + sys.exit(0) + + endtime = timeit.default_timer() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 08f1061..a191622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ dependencies = [ "pyarrow>=14.0.0", "wandb>=0.16.0", "python-dotenv>=1.0.0", + "metatensor>=0.2.0", + "chemfiles>=0.10.4", ] [tool.uv.sources] diff --git a/salted_ft/__init__.py b/salted_ft/__init__.py new file mode 100644 index 0000000..7655b6e --- /dev/null +++ b/salted_ft/__init__.py @@ -0,0 +1,17 @@ +"""SALTED-arm basis-expansion infrastructure for the r2SCAN benchmark. + +This package wraps rholearn (`lab-cosmo/rholearn`) and provides the +projection/reconstruction bridge between LeMat-Rho VASP CHGCAR data +and the rholearn training/inference pipeline. + +Layout (stacked PRs, see `plan_salted_graph2mat_basis_choice_may_20_pm.md`): + +* ``basis.py`` (PR α) — ``BasisSpec`` dataclass + shape helpers. +* ``projection.py`` (PR β) — VASP CHGCAR ↔ basis coefficients. +* ``model.py`` (PR γ) — ``SALTEDModel`` wrapper for rholearn. +* ``io.py`` (PR δ) — coefficients/grid ↔ pymatgen ``Chgcar``. +""" + +from salted_ft.basis import BasisSpec + +__all__ = ["BasisSpec"] diff --git a/salted_ft/basis.py b/salted_ft/basis.py new file mode 100644 index 0000000..939f660 --- /dev/null +++ b/salted_ft/basis.py @@ -0,0 +1,87 @@ +"""BasisSpec — the atom-centered radial × angular basis used by the SALTED arm. + +The density expansion is +:: + + rho(r) = sum_i sum_{nlm} c_{i,nlm} phi_{n}(|r - r_i|) Y_{lm}(r - r_i) + +with ``phi_n`` a Gaussian radial of width ``sigma_n`` and ``Y_lm`` a real +spherical harmonic. + +Numbers locked in Phase A4 of +``plan_salted_graph2mat_basis_choice_may_20_pm.md`` (2026-05-20): +``max_l=4``, ``n_radial=4``, ``sigma=(0.5, 1.0, 2.0, 4.0)``, ``cutoff=4.0``. +That gives 100 coefficients per atom (4 × (4+1)²), which lands the trained +model in the same parameter-count ballpark as ChargE3Net for fair comparison. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class BasisSpec: + """Configuration of the atom-centered Gaussian × Y_lm basis. + + Parameters + ---------- + max_l : + Maximum angular momentum, inclusive. Real spherical harmonics + Y_lm with ``l = 0..max_l`` and ``m = -l..l`` are used. + n_radial : + Number of radial channels. Must match ``len(sigma)``. + sigma : + Gaussian widths (Angstrom), one per radial channel. + cutoff : + Radial cutoff (Angstrom) beyond which basis functions are zero. + Should match the cutoff used by the neighbor-list / graph + constructor of the downstream ML model. + """ + + max_l: int = 4 + n_radial: int = 4 + sigma: tuple[float, ...] = field(default=(0.5, 1.0, 2.0, 4.0)) + cutoff: float = 4.0 + + def __post_init__(self) -> None: + # All validation goes here so a malformed spec raises at construction + # time, not deep inside a tensor op three PRs from now. + if self.max_l < 0: + raise ValueError( + f"max_l must be >= 0; got {self.max_l}. " + "Use max_l=0 for an s-only basis." + ) + if self.n_radial < 1: + raise ValueError( + f"n_radial must be >= 1; got {self.n_radial}. " + "A basis with zero radial channels has no expressive power." + ) + if len(self.sigma) != self.n_radial: + raise ValueError( + f"n_radial ({self.n_radial}) must equal len(sigma) " + f"({len(self.sigma)}); each radial channel needs its own width." + ) + if any(s <= 0 for s in self.sigma): + raise ValueError( + f"sigma values must be positive (Gaussian widths); got {self.sigma}." + ) + if self.cutoff <= 0: + raise ValueError( + f"cutoff must be > 0; got {self.cutoff}. " + "A nonpositive cutoff makes the basis identically zero." + ) + + @property + def n_angular_components(self) -> int: + """Number of real-Ylm components for l = 0..max_l: sum_l (2l + 1) = (max_l + 1)^2.""" + return (self.max_l + 1) ** 2 + + @property + def n_coeffs_per_atom(self) -> int: + """Coefficients per atom: n_radial channels × angular components.""" + return self.n_radial * self.n_angular_components + + def total_coeffs_shape(self, n_atoms: int) -> tuple[int, int]: + """Shape of the per-structure coefficients tensor.""" + return (n_atoms, self.n_coeffs_per_atom) diff --git a/salted_ft/io.py b/salted_ft/io.py new file mode 100644 index 0000000..d43c589 --- /dev/null +++ b/salted_ft/io.py @@ -0,0 +1,102 @@ +"""CHGCAR file I/O for the SALTED arm. + +A thin wrapper over pymatgen's ``Chgcar``. The wrapper adds two things +on top of the bare pymatgen API: + +* ``n_electrons`` rescaling. The CHGCAR convention is + ``integrated_density = sum(rho) * cell_volume / N_grid = N_electrons``. + Our predicted densities come from an L2-projected basis with no + guarantee on the integral; we have to rescale so VASP doesn't + silently fix the electron count for us at startup (which would + defeat the speedup measurement). + +* ``ase.Atoms`` input/output to match the rest of the salted_ft + pipeline. pymatgen's ``Structure`` is converted via + ``AseAtomsAdaptor`` and back. + +These two helpers are the boundary between the predicted-density +tensor world and the VASP-input file world. The actual SCF speedup +measurement lives in the entalsim ``StructureVASPSinglePoint`` maker +(separate stack). +""" + +from __future__ import annotations + +from pathlib import Path + +import ase +import numpy as np + + +def write_chgcar( + density: np.ndarray, + atoms: ase.Atoms, + path: str | Path, + n_electrons: float | None = None, +) -> None: + """Write a real-space density grid to a VASP CHGCAR file. + + Parameters + ---------- + density : + Real-space density on a regular grid, shape ``(Nx, Ny, Nz)``. + atoms : + Periodic structure; provides cell + species ordering. + path : + Output file path. + n_electrons : + If given (and > 0), rescale the density so the file's integrated + density equals this value. VASP reads this as the total electron + count when starting with ``ICHARG=1``; getting it right is + what makes the SCF-speedup comparison meaningful. + """ + if density.ndim != 3: + raise ValueError( + f"density must be a 3D grid (Nx, Ny, Nz); got shape {density.shape}" + ) + if n_electrons is not None and n_electrons <= 0: + raise ValueError( + f"n_electrons must be > 0; got {n_electrons}. Use None to skip rescaling." + ) + + from pymatgen.io.ase import AseAtomsAdaptor + from pymatgen.io.vasp.outputs import Chgcar + + structure = AseAtomsAdaptor.get_structure(atoms) + rho = np.asarray(density, dtype=np.float64).copy() + + if n_electrons is not None: + cell_volume = float(structure.lattice.volume) + n_grid = int(np.prod(rho.shape)) + current_total = rho.sum() * cell_volume / n_grid + if current_total != 0.0: + rho *= n_electrons / current_total + + # pymatgen's Chgcar stores density as the per-cell sum (not per-grid-point); + # i.e. rho_stored = rho * cell_volume in its convention. The Chgcar + # constructor expects the data dict to use the same convention as VASP's + # CHGCAR file format, which is rho * volume. We multiply here so the + # round-trip via Chgcar.from_file preserves our user-facing rho. + chgcar_data = {"total": rho * float(structure.lattice.volume)} + chgcar = Chgcar(structure, chgcar_data) + chgcar.write_file(str(path)) + + +def read_chgcar(path: str | Path) -> tuple[np.ndarray, ase.Atoms]: + """Read a CHGCAR file and return ``(density, atoms)``. + + Returns + ------- + density : np.ndarray of shape ``(Nx, Ny, Nz)``, the density per + grid point (the inverse of write_chgcar's convention). + atoms : ase.Atoms + """ + from pymatgen.io.ase import AseAtomsAdaptor + from pymatgen.io.vasp.outputs import Chgcar + + chgcar = Chgcar.from_file(str(path)) + cell_volume = float(chgcar.structure.lattice.volume) + # VASP stores density * volume; undo that for the user-facing density. + rho = np.asarray(chgcar.data["total"], dtype=np.float64) / cell_volume + atoms = AseAtomsAdaptor.get_atoms(chgcar.structure) + return rho, atoms diff --git a/salted_ft/model.py b/salted_ft/model.py new file mode 100644 index 0000000..ec1ddf6 --- /dev/null +++ b/salted_ft/model.py @@ -0,0 +1,141 @@ +"""SALTEDModel — wrapper around rholearn's basis-coefficient prediction. + +The wrapper exposes a single-call interface +``coefficients = model(atoms)`` so the SALTED arm slots into the same +evaluation pipeline as ChargE3Net / DeepDFT: predict, reconstruct on +the VASP FFT grid, compare against the converged density via NMAPE +and friends. + +When constructed with ``ckpt_path=None`` the model is in **stub mode**: +it returns deterministic, position-dependent coefficients without +requiring a trained rholearn checkpoint. This is what powers the +unit tests and the end-to-end pipeline plumbing tests during PR +gamma; PR gamma-prime (a follow-up) will swap in real rholearn +forward calls. + +When ``ckpt_path`` points at a real rholearn checkpoint the model +delegates to rholearn. The rholearn sibling repo is expected at +``../rholearn/`` relative to the LeMat-Rho clone (same pattern as +``charge3net`` for ChargE3Net and ``DeepDFT`` for DeepDFT). +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import ase +import numpy as np + +from salted_ft.basis import BasisSpec +from salted_ft.projection import reconstruct_grid_from_basis + +# rholearn sibling-repo discovery follows the same pattern as +# charge3net_ft/model.py and deepdft_ft/runner.py. Resolution is lazy: +# we only insist on the sibling repo when ckpt_path is provided. +_RHOLEARN_ROOT = Path(__file__).resolve().parent.parent.parent / "rholearn" + + +def _ensure_rholearn_importable() -> None: + """Make ``rholearn`` importable; only called when ckpt_path is set.""" + if not _RHOLEARN_ROOT.exists(): + raise RuntimeError( + f"rholearn repo not found at {_RHOLEARN_ROOT}.\n" + "Clone it with: git clone https://github.com/lab-cosmo/rholearn " + f"{_RHOLEARN_ROOT}\n" + "Note: the metatensor.torch.atomistic -> metatomic.torch namespace " + "patch in rholearn/utils/system.py may also be required." + ) + if str(_RHOLEARN_ROOT) not in sys.path: + sys.path.insert(0, str(_RHOLEARN_ROOT)) + + +class SALTEDModel: + """Predict atom-centered basis coefficients for a structure. + + Parameters + ---------- + basis_spec : + The basis the coefficients are defined against. Must match the + spec the trained checkpoint was trained on. + ckpt_path : + Path to a rholearn checkpoint. If ``None`` (default), the model + runs in stub mode: deterministic, position-dependent fake + coefficients useful for testing the surrounding pipeline. + """ + + def __init__( + self, basis_spec: BasisSpec, ckpt_path: str | Path | None = None + ) -> None: + self.basis_spec = basis_spec + self.ckpt_path = Path(ckpt_path) if ckpt_path is not None else None + if self.ckpt_path is not None: + _ensure_rholearn_importable() + # Lazy import; defer the heavy load to inference call sites. + self._rholearn_model = None + else: + self._rholearn_model = None + + def __call__(self, atoms: ase.Atoms) -> np.ndarray: + """Predict coefficients for ``atoms``. + + Returns + ------- + np.ndarray of shape ``(n_atoms, basis_spec.n_coeffs_per_atom)``, + float64, deterministic, finite. + """ + if self.ckpt_path is None: + return self._stub_predict(atoms) + return self._rholearn_predict(atoms) + + def reconstruct_density( + self, atoms: ase.Atoms, grid_shape: tuple[int, int, int] + ) -> np.ndarray: + """Predict coefficients, then reconstruct the real-space density. + + Equivalent to:: + + c = model(atoms) + reconstruct_grid_from_basis(c, atoms, grid_shape, basis_spec) + + Provided as a convenience for the VASP comparison pipeline, + which always wants the grid form. + """ + coeffs = self(atoms) + return reconstruct_grid_from_basis(coeffs, atoms, grid_shape, self.basis_spec) + + # ------------------------------------------------------------------ + # Implementations + # ------------------------------------------------------------------ + def _stub_predict(self, atoms: ase.Atoms) -> np.ndarray: + """Deterministic position-dependent coefficients without rholearn. + + Recipe: seed a NumPy random generator with a hash of the atomic + positions, atomic numbers, and basis spec. Same atoms in -> same + coefficients out. Different atom positions -> different coefficients. + + The numbers are small (order 1e-3) so reconstructed densities + don't blow up the metric ranges in downstream tests. + """ + n_atoms = len(atoms) + n_coeffs = self.basis_spec.n_coeffs_per_atom + positions = atoms.get_positions() + numbers = atoms.get_atomic_numbers() + + # Build a deterministic seed from the inputs. NumPy's + # SeedSequence handles arbitrary-length input cleanly. + seed_bytes = ( + positions.astype(np.float64).tobytes() + + numbers.astype(np.int64).tobytes() + + str(self.basis_spec).encode("utf-8") + ) + seed_int = int.from_bytes(seed_bytes[:16], byteorder="little", signed=False) + rng = np.random.default_rng(seed_int) + return rng.standard_normal((n_atoms, n_coeffs), dtype=np.float64) * 1e-3 + + def _rholearn_predict(self, atoms: ase.Atoms) -> np.ndarray: + """Real rholearn forward pass. Lands in PR gamma-prime.""" + raise NotImplementedError( + "Real rholearn forward pass is deferred to PR gamma-prime. " + "Construct SALTEDModel with ckpt_path=None for stub mode." + ) diff --git a/salted_ft/project_dataset.py b/salted_ft/project_dataset.py new file mode 100644 index 0000000..01b9729 --- /dev/null +++ b/salted_ft/project_dataset.py @@ -0,0 +1,170 @@ +"""Phase D2: project the LeMat-Rho parquet dataset onto the SALTED basis. + +One-time job. Reads every ``chunk_*.parquet`` produced by +lematerial-fetcher (rows of densities + structures), runs +``project_chgcar_to_basis`` row by row, writes a parallel +``chunk_*.parquet`` of basis coefficients that downstream training +loops (rholearn, Graph2Mat, etc.) consume. + +Output schema per row:: + + row_index int position in the source chunk + material_id str carried from source if present, else "" + n_atoms int + atomic_numbers list[int] ASE atomic numbers, length n_atoms + lattice_vectors list[list] 3x3 cell matrix in Angstrom + n_electrons float integrated density * cell_volume / n_grid + grid_shape list[int] [Nx, Ny, Nz] + coefficients list[list] (n_atoms, n_coeffs_per_atom) + basis_set_NMAPE float per-row reconstruction error (%) + +CLI:: + + uv run python -m salted_ft.project_dataset \\ + --input-dir $SETUP/charge3net_data \\ + --output-dir $SETUP/salted_projected_coefficients +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +from charge3net_ft.data import _COLUMNS, _row_to_atoms_and_density +from salted_ft.basis import BasisSpec +from salted_ft.projection import ( + project_chgcar_to_basis, + reconstruct_grid_from_basis, +) + + +def _row_nmape(true: np.ndarray, pred: np.ndarray) -> float: + """Integral-normalised mean absolute percentage error (%) for one row.""" + return float(100.0 * np.sum(np.abs(true - pred)) / (np.sum(np.abs(true)) + 1e-12)) + + +def project_chunk( + in_path: str | Path, + out_path: str | Path, + basis_spec: BasisSpec, +) -> None: + """Project every valid row in ``in_path`` and write ``out_path``.""" + in_path = Path(in_path) + out_path = Path(out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + + columns = list(_COLUMNS) + # material_id is optional; include it if present so downstream can match + # to the source LeMat-Rho row. + schema = pq.read_schema(in_path) + has_material_id = "material_id" in schema.names + if has_material_id: + columns.append("material_id") + + table = pq.read_table(in_path, columns=columns) + n_rows = len(table) + + out_rows: list[dict] = [] + for ri in range(n_rows): + chgd = table.column("compressed_charge_density")[ri] + if not chgd.is_valid: + continue # skip null density (failed DFT extraction in source) + + row = {col: table.column(col)[ri].as_py() for col in _COLUMNS} + atoms, density, _origin = _row_to_atoms_and_density(row) + + coeffs = project_chgcar_to_basis(density, atoms, basis_spec) + reconstructed = reconstruct_grid_from_basis( + coeffs, atoms, density.shape, basis_spec + ) + nmape = _row_nmape(density, reconstructed) + + cell = np.asarray(atoms.get_cell(), dtype=np.float64) + cell_volume = float(np.abs(np.linalg.det(cell))) + n_grid = int(np.prod(density.shape)) + n_electrons = float(density.sum() * cell_volume / n_grid) + + out_rows.append( + { + "row_index": ri, + "material_id": ( + table.column("material_id")[ri].as_py() if has_material_id else "" + ), + "n_atoms": int(len(atoms)), + "atomic_numbers": atoms.get_atomic_numbers().tolist(), + "lattice_vectors": cell.tolist(), + "n_electrons": n_electrons, + "grid_shape": list(density.shape), + "coefficients": coeffs.tolist(), + "basis_set_NMAPE": nmape, + } + ) + + out_table = pa.Table.from_pylist(out_rows) + pq.write_table(out_table, out_path) + + +def project_directory( + input_dir: str | Path, + output_dir: str | Path, + basis_spec: BasisSpec, +) -> None: + """Run :func:`project_chunk` over every ``chunk_*.parquet`` in ``input_dir``. + + Idempotent: a chunk whose output already exists is left untouched + so partially-completed runs can resume cheaply. + """ + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + inputs = sorted(input_dir.glob("chunk_*.parquet")) + if not inputs: + raise FileNotFoundError(f"no chunk_*.parquet files under {input_dir}") + + for in_path in inputs: + out_path = output_dir / in_path.name + if out_path.exists() and out_path.stat().st_size > 0: + continue + project_chunk(in_path, out_path, basis_spec) + + +def _main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Project the LeMat-Rho parquet dataset onto the SALTED basis." + ) + parser.add_argument("--input-dir", required=True, type=Path) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument( + "--basis-spec", + type=str, + default=None, + help="JSON-encoded BasisSpec overrides. If omitted, defaults are used.", + ) + args = parser.parse_args(argv) + + if args.basis_spec: + overrides = json.loads(args.basis_spec) + # sigma must be tuple-ified to satisfy BasisSpec's frozen dataclass + if "sigma" in overrides: + overrides["sigma"] = tuple(overrides["sigma"]) + spec = BasisSpec(**overrides) + else: + spec = BasisSpec() + print( + f"BasisSpec: lmax={spec.max_l}, n_radial={spec.n_radial}, " + f"sigma={spec.sigma}, cutoff={spec.cutoff}, " + f"n_coeffs_per_atom={spec.n_coeffs_per_atom}" + ) + + project_directory(args.input_dir, args.output_dir, spec) + return 0 + + +if __name__ == "__main__": + raise SystemExit(_main()) diff --git a/salted_ft/projection.py b/salted_ft/projection.py new file mode 100644 index 0000000..4ebc895 --- /dev/null +++ b/salted_ft/projection.py @@ -0,0 +1,323 @@ +"""VASP density grid <-> atom-centered Gaussian * Y_lm basis coefficients. + +The two operations defined here are the DIY bridge between VASP plane-wave +CHGCAR data and the rholearn/SALTED localized-basis world. See the memo +``plan_salted_graph2mat_basis_choice_may_20_pm.md`` (Phase A) for why we +have to build this layer ourselves. + +Math +---- + +The basis expansion is +:: + + rho(r) ~= sum_i sum_n sum_{l,m} c_{i,n,l,m} phi_n(|r - r_i|) Y_lm(rhat) + +where ``i`` indexes atoms, ``n`` is the radial channel, ``(l, m)`` are the +real spherical harmonic indices, ``phi_n`` is a Gaussian of width +``sigma_n``, and ``Y_lm`` is a real spherical harmonic. + +We use the **orthonormal-approximation projection**: each coefficient is +the inner product of the density with the corresponding basis function, +normalized by the basis function's L2 norm. This is exact iff the basis +is orthonormal; for our Gaussians it's a v1 stand-in for a proper +overlap-matrix least-squares solve, which lands in a follow-up PR. + +Reconstruction is the literal sum on the right-hand side. + +Both maps are linear in their input (linearity is a pinned test). + +PBC +--- + +Minimum-image convention via the cell inverse. Each grid point sees each +atom at its closest periodic image. Adequate for cells where 2*cutoff +fits inside the smallest cell vector; for very small cells we'd want +full Ewald-style supercell expansion. Not in scope for PR beta. +""" + +from __future__ import annotations + +import ase +import numpy as np + +from salted_ft.basis import BasisSpec + + +# --------------------------------------------------------------------------- +# Grid-position generation (matches charge3net's `calculate_grid_pos` plus +# `deepdft_ft.data._calculate_grid_pos` so the three pipelines agree on +# where grid point (i, j, k) lives in space). +# --------------------------------------------------------------------------- +def _grid_positions(grid_shape: tuple[int, int, int], cell: np.ndarray) -> np.ndarray: + """Cartesian coordinates of every grid point. + + Parameters + ---------- + grid_shape : (Nx, Ny, Nz) + cell : (3, 3) lattice matrix with rows as vectors + + Returns + ------- + (Nx * Ny * Nz, 3) Cartesian coordinates, ``[i, j, k]`` order matching + ``np.ravel`` of an array of that shape. + """ + # Silence harmless RuntimeWarnings from intermediate matmul reductions. + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + Nx, Ny, Nz = grid_shape + fx = np.arange(Nx, dtype=np.float64) / Nx + fy = np.arange(Ny, dtype=np.float64) / Ny + fz = np.arange(Nz, dtype=np.float64) / Nz + fX, fY, fZ = np.meshgrid(fx, fy, fz, indexing="ij") + frac = np.stack([fX.ravel(), fY.ravel(), fZ.ravel()], axis=-1) + return frac @ cell # (n_grid, 3) + + +# --------------------------------------------------------------------------- +# Real spherical harmonics. We hand-roll real Y_lm for lmax up to 4 +# (covers our default lmax=4) because the alternatives are either heavy +# (e3nn/torch in a pure-numpy module) or complex-valued (scipy.special). +# --------------------------------------------------------------------------- +_SQRT_PI = np.sqrt(np.pi) + + +def _real_sph_harm(rhat: np.ndarray, lmax: int) -> np.ndarray: + """Real spherical harmonics on unit vectors, l = 0..lmax inclusive. + + Returns an array of shape ``(..., (lmax + 1) ** 2)`` where the last + axis is ordered ``[Y_00, Y_1{-1}, Y_10, Y_11, Y_2{-2}, ..., Y_l l]`` + (the standard SOAP / SALTED ordering). + + Parameters + ---------- + rhat : (..., 3) array + Unit vectors. Zero-length inputs are handled by the caller. + lmax : + Maximum angular momentum, inclusive. + """ + if lmax > 4: + raise NotImplementedError( + f"_real_sph_harm only implements l = 0..4 (lmax={lmax} requested). " + "Extend or swap in e3nn.o3.spherical_harmonics for higher lmax." + ) + x, y, z = rhat[..., 0], rhat[..., 1], rhat[..., 2] + n_lm = (lmax + 1) ** 2 + out = np.empty(rhat.shape[:-1] + (n_lm,), dtype=np.float64) + + # l = 0 + out[..., 0] = 0.5 / _SQRT_PI + + if lmax >= 1: + # l = 1: Y_1{-1} ~ y, Y_10 ~ z, Y_11 ~ x + c1 = 0.5 * np.sqrt(3.0 / np.pi) + out[..., 1] = c1 * y + out[..., 2] = c1 * z + out[..., 3] = c1 * x + + if lmax >= 2: + # l = 2 + c2_xy = 0.5 * np.sqrt(15.0 / np.pi) # Y_2{-2}, Y_21, Y_2{-1} prefactors + c2_z2 = 0.25 * np.sqrt(5.0 / np.pi) + c2_x2y2 = 0.25 * np.sqrt(15.0 / np.pi) + out[..., 4] = c2_xy * x * y # Y_2{-2} + out[..., 5] = c2_xy * y * z # Y_2{-1} + out[..., 6] = c2_z2 * (3 * z * z - 1) # Y_20 + out[..., 7] = c2_xy * x * z # Y_21 + out[..., 8] = c2_x2y2 * (x * x - y * y) # Y_22 + + if lmax >= 3: + # l = 3 + c3a = 0.25 * np.sqrt(35.0 / (2.0 * np.pi)) + c3b = 0.5 * np.sqrt(105.0 / np.pi) + c3c = 0.25 * np.sqrt(21.0 / (2.0 * np.pi)) + c3d = 0.25 * np.sqrt(7.0 / np.pi) + out[..., 9] = c3a * y * (3 * x * x - y * y) # Y_3{-3} + out[..., 10] = c3b * x * y * z # Y_3{-2} + out[..., 11] = c3c * y * (5 * z * z - 1) # Y_3{-1} + out[..., 12] = c3d * z * (5 * z * z - 3) # Y_30 + out[..., 13] = c3c * x * (5 * z * z - 1) # Y_31 + out[..., 14] = 0.25 * np.sqrt(105.0 / np.pi) * z * (x * x - y * y) # Y_32 + out[..., 15] = c3a * x * (x * x - 3 * y * y) # Y_33 + + if lmax >= 4: + # l = 4 + c4a = 0.75 * np.sqrt(35.0 / np.pi) + c4b = 0.75 * np.sqrt(35.0 / (2.0 * np.pi)) + c4c = 0.75 * np.sqrt(5.0 / np.pi) + c4d = 0.75 * np.sqrt(5.0 / (2.0 * np.pi)) + c4e = 3.0 / 16.0 * np.sqrt(1.0 / np.pi) + out[..., 16] = c4a * x * y * (x * x - y * y) # Y_4{-4} + out[..., 17] = c4b * y * z * (3 * x * x - y * y) # Y_4{-3} + out[..., 18] = c4c * x * y * (7 * z * z - 1) # Y_4{-2} + out[..., 19] = c4d * y * z * (7 * z * z - 3) # Y_4{-1} + out[..., 20] = c4e * (35 * z**4 - 30 * z * z + 3) # Y_40 + out[..., 21] = c4d * x * z * (7 * z * z - 3) # Y_41 + out[..., 22] = ( + 0.375 * np.sqrt(5.0 / np.pi) * (x * x - y * y) * (7 * z * z - 1) + ) # Y_42 + out[..., 23] = c4b * x * z * (x * x - 3 * y * y) # Y_43 + out[..., 24] = ( + 0.1875 + * np.sqrt(35.0 / np.pi) + * (x * x * (x * x - 3 * y * y) - y * y * (3 * x * x - y * y)) + ) # Y_44 + + return out + + +# --------------------------------------------------------------------------- +# Per-atom basis-function evaluation at grid points +# --------------------------------------------------------------------------- +def _eval_basis_at_grid( + atom_position: np.ndarray, + grid_positions: np.ndarray, + cell: np.ndarray, + basis_spec: BasisSpec, +) -> np.ndarray: + """Evaluate every basis function centered on ``atom_position`` at every + grid point, using minimum-image convention. + + Returns ``(n_grid, n_coeffs_per_atom)`` array of basis-function values. + """ + # The masked points outside the cutoff intentionally produce some + # 0/0 and large-magnitude intermediates whose results we throw away + # via ``mask``. Silence the harmless RuntimeWarnings to keep test + # output readable. + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + inv_cell = np.linalg.inv(cell) + rel = grid_positions - atom_position[None, :] # (n_grid, 3) + # Minimum-image: wrap fractional displacement to [-0.5, 0.5] + frac_disp = rel @ inv_cell + frac_disp = frac_disp - np.round(frac_disp) + rel = frac_disp @ cell # (n_grid, 3) in Cartesian, wrapped + + r = np.linalg.norm(rel, axis=-1) # (n_grid,) + mask = r < basis_spec.cutoff + r_safe = np.where(r > 0, r, 1.0) + rhat = rel / r_safe[:, None] + + # Real spherical harmonics, (n_grid, (lmax+1)^2) + ylm = _real_sph_harm(rhat, basis_spec.max_l) + + n_grid = grid_positions.shape[0] + n_lm = ylm.shape[-1] + n_radial = basis_spec.n_radial + out = np.empty((n_grid, n_radial * n_lm), dtype=np.float64) + + for n_idx, sigma in enumerate(basis_spec.sigma): + radial = np.exp(-0.5 * (r / sigma) ** 2) * mask # (n_grid,) + # block layout: [n=0 lm=0..nlm-1, n=1 lm=0..nlm-1, ...] + out[:, n_idx * n_lm : (n_idx + 1) * n_lm] = radial[:, None] * ylm + + return out + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- +def project_chgcar_to_basis( + density_grid: np.ndarray, + atoms: ase.Atoms, + basis_spec: BasisSpec, +) -> np.ndarray: + """Project a real-space density grid onto the atom-centered basis. + + Uses orthonormal-approximation: each coefficient is the L2 inner + product of the density with the corresponding basis function, + divided by the basis function's own squared L2 norm. Exact when + the basis is orthonormal; a v1 stand-in until PR gamma (which will + swap in proper overlap-matrix LSQR). + + Parameters + ---------- + density_grid : (Nx, Ny, Nz) array + Real-space density on the grid (CHGCAR-like). + atoms : ase.Atoms + Periodic structure. Provides positions and cell. + basis_spec : BasisSpec + Basis to project onto. + + Returns + ------- + (n_atoms, n_coeffs_per_atom) float64 array of coefficients. + """ + grid_shape = density_grid.shape + cell = np.asarray(atoms.get_cell()) + grid_pos = _grid_positions(grid_shape, cell) # (n_grid, 3) + rho_flat = density_grid.astype(np.float64).ravel() # (n_grid,) + + n_atoms = len(atoms) + coeffs = np.zeros((n_atoms, basis_spec.n_coeffs_per_atom), dtype=np.float64) + positions = atoms.get_positions() + + # Build the full per-structure design matrix B_global of shape + # (n_grid, n_atoms * n_coeffs_per_atom) and solve a single least- + # squares system for ALL atoms' coefficients simultaneously. This + # is the correct way to handle the strong overlap between our + # Gaussian basis functions (sigma ~ cutoff means heavy overlap). + # + # The previous orthonormal-approx (numer/denom per channel) + # produced ~1000% NMAPE on real LeMat-Rho rows because it + # overcounted contributions from overlapping basis functions + # (recorded in D1 sanity check, 2026-05-21). + n_per_atom = basis_spec.n_coeffs_per_atom + B_global = np.empty((grid_pos.shape[0], n_atoms * n_per_atom), dtype=np.float64) + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + for i, pos in enumerate(positions): + B_global[:, i * n_per_atom : (i + 1) * n_per_atom] = _eval_basis_at_grid( + pos, grid_pos, cell, basis_spec + ) + # lstsq is overdetermined (n_grid > n_atoms * n_per_atom for our + # 10x10x10 grids), so the solution is the unique minimum-residual + # least-squares fit. + c_flat, *_ = np.linalg.lstsq(B_global, rho_flat, rcond=None) + coeffs = c_flat.reshape(n_atoms, n_per_atom) + + return coeffs + + +def reconstruct_grid_from_basis( + coefficients: np.ndarray, + atoms: ase.Atoms, + grid_shape: tuple[int, int, int], + basis_spec: BasisSpec, +) -> np.ndarray: + """Reconstruct a density grid from per-atom basis coefficients. + + Just evaluates the basis at every grid point and contracts with the + coefficients. The reverse of ``project_chgcar_to_basis`` in the + sense that ``reconstruct(project(rho))`` is the best basis-set + approximation to ``rho``. + + Parameters + ---------- + coefficients : (n_atoms, n_coeffs_per_atom) array + atoms : ase.Atoms + grid_shape : (Nx, Ny, Nz) + basis_spec : BasisSpec + + Returns + ------- + (Nx, Ny, Nz) float64 density grid. + """ + n_atoms = len(atoms) + if coefficients.shape != (n_atoms, basis_spec.n_coeffs_per_atom): + raise ValueError( + f"coefficients shape {coefficients.shape} mismatches " + f"({n_atoms}, {basis_spec.n_coeffs_per_atom})" + ) + + cell = np.asarray(atoms.get_cell()) + grid_pos = _grid_positions(grid_shape, cell) + positions = atoms.get_positions() + + rho_flat = np.zeros(grid_pos.shape[0], dtype=np.float64) + coefficients = coefficients.astype(np.float64) + # Same harmless matmul warnings from masked-out grid points as in + # _eval_basis_at_grid; silence them at the caller too. + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + for i, pos in enumerate(positions): + B = _eval_basis_at_grid(pos, grid_pos, cell, basis_spec) + rho_flat += B @ coefficients[i] + + return rho_flat.reshape(grid_shape) diff --git a/salted_ft/rholearn_adapter.py b/salted_ft/rholearn_adapter.py new file mode 100644 index 0000000..94ff5e5 --- /dev/null +++ b/salted_ft/rholearn_adapter.py @@ -0,0 +1,205 @@ +"""SALTED -> rholearn data-format adapter. + +rholearn's training loop consumes basis-coefficient vectors in +metatensor ``TensorMap`` format, with a specific flat-vector layout +that differs from our internal one: + +================== =================================================== +Our layout atom (outer) -> n (radial) -> lambda -> mu + (this is what ``project_chgcar_to_basis`` returns) +rholearn layout atom (outer) -> lambda -> n (radial) -> mu + (see ``rholearn/utils/convert.py::_get_flat_index``) +================== =================================================== + +This module provides three things: + +1. ``build_lmax_nmax(basis_spec, species)`` -- our uniform BasisSpec + expanded into rholearn's per-species ``lmax`` / ``nmax`` dicts. +2. ``dense_to_rholearn_flat`` / ``rholearn_flat_to_dense`` -- the + permutation between the two layouts, ndarray <-> ndarray. Roundtrip + is exact and pinned by tests. +3. ``dense_to_tensormap`` -- the full path that calls rholearn's + ``convert.coeff_vector_ndarray_to_tensormap`` to produce a + ``metatensor.TensorMap``. Lazy-imports rholearn / metatensor. + +The permutation is the load-bearing piece. Get it wrong and rholearn +trains on misordered data; the value at index k of the flat vector +no longer corresponds to the (lambda, n, mu) channel rholearn thinks +it does. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Iterable + +import numpy as np + +from salted_ft.basis import BasisSpec + + +# Path setup for lazy rholearn import. Same pattern as +# charge3net_ft/model.py and deepdft_ft/runner.py. +_RHOLEARN_ROOT = Path(__file__).resolve().parent.parent.parent / "rholearn" + + +def _ensure_rholearn_importable() -> None: + if not _RHOLEARN_ROOT.exists(): + raise RuntimeError( + f"rholearn repo not found at {_RHOLEARN_ROOT}.\n" + "Clone it with: git clone https://github.com/lab-cosmo/rholearn " + f"{_RHOLEARN_ROOT}" + ) + if str(_RHOLEARN_ROOT) not in sys.path: + sys.path.insert(0, str(_RHOLEARN_ROOT)) + + +# --------------------------------------------------------------------------- +# Basis spec dict builder +# --------------------------------------------------------------------------- +def build_lmax_nmax( + basis_spec: BasisSpec, species: Iterable[str] +) -> tuple[dict[str, int], dict[tuple[str, int], int]]: + """Expand our uniform BasisSpec into rholearn's per-species dicts. + + Returns + ------- + lmax : ``{species: max_l}`` for every species in ``species`` + nmax : ``{(species, lambda): n_radial}`` for every (species, lambda) + """ + species = list(species) + lmax = {s: basis_spec.max_l for s in species} + nmax = { + (s, lam): basis_spec.n_radial + for s in species + for lam in range(basis_spec.max_l + 1) + } + return lmax, nmax + + +# --------------------------------------------------------------------------- +# Permutation between our layout and rholearn's +# --------------------------------------------------------------------------- +def _our_to_rholearn_permutation(basis_spec: BasisSpec) -> np.ndarray: + """Return the index permutation ``p`` such that ``rholearn_flat[k] == + our_flat[p[k]]`` for a SINGLE atom. + + Our per-atom layout (length ``n_radial * (max_l + 1) ** 2``): + for n in 0..n_radial: + for lambda in 0..max_l: + for mu in -lambda..+lambda: + yield (n, lambda, mu) + + rholearn's per-atom layout (same total length): + for lambda in 0..max_l: + for n in 0..n_radial: + for mu in -lambda..+lambda: + yield (lambda, n, mu) + + The permutation is independent of the species (uniform basis). + """ + n_radial = basis_spec.n_radial + max_l = basis_spec.max_l + + # Source flat index for (n, lambda, mu) in OUR layout: + # n * (max_l + 1) ** 2 + lambda * lambda + (mu + lambda) + # (the second-and-third pieces together index the standard Y_lm slot) + def our_idx(n: int, lam: int, mu: int) -> int: + return n * (max_l + 1) ** 2 + lam * lam + (mu + lam) + + # Build the permutation by walking rholearn's order + perm = np.empty(n_radial * (max_l + 1) ** 2, dtype=np.int64) + k = 0 + for lam in range(max_l + 1): + for n in range(n_radial): + for mu in range(-lam, lam + 1): + perm[k] = our_idx(n, lam, mu) + k += 1 + return perm + + +def dense_to_rholearn_flat( + coeffs: np.ndarray, + basis_spec: BasisSpec, + symbols: Iterable[str], +) -> np.ndarray: + """Convert our dense ``(n_atoms, n_coeffs_per_atom)`` coefficients to + rholearn's flat per-structure vector. + + Output length: ``n_atoms * n_coeffs_per_atom``. ``symbols`` is + accepted for API symmetry with the inverse and species-aware + extensions; today the permutation is species-independent because + our BasisSpec is uniform across species. + """ + n_atoms = coeffs.shape[0] + assert coeffs.shape == (n_atoms, basis_spec.n_coeffs_per_atom) + perm = _our_to_rholearn_permutation(basis_spec) + # ``coeffs[:, perm]`` reorders each atom's row from our layout to rholearn's + return coeffs[:, perm].ravel().astype(np.float64) + + +def rholearn_flat_to_dense( + flat: np.ndarray, + basis_spec: BasisSpec, + symbols: Iterable[str], +) -> np.ndarray: + """Inverse of ``dense_to_rholearn_flat``. Returns the dense + ``(n_atoms, n_coeffs_per_atom)`` array. + """ + n_coeffs = basis_spec.n_coeffs_per_atom + if flat.size % n_coeffs != 0: + raise ValueError( + f"flat vector length {flat.size} is not a multiple of " + f"n_coeffs_per_atom={n_coeffs}; cannot reshape to (n_atoms, n_coeffs)" + ) + n_atoms = flat.size // n_coeffs + reshaped = flat.reshape(n_atoms, n_coeffs).astype(np.float64) + # Inverse permutation: ``inv[perm[k]] = k``. + perm = _our_to_rholearn_permutation(basis_spec) + inv = np.empty_like(perm) + inv[perm] = np.arange(perm.size) + return reshaped[:, inv] + + +# --------------------------------------------------------------------------- +# Full TensorMap path +# --------------------------------------------------------------------------- +def dense_to_tensormap( + coeffs: np.ndarray, + basis_spec: BasisSpec, + symbols: Iterable[str], + positions: np.ndarray, + cell: np.ndarray, + structure_idx: int = 0, +): + """Convert dense coefficients to a ``metatensor.TensorMap`` using + rholearn's converter. + + Lazy-imports rholearn + metatensor so this module is importable + without those deps installed (the permutation tests above are + pure numpy). + """ + _ensure_rholearn_importable() + import chemfiles # noqa: F401 (needed by rholearn's converter) + from rholearn.utils import convert # type: ignore[import-not-found] + + flat = dense_to_rholearn_flat(coeffs, basis_spec, symbols) + lmax, nmax = build_lmax_nmax(basis_spec, set(symbols)) + + # Build a chemfiles Frame from the structure (rholearn's converter + # expects one). + frame = chemfiles.Frame() + frame.cell = chemfiles.UnitCell(np.asarray(cell, dtype=np.float64)) + for sym, pos in zip(list(symbols), np.asarray(positions), strict=True): + atom = chemfiles.Atom(sym) + frame.add_atom(atom, list(pos)) + + return convert.coeff_vector_ndarray_to_tensormap( + frame, + coeff_vector=flat, + lmax=lmax, + nmax=nmax, + structure_idx=structure_idx, + tests=0, + ) diff --git a/submit_charge3net_adastra.sh b/submit_charge3net_adastra.sh new file mode 100644 index 0000000..f63342a --- /dev/null +++ b/submit_charge3net_adastra.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# ChargE3Net fine-tuning on Adastra (CINES, AMD MI250X), half-node DDP. +# +# Two training modes (select via LEMATRHO_TRAINING_MODE env): +# pretrained (default) — fine-tune from charge3net_mp.pt (MP, 245 epochs) +# from_scratch — train from random init for direct comparison +# +# Env vars: +# LEMATRHO_TRAINING_MODE pretrained | from_scratch (default: pretrained) +# LEMATRHO_ADASTRA_SETUP override $SETUP (default: /lus/scratch/CT10/cad16353/msiron/charge3net_setup) +# LEMATRHO_DRY_RUN 1 to print the resolved train command and exit +# (used by tests/test_submit_script.py) +# +# Submit examples: +# sbatch submit_charge3net_adastra.sh # pretrained +# sbatch --export=ALL,LEMATRHO_TRAINING_MODE=from_scratch submit_charge3net_adastra.sh # from-scratch +# +# Half-node resource layout (g1xxx mi250-shared has 8 GCDs, 128 CPUs, 256 GB): +# - 4 GCDs (gpus-per-node=4) +# - 64 CPUs (16 per task * 4 tasks) +# - 128 GB RAM +# - 4 tasks, one per GCD, for torch DistributedDataParallel +# Effective batch = batch-size * world_size = 16 * 4 = 64 (matches the +# upstream paper's train_mp_e3_final.yaml: batch_size=16, nnodes=2 x nprocs=2). +#SBATCH --job-name=charge3net_ft +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --account=c1816212 +#SBATCH --constraint=MI250 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=16 +# No --mem here on purpose: SLURM allocates memory proportional to our CPU +# share (64 of 128 logical CPUs = ~128 GB out of the 256 GB node). The +# earlier --mem=125000M was being read as "asking for half the node memory" +# and contributed to SLURM auto-bumping us to EXCLUSIVE mode. Letting SLURM +# pick lets the other half of the node stay schedulable for other jobs. +#SBATCH --time=06:00:00 +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err + +set -eo pipefail + +# --- Paths --- +# Submit dir must be on a scratch with inode headroom (cad16353 currently); the +# account (--account=c1816212 above) handles billing independently. See ADASTRA.md. +SETUP="${LEMATRHO_ADASTRA_SETUP:-/lus/scratch/CT10/cad16353/msiron/charge3net_setup}" +WORK_DIR="$SETUP/LeMat-Rho" +DATA_DIR="$SETUP/charge3net_data" +MP_CKPT="$SETUP/charge3net/models/charge3net_mp.pt" + +# --- Training mode ----------------------------------------------------------- +TRAINING_MODE="${LEMATRHO_TRAINING_MODE:-pretrained}" +case "$TRAINING_MODE" in + pretrained) + CKPT_PATH="$MP_CKPT" + CKPT_DIR="$SETUP/charge3net_checkpoints" + export WANDB_NAME="pretrained_mp" + ;; + from_scratch) + CKPT_PATH="" # no --ckpt-path -> ChargE3NetWrapper inits from random + CKPT_DIR="$SETUP/charge3net_checkpoints_fromscratch" + export WANDB_NAME="from_scratch" + ;; + *) + echo "ERROR: LEMATRHO_TRAINING_MODE must be 'pretrained' or 'from_scratch'," \ + "got '$TRAINING_MODE'" >&2 + exit 2 + ;; +esac + +mkdir -p "$CKPT_DIR" 2>/dev/null || true + +# --- Build train command ----------------------------------------------------- +# Constructed early so LEMATRHO_DRY_RUN can short-circuit before sourcing venv. +TRAIN_ARGS=( + --parquet-dir "$DATA_DIR" + --save-dir "$CKPT_DIR" + --epochs 50 + --batch-size 16 + --lr 5e-4 + --train-probes 200 + --val-probes 1000 + # num-workers=2 (down from 8): with 4 DDP ranks each forking workers, the + # previous setting created 32 worker processes total and the per-worker + # _TABLE_CACHE in data.py OOM-killed jobs 4971293/4971343 at ~140 GB + # cumulative RSS. The LRU eviction we landed in data.py would help on + # its own, but lowering worker count further drops cache pressure with + # zero loss in throughput at this dataset/grid size. + --num-workers 2 + --wandb-project lemat-rho-charge3net + --wandb-entity dtts + --wandb-mode offline +) +if [ -n "$CKPT_PATH" ]; then + TRAIN_ARGS+=(--ckpt-path "$CKPT_PATH") +fi +if [ -f "$CKPT_DIR/latest.pt" ]; then + TRAIN_ARGS+=(--resume-from "$CKPT_DIR/latest.pt") +fi + +if [ "${LEMATRHO_DRY_RUN:-0}" = "1" ]; then + echo "WANDB_NAME=$WANDB_NAME" + echo "TRAINING_MODE=$TRAINING_MODE" + echo "CKPT_DIR=$CKPT_DIR" + printf 'python -m charge3net_ft.train' + for arg in "${TRAIN_ARGS[@]}"; do + printf ' %s' "$arg" + done + printf '\n' + exit 0 +fi + +# --- Environment ------------------------------------------------------------- +# Proxy is required for any outbound HTTP (pip, HF, W&B). Already in ~/.bashrc +# on Adastra but we re-export here so the job script is self contained. +export HTTP_PROXY=http://proxy-l-adastra.cines.fr:3128 +export HTTPS_PROXY=$HTTP_PROXY +export http_proxy=$HTTP_PROXY +export https_proxy=$HTTP_PROXY + +source "$SETUP/venv311/bin/activate" + +export PYTHONPATH="$WORK_DIR:$SETUP/charge3net:$PYTHONPATH" +export PYTHONUNBUFFERED=1 + +# Load W&B key from .env if present. +if [ -f "$WORK_DIR/.env" ]; then + set -a + source "$WORK_DIR/.env" + set +a +fi + +# --- NCCL / DDP reliability tweaks --- +# Job 4977567 (2026-05-21) ran 2h41m, then died from NCCL TCPStore +# "Broken pipe / should dump flag" on the DDP heartbeat. Memory was +# fine (14 GB/task with the LRU cache fix). The crash is on the +# inter-rank communication channel, not the model. These three env +# vars expand the timeout budget so a transient slow rank doesn't +# tear down the whole job. +# NCCL_TIMEOUT per-collective timeout (seconds) +# NCCL_ASYNC_ERROR_HANDLING=1 clean shutdown on rank failure +# (no cascading hangs) +# TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC how long a rank can stall +# before HeartbeatMonitor tears +# down the process group +export NCCL_TIMEOUT=3600 +export NCCL_ASYNC_ERROR_HANDLING=1 +export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=1800 +export TORCH_NCCL_TRACE_BUFFER_SIZE=1000 # capture more debug info on next crash + +# --- Distributed-training env vars (read by train.py's _setup_ddp) --- +# SLURM sets SLURM_NTASKS, SLURM_PROCID, SLURM_LOCALID for us via srun. +# torch.distributed wants WORLD_SIZE / RANK / LOCAL_RANK plus MASTER_ADDR +# / MASTER_PORT. We export them once here, srun propagates to each task. +export WORLD_SIZE=$SLURM_NTASKS +export MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n 1) +export MASTER_PORT=29500 +# RANK / LOCAL_RANK are per-task — set in the wrapper srun command below. + +echo "Node: $(hostname)" +echo "Account: ${SLURM_JOB_ACCOUNT:-unknown}" +echo "Job dir: $WORK_DIR" +echo "Training mode: $TRAINING_MODE (wandb name: $WANDB_NAME)" +echo "Checkpoint dir: $CKPT_DIR" +echo "WORLD_SIZE=$WORLD_SIZE MASTER_ADDR=$MASTER_ADDR MASTER_PORT=$MASTER_PORT" +rocm-smi || true + +python3 -c " +import torch +print(f'torch: {torch.__version__}') +print(f'CUDA/ROCm available: {torch.cuda.is_available()}') +print(f'device count: {torch.cuda.device_count()}') +" + +cd "$WORK_DIR" + +# --- Train ------------------------------------------------------------------ +# srun launches 4 tasks (--ntasks-per-node=4 from #SBATCH). Each task sees +# SLURM_PROCID = global rank, SLURM_LOCALID = local rank within node. +# The TRAIN_ARGS array is exported as a quoted string so the srun-spawned +# bash can reconstruct it. +TRAIN_ARGS_QUOTED="" +for arg in "${TRAIN_ARGS[@]}"; do + TRAIN_ARGS_QUOTED+=" $(printf '%q' "$arg")" +done +export TRAIN_ARGS_QUOTED + +srun --kill-on-bad-exit=1 bash -c ' + export RANK=$SLURM_PROCID + export LOCAL_RANK=$SLURM_LOCALID + # Each task sees ALL 4 GCDs the job was allocated; torch.cuda.set_device(local_rank) + # inside _setup_ddp picks the right one. Restricting visibility per-task here + # would make every task target the same "GCD 0" within its own visibility set. + echo "task RANK=$RANK LOCAL_RANK=$LOCAL_RANK on $(hostname) (will use cuda:$LOCAL_RANK)" + eval "python3 -m charge3net_ft.train $TRAIN_ARGS_QUOTED" +' + +echo "Done. Exit code: $?" diff --git a/submit_deepdft_adastra.sh b/submit_deepdft_adastra.sh new file mode 100644 index 0000000..5fd169c --- /dev/null +++ b/submit_deepdft_adastra.sh @@ -0,0 +1,136 @@ +#!/bin/bash +# DeepDFT training on Adastra (CINES, AMD MI250X), single-GPU paper-faithful. +# +# Faithful to peterbjorgensen/DeepDFT paper settings: +# - 1 GCD (paper used 1x RTX 3090; we use 1x MI250X) +# - batch=2 materials, train=1000 probes/material, val=5000 probes/material +# (hardcoded in deepdft_ft/runner.py, same as upstream) +# - cutoff=4 A, num_interactions=3, node_size=128, PaiNN model +# - max_steps=10,000,000 +# +# Single-GPU keeps the gradient-step semantics identical to the paper. +# DDP code paths in runner.py only fire when WORLD_SIZE>1 -- we leave them +# out here on purpose. If we ever want DDP for DeepDFT we'd also need to +# sweep the LR (effective batch grows with world_size). +# +# Env vars: +# LEMATRHO_ADASTRA_SETUP override $SETUP (default: cad16353 scratch) +# LEMATRHO_DEEPDFT_VARIANT painn (default) | schnet +# LEMATRHO_DRY_RUN 1 to print resolved cmd + exit +# +# Submit examples: +# sbatch submit_deepdft_adastra.sh # PaiNN +# sbatch --export=ALL,LEMATRHO_DEEPDFT_VARIANT=schnet submit_deepdft_adastra.sh # SchNet +# +#SBATCH --job-name=deepdft_ft +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --account=c1816212 +#SBATCH --constraint=MI250 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64000M +#SBATCH --time=24:00:00 +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err + +set -eo pipefail + +# --- Paths --- +SETUP="${LEMATRHO_ADASTRA_SETUP:-/lus/scratch/CT10/cad16353/msiron/charge3net_setup}" +WORK_DIR="$SETUP/LeMat-Rho" +DATA_DIR="$SETUP/charge3net_data" +DEEPDFT_REPO="$SETUP/DeepDFT" + +# --- Model variant --- +VARIANT="${LEMATRHO_DEEPDFT_VARIANT:-painn}" +case "$VARIANT" in + painn) + EXTRA_ARGS=(--use_painn_model) + OUTPUT_DIR="$SETUP/deepdft_runs/painn" + export WANDB_NAME="deepdft_painn" + ;; + schnet) + EXTRA_ARGS=() # SchNet is the default architecture, no flag needed + OUTPUT_DIR="$SETUP/deepdft_runs/schnet" + export WANDB_NAME="deepdft_schnet" + ;; + *) + echo "ERROR: LEMATRHO_DEEPDFT_VARIANT must be 'painn' or 'schnet', got '$VARIANT'" >&2 + exit 2 + ;; +esac + +mkdir -p "$OUTPUT_DIR" 2>/dev/null || true + +# --- Build train command ----------------------------------------------------- +# Hyperparameters lifted from pretrained_models/{nmc,qm9,ethylenecarbonate}_painn +# in the upstream DeepDFT repo. Same values across all three published checkpoints. +TRAIN_ARGS=( + --dataset "$DATA_DIR" + --output_dir "$OUTPUT_DIR" + --cutoff 4 + --num_interactions 3 + --node_size 128 + --max_steps 10000000 + --device cuda + "${EXTRA_ARGS[@]}" +) +if [ -f "$OUTPUT_DIR/best_model.pth" ]; then + TRAIN_ARGS+=(--load_model "$OUTPUT_DIR/best_model.pth") +fi + +if [ "${LEMATRHO_DRY_RUN:-0}" = "1" ]; then + echo "WANDB_NAME=$WANDB_NAME" + echo "VARIANT=$VARIANT" + echo "OUTPUT_DIR=$OUTPUT_DIR" + printf 'python -m deepdft_ft.runner' + for arg in "${TRAIN_ARGS[@]}"; do + printf ' %s' "$arg" + done + printf '\n' + exit 0 +fi + +# --- Environment ------------------------------------------------------------- +export HTTP_PROXY=http://proxy-l-adastra.cines.fr:3128 +export HTTPS_PROXY=$HTTP_PROXY +export http_proxy=$HTTP_PROXY +export https_proxy=$HTTP_PROXY + +source "$SETUP/venv311/bin/activate" + +export PYTHONPATH="$WORK_DIR:$DEEPDFT_REPO:$PYTHONPATH" +export PYTHONUNBUFFERED=1 + +if [ -f "$WORK_DIR/.env" ]; then + set -a + source "$WORK_DIR/.env" + set +a +fi + +# Pin to GCD 0 (single-GPU paper-faithful). Do NOT set WORLD_SIZE so that +# runner.py's _setup_ddp returns the single-process tuple (0, 0, 1). +export HIP_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0 + +echo "Node: $(hostname)" +echo "Account: ${SLURM_JOB_ACCOUNT:-unknown}" +echo "Variant: $VARIANT (wandb name: $WANDB_NAME)" +echo "Output dir: $OUTPUT_DIR" +echo "Single-GPU mode (WORLD_SIZE unset)" +rocm-smi || true + +python3 -c " +import torch +print(f'torch: {torch.__version__}') +print(f'CUDA/ROCm available: {torch.cuda.is_available()}') +print(f'device count: {torch.cuda.device_count()}') +" + +cd "$WORK_DIR" + +# --- Train (single GPU, no srun) -------------------------------------------- +python3 -m deepdft_ft.runner "${TRAIN_ARGS[@]}" + +echo "Done. Exit code: $?" diff --git a/submit_project_lematrho_adastra.sh b/submit_project_lematrho_adastra.sh new file mode 100644 index 0000000..edccffc --- /dev/null +++ b/submit_project_lematrho_adastra.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Phase D2: project the LeMat-Rho parquet dataset onto the SALTED basis. +# +# One-time CPU job. Reads $SETUP/charge3net_data/chunk_*.parquet, +# writes $SETUP/salted_projected_coefficients/chunk_*.parquet via +# salted_ft.project_dataset (one LSQR per row, ~75 ms per row). +# +# Adastra smoke test (1 chunk, 956 valid rows) timed at 71 s wall. +# Full dataset (69 chunks, ~65k rows) extrapolates to ~80 min. +# Budget 2 h with slack. +# +# Env vars +# LEMATRHO_ADASTRA_SETUP override $SETUP (default: cad16353 scratch) +# +#SBATCH --job-name=salted_project_dataset +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --account=c1816212 +#SBATCH --constraint=GENOA +#SBATCH --cpus-per-task=16 +#SBATCH --time=02:00:00 +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err + +set -eo pipefail + +SETUP="${LEMATRHO_ADASTRA_SETUP:-/lus/scratch/CT10/cad16353/msiron/charge3net_setup}" +WORK_DIR="$SETUP/LeMat-Rho" +INPUT_DIR="$SETUP/charge3net_data" +OUTPUT_DIR="$SETUP/salted_projected_coefficients" + +mkdir -p "$OUTPUT_DIR" 2>/dev/null || true + +source "$SETUP/venv311/bin/activate" +export PYTHONPATH="$WORK_DIR:$PYTHONPATH" +export PYTHONUNBUFFERED=1 + +# numpy / lstsq is already multi-threaded via BLAS; cap thread count +# to match the SLURM allocation so we do not oversubscribe the node. +export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE +export OPENBLAS_NUM_THREADS=$SLURM_CPUS_ON_NODE +export MKL_NUM_THREADS=$SLURM_CPUS_ON_NODE + +echo "Node: $(hostname)" +echo "Account: ${SLURM_JOB_ACCOUNT:-unknown}" +echo "Input: $INPUT_DIR" +echo "Output: $OUTPUT_DIR" +echo "CPUs: $SLURM_CPUS_ON_NODE" + +cd "$WORK_DIR" + +python -m salted_ft.project_dataset \ + --input-dir "$INPUT_DIR" \ + --output-dir "$OUTPUT_DIR" + +echo "Done. Exit code: $?" +echo "Counting output chunks:" +ls "$OUTPUT_DIR"/chunk_*.parquet | wc -l diff --git a/tests/test_data.py b/tests/test_data.py index 4ef046b..4ec7efc 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -27,9 +27,13 @@ def _import_data_utils(): # Stub out the charge3net modules so the import succeeds without the repo fake_modules = [ - "src", "src.charge3net", "src.charge3net.data", - "src.charge3net.data.collate", "src.charge3net.data.graph_construction", - "src.utils", "src.utils.data", + "src", + "src.charge3net", + "src.charge3net.data", + "src.charge3net.data.collate", + "src.charge3net.data.graph_construction", + "src.utils", + "src.utils.data", ] stubs = {} for mod in fake_modules: @@ -42,6 +46,7 @@ def _import_data_utils(): # Also patch the existence check so it doesn't raise with patch("pathlib.Path.exists", return_value=True): import importlib + # Force reimport with stubs in place if "charge3net_ft.data" in sys.modules: del sys.modules["charge3net_ft.data"] @@ -54,6 +59,7 @@ def test_roundtrip_3d(self): grid = [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] json_str = json.dumps(grid) from charge3net_ft.data import _parse_grid_json + result = _parse_grid_json(json_str) assert result.shape == (2, 2, 2) assert result.dtype == np.float32 @@ -61,6 +67,7 @@ def test_roundtrip_3d(self): def test_10x10x10(self): from charge3net_ft.data import _parse_grid_json + grid = np.random.rand(10, 10, 10).tolist() result = _parse_grid_json(json.dumps(grid)) assert result.shape == (10, 10, 10) @@ -78,6 +85,7 @@ def _make_row(self): def test_atoms_species(self): import ase from charge3net_ft.data import _row_to_atoms_and_density + row = self._make_row() atoms, density, origin = _row_to_atoms_and_density(row) assert isinstance(atoms, ase.Atoms) @@ -85,21 +93,25 @@ def test_atoms_species(self): def test_pbc(self): from charge3net_ft.data import _row_to_atoms_and_density + atoms, _, _ = _row_to_atoms_and_density(self._make_row()) assert all(atoms.pbc) def test_density_shape(self): from charge3net_ft.data import _row_to_atoms_and_density + _, density, _ = _row_to_atoms_and_density(self._make_row()) assert density.shape == (10, 10, 10) def test_origin_is_zero(self): from charge3net_ft.data import _row_to_atoms_and_density + _, _, origin = _row_to_atoms_and_density(self._make_row()) np.testing.assert_array_equal(origin, [0.0, 0.0, 0.0]) def test_unknown_species_raises(self): from charge3net_ft.data import _row_to_atoms_and_density + row = self._make_row() row["species_at_sites"] = ["Xx"] # invalid symbol with pytest.raises(KeyError): @@ -111,16 +123,24 @@ def _write_chunk(self, path: Path, n_valid: int, n_null: int): """Write a synthetic chunk_*.parquet file.""" valid = [json.dumps(np.ones((10, 10, 10)).tolist())] * n_valid null = [None] * n_null - table = pa.table({ - "compressed_charge_density": pa.array(valid + null, type=pa.string()), - "species_at_sites": pa.array([["Fe"]] * (n_valid + n_null)), - "cartesian_site_positions": pa.array([[[0.0, 0.0, 0.0]]] * (n_valid + n_null)), - "lattice_vectors": pa.array([[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * (n_valid + n_null)), - }) + table = pa.table( + { + "compressed_charge_density": pa.array(valid + null, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * (n_valid + n_null)), + "cartesian_site_positions": pa.array( + [[[0.0, 0.0, 0.0]]] * (n_valid + n_null) + ), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] + * (n_valid + n_null) + ), + } + ) pq.write_table(table, path) def test_counts_valid_rows(self): from charge3net_ft.data import _build_parquet_index + with tempfile.TemporaryDirectory() as tmp: d = Path(tmp) self._write_chunk(d / "chunk_000.parquet", n_valid=5, n_null=2) @@ -131,6 +151,7 @@ def test_counts_valid_rows(self): def test_index_entries_reference_correct_file(self): from charge3net_ft.data import _build_parquet_index + with tempfile.TemporaryDirectory() as tmp: d = Path(tmp) self._write_chunk(d / "chunk_000.parquet", n_valid=3, n_null=0) @@ -142,6 +163,177 @@ def test_index_entries_reference_correct_file(self): def test_raises_on_empty_dir(self): from charge3net_ft.data import _build_parquet_index + with tempfile.TemporaryDirectory() as tmp: with pytest.raises(FileNotFoundError): _build_parquet_index(Path(tmp)) + + def test_ignores_extra_columns(self): + """Newer LeMat-Rho dataset versions add Bader-analysis columns (e.g. + bader_charges, bader_volumes) alongside the four required columns. + _build_parquet_index and _row_to_atoms_and_density should ignore the + extras transparently: data.py:46 declares an explicit _COLUMNS allowlist + and pq.read_table is called with columns=_COLUMNS. + """ + from charge3net_ft.data import _build_parquet_index, _row_to_atoms_and_density + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + n = 3 + grid = json.dumps(np.ones((10, 10, 10)).tolist()) + table = pa.table( + { + # required columns + "compressed_charge_density": pa.array([grid] * n, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * n), + "cartesian_site_positions": pa.array([[[0.0, 0.0, 0.0]]] * n), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * n + ), + # extras analogous to what Entalpic/lemat-rho-v1 added in 2026: + "bader_charges": pa.array([[0.42]] * n), + "bader_volumes": pa.array([[11.7]] * n), + "material_id": pa.array([f"mat_{i}" for i in range(n)]), + } + ) + pq.write_table(table, d / "chunk_000.parquet") + + # build_parquet_index should still find all 3 valid rows + file_paths, index = _build_parquet_index(d) + assert len(index) == n + assert len(file_paths) == 1 + + # _row_to_atoms_and_density should produce a usable atoms+density + # even when the row dict contains the extras (it indexes the + # required keys directly, so the extras are dead weight). + row = { + "species_at_sites": ["Fe"], + "cartesian_site_positions": [[0.0, 0.0, 0.0]], + "lattice_vectors": [[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]], + "compressed_charge_density": grid, + "bader_charges": [0.42], + "bader_volumes": [11.7], + "material_id": "mat_0", + } + atoms, density, origin = _row_to_atoms_and_density(row) + assert len(atoms) == 1 + assert density.shape == (10, 10, 10) + np.testing.assert_array_equal(origin, np.zeros(3)) + + +# --------------------------------------------------------------------------- +# LRU eviction for the per-worker parquet table cache. +# +# Why this is here (regression test for the OOM that killed jobs 4971293 and +# 4971343): without eviction, each DataLoader worker accumulates every chunk +# it has ever read. With 8 workers per rank x 4 DDP ranks = 32 workers, and +# ~2 GB of pyarrow-decompressed table per chunk, the cache alone can grow to +# ~140 GB on a long run. The OOM hit at MaxRSS=35 GB per rank x 4 = 140 GB, +# above our 125 GB --mem budget. +# +# The fix: cap the cache. A small LRU bounded by `_TABLE_CACHE_MAX_CHUNKS` +# evicts the least-recently-used chunk before adding a new one. +# --------------------------------------------------------------------------- + + +class TestTableCacheLRU: + """LeMatRhoDataset's _TABLE_CACHE must evict to stay below a bounded size.""" + + def _write_n_chunks(self, d: Path, n: int): + for i in range(n): + _write_one_row_chunk(d / f"chunk_{i:03d}.parquet") + + def test_cache_size_is_bounded(self): + """After reading from many chunks, the cache must not contain all of them.""" + from charge3net_ft import data as data_mod + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + n_chunks = 10 + self._write_n_chunks(d, n_chunks) + + # Force a small cap so the test is fast and unambiguous. + original_max = getattr(data_mod, "_TABLE_CACHE_MAX_CHUNKS", None) + data_mod._TABLE_CACHE_MAX_CHUNKS = 3 + data_mod._TABLE_CACHE.clear() + try: + ds = data_mod.LeMatRhoDataset(parquet_dir=d, num_probes=None) + for i in range(len(ds)): + _ = ds._read_row(i) + assert len(data_mod._TABLE_CACHE) <= 3, ( + "cache grew beyond _TABLE_CACHE_MAX_CHUNKS=3; " + f"actual size {len(data_mod._TABLE_CACHE)}" + ) + finally: + if original_max is not None: + data_mod._TABLE_CACHE_MAX_CHUNKS = original_max + data_mod._TABLE_CACHE.clear() + + def test_cache_evicts_least_recently_used(self): + """When the cache is full, the next miss should drop the LRU entry.""" + from charge3net_ft import data as data_mod + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + self._write_n_chunks(d, 5) + data_mod._TABLE_CACHE_MAX_CHUNKS = 2 + data_mod._TABLE_CACHE.clear() + try: + ds = data_mod.LeMatRhoDataset(parquet_dir=d, num_probes=None) + # Touch chunks 0, 1 -> cache holds {0, 1} + ds._read_row(0) + ds._read_row(1) + assert set(data_mod._TABLE_CACHE.keys()) == {0, 1} + # Touch chunk 2 -> the LRU (0) should evict, cache holds {1, 2} + ds._read_row(2) + assert set(data_mod._TABLE_CACHE.keys()) == {1, 2}, ( + f"expected LRU eviction of chunk 0, got cache keys " + f"{set(data_mod._TABLE_CACHE.keys())}" + ) + # Re-access 1 -> bumps 1 to most-recent; cache still {1, 2} + ds._read_row(1) + # Touch 3 -> 2 is now LRU, evict 2, cache holds {1, 3} + ds._read_row(3) + assert set(data_mod._TABLE_CACHE.keys()) == {1, 3}, ( + f"expected LRU eviction of chunk 2 after re-access of 1; " + f"got cache keys {set(data_mod._TABLE_CACHE.keys())}" + ) + finally: + data_mod._TABLE_CACHE.clear() + + def test_cache_max_default_is_reasonable(self): + """The default cap must be > 0 and small enough that 8 workers x cap + worth of cached chunks fits well below per-rank memory budgets. + + With ~2 GB per chunk and ~8 workers per rank, a default of 5 caps + the per-rank cache at ~80 GB worst case (only chunks the worker + actually saw count; in practice well under). We pick 5 to leave + plenty of margin under a 32-GB-per-rank shared-mode allocation. + """ + from charge3net_ft import data as data_mod + + assert hasattr(data_mod, "_TABLE_CACHE_MAX_CHUNKS"), ( + "_TABLE_CACHE_MAX_CHUNKS must be defined for the LRU to work" + ) + assert 1 <= data_mod._TABLE_CACHE_MAX_CHUNKS <= 20, ( + f"_TABLE_CACHE_MAX_CHUNKS={data_mod._TABLE_CACHE_MAX_CHUNKS} is " + "outside the sensible range [1, 20]; very small evicts too " + "aggressively for shuffled access, very large defeats the cap" + ) + + +def _write_one_row_chunk(path: Path): + """Helper: one valid row per chunk; used by the LRU eviction tests.""" + table = pa.table( + { + "compressed_charge_density": pa.array( + [json.dumps(np.ones((10, 10, 10)).tolist())], type=pa.string() + ), + "species_at_sites": pa.array([["Fe"]]), + "cartesian_site_positions": pa.array([[[0.0, 0.0, 0.0]]]), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] + ), + } + ) + pq.write_table(table, path) diff --git a/tests/test_deepdft_data.py b/tests/test_deepdft_data.py new file mode 100644 index 0000000..57ae3dc --- /dev/null +++ b/tests/test_deepdft_data.py @@ -0,0 +1,194 @@ +"""TDD tests for the LeMat-Rho → DeepDFT data adapter. + +DeepDFT (peterbjorgensen/DeepDFT) consumes a per-sample dict of the form:: + + { + "density": np.ndarray (Nx, Ny, Nz), + "atoms": ase.Atoms, + "origin": np.ndarray (3,), + "grid_position": np.ndarray (Nx, Ny, Nz, 3), + "metadata": dict, # must contain "filename" + } + +Our adapter ``LeMatRhoDeepDFTDataset`` reuses the existing +``_row_to_atoms_and_density`` and ``_build_parquet_index`` helpers in +``charge3net_ft.data`` (so the input pipeline is shared between models) and +returns DeepDFT's dict shape directly. No tar/CHGCAR conversion needed. +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path + +import ase +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + + +# --------------------------------------------------------------------------- +# Helpers — write a synthetic chunk_*.parquet with the same schema the real +# LeMat-Rho data has, plus the Bader columns it gained in v1. +# --------------------------------------------------------------------------- +def _write_synthetic_chunk(path: Path, n_valid: int = 3) -> None: + grid = json.dumps(np.ones((10, 10, 10), dtype=np.float32).tolist()) + table = pa.table( + { + "compressed_charge_density": pa.array([grid] * n_valid, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * n_valid), + "cartesian_site_positions": pa.array([[[0.0, 0.0, 0.0]]] * n_valid), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * n_valid + ), + # extras DeepDFT must ignore + "bader_charges": pa.array([[0.42]] * n_valid), + "material_id": pa.array([f"mat_{i}" for i in range(n_valid)]), + } + ) + pq.write_table(table, path) + + +class TestLeMatRhoDeepDFTDataset: + """Adapter __getitem__ returns DeepDFT's exact dict contract.""" + + def test_length_matches_valid_rows(self): + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=5) + _write_synthetic_chunk(d / "chunk_001.parquet", n_valid=3) + ds = LeMatRhoDeepDFTDataset(parquet_dir=d) + assert len(ds) == 8 + + def test_item_has_all_required_keys(self): + """DeepDFT's collate_fn reads density, atoms, origin, grid_position, metadata.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + ds = LeMatRhoDeepDFTDataset(parquet_dir=d) + sample = ds[0] + for key in ("density", "atoms", "origin", "grid_position", "metadata"): + assert key in sample, ( + f"DeepDFT expects key {key!r}; got {list(sample.keys())}" + ) + + def test_item_density_is_3d_numpy(self): + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + assert isinstance(sample["density"], np.ndarray) + assert sample["density"].shape == (10, 10, 10), ( + f"expected (10, 10, 10) density; got {sample['density'].shape}" + ) + + def test_item_atoms_is_ase_atoms_with_pbc(self): + """Periodic boundary conditions matter for any solid-state density.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + assert isinstance(sample["atoms"], ase.Atoms) + assert all(sample["atoms"].pbc), ( + "LeMat-Rho cells are periodic; atoms.pbc must be (True, True, True)" + ) + + def test_item_origin_is_3vec_zeros(self): + """LeMat-Rho stores grids at fractional (0, 0, 0); the adapter mirrors that.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + assert isinstance(sample["origin"], np.ndarray) + np.testing.assert_array_equal(sample["origin"], np.zeros(3)) + + def test_item_grid_position_shape_matches_density(self): + """grid_position is (Nx, Ny, Nz, 3) Cartesian probe coordinates.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + assert sample["grid_position"].shape == (10, 10, 10, 3), ( + f"grid_position must be (Nx, Ny, Nz, 3); got {sample['grid_position'].shape}" + ) + + def test_grid_position_origin_is_zero(self): + """grid_position[0, 0, 0] must be the cell origin (0, 0, 0).""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + np.testing.assert_allclose(sample["grid_position"][0, 0, 0], np.zeros(3)) + + def test_grid_position_uses_cell_matrix(self): + """grid_position[1, 0, 0] should be one step along the a vector. + + For our synthetic 10×10×10 grid with a 4-Å cubic cell: + frac coord at index (1, 0, 0) = (1/10, 0, 0) + Cartesian = frac @ cell = (4/10, 0, 0) = (0.4, 0, 0) + """ + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + np.testing.assert_allclose( + sample["grid_position"][1, 0, 0], [0.4, 0.0, 0.0], atol=1e-5 + ) + + def test_item_metadata_has_filename(self): + """DeepDFT logs reference filename — must be a stable string per sample.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=2) + ds = LeMatRhoDeepDFTDataset(parquet_dir=d) + for i in range(len(ds)): + meta = ds[i]["metadata"] + assert "filename" in meta, f"metadata missing 'filename'; got {meta}" + assert isinstance(meta["filename"], str) + # Filenames should differ across samples so DeepDFT logs don't collide. + assert ds[0]["metadata"]["filename"] != ds[1]["metadata"]["filename"] + + def test_ignores_extra_columns(self): + """Bader / material_id columns added to LeMat-Rho v1 are dead weight here. + + Same regression we already pinned for charge3net_ft.data; mirroring it + on the DeepDFT path keeps the two adapters honest in lockstep. + """ + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + # The synthetic chunk includes bader_charges + material_id columns. + # The adapter should successfully ingest the row regardless. + assert sample["density"].shape == (10, 10, 10) + + +class TestRaisesOnEmptyDir: + def test_no_chunks_in_dir_raises(self): + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + with pytest.raises(FileNotFoundError): + LeMatRhoDeepDFTDataset(parquet_dir=Path(tmp)) diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py new file mode 100644 index 0000000..5b21770 --- /dev/null +++ b/tests/test_equivariance.py @@ -0,0 +1,164 @@ +"""Structural equivariance test for ChargE3Net. + +ChargE3Net predicts the scalar charge density ρ(r). For the model to be +rotationally equivariant (i.e. ρ(R·r; R·atoms) == ρ(r; atoms)), the output +irreps of the probe-side network must contain only ℓ=0 even-parity components +("0e", pure scalars). This is the e3nn-level guarantee: as long as the final +representation is a scalar irrep, the model's output is invariant under SO(3) +acting on the input frame. + +A runtime equivariance check (rotate inputs, predict, compare to predictions +on the unrotated inputs) is the gold standard but requires a real forward +pass on the production-sized model, which is too slow for a CPU unit test. +The structural test here covers the same property at the architecture level. + +Skipped automatically when the upstream charge3net repo isn't on disk. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest +import torch + +# --------------------------------------------------------------------------- +# Skip if the sibling charge3net repo isn't installed locally +# --------------------------------------------------------------------------- +_CHARGE3NET_ROOT = Path(__file__).resolve().parent.parent.parent / "charge3net" +if not _CHARGE3NET_ROOT.exists(): + pytest.skip( + f"charge3net repo not at {_CHARGE3NET_ROOT}; " + "clone github.com/AIforGreatGood/charge3net there to run this test", + allow_module_level=True, + ) +if str(_CHARGE3NET_ROOT) not in sys.path: + sys.path.insert(0, str(_CHARGE3NET_ROOT)) + +from e3nn import o3 # noqa: E402 +from src.charge3net.models.e3 import E3DensityModel # noqa: E402 + + +@pytest.fixture(scope="module") +def production_model(): + """Build a model with the MP-checkpoint hyperparameters. + + Module-scoped so the (slow) construction happens once for all assertions. + """ + torch.manual_seed(0) + model = E3DensityModel( + num_interactions=3, + num_neighbors=20, + mul=500, + lmax=4, + cutoff=4.0, + basis="gaussian", + num_basis=20, + ) + model.train(False) + return model + + +def test_param_count_matches_mp_checkpoint(production_model): + """Sanity check: the model has the 1.9M params we expect. + + Guards against silently changing the architecture in a way that breaks + checkpoint loading from charge3net_mp.pt. + """ + n_params = sum(p.numel() for p in production_model.parameters()) + assert 1_900_000 <= n_params <= 1_920_000, ( + f"Architecture drift: expected ~1.91M params (MP checkpoint), got {n_params:,}" + ) + + +def test_atom_model_uses_higher_order_irreps(production_model): + """ChargE3Net's atom representation must include ℓ>0 irreps to be 'higher-order'. + + The paper's central claim is that going from ℓ_max=1 to ℓ_max=4 produces + substantially better densities on systems with subtle bonding. If someone + accidentally drops the higher-l components (e.g. by passing lmax=0), the + model degenerates to a scalar-only network and silently regresses to a + much weaker baseline. + """ + atom_irreps = production_model.atom_model.atom_irreps_sequence + assert len(atom_irreps) > 0, "atom_irreps_sequence is empty" + final_irreps = atom_irreps[-1] + max_l = max(ir.l for _mul, ir in final_irreps) + assert max_l >= 4, ( + f"Atom representation max ℓ is {max_l}; ChargE3Net's " + f"higher-order claim requires ℓ_max ≥ 4. Got {final_irreps}." + ) + + +def test_atom_model_has_both_parities(production_model): + """The atom representation should include both even (+) and odd (-) parity irreps. + + Without odd-parity components the model can't represent any vector- or + pseudovector-valued atom features, which the higher-order convolutions + need internally. The default get_irreps(mul, lmax) function in e3.py + generates both; this test pins that down. + """ + final_irreps = production_model.atom_model.atom_irreps_sequence[-1] + parities = {ir.p for _mul, ir in final_irreps} + assert parities == {-1, 1}, ( + f"Atom irreps should include both even (p=+1) and odd (p=-1) parities; " + f"got parities {parities}: {final_irreps}" + ) + + +def test_get_irreps_helper_is_balanced(): + """The get_irreps helper in e3.py should produce roughly balanced channel counts. + + This is the function used to construct atom_irreps. If it ever returns + zero-multiplicity for any (l, p) pair at production hyperparameters, the + architecture breaks silently (some irreps disappear). Tests the helper + directly to fail fast. + """ + from src.charge3net.models.e3 import get_irreps + + irreps = get_irreps(500, lmax=4) + multiplicities = [mul for mul, _ in irreps] + assert all(mul > 0 for mul in multiplicities), ( + f"get_irreps(500, 4) produced a zero-multiplicity irrep: {irreps}" + ) + # 5 ℓ levels × 2 parities = 10 entries + assert len(irreps) == 10, ( + f"Expected 10 irreps (5 ℓ × 2 parity), got {len(irreps)}: {irreps}" + ) + + +def test_atom_irreps_sequence_length_matches_num_interactions(production_model): + """One irreps entry per convolution layer (plus the input embedding).""" + seq = production_model.atom_model.atom_irreps_sequence + # num_interactions=3 → 3 convolutions; the sequence holds the post-conv + # representations. Length will be 3 or 4 depending on whether the input + # embedding is included; both are valid, but we pin a sane range. + assert 3 <= len(seq) <= 5, ( + f"atom_irreps_sequence length {len(seq)} is outside the expected " + f"range [3, 5] for num_interactions=3" + ) + + +def test_atom_model_uses_cutoff_consistent_with_kdtree(production_model): + """The cutoff baked into the atom model must match what the dataset uses. + + `KdTreeGraphConstructor` in LeMatRhoDataset uses cutoff=4.0; if the model + is built with a different cutoff, edges fed in at training time won't + match what the convolution layer expects. + """ + assert production_model.atom_model.cutoff == pytest.approx(4.0) + + +def test_e3nn_o3_irreps_are_proper_objects(production_model): + """The atom representation must use e3nn's o3.Irreps wrapper. + + Equivariance is enforced by the o3.Irreps abstraction (which carries + parity information and is consumed by FullyConnectedTensorProduct). If + someone replaces it with a plain list, equivariance silently breaks even + though the forward pass still produces output. + """ + final_irreps = production_model.atom_model.atom_irreps_sequence[-1] + assert isinstance(final_irreps, o3.Irreps), ( + f"Expected o3.Irreps for atom_irreps_sequence[-1]; got {type(final_irreps)}" + ) diff --git a/tests/test_salted_basis.py b/tests/test_salted_basis.py new file mode 100644 index 0000000..b8ce7f4 --- /dev/null +++ b/tests/test_salted_basis.py @@ -0,0 +1,155 @@ +"""TDD tests for the SALTED-arm BasisSpec dataclass. + +Locks down the basis numbers chosen in +``plan_salted_graph2mat_basis_choice_may_20_pm.md`` (Phase A4): + +* ``max_l = 4`` +* ``n_radial = 4`` (uniform across species in v1) +* ``sigma = (0.5, 1.0, 2.0, 4.0)`` Å — geometric radial-width ladder +* ``cutoff = 4.0`` Å — matches ChargE3Net's KdTree cutoff +* ``n_coeffs_per_atom == n_radial * (max_l + 1) ** 2`` == 100 + +These numbers are referenced by every downstream PR (projection, +reconstruction, model wrapper, VASP I/O). Pinning them here means a +later edit shows up as a single failing test, not a silent drift. +""" + +from __future__ import annotations + +import pytest + + +class TestBasisSpecDefaults: + """Default BasisSpec must match the A4 lockdown.""" + + def test_default_max_l_is_four(self): + from salted_ft.basis import BasisSpec + + assert BasisSpec().max_l == 4 + + def test_default_n_radial_is_four(self): + from salted_ft.basis import BasisSpec + + assert BasisSpec().n_radial == 4 + + def test_default_sigma_ladder(self): + from salted_ft.basis import BasisSpec + + # Geometric ladder over tight + valence + diffuse regimes. + assert BasisSpec().sigma == (0.5, 1.0, 2.0, 4.0) + + def test_default_cutoff_matches_charge3net(self): + from salted_ft.basis import BasisSpec + + # ChargE3Net's KdTreeGraphConstructor uses cutoff=4.0; the SALTED-arm + # uses the same so atom-neighbor structure is identical between models. + assert BasisSpec().cutoff == pytest.approx(4.0) + + def test_default_n_coeffs_per_atom_is_100(self): + """4 radial * (4+1)^2 angular = 100 coefficients per atom.""" + from salted_ft.basis import BasisSpec + + assert BasisSpec().n_coeffs_per_atom == 100 + + +class TestBasisSpecArithmetic: + """n_coeffs_per_atom must equal n_radial * (max_l + 1)^2 for any valid spec.""" + + @pytest.mark.parametrize( + "max_l,n_radial,expected", + [ + (0, 1, 1), # one s function + (1, 2, 8), # 2 * (1 + 3) = 8 + (2, 3, 27), # 3 * (1 + 3 + 5) = 27 + (4, 4, 100), # the production default + (6, 4, 196), # 4 * (1 + 3 + 5 + 7 + 9 + 11 + 13) = 196 + ], + ) + def test_n_coeffs_formula(self, max_l, n_radial, expected): + from salted_ft.basis import BasisSpec + + spec = BasisSpec( + max_l=max_l, + n_radial=n_radial, + sigma=tuple(0.5 * 2**i for i in range(n_radial)), + cutoff=5.0, + ) + assert spec.n_coeffs_per_atom == expected + + def test_n_radial_matches_sigma_length(self): + """sigma is the per-radial-channel width; len(sigma) must equal n_radial.""" + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"n_radial.*sigma"): + BasisSpec(max_l=2, n_radial=3, sigma=(0.5, 1.0), cutoff=4.0) + + +class TestBasisSpecValidation: + """Reject malformed specs at construction time, not at use time.""" + + def test_negative_max_l_rejected(self): + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"max_l"): + BasisSpec(max_l=-1, n_radial=4, sigma=(0.5, 1.0, 2.0, 4.0), cutoff=4.0) + + def test_zero_n_radial_rejected(self): + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"n_radial"): + BasisSpec(max_l=4, n_radial=0, sigma=(), cutoff=4.0) + + def test_negative_sigma_rejected(self): + """sigma is a Gaussian width; nonpositive widths are nonphysical.""" + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"sigma"): + BasisSpec(max_l=2, n_radial=2, sigma=(0.5, -1.0), cutoff=4.0) + + def test_nonpositive_cutoff_rejected(self): + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"cutoff"): + BasisSpec(max_l=2, n_radial=2, sigma=(0.5, 1.0), cutoff=0.0) + + +class TestBasisSpecShapes: + """Shape helpers for downstream tensor allocation.""" + + def test_n_angular_components_per_radial(self): + """(max_l + 1)^2 real spherical harmonic components per radial channel.""" + from salted_ft.basis import BasisSpec + + # l=0,1,2,3,4 -> 1+3+5+7+9 = 25 angular components per radial channel + assert BasisSpec().n_angular_components == 25 + + def test_total_coeffs_shape(self): + """coeffs tensor shape for a structure: (n_atoms, n_coeffs_per_atom).""" + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + assert spec.total_coeffs_shape(n_atoms=5) == (5, 100) + assert spec.total_coeffs_shape(n_atoms=1) == (1, 100) + + +class TestBasisSpecImmutable: + """BasisSpec must be hashable + immutable so it can key caches / metric runs.""" + + def test_is_hashable(self): + from salted_ft.basis import BasisSpec + + # Two specs with identical fields hash to the same value. + a = BasisSpec() + b = BasisSpec() + assert hash(a) == hash(b) + assert a == b + + def test_mutation_rejected(self): + """Frozen dataclass — assigning to a field raises FrozenInstanceError.""" + from dataclasses import FrozenInstanceError + + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + with pytest.raises(FrozenInstanceError): + spec.max_l = 6 # type: ignore[misc] diff --git a/tests/test_salted_io.py b/tests/test_salted_io.py new file mode 100644 index 0000000..61a00e1 --- /dev/null +++ b/tests/test_salted_io.py @@ -0,0 +1,207 @@ +"""TDD tests for VASP CHGCAR I/O wrapper (PR delta). + +The wrapper exposes ``write_chgcar(density, atoms, path)`` so a +reconstructed real-space density grid can be persisted as a VASP +CHGCAR file. That file is then the input to a paired SCF run +(``ICHARG=1``) for the speedup comparison vs the +``ICHARG=2``-from-superposition baseline. + +Locked contract: + +* ``write_chgcar(density, atoms, path, n_electrons=None)`` + Writes a pymatgen ``Chgcar``-compatible file at ``path``. If + ``n_electrons`` is given, rescales the density so that + ``sum(density) * cell_volume / N_grid == n_electrons``. +* The written file round-trips through ``Chgcar.from_file`` and + preserves shape, atom species, and cell. +* ``read_chgcar(path)`` is the inverse: returns + ``(density: np.ndarray, atoms: ase.Atoms)``. + +End-to-end SCF speedup test is gated on the entalsim +``StructureVASPSinglePoint`` maker landing; pinned here as an +``importorskip`` placeholder so it auto-activates when the +dependency arrives. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import ase +import numpy as np +import pytest + + +def _cubic_atoms(symbols=("Fe",), fractional=((0.5, 0.5, 0.5),), a=4.0): + cell = np.eye(3) * a + cart = np.array(fractional) @ cell + return ase.Atoms(symbols=list(symbols), positions=cart, cell=cell, pbc=True) + + +class TestWriteChgcar: + def test_writes_file(self): + from salted_ft.io import write_chgcar + + atoms = _cubic_atoms() + density = np.ones((8, 8, 8), dtype=np.float64) + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + assert path.exists() + assert path.stat().st_size > 0 + + def test_normalizes_to_total_electron_count(self): + """When ``n_electrons`` is set, the *integrated* density of the + written file must equal ``n_electrons`` to within ``1e-6 * n_electrons``. + That's what VASP reads as N_electrons on ICHARG=1. + """ + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms() + # Density that integrates to something arbitrary; write_chgcar + # should rescale to the requested electron count. + density = np.ones((8, 8, 8), dtype=np.float64) * 0.5 + target_n = 26.0 # Fe valence electron count, roughly + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path, n_electrons=target_n) + read_density, _ = read_chgcar(path) + # CHGCAR convention: density * volume / N_grid integrates to N_electrons + cell_volume = np.linalg.det(atoms.get_cell()) + n_grid = np.prod(read_density.shape) + total_e = read_density.sum() * cell_volume / n_grid + assert abs(total_e - target_n) / target_n < 1e-4, ( + f"integrated density {total_e:.6f} differs from target {target_n} " + "by more than 1e-4; CHGCAR normalization is wrong" + ) + + def test_rejects_non_3d_density(self): + from salted_ft.io import write_chgcar + + atoms = _cubic_atoms() + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + with pytest.raises(ValueError, match=r"3D"): + write_chgcar(np.ones((8, 8)), atoms, path) + + def test_rejects_negative_n_electrons(self): + from salted_ft.io import write_chgcar + + atoms = _cubic_atoms() + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + with pytest.raises(ValueError, match=r"n_electrons"): + write_chgcar(np.ones((8, 8, 8)), atoms, path, n_electrons=-1.0) + + +class TestReadChgcar: + def test_returns_density_and_atoms(self): + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms() + density = np.ones((8, 8, 8), dtype=np.float64) * 0.1 + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + read_density, read_atoms = read_chgcar(path) + assert read_density.shape == (8, 8, 8) + assert isinstance(read_atoms, ase.Atoms) + + def test_preserves_atom_species(self): + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms( + symbols=("Fe", "O"), fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5)) + ) + density = np.ones((8, 8, 8), dtype=np.float64) * 0.1 + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + _, read_atoms = read_chgcar(path) + # Order may differ but the multiset of species must match. + assert sorted(read_atoms.get_chemical_symbols()) == sorted( + atoms.get_chemical_symbols() + ) + + def test_preserves_cell(self): + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms(a=5.0) + density = np.ones((4, 4, 4), dtype=np.float64) * 0.05 + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + _, read_atoms = read_chgcar(path) + np.testing.assert_allclose( + np.asarray(read_atoms.get_cell()), + np.asarray(atoms.get_cell()), + atol=1e-6, + ) + + +class TestRoundtrip: + def test_density_roundtrip_within_tolerance(self): + """Write then read: shape exact, values within VASP-precision tolerance. + + VASP CHGCAR uses 5-decimal scientific notation per value, so + we expect ~1e-5 relative precision. + """ + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms() + rng = np.random.default_rng(7) + density = rng.random((8, 8, 8)).astype(np.float64) * 0.1 + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + read_density, _ = read_chgcar(path) + assert read_density.shape == density.shape + np.testing.assert_allclose(read_density, density, rtol=1e-3, atol=1e-5) + + +class TestSALTEDModelToChgcar: + """End-to-end: predict via SALTEDModel, reconstruct, write CHGCAR.""" + + def test_predicted_density_writes_to_chgcar(self): + from salted_ft.basis import BasisSpec + from salted_ft.io import read_chgcar, write_chgcar + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms() + model = SALTEDModel(basis_spec=BasisSpec()) + density = model.reconstruct_density(atoms, (8, 8, 8)) + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + assert path.exists() + read_density, _ = read_chgcar(path) + assert read_density.shape == (8, 8, 8) + + +# --------------------------------------------------------------------------- +# Forward-looking placeholder for the entalsim integration. +# +# Once Entalpic/entalsim PR #56's PR 2 (StructureVASPSinglePoint maker) +# lands and is installable, this test will auto-activate. Until then it +# skips cleanly so the suite stays green. +# --------------------------------------------------------------------------- +class TestVASPSinglePointHook: + def test_chgcar_consumed_by_entalsim_single_point_maker(self): + # Skips until entalsim ships the maker. + pytest.importorskip("entalsim.dft.tasks.single_point") + from entalsim.dft.tasks.single_point import StructureVASPSinglePoint + + from salted_ft.basis import BasisSpec + from salted_ft.io import write_chgcar + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms() + model = SALTEDModel(basis_spec=BasisSpec()) + density = model.reconstruct_density(atoms, (8, 8, 8)) + with tempfile.TemporaryDirectory() as tmp: + chgcar = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, chgcar) + # Maker should accept the written CHGCAR for ICHARG=1. + maker = StructureVASPSinglePoint(initial_chgcar=chgcar) + assert maker.initial_chgcar == chgcar diff --git a/tests/test_salted_model.py b/tests/test_salted_model.py new file mode 100644 index 0000000..a7895fe --- /dev/null +++ b/tests/test_salted_model.py @@ -0,0 +1,228 @@ +"""TDD tests for the SALTEDModel wrapper (PR gamma). + +The wrapper exposes ``__call__(atoms) -> coefficients`` so SALTED-style +predictions plug into the projection / reconstruction layer from PR beta. + +Locked contract: + +* ``SALTEDModel(basis_spec, ckpt_path=None)`` — construct. When + ``ckpt_path`` is None the wrapper produces deterministic + position-dependent stub coefficients (lets us run tests + the + reconstruction pipeline without a real rholearn checkpoint). + +* ``model(atoms)`` returns ``np.ndarray (n_atoms, n_coeffs_per_atom)``, + float64, finite, deterministic for fixed inputs. + +* ``model.reconstruct_density(atoms, grid_shape)`` returns the density + grid in the same shape ``reconstruct_grid_from_basis`` would have + produced from the predicted coefficients. Convenience method for the + VASP comparison pipeline. + +* Metric integration: the predicted density grid feeds into + ``compute_nmape`` / ``compute_rmse`` / ``compute_nrmse`` from + ``charge3net_ft.train`` and they return finite scalars. Pinned per the + brief: "Keep the metric calculations identical to our ChargE3Net pipeline." +""" + +from __future__ import annotations + +import ase +import numpy as np +import torch + + +def _cubic_atoms(symbols=("Fe",), fractional=((0.0, 0.0, 0.0),), a=4.0): + cell = np.eye(3) * a + cart = np.array(fractional) @ cell + return ase.Atoms(symbols=list(symbols), positions=cart, cell=cell, pbc=True) + + +class TestSALTEDModelConstruct: + def test_constructs_with_basis_spec(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + spec = BasisSpec() + m = SALTEDModel(basis_spec=spec) + assert m.basis_spec is spec + + def test_default_ckpt_is_none(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + assert m.ckpt_path is None + + +class TestSALTEDModelOutputShape: + def test_single_atom_output_shape(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + spec = BasisSpec() + m = SALTEDModel(basis_spec=spec) + coeffs = m(_cubic_atoms()) + assert coeffs.shape == (1, spec.n_coeffs_per_atom) + + def test_multi_atom_output_shape(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + spec = BasisSpec() + m = SALTEDModel(basis_spec=spec) + atoms = _cubic_atoms( + symbols=("Fe", "O", "Fe"), + fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), + ) + coeffs = m(atoms) + assert coeffs.shape == (3, spec.n_coeffs_per_atom) + + def test_output_dtype_is_float64(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + coeffs = m(_cubic_atoms()) + assert coeffs.dtype == np.float64 + + def test_output_is_finite(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + coeffs = m(_cubic_atoms()) + assert np.isfinite(coeffs).all() + + +class TestSALTEDModelDeterminism: + def test_same_input_gives_same_output(self): + """Reproducibility: critical for CI + regression tests.""" + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + atoms = _cubic_atoms( + symbols=("Fe", "Fe"), fractional=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6)) + ) + c1 = m(atoms) + c2 = m(atoms) + np.testing.assert_array_equal(c1, c2) + + def test_different_positions_give_different_coefficients(self): + """A degenerate stub that always returned zeros would pass shape + + determinism but be useless. Require some position-dependent + variation so downstream tests have signal to work with. + """ + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + atoms_a = _cubic_atoms(fractional=((0.0, 0.0, 0.0),)) + atoms_b = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + c_a = m(atoms_a) + c_b = m(atoms_b) + assert not np.allclose(c_a, c_b), ( + "predicted coefficients must depend on atom positions; the stub " + "appears to return position-independent constants" + ) + + +class TestSALTEDModelReconstructDensity: + def test_reconstruct_density_shape(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + grid = m.reconstruct_density(_cubic_atoms(), (8, 8, 8)) + assert grid.shape == (8, 8, 8) + + def test_reconstruct_density_matches_explicit_path(self): + """``model.reconstruct_density(atoms, shape)`` must equal calling + ``model(atoms)`` then ``reconstruct_grid_from_basis(c, ...)``. + Convenience method is just sugar. + """ + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms( + symbols=("Fe", "O"), fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5)) + ) + m = SALTEDModel(basis_spec=spec) + c = m(atoms) + expected = reconstruct_grid_from_basis(c, atoms, (8, 8, 8), spec) + got = m.reconstruct_density(atoms, (8, 8, 8)) + np.testing.assert_array_equal(got, expected) + + def test_reconstruct_density_dtype_and_finite(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + grid = m.reconstruct_density(_cubic_atoms(), (8, 8, 8)) + assert grid.dtype == np.float64 + assert np.isfinite(grid).all() + + +class TestMetricIntegration: + """Predicted density grid feeds the existing ChargE3Net metric functions.""" + + def _to_torch_batch(self, grid: np.ndarray) -> torch.Tensor: + """Flatten a (Nx, Ny, Nz) grid into a (B=1, N_probes) torch tensor. + + ChargE3Net's compute_nmape signature is (preds, targets, num_probes). + For full-grid evaluation we use B=1 and num_probes=None. + """ + return torch.from_numpy(grid.astype(np.float32).reshape(1, -1)) + + def test_compute_nmape_returns_finite_scalar(self): + from charge3net_ft.train import compute_nmape + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + m = SALTEDModel(basis_spec=BasisSpec()) + preds = self._to_torch_batch(m.reconstruct_density(atoms, (8, 8, 8))) + # Synthetic target: same shape, non-zero so the NMAPE denominator is positive + targets = torch.ones_like(preds) + nmape = compute_nmape(preds, targets, num_probes=None) + assert nmape.numel() == 1 + assert torch.isfinite(nmape).all() + + def test_compute_rmse_returns_finite_scalar(self): + from charge3net_ft.train import compute_rmse + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + m = SALTEDModel(basis_spec=BasisSpec()) + preds = self._to_torch_batch(m.reconstruct_density(atoms, (8, 8, 8))) + targets = torch.ones_like(preds) + rmse = compute_rmse(preds, targets, num_probes=None) + assert torch.isfinite(rmse).all() + + def test_compute_nrmse_returns_finite_scalar(self): + from charge3net_ft.train import compute_nrmse + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + m = SALTEDModel(basis_spec=BasisSpec()) + preds = self._to_torch_batch(m.reconstruct_density(atoms, (8, 8, 8))) + targets = torch.ones_like(preds) + nrmse = compute_nrmse(preds, targets, num_probes=None) + assert torch.isfinite(nrmse).all() + + def test_perfect_prediction_gives_zero_nmape(self): + """Sanity check: NMAPE of a tensor against itself is zero.""" + from charge3net_ft.train import compute_nmape + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + m = SALTEDModel(basis_spec=BasisSpec()) + preds = self._to_torch_batch(m.reconstruct_density(atoms, (8, 8, 8))) + # Self-similarity: target identical to prediction => zero error. + nmape = compute_nmape(preds, preds.clone(), num_probes=None) + assert nmape.item() == 0.0 diff --git a/tests/test_salted_project_dataset.py b/tests/test_salted_project_dataset.py new file mode 100644 index 0000000..f9c8253 --- /dev/null +++ b/tests/test_salted_project_dataset.py @@ -0,0 +1,236 @@ +"""TDD tests for the Phase D2 dataset-projection module. + +Locks the contract for ``salted_ft.project_dataset.project_chunk``, +which reads a LeMat-Rho-format parquet chunk, runs +``project_chgcar_to_basis`` row by row, and writes a parallel parquet +chunk of projected coefficients. + +Output schema per row:: + + { + "row_index": int (matches the original chunk row index), + "material_id": str (carried through if present, else "" ), + "n_atoms": int, + "atomic_numbers": list[int], + "lattice_vectors": list[list[float]], # 3x3 + "n_electrons": float (integrated density * cell_volume / n_grid), + "grid_shape": list[int], # [Nx, Ny, Nz] + "coefficients": list[list[float]], # (n_atoms, n_coeffs_per_atom) + "basis_set_NMAPE": float (basis-ceiling NMAPE for this row), + } + +The basis_set_NMAPE column is the per-row reconstruction error from +roundtripping; we keep it so downstream sanity-checks can know each +sample's basis ceiling. +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + + +def _write_synthetic_chunk(path: Path, n_rows: int = 3) -> None: + """Write a LeMat-Rho-format chunk for use by the projection script.""" + rng = np.random.default_rng(42) + grids = [ + json.dumps(rng.random((10, 10, 10), dtype=np.float64).tolist()) + for _ in range(n_rows) + ] + table = pa.table( + { + "compressed_charge_density": pa.array(grids, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * n_rows), + "cartesian_site_positions": pa.array([[[2.0, 2.0, 2.0]]] * n_rows), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * n_rows + ), + # extras: confirm they get ignored + "bader_charges": pa.array([[0.4]] * n_rows), + "material_id": pa.array([f"mat_{i:03d}" for i in range(n_rows)]), + } + ) + pq.write_table(table, path) + + +class TestProjectChunkContract: + """``project_chunk(in_path, out_path, basis_spec)`` -> None. + + Reads ``in_path`` (LeMat-Rho format), projects each row, writes + ``out_path`` in the schema documented at the top of this file. + """ + + def test_output_file_written(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=2) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + assert out.exists() + assert out.stat().st_size > 0 + + def test_row_count_matches_valid_input(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=3) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out) + assert len(t) == 3 + + def test_required_columns_present(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=2) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out) + required = { + "row_index", + "material_id", + "n_atoms", + "atomic_numbers", + "lattice_vectors", + "n_electrons", + "grid_shape", + "coefficients", + "basis_set_NMAPE", + } + missing = required - set(t.column_names) + assert not missing, f"missing required columns: {missing}" + + def test_coefficient_shape_per_row(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + spec = BasisSpec() + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=2) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, spec) + t = pq.read_table(out).to_pydict() + for c, n_atoms in zip(t["coefficients"], t["n_atoms"], strict=True): + # Each row has its own coefficient block; first dim is n_atoms, + # second is n_coeffs_per_atom. + arr = np.asarray(c) + assert arr.shape == (n_atoms, spec.n_coeffs_per_atom), ( + f"row coefficient shape mismatch: got {arr.shape}, " + f"expected ({n_atoms}, {spec.n_coeffs_per_atom})" + ) + + def test_basis_set_NMAPE_is_finite_and_nonnegative(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=3) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out).to_pydict() + for x in t["basis_set_NMAPE"]: + assert np.isfinite(x) + assert x >= 0.0 + + def test_material_id_preserved(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=3) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out).to_pydict() + assert t["material_id"] == ["mat_000", "mat_001", "mat_002"] + + def test_handles_null_charge_density_rows(self): + """Real LeMat-Rho chunks have some rows with NULL density (failed + DFT extraction). Those should be skipped, not crash the projection. + """ + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + grids = [ + json.dumps(np.ones((10, 10, 10)).tolist()), + None, # null density - should be skipped + json.dumps(np.ones((10, 10, 10)).tolist()), + ] + table = pa.table( + { + "compressed_charge_density": pa.array(grids, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * 3), + "cartesian_site_positions": pa.array([[[2.0, 2.0, 2.0]]] * 3), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * 3 + ), + "material_id": pa.array(["a", "b", "c"]), + } + ) + pq.write_table(table, d / "in.parquet") + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out).to_pydict() + assert len(t["row_index"]) == 2 + assert t["row_index"] == [0, 2] + + +class TestProjectDirectory: + """Driver that runs project_chunk over every chunk_*.parquet in a dir.""" + + def test_processes_all_chunks(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_directory + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + in_d = d / "in" + in_d.mkdir() + out_d = d / "out" + for i in range(3): + _write_synthetic_chunk(in_d / f"chunk_{i:06d}.parquet", n_rows=2) + project_directory(in_d, out_d, BasisSpec()) + outputs = sorted(out_d.glob("chunk_*.parquet")) + assert len(outputs) == 3 + for out in outputs: + assert pq.read_table(out).num_rows == 2 + + def test_skips_existing_outputs(self): + """Idempotent: a re-run does not re-project chunks that already exist. + + Lets us resume a partially-completed projection job after an + interruption without paying the LSQR cost again. + """ + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_directory + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + in_d = d / "in" + in_d.mkdir() + out_d = d / "out" + _write_synthetic_chunk(in_d / "chunk_000000.parquet", n_rows=2) + # First run + project_directory(in_d, out_d, BasisSpec()) + first_mtime = (out_d / "chunk_000000.parquet").stat().st_mtime + # Second run should be a no-op + project_directory(in_d, out_d, BasisSpec()) + second_mtime = (out_d / "chunk_000000.parquet").stat().st_mtime + assert first_mtime == second_mtime diff --git a/tests/test_salted_projection.py b/tests/test_salted_projection.py new file mode 100644 index 0000000..fc79801 --- /dev/null +++ b/tests/test_salted_projection.py @@ -0,0 +1,225 @@ +"""TDD tests for VASP CHGCAR <-> SALTED basis projection / reconstruction. + +These two operations are the DIY bridge layer between VASP plane-wave +densities and the rholearn/SALTED localized-basis world (see the +``plan_salted_graph2mat_basis_choice_may_20_pm.md`` memo for context). + +Locked contracts here: + +* ``project_chgcar_to_basis(density, atoms, basis_spec)`` + -> ``np.ndarray (n_atoms, n_coeffs_per_atom)`` float64. + Zero density gives zero coefficients. Linear in the input density. + +* ``reconstruct_grid_from_basis(coefficients, atoms, grid_shape, basis_spec)`` + -> ``np.ndarray (Nx, Ny, Nz)`` float64. + Zero coefficients give a zero grid. Linear in the coefficients. + A single-atom, l=0, n=0 unit coefficient produces a Gaussian peaked + at the atom position. + +The roundtrip is intentionally NOT pinned to high accuracy in this PR. +A simple orthonormal-approximation projection is enough to land the +contract; a future PR will swap in least-squares solving against the +full basis overlap matrix for tight roundtrip accuracy. +""" + +from __future__ import annotations + +import ase +import numpy as np + + +# --------------------------------------------------------------------------- +# Helpers — small synthetic structures so tests stay fast and inspectable. +# --------------------------------------------------------------------------- +def _cubic_atoms(symbols=("Fe",), fractional=((0.0, 0.0, 0.0),), a=4.0): + """Single-cell ase.Atoms with the requested species/positions in fractional coords.""" + cell = np.eye(3) * a + cart = np.array(fractional) @ cell + return ase.Atoms(symbols=list(symbols), positions=cart, cell=cell, pbc=True) + + +def _zero_grid(shape=(8, 8, 8)) -> np.ndarray: + return np.zeros(shape, dtype=np.float32) + + +def _random_grid(shape=(8, 8, 8), seed: int = 0) -> np.ndarray: + rng = np.random.default_rng(seed) + return rng.random(shape, dtype=np.float32) + + +# --------------------------------------------------------------------------- +# Projection: density grid -> coefficients +# --------------------------------------------------------------------------- +class TestProjectChgcarToBasis: + def test_output_shape_is_n_atoms_by_n_coeffs(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + spec = BasisSpec() + atoms = _cubic_atoms( + symbols=("Fe", "Fe"), fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5)) + ) + coeffs = project_chgcar_to_basis(_zero_grid(), atoms, spec) + assert coeffs.shape == (2, spec.n_coeffs_per_atom) + + def test_zero_density_gives_zero_coefficients(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + coeffs = project_chgcar_to_basis(_zero_grid(), _cubic_atoms(), BasisSpec()) + np.testing.assert_array_equal(coeffs, 0.0) + + def test_output_dtype_is_float64(self): + """float64 because we'll feed these to scipy/least-squares downstream.""" + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + coeffs = project_chgcar_to_basis(_random_grid(), _cubic_atoms(), BasisSpec()) + assert coeffs.dtype == np.float64 + + def test_linearity_in_density(self): + """project(alpha * rho) == alpha * project(rho); a basic sanity check + since both projection and reconstruction must be linear maps. + """ + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + atoms = _cubic_atoms() + spec = BasisSpec() + rho = _random_grid(seed=1) + c1 = project_chgcar_to_basis(rho, atoms, spec) + c_scaled = project_chgcar_to_basis(2.5 * rho, atoms, spec) + np.testing.assert_allclose(c_scaled, 2.5 * c1, rtol=1e-5, atol=1e-8) + + def test_additivity_in_density(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + atoms = _cubic_atoms() + spec = BasisSpec() + rho1 = _random_grid(seed=2) + rho2 = _random_grid(seed=3) + c1 = project_chgcar_to_basis(rho1, atoms, spec) + c2 = project_chgcar_to_basis(rho2, atoms, spec) + c_sum = project_chgcar_to_basis(rho1 + rho2, atoms, spec) + np.testing.assert_allclose(c_sum, c1 + c2, rtol=1e-5, atol=1e-8) + + def test_output_is_finite(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + coeffs = project_chgcar_to_basis(_random_grid(), _cubic_atoms(), BasisSpec()) + assert np.isfinite(coeffs).all() + + +# --------------------------------------------------------------------------- +# Reconstruction: coefficients -> density grid +# --------------------------------------------------------------------------- +class TestReconstructGridFromBasis: + def test_output_shape_matches_grid_shape(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + coeffs = np.zeros((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + assert grid.shape == (8, 8, 8) + + def test_zero_coefficients_give_zero_grid(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + coeffs = np.zeros((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + np.testing.assert_array_equal(grid, 0.0) + + def test_output_dtype_is_float64(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + rng = np.random.default_rng(4) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + assert grid.dtype == np.float64 + + def test_linearity_in_coefficients(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + rng = np.random.default_rng(5) + c = rng.standard_normal((1, spec.n_coeffs_per_atom)) + g1 = reconstruct_grid_from_basis(c, atoms, (8, 8, 8), spec) + g_scaled = reconstruct_grid_from_basis(3.0 * c, atoms, (8, 8, 8), spec) + np.testing.assert_allclose(g_scaled, 3.0 * g1, rtol=1e-5, atol=1e-8) + + def test_single_atom_l0_n0_peaks_at_atom_position(self): + """Unit s-coefficient on the first radial channel: density should peak + at the atom position (not somewhere else in the cell).""" + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + # Atom at the (0.5, 0.5, 0.5) interior point, away from cell edges. + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),), a=4.0) + coeffs = np.zeros((1, spec.n_coeffs_per_atom)) + coeffs[0, 0] = 1.0 # l=0, m=0, n=0 (the most localized s channel) + grid = reconstruct_grid_from_basis(coeffs, atoms, (16, 16, 16), spec) + + # Peak index in (i, j, k) integer grid should be near the center. + peak_idx = np.unravel_index(np.argmax(grid), grid.shape) + center = (8, 8, 8) # fractional 0.5 on a 16-point grid + for actual, expected in zip(peak_idx, center, strict=True): + assert abs(actual - expected) <= 1, ( + f"density peak {peak_idx} is far from atom (expected near {center}); " + "either the atom-position lookup or the basis evaluation is wrong" + ) + + def test_output_is_finite(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + rng = np.random.default_rng(6) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + assert np.isfinite(grid).all() + + +# --------------------------------------------------------------------------- +# Roundtrip: project then reconstruct (and vice versa). +# --------------------------------------------------------------------------- +class TestProjectionReconstructionRoundtrip: + def test_roundtrip_of_zero_density_is_zero(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import ( + project_chgcar_to_basis, + reconstruct_grid_from_basis, + ) + + atoms = _cubic_atoms() + spec = BasisSpec() + coeffs = project_chgcar_to_basis(_zero_grid(), atoms, spec) + roundtrip = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + np.testing.assert_array_equal(roundtrip, 0.0) + + def test_roundtrip_of_zero_coefficients_is_zero(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import ( + project_chgcar_to_basis, + reconstruct_grid_from_basis, + ) + + atoms = _cubic_atoms() + spec = BasisSpec() + c = np.zeros((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(c, atoms, (8, 8, 8), spec) + c_back = project_chgcar_to_basis(grid, atoms, spec) + np.testing.assert_array_equal(c_back, 0.0) diff --git a/tests/test_salted_rholearn_adapter.py b/tests/test_salted_rholearn_adapter.py new file mode 100644 index 0000000..4e77775 --- /dev/null +++ b/tests/test_salted_rholearn_adapter.py @@ -0,0 +1,235 @@ +"""TDD tests for the SALTED -> rholearn data adapter (Phase D3). + +rholearn's training loop consumes basis-coefficient vectors in a +specific flat layout (see ``rholearn/utils/convert.py::_get_flat_index``): + + atom (outer) -> o3_lambda -> n (radial, INNER to lambda) -> o3_mu (innermost) + +Our ``salted_ft.projection`` layout differs: + + atom (outer) -> n (radial, OUTER to lambda) -> (lambda, mu) packed + +The adapter functions tested here move between the two layouts and +produce the ``lmax`` / ``nmax`` dicts rholearn's metatensor converter +needs to know the basis shape. +""" + +from __future__ import annotations + +import numpy as np +import pytest + + +# --------------------------------------------------------------------------- +# rholearn's lmax / nmax dict format (from rholearn/utils/convert.py docstrings) +# +# lmax = {"H": 1, "C": 2} per-species max lambda +# nmax = {("H", 0): 2, ("H", 1): 3, ("C", 0): 4, ...} per-species per-lambda n +# +# Our uniform BasisSpec has max_l + n_radial constant across species. The +# adapter expands that into rholearn's per-species dicts so the same basis +# spec works for arbitrary species sets. +# --------------------------------------------------------------------------- + + +class TestBuildLmaxNmaxDicts: + """Convert our uniform BasisSpec into rholearn's per-species dicts.""" + + def test_lmax_contains_every_species(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import build_lmax_nmax + + lmax, nmax = build_lmax_nmax(BasisSpec(), species=("H", "O", "Fe")) + assert set(lmax) == {"H", "O", "Fe"} + + def test_lmax_value_matches_basis_spec(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import build_lmax_nmax + + spec = BasisSpec() + lmax, _ = build_lmax_nmax(spec, species=("Fe",)) + assert lmax["Fe"] == spec.max_l + + def test_nmax_keyed_by_species_and_lambda(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import build_lmax_nmax + + spec = BasisSpec() + _, nmax = build_lmax_nmax(spec, species=("H", "Fe")) + # Both species share the same n_radial at every lambda + for s in ("H", "Fe"): + for lam in range(spec.max_l + 1): + assert nmax[(s, lam)] == spec.n_radial, ( + f"nmax[({s!r}, {lam})] must be {spec.n_radial}, " + f"got {nmax[(s, lam)]}" + ) + + def test_total_per_atom_coefficients_match(self): + """Sum of ``(2*l + 1) * nmax[(s, l)]`` across l must equal + ``BasisSpec.n_coeffs_per_atom``. If this drifts the flat vector + produced by the adapter will be the wrong length. + """ + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import build_lmax_nmax + + spec = BasisSpec() + lmax, nmax = build_lmax_nmax(spec, species=("Fe",)) + total = sum((2 * lam + 1) * nmax[("Fe", lam)] for lam in range(lmax["Fe"] + 1)) + assert total == spec.n_coeffs_per_atom + + +# --------------------------------------------------------------------------- +# Reordering: our (atom, n_outer, lm_packed) <-> rholearn (atom, l, n, mu). +# Pure ndarray math, no metatensor required. +# --------------------------------------------------------------------------- + + +class TestDenseToRholearnFlat: + """``dense_to_rholearn_flat(coeffs, basis_spec, symbols) -> np.ndarray``.""" + + def test_output_length_matches_total_basis(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + atoms = ("Fe", "Fe") + coeffs = np.zeros((2, spec.n_coeffs_per_atom)) + flat = dense_to_rholearn_flat(coeffs, spec, atoms) + assert flat.shape == (2 * spec.n_coeffs_per_atom,) + + def test_zero_in_gives_zero_out(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + flat = dense_to_rholearn_flat( + np.zeros((1, spec.n_coeffs_per_atom)), spec, ("Fe",) + ) + np.testing.assert_array_equal(flat, 0.0) + + def test_dtype_preserved(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + rng = np.random.default_rng(0) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)).astype(np.float64) + flat = dense_to_rholearn_flat(coeffs, spec, ("Fe",)) + assert flat.dtype == np.float64 + + def test_concatenates_atoms_in_order(self): + """Per-atom blocks must appear in input order (atom 0 first, then 1, ...).""" + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + # Use distinguishable per-atom values + coeffs = np.zeros((2, spec.n_coeffs_per_atom)) + coeffs[0, :] = 1.0 + coeffs[1, :] = 2.0 + flat = dense_to_rholearn_flat(coeffs, spec, ("Fe", "Fe")) + per_atom = spec.n_coeffs_per_atom + assert np.allclose(flat[:per_atom], 1.0) + assert np.allclose(flat[per_atom:], 2.0) + + +class TestRoundtrip: + """dense -> rholearn-flat -> dense must be exactly the identity.""" + + def test_roundtrip_random_single_atom(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import ( + dense_to_rholearn_flat, + rholearn_flat_to_dense, + ) + + spec = BasisSpec() + rng = np.random.default_rng(1) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)) + flat = dense_to_rholearn_flat(coeffs, spec, ("Fe",)) + restored = rholearn_flat_to_dense(flat, spec, ("Fe",)) + np.testing.assert_array_equal(restored, coeffs) + + def test_roundtrip_random_multi_atom(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import ( + dense_to_rholearn_flat, + rholearn_flat_to_dense, + ) + + spec = BasisSpec() + rng = np.random.default_rng(2) + symbols = ("Fe", "O", "Fe", "H") + coeffs = rng.standard_normal((len(symbols), spec.n_coeffs_per_atom)) + flat = dense_to_rholearn_flat(coeffs, spec, symbols) + restored = rholearn_flat_to_dense(flat, spec, symbols) + np.testing.assert_array_equal(restored, coeffs) + + def test_permutation_is_actually_nontrivial(self): + """The reordering must MOVE values around -- if dense -> flat were + the identity that would mean we'd silently fed misordered data to + rholearn. Pinning this catches a future 'simplification' that + accidentally drops the permutation. + """ + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + # Distinguishable per-channel values via arange + coeffs = np.arange(spec.n_coeffs_per_atom, dtype=np.float64).reshape( + 1, spec.n_coeffs_per_atom + ) + flat = dense_to_rholearn_flat(coeffs, spec, ("Fe",)) + # rholearn's ordering is atom -> lambda -> n -> mu; ours is + # atom -> n -> lambda -> mu. So flat[0] is c[atom=0, lambda=0, n=0, mu=0] + # which in OUR layout is at position [n=0, lm=0] = 0. So flat[0] == 0. + # But flat[1] is c[atom=0, lambda=1, n=0, mu=-1] which in OUR layout + # is at [n=0, lm=1] = 1. flat[1] == 1. + # The DIFFERENT ordering kicks in for flat[3]: rholearn says lambda=1 + # n=1 mu=-1, which in ours is at [n=1, lm=1] = 25, not 3. + # So flat[3] != coeffs[0, 3] is the load-bearing check. + assert flat[3] != coeffs[0, 3], ( + "ordering is trivial; the reordering should move values around" + ) + + +# --------------------------------------------------------------------------- +# Smoke test for the full TensorMap path. Heavier dependency on metatensor +# but the test is short. +# --------------------------------------------------------------------------- + + +class TestDenseToTensorMap: + """``dense_to_tensormap`` returns a metatensor TensorMap with the right keys. + + Requires the rholearn sibling repo at ``../rholearn/`` (auto-skips + when absent). On Adastra (where rholearn IS installed) this test + activates and exercises the full conversion path. + """ + + def test_tensormap_has_o3_lambda_center_type_keys(self): + pytest.importorskip("metatensor") + pytest.importorskip("chemfiles") + + from pathlib import Path + + if not (Path(__file__).resolve().parent.parent.parent / "rholearn").exists(): + pytest.skip("rholearn sibling repo not present; skipping live conversion") + + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_tensormap + + spec = BasisSpec() + positions = np.array([[0.0, 0.0, 0.0], [2.0, 2.0, 2.0]]) + cell = np.eye(3) * 4.0 + symbols = ("Fe", "Fe") + rng = np.random.default_rng(3) + coeffs = rng.standard_normal((2, spec.n_coeffs_per_atom)) + tmap = dense_to_tensormap( + coeffs, spec, symbols, positions, cell, structure_idx=0 + ) + # Keys must contain ``o3_lambda`` and ``center_type`` per rholearn's + # convention (see rholearn/utils/convert.py docstrings). + names = list(tmap.keys.names) + assert "o3_lambda" in names + assert "center_type" in names diff --git a/tests/test_submit_script.py b/tests/test_submit_script.py new file mode 100644 index 0000000..419fd38 --- /dev/null +++ b/tests/test_submit_script.py @@ -0,0 +1,184 @@ +"""TDD tests for the parameterized Adastra submit script. + +The script `submit_charge3net_adastra.sh` is now configurable via two env +vars: + + LEMATRHO_TRAINING_MODE "pretrained" (default) or "from_scratch" + LEMATRHO_DRY_RUN "1" prints the resolved train command and exits + +These tests pin the contract. + +They don't depend on Adastra. The script is sourced under bash with +LEMATRHO_DRY_RUN=1 so the venv activate / rocm-smi / srun calls are +skipped and the train invocation is printed instead of executed. +""" + +from __future__ import annotations + +import os +import shutil +import subprocess +from pathlib import Path + +import pytest + + +SUBMIT_SCRIPT = Path(__file__).resolve().parent.parent / "submit_charge3net_adastra.sh" + + +def _run(env_extra: dict) -> subprocess.CompletedProcess: + """Run the submit script under bash with LEMATRHO_DRY_RUN=1.""" + if shutil.which("bash") is None: + pytest.skip("bash not available in test environment") + env = { + **os.environ, + "LEMATRHO_DRY_RUN": "1", + # Avoid touching the user's real Adastra setup or W&B credentials. + "LEMATRHO_ADASTRA_SETUP": "/tmp/fake_setup_for_tests", + # SLURM env vars that the script would normally inherit. + "SLURM_NTASKS": "4", + "SLURM_NODELIST": "g0001", + "SLURM_JOB_ACCOUNT": "c1816212_mi250", + **env_extra, + } + return subprocess.run( + ["bash", str(SUBMIT_SCRIPT)], + env=env, + capture_output=True, + text=True, + check=False, + ) + + +def test_dry_run_mode_prints_train_command(): + """LEMATRHO_DRY_RUN=1 must print the resolved train command and exit 0.""" + result = _run({}) + assert result.returncode == 0, ( + f"dry-run exited {result.returncode}; stderr={result.stderr}" + ) + assert "charge3net_ft.train" in result.stdout, ( + f"dry-run output missing the train invocation; stdout={result.stdout}" + ) + + +def test_default_mode_is_pretrained(): + """Unset LEMATRHO_TRAINING_MODE -> pretrained MP checkpoint path is used.""" + result = _run({}) + assert result.returncode == 0 + out = result.stdout + assert "--ckpt-path" in out, ( + f"default (pretrained) run must pass --ckpt-path; stdout={out}" + ) + assert "charge3net_mp.pt" in out, ( + f"default run must point --ckpt-path at the MP checkpoint; stdout={out}" + ) + + +def test_pretrained_mode_uses_default_save_dir(): + """Pretrained mode writes to charge3net_checkpoints/ (no fromscratch suffix).""" + result = _run({"LEMATRHO_TRAINING_MODE": "pretrained"}) + assert result.returncode == 0 + assert ( + "charge3net_checkpoints " in (result.stdout + " ") + or "charge3net_checkpoints\n" in result.stdout + or "/charge3net_checkpoints" in result.stdout + ) + assert "charge3net_checkpoints_fromscratch" not in result.stdout, ( + f"pretrained mode must NOT use the fromscratch save dir; stdout={result.stdout}" + ) + + +def test_from_scratch_mode_drops_ckpt_path(): + """LEMATRHO_TRAINING_MODE=from_scratch -> no --ckpt-path flag at all. + + Without --ckpt-path, ChargE3NetWrapper.__init__ initializes weights + fresh (no MP transfer). This is the comparison arm for the + pretrained vs from-scratch experiment. + """ + result = _run({"LEMATRHO_TRAINING_MODE": "from_scratch"}) + assert result.returncode == 0, ( + f"from_scratch run exited {result.returncode}; stderr={result.stderr}" + ) + out = result.stdout + assert "--ckpt-path" not in out, ( + f"from_scratch must not pass --ckpt-path; stdout={out}" + ) + # also confirm charge3net_mp.pt isn't referenced anywhere in the + # resolved command (defense against accidental partial passing) + assert "charge3net_mp.pt" not in out, ( + f"from_scratch must not reference the MP checkpoint; stdout={out}" + ) + + +def test_from_scratch_mode_uses_separate_save_dir(): + """From-scratch run writes to a different dir so checkpoints don't collide + with the pretrained run. + """ + result = _run({"LEMATRHO_TRAINING_MODE": "from_scratch"}) + assert result.returncode == 0 + out = result.stdout + assert "charge3net_checkpoints_fromscratch" in out, ( + f"from_scratch must write to charge3net_checkpoints_fromscratch/; stdout={out}" + ) + + +def test_from_scratch_mode_uses_distinct_wandb_name(): + """W&B run name differs between the two modes so the dashboard tells them apart.""" + # WANDB_NAME is what wandb reads at init time when no --name is passed. + pretrained = _run({"LEMATRHO_TRAINING_MODE": "pretrained"}).stdout + fromscratch = _run({"LEMATRHO_TRAINING_MODE": "from_scratch"}).stdout + # Both must mention WANDB_NAME or set it somehow. + assert "WANDB_NAME" in pretrained or "wandb-run-name" in pretrained, ( + f"pretrained mode must set the wandb run name; stdout={pretrained}" + ) + assert "WANDB_NAME" in fromscratch or "wandb-run-name" in fromscratch, ( + f"from_scratch mode must set the wandb run name; stdout={fromscratch}" + ) + + # And they must differ. + # Extract WANDB_NAME value from each (simple regex-free parsing). + def _wandb_name(blob: str) -> str: + for line in blob.splitlines(): + if "WANDB_NAME=" in line: + return line.split("WANDB_NAME=", 1)[1].split()[0].strip("'\"") + return "" + + p_name = _wandb_name(pretrained) + f_name = _wandb_name(fromscratch) + assert p_name and f_name and p_name != f_name, ( + f"WANDB_NAME must differ between modes; pretrained={p_name!r}, fromscratch={f_name!r}" + ) + + +def test_invalid_mode_exits_with_clear_error(): + """An unrecognized mode must fail fast with a helpful message.""" + result = _run({"LEMATRHO_TRAINING_MODE": "garbage"}) + assert result.returncode != 0, ( + f"invalid mode must exit non-zero; stdout={result.stdout} stderr={result.stderr}" + ) + combined = (result.stdout + " " + result.stderr).lower() + assert "garbage" in combined or "training_mode" in combined or "mode" in combined, ( + f"error message should mention the bad value or the env var; " + f"stdout={result.stdout} stderr={result.stderr}" + ) + + +def test_batch_size_and_val_probes_match_paper(): + """Regression test: per-GPU batch=16, val_probes=1000 match the upstream paper.""" + result = _run({}) + assert "--batch-size 16" in result.stdout, ( + f"per-GPU batch must be 16 (paper); stdout={result.stdout}" + ) + assert "--val-probes 1000" in result.stdout, ( + f"val_probes must be 1000 (paper); stdout={result.stdout}" + ) + + +def test_wandb_mode_is_offline(): + """W&B must default to offline; api.wandb.ai is unreachable from + Adastra compute nodes (caused job 4969727 to crash after 1h47m). + """ + result = _run({}) + assert "--wandb-mode offline" in result.stdout, ( + f"wandb-mode must default to offline; stdout={result.stdout}" + ) diff --git a/uv.lock b/uv.lock index ace9d7d..0865408 100644 --- a/uv.lock +++ b/uv.lock @@ -502,6 +502,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" }, ] +[[package]] +name = "chemfiles" +version = "0.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/51/35538663b6384add778945735478da66b7c3095649654325d001922f30f8/chemfiles-0.10.4.tar.gz", hash = "sha256:f9e5ece3fcc8b63fdc2708d4ecc2ba5862ae2ab6790447bffc10c1b34ef2f445", size = 3575412, upload-time = "2023-05-23T10:49:17.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/0d/e5a214dddec845c425cda2cb2273a95b2c5f77be9404d02c4f48b4e6992b/chemfiles-0.10.4-1-py2.py3-none-win_amd64.whl", hash = "sha256:5c1b50a7fd56d014f930e38a838c92098bd047a3e989ba4b89ff657c6d16e38a", size = 1129225, upload-time = "2023-05-24T15:02:46.683Z" }, + { url = "https://files.pythonhosted.org/packages/84/0e/409d1fe39dc24f3ac47dd384e78462fc4eb0435a169afe5b488cf6ded39b/chemfiles-0.10.4-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:10a4e641605db56321316310f620746db350691d7c9edc433fe2a65984e2278b", size = 1497588, upload-time = "2023-05-23T10:49:04.561Z" }, + { url = "https://files.pythonhosted.org/packages/78/5f/d7d7347db0d1a92577aa27d9412adea002295263d52cca57ff14c92cde56/chemfiles-0.10.4-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:626725b0ea907d995cbbba99df1d19c474f8ebecdea8d0d390b7f3eaf2c91039", size = 1350827, upload-time = "2023-05-23T10:49:07.125Z" }, + { url = "https://files.pythonhosted.org/packages/3a/d5/beb71f372e650ba75e3eac246a17daa09a08aeed46580b62af35234d01f2/chemfiles-0.10.4-py2.py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4dbf6fa7ad5b2a1ad1415fbca905ce3a02c71cc2aa7fbce18a2b7d13c01a3664", size = 1751189, upload-time = "2023-05-23T10:49:10.237Z" }, + { url = "https://files.pythonhosted.org/packages/50/4c/380de5755146e27236cdecf02b7fe5da4c1f3786716baee5b3a245026acb/chemfiles-0.10.4-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef8f2b9fa65885658088180bb33971d1337bc8542220c710d1f6f3c1a6d661d4", size = 1632279, upload-time = "2023-05-23T10:49:12.365Z" }, +] + [[package]] name = "click" version = "8.2.1" @@ -1048,6 +1064,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/5d/b645a1e7c71ba562cf31987ee7499f603b6b49f67ccab521b3b600f53a1e/gemmi-0.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:402a71c935cab167ac6a7a29045e47a972388ef6f62fa3f477d8b0241fe53d4e", size = 1928436, upload-time = "2025-03-24T19:20:03.183Z" }, ] +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.50" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/f6/354ae6491228b5eb40e10d89c4d13c651fe1cf7556e35ebdded50cff57ce/gitpython-3.1.50.tar.gz", hash = "sha256:80da2d12504d52e1f998772dc5baf6e553f8d2fcfe1fcc226c9d9a2ee3372dcc", size = 219798, upload-time = "2026-05-06T04:01:26.571Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/7a/1c6e3562dfd8950adbb11ffbc65d21e7c89d01a6e4f137fa981056de25c5/gitpython-3.1.50-py3-none-any.whl", hash = "sha256:d352abe2908d07355014abdd21ddf798c2a961469239afec4962e9da884858f9", size = 212507, upload-time = "2026-05-06T04:01:23.799Z" }, +] + [[package]] name = "gunicorn" version = "25.1.0" @@ -1381,28 +1421,38 @@ source = { virtual = "." } dependencies = [ { name = "ase" }, { name = "atomate2" }, + { name = "chemfiles" }, { name = "e3nn" }, { name = "fireworks" }, { name = "ipykernel" }, { name = "lz4" }, { name = "material-hasher" }, - { name = "pandas" }, + { name = "metatensor" }, + { name = "numpy" }, { name = "pyarrow" }, + { name = "python-dotenv" }, { name = "scipy" }, + { name = "torch" }, + { name = "wandb" }, ] [package.metadata] requires-dist = [ { name = "ase", specifier = ">=3.25.0" }, { name = "atomate2" }, - { name = "e3nn", specifier = ">=0.6.0" }, + { name = "chemfiles", specifier = ">=0.10.4" }, + { name = "e3nn", specifier = ">=0.5.0" }, { name = "fireworks" }, { name = "ipykernel", specifier = ">=6.29.5" }, - { name = "lz4", specifier = ">=4.4.5" }, + { name = "lz4", specifier = ">=4.0.0" }, { name = "material-hasher", git = "https://github.com/LeMaterial/lematerial-hasher" }, - { name = "pandas", specifier = ">=2.3.0" }, - { name = "pyarrow", specifier = ">=20.0.0" }, - { name = "scipy", specifier = ">=1.16.0" }, + { name = "metatensor", specifier = ">=0.2.0" }, + { name = "numpy", specifier = ">=1.24" }, + { name = "pyarrow", specifier = ">=14.0.0" }, + { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "scipy", specifier = ">=1.10.0" }, + { name = "torch", specifier = ">=2.0" }, + { name = "wandb", specifier = ">=0.16.0" }, ] [[package]] @@ -1639,6 +1689,61 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, ] +[[package]] +name = "metatensor" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "metatensor-core" }, + { name = "metatensor-learn" }, + { name = "metatensor-operations" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/58/172e96ccdca4d8d572579adc69b593dad79b74497c116ed86979257a5cbd/metatensor-0.2.0.tar.gz", hash = "sha256:ce3f8a34796d2aaa7e74b2d1392f64a05e85d1ca3e3878c1e9259e6a6a7a8138", size = 5373, upload-time = "2024-01-26T17:27:15.203Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/28/fd3f02ccb23764af794e953262127a7f2aed35073f460da6f279fe1c2b15/metatensor-0.2.0-py3-none-any.whl", hash = "sha256:60008fee73f49b349350d9d93dec63ea4e1cf30beceae17d543561d69a7ac393", size = 3702, upload-time = "2024-01-26T17:26:59.518Z" }, +] + +[[package]] +name = "metatensor-core" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/34/d5/18f05f73a0af0517dbbf441e673abf88bccfec6a92a1beeebbc9df9d5ed9/metatensor_core-0.2.0.tar.gz", hash = "sha256:30200451eb70e635fdef5dfd46476d0303b1757b1e34c23f9c9e568c9d188545", size = 177741, upload-time = "2026-05-13T15:45:51.837Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/99/4a81ad15c63b82be70e8e9ca1ae95b31b7c91d512b684c8a26fb0671a746/metatensor_core-0.2.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c5e82760244c7233c41547d6c015f38caf7f3af589e0a7f827cad4a0c0ef0bbf", size = 549924, upload-time = "2026-05-13T15:45:08.494Z" }, + { url = "https://files.pythonhosted.org/packages/f0/11/8cd0fea97a5be6793596f573bb2fabf5dfd00a67884f9c77e6c7331c3921/metatensor_core-0.2.0-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:286f477f96520c046dff35dbc3a40ac3cfdef540e1c7bc071e91769f68dbb8f8", size = 582626, upload-time = "2026-05-13T15:45:18.982Z" }, + { url = "https://files.pythonhosted.org/packages/b4/09/91e7f49401597f0858087a3e603f98bb78d900895510b799fa445e1a4a8e/metatensor_core-0.2.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f0529b6d3966fff6ad85e988443c2acf22d0251f52be38d4dce6fa4d617c0e81", size = 594606, upload-time = "2026-05-13T15:45:33.145Z" }, + { url = "https://files.pythonhosted.org/packages/bc/50/e090f6a2c56a6c822bac818ca5d900568a17df8ea6a2d1bf9f8d8cde9fc0/metatensor_core-0.2.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dbdf693cdb0436736e8d678e2d45dec5f8e47df18c7a4f775eb546c0106fb867", size = 634966, upload-time = "2026-05-13T15:45:39.684Z" }, + { url = "https://files.pythonhosted.org/packages/55/65/84df97b3922d50954644b06397e337e4a52da98ddd92f52a1532329d1378/metatensor_core-0.2.0-py3-none-win_amd64.whl", hash = "sha256:2b7dfc59c920b1d06dbebd2e7afa0a2395ea1ef01e437ad7a0e4d213f2034ce1", size = 533600, upload-time = "2026-05-13T15:45:44.907Z" }, +] + +[[package]] +name = "metatensor-learn" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "metatensor-core" }, + { name = "metatensor-operations" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/bd/0fd1901b44635a24f40528a6244b5889143747ddeb841ae0201255c1f22e/metatensor_learn-0.5.0.tar.gz", hash = "sha256:0b1d30ed217d70de7851ed1d48421515d9c6a1be7f50d9b1b43f92a689be51d0", size = 25221, upload-time = "2026-05-13T15:45:54.582Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/85/f8e2061c58cf4ea22681be48f5aecf0074abd9717fcb8f05dd3ea6e370fc/metatensor_learn-0.5.0-py3-none-any.whl", hash = "sha256:ad8863dac144f03c9ca80ec625c9e35b87ceb82438a0a80c0bf14e9dcc1b607c", size = 32888, upload-time = "2026-05-13T15:45:49.25Z" }, +] + +[[package]] +name = "metatensor-operations" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "metatensor-core" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/98/83e132e8aca5bc05ffaffd342566ab4abd8e7bb579de6df1fde8b8602abb/metatensor_operations-0.5.0.tar.gz", hash = "sha256:e1cb0a8c358842e94ac3680fa9ec6f7a006cb519b6950ed1bb7001a209087cfc", size = 57735, upload-time = "2026-05-13T15:45:53.568Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/31/18d10b7d6d2ef5829a33c52cb0148730a951f0b3ad13aac5c4fae510ccfd/metatensor_operations-0.5.0-py3-none-any.whl", hash = "sha256:9536562c9e02a5c723fc118be671e8ff37e8e69caf2dc4a2bd97fca5271ec510", size = 79354, upload-time = "2026-05-13T15:45:47.855Z" }, +] + [[package]] name = "mongomock" version = "4.3.0" @@ -2414,6 +2519,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/35/cc0aaecf278bb4575b8555f2b137de5ab821595ddae9da9d3cd1da4072c7/propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f", size = 12663, upload-time = "2025-06-09T22:56:04.484Z" }, ] +[[package]] +name = "protobuf" +version = "7.35.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/60/fd/5b1491d9e4b586d621c54f4c36b888714164b6875f8d6afa3f9072906a51/protobuf-7.35.0.tar.gz", hash = "sha256:a2efd84605f41e559f1881b0912b44099d0a2ac9bf46b3474823f10fb393b0e6", size = 458677, upload-time = "2026-05-19T23:02:29.197Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl", hash = "sha256:66be6c513931c794fa92c080ffee41671390da3d79da219cf9c0c0907f035dda", size = 433225, upload-time = "2026-05-19T23:02:19.884Z" }, + { url = "https://files.pythonhosted.org/packages/8b/39/1c76c2da93f3c507e958e0aecee2391cc44d4625de6c728bbc555195b5a8/protobuf-7.35.0-cp310-abi3-manylinux2014_aarch64.whl", hash = "sha256:fcbe42a4ac09d3ec9c987ddfcd956afd0b15f1ff613bd8371bde9405ffd5c8e5", size = 328847, upload-time = "2026-05-19T23:02:22.3Z" }, + { url = "https://files.pythonhosted.org/packages/91/1a/39f7ce90a238c1a987a4d81ec26379e02ca0aff367de68e4a1fa474215b9/protobuf-7.35.0-cp310-abi3-manylinux2014_s390x.whl", hash = "sha256:4cbf5cc286130e06a6c9bbefac442431173906dfcc979712183d4adcc01b37ee", size = 344030, upload-time = "2026-05-19T23:02:23.591Z" }, + { url = "https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl", hash = "sha256:6c0f98f10c8a05ea30f8993dfef2de093d27b490fdae78bb60c8343795d55011", size = 327130, upload-time = "2026-05-19T23:02:24.637Z" }, + { url = "https://files.pythonhosted.org/packages/8e/e5/e46adb0badc388bfb84877a5f9f026aff63f60e611016cf64dbe77e05446/protobuf-7.35.0-cp310-abi3-win32.whl", hash = "sha256:4c4617b83ade0e279d1d2bfe04025a1adb87f9ed657de038620dc0ff959357f6", size = 428946, upload-time = "2026-05-19T23:02:25.741Z" }, + { url = "https://files.pythonhosted.org/packages/a7/ab/547fbd9e16d879dd13c167478f8ae0a83a428008ca07a5e06acdc23ad473/protobuf-7.35.0-cp310-abi3-win_amd64.whl", hash = "sha256:f05bcadf9a2a6b8dda047007075135fb7d08c73d9177aabc067e1be46881a201", size = 439996, upload-time = "2026-05-19T23:02:26.808Z" }, + { url = "https://files.pythonhosted.org/packages/b8/ef/50433d346c56657a70d27f156c7b349ac59a068b01de4eb796e747eecc43/protobuf-7.35.0-py3-none-any.whl", hash = "sha256:c13f325cf242bad135c350629eeb5d54b24228eb472fb3e2e9ebbd4c5dc20ca0", size = 171659, upload-time = "2026-05-19T23:02:27.842Z" }, +] + [[package]] name = "psutil" version = "7.0.0" @@ -3199,6 +3319,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/49/65/dea992c6a97074f6d8ff9eab34741298cac2ce23e2b6c74fb7d08afdf85c/sentinels-1.1.1-py3-none-any.whl", hash = "sha256:835d3b28f3b47f5284afa4bf2db6e00f2dc5f80f9923d4b7e7aeeeccf6146a11", size = 3744, upload-time = "2025-08-12T07:57:48.858Z" }, ] +[[package]] +name = "sentry-sdk" +version = "2.60.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/a2/2e6c090db384cc515069f4f85542bd5baf6786852073020ea73d4a76d3ea/sentry_sdk-2.60.0.tar.gz", hash = "sha256:0bd25e54e78ca02d0be512529fa644bbbf9e8470d7b26371294012d4ca93c978", size = 452946, upload-time = "2026-05-13T13:34:52.516Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/41/f2b800b7f12a05dd48c2a6280d4dd812d1425fc66ed3fe3fd99420c41d1a/sentry_sdk-2.60.0-py3-none-any.whl", hash = "sha256:28a536c03291c8bcb363cf35c611b32738ec118ff64d8d6383b096448ac4c803", size = 475616, upload-time = "2026-05-13T13:34:50.259Z" }, +] + [[package]] name = "setuptools" version = "80.9.0" @@ -3217,6 +3350,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "smmap" +version = "5.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/ea/49c993d6dfdd7338c9b1000a0f36817ed7ec84577ae2e52f890d1a4ff909/smmap-5.0.3.tar.gz", hash = "sha256:4d9debb8b99007ae47165abc08670bd74cb74b5227dda7f643eccc4e9eb5642c", size = 22506, upload-time = "2026-03-09T03:43:26.1Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/d4/59e74daffcb57a07668852eeeb6035af9f32cbfd7a1d2511f17d2fe6a738/smmap-5.0.3-py3-none-any.whl", hash = "sha256:c106e05d5a61449cf6ba9a1e650227ecfb141590d2a98412103ff35d89fc7b2f", size = 24390, upload-time = "2026-03-09T03:43:24.361Z" }, +] + [[package]] name = "spglib" version = "2.6.0" @@ -3464,6 +3606,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "wandb" +version = "0.27.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "gitpython" }, + { name = "packaging" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8e/31/fe53d06b75ef0a7f2f0ee5931a89f7aedc27d233840b1839616860fed256/wandb-0.27.0.tar.gz", hash = "sha256:579e75300173059f9334e1f513a79ef15f6d9ea5c74e20d695633648cdd02031", size = 41090732, upload-time = "2026-05-14T03:44:08.894Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/5e/2c199e70e636ecfd217cde0bc7469f4511e1d03d0685eb92bfdfce391430/wandb-0.27.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:c156be4851485f3c4160cb6eb2e8991b4cdeffbccefc5636d33cf5e254847365", size = 24886476, upload-time = "2026-05-14T03:43:27.569Z" }, + { url = "https://files.pythonhosted.org/packages/0b/cd/a617c871cd304a9804e56a7ec2ec2c65685bf0091a2b9f91910175a149e2/wandb-0.27.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:20179f38afb0158859a4141d29ac650d3fdbd0cf801a74ce25565c934f03776c", size = 26045779, upload-time = "2026-05-14T03:43:31.999Z" }, + { url = "https://files.pythonhosted.org/packages/10/0a/d3f159a201530b84b72ca5f98c68d1f351c2d9a1864558ed76c811407fae/wandb-0.27.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:626497d7975fa898d0a4a239da7a510483495ca3514510dbe75004a25963af4d", size = 25480764, upload-time = "2026-05-14T03:43:35.922Z" }, + { url = "https://files.pythonhosted.org/packages/5f/6a/8721fcdf71d42639191040a77a585d2982402b1754700cb2ecfc2ca1470a/wandb-0.27.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:f772da7005cc26a2a32b729a16982a583dc68b3d493df6a09d0aa5c5ca5a2060", size = 27256204, upload-time = "2026-05-14T03:43:39.765Z" }, + { url = "https://files.pythonhosted.org/packages/00/5e/279d167ba79fb7a8a43401c9f25efd0f6663ee9bd1eaf5a8578530198888/wandb-0.27.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:63acfc5b994e4a90e4a2fbdee6d45e664da3dd865bb1419942c8995c06c41cf1", size = 25647469, upload-time = "2026-05-14T03:43:44.817Z" }, + { url = "https://files.pythonhosted.org/packages/94/51/a69ac59300e3c813939d0764348959ed2a21e14c668cb1cebcb04010da6a/wandb-0.27.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:17aae6e4a88cd05c00ea8f546220918e3ebb6f8c1c36b70ef04a5ac75f0d7160", size = 27599005, upload-time = "2026-05-14T03:43:50.926Z" }, + { url = "https://files.pythonhosted.org/packages/5f/40/bf510c8758727df020f83b717ebc1fcc1739ed7f6ae1796ebef60bf6f592/wandb-0.27.0-py3-none-win32.whl", hash = "sha256:0bd5659417e386bf6538b5e2ffe6885774c6197f0e4853bfed517d5b0db457f1", size = 25036164, upload-time = "2026-05-14T03:43:54.839Z" }, + { url = "https://files.pythonhosted.org/packages/54/ff/69f88e7d90c22b79bcb911143c13e59742ee192080b21015ff83a5a1f60a/wandb-0.27.0-py3-none-win_amd64.whl", hash = "sha256:89d584b73166eecee96fb446f18d0e45b1aa45aba6a3696296f3f06d7454516b", size = 25036170, upload-time = "2026-05-14T03:43:59.227Z" }, + { url = "https://files.pythonhosted.org/packages/f6/38/f7efd7a87297a55c7e9a331a1dbb5b19e54aeacc11fe6f43f8636a73987c/wandb-0.27.0-py3-none-win_arm64.whl", hash = "sha256:a6c129c311edf210a2b4f2f4acc557eff522628125f5f28ed27df19c16c07079", size = 22972710, upload-time = "2026-05-14T03:44:03.275Z" }, +] + [[package]] name = "wcwidth" version = "0.2.13"