From 69a6ca9e2d3847e0e5a413f340576261a23c5ef6 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Tue, 6 May 2025 14:20:40 +0200 Subject: [PATCH 1/3] feat: Use dataset store --- src/lematerial_forgebench/create_store.py | 164 ++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 src/lematerial_forgebench/create_store.py diff --git a/src/lematerial_forgebench/create_store.py b/src/lematerial_forgebench/create_store.py new file mode 100644 index 00000000..b00c6c85 --- /dev/null +++ b/src/lematerial_forgebench/create_store.py @@ -0,0 +1,164 @@ +from concurrent.futures import ProcessPoolExecutor +from functools import partial + +from material_hasher.dataset_store import DatasetStore +from material_hasher.hasher import HASHERS +from material_hasher.hasher.base import HasherBase +from material_hasher.similarity import SIMILARITY_MATCHERS +from pymatgen.core import Structure +from tqdm import tqdm + +from lematerial_forgebench.data.references.huggingface import HFDataset +from lematerial_forgebench.utils.logging import logger + + +# Hasher that extracts the BAWL fingerprint from the HF dataset +# since they are already computed +class ExtractFingerprintHasher(HasherBase): + def __init__(self): + pass + + def get_material_hash(self, structure: Structure) -> str: + return structure.fingerprint + + +HASHERS["extract_fingerprint"] = ExtractFingerprintHasher + + +def process_batch( + batch_num, total_batches, indices, dataset_class, dataset_store_class, hasher_class +): + """Process a batch of indices in a separate process. + + This function reinitializes necessary objects in each process to avoid + pickling issues. + """ + dataset = dataset_class() + dataset_store = dataset_store_class(hasher_class) + + # Select only the needed indices + dataset = dataset.select(indices) + + embeddings = [] + desc = f"Batch {batch_num}/{total_batches}" + for structure in tqdm( + dataset, desc=desc, leave=False, dynamic_ncols=True, position=batch_num + ): + try: + embedding = dataset_store._get_structure_embedding( + structure, dataset_store.equivalence_checker + ) + embeddings.append(embedding) + except Exception as e: + logger.warning(f"Error processing structure: {e}") + continue + + return embeddings + + +def fit_store( + dataset: HFDataset, dataset_store: DatasetStore, store_path: str, n_jobs: int = 1 +): + """Fit a DatasetStore on a HuggingFace dataset. + + Parameters + ---------- + dataset : HFDataset + The HuggingFace dataset to fit the store on. + dataset_store : DatasetStore + The dataset store to fit. + store_path : str + The path to save the store to. + n_jobs : int, optional + The number of jobs to use for fitting the store. + """ + # Calculate batch indices + total_size = len(dataset) + batch_size = total_size // n_jobs + indices = [list(range(i * batch_size, (i + 1) * batch_size)) for i in range(n_jobs)] + + # Handle remainder + if total_size % n_jobs != 0: + indices[-1].extend(range(n_jobs * batch_size, total_size)) + + print(f"Processing {total_size} structures in {n_jobs} batches") + + if n_jobs <= 1: + # Process directly in main process if n_jobs <= 1 + embeddings = process_batch( + 1, + 1, # batch_num, total_batches + list(range(total_size)), + dataset.__class__, + dataset_store.__class__, + dataset_store.equivalence_checker.__class__, + ) + dataset_store.store_embeddings(embeddings) + else: + process_fn = partial( + process_batch, + total_batches=n_jobs, + dataset_class=dataset.__class__, + dataset_store_class=dataset_store.__class__, + hasher_class=dataset_store.equivalence_checker.__class__, + ) + + with ProcessPoolExecutor( + max_workers=n_jobs, + ) as executor: + futures = [] + for batch_num, idx_batch in enumerate(indices, 1): + future = executor.submit(process_fn, batch_num, indices=idx_batch) + futures.append(future) + + # Process results as they complete + total_processed = 0 + with tqdm(total=total_size, desc="Total progress") as pbar: + for future in futures: + try: + result = future.result() + dataset_store.store_embeddings(result) + total_processed += len(result) + pbar.update(len(result)) + except Exception as e: + logger.error(f"Error processing batch: {e}") + raise + + logger.info(f"\nSaving results to {store_path}") + if dataset_store.equivalence_checker_class == ExtractFingerprintHasher: + dataset_store.equivalence_checker_class = HASHERS["BAWL-Legacy"] + dataset_store.save(store_path) + + +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", choices=["lemat_bulk"], default="lemat_bulk") + parser.add_argument( + "--algorithm", + choices=list(HASHERS.keys()) + list(SIMILARITY_MATCHERS.keys()), + required=True, + ) + parser.add_argument("--n_jobs", type=int, default=10) + parser.add_argument("--subsample", type=int, default=None) + args = parser.parse_args() + + hasher_class = ( + HASHERS[args.algorithm] + if args.algorithm in HASHERS + else SIMILARITY_MATCHERS[args.algorithm] + ) + if args.dataset == "lemat_bulk": + dataset = HFDataset(subsample=args.subsample) + else: + raise ValueError(f"Dataset {args.dataset} not supported") + store = DatasetStore(hasher_class) + + fit_store( + dataset, store, f"store_{args.dataset}_{args.algorithm}.npy", n_jobs=args.n_jobs + ) + + +if __name__ == "__main__": + main() From 75ec72ee0979621712fc7aecdd4c57cc0ccfef50 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Tue, 6 May 2025 14:21:01 +0200 Subject: [PATCH 2/3] chore: Point to dev branch of material-hasher --- pyproject.toml | 2 +- uv.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e9fc56de..9d3e3e9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dev-dependencies = [ ] [tool.uv.sources] -material-hasher = { git = "https://github.com/LeMaterial/lematerial-hasher.git" } +material-hasher = { git = "https://github.com/LeMaterial/lematerial-hasher.git", rev = "dev_bis" } [tool.ruff.lint] diff --git a/uv.lock b/uv.lock index 18b95554..fbf0f645 100644 --- a/uv.lock +++ b/uv.lock @@ -833,7 +833,7 @@ requires-dist = [ { name = "ase", specifier = ">=3.25.0" }, { name = "click", specifier = ">=8.1.8" }, { name = "datasets", specifier = ">=3.5.0" }, - { name = "material-hasher", git = "https://github.com/LeMaterial/lematerial-hasher.git" }, + { name = "material-hasher", git = "https://github.com/LeMaterial/lematerial-hasher.git?rev=dev_bis" }, { name = "pandas", specifier = ">=2.2.3" }, { name = "pymatgen", specifier = ">=2025.4.20" }, { name = "rich", specifier = ">=14.0.0" }, @@ -1018,7 +1018,7 @@ wheels = [ [[package]] name = "material-hasher" version = "0.1.0" -source = { git = "https://github.com/LeMaterial/lematerial-hasher.git#7f63534e7a22033ea02d694cc7fdf9f4e52450ee" } +source = { git = "https://github.com/LeMaterial/lematerial-hasher.git?rev=dev_bis#9f48de01f5148b62bd9c7ca338f780e44bb078c0" } dependencies = [ { name = "average-minimum-distance" }, { name = "datasets" }, From 51894fa23d3d88c24283f30532cd355190ae09dc Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Tue, 6 May 2025 14:37:57 +0200 Subject: [PATCH 3/3] feat: Add first novelty metric --- src/lematerial_forgebench/metrics/novelty.py | 42 ++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 src/lematerial_forgebench/metrics/novelty.py diff --git a/src/lematerial_forgebench/metrics/novelty.py b/src/lematerial_forgebench/metrics/novelty.py new file mode 100644 index 00000000..ae92e31e --- /dev/null +++ b/src/lematerial_forgebench/metrics/novelty.py @@ -0,0 +1,42 @@ +import numpy as np +from material_hasher.dataset_store import DatasetStore +from material_hasher.hasher import HASHERS +from material_hasher.similarity import SIMILARITY_MATCHERS +from pymatgen.core import Structure + +from lematerial_forgebench.metrics.base import BaseMetric + +ALGORITHMS = { + **HASHERS, + **SIMILARITY_MATCHERS, +} + + +class NoveltyMetric(BaseMetric): + def __init__( + self, dataset_store_path: str, threshold: float | None = None, **kwargs + ): + super().__init__(**kwargs) + self.dataset_store = DatasetStore.load(dataset_store_path) + self.threshold = threshold + + @staticmethod + def compute_structure(structure: Structure, **compute_args) -> float: + # Novelty means that the structure is not in the dataset + # ie it is not equivalent to any of the structures in the dataset + return 1 - np.mean( + compute_args["dataset_store"].is_equivalent( + structure, compute_args["threshold"] + ) + ) + + def _get_compute_attributes(self) -> dict: + return { + "dataset_store": self.dataset_store, + "threshold": self.threshold, + } + + def aggregate_results(self, values: list[float]) -> dict: + return { + "novelty_rate": np.mean(values), + }