From 57ab13702cf8c37d22d56839291336caf414f0d4 Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Mon, 26 May 2025 22:18:01 -0700 Subject: [PATCH 1/4] new stability relaxation metric --- .gitignore | 2 + .../benchmarks/stability_benchmark.py | 130 ++++++++++ .../metrics/stability_metrics.py | 243 ++++++++++++++++++ .../utils/e_above_hull.py | 65 +++++ .../utils/relaxers/__init__.py | 17 ++ .../utils/relaxers/registry.py | 128 +++++++++ .../utils/relaxers/relaxers.py | 205 +++++++++++++++ tests/benchmarks/test_stability_benchmark.py | 174 +++++++++++++ tests/metrics/test_stability_metrics.py | 40 +++ 9 files changed, 1004 insertions(+) create mode 100644 src/lematerial_forgebench/benchmarks/stability_benchmark.py create mode 100644 src/lematerial_forgebench/metrics/stability_metrics.py create mode 100644 src/lematerial_forgebench/utils/e_above_hull.py create mode 100644 src/lematerial_forgebench/utils/relaxers/__init__.py create mode 100644 src/lematerial_forgebench/utils/relaxers/registry.py create mode 100644 src/lematerial_forgebench/utils/relaxers/relaxers.py create mode 100644 tests/benchmarks/test_stability_benchmark.py create mode 100644 tests/metrics/test_stability_metrics.py diff --git a/.gitignore b/.gitignore index a4c6f414..6cc79689 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,5 @@ cython_debug/ lightning_logs/ src/lematerial_forgebench/.DS_Store +# Ignore .gz +*.gz \ No newline at end of file diff --git a/src/lematerial_forgebench/benchmarks/stability_benchmark.py b/src/lematerial_forgebench/benchmarks/stability_benchmark.py new file mode 100644 index 00000000..004bc9fa --- /dev/null +++ b/src/lematerial_forgebench/benchmarks/stability_benchmark.py @@ -0,0 +1,130 @@ +"""Stability benchmark for material structures. + +This module implements a benchmark that evaluates the stability of +generated material structures using various relaxation methods. +""" + +from typing import Any, Dict + +from lematerial_forgebench.benchmarks.base import BaseBenchmark +from lematerial_forgebench.evaluator import EvaluatorConfig +from lematerial_forgebench.metrics.stability_metrics import StabilityMetric + + +class StabilityBenchmark(BaseBenchmark): + """Benchmark for evaluating the stability of generated material structures.""" + + def __init__( + self, + relaxer_type: str = "orb", + relaxer_config: Dict[str, Any] | None = None, + mp_entries_file: str = "src/lematerial_forgebench/utils/relaxers/2023-02-07-ppd-mp.pkl.gz", + name: str = "StabilityBenchmark", + description: str | None = None, + metadata: Dict[str, Any] | None = None, + ): + """Initialize the stability benchmark. + + Parameters + ---------- + relaxer_type : str, default="orb" + Type of relaxer to use (e.g., "orb", "ocp"). + relaxer_config : dict, optional + Configuration for the relaxer. If None, uses default config. + mp_entries_file : str + Path to the Materials Project entries file. + name : str + Name of the benchmark. + description : str, optional + Description of the benchmark. + metadata : dict, optional + Additional metadata for the benchmark. + """ + if description is None: + description = ( + f"Evaluates the stability of crystal structures using {relaxer_type.upper()} " + "relaxation and energy above hull calculations." + ) + + # Set default relaxer config if not provided + if relaxer_config is None: + relaxer_config = {"steps": 500, "fmax": 0.02} + + # Initialize the stability metric + stability_metric = StabilityMetric( + relaxer_type=relaxer_type, + relaxer_config=relaxer_config, + mp_entries_file=mp_entries_file, + ) + + # Set up evaluator config + evaluator_configs = { + "stability": EvaluatorConfig( + name=f"{relaxer_type.upper()} Stability", + description=f"Evaluates structure stability using {relaxer_type.upper()}", + metrics={"stability": stability_metric}, + weights={"stability": 1.0}, + aggregation_method="weighted_mean", + ), + } + + # Create benchmark metadata + benchmark_metadata = { + "version": "0.1.0", + "category": "stability", + "relaxer_type": relaxer_type, + "relaxer_config": relaxer_config, + "mp_entries_file": mp_entries_file, + **(metadata or {}), + } + + super().__init__( + name=name, + description=description, + evaluator_configs=evaluator_configs, + metadata=benchmark_metadata, + ) + + def aggregate_evaluator_results( + self, evaluator_results: Dict[str, Dict[str, Any]] + ) -> Dict[str, float]: + """Aggregate results from multiple evaluators into final scores. + + Parameters + ---------- + evaluator_results : dict[str, dict[str, Any]] + Results from each evaluator, as structured by BaseBenchmark.evaluate. + Example: {"evaluator_name": {"combined_value": 0.X, "metric_name_value": 0.Y}} + + Returns + ------- + dict[str, float] + Final aggregated scores. + """ + final_scores = { + "stability_score": 0.0, + "stable_ratio": 0.0, + "mean_e_above_hull": 0.0, + "metastable_ratio": 0.0, + } + + stability_eval_data: Dict[str, Any] | None = evaluator_results.get("stability") + print("stability_eval_data", stability_eval_data) + if stability_eval_data: + if stability_eval_data.get("combined_value") is not None: + final_scores["stability_score"] = stability_eval_data["combined_value"] + + if stability_eval_data.get("stability_value") is not None: + final_scores["stable_ratio"] = stability_eval_data["stability_value"] + + if stability_eval_data.get("mean_e_above_hull") is not None: + final_scores["mean_e_above_hull"] = stability_eval_data[ + "mean_e_above_hull" + ] + + if stability_eval_data.get("metastable_ratio") is not None: + final_scores["metastable_ratio"] = stability_eval_data[ + "metastable_ratio" + ] + + return final_scores diff --git a/src/lematerial_forgebench/metrics/stability_metrics.py b/src/lematerial_forgebench/metrics/stability_metrics.py new file mode 100644 index 00000000..5f72741a --- /dev/null +++ b/src/lematerial_forgebench/metrics/stability_metrics.py @@ -0,0 +1,243 @@ +"""Relaxation metrics for evaluating material structures. + +This module implements metrics for evaluating the relaxation of +material structures using various relaxation models and calculating +energy above hull. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional + +import numpy as np +from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram +from pymatgen.core import Structure +from pymatgen.entries.compatibility import MaterialsProject2020Compatibility + +from lematerial_forgebench.metrics.base import BaseMetric, MetricConfig +from lematerial_forgebench.utils.e_above_hull import ( + generate_CSE, + get_patched_phase_diagram_mp, +) +from lematerial_forgebench.utils.logging import logger +from lematerial_forgebench.utils.relaxers import BaseRelaxer, get_relaxer + + +@dataclass +class StabilityMetricConfig(MetricConfig): + """Configuration for the StabilityMetric. + + Parameters + ---------- + relaxer_type : str + Type of relaxer to use (e.g., "chgnet", "eqv2", "esen"). + relaxer_config : dict + Configuration for the specific relaxer. + mp_entries_file : str, optional + Path to the Materials Project entries file for e_above_hull calculation. + """ + + relaxer_type: str = "orb" + relaxer_config: Dict[str, Any] = field(default_factory=dict) + mp_entries_file: Optional[str] = None + + +class StabilityMetric(BaseMetric): + """Evaluate structure relaxation and energy above hull. + + This metric handles both the relaxation of structures using various models + and the calculation of energy above hull using the Materials Project database. + + Parameters + ---------- + relaxer_type : str + Type of relaxer to use. + relaxer_config : dict + Configuration for the specific relaxer. + mp_entries_file : str, optional + Path to the Materials Project entries file. + name : str, optional + Custom name for the metric. + description : str, optional + Description of what the metric measures. + lower_is_better : bool, default=True + Lower energies indicate better stability. + n_jobs : int, default=1 + Number of parallel jobs to run. + """ + + def __init__( + self, + relaxer_type: str, + relaxer_config: Dict[str, Any], + mp_entries_file: Optional[str] = None, + name: str | None = None, + description: str | None = None, + lower_is_better: bool = False, + n_jobs: int = 1, + ): + super().__init__( + name=name or "StabilityMetric", + description=description or "Evaluates structure stability", + lower_is_better=lower_is_better, + n_jobs=n_jobs, + ) + self.config = StabilityMetricConfig( + name=self.config.name, + description=self.config.description, + lower_is_better=self.config.lower_is_better, + n_jobs=self.config.n_jobs, + relaxer_type=relaxer_type, + relaxer_config=relaxer_config, + mp_entries_file=mp_entries_file, + ) + + # Initialize the relaxer + self.relaxer = get_relaxer(relaxer_type, **relaxer_config) + + # Initialize MP compatibility + self.compatibility = MaterialsProject2020Compatibility(check_potcar=False) + + # Load MP entries if file provided + self.mp_entries = None + # check if path exists + if mp_entries_file and Path(mp_entries_file).exists(): + self.load_mp_entries(mp_entries_file) + + def load_mp_entries(self, mp_entries_file: str) -> None: + """Load Materials Project entries for e_above_hull calculation. + + Parameters + ---------- + mp_entries_file : str + Path to the MP entries file. + """ + self.mp_entries = get_patched_phase_diagram_mp(Path(mp_entries_file)) + # import pandas as pd + + # try: + # df = pd.read_json(mp_entries_file) + # df = df[df["entry"].apply(lambda x: "GGA" in x["entry_id"])] + + # mp_computed_entries = [ + # ComputedEntry.from_dict(x) + # for x in df.entry + # if "GGA" in x["parameters"]["run_type"] + # ] + # self.mp_entries = [ + # entry + # for entry in mp_computed_entries + # if not np.any(["R2SCAN" in a.name for a in entry.energy_adjustments]) + # ] + # logger.info(f"Loaded {len(self.mp_entries)} MP entries") + # except Exception as e: + # logger.error(f"Failed to load MP entries: {str(e)}") + # self.mp_entries = None + + def _get_compute_attributes(self) -> dict[str, Any]: + """Get the attributes for the compute_structure method.""" + return { + "relaxer": self.relaxer, + "compatibility": self.compatibility, + "mp_entries": self.mp_entries, + } + + @staticmethod + def compute_structure( + structure: Structure, + relaxer: BaseRelaxer, + compatibility: MaterialsProject2020Compatibility, + mp_entries: Optional[PatchedPhaseDiagram] = None, + ) -> Dict[str, Any]: + """Compute relaxation and energy above hull for a structure. + + Parameters + ---------- + structure : Structure + A pymatgen Structure object to evaluate. + relaxer : BaseRelaxer + Relaxer object to use. + compatibility : MaterialsProject2020Compatibility + MP compatibility scheme. + mp_entries : list[ComputedEntry], optional + MP entries for e_above_hull calculation. + + Returns + ------- + dict + Dictionary containing relaxation results and e_above_hull. + """ + try: + # Relax structure + relaxation_result = relaxer.relax(structure) + if not relaxation_result.success: + logger.error(f"Relaxation failed: {relaxation_result.message}") + return np.nan + + result = { + "success": True, + "relaxed_energy": relaxation_result.energy, + "relaxed_structure": relaxation_result.structure, + "e_above_hull": None, + } + + # Calculate e_above_hull if MP entries are available + if mp_entries is not None: + # Calculate e_above_hull + cse = generate_CSE( + relaxation_result.structure, relaxation_result.energy + ) + + e_above_hull = mp_entries.get_e_above_hull(cse, allow_negative=True) + result["e_above_hull"] = e_above_hull + return result["e_above_hull"] + except Exception as e: + logger.error(f"Computation failed: {str(e)}") + return np.nan + + def aggregate_results(self, values: list[float]) -> dict[str, Any]: + """Aggregate results into final metric values. + + Parameters + ---------- + values : list[float] + List of computation results for each structure. + + Returns + ------- + dict + Dictionary with aggregated metrics. + """ + # Convert to numpy array for efficient operations + values_array = np.array(values) + + # Filter out NaN values + valid_mask = ~np.isnan(values_array) + e_above_hull_values = values_array[valid_mask] + + if len(e_above_hull_values) > 0: + mean_e_above_hull = np.mean(e_above_hull_values) + e_above_hull_std = ( + np.std(e_above_hull_values) if len(e_above_hull_values) > 1 else 0.0 + ) + # Calculate ratio of stable structures (e_above_hull <= 0) using numpy + stable_ratio = np.sum(e_above_hull_values <= 0) / len(values) + # Calculate ratio of metastable structures (e_above_hull <= 0.1) using numpy + metastable_ratio = np.sum(e_above_hull_values <= 0.1) / len(values) + else: + mean_e_above_hull = 0.0 + e_above_hull_std = 0.0 + stable_ratio = 0.0 + metastable_ratio = 0.0 + + return { + "metrics": { + "mean_e_above_hull": mean_e_above_hull, + "stable_ratio": stable_ratio, + "metastable_ratio": metastable_ratio, + }, + "primary_metric": "stable_ratio", + "uncertainties": { + "e_above_hull_std": {"std": e_above_hull_std}, + }, + } diff --git a/src/lematerial_forgebench/utils/e_above_hull.py b/src/lematerial_forgebench/utils/e_above_hull.py new file mode 100644 index 00000000..da207bce --- /dev/null +++ b/src/lematerial_forgebench/utils/e_above_hull.py @@ -0,0 +1,65 @@ +"""Util functions for e_above_hull calculation.""" + +import gzip +import pickle +import tempfile +from pathlib import Path + +from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram +from pymatgen.core import Structure +from pymatgen.entries.compatibility import MaterialsProject2020Compatibility +from pymatgen.entries.computed_entries import ComputedStructureEntry +from pymatgen.io.vasp.inputs import Incar, Poscar +from pymatgen.io.vasp.sets import MPRelaxSet + + +def get_patched_phase_diagram_mp(path: Path) -> PatchedPhaseDiagram: + # Check if the file has a .gz extension + if path.suffix == ".gz": + # Open as gzip file + with gzip.open(path, "rb") as f: + ppd_mp = pickle.load(f) + else: + # Open as regular file + with open(path, "rb") as f: + ppd_mp = pickle.load(f) + return ppd_mp + + +def generate_CSE(structure, energy): + # Write VASP inputs files as if we were going to do a standard MP run + # this is mainly necessary to get the right U values / etc + b = MPRelaxSet(structure) + with tempfile.TemporaryDirectory() as tmpdirname: + b.write_input(f"{tmpdirname}/", potcar_spec=True) + poscar = Poscar.from_file(f"{tmpdirname}/POSCAR") + incar = Incar.from_file(f"{tmpdirname}/INCAR") + clean_structure = Structure.from_file(f"{tmpdirname}/POSCAR") + + # Get the U values and figure out if we should have run a GGA+U calc + param = {"hubbards": {}} + if "LDAUU" in incar: + param["hubbards"] = dict(zip(poscar.site_symbols, incar["LDAUU"])) + param["is_hubbard"] = ( + incar.get("LDAU", True) and sum(param["hubbards"].values()) > 0 + ) + if param["is_hubbard"]: + param["run_type"] = "GGA+U" + + # Make a ComputedStructureEntry without the correction + cse_d = { + "structure": clean_structure, + "energy": energy, + "correction": 0.0, + "parameters": param, + } + + # Apply the MP 2020 correction scheme (anion/+U/etc) + cse = ComputedStructureEntry.from_dict(cse_d) + _ = MaterialsProject2020Compatibility(check_potcar=False).process_entries( + cse, + clean=True, + ) + + # Return the final CSE (notice that the composition/etc is also clean, not things like Fe3+)! + return cse diff --git a/src/lematerial_forgebench/utils/relaxers/__init__.py b/src/lematerial_forgebench/utils/relaxers/__init__.py new file mode 100644 index 00000000..c85faef7 --- /dev/null +++ b/src/lematerial_forgebench/utils/relaxers/__init__.py @@ -0,0 +1,17 @@ +"""Relaxer implementations package. + +This module ensures all relaxer implementations are imported and registered +when the package is used. +""" + +# Import the registry functions +# Import all implementations to trigger registration +from . import relaxers +from .registry import BaseRelaxer, RelaxationResult, get_relaxer, register_relaxer + +__all__ = [ + "BaseRelaxer", + "RelaxationResult", + "get_relaxer", + "register_relaxer", +] diff --git a/src/lematerial_forgebench/utils/relaxers/registry.py b/src/lematerial_forgebench/utils/relaxers/registry.py new file mode 100644 index 00000000..23ead105 --- /dev/null +++ b/src/lematerial_forgebench/utils/relaxers/registry.py @@ -0,0 +1,128 @@ +"""Registry for relaxer implementations and base classes.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, Type + +from pymatgen.core import Structure +from pymatgen.entries.computed_entries import ComputedStructureEntry + + +@dataclass +class RelaxationResult: + """Result of a structure relaxation. + + Parameters + ---------- + success : bool + Whether the relaxation was successful. + energy : float | None + Final energy of the relaxed structure. + structure : Structure | None + Final relaxed structure. + message : str | None + Error message if relaxation failed. + """ + + success: bool + energy: float | None = None + structure: Structure | None = None + message: str | None = None + + +class BaseRelaxer(ABC): + """Base class for structure relaxation implementations. + + All relaxer implementations should inherit from this class and implement + the required methods. + """ + + @abstractmethod + def relax(self, structure: Structure) -> RelaxationResult: + """Relax a structure and return the result. + + Parameters + ---------- + structure : Structure + Structure to relax. + + Returns + ------- + RelaxationResult + Result of the relaxation. + """ + pass + + @abstractmethod + def get_computed_entry( + self, structure: Structure, energy: float + ) -> ComputedStructureEntry: + """Create a ComputedStructureEntry from a relaxed structure. + + Parameters + ---------- + structure : Structure + The relaxed structure. + energy : float + The energy of the relaxed structure. + + Returns + ------- + ComputedStructureEntry + The computed structure entry with appropriate corrections applied. + """ + pass + + +_RELAXER_REGISTRY: Dict[str, Type[BaseRelaxer]] = {} + + +def register_relaxer(name: str): + """Register a relaxer implementation. + + Parameters + ---------- + name : str + Name of the relaxer. + + Returns + ------- + callable + Decorator function. + """ + + def decorator(cls: Type[BaseRelaxer]) -> Type[BaseRelaxer]: + if not issubclass(cls, BaseRelaxer): + raise TypeError(f"{cls.__name__} must inherit from BaseRelaxer") + _RELAXER_REGISTRY[name] = cls + return cls + + return decorator + + +def get_relaxer(relaxer_type: str, **kwargs) -> BaseRelaxer: + """Get a relaxer implementation by name. + + Parameters + ---------- + relaxer_type : str + Name of the relaxer. + **kwargs + Additional arguments to pass to the relaxer constructor. + + Returns + ------- + BaseRelaxer + Relaxer instance. + + Raises + ------ + ValueError + If relaxer_type is not registered. + """ + if relaxer_type not in _RELAXER_REGISTRY: + raise ValueError( + f"Unknown relaxer type: {relaxer_type}. " + f"Available types: {list(_RELAXER_REGISTRY.keys())}" + ) + return _RELAXER_REGISTRY[relaxer_type](**kwargs) diff --git a/src/lematerial_forgebench/utils/relaxers/relaxers.py b/src/lematerial_forgebench/utils/relaxers/relaxers.py new file mode 100644 index 00000000..c15545bf --- /dev/null +++ b/src/lematerial_forgebench/utils/relaxers/relaxers.py @@ -0,0 +1,205 @@ +"""Implementations for structure relaxation.""" + +import gc +import tempfile +from abc import abstractmethod + +import torch +from ase.filters import FrechetCellFilter +from ase.optimize import FIRE + +# from fairchem.core import OCPCalculator +from orb_models.forcefield import pretrained +from orb_models.forcefield.calculator import ORBCalculator +from pymatgen.core import Structure +from pymatgen.entries.compatibility import MaterialsProject2020Compatibility +from pymatgen.entries.computed_entries import ComputedStructureEntry +from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.io.vasp.inputs import Incar, Poscar +from pymatgen.io.vasp.sets import MPRelaxSet + +from lematerial_forgebench.utils.logging import logger +from lematerial_forgebench.utils.relaxers.registry import ( + BaseRelaxer, + RelaxationResult, + register_relaxer, +) + + +class BaseVASPRelaxer(BaseRelaxer): + """Base class for relaxers that use VASP-like parameters.""" + + def get_computed_entry( + self, structure: Structure, energy: float + ) -> ComputedStructureEntry: + """Create a ComputedStructureEntry from a relaxed structure. + + Parameters + ---------- + structure : Structure + The relaxed structure. + energy : float + The energy of the relaxed structure. + + Returns + ------- + ComputedStructureEntry + The computed structure entry with MP2020 corrections applied. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + b = MPRelaxSet(structure) + b.write_input(f"{tmpdirname}/", potcar_spec=True) + poscar = Poscar.from_file(f"{tmpdirname}/POSCAR") + incar = Incar.from_file(f"{tmpdirname}/INCAR") + clean_structure = Structure.from_file(f"{tmpdirname}/POSCAR") + + # Get the U values and figure out if we should have run a GGA+U calc + param = {"hubbards": {}} + if "LDAUU" in incar: + param["hubbards"] = dict(zip(poscar.site_symbols, incar["LDAUU"])) + param["is_hubbard"] = ( + incar.get("LDAU", True) and sum(param["hubbards"].values()) > 0 + ) + if param["is_hubbard"]: + param["run_type"] = "GGA+U" + + # Make a ComputedStructureEntry without the correction + cse_d = { + "structure": clean_structure, + "energy": energy, + "correction": 0.0, + "parameters": param, + } + + # Apply the MP 2020 correction scheme (anion/+U/etc) + cse = ComputedStructureEntry.from_dict(cse_d) + _ = MaterialsProject2020Compatibility(check_potcar=False).process_entries( + cse, + clean=True, + ) + + return cse + + +class ASERelaxerBase(BaseVASPRelaxer): + """Base class for ASE-based relaxers.""" + + def __init__( + self, + fmax: float = 0.02, + steps: int = 500, + cpu: bool = True, + **params, + ): + """Initialize ASE relaxer. + + Parameters + ---------- + fmax : float, default=0.02 + Maximum force convergence criterion. + steps : int, default=500 + Maximum number of optimization steps. + cpu : bool, default=False + Whether to use CPU instead of GPU. + **params + Additional parameters specific to each relaxer type. + """ + self.fmax = fmax + self.steps = steps + self.cpu = cpu + self.params = params # Store additional parameters for subclasses + # Abstract calculator setup - must be implemented by subclasses + self.calc = self._setup_calculator() + + @abstractmethod + def _setup_calculator(self): + """Set up the calculator for this relaxer. + + Returns + ------- + Calculator + The initialized calculator object. + """ + pass + + def relax(self, structure: Structure) -> RelaxationResult: + """Relax a structure using calculator. + + Parameters + ---------- + structure : Structure + Structure to relax. + + Returns + ------- + RelaxationResult + Result of the relaxation. + """ + try: + if structure is None or not structure.is_valid(): + print("Skipping structure: Invalid crystal") + return RelaxationResult( + success=False, + message="Invalid crystal", + ) + + # Convert to ASE atoms + atoms = structure.to_ase_atoms() + + atoms.calc = self.calc + + # Relax structure + dyn = FIRE(FrechetCellFilter(atoms), logfile=None) + dyn.run(fmax=self.fmax, steps=self.steps) + + # Get results + final_energy = atoms.get_potential_energy() + + final_structure = AseAtomsAdaptor.get_structure(atoms) + + # Clean up + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return RelaxationResult( + success=True, + energy=final_energy, + structure=final_structure, + ) + except Exception as e: + logger.error(f"ASE relaxation failed: {str(e)}") + return RelaxationResult( + success=False, + message=str(e), + ) + + +@register_relaxer("orb") +class OrbRelaxerImpl(ASERelaxerBase): + """Orb relaxer implementation.""" + + def _setup_calculator(self): + """Set up the Orb calculator.""" + device = "cpu" if self.cpu else "cuda" + if self.params.get("direct", False): + orbff = pretrained.orb_v3_direct_inf_mpa( + device=device, + precision="float32-high", # or "float32-highest" / "float64 + ) + else: + orbff = pretrained.orb_v3_conservative_inf_mpa( + device=device, + precision="float32-high", # or "float32-highest" / "float64 + ) + return ORBCalculator(orbff, device=device) + + +# TODO: Fix Meta Fairchem Relaxers +# @register_relaxer("ocp") +# class OCPRelaxerImpl(ASERelaxerBase): +# """OCP relaxer implementation.""" + +# def _setup_calculator(self): +# """Set up the OCP calculator.""" +# return OCPCalculator(checkpoint_path=self.checkpoint_path, cpu=self.cpu) diff --git a/tests/benchmarks/test_stability_benchmark.py b/tests/benchmarks/test_stability_benchmark.py new file mode 100644 index 00000000..72b827f0 --- /dev/null +++ b/tests/benchmarks/test_stability_benchmark.py @@ -0,0 +1,174 @@ +"""Tests for stability benchmark.""" + +from pymatgen.util.testing import PymatgenTest + +from lematerial_forgebench.benchmarks.stability_benchmark import StabilityBenchmark + + +class TestStabilityBenchmark: + """Test suite for StabilityBenchmark class.""" + + def test_initialization_default(self): + """Test initialization with default parameters.""" + benchmark = StabilityBenchmark(mp_entries_file=None) + + # Check name and properties + assert benchmark.config.name == "StabilityBenchmark" + assert "version" in benchmark.config.metadata + assert benchmark.config.metadata["category"] == "stability" + + # Check default relaxer type + assert benchmark.config.metadata["relaxer_type"] == "orb" + + # Check correct evaluator + assert len(benchmark.evaluators) == 1 + assert "stability" in benchmark.evaluators + + def test_initialization_custom(self): + """Test initialization with custom relaxer configuration.""" + relaxer_config = {"steps": 2000, "fmax": 0.01, "direct": True} + + benchmark = StabilityBenchmark( + relaxer_type="orb", + relaxer_config=relaxer_config, + mp_entries_file="custom/path/mp_entries.json", + name="Custom Stability Benchmark", + description="Custom description", + metadata={"test_key": "test_value"}, + ) + + # Check custom values + assert benchmark.config.name == "Custom Stability Benchmark" + assert benchmark.config.description == "Custom description" + assert benchmark.config.metadata["test_key"] == "test_value" + assert benchmark.config.metadata["relaxer_type"] == "orb" + assert benchmark.config.metadata["relaxer_config"] == relaxer_config + assert ( + benchmark.config.metadata["mp_entries_file"] + == "custom/path/mp_entries.json" + ) + + def test_evaluate_with_mp_entries(self): + """Test benchmark evaluation on structures""" + benchmark = StabilityBenchmark() + + # Create test structures + test = PymatgenTest() + structures = [test.get_structure("Si"), test.get_structure("LiFePO4")] + + # Run benchmark + result = benchmark.evaluate(structures) + + # Check result format + assert len(result.evaluator_results) == 1 + assert "stability" in result.evaluator_results + assert "stability_score" in result.final_scores + assert "stable_ratio" in result.final_scores + print(result.final_scores) + + # Check score types + assert isinstance(result.final_scores["stability_score"], (int, float)) + assert isinstance(result.final_scores["stable_ratio"], (int, float)) + assert isinstance(result.final_scores["metastable_ratio"], (int, float)) + assert isinstance(result.final_scores["mean_e_above_hull"], (int, float)) + + def test_empty_structures(self): + """Test behavior with empty structure list.""" + benchmark = StabilityBenchmark() + + # Test behavior with no structures - should not raise error + result = benchmark.evaluate([]) + + # Should get default values + assert result.final_scores["stability_score"] == 0.0 + assert result.final_scores["stable_ratio"] == 0.0 + assert result.final_scores["metastable_ratio"] == 0.0 + assert result.final_scores["mean_e_above_hull"] == 0.0 + + def test_aggregate_evaluator_results(self): + """Test result aggregation logic.""" + benchmark = StabilityBenchmark() + + # Mock evaluator_results as passed by BaseBenchmark.evaluate + # It contains the evaluator's combined_value and the primary metric value + # of each metric configured for that evaluator (e.g., "metric_name_value"). + mock_evaluator_results_from_base = { + "stability": { # Name of the evaluator + "combined_value": 0.75, # Evaluator's combined score + # The StabilityMetric instance within the 'stability' evaluator was named "stability". + # Its primary metric (stable_ratio) is flattened by BaseBenchmark to "stability_value". + "stability_value": 0.6, + } + } + + # Aggregate results + scores = benchmark.aggregate_evaluator_results(mock_evaluator_results_from_base) + + # Check scores + # aggregate_evaluator_results should pick up combined_value as stability_score + # and stability_value as stable_ratio. + # mean_e_above_hull and metastable_ratio will be defaults (nan, 0.0) because + # they are not present in the input dict. + assert scores["stability_score"] == 0.75 + assert scores["stable_ratio"] == 0.6 + assert ( + "mean_e_above_hull" in scores + and scores["mean_e_above_hull"] != scores["mean_e_above_hull"] + ) # Check for NaN + assert scores["metastable_ratio"] == 0.0 + + def test_benchmark_metadata(self): + """Test benchmark metadata structure.""" + benchmark = StabilityBenchmark( + relaxer_type="orb", + relaxer_config={"direct": True}, + mp_entries_file="test.json", + ) + + metadata = benchmark.config.metadata + + # Check required metadata fields + assert metadata["version"] == "0.1.0" + assert metadata["category"] == "stability" + assert metadata["relaxer_type"] == "orb" + assert metadata["relaxer_config"]["direct"] is True + assert metadata["mp_entries_file"] == "test.json" + + +def test_benchmark_description_generation(): + """Test automatic description generation.""" + # Test with ORB relaxer + orb_benchmark = StabilityBenchmark(relaxer_type="orb") + assert "ORB" in orb_benchmark.config.description + assert "relaxation and energy above hull" in orb_benchmark.config.description + + # Test with custom description + custom_benchmark = StabilityBenchmark( + relaxer_type="orb", description="Custom description" + ) + assert custom_benchmark.config.description == "Custom description" + + +def test_evaluator_configuration(): + """Test that evaluator is properly configured.""" + benchmark = StabilityBenchmark(relaxer_type="orb") + + # Check evaluator configuration + stability_evaluator = benchmark.evaluators["stability"] + print(stability_evaluator) + assert stability_evaluator.config.name == "stability" + assert "ORB" in stability_evaluator.config.description + assert stability_evaluator.config.weights == {"stability": 1.0} + assert stability_evaluator.config.aggregation_method == "weighted_mean" + + +def test_orb_relaxer_type(): + """Test benchmark with ORB relaxer type.""" + benchmark = StabilityBenchmark(relaxer_type="orb") + + # Check that the relaxer type is properly set + assert benchmark.config.metadata["relaxer_type"] == "orb" + + # Check that evaluator name reflects the relaxer type + expected_name = "stability" + assert benchmark.evaluators["stability"].config.name == expected_name diff --git a/tests/metrics/test_stability_metrics.py b/tests/metrics/test_stability_metrics.py new file mode 100644 index 00000000..ffd82df5 --- /dev/null +++ b/tests/metrics/test_stability_metrics.py @@ -0,0 +1,40 @@ +"""Tests for stability metrics implementation.""" + +import pytest +from pymatgen.util.testing import PymatgenTest + +from lematerial_forgebench.metrics.stability_metrics import StabilityMetric + + +@pytest.fixture +def test_structures(): + """Create test structures for stability evaluation.""" + test = PymatgenTest() + structures = [ + test.get_structure("Si"), + test.get_structure("LiFePO4"), + test.get_structure("CsCl"), + ] + return structures + + +def test_stability_metric(test_structures): + metric = StabilityMetric( + relaxer_type="orb", + relaxer_config={"steps": 500, "fmax": 0.02}, + mp_entries_file="src/lematerial_forgebench/utils/relaxers/2023-02-07-ppd-mp.pkl.gz", + ) + + result = metric(test_structures) + + # Check that computation ran (may have failed indices due to no MP entries) + assert len(result.individual_values) == len(test_structures) + + # Check result structure + assert "mean_e_above_hull" in result.metrics + assert "stable_ratio" in result.metrics + assert "metastable_ratio" in result.metrics + assert result.primary_metric == "stable_ratio" + + assert result.metrics["stable_ratio"] >= 0.0 + assert result.metrics["metastable_ratio"] >= 0.0 From 52d35ddb07054f7d51084314a3057a0aef4e77ce Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Mon, 26 May 2025 22:29:58 -0700 Subject: [PATCH 2/4] fix linting --- src/lematerial_forgebench/utils/relaxers/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lematerial_forgebench/utils/relaxers/__init__.py b/src/lematerial_forgebench/utils/relaxers/__init__.py index c85faef7..00e43d6e 100644 --- a/src/lematerial_forgebench/utils/relaxers/__init__.py +++ b/src/lematerial_forgebench/utils/relaxers/__init__.py @@ -6,7 +6,6 @@ # Import the registry functions # Import all implementations to trigger registration -from . import relaxers from .registry import BaseRelaxer, RelaxationResult, get_relaxer, register_relaxer __all__ = [ From d346524223e48e3cd3de4699f58dcc965503b4bd Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Sun, 1 Jun 2025 23:14:18 -0700 Subject: [PATCH 3/4] updated with preprocess energy calc step --- pyproject.toml | 1 + .../benchmarks/stability_benchmark.py | 103 +++-- src/lematerial_forgebench/cli.py | 20 +- .../metrics/stability_metrics.py | 266 +++++------- .../preprocess/__init__.py | 0 src/lematerial_forgebench/preprocess/base.py | 407 ++++++++++++++++++ .../preprocess/stability_preprocess.py | 196 +++++++++ .../utils/relaxers/registry.py | 2 +- .../utils/relaxers/relaxers.py | 9 +- tests/benchmarks/test_stability_benchmark.py | 95 ++-- tests/metrics/test_stability_metrics.py | 78 +++- 11 files changed, 907 insertions(+), 270 deletions(-) create mode 100644 src/lematerial_forgebench/preprocess/__init__.py create mode 100644 src/lematerial_forgebench/preprocess/base.py create mode 100644 src/lematerial_forgebench/preprocess/stability_preprocess.py diff --git a/pyproject.toml b/pyproject.toml index e56e6f7f..6fee2720 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,3 +57,4 @@ material-hasher = { git = "https://github.com/lematerial/material-hasher.git" } [tool.ruff.lint] extend-select = ["I"] +ignore = ["F401"] \ No newline at end of file diff --git a/src/lematerial_forgebench/benchmarks/stability_benchmark.py b/src/lematerial_forgebench/benchmarks/stability_benchmark.py index 004bc9fa..d0d8de61 100644 --- a/src/lematerial_forgebench/benchmarks/stability_benchmark.py +++ b/src/lematerial_forgebench/benchmarks/stability_benchmark.py @@ -7,8 +7,11 @@ from typing import Any, Dict from lematerial_forgebench.benchmarks.base import BaseBenchmark -from lematerial_forgebench.evaluator import EvaluatorConfig -from lematerial_forgebench.metrics.stability_metrics import StabilityMetric +from lematerial_forgebench.evaluator import EvaluationResult, EvaluatorConfig +from lematerial_forgebench.metrics.stability_metrics import ( + MetastabilityMetric, + StabilityMetric, +) class StabilityBenchmark(BaseBenchmark): @@ -16,9 +19,6 @@ class StabilityBenchmark(BaseBenchmark): def __init__( self, - relaxer_type: str = "orb", - relaxer_config: Dict[str, Any] | None = None, - mp_entries_file: str = "src/lematerial_forgebench/utils/relaxers/2023-02-07-ppd-mp.pkl.gz", name: str = "StabilityBenchmark", description: str | None = None, metadata: Dict[str, Any] | None = None, @@ -42,39 +42,37 @@ def __init__( """ if description is None: description = ( - f"Evaluates the stability of crystal structures using {relaxer_type.upper()} " - "relaxation and energy above hull calculations." + "Evaluates the stability and metastability of crystal structures" ) - # Set default relaxer config if not provided - if relaxer_config is None: - relaxer_config = {"steps": 500, "fmax": 0.02} - # Initialize the stability metric - stability_metric = StabilityMetric( - relaxer_type=relaxer_type, - relaxer_config=relaxer_config, - mp_entries_file=mp_entries_file, - ) + stability_metric = StabilityMetric() - # Set up evaluator config + # Set up evaluator configs evaluator_configs = { "stability": EvaluatorConfig( - name=f"{relaxer_type.upper()} Stability", - description=f"Evaluates structure stability using {relaxer_type.upper()}", + name="Stability", + description="Evaluates structure stability", metrics={"stability": stability_metric}, weights={"stability": 1.0}, aggregation_method="weighted_mean", ), } + # Add metastability evaluator if requested + metastability_metric = MetastabilityMetric() + evaluator_configs["metastability"] = EvaluatorConfig( + name="Metastability Analysis", + description="Evaluates metastability from precomputed e_above_hull values", + metrics={"metastability": metastability_metric}, + weights={"metastability": 1.0}, + aggregation_method="weighted_mean", + ) + # Create benchmark metadata benchmark_metadata = { "version": "0.1.0", "category": "stability", - "relaxer_type": relaxer_type, - "relaxer_config": relaxer_config, - "mp_entries_file": mp_entries_file, **(metadata or {}), } @@ -86,45 +84,64 @@ def __init__( ) def aggregate_evaluator_results( - self, evaluator_results: Dict[str, Dict[str, Any]] + self, evaluator_results: Dict[str, EvaluationResult] ) -> Dict[str, float]: """Aggregate results from multiple evaluators into final scores. Parameters ---------- - evaluator_results : dict[str, dict[str, Any]] - Results from each evaluator, as structured by BaseBenchmark.evaluate. - Example: {"evaluator_name": {"combined_value": 0.X, "metric_name_value": 0.Y}} + evaluator_results : dict[str, EvaluationResult] + Results from each evaluator. Returns ------- dict[str, float] Final aggregated scores. """ + import math + + def safe_float(value, default=0.0): + """Safely convert value to float, handling None and NaN.""" + if value is None: + return default + try: + float_val = float(value) + if math.isnan(float_val): + return default + return float_val + except (TypeError, ValueError): + return default + final_scores = { - "stability_score": 0.0, "stable_ratio": 0.0, - "mean_e_above_hull": 0.0, "metastable_ratio": 0.0, + "mean_e_above_hull": 0.0, } - stability_eval_data: Dict[str, Any] | None = evaluator_results.get("stability") - print("stability_eval_data", stability_eval_data) - if stability_eval_data: - if stability_eval_data.get("combined_value") is not None: - final_scores["stability_score"] = stability_eval_data["combined_value"] + # Extract stability results + stability_results = evaluator_results.get("stability") + if stability_results: + # Main stability ratio + final_scores["stable_ratio"] = safe_float( + stability_results.get("combined_value") + ) - if stability_eval_data.get("stability_value") is not None: - final_scores["stable_ratio"] = stability_eval_data["stability_value"] + # Extract individual metrics from stability metric + stability_metric_results = stability_results.get("metric_results", {}).get( + "stability", {} + ) + stability_metrics = stability_metric_results.get("metrics", {}) - if stability_eval_data.get("mean_e_above_hull") is not None: - final_scores["mean_e_above_hull"] = stability_eval_data[ - "mean_e_above_hull" - ] + final_scores["mean_e_above_hull"] = safe_float( + stability_metrics.get("mean_e_above_hull") + ) - if stability_eval_data.get("metastable_ratio") is not None: - final_scores["metastable_ratio"] = stability_eval_data[ - "metastable_ratio" - ] + # Extract metastability results if available + metastability_results = evaluator_results.get("metastability") + if metastability_results: + # Main metastability score + final_scores["metastable_ratio"] = safe_float( + metastability_results.get("combined_value") + ) return final_scores diff --git a/src/lematerial_forgebench/cli.py b/src/lematerial_forgebench/cli.py index 7876c2c3..63f1393a 100644 --- a/src/lematerial_forgebench/cli.py +++ b/src/lematerial_forgebench/cli.py @@ -11,6 +11,7 @@ import yaml from lematerial_forgebench.benchmarks.example import ExampleBenchmark +from lematerial_forgebench.benchmarks.stability_benchmark import StabilityBenchmark from lematerial_forgebench.benchmarks.validity_benchmark import ValidityBenchmark from lematerial_forgebench.data.structure import format_structures from lematerial_forgebench.metrics.validity_metrics import ( @@ -19,6 +20,9 @@ MinimumInteratomicDistanceMetric, PhysicalPlausibilityMetric, ) +from lematerial_forgebench.preprocess.stability_preprocess import ( + StabilityPreprocessor, +) from lematerial_forgebench.utils.logging import logger CONFIGS_DIR = Path(__file__).parent.parent / "config" @@ -169,13 +173,9 @@ def main(input: str, config_name: str, output: str): coord_tolerance = coord_config.get("tolerance", 0.2) # Create custom metrics with configuration - ChargeNeutralityMetric( - tolerance=charge_tolerance, strict=charge_strict - ) + ChargeNeutralityMetric(tolerance=charge_tolerance, strict=charge_strict) - MinimumInteratomicDistanceMetric( - scaling_factor=distance_scaling - ) + MinimumInteratomicDistanceMetric(scaling_factor=distance_scaling) CoordinationEnvironmentMetric( nn_method=coord_nn_method, tolerance=coord_tolerance @@ -196,6 +196,14 @@ def main(input: str, config_name: str, output: str): "metric_configs": metric_configs, }, ) + elif benchmark_type == "stability": + # before running the benchmark, we need to preprocess the structures + stability_preprocessor = StabilityPreprocessor() + # Use the preprocessor to process structures + preprocessor_result = stability_preprocessor(structures) + structures = preprocessor_result.processed_structures + + benchmark = StabilityBenchmark() else: raise ValueError(f"Unknown benchmark type: {benchmark_type}") diff --git a/src/lematerial_forgebench/metrics/stability_metrics.py b/src/lematerial_forgebench/metrics/stability_metrics.py index 5f72741a..3f34d547 100644 --- a/src/lematerial_forgebench/metrics/stability_metrics.py +++ b/src/lematerial_forgebench/metrics/stability_metrics.py @@ -15,62 +15,19 @@ from pymatgen.entries.compatibility import MaterialsProject2020Compatibility from lematerial_forgebench.metrics.base import BaseMetric, MetricConfig -from lematerial_forgebench.utils.e_above_hull import ( - generate_CSE, - get_patched_phase_diagram_mp, -) from lematerial_forgebench.utils.logging import logger -from lematerial_forgebench.utils.relaxers import BaseRelaxer, get_relaxer - - -@dataclass -class StabilityMetricConfig(MetricConfig): - """Configuration for the StabilityMetric. - - Parameters - ---------- - relaxer_type : str - Type of relaxer to use (e.g., "chgnet", "eqv2", "esen"). - relaxer_config : dict - Configuration for the specific relaxer. - mp_entries_file : str, optional - Path to the Materials Project entries file for e_above_hull calculation. - """ - - relaxer_type: str = "orb" - relaxer_config: Dict[str, Any] = field(default_factory=dict) - mp_entries_file: Optional[str] = None class StabilityMetric(BaseMetric): - """Evaluate structure relaxation and energy above hull. - - This metric handles both the relaxation of structures using various models - and the calculation of energy above hull using the Materials Project database. - - Parameters - ---------- - relaxer_type : str - Type of relaxer to use. - relaxer_config : dict - Configuration for the specific relaxer. - mp_entries_file : str, optional - Path to the Materials Project entries file. - name : str, optional - Custom name for the metric. - description : str, optional - Description of what the metric measures. - lower_is_better : bool, default=True - Lower energies indicate better stability. - n_jobs : int, default=1 - Number of parallel jobs to run. + """Evaluate structure metastability using precomputed e_above_hull values. + + This metric assumes that e_above_hull values have already been computed + and stored in structure.properties['e_above_hull']. It calculates stability + statistics without performing any relaxation or recomputation. """ def __init__( self, - relaxer_type: str, - relaxer_config: Dict[str, Any], - mp_entries_file: Optional[str] = None, name: str | None = None, description: str | None = None, lower_is_better: bool = False, @@ -78,121 +35,143 @@ def __init__( ): super().__init__( name=name or "StabilityMetric", - description=description or "Evaluates structure stability", + description=description + or "Evaluates structure stability from precomputed e_above_hull", lower_is_better=lower_is_better, n_jobs=n_jobs, ) - self.config = StabilityMetricConfig( - name=self.config.name, - description=self.config.description, - lower_is_better=self.config.lower_is_better, - n_jobs=self.config.n_jobs, - relaxer_type=relaxer_type, - relaxer_config=relaxer_config, - mp_entries_file=mp_entries_file, - ) - # Initialize the relaxer - self.relaxer = get_relaxer(relaxer_type, **relaxer_config) + def _get_compute_attributes(self) -> dict[str, Any]: + """Get the attributes for the compute_structure method.""" + return {} + + @staticmethod + def compute_structure(structure: Structure) -> float: + """Extract precomputed e_above_hull from structure properties. + + Parameters + ---------- + structure : Structure + A pymatgen Structure object with e_above_hull in properties. + + Returns + ------- + float + The precomputed e_above_hull value, or NaN if not available. + """ + try: + # Extract e_above_hull from structure properties + e_above_hull = structure.properties.get("e_above_hull", None) + + if e_above_hull is None: + logger.warning( + "Structure missing e_above_hull in properties, please compute it first using StabilityMetric" + ) + return np.nan - # Initialize MP compatibility - self.compatibility = MaterialsProject2020Compatibility(check_potcar=False) + return float(e_above_hull) - # Load MP entries if file provided - self.mp_entries = None - # check if path exists - if mp_entries_file and Path(mp_entries_file).exists(): - self.load_mp_entries(mp_entries_file) + except Exception as e: + logger.error(f"Failed to extract e_above_hull: {str(e)}") + return np.nan - def load_mp_entries(self, mp_entries_file: str) -> None: - """Load Materials Project entries for e_above_hull calculation. + def aggregate_results(self, values: list[float]) -> dict[str, Any]: + """Aggregate results into final metric values. Parameters ---------- - mp_entries_file : str - Path to the MP entries file. + values : list[float] + List of e_above_hull values for each structure. + + Returns + ------- + dict + Dictionary with aggregated metrics. """ - self.mp_entries = get_patched_phase_diagram_mp(Path(mp_entries_file)) - # import pandas as pd - - # try: - # df = pd.read_json(mp_entries_file) - # df = df[df["entry"].apply(lambda x: "GGA" in x["entry_id"])] - - # mp_computed_entries = [ - # ComputedEntry.from_dict(x) - # for x in df.entry - # if "GGA" in x["parameters"]["run_type"] - # ] - # self.mp_entries = [ - # entry - # for entry in mp_computed_entries - # if not np.any(["R2SCAN" in a.name for a in entry.energy_adjustments]) - # ] - # logger.info(f"Loaded {len(self.mp_entries)} MP entries") - # except Exception as e: - # logger.error(f"Failed to load MP entries: {str(e)}") - # self.mp_entries = None + # Convert to numpy array for efficient operations + values_array = np.array(values) + + # Filter out NaN values + valid_mask = ~np.isnan(values_array) + e_above_hull_values = values_array[valid_mask] + + if len(e_above_hull_values) > 0: + # Calculate ratio of stable structures (e_above_hull <= 0) using numpy + stable_ratio = np.sum(e_above_hull_values <= 0) / len(values) + e_above_hull_std = np.std(e_above_hull_values) + mean_e_above_hull = np.mean(e_above_hull_values) + else: + stable_ratio = 0.0 + e_above_hull_std = 0.0 + mean_e_above_hull = 0.0 - def _get_compute_attributes(self) -> dict[str, Any]: - """Get the attributes for the compute_structure method.""" return { - "relaxer": self.relaxer, - "compatibility": self.compatibility, - "mp_entries": self.mp_entries, + "metrics": { + "stable_ratio": stable_ratio, + "mean_e_above_hull": mean_e_above_hull, + }, + "primary_metric": "stable_ratio", + "uncertainties": { + "e_above_hull_std": {"std": e_above_hull_std}, + }, } + +class MetastabilityMetric(BaseMetric): + """Evaluate structure metastability using precomputed e_above_hull values. + + This metric assumes that e_above_hull values have already been computed + and stored in structure.properties['e_above_hull']. It calculates stability + statistics without performing any relaxation or recomputation. + """ + + def __init__( + self, + name: str | None = None, + description: str | None = None, + lower_is_better: bool = True, + n_jobs: int = 1, + ): + super().__init__( + name=name or "MetastabilityMetric", + description=description + or "Evaluates structure metastability from precomputed e_above_hull", + lower_is_better=lower_is_better, + n_jobs=n_jobs, + ) + + def _get_compute_attributes(self) -> dict[str, Any]: + """Get the attributes for the compute_structure method.""" + return {} + @staticmethod - def compute_structure( - structure: Structure, - relaxer: BaseRelaxer, - compatibility: MaterialsProject2020Compatibility, - mp_entries: Optional[PatchedPhaseDiagram] = None, - ) -> Dict[str, Any]: - """Compute relaxation and energy above hull for a structure. + def compute_structure(structure: Structure) -> float: + """Extract precomputed e_above_hull from structure properties. Parameters ---------- structure : Structure - A pymatgen Structure object to evaluate. - relaxer : BaseRelaxer - Relaxer object to use. - compatibility : MaterialsProject2020Compatibility - MP compatibility scheme. - mp_entries : list[ComputedEntry], optional - MP entries for e_above_hull calculation. + A pymatgen Structure object with e_above_hull in properties. Returns ------- - dict - Dictionary containing relaxation results and e_above_hull. + float + The precomputed e_above_hull value, or NaN if not available. """ try: - # Relax structure - relaxation_result = relaxer.relax(structure) - if not relaxation_result.success: - logger.error(f"Relaxation failed: {relaxation_result.message}") - return np.nan + # Extract e_above_hull from structure properties + e_above_hull = structure.properties.get("e_above_hull", None) - result = { - "success": True, - "relaxed_energy": relaxation_result.energy, - "relaxed_structure": relaxation_result.structure, - "e_above_hull": None, - } - - # Calculate e_above_hull if MP entries are available - if mp_entries is not None: - # Calculate e_above_hull - cse = generate_CSE( - relaxation_result.structure, relaxation_result.energy + if e_above_hull is None: + logger.warning( + f"Structure `{structure.formula}` missing e_above_hull in properties, please compute it first using StabilityPreprocessor" ) + return np.nan + + return float(e_above_hull) - e_above_hull = mp_entries.get_e_above_hull(cse, allow_negative=True) - result["e_above_hull"] = e_above_hull - return result["e_above_hull"] except Exception as e: - logger.error(f"Computation failed: {str(e)}") + logger.error(f"Failed to extract e_above_hull: {str(e)}") return np.nan def aggregate_results(self, values: list[float]) -> dict[str, Any]: @@ -201,7 +180,7 @@ def aggregate_results(self, values: list[float]) -> dict[str, Any]: Parameters ---------- values : list[float] - List of computation results for each structure. + List of e_above_hull values for each structure. Returns ------- @@ -216,28 +195,15 @@ def aggregate_results(self, values: list[float]) -> dict[str, Any]: e_above_hull_values = values_array[valid_mask] if len(e_above_hull_values) > 0: - mean_e_above_hull = np.mean(e_above_hull_values) - e_above_hull_std = ( - np.std(e_above_hull_values) if len(e_above_hull_values) > 1 else 0.0 - ) - # Calculate ratio of stable structures (e_above_hull <= 0) using numpy - stable_ratio = np.sum(e_above_hull_values <= 0) / len(values) # Calculate ratio of metastable structures (e_above_hull <= 0.1) using numpy metastable_ratio = np.sum(e_above_hull_values <= 0.1) / len(values) else: - mean_e_above_hull = 0.0 - e_above_hull_std = 0.0 - stable_ratio = 0.0 metastable_ratio = 0.0 return { "metrics": { - "mean_e_above_hull": mean_e_above_hull, - "stable_ratio": stable_ratio, "metastable_ratio": metastable_ratio, }, - "primary_metric": "stable_ratio", - "uncertainties": { - "e_above_hull_std": {"std": e_above_hull_std}, - }, + "primary_metric": "metastable_ratio", + "uncertainties": {}, } diff --git a/src/lematerial_forgebench/preprocess/__init__.py b/src/lematerial_forgebench/preprocess/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lematerial_forgebench/preprocess/base.py b/src/lematerial_forgebench/preprocess/base.py new file mode 100644 index 00000000..629432fb --- /dev/null +++ b/src/lematerial_forgebench/preprocess/base.py @@ -0,0 +1,407 @@ +import time +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional, TypeVar + +import numpy as np +import pandas as pd +from pymatgen.core.structure import Structure + +from lematerial_forgebench.data.structure import format_structures +from lematerial_forgebench.utils.logging import logger + +PreprocessorClassVar = TypeVar("PreprocessorClassVar", bound="BasePreprocessor") + + +@dataclass +class PreprocessorConfig: + """Base configuration for all preprocessors. + + This class defines the common configuration parameters shared by all preprocessors. + Specific preprocessors should inherit from this class and add their own parameters. + + Parameters + ---------- + name : str, optional + Custom name for the preprocessor. If None, the class name will be used. + description : str, optional + Description of what the preprocessor does. + n_jobs : int, default=1 + Number of parallel jobs to run. + """ + + name: str | None = None + description: str | None = None + n_jobs: int = 1 + + def to_dict(self) -> dict[str, Any]: + """Convert the preprocessor configuration to a dictionary for serialization. + + Returns + ------- + dict[str, Any] + """ + return { + "name": self.name, + "description": self.description, + "n_jobs": self.n_jobs, + } + + +@dataclass +class PreprocessorResult: + """Result of a preprocessing computation. + + Parameters + ---------- + processed_structures : list[Structure] + List of successfully processed pymatgen Structure objects. + config : PreprocessorConfig + The configuration used for this preprocessing task. + computation_time : float + The time taken to complete the preprocessing. + n_input_structures : int + The total number of structures provided as input. + failed_indices : list[int] + The indices (from the original input list) of structures that failed processing. + warnings : list[str] + The warnings generated during the computation. + """ + + processed_structures: list[Structure] + config: PreprocessorConfig + computation_time: float + n_input_structures: int + failed_indices: list[int] + warnings: list[str] + + def __post_init__(self): + """Validate the preprocessor result.""" + if self.n_input_structures < ( + len(self.processed_structures) + len(self.failed_indices) + ): + # This check might be too strict if one input can result in multiple outputs or some other complex scenario + # For now, it assumes a one-to-one mapping or failure. + logger.warning( + "Number of input structures is less than processed + failed. This might indicate an issue." + ) + if self.config is None: + raise ValueError("config cannot be None") + + +class BasePreprocessor(ABC): + """Base class for all preprocessors used in materials science workflows. + + This class defines the interface for preprocessors and provides common functionality + like parallelization. + + Parameters + ---------- + name : str, optional + Custom name for the preprocessor. If None, the class name will be used. + description : str, optional + Description of what the preprocessor does. + n_jobs : int, default=1 + Number of parallel jobs to run. + """ + + def __init__( + self, + name: str | None = None, + description: str | None = None, + n_jobs: int = 1, + ): + self.config = PreprocessorConfig( + name=name or self.__class__.__name__, + description=description, + n_jobs=n_jobs, + ) + + @property + def name(self) -> str: + """Get the name of the preprocessor. + + Returns + ------- + str + """ + return self.config.name + + @property + def description(self) -> str: + """Get the description of the preprocessor. + + Returns + ------- + str + """ + return self.config.description or "No description provided." + + @staticmethod + @abstractmethod + def process_structure(structure: Structure, **process_args: Any) -> Structure: + """Process a single structure. + + This is the main method to implement in derived classes. + It should take a pymatgen Structure object, perform operations on it + (e.g., add properties, modify it), and return the processed Structure object. + If processing fails, it should raise an exception. + + Parameters + ---------- + structure : Structure + A pymatgen Structure object to process. + **process_args : Any + Additional keyword arguments that depend on the preprocessor implementation. + + Returns + ------- + Structure + The processed (e.g., annotated or modified) pymatgen Structure object. + + Raises + ------ + Exception + If processing of the structure fails. + """ + pass + + @staticmethod + def _process_batch( + structures: list[Structure], + process_args: dict[str, Any], + preprocessor_class: PreprocessorClassVar, + ) -> tuple[List[Optional[Structure]], List[int], List[str]]: + """Process a batch of structures. + + Parameters + ---------- + structures : list[Structure] + Batch of structures to process. + process_args : dict[str, Any] + Additional keyword arguments for the process_structure method. + preprocessor_class : PreprocessorClassVar + The preprocessor class to use for the computation. + + Returns + ------- + tuple[List[Optional[Structure]], List[int], List[str]] + Tuple containing: + - List of processed Structure objects or None for failures. + - List of indices (local to batch) of structures that failed. + - List of warning messages for failures. + """ + batch_results: List[Optional[Structure]] = [] + failed_indices_in_batch: List[int] = [] + warnings_for_batch: List[str] = [] + + for idx, structure in enumerate(structures): + try: + processed_structure = preprocessor_class.process_structure( + structure, **process_args + ) + batch_results.append(processed_structure) + except Exception as e: + batch_results.append(None) + failed_indices_in_batch.append(idx) + warnings_for_batch.append( + f"Failed to process structure at batch index {idx} (original may vary): {str(e)}" + ) + logger.debug( + f"Failed to process structure in batch for {preprocessor_class.name}", + exc_info=True, + ) + return batch_results, failed_indices_in_batch, warnings_for_batch + + def _split_into_batches( + self, structures: list[Structure], batch_size: int + ) -> list[list[Structure]]: + """Split structures into batches. + + Parameters + ---------- + structures : list[Structure] + List of structures to split. + batch_size : int + Size of each batch. + + Returns + ------- + list[list[Structure]] + List of batches of structures. + """ + batches = [ + structures[i : i + batch_size] + for i in range(0, len(structures), batch_size) + ] + + def _get_process_attributes(self) -> dict[str, Any]: + """Get additional attributes/arguments for the process_structure method. + Subclasses can override this to pass dynamic configuration. + + Returns + ------- + dict[str, Any] + """ + return {} + + def run( + self, + structures: list[Structure], + ) -> PreprocessorResult: + """Run the preprocessing on a list of structures. + + Parameters + ---------- + structures : list[Structure] + List of pymatgen Structure objects to process. + + Returns + ------- + PreprocessorResult + Object containing the processed structures and computation metadata. + """ + start_time = time.time() + n_input = len(structures) + + # This list will hold results in the original order, with None for failures. + all_results_with_nones: List[Optional[Structure]] = [None] * n_input + global_failed_indices: List[int] = [] + global_warnings: List[str] = [] + + process_args = self._get_process_attributes() + + try: + if ( + self.config.n_jobs <= 1 or n_input <= 1 + ): # Also run serially for single structure + # Serial computation + for idx, structure in enumerate(structures): + try: + processed_structure = self.process_structure( + structure, **process_args + ) + all_results_with_nones[idx] = processed_structure + except Exception as e: + global_failed_indices.append(idx) + global_warnings.append( + f"Failed to process structure {idx}: {str(e)}" + ) + logger.warning( + f"Failed to process structure {idx} for {self.name}", + exc_info=True, + ) + else: + # Parallel computation + # Ensure batch_size is at least 1, and doesn't create more batches than structures or jobs + num_workers = min(self.config.n_jobs, n_input) + batch_size = max( + 1, (n_input + num_workers - 1) // num_workers + ) # Distribute as evenly as possible + + batches = self._split_into_batches(structures, batch_size) + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit( + self._process_batch, batch, process_args, self.__class__ + ) + for batch in batches + ] + + current_original_idx_offset = 0 + for future in futures: + ( + batch_processed_structures_or_nones, + failed_indices_in_batch, + warnings_for_batch, + ) = future.result() + + for i, struct_or_none in enumerate( + batch_processed_structures_or_nones + ): + original_idx = current_original_idx_offset + i + if original_idx < n_input: # Boundary check + all_results_with_nones[original_idx] = struct_or_none + + global_failed_indices.extend( + [ + current_original_idx_offset + i + for i in failed_indices_in_batch + ] + ) + global_warnings.extend(warnings_for_batch) + current_original_idx_offset += len( + batch_processed_structures_or_nones + ) + + # Filter out Nones to get successfully processed structures + final_processed_structures = [ + s for s in all_results_with_nones if s is not None + ] + + except Exception as e: + logger.error(f"Global failure in preprocessor {self.name}", exc_info=True) + return PreprocessorResult( + processed_structures=[], + config=self.config, + computation_time=time.time() - start_time, + n_input_structures=n_input, + failed_indices=list(range(n_input)), # All failed + warnings=[f"Global preprocessing failure for {self.name}: {str(e)}"] + * n_input, + ) + + return PreprocessorResult( + processed_structures=final_processed_structures, + config=self.config, + computation_time=time.time() - start_time, + n_input_structures=n_input, + failed_indices=sorted( + list(set(global_failed_indices)) + ), # Ensure unique and sorted + warnings=global_warnings, + ) + + def __call__( + self, + structures: list[Structure] | list[dict] | pd.DataFrame | str | Path, + ) -> PreprocessorResult: + """Convenient callable interface for running the preprocessor. + + Parameters + ---------- + structures : list[Structure] | list[dict] | pd.DataFrame | str | Path + Structures to process, in various supported formats. + + Returns + ------- + PreprocessorResult + Object containing the processed structures and computation metadata. + """ + structures_list = format_structures(structures) + if not isinstance( + structures_list, list + ): # Ensure it's a list for consistent processing + structures_list = list(structures_list) + return self.run(structures_list) + + @classmethod + def from_config(cls, config: PreprocessorConfig) -> PreprocessorClassVar: + """Create a preprocessor from a configuration. + + Parameters + ---------- + config : PreprocessorConfig + Configuration for the preprocessor. + """ + # Extract only relevant args for BasePreprocessor constructor + # Subclasses might need to handle additional args from their specific configs + base_args = { + k: v + for k, v in config.to_dict().items() + if k in ["name", "description", "n_jobs"] + } + return cls(**base_args) diff --git a/src/lematerial_forgebench/preprocess/stability_preprocess.py b/src/lematerial_forgebench/preprocess/stability_preprocess.py new file mode 100644 index 00000000..e0e1f33f --- /dev/null +++ b/src/lematerial_forgebench/preprocess/stability_preprocess.py @@ -0,0 +1,196 @@ +"""Relaxation metrics for evaluating material structures. + +This module implements metrics for evaluating the relaxation of +material structures using various relaxation models and calculating +energy above hull. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional + +import numpy as np +from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram +from pymatgen.core import Structure +from pymatgen.entries.compatibility import MaterialsProject2020Compatibility + +from lematerial_forgebench.preprocess.base import BasePreprocessor, PreprocessorConfig +from lematerial_forgebench.utils.e_above_hull import ( + generate_CSE, + get_patched_phase_diagram_mp, +) +from lematerial_forgebench.utils.logging import logger +from lematerial_forgebench.utils.relaxers import ( + BaseRelaxer, + get_relaxer, + relaxers, +) + + +@dataclass +class StabilityPreprocessorConfig(PreprocessorConfig): + """Configuration for the StabilityPreprocessor. + + Parameters + ---------- + relaxer_type : str + Type of relaxer to use (e.g., "chgnet", "eqv2", "esen"). + relaxer_config : dict + Configuration for the specific relaxer. + mp_entries_file : str, optional + Path to the Materials Project entries file for e_above_hull calculation. + """ + + relaxer_type: str = "orb" + relaxer_config: Dict[str, Any] = field(default_factory=dict) + mp_entries_file: Optional[str] = None + + +class StabilityPreprocessor(BasePreprocessor): + """Evaluate structure relaxation and energy above hull. + + This metric handles both the relaxation of structures using various models + and the calculation of energy above hull using the Materials Project database. + + Parameters + ---------- + relaxer_type : str + Type of relaxer to use. + relaxer_config : dict + Configuration for the specific relaxer. + mp_entries_file : str, optional + Path to the Materials Project entries file. + name : str, optional + Custom name for the metric. + description : str, optional + Description of what the metric measures. + lower_is_better : bool, default=True + Lower energies indicate better stability. + n_jobs : int, default=1 + Number of parallel jobs to run. + """ + + def __init__( + self, + relaxer_type: str = "orb", + relaxer_config: Dict[str, Any] = {"fmax": 0.02, "steps": 500}, + mp_entries_file: Optional[ + str + ] = "src/lematerial_forgebench/utils/relaxers/2023-02-07-ppd-mp.pkl.gz", + name: str | None = None, + description: str | None = None, + n_jobs: int = 1, + ): + super().__init__( + name=name or "StabilityPreprocessor", + description=description or "Preprocesses structures for stability analysis", + n_jobs=n_jobs, + ) + self.config = StabilityPreprocessorConfig( + name=self.config.name, + description=self.config.description, + n_jobs=self.config.n_jobs, + relaxer_type=relaxer_type, + relaxer_config=relaxer_config, + mp_entries_file=mp_entries_file, + ) + + # Initialize the relaxer + self.relaxer = get_relaxer(relaxer_type, **relaxer_config) + + # Load MP entries if file provided + self.mp_entries = None + # check if path exists + if mp_entries_file and Path(mp_entries_file).exists(): + self.load_mp_entries(mp_entries_file) + + def load_mp_entries(self, mp_entries_file: str) -> None: + """Load Materials Project entries for e_above_hull calculation. + + Parameters + ---------- + mp_entries_file : str + Path to the MP entries file. + """ + self.mp_entries = get_patched_phase_diagram_mp(Path(mp_entries_file)) + # import pandas as pd + + # try: + # df = pd.read_json(mp_entries_file) + # df = df[df["entry"].apply(lambda x: "GGA" in x["entry_id"])] + + # mp_computed_entries = [ + # ComputedEntry.from_dict(x) + # for x in df.entry + # if "GGA" in x["parameters"]["run_type"] + # ] + # self.mp_entries = [ + # entry + # for entry in mp_computed_entries + # if not np.any(["R2SCAN" in a.name for a in entry.energy_adjustments]) + # ] + # logger.info(f"Loaded {len(self.mp_entries)} MP entries") + # except Exception as e: + # logger.error(f"Failed to load MP entries: {str(e)}") + # self.mp_entries = None + + def _get_process_attributes(self) -> dict[str, Any]: + """Get the attributes for the process_structure method.""" + return { + "relaxer": self.relaxer, + "mp_entries": self.mp_entries, + } + + @staticmethod + def process_structure( + structure: Structure, + relaxer: BaseRelaxer, + mp_entries: Optional[PatchedPhaseDiagram] = None, + ) -> Structure: + """Process a single structure by relaxing it and computing e_above_hull. + + Parameters + ---------- + structure : Structure + A pymatgen Structure object to process. + relaxer : BaseRelaxer + Relaxer object to use. + mp_entries : PatchedPhaseDiagram, optional + MP entries for e_above_hull calculation. + + Returns + ------- + Structure + The processed Structure with relaxed geometry and e_above_hull in properties. + + Raises + ------ + Exception + If relaxation fails or other processing errors occur. + """ + # Relax structure + relaxation_result = relaxer.relax(structure, relax=False) + if not relaxation_result.success: + raise RuntimeError(f"Relaxation failed: {relaxation_result.message}") + + processed_structure = relaxation_result.structure + + # Calculate e_above_hull if MP entries are available + if mp_entries is not None: + try: + cse = generate_CSE(processed_structure, relaxation_result.energy) + e_above_hull = mp_entries.get_e_above_hull(cse, allow_negative=True) + processed_structure.properties["e_above_hull"] = e_above_hull + logger.debug( + f"Computed e_above_hull: {e_above_hull:.3f} eV/atom for {processed_structure.formula}" + ) + except Exception as e: + logger.warning( + f"Failed to compute e_above_hull for {processed_structure.formula}: {str(e)}" + ) + # Still return the relaxed structure even if e_above_hull calculation fails + + # Store additional processing metadata + processed_structure.properties["relaxed_energy"] = relaxation_result.energy + + return processed_structure diff --git a/src/lematerial_forgebench/utils/relaxers/registry.py b/src/lematerial_forgebench/utils/relaxers/registry.py index 23ead105..091091cd 100644 --- a/src/lematerial_forgebench/utils/relaxers/registry.py +++ b/src/lematerial_forgebench/utils/relaxers/registry.py @@ -38,7 +38,7 @@ class BaseRelaxer(ABC): """ @abstractmethod - def relax(self, structure: Structure) -> RelaxationResult: + def relax(self, structure: Structure, relax: bool = False) -> RelaxationResult: """Relax a structure and return the result. Parameters diff --git a/src/lematerial_forgebench/utils/relaxers/relaxers.py b/src/lematerial_forgebench/utils/relaxers/relaxers.py index c15545bf..142b0b45 100644 --- a/src/lematerial_forgebench/utils/relaxers/relaxers.py +++ b/src/lematerial_forgebench/utils/relaxers/relaxers.py @@ -122,13 +122,15 @@ def _setup_calculator(self): """ pass - def relax(self, structure: Structure) -> RelaxationResult: + def relax(self, structure: Structure, relax: bool = False) -> RelaxationResult: """Relax a structure using calculator. Parameters ---------- structure : Structure Structure to relax. + relax : bool, default=False + Only relaxes the structure if True. Returns ------- @@ -149,8 +151,9 @@ def relax(self, structure: Structure) -> RelaxationResult: atoms.calc = self.calc # Relax structure - dyn = FIRE(FrechetCellFilter(atoms), logfile=None) - dyn.run(fmax=self.fmax, steps=self.steps) + if relax: + dyn = FIRE(FrechetCellFilter(atoms), logfile=None) + dyn.run(fmax=self.fmax, steps=self.steps) # Get results final_energy = atoms.get_potential_energy() diff --git a/tests/benchmarks/test_stability_benchmark.py b/tests/benchmarks/test_stability_benchmark.py index 72b827f0..d59f7d4b 100644 --- a/tests/benchmarks/test_stability_benchmark.py +++ b/tests/benchmarks/test_stability_benchmark.py @@ -3,6 +3,9 @@ from pymatgen.util.testing import PymatgenTest from lematerial_forgebench.benchmarks.stability_benchmark import StabilityBenchmark +from lematerial_forgebench.preprocess.stability_preprocess import ( + StabilityPreprocessor, +) class TestStabilityBenchmark: @@ -10,28 +13,21 @@ class TestStabilityBenchmark: def test_initialization_default(self): """Test initialization with default parameters.""" - benchmark = StabilityBenchmark(mp_entries_file=None) + benchmark = StabilityBenchmark() # Check name and properties assert benchmark.config.name == "StabilityBenchmark" assert "version" in benchmark.config.metadata assert benchmark.config.metadata["category"] == "stability" - # Check default relaxer type - assert benchmark.config.metadata["relaxer_type"] == "orb" - # Check correct evaluator - assert len(benchmark.evaluators) == 1 + assert len(benchmark.evaluators) == 2 assert "stability" in benchmark.evaluators def test_initialization_custom(self): """Test initialization with custom relaxer configuration.""" - relaxer_config = {"steps": 2000, "fmax": 0.01, "direct": True} benchmark = StabilityBenchmark( - relaxer_type="orb", - relaxer_config=relaxer_config, - mp_entries_file="custom/path/mp_entries.json", name="Custom Stability Benchmark", description="Custom description", metadata={"test_key": "test_value"}, @@ -41,12 +37,6 @@ def test_initialization_custom(self): assert benchmark.config.name == "Custom Stability Benchmark" assert benchmark.config.description == "Custom description" assert benchmark.config.metadata["test_key"] == "test_value" - assert benchmark.config.metadata["relaxer_type"] == "orb" - assert benchmark.config.metadata["relaxer_config"] == relaxer_config - assert ( - benchmark.config.metadata["mp_entries_file"] - == "custom/path/mp_entries.json" - ) def test_evaluate_with_mp_entries(self): """Test benchmark evaluation on structures""" @@ -56,18 +46,22 @@ def test_evaluate_with_mp_entries(self): test = PymatgenTest() structures = [test.get_structure("Si"), test.get_structure("LiFePO4")] + # first, we need to preprocess the structures + stability_preprocessor = StabilityPreprocessor() + preprocessor_result = stability_preprocessor(structures) + structures = preprocessor_result.processed_structures + print(structures[0].properties) + # Run benchmark result = benchmark.evaluate(structures) # Check result format - assert len(result.evaluator_results) == 1 + assert len(result.evaluator_results) == 2 assert "stability" in result.evaluator_results - assert "stability_score" in result.final_scores assert "stable_ratio" in result.final_scores print(result.final_scores) # Check score types - assert isinstance(result.final_scores["stability_score"], (int, float)) assert isinstance(result.final_scores["stable_ratio"], (int, float)) assert isinstance(result.final_scores["metastable_ratio"], (int, float)) assert isinstance(result.final_scores["mean_e_above_hull"], (int, float)) @@ -80,7 +74,6 @@ def test_empty_structures(self): result = benchmark.evaluate([]) # Should get default values - assert result.final_scores["stability_score"] == 0.0 assert result.final_scores["stable_ratio"] == 0.0 assert result.final_scores["metastable_ratio"] == 0.0 assert result.final_scores["mean_e_above_hull"] == 0.0 @@ -95,10 +88,18 @@ def test_aggregate_evaluator_results(self): mock_evaluator_results_from_base = { "stability": { # Name of the evaluator "combined_value": 0.75, # Evaluator's combined score - # The StabilityMetric instance within the 'stability' evaluator was named "stability". - # Its primary metric (stable_ratio) is flattened by BaseBenchmark to "stability_value". - "stability_value": 0.6, - } + "metric_results": { + "stability": { + "metrics": { + "stable_ratio": 0.75, + "mean_e_above_hull": 0.1, + } + } + }, + }, + "metastability": { + "combined_value": 0.85, + }, } # Aggregate results @@ -109,66 +110,28 @@ def test_aggregate_evaluator_results(self): # and stability_value as stable_ratio. # mean_e_above_hull and metastable_ratio will be defaults (nan, 0.0) because # they are not present in the input dict. - assert scores["stability_score"] == 0.75 - assert scores["stable_ratio"] == 0.6 - assert ( - "mean_e_above_hull" in scores - and scores["mean_e_above_hull"] != scores["mean_e_above_hull"] - ) # Check for NaN - assert scores["metastable_ratio"] == 0.0 + assert scores["stable_ratio"] == 0.75 + assert scores["metastable_ratio"] == 0.85 + assert scores["mean_e_above_hull"] == 0.1 def test_benchmark_metadata(self): """Test benchmark metadata structure.""" - benchmark = StabilityBenchmark( - relaxer_type="orb", - relaxer_config={"direct": True}, - mp_entries_file="test.json", - ) + benchmark = StabilityBenchmark() metadata = benchmark.config.metadata # Check required metadata fields assert metadata["version"] == "0.1.0" assert metadata["category"] == "stability" - assert metadata["relaxer_type"] == "orb" - assert metadata["relaxer_config"]["direct"] is True - assert metadata["mp_entries_file"] == "test.json" - - -def test_benchmark_description_generation(): - """Test automatic description generation.""" - # Test with ORB relaxer - orb_benchmark = StabilityBenchmark(relaxer_type="orb") - assert "ORB" in orb_benchmark.config.description - assert "relaxation and energy above hull" in orb_benchmark.config.description - - # Test with custom description - custom_benchmark = StabilityBenchmark( - relaxer_type="orb", description="Custom description" - ) - assert custom_benchmark.config.description == "Custom description" def test_evaluator_configuration(): """Test that evaluator is properly configured.""" - benchmark = StabilityBenchmark(relaxer_type="orb") + benchmark = StabilityBenchmark() # Check evaluator configuration stability_evaluator = benchmark.evaluators["stability"] print(stability_evaluator) assert stability_evaluator.config.name == "stability" - assert "ORB" in stability_evaluator.config.description assert stability_evaluator.config.weights == {"stability": 1.0} assert stability_evaluator.config.aggregation_method == "weighted_mean" - - -def test_orb_relaxer_type(): - """Test benchmark with ORB relaxer type.""" - benchmark = StabilityBenchmark(relaxer_type="orb") - - # Check that the relaxer type is properly set - assert benchmark.config.metadata["relaxer_type"] == "orb" - - # Check that evaluator name reflects the relaxer type - expected_name = "stability" - assert benchmark.evaluators["stability"].config.name == expected_name diff --git a/tests/metrics/test_stability_metrics.py b/tests/metrics/test_stability_metrics.py index ffd82df5..e7f32288 100644 --- a/tests/metrics/test_stability_metrics.py +++ b/tests/metrics/test_stability_metrics.py @@ -3,7 +3,10 @@ import pytest from pymatgen.util.testing import PymatgenTest -from lematerial_forgebench.metrics.stability_metrics import StabilityMetric +from lematerial_forgebench.metrics.stability_metrics import ( + MetastabilityMetric, + StabilityMetric, +) @pytest.fixture @@ -18,6 +21,27 @@ def test_structures(): return structures +@pytest.fixture +def test_structures_with_precomputed_e_above_hull(): + """Create test structures with precomputed e_above_hull values.""" + test = PymatgenTest() + structures = [ + test.get_structure("Si"), + test.get_structure("LiFePO4"), + test.get_structure("CsCl"), + test.get_structure("NaCl"), + test.get_structure("Fe2O3"), + ] + + # Add precomputed e_above_hull values to structure properties + e_above_hull_values = [0.05, 0.0, 0.15, 0.02, 0.08] # Mix of stable and metastable + + for structure, e_above_hull in zip(structures, e_above_hull_values): + structure.properties["e_above_hull"] = e_above_hull + + return structures + + def test_stability_metric(test_structures): metric = StabilityMetric( relaxer_type="orb", @@ -38,3 +62,55 @@ def test_stability_metric(test_structures): assert result.metrics["stable_ratio"] >= 0.0 assert result.metrics["metastable_ratio"] >= 0.0 + + +def test_metastability_metric(test_structures_with_precomputed_e_above_hull): + """Test MetastabilityMetric with precomputed e_above_hull values.""" + metric = MetastabilityMetric() + + result = metric(test_structures_with_precomputed_e_above_hull) + + # Check that computation ran successfully + assert len(result.individual_values) == len( + test_structures_with_precomputed_e_above_hull + ) + + # Check result structure + assert "mean_e_above_hull" in result.metrics + assert "metastable_ratio" in result.metrics + assert result.primary_metric == "metastable_ratio" + + # Check specific values based on our test data + # e_above_hull_values = [0.05, 0.0, 0.15, 0.02, 0.08] + # metastable (≤ 0.1): 4 structures (all except CsCl with 0.15) + + assert result.metrics["metastable_ratio"] == 4 / 5 # 4 out of 5 structures + assert ( + abs(result.metrics["mean_e_above_hull"] - 0.06) < 1e-6 + ) # (0.05+0.0+0.15+0.02+0.08)/5 = 0.06 + + # Check that all individual values were extracted correctly + expected_values = [0.05, 0.0, 0.15, 0.02, 0.08] + assert result.individual_values == expected_values + + +def test_metastability_metric_missing_properties(): + """Test MetastabilityMetric with structures missing e_above_hull properties.""" + test = PymatgenTest() + structures = [ + test.get_structure("Si"), + test.get_structure("LiFePO4"), + ] + + # Don't add e_above_hull properties - should return NaN values + metric = MetastabilityMetric() + result = metric(structures) + + # Should have NaN values for structures without properties + import numpy as np + + assert all(np.isnan(val) for val in result.individual_values) + + # Metrics should be 0 when all values are NaN + assert result.metrics["metastable_ratio"] == 0.0 + assert result.metrics["mean_e_above_hull"] == 0.0 From 7729d8b59bb740b0989eb9d8e6b8a4d6bb3409f4 Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Sun, 1 Jun 2025 23:19:24 -0700 Subject: [PATCH 4/4] linting fixes --- .../metrics/stability_metrics.py | 12 +++--------- src/lematerial_forgebench/preprocess/base.py | 2 +- .../preprocess/stability_preprocess.py | 2 -- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/lematerial_forgebench/metrics/stability_metrics.py b/src/lematerial_forgebench/metrics/stability_metrics.py index 3f34d547..678e393b 100644 --- a/src/lematerial_forgebench/metrics/stability_metrics.py +++ b/src/lematerial_forgebench/metrics/stability_metrics.py @@ -1,20 +1,14 @@ """Relaxation metrics for evaluating material structures. -This module implements metrics for evaluating the relaxation of -material structures using various relaxation models and calculating -energy above hull. +This module implements metrics for evaluating structure stability and metastability. """ -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any import numpy as np -from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram from pymatgen.core import Structure -from pymatgen.entries.compatibility import MaterialsProject2020Compatibility -from lematerial_forgebench.metrics.base import BaseMetric, MetricConfig +from lematerial_forgebench.metrics.base import BaseMetric from lematerial_forgebench.utils.logging import logger diff --git a/src/lematerial_forgebench/preprocess/base.py b/src/lematerial_forgebench/preprocess/base.py index 629432fb..de35399d 100644 --- a/src/lematerial_forgebench/preprocess/base.py +++ b/src/lematerial_forgebench/preprocess/base.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Any, List, Optional, TypeVar -import numpy as np import pandas as pd from pymatgen.core.structure import Structure @@ -236,6 +235,7 @@ def _split_into_batches( structures[i : i + batch_size] for i in range(0, len(structures), batch_size) ] + return batches def _get_process_attributes(self) -> dict[str, Any]: """Get additional attributes/arguments for the process_structure method. diff --git a/src/lematerial_forgebench/preprocess/stability_preprocess.py b/src/lematerial_forgebench/preprocess/stability_preprocess.py index e0e1f33f..6facbaa0 100644 --- a/src/lematerial_forgebench/preprocess/stability_preprocess.py +++ b/src/lematerial_forgebench/preprocess/stability_preprocess.py @@ -9,10 +9,8 @@ from pathlib import Path from typing import Any, Dict, Optional -import numpy as np from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram from pymatgen.core import Structure -from pymatgen.entries.compatibility import MaterialsProject2020Compatibility from lematerial_forgebench.preprocess.base import BasePreprocessor, PreprocessorConfig from lematerial_forgebench.utils.e_above_hull import (