diff --git a/pyproject.toml b/pyproject.toml index cc622694..c33548a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "opencv-python>=4.8.0", "rich>=13.0.0", "requests>=2.28.0", + "click>=8.0", ] [project.optional-dependencies] @@ -48,7 +49,7 @@ detection = ["inference-models>=0.19.0"] tune = ["optuna>=3.0.0"] [project.scripts] -trackers = "trackers.scripts.__main__:main" +trackers = "trackers.cli.__main__:main" [dependency-groups] dev = [ diff --git a/src/trackers/cli/__init__.py b/src/trackers/cli/__init__.py new file mode 100644 index 00000000..57226e88 --- /dev/null +++ b/src/trackers/cli/__init__.py @@ -0,0 +1,5 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ diff --git a/src/trackers/cli/__main__.py b/src/trackers/cli/__main__.py new file mode 100644 index 00000000..e6154fc8 --- /dev/null +++ b/src/trackers/cli/__main__.py @@ -0,0 +1,49 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import logging +import sys +import warnings + +import click + +from trackers.cli.download import download_command +from trackers.cli.eval import eval_command +from trackers.cli.track import track_command +from trackers.cli.tune import tune_command + + +@click.group( + context_settings={"help_option_names": ["-h", "--help"]}, +) +@click.version_option(package_name="trackers", prog_name="trackers") +@click.option("-v", "--verbose", count=True, help="Increase log verbosity (-v INFO, -vv DEBUG).") +def cli(verbose: int) -> None: + """Command-line tools for multi-object tracking.""" + level = {0: logging.WARNING, 1: logging.INFO}.get(verbose, logging.DEBUG) + logging.basicConfig(level=level, format="%(message)s", handlers=[logging.StreamHandler(sys.stderr)]) + + +cli.add_command(track_command, "track") +cli.add_command(eval_command, "eval") +cli.add_command(download_command, "download") +cli.add_command(tune_command, "tune") + + +def main() -> None: + """Main entry point for the trackers CLI.""" + warnings.warn( + "The trackers CLI is in beta. APIs may change in future releases.", + UserWarning, + stacklevel=2, + ) + cli() + + +if __name__ == "__main__": + main() diff --git a/src/trackers/cli/_options.py b/src/trackers/cli/_options.py new file mode 100644 index 00000000..d4621c39 --- /dev/null +++ b/src/trackers/cli/_options.py @@ -0,0 +1,107 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path +from typing import TypeVar + +import click + +F = TypeVar("F", bound=Callable) + +METRIC_CHOICES: list[str] = ["CLEAR", "HOTA", "Identity"] + + +def metrics_option(f: F) -> F: + """Shared --metrics option for eval and tune commands. + + Examples: + >>> import click + >>> @click.command() + ... @metrics_option + ... def cmd(metrics): pass + >>> cmd.params[0].name + 'metrics' + """ + return click.option( + "--metrics", + multiple=True, + default=("CLEAR",), + type=click.Choice(METRIC_CHOICES), + help="Metrics to compute. Repeat flag for multiple: --metrics CLEAR --metrics HOTA. Default: CLEAR.", + )(f) + + +def threshold_option(f: F) -> F: + """Shared --threshold option for eval and tune commands. + + Examples: + >>> import click + >>> @click.command() + ... @threshold_option + ... def cmd(threshold): pass + >>> cmd.params[0].name + 'threshold' + """ + return click.option( + "--threshold", + type=float, + default=0.5, + help="IoU threshold for CLEAR and Identity matching. Default: 0.5", + )(f) + + +def seqmap_option(f: F) -> F: + """Shared --seqmap option for eval and tune commands. + + Examples: + >>> import click + >>> @click.command() + ... @seqmap_option + ... def cmd(seqmap): pass + >>> cmd.params[0].name + 'seqmap' + """ + return click.option( + "--seqmap", + type=click.Path(path_type=Path), + default=None, + metavar="PATH", + help="Sequence map file listing sequences to evaluate.", + )(f) + + +def output_option(help_text: str = "Output file path.") -> Callable[[F], F]: + """Shared -o/--output option factory. + + Args: + help_text: Help text for the option. + + Returns: + Decorator that adds the output option to a command. + + Examples: + >>> import click + >>> @click.command() + ... @output_option("Output JSON file.") + ... def cmd(output): pass + >>> cmd.params[0].name + 'output' + """ + + def decorator(f: F) -> F: + return click.option( + "-o", + "--output", + type=click.Path(path_type=Path), + default=None, + metavar="PATH", + help=help_text, + )(f) + + return decorator diff --git a/src/trackers/cli/download.py b/src/trackers/cli/download.py new file mode 100644 index 00000000..f35d4d7b --- /dev/null +++ b/src/trackers/cli/download.py @@ -0,0 +1,87 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import click +from rich.console import Console +from rich.panel import Panel + +from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR +from trackers.datasets.manifest import _DATASETS + + +@click.command("download") +@click.argument("dataset", required=False, default=None) +@click.option("--list", "show_list", is_flag=True, help="List available datasets, splits, and asset types.") +@click.option( + "--split", + default=None, + help="Comma-separated splits to download (e.g. train,val,test). If omitted, all available splits are downloaded.", +) +@click.option( + "--asset", + default=None, + help=( + "Comma-separated assets to download: annotations,frames,detections." + " If omitted, all available assets are downloaded." + ), +) +@click.option("-o", "--output", default=_DEFAULT_OUTPUT_DIR, help="Output directory (default: current directory).") +@click.option( + "--cache-dir", + "cache_dir", + default=_DEFAULT_CACHE_DIR, + help="Cache directory for downloaded ZIPs (default: ~/.cache/trackers).", +) +def download_command( + dataset: str | None, + show_list: bool, + split: str | None, + asset: str | None, + output: str, + cache_dir: str, +) -> None: + """Download benchmark tracking datasets.""" + if show_list: + _print_available() + return + + if not dataset: + raise click.UsageError("Please specify a dataset name or use --list.") + + from trackers.datasets.download import download_dataset + + split_list = [s.strip() for s in split.split(",")] if split else None + asset_list = [a.strip() for a in asset.split(",")] if asset else None + + try: + download_dataset( + dataset=dataset, + split=split_list, + asset=asset_list, + output=output, + cache_dir=cache_dir, + ) + except Exception as e: + raise click.ClickException(str(e)) from e + + +def _print_available() -> None: + """Print available datasets, splits, and asset types.""" + console = Console() + for name, dataset_info in _DATASETS.items(): + description = dataset_info.get("description", "") + splits_dict: dict[str, dict] = dataset_info.get("splits", {}) + + max_split_len = max(len(s) for s in splits_dict) if splits_dict else 0 + split_lines = [ + f"{split:<{max_split_len}} {', '.join(assets.keys())}" for split, assets in splits_dict.items() + ] + + body = f"{description}\n\n" + "\n".join(split_lines) + console.print(Panel(body, title=name.value, title_align="left")) + console.print() diff --git a/src/trackers/cli/eval.py b/src/trackers/cli/eval.py new file mode 100644 index 00000000..2e282ffe --- /dev/null +++ b/src/trackers/cli/eval.py @@ -0,0 +1,111 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +from pathlib import Path +from typing import cast + +import click + +from trackers.cli._options import metrics_option, output_option, seqmap_option, threshold_option + + +@click.command("eval") +@click.option( + "--gt", + type=click.Path(path_type=Path), + default=None, + metavar="PATH", + help="Path to ground truth file (MOT format).", +) +@click.option( + "--tracker", + "tracker_path", + type=click.Path(path_type=Path), + default=None, + metavar="PATH", + help="Path to tracker predictions file (MOT format).", +) +@click.option( + "--gt-dir", + type=click.Path(path_type=Path), + default=None, + metavar="DIR", + help="Directory containing ground truth files.", +) +@click.option( + "--tracker-dir", + type=click.Path(path_type=Path), + default=None, + metavar="DIR", + help="Directory containing tracker prediction files.", +) +@seqmap_option +@metrics_option +@threshold_option +@click.option( + "--columns", multiple=True, default=(), metavar="COL", help="Metric columns to display. Default: auto-selected." +) +@output_option("Output file for results (JSON format).") +def eval_command( + gt: Path | None, + tracker_path: Path | None, + gt_dir: Path | None, + tracker_dir: Path | None, + seqmap: Path | None, + metrics: tuple[str, ...], + threshold: float, + columns: tuple[str, ...], + output: Path | None, +) -> None: + """Evaluate tracker predictions against ground truth.""" + single_mode = gt is not None and tracker_path is not None + benchmark_mode = gt_dir is not None and tracker_dir is not None + + if not single_mode and not benchmark_mode: + raise click.UsageError("Must specify either --gt/--tracker or --gt-dir/--tracker-dir") + + if single_mode and benchmark_mode: + raise click.UsageError("Cannot use both single sequence and benchmark mode") + + columns_list: list[str] | None = list(columns) if columns else None + metrics_list = list(metrics) + + from trackers.eval import evaluate_mot_sequence, evaluate_mot_sequences + + try: + if single_mode: + seq_result = evaluate_mot_sequence( + gt_path=cast(Path, gt), + tracker_path=cast(Path, tracker_path), + metrics=metrics_list, + threshold=threshold, + ) + print(seq_result.table(columns=columns_list)) + + if output: + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(seq_result.json()) + print(f"\nResults saved to: {output}") + else: + bench_result = evaluate_mot_sequences( + gt_dir=cast(Path, gt_dir), + tracker_dir=cast(Path, tracker_dir), + seqmap=seqmap, + metrics=metrics_list, + threshold=threshold, + ) + print(bench_result.table(columns=columns_list)) + + if output: + bench_result.save(output) + print(f"\nResults saved to: {output}") + + except FileNotFoundError as e: + raise click.ClickException(str(e)) from e + except ValueError as e: + raise click.ClickException(str(e)) from e diff --git a/src/trackers/cli/progress.py b/src/trackers/cli/progress.py new file mode 100644 index 00000000..67de2608 --- /dev/null +++ b/src/trackers/cli/progress.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import itertools +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import cv2 +from rich.console import Console +from rich.live import Live +from rich.text import Text + +from trackers.io.video import IMAGE_EXTENSIONS + +_SPINNER_FRAMES = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" +_STREAM_PREFIXES = ("rtsp://", "http://", "https://") +_ICON_OK = "✓" +_ICON_FAIL = "✗" + + +@dataclass +class _SourceInfo: + """Metadata about a frame source used to drive progress display. + + Attributes: + source_type: Kind of source (`video`, `image_dir`, `webcam`, + `stream`). + total_frames: Total frame count when known, `None` for unbounded + sources such as webcams and network streams. + fps: Source frame-rate when known, `None` otherwise. + """ + + source_type: Literal["video", "image_dir", "webcam", "stream"] + total_frames: int | None = None + fps: float | None = None + + +def _classify_source(source: str | Path | int) -> _SourceInfo: + """Classify a frame source and extract metadata. + + The function inspects *source* without consuming any frames so it can be + called before the main processing loop. + + Args: + source: The same value accepted by `frames_from_source`. + + Returns: + A `_SourceInfo` describing the source. + """ + if isinstance(source, int) or (isinstance(source, str) and source.isdigit()): + return _SourceInfo(source_type="webcam") + + source_str = str(source) + + if any(source_str.lower().startswith(p) for p in _STREAM_PREFIXES): + return _SourceInfo(source_type="stream") + + path = Path(source_str) + if path.is_dir(): + count = sum(1 for p in path.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTENSIONS) + return _SourceInfo( + source_type="image_dir", + total_frames=count if count > 0 else None, + ) + + cap = cv2.VideoCapture(source_str) + if not cap.isOpened(): + # Cannot open; still classify as video - the real error will come + # from frames_from_source later. + return _SourceInfo(source_type="video") + + try: + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + return _SourceInfo( + source_type="video", + total_frames=total if total > 0 else None, + fps=fps if fps and fps > 0 else None, + ) + finally: + cap.release() + + +def _format_time(seconds: float) -> str: + """Format `seconds` as `H:MM:SS` or `M:SS`.""" + if seconds < 0: + return "--" + minutes, seconds_remainder = divmod(int(seconds), 60) + hours, minutes = divmod(minutes, 60) + if hours > 0: + return f"{hours}:{minutes:02d}:{seconds_remainder:02d}" + return f"{minutes}:{seconds_remainder:02d}" + + +class _TrackingProgress: + """Context-manager that renders a single live progress line. + + Args: + source_info: Source metadata returned by `_classify_source`. + console: Optional `Console` instance (useful for testing with a + `StringIO` file). + """ + + def __init__( + self, + source_info: _SourceInfo, + console: Console | None = None, + ) -> None: + self._source_info = source_info + self._console = console or Console() + self._frames_processed = 0 + self._start_time: float = 0.0 + self._spinner = itertools.cycle(_SPINNER_FRAMES) + self._live: Live | None = None + self._interrupted = False + + def update(self) -> None: + """Record one processed frame and refresh the display.""" + self._frames_processed += 1 + icon = next(self._spinner) + if self._live is not None: + self._live.update(self._build_line(icon)) + + def complete(self, *, interrupted: bool = False) -> None: + """Signal that the processing loop has ended. + + Must be called before leaving the `with` block so that `__exit__` + can render the correct final state. + + Args: + interrupted: `True` when the loop was terminated early (e.g. + display-quit). + """ + self._interrupted = interrupted + + def __enter__(self) -> _TrackingProgress: + self._start_time = time.monotonic() + self._live = Live( + self._build_line("⠋"), + console=self._console, + refresh_per_second=12, + transient=True, + ) + self._live.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + if self._live is not None: + self._live.__exit__(None, None, None) + + icon, suffix = self._resolve_final_state(exc_type) + final = self._build_line(icon, show_eta=False, suffix=suffix) + self._console.print(final) + + @property + def _is_bounded(self) -> bool: + """Whether the source has a known total frame count.""" + return self._source_info.total_frames is not None + + def _resolve_final_state(self, exc_type: type[BaseException] | None) -> tuple[str, str]: + """Return `(icon, suffix)` for the final printed line.""" + is_real_error = exc_type is not None and not issubclass(exc_type, KeyboardInterrupt) + + if is_real_error: + return (_ICON_FAIL, "(source lost)") + + was_stopped_early = exc_type is not None or self._interrupted + + if was_stopped_early and self._is_bounded: + return (_ICON_FAIL, "(interrupted)") + + return (_ICON_OK, "") + + def _build_line( + self, + icon: str, + *, + show_eta: bool = True, + suffix: str = "", + ) -> Text: + """Compose the single-line progress string.""" + elapsed = time.monotonic() - self._start_time + fps = self._frames_processed / elapsed if elapsed > 0 else 0.0 + total = self._source_info.total_frames + + if total is not None: + total_str = str(total) + frames_part = f"{self._frames_processed:>{len(total_str)}} / {total_str}" + else: + frames_part = f"{self._frames_processed} / --" + + if total is not None and total > 0: + percentage = self._frames_processed / total * 100 + percentage_part = f"{percentage:>3.0f}%" + else: + percentage_part = " --" + + fps_part = f"{fps:>.1f} fps" + elapsed_part = f"{_format_time(elapsed)} elapsed" + + parts = [ + f"{icon} Tracking", + f"{frames_part} frames", + percentage_part, + fps_part, + elapsed_part, + ] + + if show_eta: + if total is not None and fps > 0: + remaining = (total - self._frames_processed) / fps + parts.append(f"eta {_format_time(remaining)}") + else: + parts.append("eta --") + + if suffix: + parts.append(suffix) + + return Text(" ".join(parts)) diff --git a/src/trackers/cli/track.py b/src/trackers/cli/track.py new file mode 100644 index 00000000..7b130959 --- /dev/null +++ b/src/trackers/cli/track.py @@ -0,0 +1,732 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import sys +from contextlib import nullcontext +from pathlib import Path +from typing import TYPE_CHECKING + +import click +import numpy as np +import supervision as sv + +from trackers import frames_from_source +from trackers.cli.progress import _classify_source, _SourceInfo, _TrackingProgress +from trackers.core.base import BaseTracker +from trackers.io.mot import _mot_frame_to_detections, _MOTOutput, load_mot_file +from trackers.io.paths import _resolve_video_output_path, _validate_output_path +from trackers.io.video import _DEFAULT_OUTPUT_FPS, _DisplayWindow, _VideoOutput +from trackers.utils.device import _best_device + +if TYPE_CHECKING: + from inference_models import AnyModel + +DEFAULT_MODEL = "rfdetr-nano" +DEFAULT_TRACKER = "bytetrack" +DEFAULT_CONFIDENCE = 0.5 +DEFAULT_DEVICE = "auto" + +COLOR_PALETTE = sv.ColorPalette.from_hex( + [ + "#ffff00", + "#ff9b00", + "#ff8080", + "#ff66b2", + "#ff66ff", + "#b266ff", + "#9999ff", + "#3399ff", + "#66ffff", + "#33ff99", + "#66ff66", + "#99ff00", + ] +) + + +def _make_tracker_param_callback(param_name: str) -> click.types.FuncParamType: + """Return a click callback that stores the value in ctx.obj under param_name. + + Args: + param_name: Key to store under in ctx.obj dict. + + Returns: + Click callback function. + + Examples: + >>> cb = _make_tracker_param_callback("min_score") + >>> callable(cb) + True + """ + + def _cb(ctx: click.Context, param: click.Parameter, value: object) -> object: + d = ctx.ensure_object(dict) + d[param_name] = value + return value + + return _cb + + +@click.command("track") +@click.option( + "--source", default=None, metavar="PATH", help="Video file, webcam index (0), RTSP URL, or image directory." +) +@click.option( + "--model", + default=None, + metavar="ID", + help=( + f"Model ID for detection. Pretrained: rfdetr-nano, rfdetr-base, etc." + f" Custom: workspace/project/version. Default: {DEFAULT_MODEL}" + ), +) +@click.option( + "--detections", + type=click.Path(path_type=Path), + default=None, + metavar="PATH", + help="Load pre-computed detections from MOT format file.", +) +@click.option( + "--model.confidence", + "model_confidence", + type=float, + default=DEFAULT_CONFIDENCE, + metavar="FLOAT", + help=f"Detection confidence threshold. Default: {DEFAULT_CONFIDENCE}", +) +@click.option( + "--model.device", + "model_device", + default=DEFAULT_DEVICE, + metavar="DEVICE", + help=f"Device: auto, cpu, cuda, cuda:0, mps. Default: {DEFAULT_DEVICE}", +) +@click.option( + "--model.api_key", "model_api_key", default=None, metavar="KEY", help="Roboflow API key for custom models." +) +@click.option( + "--classes", + default=None, + metavar="NAMES_OR_IDS", + help="Filter by class names or IDs (comma-separated, e.g., person,car).", +) +@click.option( + "--track_ids", default=None, metavar="IDS", help="Filter output by track IDs (comma-separated, e.g., 1,3,5)" +) +@click.option( + "--tracker", + default=DEFAULT_TRACKER, + metavar="ID", + help=f"Tracking algorithm. Default: {DEFAULT_TRACKER}", +) +@click.option( + "-o", + "--output", + "output", + type=click.Path(path_type=Path), + default=None, + metavar="PATH", + help="Output video file path.", +) +@click.option( + "--mot-output", + "mot_output", + type=click.Path(path_type=Path), + default=None, + metavar="PATH", + help="Output MOT format file path.", +) +@click.option("--overwrite", is_flag=True, help="Overwrite existing output files.") +@click.option("--display", is_flag=True, help="Show preview window.") +@click.option("--show-boxes/--no-boxes", "show_boxes", default=True, help="Draw bounding boxes. Default: True") +@click.option("--show-masks", "show_masks", is_flag=True, help="Draw segmentation masks (seg models only).") +@click.option("--show-labels", "show_labels", is_flag=True, help="Show class labels.") +@click.option("--show-ids/--no-ids", "show_ids", default=True, help="Show track IDs. Default: True") +@click.option("--show-confidence", "show_confidence", is_flag=True, help="Show confidence scores.") +@click.option("--show-trajectories", "show_trajectories", is_flag=True, help="Draw track trajectories.") +@click.pass_context +def track_command( + ctx: click.Context, + source: str | None, + model: str | None, + detections: Path | None, + model_confidence: float, + model_device: str, + model_api_key: str | None, + classes: str | None, + track_ids: str | None, + tracker: str, + output: Path | None, + mot_output: Path | None, + overwrite: bool, + display: bool, + show_boxes: bool, + show_masks: bool, + show_labels: bool, + show_ids: bool, + show_confidence: bool, + show_trajectories: bool, +) -> None: + """Track objects in video using detection and tracking.""" + needs_frames = output or display + + if source is None and not detections: + raise click.UsageError("--source is required when not using --detections.") + + if needs_frames and source is None: + raise click.UsageError("--source is required when using --output or --display.") + + if model is not None and detections is not None: + raise click.UsageError("--model and --detections are mutually exclusive.") + + if model is None and detections is None: + model = DEFAULT_MODEL + + if output: + _validate_output_path(_resolve_video_output_path(output), overwrite=overwrite) + if mot_output: + _validate_output_path(mot_output, overwrite=overwrite) + + if detections: + inference_model = None + detections_data = load_mot_file(detections) + class_names: list[str] = [] + else: + inference_model = _init_model( + model if model is not None else DEFAULT_MODEL, + device=model_device, + api_key=model_api_key, + ) + detections_data = None + class_names = getattr(inference_model, "class_names", []) + + class_filter = _resolve_class_filter(classes, class_names) + track_id_filter = _resolve_track_id_filter(track_ids) + + tracker_params = _extract_tracker_params_from_ctx(tracker, ctx) + mot_tracker = _init_tracker(tracker, **tracker_params) + + if source is not None: + rc = _run_with_source( + source=source, + mot_tracker=mot_tracker, + inference_model=inference_model, + detections_data=detections_data, + class_names=class_names, + class_filter=class_filter, + track_id_filter=track_id_filter, + model_confidence=model_confidence, + output=output, + mot_output=mot_output, + display=display, + show_boxes=show_boxes, + show_masks=show_masks, + show_labels=show_labels, + show_ids=show_ids, + show_confidence=show_confidence, + show_trajectories=show_trajectories, + ) + else: + rc = _run_frameless( + detections_data=detections_data, + class_filter=class_filter, + track_id_filter=track_id_filter, + mot_tracker=mot_tracker, + mot_output=mot_output, + ) + + if rc != 0: + sys.exit(rc) + + +_CLI_PRIMITIVE_TYPES: frozenset[type] = frozenset({int, float, str, bool}) + + +def _add_tracker_params(cmd: click.Command) -> None: + """Append dynamic tracker parameters from registry to a click Command. + + Only parameters with primitive CLI-friendly types (int, float, str, bool) + are added; class-type defaults would cause click to invoke the class with + no arguments at parse time. + + Args: + cmd: Click command to extend with tracker-specific options. + """ + existing_names = {p.name for p in cmd.params} + for tracker_id in BaseTracker._registered_trackers(): + info = BaseTracker._lookup_tracker(tracker_id) + if info is None: + continue + for param_name, param_info in info.parameters.items(): + opt_name = f"--tracker.{param_name}" + if param_name in existing_names: + continue + if param_info.param_type not in _CLI_PRIMITIVE_TYPES: + continue + existing_names.add(param_name) + + if param_info.param_type is bool: + if param_info.default_value: + decls = [f"--tracker.{param_name}/--no-tracker.{param_name}"] + else: + decls = [opt_name] + opt = click.Option( + decls, + default=param_info.default_value, + expose_value=False, + callback=_make_tracker_param_callback(param_name), + help=f"{param_info.description} Default: {param_info.default_value}", + ) + else: + opt = click.Option( + [opt_name], + type=param_info.param_type, + default=param_info.default_value, + metavar=param_info.param_type.__name__.upper(), + expose_value=False, + callback=_make_tracker_param_callback(param_name), + help=f"{param_info.description} Default: {param_info.default_value}", + ) + cmd.params.append(opt) + + +_add_tracker_params(track_command) + + +def _extract_tracker_params_from_ctx(tracker_id: str, ctx: click.Context) -> dict[str, object]: + """Extract tracker parameters collected by dynamic option callbacks. + + Args: + tracker_id: Registered tracker name. + ctx: Click context carrying values set by dynamic callbacks in ctx.obj. + + Returns: + Dictionary of tracker parameters with non-None values. + """ + info = BaseTracker._lookup_tracker(tracker_id) + if info is None: + return {} + obj = ctx.obj or {} + return {name: obj[name] for name in info.parameters if name in obj and obj[name] is not None} + + +def _run_frameless( + detections_data: dict | None, + class_filter: list[int] | None, + track_id_filter: list[int] | None, + mot_tracker: BaseTracker, + mot_output: Path | None, +) -> int: + """Run tracking from pre-computed detections without frame source. + + Args: + detections_data: Pre-loaded MOT detections keyed by frame index. + class_filter: Class IDs to keep, or None for all. + track_id_filter: Track IDs to keep in output, or None for all. + mot_tracker: Initialised tracker instance. + mot_output: Output path for MOT-format file, or None. + + Returns: + Exit code: 0 on success, 1 on error. + """ + if detections_data is None or not detections_data: + click.echo("Error: No detections found in file.", err=True) + return 1 + + total_frames = max(detections_data.keys()) + source_info = _SourceInfo(source_type="video", total_frames=total_frames) + + try: + with ( + _MOTOutput(mot_output) as mot, + _TrackingProgress(source_info) as progress, + ): + interrupted = False + for frame_idx in range(1, total_frames + 1): + if frame_idx in detections_data: + frame_detections = _mot_frame_to_detections(detections_data[frame_idx]) + else: + frame_detections = sv.Detections.empty() + + if class_filter is not None and len(frame_detections) > 0: + mask = np.isin(frame_detections.class_id, class_filter) + frame_detections = frame_detections[mask] # type: ignore[assignment] + + tracked = mot_tracker.update(frame_detections) + + if track_id_filter is not None and len(tracked) > 0: + if tracked.tracker_id is not None: + mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) + tracked = tracked[mask] + + mot.write(frame_idx, tracked) + progress.update() + + progress.complete(interrupted=interrupted) + + except KeyboardInterrupt: + pass + + return 0 + + +def _run_with_source( + source: str, + mot_tracker: BaseTracker, + inference_model: AnyModel | None, + detections_data: dict | None, + class_names: list[str], + class_filter: list[int] | None, + track_id_filter: list[int] | None, + model_confidence: float, + output: Path | None, + mot_output: Path | None, + display: bool, + show_boxes: bool, + show_masks: bool, + show_labels: bool, + show_ids: bool, + show_confidence: bool, + show_trajectories: bool, +) -> int: + """Run tracking with a frame source (video, webcam, images). + + Args: + source: Path, webcam index, or RTSP URL. + mot_tracker: Initialised tracker instance. + inference_model: Detection model, or None when using pre-computed detections. + detections_data: Pre-loaded MOT detections, or None when using model. + class_names: Class names for label display. + class_filter: Class IDs to keep, or None for all. + track_id_filter: Track IDs to keep in output, or None for all. + model_confidence: Confidence threshold for detection filtering. + output: Output video path, or None. + mot_output: Output MOT-format path, or None. + display: Whether to show live preview window. + show_boxes: Draw bounding boxes. + show_masks: Draw segmentation masks. + show_labels: Show class labels on tracks. + show_ids: Show track IDs on tracks. + show_confidence: Show detection confidence scores. + show_trajectories: Draw track trajectories. + + Returns: + Exit code: 0 on success. + """ + frame_gen = frames_from_source(source) + source_info = _classify_source(source) + + annotators, label_annotator = _init_annotators( + show_boxes=show_boxes, + show_masks=show_masks, + show_labels=show_labels, + show_ids=show_ids, + show_confidence=show_confidence, + ) + trace_annotator = None + if show_trajectories: + trace_annotator = sv.TraceAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK) + + display_ctx = _DisplayWindow() if display else nullcontext() + + try: + with ( + _VideoOutput(output, fps=source_info.fps or _DEFAULT_OUTPUT_FPS) as video, + _MOTOutput(mot_output) as mot, + display_ctx as disp, + _TrackingProgress(source_info) as progress, + ): + interrupted = False + for frame_idx, frame in frame_gen: + if inference_model is not None: + frame_detections = _run_model(inference_model, frame, model_confidence) + elif detections_data is not None and frame_idx in detections_data: + frame_detections = _mot_frame_to_detections(detections_data[frame_idx]) + else: + frame_detections = sv.Detections.empty() + + if class_filter is not None and len(frame_detections) > 0: + mask = np.isin(frame_detections.class_id, class_filter) + frame_detections = frame_detections[mask] # type: ignore[assignment] + + tracked = mot_tracker.update(frame_detections, frame) + + if track_id_filter is not None and len(tracked) > 0: + if tracked.tracker_id is not None: + mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) + tracked = tracked[mask] + + mot.write(frame_idx, tracked) + progress.update() + + if display or output: + annotated = frame.copy() + if trace_annotator is not None: + annotated = trace_annotator.annotate(annotated, tracked) + for annotator in annotators: + annotated = annotator.annotate(annotated, tracked) + if label_annotator is not None: + labeled = tracked[tracked.tracker_id != -1] + labels = _format_labels( + labeled, + class_names, + show_ids=show_ids, + show_labels=show_labels, + show_confidence=show_confidence, + ) + annotated = label_annotator.annotate(annotated, labeled, labels) + + video.write(annotated) + + if disp is not None: + disp.show(annotated) + if disp.quit_requested: + interrupted = True + break + + progress.complete(interrupted=interrupted) + + except KeyboardInterrupt: + pass + + return 0 + + +def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None: + """Resolve a comma-separated ``--track-ids`` value to a list of integer IDs. + + Args: + track_ids_arg: Raw ``--track-ids`` string (e.g. ``"1,3,5"``). ``None`` + means no filter. + + Returns: + List of integer track IDs, or ``None`` when no valid filter remains. + + Examples: + >>> _resolve_track_id_filter(None) is None + True + >>> _resolve_track_id_filter("1,3,5") + [1, 3, 5] + """ + if not track_ids_arg: + return None + + track_ids: list[int] = [] + for token in track_ids_arg.split(","): + token = token.strip() + try: + track_ids.append(int(token)) + except ValueError: + click.echo(f"Warning: '{token}' is not a valid track ID, skipping.", err=True) + return track_ids if track_ids else None + + +def _resolve_class_filter( + classes_arg: str | None, + class_names: list[str], +) -> list[int] | None: + """Resolve a comma-separated ``--classes`` value to a list of integer IDs. + + Each token is checked independently: if it parses as an ``int`` it is used + directly as a class ID; otherwise it is looked up by name in *class_names*. + Unknown names are printed as warnings and skipped. + + Args: + classes_arg: Raw ``--classes`` string (e.g. ``"person,car"`` or + ``"0,2"`` or ``"person,2"``). ``None`` means no filter. + class_names: Ordered list of class names where the index equals the + class ID (as provided by the model). + + Returns: + List of integer class IDs, or ``None`` when no valid filter remains. + + Examples: + >>> _resolve_class_filter(None, ["person", "car"]) is None + True + >>> _resolve_class_filter("person,car", ["person", "car"]) + [0, 1] + """ + if not classes_arg: + return None + + requested = [token.strip() for token in classes_arg.split(",")] + name_to_id = {name: i for i, name in enumerate(class_names)} + class_filter: list[int] = [] + for token in requested: + try: + class_filter.append(int(token)) + except ValueError: + if token in name_to_id: + class_filter.append(name_to_id[token]) + else: + click.echo(f"Warning: class '{token}' not found in model class list, skipping.", err=True) + return class_filter if class_filter else None + + +def _init_model( + model_id: str, + *, + device: str = DEFAULT_DEVICE, + api_key: str | None = None, +) -> AnyModel: + """Load detection model via inference-models. + + Args: + model_id: Model identifier (e.g., 'rfdetr-nano' or 'workspace/project/version'). + device: Device to load model on ('auto', 'cpu', 'cuda', 'mps'). + api_key: Roboflow API key for custom models. + + Returns: + Loaded model instance. + """ + try: + from inference_models import AutoModel + except ImportError as e: + raise click.ClickException( + "inference-models is required for model-based detection.\nInstall with: pip install 'trackers[detection]'" + ) from e + + resolved_device = _best_device() if device == DEFAULT_DEVICE else device + + return AutoModel.from_pretrained(model_id, api_key=api_key, device=resolved_device) + + +def _run_model(model: AnyModel, frame: np.ndarray, confidence: float) -> sv.Detections: + """Run model inference and return sv.Detections. + + Args: + model: Loaded detection model. + frame: BGR image array. + confidence: Minimum confidence threshold. + + Returns: + Filtered detections. + """ + predictions = model(frame) + if not predictions: + return sv.Detections.empty() + + detections = predictions[0].to_supervision() + + if len(detections) > 0 and detections.confidence is not None: + mask = detections.confidence >= confidence + detections = detections[mask] + + return detections + + +def _init_tracker(tracker_id: str, **kwargs: object) -> BaseTracker: + """Create tracker instance from registry. + + Args: + tracker_id: Registered tracker name (e.g., 'bytetrack', 'sort'). + **kwargs: Tracker-specific parameters. + + Returns: + Initialized tracker instance. + + Raises: + ValueError: If tracker_id is not registered. + """ + info = BaseTracker._lookup_tracker(tracker_id) + if info is None: + available = ", ".join(BaseTracker._registered_trackers()) + raise ValueError(f"Unknown tracker: '{tracker_id}'. Available: {available}") + + return info.tracker_class(**kwargs) + + +def _init_annotators( + show_boxes: bool = False, + show_masks: bool = False, + show_labels: bool = False, + show_ids: bool = False, + show_confidence: bool = False, +) -> tuple[list, sv.LabelAnnotator | None]: + """Initialize supervision annotators based on display options. + + Args: + show_boxes: Create BoxAnnotator. + show_masks: Create MaskAnnotator. + show_labels: Include class labels (triggers LabelAnnotator). + show_ids: Include track IDs (triggers LabelAnnotator). + show_confidence: Include confidence scores (triggers LabelAnnotator). + + Returns: + Tuple of (annotators list, label_annotator or None). + Label annotator is separate because it needs custom labels per frame. + """ + annotators: list = [] + label_annotator: sv.LabelAnnotator | None = None + + if show_boxes: + annotators.append(sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK)) + + if show_masks: + annotators.append(sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.TRACK)) + + if show_labels or show_ids or show_confidence: + label_annotator = sv.LabelAnnotator( + color=COLOR_PALETTE, + text_color=sv.Color.BLACK, + text_position=sv.Position.TOP_LEFT, + color_lookup=sv.ColorLookup.TRACK, + ) + + return annotators, label_annotator + + +def _format_labels( + detections: sv.Detections, + class_names: list[str], + *, + show_ids: bool = False, + show_labels: bool = False, + show_confidence: bool = False, +) -> list[str]: + """Generate label strings for each detection. + + Args: + detections: Detections to generate labels for. + class_names: List of class names for lookup. + show_ids: Include tracker IDs in labels. + show_labels: Include class names in labels. + show_confidence: Include confidence scores in labels. + + Returns: + List of label strings, one per detection. + + Examples: + >>> import numpy as np + >>> import supervision as sv + >>> dets = sv.Detections(xyxy=np.array([[0., 0., 1., 1.]])) + >>> _format_labels(dets, []) + [''] + """ + labels = [] + + for i in range(len(detections)): + parts = [] + + if show_ids and detections.tracker_id is not None: + parts.append(f"#{int(detections.tracker_id[i])}") + + if show_labels and detections.class_id is not None: + class_id = int(detections.class_id[i]) + if class_names and 0 <= class_id < len(class_names): + parts.append(class_names[class_id]) + else: + parts.append(str(class_id)) + + if show_confidence and detections.confidence is not None: + parts.append(f"{detections.confidence[i]:.2f}") + + labels.append(" ".join(parts)) + + return labels diff --git a/src/trackers/cli/tune.py b/src/trackers/cli/tune.py new file mode 100644 index 00000000..c32fd907 --- /dev/null +++ b/src/trackers/cli/tune.py @@ -0,0 +1,145 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +import json +from pathlib import Path + +import click + +from trackers.cli._options import metrics_option, output_option, seqmap_option, threshold_option + + +@click.command("tune") +@click.option( + "--tracker", "tracker_id", required=True, metavar="ID", help="Tracker ID to tune (e.g. bytetrack, sort, ocsort)." +) +@click.option( + "--gt-dir", + type=click.Path(path_type=Path), + required=True, + metavar="DIR", + help="Directory containing ground-truth MOT files.", +) +@click.option( + "--detections-dir", + type=click.Path(path_type=Path), + required=True, + metavar="DIR", + help="Directory containing pre-computed detection files in MOT flat format (one {seq}.txt per sequence).", +) +@click.option( + "--objective", + default="HOTA", + type=click.Choice(["MOTA", "HOTA", "IDF1"]), + help="Scalar metric to maximise. Default: HOTA.", +) +@click.option( + "--n-trials", "n_trials", type=int, default=100, metavar="N", help="Number of Optuna trials to run. Default: 100." +) +@metrics_option +@threshold_option +@seqmap_option +@output_option("Output file for best parameters (JSON format).") +def tune_command( + tracker_id: str, + gt_dir: Path, + detections_dir: Path, + objective: str, + n_trials: int, + metrics: tuple[str, ...], + threshold: float, + seqmap: Path | None, + output: Path | None, +) -> None: + """Tune tracker hyperparameters via Optuna.""" + rc = tune( + tracker=tracker_id, + gt_dir=gt_dir, + detections_dir=detections_dir, + objective=objective, + n_trials=n_trials, + metrics=list(metrics), + threshold=threshold, + seqmap=seqmap, + output=output, + ) + if rc != 0: + raise click.exceptions.Exit(rc) + + +def tune( + tracker: str, + gt_dir: Path, + detections_dir: Path, + objective: str = "HOTA", + n_trials: int = 100, + metrics: list[str] | None = None, + threshold: float = 0.5, + seqmap: Path | None = None, + output: Path | None = None, +) -> int: + """Tune tracker hyperparameters using Optuna. + + Args: + tracker: Tracker ID to tune (e.g. bytetrack, sort). + gt_dir: Directory of ground-truth MOT files. + detections_dir: Directory of pre-computed detection files in MOT flat + format (one {seq}.txt per sequence). + objective: Scalar metric to maximise. Options: MOTA, HOTA, IDF1. + n_trials: Number of Optuna trials to run. + metrics: Metric families to compute. Options: CLEAR, HOTA, Identity. + Default: CLEAR. + threshold: IoU threshold for CLEAR and Identity matching. + seqmap: Sequence map file listing sequences to evaluate. + output: Output file path for best parameters (JSON format). + + Returns: + Exit code: 0 on success, 1 on error. + """ + if metrics is None: + metrics = ["CLEAR"] + + from trackers.tune import Tuner + + try: + tuner = Tuner( + tracker_id=tracker, + gt_dir=gt_dir, + detections_dir=detections_dir, + metrics=metrics, + objective=objective, + n_trials=n_trials, + threshold=threshold, + seqmap=seqmap, + ) + except (ValueError, ImportError, FileNotFoundError) as e: + click.echo(str(e), err=True) + return 1 + + try: + best_params = tuner.run() + except Exception as e: + click.echo(f"Error during tuning: {e}", err=True) + return 1 + + click.echo(f"\nBest parameters for {tracker}:") + for name, value in best_params.items(): + click.echo(f" {name}: {value}") + if tuner.study is not None: + click.echo(f"\nBest {objective}: {tuner.study.best_value:.4f}") + + if output: + try: + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(best_params, indent=2)) + except OSError as e: + click.echo(f"Error writing output: {e}", err=True) + return 1 + click.echo(f"\nResults saved to: {output}") + + return 0 diff --git a/src/trackers/scripts/__main__.py b/src/trackers/scripts/__main__.py index 0993f8c7..799dc482 100644 --- a/src/trackers/scripts/__main__.py +++ b/src/trackers/scripts/__main__.py @@ -1,70 +1,11 @@ -#!/usr/bin/env python # ------------------------------------------------------------------------ # Trackers # Copyright (c) 2026 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -from __future__ import annotations - -import argparse -import sys -import warnings - - -def main() -> int: - """Main entry point for the trackers CLI.""" - # Beta warning - warnings.warn( - "The trackers CLI is in beta. APIs may change in future releases.", - UserWarning, - stacklevel=2, - ) - - parser = argparse.ArgumentParser( - prog="trackers", - description="Command-line tools for multi-object tracking.", - epilog="For more information, visit: https://github.com/roboflow/trackers", - ) - parser.add_argument( - "--version", - action="store_true", - help="Show version and exit.", - ) - - subparsers = parser.add_subparsers( - dest="command", - title="commands", - description="Available commands:", - ) - - # Import and register subcommands - from trackers.scripts.download import add_download_subparser - from trackers.scripts.eval import add_eval_subparser - from trackers.scripts.track import add_track_subparser - from trackers.scripts.tune import add_tune_subparser - - add_download_subparser(subparsers) - add_eval_subparser(subparsers) - add_track_subparser(subparsers) - add_tune_subparser(subparsers) - - # Parse arguments - args = parser.parse_args() - - if args.version: - from importlib.metadata import version - - print(f"trackers {version('trackers')}") - return 0 - - if args.command is None: - parser.print_help() - return 0 - - # Execute the command - return args.func(args) - +# Backward-compatibility shim — trackers.scripts is deprecated; use trackers.cli +from trackers.cli.__main__ import main if __name__ == "__main__": - sys.exit(main()) + main() diff --git a/src/trackers/scripts/download.py b/src/trackers/scripts/download.py index de8e461f..a8417089 100644 --- a/src/trackers/scripts/download.py +++ b/src/trackers/scripts/download.py @@ -1,109 +1,8 @@ -#!/usr/bin/env python # ------------------------------------------------------------------------ # Trackers # Copyright (c) 2026 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -from __future__ import annotations - -import argparse -import sys - -from rich.console import Console -from rich.panel import Panel - -from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR -from trackers.datasets.manifest import _DATASETS - - -def add_download_subparser( - subparsers: argparse._SubParsersAction, -) -> None: - """Add the download subcommand to the argument parser.""" - parser = subparsers.add_parser( - "download", - help="Download benchmark tracking datasets.", - description="Download tracking datasets from the official trackers bucket.", - ) - - parser.add_argument( - "--list", - action="store_true", - help="List available datasets, splits, and asset types.", - ) - parser.add_argument( - "dataset", - nargs="?", - help="Dataset name (e.g. mot17, sportsmot).", - ) - parser.add_argument( - "--split", - help="Comma-separated splits to download (e.g. train,val,test). " - "If omitted, all available splits are downloaded.", - ) - parser.add_argument( - "--asset", - help="Comma-separated assets to download: annotations,frames,detections. " - "If omitted, all available assets are downloaded.", - ) - parser.add_argument( - "-o", - "--output", - default=_DEFAULT_OUTPUT_DIR, - help="Output directory (default: current directory).", - ) - parser.add_argument( - "--cache-dir", - default=_DEFAULT_CACHE_DIR, - help="Cache directory for downloaded ZIPs (default: ~/.cache/trackers).", - ) - - parser.set_defaults(func=_run_download) - - -def _run_download(args: argparse.Namespace) -> int: - """Execute the download subcommand.""" - if args.list: - _print_available() - return 0 - - if not args.dataset: - print("Please specify a dataset name or use --list.", file=sys.stderr) - return 1 - - from trackers.datasets.download import download_dataset - - split_list = [s.strip() for s in args.split.split(",")] if args.split else None - asset_list = [a.strip() for a in args.asset.split(",")] if args.asset else None - - try: - download_dataset( - dataset=args.dataset, - split=split_list, - asset=asset_list, - output=args.output, - cache_dir=args.cache_dir, - ) - except Exception as e: - print(f"Error: {e}", file=sys.stderr) - return 1 - - return 0 - - -def _print_available() -> None: - """Print available datasets, splits, and asset types.""" - console = Console() - for name, dataset_info in _DATASETS.items(): - description = dataset_info.get("description", "") - splits_dict: dict[str, dict] = dataset_info.get("splits", {}) - - max_split_len = max(len(s) for s in splits_dict) if splits_dict else 0 - split_lines = [ - f"{split:<{max_split_len}} {', '.join(assets.keys())}" for split, assets in splits_dict.items() - ] - - body = f"{description}\n\n" + "\n".join(split_lines) - console.print(Panel(body, title=name.value, title_align="left")) - console.print() +# Backward-compatibility shim — trackers.scripts is deprecated; use trackers.cli +from trackers.cli.download import _print_available, download_command # noqa: F401 diff --git a/src/trackers/scripts/eval.py b/src/trackers/scripts/eval.py index 7bd25f21..848176d4 100644 --- a/src/trackers/scripts/eval.py +++ b/src/trackers/scripts/eval.py @@ -1,169 +1,8 @@ -#!/usr/bin/env python # ------------------------------------------------------------------------ # Trackers # Copyright (c) 2026 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -from __future__ import annotations - -import argparse -import logging -import sys -from pathlib import Path - - -def add_eval_subparser(subparsers: argparse._SubParsersAction) -> None: - """Add the eval subcommand to the argument parser.""" - parser = subparsers.add_parser( - "eval", - help="Evaluate tracker predictions against ground truth.", - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - # Single sequence mode - single_group = parser.add_argument_group("single sequence evaluation") - single_group.add_argument( - "--gt", - type=Path, - metavar="PATH", - help="Path to ground truth file (MOT format).", - ) - single_group.add_argument( - "--tracker", - type=Path, - metavar="PATH", - help="Path to tracker predictions file (MOT format).", - ) - - # Benchmark mode - bench_group = parser.add_argument_group("benchmark evaluation") - bench_group.add_argument( - "--gt-dir", - type=Path, - metavar="DIR", - help="Directory containing ground truth files.", - ) - bench_group.add_argument( - "--tracker-dir", - type=Path, - metavar="DIR", - help="Directory containing tracker prediction files.", - ) - bench_group.add_argument( - "--seqmap", - type=Path, - metavar="PATH", - help="Sequence map file listing sequences to evaluate.", - ) - - # Common options - parser.add_argument( - "--metrics", - nargs="+", - default=["CLEAR"], - choices=["CLEAR", "HOTA", "Identity"], - help="Metrics to compute. Default: CLEAR. Options: CLEAR, HOTA, Identity", - ) - parser.add_argument( - "--threshold", - type=float, - default=0.5, - help="IoU threshold for CLEAR and Identity matching. Default: 0.5", - ) - parser.add_argument( - "--columns", - nargs="+", - default=None, - metavar="COL", - help=( - "Metric columns to display. Default: auto-selected based on metrics. " - "CLEAR: MOTA, MOTP, MODA, CLR_Re, CLR_Pr, MTR, PTR, MLR, sMOTA, " - "CLR_TP, CLR_FN, CLR_FP, IDSW, MT, PT, ML, Frag. " - "HOTA: HOTA, DetA, AssA, DetRe, DetPr, AssRe, AssPr, LocA. " - "Identity: IDF1, IDR, IDP, IDTP, IDFN, IDFP" - ), - ) - parser.add_argument( - "--output", - "-o", - type=Path, - metavar="PATH", - help="Output file for results (JSON format).", - ) - - parser.set_defaults(func=run_eval) - - -def run_eval(args: argparse.Namespace) -> int: - """Execute the eval command.""" - # Configure logging to show detection info - logging.basicConfig( - level=logging.INFO, - format="%(message)s", - handlers=[logging.StreamHandler(sys.stderr)], - ) - - # Validate arguments - single_mode = args.gt is not None and args.tracker is not None - benchmark_mode = args.gt_dir is not None and args.tracker_dir is not None - - if not single_mode and not benchmark_mode: - print( - "Error: Must specify either --gt/--tracker or --gt-dir/--tracker-dir", - file=sys.stderr, - ) - return 1 - - if single_mode and benchmark_mode: - print( - "Error: Cannot use both single sequence and benchmark mode", - file=sys.stderr, - ) - return 1 - - # Columns: None means auto-select based on available metrics - columns = args.columns - - # Import evaluation functions - from trackers.eval import evaluate_mot_sequence, evaluate_mot_sequences - - try: - if single_mode: - seq_result = evaluate_mot_sequence( - gt_path=args.gt, - tracker_path=args.tracker, - metrics=args.metrics, - threshold=args.threshold, - ) - print(seq_result.table(columns=columns)) - - # Save results if output specified - if args.output: - args.output.parent.mkdir(parents=True, exist_ok=True) - args.output.write_text(seq_result.json()) - print(f"\nResults saved to: {args.output}") - else: - bench_result = evaluate_mot_sequences( - gt_dir=args.gt_dir, - tracker_dir=args.tracker_dir, - seqmap=args.seqmap, - metrics=args.metrics, - threshold=args.threshold, - ) - print(bench_result.table(columns=columns)) - - # Save results if output specified - if args.output: - bench_result.save(args.output) - print(f"\nResults saved to: {args.output}") - - except FileNotFoundError as e: - print(f"Error: {e}", file=sys.stderr) - return 1 - except ValueError as e: - print(f"Error: {e}", file=sys.stderr) - return 1 - - return 0 +# Backward-compatibility shim — trackers.scripts is deprecated; use trackers.cli +from trackers.cli.eval import eval_command # noqa: F401 diff --git a/src/trackers/scripts/progress.py b/src/trackers/scripts/progress.py index 67de2608..17804d7f 100644 --- a/src/trackers/scripts/progress.py +++ b/src/trackers/scripts/progress.py @@ -1,232 +1,13 @@ -#!/usr/bin/env python # ------------------------------------------------------------------------ # Trackers # Copyright (c) 2026 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -from __future__ import annotations - -import itertools -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Literal - -import cv2 -from rich.console import Console -from rich.live import Live -from rich.text import Text - -from trackers.io.video import IMAGE_EXTENSIONS - -_SPINNER_FRAMES = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" -_STREAM_PREFIXES = ("rtsp://", "http://", "https://") -_ICON_OK = "✓" -_ICON_FAIL = "✗" - - -@dataclass -class _SourceInfo: - """Metadata about a frame source used to drive progress display. - - Attributes: - source_type: Kind of source (`video`, `image_dir`, `webcam`, - `stream`). - total_frames: Total frame count when known, `None` for unbounded - sources such as webcams and network streams. - fps: Source frame-rate when known, `None` otherwise. - """ - - source_type: Literal["video", "image_dir", "webcam", "stream"] - total_frames: int | None = None - fps: float | None = None - - -def _classify_source(source: str | Path | int) -> _SourceInfo: - """Classify a frame source and extract metadata. - - The function inspects *source* without consuming any frames so it can be - called before the main processing loop. - - Args: - source: The same value accepted by `frames_from_source`. - - Returns: - A `_SourceInfo` describing the source. - """ - if isinstance(source, int) or (isinstance(source, str) and source.isdigit()): - return _SourceInfo(source_type="webcam") - - source_str = str(source) - - if any(source_str.lower().startswith(p) for p in _STREAM_PREFIXES): - return _SourceInfo(source_type="stream") - - path = Path(source_str) - if path.is_dir(): - count = sum(1 for p in path.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTENSIONS) - return _SourceInfo( - source_type="image_dir", - total_frames=count if count > 0 else None, - ) - - cap = cv2.VideoCapture(source_str) - if not cap.isOpened(): - # Cannot open; still classify as video - the real error will come - # from frames_from_source later. - return _SourceInfo(source_type="video") - - try: - total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - fps = cap.get(cv2.CAP_PROP_FPS) - return _SourceInfo( - source_type="video", - total_frames=total if total > 0 else None, - fps=fps if fps and fps > 0 else None, - ) - finally: - cap.release() - - -def _format_time(seconds: float) -> str: - """Format `seconds` as `H:MM:SS` or `M:SS`.""" - if seconds < 0: - return "--" - minutes, seconds_remainder = divmod(int(seconds), 60) - hours, minutes = divmod(minutes, 60) - if hours > 0: - return f"{hours}:{minutes:02d}:{seconds_remainder:02d}" - return f"{minutes}:{seconds_remainder:02d}" - - -class _TrackingProgress: - """Context-manager that renders a single live progress line. - - Args: - source_info: Source metadata returned by `_classify_source`. - console: Optional `Console` instance (useful for testing with a - `StringIO` file). - """ - - def __init__( - self, - source_info: _SourceInfo, - console: Console | None = None, - ) -> None: - self._source_info = source_info - self._console = console or Console() - self._frames_processed = 0 - self._start_time: float = 0.0 - self._spinner = itertools.cycle(_SPINNER_FRAMES) - self._live: Live | None = None - self._interrupted = False - - def update(self) -> None: - """Record one processed frame and refresh the display.""" - self._frames_processed += 1 - icon = next(self._spinner) - if self._live is not None: - self._live.update(self._build_line(icon)) - - def complete(self, *, interrupted: bool = False) -> None: - """Signal that the processing loop has ended. - - Must be called before leaving the `with` block so that `__exit__` - can render the correct final state. - - Args: - interrupted: `True` when the loop was terminated early (e.g. - display-quit). - """ - self._interrupted = interrupted - - def __enter__(self) -> _TrackingProgress: - self._start_time = time.monotonic() - self._live = Live( - self._build_line("⠋"), - console=self._console, - refresh_per_second=12, - transient=True, - ) - self._live.__enter__() - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: object, - ) -> None: - if self._live is not None: - self._live.__exit__(None, None, None) - - icon, suffix = self._resolve_final_state(exc_type) - final = self._build_line(icon, show_eta=False, suffix=suffix) - self._console.print(final) - - @property - def _is_bounded(self) -> bool: - """Whether the source has a known total frame count.""" - return self._source_info.total_frames is not None - - def _resolve_final_state(self, exc_type: type[BaseException] | None) -> tuple[str, str]: - """Return `(icon, suffix)` for the final printed line.""" - is_real_error = exc_type is not None and not issubclass(exc_type, KeyboardInterrupt) - - if is_real_error: - return (_ICON_FAIL, "(source lost)") - - was_stopped_early = exc_type is not None or self._interrupted - - if was_stopped_early and self._is_bounded: - return (_ICON_FAIL, "(interrupted)") - - return (_ICON_OK, "") - - def _build_line( - self, - icon: str, - *, - show_eta: bool = True, - suffix: str = "", - ) -> Text: - """Compose the single-line progress string.""" - elapsed = time.monotonic() - self._start_time - fps = self._frames_processed / elapsed if elapsed > 0 else 0.0 - total = self._source_info.total_frames - - if total is not None: - total_str = str(total) - frames_part = f"{self._frames_processed:>{len(total_str)}} / {total_str}" - else: - frames_part = f"{self._frames_processed} / --" - - if total is not None and total > 0: - percentage = self._frames_processed / total * 100 - percentage_part = f"{percentage:>3.0f}%" - else: - percentage_part = " --" - - fps_part = f"{fps:>.1f} fps" - elapsed_part = f"{_format_time(elapsed)} elapsed" - - parts = [ - f"{icon} Tracking", - f"{frames_part} frames", - percentage_part, - fps_part, - elapsed_part, - ] - - if show_eta: - if total is not None and fps > 0: - remaining = (total - self._frames_processed) / fps - parts.append(f"eta {_format_time(remaining)}") - else: - parts.append("eta --") - - if suffix: - parts.append(suffix) - - return Text(" ".join(parts)) +# Backward-compatibility shim — trackers.scripts is deprecated; use trackers.cli +from trackers.cli.progress import ( # noqa: F401 + _classify_source, + _format_time, + _SourceInfo, + _TrackingProgress, +) diff --git a/src/trackers/scripts/track.py b/src/trackers/scripts/track.py index 539a3a23..973e5588 100644 --- a/src/trackers/scripts/track.py +++ b/src/trackers/scripts/track.py @@ -1,738 +1,19 @@ -#!/usr/bin/env python # ------------------------------------------------------------------------ # Trackers # Copyright (c) 2026 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -from __future__ import annotations - -import argparse -import sys -from contextlib import nullcontext -from pathlib import Path -from typing import TYPE_CHECKING - -import numpy as np -import supervision as sv - -from trackers import frames_from_source -from trackers.core.base import BaseTracker -from trackers.io.mot import _mot_frame_to_detections, _MOTOutput, load_mot_file -from trackers.io.paths import _resolve_video_output_path, _validate_output_path -from trackers.io.video import _DEFAULT_OUTPUT_FPS, _DisplayWindow, _VideoOutput -from trackers.scripts.progress import _classify_source, _SourceInfo, _TrackingProgress -from trackers.utils.device import _best_device - -if TYPE_CHECKING: - from inference_models import AnyModel - -# Defaults -DEFAULT_MODEL = "rfdetr-nano" -DEFAULT_TRACKER = "bytetrack" -DEFAULT_CONFIDENCE = 0.5 -DEFAULT_DEVICE = "auto" - -# Visualization -COLOR_PALETTE = sv.ColorPalette.from_hex( - [ - "#ffff00", - "#ff9b00", - "#ff8080", - "#ff66b2", - "#ff66ff", - "#b266ff", - "#9999ff", - "#3399ff", - "#66ffff", - "#33ff99", - "#66ff66", - "#99ff00", - ] +# Backward-compatibility shim — trackers.scripts is deprecated; use trackers.cli +from trackers.cli.track import ( # noqa: F401 + _format_labels, + _init_annotators, + _init_model, + _init_tracker, + _resolve_class_filter, + _resolve_track_id_filter, + _run_frameless, + _run_model, + _run_with_source, + track_command, ) - - -def add_track_subparser(subparsers: argparse._SubParsersAction) -> None: - """Add the track subcommand to the argument parser.""" - parser = subparsers.add_parser( - "track", - help="Track objects in video using detection and tracking.", - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - # Source options - source_group = parser.add_argument_group("source") - source_group.add_argument( - "--source", - type=str, - default=None, - metavar="PATH", - help="Video file, webcam index (0), RTSP URL, or image directory.", - ) - - # Detection options (mutually exclusive) - detection_group = parser.add_argument_group("detection") - det_mutex = detection_group.add_mutually_exclusive_group(required=False) - det_mutex.add_argument( - "--model", - type=str, - default=DEFAULT_MODEL, - metavar="ID", - help=( - "Model ID for detection. Pretrained: rfdetr-nano, rfdetr-base, etc. " - f"Custom: workspace/project/version. Default: {DEFAULT_MODEL}" - ), - ) - det_mutex.add_argument( - "--detections", - type=Path, - metavar="PATH", - help="Load pre-computed detections from MOT format file.", - ) - - # Model options - model_group = parser.add_argument_group("model options") - model_group.add_argument( - "--model.confidence", - type=float, - default=DEFAULT_CONFIDENCE, - dest="model_confidence", - metavar="FLOAT", - help=f"Detection confidence threshold. Default: {DEFAULT_CONFIDENCE}", - ) - model_group.add_argument( - "--model.device", - type=str, - default=DEFAULT_DEVICE, - dest="model_device", - metavar="DEVICE", - help=f"Device: auto, cpu, cuda, cuda:0, mps. Default: {DEFAULT_DEVICE}", - ) - model_group.add_argument( - "--model.api_key", - type=str, - default=None, - dest="model_api_key", - metavar="KEY", - help="Roboflow API key for custom models.", - ) - - # Filtering options - filter_group = parser.add_argument_group("filtering") - filter_group.add_argument( - "--classes", - type=str, - default=None, - metavar="NAMES_OR_IDS", - help="Filter by class names or IDs (comma-separated, e.g., person,car).", - ) - filter_group.add_argument( - "--track_ids", - type=str, - default=None, - metavar="IDS", - help="Filter output by track IDs (comma-separated, e.g., 1,3,5)", - ) - - # Tracker options - tracker_group = parser.add_argument_group("tracker options") - available_trackers = BaseTracker._registered_trackers() - tracker_group.add_argument( - "--tracker", - type=str, - default=DEFAULT_TRACKER, - choices=available_trackers if available_trackers else [DEFAULT_TRACKER, "sort"], - metavar="ID", - help=f"Tracking algorithm. Default: {DEFAULT_TRACKER}", - ) - - # Add dynamic tracker parameters - _add_tracker_params(tracker_group) - - # Output options - output_group = parser.add_argument_group("output") - output_group.add_argument( - "-o", - "--output", - type=Path, - default=None, - metavar="PATH", - help="Output video file path.", - ) - output_group.add_argument( - "--mot-output", - type=Path, - default=None, - dest="mot_output", - metavar="PATH", - help="Output MOT format file path.", - ) - output_group.add_argument( - "--overwrite", - action="store_true", - help="Overwrite existing output files.", - ) - - # Visualization options - vis_group = parser.add_argument_group("visualization") - vis_group.add_argument( - "--display", - action="store_true", - help="Show preview window.", - ) - vis_group.add_argument( - "--show-boxes", - action="store_true", - default=True, - dest="show_boxes", - help="Draw bounding boxes. Default: True", - ) - vis_group.add_argument( - "--no-boxes", - action="store_false", - dest="show_boxes", - help="Disable bounding boxes.", - ) - vis_group.add_argument( - "--show-masks", - action="store_true", - dest="show_masks", - help="Draw segmentation masks (seg models only).", - ) - vis_group.add_argument( - "--show-labels", - action="store_true", - dest="show_labels", - help="Show class labels.", - ) - vis_group.add_argument( - "--show-ids", - action="store_true", - default=True, - dest="show_ids", - help="Show track IDs. Default: True", - ) - vis_group.add_argument( - "--no-ids", - action="store_false", - dest="show_ids", - help="Disable track IDs.", - ) - vis_group.add_argument( - "--show-confidence", - action="store_true", - dest="show_confidence", - help="Show confidence scores.", - ) - vis_group.add_argument( - "--show-trajectories", - action="store_true", - dest="show_trajectories", - help="Draw track trajectories.", - ) - - parser.set_defaults(func=run_track) - - -def _add_tracker_params(group: argparse._ArgumentGroup) -> None: - """Add tracker-specific parameters from registry to argument group.""" - for tracker_id in BaseTracker._registered_trackers(): - info = BaseTracker._lookup_tracker(tracker_id) - if info is None: - continue - - for param_name, param_info in info.parameters.items(): - arg_name = f"--tracker.{param_name}" - dest_name = f"tracker_{param_name}" - - kwargs: dict = { - "dest": dest_name, - "default": param_info.default_value, - "help": f"{param_info.description} Default: {param_info.default_value}", - } - - if param_info.param_type is bool: - kwargs["action"] = "store_false" if param_info.default_value else "store_true" - else: - kwargs["type"] = param_info.param_type - kwargs["metavar"] = param_info.param_type.__name__.upper() - - try: - group.add_argument(arg_name, **kwargs) - except argparse.ArgumentError: - # Parameter already added by another tracker - pass - - -def run_track(args: argparse.Namespace) -> int: - """Execute the track command.""" - needs_frames = args.output or args.display - - if args.source is None and not args.detections: - print( - "Error: --source is required when not using --detections.", - file=sys.stderr, - ) - return 1 - - if needs_frames and args.source is None: - print( - "Error: --source is required when using --output or --display.", - file=sys.stderr, - ) - return 1 - - # Validate output paths - if args.output: - _validate_output_path(_resolve_video_output_path(args.output), overwrite=args.overwrite) - if args.mot_output: - _validate_output_path(args.mot_output, overwrite=args.overwrite) - - # Create detection source - if args.detections: - model = None - detections_data = load_mot_file(args.detections) - class_names: list[str] = [] - else: - model = _init_model( - args.model, - device=args.model_device, - api_key=args.model_api_key, - ) - detections_data = None - class_names = getattr(model, "class_names", []) - - # Resolve class filter (names and/or integer IDs) - class_filter = _resolve_class_filter(args.classes, class_names) - - track_id_filter = _resolve_track_id_filter(args.track_ids) - - # Create tracker - tracker_params = _extract_tracker_params(args.tracker, args) - tracker = _init_tracker(args.tracker, **tracker_params) - - if args.source is not None: - return _run_with_source( - args, - model, - detections_data, - class_names, - class_filter, - track_id_filter, - tracker, - ) - else: - return _run_frameless( - args, - detections_data, - class_filter, - track_id_filter, - tracker, - ) - - -def _run_frameless( - args: argparse.Namespace, - detections_data: dict | None, - class_filter: list[int] | None, - track_id_filter: list[int] | None, - tracker: BaseTracker, -) -> int: - """Run tracking from pre-computed detections without frame source.""" - if detections_data is None or not detections_data: - print("Error: No detections found in file.", file=sys.stderr) - return 1 - - total_frames = max(detections_data.keys()) - source_info = _SourceInfo(source_type="video", total_frames=total_frames) - - try: - with ( - _MOTOutput(args.mot_output) as mot, - _TrackingProgress(source_info) as progress, - ): - interrupted = False - for frame_idx in range(1, total_frames + 1): - if frame_idx in detections_data: - detections = _mot_frame_to_detections(detections_data[frame_idx]) - else: - detections = sv.Detections.empty() - - if class_filter is not None and len(detections) > 0: - mask = np.isin(detections.class_id, class_filter) - detections = detections[mask] # type: ignore[assignment] - - tracked = tracker.update(detections) - - if track_id_filter is not None and len(tracked) > 0: - if tracked.tracker_id is not None: - mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) - tracked = tracked[mask] - - mot.write(frame_idx, tracked) - progress.update() - - progress.complete(interrupted=interrupted) - - except KeyboardInterrupt: - pass - - return 0 - - -def _run_with_source( - args: argparse.Namespace, - model, - detections_data: dict | None, - class_names: list[str], - class_filter: list[int] | None, - track_id_filter: list[int] | None, - tracker: BaseTracker, -) -> int: - """Run tracking with a frame source (video, webcam, images).""" - frame_gen = frames_from_source(args.source) - source_info = _classify_source(args.source) - - # Setup annotators - annotators, label_annotator = _init_annotators( - show_boxes=args.show_boxes, - show_masks=args.show_masks, - show_labels=args.show_labels, - show_ids=args.show_ids, - show_confidence=args.show_confidence, - ) - trace_annotator = None - if args.show_trajectories: - trace_annotator = sv.TraceAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - - display_ctx = _DisplayWindow() if args.display else nullcontext() - - try: - with ( - _VideoOutput( - args.output, - fps=source_info.fps or _DEFAULT_OUTPUT_FPS, - ) as video, - _MOTOutput(args.mot_output) as mot, - display_ctx as display, - _TrackingProgress(source_info) as progress, - ): - interrupted = False - for frame_idx, frame in frame_gen: - # Get detections - if model is not None: - detections = _run_model(model, frame, args.model_confidence) - elif detections_data is not None and frame_idx in detections_data: - detections = _mot_frame_to_detections(detections_data[frame_idx]) - else: - detections = sv.Detections.empty() - - # Filter by class - if class_filter is not None and len(detections) > 0: - mask = np.isin(detections.class_id, class_filter) - detections = detections[mask] # type: ignore[assignment] - - # Run tracker - tracked = tracker.update(detections, frame) - - # Filter by track ID - if track_id_filter is not None and len(tracked) > 0: - if tracked.tracker_id is not None: - mask = np.isin(tracked.tracker_id.astype(int), track_id_filter) - tracked = tracked[mask] - - # Write MOT output - mot.write(frame_idx, tracked) - - progress.update() - - # Annotate and display/save frame - if args.display or args.output: - annotated = frame.copy() - if trace_annotator is not None: - annotated = trace_annotator.annotate(annotated, tracked) - for annotator in annotators: - annotated = annotator.annotate(annotated, tracked) - if label_annotator is not None: - labeled = tracked[tracked.tracker_id != -1] - labels = _format_labels( - labeled, - class_names, - show_ids=args.show_ids, - show_labels=args.show_labels, - show_confidence=args.show_confidence, - ) - annotated = label_annotator.annotate(annotated, labeled, labels) - - video.write(annotated) - - if display is not None: - display.show(annotated) - if display.quit_requested: - interrupted = True - break - - progress.complete(interrupted=interrupted) - - except KeyboardInterrupt: - pass - - return 0 - - -def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None: - """Resolve a comma-separated `--track-ids` value to a list of integer IDs. - - Args: - track_ids_arg: Raw `--track-ids` string (e.g. `"1,3,5"`). `None` - means no filter. - - Returns: - List of integer track IDs, or `None` when no valid filter remains. - """ - if not track_ids_arg: - return None - - track_ids: list[int] = [] - for token in track_ids_arg.split(","): - token = token.strip() - try: - track_ids.append(int(token)) - except ValueError: - print( - f"Warning: '{token}' is not a valid track ID, skipping.", - file=sys.stderr, - ) - return track_ids if track_ids else None - - -def _resolve_class_filter( - classes_arg: str | None, - class_names: list[str], -) -> list[int] | None: - """Resolve a comma-separated `--classes` value to a list of integer IDs. - - Each token is checked independently: if it parses as an `int` it is used - directly as a class ID; otherwise it is looked up by name in *class_names*. - Unknown names are printed as warnings and skipped. - - Args: - classes_arg: Raw `--classes` string (e.g. `"person,car"` or - `"0,2"` or `"person,2"`). `None` means no filter. - class_names: Ordered list of class names where the index equals the - class ID (as provided by the model). - - Returns: - List of integer class IDs, or `None` when no valid filter remains. - """ - if not classes_arg: - return None - - requested = [token.strip() for token in classes_arg.split(",")] - name_to_id = {name: i for i, name in enumerate(class_names)} - class_filter: list[int] = [] - for token in requested: - try: - class_filter.append(int(token)) - except ValueError: - if token in name_to_id: - class_filter.append(name_to_id[token]) - else: - print( - f"Warning: class '{token}' not found in model class list, skipping.", - file=sys.stderr, - ) - return class_filter if class_filter else None - - -def _init_model( - model_id: str, - *, - device: str = DEFAULT_DEVICE, - api_key: str | None = None, -) -> AnyModel: - """Load detection model via inference-models. - - Args: - model_id: Model identifier (e.g., 'rfdetr-nano' or 'workspace/project/version'). - device: Device to load model on ('auto', 'cpu', 'cuda', 'mps'). - api_key: Roboflow API key for custom models. - - Returns: - Loaded model instance. - """ - try: - from inference_models import AutoModel - except ImportError as e: - print( - "Error: inference-models is required for model-based detection.\n" - "Install with: pip install 'trackers[detection]'", - file=sys.stderr, - ) - raise SystemExit(1) from e - - resolved_device = _best_device() if device == DEFAULT_DEVICE else device - - return AutoModel.from_pretrained( - model_id, - api_key=api_key, - device=resolved_device, - ) - - -def _run_model(model: AnyModel, frame: np.ndarray, confidence: float) -> sv.Detections: - """Run model inference and return sv.Detections.""" - predictions = model(frame) - if not predictions: - return sv.Detections.empty() - - detections = predictions[0].to_supervision() - - # Filter by confidence - if len(detections) > 0 and detections.confidence is not None: - mask = detections.confidence >= confidence - detections = detections[mask] - - return detections - - -def _extract_tracker_params(tracker_id: str, args: argparse.Namespace) -> dict[str, object]: - """Extract tracker parameters from CLI args. - - Args: - tracker_id: Registered tracker name. - args: Parsed CLI arguments. - - Returns: - Dictionary of tracker parameters with non-None values. - """ - info = BaseTracker._lookup_tracker(tracker_id) - if info is None: - return {} - - params = {} - for param_name in info.parameters: - dest_name = f"tracker_{param_name}" - if hasattr(args, dest_name): - value = getattr(args, dest_name) - if value is not None: - params[param_name] = value - return params - - -def _init_tracker(tracker_id: str, **kwargs: object) -> BaseTracker: - """Create tracker instance from registry. - - Args: - tracker_id: Registered tracker name (e.g., 'bytetrack', 'sort'). - **kwargs: Tracker-specific parameters. - - Returns: - Initialized tracker instance. - - Raises: - ValueError: If tracker_id is not registered. - """ - info = BaseTracker._lookup_tracker(tracker_id) - if info is None: - available = ", ".join(BaseTracker._registered_trackers()) - raise ValueError(f"Unknown tracker: '{tracker_id}'. Available: {available}") - - return info.tracker_class(**kwargs) - - -def _init_annotators( - show_boxes: bool = False, - show_masks: bool = False, - show_labels: bool = False, - show_ids: bool = False, - show_confidence: bool = False, -) -> tuple[list, sv.LabelAnnotator | None]: - """Initialize supervision annotators based on display options. - - Args: - show_boxes: Create BoxAnnotator. - show_masks: Create MaskAnnotator. - show_labels: Include class labels (triggers LabelAnnotator). - show_ids: Include track IDs (triggers LabelAnnotator). - show_confidence: Include confidence scores (triggers LabelAnnotator). - - Returns: - Tuple of (annotators list, label_annotator or None). - Label annotator is separate because it needs custom labels per frame. - """ - annotators: list = [] - label_annotator: sv.LabelAnnotator | None = None - - if show_boxes: - annotators.append( - sv.BoxAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - ) - - if show_masks: - annotators.append( - sv.MaskAnnotator( - color=COLOR_PALETTE, - color_lookup=sv.ColorLookup.TRACK, - ) - ) - - if show_labels or show_ids or show_confidence: - label_annotator = sv.LabelAnnotator( - color=COLOR_PALETTE, - text_color=sv.Color.BLACK, - text_position=sv.Position.TOP_LEFT, - color_lookup=sv.ColorLookup.TRACK, - ) - - return annotators, label_annotator - - -def _format_labels( - detections: sv.Detections, - class_names: list[str], - *, - show_ids: bool = False, - show_labels: bool = False, - show_confidence: bool = False, -) -> list[str]: - """Generate label strings for each detection. - - Args: - detections: Detections to generate labels for. - class_names: List of class names for lookup. - show_ids: Include tracker IDs in labels. - show_labels: Include class names in labels. - show_confidence: Include confidence scores in labels. - - Returns: - List of label strings, one per detection. - """ - labels = [] - - for i in range(len(detections)): - parts = [] - - if show_ids and detections.tracker_id is not None: - parts.append(f"#{int(detections.tracker_id[i])}") - - if show_labels and detections.class_id is not None: - class_id = int(detections.class_id[i]) - if class_names and 0 <= class_id < len(class_names): - parts.append(class_names[class_id]) - else: - parts.append(str(class_id)) - - if show_confidence and detections.confidence is not None: - parts.append(f"{detections.confidence[i]:.2f}") - - labels.append(" ".join(parts)) - - return labels diff --git a/src/trackers/scripts/tune.py b/src/trackers/scripts/tune.py index 0b889b45..2f3cba84 100644 --- a/src/trackers/scripts/tune.py +++ b/src/trackers/scripts/tune.py @@ -1,179 +1,8 @@ -#!/usr/bin/env python # ------------------------------------------------------------------------ # Trackers # Copyright (c) 2026 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -from __future__ import annotations - -import argparse -import json -import sys -from pathlib import Path - - -def add_tune_subparser(subparsers: argparse._SubParsersAction) -> None: - """Add the tune subcommand to the argument parser.""" - parser = subparsers.add_parser( - "tune", - help="Tune tracker hyperparameters via Optuna.", - description=( - "Run Optuna-based hyperparameter optimisation for a registered " - "tracker using pre-computed detections and ground-truth MOT files." - ), - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - parser.add_argument( - "--tracker", - required=True, - metavar="ID", - help="Tracker ID to tune (e.g. bytetrack, sort, ocsort).", - ) - parser.add_argument( - "--gt-dir", - type=Path, - required=True, - metavar="DIR", - help="Directory containing ground-truth MOT files.", - ) - parser.add_argument( - "--detections-dir", - type=Path, - required=True, - metavar="DIR", - help=("Directory containing pre-computed detection files in MOT flat format (one {seq}.txt per sequence)."), - ) - parser.add_argument( - "--objective", - default="HOTA", - choices=["MOTA", "HOTA", "IDF1"], - help="Scalar metric to maximise. Default: HOTA.", - ) - parser.add_argument( - "--n-trials", - type=int, - default=100, - metavar="N", - help="Number of Optuna trials to run. Default: 100.", - ) - parser.add_argument( - "--metrics", - nargs="+", - default=["CLEAR"], - choices=["CLEAR", "HOTA", "Identity"], - help=( - "Metric families to compute. Default: CLEAR. The family required " - "by --objective is added automatically if missing." - ), - ) - parser.add_argument( - "--threshold", - type=float, - default=0.5, - help="IoU threshold for CLEAR and Identity matching. Default: 0.5.", - ) - parser.add_argument( - "--seqmap", - type=Path, - metavar="PATH", - help="Sequence map file listing sequences to evaluate.", - ) - parser.add_argument( - "--output", - "-o", - type=Path, - metavar="PATH", - help="Output file for best parameters (JSON format).", - ) - - parser.set_defaults(func=run_tune) - - -def run_tune(args: argparse.Namespace) -> int: - """Execute the tune command.""" - return tune( - tracker=args.tracker, - gt_dir=args.gt_dir, - detections_dir=args.detections_dir, - objective=args.objective, - n_trials=args.n_trials, - metrics=args.metrics, - threshold=args.threshold, - seqmap=args.seqmap, - output=args.output, - ) - - -def tune( - tracker: str, - gt_dir: Path, - detections_dir: Path, - objective: str = "HOTA", - n_trials: int = 100, - metrics: list[str] | None = None, - threshold: float = 0.5, - seqmap: Path | None = None, - output: Path | None = None, -) -> int: - """Tune tracker hyperparameters using Optuna. - - Args: - tracker: Tracker ID to tune (e.g. bytetrack, sort). - gt_dir: Directory of ground-truth MOT files. - detections_dir: Directory of pre-computed detection files in MOT flat - format (one {seq}.txt per sequence). - objective: Scalar metric to maximise. Options: MOTA, HOTA, IDF1. - n_trials: Number of Optuna trials to run. - metrics: Metric families to compute. Options: CLEAR, HOTA, Identity. - Default: CLEAR. - threshold: IoU threshold for CLEAR and Identity matching. - seqmap: Sequence map file listing sequences to evaluate. - output: Output file path for best parameters (JSON format). - - Returns: - Exit code: 0 on success, 1 on error. - """ - if metrics is None: - metrics = ["CLEAR"] - - from trackers.tune import Tuner - - try: - tuner = Tuner( - tracker_id=tracker, - gt_dir=gt_dir, - detections_dir=detections_dir, - metrics=metrics, - objective=objective, - n_trials=n_trials, - threshold=threshold, - seqmap=seqmap, - ) - except (ValueError, ImportError, FileNotFoundError) as e: - print(str(e), file=sys.stderr) - return 1 - - try: - best_params = tuner.run() - except Exception as e: - print(f"Error during tuning: {e}", file=sys.stderr) - return 1 - - print(f"\nBest parameters for {tracker}:") - for name, value in best_params.items(): - print(f" {name}: {value}") - if tuner.study is not None: - print(f"\nBest {objective}: {tuner.study.best_value:.4f}") - - if output: - try: - output.parent.mkdir(parents=True, exist_ok=True) - output.write_text(json.dumps(best_params, indent=2)) - except OSError as e: - print(f"Error writing output: {e}", file=sys.stderr) - return 1 - print(f"\nResults saved to: {output}") - - return 0 +# Backward-compatibility shim — trackers.scripts is deprecated; use trackers.cli +from trackers.cli.tune import tune, tune_command # noqa: F401 diff --git a/tests/scripts/test_download.py b/tests/scripts/test_download.py index 94b3f573..4ead37c8 100644 --- a/tests/scripts/test_download.py +++ b/tests/scripts/test_download.py @@ -4,96 +4,59 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +"""CLI-level tests for trackers/cli/download.py.""" + from __future__ import annotations -import argparse from unittest.mock import patch import pytest +from click.testing import CliRunner +from trackers.cli.__main__ import cli +from trackers.cli.download import _print_available from trackers.datasets.download import _DEFAULT_CACHE_DIR, _DEFAULT_OUTPUT_DIR -from trackers.scripts.download import ( - _print_available, - _run_download, - add_download_subparser, -) - - -def _parse_args(argv: list[str]) -> argparse.Namespace: - """Parse argv through a fresh download subparser and return the namespace.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - add_download_subparser(subparsers) - return parser.parse_args(argv) - - -class TestSubparserRegistration: - """Argument parsing and help strings.""" - - def test_list_flag(self) -> None: - """--list sets the flag to True.""" - args = _parse_args(["download", "--list"]) - assert args.list is True - - def test_list_flag_default_false(self) -> None: - """--list is False when omitted.""" - args = _parse_args(["download", "mot17"]) - assert args.list is False - def test_split_flag_accepts_comma_separated(self) -> None: - """--split accepts comma-separated values.""" - args = _parse_args(["download", "mot17", "--split", "train,val"]) - assert args.split == "train,val" - def test_asset_flag_accepts_comma_separated(self) -> None: - """--asset accepts comma-separated values.""" - args = _parse_args(["download", "mot17", "--asset", "frames,annotations"]) - assert args.asset == "frames,annotations" +class TestDownloadCommand: + """Argument parsing and routing for the download subcommand.""" - def test_output_directory_short_flag(self) -> None: - """-o sets the output directory.""" - args = _parse_args(["download", "mot17", "-o", "./datasets"]) - assert args.output == "./datasets" - - def test_cache_dir_flag(self) -> None: - """--cache-dir sets the cache directory.""" - args = _parse_args(["download", "mot17", "--cache-dir", "./cache"]) - assert args.cache_dir == "./cache" - - def test_dataset_positional(self) -> None: - """Dataset is captured as positional argument.""" - args = _parse_args(["download", "sportsmot"]) - assert args.dataset == "sportsmot" - - -class TestRunDownload: - """Execution of the download subcommand.""" - - def test_list_triggers_print(self) -> None: - """--list calls _print_available and returns 0.""" - args = _parse_args(["download", "--list"]) - - with patch("trackers.scripts.download._print_available") as mock_print: - rc = _run_download(args) - assert rc == 0 - mock_print.assert_called_once() + def test_list_flag_exits_zero(self) -> None: + """--list prints datasets and exits 0.""" + runner = CliRunner() + with patch("trackers.cli.download._print_available") as mock_print: + result = runner.invoke(cli, ["download", "--list"]) + assert result.exit_code == 0 + mock_print.assert_called_once() def test_list_takes_precedence_over_dataset(self) -> None: - """--list wins over dataset positional.""" - args = _parse_args(["download", "mot17", "--list"]) - - with patch("trackers.scripts.download._print_available") as mock_print: - rc = _run_download(args) - assert rc == 0 - mock_print.assert_called_once() - - def test_missing_dataset_exits_with_error(self, capsys: pytest.CaptureFixture[str]) -> None: - """No dataset and no --list prints error to stderr and returns 1.""" - args = _parse_args(["download"]) - rc = _run_download(args) - captured = capsys.readouterr() - assert rc == 1 - assert "Please specify a dataset" in captured.err + """--list wins over positional dataset argument.""" + runner = CliRunner() + with patch("trackers.cli.download._print_available") as mock_print: + result = runner.invoke(cli, ["download", "mot17", "--list"]) + assert result.exit_code == 0 + mock_print.assert_called_once() + + def test_missing_dataset_exits_nonzero(self) -> None: + """No dataset and no --list exits with non-zero code and error message.""" + runner = CliRunner() + result = runner.invoke(cli, ["download"]) + assert result.exit_code != 0 + assert "Please specify a dataset" in result.output + + def test_dataset_positional_accepted(self) -> None: + """Dataset positional argument is forwarded to download_dataset.""" + runner = CliRunner() + with patch("trackers.datasets.download.download_dataset") as mock_dl: + result = runner.invoke(cli, ["download", "mot17"]) + assert result.exit_code == 0 + mock_dl.assert_called_once_with( + dataset="mot17", + split=None, + asset=None, + output=_DEFAULT_OUTPUT_DIR, + cache_dir=_DEFAULT_CACHE_DIR, + ) @pytest.mark.parametrize( "split_arg,expected_splits", @@ -105,41 +68,17 @@ def test_missing_dataset_exits_with_error(self, capsys: pytest.CaptureFixture[st ) def test_split_comma_parsing(self, split_arg: str, expected_splits: list[str]) -> None: """--split values are split on commas and whitespace-stripped.""" - args = _parse_args(["download", "mot17", "--split", split_arg, "--asset", "annotations"]) - - with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) - assert rc == 0 - mock_dl.assert_called_once_with( - dataset="mot17", - split=expected_splits, - asset=["annotations"], - output=_DEFAULT_OUTPUT_DIR, - cache_dir=_DEFAULT_CACHE_DIR, - ) - - @pytest.mark.parametrize( - "split_arg,expected_splits", - [ - ("train,", ["train", ""]), - (",train", ["", "train"]), - ("train,,val", ["train", "", "val"]), - ], - ) - def test_split_comma_parsing_boundary(self, split_arg: str, expected_splits: list[str]) -> None: - """--split handles malformed comma inputs gracefully.""" - args = _parse_args(["download", "mot17", "--split", split_arg, "--asset", "annotations"]) - + runner = CliRunner() with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) - assert rc == 0 - mock_dl.assert_called_once_with( - dataset="mot17", - split=expected_splits, - asset=["annotations"], - output=_DEFAULT_OUTPUT_DIR, - cache_dir=_DEFAULT_CACHE_DIR, - ) + result = runner.invoke(cli, ["download", "mot17", "--split", split_arg, "--asset", "annotations"]) + assert result.exit_code == 0 + mock_dl.assert_called_once_with( + dataset="mot17", + split=expected_splits, + asset=["annotations"], + output=_DEFAULT_OUTPUT_DIR, + cache_dir=_DEFAULT_CACHE_DIR, + ) @pytest.mark.parametrize( "asset_arg,expected_assets", @@ -151,80 +90,86 @@ def test_split_comma_parsing_boundary(self, split_arg: str, expected_splits: lis ) def test_asset_comma_parsing(self, asset_arg: str, expected_assets: list[str]) -> None: """--asset values are split on commas and whitespace-stripped.""" - args = _parse_args(["download", "sportsmot", "--split", "train", "--asset", asset_arg]) - + runner = CliRunner() with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) - assert rc == 0 - mock_dl.assert_called_once_with( - dataset="sportsmot", - split=["train"], - asset=expected_assets, - output=_DEFAULT_OUTPUT_DIR, - cache_dir=_DEFAULT_CACHE_DIR, - ) + result = runner.invoke(cli, ["download", "sportsmot", "--split", "train", "--asset", asset_arg]) + assert result.exit_code == 0 + mock_dl.assert_called_once_with( + dataset="sportsmot", + split=["train"], + asset=expected_assets, + output=_DEFAULT_OUTPUT_DIR, + cache_dir=_DEFAULT_CACHE_DIR, + ) def test_none_splits_and_assets_when_omitted(self) -> None: """When --split and --asset are omitted, None is forwarded.""" - args = _parse_args(["download", "mot17"]) - + runner = CliRunner() with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) - assert rc == 0 - mock_dl.assert_called_once_with( - dataset="mot17", - split=None, - asset=None, - output=_DEFAULT_OUTPUT_DIR, - cache_dir=_DEFAULT_CACHE_DIR, - ) + result = runner.invoke(cli, ["download", "mot17"]) + assert result.exit_code == 0 + mock_dl.assert_called_once_with( + dataset="mot17", + split=None, + asset=None, + output=_DEFAULT_OUTPUT_DIR, + cache_dir=_DEFAULT_CACHE_DIR, + ) def test_output_directory_forwarded(self) -> None: """-o value is forwarded to download_dataset.""" - args = _parse_args(["download", "mot17", "-o", "/custom/path"]) - + runner = CliRunner() + with patch("trackers.datasets.download.download_dataset") as mock_dl: + result = runner.invoke(cli, ["download", "mot17", "-o", "/custom/path"]) + assert result.exit_code == 0 + mock_dl.assert_called_once_with( + dataset="mot17", + split=None, + asset=None, + output="/custom/path", + cache_dir=_DEFAULT_CACHE_DIR, + ) + + def test_cache_dir_forwarded(self) -> None: + """--cache-dir value is forwarded to download_dataset.""" + runner = CliRunner() with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) - assert rc == 0 - mock_dl.assert_called_once_with( - dataset="mot17", - split=None, - asset=None, - output="/custom/path", - cache_dir=_DEFAULT_CACHE_DIR, - ) - - def test_value_error_returns_exit_code(self) -> None: - """ValueError from download_dataset is caught and returns 1.""" - args = _parse_args(["download", "mot17"]) - - with patch( - "trackers.datasets.download.download_dataset", - side_effect=ValueError("bad dataset"), - ): - rc = _run_download(args) - assert rc == 1 + result = runner.invoke(cli, ["download", "mot17", "--cache-dir", "./cache"]) + assert result.exit_code == 0 + mock_dl.assert_called_once_with( + dataset="mot17", + split=None, + asset=None, + output=_DEFAULT_OUTPUT_DIR, + cache_dir="./cache", + ) + + def test_exception_from_download_exits_nonzero(self) -> None: + """Exception from download_dataset is caught and exits non-zero.""" + runner = CliRunner() + with patch("trackers.datasets.download.download_dataset", side_effect=ValueError("bad dataset")): + result = runner.invoke(cli, ["download", "mot17"]) + assert result.exit_code != 0 def test_split_with_spaces_stripped(self) -> None: """--split with spaces around commas strips whitespace.""" - args = _parse_args(["download", "mot17", "--split", "train , val", "--asset", "annotations"]) - + runner = CliRunner() with patch("trackers.datasets.download.download_dataset") as mock_dl: - rc = _run_download(args) - assert rc == 0 - mock_dl.assert_called_once_with( - dataset="mot17", - split=["train", "val"], - asset=["annotations"], - output=_DEFAULT_OUTPUT_DIR, - cache_dir=_DEFAULT_CACHE_DIR, - ) + result = runner.invoke(cli, ["download", "mot17", "--split", "train , val", "--asset", "annotations"]) + assert result.exit_code == 0 + mock_dl.assert_called_once_with( + dataset="mot17", + split=["train", "val"], + asset=["annotations"], + output=_DEFAULT_OUTPUT_DIR, + cache_dir=_DEFAULT_CACHE_DIR, + ) class TestPrintAvailable: """Output of --list.""" def test_prints_without_error(self, capsys: pytest.CaptureFixture[str]) -> None: - """_print_available runs without raising and does not leak output.""" + """_print_available runs without raising.""" _print_available() capsys.readouterr() diff --git a/tests/scripts/test_progress.py b/tests/scripts/test_progress.py index 91486afe..85d3bab3 100644 --- a/tests/scripts/test_progress.py +++ b/tests/scripts/test_progress.py @@ -17,7 +17,7 @@ import pytest from rich.console import Console -from trackers.scripts.progress import ( +from trackers.cli.progress import ( _classify_source, _format_time, _SourceInfo, @@ -129,7 +129,7 @@ def test_video_with_zero_frame_count(self) -> None: cv2.CAP_PROP_FPS: 30.0, }.get(prop, 0.0) - with patch("trackers.scripts.progress.cv2.VideoCapture", return_value=mock_cap): + with patch("trackers.cli.progress.cv2.VideoCapture", return_value=mock_cap): info = _classify_source("some_video.mp4") assert info.source_type == "video" diff --git a/tests/scripts/test_track.py b/tests/scripts/test_track.py index be3867ed..900d0c74 100644 --- a/tests/scripts/test_track.py +++ b/tests/scripts/test_track.py @@ -11,8 +11,10 @@ import numpy as np import pytest import supervision as sv +from click.testing import CliRunner -from trackers.scripts.track import ( +from trackers.cli.__main__ import cli +from trackers.cli.track import ( _format_labels, _init_annotators, _resolve_class_filter, @@ -20,6 +22,25 @@ ) +class TestTrackCommandCLI: + """Smoke tests for the `track` CLI command via CliRunner.""" + + def test_help_exits_cleanly(self) -> None: + """--help must exit 0 and not raise a TypeError from class-type defaults.""" + runner = CliRunner() + result = runner.invoke(cli, ["track", "--help"]) + assert result.exit_code == 0, result.output + assert "ABCMETA" not in result.output + + def test_missing_source_and_detections_raises_usage_error(self) -> None: + """Omitting --source and --detections must raise UsageError, not TypeError.""" + runner = CliRunner() + result = runner.invoke(cli, ["track"]) + assert result.exit_code != 0 + assert result.exception is None or isinstance(result.exception, SystemExit) + assert "source" in result.output.lower() or "usage" in result.output.lower() + + class TestInitAnnotators: @pytest.mark.parametrize( "flags,expected_types,has_label_annotator", diff --git a/tests/scripts/test_tune.py b/tests/scripts/test_tune.py index 43e85799..5e653dcd 100644 --- a/tests/scripts/test_tune.py +++ b/tests/scripts/test_tune.py @@ -4,100 +4,150 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -"""CLI-level tests for trackers/scripts/tune.py.""" +"""CLI-level tests for trackers/cli/tune.py.""" from __future__ import annotations -import argparse import json from pathlib import Path from unittest.mock import MagicMock, patch import pytest +from click.testing import CliRunner -from trackers.scripts.tune import add_tune_subparser, run_tune, tune +from trackers.cli.__main__ import cli +from trackers.cli.tune import tune -def _make_parser() -> tuple[argparse.ArgumentParser, argparse._SubParsersAction]: - """Return a top-level parser with a subparsers group.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - return parser, subparsers +class TestTuneCommand: + """Click CLI surface for the tune subcommand.""" + def test_missing_required_args_exits_nonzero(self) -> None: + """tune without required flags exits with a non-zero code.""" + runner = CliRunner() + result = runner.invoke(cli, ["tune"]) + assert result.exit_code != 0 -class TestAddTuneSubparser: - @pytest.fixture - def minimal_args(self) -> argparse.Namespace: - """Parsed args with only required flags.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - return parser.parse_args(["tune", "--tracker", "sort", "--gt-dir", "/gt", "--detections-dir", "/det"]) - - def test_registers_tune_subcommand(self) -> None: - """tune subcommand is accessible under the 'tune' name.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args(["tune", "--tracker", "sort", "--gt-dir", "/gt", "--detections-dir", "/det"]) - assert args.func is run_tune + def test_tracker_flag_accepted(self, tmp_path: Path) -> None: + """--tracker, --gt-dir, --detections-dir are parsed without error when Tuner raises.""" + gt_dir = tmp_path / "gt" + gt_dir.mkdir() + det_dir = tmp_path / "det" + det_dir.mkdir() + runner = CliRunner() + result = runner.invoke( + cli, + ["tune", "--tracker", "bytetrack", "--gt-dir", str(gt_dir), "--detections-dir", str(det_dir)], + ) + # bytetrack with empty dirs → exit 1 from tune(), not a click error + assert result.exit_code in (0, 1) - def test_required_args_parsed(self) -> None: - """--tracker, --gt-dir, and --detections-dir are required and parsed.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args( + @pytest.mark.parametrize("objective", ["MOTA", "HOTA", "IDF1"]) + def test_objective_choices_accepted(self, tmp_path: Path, objective: str) -> None: + """Valid --objective values are accepted (exit comes from Tuner, not click).""" + gt_dir = tmp_path / "gt" + gt_dir.mkdir() + det_dir = tmp_path / "det" + det_dir.mkdir() + runner = CliRunner() + result = runner.invoke( + cli, [ "tune", "--tracker", "bytetrack", "--gt-dir", - "/data/gt", + str(gt_dir), "--detections-dir", - "/data/det", - ] + str(det_dir), + "--objective", + objective, + ], ) - assert args.tracker == "bytetrack" - assert args.gt_dir == Path("/data/gt") - assert args.detections_dir == Path("/data/det") - - @pytest.mark.parametrize( - "flag,expected", - [ - ("objective", "HOTA"), - ("n_trials", 100), - ("threshold", 0.5), - ("seqmap", None), - ("output", None), - ], - ) - def test_optional_defaults(self, minimal_args: argparse.Namespace, flag: str, expected: object) -> None: - """Optional arguments have correct defaults when omitted.""" - assert getattr(minimal_args, flag) == expected - - def test_metrics_default(self, minimal_args: argparse.Namespace) -> None: - """--metrics defaults to ['CLEAR'] when not supplied.""" - assert minimal_args.metrics == ["CLEAR"] - - def test_output_flag_short_form(self) -> None: - """-o is an alias for --output.""" - parser, subparsers = _make_parser() - add_tune_subparser(subparsers) - args = parser.parse_args( + assert result.exit_code in (0, 1) + + def test_invalid_objective_rejected(self, tmp_path: Path) -> None: + """Unknown --objective value exits with click usage error (code 2).""" + runner = CliRunner() + result = runner.invoke( + cli, [ "tune", "--tracker", - "sort", + "bytetrack", "--gt-dir", - "/gt", + str(tmp_path), "--detections-dir", - "/det", - "-o", - "/out/params.json", - ] + str(tmp_path), + "--objective", + "UNKNOWN", + ], ) - assert args.output == Path("/out/params.json") + assert result.exit_code == 2 + + def test_n_trials_flag(self, tmp_path: Path) -> None: + """--n-trials is forwarded to tune().""" + gt_dir = tmp_path / "gt" + gt_dir.mkdir() + det_dir = tmp_path / "det" + det_dir.mkdir() + mock_tuner = MagicMock() + mock_tuner.run.return_value = {"high_thresh": 0.6} + mock_tuner.study = None + runner = CliRunner() + with patch("trackers.tune.Tuner", return_value=mock_tuner) as mock_cls: + runner.invoke( + cli, + [ + "tune", + "--tracker", + "bytetrack", + "--gt-dir", + str(gt_dir), + "--detections-dir", + str(det_dir), + "--n-trials", + "50", + ], + ) + _, kwargs = mock_cls.call_args + assert kwargs.get("n_trials") == 50 + + def test_output_flag_writes_json(self, tmp_path: Path) -> None: + """-o writes best parameters to a JSON file.""" + gt_dir = tmp_path / "gt" + gt_dir.mkdir() + det_dir = tmp_path / "det" + det_dir.mkdir() + output_path = tmp_path / "params.json" + best = {"high_thresh": 0.6} + mock_tuner = MagicMock() + mock_tuner.run.return_value = best + mock_tuner.study = None + runner = CliRunner() + with patch("trackers.tune.Tuner", return_value=mock_tuner): + result = runner.invoke( + cli, + [ + "tune", + "--tracker", + "bytetrack", + "--gt-dir", + str(gt_dir), + "--detections-dir", + str(det_dir), + "-o", + str(output_path), + ], + ) + assert result.exit_code == 0 + assert output_path.exists() + assert json.loads(output_path.read_text()) == best class TestTune: + """Unit tests for the tune() helper function (no CLI layer).""" + def test_returns_1_on_invalid_tracker(self, tmp_path: Path) -> None: """Invalid tracker ID causes tune() to return exit code 1.""" gt_dir = tmp_path / "gt" @@ -113,7 +163,6 @@ def test_returns_1_on_missing_files(self, tmp_path: Path) -> None: gt_dir.mkdir() det_dir = tmp_path / "det" det_dir.mkdir() - # bytetrack is registered; empty det_dir → FileNotFoundError via Tuner result = tune("bytetrack", gt_dir, det_dir) assert result == 1 @@ -121,10 +170,7 @@ def test_returns_1_on_import_error(self, tmp_path: Path) -> None: """ImportError (e.g. optuna not installed) causes tune() to return 1.""" gt_dir = tmp_path / "gt" det_dir = tmp_path / "det" - with patch( - "trackers.tune.Tuner", - side_effect=ImportError("optuna is required"), - ): + with patch("trackers.tune.Tuner", side_effect=ImportError("optuna is required")): result = tune("bytetrack", gt_dir, det_dir) assert result == 1 @@ -178,36 +224,3 @@ def test_returns_1_on_tuner_run_exception(self, tmp_path: Path) -> None: with patch("trackers.tune.Tuner", return_value=mock_tuner): result = tune("bytetrack", gt_dir, det_dir) assert result == 1 - - -class TestRunTune: - def test_delegates_to_tune_with_namespace_args(self, tmp_path: Path) -> None: - """run_tune() passes all argparse.Namespace fields to tune() correctly.""" - gt_dir = tmp_path / "gt" - det_dir = tmp_path / "det" - output_path = tmp_path / "params.json" - args = argparse.Namespace( - tracker="sort", - gt_dir=gt_dir, - detections_dir=det_dir, - objective="MOTA", - n_trials=50, - metrics=["CLEAR", "HOTA"], - threshold=0.3, - seqmap=None, - output=output_path, - ) - with patch("trackers.scripts.tune.tune", return_value=0) as mock_tune: - result = run_tune(args) - assert result == 0 - mock_tune.assert_called_once_with( - tracker="sort", - gt_dir=gt_dir, - detections_dir=det_dir, - objective="MOTA", - n_trials=50, - metrics=["CLEAR", "HOTA"], - threshold=0.3, - seqmap=None, - output=output_path, - ) diff --git a/uv.lock b/uv.lock index 5c102b60..2bf71c16 100644 --- a/uv.lock +++ b/uv.lock @@ -4026,6 +4026,7 @@ name = "trackers" version = "2.4.0" source = { editable = "." } dependencies = [ + { name = "click" }, { name = "numpy" }, { name = "opencv-python" }, { name = "requests" }, @@ -4069,6 +4070,7 @@ mypy-types = [ [package.metadata] requires-dist = [ + { name = "click", specifier = ">=8.0" }, { name = "inference-models", marker = "extra == 'detection'", specifier = ">=0.19.0" }, { name = "numpy", specifier = ">=2.0.2" }, { name = "opencv-python", specifier = ">=4.8.0" },