diff --git a/demos/openpose/src/demo.py b/demos/openpose/src/demo.py index 995227a0..f5967813 100644 --- a/demos/openpose/src/demo.py +++ b/demos/openpose/src/demo.py @@ -104,25 +104,24 @@ def __call__(self, image): detected_poses = datum.poseKeypoints if detected_poses is None: - tracked_objects = tracker.update(period=args.skip_frame) - continue - - detections = ( - [] - if not detected_poses.any() - else [ - Detection(p, scores=s) - for (p, s) in zip( - detected_poses[:, :, :2], detected_poses[:, :, 2] - ) - ] - ) + detections = [] + else: + detections = ( + [] + if not detected_poses.any() + else [ + Detection(p, scores=s) + for (p, s) in zip( + detected_poses[:, :, :2], detected_poses[:, :, 2] + ) + ] + ) tracked_objects = tracker.update( - detections=detections, period=args.skip_frame + detections=detections, ) norfair.draw_points(frame, detections) else: - tracked_objects = tracker.update(period=args.skip_frame) + tracked_objects = tracker.update() norfair.draw_tracked_objects(frame, tracked_objects) video.write(frame) diff --git a/demos/profiling/src/demo.py b/demos/profiling/src/demo.py index d7805e1f..ba74145d 100644 --- a/demos/profiling/src/demo.py +++ b/demos/profiling/src/demo.py @@ -52,7 +52,7 @@ def process_video( detector_time = time.time() tracked_objects = tracker.update( - detections=detections, period=frame_skip_period + detections=detections, ) tracker_time = time.time() diff --git a/demos/reid/src/demo.py b/demos/reid/src/demo.py index 36548d5b..c7e216eb 100644 --- a/demos/reid/src/demo.py +++ b/demos/reid/src/demo.py @@ -77,7 +77,7 @@ def main( else: detection.embedding = None - tracked_objects = tracker.update(detections=detections, period=skip_period) + tracked_objects = tracker.update(detections=detections) else: tracked_objects = tracker.update() draw_points(cv2_frame, detections) diff --git a/demos/sahi/src/demo.py b/demos/sahi/src/demo.py index fb43b269..05ab91ff 100644 --- a/demos/sahi/src/demo.py +++ b/demos/sahi/src/demo.py @@ -70,7 +70,7 @@ def main( result = get_prediction(frame, detection_model) detections = get_detections(result.object_prediction_list) - tracked_objects = tracker.update(detections=detections, period=skip_period) + tracked_objects = tracker.update(detections=detections) else: tracked_objects = tracker.update() diff --git a/norfair/tracker.py b/norfair/tracker.py index 275eceee..63ac5483 100644 --- a/norfair/tracker.py +++ b/norfair/tracker.py @@ -141,10 +141,12 @@ def __init__( self.reid_distance_threshold = reid_distance_threshold self._obj_factory = _TrackedObjectFactory() + self.skipped_frames = 1 + def update( self, detections: Optional[List["Detection"]] = None, - period: int = 1, + hit_counter_jump: int = None, coord_transformations: Optional[CoordinatesTransformation] = None, ) -> List["TrackedObject"]: """ @@ -161,10 +163,11 @@ def update( If no detections have been found in the current frame, or the user is purposely skipping frames to improve video processing time, this argument should be set to None or ignored, as the update function is needed to advance the state of the Kalman Filters inside the tracker. - period : int, optional - The user can chose not to run their detector on all frames, so as to process video faster. - This parameter sets every how many frames the detector is getting ran, - so that the tracker is aware of this situation and can handle it properly. + hit_counter_jump : int, optional + The user can choose not to run their detector on all frames, so as to process video faster. + This parameter sets by how much the hit_counter of a trackedObject can increase or decrease in a single update. + By default, the amount the hit_counter jumps corresponds to how many frmes were previously skipped. + A frame is considered skipped if detections is None. This argument can be reset on each frame processed, which is useful if the user is dynamically changing how many frames the detector is skipping on a video when working in real-time. @@ -176,6 +179,16 @@ def update( List[TrackedObject] The list of active tracked objects. """ + + # if no detection argument is provided, then assume frame was skipped + if detections is None: + hit_counter_jump = 0 + self.skipped_frames += 1 + else: + if hit_counter_jump is None: + hit_counter_jump = self.skipped_frames + self.skipped_frames = 1 + if coord_transformations is not None: for det in detections: det.update_coordinate_transformation(coord_transformations) @@ -201,7 +214,7 @@ def update( # Update tracker for obj in self.tracked_objects: - obj.tracker_step() + obj.tracker_step(hit_counter_jump) obj.update_coordinate_transformation(coord_transformations) # Update initialized tracked objects with detections @@ -214,7 +227,7 @@ def update( self.distance_threshold, [o for o in alive_objects if not o.is_initializing], detections, - period, + hit_counter_jump, ) # Update not yet initialized tracked objects with yet unmatched detections @@ -227,7 +240,7 @@ def update( self.distance_threshold, [o for o in alive_objects if o.is_initializing], unmatched_detections, - period, + hit_counter_jump, ) if self.reid_distance_function is not None: @@ -237,7 +250,7 @@ def update( self.reid_distance_threshold, unmatched_init_trackers + dead_objects, matched_not_init_trackers, - period, + hit_counter_jump, ) # Create new tracked objects from remaining unmatched detections @@ -249,7 +262,7 @@ def update( initialization_delay=self.initialization_delay, pointwise_hit_counter_max=self.pointwise_hit_counter_max, detection_threshold=self.detection_threshold, - period=period, + initial_hit_counter=1, filter_factory=self.filter_factory, past_detections_length=self.past_detections_length, reid_hit_counter_max=self.reid_hit_counter_max, @@ -289,7 +302,7 @@ def _update_objects_in_place( distance_threshold, objects: Sequence["TrackedObject"], candidates: Optional[Union[List["Detection"], List["TrackedObject"]]], - period: int, + hit_counter_jump: int, ): if candidates is not None and len(candidates) > 0: distance_matrix = distance_function.get_distances(objects, candidates) @@ -326,7 +339,9 @@ def _update_objects_in_place( matched_object = objects[match_obj_idx] if match_distance < distance_threshold: if isinstance(matched_candidate, Detection): - matched_object.hit(matched_candidate, period=period) + matched_object.hit( + matched_candidate, hit_counter_jump=hit_counter_jump + ) matched_object.last_distance = match_distance matched_objects.append(matched_object) elif isinstance(matched_candidate, TrackedObject): @@ -399,7 +414,7 @@ def create( initialization_delay: int, pointwise_hit_counter_max: int, detection_threshold: float, - period: int, + initial_hit_counter: int, filter_factory: "FilterFactory", past_detections_length: int, reid_hit_counter_max: Optional[int], @@ -412,7 +427,7 @@ def create( initialization_delay=initialization_delay, pointwise_hit_counter_max=pointwise_hit_counter_max, detection_threshold=detection_threshold, - period=period, + initial_hit_counter=initial_hit_counter, filter_factory=filter_factory, past_detections_length=past_detections_length, reid_hit_counter_max=reid_hit_counter_max, @@ -477,7 +492,7 @@ def __init__( initialization_delay: int, pointwise_hit_counter_max: int, detection_threshold: float, - period: int, + initial_hit_counter: int, filter_factory: "FilterFactory", past_detections_length: int, reid_hit_counter_max: Optional[int], @@ -491,11 +506,12 @@ def __init__( self.dim_points = initial_detection.absolute_points.shape[1] self.num_points = initial_detection.absolute_points.shape[0] self.hit_counter_max: int = hit_counter_max - self.pointwise_hit_counter_max: int = max(pointwise_hit_counter_max, period) + self.pointwise_hit_counter_max: int = max( + pointwise_hit_counter_max, initial_hit_counter + ) self.initialization_delay = initialization_delay self.detection_threshold: float = detection_threshold - self.initial_period: int = period - self.hit_counter: int = period + self.hit_counter: int = initial_hit_counter self.reid_hit_counter_max = reid_hit_counter_max self.reid_hit_counter: Optional[int] = None self.last_distance: Optional[float] = None @@ -534,14 +550,14 @@ def __init__( if coord_transformations is not None: self.update_coordinate_transformation(coord_transformations) - def tracker_step(self): + def tracker_step(self, hit_counter_jump): if self.reid_hit_counter is None: if self.hit_counter <= 0: self.reid_hit_counter = self.reid_hit_counter_max else: - self.reid_hit_counter -= 1 - self.hit_counter -= 1 - self.point_hit_counter -= 1 + self.reid_hit_counter -= hit_counter_jump + self.hit_counter -= hit_counter_jump + self.point_hit_counter -= hit_counter_jump self.age += 1 # Advances the tracker's state self.filter.predict() @@ -612,20 +628,22 @@ def get_estimate(self, absolute=False) -> np.ndarray: def live_points(self): return self.point_hit_counter > 0 - def hit(self, detection: "Detection", period: int = 1): + def hit(self, detection: "Detection", hit_counter_jump: int = 1): """Update tracked object with a new detection Parameters ---------- detection : Detection the new detection matched to this tracked object - period : int, optional - frames corresponding to the period of time since last update. + hit_counter_jump : int, optional + by how much will the hit counter be increased when it matches a detection """ self._conditionally_add_to_past_detections(detection) self.last_detection = detection - self.hit_counter = min(self.hit_counter + 2 * period, self.hit_counter_max) + self.hit_counter = min( + self.hit_counter + 2 * hit_counter_jump, self.hit_counter_max + ) if self.is_initializing and self.hit_counter > self.initialization_delay: self.is_initializing = False @@ -644,11 +662,11 @@ def hit(self, detection: "Detection", period: int = 1): H_pos = np.diag(matched_sensors_mask).astype( float ) # We measure x, y positions - self.point_hit_counter[points_over_threshold_mask] += 2 * period + self.point_hit_counter[points_over_threshold_mask] += 2 * hit_counter_jump else: points_over_threshold_mask = np.array([True] * self.num_points) H_pos = np.identity(self.num_points * self.dim_points) - self.point_hit_counter += 2 * period + self.point_hit_counter += 2 * hit_counter_jump self.point_hit_counter[ self.point_hit_counter >= self.pointwise_hit_counter_max ] = self.pointwise_hit_counter_max @@ -717,7 +735,7 @@ def _conditionally_add_to_past_detections(self, detection): def merge(self, tracked_object): """Merge with a not yet initialized TrackedObject instance""" self.reid_hit_counter = None - self.hit_counter = self.initial_period * 2 + self.hit_counter = self.initialization_delay + 1 self.point_hit_counter = tracked_object.point_hit_counter self.last_distance = tracked_object.last_distance self.current_min_distance = tracked_object.current_min_distance diff --git a/tests/test_tracker.py b/tests/test_tracker.py index 6e3b9e8a..54dc2660 100644 --- a/tests/test_tracker.py +++ b/tests/test_tracker.py @@ -87,7 +87,7 @@ def test_simple(filter_factory): # check that counter goes down to 0 wen no detections for counter in range(counter_max - 1, -1, -1): age += 1 - tracked_objects = tracker.update() + tracked_objects = tracker.update(detections=[]) assert len(tracked_objects) == 1 obj = tracked_objects[0] np.testing.assert_almost_equal( @@ -97,7 +97,7 @@ def test_simple(filter_factory): assert obj.hit_counter == counter # check that object dissapears in the next frame - assert len(tracker.update()) == 0 + assert len(tracker.update(detections=[])) == 0 @pytest.mark.parametrize( @@ -242,11 +242,11 @@ def test_count(delay): assert tracker.current_object_count == 1 for _ in range(delay + 1, 0, -1): - assert len(tracker.update()) == 1 + assert len(tracker.update(detections=[])) == 1 assert tracker.total_object_count == 1 assert tracker.current_object_count == 1 - assert len(tracker.update()) == 0 + assert len(tracker.update(detections=[])) == 0 assert tracker.total_object_count == 1 assert tracker.current_object_count == 0 @@ -265,11 +265,11 @@ def test_count(delay): assert tracker.current_object_count == 2 for _ in range(delay + 1, 0, -1): - assert len(tracker.update()) == 2 + assert len(tracker.update(detections=[])) == 2 assert tracker.total_object_count == 3 assert tracker.current_object_count == 2 - assert len(tracker.update()) == 0 + assert len(tracker.update(detections=[])) == 0 assert tracker.total_object_count == 3 assert tracker.current_object_count == 0 @@ -332,7 +332,7 @@ def dist(new_obj, tracked_obj): # check that object is dead if it doesn't get matched to any detections obj_id = tracked_objects[0].id for _ in range(hit_counter_max + 1): - tracked_objects = tracker.update() + tracked_objects = tracker.update(detections=[]) assert len(tracked_objects) == 0 # check that previous object gets back to life after reid matching @@ -345,7 +345,7 @@ def dist(new_obj, tracked_obj): # check that previous object gets eliminated after hit_counter_max + reid_hit_counter_max + 1 for _ in range(hit_counter_max + reid_hit_counter_max + 1): - tracked_objects = tracker.update() + tracked_objects = tracker.update(detections=[]) assert len(tracked_objects) == 0 for _ in range(2): tracked_objects = tracker.update([Detection(points=np.array([[1, 1]]))])