Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,7 @@ def write_zarr(
store: StoreLike,
*,
chunks: tuple[int, ...] | None = None,
consolidate_metadata: bool = True,
convert_strings_to_categoricals: bool = True,
):
"""\
Expand All @@ -1956,6 +1957,8 @@ def write_zarr(
The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class.
chunks
Chunk shape.
consolidate_metadata
Whether to consolidate Zarr metadata.
convert_strings_to_categoricals
Convert string columns to categorical.
"""
Expand All @@ -1972,6 +1975,7 @@ def write_zarr(
store,
self,
chunks=chunks,
consolidate_metadata=consolidate_metadata,
convert_strings_to_categoricals=convert_strings_to_categoricals,
)

Expand Down
8 changes: 1 addition & 7 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,14 +386,8 @@ def write_elem(
if is_consolidated:
msg = "Cannot overwrite/edit a store with consolidated metadata"
raise ValueError(msg)
if k == "/":
if isinstance(store, ZarrGroup):
from zarr.core.sync import sync

sync(store.store.clear())
else:
store.clear()
elif k in store:
if k != "/" and k in store:
del store[k]

# Normalize array-API (e.g., JAX/CuPy) even if not AnnData
Expand Down
10 changes: 8 additions & 2 deletions src/anndata/_io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def write_zarr(
adata: AnnData,
*,
chunks: tuple[int, ...] | None = None,
consolidate_metadata: bool = True,
convert_strings_to_categoricals: bool = True,
**ds_kwargs,
) -> None:
Expand All @@ -55,7 +56,8 @@ def callback(
write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs)

write_dispatched(f, "/", adata, callback=callback, dataset_kwargs=ds_kwargs)
zarr.consolidate_metadata(f.store)
if consolidate_metadata:
zarr.consolidate_metadata(f.store)


def read_zarr(store: PathLike[str] | str | MutableMapping | zarr.Group) -> AnnData:
Expand Down Expand Up @@ -147,7 +149,11 @@ def open_write_group(
) -> zarr.Group:
if "zarr_format" not in kwargs:
kwargs["zarr_format"] = settings.zarr_write_format
return zarr.open_group(store, mode=mode, **kwargs)
return (
zarr.open_group(store, mode=mode, **kwargs)
if not isinstance(store, zarr.Group)
else store
)


def is_group_consolidated(group: zarr.Group) -> bool:
Expand Down
18 changes: 10 additions & 8 deletions src/anndata/experimental/backed/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..._core.xarray import requires_xarray
from ..._settings import settings
from ...compat import ZarrGroup
from ...utils import get_literal_members, warn
from ...utils import get_literal_members
from .. import read_dispatched

if TYPE_CHECKING:
Expand Down Expand Up @@ -106,13 +106,15 @@ def read_lazy(
import zarr

if not isinstance(store, ZarrGroup):
try:
f = zarr.open_consolidated(store, mode="r")
except ValueError:
msg = "Did not read zarr as consolidated. Consider consolidating your metadata."
warn(msg, UserWarning)
has_keys = False
f = zarr.open_group(store, mode="r")
from anndata._io.zarr import is_group_consolidated

has_keys = is_group_consolidated()
f = (
zarr.open_consolidated(store, mode="r")
if has_keys
else zarr.open_group(store, mode="r")
)

else:
f = store
elif is_store_arg_h5_store:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@ def test_write_elem_consolidated(tmp_path: Path):

@pytest.mark.zarr_io
def test_write_elem_version_mismatch(tmp_path: Path):
tmp_path = Path("foo")
zarr_path = tmp_path / "foo.zarr"
adata = ad.AnnData(np.ones((10, 10)))
g = zarr.open_group(
Expand All @@ -977,3 +978,49 @@ def test_write_elem_version_mismatch(tmp_path: Path):
ad.io.write_elem(g, "/", adata)
adata_roundtripped = ad.read_zarr(g)
assert_equal(adata_roundtripped, adata)


@pytest.mark.zarr_io
def test_write_zarr_store(tmp_path: Path):
zarr_path = tmp_path / "foo.zarr"
adata = ad.AnnData(np.ones((10, 10)))
g = zarr.open_group(
zarr_path,
mode="w",
)
adata.write_zarr(g)
adata_roundtripped = ad.read_zarr(g)
assert_equal(adata_roundtripped, adata)


@pytest.mark.zarr_io
def test_write_zarr_store_overwrite(tmp_path: Path):
zarr_path = tmp_path / "foo.zarr"
adata1 = ad.AnnData(np.ones((10, 10)))
g = zarr.open_group(
zarr_path,
mode="w",
)
adata1.write_zarr(g, consolidate_metadata=False)
adata2 = ad.AnnData(np.zeros((5, 5)))

adata2.write_zarr(g, consolidate_metadata=False)
adata_roundtripped = ad.read_zarr(g)
assert_equal(adata_roundtripped, adata2)


@pytest.mark.zarr_io
def test_write_zarr_store_separate_groups(tmp_path: Path):
zarr_path = tmp_path / "foo.zarr"
adata1 = ad.AnnData(np.ones((10, 10)))
adata2 = ad.AnnData(np.zeros((5, 5)))
g = zarr.open_group(
zarr_path,
mode="w",
)
g1 = g.create_group("g1")
g2 = g.create_group("g2")
adata1.write_zarr(g1, consolidate_metadata=False)
adata2.write_zarr(g2, consolidate_metadata=False)
assert_equal(ad.read_zarr(g1), adata1)
assert_equal(ad.read_zarr(g2), adata2)
Loading