diff --git a/src/jabs/classifier/__init__.py b/src/jabs/classifier/__init__.py index 4c576d1e..e4751575 100644 --- a/src/jabs/classifier/__init__.py +++ b/src/jabs/classifier/__init__.py @@ -10,16 +10,20 @@ from .multi_class_classifier import MultiClassClassifier from .protocols import ClassifierProtocol from .training_report import ( + BinaryCVResult, CrossValidationResult, + MultiClassCVResult, TrainingReportData, generate_markdown_report, save_training_report, ) __all__ = [ + "BinaryCVResult", "Classifier", "ClassifierProtocol", "CrossValidationResult", + "MultiClassCVResult", "MultiClassClassifier", "TrainingReportData", "generate_markdown_report", diff --git a/src/jabs/classifier/base.py b/src/jabs/classifier/base.py new file mode 100644 index 00000000..c6b1ce24 --- /dev/null +++ b/src/jabs/classifier/base.py @@ -0,0 +1,328 @@ +"""Shared infrastructure for behavior classifiers. + +``BaseClassifier`` consolidates persistence, identity properties, factory +dispatch, feature cleaning, and feature-importance reporting that are common +to both binary and multi-class classifiers. Subclasses provide the train and +predict implementations that determine the actual learning behavior. + +The class is concrete (not abstract): subclasses are not required to +override anything to instantiate. Public surface for classifier *consumers* +is governed by :class:`jabs.classifier.ClassifierProtocol`, which both +subclasses satisfy structurally. +""" + +from __future__ import annotations + +import typing +import warnings +from pathlib import Path +from typing import ClassVar + +import joblib +import numpy as np +import numpy.typing as npt +import pandas as pd +from sklearn.exceptions import InconsistentVersionWarning + +from jabs.core.enums import ClassifierType +from jabs.core.utils import hash_file + +from . import classifier_utils, factories + + +class BaseClassifier: + """Shared persistence and identity machinery for JABS classifiers. + + Class attributes that subclasses must set: + ``_VERSION``: pickled-format version integer for this subclass. + ``_MULTICLASS``: True if this subclass operates in multi-class mode. + ``_PERSISTED_REQUIRED``: tuple of instance attribute names that ``load`` + must restore from the pickled instance. + ``_PERSISTED_OPTIONAL``: tuple of instance attribute names that + ``load`` should restore if present on the pickled instance (default + to ``None`` otherwise). Used to support older pickles that may not + have all attributes the live class now declares. + """ + + _VERSION: ClassVar[int] = 0 + _MULTICLASS: ClassVar[bool] = False + _PERSISTED_REQUIRED: ClassVar[tuple[str, ...]] = () + _PERSISTED_OPTIONAL: ClassVar[tuple[str, ...]] = () + + def __init__(self, classifier_type: ClassifierType, n_jobs: int = 1) -> None: + self._classifier_type = classifier_type + self._classifier: typing.Any = None + self._project_settings: dict | None = None + self._feature_names: list[str] | None = None + self._n_jobs = n_jobs + self._version = self._VERSION + + self._classifier_file: str | None = None + self._classifier_hash: str | None = None + self._classifier_source: str | None = None + + self._supported_classifiers = self._supported_classifier_choices() + if classifier_type not in self._supported_classifiers: + raise ValueError("Invalid classifier type") + + @property + def classifier_name(self) -> str: + """Return the name of the underlying algorithm.""" + return self._classifier_type.value + + @property + def classifier_type(self) -> ClassifierType: + """Return the underlying classifier algorithm enum value.""" + return self._classifier_type + + @property + def classifier_file(self) -> str | None: + """Return the filename of the saved classifier, if any.""" + return self._classifier_file + + @property + def classifier_hash(self) -> str | None: + """Return the content hash of the saved classifier, if any.""" + return self._classifier_hash + + @property + def project_settings(self) -> dict: + """Return a copy of the classifier's training settings.""" + if self._project_settings is not None: + return dict(self._project_settings) + return {} + + @property + def version(self) -> int: + """Return the serialized classifier format version.""" + return self._version + + @property + def feature_names(self) -> list[str] | None: + """Return the list of feature names used to train this classifier.""" + return self._feature_names + + @classmethod + def _supported_classifier_choices(cls) -> set[ClassifierType]: + """Return classifier types available in the current environment. + + Resolved per-call so that test code can patch + :func:`jabs.classifier.factories.supported_classifier_types` or this + method on the subclass without freezing state at import time. + """ + return factories.supported_classifier_types(multiclass=cls._MULTICLASS) + + def set_classifier(self, classifier: ClassifierType) -> None: + """Switch the underlying classifier algorithm. + + Args: + classifier: The classifier type to switch to. + + Raises: + ValueError: If the classifier type is not supported. + """ + if classifier not in self._supported_classifier_choices(): + raise ValueError("Invalid Classifier Type") + self._classifier_type = classifier + + def set_dict_settings(self, settings: dict) -> None: + """Assign classifier settings from a dictionary. + + Args: + settings: dict of settings (same structure as + ``project.settings_manager.get_behavior``). + """ + self._project_settings = dict(settings) + + def classifier_choices(self) -> dict[ClassifierType, str]: + """Return the available classifier types as a sorted display map. + + Returns: + dict mapping ``ClassifierType`` enum values to their string names. + """ + return {t: t.value for t in sorted(self._supported_classifiers, key=lambda t: t.value)} + + def _create_classifier(self, random_seed: int | None = None) -> typing.Any: + """Instantiate the underlying sklearn/xgboost/catboost classifier.""" + factory = factories.get_factory(self._classifier_type, multiclass=self._MULTICLASS) + return factory(self._n_jobs, random_seed) + + def _clean_features(self, features: pd.DataFrame) -> pd.DataFrame: + """Replace ±inf/NaN in feature matrix per classifier type.""" + return classifier_utils.clean_features(features, self._classifier_type) + + def _get_features_to_classify(self, features: pd.DataFrame) -> pd.DataFrame: + """Reorder/select feature columns to match the trained model. + + Args: + features: DataFrame of feature data to filter. + + Returns: + DataFrame containing only the columns the trained model expects, + in the order the model expects them. + + Raises: + RuntimeError: If feature names cannot be obtained from the model. + """ + if self._classifier_type == ClassifierType.XGBOOST: + classifier_columns = self._classifier.get_booster().feature_names + elif hasattr(self._classifier, "feature_names_in_"): + classifier_columns = list(self._classifier.feature_names_in_) + elif hasattr(self._classifier, "feature_names_"): + classifier_columns = list(self._classifier.feature_names_) + else: + raise RuntimeError("Error obtaining feature names from classifier.") + return features[classifier_columns] + + @staticmethod + def combine_data(per_frame: pd.DataFrame, window: pd.DataFrame) -> pd.DataFrame: + """Combine per-frame and window feature DataFrames into one.""" + return classifier_utils.combine_data(per_frame, window) + + @staticmethod + def derive_predictions( + probabilities: npt.NDArray[np.floating], + ) -> tuple[npt.NDArray[np.int8], npt.NDArray[np.floating]]: + """Derive class predictions and confidence from class probabilities. + + Args: + probabilities: Array of shape ``(n_frames, n_classes)`` of predicted + class probabilities. + + Returns: + Tuple ``(predictions, confidence)`` where ``predictions`` is the + argmax class index per frame (``-1`` if confidence is zero, + indicating no pose) and ``confidence`` is the probability of the + chosen class. + """ + predictions = np.argmax(probabilities, axis=1).astype(np.int8) + confidence = probabilities[np.arange(len(probabilities)), predictions] + predictions[confidence == 0] = -1 + return predictions, confidence + + def get_feature_importance(self, limit: int = 20) -> list[tuple[str, float]]: + """Return ranked feature importances, highest first. + + Args: + limit: Maximum number of features to return. + + Returns: + List of ``(feature_name, importance)`` tuples sorted by importance + descending. Returns an empty list if the classifier is untrained or + does not expose feature importances. + """ + if self._classifier is None or self._feature_names is None: + return [] + if not hasattr(self._classifier, "feature_importances_"): + return [] + importances = list(np.asarray(self._classifier.feature_importances_).reshape(-1)) + feature_importance = [ + (feature, round(importance, 2)) + for feature, importance in zip(self._feature_names, importances, strict=True) + ] + feature_importance.sort(key=lambda x: x[1], reverse=True) + return feature_importance[:limit] + + def save(self, path: Path) -> None: + """Serialize the classifier to disk using joblib. + + Args: + path: Destination file path. + """ + joblib.dump(self, path) + if self._classifier_file is None: + self._classifier_file = Path(path).name + self._classifier_hash = hash_file(Path(path)) + self._classifier_source = "serialized" + + @classmethod + def from_pickle(cls, path: Path) -> BaseClassifier: + """Load a classifier from a pickle file with full validation and metadata backfill. + + Applies the same version, classifier-type, and metadata checks as + :meth:`load`, but as a classmethod factory so no dummy instance is + required. The class of the returned object is determined by the + calling class - ``Classifier.from_pickle(...)`` rejects pickled + ``MultiClassClassifier`` instances and vice versa. + + Args: + path: Path to the saved classifier pickle file. + + Returns: + Loaded and validated classifier instance of type ``cls``. + + Raises: + ValueError: If the file is not an instance of ``cls``, was trained + with an incompatible sklearn or JABS version, or uses an + unsupported classifier type. + """ + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", InconsistentVersionWarning) + c = joblib.load(path) + for warning in caught_warnings: + if issubclass(warning.category, InconsistentVersionWarning): + raise ValueError("Classifier trained with different version of sklearn.") + warnings.warn(warning.message, warning.category, stacklevel=2) + + if not isinstance(c, cls): + raise ValueError(f"{path} is not an instance of {cls.__name__}") + + if c._version != cls._VERSION: + raise ValueError( + f"Unable to deserialize pickled classifier. " + f"File version {c._version}, expected {cls._VERSION}." + ) + + if c._classifier_type not in cls._supported_classifier_choices(): + raise ValueError("Invalid classifier type") + + if c._classifier_file is None: + c._classifier_file = Path(path).name + c._classifier_hash = hash_file(Path(path)) + c._classifier_source = "pickle" + + return c + + def load(self, path: Path) -> None: + """Deserialize a classifier from disk, updating this instance in place. + + Args: + path: Source file path. + + Raises: + ValueError: If the file is not an instance of this class, was saved + with a different version, or uses an unsupported classifier type. + """ + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", InconsistentVersionWarning) + c = joblib.load(path) + for warning in caught_warnings: + if issubclass(warning.category, InconsistentVersionWarning): + raise ValueError("Classifier trained with different version of sklearn.") + warnings.warn(warning.message, warning.category, stacklevel=2) + + if not isinstance(c, type(self)): + raise ValueError(f"{path} is not an instance of {type(self).__name__}") + + if c._version != self._VERSION: + raise ValueError( + f"Unable to deserialize pickled classifier. " + f"File version {c._version}, expected {self._VERSION}." + ) + + if c._classifier_type not in self._supported_classifiers: + raise ValueError("Invalid classifier type") + + for attr in self._PERSISTED_REQUIRED: + setattr(self, attr, getattr(c, attr)) + for attr in self._PERSISTED_OPTIONAL: + setattr(self, attr, getattr(c, attr, None)) + + if c._classifier_file is not None: + self._classifier_file = c._classifier_file + self._classifier_hash = c._classifier_hash + self._classifier_source = c._classifier_source + else: + self._classifier_file = Path(path).name + self._classifier_hash = hash_file(Path(path)) + self._classifier_source = "pickle" diff --git a/src/jabs/classifier/classifier.py b/src/jabs/classifier/classifier.py index bc21a58f..fdb9ab5a 100644 --- a/src/jabs/classifier/classifier.py +++ b/src/jabs/classifier/classifier.py @@ -1,12 +1,12 @@ +"""Binary behavior classifier (behavior vs. not-behavior).""" + import logging -import typing import warnings from pathlib import Path +from typing import ClassVar -import joblib import numpy as np import pandas as pd -from sklearn.exceptions import InconsistentVersionWarning from jabs.core.enums import ( DEFAULT_CV_GROUPING_STRATEGY, @@ -14,63 +14,53 @@ CrossValidationGroupingStrategy, ) from jabs.core.utils import hash_file -from jabs.project import Project, TrackLabels, load_training_data +from jabs.project import Project, load_training_data from . import classifier_utils -from .factories import XGBOOST_AVAILABLE, make_catboost, make_random_forest, make_xgboost - -_VERSION = 11 +from .base import BaseClassifier -# _CLASSIFIER_FACTORIES serves as both the single source of truth for classifiers -# supported by the current JABS environment, in addition to the mapping of ClassifierTypes -# to factory functions that produce instantiated classifiers for that type. -# XGBoost availability is detected once in factories.py; import XGBOOST_AVAILABLE here -# rather than re-probing so that the warning is emitted exactly once. -_CLASSIFIER_FACTORIES: dict[ClassifierType, typing.Callable[[int, int | None], typing.Any]] = { - ClassifierType.RANDOM_FOREST: make_random_forest, - ClassifierType.CATBOOST: make_catboost, -} +logger = logging.getLogger(__name__) -if XGBOOST_AVAILABLE: - _CLASSIFIER_FACTORIES[ClassifierType.XGBOOST] = make_xgboost +class Classifier(BaseClassifier): + """A binary behavior classifier (behavior vs. not-behavior). -class Classifier: - """A machine learning classifier for behavior classification tasks. - - This class supports training, evaluating, saving, and loading classifiers - for behavioral data using Random Forest or XGBoost algorithms. - It provides utilities for data splitting, balancing, augmentation, and feature management. + Supports training, evaluating, saving, and loading classifiers for + behavioral data using Random Forest, CatBoost, or XGBoost algorithms. + Persistence and identity machinery are inherited from + :class:`BaseClassifier`. Attributes: - LABEL_THRESHOLD (int): Minimum number of labels required per group. + LABEL_THRESHOLD: Minimum number of labels required per group. """ - LABEL_THRESHOLD = 20 - - def __init__(self, classifier: ClassifierType = ClassifierType.RANDOM_FOREST, n_jobs: int = 1): - self._classifier_type = classifier - self._classifier = None - self._project_settings = None - self._behavior = None - self._feature_names = None - self._n_jobs = n_jobs - self._version = _VERSION - - self._classifier_file = None - self._classifier_hash = None - self._classifier_source = None - self._supported_classifiers = self._supported_classifier_choices() - - # make sure the value passed for the classifier parameter is valid - if classifier not in self._supported_classifiers: - raise ValueError("Invalid classifier type") + LABEL_THRESHOLD: ClassVar[int] = classifier_utils.LABEL_THRESHOLD + + _VERSION: ClassVar[int] = 11 + _MULTICLASS: ClassVar[bool] = False + _PERSISTED_REQUIRED: ClassVar[tuple[str, ...]] = ( + "_classifier", + "_behavior", + "_project_settings", + "_classifier_type", + "_feature_names", + ) + + def __init__( + self, + classifier: ClassifierType = ClassifierType.RANDOM_FOREST, + n_jobs: int = 1, + ) -> None: + super().__init__(classifier_type=classifier, n_jobs=n_jobs) + self._behavior: str | None = None @classmethod - def from_training_file(cls, path: Path, classifier_type: ClassifierType | None = None): + def from_training_file( + cls, path: Path, classifier_type: ClassifierType | None = None + ) -> "Classifier": """Initialize a classifier from an exported training data file. - This method will load the training data and train a classifier. + This method loads the training data and trains a classifier. Args: path: exported training data file @@ -78,7 +68,7 @@ def from_training_file(cls, path: Path, classifier_type: ClassifierType | None = file. If ``None``, the type recorded in the file is used. Returns: - trained classifier object + trained Classifier object """ loaded_training_data, _ = load_training_data(path) behavior = loaded_training_data["behavior"] @@ -91,8 +81,10 @@ def from_training_file(cls, path: Path, classifier_type: ClassifierType | None = if effective_type in classifier._supported_classifiers: classifier.set_classifier(effective_type) else: - logging.warning( - f"Specified classifier type {effective_type.name} is unavailable, using default: {classifier.classifier_type.name}" + logger.warning( + "Specified classifier type %s is unavailable, using default: %s", + effective_type.name, + classifier.classifier_type.name, ) training_features = classifier.combine_data( loaded_training_data["per_frame"], loaded_training_data["window"] @@ -111,94 +103,41 @@ def from_training_file(cls, path: Path, classifier_type: ClassifierType | None = return classifier - @property - def classifier_name(self) -> str: - """return the name of the classifier used as a string""" - return self._classifier_type.value - - @property - def classifier_type(self) -> ClassifierType: - """return classifier type""" - return self._classifier_type - - @property - def classifier_file(self) -> str | None: - """return the filename of the saved classifier""" - return self._classifier_file - - @property - def classifier_hash(self) -> str | None: - """return the hash of the classifier file""" - return self._classifier_hash - - @property - def project_settings(self) -> dict: - """return a copy of dictionary of project settings for this classifier""" - if self._project_settings is not None: - return dict(self._project_settings) - return {} - @property def behavior_name(self) -> str | None: - """return the behavior name property""" + """Return the behavior name property.""" return self._behavior @behavior_name.setter - def behavior_name(self, value) -> None: - """set the behavior name property""" + def behavior_name(self, value: str | None) -> None: + """Set the behavior name property.""" self._behavior = value - @property - def version(self) -> int: - """return the classifier format version""" - return self._version - - @property - def feature_names(self) -> list[str] | None: - """returns the list of feature names used when training this classifier""" - return self._feature_names - @staticmethod - def get_leave_one_group_out_max(labels, groups): - """counts the number of possible leave one out groups for k-fold cross validation + def get_leave_one_group_out_max(labels: np.ndarray, groups: np.ndarray) -> int: + """Count the number of possible leave-one-group-out splits. Args: - labels: labels to check if they were above the threshold - groups: group id corresponding to the labels + labels: Labels to check against the per-class threshold. + groups: Group id corresponding to each label. Returns: - int of the maximum number of cross validation to use + Number of groups that can serve as a valid test split. Note: labels excludes label for frames with no identity. """ - labels = np.asarray(labels) - groups = np.asarray(groups) - unique_groups = np.unique(groups) - count = 0 - for g in unique_groups: - test_mask = groups == g - test_labels = labels[test_mask] - train_labels = labels[~test_mask] - # Test split must have both classes above threshold. - test_ok = ( - np.sum(test_labels == TrackLabels.Label.BEHAVIOR) >= Classifier.LABEL_THRESHOLD - and np.sum(test_labels == TrackLabels.Label.NOT_BEHAVIOR) - >= Classifier.LABEL_THRESHOLD - ) - # Training split must also have both classes above threshold so the - # model can learn every class regardless of which group is held out. - train_ok = ( - np.sum(train_labels == TrackLabels.Label.BEHAVIOR) >= Classifier.LABEL_THRESHOLD - and np.sum(train_labels == TrackLabels.Label.NOT_BEHAVIOR) - >= Classifier.LABEL_THRESHOLD - ) - if test_ok and train_ok: - count += 1 - return count + return classifier_utils.count_valid_logo_splits( + labels, groups, label_threshold=Classifier.LABEL_THRESHOLD + ) @staticmethod - def leave_one_group_out(per_frame_features, window_features, labels, groups): - """implements "leave one group out" data splitting strategy + def leave_one_group_out( + per_frame_features: pd.DataFrame, + window_features: pd.DataFrame, + labels: np.ndarray, + groups: np.ndarray, + ): + """Yield "leave one group out" train/test splits. Args: per_frame_features: per frame features for all labeled data @@ -206,15 +145,9 @@ def leave_one_group_out(per_frame_features, window_features, labels, groups): labels: labels corresponding to each feature row groups: group id corresponding to each feature row - Returns: - dictionary of training and test data and labels: - { - 'training_data': list of numpy arrays, - 'test_data': list of numpy arrays, - 'training_labels': numpy array, - 'test_labels': numpy_array, - 'feature_names': list of feature names - } + Yields: + Dict with training_data, test_data, training_labels, test_labels, + and feature_names. """ yield from classifier_utils.leave_one_group_out( per_frame_features, @@ -225,109 +158,55 @@ def leave_one_group_out(per_frame_features, window_features, labels, groups): ) @staticmethod - def downsample_balance(features, labels, random_seed=None): - """downsamples features and labels such that labels are equally distributed - - Args: - features: features to downsample - labels: labels to downsample - random_seed: optional random seed - - Returns: - tuple of downsampled features, labels - """ + def downsample_balance( + features: pd.DataFrame, labels: np.ndarray, random_seed: int | None = None + ): + """Downsample features and labels to an equal class distribution.""" return classifier_utils.downsample_balance(features, labels, random_seed) @staticmethod - def augment_symmetric(features, labels, random_str="ASygRQDZJD"): - """augments the features to include L-R and R-L duplicates - - This requires 'left' or 'right' to be in the feature name to be swapped - Features that don't include these terms will not be swapped - - Args: - features: features to augment - labels: labels to augment - random_str: a random string to use as a temporary - replacement when swapping left/right - - Returns: - tuple of augmented features, labels - """ + def augment_symmetric( + features: pd.DataFrame, labels: np.ndarray, random_str: str = "ASygRQDZJD" + ): + """Augment features with left/right reflected duplicates.""" return classifier_utils.augment_symmetric(features, labels, random_str) - def set_classifier(self, classifier: ClassifierType): - """change the type of the classifier being used""" - if classifier not in self._supported_classifiers: - raise ValueError("Invalid Classifier Type") - self._classifier_type = classifier + def set_project_settings(self, project: Project) -> None: + """Assign project settings to the classifier. - def set_project_settings(self, project: Project): - """assign project settings to the classifier + If no behavior is currently set, uses project defaults; otherwise looks + up the behavior-scoped settings from the project's settings manager. Args: - project: project to copy classifier-relevant settings from for the current behavior - - if no behavior is currently set will use project defaults + project: Project to copy classifier-relevant settings from. """ if self._behavior is None: self._project_settings = project.get_project_defaults() else: self._project_settings = project.settings_manager.get_behavior(self._behavior) - def set_dict_settings(self, settings: dict): - """assign project settings via a dict to the classifier + def train(self, data: dict, random_seed: int | None = None) -> None: + """Train the classifier. Args: - settings: dict of project settings. Must be same structure as project.settings_manager.get_behavior - - TODO: Add checks to enforce conformity to project settings - """ - self._project_settings = dict(settings) - - def classifier_choices(self): - """get the available classifier types - - Returns: - dict where keys are ClassifierType enum values, and the - values are string names for the classifiers. - """ - return {t: t.value for t in sorted(self._supported_classifiers, key=lambda t: t.value)} - - def _create_classifier(self, random_seed: int | None = None): - """Instantiate the underlying classifier for the current classifier type.""" - try: - factory = _CLASSIFIER_FACTORIES[self._classifier_type] - except KeyError: - raise ValueError(f"Unsupported classifier type: {self._classifier_type!r}") from None - return factory(self._n_jobs, random_seed) - - def train(self, data, random_seed: int | None = None): - """train the classifier + data: dict returned from train_test_split(). + random_seed: optional random seed for reproducibility. - Args: - data: dict returned from train_test_split() - random_seed: optional random seed (used when we want - reproducible results between trainings) - - Returns: - None - - raises ValueError for having either unset project settings or an unset classifier + Raises: + ValueError: If project settings are unset. """ if self._project_settings is None: raise ValueError("Project settings for classifier unset, cannot train classifier.") - # Assume that feature names is provided, otherwise extract it from the dataframe if "feature_names" in data: self._feature_names = data["feature_names"] else: self._feature_names = data["training_data"].columns.to_list() - # Obtain the feature and label matrices features = data["training_data"] labels = data["training_labels"] - # Symmetric augmentation should occur before balancing so that the class with more labels can sample from the whole set + # Symmetric augmentation should occur before balancing so that the + # class with more labels can sample from the whole set. if self._project_settings.get("symmetric_behavior", False): features, labels = self.augment_symmetric(features, labels) if self._project_settings.get("balance_labels", False): @@ -339,45 +218,28 @@ def train(self, data, random_seed: int | None = None): warnings.simplefilter("ignore", category=FutureWarning) self._classifier = classifier.fit(cleaned_features, labels) - # Classifier may have been re-used from a prior training, blank the logging attributes self._classifier_file = None self._classifier_hash = None self._classifier_source = None - def get_features_to_classify(self, features: pd.DataFrame) -> pd.DataFrame: - """gets features for classification, handling classifier-specific quirks.""" - if self.classifier_type == ClassifierType.XGBOOST: - # XGBoost feature names are obtained from the booster - classifier_columns = self._classifier.get_booster().feature_names - else: - # For other classifiers, use the feature names from the underlying model - if hasattr(self._classifier, "feature_names_in_"): - classifier_columns = list(self._classifier.feature_names_in_) - elif hasattr(self._classifier, "feature_names_"): - classifier_columns = list(self._classifier.feature_names_) - else: - raise RuntimeError("Error obtaining feature names from classifier.") - - return features[classifier_columns] - def predict( self, features: pd.DataFrame, frame_indexes: np.ndarray | None = None ) -> np.ndarray: - """predict classes for a given set of features + """Predict classes for a given set of features. Args: - features: DataFrame of feature data to classify - frame_indexes: frame indexes to classify (default all) + features: DataFrame of feature data to classify. + frame_indexes: Frame indexes to classify (default all). Returns: - predicted class vector + Predicted class vector. Frames absent from ``frame_indexes`` are + assigned -1. """ - cleaned_features = self.get_features_to_classify(self._clean_features(features)) + cleaned_features = self._get_features_to_classify(self._clean_features(features)) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FutureWarning) result = self._classifier.predict(cleaned_features) - # Insert -1s into class prediction when no prediction is made if frame_indexes is not None: result_adjusted = np.full(result.shape, -1, dtype=np.int8) result_adjusted[frame_indexes] = result[frame_indexes] @@ -388,21 +250,21 @@ def predict( def predict_proba( self, features: pd.DataFrame, frame_indexes: np.ndarray | None = None ) -> np.ndarray: - """predict probabilities for a given set of features. + """Predict probabilities for a given set of features. Args: - features: DataFrame of feature data to classify - frame_indexes: frame indexes to classify (default all) + features: DataFrame of feature data to classify. + frame_indexes: Frame indexes to classify (default all). Returns: - prediction probability matrix + Prediction probability matrix. Frames absent from ``frame_indexes`` + are assigned zero probabilities. """ - cleaned_features = self.get_features_to_classify(self._clean_features(features)) + cleaned_features = self._get_features_to_classify(self._clean_features(features)) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FutureWarning) result = self._classifier.predict_proba(cleaned_features) - # Insert 0 probabilities when no prediction is made if frame_indexes is not None: result_adjusted = np.full(result.shape, 0, dtype=np.float32) result_adjusted[frame_indexes] = result[frame_indexes] @@ -410,209 +272,52 @@ def predict_proba( return result - def save(self, path: Path): - """save the classifier to a file - - Uses joblib to serialize the classifier object to a file. - """ - joblib.dump(self, path) - - # If the classifier was not generated from exported training data - # we can hash the serialized classifier. - # Note that this hash changes every time the "train" button is - # pressed, regardless of whether the training data changes. - if self._classifier_file is None: - self._classifier_file = Path(path).name - self._classifier_hash = hash_file(Path(path)) - self._classifier_source = "serialized" - - @classmethod - def from_pickle(cls, path: Path) -> "Classifier": - """Load a Classifier from a pickle file with full validation and metadata backfill. - - Applies the same version, classifier-type, and metadata checks as ``load()``, - but as a classmethod factory so no dummy instance is required. + def print_feature_importance(self, limit: int = 20) -> None: + """Print the most important features and their importance. Args: - path: Path to the saved classifier pickle file. - - Returns: - Loaded and validated ``Classifier`` instance. - - Raises: - ValueError: If the file is not a ``Classifier``, was trained with an - incompatible sklearn or JABS version, or uses an unsupported - classifier type. - """ - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("always", InconsistentVersionWarning) - c = joblib.load(path) - for warning in caught_warnings: - if issubclass(warning.category, InconsistentVersionWarning): - raise ValueError("Classifier trained with different version of sklearn.") - else: - warnings.warn(warning.message, warning.category, stacklevel=2) - - if not isinstance(c, cls): - raise ValueError(f"{path} is not an instance of Classifier") - - if c.version != _VERSION: - raise ValueError( - f"Unable to deserialize pickled classifier. File version {c.version}, expected {_VERSION}." - ) - - if c._classifier_type not in cls._supported_classifier_choices(): - raise ValueError("Invalid classifier type") - - if c._classifier_file is None: - c._classifier_file = Path(path).name - c._classifier_hash = hash_file(Path(path)) - c._classifier_source = "pickle" - - return c - - def load(self, path: Path): - """load a classifier from a file - - Uses joblib to deserialize the classifier object that was previously saved - using the joblib.dump() method. + limit: Maximum number of features to print. """ - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("always", InconsistentVersionWarning) - c = joblib.load(path) - for warning in caught_warnings: - if issubclass(warning.category, InconsistentVersionWarning): - raise ValueError("Classifier trained with different version of sklearn.") - else: - warnings.warn(warning.message, warning.category, stacklevel=2) - - if not isinstance(c, Classifier): - raise ValueError(f"{path} is not instance of Classifier") - - if c.version != _VERSION: - raise ValueError( - f"Unable to deserialize pickled classifier. File version {c.version}, expected {_VERSION}." - ) - - # make sure the value passed for the classifier parameter is valid - if c._classifier_type not in self._supported_classifiers: - raise ValueError("Invalid classifier type") - - self._classifier = c._classifier - self._behavior = c._behavior - self._project_settings = c._project_settings - self._classifier_type = c._classifier_type - if c._classifier_file is not None: - self._classifier_file = c._classifier_file - self._classifier_hash = c._classifier_hash - self._classifier_source = c._classifier_source - else: - self._classifier_file = Path(path).name - self._classifier_hash = hash_file(Path(path)) - self._classifier_source = "pickle" + feature_importance = self.get_feature_importance(limit=limit) + print(f"{'Feature Name':100} Importance") + print("-" * 120) + for feature, importance in feature_importance[:limit]: + print(f"{feature:100} {importance:0.2f}") @staticmethod - def accuracy_score(truth, predictions): - """return accuracy score""" + def accuracy_score(truth: np.ndarray, predictions: np.ndarray) -> float: + """Return accuracy score.""" return classifier_utils.accuracy_score(truth, predictions) @staticmethod - def precision_recall_score(truth, predictions): - """return precision recall score""" + def precision_recall_score(truth: np.ndarray, predictions: np.ndarray): + """Return precision/recall/f-score/support.""" return classifier_utils.precision_recall_score(truth, predictions) @staticmethod - def confusion_matrix(truth, predictions): - """return the confusion matrix using sklearn's confusion_matrix function""" + def confusion_matrix(truth: np.ndarray, predictions: np.ndarray) -> np.ndarray: + """Return the confusion matrix.""" return classifier_utils.confusion_matrix(truth, predictions) - @staticmethod - def combine_data(per_frame, window): - """combine feature sets together - - Args: - per_frame: per frame features dataframe - window: window feature dataframe - - Returns: - merged dataframe - """ - return classifier_utils.combine_data(per_frame, window) - - def get_feature_importance(self, limit=20) -> list[tuple[str, float]]: - """get the most important features and their importance - - Args: - limit: maximum number of features to return, defaults to 20 - - Returns: - list of tuples of feature name and importance - """ - # Get numerical feature importance - importances = list(self._classifier.feature_importances_) - # List of tuples with variable and importance - feature_importance = [ - (feature, round(importance, 2)) - for feature, importance in zip(self._feature_names, importances, strict=True) - ] - # Sort the feature importance by most important first - feature_importance = sorted(feature_importance, key=lambda x: x[1], reverse=True) - return feature_importance[:limit] - - def print_feature_importance(self, limit=20): - """print the most important features and their importance - - Args: - limit: maximum number of features to print, defaults to 20 - """ - feature_importance = self.get_feature_importance(limit=limit) - # Print out the feature and importance - print(f"{'Feature Name':100} Importance") - print("-" * 120) - for feature, importance in feature_importance[:limit]: - print(f"{feature:100} {importance:0.2f}") - @staticmethod def count_label_threshold( all_counts: dict, cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, ) -> int: - """counts the number of groups that meet label threshold criteria + """Count groups that meet the label-threshold criteria. Args: - all_counts: labeled frame and bout counts for the entire - project - - - all_counts is a dict with the following form - { - '