diff --git a/src/trackers/core/mcbyte/mask_manager.py b/src/trackers/core/mcbyte/mask_manager.py new file mode 100644 index 00000000..3cff7da9 --- /dev/null +++ b/src/trackers/core/mcbyte/mask_manager.py @@ -0,0 +1,72 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import numpy as np + +from trackers.core.mcbyte.masks.base import ( + MaskGenerator, + MaskOutput, + MaskPropagator, + TrackletSnapshot, +) + + +class MaskManager: + """Manage McByte mask generation and propagation. + + The manager follows the original McByte timing: masks for the current frame + are prepared before association, but they are initialized/updated from tracker + outputs produced on the previous frame. + """ + + def __init__( + self, + mask_generator: MaskGenerator, + mask_propagator: MaskPropagator | None = None, + ) -> None: + self.mask_generator = mask_generator + self.mask_propagator = mask_propagator + self._initialized = False + + def reset(self) -> None: + self._initialized = False + if self.mask_propagator is not None: + self.mask_propagator.reset() + + def get_updated_masks( + self, + frame: np.ndarray, + previous_frame: np.ndarray | None, + previous_tracklets: list[TrackletSnapshot], + ) -> MaskOutput | None: + """Return masks for the current frame. + + No masks are returned until at least one previous frame and previous + tracker output are available. + + If a propagator is configured, masks are initialized from the previous + frame and propagated to the current frame. If propagation is unavailable + or fails, ``None`` is returned to avoid using stale or misaligned masks. + """ + if previous_frame is None or len(previous_tracklets) == 0: + return None + + if self.mask_propagator is None: + return None + + if not self._initialized: + mask_output = self.mask_generator.generate(previous_frame, previous_tracklets) + self.mask_propagator.initialize(previous_frame, mask_output) + self._initialized = True + + propagated_output = self.mask_propagator.propagate(frame) + if propagated_output is not None: + return propagated_output + + self._initialized = False + return None diff --git a/src/trackers/core/mcbyte/masks/__init__.py b/src/trackers/core/mcbyte/masks/__init__.py new file mode 100644 index 00000000..ea911a75 --- /dev/null +++ b/src/trackers/core/mcbyte/masks/__init__.py @@ -0,0 +1,19 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from trackers.core.mcbyte.masks.base import ( + MaskGenerator, + MaskOutput, + MaskPropagator, + TrackletSnapshot, +) + +__all__ = [ + "MaskGenerator", + "MaskOutput", + "MaskPropagator", + "TrackletSnapshot", +] diff --git a/src/trackers/core/mcbyte/masks/base.py b/src/trackers/core/mcbyte/masks/base.py new file mode 100644 index 00000000..7b1d1122 --- /dev/null +++ b/src/trackers/core/mcbyte/masks/base.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import numpy as np + + +@dataclass(frozen=True) +class TrackletSnapshot: + """Minimal tracker state needed by mask components.""" + + tracker_id: int + xyxy: np.ndarray + + +@dataclass(frozen=True) +class MaskOutput: + """Mask information produced before McByte association.""" + + masks: np.ndarray | None + tracklet_mask_dict: dict[int, int] + mask_avg_prob_dict: dict[int, float] | None = None + + +class MaskGenerator(ABC): + """Generate masks from tracklet boxes.""" + + @abstractmethod + def generate( + self, + frame: np.ndarray, + tracklets: list[TrackletSnapshot], + ) -> MaskOutput: + """Generate masks for the given tracklet snapshots.""" + + +class MaskPropagator(ABC): + """Propagate masks from one frame to the next.""" + + @abstractmethod + def reset(self) -> None: + """Reset propagation state.""" + + @abstractmethod + def initialize( + self, + frame: np.ndarray, + mask_output: MaskOutput, + ) -> None: + """Initialize propagation state.""" + + @abstractmethod + def propagate( + self, + frame: np.ndarray, + ) -> MaskOutput | None: + """Propagate masks to the current frame.""" diff --git a/src/trackers/core/mcbyte/masks/dummy.py b/src/trackers/core/mcbyte/masks/dummy.py new file mode 100644 index 00000000..433c887c --- /dev/null +++ b/src/trackers/core/mcbyte/masks/dummy.py @@ -0,0 +1,84 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import numpy as np + +from trackers.core.mcbyte.masks.base import ( + MaskGenerator, + MaskOutput, + MaskPropagator, + TrackletSnapshot, +) + + +class DummyBoxMaskGenerator(MaskGenerator): + """Generate rectangular binary masks from tracklet bounding boxes.""" + + def generate( + self, + frame: np.ndarray, + tracklets: list[TrackletSnapshot], + ) -> MaskOutput: + height, width = frame.shape[:2] + masks = np.zeros((len(tracklets), height, width), dtype=bool) + tracklet_mask_dict: dict[int, int] = {} + + for mask_index, tracklet in enumerate(tracklets): + x1, y1, x2, y2 = tracklet.xyxy.astype(int) + + x1 = int(np.clip(x1, 0, width)) + x2 = int(np.clip(x2, 0, width)) + y1 = int(np.clip(y1, 0, height)) + y2 = int(np.clip(y2, 0, height)) + + masks[mask_index, y1:y2, x1:x2] = True + tracklet_mask_dict[tracklet.tracker_id] = mask_index + + return MaskOutput( + masks=masks, + tracklet_mask_dict=tracklet_mask_dict, + mask_avg_prob_dict=None, + ) + + +class DummyIdentityMaskPropagator(MaskPropagator): + """Return the last initialized mask output unchanged.""" + + def __init__(self) -> None: + self._mask_output: MaskOutput | None = None + + def reset(self) -> None: + self._mask_output = None + + def initialize( + self, + frame: np.ndarray, + mask_output: MaskOutput, + ) -> None: + self._mask_output = MaskOutput( + masks=None if mask_output.masks is None else mask_output.masks.copy(), + tracklet_mask_dict=mask_output.tracklet_mask_dict.copy(), + mask_avg_prob_dict=( + None if mask_output.mask_avg_prob_dict is None else mask_output.mask_avg_prob_dict.copy() + ), + ) + + def propagate( + self, + frame: np.ndarray, + ) -> MaskOutput | None: + if self._mask_output is None: + return None + + return MaskOutput( + masks=None if self._mask_output.masks is None else self._mask_output.masks.copy(), + tracklet_mask_dict=self._mask_output.tracklet_mask_dict.copy(), + mask_avg_prob_dict=( + None if self._mask_output.mask_avg_prob_dict is None else self._mask_output.mask_avg_prob_dict.copy() + ), + ) diff --git a/src/trackers/core/mcbyte/tracker.py b/src/trackers/core/mcbyte/tracker.py index 9edecd9e..2a101b44 100644 --- a/src/trackers/core/mcbyte/tracker.py +++ b/src/trackers/core/mcbyte/tracker.py @@ -12,6 +12,12 @@ from scipy.optimize import linear_sum_assignment from trackers.core.base import BaseTracker +from trackers.core.mcbyte.mask_manager import MaskManager +from trackers.core.mcbyte.masks.base import MaskOutput, TrackletSnapshot +from trackers.core.mcbyte.masks.dummy import ( + DummyBoxMaskGenerator, + DummyIdentityMaskPropagator, +) from trackers.core.mcbyte.tracklet import McByteTracklet from trackers.core.mcbyte.utils import _fuse_score, get_alive_tracklets from trackers.utils.cmc import CMC, CMCConfig, CMCMethod @@ -24,6 +30,39 @@ class McByteTracker(BaseTracker): + """McByte-style multi-object tracker. + + This tracker currently provides the initial McByte integration skeleton, + built on top of IoU association, Kalman-filter-based tracklets, optional camera + motion compensation, and optional mask-manager infrastructure. + + Args: + lost_track_buffer: Time buffer, in frames at 30 FPS, for keeping lost + tracks alive before deletion. This value is scaled by ``frame_rate``. + frame_rate: Video frame rate used to scale ``lost_track_buffer``. + track_activation_threshold: Minimum confidence required to spawn a new + track. + minimum_consecutive_frames: Number of successful updates required before + assigning a stable track ID. + minimum_iou_threshold_first_assoc: Minimum similarity threshold for the + first association stage. + minimum_iou_threshold_second_assoc: Minimum similarity threshold for the + second association stage. + minimum_iou_threshold_unconfirmed_assoc: Minimum similarity threshold for + matching unconfirmed tracks. + high_conf_det_threshold: Confidence threshold used to split detections + into high- and low-confidence groups. + enable_cmc: Whether to enable camera motion compensation. + cmc_method: Camera motion compensation method. + cmc_downscale: Downscale factor used by camera motion compensation. + instant_first_frame_activation: Whether tracks spawned on the first frame + receive confirmed IDs immediately. + state_estimator_class: State estimator class used by McByte tracklets. + iou: IoU implementation used for association. + enable_mask_manager: Whether to create the default dummy mask manager. + mask_manager: Optional custom mask manager instance. + """ + tracker_id = "mcbyte" def __init__( @@ -42,6 +81,8 @@ def __init__( instant_first_frame_activation: bool = True, state_estimator_class: type[BaseStateEstimator] = XCYCWHStateEstimator, iou: BaseIoU | None = None, + enable_mask_manager: bool = False, + mask_manager: MaskManager | None = None, ) -> None: # Calculate maximum frames without update based on lost_track_buffer and # frame_rate. This scales the buffer based on the frame rate to ensure @@ -62,6 +103,17 @@ def __init__( self.enable_cmc = enable_cmc self.cmc = CMC(CMCConfig(method=cmc_method, downscale=cmc_downscale)) if enable_cmc else None + self.mask_manager = mask_manager + if self.mask_manager is None and enable_mask_manager: + self.mask_manager = MaskManager( + mask_generator=DummyBoxMaskGenerator(), + mask_propagator=DummyIdentityMaskPropagator(), + ) + + self._previous_frame: np.ndarray | None = None + self._previous_tracklets: list[TrackletSnapshot] = [] + self._last_mask_output: MaskOutput | None = None + def update( self, detections: sv.Detections, @@ -90,9 +142,24 @@ def update( """ self.frame_id += 1 + # For the convenience and better understanding. McByte processes uses previous + # frame and current frame. It is better to keep the method argument as "frame", + # as in case of the other trackers. + current_frame = frame + + if self.mask_manager is not None and current_frame is not None: + self._last_mask_output = self.mask_manager.get_updated_masks( + frame=current_frame, + previous_frame=self._previous_frame, + previous_tracklets=self._previous_tracklets, + ) + else: + self._last_mask_output = None + if len(self.tracks) == 0 and len(detections) == 0: result = sv.Detections.empty() result.tracker_id = np.array([], dtype=int) + self._store_previous_mask_inputs(current_frame, result) return result out_det_indices: list[int] = [] @@ -132,9 +199,9 @@ def update( unconfirmed_tracks.append(track) # CMC: apply to all predicted tracks before association - if self.enable_cmc and self.cmc is not None and frame is not None: + if self.enable_cmc and self.cmc is not None and current_frame is not None: mask_boxes = high_boxes if len(high_boxes) > 0 else None - H = self.cmc.estimate(frame, mask_boxes) + H = self.cmc.estimate(current_frame, mask_boxes) CMC.apply_batch(H, self.tracks) # Step 1: associate high-confidence detections to confirmed + lost tracks. # Lost tracks are included here (following the original ByteTrack), and @@ -230,13 +297,44 @@ def update( if not out_det_indices: result = sv.Detections.empty() result.tracker_id = np.array([], dtype=int) + self._store_previous_mask_inputs(current_frame, result) return result idx = np.array(out_det_indices) result = cast(sv.Detections, detections[idx]) result.tracker_id = np.array(out_tracker_ids, dtype=int) + self._store_previous_mask_inputs(current_frame, result) return result + def _store_previous_mask_inputs( + self, + frame: np.ndarray | None, + detections: sv.Detections, + ) -> None: + """Store current tracker output for mask preparation on the next frame.""" + self._previous_frame = None + self._previous_tracklets = [] + + if self.mask_manager is None or frame is None or detections.tracker_id is None: + return + + previous_tracklets = [] + for xyxy, tracker_id in zip(detections.xyxy, detections.tracker_id): + if tracker_id < 0: + continue + previous_tracklets.append( + TrackletSnapshot( + tracker_id=int(tracker_id), + xyxy=xyxy.copy(), + ) + ) + + if len(previous_tracklets) == 0: + return + + self._previous_frame = frame.copy() + self._previous_tracklets = previous_tracklets + def _get_iou_matrix(self, tracklets: list[McByteTracklet], detections: np.ndarray) -> np.ndarray: if len(tracklets) == 0: tracklet_boxes = np.empty((0, 4)) @@ -316,12 +414,18 @@ def _spawn_new_tracks( out_tracker_ids.append(tracklet.tracker_id) def reset(self) -> None: - """Reset tracker state by clearing all tracks and resetting ID counter. - Call this method when switching to a new video or scene. + """Reset tracker state by clearing all tracks, resetting ID counter, camera + motion compensation and mask manager. Call this method when switching to a new + video or scene. """ self.tracks = [] self.frame_id = 0 McByteTracklet.count_id = 0 + self._previous_frame = None + self._previous_tracklets = [] + self._last_mask_output = None + if self.mask_manager is not None: + self.mask_manager.reset() if self.cmc is not None: self.cmc.reset() diff --git a/tests/core/test_mcbyte_mask_manager.py b/tests/core/test_mcbyte_mask_manager.py new file mode 100644 index 00000000..5e0c8e2b --- /dev/null +++ b/tests/core/test_mcbyte_mask_manager.py @@ -0,0 +1,161 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import numpy as np + +from trackers.core.mcbyte.mask_manager import MaskManager +from trackers.core.mcbyte.masks import TrackletSnapshot +from trackers.core.mcbyte.masks.dummy import DummyBoxMaskGenerator, DummyIdentityMaskPropagator + + +def _make_frame(h: int = 100, w: int = 120) -> np.ndarray: + return np.zeros((h, w, 3), dtype=np.uint8) + + +def test_dummy_box_mask_generator_returns_expected_shape() -> None: + generator = DummyBoxMaskGenerator() + frame = _make_frame() + + output = generator.generate( + frame=frame, + tracklets=[ + TrackletSnapshot( + tracker_id=7, + xyxy=np.array([10, 20, 30, 40], dtype=np.float32), + ) + ], + ) + + assert output.masks is not None + assert output.masks.shape == (1, 100, 120) + assert output.tracklet_mask_dict == {7: 0} + + +def test_dummy_box_mask_generator_fills_detection_box() -> None: + generator = DummyBoxMaskGenerator() + frame = _make_frame() + + output = generator.generate( + frame=frame, + tracklets=[ + TrackletSnapshot( + tracker_id=7, + xyxy=np.array([10, 20, 30, 40], dtype=np.float32), + ) + ], + ) + + assert output.masks is not None + assert output.masks[0, 20:40, 10:30].all() + assert not output.masks[0, :10, :10].any() + assert output.masks.sum() == 20 * 20 + + +def test_mask_manager_returns_none_without_previous_frame_or_tracklets() -> None: + manager = MaskManager( + mask_generator=DummyBoxMaskGenerator(), + mask_propagator=DummyIdentityMaskPropagator(), + ) + + output = manager.get_updated_masks( + frame=_make_frame(), + previous_frame=None, + previous_tracklets=[], + ) + + assert output is None + + +def test_mask_manager_returns_none_without_propagator() -> None: + manager = MaskManager( + mask_generator=DummyBoxMaskGenerator(), + mask_propagator=None, + ) + + output = manager.get_updated_masks( + frame=_make_frame(), + previous_frame=_make_frame(), + previous_tracklets=[ + TrackletSnapshot( + tracker_id=3, + xyxy=np.array([5, 6, 25, 30], dtype=np.float32), + ) + ], + ) + + assert output is None + + +def test_mask_manager_uses_propagator_after_initialization() -> None: + manager = MaskManager( + mask_generator=DummyBoxMaskGenerator(), + mask_propagator=DummyIdentityMaskPropagator(), + ) + + previous_tracklets = [ + TrackletSnapshot( + tracker_id=3, + xyxy=np.array([5, 6, 25, 30], dtype=np.float32), + ) + ] + + first_output = manager.get_updated_masks( + frame=_make_frame(), + previous_frame=_make_frame(), + previous_tracklets=previous_tracklets, + ) + + second_output = manager.get_updated_masks( + frame=_make_frame(), + previous_frame=_make_frame(), + previous_tracklets=[ + TrackletSnapshot( + tracker_id=99, + xyxy=np.array([50, 50, 70, 70], dtype=np.float32), + ) + ], + ) + + assert first_output is not None + assert second_output is not None + assert second_output.tracklet_mask_dict == first_output.tracklet_mask_dict + + +def test_mask_manager_reset_clears_state() -> None: + manager = MaskManager( + mask_generator=DummyBoxMaskGenerator(), + mask_propagator=DummyIdentityMaskPropagator(), + ) + + output_before_reset = manager.get_updated_masks( + frame=_make_frame(), + previous_frame=_make_frame(), + previous_tracklets=[ + TrackletSnapshot( + tracker_id=3, + xyxy=np.array([5, 6, 25, 30], dtype=np.float32), + ) + ], + ) + + manager.reset() + + output_after_reset = manager.get_updated_masks( + frame=_make_frame(), + previous_frame=_make_frame(), + previous_tracklets=[ + TrackletSnapshot( + tracker_id=9, + xyxy=np.array([40, 40, 60, 60], dtype=np.float32), + ) + ], + ) + + assert output_before_reset is not None + assert output_after_reset is not None + assert output_after_reset.tracklet_mask_dict == {9: 0} diff --git a/tests/core/test_mcbyte_tracker.py b/tests/core/test_mcbyte_tracker.py index 12417995..38e28e59 100644 --- a/tests/core/test_mcbyte_tracker.py +++ b/tests/core/test_mcbyte_tracker.py @@ -41,3 +41,43 @@ def test_mcbyte_instantiates_and_updates_with_frame_and_sparse_opt_flow_cmc_retu assert result.tracker_id is not None assert result.tracker_id[0] >= 0 assert len(tracker.tracks) == 1 + + +def test_mcbyte_reset_clears_mask_state() -> None: + """reset() clears tracker and mask-manager temporal state.""" + tracker = McByteTracker( + enable_cmc=False, + enable_mask_manager=True, + minimum_consecutive_frames=1, + ) + + frame = _make_frame() + + tracker.update(_detection((100.0, 100.0, 200.0, 200.0)), frame) + tracker.update(_detection((100.0, 100.0, 200.0, 200.0)), frame) + + assert tracker._previous_frame is not None + assert len(tracker._previous_tracklets) == 1 + assert tracker._last_mask_output is not None + + tracker.reset() + + assert tracker._previous_frame is None + assert tracker._previous_tracklets == [] + assert tracker._last_mask_output is None + + +def test_mcbyte_does_not_store_previous_frame_without_mask_manager() -> None: + """McByteTracker avoids frame copies when mask manager is disabled.""" + tracker = McByteTracker( + enable_cmc=False, + enable_mask_manager=False, + minimum_consecutive_frames=1, + ) + + frame = _make_frame() + + tracker.update(_detection((100.0, 100.0, 200.0, 200.0)), frame) + + assert tracker._previous_frame is None + assert tracker._previous_tracklets == []