Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions src/jabs/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,21 +621,26 @@ def save_predictions(
prediction_labels = np.full(
(pose_est.num_identities, pose_est.num_frames), -1, dtype=np.int8
)
# Probability shape is determined by mode, not by sniffing a sample
# array: multi-class predictions (class_names provided) store one column
# per class, binary predictions store a scalar per frame. Deciding from
# class_names keeps the shape correct even when `probabilities` is empty
# (e.g. a video with no identities to classify).
prediction_prob: np.ndarray
if probabilities:
sample_prob = next(iter(probabilities.values()))
if sample_prob.ndim == 1:
prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32)
elif sample_prob.ndim == 2:
prediction_prob = np.zeros(
(pose_est.num_identities, pose_est.num_frames, sample_prob.shape[1]),
dtype=np.float32,
)
else:
raise ValueError(f"Unsupported probability shape: {sample_prob.shape}")
if class_names is not None:
prediction_prob = np.zeros(
(pose_est.num_identities, pose_est.num_frames, len(class_names)),
dtype=np.float32,
)
else:
prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32)
Comment thread
gbeane marked this conversation as resolved.

# Expected per-identity probability shape; checked explicitly before
# assignment because allocating from class_names (rather than from the
# array itself) means a mis-shaped input could otherwise broadcast
# silently (e.g. (n_frames, 1) duplicated across classes).
expected_prob_shape = prediction_prob.shape[1:]

if postprocessed_predictions:
postprocessed_labels = np.full(
(pose_est.num_identities, pose_est.num_frames), -1, dtype=np.int8
Expand All @@ -645,8 +650,14 @@ def save_predictions(

# stack the numpy arrays
for identity in predictions:
identity_prob = probabilities[identity]
if identity_prob.shape != expected_prob_shape:
raise ValueError(
f"probability array for identity {identity} has shape "
f"{identity_prob.shape}, expected {expected_prob_shape}"
)
prediction_labels[identity] = predictions[identity]
prediction_prob[identity] = probabilities[identity]
prediction_prob[identity] = identity_prob
if postprocessed_predictions:
postprocessed_labels[identity] = postprocessed_predictions[identity]

Expand Down
133 changes: 133 additions & 0 deletions tests/project/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,139 @@ def _conflicting_labels(name: str) -> VideoLabels:
assert project.get_overlapping_behavior_label_videos() == ["mmm.avi", "zzz.avi"]


# ---------------------------------------------------------------------------
# save_predictions probability-array shape
# ---------------------------------------------------------------------------


def _prediction_test_pose(num_identities: int, num_frames: int):
"""Minimal pose stand-in exposing the attributes save_predictions reads."""
return type(
"PoseEstimation",
(object,),
{
"num_identities": num_identities,
"num_frames": num_frames,
"pose_file": "video1_pose_est_v6.h5",
"hash": "posehash",
"identity_to_track": None,
"external_identities": None,
},
)()


def _prediction_test_classifier():
"""Minimal classifier stand-in exposing the metadata write_predictions reads."""
return type(
"Classifier",
(object,),
{"classifier_file": "_multiclass.pickle", "classifier_hash": "clshash"},
)()


def _bare_project(tmp_path: Path) -> Project:
return Project(
tmp_path,
enable_video_check=False,
enable_session_tracker=False,
validate_project_dir=False,
)


def test_save_predictions_multiclass_empty_allocates_per_class_shape(tmp_path: Path) -> None:
"""An empty multi-class video still allocates a (n_id, n_frames, n_classes) prob array.

Shape is derived from class_names, not from sniffing a sample probability
array, so an empty ``probabilities`` dict no longer falls through to a binary
2D shape (which would then fail BehaviorPrediction's shape validation).
"""
project = _bare_project(tmp_path)
class_names = [MULTICLASS_NONE_BEHAVIOR, "Walk", "Run"]

project.save_predictions(
_prediction_test_pose(2, 5),
"video1.avi",
{},
{},
MULTICLASS_PREDICTION_KEY,
_prediction_test_classifier(),
class_names=class_names,
)

safe = to_safe_name(MULTICLASS_PREDICTION_KEY)
with h5py.File(project.project_paths.prediction_dir / "video1.h5", "r") as hf:
assert hf[f"predictions/{safe}/probabilities"].shape == (2, 5, 3)


def test_save_predictions_multiclass_writes_per_class_probabilities(tmp_path: Path) -> None:
"""Non-empty multi-class predictions write the per-identity per-class matrices."""
project = _bare_project(tmp_path)
class_names = [MULTICLASS_NONE_BEHAVIOR, "Walk", "Run"]
predictions = {0: np.array([1, 0, 2], dtype=np.int8)}
probabilities = {
0: np.array([[0.1, 0.8, 0.1], [0.7, 0.2, 0.1], [0.2, 0.3, 0.5]], dtype=np.float32)
}

project.save_predictions(
_prediction_test_pose(1, 3),
"video1.avi",
predictions,
probabilities,
MULTICLASS_PREDICTION_KEY,
_prediction_test_classifier(),
class_names=class_names,
)

safe = to_safe_name(MULTICLASS_PREDICTION_KEY)
with h5py.File(project.project_paths.prediction_dir / "video1.h5", "r") as hf:
probs = hf[f"predictions/{safe}/probabilities"][()]
assert probs.shape == (1, 3, 3)
np.testing.assert_allclose(probs[0], probabilities[0])


def test_save_predictions_multiclass_rejects_misshaped_probabilities(tmp_path: Path) -> None:
"""A per-identity probability array that doesn't match (n_frames, n_classes) raises.

Guards against a mis-shaped array (e.g. 1D or (n_frames, 1)) broadcasting
silently into the class-sized allocation instead of failing.
"""
project = _bare_project(tmp_path)
class_names = [MULTICLASS_NONE_BEHAVIOR, "Walk", "Run"]
predictions = {0: np.array([1, 0, 2], dtype=np.int8)}
# 1D probabilities would broadcast across the 3 class columns without a check
probabilities = {0: np.array([0.9, 0.4, 0.8], dtype=np.float32)}

with pytest.raises(ValueError, match="probability array for identity 0"):
project.save_predictions(
_prediction_test_pose(1, 3),
"video1.avi",
predictions,
probabilities,
MULTICLASS_PREDICTION_KEY,
_prediction_test_classifier(),
class_names=class_names,
)


def test_save_predictions_binary_allocates_scalar_shape(tmp_path: Path) -> None:
"""Binary predictions (no class_names) allocate a 2D (n_id, n_frames) prob array."""
project = _bare_project(tmp_path)
predictions = {0: np.array([1, 0, 1], dtype=np.int8)}
probabilities = {0: np.array([0.9, 0.4, 0.8], dtype=np.float32)}

project.save_predictions(
_prediction_test_pose(1, 3),
"video1.avi",
predictions,
probabilities,
"Walk",
_prediction_test_classifier(),
)

with h5py.File(project.project_paths.prediction_dir / "video1.h5", "r") as hf:
assert hf["predictions/Walk/probabilities"].shape == (1, 3)


def test_rename_behavior_multiclass_updates_classifier_and_predictions(tmp_path: Path) -> None:
"""In multi-class mode, rename updates the shared classifier and prediction class_names.

Expand Down
Loading