diff --git a/modnet/models/ensemble.py b/modnet/models/ensemble.py index 907909e4..80fa0e65 100644 --- a/modnet/models/ensemble.py +++ b/modnet/models/ensemble.py @@ -222,7 +222,7 @@ class OR only return the most probable class. arr=np.array(all_predictions), ) - p_std = np.array(all_predictions).std(axis=0) + p_std = np.array(all_predictions, dtype=np.float32).std(axis=0) df_mean = pd.DataFrame(p_mean, index=p.index, columns=p_columns) df_std = pd.DataFrame(p_std, index=p.index, columns=p.columns)