From 7169439677d1190b5ec17bcb920ebad9eacc588d Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Sat, 12 Jul 2025 12:49:57 +0000 Subject: [PATCH] chore: Support batch embeddings calculations for mace --- .../models/mace/embeddings.py | 49 +++++++++++++++++-- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/src/lematerial_forgebench/models/mace/embeddings.py b/src/lematerial_forgebench/models/mace/embeddings.py index 1367416e..543bd60c 100644 --- a/src/lematerial_forgebench/models/mace/embeddings.py +++ b/src/lematerial_forgebench/models/mace/embeddings.py @@ -1,7 +1,12 @@ """MACE embedding extraction utilities.""" +from typing import Union + import numpy as np +import torch +from mace import data from pymatgen.core.structure import Structure +from torch_geometric.data import Batch, Data from lematerial_forgebench.models.base import BaseEmbeddingExtractor @@ -14,19 +19,57 @@ def __init__(self, calculator, device="cpu"): self.calculator = calculator self.device = device - def extract_node_embeddings(self, structure: Structure) -> np.ndarray: + def extract_node_embeddings( + self, structure: Union[Structure, list[Structure]] + ) -> Union[np.ndarray, list[np.ndarray]]: """Extract per-atom embeddings from MACE model. Parameters ---------- - structure : Structure - Input structure + structure : Union[Structure, list[Structure]] + Input structure or list of structures Returns ------- np.ndarray Node embeddings with shape (n_atoms, descriptor_dim) """ + if isinstance(structure, list): + keyspec = data.KeySpecification( + info_keys={}, arrays_keys={"charges": self.calculator.charges_key} + ) + configs = [ + data.config_from_atoms( + _structure.to_ase_atoms(), + key_specification=keyspec, + head_name=self.calculator.head, + ) + for _structure in structure + ] + atomic_data_list = [ + Data( + **data.AtomicData.from_config( + config, + z_table=self.calculator.z_table, + cutoff=self.calculator.r_max, + heads=self.calculator.available_heads, + ).__dict__ + ) + for config in configs + ] + + batch = Batch.from_data_list(atomic_data_list) + batch = batch.to(self.device) + output = self.calculator.models[0](batch) + node_features = output["node_feats"] + node_features_list = torch.split(node_features, batch.ptr.diff().tolist()) + node_features_list = [ + node_features.detach().cpu().numpy() + for node_features in node_features_list + ] + + return node_features_list + atoms = structure.to_ase_atoms() # Use MACE's built-in descriptor extraction