Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 97 additions & 3 deletions benchmarks/benchmarks/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from types import MappingProxyType
from typing import TYPE_CHECKING

import h5py
import numpy as np
import pandas as pd
import zarr
Expand All @@ -25,6 +26,41 @@ def make_alternating_mask(n):
return mask_alternating


def make_integer_indexers() -> MappingProxyType:
rng = np.random.default_rng(42)
fragmented_sample = rng.choice(10_000, size=2_048, replace=False)
return MappingProxyType({
"single_run": np.arange(2_048),
"multiple_runs": np.concatenate([
np.arange(0, 512),
np.arange(1_500, 2_012),
np.arange(3_000, 3_512),
np.arange(7_000, 7_512),
]),
"fragmented_sorted": np.sort(fragmented_sample),
"fragmented_unsorted": rng.permutation(fragmented_sample),
"clustered_shuffled": rng.permutation(
np.concatenate([
np.arange(100, 612),
np.arange(1_500, 2_012),
np.arange(3_000, 3_512),
np.arange(7_000, 7_512),
])
),
"clustered_duplicates": rng.permutation(
np.concatenate([
np.repeat(np.arange(100, 356), 2),
np.repeat(np.arange(1_500, 1_756), 2),
np.repeat(np.arange(3_000, 3_256), 2),
np.repeat(np.arange(7_000, 7_256), 2),
])
),
})


INT_INDEXERS = make_integer_indexers()


class SparseCSRContiguousSlice:
_indexers = MappingProxyType({
"0:1000": slice(0, 1000),
Expand Down Expand Up @@ -54,7 +90,7 @@ def setup_cache(self):
format="csr",
random_state=np.random.default_rng(42),
)
g = zarr.group(self.filepath)
g = zarr.open(self.filepath, mode="w")
write_elem(g, "X", X)

def setup(self, index: str, use_dask: bool): # noqa: FBT001
Expand Down Expand Up @@ -84,6 +120,64 @@ def peakmem_getitem_adata(self, *_):
res.compute()


class SparseBackedIntegerIndexing:
filepath = "data"
params = (
("h5ad", "zarr"),
("csr", "csc"),
list(INT_INDEXERS.keys()),
)
param_names = ("store_type", "sparse_format", "index_case")

def setup_cache(self):
rng = np.random.default_rng(42)
csr = sparse.random(
10_000,
10_000,
density=0.01,
format="csr",
random_state=rng,
)
csc = csr.tocsc()
for store_type in ["h5ad", "zarr"]:
path = f"{self.filepath}.{store_type}"
if store_type == "h5ad":
with h5py.File(path, mode="w") as f:
write_elem(f, "csr", csr)
write_elem(f, "csc", csc)
else:
g = zarr.open(path, mode="w")
write_elem(g, "csr", csr)
write_elem(g, "csc", csc)

def setup(self, store_type: str, sparse_format: str, index_case: str):
self._h5_file = None
if store_type == "h5ad":
self._h5_file = h5py.File(f"{self.filepath}.h5ad", mode="r")
self.group = self._h5_file
else:
self.group = zarr.open(f"{self.filepath}.zarr", mode="r")
self.x = sparse_dataset(self.group[sparse_format])
self.index = INT_INDEXERS[index_case]
self.is_csr = sparse_format == "csr"

def teardown(self, *_):
if self._h5_file is not None:
self._h5_file.close()

def time_getitem(self, *_):
if self.is_csr:
self.x[self.index, :]
else:
self.x[:, self.index]

def peakmem_getitem(self, *_):
if self.is_csr:
self.x[self.index, :]
else:
self.x[:, self.index]


class SparseCSRDaskConcat:
filepath = "data.zarr"

Expand All @@ -98,7 +192,7 @@ def setup_cache(self):
format="csr",
random_state=np.random.default_rng(42),
)
g = zarr.group(self.filepath)
g = zarr.open(self.filepath, mode="w")
write_elem(g, "X", X)

def setup(self, *_):
Expand Down Expand Up @@ -146,7 +240,7 @@ def setup_cache(self):
format="csr",
random_state=np.random.default_rng(42),
)
g = zarr.group(self.filepath)
g = zarr.open(self.filepath, mode="w")
write_elem(g, "X", X)

def setup(self, *_):
Expand Down
58 changes: 52 additions & 6 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ def slice_as_int(s: slice, l: int) -> int:
return out[0]


def _contiguous_slices_from_sorted_indices(indices: np.ndarray) -> list[slice]:
split_points = np.flatnonzero(np.diff(indices) != 1) + 1
starts = np.concatenate(([0], split_points))
stops = np.concatenate((split_points, [len(indices)]))
return [
slice(indices[start], indices[stop - 1] + 1)
for start, stop in zip(starts, stops, strict=False)
]


@dataclass
class BackedSparseMatrix[ArrayT: _ArrayStorageType]:
"""\
Expand Down Expand Up @@ -275,18 +285,54 @@ def _get_sliceXslice(
def _get_arrayXslice(
self, major_index: Sequence | np.ndarray, minor_index: slice
) -> SparseMatrixType:
major_index = np.asarray(major_index)
if major_index.dtype == bool:
major_index = np.flatnonzero(major_index)
if len(major_index) == 0:
return self.memory_format(
(0, self.minor_axis_size)
if self.format == "csr"
else (self.minor_axis_size, 0)
)
if major_index.dtype == bool:
major_index = np.where(major_index)
out_shape = self._gen_maj_min_tuple(len(major_index), self.minor_axis_size)
return self.memory_format(
self.get_compressed_vectors(major_index), shape=out_shape
)[self._gen_maj_min_tuple(slice(None), minor_index)]
if np.any(major_index < 0):
return self.memory_format(
(self.data[...], self.indices[...], self.indptr[...]),
shape=self.shape,
)[self._gen_maj_min_tuple(major_index, minor_index)]
if np.any(major_index >= self.shape[self.major_axis]):
max_index = major_index.max()
msg = f"index ({max_index}) out of range"
raise IndexError(msg)

unique_major_index = np.unique(major_index)
run_count = 1 + np.count_nonzero(np.diff(unique_major_index) != 1)
mean_slice_length = len(unique_major_index) / run_count
if mean_slice_length <= 7:
out_shape = self._gen_maj_min_tuple(len(major_index), self.minor_axis_size)
return self.memory_format(
self.get_compressed_vectors(major_index), shape=out_shape
)[self._gen_maj_min_tuple(slice(None), minor_index)]

original_major_index = np.asarray(major_index)
inverse = np.searchsorted(unique_major_index, original_major_index)
if run_count == 1:
compressed_vectors = self._get_contiguous_compressed_slice(
slice(unique_major_index[0], unique_major_index[-1] + 1)
)
else:
slices = _contiguous_slices_from_sorted_indices(unique_major_index)
compressed_vectors = self.get_compressed_vectors_for_slices(slices)
sub = self.memory_format(
compressed_vectors,
shape=self._gen_maj_min_tuple(
len(unique_major_index), self.minor_axis_size
),
)
if len(unique_major_index) != len(original_major_index) or not np.array_equal(
unique_major_index, original_major_index
):
sub = sub[self._gen_maj_min_tuple(inverse, slice(None))]
return sub[self._gen_maj_min_tuple(slice(None), minor_index)]

def subset_by_major_axis_mask(
self: BackedSparseMatrix, mask: np.ndarray
Expand Down
Loading
Loading