diff --git a/.gitignore b/.gitignore index a4c6f414..6cc79689 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,5 @@ cython_debug/ lightning_logs/ src/lematerial_forgebench/.DS_Store +# Ignore .gz +*.gz \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e56e6f7f..6fee2720 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,3 +57,4 @@ material-hasher = { git = "https://github.com/lematerial/material-hasher.git" } [tool.ruff.lint] extend-select = ["I"] +ignore = ["F401"] \ No newline at end of file diff --git a/src/lematerial_forgebench/benchmarks/stability_benchmark.py b/src/lematerial_forgebench/benchmarks/stability_benchmark.py new file mode 100644 index 00000000..d0d8de61 --- /dev/null +++ b/src/lematerial_forgebench/benchmarks/stability_benchmark.py @@ -0,0 +1,147 @@ +"""Stability benchmark for material structures. + +This module implements a benchmark that evaluates the stability of +generated material structures using various relaxation methods. +""" + +from typing import Any, Dict + +from lematerial_forgebench.benchmarks.base import BaseBenchmark +from lematerial_forgebench.evaluator import EvaluationResult, EvaluatorConfig +from lematerial_forgebench.metrics.stability_metrics import ( + MetastabilityMetric, + StabilityMetric, +) + + +class StabilityBenchmark(BaseBenchmark): + """Benchmark for evaluating the stability of generated material structures.""" + + def __init__( + self, + name: str = "StabilityBenchmark", + description: str | None = None, + metadata: Dict[str, Any] | None = None, + ): + """Initialize the stability benchmark. + + Parameters + ---------- + relaxer_type : str, default="orb" + Type of relaxer to use (e.g., "orb", "ocp"). + relaxer_config : dict, optional + Configuration for the relaxer. If None, uses default config. + mp_entries_file : str + Path to the Materials Project entries file. + name : str + Name of the benchmark. + description : str, optional + Description of the benchmark. + metadata : dict, optional + Additional metadata for the benchmark. + """ + if description is None: + description = ( + "Evaluates the stability and metastability of crystal structures" + ) + + # Initialize the stability metric + stability_metric = StabilityMetric() + + # Set up evaluator configs + evaluator_configs = { + "stability": EvaluatorConfig( + name="Stability", + description="Evaluates structure stability", + metrics={"stability": stability_metric}, + weights={"stability": 1.0}, + aggregation_method="weighted_mean", + ), + } + + # Add metastability evaluator if requested + metastability_metric = MetastabilityMetric() + evaluator_configs["metastability"] = EvaluatorConfig( + name="Metastability Analysis", + description="Evaluates metastability from precomputed e_above_hull values", + metrics={"metastability": metastability_metric}, + weights={"metastability": 1.0}, + aggregation_method="weighted_mean", + ) + + # Create benchmark metadata + benchmark_metadata = { + "version": "0.1.0", + "category": "stability", + **(metadata or {}), + } + + super().__init__( + name=name, + description=description, + evaluator_configs=evaluator_configs, + metadata=benchmark_metadata, + ) + + def aggregate_evaluator_results( + self, evaluator_results: Dict[str, EvaluationResult] + ) -> Dict[str, float]: + """Aggregate results from multiple evaluators into final scores. + + Parameters + ---------- + evaluator_results : dict[str, EvaluationResult] + Results from each evaluator. + + Returns + ------- + dict[str, float] + Final aggregated scores. + """ + import math + + def safe_float(value, default=0.0): + """Safely convert value to float, handling None and NaN.""" + if value is None: + return default + try: + float_val = float(value) + if math.isnan(float_val): + return default + return float_val + except (TypeError, ValueError): + return default + + final_scores = { + "stable_ratio": 0.0, + "metastable_ratio": 0.0, + "mean_e_above_hull": 0.0, + } + + # Extract stability results + stability_results = evaluator_results.get("stability") + if stability_results: + # Main stability ratio + final_scores["stable_ratio"] = safe_float( + stability_results.get("combined_value") + ) + + # Extract individual metrics from stability metric + stability_metric_results = stability_results.get("metric_results", {}).get( + "stability", {} + ) + stability_metrics = stability_metric_results.get("metrics", {}) + + final_scores["mean_e_above_hull"] = safe_float( + stability_metrics.get("mean_e_above_hull") + ) + + # Extract metastability results if available + metastability_results = evaluator_results.get("metastability") + if metastability_results: + # Main metastability score + final_scores["metastable_ratio"] = safe_float( + metastability_results.get("combined_value") + ) + + return final_scores diff --git a/src/lematerial_forgebench/cli.py b/src/lematerial_forgebench/cli.py index 7876c2c3..63f1393a 100644 --- a/src/lematerial_forgebench/cli.py +++ b/src/lematerial_forgebench/cli.py @@ -11,6 +11,7 @@ import yaml from lematerial_forgebench.benchmarks.example import ExampleBenchmark +from lematerial_forgebench.benchmarks.stability_benchmark import StabilityBenchmark from lematerial_forgebench.benchmarks.validity_benchmark import ValidityBenchmark from lematerial_forgebench.data.structure import format_structures from lematerial_forgebench.metrics.validity_metrics import ( @@ -19,6 +20,9 @@ MinimumInteratomicDistanceMetric, PhysicalPlausibilityMetric, ) +from lematerial_forgebench.preprocess.stability_preprocess import ( + StabilityPreprocessor, +) from lematerial_forgebench.utils.logging import logger CONFIGS_DIR = Path(__file__).parent.parent / "config" @@ -169,13 +173,9 @@ def main(input: str, config_name: str, output: str): coord_tolerance = coord_config.get("tolerance", 0.2) # Create custom metrics with configuration - ChargeNeutralityMetric( - tolerance=charge_tolerance, strict=charge_strict - ) + ChargeNeutralityMetric(tolerance=charge_tolerance, strict=charge_strict) - MinimumInteratomicDistanceMetric( - scaling_factor=distance_scaling - ) + MinimumInteratomicDistanceMetric(scaling_factor=distance_scaling) CoordinationEnvironmentMetric( nn_method=coord_nn_method, tolerance=coord_tolerance @@ -196,6 +196,14 @@ def main(input: str, config_name: str, output: str): "metric_configs": metric_configs, }, ) + elif benchmark_type == "stability": + # before running the benchmark, we need to preprocess the structures + stability_preprocessor = StabilityPreprocessor() + # Use the preprocessor to process structures + preprocessor_result = stability_preprocessor(structures) + structures = preprocessor_result.processed_structures + + benchmark = StabilityBenchmark() else: raise ValueError(f"Unknown benchmark type: {benchmark_type}") diff --git a/src/lematerial_forgebench/metrics/stability_metrics.py b/src/lematerial_forgebench/metrics/stability_metrics.py new file mode 100644 index 00000000..678e393b --- /dev/null +++ b/src/lematerial_forgebench/metrics/stability_metrics.py @@ -0,0 +1,203 @@ +"""Relaxation metrics for evaluating material structures. + +This module implements metrics for evaluating structure stability and metastability. +""" + +from typing import Any + +import numpy as np +from pymatgen.core import Structure + +from lematerial_forgebench.metrics.base import BaseMetric +from lematerial_forgebench.utils.logging import logger + + +class StabilityMetric(BaseMetric): + """Evaluate structure metastability using precomputed e_above_hull values. + + This metric assumes that e_above_hull values have already been computed + and stored in structure.properties['e_above_hull']. It calculates stability + statistics without performing any relaxation or recomputation. + """ + + def __init__( + self, + name: str | None = None, + description: str | None = None, + lower_is_better: bool = False, + n_jobs: int = 1, + ): + super().__init__( + name=name or "StabilityMetric", + description=description + or "Evaluates structure stability from precomputed e_above_hull", + lower_is_better=lower_is_better, + n_jobs=n_jobs, + ) + + def _get_compute_attributes(self) -> dict[str, Any]: + """Get the attributes for the compute_structure method.""" + return {} + + @staticmethod + def compute_structure(structure: Structure) -> float: + """Extract precomputed e_above_hull from structure properties. + + Parameters + ---------- + structure : Structure + A pymatgen Structure object with e_above_hull in properties. + + Returns + ------- + float + The precomputed e_above_hull value, or NaN if not available. + """ + try: + # Extract e_above_hull from structure properties + e_above_hull = structure.properties.get("e_above_hull", None) + + if e_above_hull is None: + logger.warning( + "Structure missing e_above_hull in properties, please compute it first using StabilityMetric" + ) + return np.nan + + return float(e_above_hull) + + except Exception as e: + logger.error(f"Failed to extract e_above_hull: {str(e)}") + return np.nan + + def aggregate_results(self, values: list[float]) -> dict[str, Any]: + """Aggregate results into final metric values. + + Parameters + ---------- + values : list[float] + List of e_above_hull values for each structure. + + Returns + ------- + dict + Dictionary with aggregated metrics. + """ + # Convert to numpy array for efficient operations + values_array = np.array(values) + + # Filter out NaN values + valid_mask = ~np.isnan(values_array) + e_above_hull_values = values_array[valid_mask] + + if len(e_above_hull_values) > 0: + # Calculate ratio of stable structures (e_above_hull <= 0) using numpy + stable_ratio = np.sum(e_above_hull_values <= 0) / len(values) + e_above_hull_std = np.std(e_above_hull_values) + mean_e_above_hull = np.mean(e_above_hull_values) + else: + stable_ratio = 0.0 + e_above_hull_std = 0.0 + mean_e_above_hull = 0.0 + + return { + "metrics": { + "stable_ratio": stable_ratio, + "mean_e_above_hull": mean_e_above_hull, + }, + "primary_metric": "stable_ratio", + "uncertainties": { + "e_above_hull_std": {"std": e_above_hull_std}, + }, + } + + +class MetastabilityMetric(BaseMetric): + """Evaluate structure metastability using precomputed e_above_hull values. + + This metric assumes that e_above_hull values have already been computed + and stored in structure.properties['e_above_hull']. It calculates stability + statistics without performing any relaxation or recomputation. + """ + + def __init__( + self, + name: str | None = None, + description: str | None = None, + lower_is_better: bool = True, + n_jobs: int = 1, + ): + super().__init__( + name=name or "MetastabilityMetric", + description=description + or "Evaluates structure metastability from precomputed e_above_hull", + lower_is_better=lower_is_better, + n_jobs=n_jobs, + ) + + def _get_compute_attributes(self) -> dict[str, Any]: + """Get the attributes for the compute_structure method.""" + return {} + + @staticmethod + def compute_structure(structure: Structure) -> float: + """Extract precomputed e_above_hull from structure properties. + + Parameters + ---------- + structure : Structure + A pymatgen Structure object with e_above_hull in properties. + + Returns + ------- + float + The precomputed e_above_hull value, or NaN if not available. + """ + try: + # Extract e_above_hull from structure properties + e_above_hull = structure.properties.get("e_above_hull", None) + + if e_above_hull is None: + logger.warning( + f"Structure `{structure.formula}` missing e_above_hull in properties, please compute it first using StabilityPreprocessor" + ) + return np.nan + + return float(e_above_hull) + + except Exception as e: + logger.error(f"Failed to extract e_above_hull: {str(e)}") + return np.nan + + def aggregate_results(self, values: list[float]) -> dict[str, Any]: + """Aggregate results into final metric values. + + Parameters + ---------- + values : list[float] + List of e_above_hull values for each structure. + + Returns + ------- + dict + Dictionary with aggregated metrics. + """ + # Convert to numpy array for efficient operations + values_array = np.array(values) + + # Filter out NaN values + valid_mask = ~np.isnan(values_array) + e_above_hull_values = values_array[valid_mask] + + if len(e_above_hull_values) > 0: + # Calculate ratio of metastable structures (e_above_hull <= 0.1) using numpy + metastable_ratio = np.sum(e_above_hull_values <= 0.1) / len(values) + else: + metastable_ratio = 0.0 + + return { + "metrics": { + "metastable_ratio": metastable_ratio, + }, + "primary_metric": "metastable_ratio", + "uncertainties": {}, + } diff --git a/src/lematerial_forgebench/preprocess/__init__.py b/src/lematerial_forgebench/preprocess/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lematerial_forgebench/preprocess/base.py b/src/lematerial_forgebench/preprocess/base.py new file mode 100644 index 00000000..de35399d --- /dev/null +++ b/src/lematerial_forgebench/preprocess/base.py @@ -0,0 +1,407 @@ +import time +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional, TypeVar + +import pandas as pd +from pymatgen.core.structure import Structure + +from lematerial_forgebench.data.structure import format_structures +from lematerial_forgebench.utils.logging import logger + +PreprocessorClassVar = TypeVar("PreprocessorClassVar", bound="BasePreprocessor") + + +@dataclass +class PreprocessorConfig: + """Base configuration for all preprocessors. + + This class defines the common configuration parameters shared by all preprocessors. + Specific preprocessors should inherit from this class and add their own parameters. + + Parameters + ---------- + name : str, optional + Custom name for the preprocessor. If None, the class name will be used. + description : str, optional + Description of what the preprocessor does. + n_jobs : int, default=1 + Number of parallel jobs to run. + """ + + name: str | None = None + description: str | None = None + n_jobs: int = 1 + + def to_dict(self) -> dict[str, Any]: + """Convert the preprocessor configuration to a dictionary for serialization. + + Returns + ------- + dict[str, Any] + """ + return { + "name": self.name, + "description": self.description, + "n_jobs": self.n_jobs, + } + + +@dataclass +class PreprocessorResult: + """Result of a preprocessing computation. + + Parameters + ---------- + processed_structures : list[Structure] + List of successfully processed pymatgen Structure objects. + config : PreprocessorConfig + The configuration used for this preprocessing task. + computation_time : float + The time taken to complete the preprocessing. + n_input_structures : int + The total number of structures provided as input. + failed_indices : list[int] + The indices (from the original input list) of structures that failed processing. + warnings : list[str] + The warnings generated during the computation. + """ + + processed_structures: list[Structure] + config: PreprocessorConfig + computation_time: float + n_input_structures: int + failed_indices: list[int] + warnings: list[str] + + def __post_init__(self): + """Validate the preprocessor result.""" + if self.n_input_structures < ( + len(self.processed_structures) + len(self.failed_indices) + ): + # This check might be too strict if one input can result in multiple outputs or some other complex scenario + # For now, it assumes a one-to-one mapping or failure. + logger.warning( + "Number of input structures is less than processed + failed. This might indicate an issue." + ) + if self.config is None: + raise ValueError("config cannot be None") + + +class BasePreprocessor(ABC): + """Base class for all preprocessors used in materials science workflows. + + This class defines the interface for preprocessors and provides common functionality + like parallelization. + + Parameters + ---------- + name : str, optional + Custom name for the preprocessor. If None, the class name will be used. + description : str, optional + Description of what the preprocessor does. + n_jobs : int, default=1 + Number of parallel jobs to run. + """ + + def __init__( + self, + name: str | None = None, + description: str | None = None, + n_jobs: int = 1, + ): + self.config = PreprocessorConfig( + name=name or self.__class__.__name__, + description=description, + n_jobs=n_jobs, + ) + + @property + def name(self) -> str: + """Get the name of the preprocessor. + + Returns + ------- + str + """ + return self.config.name + + @property + def description(self) -> str: + """Get the description of the preprocessor. + + Returns + ------- + str + """ + return self.config.description or "No description provided." + + @staticmethod + @abstractmethod + def process_structure(structure: Structure, **process_args: Any) -> Structure: + """Process a single structure. + + This is the main method to implement in derived classes. + It should take a pymatgen Structure object, perform operations on it + (e.g., add properties, modify it), and return the processed Structure object. + If processing fails, it should raise an exception. + + Parameters + ---------- + structure : Structure + A pymatgen Structure object to process. + **process_args : Any + Additional keyword arguments that depend on the preprocessor implementation. + + Returns + ------- + Structure + The processed (e.g., annotated or modified) pymatgen Structure object. + + Raises + ------ + Exception + If processing of the structure fails. + """ + pass + + @staticmethod + def _process_batch( + structures: list[Structure], + process_args: dict[str, Any], + preprocessor_class: PreprocessorClassVar, + ) -> tuple[List[Optional[Structure]], List[int], List[str]]: + """Process a batch of structures. + + Parameters + ---------- + structures : list[Structure] + Batch of structures to process. + process_args : dict[str, Any] + Additional keyword arguments for the process_structure method. + preprocessor_class : PreprocessorClassVar + The preprocessor class to use for the computation. + + Returns + ------- + tuple[List[Optional[Structure]], List[int], List[str]] + Tuple containing: + - List of processed Structure objects or None for failures. + - List of indices (local to batch) of structures that failed. + - List of warning messages for failures. + """ + batch_results: List[Optional[Structure]] = [] + failed_indices_in_batch: List[int] = [] + warnings_for_batch: List[str] = [] + + for idx, structure in enumerate(structures): + try: + processed_structure = preprocessor_class.process_structure( + structure, **process_args + ) + batch_results.append(processed_structure) + except Exception as e: + batch_results.append(None) + failed_indices_in_batch.append(idx) + warnings_for_batch.append( + f"Failed to process structure at batch index {idx} (original may vary): {str(e)}" + ) + logger.debug( + f"Failed to process structure in batch for {preprocessor_class.name}", + exc_info=True, + ) + return batch_results, failed_indices_in_batch, warnings_for_batch + + def _split_into_batches( + self, structures: list[Structure], batch_size: int + ) -> list[list[Structure]]: + """Split structures into batches. + + Parameters + ---------- + structures : list[Structure] + List of structures to split. + batch_size : int + Size of each batch. + + Returns + ------- + list[list[Structure]] + List of batches of structures. + """ + batches = [ + structures[i : i + batch_size] + for i in range(0, len(structures), batch_size) + ] + return batches + + def _get_process_attributes(self) -> dict[str, Any]: + """Get additional attributes/arguments for the process_structure method. + Subclasses can override this to pass dynamic configuration. + + Returns + ------- + dict[str, Any] + """ + return {} + + def run( + self, + structures: list[Structure], + ) -> PreprocessorResult: + """Run the preprocessing on a list of structures. + + Parameters + ---------- + structures : list[Structure] + List of pymatgen Structure objects to process. + + Returns + ------- + PreprocessorResult + Object containing the processed structures and computation metadata. + """ + start_time = time.time() + n_input = len(structures) + + # This list will hold results in the original order, with None for failures. + all_results_with_nones: List[Optional[Structure]] = [None] * n_input + global_failed_indices: List[int] = [] + global_warnings: List[str] = [] + + process_args = self._get_process_attributes() + + try: + if ( + self.config.n_jobs <= 1 or n_input <= 1 + ): # Also run serially for single structure + # Serial computation + for idx, structure in enumerate(structures): + try: + processed_structure = self.process_structure( + structure, **process_args + ) + all_results_with_nones[idx] = processed_structure + except Exception as e: + global_failed_indices.append(idx) + global_warnings.append( + f"Failed to process structure {idx}: {str(e)}" + ) + logger.warning( + f"Failed to process structure {idx} for {self.name}", + exc_info=True, + ) + else: + # Parallel computation + # Ensure batch_size is at least 1, and doesn't create more batches than structures or jobs + num_workers = min(self.config.n_jobs, n_input) + batch_size = max( + 1, (n_input + num_workers - 1) // num_workers + ) # Distribute as evenly as possible + + batches = self._split_into_batches(structures, batch_size) + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit( + self._process_batch, batch, process_args, self.__class__ + ) + for batch in batches + ] + + current_original_idx_offset = 0 + for future in futures: + ( + batch_processed_structures_or_nones, + failed_indices_in_batch, + warnings_for_batch, + ) = future.result() + + for i, struct_or_none in enumerate( + batch_processed_structures_or_nones + ): + original_idx = current_original_idx_offset + i + if original_idx < n_input: # Boundary check + all_results_with_nones[original_idx] = struct_or_none + + global_failed_indices.extend( + [ + current_original_idx_offset + i + for i in failed_indices_in_batch + ] + ) + global_warnings.extend(warnings_for_batch) + current_original_idx_offset += len( + batch_processed_structures_or_nones + ) + + # Filter out Nones to get successfully processed structures + final_processed_structures = [ + s for s in all_results_with_nones if s is not None + ] + + except Exception as e: + logger.error(f"Global failure in preprocessor {self.name}", exc_info=True) + return PreprocessorResult( + processed_structures=[], + config=self.config, + computation_time=time.time() - start_time, + n_input_structures=n_input, + failed_indices=list(range(n_input)), # All failed + warnings=[f"Global preprocessing failure for {self.name}: {str(e)}"] + * n_input, + ) + + return PreprocessorResult( + processed_structures=final_processed_structures, + config=self.config, + computation_time=time.time() - start_time, + n_input_structures=n_input, + failed_indices=sorted( + list(set(global_failed_indices)) + ), # Ensure unique and sorted + warnings=global_warnings, + ) + + def __call__( + self, + structures: list[Structure] | list[dict] | pd.DataFrame | str | Path, + ) -> PreprocessorResult: + """Convenient callable interface for running the preprocessor. + + Parameters + ---------- + structures : list[Structure] | list[dict] | pd.DataFrame | str | Path + Structures to process, in various supported formats. + + Returns + ------- + PreprocessorResult + Object containing the processed structures and computation metadata. + """ + structures_list = format_structures(structures) + if not isinstance( + structures_list, list + ): # Ensure it's a list for consistent processing + structures_list = list(structures_list) + return self.run(structures_list) + + @classmethod + def from_config(cls, config: PreprocessorConfig) -> PreprocessorClassVar: + """Create a preprocessor from a configuration. + + Parameters + ---------- + config : PreprocessorConfig + Configuration for the preprocessor. + """ + # Extract only relevant args for BasePreprocessor constructor + # Subclasses might need to handle additional args from their specific configs + base_args = { + k: v + for k, v in config.to_dict().items() + if k in ["name", "description", "n_jobs"] + } + return cls(**base_args) diff --git a/src/lematerial_forgebench/preprocess/stability_preprocess.py b/src/lematerial_forgebench/preprocess/stability_preprocess.py new file mode 100644 index 00000000..6facbaa0 --- /dev/null +++ b/src/lematerial_forgebench/preprocess/stability_preprocess.py @@ -0,0 +1,194 @@ +"""Relaxation metrics for evaluating material structures. + +This module implements metrics for evaluating the relaxation of +material structures using various relaxation models and calculating +energy above hull. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional + +from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram +from pymatgen.core import Structure + +from lematerial_forgebench.preprocess.base import BasePreprocessor, PreprocessorConfig +from lematerial_forgebench.utils.e_above_hull import ( + generate_CSE, + get_patched_phase_diagram_mp, +) +from lematerial_forgebench.utils.logging import logger +from lematerial_forgebench.utils.relaxers import ( + BaseRelaxer, + get_relaxer, + relaxers, +) + + +@dataclass +class StabilityPreprocessorConfig(PreprocessorConfig): + """Configuration for the StabilityPreprocessor. + + Parameters + ---------- + relaxer_type : str + Type of relaxer to use (e.g., "chgnet", "eqv2", "esen"). + relaxer_config : dict + Configuration for the specific relaxer. + mp_entries_file : str, optional + Path to the Materials Project entries file for e_above_hull calculation. + """ + + relaxer_type: str = "orb" + relaxer_config: Dict[str, Any] = field(default_factory=dict) + mp_entries_file: Optional[str] = None + + +class StabilityPreprocessor(BasePreprocessor): + """Evaluate structure relaxation and energy above hull. + + This metric handles both the relaxation of structures using various models + and the calculation of energy above hull using the Materials Project database. + + Parameters + ---------- + relaxer_type : str + Type of relaxer to use. + relaxer_config : dict + Configuration for the specific relaxer. + mp_entries_file : str, optional + Path to the Materials Project entries file. + name : str, optional + Custom name for the metric. + description : str, optional + Description of what the metric measures. + lower_is_better : bool, default=True + Lower energies indicate better stability. + n_jobs : int, default=1 + Number of parallel jobs to run. + """ + + def __init__( + self, + relaxer_type: str = "orb", + relaxer_config: Dict[str, Any] = {"fmax": 0.02, "steps": 500}, + mp_entries_file: Optional[ + str + ] = "src/lematerial_forgebench/utils/relaxers/2023-02-07-ppd-mp.pkl.gz", + name: str | None = None, + description: str | None = None, + n_jobs: int = 1, + ): + super().__init__( + name=name or "StabilityPreprocessor", + description=description or "Preprocesses structures for stability analysis", + n_jobs=n_jobs, + ) + self.config = StabilityPreprocessorConfig( + name=self.config.name, + description=self.config.description, + n_jobs=self.config.n_jobs, + relaxer_type=relaxer_type, + relaxer_config=relaxer_config, + mp_entries_file=mp_entries_file, + ) + + # Initialize the relaxer + self.relaxer = get_relaxer(relaxer_type, **relaxer_config) + + # Load MP entries if file provided + self.mp_entries = None + # check if path exists + if mp_entries_file and Path(mp_entries_file).exists(): + self.load_mp_entries(mp_entries_file) + + def load_mp_entries(self, mp_entries_file: str) -> None: + """Load Materials Project entries for e_above_hull calculation. + + Parameters + ---------- + mp_entries_file : str + Path to the MP entries file. + """ + self.mp_entries = get_patched_phase_diagram_mp(Path(mp_entries_file)) + # import pandas as pd + + # try: + # df = pd.read_json(mp_entries_file) + # df = df[df["entry"].apply(lambda x: "GGA" in x["entry_id"])] + + # mp_computed_entries = [ + # ComputedEntry.from_dict(x) + # for x in df.entry + # if "GGA" in x["parameters"]["run_type"] + # ] + # self.mp_entries = [ + # entry + # for entry in mp_computed_entries + # if not np.any(["R2SCAN" in a.name for a in entry.energy_adjustments]) + # ] + # logger.info(f"Loaded {len(self.mp_entries)} MP entries") + # except Exception as e: + # logger.error(f"Failed to load MP entries: {str(e)}") + # self.mp_entries = None + + def _get_process_attributes(self) -> dict[str, Any]: + """Get the attributes for the process_structure method.""" + return { + "relaxer": self.relaxer, + "mp_entries": self.mp_entries, + } + + @staticmethod + def process_structure( + structure: Structure, + relaxer: BaseRelaxer, + mp_entries: Optional[PatchedPhaseDiagram] = None, + ) -> Structure: + """Process a single structure by relaxing it and computing e_above_hull. + + Parameters + ---------- + structure : Structure + A pymatgen Structure object to process. + relaxer : BaseRelaxer + Relaxer object to use. + mp_entries : PatchedPhaseDiagram, optional + MP entries for e_above_hull calculation. + + Returns + ------- + Structure + The processed Structure with relaxed geometry and e_above_hull in properties. + + Raises + ------ + Exception + If relaxation fails or other processing errors occur. + """ + # Relax structure + relaxation_result = relaxer.relax(structure, relax=False) + if not relaxation_result.success: + raise RuntimeError(f"Relaxation failed: {relaxation_result.message}") + + processed_structure = relaxation_result.structure + + # Calculate e_above_hull if MP entries are available + if mp_entries is not None: + try: + cse = generate_CSE(processed_structure, relaxation_result.energy) + e_above_hull = mp_entries.get_e_above_hull(cse, allow_negative=True) + processed_structure.properties["e_above_hull"] = e_above_hull + logger.debug( + f"Computed e_above_hull: {e_above_hull:.3f} eV/atom for {processed_structure.formula}" + ) + except Exception as e: + logger.warning( + f"Failed to compute e_above_hull for {processed_structure.formula}: {str(e)}" + ) + # Still return the relaxed structure even if e_above_hull calculation fails + + # Store additional processing metadata + processed_structure.properties["relaxed_energy"] = relaxation_result.energy + + return processed_structure diff --git a/src/lematerial_forgebench/utils/e_above_hull.py b/src/lematerial_forgebench/utils/e_above_hull.py new file mode 100644 index 00000000..da207bce --- /dev/null +++ b/src/lematerial_forgebench/utils/e_above_hull.py @@ -0,0 +1,65 @@ +"""Util functions for e_above_hull calculation.""" + +import gzip +import pickle +import tempfile +from pathlib import Path + +from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram +from pymatgen.core import Structure +from pymatgen.entries.compatibility import MaterialsProject2020Compatibility +from pymatgen.entries.computed_entries import ComputedStructureEntry +from pymatgen.io.vasp.inputs import Incar, Poscar +from pymatgen.io.vasp.sets import MPRelaxSet + + +def get_patched_phase_diagram_mp(path: Path) -> PatchedPhaseDiagram: + # Check if the file has a .gz extension + if path.suffix == ".gz": + # Open as gzip file + with gzip.open(path, "rb") as f: + ppd_mp = pickle.load(f) + else: + # Open as regular file + with open(path, "rb") as f: + ppd_mp = pickle.load(f) + return ppd_mp + + +def generate_CSE(structure, energy): + # Write VASP inputs files as if we were going to do a standard MP run + # this is mainly necessary to get the right U values / etc + b = MPRelaxSet(structure) + with tempfile.TemporaryDirectory() as tmpdirname: + b.write_input(f"{tmpdirname}/", potcar_spec=True) + poscar = Poscar.from_file(f"{tmpdirname}/POSCAR") + incar = Incar.from_file(f"{tmpdirname}/INCAR") + clean_structure = Structure.from_file(f"{tmpdirname}/POSCAR") + + # Get the U values and figure out if we should have run a GGA+U calc + param = {"hubbards": {}} + if "LDAUU" in incar: + param["hubbards"] = dict(zip(poscar.site_symbols, incar["LDAUU"])) + param["is_hubbard"] = ( + incar.get("LDAU", True) and sum(param["hubbards"].values()) > 0 + ) + if param["is_hubbard"]: + param["run_type"] = "GGA+U" + + # Make a ComputedStructureEntry without the correction + cse_d = { + "structure": clean_structure, + "energy": energy, + "correction": 0.0, + "parameters": param, + } + + # Apply the MP 2020 correction scheme (anion/+U/etc) + cse = ComputedStructureEntry.from_dict(cse_d) + _ = MaterialsProject2020Compatibility(check_potcar=False).process_entries( + cse, + clean=True, + ) + + # Return the final CSE (notice that the composition/etc is also clean, not things like Fe3+)! + return cse diff --git a/src/lematerial_forgebench/utils/relaxers/__init__.py b/src/lematerial_forgebench/utils/relaxers/__init__.py new file mode 100644 index 00000000..00e43d6e --- /dev/null +++ b/src/lematerial_forgebench/utils/relaxers/__init__.py @@ -0,0 +1,16 @@ +"""Relaxer implementations package. + +This module ensures all relaxer implementations are imported and registered +when the package is used. +""" + +# Import the registry functions +# Import all implementations to trigger registration +from .registry import BaseRelaxer, RelaxationResult, get_relaxer, register_relaxer + +__all__ = [ + "BaseRelaxer", + "RelaxationResult", + "get_relaxer", + "register_relaxer", +] diff --git a/src/lematerial_forgebench/utils/relaxers/registry.py b/src/lematerial_forgebench/utils/relaxers/registry.py new file mode 100644 index 00000000..091091cd --- /dev/null +++ b/src/lematerial_forgebench/utils/relaxers/registry.py @@ -0,0 +1,128 @@ +"""Registry for relaxer implementations and base classes.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, Type + +from pymatgen.core import Structure +from pymatgen.entries.computed_entries import ComputedStructureEntry + + +@dataclass +class RelaxationResult: + """Result of a structure relaxation. + + Parameters + ---------- + success : bool + Whether the relaxation was successful. + energy : float | None + Final energy of the relaxed structure. + structure : Structure | None + Final relaxed structure. + message : str | None + Error message if relaxation failed. + """ + + success: bool + energy: float | None = None + structure: Structure | None = None + message: str | None = None + + +class BaseRelaxer(ABC): + """Base class for structure relaxation implementations. + + All relaxer implementations should inherit from this class and implement + the required methods. + """ + + @abstractmethod + def relax(self, structure: Structure, relax: bool = False) -> RelaxationResult: + """Relax a structure and return the result. + + Parameters + ---------- + structure : Structure + Structure to relax. + + Returns + ------- + RelaxationResult + Result of the relaxation. + """ + pass + + @abstractmethod + def get_computed_entry( + self, structure: Structure, energy: float + ) -> ComputedStructureEntry: + """Create a ComputedStructureEntry from a relaxed structure. + + Parameters + ---------- + structure : Structure + The relaxed structure. + energy : float + The energy of the relaxed structure. + + Returns + ------- + ComputedStructureEntry + The computed structure entry with appropriate corrections applied. + """ + pass + + +_RELAXER_REGISTRY: Dict[str, Type[BaseRelaxer]] = {} + + +def register_relaxer(name: str): + """Register a relaxer implementation. + + Parameters + ---------- + name : str + Name of the relaxer. + + Returns + ------- + callable + Decorator function. + """ + + def decorator(cls: Type[BaseRelaxer]) -> Type[BaseRelaxer]: + if not issubclass(cls, BaseRelaxer): + raise TypeError(f"{cls.__name__} must inherit from BaseRelaxer") + _RELAXER_REGISTRY[name] = cls + return cls + + return decorator + + +def get_relaxer(relaxer_type: str, **kwargs) -> BaseRelaxer: + """Get a relaxer implementation by name. + + Parameters + ---------- + relaxer_type : str + Name of the relaxer. + **kwargs + Additional arguments to pass to the relaxer constructor. + + Returns + ------- + BaseRelaxer + Relaxer instance. + + Raises + ------ + ValueError + If relaxer_type is not registered. + """ + if relaxer_type not in _RELAXER_REGISTRY: + raise ValueError( + f"Unknown relaxer type: {relaxer_type}. " + f"Available types: {list(_RELAXER_REGISTRY.keys())}" + ) + return _RELAXER_REGISTRY[relaxer_type](**kwargs) diff --git a/src/lematerial_forgebench/utils/relaxers/relaxers.py b/src/lematerial_forgebench/utils/relaxers/relaxers.py new file mode 100644 index 00000000..142b0b45 --- /dev/null +++ b/src/lematerial_forgebench/utils/relaxers/relaxers.py @@ -0,0 +1,208 @@ +"""Implementations for structure relaxation.""" + +import gc +import tempfile +from abc import abstractmethod + +import torch +from ase.filters import FrechetCellFilter +from ase.optimize import FIRE + +# from fairchem.core import OCPCalculator +from orb_models.forcefield import pretrained +from orb_models.forcefield.calculator import ORBCalculator +from pymatgen.core import Structure +from pymatgen.entries.compatibility import MaterialsProject2020Compatibility +from pymatgen.entries.computed_entries import ComputedStructureEntry +from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.io.vasp.inputs import Incar, Poscar +from pymatgen.io.vasp.sets import MPRelaxSet + +from lematerial_forgebench.utils.logging import logger +from lematerial_forgebench.utils.relaxers.registry import ( + BaseRelaxer, + RelaxationResult, + register_relaxer, +) + + +class BaseVASPRelaxer(BaseRelaxer): + """Base class for relaxers that use VASP-like parameters.""" + + def get_computed_entry( + self, structure: Structure, energy: float + ) -> ComputedStructureEntry: + """Create a ComputedStructureEntry from a relaxed structure. + + Parameters + ---------- + structure : Structure + The relaxed structure. + energy : float + The energy of the relaxed structure. + + Returns + ------- + ComputedStructureEntry + The computed structure entry with MP2020 corrections applied. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + b = MPRelaxSet(structure) + b.write_input(f"{tmpdirname}/", potcar_spec=True) + poscar = Poscar.from_file(f"{tmpdirname}/POSCAR") + incar = Incar.from_file(f"{tmpdirname}/INCAR") + clean_structure = Structure.from_file(f"{tmpdirname}/POSCAR") + + # Get the U values and figure out if we should have run a GGA+U calc + param = {"hubbards": {}} + if "LDAUU" in incar: + param["hubbards"] = dict(zip(poscar.site_symbols, incar["LDAUU"])) + param["is_hubbard"] = ( + incar.get("LDAU", True) and sum(param["hubbards"].values()) > 0 + ) + if param["is_hubbard"]: + param["run_type"] = "GGA+U" + + # Make a ComputedStructureEntry without the correction + cse_d = { + "structure": clean_structure, + "energy": energy, + "correction": 0.0, + "parameters": param, + } + + # Apply the MP 2020 correction scheme (anion/+U/etc) + cse = ComputedStructureEntry.from_dict(cse_d) + _ = MaterialsProject2020Compatibility(check_potcar=False).process_entries( + cse, + clean=True, + ) + + return cse + + +class ASERelaxerBase(BaseVASPRelaxer): + """Base class for ASE-based relaxers.""" + + def __init__( + self, + fmax: float = 0.02, + steps: int = 500, + cpu: bool = True, + **params, + ): + """Initialize ASE relaxer. + + Parameters + ---------- + fmax : float, default=0.02 + Maximum force convergence criterion. + steps : int, default=500 + Maximum number of optimization steps. + cpu : bool, default=False + Whether to use CPU instead of GPU. + **params + Additional parameters specific to each relaxer type. + """ + self.fmax = fmax + self.steps = steps + self.cpu = cpu + self.params = params # Store additional parameters for subclasses + # Abstract calculator setup - must be implemented by subclasses + self.calc = self._setup_calculator() + + @abstractmethod + def _setup_calculator(self): + """Set up the calculator for this relaxer. + + Returns + ------- + Calculator + The initialized calculator object. + """ + pass + + def relax(self, structure: Structure, relax: bool = False) -> RelaxationResult: + """Relax a structure using calculator. + + Parameters + ---------- + structure : Structure + Structure to relax. + relax : bool, default=False + Only relaxes the structure if True. + + Returns + ------- + RelaxationResult + Result of the relaxation. + """ + try: + if structure is None or not structure.is_valid(): + print("Skipping structure: Invalid crystal") + return RelaxationResult( + success=False, + message="Invalid crystal", + ) + + # Convert to ASE atoms + atoms = structure.to_ase_atoms() + + atoms.calc = self.calc + + # Relax structure + if relax: + dyn = FIRE(FrechetCellFilter(atoms), logfile=None) + dyn.run(fmax=self.fmax, steps=self.steps) + + # Get results + final_energy = atoms.get_potential_energy() + + final_structure = AseAtomsAdaptor.get_structure(atoms) + + # Clean up + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return RelaxationResult( + success=True, + energy=final_energy, + structure=final_structure, + ) + except Exception as e: + logger.error(f"ASE relaxation failed: {str(e)}") + return RelaxationResult( + success=False, + message=str(e), + ) + + +@register_relaxer("orb") +class OrbRelaxerImpl(ASERelaxerBase): + """Orb relaxer implementation.""" + + def _setup_calculator(self): + """Set up the Orb calculator.""" + device = "cpu" if self.cpu else "cuda" + if self.params.get("direct", False): + orbff = pretrained.orb_v3_direct_inf_mpa( + device=device, + precision="float32-high", # or "float32-highest" / "float64 + ) + else: + orbff = pretrained.orb_v3_conservative_inf_mpa( + device=device, + precision="float32-high", # or "float32-highest" / "float64 + ) + return ORBCalculator(orbff, device=device) + + +# TODO: Fix Meta Fairchem Relaxers +# @register_relaxer("ocp") +# class OCPRelaxerImpl(ASERelaxerBase): +# """OCP relaxer implementation.""" + +# def _setup_calculator(self): +# """Set up the OCP calculator.""" +# return OCPCalculator(checkpoint_path=self.checkpoint_path, cpu=self.cpu) diff --git a/tests/benchmarks/test_stability_benchmark.py b/tests/benchmarks/test_stability_benchmark.py new file mode 100644 index 00000000..d59f7d4b --- /dev/null +++ b/tests/benchmarks/test_stability_benchmark.py @@ -0,0 +1,137 @@ +"""Tests for stability benchmark.""" + +from pymatgen.util.testing import PymatgenTest + +from lematerial_forgebench.benchmarks.stability_benchmark import StabilityBenchmark +from lematerial_forgebench.preprocess.stability_preprocess import ( + StabilityPreprocessor, +) + + +class TestStabilityBenchmark: + """Test suite for StabilityBenchmark class.""" + + def test_initialization_default(self): + """Test initialization with default parameters.""" + benchmark = StabilityBenchmark() + + # Check name and properties + assert benchmark.config.name == "StabilityBenchmark" + assert "version" in benchmark.config.metadata + assert benchmark.config.metadata["category"] == "stability" + + # Check correct evaluator + assert len(benchmark.evaluators) == 2 + assert "stability" in benchmark.evaluators + + def test_initialization_custom(self): + """Test initialization with custom relaxer configuration.""" + + benchmark = StabilityBenchmark( + name="Custom Stability Benchmark", + description="Custom description", + metadata={"test_key": "test_value"}, + ) + + # Check custom values + assert benchmark.config.name == "Custom Stability Benchmark" + assert benchmark.config.description == "Custom description" + assert benchmark.config.metadata["test_key"] == "test_value" + + def test_evaluate_with_mp_entries(self): + """Test benchmark evaluation on structures""" + benchmark = StabilityBenchmark() + + # Create test structures + test = PymatgenTest() + structures = [test.get_structure("Si"), test.get_structure("LiFePO4")] + + # first, we need to preprocess the structures + stability_preprocessor = StabilityPreprocessor() + preprocessor_result = stability_preprocessor(structures) + structures = preprocessor_result.processed_structures + print(structures[0].properties) + + # Run benchmark + result = benchmark.evaluate(structures) + + # Check result format + assert len(result.evaluator_results) == 2 + assert "stability" in result.evaluator_results + assert "stable_ratio" in result.final_scores + print(result.final_scores) + + # Check score types + assert isinstance(result.final_scores["stable_ratio"], (int, float)) + assert isinstance(result.final_scores["metastable_ratio"], (int, float)) + assert isinstance(result.final_scores["mean_e_above_hull"], (int, float)) + + def test_empty_structures(self): + """Test behavior with empty structure list.""" + benchmark = StabilityBenchmark() + + # Test behavior with no structures - should not raise error + result = benchmark.evaluate([]) + + # Should get default values + assert result.final_scores["stable_ratio"] == 0.0 + assert result.final_scores["metastable_ratio"] == 0.0 + assert result.final_scores["mean_e_above_hull"] == 0.0 + + def test_aggregate_evaluator_results(self): + """Test result aggregation logic.""" + benchmark = StabilityBenchmark() + + # Mock evaluator_results as passed by BaseBenchmark.evaluate + # It contains the evaluator's combined_value and the primary metric value + # of each metric configured for that evaluator (e.g., "metric_name_value"). + mock_evaluator_results_from_base = { + "stability": { # Name of the evaluator + "combined_value": 0.75, # Evaluator's combined score + "metric_results": { + "stability": { + "metrics": { + "stable_ratio": 0.75, + "mean_e_above_hull": 0.1, + } + } + }, + }, + "metastability": { + "combined_value": 0.85, + }, + } + + # Aggregate results + scores = benchmark.aggregate_evaluator_results(mock_evaluator_results_from_base) + + # Check scores + # aggregate_evaluator_results should pick up combined_value as stability_score + # and stability_value as stable_ratio. + # mean_e_above_hull and metastable_ratio will be defaults (nan, 0.0) because + # they are not present in the input dict. + assert scores["stable_ratio"] == 0.75 + assert scores["metastable_ratio"] == 0.85 + assert scores["mean_e_above_hull"] == 0.1 + + def test_benchmark_metadata(self): + """Test benchmark metadata structure.""" + benchmark = StabilityBenchmark() + + metadata = benchmark.config.metadata + + # Check required metadata fields + assert metadata["version"] == "0.1.0" + assert metadata["category"] == "stability" + + +def test_evaluator_configuration(): + """Test that evaluator is properly configured.""" + benchmark = StabilityBenchmark() + + # Check evaluator configuration + stability_evaluator = benchmark.evaluators["stability"] + print(stability_evaluator) + assert stability_evaluator.config.name == "stability" + assert stability_evaluator.config.weights == {"stability": 1.0} + assert stability_evaluator.config.aggregation_method == "weighted_mean" diff --git a/tests/metrics/test_stability_metrics.py b/tests/metrics/test_stability_metrics.py new file mode 100644 index 00000000..e7f32288 --- /dev/null +++ b/tests/metrics/test_stability_metrics.py @@ -0,0 +1,116 @@ +"""Tests for stability metrics implementation.""" + +import pytest +from pymatgen.util.testing import PymatgenTest + +from lematerial_forgebench.metrics.stability_metrics import ( + MetastabilityMetric, + StabilityMetric, +) + + +@pytest.fixture +def test_structures(): + """Create test structures for stability evaluation.""" + test = PymatgenTest() + structures = [ + test.get_structure("Si"), + test.get_structure("LiFePO4"), + test.get_structure("CsCl"), + ] + return structures + + +@pytest.fixture +def test_structures_with_precomputed_e_above_hull(): + """Create test structures with precomputed e_above_hull values.""" + test = PymatgenTest() + structures = [ + test.get_structure("Si"), + test.get_structure("LiFePO4"), + test.get_structure("CsCl"), + test.get_structure("NaCl"), + test.get_structure("Fe2O3"), + ] + + # Add precomputed e_above_hull values to structure properties + e_above_hull_values = [0.05, 0.0, 0.15, 0.02, 0.08] # Mix of stable and metastable + + for structure, e_above_hull in zip(structures, e_above_hull_values): + structure.properties["e_above_hull"] = e_above_hull + + return structures + + +def test_stability_metric(test_structures): + metric = StabilityMetric( + relaxer_type="orb", + relaxer_config={"steps": 500, "fmax": 0.02}, + mp_entries_file="src/lematerial_forgebench/utils/relaxers/2023-02-07-ppd-mp.pkl.gz", + ) + + result = metric(test_structures) + + # Check that computation ran (may have failed indices due to no MP entries) + assert len(result.individual_values) == len(test_structures) + + # Check result structure + assert "mean_e_above_hull" in result.metrics + assert "stable_ratio" in result.metrics + assert "metastable_ratio" in result.metrics + assert result.primary_metric == "stable_ratio" + + assert result.metrics["stable_ratio"] >= 0.0 + assert result.metrics["metastable_ratio"] >= 0.0 + + +def test_metastability_metric(test_structures_with_precomputed_e_above_hull): + """Test MetastabilityMetric with precomputed e_above_hull values.""" + metric = MetastabilityMetric() + + result = metric(test_structures_with_precomputed_e_above_hull) + + # Check that computation ran successfully + assert len(result.individual_values) == len( + test_structures_with_precomputed_e_above_hull + ) + + # Check result structure + assert "mean_e_above_hull" in result.metrics + assert "metastable_ratio" in result.metrics + assert result.primary_metric == "metastable_ratio" + + # Check specific values based on our test data + # e_above_hull_values = [0.05, 0.0, 0.15, 0.02, 0.08] + # metastable (≤ 0.1): 4 structures (all except CsCl with 0.15) + + assert result.metrics["metastable_ratio"] == 4 / 5 # 4 out of 5 structures + assert ( + abs(result.metrics["mean_e_above_hull"] - 0.06) < 1e-6 + ) # (0.05+0.0+0.15+0.02+0.08)/5 = 0.06 + + # Check that all individual values were extracted correctly + expected_values = [0.05, 0.0, 0.15, 0.02, 0.08] + assert result.individual_values == expected_values + + +def test_metastability_metric_missing_properties(): + """Test MetastabilityMetric with structures missing e_above_hull properties.""" + test = PymatgenTest() + structures = [ + test.get_structure("Si"), + test.get_structure("LiFePO4"), + ] + + # Don't add e_above_hull properties - should return NaN values + metric = MetastabilityMetric() + result = metric(structures) + + # Should have NaN values for structures without properties + import numpy as np + + assert all(np.isnan(val) for val in result.individual_values) + + # Metrics should be 0 when all values are NaN + assert result.metrics["metastable_ratio"] == 0.0 + assert result.metrics["mean_e_above_hull"] == 0.0