diff --git a/nextflow/modules/pose_qc.nf b/nextflow/modules/pose_qc.nf index 1a480d40..bdce391c 100644 --- a/nextflow/modules/pose_qc.nf +++ b/nextflow/modules/pose_qc.nf @@ -5,10 +5,12 @@ * @param in_pose The input pose file * * @return Rendered video + * + * @publish ./qc Rendered pose video */ process RENDER_POSE { label "tracking" - publishDir "compressed/pose/", mode:'copy' + publishDir "${params.pubdir}/qc", mode:'copy' input: tuple path(in_video), path(in_pose) @@ -18,6 +20,32 @@ process RENDER_POSE { script: """ - python3 /kumar_lab_models/mouse-tracking-runtime/render_pose.py --in-vid ${in_video} --in-pose ${in_pose} --out-vid ${in_video.baseName}_pose.mp4 + mouse-tracking utils render-pose ${in_video} ${in_pose} ${in_video.baseName}_pose.mp4 + """ +} + +/** + * Render fecal boli on a frame + * + * @param in_video The input video file + * @param in_pose The input pose file + * + * @return Rendered fecal boli video + * + * @publish ./qc Rendered fecal boli video + */ +process RENDER_BOLI { + label "tracking" + publishDir "${params.pubdir}/qc", mode:'copy' + + input: + tuple path(in_video), path(in_pose) + + output: + path "${in_video.baseName}_boli.avi" + + script: + """ + mouse-tracking utils render-fecal-boli ${in_video} ${in_pose} ${in_video.baseName}_boli.avi """ } diff --git a/nextflow/modules/single_mouse.nf b/nextflow/modules/single_mouse.nf index 40b4cbb0..3058829f 100644 --- a/nextflow/modules/single_mouse.nf +++ b/nextflow/modules/single_mouse.nf @@ -17,7 +17,7 @@ process PREDICT_SINGLE_MOUSE_SEGMENTATION { label "gpu" label "tracking" label "r_single_seg" - + input: tuple path(video_file), path(in_pose_file) @@ -90,6 +90,35 @@ process QC_SINGLE_MOUSE { """ } +/** + * Modifies a pose file to filter out large poses. + * + * @param tuple + * - in_video The input video file + * - in_pose_file The input pose file to modify + * + * @return tuple files + * - Path to the video file. + * - Path to the filtered pose file. + */ +process FILTER_LARGE_POSES { + label "tracking" + + input: + tuple path(in_video), path(in_pose_file) + + output: + tuple path("${in_video.baseName}_filtered.${in_video.extension}"), path("${in_video.baseName}_filtered.h5"), emit: files + + script: + """ + cp ${in_pose_file} ${in_video.baseName}_filtered.h5 + ln -s ${in_video} ${in_video.baseName}_filtered.${in_video.extension} + + mouse-tracking utils filter-large-area-pose ${in_video.baseName}_filtered.h5 + """ +} + /** * Clips a video and its corresponding pose file to a specified duration from the start. * diff --git a/nextflow/workflows/single_mouse_pipeline.nf b/nextflow/workflows/single_mouse_pipeline.nf index aec558d7..49b2207c 100644 --- a/nextflow/workflows/single_mouse_pipeline.nf +++ b/nextflow/workflows/single_mouse_pipeline.nf @@ -2,7 +2,10 @@ * This module contains the single mouse tracking pipeline. * It processes video input to track a single mouse. */ -include { PREDICT_SINGLE_MOUSE_SEGMENTATION; PREDICT_SINGLE_MOUSE_KEYPOINTS; CLIP_VIDEO_AND_POSE } from "${projectDir}/nextflow/modules/single_mouse" +include { PREDICT_SINGLE_MOUSE_SEGMENTATION; + PREDICT_SINGLE_MOUSE_KEYPOINTS; + CLIP_VIDEO_AND_POSE; + FILTER_LARGE_POSES; } from "${projectDir}/nextflow/modules/single_mouse" include { PREDICT_ARENA_CORNERS } from "${projectDir}/nextflow/modules/static_objects" include { PREDICT_FECAL_BOLI } from "${projectDir}/nextflow/modules/fecal_boli" include { QC_SINGLE_MOUSE } from "${projectDir}/nextflow/modules/single_mouse" @@ -40,14 +43,15 @@ workflow SINGLE_MOUSE_TRACKING { main: // Generate pose files pose_init = VIDEO_TO_POSE(input_video).files - // Pose v2 is output from this step + // Pose v2 is output from keypoint prediction step pose_v2_data = PREDICT_SINGLE_MOUSE_KEYPOINTS(pose_init).files if (params.align_videos) { pose_v2_data = CLIP_VIDEO_AND_POSE(pose_v2_data, params.clip_duration).files } + // Valid Pose v6 is produced when segmentation is added. pose_and_seg_data = PREDICT_SINGLE_MOUSE_SEGMENTATION(pose_v2_data).files - // Completed Pose v6 is output from this step - pose_with_corners = PREDICT_ARENA_CORNERS(pose_and_seg_data).files + filtered_pose_v6 = FILTER_LARGE_POSES(pose_and_seg_data).files + pose_with_corners = PREDICT_ARENA_CORNERS(filtered_pose_v6).files pose_v6_data = PREDICT_FECAL_BOLI(pose_with_corners).files // Publish the pose v2 results diff --git a/pyproject.toml b/pyproject.toml index 65ba3540..fb4c3c53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "h5py>=3.11.0", "pydantic-settings>=2.10.1", "yacs>=0.1.8", + "plotnine>=0.12.0", ] [project.optional-dependencies] @@ -28,6 +29,7 @@ gpu = [ "torch==2.6.0", "torchvision==0.21.0", "torchaudio==2.6.0", + "nvidia-cusparselt-cu12==0.6.3", ] # CPU-only convenience for local tests (unchanged idea) diff --git a/src/mouse_tracking/cli/qa.py b/src/mouse_tracking/cli/qa.py index 48d0f7b0..f50fe299 100644 --- a/src/mouse_tracking/cli/qa.py +++ b/src/mouse_tracking/cli/qa.py @@ -37,6 +37,25 @@ def single_pose( ) +@app.command() +def single_feature( + pose: Path = typer.Argument(..., help="Path to the pose file to inspect"), + behavior: Path = typer.Argument(..., help="Path to the behavior table to inspect"), + output: Path | None = typer.Option( + None, help="Output filename. Will append row if already exists." + ), +): + """Run single mouse feature inspection.""" + # Dynamically set the output filename if not provided + if not output: + output = Path( + f"QA_features_{pose.stem}_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.pdf" + ) + + # TODO implement desired plots of feature data for more in-depth inspection + raise NotImplementedError("Feature inspection is not yet implemented.") + + @app.command() def multi_pose(): """Run multi pose quality assurance.""" diff --git a/src/mouse_tracking/cli/utils.py b/src/mouse_tracking/cli/utils.py index b2f57e3e..17638ba0 100644 --- a/src/mouse_tracking/cli/utils.py +++ b/src/mouse_tracking/cli/utils.py @@ -6,13 +6,19 @@ from rich import print from mouse_tracking import __version__ +from mouse_tracking.core.config.pose_utils import PoseUtilsConfig from mouse_tracking.matching.match_predictions import match_predictions from mouse_tracking.pose import render -from mouse_tracking.pose.convert import downgrade_pose_file from mouse_tracking.utils import fecal_boli, static_objects from mouse_tracking.utils.clip_video import clip_video_auto, clip_video_manual +from mouse_tracking.utils.writers import ( + downgrade_pose_file, + filter_large_contours, + filter_large_keypoints, +) app = typer.Typer() +CONFIG = PoseUtilsConfig() def version_callback(value: bool) -> None: @@ -54,6 +60,27 @@ def aggregate_fecal_boli( result.to_csv(output, index=False) +@app.command() +def render_fecal_boli_video( + in_video: Path = typer.Option( + ..., "--in-video", help="Path to the input video file" + ), + in_pose: Path = typer.Option( + ..., "--in-pose", help="Path to the input HDF5 pose file" + ), + out_video: Path = typer.Option( + ..., "--out-video", help="Path to the output video file" + ), +): + """ + Render fecal boli on video frames. + + This command renders fecal boli from the pose file onto the input video. + Video playback is 1fps with original frame timestamp overlayed. + """ + fecal_boli.render_fecal_boli_video(str(in_video), str(in_pose), str(out_video)) + + clip_video_app = typer.Typer(help="Produce a video and pose clip aligned to criteria.") @@ -227,3 +254,26 @@ def stitch_tracklets( This command stitches tracklets from the specified source. """ match_predictions(in_pose) + + +@app.command() +def filter_large_area_pose( + in_pose: Path = typer.Argument(..., help="Input HDF5 pose file"), + max_area: int = typer.Option( + CONFIG.OFA_MAX_EXPECTED_AREA_PX, + help="Maximum area a pose can have, using a bounding box on keypoint pose.", + ), +): + """ + Filer pose by area. + + This command unmarks identity of pose (both keypoint and segmentation) with large areas. + """ + filter_large_keypoints( + in_pose, + max_area, + ) + filter_large_contours( + in_pose, + max_area, + ) diff --git a/src/mouse_tracking/core/config/pose_utils.py b/src/mouse_tracking/core/config/pose_utils.py index 21309e4a..9eeea3a2 100644 --- a/src/mouse_tracking/core/config/pose_utils.py +++ b/src/mouse_tracking/core/config/pose_utils.py @@ -35,6 +35,9 @@ class PoseUtilsConfig(BaseSettings): MIN_JABS_CONFIDENCE: float = 0.3 MIN_JABS_KEYPOINTS: int = 3 + # Large animals are rarely larger than 100px in our OFA + OFA_MAX_EXPECTED_AREA_PX: int = 150 * 150 + # Colors MOUSE_COLORS: list[tuple[int, int, int]] = [ (228, 26, 28), # Red diff --git a/src/mouse_tracking/pose/convert.py b/src/mouse_tracking/pose/convert.py index 5450ee31..8dae5c57 100644 --- a/src/mouse_tracking/pose/convert.py +++ b/src/mouse_tracking/pose/convert.py @@ -1,14 +1,8 @@ """Pose data conversion utilities.""" -import os -import re - -import h5py import numpy as np -from mouse_tracking.core.exceptions import InvalidPoseFileException from mouse_tracking.utils.run_length_encode import run_length_encode -from mouse_tracking.utils.writers import write_pixel_per_cm_attr, write_pose_v2_data def v2_to_v3(pose_data, conf_data, threshold: float = 0.3): @@ -95,53 +89,3 @@ def multi_to_v2(pose_data, conf_data, identity_data): return_list.append((cur_id, single_pose, single_conf)) return return_list - - -def downgrade_pose_file(pose_h5_path, disable_id: bool = False): - """Downgrades a multi-mouse pose file into multiple single mouse pose files. - - Args: - pose_h5_path: input pose file - disable_id: bool to disable identity embedding tracks (if available) and use tracklet data instead - """ - if not os.path.isfile(pose_h5_path): - raise FileNotFoundError(f"ERROR: missing file: {pose_h5_path}") - # Read in all the necessary data - with h5py.File(pose_h5_path, "r") as pose_h5: - if "version" in pose_h5["poseest"].attrs: - major_version = pose_h5["poseest"].attrs["version"][0] - else: - raise InvalidPoseFileException( - f"Pose file {pose_h5_path} did not have a valid version." - ) - if major_version == 2: - print(f"Pose file {pose_h5_path} is already v2. Exiting.") - exit(0) - - all_points = pose_h5["poseest/points"][:] - all_confidence = pose_h5["poseest/confidence"][:] - if major_version >= 4 and not disable_id: - all_track_id = pose_h5["poseest/instance_embed_id"][:] - elif major_version >= 3: - all_track_id = pose_h5["poseest/instance_track_id"][:] - try: - config_str = pose_h5["poseest/points"].attrs["config"] - model_str = pose_h5["poseest/points"].attrs["model"] - except (KeyError, AttributeError): - config_str = "unknown" - model_str = "unknown" - pose_attrs = pose_h5["poseest"].attrs - if "cm_per_pixel" in pose_attrs and "cm_per_pixel_source" in pose_attrs: - pixel_scaling = True - px_per_cm = pose_h5["poseest"].attrs["cm_per_pixel"] - source = pose_h5["poseest"].attrs["cm_per_pixel_source"] - else: - pixel_scaling = False - - downgraded_pose_data = multi_to_v2(all_points, all_confidence, all_track_id) - new_file_base = re.sub("_pose_est_v[0-9]+\\.h5", "", pose_h5_path) - for animal_id, pose_data, conf_data in downgraded_pose_data: - out_fname = f"{new_file_base}_animal_{animal_id}_pose_est_v2.h5" - write_pose_v2_data(out_fname, pose_data, conf_data, config_str, model_str) - if pixel_scaling: - write_pixel_per_cm_attr(out_fname, px_per_cm, source) diff --git a/src/mouse_tracking/pose/inspect.py b/src/mouse_tracking/pose/inspect.py index 130529ce..a3d1ec7e 100644 --- a/src/mouse_tracking/pose/inspect.py +++ b/src/mouse_tracking/pose/inspect.py @@ -13,6 +13,60 @@ CONFIG = PoseUtilsConfig() +def get_keypoint_bounding_box( + pose_data: np.ndarray, + pose_confidence: np.ndarray, + include_tail: bool = False, + min_confidence: float = CONFIG.MIN_JABS_CONFIDENCE, +) -> np.ma.array: + """Calculates bounding boxes of pose data. + + Args: + pose_data: keypoint data of shape [frame, n_animal, 12, 2] + pose_confidence: confidence values for pose_data in shape [frame, n_animal, 12] + include_tail: indicator to include tail keypoints in calculated bounding boxes + min_confidence: confidence threshold to consider points present + + Returns: + np.ma.array of shape [frame, n_animal, 2, 2] containing the bounding boxes. Boxes are stored as top-left, bottom-right with the final dimension matching order of pose data. + """ + masked_pose_data = np.ma.array( + pose_data, + mask=~np.repeat( + np.expand_dims(pose_confidence > min_confidence, -1), 2, axis=-1 + ), + ) + if not include_tail: + masked_pose_data = np.delete( + masked_pose_data, [CONFIG.MID_TAIL_INDEX, CONFIG.TIP_TAIL_INDEX], axis=-2 + ) + pose_boxes_br = np.max(masked_pose_data, axis=-2) + pose_boxes_tl = np.min(masked_pose_data, axis=-2) + return np.ma.stack([pose_boxes_tl, pose_boxes_br], axis=-2) + + +def get_contour_bounding_box( + contour_data: np.ndarray, + pad_value: int = -1, +) -> np.ma.array: + """Calculates bounding boxes for segmentation contour data. + + Args: + contour_data: segmentation data of shape [frame, n_animal, n_contour, n_points, 2] + pad_value: padding value used for making full matrices + + Returns: + np.ma.array of shape [frame, n_animal, 2, 2] containing the bounding boxes. Boxes are stored as top-left, bottom-right with the final dimension matching the order of contour data. + """ + masked_contours = np.ma.array( + contour_data, + mask=contour_data == pad_value, + ) + boxes_br = np.max(np.max(masked_contours, axis=-2), axis=-2) + boxes_tl = np.min(np.min(masked_contours, axis=-2), axis=-2) + return np.ma.stack([boxes_tl, boxes_br], axis=-2) + + def inspect_pose_v2(pose_file, pad: int = 150, duration: int = 108000) -> dict: """Inspects a single mouse pose file v2 for coverage metrics. @@ -28,6 +82,7 @@ def inspect_pose_v2(pose_file, pad: int = 150, duration: int = 108000) -> dict: pose_counts: total number of poses predicted missing_poses: missing poses in the primary duration of the video missing_keypoint_frames: number of frames which don't contain 12 keypoints in the primary duration + large_poses: number of poses that are larger than a typical mouse should be """ with h5py.File(pose_file, "r") as f: pose_version = f["poseest"].attrs["version"][0] @@ -35,12 +90,19 @@ def inspect_pose_v2(pose_file, pad: int = 150, duration: int = 108000) -> dict: msg = f"Only v2 pose files are supported for inspection. {pose_file} is version {pose_version}" raise ValueError(msg) pose_quality = f["poseest/confidence"][:] + pose_data = f["poseest/points"][:] num_keypoints = np.sum(pose_quality > CONFIG.MIN_JABS_CONFIDENCE, axis=1) high_conf_keypoints = np.all( pose_quality > CONFIG.MIN_HIGH_CONFIDENCE, axis=2 ).squeeze(1) + pose_boxes = get_keypoint_bounding_box(pose_data, pose_quality, True) + # Cast to float because uint16 subtraction and multiplication is not what we want + pose_boxes = pose_boxes.astype(float) + pose_box_size = pose_boxes[:, :, 1] - pose_boxes[:, :, 0] + pose_box_area = np.squeeze(pose_box_size[:, :, 0] * pose_box_size[:, :, 1]) + return { "first_frame_pose": safe_find_first(high_conf_keypoints), "first_frame_full_high_conf": safe_find_first(high_conf_keypoints), @@ -48,6 +110,9 @@ def inspect_pose_v2(pose_file, pad: int = 150, duration: int = 108000) -> dict: "missing_poses": duration - np.sum((num_keypoints > CONFIG.MIN_JABS_CONFIDENCE)[pad : pad + duration]), "missing_keypoint_frames": np.sum(num_keypoints[pad : pad + duration] != 12), + "large_poses": np.sum( + pose_box_area[pad : pad + duration] > CONFIG.OFA_MAX_EXPECTED_AREA_PX + ), } @@ -74,6 +139,7 @@ def inspect_pose_v6(pose_file, pad: int = 150, duration: int = 108000) -> dict: pose_counts: Total number of poses predicted seg_counts: Total number of segmentations matched with poses missing_poses: Missing poses in the observation duration of the video + large_poses: number of poses that are larger than a typical mouse should be missing_segs: Missing segmentations in the observation duration of the video pose_tracklets: Number of tracklets in the observation duration missing_keypoint_frames: Number of frames which don't contain 12 keypoints in the observation duration @@ -87,6 +153,7 @@ def inspect_pose_v6(pose_file, pad: int = 150, duration: int = 108000) -> dict: if np.max(pose_counts) > 1: msg = f"Only single mouse pose files are supported for inspection. {pose_file} contains multiple instances" raise ValueError(msg) + pose_data = f["poseest/points"][:] pose_quality = f["poseest/confidence"][:] pose_tracks = f["poseest/instance_track_id"][:] seg_ids = f["poseest/longterm_seg_id"][:] @@ -119,6 +186,12 @@ def inspect_pose_v6(pose_file, pad: int = 150, duration: int = 108000) -> dict: axis=2, ).squeeze(1) + pose_boxes = get_keypoint_bounding_box(pose_data, pose_quality, True) + # Cast to float because uint16 subtraction and multiplication is not what we want + pose_boxes = pose_boxes.astype(float) + pose_box_size = pose_boxes[:, :, 1] - pose_boxes[:, :, 0] + pose_box_area = np.squeeze(pose_box_size[:, :, 0] * pose_box_size[:, :, 1]) + return { "pose_file": Path(pose_file).name, "pose_hash": hash_file(Path(pose_file)), @@ -136,6 +209,9 @@ def inspect_pose_v6(pose_file, pad: int = 150, duration: int = 108000) -> dict: "pose_counts": np.sum(pose_counts), "seg_counts": np.sum(seg_ids > 0), "missing_poses": duration - np.sum(pose_counts[pad : pad + duration]), + "large_poses": np.sum( + pose_box_area[pad : pad + duration] > CONFIG.OFA_MAX_EXPECTED_AREA_PX + ), "missing_segs": duration - np.sum(seg_ids[pad : pad + duration] > 0), "pose_tracklets": len( np.unique( @@ -146,3 +222,53 @@ def inspect_pose_v6(pose_file, pad: int = 150, duration: int = 108000) -> dict: ), "missing_keypoint_frames": np.sum(num_keypoints[pad : pad + duration] != 12), } + + +def find_first_pose( + confidence, confidence_threshold: float = 0.3, num_keypoints: int = 12 +): + """Detects the first pose with all the keypoints. + + Args: + confidence: confidence matrix + confidence_threshold: minimum confidence to be considered a valid keypoint. See `convert_v2_to_v3` for additional notes on confidences + num_keypoints: number of keypoints + + Returns: + integer indicating the first frame when the pose was observed. + In the case of multi-animal, the first frame when any full pose was found + + Raises: + ValueError if no pose meets the criteria + """ + valid_keypoints = confidence > confidence_threshold + num_keypoints_in_pose = np.sum(valid_keypoints, axis=-1) + # Multi-mouse + if num_keypoints_in_pose.ndim == 2: + num_keypoints_in_pose = np.max(num_keypoints_in_pose, axis=-1) + + completed_pose_frames = np.argwhere(num_keypoints_in_pose >= num_keypoints) + if len(completed_pose_frames) == 0: + msg = f"No poses detected with {num_keypoints} keypoints and confidence threshold {confidence_threshold}" + raise ValueError(msg) + + return completed_pose_frames[0][0] + + +def find_first_pose_file( + pose_file, confidence_threshold: float = 0.3, num_keypoints: int = 12 +): + """Lazy wrapper for `find_first_pose` that reads in file data. + + Args: + pose_file: pose file to read confidence matrix from + confidence_threshold: see `find_first_pose` + num_keypoints: see `find_first_pose` + + Returns: + see `find_first_pose` + """ + with h5py.File(pose_file, "r") as f: + confidences = f["poseest/confidence"][...] + + return find_first_pose(confidences, confidence_threshold, num_keypoints) diff --git a/src/mouse_tracking/pytorch_inference/multi_pose.py b/src/mouse_tracking/pytorch_inference/multi_pose.py index 89e954d9..165a168d 100644 --- a/src/mouse_tracking/pytorch_inference/multi_pose.py +++ b/src/mouse_tracking/pytorch_inference/multi_pose.py @@ -11,10 +11,10 @@ import torch.backends.cudnn as cudnn from mouse_tracking.models.model_definitions import MULTI_MOUSE_POSE +from mouse_tracking.pose.render import render_pose_overlay from mouse_tracking.pytorch_inference.hrnet.config import cfg from mouse_tracking.pytorch_inference.hrnet.models import pose_hrnet from mouse_tracking.utils.hrnet import argmax_2d_torch, preprocess_hrnet -from mouse_tracking.utils.pose import render_pose_overlay from mouse_tracking.utils.prediction_saver import prediction_saver from mouse_tracking.utils.segmentation import get_frame_masks from mouse_tracking.utils.timers import time_accumulator diff --git a/src/mouse_tracking/pytorch_inference/single_pose.py b/src/mouse_tracking/pytorch_inference/single_pose.py index 8207b097..795f8ffb 100644 --- a/src/mouse_tracking/pytorch_inference/single_pose.py +++ b/src/mouse_tracking/pytorch_inference/single_pose.py @@ -10,10 +10,10 @@ import torch.backends.cudnn as cudnn from mouse_tracking.models.model_definitions import SINGLE_MOUSE_POSE +from mouse_tracking.pose.render import render_pose_overlay from mouse_tracking.pytorch_inference.hrnet.config import cfg from mouse_tracking.pytorch_inference.hrnet.models import pose_hrnet from mouse_tracking.utils.hrnet import argmax_2d_torch, preprocess_hrnet -from mouse_tracking.utils.pose import render_pose_overlay from mouse_tracking.utils.prediction_saver import prediction_saver from mouse_tracking.utils.timers import time_accumulator from mouse_tracking.utils.writers import write_pose_v2_data diff --git a/src/mouse_tracking/utils/clip_video.py b/src/mouse_tracking/utils/clip_video.py index 1dcc93a8..a4ad5d39 100644 --- a/src/mouse_tracking/utils/clip_video.py +++ b/src/mouse_tracking/utils/clip_video.py @@ -5,8 +5,8 @@ import numpy as np +from mouse_tracking.pose.inspect import find_first_pose_file from mouse_tracking.utils import writers -from mouse_tracking.utils.pose import find_first_pose_file from mouse_tracking.utils.timers import print_time diff --git a/src/mouse_tracking/utils/features.py b/src/mouse_tracking/utils/features.py new file mode 100644 index 00000000..4ff1d7b9 --- /dev/null +++ b/src/mouse_tracking/utils/features.py @@ -0,0 +1,155 @@ +"""Feature reading and inspecting utility functions.""" + +import os +from pathlib import Path + +import h5py +import numpy as np +import pandas as pd + + +class JABSFeature: + """Methods to interact with JABS feature data.""" + + def __init__(self, feature_file: Path): + """Initializes a feature object. + + Args: + feature_file: file to interact with feature data + """ + assert os.path.exists(feature_file) + + # Transforms feature file into a metadata object + self._file = feature_file + # Populate keys generates a pandas table to better index which features are available + self._populate_keys() + + @property + def feature_keys(self): + """Dataframe containing the available feature keys.""" + return self._feature_keys.copy() + + def get_key_data(self, key: str): + """Retrieves raw data contained within a feature file. + + Args: + key: fully defined key to extract from the feature file + + Returns: + np.ndarray containing the requested data + + Todo: + Have this function check for valid keys before crashing + """ + found_data = False + with h5py.File(self._file, "r") as f: + if key in f: + retrieved_data = f[key][...] + found_data = True + + # Try and search for a simplified key + if not found_data: + raise ValueError(f"Full keys only supported currently. Given {key}") + + return retrieved_data + + def get_window_feature( + self, + feature_key: str, + window_size: int, + window_op: str, + feature_module: str | None = None, + ): + """Retrieves the stored feature vector from a window feature. + + Args: + feature_key: key of the feature to retrieve + window_size: window size of the window feature + window_op: window operation for the feature + feature_module: Optional module which the features belong to + + Returns: + np.ndarray containing the feature data + + Raises: + KeyError if the feature key, window size, or window op are not present + """ + if feature_module is None: + feature_module = self.discover_feature_module(feature_key) + + if "window_op" not in self._feature_keys: + raise KeyError( + f"Feature file {self._file} does not contain window feature data." + ) + if not ( + (self._feature_keys["module"] == feature_module) + & (self._feature_keys["window_op"] == window_op) + & (self._feature_keys["feature"] == feature_key) + ).any(): + raise KeyError("Module, Window Op, and Feature key not available.") + + feature_key = f"features/window_features_{window_size}/{feature_module} {window_op} {feature_key}" + + with h5py.File(self._file, "r") as f: + feature_data = f[feature_key][:] + + return feature_data + + def discover_feature_module(self, feature_key: str): + """Discovers the feature module given a key. + + Args: + feature_key: feature to identify the module name + + Returns: + module string of the feature + + Raises: + ValueError if module is not unique + """ + discovered_feature_module = np.unique( + self._feature_keys.loc[ + self._feature_keys["feature"] == feature_key, "module" + ] + ) + if len(discovered_feature_module) != 1: + raise ValueError( + f"Feature {feature_key} does not map uniquely to a feature module. Found modules: {discovered_feature_module}" + ) + return discovered_feature_module[0] + + def _populate_keys(self): + """Populates the available module-feature pairs.""" + with h5py.File(self._file, "r") as f: + feature_grps = list(f["features"].keys()) + available_features = list(f["features/per_frame"].keys()) + + base_features = pd.DataFrame( + [["features/per_frame/" + x, *x.split(" ", 1)] for x in available_features], + columns=["key", "module", "feature"], + ) + + self._window_sizes = [ + x.split("_")[2] for x in feature_grps if x.startswith("window_features_") + ] + discovered_window_keys = [] + if len(self._window_sizes) > 0: + with h5py.File(self._file, "r") as f: + window_keys = list( + f[f"features/window_features_{self._window_sizes[0]}"].keys() + ) + for cur_window in self._window_sizes: + next_window_keys = pd.DataFrame( + [ + [ + f"features/window_features_{cur_window}/{x}", + *x.split(" ", 1), + cur_window, + ] + for x in window_keys + ], + columns=["key", "module", "feature", "window_size"], + ) + discovered_window_keys.append(next_window_keys) + + self._feature_keys = pd.concat([base_features, *discovered_window_keys]) diff --git a/src/mouse_tracking/utils/fecal_boli.py b/src/mouse_tracking/utils/fecal_boli.py index fc61bf88..f280bce9 100644 --- a/src/mouse_tracking/utils/fecal_boli.py +++ b/src/mouse_tracking/utils/fecal_boli.py @@ -3,9 +3,14 @@ import glob import h5py +import imageio import numpy as np import pandas as pd +from mouse_tracking.utils.rendering import plot_frame_info +from mouse_tracking.utils.static_objects import plot_keypoints +from mouse_tracking.utils.timers import print_time + def aggregate_folder_data(folder: str, depth: int = 2, num_bins: int = -1): """Aggregates fecal boli data in a folder into a table. @@ -55,3 +60,56 @@ def aggregate_folder_data(folder: str, depth: int = 2, num_bins: int = -1): all_data = pd.concat(read_data).reset_index(drop=False) return all_data + + +def render_fecal_boli_video(in_video: str, in_pose: str, out_video: str): + """ + Renders fecal boli on a frame. + + Args: + in_video: The input video file + in_pose: The input pose file + out_video: The output video file + """ + # Open the input video + vid_writer = imageio.get_writer(out_video, fps=1) + + # Load the pose data + with h5py.File(in_pose, "r") as f: + fecal_boli = f["dynamic_objects/fecal_boli/points"][...] + fecal_boli_counts = f["dynamic_objects/fecal_boli/counts"][...] + fecal_boli_frames = f["dynamic_objects/fecal_boli/sample_indices"][...] + + video_reader = imageio.get_reader(in_video) + video_done = False + prediction_idx = 0 + + while not video_done: + try: + prediction_frame = fecal_boli_frames[prediction_idx] + input_frame = video_reader.get_data(prediction_frame) + except StopIteration: + video_done = True + break + + fecal_boli_count_in_frame = fecal_boli_counts[prediction_idx] + fecal_boli_data = fecal_boli[prediction_idx, : int(fecal_boli_count_in_frame)] + if fecal_boli_count_in_frame > 0: + rendered_frame = plot_keypoints( + fecal_boli_data, input_frame, is_yx=True, radius=5, alpha=0.5 + ) + else: + rendered_frame = input_frame + + rendered_frame = plot_frame_info( + rendered_frame, f"Video Timestamp: {print_time(prediction_frame)}" + ) + + # Write the frame to the output video + vid_writer.append_data(rendered_frame) + prediction_idx += 1 + if prediction_idx >= len(fecal_boli_frames): + video_done = True + break + + vid_writer.close() diff --git a/src/mouse_tracking/utils/plotting.py b/src/mouse_tracking/utils/plotting.py new file mode 100644 index 00000000..a8e3152a --- /dev/null +++ b/src/mouse_tracking/utils/plotting.py @@ -0,0 +1,45 @@ +"""Helper functions for plotting data.""" + +from pathlib import Path + +import numpy as np +import pandas as pd +import plotnine as p9 + +from mouse_tracking.utils.features import JABSFeature + + +def plot_jabs_feature(feature_file: Path, feature: str): + """ + Generates a plot for a JABS feature for the requested time. + + Args: + feature_file: JABS feature file + feature: JABS feature key to plot + + Returns: + matplotlib.figure.Figure of the plot + + Raises: + ValueError when feature does not exist in the feature file. + """ + feature_obj = JABSFeature(feature_file) + if feature not in np.asarray(feature_obj.feature_keys[["key"]].values).reshape( + [-1] + ): + raise ValueError( + f"Feature {feature} not present in feature file {feature_file.name}" + ) + + feature_arr = feature_obj.get_key_data(feature) + plot = ( + p9.ggplot( + pd.DataFrame({"frame": np.arange(len(feature_arr)), "val": feature_arr}), + p9.aes(x="frame/30/60", y="val"), + ) + + p9.geom_point() + + p9.theme_bw() + + p9.labs(x="minute", y=feature) + ) + + return plot.draw() diff --git a/src/mouse_tracking/utils/pose.py b/src/mouse_tracking/utils/pose.py deleted file mode 100644 index 32e3b1af..00000000 --- a/src/mouse_tracking/utils/pose.py +++ /dev/null @@ -1,354 +0,0 @@ -import re -from pathlib import Path - -import cv2 -import h5py -import numpy as np - -from mouse_tracking.utils.arrays import safe_find_first -from mouse_tracking.utils.hashing import hash_file -from mouse_tracking.utils.run_length_encode import rle - -NOSE_INDEX = 0 -LEFT_EAR_INDEX = 1 -RIGHT_EAR_INDEX = 2 -BASE_NECK_INDEX = 3 -LEFT_FRONT_PAW_INDEX = 4 -RIGHT_FRONT_PAW_INDEX = 5 -CENTER_SPINE_INDEX = 6 -LEFT_REAR_PAW_INDEX = 7 -RIGHT_REAR_PAW_INDEX = 8 -BASE_TAIL_INDEX = 9 -MID_TAIL_INDEX = 10 -TIP_TAIL_INDEX = 11 - -CONNECTED_SEGMENTS = [ - [LEFT_FRONT_PAW_INDEX, CENTER_SPINE_INDEX, RIGHT_FRONT_PAW_INDEX], - [LEFT_REAR_PAW_INDEX, BASE_TAIL_INDEX, RIGHT_REAR_PAW_INDEX], - [ - NOSE_INDEX, - BASE_NECK_INDEX, - CENTER_SPINE_INDEX, - BASE_TAIL_INDEX, - MID_TAIL_INDEX, - TIP_TAIL_INDEX, - ], -] - -MIN_HIGH_CONFIDENCE = 0.75 -MIN_GAIT_CONFIDENCE = 0.3 -MIN_JABS_CONFIDENCE = 0.3 -MIN_JABS_KEYPOINTS = 3 - - -def convert_v2_to_v3(pose_data, conf_data, threshold: float = 0.3): - """Converts single mouse pose data into multimouse. - - Args: - pose_data: single mouse pose data of shape [frame, 12, 2] - conf_data: keypoint confidence data of shape [frame, 12] - threshold: threshold for filtering valid keypoint predictions - 0.3 is used in JABS - 0.4 is used for multi-mouse prediction code - 0.5 is a typical default in other software - - Returns: - tuple of (pose_data_v3, conf_data_v3, instance_count, instance_embedding, instance_track_id) - pose_data_v3: pose_data reformatted to v3 - conf_data_v3: conf_data reformatted to v3 - instance_count: instance count field for v3 files - instance_embedding: dummy data for embedding data field in v3 files - instance_track_id: tracklet data for v3 files - """ - pose_data_v3 = np.reshape(pose_data, [-1, 1, 12, 2]) - conf_data_v3 = np.reshape(conf_data, [-1, 1, 12]) - bad_pose_data = conf_data_v3 < threshold - pose_data_v3[np.repeat(np.expand_dims(bad_pose_data, -1), 2, axis=-1)] = 0 - conf_data_v3[bad_pose_data] = 0 - instance_count = np.full([pose_data_v3.shape[0]], 1, dtype=np.uint8) - instance_count[np.all(bad_pose_data, axis=-1).reshape(-1)] = 0 - instance_embedding = np.full(conf_data_v3.shape, 0, dtype=np.float32) - # Tracks can only be continuous blocks - instance_track_id = np.full(pose_data_v3.shape[:2], 0, dtype=np.uint32) - rle_starts, rle_durations, rle_values = rle(instance_count) - for i, (start, duration) in enumerate( - zip(rle_starts[rle_values == 1], rle_durations[rle_values == 1], strict=False) - ): - instance_track_id[start : start + duration] = i - return ( - pose_data_v3, - conf_data_v3, - instance_count, - instance_embedding, - instance_track_id, - ) - - -def convert_multi_to_v2(pose_data, conf_data, identity_data): - """Converts multi mouse pose data (v3+) into multiple single mouse (v2). - - Args: - pose_data: multi mouse pose data of shape [frame, max_animals, 12, 2] - conf_data: keypoint confidence data of shape [frame, max_animals, 12] - identity_data: identity data which indicates animal indices of shape [frame, max_animals] - - Returns: - list of tuples containing (id, pose_data_v2, conf_data_v2) - id: tracklet id - pose_data_v2: pose_data reformatted to v2 - conf_data_v2: conf_data reformatted to v2 - - Raises: - ValueError if an identity has 2 pose predictions in a single frame. - """ - invalid_poses = np.all(conf_data == 0, axis=-1) - id_values = np.unique(identity_data[~invalid_poses]) - masked_id_data = identity_data.copy().astype(np.int32) - # This is to handle id 0 (with 0-padding). -1 is an invalid id. - masked_id_data[invalid_poses] = -1 - - return_list = [] - for cur_id in id_values: - id_frames, id_idxs = np.where(masked_id_data == cur_id) - if len(id_frames) != len(set(id_frames)): - sorted_frames = np.sort(id_frames) - duplicated_frames = sorted_frames[:-1][ - sorted_frames[1:] == sorted_frames[:-1] - ] - msg = f"Identity {cur_id} contained multiple poses assigned on frames {duplicated_frames}." - raise ValueError(msg) - single_pose = np.zeros([len(pose_data), 12, 2], dtype=pose_data.dtype) - single_conf = np.zeros([len(pose_data), 12], dtype=conf_data.dtype) - single_pose[id_frames] = pose_data[id_frames, id_idxs] - single_conf[id_frames] = conf_data[id_frames, id_idxs] - - return_list.append((cur_id, single_pose, single_conf)) - - return return_list - - -def render_pose_overlay( - image: np.ndarray, - frame_points: np.ndarray, - exclude_points: list | None = None, - color: tuple = (255, 255, 255), -) -> np.ndarray: - """Renders a single pose on an image. - - Args: - image: image to render pose on - frame_points: keypoints to render. keypoints are ordered [y, x] - exclude_points: set of keypoint indices to exclude - color: color to render the pose - - Returns: - modified image - """ - if exclude_points is None: - exclude_points = [] - new_image = image.copy() - missing_keypoints = np.where(np.all(frame_points == 0, axis=-1))[0].tolist() - exclude_points = set(exclude_points + missing_keypoints) - - def gen_line_fragments(): - """Created lines to draw.""" - for curr_pt_indexes in CONNECTED_SEGMENTS: - curr_fragment = [] - for curr_pt_index in curr_pt_indexes: - if curr_pt_index in exclude_points: - if len(curr_fragment) >= 2: - yield curr_fragment - curr_fragment = [] - else: - curr_fragment.append(curr_pt_index) - if len(curr_fragment) >= 2: - yield curr_fragment - - line_pt_indexes = list(gen_line_fragments()) - - for curr_line_indexes in line_pt_indexes: - line_pts = np.array( - [(pt_x, pt_y) for pt_y, pt_x in frame_points[curr_line_indexes]], np.int32 - ) - if np.any(np.all(line_pts == 0, axis=-1)): - continue - cv2.polylines(new_image, [line_pts], False, (0, 0, 0), 2, cv2.LINE_AA) - cv2.polylines(new_image, [line_pts], False, color, 1, cv2.LINE_AA) - - for point_index in range(12): - if point_index in exclude_points: - continue - point_y, point_x = frame_points[point_index, :] - cv2.circle(new_image, (point_x, point_y), 3, (0, 0, 0), -1, cv2.LINE_AA) - cv2.circle(new_image, (point_x, point_y), 2, color, -1, cv2.LINE_AA) - - return new_image - - -def find_first_pose( - confidence, confidence_threshold: float = 0.3, num_keypoints: int = 12 -): - """Detects the first pose with all the keypoints. - - Args: - confidence: confidence matrix - confidence_threshold: minimum confidence to be considered a valid keypoint. See `convert_v2_to_v3` for additional notes on confidences - num_keypoints: number of keypoints - - Returns: - integer indicating the first frame when the pose was observed. - In the case of multi-animal, the first frame when any full pose was found - - Raises: - ValueError if no pose meets the criteria - """ - valid_keypoints = confidence > confidence_threshold - num_keypoints_in_pose = np.sum(valid_keypoints, axis=-1) - # Multi-mouse - if num_keypoints_in_pose.ndim == 2: - num_keypoints_in_pose = np.max(num_keypoints_in_pose, axis=-1) - - completed_pose_frames = np.argwhere(num_keypoints_in_pose >= num_keypoints) - if len(completed_pose_frames) == 0: - msg = f"No poses detected with {num_keypoints} keypoints and confidence threshold {confidence_threshold}" - raise ValueError(msg) - - return completed_pose_frames[0][0] - - -def find_first_pose_file( - pose_file, confidence_threshold: float = 0.3, num_keypoints: int = 12 -): - """Lazy wrapper for `find_first_pose` that reads in file data. - - Args: - pose_file: pose file to read confidence matrix from - confidence_threshold: see `find_first_pose` - num_keypoints: see `find_first_pose` - - Returns: - see `find_first_pose` - """ - with h5py.File(pose_file, "r") as f: - confidences = f["poseest/confidence"][...] - - return find_first_pose(confidences, confidence_threshold, num_keypoints) - - -def inspect_pose_v2(pose_file, pad: int = 150, duration: int = 108000): - """Inspects a single mouse pose file v2 for coverage metrics. - - Args: - pose_file: The pose file to inspect - pad: pad size expected in the beginning - duration: expected duration of experiment - - Returns: - Dict containing the following keyed data: - first_frame_pose: First frame where the pose data appeared - first_frame_full_high_conf: First frame with 12 keypoints at high confidence - pose_counts: total number of poses predicted - missing_poses: missing poses in the primary duration of the video - missing_keypoint_frames: number of frames which don't contain 12 keypoints in the primary duration - """ - with h5py.File(pose_file, "r") as f: - pose_version = f["poseest"].attrs["version"][0] - if pose_version != 2: - msg = f"Only v2 pose files are supported for inspection. {pose_file} is version {pose_version}" - raise ValueError(msg) - pose_quality = f["poseest/confidence"][:] - - num_keypoints = np.sum(pose_quality > MIN_JABS_CONFIDENCE, axis=1) - return_dict = {} - return_dict["first_frame_pose"] = safe_find_first(np.all(num_keypoints, axis=1)) - high_conf_keypoints = np.all(pose_quality > MIN_HIGH_CONFIDENCE, axis=2).squeeze(1) - return_dict["first_frame_full_high_conf"] = safe_find_first(high_conf_keypoints) - return_dict["pose_counts"] = np.sum(num_keypoints > MIN_JABS_CONFIDENCE) - return_dict["missing_poses"] = duration - np.sum( - (num_keypoints > MIN_JABS_CONFIDENCE)[pad : pad + duration] - ) - return_dict["missing_keypoint_frames"] = np.sum( - num_keypoints[pad : pad + duration] != 12 - ) - return return_dict - - -def inspect_pose_v6(pose_file, pad: int = 150, duration: int = 108000): - """Inspects a single mouse pose file v6 for coverage metrics. - - Args: - pose_file: The pose file to inspect - pad: duration of data skipped in the beginning (not observation period) - duration: observation duration of experiment - - Returns: - Dict containing the following keyed data: - pose_file: The pose file inspected - pose_hash: The blake2b hash of the pose file - video_name: The video name associated with the pose file (no extension) - video_duration: Duration of the video - corners_present: If the corners are present in the pose file - first_frame_pose: First frame where the pose data appeared - first_frame_full_high_conf: First frame with 12 keypoints > 0.75 confidence - first_frame_jabs: First frame with 3 keypoints > 0.3 confidence - first_frame_gait: First frame > 0.3 confidence for base tail and rear paws keypoints - first_frame_seg: First frame where segmentation data was assigned an id - pose_counts: Total number of poses predicted - seg_counts: Total number of segmentations matched with poses - missing_poses: Missing poses in the observation duration of the video - missing_segs: Missing segmentations in the observation duration of the video - pose_tracklets: Number of tracklets in the observation duration - missing_keypoint_frames: Number of frames which don't contain 12 keypoints in the observation duration - """ - with h5py.File(pose_file, "r") as f: - pose_version = f["poseest"].attrs["version"][0] - if pose_version < 6: - msg = f"Only v6+ pose files are supported for inspection. {pose_file} is version {pose_version}" - raise ValueError(msg) - pose_counts = f["poseest/instance_count"][:] - if np.max(pose_counts) > 1: - msg = f"Only single mouse pose files are supported for inspection. {pose_file} contains multiple instances" - raise ValueError(msg) - pose_quality = f["poseest/confidence"][:] - pose_tracks = f["poseest/instance_track_id"][:] - seg_ids = f["poseest/longterm_seg_id"][:] - corners_present = "static_objects/corners" in f - - num_keypoints = 12 - np.sum(pose_quality.squeeze(1) == 0, axis=1) - return_dict = {} - return_dict["pose_file"] = Path(pose_file).name - return_dict["pose_hash"] = hash_file(Path(pose_file)) - # Keep 2 folders if present for video name - folder_name = "/".join(Path(pose_file).parts[-3:-1]) + "/" - return_dict["video_name"] = folder_name + re.sub( - "_pose_est_v[0-9]+", "", Path(pose_file).stem - ) - return_dict["video_duration"] = pose_counts.shape[0] - return_dict["corners_present"] = corners_present - return_dict["first_frame_pose"] = safe_find_first(pose_counts > 0) - high_conf_keypoints = np.all(pose_quality > MIN_HIGH_CONFIDENCE, axis=2).squeeze(1) - return_dict["first_frame_full_high_conf"] = safe_find_first(high_conf_keypoints) - jabs_keypoints = np.sum(pose_quality > MIN_JABS_CONFIDENCE, axis=2).squeeze(1) - return_dict["first_frame_jabs"] = safe_find_first( - jabs_keypoints >= MIN_JABS_KEYPOINTS - ) - gait_keypoints = np.all( - pose_quality[:, :, [BASE_TAIL_INDEX, LEFT_REAR_PAW_INDEX, RIGHT_REAR_PAW_INDEX]] - > MIN_GAIT_CONFIDENCE, - axis=2, - ).squeeze(1) - return_dict["first_frame_gait"] = safe_find_first(gait_keypoints) - return_dict["first_frame_seg"] = safe_find_first(seg_ids > 0) - return_dict["pose_counts"] = np.sum(pose_counts) - return_dict["seg_counts"] = np.sum(seg_ids > 0) - return_dict["missing_poses"] = duration - np.sum(pose_counts[pad : pad + duration]) - return_dict["missing_segs"] = duration - np.sum(seg_ids[pad : pad + duration] > 0) - return_dict["pose_tracklets"] = len( - np.unique( - pose_tracks[pad : pad + duration][pose_counts[pad : pad + duration] == 1] - ) - ) - return_dict["missing_keypoint_frames"] = np.sum( - num_keypoints[pad : pad + duration] != 12 - ) - return return_dict diff --git a/src/mouse_tracking/utils/rendering.py b/src/mouse_tracking/utils/rendering.py new file mode 100644 index 00000000..39078e7f --- /dev/null +++ b/src/mouse_tracking/utils/rendering.py @@ -0,0 +1,47 @@ +"""Helper functions for rendering information on frames.""" + +import cv2 +import numpy as np + + +def plot_frame_info(frame: np.ndarray, info_text: str): + """ + Plots information on a video frame. + + Args: + frame: The video frame to annotate + info_text: The text to display on the frame + + Returns: + Copy of frame with text overlay + """ + # Get a copy of the frame to draw on + annotated_frame = frame.copy() + + # Define the position for the text + text_position = [25, 25] # Top left + + # Put the text on the frame + # Black bordered orange + cv2.putText( + annotated_frame, + info_text, + text_position, + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (0, 0, 0), # Black + 2, + cv2.LINE_AA, + ) + cv2.putText( + annotated_frame, + info_text, + text_position, + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (241, 163, 64), # Orange + 1, + cv2.LINE_AA, + ) + + return annotated_frame diff --git a/src/mouse_tracking/utils/static_objects.py b/src/mouse_tracking/utils/static_objects.py index 6911aa82..b3c2bd7c 100644 --- a/src/mouse_tracking/utils/static_objects.py +++ b/src/mouse_tracking/utils/static_objects.py @@ -20,6 +20,9 @@ def plot_keypoints( kp: np.ndarray, img: np.ndarray, color: tuple = (0, 0, 255), + alpha: float = 1.0, + thickness: int = 1, + radius: int = 2, is_yx: bool = False, include_lines: bool = False, ) -> np.ndarray: @@ -29,6 +32,9 @@ def plot_keypoints( kp: keypoints of shape [n_keypoints, 2] img: image to render the keypoint on color: BGR tuple to render the keypoint + alpha: blending factor for the overlay + thickness: thickness of the black border + radius: radius of the keypoint circle is_yx: are the keypoints formatted y, x instead of x, y? include_lines: also render lines between keypoints? @@ -39,18 +45,32 @@ def plot_keypoints( kps_ordered = np.flip(kp, axis=-1) if is_yx else kp if include_lines and kps_ordered.ndim == 2 and kps_ordered.shape[0] >= 1: img_copy = cv2.drawContours( - img_copy, [kps_ordered.astype(np.int32)], 0, (0, 0, 0), 2, cv2.LINE_AA + img_copy, + [kps_ordered.astype(np.int32)], + 0, + (0, 0, 0), + 1 + thickness, + cv2.LINE_AA, ) img_copy = cv2.drawContours( img_copy, [kps_ordered.astype(np.int32)], 0, color, 1, cv2.LINE_AA ) for _i, kp_data in enumerate(kps_ordered): _ = cv2.circle( - img_copy, (int(kp_data[0]), int(kp_data[1])), 3, (0, 0, 0), -1, cv2.LINE_AA + img_copy, + (int(kp_data[0]), int(kp_data[1])), + radius + thickness, + (0, 0, 0), + -1, + cv2.LINE_AA, ) _ = cv2.circle( - img_copy, (int(kp_data[0]), int(kp_data[1])), 2, color, -1, cv2.LINE_AA + img_copy, (int(kp_data[0]), int(kp_data[1])), radius, color, -1, cv2.LINE_AA ) + + if alpha != 1.0: + img_copy = cv2.addWeighted(img_copy, alpha, img.copy(), 1 - alpha, 0) + return img_copy diff --git a/src/mouse_tracking/utils/writers.py b/src/mouse_tracking/utils/writers.py index 9efcebcf..5f3102c3 100644 --- a/src/mouse_tracking/utils/writers.py +++ b/src/mouse_tracking/utils/writers.py @@ -1,5 +1,7 @@ """Functions related to saving data to pose files.""" +import os +import re from pathlib import Path import h5py @@ -7,7 +9,11 @@ from mouse_tracking.core.exceptions import InvalidPoseFileException from mouse_tracking.matching import hungarian_match_points_seg -from mouse_tracking.utils.pose import convert_v2_to_v3 +from mouse_tracking.pose.convert import multi_to_v2, v2_to_v3 +from mouse_tracking.pose.inspect import ( + get_contour_bounding_box, + get_keypoint_bounding_box, +) def promote_pose_data(pose_file, current_version: int, new_version: int): @@ -43,7 +49,7 @@ def promote_pose_data(pose_file, current_version: int, new_version: int): config_str = "unknown" model_str = "unknown" pose_data, conf_data, instance_count, instance_embedding, instance_track_id = ( - convert_v2_to_v3(pose_data, conf_data) + v2_to_v3(pose_data, conf_data) ) # Overwrite the existing data with a new axis write_pose_v2_data(pose_file, pose_data, conf_data, config_str, model_str) @@ -588,3 +594,129 @@ def write_pose_clip( for key, attrs in all_attrs.items(): for cur_attr, data in attrs.items(): out_f[key].attrs.create(cur_attr, data) + + +def downgrade_pose_file(pose_h5_path, disable_id: bool = False): + """Downgrades a multi-mouse pose file into multiple single mouse pose files. + + Args: + pose_h5_path: input pose file + disable_id: bool to disable identity embedding tracks (if available) and use tracklet data instead + """ + if not os.path.isfile(pose_h5_path): + raise FileNotFoundError(f"ERROR: missing file: {pose_h5_path}") + # Read in all the necessary data + with h5py.File(pose_h5_path, "r") as pose_h5: + if "version" in pose_h5["poseest"].attrs: + major_version = pose_h5["poseest"].attrs["version"][0] + else: + raise InvalidPoseFileException( + f"Pose file {pose_h5_path} did not have a valid version." + ) + if major_version == 2: + print(f"Pose file {pose_h5_path} is already v2. Exiting.") + exit(0) + + all_points = pose_h5["poseest/points"][:] + all_confidence = pose_h5["poseest/confidence"][:] + if major_version >= 4 and not disable_id: + all_track_id = pose_h5["poseest/instance_embed_id"][:] + elif major_version >= 3: + all_track_id = pose_h5["poseest/instance_track_id"][:] + try: + config_str = pose_h5["poseest/points"].attrs["config"] + model_str = pose_h5["poseest/points"].attrs["model"] + except (KeyError, AttributeError): + config_str = "unknown" + model_str = "unknown" + pose_attrs = pose_h5["poseest"].attrs + if "cm_per_pixel" in pose_attrs and "cm_per_pixel_source" in pose_attrs: + pixel_scaling = True + px_per_cm = pose_h5["poseest"].attrs["cm_per_pixel"] + source = pose_h5["poseest"].attrs["cm_per_pixel_source"] + else: + pixel_scaling = False + + downgraded_pose_data = multi_to_v2(all_points, all_confidence, all_track_id) + new_file_base = re.sub("_pose_est_v[0-9]+\\.h5", "", pose_h5_path) + for animal_id, pose_data, conf_data in downgraded_pose_data: + out_fname = f"{new_file_base}_animal_{animal_id}_pose_est_v2.h5" + write_pose_v2_data(out_fname, pose_data, conf_data, config_str, model_str) + if pixel_scaling: + write_pixel_per_cm_attr(out_fname, px_per_cm, source) + + +def filter_large_keypoints(in_pose_f: str | Path, area_threshold: float): + """Unmarks identity of keypoints that exceed area threshold. + + Args: + in_pose_f: Input pose filename + area_threshold: maximum pose bounding box allowed + + Raises: + InvalidPoseFileException if the pose file is not >= 4. + """ + with h5py.File(in_pose_f, "r") as f: + try: + current_version = f["poseest"].attrs["version"][0] + except (KeyError, AttributeError, IndexError): + InvalidPoseFileException("Pose file does not have a version.") + if current_version < 4: + raise InvalidPoseFileException( + f"Pose file {in_pose_f} is {current_version}. Filtering is only implemented for pose file versions > 4." + ) + + pose_data = f["poseest/points"][:] + pose_confidence = f["poseest/confidence"][:] + identity_data = f["poseest/instance_embed_id"][:] + pose_masks = f["poseest/id_mask"][:] + + pose_boxes = get_keypoint_bounding_box(pose_data, pose_confidence) + pose_boxes = pose_boxes.astype(float) + pose_box_size = pose_boxes[:, :, 1] - pose_boxes[:, :, 0] + pose_box_area = pose_box_size[:, :, 0] * pose_box_size[:, :, 1] + + identities_to_unassign = np.where(pose_box_area > area_threshold) + identity_data[identities_to_unassign] = 0 + pose_masks[identities_to_unassign] = 1 + pose_confidence[identities_to_unassign] = 0.0 + + with h5py.File(in_pose_f, "a") as f: + f["poseest/instance_embed_id"][:] = identity_data + f["poseest/id_mask"][:] = pose_masks + f["poseest/confidence"][:] = pose_confidence + + +def filter_large_contours(in_pose_f: str | Path, area_threshold: float): + """Unmarks identity of contour data that exceed area threshold. + + Args: + in_pose_f: Input pose filename + area_threshold: maximum pose bounding box allowed + + Raises: + InvalidPoseFileException f the pose file is not >= 6. + """ + with h5py.File(in_pose_f, "r") as f: + try: + current_version = f["poseest"].attrs["version"][0] + except (KeyError, AttributeError, IndexError): + InvalidPoseFileException("Pose file is does not have a version.") + if current_version < 6: + raise InvalidPoseFileException( + f"Pose file {in_pose_f} is {current_version}. Filtering is only implement for pose file version > 6." + ) + + seg_data = f["poseest/seg_data"][:] + seg_ids = f["poseest/longterm_seg_id"][:] + + seg_boxes = get_contour_bounding_box(seg_data) + seg_boxes = seg_boxes.astype(float) + seg_box_size = seg_boxes[:, :, 1] - seg_boxes[:, :, 0] + seg_box_area = seg_box_size[:, :, 0] * seg_box_size[:, :, 1] + + identities_to_unassign = np.where(seg_box_area > area_threshold) + seg_ids[identities_to_unassign] = 0 + + with h5py.File(in_pose_f, "a") as f: + f["poseest/longterm_seg_id"][:] = seg_ids diff --git a/tests/cli/qa/test_commands.py b/tests/cli/qa/test_commands.py index 80c6feae..04f6b280 100644 --- a/tests/cli/qa/test_commands.py +++ b/tests/cli/qa/test_commands.py @@ -58,7 +58,7 @@ def test_qa_commands_registered(command_name, expected_docstring): def test_all_expected_qa_commands_present(): """Test that all expected QA commands are present.""" # Arrange - expected_commands = {"single_pose", "multi_pose"} + expected_commands = {"single_pose", "multi_pose", "single_feature"} # Act registered_commands = app.registered_commands diff --git a/tests/cli/utils/test_commands.py b/tests/cli/utils/test_commands.py index e9c718bd..8567a1df 100644 --- a/tests/cli/utils/test_commands.py +++ b/tests/cli/utils/test_commands.py @@ -68,7 +68,9 @@ def test_all_expected_utils_commands_present(): expected_commands = { "aggregate_fecal_boli", "downgrade_multi_to_single", + "filter_large_area_pose", "flip_xy_field", + "render_fecal_boli_video", "render_pose", "stitch_tracklets", } diff --git a/tests/pose/convert/test_downgrade_pose_file.py b/tests/pose/convert/test_downgrade_pose_file.py index e5bf6c2f..dd1e4919 100644 --- a/tests/pose/convert/test_downgrade_pose_file.py +++ b/tests/pose/convert/test_downgrade_pose_file.py @@ -11,7 +11,7 @@ import pytest from mouse_tracking.core.exceptions import InvalidPoseFileException -from mouse_tracking.pose.convert import downgrade_pose_file +from mouse_tracking.utils.writers import downgrade_pose_file def _create_mock_h5_file_context(data_dict, attrs_dict): @@ -55,7 +55,7 @@ class TestDowngradePoseFileErrorHandling: def test_missing_file_raises_file_not_found_error(self): """Test that missing input file raises FileNotFoundError.""" with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=False), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=False), pytest.raises( FileNotFoundError, match="ERROR: missing file: nonexistent.h5" ), @@ -70,8 +70,8 @@ def test_missing_version_attribute_raises_invalid_pose_file_exception(self): ) with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), pytest.raises( InvalidPoseFileException, match="Pose file test.h5 did not have a valid version", @@ -79,7 +79,7 @@ def test_missing_version_attribute_raises_invalid_pose_file_exception(self): ): downgrade_pose_file("test.h5") - @patch("mouse_tracking.pose.convert.exit") + @patch("mouse_tracking.utils.writers.exit") def test_v2_file_prints_message_and_exits(self, mock_exit): """Test that v2 files print message and exit gracefully.""" # Make exit raise SystemExit to actually terminate execution @@ -91,8 +91,8 @@ def test_v2_file_prints_message_and_exits(self, mock_exit): ) with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), patch("builtins.print") as mock_print, ): with pytest.raises(SystemExit) as exc_info: @@ -108,9 +108,9 @@ def test_v2_file_prints_message_and_exits(self, mock_exit): class TestDowngradePoseFileV3Processing: """Test successful processing of v3 pose files.""" - @patch("mouse_tracking.pose.convert.write_pixel_per_cm_attr") - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pixel_per_cm_attr") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_v3_file_basic_processing( self, mock_multi_to_v2, mock_write_v2, mock_write_pixel ): @@ -139,8 +139,8 @@ def test_v3_file_basic_processing( ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("test_pose_est_v3.h5") @@ -173,9 +173,9 @@ def test_v3_file_basic_processing( # Verify pixel scaling was not written (no pixel data) mock_write_pixel.assert_not_called() - @patch("mouse_tracking.pose.convert.write_pixel_per_cm_attr") - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pixel_per_cm_attr") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_v3_file_with_pixel_scaling( self, mock_multi_to_v2, mock_write_v2, mock_write_pixel ): @@ -205,8 +205,8 @@ def test_v3_file_with_pixel_scaling( ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("experiment_pose_est_v3.h5") @@ -215,8 +215,8 @@ def test_v3_file_with_pixel_scaling( "experiment_animal_1_pose_est_v2.h5", 0.1, "manual" ) - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_v3_file_missing_config_model_attributes( self, mock_multi_to_v2, mock_write_v2 ): @@ -242,8 +242,8 @@ def test_v3_file_missing_config_model_attributes( ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("test_pose_est_v3.h5") @@ -260,8 +260,8 @@ def test_v3_file_missing_config_model_attributes( class TestDowngradePoseFileV4Processing: """Test successful processing of v4+ pose files.""" - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_v4_file_uses_embed_id_by_default(self, mock_multi_to_v2, mock_write_v2): """Test that v4+ files use instance_embed_id by default.""" pose_data = np.random.rand(8, 3, 12, 2).astype(np.float32) @@ -289,8 +289,8 @@ def test_v4_file_uses_embed_id_by_default(self, mock_multi_to_v2, mock_write_v2) ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("data_pose_est_v4.h5") @@ -298,8 +298,8 @@ def test_v4_file_uses_embed_id_by_default(self, mock_multi_to_v2, mock_write_v2) args = mock_multi_to_v2.call_args[0] np.testing.assert_array_equal(args[2], embed_id) - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_v4_file_uses_track_id_when_disabled(self, mock_multi_to_v2, mock_write_v2): """Test that v4+ files use instance_track_id when disable_id=True.""" pose_data = np.random.rand(5, 2, 12, 2).astype(np.float32) @@ -326,8 +326,8 @@ def test_v4_file_uses_track_id_when_disabled(self, mock_multi_to_v2, mock_write_ ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("data_pose_est_v5.h5", disable_id=True) @@ -339,8 +339,8 @@ def test_v4_file_uses_track_id_when_disabled(self, mock_multi_to_v2, mock_write_ class TestDowngradePoseFileFilenameHandling: """Test filename pattern replacement functionality.""" - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_various_filename_patterns(self, mock_multi_to_v2, mock_write_v2): """Test that different version filename patterns are handled correctly.""" test_cases = [ @@ -356,9 +356,9 @@ def test_various_filename_patterns(self, mock_multi_to_v2, mock_write_v2): for input_file, expected_output in test_cases: with ( self._setup_basic_v3_mock(mock_multi_to_v2), - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), patch( - "mouse_tracking.pose.convert.h5py.File", + "mouse_tracking.utils.writers.h5py.File", return_value=self.mock_h5, ), ): @@ -401,8 +401,8 @@ def _setup_basic_v3_mock(self, mock_multi_to_v2): class TestDowngradePoseFileEdgeCases: """Test edge cases and unusual scenarios.""" - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_empty_multi_to_v2_result(self, mock_multi_to_v2, mock_write_v2): """Test behavior when multi_to_v2 returns no animals.""" pose_data = np.zeros((5, 2, 12, 2), dtype=np.float32) @@ -424,16 +424,16 @@ def test_empty_multi_to_v2_result(self, mock_multi_to_v2, mock_write_v2): mock_multi_to_v2.return_value = [] # No animals found with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("empty_pose_est_v3.h5") # Verify no files were written mock_write_v2.assert_not_called() - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_single_animal_result(self, mock_multi_to_v2, mock_write_v2): """Test processing with only one animal in the data.""" pose_data = np.random.rand(10, 1, 12, 2).astype(np.float32) @@ -457,8 +457,8 @@ def test_single_animal_result(self, mock_multi_to_v2, mock_write_v2): ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("single_pose_est_v3.h5") @@ -471,8 +471,8 @@ def test_single_animal_result(self, mock_multi_to_v2, mock_write_v2): "single_model", ) - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_large_animal_ids(self, mock_multi_to_v2, mock_write_v2): """Test processing with large animal ID numbers.""" pose_data = np.random.rand(3, 2, 12, 2).astype(np.float32) @@ -497,8 +497,8 @@ def test_large_animal_ids(self, mock_multi_to_v2, mock_write_v2): ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("large_ids_pose_est_v3.h5") @@ -525,9 +525,9 @@ def test_large_animal_ids(self, mock_multi_to_v2, mock_write_v2): class TestDowngradePoseFileIntegration: """Test integration scenarios that combine multiple aspects.""" - @patch("mouse_tracking.pose.convert.write_pixel_per_cm_attr") - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pixel_per_cm_attr") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_realistic_multi_animal_v4_scenario( self, mock_multi_to_v2, mock_write_v2, mock_write_pixel ): @@ -567,8 +567,8 @@ def test_realistic_multi_animal_v4_scenario( ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("experiment_20241201_cage1_pose_est_v4.h5") @@ -599,8 +599,8 @@ def test_realistic_multi_animal_v4_scenario( args = mock_multi_to_v2.call_args[0] np.testing.assert_array_equal(args[2], embed_id) - @patch("mouse_tracking.pose.convert.write_pose_v2_data") - @patch("mouse_tracking.pose.convert.multi_to_v2") + @patch("mouse_tracking.utils.writers.write_pose_v2_data") + @patch("mouse_tracking.utils.writers.multi_to_v2") def test_v6_file_with_missing_optional_attributes( self, mock_multi_to_v2, mock_write_v2 ): @@ -641,8 +641,8 @@ def test_v6_file_with_missing_optional_attributes( ] with ( - patch("mouse_tracking.pose.convert.os.path.isfile", return_value=True), - patch("mouse_tracking.pose.convert.h5py.File", return_value=mock_h5), + patch("mouse_tracking.utils.writers.os.path.isfile", return_value=True), + patch("mouse_tracking.utils.writers.h5py.File", return_value=mock_h5), ): downgrade_pose_file("advanced_pose_est_v6.h5") diff --git a/tests/pose/inspect/test_inspect_pose_v2.py b/tests/pose/inspect/test_inspect_pose_v2.py index d6457a96..62738b83 100644 --- a/tests/pose/inspect/test_inspect_pose_v2.py +++ b/tests/pose/inspect/test_inspect_pose_v2.py @@ -32,6 +32,7 @@ def test_successful_inspection_basic( # Mock CONFIG constants mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 # Mock HDF5 file structure mock_file = MagicMock() @@ -45,6 +46,9 @@ def test_successful_inspection_basic( pose_quality[:100, :, :] = 0 # No confidence before frame 100 pose_quality[100:110000, :, :] = 0.8 # High confidence after frame 100 + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(num_frames, 1, 12, 2).astype(np.uint16) * 100 + # Mock dataset access def mock_getitem(key): if key == "poseest": @@ -53,6 +57,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data else: raise KeyError(f"Key {key} not found") @@ -70,6 +76,7 @@ def mock_getitem(key): assert "pose_counts" in result assert "missing_poses" in result assert "missing_keypoint_frames" in result + assert "large_poses" in result assert result["first_frame_pose"] == 100 assert result["first_frame_full_high_conf"] == 100 @@ -93,6 +100,7 @@ def test_successful_inspection_with_detailed_calculations( # Mock CONFIG constants mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -106,6 +114,9 @@ def test_successful_inspection_with_detailed_calculations( pose_quality[60:240, :, :8] = 0.4 pose_quality[80:220, :, :] = 0.8 + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(total_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -113,6 +124,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem @@ -249,6 +262,7 @@ def test_confidence_threshold_calculations( # Mock CONFIG constants mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -257,10 +271,14 @@ def test_confidence_threshold_calculations( # Frame 0: No keypoints above threshold # Frame 1: Some keypoints above JABS threshold but not high confidence # Frame 2: All keypoints above high confidence threshold - pose_quality = np.zeros((100, 1, 12)) + num_frames = 100 + pose_quality = np.zeros((num_frames, 1, 12)) pose_quality[1, :, :5] = 0.4 # 5 keypoints above 0.3 pose_quality[2:, :, :] = 0.8 # All keypoints above 0.75 + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(num_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -268,6 +286,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem @@ -303,6 +323,7 @@ def test_pad_and_duration_calculations( # Mock CONFIG constants mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -314,6 +335,9 @@ def test_pad_and_duration_calculations( 0.4 # Poses in frames 60-239, 8 keypoints > threshold ) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(total_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -321,6 +345,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem @@ -355,17 +381,22 @@ def test_pose_counts_calculation( mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file # Create specific test data - pose_quality = np.zeros((100, 1, 12)) + num_frames = 100 + pose_quality = np.zeros((num_frames, 1, 12)) # Frames 10-50: 5 keypoints above threshold # Frames 60-80: 3 keypoints above threshold pose_quality[10:50, :, :5] = 0.4 pose_quality[60:80, :, :3] = 0.5 + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(num_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -373,6 +404,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_safe_find_first.return_value = 0 @@ -402,12 +435,17 @@ def test_empty_arrays(self, mock_config, mock_h5py_file, mock_safe_find_first): mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file # Empty arrays - pose_quality = np.array([]).reshape(0, 1, 12) + num_frames = 0 + pose_quality = np.array([]).reshape(num_frames, 1, 12) + + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.array([]).reshape(num_frames, 1, 12, 2).astype(np.uint16) def mock_getitem(key): if key == "poseest": @@ -416,6 +454,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem @@ -443,12 +483,17 @@ def test_all_zero_confidence( mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file # All confidence values are zero - use enough frames for default pad+duration - pose_quality = np.zeros((110000, 1, 12)) # All zero confidence + num_frames = 110000 + pose_quality = np.zeros((num_frames, 1, 12)) # All zero confidence + + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(num_frames, 1, 12, 2).astype(np.uint16) * 100 def mock_getitem(key): if key == "poseest": @@ -457,6 +502,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem @@ -488,6 +535,7 @@ def test_custom_pad_and_duration( mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -496,6 +544,9 @@ def test_custom_pad_and_duration( total_frames = 60000 pose_quality = np.full((total_frames, 1, 12), 0.8) # All high confidence + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(total_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -503,6 +554,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem @@ -551,14 +604,20 @@ def test_threshold_boundary_conditions( mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = threshold + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file - # Single frame with one keypoint at specific confidence - pose_quality = np.zeros((1, 1, 12)) + # Use 10 frames to avoid squeeze creating a scalar (need at least 2 frames) + # Only first frame has keypoint at specific confidence + num_frames = 10 + pose_quality = np.zeros((num_frames, 1, 12)) pose_quality[0, 0, 0] = confidence_value + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(num_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -566,12 +625,14 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_safe_find_first.return_value = 0 if expected_keypoints > 0 else -1 # Act - result = inspect_pose_v2(pose_file_path, pad=0, duration=1) + result = inspect_pose_v2(pose_file_path, pad=0, duration=10) # Assert expected_pose_counts = expected_keypoints @@ -594,12 +655,17 @@ def test_all_dependencies_called_correctly( # Mock CONFIG mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 # Mock HDF5 file mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file - pose_quality = np.full((100, 1, 12), 0.8) + num_frames = 100 + pose_quality = np.full((num_frames, 1, 12), 0.8) + + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(num_frames, 1, 12, 2).astype(np.uint16) * 100 def mock_getitem(key): if key == "poseest": @@ -608,6 +674,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem @@ -627,6 +695,7 @@ def mock_getitem(key): "pose_counts", "missing_poses", "missing_keypoint_frames", + "large_poses", } assert set(result.keys()) == expected_keys @@ -642,12 +711,19 @@ def test_array_shape_handling( mock_config.MIN_HIGH_CONFIDENCE = 0.75 mock_config.MIN_JABS_CONFIDENCE = 0.3 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file # v2 shape: [frames, instances, keypoints] same as v6, typically 1 instance - pose_quality = np.random.rand(1000, 1, 12) # 3D with single instance dimension + num_frames = 1000 + pose_quality = np.random.rand( + num_frames, 1, 12 + ) # 3D with single instance dimension + + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(num_frames, 1, 12, 2).astype(np.uint16) * 100 def mock_getitem(key): if key == "poseest": @@ -656,6 +732,8 @@ def mock_getitem(key): return mock_poseest elif key == "poseest/confidence": return pose_quality + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_safe_find_first.return_value = 0 diff --git a/tests/pose/inspect/test_inspect_pose_v6.py b/tests/pose/inspect/test_inspect_pose_v6.py index ff307ccf..eba01871 100644 --- a/tests/pose/inspect/test_inspect_pose_v6.py +++ b/tests/pose/inspect/test_inspect_pose_v6.py @@ -38,6 +38,9 @@ def test_successful_inspection_with_corners( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 # Mock HDF5 file structure mock_file = MagicMock() @@ -62,6 +65,9 @@ def test_successful_inspection_with_corners( seg_ids = np.zeros(num_frames, dtype=np.uint32) seg_ids[150:105000] = 1 # Segmentation starts at frame 150 + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(num_frames, 1, 12, 2).astype(np.uint16) * 100 + # Mock dataset access def mock_getitem(key): if key == "poseest": @@ -76,6 +82,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data else: raise KeyError(f"Key {key} not found") @@ -141,6 +149,9 @@ def test_successful_inspection_without_corners( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 # Mock HDF5 file structure mock_file = MagicMock() @@ -152,6 +163,9 @@ def test_successful_inspection_without_corners( pose_tracks = np.ones((1000, 1), dtype=np.uint32) seg_ids = np.ones(1000, dtype=np.uint32) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(1000, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -165,6 +179,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_file.__contains__.return_value = False # No corners @@ -273,6 +289,9 @@ def test_confidence_threshold_calculations( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -290,6 +309,9 @@ def test_confidence_threshold_calculations( pose_tracks = np.ones((100, 1), dtype=np.uint32) seg_ids = np.ones(100, dtype=np.uint32) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(100, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -303,6 +325,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_file.__contains__.return_value = True @@ -352,6 +376,9 @@ def test_pad_and_duration_calculations( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -367,6 +394,9 @@ def test_pad_and_duration_calculations( seg_ids = np.zeros(total_frames, dtype=np.uint32) seg_ids[70:230] = 1 # Segmentation in frames 70-229 (160 frames) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(total_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -380,6 +410,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_file.__contains__.return_value = False @@ -423,6 +455,9 @@ def test_tracklet_calculation( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -442,6 +477,9 @@ def test_tracklet_calculation( pose_quality = np.full((total_frames, 1, 12), 0.8) seg_ids = np.ones(total_frames, dtype=np.uint32) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(total_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -455,6 +493,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_file.__contains__.return_value = True @@ -516,6 +556,9 @@ def test_video_name_parsing( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -526,6 +569,9 @@ def test_video_name_parsing( pose_tracks = np.ones((100, 1), dtype=np.uint32) seg_ids = np.ones(100, dtype=np.uint32) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(100, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -539,6 +585,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_file.__contains__.return_value = True @@ -629,6 +677,9 @@ def test_all_zero_confidence( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -639,6 +690,9 @@ def test_all_zero_confidence( pose_tracks = np.ones((110000, 1), dtype=np.uint32) seg_ids = np.ones(110000, dtype=np.uint32) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(110000, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -652,6 +706,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_file.__contains__.return_value = True @@ -692,6 +748,9 @@ def test_custom_pad_and_duration( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 mock_file = MagicMock() mock_h5py_file.return_value.__enter__.return_value = mock_file @@ -703,6 +762,9 @@ def test_custom_pad_and_duration( pose_tracks = np.ones((total_frames, 1), dtype=np.uint32) seg_ids = np.ones(total_frames, dtype=np.uint32) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(total_frames, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -716,6 +778,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_file.__contains__.return_value = True @@ -766,6 +830,9 @@ def test_all_dependencies_called_correctly( mock_config.BASE_TAIL_INDEX = 9 mock_config.LEFT_REAR_PAW_INDEX = 7 mock_config.RIGHT_REAR_PAW_INDEX = 8 + mock_config.OFA_MAX_EXPECTED_AREA_PX = 22500 # 150 * 150 + mock_config.MID_TAIL_INDEX = 10 + mock_config.TIP_TAIL_INDEX = 11 # Mock Path operations mock_path_instance = MagicMock() @@ -786,6 +853,9 @@ def test_all_dependencies_called_correctly( pose_tracks = np.ones((100, 1), dtype=np.uint32) seg_ids = np.ones(100, dtype=np.uint32) + # Create pose data with shape [frames, instances, keypoints, 2] + pose_data = np.random.rand(100, 1, 12, 2).astype(np.uint16) * 100 + def mock_getitem(key): if key == "poseest": mock_poseest = MagicMock() @@ -799,6 +869,8 @@ def mock_getitem(key): return pose_tracks elif key == "poseest/longterm_seg_id": return seg_ids + elif key == "poseest/points": + return pose_data mock_file.__getitem__.side_effect = mock_getitem mock_file.__contains__.return_value = True @@ -833,6 +905,7 @@ def mock_getitem(key): "pose_counts", "seg_counts", "missing_poses", + "large_poses", "missing_segs", "pose_tracklets", "missing_keypoint_frames", diff --git a/tests/utils/arrays/test_safe_find_first.py b/tests/utils/arrays/test_safe_find_first.py index d9276ce5..f4f82a5d 100644 --- a/tests/utils/arrays/test_safe_find_first.py +++ b/tests/utils/arrays/test_safe_find_first.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from mouse_tracking.utils.pose import safe_find_first +from mouse_tracking.utils.arrays import safe_find_first class TestSafeFindFirstBasicFunctionality: diff --git a/tests/utils/writers/test_promote_pose_data.py b/tests/utils/writers/test_promote_pose_data.py index f02bb491..63f26906 100644 --- a/tests/utils/writers/test_promote_pose_data.py +++ b/tests/utils/writers/test_promote_pose_data.py @@ -13,12 +13,12 @@ class TestPromotePoseDataV2ToV3: @patch("mouse_tracking.utils.writers.write_pose_v3_data") @patch("mouse_tracking.utils.writers.write_pose_v2_data") - @patch("mouse_tracking.utils.writers.convert_v2_to_v3") + @patch("mouse_tracking.utils.writers.v2_to_v3") @patch("mouse_tracking.utils.writers.h5py.File") def test_v2_to_v3_basic_promotion( self, mock_h5py_file, - mock_convert_v2_to_v3, + mock_v2_to_v3, mock_write_pose_v2_data, mock_write_pose_v3_data, ): @@ -52,7 +52,7 @@ def test_v2_to_v3_basic_promotion( instance_embedding = np.zeros((10, 1, 12), dtype=np.float32) instance_track_id = np.zeros((10, 1), dtype=np.uint32) - mock_convert_v2_to_v3.return_value = ( + mock_v2_to_v3.return_value = ( converted_pose_data, converted_conf_data, instance_count, @@ -72,8 +72,8 @@ def test_v2_to_v3_basic_promotion( expected_reshaped_conf = np.reshape(original_conf_data, [-1, 1, 12]) # Verify convert_v2_to_v3 was called with reshaped data - mock_convert_v2_to_v3.assert_called_once() - call_args = mock_convert_v2_to_v3.call_args[0] + mock_v2_to_v3.assert_called_once() + call_args = mock_v2_to_v3.call_args[0] np.testing.assert_array_equal(call_args[0], expected_reshaped_pose) np.testing.assert_array_equal(call_args[1], expected_reshaped_conf) @@ -91,12 +91,12 @@ def test_v2_to_v3_basic_promotion( @patch("mouse_tracking.utils.writers.write_pose_v3_data") @patch("mouse_tracking.utils.writers.write_pose_v2_data") - @patch("mouse_tracking.utils.writers.convert_v2_to_v3") + @patch("mouse_tracking.utils.writers.v2_to_v3") @patch("mouse_tracking.utils.writers.h5py.File") def test_v2_to_v3_missing_attributes( self, mock_h5py_file, - mock_convert_v2_to_v3, + mock_v2_to_v3, mock_write_pose_v2_data, mock_write_pose_v3_data, ): @@ -125,7 +125,7 @@ def test_v2_to_v3_missing_attributes( }[key] # Mock convert_v2_to_v3 return values - mock_convert_v2_to_v3.return_value = ( + mock_v2_to_v3.return_value = ( np.random.rand(5, 1, 12, 2), np.random.rand(5, 1, 12), np.ones(5, dtype=np.uint8), @@ -143,8 +143,8 @@ def test_v2_to_v3_missing_attributes( # Use assert_called_with to verify the exact arguments mock_write_pose_v2_data.assert_called_with( pose_file, - mock_convert_v2_to_v3.return_value[0], # pose_data - mock_convert_v2_to_v3.return_value[1], # conf_data + mock_v2_to_v3.return_value[0], # pose_data + mock_v2_to_v3.return_value[1], # conf_data "unknown", # config_str "unknown", # model_str ) @@ -152,12 +152,12 @@ def test_v2_to_v3_missing_attributes( @patch("mouse_tracking.utils.writers.write_pose_v4_data") @patch("mouse_tracking.utils.writers.write_pose_v3_data") @patch("mouse_tracking.utils.writers.write_pose_v2_data") - @patch("mouse_tracking.utils.writers.convert_v2_to_v3") + @patch("mouse_tracking.utils.writers.v2_to_v3") @patch("mouse_tracking.utils.writers.h5py.File") def test_v2_to_v4_skips_v3_promotion( self, mock_h5py_file, - mock_convert_v2_to_v3, + mock_v2_to_v3, mock_write_pose_v2_data, mock_write_pose_v3_data, mock_write_pose_v4_data, @@ -211,7 +211,7 @@ def mock_file_side_effect(filename, mode): mock_h5py_file.side_effect = mock_file_side_effect - mock_convert_v2_to_v3.return_value = ( + mock_v2_to_v3.return_value = ( np.random.rand(3, 1, 12, 2), np.random.rand(3, 1, 12), np.ones(3, dtype=np.uint8), @@ -224,7 +224,7 @@ def mock_file_side_effect(filename, mode): # Assert # Should call v2 to v3 conversion functions and then v4 functions - mock_convert_v2_to_v3.assert_called_once() + mock_v2_to_v3.assert_called_once() mock_write_pose_v2_data.assert_called_once() mock_write_pose_v3_data.assert_called_once() mock_write_pose_v4_data.assert_called_once() @@ -576,7 +576,7 @@ def mock_file_side_effect(filename, mode): side_effect=mock_file_side_effect, ), patch( - "mouse_tracking.utils.writers.convert_v2_to_v3", + "mouse_tracking.utils.writers.v2_to_v3", return_value=( np.random.rand(3, 1, 12, 2), np.random.rand(3, 1, 12), @@ -608,12 +608,12 @@ class TestPromotePoseDataIntegration: @patch("mouse_tracking.utils.writers.write_pose_v4_data") @patch("mouse_tracking.utils.writers.write_pose_v3_data") @patch("mouse_tracking.utils.writers.write_pose_v2_data") - @patch("mouse_tracking.utils.writers.convert_v2_to_v3") + @patch("mouse_tracking.utils.writers.v2_to_v3") @patch("mouse_tracking.utils.writers.h5py.File") def test_full_v2_to_v6_promotion( self, mock_h5py_file, - mock_convert_v2_to_v3, + mock_v2_to_v3, mock_write_pose_v2_data, mock_write_pose_v3_data, mock_write_pose_v4_data, @@ -665,7 +665,7 @@ def mock_file_side_effect(filename, mode): mock_h5py_file.side_effect = mock_file_side_effect # Mock convert function - mock_convert_v2_to_v3.return_value = ( + mock_v2_to_v3.return_value = ( np.random.rand(5, 1, 12, 2), np.random.rand(5, 1, 12), np.ones(5, dtype=np.uint8), diff --git a/uv.lock b/uv.lock index 06dd62a5..d3ef13dc 100644 --- a/uv.lock +++ b/uv.lock @@ -453,6 +453,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "mizani" +version = "0.9.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "scipy" }, + { name = "tzdata", marker = "sys_platform == 'emscripten' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/8c/af5ab84c8e75efbe14cbc15558b46b259a068bab4b285000ae4fcddec1b0/mizani-0.9.3.tar.gz", hash = "sha256:fb61339e9e4711850e902ca286b1ae75255f483823d891aa0515b426d56c606d", size = 161687, upload-time = "2023-09-01T12:30:59.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/95/d4e33d3f5bc9fee5512637661208b6b595bda58e9b6a66fa867137761dd7/mizani-0.9.3-py3-none-any.whl", hash = "sha256:ac5d49b913de88dc2fb28d82141e9777b97407a6971a158f758093ad5bb820a1", size = 73742, upload-time = "2023-09-01T12:30:57.724Z" }, +] + [[package]] name = "ml-dtypes" version = "0.5.3" @@ -482,6 +498,7 @@ dependencies = [ { name = "opencv-python-headless" }, { name = "pandas" }, { name = "pillow" }, + { name = "plotnine" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "scipy" }, @@ -497,6 +514,7 @@ cpu = [ { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" }, ] gpu = [ + { name = "nvidia-cusparselt-cu12" }, { name = "tensorflow", extra = ["and-cuda"] }, { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" }, { name = "torch", version = "2.6.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform == 'linux'" }, @@ -532,9 +550,11 @@ requires-dist = [ { name = "matplotlib", specifier = "==3.7.1" }, { name = "networkx", specifier = "==3.3" }, { name = "numpy", specifier = ">=1.26.0,<2.2.0" }, + { name = "nvidia-cusparselt-cu12", marker = "extra == 'gpu'", specifier = "==0.6.3" }, { name = "opencv-python-headless", specifier = "==4.8.0.76" }, { name = "pandas", specifier = "==2.0.3" }, { name = "pillow", specifier = "==9.4.0" }, + { name = "plotnine", specifier = ">=0.12.0" }, { name = "pydantic", specifier = "==2.7.4" }, { name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "scipy", specifier = "==1.11.4" }, @@ -731,6 +751,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/ef/063500c25670fbd1cbb0cd3eb7c8a061585b53adb4dd8bf3492bb49b0df3/nvidia_cusparse_cu12-12.5.10.65-py3-none-win_amd64.whl", hash = "sha256:9e487468a22a1eaf1fbd1d2035936a905feb79c4ce5c2f67626764ee4f90227c", size = 362504719, upload-time = "2025-06-05T20:15:17.947Z" }, ] +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/da/4de092c61c6dea1fc9c936e69308a02531d122e12f1f649825934ad651b5/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8371549623ba601a06322af2133c4a44350575f5a3108fb75f3ef20b822ad5f1", size = 156402859, upload-time = "2024-10-16T02:23:17.184Z" }, + { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796, upload-time = "2024-10-15T21:29:17.709Z" }, + { url = "https://files.pythonhosted.org/packages/46/3e/9e1e394a02a06f694be2c97bbe47288bb7c90ea84c7e9cf88f7b28afe165/nvidia_cusparselt_cu12-0.6.3-py3-none-win_amd64.whl", hash = "sha256:3b325bcbd9b754ba43df5a311488fca11a6b5dc3d11df4d190c000cf1a0765c7", size = 155595972, upload-time = "2024-10-15T22:58:35.426Z" }, +] + [[package]] name = "nvidia-nccl-cu12" version = "2.27.7" @@ -830,6 +860,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/30/b97456e7063edac0e5a405128065f0cd2033adfe3716fb2256c186bd41d0/pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e", size = 10664333, upload-time = "2023-06-28T23:16:39.209Z" }, ] +[[package]] +name = "patsy" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/81/74f6a65b848ffd16c18f920620ce999fe45fe27f01ab3911260ce4ed85e4/patsy-1.0.1.tar.gz", hash = "sha256:e786a9391eec818c054e359b737bbce692f051aee4c661f4141cc88fb459c0c4", size = 396010, upload-time = "2024-11-12T14:10:54.642Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/2b/b50d3d08ea0fc419c183a84210571eba005328efa62b6b98bc28e9ead32a/patsy-1.0.1-py2.py3-none-any.whl", hash = "sha256:751fb38f9e97e62312e921a1954b81e1bb2bcda4f5eeabaf94db251ee791509c", size = 232923, upload-time = "2024-11-12T14:10:52.85Z" }, +] + [[package]] name = "pillow" version = "9.4.0" @@ -851,6 +893,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/7c/293136a5171800001be33c21a51daaca68fae954b543e2c015a6bb81a716/Pillow-9.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:f6e78171be3fb7941f9910ea15b4b14ec27725865a73c15277bc39f5ca4f8391", size = 2475100, upload-time = "2023-01-02T02:52:51.402Z" }, ] +[[package]] +name = "plotnine" +version = "0.12.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "mizani" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "patsy" }, + { name = "scipy" }, + { name = "statsmodels" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1e/17/9d5225e607a89abee9f725e4c5104e5528807892c854faa5ddd6b083d0bd/plotnine-0.12.4.tar.gz", hash = "sha256:adc41a672503594445a8fa19872799253bd0784cdbd5a1cc16657a1dd20ba905", size = 5765782, upload-time = "2023-11-06T11:02:26.737Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/b5/fb81914804ad0d8e4a53118df343efdba1562de13275189cf2228ef8e3c1/plotnine-0.12.4-py3-none-any.whl", hash = "sha256:12748f346f107c33f3e0658ac46fbb052205ae7e97ffaf52be68310e5d29f799", size = 1266611, upload-time = "2023-11-06T11:02:10.994Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -1166,6 +1226,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254", size = 11053, upload-time = "2021-05-05T14:18:17.237Z" }, ] +[[package]] +name = "statsmodels" +version = "0.14.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "patsy" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/cc/8c1bf59bf8203dea1bf2ea811cfe667d7bcc6909c83d8afb02b08e30f50b/statsmodels-0.14.5.tar.gz", hash = "sha256:de260e58cccfd2ceddf835b55a357233d6ca853a1aa4f90f7553a52cc71c6ddf", size = 20525016, upload-time = "2025-07-07T12:14:23.195Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/2c/55b2a5d10c1a211ecab3f792021d2581bbe1c5ca0a1059f6715dddc6899d/statsmodels-0.14.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9fc2b5cdc0c95cba894849651fec1fa1511d365e3eb72b0cc75caac44077cd48", size = 10058241, upload-time = "2025-07-07T12:13:16.286Z" }, + { url = "https://files.pythonhosted.org/packages/66/d9/6967475805de06691e951072d05e40e3f1c71b6221bb92401193ee19bd2a/statsmodels-0.14.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b8d96b0bbaeabd3a557c35cc7249baa9cfbc6dd305c32a9f2cbdd7f46c037e7f", size = 9734017, upload-time = "2025-07-07T12:05:08.498Z" }, + { url = "https://files.pythonhosted.org/packages/df/a8/803c280419a7312e2472969fe72cf461c1210a27770a662cbe3b5cd7c6fe/statsmodels-0.14.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:145bc39b2cb201efb6c83cc3f2163c269e63b0d4809801853dec6f440bd3bc37", size = 10459677, upload-time = "2025-07-07T14:21:51.809Z" }, + { url = "https://files.pythonhosted.org/packages/a1/25/edf20acbd670934b02cd9344e29c9a03ce040122324b3491bb075ae76b2d/statsmodels-0.14.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d7c14fb2617bb819fb2532e1424e1da2b98a3419a80e95f33365a72d437d474e", size = 10678631, upload-time = "2025-07-07T14:22:05.496Z" }, + { url = "https://files.pythonhosted.org/packages/64/22/8b1e38310272e766abd6093607000a81827420a3348f09eff08a9e54cbaf/statsmodels-0.14.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1e9742d8a5ac38a3bfc4b7f4b0681903920f20cbbf466d72b1fd642033846108", size = 10699273, upload-time = "2025-07-07T14:22:19.487Z" }, + { url = "https://files.pythonhosted.org/packages/d1/6f/6de51f1077b7cef34611f1d6721392ea170153251b4d977efcf6d100f779/statsmodels-0.14.5-cp310-cp310-win_amd64.whl", hash = "sha256:1cab9e6fce97caf4239cdb2df375806937da5d0b7ba2699b13af33a07f438464", size = 9644785, upload-time = "2025-07-07T12:05:20.927Z" }, +] + [[package]] name = "sympy" version = "1.13.1"