diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 03ac68dad..fac9b82fc 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1945,6 +1945,7 @@ def write_zarr( store: StoreLike, *, chunks: tuple[int, ...] | None = None, + consolidate_metadata: bool = True, convert_strings_to_categoricals: bool = True, ): """\ @@ -1956,6 +1957,8 @@ def write_zarr( The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class. chunks Chunk shape. + consolidate_metadata + Whether to consolidate Zarr metadata. convert_strings_to_categoricals Convert string columns to categorical. """ @@ -1972,6 +1975,7 @@ def write_zarr( store, self, chunks=chunks, + consolidate_metadata=consolidate_metadata, convert_strings_to_categoricals=convert_strings_to_categoricals, ) diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 51726e4e2..19bcf1d71 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -386,14 +386,8 @@ def write_elem( if is_consolidated: msg = "Cannot overwrite/edit a store with consolidated metadata" raise ValueError(msg) - if k == "/": - if isinstance(store, ZarrGroup): - from zarr.core.sync import sync - sync(store.store.clear()) - else: - store.clear() - elif k in store: + if k != "/" and k in store: del store[k] # Normalize array-API (e.g., JAX/CuPy) even if not AnnData diff --git a/src/anndata/_io/zarr.py b/src/anndata/_io/zarr.py index 06d16d909..689c9b1a5 100644 --- a/src/anndata/_io/zarr.py +++ b/src/anndata/_io/zarr.py @@ -30,6 +30,7 @@ def write_zarr( adata: AnnData, *, chunks: tuple[int, ...] | None = None, + consolidate_metadata: bool = True, convert_strings_to_categoricals: bool = True, **ds_kwargs, ) -> None: @@ -55,7 +56,8 @@ def callback( write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs) write_dispatched(f, "/", adata, callback=callback, dataset_kwargs=ds_kwargs) - zarr.consolidate_metadata(f.store) + if consolidate_metadata: + zarr.consolidate_metadata(f.store) def read_zarr(store: PathLike[str] | str | MutableMapping | zarr.Group) -> AnnData: @@ -147,7 +149,11 @@ def open_write_group( ) -> zarr.Group: if "zarr_format" not in kwargs: kwargs["zarr_format"] = settings.zarr_write_format - return zarr.open_group(store, mode=mode, **kwargs) + return ( + zarr.open_group(store, mode=mode, **kwargs) + if not isinstance(store, zarr.Group) + else store + ) def is_group_consolidated(group: zarr.Group) -> bool: diff --git a/src/anndata/experimental/backed/_io.py b/src/anndata/experimental/backed/_io.py index 2cf33a918..f468ae850 100644 --- a/src/anndata/experimental/backed/_io.py +++ b/src/anndata/experimental/backed/_io.py @@ -15,7 +15,7 @@ from ..._core.xarray import requires_xarray from ..._settings import settings from ...compat import ZarrGroup -from ...utils import get_literal_members, warn +from ...utils import get_literal_members from .. import read_dispatched if TYPE_CHECKING: @@ -106,13 +106,15 @@ def read_lazy( import zarr if not isinstance(store, ZarrGroup): - try: - f = zarr.open_consolidated(store, mode="r") - except ValueError: - msg = "Did not read zarr as consolidated. Consider consolidating your metadata." - warn(msg, UserWarning) - has_keys = False - f = zarr.open_group(store, mode="r") + from anndata._io.zarr import is_group_consolidated + + has_keys = is_group_consolidated() + f = ( + zarr.open_consolidated(store, mode="r") + if has_keys + else zarr.open_group(store, mode="r") + ) + else: f = store elif is_store_arg_h5_store: diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 3359b2ff8..146748e82 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -967,6 +967,7 @@ def test_write_elem_consolidated(tmp_path: Path): @pytest.mark.zarr_io def test_write_elem_version_mismatch(tmp_path: Path): + tmp_path = Path("foo") zarr_path = tmp_path / "foo.zarr" adata = ad.AnnData(np.ones((10, 10))) g = zarr.open_group( @@ -977,3 +978,49 @@ def test_write_elem_version_mismatch(tmp_path: Path): ad.io.write_elem(g, "/", adata) adata_roundtripped = ad.read_zarr(g) assert_equal(adata_roundtripped, adata) + + +@pytest.mark.zarr_io +def test_write_zarr_store(tmp_path: Path): + zarr_path = tmp_path / "foo.zarr" + adata = ad.AnnData(np.ones((10, 10))) + g = zarr.open_group( + zarr_path, + mode="w", + ) + adata.write_zarr(g) + adata_roundtripped = ad.read_zarr(g) + assert_equal(adata_roundtripped, adata) + + +@pytest.mark.zarr_io +def test_write_zarr_store_overwrite(tmp_path: Path): + zarr_path = tmp_path / "foo.zarr" + adata1 = ad.AnnData(np.ones((10, 10))) + g = zarr.open_group( + zarr_path, + mode="w", + ) + adata1.write_zarr(g, consolidate_metadata=False) + adata2 = ad.AnnData(np.zeros((5, 5))) + + adata2.write_zarr(g, consolidate_metadata=False) + adata_roundtripped = ad.read_zarr(g) + assert_equal(adata_roundtripped, adata2) + + +@pytest.mark.zarr_io +def test_write_zarr_store_separate_groups(tmp_path: Path): + zarr_path = tmp_path / "foo.zarr" + adata1 = ad.AnnData(np.ones((10, 10))) + adata2 = ad.AnnData(np.zeros((5, 5))) + g = zarr.open_group( + zarr_path, + mode="w", + ) + g1 = g.create_group("g1") + g2 = g.create_group("g2") + adata1.write_zarr(g1, consolidate_metadata=False) + adata2.write_zarr(g2, consolidate_metadata=False) + assert_equal(ad.read_zarr(g1), adata1) + assert_equal(ad.read_zarr(g2), adata2)