diff --git a/src/jabs/project/project.py b/src/jabs/project/project.py index dfc9eb4d..fd361414 100644 --- a/src/jabs/project/project.py +++ b/src/jabs/project/project.py @@ -621,21 +621,26 @@ def save_predictions( prediction_labels = np.full( (pose_est.num_identities, pose_est.num_frames), -1, dtype=np.int8 ) + # Probability shape is determined by mode, not by sniffing a sample + # array: multi-class predictions (class_names provided) store one column + # per class, binary predictions store a scalar per frame. Deciding from + # class_names keeps the shape correct even when `probabilities` is empty + # (e.g. a video with no identities to classify). prediction_prob: np.ndarray - if probabilities: - sample_prob = next(iter(probabilities.values())) - if sample_prob.ndim == 1: - prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32) - elif sample_prob.ndim == 2: - prediction_prob = np.zeros( - (pose_est.num_identities, pose_est.num_frames, sample_prob.shape[1]), - dtype=np.float32, - ) - else: - raise ValueError(f"Unsupported probability shape: {sample_prob.shape}") + if class_names is not None: + prediction_prob = np.zeros( + (pose_est.num_identities, pose_est.num_frames, len(class_names)), + dtype=np.float32, + ) else: prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32) + # Expected per-identity probability shape; checked explicitly before + # assignment because allocating from class_names (rather than from the + # array itself) means a mis-shaped input could otherwise broadcast + # silently (e.g. (n_frames, 1) duplicated across classes). + expected_prob_shape = prediction_prob.shape[1:] + if postprocessed_predictions: postprocessed_labels = np.full( (pose_est.num_identities, pose_est.num_frames), -1, dtype=np.int8 @@ -645,8 +650,14 @@ def save_predictions( # stack the numpy arrays for identity in predictions: + identity_prob = probabilities[identity] + if identity_prob.shape != expected_prob_shape: + raise ValueError( + f"probability array for identity {identity} has shape " + f"{identity_prob.shape}, expected {expected_prob_shape}" + ) prediction_labels[identity] = predictions[identity] - prediction_prob[identity] = probabilities[identity] + prediction_prob[identity] = identity_prob if postprocessed_predictions: postprocessed_labels[identity] = postprocessed_predictions[identity] diff --git a/tests/project/test_project.py b/tests/project/test_project.py index 91c9e790..25ec5be4 100644 --- a/tests/project/test_project.py +++ b/tests/project/test_project.py @@ -530,6 +530,139 @@ def _conflicting_labels(name: str) -> VideoLabels: assert project.get_overlapping_behavior_label_videos() == ["mmm.avi", "zzz.avi"] +# --------------------------------------------------------------------------- +# save_predictions probability-array shape +# --------------------------------------------------------------------------- + + +def _prediction_test_pose(num_identities: int, num_frames: int): + """Minimal pose stand-in exposing the attributes save_predictions reads.""" + return type( + "PoseEstimation", + (object,), + { + "num_identities": num_identities, + "num_frames": num_frames, + "pose_file": "video1_pose_est_v6.h5", + "hash": "posehash", + "identity_to_track": None, + "external_identities": None, + }, + )() + + +def _prediction_test_classifier(): + """Minimal classifier stand-in exposing the metadata write_predictions reads.""" + return type( + "Classifier", + (object,), + {"classifier_file": "_multiclass.pickle", "classifier_hash": "clshash"}, + )() + + +def _bare_project(tmp_path: Path) -> Project: + return Project( + tmp_path, + enable_video_check=False, + enable_session_tracker=False, + validate_project_dir=False, + ) + + +def test_save_predictions_multiclass_empty_allocates_per_class_shape(tmp_path: Path) -> None: + """An empty multi-class video still allocates a (n_id, n_frames, n_classes) prob array. + + Shape is derived from class_names, not from sniffing a sample probability + array, so an empty ``probabilities`` dict no longer falls through to a binary + 2D shape (which would then fail BehaviorPrediction's shape validation). + """ + project = _bare_project(tmp_path) + class_names = [MULTICLASS_NONE_BEHAVIOR, "Walk", "Run"] + + project.save_predictions( + _prediction_test_pose(2, 5), + "video1.avi", + {}, + {}, + MULTICLASS_PREDICTION_KEY, + _prediction_test_classifier(), + class_names=class_names, + ) + + safe = to_safe_name(MULTICLASS_PREDICTION_KEY) + with h5py.File(project.project_paths.prediction_dir / "video1.h5", "r") as hf: + assert hf[f"predictions/{safe}/probabilities"].shape == (2, 5, 3) + + +def test_save_predictions_multiclass_writes_per_class_probabilities(tmp_path: Path) -> None: + """Non-empty multi-class predictions write the per-identity per-class matrices.""" + project = _bare_project(tmp_path) + class_names = [MULTICLASS_NONE_BEHAVIOR, "Walk", "Run"] + predictions = {0: np.array([1, 0, 2], dtype=np.int8)} + probabilities = { + 0: np.array([[0.1, 0.8, 0.1], [0.7, 0.2, 0.1], [0.2, 0.3, 0.5]], dtype=np.float32) + } + + project.save_predictions( + _prediction_test_pose(1, 3), + "video1.avi", + predictions, + probabilities, + MULTICLASS_PREDICTION_KEY, + _prediction_test_classifier(), + class_names=class_names, + ) + + safe = to_safe_name(MULTICLASS_PREDICTION_KEY) + with h5py.File(project.project_paths.prediction_dir / "video1.h5", "r") as hf: + probs = hf[f"predictions/{safe}/probabilities"][()] + assert probs.shape == (1, 3, 3) + np.testing.assert_allclose(probs[0], probabilities[0]) + + +def test_save_predictions_multiclass_rejects_misshaped_probabilities(tmp_path: Path) -> None: + """A per-identity probability array that doesn't match (n_frames, n_classes) raises. + + Guards against a mis-shaped array (e.g. 1D or (n_frames, 1)) broadcasting + silently into the class-sized allocation instead of failing. + """ + project = _bare_project(tmp_path) + class_names = [MULTICLASS_NONE_BEHAVIOR, "Walk", "Run"] + predictions = {0: np.array([1, 0, 2], dtype=np.int8)} + # 1D probabilities would broadcast across the 3 class columns without a check + probabilities = {0: np.array([0.9, 0.4, 0.8], dtype=np.float32)} + + with pytest.raises(ValueError, match="probability array for identity 0"): + project.save_predictions( + _prediction_test_pose(1, 3), + "video1.avi", + predictions, + probabilities, + MULTICLASS_PREDICTION_KEY, + _prediction_test_classifier(), + class_names=class_names, + ) + + +def test_save_predictions_binary_allocates_scalar_shape(tmp_path: Path) -> None: + """Binary predictions (no class_names) allocate a 2D (n_id, n_frames) prob array.""" + project = _bare_project(tmp_path) + predictions = {0: np.array([1, 0, 1], dtype=np.int8)} + probabilities = {0: np.array([0.9, 0.4, 0.8], dtype=np.float32)} + + project.save_predictions( + _prediction_test_pose(1, 3), + "video1.avi", + predictions, + probabilities, + "Walk", + _prediction_test_classifier(), + ) + + with h5py.File(project.project_paths.prediction_dir / "video1.h5", "r") as hf: + assert hf["predictions/Walk/probabilities"].shape == (1, 3) + + def test_rename_behavior_multiclass_updates_classifier_and_predictions(tmp_path: Path) -> None: """In multi-class mode, rename updates the shared classifier and prediction class_names.