diff --git a/docs/changelog.md b/docs/changelog.md index b56a0b5fa..cf612f33d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -11,6 +11,8 @@ date_modified: 2026-04-30 - Added [#2267](https://github.com/roboflow/supervision/pull/2267): [`DetectionDataset.as_coco`](https://supervision.roboflow.com/latest/datasets/core/#supervision.dataset.core.DetectionDataset.as_coco) and `save_coco_annotations` now accept `starting_image_id` and `starting_annotation_id` parameters (both default to `1`, preserving existing behavior) and return a `(next_image_id, next_annotation_id)` tuple. Feed the returned values into the next split's call to produce globally unique COCO ids across train/valid/test exports. Fixes id collisions reported in [#768](https://github.com/roboflow/supervision/issues/768). **Note**: the return type changes from `None` to `tuple[int, int]` — callers that assert `result is None` must be updated. +- Added [#2027](https://github.com/roboflow/supervision/issues/2027): [`sv.InferenceSlicer`](https://supervision.roboflow.com/latest/detection/tools/inference_slicer/#supervision.detection.tools.inference_slicer.InferenceSlicer) now accepts an open rasterio-style dataset in addition to in-memory images. Each tile is read lazily via a windowed read instead of loading the whole image, enabling tiled inference on multi-GB aerial/drone GeoTIFFs without running out of memory. Detection is duck-typed, so `rasterio` stays an optional dependency installable via `pip install "supervision[geotiff]"` and the core library imports no rasterio symbols. A geographic (non-projected) CRS raises `ValueError`. + ### 0.28.0 Apr 30, 2026 - Added [#2159](https://github.com/roboflow/supervision/pull/2159): [`sv.CompactMask`](https://supervision.roboflow.com/latest/detection/compact_mask/#supervision.detection.compact_mask.CompactMask) for memory-efficient mask storage. Masks are stored as crop-region bounding boxes plus RLE-encoded data instead of full-resolution bitmaps, reducing memory by up to 240× for sparse masks. Integrates transparently with `sv.Detections.mask` — filtering, merging, and `area` all work without materialising the full array. diff --git a/pyproject.toml b/pyproject.toml index 855cdc436..269054878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,9 @@ dependencies = [ "scipy>=1.10", "tqdm>=4.62.3" ] +optional-dependencies.geotiff = [ + "rasterio>=1.3", +] optional-dependencies.metrics = [ "pandas>=2", ] diff --git a/src/supervision/detection/tools/inference_slicer.py b/src/supervision/detection/tools/inference_slicer.py index c26a7584a..71b5445a8 100644 --- a/src/supervision/detection/tools/inference_slicer.py +++ b/src/supervision/detection/tools/inference_slicer.py @@ -4,7 +4,7 @@ import warnings from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any +from typing import Any, Protocol import numpy as np import numpy.typing as npt @@ -20,6 +20,35 @@ from supervision.utils.internal import SupervisionWarnings +class WindowedRasterDataset(Protocol): + """Structural type for a rasterio-style dataset read window-by-window. + + Matched structurally (see [`_is_windowed_raster`][]) rather than by import so + `rasterio` stays an optional dependency — any object exposing these members + works. `rasterio.io.DatasetReader` satisfies this protocol. + """ + + width: int + height: int + crs: Any + + def read(self, window: Any) -> npt.NDArray[Any]: ... + + +def _is_windowed_raster(image: object) -> bool: + """Duck-type check for a rasterio-style dataset that supports windowed reads. + + Avoids importing rasterio so it remains an optional dependency. numpy arrays + and PIL images do not expose this combination of attributes. + """ + return ( + callable(getattr(image, "read", None)) + and hasattr(image, "crs") + and hasattr(image, "width") + and hasattr(image, "height") + ) + + def move_detections( detections: Detections, offset: npt.NDArray[Any], @@ -138,6 +167,24 @@ def callback(tile): image = Image.open("example.png") detections = slicer(image) ``` + + ```python + import rasterio + import supervision as sv + + def callback(tile): # tile is (H, W, C); select/convert bands as needed + ... + + slicer = sv.InferenceSlicer(callback, slice_wh=640, overlap_wh=100) + + with rasterio.open("large_orthomosaic.tif") as dataset: + detections = slicer(dataset) + ``` + + Passing an open rasterio dataset reads each tile lazily via a windowed + read, so multi-GB GeoTIFFs never need to be loaded into memory at once. + `rasterio` is an optional dependency installable via + `pip install "supervision[geotiff]"`. """ def __init__( @@ -175,7 +222,7 @@ def __init__( self._obb_thread_workers_warned: bool = False self._obb_thread_workers_lock = threading.Lock() - def __call__(self, image: ImageType) -> Detections: + def __call__(self, image: ImageType | WindowedRasterDataset) -> Detections: """ Perform tiled inference on the full image and return merged detections. @@ -188,13 +235,31 @@ def __call__(self, image: ImageType) -> Detections: once per slicer instance. Args: - image: The full image to run inference on. + image: The full image to run inference on. In addition to in-memory + images (NumPy arrays or PIL images), this also accepts an open + rasterio-style dataset. When a dataset is provided, each tile is + read lazily via a windowed read instead of loading the whole image + into memory, enabling tiled inference on multi-GB GeoTIFFs. Tiles + read from a dataset preserve the source dtype (e.g. ``uint16`` for + 16-bit sensors) and keep every band; convert or select bands to + the dtype/channels your model expects inside the callback. Returns: Merged detections across all slices. """ detections_list: list[Detections] = [] - resolution_wh = get_image_resolution_wh(image) + if _is_windowed_raster(image): + crs = image.crs + if crs is not None and not crs.is_projected: + raise ValueError( + "InferenceSlicer requires a projected coordinate reference " + "system for pixel-space tiled inference on a raster dataset. " + f"The provided dataset uses a geographic CRS ({crs}). Reproject " + "it to a projected CRS (e.g. with `gdalwarp`) before slicing." + ) + resolution_wh = (image.width, image.height) + else: + resolution_wh = get_image_resolution_wh(image) offsets = self._generate_offset( resolution_wh=resolution_wh, @@ -272,7 +337,9 @@ def __call__(self, image: ImageType) -> Detections: ) return merged - def _run_callback(self, image: ImageType, offset: npt.NDArray[Any]) -> Detections: + def _run_callback( + self, image: ImageType | WindowedRasterDataset, offset: npt.NDArray[Any] + ) -> Detections: """ Run detection callback on a sliced portion of the image and adjust coordinates. @@ -284,7 +351,20 @@ def _run_callback(self, image: ImageType, offset: npt.NDArray[Any]) -> Detection Returns: Detections adjusted to the full image coordinate system. """ - image_slice = crop_image(image=image, xyxy=offset) + if _is_windowed_raster(image): + x_min, y_min, x_max, y_max = (int(v) for v in offset) + # rasterio tuple window: + # ((row_start, row_stop), (col_start, col_stop)) + window = ((y_min, y_max), (x_min, x_max)) + bands = image.read(window=window) # shape (channels, height, width) + image_slice = np.ascontiguousarray( + np.transpose(bands, (1, 2, 0)) + ) # -> (H, W, C) + resolution_wh = (image.width, image.height) + else: + image_slice = crop_image(image=image, xyxy=offset) + resolution_wh = get_image_resolution_wh(image) + detections = self.callback(image_slice) if ( @@ -299,7 +379,6 @@ def _run_callback(self, image: ImageType, offset: npt.NDArray[Any]) -> Detection image_shape=(slice_h, slice_w), ) - resolution_wh = get_image_resolution_wh(image) # Fast-path: skip locking and bounds checking when the warning has already # been emitted or when there are no detections to inspect. needs_warning_check = ( diff --git a/tests/detection/tools/test_inference_slicer_geotiff.py b/tests/detection/tools/test_inference_slicer_geotiff.py new file mode 100644 index 000000000..dd6beb934 --- /dev/null +++ b/tests/detection/tools/test_inference_slicer_geotiff.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from supervision.detection.core import Detections +from supervision.detection.tools.inference_slicer import InferenceSlicer + + +class _FakeCRS: + """Minimal rasterio-style CRS stub exposing only `is_projected`.""" + + def __init__(self, is_projected: bool): + self.is_projected = is_projected + + def __repr__(self) -> str: + kind = "projected" if self.is_projected else "geographic" + return f"_FakeCRS({kind})" + + +class _FakeRasterDataset: + """Lightweight rasterio-style dataset supporting windowed reads. + + Mimics the duck-typed interface that ``InferenceSlicer`` relies on without + requiring ``rasterio`` to be installed. + """ + + def __init__(self, image_hwc: np.ndarray, crs: object | None = None): + self._image = image_hwc # numpy (H, W, C) + self.height, self.width = image_hwc.shape[:2] + self.crs = crs # None or object with .is_projected + + def read(self, window: tuple[tuple[int, int], tuple[int, int]]) -> np.ndarray: + (row_start, row_stop), (col_start, col_stop) = window + crop = self._image[row_start:row_stop, col_start:col_stop, :] + return np.transpose(crop, (2, 0, 1)) # (C, H, W) like rasterio + + +def _fixed_detection_callback(_: np.ndarray) -> Detections: + """Return a constant detection for every tile.""" + return Detections( + xyxy=np.array([[0, 0, 10, 10]], dtype=float), + confidence=np.array([0.9]), + class_id=np.array([0]), + ) + + +def _sortable(detections: Detections) -> np.ndarray: + """Sort detection boxes so two runs can be compared order-independently.""" + return np.array( + sorted(detections.xyxy.tolist()), + dtype=float, + ) + + +def test_windowed_raster_matches_in_memory_array() -> None: + # Arrange + rng = np.random.default_rng(42) + image = rng.integers(0, 255, size=(256, 256, 3), dtype=np.uint8) + dataset = _FakeRasterDataset(image, crs=_FakeCRS(is_projected=True)) + slicer = InferenceSlicer( + callback=_fixed_detection_callback, + slice_wh=128, + overlap_wh=0, + ) + + # Act + detections_array = slicer(image) + detections_raster = slicer(dataset) + + # Assert + assert np.array_equal(_sortable(detections_array), _sortable(detections_raster)) + + +def test_windowed_raster_reads_correct_window_content() -> None: + """The windowed read must return the same pixels crop_image would, so the + callback sees identical tile content for both input types.""" + # Arrange + rng = np.random.default_rng(7) + image = rng.integers(0, 255, size=(128, 192, 3), dtype=np.uint8) + dataset = _FakeRasterDataset(image) + + seen_array_tiles: list[np.ndarray] = [] + seen_raster_tiles: list[np.ndarray] = [] + + def recording_callback(sink: list[np.ndarray]): + def callback(tile: np.ndarray) -> Detections: + sink.append(tile.copy()) + return Detections.empty() + + return callback + + slicer_array = InferenceSlicer( + callback=recording_callback(seen_array_tiles), + slice_wh=64, + overlap_wh=0, + ) + slicer_raster = InferenceSlicer( + callback=recording_callback(seen_raster_tiles), + slice_wh=64, + overlap_wh=0, + ) + + # Act + slicer_array(image) + slicer_raster(dataset) + + # Assert + assert len(seen_array_tiles) == len(seen_raster_tiles) + for array_tile, raster_tile in zip(seen_array_tiles, seen_raster_tiles): + assert np.array_equal(array_tile, raster_tile) + + +def test_windowed_raster_matches_in_memory_array_with_overlap() -> None: + """Overlapping tiles must read identical windows for both input types.""" + # Arrange + rng = np.random.default_rng(99) + image = rng.integers(0, 255, size=(200, 220, 3), dtype=np.uint8) + dataset = _FakeRasterDataset(image) + + seen_array_tiles: list[np.ndarray] = [] + seen_raster_tiles: list[np.ndarray] = [] + + def recording_callback(sink: list[np.ndarray]): + def callback(tile: np.ndarray) -> Detections: + sink.append(tile.copy()) + return Detections.empty() + + return callback + + slicer_array = InferenceSlicer( + callback=recording_callback(seen_array_tiles), + slice_wh=96, + overlap_wh=32, + ) + slicer_raster = InferenceSlicer( + callback=recording_callback(seen_raster_tiles), + slice_wh=96, + overlap_wh=32, + ) + + # Act + slicer_array(image) + slicer_raster(dataset) + + # Assert + assert len(seen_array_tiles) == len(seen_raster_tiles) > 1 + for array_tile, raster_tile in zip(seen_array_tiles, seen_raster_tiles): + assert np.array_equal(array_tile, raster_tile) + + +def test_windowed_raster_preserves_band_dtype() -> None: + """Tiles read from a dataset keep the source dtype (e.g. uint16).""" + # Arrange + rng = np.random.default_rng(5) + image = rng.integers(0, 4000, size=(128, 128, 3), dtype=np.uint16) + dataset = _FakeRasterDataset(image) + + seen: list[np.ndarray] = [] + + def callback(tile: np.ndarray) -> Detections: + seen.append(tile) + return Detections.empty() + + slicer = InferenceSlicer(callback=callback, slice_wh=64, overlap_wh=0) + + # Act + slicer(dataset) + + # Assert + assert seen + assert all(tile.dtype == np.uint16 for tile in seen) + + +def test_windowed_raster_with_no_crs_works() -> None: + # Arrange + image = np.zeros((128, 128, 3), dtype=np.uint8) + dataset = _FakeRasterDataset(image, crs=None) + slicer = InferenceSlicer( + callback=_fixed_detection_callback, + slice_wh=64, + overlap_wh=0, + ) + + # Act + detections = slicer(dataset) + + # Assert + assert len(detections) == 4 + + +def test_windowed_raster_with_geographic_crs_raises() -> None: + # Arrange + image = np.zeros((128, 128, 3), dtype=np.uint8) + dataset = _FakeRasterDataset(image, crs=_FakeCRS(is_projected=False)) + slicer = InferenceSlicer( + callback=_fixed_detection_callback, + slice_wh=64, + overlap_wh=0, + ) + + # Act / Assert + with pytest.raises(ValueError, match="projected coordinate reference"): + slicer(dataset) + + +def test_windowed_raster_with_projected_crs_does_not_raise() -> None: + # Arrange + image = np.zeros((128, 128, 3), dtype=np.uint8) + dataset = _FakeRasterDataset(image, crs=_FakeCRS(is_projected=True)) + slicer = InferenceSlicer( + callback=_fixed_detection_callback, + slice_wh=64, + overlap_wh=0, + ) + + # Act + detections = slicer(dataset) + + # Assert + assert len(detections) == 4 + + +def test_real_rasterio_memoryfile_integration() -> None: + """Integration check against a real rasterio dataset, skipped if rasterio + is not installed.""" + pytest.importorskip("rasterio") + from rasterio.io import MemoryFile + + # Arrange + rng = np.random.default_rng(123) + image = rng.integers(0, 255, size=(128, 128, 3), dtype=np.uint8) + bands = np.transpose(image, (2, 0, 1)) # (C, H, W) + + slicer = InferenceSlicer( + callback=_fixed_detection_callback, + slice_wh=64, + overlap_wh=0, + ) + detections_array = slicer(image) + + profile = { + "driver": "GTiff", + "height": image.shape[0], + "width": image.shape[1], + "count": image.shape[2], + "dtype": image.dtype, + } + + # Act + with MemoryFile() as memfile: + with memfile.open(**profile) as dst: + dst.write(bands) + with memfile.open() as dataset: + detections_raster = slicer(dataset) + + # Assert + assert np.array_equal(_sortable(detections_array), _sortable(detections_raster))