Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 40 additions & 236 deletions scripts/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Comprehensive benchmark runner for material generation evaluation.

This script:
1. Takes a list of CIF files as input
1. Takes an ASE database (.aselmdb) file as input
2. Loads a configuration specifying which benchmark families to run
3. ALWAYS runs validity benchmark and preprocessor first (mandatory)
4. Filters to only valid structures for subsequent processing
Expand All @@ -12,8 +12,8 @@
7. Saves results to JSON files in the results/ directory

Usage:
uv run scripts/run_benchmarks.py --cifs path/to/cifs.txt --config comprehensive --name my_run
uv run scripts/run_benchmarks.py --cifs path/to/cifs.txt --config comprehensive --name test_run
uv run scripts/run_benchmarks.py --asedb path/to/structures.aselmdb --config comprehensive --name my_run
uv run scripts/run_benchmarks.py --asedb path/to/structures.aselmdb --config comprehensive --name test_run
"""

import argparse
Expand Down Expand Up @@ -138,204 +138,54 @@ def cleanup_after_benchmark(benchmark_name: str, monitor_memory: bool = False):
log_memory_usage(f"after {benchmark_name} cleanup", force_log=monitor_memory)


def load_cif_files(input_path: str) -> List[str]:
"""Load list of CIF file paths from a text file or directory.
def load_ase_database(db_path: str) -> List[Any]:
"""Load structures from an ASE database file.

Parameters
----------
input_path : str
Path to either:
- A text file containing CIF file paths (one per line)
- A directory containing CIF files
db_path : str
Path to the ASE database file (.aselmdb)

Returns
-------
List[str]
List of CIF file paths
List[Any]
List of structures loaded from the database
"""
input_path_obj = Path(input_path)
from ase.db import connect
from pymatgen.io.ase import AseAtomsAdaptor

if input_path_obj.is_dir():
# Directory mode: find all CIF files in the directory
logger.info(f"Scanning directory for CIF files: {input_path}")
cif_paths = []

# Find all .cif files in the directory (recursive)
for cif_file in input_path_obj.rglob("*.cif"):
cif_paths.append(str(cif_file))

if not cif_paths:
raise FileNotFoundError(f"No CIF files found in directory: {input_path}")

logger.info(f"Found {len(cif_paths)} CIF files in directory")
return cif_paths

elif input_path_obj.is_file():
# File mode: read CIF paths from text file
logger.info(f"Loading CIF file list from: {input_path}")
with open(input_path, "r") as f:
cif_paths = [
line.strip() for line in f if line.strip() and not line.startswith("#")
]

# Validate that files exist
missing_files = [path for path in cif_paths if not Path(path).exists()]
if missing_files:
raise FileNotFoundError(f"Missing CIF files: {missing_files}")

return cif_paths
else:
raise FileNotFoundError(f"Path does not exist: {input_path}")


def load_structures_from_wycoff_csv(
csv_path: str, respect_validity_flags: bool = True
) -> List:
"""Load structures from a CSV file with proper validation handling.

This function handles the different validation behaviors between Structure.from_file()
and pre-computed validity flags from crystal generation pipelines like WyckoffTransformer.

Parameters
----------
csv_path : str
Path to CSV file containing structures
respect_validity_flags : bool, default=True
If True, skip structures marked as invalid in CSV validity columns
(structural_validity, smact_validity). These flags are pre-computed using
the same validation criteria as Structure.from_file() (0.5 Å minimum distance)

Returns
-------
List
List of pymatgen Structure objects

Raises
------
ValueError
If no structure column found or no valid structures loaded

Examples
--------
# Match Structure.from_file() behavior (using pre-computed validity flags)
structures = load_structures_from_wycoff_csv("data.csv", respect_validity_flags=True)

# Load everything possible (permissive, ignores validity flags)
structures = load_structures_from_wycoff_csv("data.csv", respect_validity_flags=False)
"""
import json

import pandas as pd
from pymatgen.core import Structure

from lemat_genbench.utils.logging import logger

logger.info(f"Loading structures from CSV: {csv_path}")

# Read CSV file
df = pd.read_csv(csv_path)

# Coerce validity flags to bools if present (handles 'true'/'false', 1/0, yes/no)
def to_bool_series(s):
return (
s.astype(str)
.str.strip()
.str.lower()
.map(
{
"true": True,
"1": True,
"yes": True,
"y": True,
"false": False,
"0": False,
"no": False,
"n": False,
}
)
.fillna(False)
)

if respect_validity_flags:
if "structural_validity" in df.columns:
df["structural_validity"] = to_bool_series(df["structural_validity"])
if "smact_validity" in df.columns:
df["smact_validity"] = to_bool_series(df["smact_validity"])

# Find structure column (try different possible names)
structure_column = None
for col_name in ["structure", "LeMatStructs", "cif_string"]:
if col_name in df.columns:
structure_column = col_name
break

if structure_column is None:
raise ValueError(
"CSV file must contain a 'structure', 'LeMatStructs', or 'cif_string' column"
)
db_path_obj = Path(db_path)
if not db_path_obj.exists():
raise FileNotFoundError(f"ASE database file not found: {db_path}")

