Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions benchmarks/benchmarks/h5py_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

import numpy as np

from anndata._core.index import _process_index_for_h5py


class ProcessIndexForH5py:
param_names = ("size", "scenario")
params = (
[10_000, 100_000, 1_000_000],
["sorted_unique", "unsorted_unique", "duplicate_heavy"],
)

def setup(self, size, scenario):
rng = np.random.default_rng(0)
if scenario == "sorted_unique":
self.idx = np.arange(size, dtype=np.int64)
elif scenario == "unsorted_unique":
self.idx = rng.permutation(size).astype(np.int64)
else:
self.idx = rng.integers(0, max(1, size // 10), size=size, dtype=np.int64)

def time_process_index_for_h5py(self, size, scenario):
_process_index_for_h5py(self.idx)
Comment on lines +24 to +25

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be benchmarking private methods. Please either remove the benchmark or do it against a public method

32 changes: 3 additions & 29 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _subset_dataset(
) -> np.ndarray:
order: tuple[NDArray[np.integer] | slice, ...]
inv_order: tuple[NDArray[np.integer] | slice, ...]
order, inv_order = zip(*map(_index_order_and_inverse, subset_idx), strict=True)
order, inv_order = zip(*map(_process_index_for_h5py, subset_idx), strict=True)
# check for duplicates or multi-dimensional fancy indexing
array_dims = [i for i in order if isinstance(i, np.ndarray)]
has_duplicates = any(len(np.unique(i)) != len(i) for i in array_dims)
Expand All @@ -465,25 +465,7 @@ def _subset_dataset(
# For multi-dimensional indexing, bypass the sorting logic and use original indices
return _safe_fancy_index_h5py(d, subset_idx)
# from hdf5, then to real order
return d[order][inv_order]


@overload
def _index_order_and_inverse(
axis_idx: NDArray[np.integer] | NDArray[np.bool_],
) -> tuple[NDArray[np.integer], NDArray[np.integer]]: ...
@overload
def _index_order_and_inverse(axis_idx: slice) -> tuple[slice, slice]: ...
def _index_order_and_inverse(
axis_idx: _Index1DNorm,
) -> tuple[_Index1DNorm, NDArray[np.integer] | slice]:
"""Order and get inverse index array."""
if not isinstance(axis_idx, np.ndarray):
return axis_idx, slice(None)
if axis_idx.dtype == bool:
axis_idx = np.flatnonzero(axis_idx)
order = np.argsort(axis_idx)
return axis_idx[order], np.argsort(order)
return d[order][tuple(slice(None) if i is None else i for i in inv_order)]


@overload
Expand All @@ -503,16 +485,8 @@ def _process_index_for_h5py(
if idx.dtype == bool:
idx = np.flatnonzero(idx)

# For h5py fancy indexing, we need sorted indices
# But we also need to track how to reverse the sorting
unique, inverse = np.unique(idx, return_inverse=True)
return (
# Has duplicates - use unique + inverse mapping approach
(unique, inverse)
if len(unique) != len(idx)
# No duplicates - just sort and track reverse mapping
else _index_order_and_inverse(idx)
)
return unique, inverse


def _safe_fancy_index_h5py(
Expand Down
45 changes: 45 additions & 0 deletions tests/test_backed_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,54 @@ def h5py_test_data(tmp_path):
return h5_path, test_data


@pytest.mark.parametrize(
("idx", "expected"),
[
pytest.param(
np.array([0, 1, 2], dtype=np.int64),
(np.array([0, 1, 2], dtype=np.int64), np.array([0, 1, 2])),
id="sorted_unique",
),
pytest.param(
np.array([3, 1, 0, 2], dtype=np.int64),
(np.array([0, 1, 2, 3], dtype=np.int64), np.array([3, 1, 0, 2])),
id="unsorted_unique",
),
pytest.param(
np.array([2, 2, 5, 3, 8, 10, 8], dtype=np.int64),
(
np.array([2, 3, 5, 8, 10], dtype=np.int64),
np.array([0, 0, 2, 1, 3, 4, 3]),
),
id="duplicate_heavy",
),
pytest.param(
np.array([True, False, True, False], dtype=bool),
(np.array([0, 2], dtype=np.int64), np.array([0, 1])),
id="boolean_mask",
),
pytest.param(
np.array([], dtype=np.int64),
(np.array([], dtype=np.int64), np.array([], dtype=np.int64)),
id="empty",
),
],
)
def test_process_index_for_h5py(idx, expected):
from anndata._core.index import _process_index_for_h5py

processed, reverse = _process_index_for_h5py(idx)
expected_processed, expected_reverse = expected

np.testing.assert_array_equal(processed, expected_processed)
np.testing.assert_array_equal(reverse, expected_reverse)


@pytest.mark.parametrize(
("indices", "description"),
[
pytest.param((np.array([0, 1, 0, 2]),), "single_dimension_with_duplicates"),
pytest.param((np.array([3, 1, 0, 2]),), "single_dimension_unsorted_unique"),
pytest.param(
(np.array([0, 1, 2]), np.array([1, 2])), "multi_dimensional_no_duplicates"
),
Expand All @@ -450,6 +494,7 @@ def h5py_test_data(tmp_path):
(np.array([0, 1, 0]), [1, 2]), "mixed_indexing_with_slices_and_lists"
),
pytest.param((np.array([3, 1, 3, 0, 1]),), "unsorted_indices_with_duplicates"),
pytest.param((np.array([], dtype=np.int64),), "empty_indices"),
],
)
def test_safe_fancy_index_h5py_function(h5py_test_data, indices, description):
Expand Down
Loading