diff --git a/README.md b/README.md index 7eab7e4a..0dbdbb8f 100644 --- a/README.md +++ b/README.md @@ -75,11 +75,11 @@ Prefer the terminal? Point `trackers track` at a video, webcam feed, RTSP stream ```bash trackers track \ --source video.mp4 \ - --output output.mp4 \ - --model rfdetr-medium \ + --out.output output.mp4 \ + --detection.model rfdetr-medium \ --tracker bytetrack \ - --show-labels \ - --show-trajectories + --show.labels \ + --show.trajectories ``` For all CLI options, see the [tracking guide](https://trackers.roboflow.com/develop/learn/track/). @@ -103,10 +103,10 @@ Once you have tracking results, you want to know how good they are. `trackers ev ```bash trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` ``` @@ -130,7 +130,7 @@ For the full evaluation workflow, see the [evaluation guide](https://trackers.ro Need benchmark data to evaluate against? `trackers download` pulls MOT17, SportsMOT, and other supported datasets with a single command, handling splits and assets selectively so you only download what you need. ```bash -trackers download mot17 \ +trackers download --dataset mot17 \ --split val \ --asset annotations,detections ``` diff --git a/docs/index.md b/docs/index.md index 80d354d4..3596506c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,11 +41,11 @@ Point at a video, webcam, RTSP stream, or image directory. Get tracked output. ```bash trackers track \ --source video.mp4 \ - --output output.mp4 \ - --model rfdetr-medium \ + --out.output output.mp4 \ + --detection.model rfdetr-medium \ --tracker bytetrack \ - --show-labels \ - --show-trajectories + --show.labels \ + --show.trajectories ``` For all CLI options, see the [tracking guide](learn/track.md). @@ -86,10 +86,10 @@ Benchmark your tracker against ground truth with standard MOT metrics. ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` ``` @@ -130,7 +130,7 @@ For detailed benchmarks and tuned configurations, see the [tracker comparison](t Pull benchmark datasets for evaluation with a single command. ```bash -trackers download mot17 \ +trackers download --dataset mot17 \ --split val \ --asset annotations,detections ``` diff --git a/docs/learn/detection-quality.md b/docs/learn/detection-quality.md index 4ea4a909..12cfd89d 100644 --- a/docs/learn/detection-quality.md +++ b/docs/learn/detection-quality.md @@ -47,7 +47,8 @@ We pick three models that span a wide accuracy range on COCO, from a lightweight Pull the MOT17 validation split. You need frames for detection and annotations for evaluation. ```text -trackers download mot17 \ +trackers download \ + --dataset mot17 \ --split val \ --asset frames,annotations \ --output ./data @@ -66,10 +67,10 @@ Run ByteTrack with default parameters three times, changing only the detection m ```bash trackers track \ --source ./data/mot17/val/MOT17-13-FRCNN/img1 \ - --model yolo26n-640 \ + --detection.model yolo26n-640 \ --tracker bytetrack \ - --classes person \ - --mot-output results/yolo26n/MOT17-13-FRCNN.txt + --filters.classes person \ + --out.mot_results results/yolo26n/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -78,10 +79,10 @@ Run ByteTrack with default parameters three times, changing only the detection m for seq in MOT17-02-FRCNN MOT17-04-FRCNN MOT17-05-FRCNN MOT17-09-FRCNN MOT17-10-FRCNN MOT17-11-FRCNN MOT17-13-FRCNN; do trackers track \ --source ./data/mot17/val/$seq/img1 \ - --model yolo26n-640 \ + --detection.model yolo26n-640 \ --tracker bytetrack \ - --classes person \ - --mot-output results/yolo26n/$seq.txt + --filters.classes person \ + --out.mot_results results/yolo26n/$seq.txt done ``` @@ -97,10 +98,10 @@ Run ByteTrack with default parameters three times, changing only the detection m ```bash trackers track \ --source ./data/mot17/val/MOT17-13-FRCNN/img1 \ - --model rfdetr-nano \ + --detection.model rfdetr-nano \ --tracker bytetrack \ - --classes person \ - --mot-output results/rfdetr-nano/MOT17-13-FRCNN.txt + --filters.classes person \ + --out.mot_results results/rfdetr-nano/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -109,10 +110,10 @@ Run ByteTrack with default parameters three times, changing only the detection m for seq in MOT17-02-FRCNN MOT17-04-FRCNN MOT17-05-FRCNN MOT17-09-FRCNN MOT17-10-FRCNN MOT17-11-FRCNN MOT17-13-FRCNN; do trackers track \ --source ./data/mot17/val/$seq/img1 \ - --model rfdetr-nano \ + --detection.model rfdetr-nano \ --tracker bytetrack \ - --classes person \ - --mot-output results/rfdetr-nano/$seq.txt + --filters.classes person \ + --out.mot_results results/rfdetr-nano/$seq.txt done ``` @@ -128,10 +129,10 @@ Run ByteTrack with default parameters three times, changing only the detection m ```bash trackers track \ --source ./data/mot17/val/MOT17-13-FRCNN/img1 \ - --model rfdetr-medium \ + --detection.model rfdetr-medium \ --tracker bytetrack \ - --classes person \ - --mot-output results/rfdetr-medium/MOT17-13-FRCNN.txt + --filters.classes person \ + --out.mot_results results/rfdetr-medium/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -140,10 +141,10 @@ Run ByteTrack with default parameters three times, changing only the detection m for seq in MOT17-02-FRCNN MOT17-04-FRCNN MOT17-05-FRCNN MOT17-09-FRCNN MOT17-10-FRCNN MOT17-11-FRCNN MOT17-13-FRCNN; do trackers track \ --source ./data/mot17/val/$seq/img1 \ - --model rfdetr-medium \ + --detection.model rfdetr-medium \ --tracker bytetrack \ - --classes person \ - --mot-output results/rfdetr-medium/$seq.txt + --filters.classes person \ + --out.mot_results results/rfdetr-medium/$seq.txt done ``` @@ -162,10 +163,10 @@ Evaluate each run against ground truth using CLEAR, HOTA, and Identity metrics. ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results/yolo26n \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results/yolo26n \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` **Output:** @@ -180,10 +181,10 @@ COMBINED 23.444 32.874 34.411 ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results/rfdetr-nano \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results/rfdetr-nano \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` **Output:** @@ -198,10 +199,10 @@ COMBINED 25.667 35.735 38.182 ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results/rfdetr-medium \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results/rfdetr-medium \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` **Output:** diff --git a/docs/learn/download.md b/docs/learn/download.md index 68b12e31..70dba25d 100644 --- a/docs/learn/download.md +++ b/docs/learn/download.md @@ -44,10 +44,10 @@ The table below lists every dataset you can download, along with its splits, ass === "CLI" - Use `--list` to print available datasets, splits, and asset types. + Use `--list_available` to print available datasets, splits, and asset types. ```text - trackers download --list + trackers download --list_available ``` === "Python" @@ -72,7 +72,7 @@ Pass a dataset name to download all of its splits and assets. Download the full MOT17 dataset. ```text - trackers download mot17 + trackers download --dataset mot17 ``` === "Python" @@ -96,15 +96,15 @@ Full datasets can be large. Narrow your download to specific splits and asset ty Use `--split` and `--asset` to filter by split, asset type, or both. ```text - trackers download mot17 --split train --asset annotations + trackers download --dataset mot17 --split train --asset annotations ``` ```text - trackers download mot17 --split train,val --asset annotations,frames + trackers download --dataset mot17 --split train,val --asset annotations,frames ``` ```text - trackers download sportsmot --split val --asset annotations + trackers download --dataset sportsmot --split val --asset annotations ``` === "Python" @@ -152,7 +152,7 @@ Dataset files are extracted to the current directory by default. Set a custom ou Use `--output` to extract into a custom directory. ```text - trackers download mot17 \ + trackers download --dataset mot17 \ --split train,val \ --asset annotations,frames \ --output ./datasets @@ -209,13 +209,13 @@ Every downloaded ZIP is saved to `~/.cache/trackers` and verified with an MD5 ch === "CLI" - Use `--cache-dir` to store ZIPs in a custom location. + Use `--cache_dir` to store ZIPs in a custom location. ```text - trackers download mot17 \ + trackers download --dataset mot17 \ --split train \ --asset annotations \ - --cache-dir ./my-cache + --cache_dir ./my-cache ``` === "Python" @@ -254,12 +254,12 @@ All arguments accepted by the `trackers download` command. - dataset + --dataset Dataset name to download. Options: mot17, sportsmot. — - --list + --list_available List available datasets, splits, and asset types without downloading. false @@ -279,7 +279,7 @@ All arguments accepted by the `trackers download` command. . - --cache-dir + --cache_dir Directory for caching downloaded ZIP files. Cached files are verified by MD5 and reused across runs. ~/.cache/trackers diff --git a/docs/learn/evaluate.md b/docs/learn/evaluate.md index 45cd8eca..821a6348 100644 --- a/docs/learn/evaluate.md +++ b/docs/learn/evaluate.md @@ -41,7 +41,8 @@ Use `trackers download` to pull ground-truth annotations and detections from sup Fetch MOT17 validation annotations and detections from the command line. ```text - trackers download mot17 \ + trackers download \ + --dataset mot17 \ --split val \ --asset annotations,detections \ --output ./data @@ -89,13 +90,13 @@ For more download options, see the [download guide](download.md). Feed the pre-computed detections into a tracker and write the results to a file for evaluation. -Pass `--detections` to provide input detections and `--mot-output` to save the tracker output in MOT format. +Pass `--detection.detections` to provide input detections and `--out.mot_results` to save the tracker output in MOT format. ```text trackers track \ - --detections ./data/mot17/val/MOT17-02-FRCNN/det/det.txt \ + --detection.detections ./data/mot17/val/MOT17-02-FRCNN/det/det.txt \ --tracker bytetrack \ - --mot-output results/MOT17-02-FRCNN.txt + --out.mot_results results/MOT17-02-FRCNN.txt ``` --- @@ -108,8 +109,8 @@ Compare the tracker output against ground truth to compute standard MOT metrics. trackers eval \ --gt ./data/mot17/val/MOT17-02-FRCNN/gt/gt.txt \ --tracker results/MOT17-02-FRCNN.txt \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` **Output:** @@ -148,10 +149,10 @@ Evaluate all sequences at once and get per-sequence results plus a combined aggr ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 \ + --gt_dir ./data/mot17/val \ + --tracker_dir results \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' \ --output results.json ``` @@ -204,12 +205,12 @@ All arguments accepted by `trackers eval`. — - --gt-dir + --gt_dir Directory containing ground-truth files for multi-sequence evaluation. — - --tracker-dir + --tracker_dir Directory containing tracker prediction files for multi-sequence evaluation. — diff --git a/docs/learn/track.md b/docs/learn/track.md index 57027705..dc4e8d44 100644 --- a/docs/learn/track.md +++ b/docs/learn/track.md @@ -93,14 +93,14 @@ Trackers assign stable IDs to detections across frames, maintaining object ident === "CLI" - Select a tracker with `--tracker` and tune its behavior with `--tracker.*` arguments. + Select a tracker with `--tracker` and tune its behavior with `--tracker_params.*` arguments. ```text trackers track \ --source source.mp4 \ --tracker bytetrack \ - --tracker.lost_track_buffer 60 \ - --tracker.minimum_consecutive_frames 5 + --tracker_params.lost_track_buffer 60 \ + --tracker_params.minimum_consecutive_frames 5 ``` === "Python" @@ -139,15 +139,15 @@ Trackers don't detect objects—they link detections across frames. A detection === "CLI" - Configure detection with `--model.*` arguments. Filter by confidence and class before tracking. + Configure detection with `--detection.*` arguments. Filter by confidence and class before tracking. ```text trackers track \ --source source.mp4 \ - --model rfdetr-medium \ - --model.confidence 0.3 \ - --model.device cuda \ - --classes person,car + --detection.model rfdetr-medium \ + --detection.confidence 0.3 \ + --detection.device cuda \ + --filters.classes person,car ``` === "Python" @@ -188,10 +188,10 @@ Visualization renders tracking results for debugging, demos, and qualitative eva ```text trackers track \ --source source.mp4 \ - --display \ - --show-labels \ - --show-confidence \ - --show-trajectories + --vis.display \ + --show.labels \ + --show.confidence \ + --show.trajectories ``` === "Python" @@ -274,7 +274,7 @@ Save tracking results as annotated video files or display them in real time. Specify an output path to save annotated video. ```text - trackers track --source source.mp4 --output output.mp4 --overwrite + trackers track --source source.mp4 --out.output output.mp4 --out.overwrite ``` === "Python" @@ -342,37 +342,37 @@ All arguments accepted by the `trackers track` command. — - --output + --out.output Path for output video. If a directory is given, saves as output.mp4 inside it. none - --overwrite + --out.overwrite Allow overwriting existing output files. Without this flag, existing files cause an error. false - --model + --detection.model Model identifier. Pretrained: rfdetr-nano, rfdetr-small, rfdetr-medium, rfdetr-large. Segmentation: rfdetr-seg-*. rfdetr-nano - --model.confidence + --detection.confidence Minimum confidence threshold. Lower values increase recall but may add noise. 0.5 - --model.device + --detection.device Compute device. Options: auto, cpu, cuda, cuda:0, mps. auto - --model.api_key + --detection.api_key Roboflow API key for custom hosted models. none - --classes + --filters.classes Comma-separated class names or IDs to track. Example: person,car or 0,2. all @@ -382,57 +382,62 @@ All arguments accepted by the `trackers track` command. bytetrack - --tracker.lost_track_buffer + --tracker_params.lost_track_buffer Frames to retain a track without detections. Higher values improve occlusion handling but risk ID drift. 30 - --tracker.track_activation_threshold + --tracker_params.track_activation_threshold Minimum confidence to start a new track. Lower values catch more objects but increase false positives. 0.25 - --tracker.minimum_consecutive_frames + --tracker_params.minimum_consecutive_frames Consecutive detections required before a track is confirmed. Suppresses spurious detections. 3 - --tracker.minimum_iou_threshold + --tracker_params.minimum_iou_threshold Minimum IoU overlap to match a detection to an existing track. Higher values require tighter alignment. 0.3 - --display + --tracker_params.iou_variant + IoU similarity metric for data association. Options: iou, giou, diou, ciou, biou. Applies to all trackers. + iou + + + --vis.display Opens a live preview window. Press q or ESC to quit. false - --show-boxes + --show.boxes Draw bounding boxes around tracked objects. true - --show-masks + --show.masks Draw segmentation masks. Only available with rfdetr-seg-* models. false - --show-confidence + --show.confidence Show detection confidence scores in labels. false - --show-labels + --show.labels Show class names in labels. false - --show-ids + --show.ids Show tracker IDs in labels. true - --show-trajectories + --show.trajectories Draw motion trails showing recent positions of each track. false diff --git a/docs/learn/tune.md b/docs/learn/tune.md index fc8d1d03..231d4e56 100644 --- a/docs/learn/tune.md +++ b/docs/learn/tune.md @@ -74,9 +74,9 @@ override the same key in `search_space` if present, and are returned from ```text trackers tune \ --tracker botsort \ - --gt-dir ./data/gt \ - --detections-dir ./data/detections \ - --fixed-params '{"enable_cmc": false}' + --gt_dir ./data/gt \ + --detections_dir ./data/detections \ + --fixed_params '{"enable_cmc": false}' ``` Images are read from `{images_dir}/{sequence}/img1/` using MOT-style stems: @@ -116,11 +116,11 @@ For detections, use `id=-1`. For more details on the format and evaluation workf ```text trackers tune \ --tracker bytetrack \ - --gt-dir ./data/gt \ - --detections-dir ./data/detections \ + --gt_dir ./data/gt \ + --detections_dir ./data/detections \ --objective HOTA \ - --metrics CLEAR HOTA Identity \ - --n-trials 50 \ + --metrics '[CLEAR,HOTA,Identity]' \ + --n_trials 50 \ --output ./results/bytetrack-best.json ``` @@ -163,8 +163,8 @@ MOT17-09-FRCNN ```text trackers tune \ --tracker bytetrack \ - --gt-dir ./data/gt \ - --detections-dir ./data/detections \ + --gt_dir ./data/gt \ + --detections_dir ./data/detections \ --seqmap ./seqmap.txt ``` @@ -226,12 +226,12 @@ All arguments accepted by `trackers tune`. — - --gt-dir + --gt_dir Directory with ground-truth MOT files ({sequence}.txt). — - --detections-dir + --detections_dir Directory with detection MOT files ({sequence}.txt), one file per sequence. — @@ -241,7 +241,7 @@ All arguments accepted by `trackers tune`. HOTA - --n-trials + --n_trials Number of Optuna trials to run. 100 @@ -258,7 +258,7 @@ All arguments accepted by `trackers tune`. --seqmap Optional path to a sequence map file. When set, only listed sequences are tuned. - all files in --detections-dir + all files in --detections_dir --seed @@ -266,7 +266,7 @@ All arguments accepted by `trackers tune`. None - --output, -o + --output Path to save best parameters as JSON. None diff --git a/pyproject.toml b/pyproject.toml index b010f719..be3a557b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "opencv-python>=4.8.0", "rich>=13.0.0", "requests>=2.28.0", + "jsonargparse>=4.48.0", ] [project.optional-dependencies] @@ -48,7 +49,7 @@ detection = ["inference-models>=0.19.0"] tune = ["optuna>=3.0.0"] [project.scripts] -trackers = "trackers.scripts.__main__:main" +trackers = "trackers.cli.__main__:main" [dependency-groups] dev = [ diff --git a/src/trackers/scripts/__init__.py b/src/trackers/cli/__init__.py similarity index 100% rename from src/trackers/scripts/__init__.py rename to src/trackers/cli/__init__.py diff --git a/src/trackers/cli/__main__.py b/src/trackers/cli/__main__.py new file mode 100644 index 00000000..20379935 --- /dev/null +++ b/src/trackers/cli/__main__.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Command-line entry point for the trackers package.""" + +from __future__ import annotations + +import sys +import warnings + +from jsonargparse import CLI, ActionYesNo, ArgumentParser + +from trackers.cli.download import download +from trackers.cli.eval import eval_cmd +from trackers.cli.track import track +from trackers.cli.tune import tune + + +class _BoolFlagParser(ArgumentParser): + """Render plain ``bool`` fields as ``--flag`` / ``--no-flag`` pairs.""" + + def add_argument(self, *args, **kwargs): # type: ignore[override] + if kwargs.get("type") is bool: + kwargs.pop("type") + kwargs["action"] = ActionYesNo(yes_prefix="", no_prefix="no-") + return super().add_argument(*args, **kwargs) + + +def main() -> int: + """Dispatch to track / eval / tune / download via jsonargparse CLI.""" + warnings.warn( + "The trackers CLI is in beta. APIs may change in future releases.", + UserWarning, + stacklevel=2, + ) + rc = CLI( + {"track": track, "eval": eval_cmd, "tune": tune, "download": download}, + as_positional=False, + prog="trackers", + description="Command-line tools for multi-object tracking.", + parser_class=_BoolFlagParser, + ) + return int(rc) if rc is not None else 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/trackers/cli/download.py b/src/trackers/cli/download.py new file mode 100644 index 00000000..73becaf5 --- /dev/null +++ b/src/trackers/cli/download.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""``trackers download`` subcommand — fetch benchmark tracking datasets.""" + +from __future__ import annotations + +import sys + +from rich.console import Console +from rich.panel import Panel + +from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR +from trackers.datasets.manifest import _DATASETS + + +def download( + dataset: str | None = None, + split: str | None = None, + asset: str | None = None, + output: str = _DEFAULT_OUTPUT_DIR, + cache_dir: str = _DEFAULT_CACHE_DIR, + list_available: bool = False, +) -> int: + """Download benchmark tracking datasets from the official trackers bucket. + + Args: + dataset: Dataset name (e.g. ``mot17``, ``sportsmot``). Required unless + ``list_available`` is set. + split: Comma-separated splits to download (e.g. ``train,val,test``). + ``None`` selects every available split. + asset: Comma-separated assets to download (``annotations,frames,detections``). + ``None`` selects every available asset. + output: Output directory. Defaults to the current working directory. + cache_dir: Cache directory for downloaded ZIPs. + list_available: When ``True``, print the available datasets, splits, and + asset types, then exit. + + Returns: + Exit code: ``0`` on success, ``1`` on error. + """ + if list_available: + _print_available() + return 0 + + if not dataset: + print("Please specify a dataset name or use --list_available.", file=sys.stderr) + return 1 + + from trackers.datasets.download import download_dataset + + split_list = [s.strip() for s in split.split(",")] if split else None + asset_list = [a.strip() for a in asset.split(",")] if asset else None + + try: + download_dataset( + dataset=dataset, + split=split_list, + asset=asset_list, + output=output, + cache_dir=cache_dir, + ) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + + return 0 + + +def _print_available() -> None: + """Print available datasets, splits, and asset types.""" + console = Console() + for name, dataset_info in _DATASETS.items(): + description = dataset_info.get("description", "") + splits_dict: dict[str, dict] = dataset_info.get("splits", {}) + + max_split_len = max(len(s) for s in splits_dict) if splits_dict else 0 + split_lines = [ + f"{split:<{max_split_len}} {', '.join(assets.keys())}" for split, assets in splits_dict.items() + ] + + body = f"{description}\n\n" + "\n".join(split_lines) + console.print(Panel(body, title=name.value, title_align="left")) + console.print() diff --git a/src/trackers/cli/eval.py b/src/trackers/cli/eval.py new file mode 100644 index 00000000..cd1c2576 --- /dev/null +++ b/src/trackers/cli/eval.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""``trackers eval`` subcommand — evaluate tracker predictions against ground truth.""" + +from __future__ import annotations + +import logging +import sys +from pathlib import Path + + +def eval_cmd( + gt: Path | None = None, + tracker: Path | None = None, + gt_dir: Path | None = None, + tracker_dir: Path | None = None, + seqmap: Path | None = None, + metrics: list[str] | None = None, + threshold: float = 0.5, + columns: list[str] | None = None, + output: Path | None = None, +) -> int: + """Evaluate tracker predictions against ground-truth MOT files. + + Two modes: + + - Single sequence: pass ``gt`` and ``tracker``. + - Benchmark: pass ``gt_dir`` and ``tracker_dir`` (with optional ``seqmap``). + + Args: + gt: Ground-truth file (MOT format) for single-sequence mode. + tracker: Tracker predictions file (MOT format) for single-sequence mode. + gt_dir: Directory of ground-truth files for benchmark mode. + tracker_dir: Directory of tracker prediction files for benchmark mode. + seqmap: Sequence map listing sequences to evaluate. + metrics: Metrics to compute. Options: ``CLEAR``, ``HOTA``, ``Identity``. + Defaults to ``["CLEAR"]``. + threshold: IoU threshold for CLEAR and Identity matching. + columns: Metric columns to display. ``None`` auto-selects from + available metrics. + output: Output JSON file for results. + + Returns: + Exit code: ``0`` on success, ``1`` on error. + """ + metrics = metrics or ["CLEAR"] + + logging.basicConfig( + level=logging.INFO, + format="%(message)s", + handlers=[logging.StreamHandler(sys.stderr)], + ) + + single_mode = gt is not None and tracker is not None + benchmark_mode = gt_dir is not None and tracker_dir is not None + + if not single_mode and not benchmark_mode: + print("Error: Must specify either --gt/--tracker or --gt_dir/--tracker_dir", file=sys.stderr) + return 1 + + if single_mode and benchmark_mode: + print("Error: Cannot use both single sequence and benchmark mode", file=sys.stderr) + return 1 + + from trackers.eval import evaluate_mot_sequence, evaluate_mot_sequences + + try: + if single_mode: + assert gt is not None and tracker is not None # noqa: S101 — narrows for type checker + seq_result = evaluate_mot_sequence( + gt_path=gt, + tracker_path=tracker, + metrics=metrics, + threshold=threshold, + ) + print(seq_result.table(columns=columns)) + if output: + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(seq_result.json()) + print(f"\nResults saved to: {output}") + else: + assert gt_dir is not None and tracker_dir is not None # noqa: S101 — narrows for type checker + bench_result = evaluate_mot_sequences( + gt_dir=gt_dir, + tracker_dir=tracker_dir, + seqmap=seqmap, + metrics=metrics, + threshold=threshold, + ) + print(bench_result.table(columns=columns)) + if output: + bench_result.save(output) + print(f"\nResults saved to: {output}") + except (FileNotFoundError, ValueError) as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + + return 0 diff --git a/src/trackers/scripts/progress.py b/src/trackers/cli/progress.py similarity index 100% rename from src/trackers/scripts/progress.py rename to src/trackers/cli/progress.py diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py new file mode 100644 index 00000000..83672c6d --- /dev/null +++ b/src/trackers/cli/track.py @@ -0,0 +1,651 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""``trackers track`` subcommand — run a detector + tracker over a video source.""" + +from __future__ import annotations + +import sys +import warnings +from contextlib import nullcontext +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import supervision as sv + +from trackers import frames_from_source +from trackers.cli.progress import _classify_source, _SourceInfo, _TrackingProgress +from trackers.core.base import BaseTracker +from trackers.io.mot import _mot_frame_to_detections, _MOTOutput, load_mot_file +from trackers.io.paths import _resolve_video_output_path, _validate_output_path +from trackers.io.video import _DEFAULT_OUTPUT_FPS, _DisplayWindow, _VideoOutput +from trackers.utils.device import _best_device +from trackers.utils.iou import variant_from_name + +if TYPE_CHECKING: + from inference_models import AnyModel + +# Defaults +DEFAULT_MODEL = "rfdetr-nano" +DEFAULT_TRACKER = "bytetrack" +DEFAULT_CONFIDENCE = 0.5 +DEFAULT_DEVICE = "auto" + +# Visualization +COLOR_PALETTE = sv.ColorPalette.from_hex( + [ + "#ffff00", + "#ff9b00", + "#ff8080", + "#ff66b2", + "#ff66ff", + "#b266ff", + "#9999ff", + "#3399ff", + "#66ffff", + "#33ff99", + "#66ff66", + "#99ff00", + ] +) + + +@dataclass +class DetectionOptions: + """Detection model and inference settings. + + Attributes: + model: Model ID (e.g. ``rfdetr-nano``) or + ``workspace/project/version`` for a Roboflow custom model. + Ignored when ``detections`` is set. + detections: Path to a pre-computed MOT-format detections file. + Mutually exclusive with ``model``; supply one or the other. + confidence: Detection confidence threshold. + device: Inference device: ``auto``, ``cpu``, ``cuda``, ``cuda:0``, + ``mps``. + api_key: Roboflow API key (required for private custom models). + """ + + model: str = DEFAULT_MODEL + detections: Path | None = None + confidence: float = DEFAULT_CONFIDENCE + device: str = DEFAULT_DEVICE + api_key: str | None = None + + +@dataclass +class FilteringOptions: + """Detection and track filters. + + Attributes: + classes: Comma-separated class names or IDs to keep + (e.g. ``person,car`` or ``0,2``). + track_ids: Comma-separated track IDs to keep in the output + (e.g. ``1,3,5``). + """ + + classes: str | None = None + track_ids: str | None = None + + +@dataclass +class OutputOptions: + """Output paths and write options. + + Attributes: + output: Annotated-video output path. + mot_results: MOT-format predictions output path. + overwrite: Overwrite existing output files without prompting. + """ + + output: Path | None = None + mot_results: Path | None = None + overwrite: bool = False + + +@dataclass +class VisualizationOptions: + """Live preview and display settings. + + Attributes: + display: Show a live preview window during tracking. + """ + + display: bool = False + + +@dataclass +class ShowOptions: + """Annotation elements to draw on each frame. + + Attributes: + boxes: Draw bounding boxes around detections. + masks: Draw segmentation masks (segmentation models only). + labels: Draw class labels. + ids: Draw track IDs. + confidence: Draw detection confidence scores. + trajectories: Draw track trajectory trails. + """ + + boxes: bool = True + masks: bool = False + labels: bool = False + ids: bool = True + confidence: bool = False + trajectories: bool = False + + +@dataclass +class TrackerParams: + """Optional tracker-specific parameters. + + Union of parameters across all registered trackers; each tracker only + receives the keys it knows about. Fields left as ``None`` are dropped + before instantiation so the tracker's own defaults apply. + + Attributes: + lost_track_buffer: Frames to keep a lost track before discarding. + frame_rate: Source frame rate for time-based logic. + track_activation_threshold: Detection score needed to spawn a track. + minimum_consecutive_frames: Consecutive matches to confirm a track. + minimum_iou_threshold: IoU threshold for SORT/OC-SORT association. + minimum_iou_threshold_first_assoc: BoT-SORT first-stage IoU. + minimum_iou_threshold_second_assoc: BoT-SORT second-stage IoU. + minimum_iou_threshold_unconfirmed_assoc: BoT-SORT unconfirmed IoU. + high_conf_det_threshold: High-confidence detection threshold. + direction_consistency_weight: OC-SORT direction consistency weight. + delta_t: OC-SORT velocity delta horizon. + enable_cmc: BoT-SORT camera motion compensation toggle. + cmc_method: BoT-SORT CMC method name. + cmc_downscale: BoT-SORT CMC downscale factor. + instant_first_frame_activation: BoT-SORT first-frame activation toggle. + iou_variant: IoU similarity metric for data association. One of + ``iou`` (standard), ``giou``, ``diou``, ``ciou``, ``biou``. + Applies to all trackers. Defaults to ``iou``. + """ + + lost_track_buffer: int | None = None + frame_rate: float | None = None + track_activation_threshold: float | None = None + minimum_consecutive_frames: int | None = None + minimum_iou_threshold: float | None = None + minimum_iou_threshold_first_assoc: float | None = None + minimum_iou_threshold_second_assoc: float | None = None + minimum_iou_threshold_unconfirmed_assoc: float | None = None + high_conf_det_threshold: float | None = None + direction_consistency_weight: float | None = None + delta_t: int | None = None + enable_cmc: bool | None = None + cmc_method: str | None = None + cmc_downscale: int | None = None + instant_first_frame_activation: bool | None = None + iou_variant: str | None = None + + +def track( + source: str | None = None, + detection: DetectionOptions | None = None, + filters: FilteringOptions | None = None, + tracker: str = DEFAULT_TRACKER, + tracker_params: TrackerParams | None = None, + out: OutputOptions | None = None, + vis: VisualizationOptions | None = None, + show: ShowOptions | None = None, +) -> int: + """Run detection and tracking over a video, webcam, RTSP, or image directory. + + Args: + source: Video file, webcam index (e.g. ``"0"``), RTSP URL, or image + directory. Required unless ``detection.detections`` is supplied. + detection: Detection model and inference options. + filters: Class and track-ID filters applied to detections and tracks. + tracker: Tracking algorithm ID. Discoverable via + ``BaseTracker._registered_trackers()``. + tracker_params: Optional tracker parameter overrides; only fields + matching the chosen tracker's ``__init__`` are forwarded. + out: Output path and overwrite options. + vis: Live preview and display options. + show: Annotation elements to draw on each frame. + + Returns: + Exit code: ``0`` on success, ``1`` on validation error. + """ + if detection is None: + detection = DetectionOptions() + if filters is None: + filters = FilteringOptions() + if out is None: + out = OutputOptions() + if vis is None: + vis = VisualizationOptions() + if show is None: + show = ShowOptions() + model = detection.model + detections = detection.detections + confidence = detection.confidence + device = detection.device + api_key = detection.api_key + classes = filters.classes + track_ids = filters.track_ids + output = out.output + mot_results = out.mot_results + overwrite = out.overwrite + display = vis.display + show_boxes = show.boxes + show_masks = show.masks + show_labels = show.labels + show_ids = show.ids + show_confidence = show.confidence + show_trajectories = show.trajectories + + needs_frames = output is not None or display + + if source is None and detections is None: + print("Error: --source is required when not using --detections.", file=sys.stderr) + return 1 + if needs_frames and source is None: + print("Error: --source is required when using --output or --display.", file=sys.stderr) + return 1 + + if output: + _validate_output_path(_resolve_video_output_path(output), overwrite=overwrite) + if mot_results: + _validate_output_path(mot_results, overwrite=overwrite) + + if detections is not None: + model_obj: AnyModel | None = None + detections_data: dict | None = load_mot_file(detections) + class_names: list[str] = [] + else: + model_obj = _init_model(model, device=device, api_key=api_key) + detections_data = None + class_names = getattr(model_obj, "class_names", []) + + class_filter = _resolve_class_filter(classes, class_names) + track_id_filter = _resolve_track_id_filter(track_ids) + try: + tracker_obj = _init_tracker(tracker, tracker_params) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + + if source is not None: + return _run_with_source( + source=source, + model=model_obj, + confidence=confidence, + detections_data=detections_data, + class_names=class_names, + class_filter=class_filter, + track_id_filter=track_id_filter, + tracker=tracker_obj, + output=output, + mot_results=mot_results, + display=display, + show_boxes=show_boxes, + show_masks=show_masks, + show_labels=show_labels, + show_ids=show_ids, + show_confidence=show_confidence, + show_trajectories=show_trajectories, + ) + + return _run_frameless( + detections_data=detections_data, + class_filter=class_filter, + track_id_filter=track_id_filter, + tracker=tracker_obj, + mot_results=mot_results, + ) + + +def _run_frameless( + *, + detections_data: dict | None, + class_filter: list[int] | None, + track_id_filter: list[int] | None, + tracker: BaseTracker, + mot_results: Path | None, +) -> int: + """Run tracking from pre-computed detections without a frame source.""" + if not detections_data: + print("Error: No detections found in file.", file=sys.stderr) + return 1 + + total_frames = max(detections_data.keys()) + source_info = _SourceInfo(source_type="video", total_frames=total_frames) + + try: + with _MOTOutput(mot_results) as mot, _TrackingProgress(source_info) as progress: + for frame_idx in range(1, total_frames + 1): + if frame_idx in detections_data: + dets = _mot_frame_to_detections(detections_data[frame_idx]) + else: + dets = sv.Detections.empty() + + if class_filter is not None and len(dets) > 0 and dets.class_id is not None: + mask = np.isin(dets.class_id, class_filter) + dets = dets[mask] # type: ignore[assignment] + + tracked = tracker.update(dets) + + if track_id_filter is not None and len(tracked) > 0 and tracked.tracker_id is not None: + mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) + tracked = tracked[mask] # type: ignore[assignment] + + mot.write(frame_idx, tracked) + progress.update() + + progress.complete(interrupted=False) + except KeyboardInterrupt: + pass + + return 0 + + +def _run_with_source( + *, + source: str, + model: AnyModel | None, + confidence: float, + detections_data: dict | None, + class_names: list[str], + class_filter: list[int] | None, + track_id_filter: list[int] | None, + tracker: BaseTracker, + output: Path | None, + mot_results: Path | None, + display: bool, + show_boxes: bool, + show_masks: bool, + show_labels: bool, + show_ids: bool, + show_confidence: bool, + show_trajectories: bool, +) -> int: + """Run tracking with a frame source (video, webcam, images).""" + frame_gen = frames_from_source(source) + source_info = _classify_source(source) + + annotators, label_annotator = _init_annotators( + show_boxes=show_boxes, + show_masks=show_masks, + show_labels=show_labels, + show_ids=show_ids, + show_confidence=show_confidence, + ) + trace_annotator = ( + sv.TraceAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK) if show_trajectories else None + ) + display_ctx = _DisplayWindow() if display else nullcontext() + + try: + with ( + _VideoOutput(output, fps=source_info.fps or _DEFAULT_OUTPUT_FPS) as video, + _MOTOutput(mot_results) as mot, + display_ctx as display_win, + _TrackingProgress(source_info) as progress, + ): + interrupted = False + for frame_idx, frame in frame_gen: + if model is not None: + dets = _run_model(model, frame, confidence) + elif detections_data is not None and frame_idx in detections_data: + dets = _mot_frame_to_detections(detections_data[frame_idx]) + else: + dets = sv.Detections.empty() + + if class_filter is not None and len(dets) > 0 and dets.class_id is not None: + mask = np.isin(dets.class_id, class_filter) + dets = dets[mask] # type: ignore[assignment] + + tracked = tracker.update(dets, frame) + + if track_id_filter is not None and len(tracked) > 0 and tracked.tracker_id is not None: + mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) + tracked = tracked[mask] # type: ignore[assignment] + + mot.write(frame_idx, tracked) + progress.update() + + if display or output: + annotated = frame.copy() + if trace_annotator is not None: + annotated = trace_annotator.annotate(annotated, tracked) + for ann in annotators: + annotated = ann.annotate(annotated, tracked) + if label_annotator is not None: + labeled = tracked[tracked.tracker_id != -1] + labels = _format_labels( + labeled, + class_names, + show_ids=show_ids, + show_labels=show_labels, + show_confidence=show_confidence, + ) + annotated = label_annotator.annotate(annotated, labeled, labels) + + video.write(annotated) + + if display_win is not None: + display_win.show(annotated) + if display_win.quit_requested: + interrupted = True + break + + progress.complete(interrupted=interrupted) + except KeyboardInterrupt: + pass + + return 0 + + +def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None: + """Resolve a comma-separated ``track_ids`` string to a list of integer IDs. + + Args: + track_ids_arg: Raw ``--track_ids`` string (e.g. ``"1,3,5"``). ``None`` + means no filter. + + Returns: + List of integer track IDs, or ``None`` when no valid filter remains. + """ + if not track_ids_arg: + return None + + track_ids: list[int] = [] + for raw in track_ids_arg.split(","): + token = raw.strip() + try: + track_ids.append(int(token)) + except ValueError: + print(f"Warning: '{token}' is not a valid track ID, skipping.", file=sys.stderr) + return track_ids or None + + +def _resolve_class_filter(classes_arg: str | None, class_names: list[str]) -> list[int] | None: + """Resolve a comma-separated ``classes`` string to a list of integer IDs. + + Each token is checked independently: if it parses as an ``int`` it is used + directly as a class ID; otherwise it is looked up by name in ``class_names``. + Unknown names are printed as warnings and skipped. + + Args: + classes_arg: Raw ``--classes`` string (e.g. ``"person,car"`` or + ``"0,2"`` or ``"person,2"``). ``None`` means no filter. + class_names: Ordered list of class names where the index equals the + class ID (as provided by the model). + + Returns: + List of integer class IDs, or ``None`` when no valid filter remains. + """ + if not classes_arg: + return None + + name_to_id = {name: i for i, name in enumerate(class_names)} + class_filter: list[int] = [] + for raw in classes_arg.split(","): + token = raw.strip() + try: + class_filter.append(int(token)) + except ValueError: + if token in name_to_id: + class_filter.append(name_to_id[token]) + else: + print(f"Warning: class '{token}' not found in model class list, skipping.", file=sys.stderr) + return class_filter or None + + +def _init_model(model_id: str, *, device: str = DEFAULT_DEVICE, api_key: str | None = None) -> AnyModel: + """Load detection model via ``inference-models``. + + Args: + model_id: Model identifier (e.g. ``rfdetr-nano`` or + ``workspace/project/version``). + device: Device to load model on (``auto``, ``cpu``, ``cuda``, ``mps``). + api_key: Roboflow API key for custom models. + + Returns: + Loaded model instance. + """ + try: + from inference_models import AutoModel + except ImportError as e: + print( + "Error: inference-models is required for model-based detection.\n" + "Install with: pip install 'trackers[detection]'", + file=sys.stderr, + ) + raise SystemExit(1) from e + + resolved_device = _best_device() if device == DEFAULT_DEVICE else device + return AutoModel.from_pretrained(model_id, api_key=api_key, device=resolved_device) + + +def _run_model(model: AnyModel, frame: np.ndarray, confidence: float) -> sv.Detections: + """Run model inference, filter by confidence, return ``sv.Detections``.""" + predictions = model(frame) + if not predictions: + return sv.Detections.empty() + + dets = predictions[0].to_supervision() + if len(dets) > 0 and dets.confidence is not None: + dets = dets[dets.confidence >= confidence] + return dets + + +def _init_tracker(tracker_id: str, params: TrackerParams | None) -> BaseTracker: + """Create a tracker instance from the registry. + + Only fields the chosen tracker accepts are forwarded; ``None`` values are + always dropped so the tracker's own defaults apply. + + Args: + tracker_id: Registered tracker name (e.g. ``bytetrack``, ``sort``). + params: Optional tracker parameter overrides. + + Returns: + Initialised tracker instance. + + Raises: + ValueError: If ``tracker_id`` is not registered. + """ + info = BaseTracker._lookup_tracker(tracker_id) + if info is None: + available = ", ".join(BaseTracker._registered_trackers()) + raise ValueError(f"Unknown tracker: '{tracker_id}'. Available: {available}") + + raw = asdict(params) if params is not None else {} + iou_variant = raw.pop("iou_variant", None) + accepted = set(info.parameters) + kwargs = {k: v for k, v in raw.items() if v is not None and k in accepted} + if iou_variant is not None: + if "iou" in accepted: + kwargs["iou"] = variant_from_name(iou_variant) + else: + warnings.warn( + f"Tracker '{tracker_id}' does not support iou_variant; '{iou_variant}' will be ignored.", + UserWarning, + stacklevel=2, + ) + return info.tracker_class(**kwargs) + + +def _init_annotators( + show_boxes: bool = False, + show_masks: bool = False, + show_labels: bool = False, + show_ids: bool = False, + show_confidence: bool = False, +) -> tuple[list, sv.LabelAnnotator | None]: + """Initialise supervision annotators based on display options. + + Args: + show_boxes: Create ``BoxAnnotator``. + show_masks: Create ``MaskAnnotator``. + show_labels: Include class labels (triggers ``LabelAnnotator``). + show_ids: Include track IDs (triggers ``LabelAnnotator``). + show_confidence: Include confidence scores (triggers ``LabelAnnotator``). + + Returns: + Tuple of (annotators list, label_annotator or None). Label annotator is + separate because it needs custom labels per frame. + """ + annotators: list = [] + label_annotator: sv.LabelAnnotator | None = None + + if show_boxes: + annotators.append(sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK)) + if show_masks: + annotators.append(sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK)) + if show_labels or show_ids or show_confidence: + label_annotator = sv.LabelAnnotator( + color=COLOR_PALETTE, + text_color=sv.Color.BLACK, + text_position=sv.Position.TOP_LEFT, + color_lookup=sv.ColorLookup.TRACK, + ) + return annotators, label_annotator + + +def _format_labels( + detections: sv.Detections, + class_names: list[str], + *, + show_ids: bool = False, + show_labels: bool = False, + show_confidence: bool = False, +) -> list[str]: + """Generate label strings for each detection. + + Args: + detections: Detections to generate labels for. + class_names: List of class names for lookup. + show_ids: Include tracker IDs in labels. + show_labels: Include class names in labels. + show_confidence: Include confidence scores in labels. + + Returns: + List of label strings, one per detection. + """ + labels = [] + for i in range(len(detections)): + parts: list[str] = [] + if show_ids and detections.tracker_id is not None: + parts.append(f"#{int(detections.tracker_id[i])}") + if show_labels and detections.class_id is not None: + class_id = int(detections.class_id[i]) + if class_names and 0 <= class_id < len(class_names): + parts.append(class_names[class_id]) + else: + parts.append(str(class_id)) + if show_confidence and detections.confidence is not None: + parts.append(f"{detections.confidence[i]:.2f}") + labels.append(" ".join(parts)) + return labels diff --git a/src/trackers/cli/tune.py b/src/trackers/cli/tune.py new file mode 100644 index 00000000..45edaf43 --- /dev/null +++ b/src/trackers/cli/tune.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""``trackers tune`` subcommand — Optuna hyperparameter optimisation.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + + +def tune( + tracker: str, + gt_dir: Path, + detections_dir: Path, + objective: str = "HOTA", + n_trials: int = 100, + metrics: list[str] | None = None, + threshold: float = 0.5, + seqmap: Path | None = None, + fixed_params: dict | None = None, + images_dir: Path | None = None, + enqueue_defaults: bool = True, + seed: int | None = None, + output: Path | None = None, +) -> int: + """Tune tracker hyperparameters using Optuna. + + Args: + tracker: Tracker ID to tune (e.g. ``bytetrack``, ``sort``, ``ocsort``). + gt_dir: Directory of ground-truth MOT files. + detections_dir: Directory of pre-computed detection files in MOT flat + format (one ``{seq}.txt`` per sequence). + objective: Scalar metric to maximise. Options: ``MOTA``, ``HOTA``, + ``IDF1``. + n_trials: Number of Optuna trials to run. + metrics: Metric families to compute. Options: ``CLEAR``, ``HOTA``, + ``Identity``. Default: ``["CLEAR"]``. The family required by + ``objective`` is added automatically if missing. + threshold: IoU threshold for CLEAR and Identity matching. + seqmap: Sequence map file listing sequences to evaluate. + fixed_params: Tracker ``__init__`` kwargs held constant across all + trials (e.g. ``{"enable_cmc": False}``). + images_dir: MOT-style image root for frame-based features such as CMC. + Frames are read from ``{images_dir}/{sequence}/img1/``. + enqueue_defaults: When ``True`` (default), the first trial evaluates + the tracker's default parameters before Optuna sampling begins. + seed: Random seed for Optuna's TPE sampler for reproducible runs. + output: Output JSON file for best parameters. + + Returns: + Exit code: ``0`` on success, ``1`` on error. + """ + if metrics is None: + metrics = ["CLEAR"] + + from trackers.tune import Tuner + + try: + tuner = Tuner( + tracker_id=tracker, + gt_dir=gt_dir, + detections_dir=detections_dir, + metrics=metrics, + objective=objective, + n_trials=n_trials, + threshold=threshold, + seqmap=seqmap, + fixed_params=fixed_params, + images_dir=images_dir, + enqueue_defaults=enqueue_defaults, + seed=seed, + ) + except (ValueError, ImportError, FileNotFoundError) as e: + print(str(e), file=sys.stderr) + return 1 + + try: + best_params = tuner.run() + except Exception as e: + print(f"Error during tuning: {e}", file=sys.stderr) + return 1 + + print(f"\nBest parameters for {tracker}:") + for name, value in best_params.items(): + print(f" {name}: {value}") + if tuner.study is not None: + print(f"\nBest {objective}: {tuner.study.best_value:.4f}") + + if output: + try: + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(best_params, indent=2)) + except OSError as e: + print(f"Error writing output: {e}", file=sys.stderr) + return 1 + print(f"\nResults saved to: {output}") + + return 0 diff --git a/src/trackers/scripts/__main__.py b/src/trackers/scripts/__main__.py deleted file mode 100644 index 0993f8c7..00000000 --- a/src/trackers/scripts/__main__.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python -# ------------------------------------------------------------------------ -# Trackers -# Copyright (c) 2026 Roboflow. All Rights Reserved. -# Licensed under the Apache License, Version 2.0 [see LICENSE for details] -# ------------------------------------------------------------------------ - -from __future__ import annotations - -import argparse -import sys -import warnings - - -def main() -> int: - """Main entry point for the trackers CLI.""" - # Beta warning - warnings.warn( - "The trackers CLI is in beta. APIs may change in future releases.", - UserWarning, - stacklevel=2, - ) - - parser = argparse.ArgumentParser( - prog="trackers", - description="Command-line tools for multi-object tracking.", - epilog="For more information, visit: https://github.com/roboflow/trackers", - ) - parser.add_argument( - "--version", - action="store_true", - help="Show version and exit.", - ) - - subparsers = parser.add_subparsers( - dest="command", - title="commands", - description="Available commands:", - ) - - # Import and register subcommands - from trackers.scripts.download import add_download_subparser - from trackers.scripts.eval import add_eval_subparser - from trackers.scripts.track import add_track_subparser - from trackers.scripts.tune import add_tune_subparser - - add_download_subparser(subparsers) - add_eval_subparser(subparsers) - add_track_subparser(subparsers) - add_tune_subparser(subparsers) - - # Parse arguments - args = parser.parse_args() - - if args.version: - from importlib.metadata import version - - print(f"trackers {version('trackers')}") - return 0 - - if args.command is None: - parser.print_help() - return 0 - - # Execute the command - return args.func(args) - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/src/trackers/scripts/download.py b/src/trackers/scripts/download.py deleted file mode 100644 index de8e461f..00000000 --- a/src/trackers/scripts/download.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python -# ------------------------------------------------------------------------ -# Trackers -# Copyright (c) 2026 Roboflow. All Rights Reserved. -# Licensed under the Apache License, Version 2.0 [see LICENSE for details] -# ------------------------------------------------------------------------ - -from __future__ import annotations - -import argparse -import sys - -from rich.console import Console -from rich.panel import Panel - -from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR -from trackers.datasets.manifest import _DATASETS - - -def add_download_subparser( - subparsers: argparse._SubParsersAction, -) -> None: - """Add the download subcommand to the argument parser.""" - parser = subparsers.add_parser( - "download", - help="Download benchmark tracking datasets.", - description="Download tracking datasets from the official trackers bucket.", - ) - - parser.add_argument( - "--list", - action="store_true", - help="List available datasets, splits, and asset types.", - ) - parser.add_argument( - "dataset", - nargs="?", - help="Dataset name (e.g. mot17, sportsmot).", - ) - parser.add_argument( - "--split", - help="Comma-separated splits to download (e.g. train,val,test). " - "If omitted, all available splits are downloaded.", - ) - parser.add_argument( - "--asset", - help="Comma-separated assets to download: annotations,frames,detections. " - "If omitted, all available assets are downloaded.", - ) - parser.add_argument( - "-o", - "--output", - default=_DEFAULT_OUTPUT_DIR, - help="Output directory (default: current directory).", - ) - parser.add_argument( - "--cache-dir", - default=_DEFAULT_CACHE_DIR, - help="Cache directory for downloaded ZIPs (default: ~/.cache/trackers).", - ) - - parser.set_defaults(func=_run_download) - - -def _run_download(args: argparse.Namespace) -> int: - """Execute the download subcommand.""" - if args.list: - _print_available() - return 0 - - if not args.dataset: - print("Please specify a dataset name or use --list.", file=sys.stderr) - return 1 - - from trackers.datasets.download import download_dataset - - split_list = [s.strip() for s in args.split.split(",")] if args.split else None - asset_list = [a.strip() for a in args.asset.split(",")] if args.asset else None - - try: - download_dataset( - dataset=args.dataset, - split=split_list, - asset=asset_list, - output=args.output, - cache_dir=args.cache_dir, - ) - except Exception as e: - print(f"Error: {e}", file=sys.stderr) - return 1 - - return 0 - - -def _print_available() -> None: - """Print available datasets, splits, and asset types.""" - console = Console() - for name, dataset_info in _DATASETS.items(): - description = dataset_info.get("description", "") - splits_dict: dict[str, dict] = dataset_info.get("splits", {}) - - max_split_len = max(len(s) for s in splits_dict) if splits_dict else 0 - split_lines = [ - f"{split:<{max_split_len}} {', '.join(assets.keys())}" for split, assets in splits_dict.items() - ] - - body = f"{description}\n\n" + "\n".join(split_lines) - console.print(Panel(body, title=name.value, title_align="left")) - console.print() diff --git a/src/trackers/scripts/eval.py b/src/trackers/scripts/eval.py deleted file mode 100644 index 7bd25f21..00000000 --- a/src/trackers/scripts/eval.py +++ /dev/null @@ -1,169 +0,0 @@ -#!/usr/bin/env python -# ------------------------------------------------------------------------ -# Trackers -# Copyright (c) 2026 Roboflow. All Rights Reserved. -# Licensed under the Apache License, Version 2.0 [see LICENSE for details] -# ------------------------------------------------------------------------ - -from __future__ import annotations - -import argparse -import logging -import sys -from pathlib import Path - - -def add_eval_subparser(subparsers: argparse._SubParsersAction) -> None: - """Add the eval subcommand to the argument parser.""" - parser = subparsers.add_parser( - "eval", - help="Evaluate tracker predictions against ground truth.", - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - # Single sequence mode - single_group = parser.add_argument_group("single sequence evaluation") - single_group.add_argument( - "--gt", - type=Path, - metavar="PATH", - help="Path to ground truth file (MOT format).", - ) - single_group.add_argument( - "--tracker", - type=Path, - metavar="PATH", - help="Path to tracker predictions file (MOT format).", - ) - - # Benchmark mode - bench_group = parser.add_argument_group("benchmark evaluation") - bench_group.add_argument( - "--gt-dir", - type=Path, - metavar="DIR", - help="Directory containing ground truth files.", - ) - bench_group.add_argument( - "--tracker-dir", - type=Path, - metavar="DIR", - help="Directory containing tracker prediction files.", - ) - bench_group.add_argument( - "--seqmap", - type=Path, - metavar="PATH", - help="Sequence map file listing sequences to evaluate.", - ) - - # Common options - parser.add_argument( - "--metrics", - nargs="+", - default=["CLEAR"], - choices=["CLEAR", "HOTA", "Identity"], - help="Metrics to compute. Default: CLEAR. Options: CLEAR, HOTA, Identity", - ) - parser.add_argument( - "--threshold", - type=float, - default=0.5, - help="IoU threshold for CLEAR and Identity matching. Default: 0.5", - ) - parser.add_argument( - "--columns", - nargs="+", - default=None, - metavar="COL", - help=( - "Metric columns to display. Default: auto-selected based on metrics. " - "CLEAR: MOTA, MOTP, MODA, CLR_Re, CLR_Pr, MTR, PTR, MLR, sMOTA, " - "CLR_TP, CLR_FN, CLR_FP, IDSW, MT, PT, ML, Frag. " - "HOTA: HOTA, DetA, AssA, DetRe, DetPr, AssRe, AssPr, LocA. " - "Identity: IDF1, IDR, IDP, IDTP, IDFN, IDFP" - ), - ) - parser.add_argument( - "--output", - "-o", - type=Path, - metavar="PATH", - help="Output file for results (JSON format).", - ) - - parser.set_defaults(func=run_eval) - - -def run_eval(args: argparse.Namespace) -> int: - """Execute the eval command.""" - # Configure logging to show detection info - logging.basicConfig( - level=logging.INFO, - format="%(message)s", - handlers=[logging.StreamHandler(sys.stderr)], - ) - - # Validate arguments - single_mode = args.gt is not None and args.tracker is not None - benchmark_mode = args.gt_dir is not None and args.tracker_dir is not None - - if not single_mode and not benchmark_mode: - print( - "Error: Must specify either --gt/--tracker or --gt-dir/--tracker-dir", - file=sys.stderr, - ) - return 1 - - if single_mode and benchmark_mode: - print( - "Error: Cannot use both single sequence and benchmark mode", - file=sys.stderr, - ) - return 1 - - # Columns: None means auto-select based on available metrics - columns = args.columns - - # Import evaluation functions - from trackers.eval import evaluate_mot_sequence, evaluate_mot_sequences - - try: - if single_mode: - seq_result = evaluate_mot_sequence( - gt_path=args.gt, - tracker_path=args.tracker, - metrics=args.metrics, - threshold=args.threshold, - ) - print(seq_result.table(columns=columns)) - - # Save results if output specified - if args.output: - args.output.parent.mkdir(parents=True, exist_ok=True) - args.output.write_text(seq_result.json()) - print(f"\nResults saved to: {args.output}") - else: - bench_result = evaluate_mot_sequences( - gt_dir=args.gt_dir, - tracker_dir=args.tracker_dir, - seqmap=args.seqmap, - metrics=args.metrics, - threshold=args.threshold, - ) - print(bench_result.table(columns=columns)) - - # Save results if output specified - if args.output: - bench_result.save(args.output) - print(f"\nResults saved to: {args.output}") - - except FileNotFoundError as e: - print(f"Error: {e}", file=sys.stderr) - return 1 - except ValueError as e: - print(f"Error: {e}", file=sys.stderr) - return 1 - - return 0 diff --git a/src/trackers/scripts/track.py b/src/trackers/scripts/track.py deleted file mode 100644 index 539a3a23..00000000 --- a/src/trackers/scripts/track.py +++ /dev/null @@ -1,738 +0,0 @@ -#!/usr/bin/env python -# ------------------------------------------------------------------------ -# Trackers -# Copyright (c) 2026 Roboflow. All Rights Reserved. -# Licensed under the Apache License, Version 2.0 [see LICENSE for details] -# ------------------------------------------------------------------------ - -from __future__ import annotations - -import argparse -import sys -from contextlib import nullcontext -from pathlib import Path -from typing import TYPE_CHECKING - -import numpy as np -import supervision as sv - -from trackers import frames_from_source -from trackers.core.base import BaseTracker -from trackers.io.mot import _mot_frame_to_detections, _MOTOutput, load_mot_file -from trackers.io.paths import _resolve_video_output_path, _validate_output_path -from trackers.io.video import _DEFAULT_OUTPUT_FPS, _DisplayWindow, _VideoOutput -from trackers.scripts.progress import _classify_source, _SourceInfo, _TrackingProgress -from trackers.utils.device import _best_device - -if TYPE_CHECKING: - from inference_models import AnyModel - -# Defaults -DEFAULT_MODEL = "rfdetr-nano" -DEFAULT_TRACKER = "bytetrack" -DEFAULT_CONFIDENCE = 0.5 -DEFAULT_DEVICE = "auto" - -# Visualization -COLOR_PALETTE = sv.ColorPalette.from_hex( - [ - "#ffff00", - "#ff9b00", - "#ff8080", - "#ff66b2", - "#ff66ff", - "#b266ff", - "#9999ff", - "#3399ff", - "#66ffff", - "#33ff99", - "#66ff66", - "#99ff00", - ] -) - - -def add_track_subparser(subparsers: argparse._SubParsersAction) -> None: - """Add the track subcommand to the argument parser.""" - parser = subparsers.add_parser( - "track", - help="Track objects in video using detection and tracking.", - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - # Source options - source_group = parser.add_argument_group("source") - source_group.add_argument( - "--source", - type=str, - default=None, - metavar="PATH", - help="Video file, webcam index (0), RTSP URL, or image directory.", - ) - - # Detection options (mutually exclusive) - detection_group = parser.add_argument_group("detection") - det_mutex = detection_group.add_mutually_exclusive_group(required=False) - det_mutex.add_argument( - "--model", - type=str, - default=DEFAULT_MODEL, - metavar="ID", - help=( - "Model ID for detection. Pretrained: rfdetr-nano, rfdetr-base, etc. " - f"Custom: workspace/project/version. Default: {DEFAULT_MODEL}" - ), - ) - det_mutex.add_argument( - "--detections", - type=Path, - metavar="PATH", - help="Load pre-computed detections from MOT format file.", - ) - - # Model options - model_group = parser.add_argument_group("model options") - model_group.add_argument( - "--model.confidence", - type=float, - default=DEFAULT_CONFIDENCE, - dest="model_confidence", - metavar="FLOAT", - help=f"Detection confidence threshold. Default: {DEFAULT_CONFIDENCE}", - ) - model_group.add_argument( - "--model.device", - type=str, - default=DEFAULT_DEVICE, - dest="model_device", - metavar="DEVICE", - help=f"Device: auto, cpu, cuda, cuda:0, mps. Default: {DEFAULT_DEVICE}", - ) - model_group.add_argument( - "--model.api_key", - type=str, - default=None, - dest="model_api_key", - metavar="KEY", - help="Roboflow API key for custom models.", - ) - - # Filtering options - filter_group = parser.add_argument_group("filtering") - filter_group.add_argument( - "--classes", - type=str, - default=None, - metavar="NAMES_OR_IDS", - help="Filter by class names or IDs (comma-separated, e.g., person,car).", - ) - filter_group.add_argument( - "--track_ids", - type=str, - default=None, - metavar="IDS", - help="Filter output by track IDs (comma-separated, e.g., 1,3,5)", - ) - - # Tracker options - tracker_group = parser.add_argument_group("tracker options") - available_trackers = BaseTracker._registered_trackers() - tracker_group.add_argument( - "--tracker", - type=str, - default=DEFAULT_TRACKER, - choices=available_trackers if available_trackers else [DEFAULT_TRACKER, "sort"], - metavar="ID", - help=f"Tracking algorithm. Default: {DEFAULT_TRACKER}", - ) - - # Add dynamic tracker parameters - _add_tracker_params(tracker_group) - - # Output options - output_group = parser.add_argument_group("output") - output_group.add_argument( - "-o", - "--output", - type=Path, - default=None, - metavar="PATH", - help="Output video file path.", - ) - output_group.add_argument( - "--mot-output", - type=Path, - default=None, - dest="mot_output", - metavar="PATH", - help="Output MOT format file path.", - ) - output_group.add_argument( - "--overwrite", - action="store_true", - help="Overwrite existing output files.", - ) - - # Visualization options - vis_group = parser.add_argument_group("visualization") - vis_group.add_argument( - "--display", - action="store_true", - help="Show preview window.", - ) - vis_group.add_argument( - "--show-boxes", - action="store_true", - default=True, - dest="show_boxes", - help="Draw bounding boxes. Default: True", - ) - vis_group.add_argument( - "--no-boxes", - action="store_false", - dest="show_boxes", - help="Disable bounding boxes.", - ) - vis_group.add_argument( - "--show-masks", - action="store_true", - dest="show_masks", - help="Draw segmentation masks (seg models only).", - ) - vis_group.add_argument( - "--show-labels", - action="store_true", - dest="show_labels", - help="Show class labels.", - ) - vis_group.add_argument( - "--show-ids", - action="store_true", - default=True, - dest="show_ids", - help="Show track IDs. Default: True", - ) - vis_group.add_argument( - "--no-ids", - action="store_false", - dest="show_ids", - help="Disable track IDs.", - ) - vis_group.add_argument( - "--show-confidence", - action="store_true", - dest="show_confidence", - help="Show confidence scores.", - ) - vis_group.add_argument( - "--show-trajectories", - action="store_true", - dest="show_trajectories", - help="Draw track trajectories.", - ) - - parser.set_defaults(func=run_track) - - -def _add_tracker_params(group: argparse._ArgumentGroup) -> None: - """Add tracker-specific parameters from registry to argument group.""" - for tracker_id in BaseTracker._registered_trackers(): - info = BaseTracker._lookup_tracker(tracker_id) - if info is None: - continue - - for param_name, param_info in info.parameters.items(): - arg_name = f"--tracker.{param_name}" - dest_name = f"tracker_{param_name}" - - kwargs: dict = { - "dest": dest_name, - "default": param_info.default_value, - "help": f"{param_info.description} Default: {param_info.default_value}", - } - - if param_info.param_type is bool: - kwargs["action"] = "store_false" if param_info.default_value else "store_true" - else: - kwargs["type"] = param_info.param_type - kwargs["metavar"] = param_info.param_type.__name__.upper() - - try: - group.add_argument(arg_name, **kwargs) - except argparse.ArgumentError: - # Parameter already added by another tracker - pass - - -def run_track(args: argparse.Namespace) -> int: - """Execute the track command.""" - needs_frames = args.output or args.display - - if args.source is None and not args.detections: - print( - "Error: --source is required when not using --detections.", - file=sys.stderr, - ) - return 1 - - if needs_frames and args.source is None: - print( - "Error: --source is required when using --output or --display.", - file=sys.stderr, - ) - return 1 - - # Validate output paths - if args.output: - _validate_output_path(_resolve_video_output_path(args.output), overwrite=args.overwrite) - if args.mot_output: - _validate_output_path(args.mot_output, overwrite=args.overwrite) - - # Create detection source - if args.detections: - model = None - detections_data = load_mot_file(args.detections) - class_names: list[str] = [] - else: - model = _init_model( - args.model, - device=args.model_device, - api_key=args.model_api_key, - ) - detections_data = None - class_names = getattr(model, "class_names", []) - - # Resolve class filter (names and/or integer IDs) - class_filter = _resolve_class_filter(args.classes, class_names) - - track_id_filter = _resolve_track_id_filter(args.track_ids) - - # Create tracker - tracker_params = _extract_tracker_params(args.tracker, args) - tracker = _init_tracker(args.tracker, **tracker_params) - - if args.source is not None: - return _run_with_source( - args, - model, - detections_data, - class_names, - class_filter, - track_id_filter, - tracker, - ) - else: - return _run_frameless( - args, - detections_data, - class_filter, - track_id_filter, - tracker, - ) - - -def _run_frameless( - args: argparse.Namespace, - detections_data: dict | None, - class_filter: list[int] | None, - track_id_filter: list[int] | None, - tracker: BaseTracker, -) -> int: - """Run tracking from pre-computed detections without frame source.""" - if detections_data is None or not detections_data: - print("Error: No detections found in file.", file=sys.stderr) - return 1 - - total_frames = max(detections_data.keys()) - source_info = _SourceInfo(source_type="video", total_frames=total_frames) - - try: - with ( - _MOTOutput(args.mot_output) as mot, - _TrackingProgress(source_info) as progress, - ): - interrupted = False - for frame_idx in range(1, total_frames + 1): - if frame_idx in detections_data: - detections = _mot_frame_to_detections(detections_data[frame_idx]) - else: - detections = sv.Detections.empty() - - if class_filter is not None and len(detections) > 0: - mask = np.isin(detections.class_id, class_filter) - detections = detections[mask] # type: ignore[assignment] - - tracked = tracker.update(detections) - - if track_id_filter is not None and len(tracked) > 0: - if tracked.tracker_id is not None: - mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) - tracked = tracked[mask] - - mot.write(frame_idx, tracked) - progress.update() - - progress.complete(interrupted=interrupted) - - except KeyboardInterrupt: - pass - - return 0 - - -def _run_with_source( - args: argparse.Namespace, - model, - detections_data: dict | None, - class_names: list[str], - class_filter: list[int] | None, - track_id_filter: list[int] | None, - tracker: BaseTracker, -) -> int: - """Run tracking with a frame source (video, webcam, images).""" - frame_gen = frames_from_source(args.source) - source_info = _classify_source(args.source) - - # Setup annotators - annotators, label_annotator = _init_annotators( - show_boxes=args.show_boxes, - show_masks=args.show_masks, - show_labels=args.show_labels, - show_ids=args.show_ids, - show_confidence=args.show_confidence, - ) - trace_annotator = None - if args.show_trajectories: - trace_annotator = sv.TraceAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - - display_ctx = _DisplayWindow() if args.display else nullcontext() - - try: - with ( - _VideoOutput( - args.output, - fps=source_info.fps or _DEFAULT_OUTPUT_FPS, - ) as video, - _MOTOutput(args.mot_output) as mot, - display_ctx as display, - _TrackingProgress(source_info) as progress, - ): - interrupted = False - for frame_idx, frame in frame_gen: - # Get detections - if model is not None: - detections = _run_model(model, frame, args.model_confidence) - elif detections_data is not None and frame_idx in detections_data: - detections = _mot_frame_to_detections(detections_data[frame_idx]) - else: - detections = sv.Detections.empty() - - # Filter by class - if class_filter is not None and len(detections) > 0: - mask = np.isin(detections.class_id, class_filter) - detections = detections[mask] # type: ignore[assignment] - - # Run tracker - tracked = tracker.update(detections, frame) - - # Filter by track ID - if track_id_filter is not None and len(tracked) > 0: - if tracked.tracker_id is not None: - mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) - tracked = tracked[mask] - - # Write MOT output - mot.write(frame_idx, tracked) - - progress.update() - - # Annotate and display/save frame - if args.display or args.output: - annotated = frame.copy() - if trace_annotator is not None: - annotated = trace_annotator.annotate(annotated, tracked) - for annotator in annotators: - annotated = annotator.annotate(annotated, tracked) - if label_annotator is not None: - labeled = tracked[tracked.tracker_id != -1] - labels = _format_labels( - labeled, - class_names, - show_ids=args.show_ids, - show_labels=args.show_labels, - show_confidence=args.show_confidence, - ) - annotated = label_annotator.annotate(annotated, labeled, labels) - - video.write(annotated) - - if display is not None: - display.show(annotated) - if display.quit_requested: - interrupted = True - break - - progress.complete(interrupted=interrupted) - - except KeyboardInterrupt: - pass - - return 0 - - -def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None: - """Resolve a comma-separated `--track-ids` value to a list of integer IDs. - - Args: - track_ids_arg: Raw `--track-ids` string (e.g. `"1,3,5"`). `None` - means no filter. - - Returns: - List of integer track IDs, or `None` when no valid filter remains. - """ - if not track_ids_arg: - return None - - track_ids: list[int] = [] - for token in track_ids_arg.split(","): - token = token.strip() - try: - track_ids.append(int(token)) - except ValueError: - print( - f"Warning: '{token}' is not a valid track ID, skipping.", - file=sys.stderr, - ) - return track_ids if track_ids else None - - -def _resolve_class_filter( - classes_arg: str | None, - class_names: list[str], -) -> list[int] | None: - """Resolve a comma-separated `--classes` value to a list of integer IDs. - - Each token is checked independently: if it parses as an `int` it is used - directly as a class ID; otherwise it is looked up by name in *class_names*. - Unknown names are printed as warnings and skipped. - - Args: - classes_arg: Raw `--classes` string (e.g. `"person,car"` or - `"0,2"` or `"person,2"`). `None` means no filter. - class_names: Ordered list of class names where the index equals the - class ID (as provided by the model). - - Returns: - List of integer class IDs, or `None` when no valid filter remains. - """ - if not classes_arg: - return None - - requested = [token.strip() for token in classes_arg.split(",")] - name_to_id = {name: i for i, name in enumerate(class_names)} - class_filter: list[int] = [] - for token in requested: - try: - class_filter.append(int(token)) - except ValueError: - if token in name_to_id: - class_filter.append(name_to_id[token]) - else: - print( - f"Warning: class '{token}' not found in model class list, skipping.", - file=sys.stderr, - ) - return class_filter if class_filter else None - - -def _init_model( - model_id: str, - *, - device: str = DEFAULT_DEVICE, - api_key: str | None = None, -) -> AnyModel: - """Load detection model via inference-models. - - Args: - model_id: Model identifier (e.g., 'rfdetr-nano' or 'workspace/project/version'). - device: Device to load model on ('auto', 'cpu', 'cuda', 'mps'). - api_key: Roboflow API key for custom models. - - Returns: - Loaded model instance. - """ - try: - from inference_models import AutoModel - except ImportError as e: - print( - "Error: inference-models is required for model-based detection.\n" - "Install with: pip install 'trackers[detection]'", - file=sys.stderr, - ) - raise SystemExit(1) from e - - resolved_device = _best_device() if device == DEFAULT_DEVICE else device - - return AutoModel.from_pretrained( - model_id, - api_key=api_key, - device=resolved_device, - ) - - -def _run_model(model: AnyModel, frame: np.ndarray, confidence: float) -> sv.Detections: - """Run model inference and return sv.Detections.""" - predictions = model(frame) - if not predictions: - return sv.Detections.empty() - - detections = predictions[0].to_supervision() - - # Filter by confidence - if len(detections) > 0 and detections.confidence is not None: - mask = detections.confidence >= confidence - detections = detections[mask] - - return detections - - -def _extract_tracker_params(tracker_id: str, args: argparse.Namespace) -> dict[str, object]: - """Extract tracker parameters from CLI args. - - Args: - tracker_id: Registered tracker name. - args: Parsed CLI arguments. - - Returns: - Dictionary of tracker parameters with non-None values. - """ - info = BaseTracker._lookup_tracker(tracker_id) - if info is None: - return {} - - params = {} - for param_name in info.parameters: - dest_name = f"tracker_{param_name}" - if hasattr(args, dest_name): - value = getattr(args, dest_name) - if value is not None: - params[param_name] = value - return params - - -def _init_tracker(tracker_id: str, **kwargs: object) -> BaseTracker: - """Create tracker instance from registry. - - Args: - tracker_id: Registered tracker name (e.g., 'bytetrack', 'sort'). - **kwargs: Tracker-specific parameters. - - Returns: - Initialized tracker instance. - - Raises: - ValueError: If tracker_id is not registered. - """ - info = BaseTracker._lookup_tracker(tracker_id) - if info is None: - available = ", ".join(BaseTracker._registered_trackers()) - raise ValueError(f"Unknown tracker: '{tracker_id}'. Available: {available}") - - return info.tracker_class(**kwargs) - - -def _init_annotators( - show_boxes: bool = False, - show_masks: bool = False, - show_labels: bool = False, - show_ids: bool = False, - show_confidence: bool = False, -) -> tuple[list, sv.LabelAnnotator | None]: - """Initialize supervision annotators based on display options. - - Args: - show_boxes: Create BoxAnnotator. - show_masks: Create MaskAnnotator. - show_labels: Include class labels (triggers LabelAnnotator). - show_ids: Include track IDs (triggers LabelAnnotator). - show_confidence: Include confidence scores (triggers LabelAnnotator). - - Returns: - Tuple of (annotators list, label_annotator or None). - Label annotator is separate because it needs custom labels per frame. - """ - annotators: list = [] - label_annotator: sv.LabelAnnotator | None = None - - if show_boxes: - annotators.append( - sv.BoxAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - ) - - if show_masks: - annotators.append( - sv.MaskAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - ) - - if show_labels or show_ids or show_confidence: - label_annotator = sv.LabelAnnotator( - color=COLOR_PALETTE, - text_color=sv.Color.BLACK, - text_position=sv.Position.TOP_LEFT, - color_lookup=sv.ColorLookup.TRACK, - ) - - return annotators, label_annotator - - -def _format_labels( - detections: sv.Detections, - class_names: list[str], - *, - show_ids: bool = False, - show_labels: bool = False, - show_confidence: bool = False, -) -> list[str]: - """Generate label strings for each detection. - - Args: - detections: Detections to generate labels for. - class_names: List of class names for lookup. - show_ids: Include tracker IDs in labels. - show_labels: Include class names in labels. - show_confidence: Include confidence scores in labels. - - Returns: - List of label strings, one per detection. - """ - labels = [] - - for i in range(len(detections)): - parts = [] - - if show_ids and detections.tracker_id is not None: - parts.append(f"#{int(detections.tracker_id[i])}") - - if show_labels and detections.class_id is not None: - class_id = int(detections.class_id[i]) - if class_names and 0 <= class_id < len(class_names): - parts.append(class_names[class_id]) - else: - parts.append(str(class_id)) - - if show_confidence and detections.confidence is not None: - parts.append(f"{detections.confidence[i]:.2f}") - - labels.append(" ".join(parts)) - - return labels diff --git a/src/trackers/scripts/tune.py b/src/trackers/scripts/tune.py deleted file mode 100644 index 03457432..00000000 --- a/src/trackers/scripts/tune.py +++ /dev/null @@ -1,230 +0,0 @@ -#!/usr/bin/env python -# ------------------------------------------------------------------------ -# Trackers -# Copyright (c) 2026 Roboflow. All Rights Reserved. -# Licensed under the Apache License, Version 2.0 [see LICENSE for details] -# ------------------------------------------------------------------------ - -from __future__ import annotations - -import argparse -import json -import sys -from pathlib import Path - - -def add_tune_subparser(subparsers: argparse._SubParsersAction) -> None: - """Add the tune subcommand to the argument parser.""" - parser = subparsers.add_parser( - "tune", - help="Tune tracker hyperparameters via Optuna.", - description=( - "Run Optuna-based hyperparameter optimisation for a registered " - "tracker using pre-computed detections and ground-truth MOT files." - ), - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - parser.add_argument( - "--tracker", - required=True, - metavar="ID", - help="Tracker ID to tune (e.g. bytetrack, sort, ocsort).", - ) - parser.add_argument( - "--gt-dir", - type=Path, - required=True, - metavar="DIR", - help="Directory containing ground-truth MOT files.", - ) - parser.add_argument( - "--detections-dir", - type=Path, - required=True, - metavar="DIR", - help=("Directory containing pre-computed detection files in MOT flat format (one {seq}.txt per sequence)."), - ) - parser.add_argument( - "--objective", - default="HOTA", - choices=["MOTA", "HOTA", "IDF1"], - help="Scalar metric to maximise. Default: HOTA.", - ) - parser.add_argument( - "--n-trials", - type=int, - default=100, - metavar="N", - help="Number of Optuna trials to run. Default: 100.", - ) - parser.add_argument( - "--metrics", - nargs="+", - default=["CLEAR"], - choices=["CLEAR", "HOTA", "Identity"], - help=( - "Metric families to compute. Default: CLEAR. The family required " - "by --objective is added automatically if missing." - ), - ) - parser.add_argument( - "--threshold", - type=float, - default=0.5, - help="IoU threshold for CLEAR and Identity matching. Default: 0.5.", - ) - parser.add_argument( - "--seqmap", - type=Path, - metavar="PATH", - help="Sequence map file listing sequences to evaluate.", - ) - parser.add_argument( - "--fixed-params", - type=str, - metavar="JSON", - help=("JSON object of tracker kwargs held fixed for every trial (e.g. '{\"enable_cmc\": false}')."), - ) - parser.add_argument( - "--images-dir", - type=Path, - metavar="DIR", - help="MOT image root ({sequence}/img1/) for trackers that need frames (e.g. BoTSORT CMC).", - ) - parser.add_argument( - "--no-enqueue-defaults", - action="store_true", - help="Skip the baseline trial that uses tracker/search_space defaults.", - ) - parser.add_argument( - "--seed", - type=int, - default=None, - metavar="N", - help="Random seed for Optuna sampling (reproducible hyperparameter trials).", - ) - parser.add_argument( - "--output", - "-o", - type=Path, - metavar="PATH", - help="Output file for best parameters (JSON format).", - ) - - parser.set_defaults(func=run_tune) - - -def run_tune(args: argparse.Namespace) -> int: - """Execute the tune command.""" - fixed_params = None - if args.fixed_params is not None: - try: - fixed_params = json.loads(args.fixed_params) - except json.JSONDecodeError as e: - print(f"Invalid --fixed-params JSON: {e}", file=sys.stderr) - return 1 - if not isinstance(fixed_params, dict): - print("--fixed-params must be a JSON object", file=sys.stderr) - return 1 - - return tune( - tracker=args.tracker, - gt_dir=args.gt_dir, - detections_dir=args.detections_dir, - objective=args.objective, - n_trials=args.n_trials, - metrics=args.metrics, - threshold=args.threshold, - seqmap=args.seqmap, - fixed_params=fixed_params, - images_dir=args.images_dir, - enqueue_defaults=not args.no_enqueue_defaults, - seed=args.seed, - output=args.output, - ) - - -def tune( - tracker: str, - gt_dir: Path, - detections_dir: Path, - objective: str = "HOTA", - n_trials: int = 100, - metrics: list[str] | None = None, - threshold: float = 0.5, - seqmap: Path | None = None, - fixed_params: dict | None = None, - images_dir: Path | None = None, - enqueue_defaults: bool = True, - seed: int | None = None, - output: Path | None = None, -) -> int: - """Tune tracker hyperparameters using Optuna. - - Args: - tracker: Tracker ID to tune (e.g. bytetrack, sort). - gt_dir: Directory of ground-truth MOT files. - detections_dir: Directory of pre-computed detection files in MOT flat - format (one {seq}.txt per sequence). - objective: Scalar metric to maximise. Options: MOTA, HOTA, IDF1. - n_trials: Number of Optuna trials to run. - metrics: Metric families to compute. Options: CLEAR, HOTA, Identity. - Default: CLEAR. - threshold: IoU threshold for CLEAR and Identity matching. - seqmap: Sequence map file listing sequences to evaluate. - enqueue_defaults: Whether to run a baseline trial before sampling. - fixed_params: Tracker kwargs held constant for every trial. - images_dir: MOT image root for frame-based features (e.g. CMC). - seed: Random seed for Optuna's TPE sampler. - output: Output file path for best parameters (JSON format). - - Returns: - Exit code: 0 on success, 1 on error. - """ - if metrics is None: - metrics = ["CLEAR"] - - from trackers.tune import Tuner - - try: - tuner = Tuner( - tracker_id=tracker, - gt_dir=gt_dir, - detections_dir=detections_dir, - metrics=metrics, - objective=objective, - n_trials=n_trials, - enqueue_defaults=enqueue_defaults, - fixed_params=fixed_params, - images_dir=images_dir, - seed=seed, - threshold=threshold, - seqmap=seqmap, - ) - except (ValueError, ImportError, FileNotFoundError) as e: - print(str(e), file=sys.stderr) - return 1 - - try: - best_params = tuner.run() - except Exception as e: - print(f"Error during tuning: {e}", file=sys.stderr) - return 1 - - print(f"\nBest parameters for {tracker}:") - for name, value in best_params.items(): - print(f" {name}: {value}") - if tuner.study is not None: - print(f"\nBest {objective}: {tuner.study.best_value:.4f}") - - if output: - try: - output.parent.mkdir(parents=True, exist_ok=True) - output.write_text(json.dumps(best_params, indent=2)) - except OSError as e: - print(f"Error writing output: {e}", file=sys.stderr) - return 1 - print(f"\nResults saved to: {output}") - - return 0 diff --git a/src/trackers/utils/iou.py b/src/trackers/utils/iou.py index 9325bf68..64982bb0 100644 --- a/src/trackers/utils/iou.py +++ b/src/trackers/utils/iou.py @@ -400,3 +400,37 @@ def _compute(self, boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: def normalize_for_fusion(self, similarity_matrix: np.ndarray) -> np.ndarray: return (similarity_matrix + 1.0) / 2.0 + + +_VARIANTS: dict[str, type[BaseIoU]] = { + "iou": IoU, + "giou": GIoU, + "diou": DIoU, + "ciou": CIoU, + "biou": BIoU, +} + + +def variant_from_name(name: str) -> BaseIoU: + """Resolve a variant name (case-insensitive) to a default-constructed instance. + + Args: + name: One of ``iou``, ``giou``, ``diou``, ``ciou``, ``biou`` + (case-insensitive). + + Returns: + A default-constructed instance of the matching :class:`BaseIoU` subclass. + + Raises: + ValueError: If ``name`` does not match any known variant. + + Examples: + >>> isinstance(variant_from_name("giou"), GIoU) + True + >>> isinstance(variant_from_name("BIOU"), BIoU) + True + """ + try: + return _VARIANTS[name.lower()]() + except KeyError as exc: + raise ValueError(f"Unknown IoU variant {name!r}. Valid: {sorted(_VARIANTS)}") from exc diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..57226e88 --- /dev/null +++ b/tests/cli/__init__.py @@ -0,0 +1,5 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ diff --git a/tests/scripts/test_download.py b/tests/cli/test_download.py similarity index 55% rename from tests/scripts/test_download.py rename to tests/cli/test_download.py index 94b3f573..852db7a5 100644 --- a/tests/scripts/test_download.py +++ b/tests/cli/test_download.py @@ -4,93 +4,38 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +"""CLI-level tests for trackers/cli/download.py.""" + from __future__ import annotations -import argparse from unittest.mock import patch import pytest +from trackers.cli.download import _print_available, download from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR -from trackers.scripts.download import ( - _print_available, - _run_download, - add_download_subparser, -) - - -def _parse_args(argv: list[str]) -> argparse.Namespace: - """Parse argv through a fresh download subparser and return the namespace.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - add_download_subparser(subparsers) - return parser.parse_args(argv) - - -class TestSubparserRegistration: - """Argument parsing and help strings.""" - - def test_list_flag(self) -> None: - """--list sets the flag to True.""" - args = _parse_args(["download", "--list"]) - assert args.list is True - - def test_list_flag_default_false(self) -> None: - """--list is False when omitted.""" - args = _parse_args(["download", "mot17"]) - assert args.list is False - - def test_split_flag_accepts_comma_separated(self) -> None: - """--split accepts comma-separated values.""" - args = _parse_args(["download", "mot17", "--split", "train,val"]) - assert args.split == "train,val" - - def test_asset_flag_accepts_comma_separated(self) -> None: - """--asset accepts comma-separated values.""" - args = _parse_args(["download", "mot17", "--asset", "frames,annotations"]) - assert args.asset == "frames,annotations" - - def test_output_directory_short_flag(self) -> None: - """-o sets the output directory.""" - args = _parse_args(["download", "mot17", "-o", "./datasets"]) - assert args.output == "./datasets" - - def test_cache_dir_flag(self) -> None: - """--cache-dir sets the cache directory.""" - args = _parse_args(["download", "mot17", "--cache-dir", "./cache"]) - assert args.cache_dir == "./cache" - def test_dataset_positional(self) -> None: - """Dataset is captured as positional argument.""" - args = _parse_args(["download", "sportsmot"]) - assert args.dataset == "sportsmot" - -class TestRunDownload: +class TestDownload: """Execution of the download subcommand.""" def test_list_triggers_print(self) -> None: - """--list calls _print_available and returns 0.""" - args = _parse_args(["download", "--list"]) - - with patch("trackers.scripts.download._print_available") as mock_print: - rc = _run_download(args) + """list_available=True calls _print_available and returns 0.""" + with patch("trackers.cli.download._print_available") as mock_print: + rc = download(list_available=True) assert rc == 0 mock_print.assert_called_once() def test_list_takes_precedence_over_dataset(self) -> None: - """--list wins over dataset positional.""" - args = _parse_args(["download", "mot17", "--list"]) - - with patch("trackers.scripts.download._print_available") as mock_print: - rc = _run_download(args) + """list_available=True wins over dataset argument.""" + with patch("trackers.cli.download._print_available") as mock_print: + rc = download(dataset="mot17", list_available=True) assert rc == 0 mock_print.assert_called_once() def test_missing_dataset_exits_with_error(self, capsys: pytest.CaptureFixture[str]) -> None: - """No dataset and no --list prints error to stderr and returns 1.""" - args = _parse_args(["download"]) - rc = _run_download(args) + """No dataset and no list_available prints error to stderr and returns 1.""" + rc = download() captured = capsys.readouterr() assert rc == 1 assert "Please specify a dataset" in captured.err @@ -104,11 +49,9 @@ def test_missing_dataset_exits_with_error(self, capsys: pytest.CaptureFixture[st ], ) def test_split_comma_parsing(self, split_arg: str, expected_splits: list[str]) -> None: - """--split values are split on commas and whitespace-stripped.""" - args = _parse_args(["download", "mot17", "--split", split_arg, "--asset", "annotations"]) - + """split values are split on commas and whitespace-stripped.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17", split=split_arg, asset="annotations") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -127,11 +70,9 @@ def test_split_comma_parsing(self, split_arg: str, expected_splits: list[str]) - ], ) def test_split_comma_parsing_boundary(self, split_arg: str, expected_splits: list[str]) -> None: - """--split handles malformed comma inputs gracefully.""" - args = _parse_args(["download", "mot17", "--split", split_arg, "--asset", "annotations"]) - + """split handles malformed comma inputs gracefully.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17", split=split_arg, asset="annotations") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -150,11 +91,9 @@ def test_split_comma_parsing_boundary(self, split_arg: str, expected_splits: lis ], ) def test_asset_comma_parsing(self, asset_arg: str, expected_assets: list[str]) -> None: - """--asset values are split on commas and whitespace-stripped.""" - args = _parse_args(["download", "sportsmot", "--split", "train", "--asset", asset_arg]) - + """asset values are split on commas and whitespace-stripped.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="sportsmot", split="train", asset=asset_arg) assert rc == 0 mock_dl.assert_called_once_with( dataset="sportsmot", @@ -165,11 +104,9 @@ def test_asset_comma_parsing(self, asset_arg: str, expected_assets: list[str]) - ) def test_none_splits_and_assets_when_omitted(self) -> None: - """When --split and --asset are omitted, None is forwarded.""" - args = _parse_args(["download", "mot17"]) - + """When split and asset are omitted, None is forwarded.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -180,11 +117,9 @@ def test_none_splits_and_assets_when_omitted(self) -> None: ) def test_output_directory_forwarded(self) -> None: - """-o value is forwarded to download_dataset.""" - args = _parse_args(["download", "mot17", "-o", "/custom/path"]) - + """output value is forwarded to download_dataset.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17", output="/custom/path") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -196,21 +131,17 @@ def test_output_directory_forwarded(self) -> None: def test_value_error_returns_exit_code(self) -> None: """ValueError from download_dataset is caught and returns 1.""" - args = _parse_args(["download", "mot17"]) - with patch( "trackers.datasets.download.download_dataset", side_effect=ValueError("bad dataset"), ): - rc = _run_download(args) + rc = download(dataset="mot17") assert rc == 1 def test_split_with_spaces_stripped(self) -> None: - """--split with spaces around commas strips whitespace.""" - args = _parse_args(["download", "mot17", "--split", "train , val", "--asset", "annotations"]) - + """split with spaces around commas strips whitespace.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17", split="train , val", asset="annotations") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -222,7 +153,7 @@ def test_split_with_spaces_stripped(self) -> None: class TestPrintAvailable: - """Output of --list.""" + """Output of list_available.""" def test_prints_without_error(self, capsys: pytest.CaptureFixture[str]) -> None: """_print_available runs without raising and does not leak output.""" diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py new file mode 100644 index 00000000..9f4c4e19 --- /dev/null +++ b/tests/cli/test_main.py @@ -0,0 +1,57 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Tests for trackers.cli.__main__ — jsonargparse CLI integration.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml +from jsonargparse import ArgumentParser + +from trackers.cli.track import track + + +@pytest.fixture() +def track_parser() -> ArgumentParser: + """ArgumentParser built from the track() signature with --config support.""" + parser = ArgumentParser(exit_on_error=False) + parser.add_function_arguments(track) + parser.add_argument("--config", action="config") + return parser + + +class TestConfigFileSupport: + """Verify jsonargparse --config flag behaviour for the track subcommand.""" + + def test_config_value_applied_to_tracker(self, track_parser: ArgumentParser, tmp_path: Path) -> None: + """YAML --config value is parsed into the track() namespace.""" + cfg = tmp_path / "run.yaml" + cfg.write_text(yaml.dump({"tracker": "sort"})) + + ns = track_parser.parse_args(["--config", str(cfg)]) + + assert ns.tracker == "sort" + + def test_cli_arg_overrides_config_value(self, track_parser: ArgumentParser, tmp_path: Path) -> None: + """Explicit CLI arg takes precedence over the --config file value.""" + cfg = tmp_path / "run.yaml" + cfg.write_text(yaml.dump({"tracker": "sort"})) + + ns = track_parser.parse_args(["--config", str(cfg), "--tracker", "bytetrack"]) + + assert ns.tracker == "bytetrack" + + def test_nested_dataclass_field_in_config(self, track_parser: ArgumentParser, tmp_path: Path) -> None: + """Nested DetectionOptions fields can be set via --config.""" + cfg = tmp_path / "run.yaml" + cfg.write_text(yaml.dump({"detection": {"confidence": 0.3}})) + + ns = track_parser.parse_args(["--config", str(cfg)]) + + assert ns.detection.confidence == pytest.approx(0.3) diff --git a/tests/scripts/test_progress.py b/tests/cli/test_progress.py similarity index 98% rename from tests/scripts/test_progress.py rename to tests/cli/test_progress.py index 91486afe..85d3bab3 100644 --- a/tests/scripts/test_progress.py +++ b/tests/cli/test_progress.py @@ -17,7 +17,7 @@ import pytest from rich.console import Console -from trackers.scripts.progress import ( +from trackers.cli.progress import ( _classify_source, _format_time, _SourceInfo, @@ -129,7 +129,7 @@ def test_video_with_zero_frame_count(self) -> None: cv2.CAP_PROP_FPS: 30.0, }.get(prop, 0.0) - with patch("trackers.scripts.progress.cv2.VideoCapture", return_value=mock_cap): + with patch("trackers.cli.progress.cv2.VideoCapture", return_value=mock_cap): info = _classify_source("some_video.mp4") assert info.source_type == "video" diff --git a/tests/scripts/test_track.py b/tests/cli/test_track.py similarity index 99% rename from tests/scripts/test_track.py rename to tests/cli/test_track.py index be3867ed..5ac9708e 100644 --- a/tests/scripts/test_track.py +++ b/tests/cli/test_track.py @@ -12,7 +12,7 @@ import pytest import supervision as sv -from trackers.scripts.track import ( +from trackers.cli.track import ( _format_labels, _init_annotators, _resolve_class_filter, diff --git a/tests/scripts/test_tune.py b/tests/cli/test_tune.py similarity index 52% rename from tests/scripts/test_tune.py rename to tests/cli/test_tune.py index 24169e20..3a8f1be1 100644 --- a/tests/scripts/test_tune.py +++ b/tests/cli/test_tune.py @@ -4,97 +4,17 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -"""CLI-level tests for trackers/scripts/tune.py.""" +"""CLI-level tests for trackers/cli/tune.py.""" from __future__ import annotations -import argparse import json from pathlib import Path from unittest.mock import MagicMock, patch import pytest -from trackers.scripts.tune import add_tune_subparser, run_tune, tune - - -def _make_parser() -> tuple[argparse.ArgumentParser, argparse._SubParsersAction]: - """Return a top-level parser with a subparsers group.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - return parser, subparsers - - -class TestAddTuneSubparser: - @pytest.fixture - def minimal_args(self) -> argparse.Namespace: - """Parsed args with only required flags.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - return parser.parse_args(["tune", "--tracker", "sort", "--gt-dir", "/gt", "--detections-dir", "/det"]) - - def test_registers_tune_subcommand(self) -> None: - """tune subcommand is accessible under the 'tune' name.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args(["tune", "--tracker", "sort", "--gt-dir", "/gt", "--detections-dir", "/det"]) - assert args.func is run_tune - - def test_required_args_parsed(self) -> None: - """--tracker, --gt-dir, and --detections-dir are required and parsed.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args( - [ - "tune", - "--tracker", - "bytetrack", - "--gt-dir", - "/data/gt", - "--detections-dir", - "/data/det", - ] - ) - assert args.tracker == "bytetrack" - assert args.gt_dir == Path("/data/gt") - assert args.detections_dir == Path("/data/det") - - @pytest.mark.parametrize( - "flag,expected", - [ - ("objective", "HOTA"), - ("n_trials", 100), - ("threshold", 0.5), - ("seqmap", None), - ("output", None), - ], - ) - def test_optional_defaults(self, minimal_args: argparse.Namespace, flag: str, expected: object) -> None: - """Optional arguments have correct defaults when omitted.""" - assert getattr(minimal_args, flag) == expected - - def test_metrics_default(self, minimal_args: argparse.Namespace) -> None: - """--metrics defaults to ['CLEAR'] when not supplied.""" - assert minimal_args.metrics == ["CLEAR"] - - def test_output_flag_short_form(self) -> None: - """-o is an alias for --output.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args( - [ - "tune", - "--tracker", - "sort", - "--gt-dir", - "/gt", - "--detections-dir", - "/det", - "-o", - "/out/params.json", - ] - ) - assert args.output == Path("/out/params.json") +from trackers.cli.tune import tune class TestTune: @@ -180,42 +100,118 @@ def test_returns_1_on_tuner_run_exception(self, tmp_path: Path) -> None: assert result == 1 -class TestRunTune: - def test_delegates_to_tune_with_namespace_args(self, tmp_path: Path) -> None: - """run_tune() passes all argparse.Namespace fields to tune() correctly.""" +class TestCliInvocation: + """tune() is wired into the jsonargparse CLI with the expected args.""" + + @staticmethod + def _invoke(args: list[str], spy: list[dict]) -> object: + """Run jsonargparse.CLI() with a recording spy for `tune`. + + The spy mirrors the real signature so jsonargparse can introspect it. + """ + from jsonargparse import CLI + + from trackers.cli.tune import tune as real_tune + + def spy_tune( + tracker: str, + gt_dir: Path, + detections_dir: Path, + objective: str = "HOTA", + n_trials: int = 100, + metrics: list[str] | None = None, + threshold: float = 0.5, + seqmap: Path | None = None, + fixed_params: dict | None = None, + images_dir: Path | None = None, + enqueue_defaults: bool = True, + seed: int | None = None, + output: Path | None = None, + ) -> int: + spy.append( + dict( + tracker=tracker, + gt_dir=gt_dir, + detections_dir=detections_dir, + objective=objective, + n_trials=n_trials, + metrics=metrics, + threshold=threshold, + seqmap=seqmap, + fixed_params=fixed_params, + images_dir=images_dir, + enqueue_defaults=enqueue_defaults, + seed=seed, + output=output, + ) + ) + return 0 + + # Copy the docstring so jsonargparse's introspection matches the real function. + spy_tune.__doc__ = real_tune.__doc__ + return CLI({"tune": spy_tune}, as_positional=False, args=args) + + def test_cli_dispatch_to_tune(self, tmp_path: Path) -> None: + """jsonargparse.CLI() parses the tune subcommand and forwards args.""" gt_dir = tmp_path / "gt" det_dir = tmp_path / "det" - output_path = tmp_path / "params.json" - args = argparse.Namespace( - tracker="sort", - gt_dir=gt_dir, - detections_dir=det_dir, - objective="MOTA", - n_trials=50, - metrics=["CLEAR", "HOTA"], - threshold=0.3, - seqmap=None, - fixed_params=None, - images_dir=None, - no_enqueue_defaults=False, - seed=None, - output=output_path, + spy: list[dict] = [] + result = self._invoke( + [ + "tune", + "--tracker", + "sort", + "--gt_dir", + str(gt_dir), + "--detections_dir", + str(det_dir), + "--objective", + "MOTA", + "--n_trials", + "50", + ], + spy, ) - with patch("trackers.scripts.tune.tune", return_value=0) as mock_tune: - result = run_tune(args) assert result == 0 - mock_tune.assert_called_once_with( - tracker="sort", - gt_dir=gt_dir, - detections_dir=det_dir, - objective="MOTA", - n_trials=50, - metrics=["CLEAR", "HOTA"], - threshold=0.3, - seqmap=None, - fixed_params=None, - images_dir=None, - enqueue_defaults=True, - seed=None, - output=output_path, + assert len(spy) == 1 + assert spy[0]["tracker"] == "sort" + assert spy[0]["gt_dir"] == gt_dir + assert spy[0]["detections_dir"] == det_dir + assert spy[0]["objective"] == "MOTA" + assert spy[0]["n_trials"] == 50 + + @pytest.mark.parametrize( + "flag,arg_value,attr,expected", + [ + ("--objective", "HOTA", "objective", "HOTA"), + ("--n_trials", "100", "n_trials", 100), + ("--threshold", "0.5", "threshold", 0.5), + ], + ) + def test_cli_defaults( + self, + tmp_path: Path, + flag: str, + arg_value: str, + attr: str, + expected: object, + ) -> None: + """Optional flags carry their declared defaults when invoked via CLI.""" + gt_dir = tmp_path / "gt" + det_dir = tmp_path / "det" + spy: list[dict] = [] + self._invoke( + [ + "tune", + "--tracker", + "sort", + "--gt_dir", + str(gt_dir), + "--detections_dir", + str(det_dir), + flag, + arg_value, + ], + spy, ) + assert spy[0][attr] == expected diff --git a/tests/utils/test_iou.py b/tests/utils/test_iou.py index 8c992558..6eddf05e 100644 --- a/tests/utils/test_iou.py +++ b/tests/utils/test_iou.py @@ -12,7 +12,7 @@ torch = pytest.importorskip("torch") torchvision = pytest.importorskip("torchvision") -from trackers.utils.iou import BaseIoU, BIoU, CIoU, DIoU, GIoU, IoU # noqa: E402 +from trackers.utils.iou import BaseIoU, BIoU, CIoU, DIoU, GIoU, IoU, variant_from_name # noqa: E402 def _torchvision_giou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: @@ -460,3 +460,34 @@ def test_inverted_coords_gives_zero_or_negative_similarity(self, metric: BaseIoU result = metric.compute(boxes_a, boxes_b) assert result.shape == (1, 1) assert np.isfinite(result).all(), "Inverted-coord box should not produce NaN/inf" + + +class TestVariantFromName: + """Tests for variant_from_name() registry lookup.""" + + @pytest.mark.parametrize( + ("name", "expected_type"), + [ + ("iou", IoU), + ("giou", GIoU), + ("diou", DIoU), + ("ciou", CIoU), + ("biou", BIoU), + ], + ) + def test_valid_names_return_correct_instance(self, name: str, expected_type: type) -> None: + """Each lowercase variant name resolves to the right BaseIoU subclass.""" + result = variant_from_name(name) + assert isinstance(result, expected_type) + + @pytest.mark.parametrize("name", ["IOU", "GIoU", "BIOU", "DiOU", "CIou"]) + def test_case_insensitive_lookup(self, name: str) -> None: + """Lookup is case-insensitive — any casing resolves without error.""" + result = variant_from_name(name) + assert isinstance(result, BaseIoU) + + @pytest.mark.parametrize("name", ["", "foo", "wiou", "iou2"]) + def test_invalid_name_raises_value_error(self, name: str) -> None: + """Unknown names raise ValueError; repr(name) appears in the error message.""" + with pytest.raises(ValueError, match=repr(name)): + variant_from_name(name) diff --git a/uv.lock b/uv.lock index 08d017d8..bab634d3 100644 --- a/uv.lock +++ b/uv.lock @@ -649,6 +649,15 @@ version = "0.6.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/a2/55/8f8cab2afd404cf578136ef2cc5dfb50baa1761b68c9da1fb1e4eed343c9/docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491", size = 25901, upload-time = "2014-06-16T11:18:57.406Z" } +[[package]] +name = "docstring-parser" +version = "0.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/4d/f332313098c1de1b2d2ff91cf2674415cc7cddab2ca1b01ae29774bd5fdf/docstring_parser-0.18.0.tar.gz", hash = "sha256:292510982205c12b1248696f44959db3cdd1740237a968ea1e2e7a900eeb2015", size = 29341, upload-time = "2026-04-14T04:09:19.867Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/5f/ed01f9a3cdffbd5a008556fc7b2a08ddb1cc6ace7effa7340604b1d16699/docstring_parser-0.18.0-py3-none-any.whl", hash = "sha256:b3fcbed555c47d8479be0796ef7e19c2670d428d72e96da63f3a40122860374b", size = 22484, upload-time = "2026-04-14T04:09:18.638Z" }, +] + [[package]] name = "docutils" version = "0.21.2" @@ -1103,6 +1112,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" }, ] +[[package]] +name = "importlib-resources" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/06/b56dfa750b44e86157093bc8fca0ab81dccbf5260510de4eaf1cb69b5b99/importlib_resources-7.1.0.tar.gz", hash = "sha256:0722d4c6212489c530f2a145a34c0a7a3b4721bc96a15fada5930e2a0b760708", size = 44985, upload-time = "2026-04-12T16:36:09.232Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" }, +] + [[package]] name = "inference-models" version = "0.27.2" @@ -1228,6 +1246,24 @@ version = "3.0.1" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/5e/73/e01e4c5e11ad0494f4407a3f623ad4d87714909f50b17a06ed121034ff6e/jsmin-3.0.1.tar.gz", hash = "sha256:c0959a121ef94542e807a674142606f7e90214a2b3d1eb17300244bbb5cc2bfc", size = 13925, upload-time = "2022-01-16T20:35:59.13Z" } +[[package]] +name = "jsonargparse" +version = "4.48.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fa/03/fb33f57f4987eb5eef2f221dbeccb482b6b221ae97161498ff2e4ce41c55/jsonargparse-4.48.0.tar.gz", hash = "sha256:128f0897951190a08820c282b92408e2e9a508ef6d439f02bdb87244171e77d8", size = 122074, upload-time = "2026-04-10T06:52:40.309Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/e9/c922101c1e80455d4b44b766b353dafc990da350228fc2515790e5949dd5/jsonargparse-4.48.0-py3-none-any.whl", hash = "sha256:c6a92fd71eb256437371750bb11f436b9c3294da2535f1b0406346816f04be16", size = 131277, upload-time = "2026-04-10T06:52:37.394Z" }, +] + +[package.optional-dependencies] +signatures = [ + { name = "docstring-parser" }, + { name = "typeshed-client" }, +] + [[package]] name = "keyring" version = "25.6.0" @@ -4069,6 +4105,7 @@ name = "trackers" version = "2.4.0" source = { editable = "." } dependencies = [ + { name = "jsonargparse", extra = ["signatures"] }, { name = "numpy" }, { name = "opencv-python" }, { name = "requests" }, @@ -4115,6 +4152,7 @@ mypy-types = [ [package.metadata] requires-dist = [ { name = "inference-models", marker = "extra == 'detection'", specifier = ">=0.19.0" }, + { name = "jsonargparse", extras = ["signatures"], specifier = ">=4.48.0" }, { name = "numpy", specifier = ">=2.0.2" }, { name = "opencv-python", specifier = ">=4.8.0" }, { name = "optuna", marker = "extra == 'tune'", specifier = ">=3.0.0" }, @@ -4249,6 +4287,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/96/080db0afdf2c5cc5fe512b41354e8d114fe8f65e9510c56ff8dfd40216ce/types_requests-2.33.0.20260508-py3-none-any.whl", hash = "sha256:fa01459cca184229713df03709db46a905325906d27e042cd4fd7ea3d15d3400", size = 20722, upload-time = "2026-05-08T04:50:55.548Z" }, ] +[[package]] +name = "typeshed-client" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-resources" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/7d/62fbae352d5fb7ce5ef4d9ca73bf7a9b02b790d2524ab6ef1e0e799a5d1b/typeshed_client-2.11.0.tar.gz", hash = "sha256:0b8f2ab88f611f5e97b70d2a8123942d3d7d5c74cee8ae694db83422f32f9481", size = 522774, upload-time = "2026-05-01T14:51:52.38Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/22/fa16b462157bd869dfad528f5637506b9430ca63d48fb536ecf4cc78481a/typeshed_client-2.11.0-py3-none-any.whl", hash = "sha256:5745e0990b80b29a286b22d68f81779c5c7adf1cac8969eeafba44b73b486c36", size = 787609, upload-time = "2026-05-01T14:51:51.005Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0"