From 4a2351c3c50b7564e362d4b2cdd6a23e90b39ab4 Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Tue, 12 May 2026 15:01:17 +0200 Subject: [PATCH 01/13] =?UTF-8?q?rename=20scripts=E2=86=92cli,=20add=20jso?= =?UTF-8?q?nargparse=20with=20--config=20YAML=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename `trackers/scripts/` → `trackers/cli/` and update `pyproject.toml` entry point to `trackers.cli.__main__:main` - Replace `argparse.ArgumentParser` with `jsonargparse.ArgumentParser` in `__main__.py`; all subcommand parsers now use `jsonargparse.DefaultHelpFormatter` - Add `--config` argument with `action="config"` — any subcommand run can now be specified as a YAML file (e.g. `trackers track --config run.yaml`) - Update handler type annotations from `argparse.Namespace` to `jsonargparse.Namespace`; `_add_tracker_params` / `add_*_subparser` use `Any` to accept both argparse and jsonargparse subparser objects - Add `jsonargparse>=4.48.0` to core dependencies; only transitive dep is PyYAML (depth 1, zero own deps) - Migrate all tests from `tests/scripts/` → `tests/cli/`; update patch paths (`trackers.scripts.*` → `trackers.cli.*`) --- Co-authored-by: Claude Code --- pyproject.toml | 3 ++- src/trackers/{scripts => cli}/__init__.py | 0 src/trackers/{scripts => cli}/__main__.py | 20 ++++++++++------ src/trackers/{scripts => cli}/download.py | 10 ++++---- src/trackers/{scripts => cli}/eval.py | 10 ++++---- src/trackers/{scripts => cli}/progress.py | 0 src/trackers/{scripts => cli}/track.py | 29 ++++++++++------------- src/trackers/{scripts => cli}/tune.py | 10 ++++---- tests/cli/__init__.py | 5 ++++ tests/{scripts => cli}/test_download.py | 8 +++---- tests/{scripts => cli}/test_progress.py | 4 ++-- tests/{scripts => cli}/test_track.py | 2 +- tests/{scripts => cli}/test_tune.py | 6 ++--- uv.lock | 14 +++++++++++ 14 files changed, 74 insertions(+), 47 deletions(-) rename src/trackers/{scripts => cli}/__init__.py (100%) rename src/trackers/{scripts => cli}/__main__.py (75%) rename src/trackers/{scripts => cli}/download.py (93%) rename src/trackers/{scripts => cli}/eval.py (96%) rename src/trackers/{scripts => cli}/progress.py (100%) rename src/trackers/{scripts => cli}/track.py (96%) rename src/trackers/{scripts => cli}/tune.py (96%) create mode 100644 tests/cli/__init__.py rename tests/{scripts => cli}/test_download.py (97%) rename tests/{scripts => cli}/test_progress.py (98%) rename tests/{scripts => cli}/test_track.py (99%) rename tests/{scripts => cli}/test_tune.py (97%) diff --git a/pyproject.toml b/pyproject.toml index cc622694..514168c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "opencv-python>=4.8.0", "rich>=13.0.0", "requests>=2.28.0", + "jsonargparse>=4.48.0", ] [project.optional-dependencies] @@ -48,7 +49,7 @@ detection = ["inference-models>=0.19.0"] tune = ["optuna>=3.0.0"] [project.scripts] -trackers = "trackers.scripts.__main__:main" +trackers = "trackers.cli.__main__:main" [dependency-groups] dev = [ diff --git a/src/trackers/scripts/__init__.py b/src/trackers/cli/__init__.py similarity index 100% rename from src/trackers/scripts/__init__.py rename to src/trackers/cli/__init__.py diff --git a/src/trackers/scripts/__main__.py b/src/trackers/cli/__main__.py similarity index 75% rename from src/trackers/scripts/__main__.py rename to src/trackers/cli/__main__.py index 0993f8c7..afa12bb4 100644 --- a/src/trackers/scripts/__main__.py +++ b/src/trackers/cli/__main__.py @@ -7,10 +7,11 @@ from __future__ import annotations -import argparse import sys import warnings +import jsonargparse + def main() -> int: """Main entry point for the trackers CLI.""" @@ -21,7 +22,7 @@ def main() -> int: stacklevel=2, ) - parser = argparse.ArgumentParser( + parser = jsonargparse.ArgumentParser( prog="trackers", description="Command-line tools for multi-object tracking.", epilog="For more information, visit: https://github.com/roboflow/trackers", @@ -31,18 +32,23 @@ def main() -> int: action="store_true", help="Show version and exit.", ) + parser.add_argument( + "--config", + action="config", + help="Path to a YAML/JSON config file with default argument values.", + ) - subparsers = parser.add_subparsers( + subparsers = parser.add_subparsers( # type: ignore[var-annotated] dest="command", title="commands", description="Available commands:", ) # Import and register subcommands - from trackers.scripts.download import add_download_subparser - from trackers.scripts.eval import add_eval_subparser - from trackers.scripts.track import add_track_subparser - from trackers.scripts.tune import add_tune_subparser + from trackers.cli.download import add_download_subparser + from trackers.cli.eval import add_eval_subparser + from trackers.cli.track import add_track_subparser + from trackers.cli.tune import add_tune_subparser add_download_subparser(subparsers) add_eval_subparser(subparsers) diff --git a/src/trackers/scripts/download.py b/src/trackers/cli/download.py similarity index 93% rename from src/trackers/scripts/download.py rename to src/trackers/cli/download.py index de8e461f..7982fe05 100644 --- a/src/trackers/scripts/download.py +++ b/src/trackers/cli/download.py @@ -7,9 +7,10 @@ from __future__ import annotations -import argparse import sys +from typing import Any +import jsonargparse from rich.console import Console from rich.panel import Panel @@ -17,14 +18,13 @@ from trackers.datasets.manifest import _DATASETS -def add_download_subparser( - subparsers: argparse._SubParsersAction, -) -> None: +def add_download_subparser(subparsers: Any) -> None: """Add the download subcommand to the argument parser.""" parser = subparsers.add_parser( "download", help="Download benchmark tracking datasets.", description="Download tracking datasets from the official trackers bucket.", + formatter_class=jsonargparse.DefaultHelpFormatter, ) parser.add_argument( @@ -62,7 +62,7 @@ def add_download_subparser( parser.set_defaults(func=_run_download) -def _run_download(args: argparse.Namespace) -> int: +def _run_download(args: jsonargparse.Namespace) -> int: """Execute the download subcommand.""" if args.list: _print_available() diff --git a/src/trackers/scripts/eval.py b/src/trackers/cli/eval.py similarity index 96% rename from src/trackers/scripts/eval.py rename to src/trackers/cli/eval.py index 7bd25f21..a5568c28 100644 --- a/src/trackers/scripts/eval.py +++ b/src/trackers/cli/eval.py @@ -7,19 +7,21 @@ from __future__ import annotations -import argparse import logging import sys from pathlib import Path +from typing import Any +import jsonargparse -def add_eval_subparser(subparsers: argparse._SubParsersAction) -> None: + +def add_eval_subparser(subparsers: Any) -> None: """Add the eval subcommand to the argument parser.""" parser = subparsers.add_parser( "eval", help="Evaluate tracker predictions against ground truth.", description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, + formatter_class=jsonargparse.DefaultHelpFormatter, ) # Single sequence mode @@ -96,7 +98,7 @@ def add_eval_subparser(subparsers: argparse._SubParsersAction) -> None: parser.set_defaults(func=run_eval) -def run_eval(args: argparse.Namespace) -> int: +def run_eval(args: jsonargparse.Namespace) -> int: """Execute the eval command.""" # Configure logging to show detection info logging.basicConfig( diff --git a/src/trackers/scripts/progress.py b/src/trackers/cli/progress.py similarity index 100% rename from src/trackers/scripts/progress.py rename to src/trackers/cli/progress.py diff --git a/src/trackers/scripts/track.py b/src/trackers/cli/track.py similarity index 96% rename from src/trackers/scripts/track.py rename to src/trackers/cli/track.py index 539a3a23..d76bc73a 100644 --- a/src/trackers/scripts/track.py +++ b/src/trackers/cli/track.py @@ -7,21 +7,21 @@ from __future__ import annotations -import argparse import sys -from contextlib import nullcontext +from contextlib import nullcontext, suppress from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import jsonargparse import numpy as np import supervision as sv from trackers import frames_from_source +from trackers.cli.progress import _classify_source, _SourceInfo, _TrackingProgress from trackers.core.base import BaseTracker from trackers.io.mot import _mot_frame_to_detections, _MOTOutput, load_mot_file from trackers.io.paths import _resolve_video_output_path, _validate_output_path from trackers.io.video import _DEFAULT_OUTPUT_FPS, _DisplayWindow, _VideoOutput -from trackers.scripts.progress import _classify_source, _SourceInfo, _TrackingProgress from trackers.utils.device import _best_device if TYPE_CHECKING: @@ -52,13 +52,13 @@ ) -def add_track_subparser(subparsers: argparse._SubParsersAction) -> None: +def add_track_subparser(subparsers: Any) -> None: """Add the track subcommand to the argument parser.""" parser = subparsers.add_parser( "track", help="Track objects in video using detection and tracking.", description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, + formatter_class=jsonargparse.DefaultHelpFormatter, ) # Source options @@ -235,7 +235,7 @@ def add_track_subparser(subparsers: argparse._SubParsersAction) -> None: parser.set_defaults(func=run_track) -def _add_tracker_params(group: argparse._ArgumentGroup) -> None: +def _add_tracker_params(group: Any) -> None: """Add tracker-specific parameters from registry to argument group.""" for tracker_id in BaseTracker._registered_trackers(): info = BaseTracker._lookup_tracker(tracker_id) @@ -258,14 +258,11 @@ def _add_tracker_params(group: argparse._ArgumentGroup) -> None: kwargs["type"] = param_info.param_type kwargs["metavar"] = param_info.param_type.__name__.upper() - try: + with suppress(Exception): group.add_argument(arg_name, **kwargs) - except argparse.ArgumentError: - # Parameter already added by another tracker - pass -def run_track(args: argparse.Namespace) -> int: +def run_track(args: jsonargparse.Namespace) -> int: """Execute the track command.""" needs_frames = args.output or args.display @@ -333,7 +330,7 @@ def run_track(args: argparse.Namespace) -> int: def _run_frameless( - args: argparse.Namespace, + args: jsonargparse.Namespace, detections_data: dict | None, class_filter: list[int] | None, track_id_filter: list[int] | None, @@ -382,8 +379,8 @@ def _run_frameless( def _run_with_source( - args: argparse.Namespace, - model, + args: jsonargparse.Namespace, + model: AnyModel | None, detections_data: dict | None, class_names: list[str], class_filter: list[int] | None, @@ -600,7 +597,7 @@ def _run_model(model: AnyModel, frame: np.ndarray, confidence: float) -> sv.Dete return detections -def _extract_tracker_params(tracker_id: str, args: argparse.Namespace) -> dict[str, object]: +def _extract_tracker_params(tracker_id: str, args: jsonargparse.Namespace) -> dict[str, object]: """Extract tracker parameters from CLI args. Args: diff --git a/src/trackers/scripts/tune.py b/src/trackers/cli/tune.py similarity index 96% rename from src/trackers/scripts/tune.py rename to src/trackers/cli/tune.py index 0b889b45..a8471871 100644 --- a/src/trackers/scripts/tune.py +++ b/src/trackers/cli/tune.py @@ -7,13 +7,15 @@ from __future__ import annotations -import argparse import json import sys from pathlib import Path +from typing import Any +import jsonargparse -def add_tune_subparser(subparsers: argparse._SubParsersAction) -> None: + +def add_tune_subparser(subparsers: Any) -> None: """Add the tune subcommand to the argument parser.""" parser = subparsers.add_parser( "tune", @@ -22,7 +24,7 @@ def add_tune_subparser(subparsers: argparse._SubParsersAction) -> None: "Run Optuna-based hyperparameter optimisation for a registered " "tracker using pre-computed detections and ground-truth MOT files." ), - formatter_class=argparse.RawDescriptionHelpFormatter, + formatter_class=jsonargparse.DefaultHelpFormatter, ) parser.add_argument( @@ -91,7 +93,7 @@ def add_tune_subparser(subparsers: argparse._SubParsersAction) -> None: parser.set_defaults(func=run_tune) -def run_tune(args: argparse.Namespace) -> int: +def run_tune(args: jsonargparse.Namespace) -> int: """Execute the tune command.""" return tune( tracker=args.tracker, diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..57226e88 --- /dev/null +++ b/tests/cli/__init__.py @@ -0,0 +1,5 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ diff --git a/tests/scripts/test_download.py b/tests/cli/test_download.py similarity index 97% rename from tests/scripts/test_download.py rename to tests/cli/test_download.py index 94b3f573..7f19ff0f 100644 --- a/tests/scripts/test_download.py +++ b/tests/cli/test_download.py @@ -11,12 +11,12 @@ import pytest -from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR -from trackers.scripts.download import ( +from trackers.cli.download import ( _print_available, _run_download, add_download_subparser, ) +from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR def _parse_args(argv: list[str]) -> argparse.Namespace: @@ -73,7 +73,7 @@ def test_list_triggers_print(self) -> None: """--list calls _print_available and returns 0.""" args = _parse_args(["download", "--list"]) - with patch("trackers.scripts.download._print_available") as mock_print: + with patch("trackers.cli.download._print_available") as mock_print: rc = _run_download(args) assert rc == 0 mock_print.assert_called_once() @@ -82,7 +82,7 @@ def test_list_takes_precedence_over_dataset(self) -> None: """--list wins over dataset positional.""" args = _parse_args(["download", "mot17", "--list"]) - with patch("trackers.scripts.download._print_available") as mock_print: + with patch("trackers.cli.download._print_available") as mock_print: rc = _run_download(args) assert rc == 0 mock_print.assert_called_once() diff --git a/tests/scripts/test_progress.py b/tests/cli/test_progress.py similarity index 98% rename from tests/scripts/test_progress.py rename to tests/cli/test_progress.py index 91486afe..85d3bab3 100644 --- a/tests/scripts/test_progress.py +++ b/tests/cli/test_progress.py @@ -17,7 +17,7 @@ import pytest from rich.console import Console -from trackers.scripts.progress import ( +from trackers.cli.progress import ( _classify_source, _format_time, _SourceInfo, @@ -129,7 +129,7 @@ def test_video_with_zero_frame_count(self) -> None: cv2.CAP_PROP_FPS: 30.0, }.get(prop, 0.0) - with patch("trackers.scripts.progress.cv2.VideoCapture", return_value=mock_cap): + with patch("trackers.cli.progress.cv2.VideoCapture", return_value=mock_cap): info = _classify_source("some_video.mp4") assert info.source_type == "video" diff --git a/tests/scripts/test_track.py b/tests/cli/test_track.py similarity index 99% rename from tests/scripts/test_track.py rename to tests/cli/test_track.py index be3867ed..5ac9708e 100644 --- a/tests/scripts/test_track.py +++ b/tests/cli/test_track.py @@ -12,7 +12,7 @@ import pytest import supervision as sv -from trackers.scripts.track import ( +from trackers.cli.track import ( _format_labels, _init_annotators, _resolve_class_filter, diff --git a/tests/scripts/test_tune.py b/tests/cli/test_tune.py similarity index 97% rename from tests/scripts/test_tune.py rename to tests/cli/test_tune.py index 43e85799..1d6733b5 100644 --- a/tests/scripts/test_tune.py +++ b/tests/cli/test_tune.py @@ -4,7 +4,7 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -"""CLI-level tests for trackers/scripts/tune.py.""" +"""CLI-level tests for trackers/cli/tune.py.""" from __future__ import annotations @@ -15,7 +15,7 @@ import pytest -from trackers.scripts.tune import add_tune_subparser, run_tune, tune +from trackers.cli.tune import add_tune_subparser, run_tune, tune def _make_parser() -> tuple[argparse.ArgumentParser, argparse._SubParsersAction]: @@ -197,7 +197,7 @@ def test_delegates_to_tune_with_namespace_args(self, tmp_path: Path) -> None: seqmap=None, output=output_path, ) - with patch("trackers.scripts.tune.tune", return_value=0) as mock_tune: + with patch("trackers.cli.tune.tune", return_value=0) as mock_tune: result = run_tune(args) assert result == 0 mock_tune.assert_called_once_with( diff --git a/uv.lock b/uv.lock index 5c102b60..7fbdc6c1 100644 --- a/uv.lock +++ b/uv.lock @@ -1183,6 +1183,18 @@ version = "3.0.1" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/5e/73/e01e4c5e11ad0494f4407a3f623ad4d87714909f50b17a06ed121034ff6e/jsmin-3.0.1.tar.gz", hash = "sha256:c0959a121ef94542e807a674142606f7e90214a2b3d1eb17300244bbb5cc2bfc", size = 13925, upload-time = "2022-01-16T20:35:59.13Z" } +[[package]] +name = "jsonargparse" +version = "4.48.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fa/03/fb33f57f4987eb5eef2f221dbeccb482b6b221ae97161498ff2e4ce41c55/jsonargparse-4.48.0.tar.gz", hash = "sha256:128f0897951190a08820c282b92408e2e9a508ef6d439f02bdb87244171e77d8", size = 122074, upload-time = "2026-04-10T06:52:40.309Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/e9/c922101c1e80455d4b44b766b353dafc990da350228fc2515790e5949dd5/jsonargparse-4.48.0-py3-none-any.whl", hash = "sha256:c6a92fd71eb256437371750bb11f436b9c3294da2535f1b0406346816f04be16", size = 131277, upload-time = "2026-04-10T06:52:37.394Z" }, +] + [[package]] name = "keyring" version = "25.6.0" @@ -4026,6 +4038,7 @@ name = "trackers" version = "2.4.0" source = { editable = "." } dependencies = [ + { name = "jsonargparse" }, { name = "numpy" }, { name = "opencv-python" }, { name = "requests" }, @@ -4070,6 +4083,7 @@ mypy-types = [ [package.metadata] requires-dist = [ { name = "inference-models", marker = "extra == 'detection'", specifier = ">=0.19.0" }, + { name = "jsonargparse", specifier = ">=4.48.0" }, { name = "numpy", specifier = ">=2.0.2" }, { name = "opencv-python", specifier = ">=4.8.0" }, { name = "optuna", marker = "extra == 'tune'", specifier = ">=3.0.0" }, From e9cc3d901939414e7d245338e8f3fe65987791c2 Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Tue, 12 May 2026 16:42:15 +0200 Subject: [PATCH 02/13] =?UTF-8?q?rewrite=20cli=20with=20jsonargparse.CLI()?= =?UTF-8?q?=20=E2=80=94=201272=E2=86=92867=20lines=20(-32%)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - replace add_subparsers boilerplate with CLI({"track": ..., "eval": ..., "tune": ..., "download": ...}) on typed entry points; subcommands and --config YAML are derived automatically from function signatures - track.py: collapse 735→547 lines by dropping argparse plumbing; group dynamic per-tracker overrides under a TrackerParams dataclass (--tracker_params.=...) so jsonargparse renders one nested help group - __main__.py: 76→40 lines; main() now warns, dispatches via CLI, and returns the entry-point exit code - download.py: positional `dataset` becomes `--dataset` (jsonargparse has no positionals); `--list` renamed to `--list_available` - update tests/cli/test_download.py and test_tune.py to call the new typed functions directly and exercise the CLI via jsonargparse.CLI(...args=[...]) instead of constructing argparse subparsers - 498 non-integration tests pass; all five `--help` screens render --- Co-authored-by: Claude Code --- src/trackers/cli/__main__.py | 62 +--- src/trackers/cli/download.py | 91 ++--- src/trackers/cli/eval.py | 188 +++------- src/trackers/cli/track.py | 700 +++++++++++++---------------------- src/trackers/cli/tune.py | 114 +----- tests/cli/test_download.py | 117 ++---- tests/cli/test_tune.py | 210 ++++++----- 7 files changed, 502 insertions(+), 980 deletions(-) diff --git a/src/trackers/cli/__main__.py b/src/trackers/cli/__main__.py index afa12bb4..8e3c521c 100644 --- a/src/trackers/cli/__main__.py +++ b/src/trackers/cli/__main__.py @@ -5,71 +5,35 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +"""Command-line entry point for the trackers package.""" + from __future__ import annotations import sys import warnings -import jsonargparse +from jsonargparse import CLI + +from trackers.cli.download import download +from trackers.cli.eval import eval_cmd +from trackers.cli.track import track +from trackers.cli.tune import tune def main() -> int: - """Main entry point for the trackers CLI.""" - # Beta warning + """Dispatch to track / eval / tune / download via jsonargparse CLI.""" warnings.warn( "The trackers CLI is in beta. APIs may change in future releases.", UserWarning, stacklevel=2, ) - - parser = jsonargparse.ArgumentParser( + rc = CLI( + {"track": track, "eval": eval_cmd, "tune": tune, "download": download}, + as_positional=False, prog="trackers", description="Command-line tools for multi-object tracking.", - epilog="For more information, visit: https://github.com/roboflow/trackers", - ) - parser.add_argument( - "--version", - action="store_true", - help="Show version and exit.", - ) - parser.add_argument( - "--config", - action="config", - help="Path to a YAML/JSON config file with default argument values.", - ) - - subparsers = parser.add_subparsers( # type: ignore[var-annotated] - dest="command", - title="commands", - description="Available commands:", ) - - # Import and register subcommands - from trackers.cli.download import add_download_subparser - from trackers.cli.eval import add_eval_subparser - from trackers.cli.track import add_track_subparser - from trackers.cli.tune import add_tune_subparser - - add_download_subparser(subparsers) - add_eval_subparser(subparsers) - add_track_subparser(subparsers) - add_tune_subparser(subparsers) - - # Parse arguments - args = parser.parse_args() - - if args.version: - from importlib.metadata import version - - print(f"trackers {version('trackers')}") - return 0 - - if args.command is None: - parser.print_help() - return 0 - - # Execute the command - return args.func(args) + return int(rc) if rc is not None else 0 if __name__ == "__main__": diff --git a/src/trackers/cli/download.py b/src/trackers/cli/download.py index 7982fe05..73becaf5 100644 --- a/src/trackers/cli/download.py +++ b/src/trackers/cli/download.py @@ -5,12 +5,12 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +"""``trackers download`` subcommand — fetch benchmark tracking datasets.""" + from __future__ import annotations import sys -from typing import Any -import jsonargparse from rich.console import Console from rich.panel import Panel @@ -18,72 +18,51 @@ from trackers.datasets.manifest import _DATASETS -def add_download_subparser(subparsers: Any) -> None: - """Add the download subcommand to the argument parser.""" - parser = subparsers.add_parser( - "download", - help="Download benchmark tracking datasets.", - description="Download tracking datasets from the official trackers bucket.", - formatter_class=jsonargparse.DefaultHelpFormatter, - ) - - parser.add_argument( - "--list", - action="store_true", - help="List available datasets, splits, and asset types.", - ) - parser.add_argument( - "dataset", - nargs="?", - help="Dataset name (e.g. mot17, sportsmot).", - ) - parser.add_argument( - "--split", - help="Comma-separated splits to download (e.g. train,val,test). " - "If omitted, all available splits are downloaded.", - ) - parser.add_argument( - "--asset", - help="Comma-separated assets to download: annotations,frames,detections. " - "If omitted, all available assets are downloaded.", - ) - parser.add_argument( - "-o", - "--output", - default=_DEFAULT_OUTPUT_DIR, - help="Output directory (default: current directory).", - ) - parser.add_argument( - "--cache-dir", - default=_DEFAULT_CACHE_DIR, - help="Cache directory for downloaded ZIPs (default: ~/.cache/trackers).", - ) - - parser.set_defaults(func=_run_download) - - -def _run_download(args: jsonargparse.Namespace) -> int: - """Execute the download subcommand.""" - if args.list: +def download( + dataset: str | None = None, + split: str | None = None, + asset: str | None = None, + output: str = _DEFAULT_OUTPUT_DIR, + cache_dir: str = _DEFAULT_CACHE_DIR, + list_available: bool = False, +) -> int: + """Download benchmark tracking datasets from the official trackers bucket. + + Args: + dataset: Dataset name (e.g. ``mot17``, ``sportsmot``). Required unless + ``list_available`` is set. + split: Comma-separated splits to download (e.g. ``train,val,test``). + ``None`` selects every available split. + asset: Comma-separated assets to download (``annotations,frames,detections``). + ``None`` selects every available asset. + output: Output directory. Defaults to the current working directory. + cache_dir: Cache directory for downloaded ZIPs. + list_available: When ``True``, print the available datasets, splits, and + asset types, then exit. + + Returns: + Exit code: ``0`` on success, ``1`` on error. + """ + if list_available: _print_available() return 0 - if not args.dataset: - print("Please specify a dataset name or use --list.", file=sys.stderr) + if not dataset: + print("Please specify a dataset name or use --list_available.", file=sys.stderr) return 1 from trackers.datasets.download import download_dataset - split_list = [s.strip() for s in args.split.split(",")] if args.split else None - asset_list = [a.strip() for a in args.asset.split(",")] if args.asset else None + split_list = [s.strip() for s in split.split(",")] if split else None + asset_list = [a.strip() for a in asset.split(",")] if asset else None try: download_dataset( - dataset=args.dataset, + dataset=dataset, split=split_list, asset=asset_list, - output=args.output, - cache_dir=args.cache_dir, + output=output, + cache_dir=cache_dir, ) except Exception as e: print(f"Error: {e}", file=sys.stderr) diff --git a/src/trackers/cli/eval.py b/src/trackers/cli/eval.py index a5568c28..cd1c2576 100644 --- a/src/trackers/cli/eval.py +++ b/src/trackers/cli/eval.py @@ -5,166 +5,98 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +"""``trackers eval`` subcommand — evaluate tracker predictions against ground truth.""" + from __future__ import annotations import logging import sys from pathlib import Path -from typing import Any - -import jsonargparse - -def add_eval_subparser(subparsers: Any) -> None: - """Add the eval subcommand to the argument parser.""" - parser = subparsers.add_parser( - "eval", - help="Evaluate tracker predictions against ground truth.", - description=__doc__, - formatter_class=jsonargparse.DefaultHelpFormatter, - ) - - # Single sequence mode - single_group = parser.add_argument_group("single sequence evaluation") - single_group.add_argument( - "--gt", - type=Path, - metavar="PATH", - help="Path to ground truth file (MOT format).", - ) - single_group.add_argument( - "--tracker", - type=Path, - metavar="PATH", - help="Path to tracker predictions file (MOT format).", - ) - # Benchmark mode - bench_group = parser.add_argument_group("benchmark evaluation") - bench_group.add_argument( - "--gt-dir", - type=Path, - metavar="DIR", - help="Directory containing ground truth files.", - ) - bench_group.add_argument( - "--tracker-dir", - type=Path, - metavar="DIR", - help="Directory containing tracker prediction files.", - ) - bench_group.add_argument( - "--seqmap", - type=Path, - metavar="PATH", - help="Sequence map file listing sequences to evaluate.", - ) +def eval_cmd( + gt: Path | None = None, + tracker: Path | None = None, + gt_dir: Path | None = None, + tracker_dir: Path | None = None, + seqmap: Path | None = None, + metrics: list[str] | None = None, + threshold: float = 0.5, + columns: list[str] | None = None, + output: Path | None = None, +) -> int: + """Evaluate tracker predictions against ground-truth MOT files. + + Two modes: + + - Single sequence: pass ``gt`` and ``tracker``. + - Benchmark: pass ``gt_dir`` and ``tracker_dir`` (with optional ``seqmap``). + + Args: + gt: Ground-truth file (MOT format) for single-sequence mode. + tracker: Tracker predictions file (MOT format) for single-sequence mode. + gt_dir: Directory of ground-truth files for benchmark mode. + tracker_dir: Directory of tracker prediction files for benchmark mode. + seqmap: Sequence map listing sequences to evaluate. + metrics: Metrics to compute. Options: ``CLEAR``, ``HOTA``, ``Identity``. + Defaults to ``["CLEAR"]``. + threshold: IoU threshold for CLEAR and Identity matching. + columns: Metric columns to display. ``None`` auto-selects from + available metrics. + output: Output JSON file for results. + + Returns: + Exit code: ``0`` on success, ``1`` on error. + """ + metrics = metrics or ["CLEAR"] - # Common options - parser.add_argument( - "--metrics", - nargs="+", - default=["CLEAR"], - choices=["CLEAR", "HOTA", "Identity"], - help="Metrics to compute. Default: CLEAR. Options: CLEAR, HOTA, Identity", - ) - parser.add_argument( - "--threshold", - type=float, - default=0.5, - help="IoU threshold for CLEAR and Identity matching. Default: 0.5", - ) - parser.add_argument( - "--columns", - nargs="+", - default=None, - metavar="COL", - help=( - "Metric columns to display. Default: auto-selected based on metrics. " - "CLEAR: MOTA, MOTP, MODA, CLR_Re, CLR_Pr, MTR, PTR, MLR, sMOTA, " - "CLR_TP, CLR_FN, CLR_FP, IDSW, MT, PT, ML, Frag. " - "HOTA: HOTA, DetA, AssA, DetRe, DetPr, AssRe, AssPr, LocA. " - "Identity: IDF1, IDR, IDP, IDTP, IDFN, IDFP" - ), - ) - parser.add_argument( - "--output", - "-o", - type=Path, - metavar="PATH", - help="Output file for results (JSON format).", - ) - - parser.set_defaults(func=run_eval) - - -def run_eval(args: jsonargparse.Namespace) -> int: - """Execute the eval command.""" - # Configure logging to show detection info logging.basicConfig( level=logging.INFO, format="%(message)s", handlers=[logging.StreamHandler(sys.stderr)], ) - # Validate arguments - single_mode = args.gt is not None and args.tracker is not None - benchmark_mode = args.gt_dir is not None and args.tracker_dir is not None + single_mode = gt is not None and tracker is not None + benchmark_mode = gt_dir is not None and tracker_dir is not None if not single_mode and not benchmark_mode: - print( - "Error: Must specify either --gt/--tracker or --gt-dir/--tracker-dir", - file=sys.stderr, - ) + print("Error: Must specify either --gt/--tracker or --gt_dir/--tracker_dir", file=sys.stderr) return 1 if single_mode and benchmark_mode: - print( - "Error: Cannot use both single sequence and benchmark mode", - file=sys.stderr, - ) + print("Error: Cannot use both single sequence and benchmark mode", file=sys.stderr) return 1 - # Columns: None means auto-select based on available metrics - columns = args.columns - - # Import evaluation functions from trackers.eval import evaluate_mot_sequence, evaluate_mot_sequences try: if single_mode: + assert gt is not None and tracker is not None # noqa: S101 — narrows for type checker seq_result = evaluate_mot_sequence( - gt_path=args.gt, - tracker_path=args.tracker, - metrics=args.metrics, - threshold=args.threshold, + gt_path=gt, + tracker_path=tracker, + metrics=metrics, + threshold=threshold, ) print(seq_result.table(columns=columns)) - - # Save results if output specified - if args.output: - args.output.parent.mkdir(parents=True, exist_ok=True) - args.output.write_text(seq_result.json()) - print(f"\nResults saved to: {args.output}") + if output: + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(seq_result.json()) + print(f"\nResults saved to: {output}") else: + assert gt_dir is not None and tracker_dir is not None # noqa: S101 — narrows for type checker bench_result = evaluate_mot_sequences( - gt_dir=args.gt_dir, - tracker_dir=args.tracker_dir, - seqmap=args.seqmap, - metrics=args.metrics, - threshold=args.threshold, + gt_dir=gt_dir, + tracker_dir=tracker_dir, + seqmap=seqmap, + metrics=metrics, + threshold=threshold, ) print(bench_result.table(columns=columns)) - - # Save results if output specified - if args.output: - bench_result.save(args.output) - print(f"\nResults saved to: {args.output}") - - except FileNotFoundError as e: - print(f"Error: {e}", file=sys.stderr) - return 1 - except ValueError as e: + if output: + bench_result.save(output) + print(f"\nResults saved to: {output}") + except (FileNotFoundError, ValueError) as e: print(f"Error: {e}", file=sys.stderr) return 1 diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py index d76bc73a..e5c6d394 100644 --- a/src/trackers/cli/track.py +++ b/src/trackers/cli/track.py @@ -5,14 +5,16 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +"""``trackers track`` subcommand — run a detector + tracker over a video source.""" + from __future__ import annotations import sys -from contextlib import nullcontext, suppress +from contextlib import nullcontext +from dataclasses import asdict, dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -import jsonargparse import numpy as np import supervision as sv @@ -52,292 +54,173 @@ ) -def add_track_subparser(subparsers: Any) -> None: - """Add the track subcommand to the argument parser.""" - parser = subparsers.add_parser( - "track", - help="Track objects in video using detection and tracking.", - description=__doc__, - formatter_class=jsonargparse.DefaultHelpFormatter, - ) - - # Source options - source_group = parser.add_argument_group("source") - source_group.add_argument( - "--source", - type=str, - default=None, - metavar="PATH", - help="Video file, webcam index (0), RTSP URL, or image directory.", - ) - - # Detection options (mutually exclusive) - detection_group = parser.add_argument_group("detection") - det_mutex = detection_group.add_mutually_exclusive_group(required=False) - det_mutex.add_argument( - "--model", - type=str, - default=DEFAULT_MODEL, - metavar="ID", - help=( - "Model ID for detection. Pretrained: rfdetr-nano, rfdetr-base, etc. " - f"Custom: workspace/project/version. Default: {DEFAULT_MODEL}" - ), - ) - det_mutex.add_argument( - "--detections", - type=Path, - metavar="PATH", - help="Load pre-computed detections from MOT format file.", - ) - - # Model options - model_group = parser.add_argument_group("model options") - model_group.add_argument( - "--model.confidence", - type=float, - default=DEFAULT_CONFIDENCE, - dest="model_confidence", - metavar="FLOAT", - help=f"Detection confidence threshold. Default: {DEFAULT_CONFIDENCE}", - ) - model_group.add_argument( - "--model.device", - type=str, - default=DEFAULT_DEVICE, - dest="model_device", - metavar="DEVICE", - help=f"Device: auto, cpu, cuda, cuda:0, mps. Default: {DEFAULT_DEVICE}", - ) - model_group.add_argument( - "--model.api_key", - type=str, - default=None, - dest="model_api_key", - metavar="KEY", - help="Roboflow API key for custom models.", - ) - - # Filtering options - filter_group = parser.add_argument_group("filtering") - filter_group.add_argument( - "--classes", - type=str, - default=None, - metavar="NAMES_OR_IDS", - help="Filter by class names or IDs (comma-separated, e.g., person,car).", - ) - filter_group.add_argument( - "--track_ids", - type=str, - default=None, - metavar="IDS", - help="Filter output by track IDs (comma-separated, e.g., 1,3,5)", - ) - - # Tracker options - tracker_group = parser.add_argument_group("tracker options") - available_trackers = BaseTracker._registered_trackers() - tracker_group.add_argument( - "--tracker", - type=str, - default=DEFAULT_TRACKER, - choices=available_trackers if available_trackers else [DEFAULT_TRACKER, "sort"], - metavar="ID", - help=f"Tracking algorithm. Default: {DEFAULT_TRACKER}", - ) - - # Add dynamic tracker parameters - _add_tracker_params(tracker_group) - - # Output options - output_group = parser.add_argument_group("output") - output_group.add_argument( - "-o", - "--output", - type=Path, - default=None, - metavar="PATH", - help="Output video file path.", - ) - output_group.add_argument( - "--mot-output", - type=Path, - default=None, - dest="mot_output", - metavar="PATH", - help="Output MOT format file path.", - ) - output_group.add_argument( - "--overwrite", - action="store_true", - help="Overwrite existing output files.", - ) - - # Visualization options - vis_group = parser.add_argument_group("visualization") - vis_group.add_argument( - "--display", - action="store_true", - help="Show preview window.", - ) - vis_group.add_argument( - "--show-boxes", - action="store_true", - default=True, - dest="show_boxes", - help="Draw bounding boxes. Default: True", - ) - vis_group.add_argument( - "--no-boxes", - action="store_false", - dest="show_boxes", - help="Disable bounding boxes.", - ) - vis_group.add_argument( - "--show-masks", - action="store_true", - dest="show_masks", - help="Draw segmentation masks (seg models only).", - ) - vis_group.add_argument( - "--show-labels", - action="store_true", - dest="show_labels", - help="Show class labels.", - ) - vis_group.add_argument( - "--show-ids", - action="store_true", - default=True, - dest="show_ids", - help="Show track IDs. Default: True", - ) - vis_group.add_argument( - "--no-ids", - action="store_false", - dest="show_ids", - help="Disable track IDs.", - ) - vis_group.add_argument( - "--show-confidence", - action="store_true", - dest="show_confidence", - help="Show confidence scores.", - ) - vis_group.add_argument( - "--show-trajectories", - action="store_true", - dest="show_trajectories", - help="Draw track trajectories.", - ) - - parser.set_defaults(func=run_track) - - -def _add_tracker_params(group: Any) -> None: - """Add tracker-specific parameters from registry to argument group.""" - for tracker_id in BaseTracker._registered_trackers(): - info = BaseTracker._lookup_tracker(tracker_id) - if info is None: - continue - - for param_name, param_info in info.parameters.items(): - arg_name = f"--tracker.{param_name}" - dest_name = f"tracker_{param_name}" - - kwargs: dict = { - "dest": dest_name, - "default": param_info.default_value, - "help": f"{param_info.description} Default: {param_info.default_value}", - } - - if param_info.param_type is bool: - kwargs["action"] = "store_false" if param_info.default_value else "store_true" - else: - kwargs["type"] = param_info.param_type - kwargs["metavar"] = param_info.param_type.__name__.upper() +@dataclass +class TrackerParams: + """Optional tracker-specific parameters. + + Union of parameters across all registered trackers; each tracker only + receives the keys it knows about. Fields left as ``None`` are dropped + before instantiation so the tracker's own defaults apply. + + Attributes: + lost_track_buffer: Frames to keep a lost track before discarding. + frame_rate: Source frame rate for time-based logic. + track_activation_threshold: Detection score needed to spawn a track. + minimum_consecutive_frames: Consecutive matches to confirm a track. + minimum_iou_threshold: IoU threshold for SORT/OC-SORT association. + minimum_iou_threshold_first_assoc: BoT-SORT first-stage IoU. + minimum_iou_threshold_second_assoc: BoT-SORT second-stage IoU. + minimum_iou_threshold_unconfirmed_assoc: BoT-SORT unconfirmed IoU. + high_conf_det_threshold: High-confidence detection threshold. + direction_consistency_weight: OC-SORT direction consistency weight. + delta_t: OC-SORT velocity delta horizon. + enable_cmc: BoT-SORT camera motion compensation toggle. + cmc_method: BoT-SORT CMC method name. + cmc_downscale: BoT-SORT CMC downscale factor. + instant_first_frame_activation: BoT-SORT first-frame activation toggle. + """ - with suppress(Exception): - group.add_argument(arg_name, **kwargs) + lost_track_buffer: int | None = None + frame_rate: float | None = None + track_activation_threshold: float | None = None + minimum_consecutive_frames: int | None = None + minimum_iou_threshold: float | None = None + minimum_iou_threshold_first_assoc: float | None = None + minimum_iou_threshold_second_assoc: float | None = None + minimum_iou_threshold_unconfirmed_assoc: float | None = None + high_conf_det_threshold: float | None = None + direction_consistency_weight: float | None = None + delta_t: int | None = None + enable_cmc: bool | None = None + cmc_method: str | None = None + cmc_downscale: int | None = None + instant_first_frame_activation: bool | None = None + + +def track( + source: str | None = None, + model: str = DEFAULT_MODEL, + detections: Path | None = None, + confidence: float = DEFAULT_CONFIDENCE, + device: str = DEFAULT_DEVICE, + api_key: str | None = None, + classes: str | None = None, + track_ids: str | None = None, + tracker: str = DEFAULT_TRACKER, + tracker_params: TrackerParams | None = None, + output: Path | None = None, + mot_output: Path | None = None, + overwrite: bool = False, + display: bool = False, + show_boxes: bool = True, + show_masks: bool = False, + show_labels: bool = False, + show_ids: bool = True, + show_confidence: bool = False, + show_trajectories: bool = False, +) -> int: + """Run detection and tracking over a video, webcam, RTSP, or image directory. + Args: + source: Video file, webcam index (e.g. ``"0"``), RTSP URL, or image + directory. Required unless ``detections`` is supplied. + model: Detection model ID (e.g. ``rfdetr-nano``) or + ``workspace/project/version`` for a Roboflow custom model. + detections: Path to a pre-computed MOT-format detections file. Mutually + exclusive with ``model``. + confidence: Detection confidence threshold. + device: Inference device: ``auto``, ``cpu``, ``cuda``, ``cuda:0``, + ``mps``. + api_key: Roboflow API key for custom models. + classes: Comma-separated class names or IDs to keep + (e.g. ``person,car``). + track_ids: Comma-separated track IDs to keep in the output + (e.g. ``1,3,5``). + tracker: Tracking algorithm ID. Discoverable via + ``BaseTracker._registered_trackers()``. + tracker_params: Optional tracker parameters; only fields matching the + chosen tracker's ``__init__`` are forwarded. + output: Output annotated-video path. + mot_output: Output MOT-format predictions path. + overwrite: Overwrite existing output files. + display: Show a preview window during tracking. + show_boxes: Draw bounding boxes. + show_masks: Draw segmentation masks (segmentation models only). + show_labels: Draw class labels. + show_ids: Draw track IDs. + show_confidence: Draw confidence scores. + show_trajectories: Draw track trajectories (trails). -def run_track(args: jsonargparse.Namespace) -> int: - """Execute the track command.""" - needs_frames = args.output or args.display + Returns: + Exit code: ``0`` on success, ``1`` on validation error. + """ + needs_frames = output is not None or display - if args.source is None and not args.detections: - print( - "Error: --source is required when not using --detections.", - file=sys.stderr, - ) + if source is None and detections is None: + print("Error: --source is required when not using --detections.", file=sys.stderr) return 1 - - if needs_frames and args.source is None: - print( - "Error: --source is required when using --output or --display.", - file=sys.stderr, - ) + if needs_frames and source is None: + print("Error: --source is required when using --output or --display.", file=sys.stderr) return 1 - # Validate output paths - if args.output: - _validate_output_path(_resolve_video_output_path(args.output), overwrite=args.overwrite) - if args.mot_output: - _validate_output_path(args.mot_output, overwrite=args.overwrite) + if output: + _validate_output_path(_resolve_video_output_path(output), overwrite=overwrite) + if mot_output: + _validate_output_path(mot_output, overwrite=overwrite) - # Create detection source - if args.detections: - model = None - detections_data = load_mot_file(args.detections) + if detections is not None: + model_obj: AnyModel | None = None + detections_data: dict | None = load_mot_file(detections) class_names: list[str] = [] else: - model = _init_model( - args.model, - device=args.model_device, - api_key=args.model_api_key, - ) + model_obj = _init_model(model, device=device, api_key=api_key) detections_data = None - class_names = getattr(model, "class_names", []) - - # Resolve class filter (names and/or integer IDs) - class_filter = _resolve_class_filter(args.classes, class_names) + class_names = getattr(model_obj, "class_names", []) - track_id_filter = _resolve_track_id_filter(args.track_ids) + class_filter = _resolve_class_filter(classes, class_names) + track_id_filter = _resolve_track_id_filter(track_ids) + tracker_obj = _init_tracker(tracker, tracker_params) - # Create tracker - tracker_params = _extract_tracker_params(args.tracker, args) - tracker = _init_tracker(args.tracker, **tracker_params) - - if args.source is not None: + if source is not None: return _run_with_source( - args, - model, - detections_data, - class_names, - class_filter, - track_id_filter, - tracker, - ) - else: - return _run_frameless( - args, - detections_data, - class_filter, - track_id_filter, - tracker, + source=source, + model=model_obj, + confidence=confidence, + detections_data=detections_data, + class_names=class_names, + class_filter=class_filter, + track_id_filter=track_id_filter, + tracker=tracker_obj, + output=output, + mot_output=mot_output, + display=display, + show_boxes=show_boxes, + show_masks=show_masks, + show_labels=show_labels, + show_ids=show_ids, + show_confidence=show_confidence, + show_trajectories=show_trajectories, ) + return _run_frameless( + detections_data=detections_data, + class_filter=class_filter, + track_id_filter=track_id_filter, + tracker=tracker_obj, + mot_output=mot_output, + ) + def _run_frameless( - args: jsonargparse.Namespace, + *, detections_data: dict | None, class_filter: list[int] | None, track_id_filter: list[int] | None, tracker: BaseTracker, + mot_output: Path | None, ) -> int: - """Run tracking from pre-computed detections without frame source.""" - if detections_data is None or not detections_data: + """Run tracking from pre-computed detections without a frame source.""" + if not detections_data: print("Error: No detections found in file.", file=sys.stderr) return 1 @@ -345,33 +228,27 @@ def _run_frameless( source_info = _SourceInfo(source_type="video", total_frames=total_frames) try: - with ( - _MOTOutput(args.mot_output) as mot, - _TrackingProgress(source_info) as progress, - ): - interrupted = False + with _MOTOutput(mot_output) as mot, _TrackingProgress(source_info) as progress: for frame_idx in range(1, total_frames + 1): if frame_idx in detections_data: - detections = _mot_frame_to_detections(detections_data[frame_idx]) + dets = _mot_frame_to_detections(detections_data[frame_idx]) else: - detections = sv.Detections.empty() + dets = sv.Detections.empty() - if class_filter is not None and len(detections) > 0: - mask = np.isin(detections.class_id, class_filter) - detections = detections[mask] # type: ignore[assignment] + if class_filter is not None and len(dets) > 0 and dets.class_id is not None: + mask = np.isin(dets.class_id, class_filter) + dets = dets[mask] # type: ignore[assignment] - tracked = tracker.update(detections) + tracked = tracker.update(dets) - if track_id_filter is not None and len(tracked) > 0: - if tracked.tracker_id is not None: - mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) - tracked = tracked[mask] + if track_id_filter is not None and len(tracked) > 0 and tracked.tracker_id is not None: + mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) + tracked = tracked[mask] # type: ignore[assignment] mot.write(frame_idx, tracked) progress.update() - progress.complete(interrupted=interrupted) - + progress.complete(interrupted=False) except KeyboardInterrupt: pass @@ -379,102 +256,96 @@ def _run_frameless( def _run_with_source( - args: jsonargparse.Namespace, + *, + source: str, model: AnyModel | None, + confidence: float, detections_data: dict | None, class_names: list[str], class_filter: list[int] | None, track_id_filter: list[int] | None, tracker: BaseTracker, + output: Path | None, + mot_output: Path | None, + display: bool, + show_boxes: bool, + show_masks: bool, + show_labels: bool, + show_ids: bool, + show_confidence: bool, + show_trajectories: bool, ) -> int: """Run tracking with a frame source (video, webcam, images).""" - frame_gen = frames_from_source(args.source) - source_info = _classify_source(args.source) + frame_gen = frames_from_source(source) + source_info = _classify_source(source) - # Setup annotators annotators, label_annotator = _init_annotators( - show_boxes=args.show_boxes, - show_masks=args.show_masks, - show_labels=args.show_labels, - show_ids=args.show_ids, - show_confidence=args.show_confidence, + show_boxes=show_boxes, + show_masks=show_masks, + show_labels=show_labels, + show_ids=show_ids, + show_confidence=show_confidence, ) - trace_annotator = None - if args.show_trajectories: - trace_annotator = sv.TraceAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - - display_ctx = _DisplayWindow() if args.display else nullcontext() + trace_annotator = ( + sv.TraceAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK) if show_trajectories else None + ) + display_ctx = _DisplayWindow() if display else nullcontext() try: with ( - _VideoOutput( - args.output, - fps=source_info.fps or _DEFAULT_OUTPUT_FPS, - ) as video, - _MOTOutput(args.mot_output) as mot, - display_ctx as display, + _VideoOutput(output, fps=source_info.fps or _DEFAULT_OUTPUT_FPS) as video, + _MOTOutput(mot_output) as mot, + display_ctx as display_win, _TrackingProgress(source_info) as progress, ): interrupted = False for frame_idx, frame in frame_gen: - # Get detections if model is not None: - detections = _run_model(model, frame, args.model_confidence) + dets = _run_model(model, frame, confidence) elif detections_data is not None and frame_idx in detections_data: - detections = _mot_frame_to_detections(detections_data[frame_idx]) + dets = _mot_frame_to_detections(detections_data[frame_idx]) else: - detections = sv.Detections.empty() + dets = sv.Detections.empty() - # Filter by class - if class_filter is not None and len(detections) > 0: - mask = np.isin(detections.class_id, class_filter) - detections = detections[mask] # type: ignore[assignment] + if class_filter is not None and len(dets) > 0 and dets.class_id is not None: + mask = np.isin(dets.class_id, class_filter) + dets = dets[mask] # type: ignore[assignment] - # Run tracker - tracked = tracker.update(detections, frame) + tracked = tracker.update(dets, frame) - # Filter by track ID - if track_id_filter is not None and len(tracked) > 0: - if tracked.tracker_id is not None: - mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) - tracked = tracked[mask] + if track_id_filter is not None and len(tracked) > 0 and tracked.tracker_id is not None: + mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) + tracked = tracked[mask] # type: ignore[assignment] - # Write MOT output mot.write(frame_idx, tracked) - progress.update() - # Annotate and display/save frame - if args.display or args.output: + if display or output: annotated = frame.copy() if trace_annotator is not None: annotated = trace_annotator.annotate(annotated, tracked) - for annotator in annotators: - annotated = annotator.annotate(annotated, tracked) + for ann in annotators: + annotated = ann.annotate(annotated, tracked) if label_annotator is not None: labeled = tracked[tracked.tracker_id != -1] labels = _format_labels( labeled, class_names, - show_ids=args.show_ids, - show_labels=args.show_labels, - show_confidence=args.show_confidence, + show_ids=show_ids, + show_labels=show_labels, + show_confidence=show_confidence, ) annotated = label_annotator.annotate(annotated, labeled, labels) video.write(annotated) - if display is not None: - display.show(annotated) - if display.quit_requested: + if display_win is not None: + display_win.show(annotated) + if display_win.quit_requested: interrupted = True break progress.complete(interrupted=interrupted) - except KeyboardInterrupt: pass @@ -482,81 +353,68 @@ def _run_with_source( def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None: - """Resolve a comma-separated `--track-ids` value to a list of integer IDs. + """Resolve a comma-separated ``track_ids`` string to a list of integer IDs. Args: - track_ids_arg: Raw `--track-ids` string (e.g. `"1,3,5"`). `None` + track_ids_arg: Raw ``--track_ids`` string (e.g. ``"1,3,5"``). ``None`` means no filter. Returns: - List of integer track IDs, or `None` when no valid filter remains. + List of integer track IDs, or ``None`` when no valid filter remains. """ if not track_ids_arg: return None track_ids: list[int] = [] - for token in track_ids_arg.split(","): - token = token.strip() + for raw in track_ids_arg.split(","): + token = raw.strip() try: track_ids.append(int(token)) except ValueError: - print( - f"Warning: '{token}' is not a valid track ID, skipping.", - file=sys.stderr, - ) - return track_ids if track_ids else None + print(f"Warning: '{token}' is not a valid track ID, skipping.", file=sys.stderr) + return track_ids or None -def _resolve_class_filter( - classes_arg: str | None, - class_names: list[str], -) -> list[int] | None: - """Resolve a comma-separated `--classes` value to a list of integer IDs. +def _resolve_class_filter(classes_arg: str | None, class_names: list[str]) -> list[int] | None: + """Resolve a comma-separated ``classes`` string to a list of integer IDs. - Each token is checked independently: if it parses as an `int` it is used - directly as a class ID; otherwise it is looked up by name in *class_names*. + Each token is checked independently: if it parses as an ``int`` it is used + directly as a class ID; otherwise it is looked up by name in ``class_names``. Unknown names are printed as warnings and skipped. Args: - classes_arg: Raw `--classes` string (e.g. `"person,car"` or - `"0,2"` or `"person,2"`). `None` means no filter. + classes_arg: Raw ``--classes`` string (e.g. ``"person,car"`` or + ``"0,2"`` or ``"person,2"``). ``None`` means no filter. class_names: Ordered list of class names where the index equals the class ID (as provided by the model). Returns: - List of integer class IDs, or `None` when no valid filter remains. + List of integer class IDs, or ``None`` when no valid filter remains. """ if not classes_arg: return None - requested = [token.strip() for token in classes_arg.split(",")] name_to_id = {name: i for i, name in enumerate(class_names)} class_filter: list[int] = [] - for token in requested: + for raw in classes_arg.split(","): + token = raw.strip() try: class_filter.append(int(token)) except ValueError: if token in name_to_id: class_filter.append(name_to_id[token]) else: - print( - f"Warning: class '{token}' not found in model class list, skipping.", - file=sys.stderr, - ) - return class_filter if class_filter else None + print(f"Warning: class '{token}' not found in model class list, skipping.", file=sys.stderr) + return class_filter or None -def _init_model( - model_id: str, - *, - device: str = DEFAULT_DEVICE, - api_key: str | None = None, -) -> AnyModel: - """Load detection model via inference-models. +def _init_model(model_id: str, *, device: str = DEFAULT_DEVICE, api_key: str | None = None) -> AnyModel: + """Load detection model via ``inference-models``. Args: - model_id: Model identifier (e.g., 'rfdetr-nano' or 'workspace/project/version'). - device: Device to load model on ('auto', 'cpu', 'cuda', 'mps'). + model_id: Model identifier (e.g. ``rfdetr-nano`` or + ``workspace/project/version``). + device: Device to load model on (``auto``, ``cpu``, ``cuda``, ``mps``). api_key: Roboflow API key for custom models. Returns: @@ -573,72 +431,45 @@ def _init_model( raise SystemExit(1) from e resolved_device = _best_device() if device == DEFAULT_DEVICE else device - - return AutoModel.from_pretrained( - model_id, - api_key=api_key, - device=resolved_device, - ) + return AutoModel.from_pretrained(model_id, api_key=api_key, device=resolved_device) def _run_model(model: AnyModel, frame: np.ndarray, confidence: float) -> sv.Detections: - """Run model inference and return sv.Detections.""" + """Run model inference, filter by confidence, return ``sv.Detections``.""" predictions = model(frame) if not predictions: return sv.Detections.empty() - detections = predictions[0].to_supervision() - - # Filter by confidence - if len(detections) > 0 and detections.confidence is not None: - mask = detections.confidence >= confidence - detections = detections[mask] - - return detections - - -def _extract_tracker_params(tracker_id: str, args: jsonargparse.Namespace) -> dict[str, object]: - """Extract tracker parameters from CLI args. - - Args: - tracker_id: Registered tracker name. - args: Parsed CLI arguments. - - Returns: - Dictionary of tracker parameters with non-None values. - """ - info = BaseTracker._lookup_tracker(tracker_id) - if info is None: - return {} + dets = predictions[0].to_supervision() + if len(dets) > 0 and dets.confidence is not None: + dets = dets[dets.confidence >= confidence] + return dets - params = {} - for param_name in info.parameters: - dest_name = f"tracker_{param_name}" - if hasattr(args, dest_name): - value = getattr(args, dest_name) - if value is not None: - params[param_name] = value - return params +def _init_tracker(tracker_id: str, params: TrackerParams | None) -> BaseTracker: + """Create a tracker instance from the registry. -def _init_tracker(tracker_id: str, **kwargs: object) -> BaseTracker: - """Create tracker instance from registry. + Only fields the chosen tracker accepts are forwarded; ``None`` values are + always dropped so the tracker's own defaults apply. Args: - tracker_id: Registered tracker name (e.g., 'bytetrack', 'sort'). - **kwargs: Tracker-specific parameters. + tracker_id: Registered tracker name (e.g. ``bytetrack``, ``sort``). + params: Optional tracker parameter overrides. Returns: - Initialized tracker instance. + Initialised tracker instance. Raises: - ValueError: If tracker_id is not registered. + ValueError: If ``tracker_id`` is not registered. """ info = BaseTracker._lookup_tracker(tracker_id) if info is None: available = ", ".join(BaseTracker._registered_trackers()) raise ValueError(f"Unknown tracker: '{tracker_id}'. Available: {available}") + raw = asdict(params) if params is not None else {} + accepted = set(info.parameters) + kwargs = {k: v for k, v in raw.items() if v is not None and k in accepted} return info.tracker_class(**kwargs) @@ -649,38 +480,26 @@ def _init_annotators( show_ids: bool = False, show_confidence: bool = False, ) -> tuple[list, sv.LabelAnnotator | None]: - """Initialize supervision annotators based on display options. + """Initialise supervision annotators based on display options. Args: - show_boxes: Create BoxAnnotator. - show_masks: Create MaskAnnotator. - show_labels: Include class labels (triggers LabelAnnotator). - show_ids: Include track IDs (triggers LabelAnnotator). - show_confidence: Include confidence scores (triggers LabelAnnotator). + show_boxes: Create ``BoxAnnotator``. + show_masks: Create ``MaskAnnotator``. + show_labels: Include class labels (triggers ``LabelAnnotator``). + show_ids: Include track IDs (triggers ``LabelAnnotator``). + show_confidence: Include confidence scores (triggers ``LabelAnnotator``). Returns: - Tuple of (annotators list, label_annotator or None). - Label annotator is separate because it needs custom labels per frame. + Tuple of (annotators list, label_annotator or None). Label annotator is + separate because it needs custom labels per frame. """ annotators: list = [] label_annotator: sv.LabelAnnotator | None = None if show_boxes: - annotators.append( - sv.BoxAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - ) - + annotators.append(sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK)) if show_masks: - annotators.append( - sv.MaskAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - ) - + annotators.append(sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK)) if show_labels or show_ids or show_confidence: label_annotator = sv.LabelAnnotator( color=COLOR_PALETTE, @@ -688,7 +507,6 @@ def _init_annotators( text_position=sv.Position.TOP_LEFT, color_lookup=sv.ColorLookup.TRACK, ) - return annotators, label_annotator @@ -713,23 +531,17 @@ def _format_labels( List of label strings, one per detection. """ labels = [] - for i in range(len(detections)): - parts = [] - + parts: list[str] = [] if show_ids and detections.tracker_id is not None: parts.append(f"#{int(detections.tracker_id[i])}") - if show_labels and detections.class_id is not None: class_id = int(detections.class_id[i]) if class_names and 0 <= class_id < len(class_names): parts.append(class_names[class_id]) else: parts.append(str(class_id)) - if show_confidence and detections.confidence is not None: parts.append(f"{detections.confidence[i]:.2f}") - labels.append(" ".join(parts)) - return labels diff --git a/src/trackers/cli/tune.py b/src/trackers/cli/tune.py index a8471871..bee54650 100644 --- a/src/trackers/cli/tune.py +++ b/src/trackers/cli/tune.py @@ -5,107 +5,13 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +"""``trackers tune`` subcommand — Optuna hyperparameter optimisation.""" + from __future__ import annotations import json import sys from pathlib import Path -from typing import Any - -import jsonargparse - - -def add_tune_subparser(subparsers: Any) -> None: - """Add the tune subcommand to the argument parser.""" - parser = subparsers.add_parser( - "tune", - help="Tune tracker hyperparameters via Optuna.", - description=( - "Run Optuna-based hyperparameter optimisation for a registered " - "tracker using pre-computed detections and ground-truth MOT files." - ), - formatter_class=jsonargparse.DefaultHelpFormatter, - ) - - parser.add_argument( - "--tracker", - required=True, - metavar="ID", - help="Tracker ID to tune (e.g. bytetrack, sort, ocsort).", - ) - parser.add_argument( - "--gt-dir", - type=Path, - required=True, - metavar="DIR", - help="Directory containing ground-truth MOT files.", - ) - parser.add_argument( - "--detections-dir", - type=Path, - required=True, - metavar="DIR", - help=("Directory containing pre-computed detection files in MOT flat format (one {seq}.txt per sequence)."), - ) - parser.add_argument( - "--objective", - default="HOTA", - choices=["MOTA", "HOTA", "IDF1"], - help="Scalar metric to maximise. Default: HOTA.", - ) - parser.add_argument( - "--n-trials", - type=int, - default=100, - metavar="N", - help="Number of Optuna trials to run. Default: 100.", - ) - parser.add_argument( - "--metrics", - nargs="+", - default=["CLEAR"], - choices=["CLEAR", "HOTA", "Identity"], - help=( - "Metric families to compute. Default: CLEAR. The family required " - "by --objective is added automatically if missing." - ), - ) - parser.add_argument( - "--threshold", - type=float, - default=0.5, - help="IoU threshold for CLEAR and Identity matching. Default: 0.5.", - ) - parser.add_argument( - "--seqmap", - type=Path, - metavar="PATH", - help="Sequence map file listing sequences to evaluate.", - ) - parser.add_argument( - "--output", - "-o", - type=Path, - metavar="PATH", - help="Output file for best parameters (JSON format).", - ) - - parser.set_defaults(func=run_tune) - - -def run_tune(args: jsonargparse.Namespace) -> int: - """Execute the tune command.""" - return tune( - tracker=args.tracker, - gt_dir=args.gt_dir, - detections_dir=args.detections_dir, - objective=args.objective, - n_trials=args.n_trials, - metrics=args.metrics, - threshold=args.threshold, - seqmap=args.seqmap, - output=args.output, - ) def tune( @@ -122,20 +28,22 @@ def tune( """Tune tracker hyperparameters using Optuna. Args: - tracker: Tracker ID to tune (e.g. bytetrack, sort). + tracker: Tracker ID to tune (e.g. ``bytetrack``, ``sort``, ``ocsort``). gt_dir: Directory of ground-truth MOT files. detections_dir: Directory of pre-computed detection files in MOT flat - format (one {seq}.txt per sequence). - objective: Scalar metric to maximise. Options: MOTA, HOTA, IDF1. + format (one ``{seq}.txt`` per sequence). + objective: Scalar metric to maximise. Options: ``MOTA``, ``HOTA``, + ``IDF1``. n_trials: Number of Optuna trials to run. - metrics: Metric families to compute. Options: CLEAR, HOTA, Identity. - Default: CLEAR. + metrics: Metric families to compute. Options: ``CLEAR``, ``HOTA``, + ``Identity``. Default: ``["CLEAR"]``. The family required by + ``objective`` is added automatically if missing. threshold: IoU threshold for CLEAR and Identity matching. seqmap: Sequence map file listing sequences to evaluate. - output: Output file path for best parameters (JSON format). + output: Output JSON file for best parameters. Returns: - Exit code: 0 on success, 1 on error. + Exit code: ``0`` on success, ``1`` on error. """ if metrics is None: metrics = ["CLEAR"] diff --git a/tests/cli/test_download.py b/tests/cli/test_download.py index 7f19ff0f..852db7a5 100644 --- a/tests/cli/test_download.py +++ b/tests/cli/test_download.py @@ -4,93 +4,38 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +"""CLI-level tests for trackers/cli/download.py.""" + from __future__ import annotations -import argparse from unittest.mock import patch import pytest -from trackers.cli.download import ( - _print_available, - _run_download, - add_download_subparser, -) +from trackers.cli.download import _print_available, download from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR -def _parse_args(argv: list[str]) -> argparse.Namespace: - """Parse argv through a fresh download subparser and return the namespace.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - add_download_subparser(subparsers) - return parser.parse_args(argv) - - -class TestSubparserRegistration: - """Argument parsing and help strings.""" - - def test_list_flag(self) -> None: - """--list sets the flag to True.""" - args = _parse_args(["download", "--list"]) - assert args.list is True - - def test_list_flag_default_false(self) -> None: - """--list is False when omitted.""" - args = _parse_args(["download", "mot17"]) - assert args.list is False - - def test_split_flag_accepts_comma_separated(self) -> None: - """--split accepts comma-separated values.""" - args = _parse_args(["download", "mot17", "--split", "train,val"]) - assert args.split == "train,val" - - def test_asset_flag_accepts_comma_separated(self) -> None: - """--asset accepts comma-separated values.""" - args = _parse_args(["download", "mot17", "--asset", "frames,annotations"]) - assert args.asset == "frames,annotations" - - def test_output_directory_short_flag(self) -> None: - """-o sets the output directory.""" - args = _parse_args(["download", "mot17", "-o", "./datasets"]) - assert args.output == "./datasets" - - def test_cache_dir_flag(self) -> None: - """--cache-dir sets the cache directory.""" - args = _parse_args(["download", "mot17", "--cache-dir", "./cache"]) - assert args.cache_dir == "./cache" - - def test_dataset_positional(self) -> None: - """Dataset is captured as positional argument.""" - args = _parse_args(["download", "sportsmot"]) - assert args.dataset == "sportsmot" - - -class TestRunDownload: +class TestDownload: """Execution of the download subcommand.""" def test_list_triggers_print(self) -> None: - """--list calls _print_available and returns 0.""" - args = _parse_args(["download", "--list"]) - + """list_available=True calls _print_available and returns 0.""" with patch("trackers.cli.download._print_available") as mock_print: - rc = _run_download(args) + rc = download(list_available=True) assert rc == 0 mock_print.assert_called_once() def test_list_takes_precedence_over_dataset(self) -> None: - """--list wins over dataset positional.""" - args = _parse_args(["download", "mot17", "--list"]) - + """list_available=True wins over dataset argument.""" with patch("trackers.cli.download._print_available") as mock_print: - rc = _run_download(args) + rc = download(dataset="mot17", list_available=True) assert rc == 0 mock_print.assert_called_once() def test_missing_dataset_exits_with_error(self, capsys: pytest.CaptureFixture[str]) -> None: - """No dataset and no --list prints error to stderr and returns 1.""" - args = _parse_args(["download"]) - rc = _run_download(args) + """No dataset and no list_available prints error to stderr and returns 1.""" + rc = download() captured = capsys.readouterr() assert rc == 1 assert "Please specify a dataset" in captured.err @@ -104,11 +49,9 @@ def test_missing_dataset_exits_with_error(self, capsys: pytest.CaptureFixture[st ], ) def test_split_comma_parsing(self, split_arg: str, expected_splits: list[str]) -> None: - """--split values are split on commas and whitespace-stripped.""" - args = _parse_args(["download", "mot17", "--split", split_arg, "--asset", "annotations"]) - + """split values are split on commas and whitespace-stripped.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17", split=split_arg, asset="annotations") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -127,11 +70,9 @@ def test_split_comma_parsing(self, split_arg: str, expected_splits: list[str]) - ], ) def test_split_comma_parsing_boundary(self, split_arg: str, expected_splits: list[str]) -> None: - """--split handles malformed comma inputs gracefully.""" - args = _parse_args(["download", "mot17", "--split", split_arg, "--asset", "annotations"]) - + """split handles malformed comma inputs gracefully.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17", split=split_arg, asset="annotations") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -150,11 +91,9 @@ def test_split_comma_parsing_boundary(self, split_arg: str, expected_splits: lis ], ) def test_asset_comma_parsing(self, asset_arg: str, expected_assets: list[str]) -> None: - """--asset values are split on commas and whitespace-stripped.""" - args = _parse_args(["download", "sportsmot", "--split", "train", "--asset", asset_arg]) - + """asset values are split on commas and whitespace-stripped.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="sportsmot", split="train", asset=asset_arg) assert rc == 0 mock_dl.assert_called_once_with( dataset="sportsmot", @@ -165,11 +104,9 @@ def test_asset_comma_parsing(self, asset_arg: str, expected_assets: list[str]) - ) def test_none_splits_and_assets_when_omitted(self) -> None: - """When --split and --asset are omitted, None is forwarded.""" - args = _parse_args(["download", "mot17"]) - + """When split and asset are omitted, None is forwarded.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -180,11 +117,9 @@ def test_none_splits_and_assets_when_omitted(self) -> None: ) def test_output_directory_forwarded(self) -> None: - """-o value is forwarded to download_dataset.""" - args = _parse_args(["download", "mot17", "-o", "/custom/path"]) - + """output value is forwarded to download_dataset.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17", output="/custom/path") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -196,21 +131,17 @@ def test_output_directory_forwarded(self) -> None: def test_value_error_returns_exit_code(self) -> None: """ValueError from download_dataset is caught and returns 1.""" - args = _parse_args(["download", "mot17"]) - with patch( "trackers.datasets.download.download_dataset", side_effect=ValueError("bad dataset"), ): - rc = _run_download(args) + rc = download(dataset="mot17") assert rc == 1 def test_split_with_spaces_stripped(self) -> None: - """--split with spaces around commas strips whitespace.""" - args = _parse_args(["download", "mot17", "--split", "train , val", "--asset", "annotations"]) - + """split with spaces around commas strips whitespace.""" with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) + rc = download(dataset="mot17", split="train , val", asset="annotations") assert rc == 0 mock_dl.assert_called_once_with( dataset="mot17", @@ -222,7 +153,7 @@ def test_split_with_spaces_stripped(self) -> None: class TestPrintAvailable: - """Output of --list.""" + """Output of list_available.""" def test_prints_without_error(self, capsys: pytest.CaptureFixture[str]) -> None: """_print_available runs without raising and does not leak output.""" diff --git a/tests/cli/test_tune.py b/tests/cli/test_tune.py index 1d6733b5..5ae943dc 100644 --- a/tests/cli/test_tune.py +++ b/tests/cli/test_tune.py @@ -8,93 +8,13 @@ from __future__ import annotations -import argparse import json from pathlib import Path from unittest.mock import MagicMock, patch import pytest -from trackers.cli.tune import add_tune_subparser, run_tune, tune - - -def _make_parser() -> tuple[argparse.ArgumentParser, argparse._SubParsersAction]: - """Return a top-level parser with a subparsers group.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - return parser, subparsers - - -class TestAddTuneSubparser: - @pytest.fixture - def minimal_args(self) -> argparse.Namespace: - """Parsed args with only required flags.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - return parser.parse_args(["tune", "--tracker", "sort", "--gt-dir", "/gt", "--detections-dir", "/det"]) - - def test_registers_tune_subcommand(self) -> None: - """tune subcommand is accessible under the 'tune' name.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args(["tune", "--tracker", "sort", "--gt-dir", "/gt", "--detections-dir", "/det"]) - assert args.func is run_tune - - def test_required_args_parsed(self) -> None: - """--tracker, --gt-dir, and --detections-dir are required and parsed.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args( - [ - "tune", - "--tracker", - "bytetrack", - "--gt-dir", - "/data/gt", - "--detections-dir", - "/data/det", - ] - ) - assert args.tracker == "bytetrack" - assert args.gt_dir == Path("/data/gt") - assert args.detections_dir == Path("/data/det") - - @pytest.mark.parametrize( - "flag,expected", - [ - ("objective", "HOTA"), - ("n_trials", 100), - ("threshold", 0.5), - ("seqmap", None), - ("output", None), - ], - ) - def test_optional_defaults(self, minimal_args: argparse.Namespace, flag: str, expected: object) -> None: - """Optional arguments have correct defaults when omitted.""" - assert getattr(minimal_args, flag) == expected - - def test_metrics_default(self, minimal_args: argparse.Namespace) -> None: - """--metrics defaults to ['CLEAR'] when not supplied.""" - assert minimal_args.metrics == ["CLEAR"] - - def test_output_flag_short_form(self) -> None: - """-o is an alias for --output.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args( - [ - "tune", - "--tracker", - "sort", - "--gt-dir", - "/gt", - "--detections-dir", - "/det", - "-o", - "/out/params.json", - ] - ) - assert args.output == Path("/out/params.json") +from trackers.cli.tune import tune class TestTune: @@ -180,34 +100,110 @@ def test_returns_1_on_tuner_run_exception(self, tmp_path: Path) -> None: assert result == 1 -class TestRunTune: - def test_delegates_to_tune_with_namespace_args(self, tmp_path: Path) -> None: - """run_tune() passes all argparse.Namespace fields to tune() correctly.""" +class TestCliInvocation: + """tune() is wired into the jsonargparse CLI with the expected args.""" + + @staticmethod + def _invoke(args: list[str], spy: list[dict]) -> object: + """Run jsonargparse.CLI() with a recording spy for `tune`. + + The spy mirrors the real signature so jsonargparse can introspect it. + """ + from jsonargparse import CLI + + from trackers.cli.tune import tune as real_tune + + def spy_tune( + tracker: str, + gt_dir: Path, + detections_dir: Path, + objective: str = "HOTA", + n_trials: int = 100, + metrics: list[str] | None = None, + threshold: float = 0.5, + seqmap: Path | None = None, + output: Path | None = None, + ) -> int: + spy.append( + dict( + tracker=tracker, + gt_dir=gt_dir, + detections_dir=detections_dir, + objective=objective, + n_trials=n_trials, + metrics=metrics, + threshold=threshold, + seqmap=seqmap, + output=output, + ) + ) + return 0 + + # Copy the docstring so jsonargparse's introspection matches the real function. + spy_tune.__doc__ = real_tune.__doc__ + return CLI({"tune": spy_tune}, as_positional=False, args=args) + + def test_cli_dispatch_to_tune(self, tmp_path: Path) -> None: + """jsonargparse.CLI() parses the tune subcommand and forwards args.""" gt_dir = tmp_path / "gt" det_dir = tmp_path / "det" - output_path = tmp_path / "params.json" - args = argparse.Namespace( - tracker="sort", - gt_dir=gt_dir, - detections_dir=det_dir, - objective="MOTA", - n_trials=50, - metrics=["CLEAR", "HOTA"], - threshold=0.3, - seqmap=None, - output=output_path, + spy: list[dict] = [] + result = self._invoke( + [ + "tune", + "--tracker", + "sort", + "--gt_dir", + str(gt_dir), + "--detections_dir", + str(det_dir), + "--objective", + "MOTA", + "--n_trials", + "50", + ], + spy, ) - with patch("trackers.cli.tune.tune", return_value=0) as mock_tune: - result = run_tune(args) assert result == 0 - mock_tune.assert_called_once_with( - tracker="sort", - gt_dir=gt_dir, - detections_dir=det_dir, - objective="MOTA", - n_trials=50, - metrics=["CLEAR", "HOTA"], - threshold=0.3, - seqmap=None, - output=output_path, + assert len(spy) == 1 + assert spy[0]["tracker"] == "sort" + assert spy[0]["gt_dir"] == gt_dir + assert spy[0]["detections_dir"] == det_dir + assert spy[0]["objective"] == "MOTA" + assert spy[0]["n_trials"] == 50 + + @pytest.mark.parametrize( + "flag,arg_value,attr,expected", + [ + ("--objective", "HOTA", "objective", "HOTA"), + ("--n_trials", "100", "n_trials", 100), + ("--threshold", "0.5", "threshold", 0.5), + ], + ) + def test_cli_defaults( + self, + tmp_path: Path, + flag: str, + arg_value: str, + attr: str, + expected: object, + ) -> None: + """Optional flags carry their declared defaults when invoked via CLI.""" + gt_dir = tmp_path / "gt" + det_dir = tmp_path / "det" + spy: list[dict] = [] + self._invoke( + [ + "tune", + "--tracker", + "sort", + "--gt_dir", + str(gt_dir), + "--detections_dir", + str(det_dir), + flag, + arg_value, + ], + spy, ) + assert spy[0][attr] == expected From 2939d0e1515d66ef28286703ef06a6f67f1a3086 Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Wed, 20 May 2026 20:44:44 +0200 Subject: [PATCH 03/13] feat(cli): group track args into dataclasses for --help grouping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - add DetectionOptions, FilteringOptions, OutputOptions, VisualizationOptions, ShowOptions dataclasses; each renders as a named group in --help via jsonargparse[signatures] - track() signature collapses 19 flat args into 6 typed group params (detection, filters, out, vis, show) + flat source and tracker; body unpacks to local vars so internals unchanged - switch jsonargparse>=4.48.0 → jsonargparse[signatures]>=4.48.0 to pull in docstring-parser (group titles and per-field help) - flag renames: --model → --detection.model, --confidence → --detection.confidence, --output → --out.output, --show_boxes → --show.boxes, etc. --- Co-authored-by: Claude Code --- pyproject.toml | 2 +- src/trackers/cli/track.py | 158 ++++++++++++++++++++++++++++---------- uv.lock | 41 +++++++++- 3 files changed, 156 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 514168c3..3f428944 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "opencv-python>=4.8.0", "rich>=13.0.0", "requests>=2.28.0", - "jsonargparse>=4.48.0", + "jsonargparse[signatures]>=4.48.0", ] [project.optional-dependencies] diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py index e5c6d394..a153b4ca 100644 --- a/src/trackers/cli/track.py +++ b/src/trackers/cli/track.py @@ -54,6 +54,91 @@ ) +@dataclass +class DetectionOptions: + """Detection model and inference settings. + + Attributes: + model: Model ID (e.g. ``rfdetr-nano``) or + ``workspace/project/version`` for a Roboflow custom model. + Ignored when ``detections`` is set. + detections: Path to a pre-computed MOT-format detections file. + Mutually exclusive with ``model``; supply one or the other. + confidence: Detection confidence threshold. + device: Inference device: ``auto``, ``cpu``, ``cuda``, ``cuda:0``, + ``mps``. + api_key: Roboflow API key (required for private custom models). + """ + + model: str = DEFAULT_MODEL + detections: Path | None = None + confidence: float = DEFAULT_CONFIDENCE + device: str = DEFAULT_DEVICE + api_key: str | None = None + + +@dataclass +class FilteringOptions: + """Detection and track filters. + + Attributes: + classes: Comma-separated class names or IDs to keep + (e.g. ``person,car`` or ``0,2``). + track_ids: Comma-separated track IDs to keep in the output + (e.g. ``1,3,5``). + """ + + classes: str | None = None + track_ids: str | None = None + + +@dataclass +class OutputOptions: + """Output paths and write options. + + Attributes: + output: Annotated-video output path. + mot_output: MOT-format predictions output path. + overwrite: Overwrite existing output files without prompting. + """ + + output: Path | None = None + mot_output: Path | None = None + overwrite: bool = False + + +@dataclass +class VisualizationOptions: + """Live preview and display settings. + + Attributes: + display: Show a live preview window during tracking. + """ + + display: bool = False + + +@dataclass +class ShowOptions: + """Annotation elements to draw on each frame. + + Attributes: + boxes: Draw bounding boxes around detections. + masks: Draw segmentation masks (segmentation models only). + labels: Draw class labels. + ids: Draw track IDs. + confidence: Draw detection confidence scores. + trajectories: Draw track trajectory trails. + """ + + boxes: bool = True + masks: bool = False + labels: bool = False + ids: bool = True + confidence: bool = False + trajectories: bool = False + + @dataclass class TrackerParams: """Optional tracker-specific parameters. @@ -99,61 +184,50 @@ class TrackerParams: def track( source: str | None = None, - model: str = DEFAULT_MODEL, - detections: Path | None = None, - confidence: float = DEFAULT_CONFIDENCE, - device: str = DEFAULT_DEVICE, - api_key: str | None = None, - classes: str | None = None, - track_ids: str | None = None, + detection: DetectionOptions = DetectionOptions(), + filters: FilteringOptions = FilteringOptions(), tracker: str = DEFAULT_TRACKER, tracker_params: TrackerParams | None = None, - output: Path | None = None, - mot_output: Path | None = None, - overwrite: bool = False, - display: bool = False, - show_boxes: bool = True, - show_masks: bool = False, - show_labels: bool = False, - show_ids: bool = True, - show_confidence: bool = False, - show_trajectories: bool = False, + out: OutputOptions = OutputOptions(), + vis: VisualizationOptions = VisualizationOptions(), + show: ShowOptions = ShowOptions(), ) -> int: """Run detection and tracking over a video, webcam, RTSP, or image directory. Args: source: Video file, webcam index (e.g. ``"0"``), RTSP URL, or image - directory. Required unless ``detections`` is supplied. - model: Detection model ID (e.g. ``rfdetr-nano``) or - ``workspace/project/version`` for a Roboflow custom model. - detections: Path to a pre-computed MOT-format detections file. Mutually - exclusive with ``model``. - confidence: Detection confidence threshold. - device: Inference device: ``auto``, ``cpu``, ``cuda``, ``cuda:0``, - ``mps``. - api_key: Roboflow API key for custom models. - classes: Comma-separated class names or IDs to keep - (e.g. ``person,car``). - track_ids: Comma-separated track IDs to keep in the output - (e.g. ``1,3,5``). + directory. Required unless ``detection.detections`` is supplied. + detection: Detection model and inference options. + filters: Class and track-ID filters applied to detections and tracks. tracker: Tracking algorithm ID. Discoverable via ``BaseTracker._registered_trackers()``. - tracker_params: Optional tracker parameters; only fields matching the - chosen tracker's ``__init__`` are forwarded. - output: Output annotated-video path. - mot_output: Output MOT-format predictions path. - overwrite: Overwrite existing output files. - display: Show a preview window during tracking. - show_boxes: Draw bounding boxes. - show_masks: Draw segmentation masks (segmentation models only). - show_labels: Draw class labels. - show_ids: Draw track IDs. - show_confidence: Draw confidence scores. - show_trajectories: Draw track trajectories (trails). + tracker_params: Optional tracker parameter overrides; only fields + matching the chosen tracker's ``__init__`` are forwarded. + out: Output path and overwrite options. + vis: Live preview and display options. + show: Annotation elements to draw on each frame. Returns: Exit code: ``0`` on success, ``1`` on validation error. """ + model = detection.model + detections = detection.detections + confidence = detection.confidence + device = detection.device + api_key = detection.api_key + classes = filters.classes + track_ids = filters.track_ids + output = out.output + mot_output = out.mot_output + overwrite = out.overwrite + display = vis.display + show_boxes = show.boxes + show_masks = show.masks + show_labels = show.labels + show_ids = show.ids + show_confidence = show.confidence + show_trajectories = show.trajectories + needs_frames = output is not None or display if source is None and detections is None: diff --git a/uv.lock b/uv.lock index 7fbdc6c1..f451cf00 100644 --- a/uv.lock +++ b/uv.lock @@ -604,6 +604,15 @@ version = "0.6.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/a2/55/8f8cab2afd404cf578136ef2cc5dfb50baa1761b68c9da1fb1e4eed343c9/docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491", size = 25901, upload-time = "2014-06-16T11:18:57.406Z" } +[[package]] +name = "docstring-parser" +version = "0.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/4d/f332313098c1de1b2d2ff91cf2674415cc7cddab2ca1b01ae29774bd5fdf/docstring_parser-0.18.0.tar.gz", hash = "sha256:292510982205c12b1248696f44959db3cdd1740237a968ea1e2e7a900eeb2015", size = 29341, upload-time = "2026-04-14T04:09:19.867Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/5f/ed01f9a3cdffbd5a008556fc7b2a08ddb1cc6ace7effa7340604b1d16699/docstring_parser-0.18.0-py3-none-any.whl", hash = "sha256:b3fcbed555c47d8479be0796ef7e19c2670d428d72e96da63f3a40122860374b", size = 22484, upload-time = "2026-04-14T04:09:18.638Z" }, +] + [[package]] name = "docutils" version = "0.21.2" @@ -1058,6 +1067,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" }, ] +[[package]] +name = "importlib-resources" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/06/b56dfa750b44e86157093bc8fca0ab81dccbf5260510de4eaf1cb69b5b99/importlib_resources-7.1.0.tar.gz", hash = "sha256:0722d4c6212489c530f2a145a34c0a7a3b4721bc96a15fada5930e2a0b760708", size = 44985, upload-time = "2026-04-12T16:36:09.232Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" }, +] + [[package]] name = "inference-models" version = "0.27.2" @@ -1195,6 +1213,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/e9/c922101c1e80455d4b44b766b353dafc990da350228fc2515790e5949dd5/jsonargparse-4.48.0-py3-none-any.whl", hash = "sha256:c6a92fd71eb256437371750bb11f436b9c3294da2535f1b0406346816f04be16", size = 131277, upload-time = "2026-04-10T06:52:37.394Z" }, ] +[package.optional-dependencies] +signatures = [ + { name = "docstring-parser" }, + { name = "typeshed-client" }, +] + [[package]] name = "keyring" version = "25.6.0" @@ -4038,7 +4062,7 @@ name = "trackers" version = "2.4.0" source = { editable = "." } dependencies = [ - { name = "jsonargparse" }, + { name = "jsonargparse", extra = ["signatures"] }, { name = "numpy" }, { name = "opencv-python" }, { name = "requests" }, @@ -4083,7 +4107,7 @@ mypy-types = [ [package.metadata] requires-dist = [ { name = "inference-models", marker = "extra == 'detection'", specifier = ">=0.19.0" }, - { name = "jsonargparse", specifier = ">=4.48.0" }, + { name = "jsonargparse", extras = ["signatures"], specifier = ">=4.48.0" }, { name = "numpy", specifier = ">=2.0.2" }, { name = "opencv-python", specifier = ">=4.8.0" }, { name = "optuna", marker = "extra == 'tune'", specifier = ">=3.0.0" }, @@ -4209,6 +4233,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/96/080db0afdf2c5cc5fe512b41354e8d114fe8f65e9510c56ff8dfd40216ce/types_requests-2.33.0.20260508-py3-none-any.whl", hash = "sha256:fa01459cca184229713df03709db46a905325906d27e042cd4fd7ea3d15d3400", size = 20722, upload-time = "2026-05-08T04:50:55.548Z" }, ] +[[package]] +name = "typeshed-client" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-resources" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/7d/62fbae352d5fb7ce5ef4d9ca73bf7a9b02b790d2524ab6ef1e0e799a5d1b/typeshed_client-2.11.0.tar.gz", hash = "sha256:0b8f2ab88f611f5e97b70d2a8123942d3d7d5c74cee8ae694db83422f32f9481", size = 522774, upload-time = "2026-05-01T14:51:52.38Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/22/fa16b462157bd869dfad528f5637506b9430ca63d48fb536ecf4cc78481a/typeshed_client-2.11.0-py3-none-any.whl", hash = "sha256:5745e0990b80b29a286b22d68f81779c5c7adf1cac8969eeafba44b73b486c36", size = 787609, upload-time = "2026-05-01T14:51:51.005Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From 34fae36b0b3f8fe6cc0a174ca718e0dd9d4c4178 Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Wed, 20 May 2026 20:55:09 +0200 Subject: [PATCH 04/13] refine(cli): render bool args as --flag/--no-flag pairs - add _BoolFlagParser subclass that swaps type=bool args to ActionYesNo(yes_prefix="", no_prefix="no-"); all plain bool fields now render as --show.ids/--no-show.ids instead of --show.ids {true,false} - pass parser_class=_BoolFlagParser to CLI(); bool|None fields (e.g. TrackerParams.enable_cmc) unaffected by the type is bool guard --- Co-authored-by: Claude Code --- src/trackers/cli/__main__.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/trackers/cli/__main__.py b/src/trackers/cli/__main__.py index 8e3c521c..20379935 100644 --- a/src/trackers/cli/__main__.py +++ b/src/trackers/cli/__main__.py @@ -12,7 +12,7 @@ import sys import warnings -from jsonargparse import CLI +from jsonargparse import CLI, ActionYesNo, ArgumentParser from trackers.cli.download import download from trackers.cli.eval import eval_cmd @@ -20,6 +20,16 @@ from trackers.cli.tune import tune +class _BoolFlagParser(ArgumentParser): + """Render plain ``bool`` fields as ``--flag`` / ``--no-flag`` pairs.""" + + def add_argument(self, *args, **kwargs): # type: ignore[override] + if kwargs.get("type") is bool: + kwargs.pop("type") + kwargs["action"] = ActionYesNo(yes_prefix="", no_prefix="no-") + return super().add_argument(*args, **kwargs) + + def main() -> int: """Dispatch to track / eval / tune / download via jsonargparse CLI.""" warnings.warn( @@ -32,6 +42,7 @@ def main() -> int: as_positional=False, prog="trackers", description="Command-line tools for multi-object tracking.", + parser_class=_BoolFlagParser, ) return int(rc) if rc is not None else 0 From 3fe830ebe6d6aca4b7440876c9a7bea806fd1fd5 Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Wed, 20 May 2026 21:31:03 +0200 Subject: [PATCH 05/13] docs(cli): update all CLI examples for jsonargparse flag names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Migrate flag names: --model→--detection.model, --classes→--filters.classes, --display→--vis.display, --show-*→--show.*, --mot-output→--out.mot_output, --output→--out.output - Fix list syntax: --metrics CLEAR HOTA Identity→--metrics '[CLEAR,HOTA,Identity]', --columns MOTA HOTA IDF1→--columns '[MOTA,HOTA,IDF1]' (space-separated fails with jsonargparse union validation) - Fix download flags: positional dataset→--dataset, --list→--list_available, --cache-dir→--cache_dir - Fix hyphen→underscore: --gt-dir→--gt_dir, --tracker-dir→--tracker_dir, --detections-dir→--detections_dir, --n-trials→--n_trials - Validated against trackers track/eval/tune/download --help output --- Co-authored-by: Claude Code --- README.md | 18 +++++----- docs/index.md | 18 +++++----- docs/learn/detection-quality.md | 63 +++++++++++++++++---------------- docs/learn/download.md | 26 +++++++------- docs/learn/evaluate.md | 25 ++++++------- docs/learn/track.md | 62 ++++++++++++++++---------------- docs/learn/tune.md | 22 ++++++------ 7 files changed, 118 insertions(+), 116 deletions(-) diff --git a/README.md b/README.md index 7a003a17..a2b94e7f 100644 --- a/README.md +++ b/README.md @@ -75,11 +75,11 @@ Prefer the terminal? Point `trackers track` at a video, webcam feed, RTSP stream ```bash trackers track \ --source video.mp4 \ - --output output.mp4 \ - --model rfdetr-medium \ + --out.output output.mp4 \ + --detection.model rfdetr-medium \ --tracker bytetrack \ - --show-labels \ - --show-trajectories + --show.labels \ + --show.trajectories ``` For all CLI options, see the [tracking guide](https://trackers.roboflow.com/develop/learn/track/). @@ -103,10 +103,10 @@ Once you have tracking results, you want to know how good they are. `trackers ev ```bash trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` ``` @@ -130,7 +130,7 @@ For the full evaluation workflow, see the [evaluation guide](https://trackers.ro Need benchmark data to evaluate against? `trackers download` pulls MOT17, SportsMOT, and other supported datasets with a single command, handling splits and assets selectively so you only download what you need. ```bash -trackers download mot17 \ +trackers download --dataset mot17 \ --split val \ --asset annotations,detections ``` diff --git a/docs/index.md b/docs/index.md index 06fc8053..befd0a75 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,11 +41,11 @@ Point at a video, webcam, RTSP stream, or image directory. Get tracked output. ```bash trackers track \ --source video.mp4 \ - --output output.mp4 \ - --model rfdetr-medium \ + --out.output output.mp4 \ + --detection.model rfdetr-medium \ --tracker bytetrack \ - --show-labels \ - --show-trajectories + --show.labels \ + --show.trajectories ``` For all CLI options, see the [tracking guide](learn/track.md). @@ -86,10 +86,10 @@ Benchmark your tracker against ground truth with standard MOT metrics. ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` ``` @@ -130,7 +130,7 @@ For detailed benchmarks and tuned configurations, see the [tracker comparison](t Pull benchmark datasets for evaluation with a single command. ```bash -trackers download mot17 \ +trackers download --dataset mot17 \ --split val \ --asset annotations,detections ``` diff --git a/docs/learn/detection-quality.md b/docs/learn/detection-quality.md index 4ea4a909..54e073e4 100644 --- a/docs/learn/detection-quality.md +++ b/docs/learn/detection-quality.md @@ -47,7 +47,8 @@ We pick three models that span a wide accuracy range on COCO, from a lightweight Pull the MOT17 validation split. You need frames for detection and annotations for evaluation. ```text -trackers download mot17 \ +trackers download \ + --dataset mot17 \ --split val \ --asset frames,annotations \ --output ./data @@ -66,10 +67,10 @@ Run ByteTrack with default parameters three times, changing only the detection m ```bash trackers track \ --source ./data/mot17/val/MOT17-13-FRCNN/img1 \ - --model yolo26n-640 \ + --detection.model yolo26n-640 \ --tracker bytetrack \ - --classes person \ - --mot-output results/yolo26n/MOT17-13-FRCNN.txt + --filters.classes person \ + --out.mot_output results/yolo26n/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -78,10 +79,10 @@ Run ByteTrack with default parameters three times, changing only the detection m for seq in MOT17-02-FRCNN MOT17-04-FRCNN MOT17-05-FRCNN MOT17-09-FRCNN MOT17-10-FRCNN MOT17-11-FRCNN MOT17-13-FRCNN; do trackers track \ --source ./data/mot17/val/$seq/img1 \ - --model yolo26n-640 \ + --detection.model yolo26n-640 \ --tracker bytetrack \ - --classes person \ - --mot-output results/yolo26n/$seq.txt + --filters.classes person \ + --out.mot_output results/yolo26n/$seq.txt done ``` @@ -97,10 +98,10 @@ Run ByteTrack with default parameters three times, changing only the detection m ```bash trackers track \ --source ./data/mot17/val/MOT17-13-FRCNN/img1 \ - --model rfdetr-nano \ + --detection.model rfdetr-nano \ --tracker bytetrack \ - --classes person \ - --mot-output results/rfdetr-nano/MOT17-13-FRCNN.txt + --filters.classes person \ + --out.mot_output results/rfdetr-nano/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -109,10 +110,10 @@ Run ByteTrack with default parameters three times, changing only the detection m for seq in MOT17-02-FRCNN MOT17-04-FRCNN MOT17-05-FRCNN MOT17-09-FRCNN MOT17-10-FRCNN MOT17-11-FRCNN MOT17-13-FRCNN; do trackers track \ --source ./data/mot17/val/$seq/img1 \ - --model rfdetr-nano \ + --detection.model rfdetr-nano \ --tracker bytetrack \ - --classes person \ - --mot-output results/rfdetr-nano/$seq.txt + --filters.classes person \ + --out.mot_output results/rfdetr-nano/$seq.txt done ``` @@ -128,10 +129,10 @@ Run ByteTrack with default parameters three times, changing only the detection m ```bash trackers track \ --source ./data/mot17/val/MOT17-13-FRCNN/img1 \ - --model rfdetr-medium \ + --detection.model rfdetr-medium \ --tracker bytetrack \ - --classes person \ - --mot-output results/rfdetr-medium/MOT17-13-FRCNN.txt + --filters.classes person \ + --out.mot_output results/rfdetr-medium/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -140,10 +141,10 @@ Run ByteTrack with default parameters three times, changing only the detection m for seq in MOT17-02-FRCNN MOT17-04-FRCNN MOT17-05-FRCNN MOT17-09-FRCNN MOT17-10-FRCNN MOT17-11-FRCNN MOT17-13-FRCNN; do trackers track \ --source ./data/mot17/val/$seq/img1 \ - --model rfdetr-medium \ + --detection.model rfdetr-medium \ --tracker bytetrack \ - --classes person \ - --mot-output results/rfdetr-medium/$seq.txt + --filters.classes person \ + --out.mot_output results/rfdetr-medium/$seq.txt done ``` @@ -162,10 +163,10 @@ Evaluate each run against ground truth using CLEAR, HOTA, and Identity metrics. ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results/yolo26n \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results/yolo26n \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` **Output:** @@ -180,10 +181,10 @@ COMBINED 23.444 32.874 34.411 ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results/rfdetr-nano \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results/rfdetr-nano \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` **Output:** @@ -198,10 +199,10 @@ COMBINED 25.667 35.735 38.182 ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results/rfdetr-medium \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --gt_dir ./data/mot17/val \ + --tracker_dir results/rfdetr-medium \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` **Output:** diff --git a/docs/learn/download.md b/docs/learn/download.md index 68b12e31..70dba25d 100644 --- a/docs/learn/download.md +++ b/docs/learn/download.md @@ -44,10 +44,10 @@ The table below lists every dataset you can download, along with its splits, ass === "CLI" - Use `--list` to print available datasets, splits, and asset types. + Use `--list_available` to print available datasets, splits, and asset types. ```text - trackers download --list + trackers download --list_available ``` === "Python" @@ -72,7 +72,7 @@ Pass a dataset name to download all of its splits and assets. Download the full MOT17 dataset. ```text - trackers download mot17 + trackers download --dataset mot17 ``` === "Python" @@ -96,15 +96,15 @@ Full datasets can be large. Narrow your download to specific splits and asset ty Use `--split` and `--asset` to filter by split, asset type, or both. ```text - trackers download mot17 --split train --asset annotations + trackers download --dataset mot17 --split train --asset annotations ``` ```text - trackers download mot17 --split train,val --asset annotations,frames + trackers download --dataset mot17 --split train,val --asset annotations,frames ``` ```text - trackers download sportsmot --split val --asset annotations + trackers download --dataset sportsmot --split val --asset annotations ``` === "Python" @@ -152,7 +152,7 @@ Dataset files are extracted to the current directory by default. Set a custom ou Use `--output` to extract into a custom directory. ```text - trackers download mot17 \ + trackers download --dataset mot17 \ --split train,val \ --asset annotations,frames \ --output ./datasets @@ -209,13 +209,13 @@ Every downloaded ZIP is saved to `~/.cache/trackers` and verified with an MD5 ch === "CLI" - Use `--cache-dir` to store ZIPs in a custom location. + Use `--cache_dir` to store ZIPs in a custom location. ```text - trackers download mot17 \ + trackers download --dataset mot17 \ --split train \ --asset annotations \ - --cache-dir ./my-cache + --cache_dir ./my-cache ``` === "Python" @@ -254,12 +254,12 @@ All arguments accepted by the `trackers download` command. - dataset + --dataset Dataset name to download. Options: mot17, sportsmot. — - --list + --list_available List available datasets, splits, and asset types without downloading. false @@ -279,7 +279,7 @@ All arguments accepted by the `trackers download` command. . - --cache-dir + --cache_dir Directory for caching downloaded ZIP files. Cached files are verified by MD5 and reused across runs. ~/.cache/trackers diff --git a/docs/learn/evaluate.md b/docs/learn/evaluate.md index 45cd8eca..fd4e220d 100644 --- a/docs/learn/evaluate.md +++ b/docs/learn/evaluate.md @@ -41,7 +41,8 @@ Use `trackers download` to pull ground-truth annotations and detections from sup Fetch MOT17 validation annotations and detections from the command line. ```text - trackers download mot17 \ + trackers download \ + --dataset mot17 \ --split val \ --asset annotations,detections \ --output ./data @@ -89,13 +90,13 @@ For more download options, see the [download guide](download.md). Feed the pre-computed detections into a tracker and write the results to a file for evaluation. -Pass `--detections` to provide input detections and `--mot-output` to save the tracker output in MOT format. +Pass `--detection.detections` to provide input detections and `--out.mot_output` to save the tracker output in MOT format. ```text trackers track \ - --detections ./data/mot17/val/MOT17-02-FRCNN/det/det.txt \ + --detection.detections ./data/mot17/val/MOT17-02-FRCNN/det/det.txt \ --tracker bytetrack \ - --mot-output results/MOT17-02-FRCNN.txt + --out.mot_output results/MOT17-02-FRCNN.txt ``` --- @@ -108,8 +109,8 @@ Compare the tracker output against ground truth to compute standard MOT metrics. trackers eval \ --gt ./data/mot17/val/MOT17-02-FRCNN/gt/gt.txt \ --tracker results/MOT17-02-FRCNN.txt \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' ``` **Output:** @@ -148,10 +149,10 @@ Evaluate all sequences at once and get per-sequence results plus a combined aggr ```text trackers eval \ - --gt-dir ./data/mot17/val \ - --tracker-dir results \ - --metrics CLEAR HOTA Identity \ - --columns MOTA HOTA IDF1 \ + --gt_dir ./data/mot17/val \ + --tracker_dir results \ + --metrics '[CLEAR,HOTA,Identity]' \ + --columns '[MOTA,HOTA,IDF1]' \ --output results.json ``` @@ -204,12 +205,12 @@ All arguments accepted by `trackers eval`. — - --gt-dir + --gt_dir Directory containing ground-truth files for multi-sequence evaluation. — - --tracker-dir + --tracker_dir Directory containing tracker prediction files for multi-sequence evaluation. — diff --git a/docs/learn/track.md b/docs/learn/track.md index c4b66ac7..8511a904 100644 --- a/docs/learn/track.md +++ b/docs/learn/track.md @@ -93,14 +93,14 @@ Trackers assign stable IDs to detections across frames, maintaining object ident === "CLI" - Select a tracker with `--tracker` and tune its behavior with `--tracker.*` arguments. + Select a tracker with `--tracker` and tune its behavior with `--tracker_params.*` arguments. ```text trackers track \ --source source.mp4 \ --tracker bytetrack \ - --tracker.lost_track_buffer 60 \ - --tracker.minimum_consecutive_frames 5 + --tracker_params.lost_track_buffer 60 \ + --tracker_params.minimum_consecutive_frames 5 ``` === "Python" @@ -139,15 +139,15 @@ Trackers don't detect objects—they link detections across frames. A detection === "CLI" - Configure detection with `--model.*` arguments. Filter by confidence and class before tracking. + Configure detection with `--detection.*` arguments. Filter by confidence and class before tracking. ```text trackers track \ --source source.mp4 \ - --model rfdetr-medium \ - --model.confidence 0.3 \ - --model.device cuda \ - --classes person,car + --detection.model rfdetr-medium \ + --detection.confidence 0.3 \ + --detection.device cuda \ + --filters.classes person,car ``` === "Python" @@ -188,10 +188,10 @@ Visualization renders tracking results for debugging, demos, and qualitative eva ```text trackers track \ --source source.mp4 \ - --display \ - --show-labels \ - --show-confidence \ - --show-trajectories + --vis.display \ + --show.labels \ + --show.confidence \ + --show.trajectories ``` === "Python" @@ -274,7 +274,7 @@ Save tracking results as annotated video files or display them in real time. Specify an output path to save annotated video. ```text - trackers track --source source.mp4 --output output.mp4 --overwrite + trackers track --source source.mp4 --out.output output.mp4 --out.overwrite ``` === "Python" @@ -342,37 +342,37 @@ All arguments accepted by the `trackers track` command. — - --output + --out.output Path for output video. If a directory is given, saves as output.mp4 inside it. none - --overwrite + --out.overwrite Allow overwriting existing output files. Without this flag, existing files cause an error. false - --model + --detection.model Model identifier. Pretrained: rfdetr-nano, rfdetr-small, rfdetr-medium, rfdetr-large. Segmentation: rfdetr-seg-*. rfdetr-nano - --model.confidence + --detection.confidence Minimum confidence threshold. Lower values increase recall but may add noise. 0.5 - --model.device + --detection.device Compute device. Options: auto, cpu, cuda, cuda:0, mps. auto - --model.api_key + --detection.api_key Roboflow API key for custom hosted models. none - --classes + --filters.classes Comma-separated class names or IDs to track. Example: person,car or 0,2. all @@ -382,57 +382,57 @@ All arguments accepted by the `trackers track` command. bytetrack - --tracker.lost_track_buffer + --tracker_params.lost_track_buffer Frames to retain a track without detections. Higher values improve occlusion handling but risk ID drift. 30 - --tracker.track_activation_threshold + --tracker_params.track_activation_threshold Minimum confidence to start a new track. Lower values catch more objects but increase false positives. 0.25 - --tracker.minimum_consecutive_frames + --tracker_params.minimum_consecutive_frames Consecutive detections required before a track is confirmed. Suppresses spurious detections. 3 - --tracker.minimum_iou_threshold + --tracker_params.minimum_iou_threshold Minimum IoU overlap to match a detection to an existing track. Higher values require tighter alignment. 0.3 - --display + --vis.display Opens a live preview window. Press q or ESC to quit. false - --show-boxes + --show.boxes Draw bounding boxes around tracked objects. true - --show-masks + --show.masks Draw segmentation masks. Only available with rfdetr-seg-* models. false - --show-confidence + --show.confidence Show detection confidence scores in labels. false - --show-labels + --show.labels Show class names in labels. false - --show-ids + --show.ids Show tracker IDs in labels. true - --show-trajectories + --show.trajectories Draw motion trails showing recent positions of each track. false diff --git a/docs/learn/tune.md b/docs/learn/tune.md index 8b5ebd02..3c8c3e6b 100644 --- a/docs/learn/tune.md +++ b/docs/learn/tune.md @@ -65,11 +65,11 @@ For detections, use `id=-1`. For more details on the format and evaluation workf ```text trackers tune \ --tracker bytetrack \ - --gt-dir ./data/gt \ - --detections-dir ./data/detections \ + --gt_dir ./data/gt \ + --detections_dir ./data/detections \ --objective HOTA \ - --metrics CLEAR HOTA Identity \ - --n-trials 50 \ + --metrics '[CLEAR,HOTA,Identity]' \ + --n_trials 50 \ --output ./results/bytetrack-best.json ``` @@ -111,8 +111,8 @@ MOT17-09-FRCNN ```text trackers tune \ --tracker bytetrack \ - --gt-dir ./data/gt \ - --detections-dir ./data/detections \ + --gt_dir ./data/gt \ + --detections_dir ./data/detections \ --seqmap ./seqmap.txt ``` @@ -174,12 +174,12 @@ All arguments accepted by `trackers tune`. — - --gt-dir + --gt_dir Directory with ground-truth MOT files ({sequence}.txt). — - --detections-dir + --detections_dir Directory with detection MOT files ({sequence}.txt), one file per sequence. — @@ -189,7 +189,7 @@ All arguments accepted by `trackers tune`. HOTA - --n-trials + --n_trials Number of Optuna trials to run. 100 @@ -206,10 +206,10 @@ All arguments accepted by `trackers tune`. --seqmap Optional path to a sequence map file. When set, only listed sequences are tuned. - all files in --detections-dir + all files in --detections_dir - --output, -o + --output Path to save best parameters as JSON. None From 492873dba3d29d8828bbde3791c1fb83defb16cf Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Wed, 20 May 2026 21:41:54 +0200 Subject: [PATCH 06/13] =?UTF-8?q?refine(cli):=20rename=20OutputOptions.mot?= =?UTF-8?q?=5Foutput=20=E2=86=92=20mot=5Fresults?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename field and all internal usages in track.py (11 sites) - Update --out.mot_results in evaluate.md and detection-quality.md docs --- Co-authored-by: Claude Code --- docs/learn/detection-quality.md | 12 ++++++------ docs/learn/evaluate.md | 4 ++-- src/trackers/cli/track.py | 22 +++++++++++----------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/learn/detection-quality.md b/docs/learn/detection-quality.md index 54e073e4..12cfd89d 100644 --- a/docs/learn/detection-quality.md +++ b/docs/learn/detection-quality.md @@ -70,7 +70,7 @@ Run ByteTrack with default parameters three times, changing only the detection m --detection.model yolo26n-640 \ --tracker bytetrack \ --filters.classes person \ - --out.mot_output results/yolo26n/MOT17-13-FRCNN.txt + --out.mot_results results/yolo26n/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -82,7 +82,7 @@ Run ByteTrack with default parameters three times, changing only the detection m --detection.model yolo26n-640 \ --tracker bytetrack \ --filters.classes person \ - --out.mot_output results/yolo26n/$seq.txt + --out.mot_results results/yolo26n/$seq.txt done ``` @@ -101,7 +101,7 @@ Run ByteTrack with default parameters three times, changing only the detection m --detection.model rfdetr-nano \ --tracker bytetrack \ --filters.classes person \ - --out.mot_output results/rfdetr-nano/MOT17-13-FRCNN.txt + --out.mot_results results/rfdetr-nano/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -113,7 +113,7 @@ Run ByteTrack with default parameters three times, changing only the detection m --detection.model rfdetr-nano \ --tracker bytetrack \ --filters.classes person \ - --out.mot_output results/rfdetr-nano/$seq.txt + --out.mot_results results/rfdetr-nano/$seq.txt done ``` @@ -132,7 +132,7 @@ Run ByteTrack with default parameters three times, changing only the detection m --detection.model rfdetr-medium \ --tracker bytetrack \ --filters.classes person \ - --out.mot_output results/rfdetr-medium/MOT17-13-FRCNN.txt + --out.mot_results results/rfdetr-medium/MOT17-13-FRCNN.txt ``` === "All sequences" @@ -144,7 +144,7 @@ Run ByteTrack with default parameters three times, changing only the detection m --detection.model rfdetr-medium \ --tracker bytetrack \ --filters.classes person \ - --out.mot_output results/rfdetr-medium/$seq.txt + --out.mot_results results/rfdetr-medium/$seq.txt done ``` diff --git a/docs/learn/evaluate.md b/docs/learn/evaluate.md index fd4e220d..821a6348 100644 --- a/docs/learn/evaluate.md +++ b/docs/learn/evaluate.md @@ -90,13 +90,13 @@ For more download options, see the [download guide](download.md). Feed the pre-computed detections into a tracker and write the results to a file for evaluation. -Pass `--detection.detections` to provide input detections and `--out.mot_output` to save the tracker output in MOT format. +Pass `--detection.detections` to provide input detections and `--out.mot_results` to save the tracker output in MOT format. ```text trackers track \ --detection.detections ./data/mot17/val/MOT17-02-FRCNN/det/det.txt \ --tracker bytetrack \ - --out.mot_output results/MOT17-02-FRCNN.txt + --out.mot_results results/MOT17-02-FRCNN.txt ``` --- diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py index a153b4ca..67124052 100644 --- a/src/trackers/cli/track.py +++ b/src/trackers/cli/track.py @@ -98,12 +98,12 @@ class OutputOptions: Attributes: output: Annotated-video output path. - mot_output: MOT-format predictions output path. + mot_results: MOT-format predictions output path. overwrite: Overwrite existing output files without prompting. """ output: Path | None = None - mot_output: Path | None = None + mot_results: Path | None = None overwrite: bool = False @@ -218,7 +218,7 @@ def track( classes = filters.classes track_ids = filters.track_ids output = out.output - mot_output = out.mot_output + mot_results = out.mot_results overwrite = out.overwrite display = vis.display show_boxes = show.boxes @@ -239,8 +239,8 @@ def track( if output: _validate_output_path(_resolve_video_output_path(output), overwrite=overwrite) - if mot_output: - _validate_output_path(mot_output, overwrite=overwrite) + if mot_results: + _validate_output_path(mot_results, overwrite=overwrite) if detections is not None: model_obj: AnyModel | None = None @@ -266,7 +266,7 @@ def track( track_id_filter=track_id_filter, tracker=tracker_obj, output=output, - mot_output=mot_output, + mot_results=mot_results, display=display, show_boxes=show_boxes, show_masks=show_masks, @@ -281,7 +281,7 @@ def track( class_filter=class_filter, track_id_filter=track_id_filter, tracker=tracker_obj, - mot_output=mot_output, + mot_results=mot_results, ) @@ -291,7 +291,7 @@ def _run_frameless( class_filter: list[int] | None, track_id_filter: list[int] | None, tracker: BaseTracker, - mot_output: Path | None, + mot_results: Path | None, ) -> int: """Run tracking from pre-computed detections without a frame source.""" if not detections_data: @@ -302,7 +302,7 @@ def _run_frameless( source_info = _SourceInfo(source_type="video", total_frames=total_frames) try: - with _MOTOutput(mot_output) as mot, _TrackingProgress(source_info) as progress: + with _MOTOutput(mot_results) as mot, _TrackingProgress(source_info) as progress: for frame_idx in range(1, total_frames + 1): if frame_idx in detections_data: dets = _mot_frame_to_detections(detections_data[frame_idx]) @@ -340,7 +340,7 @@ def _run_with_source( track_id_filter: list[int] | None, tracker: BaseTracker, output: Path | None, - mot_output: Path | None, + mot_results: Path | None, display: bool, show_boxes: bool, show_masks: bool, @@ -368,7 +368,7 @@ def _run_with_source( try: with ( _VideoOutput(output, fps=source_info.fps or _DEFAULT_OUTPUT_FPS) as video, - _MOTOutput(mot_output) as mot, + _MOTOutput(mot_results) as mot, display_ctx as display_win, _TrackingProgress(source_info) as progress, ): From 123e75075360e895623d859cdf141c2810481593 Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Thu, 21 May 2026 11:46:23 +0200 Subject: [PATCH 07/13] feat(cli): expose iou_variant in TrackerParams Add iou_variant string field to TrackerParams; _init_tracker converts it to a BaseIoU instance (IoU/GIoU/DIoU/CIoU/BIoU) before forwarding to any tracker that accepts the iou kwarg. Document in track.md. --- Co-authored-by: Claude Code --- docs/learn/track.md | 5 +++++ src/trackers/cli/track.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/docs/learn/track.md b/docs/learn/track.md index 9cac25a1..dc4e8d44 100644 --- a/docs/learn/track.md +++ b/docs/learn/track.md @@ -401,6 +401,11 @@ All arguments accepted by the `trackers track` command. Minimum IoU overlap to match a detection to an existing track. Higher values require tighter alignment. 0.3 + + --tracker_params.iou_variant + IoU similarity metric for data association. Options: iou, giou, diou, ciou, biou. Applies to all trackers. + iou + --vis.display Opens a live preview window. Press q or ESC to quit. diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py index 67124052..6e935746 100644 --- a/src/trackers/cli/track.py +++ b/src/trackers/cli/track.py @@ -25,6 +25,15 @@ from trackers.io.paths import _resolve_video_output_path, _validate_output_path from trackers.io.video import _DEFAULT_OUTPUT_FPS, _DisplayWindow, _VideoOutput from trackers.utils.device import _best_device +from trackers.utils.iou import BaseIoU, BIoU, CIoU, DIoU, GIoU, IoU + +_IOU_VARIANTS: dict[str, type[BaseIoU]] = { + "iou": IoU, + "giou": GIoU, + "diou": DIoU, + "ciou": CIoU, + "biou": BIoU, +} if TYPE_CHECKING: from inference_models import AnyModel @@ -163,6 +172,9 @@ class TrackerParams: cmc_method: BoT-SORT CMC method name. cmc_downscale: BoT-SORT CMC downscale factor. instant_first_frame_activation: BoT-SORT first-frame activation toggle. + iou_variant: IoU similarity metric for data association. One of + ``iou`` (standard), ``giou``, ``diou``, ``ciou``, ``biou``. + Applies to all trackers. Defaults to ``iou``. """ lost_track_buffer: int | None = None @@ -180,6 +192,7 @@ class TrackerParams: cmc_method: str | None = None cmc_downscale: int | None = None instant_first_frame_activation: bool | None = None + iou_variant: str | None = None def track( @@ -542,8 +555,15 @@ def _init_tracker(tracker_id: str, params: TrackerParams | None) -> BaseTracker: raise ValueError(f"Unknown tracker: '{tracker_id}'. Available: {available}") raw = asdict(params) if params is not None else {} + iou_variant = raw.pop("iou_variant", None) accepted = set(info.parameters) kwargs = {k: v for k, v in raw.items() if v is not None and k in accepted} + if iou_variant is not None and "iou" in accepted: + iou_cls = _IOU_VARIANTS.get(iou_variant.lower()) + if iou_cls is None: + valid = ", ".join(_IOU_VARIANTS) + raise ValueError(f"Unknown iou_variant '{iou_variant}'. Valid: {valid}") + kwargs["iou"] = iou_cls() return info.tracker_class(**kwargs) From eee9e6c8d5384300d5818fa4abc12edea9c3e623 Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Thu, 21 May 2026 12:01:41 +0200 Subject: [PATCH 08/13] refactor(cli): move IoU variant registry to iou.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `_VARIANTS` dict and `variant_from_name()` factory to `iou.py` — single source of truth; new variants only need one-file change - Remove `_IOU_VARIANTS` dict and 6 IoU class imports from `cli/track.py`; replace with `from trackers.utils.iou import variant_from_name` - Warn (`UserWarning`) in `_init_tracker` when `iou_variant` is supplied but the chosen tracker has no `iou` param, instead of silently dropping it --- Co-authored-by: Claude Code --- src/trackers/cli/track.py | 26 +++++++++++--------------- src/trackers/utils/iou.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py index 6e935746..afb29c99 100644 --- a/src/trackers/cli/track.py +++ b/src/trackers/cli/track.py @@ -10,6 +10,7 @@ from __future__ import annotations import sys +import warnings from contextlib import nullcontext from dataclasses import asdict, dataclass from pathlib import Path @@ -25,15 +26,7 @@ from trackers.io.paths import _resolve_video_output_path, _validate_output_path from trackers.io.video import _DEFAULT_OUTPUT_FPS, _DisplayWindow, _VideoOutput from trackers.utils.device import _best_device -from trackers.utils.iou import BaseIoU, BIoU, CIoU, DIoU, GIoU, IoU - -_IOU_VARIANTS: dict[str, type[BaseIoU]] = { - "iou": IoU, - "giou": GIoU, - "diou": DIoU, - "ciou": CIoU, - "biou": BIoU, -} +from trackers.utils.iou import variant_from_name if TYPE_CHECKING: from inference_models import AnyModel @@ -558,12 +551,15 @@ def _init_tracker(tracker_id: str, params: TrackerParams | None) -> BaseTracker: iou_variant = raw.pop("iou_variant", None) accepted = set(info.parameters) kwargs = {k: v for k, v in raw.items() if v is not None and k in accepted} - if iou_variant is not None and "iou" in accepted: - iou_cls = _IOU_VARIANTS.get(iou_variant.lower()) - if iou_cls is None: - valid = ", ".join(_IOU_VARIANTS) - raise ValueError(f"Unknown iou_variant '{iou_variant}'. Valid: {valid}") - kwargs["iou"] = iou_cls() + if iou_variant is not None: + if "iou" in accepted: + kwargs["iou"] = variant_from_name(iou_variant) + else: + warnings.warn( + f"Tracker '{tracker_id}' does not support iou_variant; '{iou_variant}' will be ignored.", + UserWarning, + stacklevel=2, + ) return info.tracker_class(**kwargs) diff --git a/src/trackers/utils/iou.py b/src/trackers/utils/iou.py index 9325bf68..64982bb0 100644 --- a/src/trackers/utils/iou.py +++ b/src/trackers/utils/iou.py @@ -400,3 +400,37 @@ def _compute(self, boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: def normalize_for_fusion(self, similarity_matrix: np.ndarray) -> np.ndarray: return (similarity_matrix + 1.0) / 2.0 + + +_VARIANTS: dict[str, type[BaseIoU]] = { + "iou": IoU, + "giou": GIoU, + "diou": DIoU, + "ciou": CIoU, + "biou": BIoU, +} + + +def variant_from_name(name: str) -> BaseIoU: + """Resolve a variant name (case-insensitive) to a default-constructed instance. + + Args: + name: One of ``iou``, ``giou``, ``diou``, ``ciou``, ``biou`` + (case-insensitive). + + Returns: + A default-constructed instance of the matching :class:`BaseIoU` subclass. + + Raises: + ValueError: If ``name`` does not match any known variant. + + Examples: + >>> isinstance(variant_from_name("giou"), GIoU) + True + >>> isinstance(variant_from_name("BIOU"), BIoU) + True + """ + try: + return _VARIANTS[name.lower()]() + except KeyError as exc: + raise ValueError(f"Unknown IoU variant {name!r}. Valid: {sorted(_VARIANTS)}") from exc From e6fa27732600a371e12465e5a7a3d91b8e63949e Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Mon, 25 May 2026 09:56:09 +0200 Subject: [PATCH 09/13] refactor(cli): drop [signatures] extra from jsonargparse dep Copilot review #3281138854: jsonargparse[signatures] pulls in docstring-parser and typeshed-client. No code in trackers/cli/ imports from these packages; plain jsonargparse>=4.48.0 suffices. --- Co-authored-by: Claude Code --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6877a82d..be3a557b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "opencv-python>=4.8.0", "rich>=13.0.0", "requests>=2.28.0", - "jsonargparse[signatures]>=4.48.0", + "jsonargparse>=4.48.0", ] [project.optional-dependencies] From 437c7f05d67367a5f3cab9bc91639fb5b86ce5cf Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Mon, 25 May 2026 09:56:44 +0200 Subject: [PATCH 10/13] fix(cli): catch ValueError from _init_tracker, return exit code 1 Copilot review #3281138773: unknown tracker id or iou_variant name raised unhandled ValueError through jsonargparse.CLI, surfacing as a traceback instead of a clean CLI error. Wrap the _init_tracker call in track() and print to stderr before returning 1. --- Co-authored-by: Claude Code --- src/trackers/cli/track.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py index afb29c99..8859ba91 100644 --- a/src/trackers/cli/track.py +++ b/src/trackers/cli/track.py @@ -259,7 +259,11 @@ def track( class_filter = _resolve_class_filter(classes, class_names) track_id_filter = _resolve_track_id_filter(track_ids) - tracker_obj = _init_tracker(tracker, tracker_params) + try: + tracker_obj = _init_tracker(tracker, tracker_params) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + return 1 if source is not None: return _run_with_source( From 79ff6fb129e0b0963404c5902d9ad88b618ed99e Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Mon, 25 May 2026 09:58:35 +0200 Subject: [PATCH 11/13] fix(cli): avoid mutable default instances in track() signature Copilot review #3281138823: dataclass defaults created once at import time are shared across calls. Change option-group params to None defaults; instantiate inside the function body so each call gets a fresh instance. --- Co-authored-by: Claude Code --- src/trackers/cli/track.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py index 8859ba91..83672c6d 100644 --- a/src/trackers/cli/track.py +++ b/src/trackers/cli/track.py @@ -190,13 +190,13 @@ class TrackerParams: def track( source: str | None = None, - detection: DetectionOptions = DetectionOptions(), - filters: FilteringOptions = FilteringOptions(), + detection: DetectionOptions | None = None, + filters: FilteringOptions | None = None, tracker: str = DEFAULT_TRACKER, tracker_params: TrackerParams | None = None, - out: OutputOptions = OutputOptions(), - vis: VisualizationOptions = VisualizationOptions(), - show: ShowOptions = ShowOptions(), + out: OutputOptions | None = None, + vis: VisualizationOptions | None = None, + show: ShowOptions | None = None, ) -> int: """Run detection and tracking over a video, webcam, RTSP, or image directory. @@ -216,6 +216,16 @@ def track( Returns: Exit code: ``0`` on success, ``1`` on validation error. """ + if detection is None: + detection = DetectionOptions() + if filters is None: + filters = FilteringOptions() + if out is None: + out = OutputOptions() + if vis is None: + vis = VisualizationOptions() + if show is None: + show = ShowOptions() model = detection.model detections = detection.detections confidence = detection.confidence From d49cf28f8de17314aebcaad561eb81bc8371e61a Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Mon, 25 May 2026 10:15:39 +0200 Subject: [PATCH 12/13] test(utils): add parametrized tests for variant_from_name() Copilot review #3281138907: variant_from_name() is new CLI-facing logic but had no unit test coverage. Add TestVariantFromName class with tests for valid lookups, case-insensitivity, and invalid-name ValueError with repr(name) in the error message. --- Co-authored-by: Claude Code --- tests/utils/test_iou.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_iou.py b/tests/utils/test_iou.py index 8c992558..6eddf05e 100644 --- a/tests/utils/test_iou.py +++ b/tests/utils/test_iou.py @@ -12,7 +12,7 @@ torch = pytest.importorskip("torch") torchvision = pytest.importorskip("torchvision") -from trackers.utils.iou import BaseIoU, BIoU, CIoU, DIoU, GIoU, IoU # noqa: E402 +from trackers.utils.iou import BaseIoU, BIoU, CIoU, DIoU, GIoU, IoU, variant_from_name # noqa: E402 def _torchvision_giou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: @@ -460,3 +460,34 @@ def test_inverted_coords_gives_zero_or_negative_similarity(self, metric: BaseIoU result = metric.compute(boxes_a, boxes_b) assert result.shape == (1, 1) assert np.isfinite(result).all(), "Inverted-coord box should not produce NaN/inf" + + +class TestVariantFromName: + """Tests for variant_from_name() registry lookup.""" + + @pytest.mark.parametrize( + ("name", "expected_type"), + [ + ("iou", IoU), + ("giou", GIoU), + ("diou", DIoU), + ("ciou", CIoU), + ("biou", BIoU), + ], + ) + def test_valid_names_return_correct_instance(self, name: str, expected_type: type) -> None: + """Each lowercase variant name resolves to the right BaseIoU subclass.""" + result = variant_from_name(name) + assert isinstance(result, expected_type) + + @pytest.mark.parametrize("name", ["IOU", "GIoU", "BIOU", "DiOU", "CIou"]) + def test_case_insensitive_lookup(self, name: str) -> None: + """Lookup is case-insensitive — any casing resolves without error.""" + result = variant_from_name(name) + assert isinstance(result, BaseIoU) + + @pytest.mark.parametrize("name", ["", "foo", "wiou", "iou2"]) + def test_invalid_name_raises_value_error(self, name: str) -> None: + """Unknown names raise ValueError; repr(name) appears in the error message.""" + with pytest.raises(ValueError, match=repr(name)): + variant_from_name(name) From e52f7d501a4f0c41f6be622220a79af03d684bdd Mon Sep 17 00:00:00 2001 From: jirka <6035284+Borda@users.noreply.github.com> Date: Mon, 25 May 2026 10:23:00 +0200 Subject: [PATCH 13/13] test(cli): verify --config file support via jsonargparse parser Copilot review #3281138882 (reclassified [req]): --config is a key behaviour change with no test coverage. Add TestConfigFileSupport covering: YAML values applied, CLI args override config, nested dataclass fields parsed from config. --- Co-authored-by: Claude Code --- tests/cli/test_main.py | 57 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/cli/test_main.py diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py new file mode 100644 index 00000000..9f4c4e19 --- /dev/null +++ b/tests/cli/test_main.py @@ -0,0 +1,57 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Tests for trackers.cli.__main__ — jsonargparse CLI integration.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml +from jsonargparse import ArgumentParser + +from trackers.cli.track import track + + +@pytest.fixture() +def track_parser() -> ArgumentParser: + """ArgumentParser built from the track() signature with --config support.""" + parser = ArgumentParser(exit_on_error=False) + parser.add_function_arguments(track) + parser.add_argument("--config", action="config") + return parser + + +class TestConfigFileSupport: + """Verify jsonargparse --config flag behaviour for the track subcommand.""" + + def test_config_value_applied_to_tracker(self, track_parser: ArgumentParser, tmp_path: Path) -> None: + """YAML --config value is parsed into the track() namespace.""" + cfg = tmp_path / "run.yaml" + cfg.write_text(yaml.dump({"tracker": "sort"})) + + ns = track_parser.parse_args(["--config", str(cfg)]) + + assert ns.tracker == "sort" + + def test_cli_arg_overrides_config_value(self, track_parser: ArgumentParser, tmp_path: Path) -> None: + """Explicit CLI arg takes precedence over the --config file value.""" + cfg = tmp_path / "run.yaml" + cfg.write_text(yaml.dump({"tracker": "sort"})) + + ns = track_parser.parse_args(["--config", str(cfg), "--tracker", "bytetrack"]) + + assert ns.tracker == "bytetrack" + + def test_nested_dataclass_field_in_config(self, track_parser: ArgumentParser, tmp_path: Path) -> None: + """Nested DetectionOptions fields can be set via --config.""" + cfg = tmp_path / "run.yaml" + cfg.write_text(yaml.dump({"detection": {"confidence": 0.3}})) + + ns = track_parser.parse_args(["--config", str(cfg)]) + + assert ns.detection.confidence == pytest.approx(0.3)