logger.info(f"Loading structures from ASE database: {db_path}")
structures = []
skipped_invalid = 0
skipped_errors = 0

for idx, row in df.iterrows():
# Check validity flags first (if respecting them)
if respect_validity_flags:
if "structural_validity" in df.columns and not row["structural_validity"]:
logger.debug(
f"Skipping structure {idx + 1}: marked as structurally invalid"
)
skipped_invalid += 1
continue
if "smact_validity" in df.columns and not row["smact_validity"]:
logger.debug(f"Skipping structure {idx + 1}: marked as SMACT invalid")
skipped_invalid += 1
continue

try:
structure_data = row[structure_column]

# Skip rows with missing structure cells
if pd.isna(structure_data):
logger.debug(f"Skipping structure {idx + 1}: missing structure data")
skipped_errors += 1

# Connect to the database and load all structures
with connect(db_path) as db:
n_total = len(db)
logger.info(f"Found {n_total} entries in ASE database")

# Convert each ASE Atoms object to a pymatgen Structure
adaptor = AseAtomsAdaptor()
for row in db.select():
try:
atoms = row.toatoms()
structure = adaptor.get_structure(atoms)
structures.append(structure)
except Exception as e:
logger.warning(f"Failed to convert structure {row.id}: {str(e)}")
continue

# Parse structure based on data format
if isinstance(structure_data, str) and structure_data.strip().startswith(
"{"
):
try:
# Try to parse as JSON first (for pymatgen Structure dict format)
structure_dict = json.loads(structure_data)
structure = Structure.from_dict(structure_dict)
except json.JSONDecodeError:
# If not valid JSON, try as CIF string
structure = Structure.from_str(structure_data, fmt="cif")
else:
# Try as CIF string
structure = Structure.from_str(structure_data, fmt="cif")

structures.append(structure)
logger.debug(f"✅ Loaded structure {idx + 1} from CSV")

except Exception as e:
# In permissive mode, log and skip
logger.warning(f"Failed to load structure {idx + 1} from CSV: {str(e)}")
skipped_errors += 1

if not structures:
raise ValueError("No valid structures loaded from CSV file")

logger.info(f"✅ Loaded {len(structures)} structures from CSV")
if skipped_invalid > 0:
logger.info(f"⚠️ Skipped {skipped_invalid} structures marked as invalid")
if skipped_errors > 0:
logger.info(f"⚠️ Skipped {skipped_errors} structures due to loading errors")
raise ValueError("No valid structures loaded from ASE database")

logger.info(f"Successfully loaded {len(structures)} structures from ASE database")
return structures




def load_benchmark_config(config_name: str) -> Dict[str, Any]:
"""Load benchmark configuration from YAML file."""
config_dir = Path(__file__).parent.parent / "src" / "config"
Expand Down Expand Up @@ -823,14 +673,12 @@ def save_results(
def main():
"""Main function to run benchmarks."""
parser = argparse.ArgumentParser(
description="Run material generation benchmarks with original novelty/uniqueness/SUN (validity ALWAYS mandatory)"
description="Run material generation benchmarks on structures from an ASE database with original novelty/uniqueness/SUN (validity ALWAYS mandatory)"
)
parser.add_argument(
"--cifs",
help="Path to text file containing CIF file paths OR directory containing CIF files",
)
parser.add_argument(
"--csv", help="Path to CSV file containing structures in LeMatStructs column"
"--asedb",
help="Path to ASE database file (.aselmdb) containing structures",
required=True,
)
parser.add_argument(
"--config",
Expand Down Expand Up @@ -863,56 +711,12 @@ def main():

args = parser.parse_args()

# Validate input arguments
if not args.cifs and not args.csv:
parser.error("Either --cifs or --csv must be provided")
if args.cifs and args.csv:
parser.error("Only one of --cifs or --csv can be provided")

try:
# Log initial memory usage
log_memory_usage("start of benchmark run", force_log=args.monitor_memory)

# Load structures based on input type
if args.csv:
# Load structures from CSV
structures = load_structures_from_wycoff_csv(args.csv)
else:
# Load CIF files
logger.info(f"Loading CIF files from: {args.cifs}")
cif_paths = load_cif_files(args.cifs)
logger.info(f"✅ Loaded {len(cif_paths)} CIF files")

# Load structures from CIF files
logger.info("Converting CIF files to structures...")
structures = []

# Add progress bar for structure loading
with tqdm(cif_paths, desc="Loading CIF structures", unit="file") as pbar:
for cif_path in pbar:
try:
# Load CIF file using pymatgen
from pymatgen.core import Structure

structure = Structure.from_file(cif_path)
structures.append(structure)
pbar.set_postfix(
{
"loaded": len(structures),
"failed": len(cif_paths) - len(structures),
}
)
except Exception as e:
logger.warning(f"Failed to load {cif_path}: {str(e)}")
pbar.set_postfix(
{
"loaded": len(structures),
"failed": len(cif_paths) - len(structures),
}
)

if not structures:
raise ValueError("No valid structures loaded from CIF files")
# Load structures from ASE database
structures = load_ase_database(args.asedb)

n_total_structures = len(structures)
logger.info(f"✅ Loaded {n_total_structures} structures")
Expand Down