diff --git a/src/trackers/core/mcbyte/__init__.py b/src/trackers/core/mcbyte/__init__.py new file mode 100644 index 00000000..207d6459 --- /dev/null +++ b/src/trackers/core/mcbyte/__init__.py @@ -0,0 +1,8 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +from .tracker import McByteTracker + +__all__ = ["McByteTracker"] diff --git a/src/trackers/core/mcbyte/tracker.py b/src/trackers/core/mcbyte/tracker.py new file mode 100644 index 00000000..9edecd9e --- /dev/null +++ b/src/trackers/core/mcbyte/tracker.py @@ -0,0 +1,343 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from typing import cast + +import numpy as np +import supervision as sv +from deprecate import deprecated +from scipy.optimize import linear_sum_assignment + +from trackers.core.base import BaseTracker +from trackers.core.mcbyte.tracklet import McByteTracklet +from trackers.core.mcbyte.utils import _fuse_score, get_alive_tracklets +from trackers.utils.cmc import CMC, CMCConfig, CMCMethod +from trackers.utils.detections import default_confidences +from trackers.utils.iou import BaseIoU, IoU +from trackers.utils.state_representations import ( + BaseStateEstimator, + XCYCWHStateEstimator, +) + + +class McByteTracker(BaseTracker): + tracker_id = "mcbyte" + + def __init__( + self, + lost_track_buffer: int = 30, + frame_rate: float = 30.0, + track_activation_threshold: float = 0.7, + minimum_consecutive_frames: int = 2, + minimum_iou_threshold_first_assoc: float = 0.2, + minimum_iou_threshold_second_assoc: float = 0.5, + minimum_iou_threshold_unconfirmed_assoc: float = 0.3, + high_conf_det_threshold: float = 0.6, + enable_cmc: bool = True, + cmc_method: CMCMethod = "sparseOptFlow", + cmc_downscale: int = 2, + instant_first_frame_activation: bool = True, + state_estimator_class: type[BaseStateEstimator] = XCYCWHStateEstimator, + iou: BaseIoU | None = None, + ) -> None: + # Calculate maximum frames without update based on lost_track_buffer and + # frame_rate. This scales the buffer based on the frame rate to ensure + # consistent time-based tracking across different frame rates. + self.maximum_frames_without_update = int(frame_rate / 30.0 * lost_track_buffer) + self.minimum_consecutive_frames = minimum_consecutive_frames + self.minimum_iou_threshold_first_assoc = minimum_iou_threshold_first_assoc + self.minimum_iou_threshold_second_assoc = minimum_iou_threshold_second_assoc + self.minimum_iou_threshold_unconfirmed_assoc = minimum_iou_threshold_unconfirmed_assoc + self.track_activation_threshold = track_activation_threshold + self.high_conf_det_threshold = high_conf_det_threshold + self.instant_first_frame_activation = instant_first_frame_activation + self.tracks: list[McByteTracklet] = [] + self.state_estimator_class = state_estimator_class + self.iou = iou if iou is not None else IoU() + self.frame_id: int = 0 + + self.enable_cmc = enable_cmc + self.cmc = CMC(CMCConfig(method=cmc_method, downscale=cmc_downscale)) if enable_cmc else None + + def update( + self, + detections: sv.Detections, + frame: np.ndarray | None = None, + ) -> sv.Detections: + """ + Update the tracker with detections from the current frame. + + This is the main per-frame entry point. + + Args: + detections: Supervision detections for the current frame. Must include + ``.xyxy``. Confidence (`detections.confidence`) is optional but + recommended. This method does not mutate the input detections; + it returns a new ``sv.Detections`` with ``tracker_id`` assigned. + + Returns: + New sv.Detections with tracker_id assigned for each detection. + Confirmed tracks have tracker_id >= 0; unconfirmed tracks have + tracker_id of -1. + + Notes: + - If CMC is enabled, pass the current video frame via ``frame`` so the + tracker can estimate a global affine transform and warp predicted + track states before association. + """ + self.frame_id += 1 + + if len(self.tracks) == 0 and len(detections) == 0: + result = sv.Detections.empty() + result.tracker_id = np.array([], dtype=int) + return result + + out_det_indices: list[int] = [] + out_tracker_ids: list[int] = [] + + # Predict new locations for existing tracks + for tracker in self.tracks: + tracker.predict() + + detection_boxes = detections.xyxy + confidences = default_confidences(detections) + + # Split indices into high / low / discarded by confidence + high_mask = confidences >= self.high_conf_det_threshold + low_mask = (confidences > 0.1) & (~high_mask) + + high_indices = np.where(high_mask)[0] + low_indices = np.where(low_mask)[0] + + high_boxes = detection_boxes[high_indices] + low_boxes = detection_boxes[low_indices] + high_scores = confidences[high_indices] + + # Split tracks into confirmed, unconfirmed, and lost. + # After predict(), time_since_update == 1 means the track was matched in + # the previous frame ("tracked"), while time_since_update > 1 means the + # track has been unmatched for multiple frames ("lost"). + confirmed_tracks: list[McByteTracklet] = [] + unconfirmed_tracks: list[McByteTracklet] = [] + lost_tracks: list[McByteTracklet] = [] + for track in self.tracks: + if track.time_since_update > 1: + lost_tracks.append(track) + elif track.number_of_successful_updates >= self.minimum_consecutive_frames: + confirmed_tracks.append(track) + else: + unconfirmed_tracks.append(track) + + # CMC: apply to all predicted tracks before association + if self.enable_cmc and self.cmc is not None and frame is not None: + mask_boxes = high_boxes if len(high_boxes) > 0 else None + H = self.cmc.estimate(frame, mask_boxes) + CMC.apply_batch(H, self.tracks) + # Step 1: associate high-confidence detections to confirmed + lost tracks. + # Lost tracks are included here (following the original ByteTrack), and + # IoU is fused with detection scores. + strack_pool = confirmed_tracks + lost_tracks + iou_matrix = self._get_iou_matrix(strack_pool, high_boxes) + iou_matrix = _fuse_score(self.iou.normalize_for_fusion(iou_matrix), high_scores) + matched, unmatched_pool, unmatched_high = self._get_associated_indices( + iou_matrix, self.minimum_iou_threshold_first_assoc + ) + + for row, col in matched: + track = strack_pool[row] + track.update(high_boxes[col]) + if track.number_of_successful_updates >= self.minimum_consecutive_frames and track.tracker_id == -1: + track.tracker_id = McByteTracklet.get_next_tracker_id() + out_det_indices.append(int(high_indices[col])) + out_tracker_ids.append(track.tracker_id) + + # Step 2: associate low-confidence detections to remaining *tracked* tracks + # only (excluding lost tracks, following the original ByteTrack). + # No score fusing in second association. + remaining_tracked = [strack_pool[i] for i in unmatched_pool if strack_pool[i].time_since_update == 1] + iou_matrix = self._get_iou_matrix(remaining_tracked, low_boxes) + matched, _, unmatched_low = self._get_associated_indices(iou_matrix, self.minimum_iou_threshold_second_assoc) + + for row, col in matched: + track = remaining_tracked[row] + track.update(low_boxes[col]) + if track.number_of_successful_updates >= self.minimum_consecutive_frames and track.tracker_id == -1: + track.tracker_id = McByteTracklet.get_next_tracker_id() + out_det_indices.append(int(low_indices[col])) + out_tracker_ids.append(track.tracker_id) + + # Unmatched low-confidence detections + for det_local_idx in sorted(unmatched_low): + out_det_indices.append(int(low_indices[det_local_idx])) + out_tracker_ids.append(-1) + + # Step 3: match unconfirmed tracks with remaining unmatched high-confidence + # detections (with score fusing, following the original ByteTrack). + # Unmatched unconfirmed tracks are removed (not kept as lost). + unmatched_high_list = sorted(unmatched_high) + unmatched_uc_indices: list[int] = list(range(len(unconfirmed_tracks))) + + if len(unconfirmed_tracks) > 0 and len(unmatched_high_list) > 0: + uh_boxes = high_boxes[unmatched_high_list] + uh_scores = high_scores[unmatched_high_list] + + iou_matrix = self._get_iou_matrix(unconfirmed_tracks, uh_boxes) + iou_matrix = _fuse_score(self.iou.normalize_for_fusion(iou_matrix), uh_scores) + matched_uc, unmatched_uc_indices, remaining_uh = self._get_associated_indices( + iou_matrix, self.minimum_iou_threshold_unconfirmed_assoc + ) + + for row, col in matched_uc: + track = unconfirmed_tracks[row] + orig_high_idx = unmatched_high_list[col] + track.update(high_boxes[orig_high_idx]) + if track.number_of_successful_updates >= self.minimum_consecutive_frames and track.tracker_id == -1: + track.tracker_id = McByteTracklet.get_next_tracker_id() + out_det_indices.append(int(high_indices[orig_high_idx])) + out_tracker_ids.append(track.tracker_id) + + # Only remaining unmatched high-conf dets proceed to spawning + unmatched_high = [unmatched_high_list[i] for i in remaining_uh] + + # Remove unmatched unconfirmed tracks (following original ByteTrack, + # which marks them as removed rather than keeping them as lost). + if len(unmatched_uc_indices) > 0: + remove_ids = {id(unconfirmed_tracks[i]) for i in unmatched_uc_indices} + self.tracks = [t for t in self.tracks if id(t) not in remove_ids] + + # Spawn new tracks from unmatched high-confidence detections + self._spawn_new_tracks( + detection_boxes, + confidences, + unmatched_high, + high_indices, + out_det_indices, + out_tracker_ids, + is_first_frame=(self.frame_id == 1), + ) + + # Kill lost tracks + self.tracks = get_alive_tracklets( + tracklets=self.tracks, + maximum_frames_without_update=self.maximum_frames_without_update, + minimum_consecutive_frames=self.minimum_consecutive_frames, + ) + + # Build final detections + if not out_det_indices: + result = sv.Detections.empty() + result.tracker_id = np.array([], dtype=int) + return result + + idx = np.array(out_det_indices) + result = cast(sv.Detections, detections[idx]) + result.tracker_id = np.array(out_tracker_ids, dtype=int) + return result + + def _get_iou_matrix(self, tracklets: list[McByteTracklet], detections: np.ndarray) -> np.ndarray: + if len(tracklets) == 0: + tracklet_boxes = np.empty((0, 4)) + else: + tracklet_boxes = np.array([tracklet.get_state_bbox() for tracklet in tracklets]) + return self.iou.compute(tracklet_boxes, detections) + + def _get_associated_indices( + self, + similarity_matrix: np.ndarray, + min_similarity_thresh: float, + ) -> tuple[list[tuple[int, int]], list[int], list[int]]: + """ + Associate detections to tracks based on Similarity (IoU) using the + Jonker-Volgenant algorithm approach with no initialization instead of the + Hungarian algorithm as mentioned in the SORT paper, but it solves the + assignment problem in an optimal way. + + Args: + similarity_matrix: Similarity matrix between tracks (rows) and detections + (columns). min_similarity_thresh: Minimum similarity threshold for a valid + match. + + Returns: + matched: List of ``(tracker_idx, detection_idx)`` tuples for + associations that meet the similarity threshold. + unmatched_tracks: Sorted list of track indices not matched to any + detection. + unmatched_detections: Sorted list of detection indices not matched + to any track. + """ + matched_indices = [] + n_tracks, n_detections = similarity_matrix.shape + unmatched_tracks = set(range(n_tracks)) + unmatched_detections = set(range(n_detections)) + + if n_tracks > 0 and n_detections > 0: + row_indices, col_indices = linear_sum_assignment(similarity_matrix, maximize=True) + for row, col in zip(row_indices, col_indices): + if similarity_matrix[row, col] >= min_similarity_thresh: + matched_indices.append((row, col)) + unmatched_tracks.remove(row) + unmatched_detections.remove(col) + + # Return sorted lists for deterministic order across Python runtimes. + return matched_indices, sorted(unmatched_tracks), sorted(unmatched_detections) + + def _spawn_new_tracks( + self, + detection_boxes: np.ndarray, + confidences: np.ndarray, + unmatched_high_local: list[int], + high_indices: np.ndarray, + out_det_indices: list[int], + out_tracker_ids: list[int], + is_first_frame: bool = False, + ) -> None: + """Create new tracklets from unmatched high-confidence detections. + + On the very first frame, new tracklets are immediately activated with a + real tracker ID, following the original ByteTrack convention where + ``activate()`` sets ``is_activated = True`` only when + ``frame_id == 1``. + """ + for det_local_idx in unmatched_high_local: + global_idx = int(high_indices[det_local_idx]) + conf = float(confidences[global_idx]) + if conf >= self.track_activation_threshold: + tracklet = McByteTracklet( + initial_bbox=detection_boxes[global_idx], + state_estimator_class=self.state_estimator_class, + ) + if is_first_frame and self.instant_first_frame_activation: + tracklet.tracker_id = McByteTracklet.get_next_tracker_id() + self.tracks.append(tracklet) + out_det_indices.append(global_idx) + out_tracker_ids.append(tracklet.tracker_id) + + def reset(self) -> None: + """Reset tracker state by clearing all tracks and resetting ID counter. + Call this method when switching to a new video or scene. + """ + self.tracks = [] + self.frame_id = 0 + McByteTracklet.count_id = 0 + if self.cmc is not None: + self.cmc.reset() + + @deprecated(target=None, deprecated_in="2.5", remove_in="3.0") + def apply_cmc_batch(self, H: np.ndarray | None) -> None: + """Apply CMC to all active tracks. + + .. deprecated:: 2.5 + Use CMC.apply_batch(H, self.tracks) directly. + + Args: + H: 2x3 affine transform matrix returned by CMC.estimate(). + If None, this method is a no-op. + + Examples: + >>> tracker = McByteTracker() + >>> tracker.apply_cmc_batch(None) # no-op + """ + CMC.apply_batch(H, self.tracks) diff --git a/src/trackers/core/mcbyte/tracklet.py b/src/trackers/core/mcbyte/tracklet.py new file mode 100644 index 00000000..a3175c4e --- /dev/null +++ b/src/trackers/core/mcbyte/tracklet.py @@ -0,0 +1,229 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import numpy as np + +from trackers.utils.base_tracklet import BaseTracklet +from trackers.utils.cmc import CMC +from trackers.utils.converters import xyxy_to_xywh +from trackers.utils.state_representations import ( + BaseStateEstimator, + XCYCSRStateEstimator, + XCYCWHStateEstimator, + XYXYStateEstimator, +) + + +class McByteTracklet(BaseTracklet): + """Tracklet for the McByte tracker. + + Uses ``XCYCWHStateEstimator`` (center + width/height) by default, + mirroring the original BoT-SORT Kalman filter model - also used in original McByte. + + * **Scale-aware noise**: ``Q``, ``R`` and the initial ``P`` are computed + from the current width / height of the tracked object each frame, so + that uncertainty scales with object size. + * **Width / height clamping** after every predict and update step. + * ``predict()`` increments ``time_since_update``: unmatched tracks are + never explicitly fed ``update(None)``. + * ``number_of_successful_updates`` counts every successful measurement + update (never reset on a miss). + * ``apply_cmc(H)`` applies a 2x3 affine camera-motion transform to the + internal Kalman state and covariance. + """ + + count_id: int = 0 + + # Noise sigma constants (scale-aware noise for McByte) + _SIGMA_P: float = 0.05 + _SIGMA_V: float = 0.00625 + _SIGMA_M: float = 0.05 + + def __init__( + self, + initial_bbox: np.ndarray, + state_estimator_class: type[BaseStateEstimator] = XCYCWHStateEstimator, + ) -> None: + super().__init__(initial_bbox, state_estimator_class) + self._configure_initial_noise(initial_bbox) + # Count initial bbox as first successful update so that + # number_of_successful_updates starts at 1. + self.number_of_successful_updates = 1 + + def _configure_initial_noise(self, bbox: np.ndarray) -> None: + """Set initial P, Q, R based on the first detection's size.""" + measurement = xyxy_to_xywh(bbox) + w, h = float(measurement[2]), float(measurement[3]) + self._set_scale_aware_noise(w, h, initial=True) + + def _set_scale_aware_noise(self, w: float, h: float, *, initial: bool = False) -> None: + sp, sv, sm = self._SIGMA_P, self._SIGMA_V, self._SIGMA_M + + if isinstance(self.state_estimator, XCYCSRStateEstimator): + s = np.sqrt(max(w * h, 1e-6)) + Q = np.diag( + [ + (sp * w) ** 2, + (sp * h) ** 2, + (sp * s) ** 2, + (sp * 1.0) ** 2, + (sv * w) ** 2, + (sv * h) ** 2, + (sv * s) ** 2, + ] + ) + R = np.diag( + [ + (sm * w) ** 2, + (sm * h) ** 2, + (sm * s) ** 2, + (sm * 1.0) ** 2, + ] + ) + else: + Q = np.diag( + [ + (sp * w) ** 2, + (sp * h) ** 2, + (sp * w) ** 2, + (sp * h) ** 2, + (sv * w) ** 2, + (sv * h) ** 2, + (sv * w) ** 2, + (sv * h) ** 2, + ] + ) + R = np.diag( + [ + (sm * w) ** 2, + (sm * h) ** 2, + (sm * w) ** 2, + (sm * h) ** 2, + ] + ) + + if initial: + if isinstance(self.state_estimator, XCYCSRStateEstimator): + s = np.sqrt(max(w * h, 1e-6)) + P = np.diag( + [ + (2 * sp * w) ** 2, + (2 * sp * h) ** 2, + (2 * sp * s) ** 2, + (2 * sp * 1.0) ** 2, + (10 * sv * w) ** 2, + (10 * sv * h) ** 2, + (10 * sv * s) ** 2, + ] + ) + else: + P = np.diag( + [ + (2 * sp * w) ** 2, + (2 * sp * h) ** 2, + (2 * sp * w) ** 2, + (2 * sp * h) ** 2, + (10 * sv * w) ** 2, + (10 * sv * h) ** 2, + (10 * sv * w) ** 2, + (10 * sv * h) ** 2, + ] + ) + self.state_estimator.set_kf_covariances(R=R, Q=Q, P=P) + else: + self.state_estimator.set_kf_covariances(R=R, Q=Q) + + def _refresh_noise_from_state(self) -> None: + """Recompute Q and R from the current bbox size.""" + bbox = self.state_estimator.state_to_bbox() + w = max(float(bbox[2] - bbox[0]), 1e-3) + h = max(float(bbox[3] - bbox[1]), 1e-3) + self._set_scale_aware_noise(w, h) + + @staticmethod + def _clamp_xyxy_state(kf_x: np.ndarray) -> None: + """Ensure XYXY state keeps valid box corners.""" + if kf_x[2, 0] <= kf_x[0, 0]: + kf_x[2, 0] = kf_x[0, 0] + 1e-3 + if kf_x[3, 0] <= kf_x[1, 0]: + kf_x[3, 0] = kf_x[1, 0] + 1e-3 + + @staticmethod + def _clamp_xcycwh_state(kf_x: np.ndarray) -> None: + """Ensure XCYCWH state keeps positive width and height.""" + kf_x[2, 0] = max(kf_x[2, 0], 1e-3) + kf_x[3, 0] = max(kf_x[3, 0], 1e-3) + + @staticmethod + def _clamp_xcycsr_state(kf_x: np.ndarray) -> None: + """Ensure XCYCSR state keeps positive scale and aspect ratio.""" + kf_x[2, 0] = max(kf_x[2, 0], 1e-3) + kf_x[3, 0] = max(kf_x[3, 0], 1e-3) + + def _clamp_state_bbox(self) -> None: + """Clamp geometric components based on active state representation.""" + kf_x = self.state_estimator.kf.x + if isinstance(self.state_estimator, XYXYStateEstimator): + self._clamp_xyxy_state(kf_x) + elif isinstance(self.state_estimator, XCYCWHStateEstimator): + self._clamp_xcycwh_state(kf_x) + elif isinstance(self.state_estimator, XCYCSRStateEstimator): + self._clamp_xcycsr_state(kf_x) + + def update(self, bbox: np.ndarray) -> None: + """Update tracklet with a new observation. + + In the McByte flow **only matched tracks** call ``update(bbox)`` + with an actual bounding box. Unmatched tracks simply skip + ``update`` (their ``time_since_update`` is incremented in + ``predict`` instead). + """ + self._refresh_noise_from_state() + self.state_estimator.update(bbox) + self._clamp_state_bbox() + self.time_since_update = 0 + self.number_of_successful_updates += 1 + + def predict(self) -> np.ndarray: + """Predict the next bounding-box position. + + Increments ``time_since_update`` to track how many frames have + elapsed since the last matched measurement — this replaces the + ``update(None)`` call used in ByteTrack/SORT. + """ + self._refresh_noise_from_state() + self.state_estimator.predict() + self._clamp_state_bbox() + self.age += 1 + self.time_since_update += 1 + return self.state_estimator.state_to_bbox() + + def get_state_bbox(self) -> np.ndarray: + """Return the current bounding-box estimate in xyxy format.""" + return self.state_estimator.state_to_bbox() + + def apply_cmc(self, H: np.ndarray | None) -> None: + """Apply a 2x3 affine camera-motion transform **in place**. + + Delegates to :meth:`CMC.apply_batch` with ``[self]`` as the + tracklet list. See that method for full documentation of the + transform convention, state-representation handling, and covariance + update rules. + + Args: + H: 2x3 affine transform matrix. If ``None``, this is a no-op. + + Examples: + >>> import numpy as np + >>> bbox = np.array([10.0, 20.0, 50.0, 80.0]) + >>> tracklet = McByteTracklet(bbox) + >>> H = np.array([[1.0, 0.0, 5.0], [0.0, 1.0, -3.0]], dtype=np.float32) + >>> tracklet.apply_cmc(H) + >>> tracklet.apply_cmc(None) # no-op + """ + CMC.apply_batch(H, [self]) diff --git a/src/trackers/core/mcbyte/utils.py b/src/trackers/core/mcbyte/utils.py new file mode 100644 index 00000000..0af28a33 --- /dev/null +++ b/src/trackers/core/mcbyte/utils.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from collections.abc import Sequence + +import numpy as np + +from trackers.core.mcbyte.tracklet import McByteTracklet + + +def get_alive_tracklets( + tracklets: Sequence[McByteTracklet], + minimum_consecutive_frames: int, + maximum_frames_without_update: int, +) -> list[McByteTracklet]: + """ + Remove dead or immature lost tracklets and return alive ones. + + A tracklet is kept if it is within ``maximum_frames_without_update`` **and** + it is either mature (enough successful updates) or was just updated this + frame. + + Args: + tracklets: List of McByteTracklet objects. + minimum_consecutive_frames: Number of successful updates that an object + must have before it is considered a 'valid' track. + maximum_frames_without_update: Maximum number of frames without update + before a track is considered dead. + + Returns: + List of alive tracklets. + """ + alive_tracklets = [] + for tracker in tracklets: + is_mature = tracker.number_of_successful_updates >= minimum_consecutive_frames + is_active = tracker.time_since_update == 0 + if tracker.time_since_update < maximum_frames_without_update and (is_mature or is_active): + alive_tracklets.append(tracker) + return alive_tracklets + + +def _fuse_score(iou_similarity: np.ndarray, scores: np.ndarray) -> np.ndarray: + """Fuse IoU similarity matrix with detection confidence scores. + + Following the original ByteTrack implementation, the IoU similarity is + multiplied element-wise by the detection scores. This biases the + association toward higher-confidence detections. + + Args: + iou_similarity: IoU similarity matrix of shape ``(n_tracks, n_dets)``. + scores: Detection confidence scores of shape ``(n_dets,)``. + + Returns: + Fused similarity matrix of the same shape. + """ + if iou_similarity.size == 0: + return iou_similarity + return iou_similarity * scores[np.newaxis, :] diff --git a/tests/core/test_mcbyte_tracker.py b/tests/core/test_mcbyte_tracker.py new file mode 100644 index 00000000..12417995 --- /dev/null +++ b/tests/core/test_mcbyte_tracker.py @@ -0,0 +1,43 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import numpy as np +import supervision as sv + +from trackers.core.mcbyte.tracker import McByteTracker + + +def _detection(xyxy: tuple[float, float, float, float], conf: float = 0.9) -> sv.Detections: + return sv.Detections( + xyxy=np.array([xyxy], dtype=np.float32), + confidence=np.array([conf], dtype=np.float32), + ) + + +def _make_frame(h: int = 480, w: int = 640, seed: int = 42) -> np.ndarray: + rng = np.random.default_rng(seed) + return rng.integers(0, 255, (h, w, 3), dtype=np.uint8) + + +def test_mcbyte_instantiates_and_updates_with_frame_and_sparse_opt_flow_cmc_returns_ids() -> None: + """McByteTracker can update with a frame and CMC enabled, returning track IDs.""" + tracker = McByteTracker( + enable_cmc=True, + cmc_method="sparseOptFlow", + minimum_consecutive_frames=2, + ) + + frame = _make_frame() + + for _ in range(5): + result = tracker.update(_detection((100.0, 100.0, 200.0, 200.0)), frame) + + assert len(result) == 1 + assert result.tracker_id is not None + assert result.tracker_id[0] >= 0 + assert len(tracker.tracks) == 1