Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions modnet/models/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,25 @@

"""

import multiprocessing
import warnings
from collections import defaultdict
from typing import List, Tuple, Dict, Optional, Callable, Any, Union

from pathlib import Path
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
import numpy as np
import warnings
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, roc_auc_score
import pandas as pd
import tensorflow as tf
import tqdm
from sklearn.impute import SimpleImputer
from sklearn.metrics import mean_absolute_error, mean_squared_error, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from modnet import __version__
from modnet.preprocessing import MODData
from modnet.utils import LOG
from modnet import __version__

import tqdm

__all__ = ("MODNetModel",)

Expand Down Expand Up @@ -532,9 +530,10 @@ def fit_preset(

"""

from modnet.matbench.benchmark import matbench_kfold_splits
import os

from modnet.matbench.benchmark import matbench_kfold_splits

os.environ["TF_CPP_MIN_LOG_LEVEL"] = (
"2" # many models will be fitted => reduce output
)
Expand Down Expand Up @@ -1358,7 +1357,7 @@ def fit(
history = self.model.fit(**fit_params)
self.history = history.history

def predict(self, test_data: MODData, return_prob=False) -> pd.DataFrame:
def predict(self, test_data: MODData, return_prob=False, **kwargs) -> pd.DataFrame:
"""Predict the target values for the passed MODData.

Parameters:
Expand All @@ -1372,6 +1371,7 @@ class OR only return the most probable class.


"""
# kwargs added for compatibility of this DeprecatedMODNetModel with ensemble (e.g. remap_out_of_bounds could be given but ignored here)
# prevents Nan predictions if some features are inf
x = (
test_data.get_featurized_df()
Expand All @@ -1381,6 +1381,9 @@ class OR only return the most probable class.

# Scale and impute input features:
if self._scale_impute is not None:
self._scale_impute.named_steps["scaler"].clip = (
False # deprecated compatibility
)
x = self._scale_impute.transform(x)

p = np.array(self.model.predict(x))
Expand Down