From c4a362d6ce604dd1070e3f46d4d67d004259b21a Mon Sep 17 00:00:00 2001 From: Ekin-Kahraman Date: Sat, 2 May 2026 13:14:10 +0100 Subject: [PATCH 1/7] feat(concat): aligned_axis_key_join for on-axis key joining Add a parameter to ad.concat() that separates how on-axis keys are joined (obs/var columns; obsm/obsp or varm/varp keys) from how off-axis indices are aligned. Default None falls back to join for backward compatibility. Closes #2374 --- src/anndata/_core/merge.py | 198 ++++++++++++++++++++-- tests/test_concatenate.py | 327 +++++++++++++++++++++++++++++++++++++ 2 files changed, 512 insertions(+), 13 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 0bff842ae..0e88fc6ba 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1136,6 +1136,115 @@ def outer_concat_aligned_mapping( return result +def _concat_aligned_mapping_split_join( # noqa: PLR0913 + mappings, + *, + key_join: Join_T, + content_join: Join_T, + fill_value=None, + axis=0, + concat_axis=None, + index=None, + force_lazy: bool = False, +): + """Concatenate aligned mappings (obsm/varm style) with separate key and + content joins. ``key_join`` selects which keys appear in the result; + ``content_join`` selects how shared keys' values are aligned along the + off-axis dimension (e.g. inner intersects DataFrame columns, outer unions + them). Used for ``concat(aligned_axis_key_join=...)`` when the on-axis key + join differs from the off-axis ``join``. + """ + if concat_axis is None: + concat_axis = axis + keys = union_keys(mappings) if key_join == "outer" else intersect_keys(mappings) + ns = [m.parent.shape[axis] for m in mappings] + + result = {} + for k in keys: + els = [m.get(k, MissingVal) for m in mappings] + any_missing = any(is_missing(el) for el in els) + present_els = [el for el in els if not_missing(el)] + + if content_join == "inner": + # Inner content alignment: intersect the off-axis dimension among + # values that are actually present, then reindex everything to that + # intersection. Missing entries get a filler matching the shape so + # the downstream concat can stack them. + if any_missing and any(isinstance(el, AwkArray) for el in present_els): + msg = ( + "Combining `aligned_axis_key_join` with `join='inner'` is " + "not yet implemented for awkward arrays in `obsm`/`varm` " + "when the key is missing from at least one input. Use the " + "same value for `join` and `aligned_axis_key_join`, or " + "drop the affected awkward entries before concatenating." + ) + raise NotImplementedError(msg) + if all(isinstance(el, pd.DataFrame) for el in present_els): + common_cols = reduce( + lambda x, y: x.intersection(y), + (el.columns for el in present_els), + ) + cur_reindexers = [ + Reindexer(el.columns, common_cols) + if not_missing(el) + else ( + lambda _, n=n, cols=common_cols, fv=fill_value: pd.DataFrame( + np.nan if fv is None else fv, + index=range(n), + columns=cols, + ) + ) + for el, n in zip(els, ns, strict=True) + ] + # Use an empty filler so concat_arrays' DataFrame check passes; + # the lambda reindexers above replace these with proper DataFrames. + off_axis_size = 0 + else: + inner_present = gen_inner_reindexers( + present_els, new_index=index, axis=concat_axis + ) + target_idx = inner_present[0].new_idx + present_iter = iter(inner_present) + cur_reindexers = [ + next(present_iter) + if not_missing(el) + else Reindexer(target_idx, target_idx) + for el in els + ] + off_axis_size = len(target_idx) + else: + cur_reindexers = gen_outer_reindexers( + els, ns, new_index=index, axis=concat_axis + ) + off_axis_size = 0 + if any(isinstance(e, DaskArray) for e in els if not_missing(e)): + if not isinstance(cur_reindexers[0], Reindexer): # pragma: no cover + msg = "Cannot re-index a dask array without a Reindexer" + raise ValueError(msg) + off_axis_size = cur_reindexers[0].idx.shape[0] + + result[k] = concat_arrays( + [ + el + if not_missing(el) + else missing_element( + n, + axis=concat_axis, + els=els, + fill_value=fill_value, + off_axis_size=off_axis_size, + ) + for el, n in zip(els, ns, strict=True) + ], + cur_reindexers, + axis=concat_axis, + index=index, + fill_value=fill_value if any_missing else None, + force_lazy=force_lazy, + ) + return result + + def concat_pairwise_mapping( mappings: Collection[Mapping], shapes: Collection[int], join_keys=intersect_keys ): @@ -1447,6 +1556,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 *, axis: Literal["obs", 0, "var", 1] = "obs", join: Join_T = "inner", + aligned_axis_key_join: Join_T | None = None, merge: StrategiesLiteral | Callable | None = None, uns_merge: StrategiesLiteral | Callable | None = None, label: str | None = None, @@ -1471,6 +1581,12 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 How to align values when concatenating. If "outer", the union of the other axis is taken. If "inner", the intersection. See :doc:`concatenation <../concatenation>` for more. + aligned_axis_key_join + How to join keys on the *concatenation axis* itself: columns of `obs`/`var`, + and keys of `obsm`/`obsp` (or `varm`/`varp` when concatenating along `axis="var"`). + Use "outer" to take the union of these keys, "inner" to take the intersection. + Defaults to `None`, in which case `join` is used for both the off-axis index + alignment and the on-axis key join (the historical behaviour). merge How elements not aligned to the axis being concatenated along are selected. Currently implemented strategies include: @@ -1651,6 +1767,16 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 {'c': {'c.c': 5}} >>> dict(ad.concat([a, b, c], uns_merge="first").uns) {'a': 1, 'b': 2, 'c': {'c.a': 3, 'c.b': 4, 'c.c': 5}} + + `aligned_axis_key_join` controls on-axis key joining (obs columns, + obsm/obsp keys) independently of the off-axis index `join`. The default + of `None` falls back to `join`, preserving existing behaviour. To keep + the union of `var` indices but the intersection of obs columns: + + >>> ad.concat([a, b], join="outer", aligned_axis_key_join="inner").obs.columns.tolist() + ['group'] + >>> ad.concat([a, b], join="outer").obs.columns.tolist() + ['group', 'measure'] """ from anndata._core.xarray import Dataset2D @@ -1660,6 +1786,21 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 merge = resolve_merge_strategy(merge) uns_merge = resolve_merge_strategy(uns_merge) + # Resolve the on-axis key join. When aligned_axis_key_join is None, the + # historical behaviour applies and `join` controls both off-axis index + # alignment and on-axis key joining. When set, it overrides the on-axis + # key joining for obs/obsm/obsp (or var/varm/varp when axis="var"). + if aligned_axis_key_join is not None and aligned_axis_key_join not in ( + "inner", + "outer", + ): + msg = ( + f"`aligned_axis_key_join` must be one of 'inner', 'outer', or None, " + f"got {aligned_axis_key_join!r}" + ) + raise ValueError(msg) + aligned_join = aligned_axis_key_join if aligned_axis_key_join is not None else join + if isinstance(adatas, Mapping): if keys is not None: msg = ( @@ -1701,7 +1842,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ] # Annotation for concatenation axis - check_combinable_cols([getattr(a, axis_name).columns for a in adatas], join=join) + check_combinable_cols( + [getattr(a, axis_name).columns for a in adatas], join=aligned_join + ) annotations = [getattr(a, axis_name) for a in adatas] are_any_annotations_dataframes = any( isinstance(a, pd.DataFrame) for a in annotations @@ -1712,14 +1855,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ) concat_annot = pd.concat( unify_dtypes(annotations_in_memory), - join=join, + join=aligned_join, ignore_index=True, ) concat_annot.index = concat_indices else: concat_annot = concat_dataset2d_on_annot_axis( annotations, - join, + aligned_join, force_lazy=force_lazy, concat_indices=concat_indices, ) @@ -1758,33 +1901,61 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value) + # Helper bindings for off-axis-shaped objects (layers): keep using `join`. if join == "inner": concat_aligned_mapping = inner_concat_aligned_mapping - join_keys = intersect_keys elif join == "outer": concat_aligned_mapping = partial( outer_concat_aligned_mapping, fill_value=fill_value ) - join_keys = union_keys else: msg = f"{join=} should have been validated above by pd.concat" raise AssertionError(msg) + # Pairwise key joining (obsp/varp) is purely about which keys appear in + # the result; the per-key block-diagonal does not have a content-alignment + # axis, so a single setting is sufficient here. + if aligned_join == "inner": + aligned_join_keys = intersect_keys + elif aligned_join == "outer": + aligned_join_keys = union_keys + else: + msg = f"{aligned_join=} should have been validated" + raise AssertionError(msg) + layers = concat_aligned_mapping( [a.layers for a in adatas], axis=axis, reindexers=reindexers ) - concat_mapping = concat_aligned_mapping( - [getattr(a, f"{axis_name}m") for a in adatas], - axis=axis, - concat_axis=0, - index=concat_indices, - force_lazy=force_lazy, - ) + + # obsm/varm: aligned_axis_key_join controls which keys appear, while the + # content alignment (e.g. DataFrame columns within a shared key) follows + # `join`. When the two settings agree we can use the existing helpers; + # when they diverge we use a split-join helper. + obsm_mappings = [getattr(a, f"{axis_name}m") for a in adatas] + if aligned_join == join: + concat_mapping = concat_aligned_mapping( + obsm_mappings, + axis=axis, + concat_axis=0, + index=concat_indices, + force_lazy=force_lazy, + ) + else: + concat_mapping = _concat_aligned_mapping_split_join( + obsm_mappings, + key_join=aligned_join, + content_join=join, + fill_value=fill_value, + axis=axis, + concat_axis=0, + index=concat_indices, + force_lazy=force_lazy, + ) if pairwise: concat_pairwise = concat_pairwise_mapping( mappings=[getattr(a, f"{axis_name}p") for a in adatas], shapes=[a.shape[axis] for a in adatas], - join_keys=join_keys, + join_keys=aligned_join_keys, ) else: concat_pairwise = {} @@ -1816,6 +1987,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 for a in adatas ], join=join, + aligned_axis_key_join=aligned_axis_key_join, label=label, keys=keys, index_unique=index_unique, diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 2fa87d651..8e3603c37 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1875,3 +1875,330 @@ def test_1d_concat(): adata = AnnData(np.ones((5, 20)), obsm={"1d-array": np.ones(5)}) concated = concat([adata, adata]) assert concated.obsm["1d-array"].shape == (10, 1) + + +# ------------------------------------------------------------------ +# aligned_axis_key_join (issue #2374) +# ------------------------------------------------------------------ + + +def _adatas_with_partial_overlap_along_axis(axis_name): + """Two AnnData with different on-axis annotation columns and aligned-mapping keys.""" + if axis_name == "var": + obs_a = pd.DataFrame(index=["row1", "row2"]) + obs_b = pd.DataFrame(index=["row1", "row2"]) + var_a = pd.DataFrame( + {"shared": ["a", "b"], "only_a": [1, 2]}, index=["v1", "v2"] + ) + var_b = pd.DataFrame( + {"shared": ["c", "d"], "only_b": [3, 4]}, index=["v3", "v4"] + ) + else: + obs_a = pd.DataFrame( + {"shared": ["a", "b"], "only_a": [1, 2]}, index=["s1", "s2"] + ) + obs_b = pd.DataFrame( + {"shared": ["c", "d"], "only_b": [3, 4]}, index=["s3", "s4"] + ) + var_a = pd.DataFrame(index=["v1", "v2"]) + var_b = pd.DataFrame(index=["v1", "v2"]) + + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=obs_a, + var=var_a, + **{ + f"{axis_name}m": { + "shared_m": np.ones((2, 3)), + "only_a_m": np.zeros((2, 3)), + } + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=obs_b, + var=var_b, + **{ + f"{axis_name}m": { + "shared_m": 2 * np.ones((2, 3)), + "only_b_m": np.ones((2, 3)), + } + }, + ) + return a, b + + +@pytest.mark.parametrize("axis_name", ["obs", "var"]) +def test_aligned_axis_key_join_default_falls_back_to_join(axis_name): + """When `aligned_axis_key_join=None`, the on-axis behaviour matches `join`.""" + a, b = _adatas_with_partial_overlap_along_axis(axis_name) + axis = 0 if axis_name == "obs" else 1 + + inner = concat([a, b], axis=axis) + inner_explicit = concat([a, b], axis=axis, aligned_axis_key_join=None) + assert list(getattr(inner, axis_name).columns) == list( + getattr(inner_explicit, axis_name).columns + ) + assert list(getattr(inner, f"{axis_name}m").keys()) == list( + getattr(inner_explicit, f"{axis_name}m").keys() + ) + # Inner is the historical default, only "shared" should remain. + assert list(getattr(inner, axis_name).columns) == ["shared"] + assert list(getattr(inner, f"{axis_name}m").keys()) == ["shared_m"] + + +@pytest.mark.parametrize("axis_name", ["obs", "var"]) +def test_aligned_axis_key_join_outer_with_inner_join(axis_name): + """`aligned_axis_key_join="outer"` unions on-axis keys while leaving off-axis as inner.""" + a, b = _adatas_with_partial_overlap_along_axis(axis_name) + axis = 0 if axis_name == "obs" else 1 + + res = concat([a, b], axis=axis, join="inner", aligned_axis_key_join="outer") + cols = list(getattr(res, axis_name).columns) + keys = list(getattr(res, f"{axis_name}m").keys()) + assert set(cols) == {"shared", "only_a", "only_b"} + assert set(keys) == {"shared_m", "only_a_m", "only_b_m"} + + +@pytest.mark.parametrize("axis_name", ["obs", "var"]) +def test_aligned_axis_key_join_inner_with_outer_join(axis_name): + """`aligned_axis_key_join="inner"` intersects on-axis keys while leaving off-axis as outer.""" + a, b = _adatas_with_partial_overlap_along_axis(axis_name) + axis = 0 if axis_name == "obs" else 1 + + res = concat([a, b], axis=axis, join="outer", aligned_axis_key_join="inner") + cols = list(getattr(res, axis_name).columns) + keys = list(getattr(res, f"{axis_name}m").keys()) + assert cols == ["shared"] + assert keys == ["shared_m"] + + +def test_aligned_axis_key_join_does_not_affect_layers(): + """Layers key-joining follows `join`, not `aligned_axis_key_join`.""" + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + layers={ + "shared_layer": np.ones((2, 2)), + "only_a_layer": np.zeros((2, 2)), + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + layers={ + "shared_layer": 2 * np.ones((2, 2)), + "only_b_layer": np.ones((2, 2)), + }, + ) + + # join="inner" + aligned="outer": layers stay intersection. + res = concat([a, b], join="inner", aligned_axis_key_join="outer") + assert sorted(res.layers.keys()) == ["shared_layer"] + + # join="outer" + aligned="inner": layers go union. + res = concat([a, b], join="outer", aligned_axis_key_join="inner") + assert sorted(res.layers.keys()) == [ + "only_a_layer", + "only_b_layer", + "shared_layer", + ] + + +def test_aligned_axis_key_join_does_not_affect_alt_axis_mappings(): + """When concatenating along obs, varm/varp follow `merge`, not `aligned_axis_key_join`.""" + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + varm={ + "shared_varm": np.ones((2, 3)), + "only_a_varm": np.zeros((2, 3)), + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + varm={ + "shared_varm": 2 * np.ones((2, 3)), + "only_b_varm": np.ones((2, 3)), + }, + ) + + # The contract: aligned_axis_key_join only governs the on-axis (obs here). + # varm is alt-axis, controlled entirely by `merge`. Compare results + # produced with and without aligned_axis_key_join for several merge + # strategies — varm contents must be identical. + for merge_strategy in (None, "same", "first", "unique", "only"): + kwargs = {"axis": "obs", "merge": merge_strategy} + baseline = concat([a, b], **kwargs) + with_inner = concat([a, b], aligned_axis_key_join="inner", **kwargs) + with_outer = concat([a, b], aligned_axis_key_join="outer", **kwargs) + assert sorted(baseline.varm.keys()) == sorted(with_inner.varm.keys()), ( + f"varm keys diverged under merge={merge_strategy!r}" + ) + assert sorted(baseline.varm.keys()) == sorted(with_outer.varm.keys()), ( + f"varm keys diverged under merge={merge_strategy!r}" + ) + + +def test_aligned_axis_key_join_obsp_pairwise(): + """Pairwise on-axis (obsp) key joining responds to `aligned_axis_key_join`.""" + a = AnnData( + X=sparse.csr_matrix(np.eye(3, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2", "s3"]), + obsp={ + "shared_obsp": sparse.csr_matrix(np.eye(3, dtype=np.float64)), + "only_a_obsp": sparse.csr_matrix(np.eye(3, dtype=np.float64)), + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s4", "s5"]), + obsp={ + "shared_obsp": sparse.csr_matrix(np.eye(2, dtype=np.float64)), + "only_b_obsp": sparse.csr_matrix(np.eye(2, dtype=np.float64)), + }, + ) + outer = concat([a, b], pairwise=True, aligned_axis_key_join="outer") + inner = concat([a, b], pairwise=True, aligned_axis_key_join="inner") + assert sorted(outer.obsp.keys()) == ["only_a_obsp", "only_b_obsp", "shared_obsp"] + assert sorted(inner.obsp.keys()) == ["shared_obsp"] + + +def test_aligned_axis_key_join_invalid_value(): + """Invalid `aligned_axis_key_join` raises a clear error.""" + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + ) + with pytest.raises(ValueError, match="aligned_axis_key_join"): + concat([a, a], aligned_axis_key_join="banana") + + +def test_aligned_axis_key_join_obsm_dataframe_content_follows_join(): + """When `aligned_axis_key_join` differs from `join`, the on-axis key set + follows `aligned_axis_key_join` but the per-key content alignment (e.g. + DataFrame columns inside a shared `obsm[k]`) follows `join`. + """ + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "df": pd.DataFrame({"x": [1, 2], "a_only": [10, 20]}, index=["s1", "s2"]) + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "df": pd.DataFrame({"x": [3, 4], "b_only": [30, 40]}, index=["s3", "s4"]) + }, + ) + + # join="inner" + aligned="outer": obsm key set unioned (still just "df"), + # but the columns inside df should stay intersected. + res = concat([a, b], join="inner", aligned_axis_key_join="outer") + assert list(res.obsm["df"].columns) == ["x"] + + # join="outer" + aligned="inner": obsm key set intersected (still "df"), + # df columns should be unioned per join="outer". + res = concat([a, b], join="outer", aligned_axis_key_join="inner") + assert sorted(res.obsm["df"].columns) == ["a_only", "b_only", "x"] + + +def test_aligned_axis_key_join_inner_content_with_missing_key(): + """When `join="inner"` and `aligned_axis_key_join="outer"` with 3+ + inputs and a key present in only some, content alignment intersects + among the present values rather than unioning them. + """ + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "df": pd.DataFrame({"x": [1, 2], "a_only": [10, 20]}, index=["s1", "s2"]) + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "df": pd.DataFrame({"x": [3, 4], "b_only": [30, 40]}, index=["s3", "s4"]) + }, + ) + c = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s5", "s6"]), + var=pd.DataFrame(index=["v1", "v2"]), + ) + + # DataFrame obsm: inner intersection among present DataFrames is just "x" + res = concat([a, b, c], aligned_axis_key_join="outer") + assert list(res.obsm["df"].columns) == ["x"] + + # ndarray obsm with mismatched widths: inner = min width + a_arr = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={"arr": np.ones((2, 3))}, + ) + b_arr = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={"arr": np.ones((2, 5))}, + ) + c_empty = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s5", "s6"]), + var=pd.DataFrame(index=["v1", "v2"]), + ) + res_arr = concat([a_arr, b_arr, c_empty], aligned_axis_key_join="outer") + assert res_arr.obsm["arr"].shape == (6, 3) + + +def test_aligned_axis_key_join_reproduces_issue_2374_example(): + """Direct reproduction of the example from issue #2374. Verifies that + `aligned_axis_key_join` controls on-axis annotation columns (obs/var) + independently from `join`, which keeps controlling the off-axis index + (var_names when axis=0, obs_names when axis=1). + """ + adatas = [ + AnnData( + X=np.ones((1, 2)), + obs=pd.DataFrame({"1": ["a"], "2": ["b"]}), + var=pd.DataFrame(index=["1", "2"]), + ), + AnnData( + X=np.ones((1, 2)), + obs=pd.DataFrame({"1": ["a"], "3": ["b"]}), + var=pd.DataFrame(index=["1", "3"]), + ), + ] + + # Existing behaviour (unchanged): join controls both axes. + r = concat(adatas, join="inner") + assert list(r.obs.columns) == ["1"] + assert list(r.var_names) == ["1"] + r = concat(adatas, join="outer") + assert sorted(r.obs.columns) == ["1", "2", "3"] + assert sorted(r.var_names) == ["1", "2", "3"] + + # New behaviour: outer on-axis columns, inner off-axis index. + r = concat(adatas, join="inner", aligned_axis_key_join="outer") + assert sorted(r.obs.columns) == ["1", "2", "3"] + assert list(r.var_names) == ["1"] + + # Converse: inner on-axis columns, outer off-axis index. + r = concat(adatas, join="outer", aligned_axis_key_join="inner") + assert list(r.obs.columns) == ["1"] + assert sorted(r.var_names) == ["1", "2", "3"] From aa6a37fc97d4620b4f426ed2e41a6b2d0b92f3d6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 May 2026 18:40:22 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/_core/merge.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 0e88fc6ba..e01cefb06 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1773,7 +1773,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 of `None` falls back to `join`, preserving existing behaviour. To keep the union of `var` indices but the intersection of obs columns: - >>> ad.concat([a, b], join="outer", aligned_axis_key_join="inner").obs.columns.tolist() + >>> ad.concat( + ... [a, b], join="outer", aligned_axis_key_join="inner" + ... ).obs.columns.tolist() ['group'] >>> ad.concat([a, b], join="outer").obs.columns.tolist() ['group', 'measure'] From a819bf6a3468036ef00b9c9c9711ab99e166c1a0 Mon Sep 17 00:00:00 2001 From: Ekin-Kahraman Date: Sun, 3 May 2026 20:07:09 +0100 Subject: [PATCH 3/7] test(concat): cover awkward+missing+inner NotImplementedError branch Add coverage for the unimplemented branch in _concat_aligned_mapping_split_join where awkward arrays meet inner content-join with missing keys, lifting codecov/patch above target. --- tests/test_concatenate.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 8e3603c37..072c36020 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -2045,6 +2045,32 @@ def test_aligned_axis_key_join_does_not_affect_alt_axis_mappings(): ) +def test_aligned_axis_key_join_awkward_inner_missing_key_raises(): + """Awkward arrays with inner content-join + missing keys raises NotImplementedError.""" + import awkward as ak + + a = AnnData( + X=np.eye(2, dtype=np.float64), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "shared_awk": ak.Array([[1, 2], [3]]), + "only_a_awk": ak.Array([[4], [5, 6]]), + }, + ) + b = AnnData( + X=np.eye(2, dtype=np.float64), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={"shared_awk": ak.Array([[7, 8], [9]])}, + ) + # join="inner" forces inner content-join; aligned_axis_key_join="outer" + # forces outer key set, so `only_a_awk` is missing from `b`. The combo of + # awkward + missing + inner content is the unimplemented branch. + with pytest.raises(NotImplementedError, match="awkward"): + concat([a, b], join="inner", aligned_axis_key_join="outer") + + def test_aligned_axis_key_join_obsp_pairwise(): """Pairwise on-axis (obsp) key joining responds to `aligned_axis_key_join`.""" a = AnnData( From a7bbd0c5418f0aa5261267b0fce3d5f44fb81a03 Mon Sep 17 00:00:00 2001 From: Ekin-Kahraman Date: Tue, 5 May 2026 21:27:21 +0100 Subject: [PATCH 4/7] refactor(concat): reuse inner/outer helpers + extend aligned_axis_key_join to layers Addresses the two review points on #2416: 1. Reuse the existing helpers instead of a separate split helper. Adds `keys: Iterable | None = None` to `inner_concat_aligned_mapping` and `outer_concat_aligned_mapping`. Default behaviour is unchanged (`intersect_keys` / `union_keys` respectively). When the caller passes an explicit `keys` set, the iterated key set is overridden; `inner_concat_aligned_mapping` additionally takes `fill_value` and handles entries missing from a subset of mappings (the outer-key + inner-content combination). Drops `_concat_aligned_mapping_split_join`; the obsm/varm callsite in `concat()` collapses to a single dispatch. 2. Layers now respect aligned_axis_key_join. The on-axis layer-name set is controlled by `aligned_axis_key_join`; the off-axis (alt-axis) alignment of each kept layer still follows `join` via the precomputed X-axis reindexers. Subtlety: in the `join="inner"` + `aligned_axis_key_join="outer"` path, the missing-key inner branch must honour the caller's reindexers rather than regenerating from `present_els` only. The present-only path would intersect over the present subset, leaving one-sided layers at their original alt-axis width and breaking the AnnData invariant that every layer shares X's alt-axis. The helper now takes the precomputed `reindexers[i]` for present entries and inserts an identity Reindexer for missing ones; the `missing_element` filler uses the matching alt-axis size. Tests: replaces `test_aligned_axis_key_join_does_not_affect_layers` with two tests asserting layers do follow the new contract: - `outer` key join + `inner` content: layer names unioned, all kept layers shaped to the inner alt-axis intersection (with content spot-checked for fill behaviour); - `inner` key join + `outer` content: layer names intersected, kept layer shaped to the outer alt-axis union. Existing 13 obsm/varm/obsp tests unchanged. Default `aligned_axis_key_join=None` still routes to the historical single-knob behaviour. `concat_on_disk` remains out of scope. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/anndata/_core/merge.py | 305 +++++++++++++++++++------------------ tests/test_concatenate.py | 88 +++++++++-- 2 files changed, 232 insertions(+), 161 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index e01cefb06..3c51ca30c 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -948,7 +948,7 @@ def concat_arrays( # noqa: PLR0911, PLR0912 ) -def inner_concat_aligned_mapping( +def inner_concat_aligned_mapping( # noqa: PLR0913 mappings, *, reindexers=None, @@ -956,22 +956,126 @@ def inner_concat_aligned_mapping( axis=0, concat_axis=None, force_lazy: bool = False, + keys=None, + fill_value=None, ): + """Inner-content concat of aligned mappings. + + By default iterates ``intersect_keys(mappings)``. Pass ``keys`` to override + the iterated key set (e.g. ``union_keys(mappings)`` for an outer key join + paired with inner content alignment); missing entries are then filled + with ``fill_value``. + """ if concat_axis is None: concat_axis = axis + if keys is None: + keys = intersect_keys(mappings) + result = {} + ns = [m.parent.shape[axis] for m in mappings] - for k in intersect_keys(mappings): - els = [m[k] for m in mappings] - if reindexers is None: - cur_reindexers = gen_inner_reindexers( - els, new_index=index, axis=concat_axis + for k in keys: + els = [m.get(k, MissingVal) for m in mappings] + any_missing = any(is_missing(el) for el in els) + + if not any_missing: + # All keys present in all mappings — the default + # ``intersect_keys`` path. No fill_value handling needed. + if reindexers is None: + cur_reindexers = gen_inner_reindexers( + els, new_index=index, axis=concat_axis + ) + else: + cur_reindexers = reindexers + result[k] = concat_arrays( + els, + cur_reindexers, + index=index, + axis=concat_axis, + force_lazy=force_lazy, + ) + continue + + # Missing-key path: only reachable when caller passes an explicit + # ``keys`` set wider than ``intersect_keys`` (the outer-key + inner- + # content combination from ``concat(aligned_axis_key_join=...)``). + present_els = [el for el in els if not_missing(el)] + + if any(isinstance(el, AwkArray) for el in present_els): + msg = ( + "Combining `aligned_axis_key_join` with `join='inner'` is " + "not yet implemented for awkward arrays in `obsm`/`varm` " + "when the key is missing from at least one input. Use the " + "same value for `join` and `aligned_axis_key_join`, or " + "drop the affected awkward entries before concatenating." + ) + raise NotImplementedError(msg) + + if reindexers is not None: + # Caller-provided reindexers already encode the alt-axis alignment + # across *all* mappings (e.g. the gene-axis intersection for + # ``layers`` when ``join="inner"``). The present-only reindexers + # below would intersect over the present subset only, which is + # wrong for layers because their alt-axis must match X's. Honour + # the caller's reindexers and drop in an identity reindexer for + # missing entries; the filler created by ``missing_element`` + # below uses the matching alt-axis size. + target_idx = reindexers[0].new_idx + cur_reindexers = [ + reindexers[i] if not_missing(el) else Reindexer(target_idx, target_idx) + for i, el in enumerate(els) + ] + off_axis_size = len(target_idx) + elif all(isinstance(el, pd.DataFrame) for el in present_els): + common_cols = reduce( + lambda x, y: x.intersection(y), + (el.columns for el in present_els), ) + cur_reindexers = [ + Reindexer(el.columns, common_cols) + if not_missing(el) + else ( + lambda _, n=n, cols=common_cols, fv=fill_value: pd.DataFrame( + np.nan if fv is None else fv, + index=range(n), + columns=cols, + ) + ) + for el, n in zip(els, ns, strict=True) + ] + off_axis_size = 0 else: - cur_reindexers = reindexers + inner_present = gen_inner_reindexers( + present_els, new_index=index, axis=concat_axis + ) + target_idx = inner_present[0].new_idx + present_iter = iter(inner_present) + cur_reindexers = [ + next(present_iter) + if not_missing(el) + else Reindexer(target_idx, target_idx) + for el in els + ] + off_axis_size = len(target_idx) result[k] = concat_arrays( - els, cur_reindexers, index=index, axis=concat_axis, force_lazy=force_lazy + [ + el + if not_missing(el) + else missing_element( + n, + axis=concat_axis, + els=els, + fill_value=fill_value, + off_axis_size=off_axis_size, + ) + for el, n in zip(els, ns, strict=True) + ], + cur_reindexers, + axis=concat_axis, + index=index, + fill_value=fill_value, + force_lazy=force_lazy, ) return result @@ -1081,7 +1185,7 @@ def missing_element( return xp.zeros(shape, dtype=bool) -def outer_concat_aligned_mapping( +def outer_concat_aligned_mapping( # noqa: PLR0913 mappings, *, reindexers=None, @@ -1090,13 +1194,23 @@ def outer_concat_aligned_mapping( concat_axis=None, fill_value=None, force_lazy: bool = False, + keys=None, ): + """Outer-content concat of aligned mappings. + + By default iterates ``union_keys(mappings)``. Pass ``keys`` to override + the iterated key set (e.g. ``intersect_keys(mappings)`` for an inner key + join paired with outer content alignment). + """ if concat_axis is None: concat_axis = axis + if keys is None: + keys = union_keys(mappings) + result = {} ns = [m.parent.shape[axis] for m in mappings] - for k in union_keys(mappings): + for k in keys: els = [m.get(k, MissingVal) for m in mappings] if reindexers is None: cur_reindexers = gen_outer_reindexers( @@ -1136,115 +1250,6 @@ def outer_concat_aligned_mapping( return result -def _concat_aligned_mapping_split_join( # noqa: PLR0913 - mappings, - *, - key_join: Join_T, - content_join: Join_T, - fill_value=None, - axis=0, - concat_axis=None, - index=None, - force_lazy: bool = False, -): - """Concatenate aligned mappings (obsm/varm style) with separate key and - content joins. ``key_join`` selects which keys appear in the result; - ``content_join`` selects how shared keys' values are aligned along the - off-axis dimension (e.g. inner intersects DataFrame columns, outer unions - them). Used for ``concat(aligned_axis_key_join=...)`` when the on-axis key - join differs from the off-axis ``join``. - """ - if concat_axis is None: - concat_axis = axis - keys = union_keys(mappings) if key_join == "outer" else intersect_keys(mappings) - ns = [m.parent.shape[axis] for m in mappings] - - result = {} - for k in keys: - els = [m.get(k, MissingVal) for m in mappings] - any_missing = any(is_missing(el) for el in els) - present_els = [el for el in els if not_missing(el)] - - if content_join == "inner": - # Inner content alignment: intersect the off-axis dimension among - # values that are actually present, then reindex everything to that - # intersection. Missing entries get a filler matching the shape so - # the downstream concat can stack them. - if any_missing and any(isinstance(el, AwkArray) for el in present_els): - msg = ( - "Combining `aligned_axis_key_join` with `join='inner'` is " - "not yet implemented for awkward arrays in `obsm`/`varm` " - "when the key is missing from at least one input. Use the " - "same value for `join` and `aligned_axis_key_join`, or " - "drop the affected awkward entries before concatenating." - ) - raise NotImplementedError(msg) - if all(isinstance(el, pd.DataFrame) for el in present_els): - common_cols = reduce( - lambda x, y: x.intersection(y), - (el.columns for el in present_els), - ) - cur_reindexers = [ - Reindexer(el.columns, common_cols) - if not_missing(el) - else ( - lambda _, n=n, cols=common_cols, fv=fill_value: pd.DataFrame( - np.nan if fv is None else fv, - index=range(n), - columns=cols, - ) - ) - for el, n in zip(els, ns, strict=True) - ] - # Use an empty filler so concat_arrays' DataFrame check passes; - # the lambda reindexers above replace these with proper DataFrames. - off_axis_size = 0 - else: - inner_present = gen_inner_reindexers( - present_els, new_index=index, axis=concat_axis - ) - target_idx = inner_present[0].new_idx - present_iter = iter(inner_present) - cur_reindexers = [ - next(present_iter) - if not_missing(el) - else Reindexer(target_idx, target_idx) - for el in els - ] - off_axis_size = len(target_idx) - else: - cur_reindexers = gen_outer_reindexers( - els, ns, new_index=index, axis=concat_axis - ) - off_axis_size = 0 - if any(isinstance(e, DaskArray) for e in els if not_missing(e)): - if not isinstance(cur_reindexers[0], Reindexer): # pragma: no cover - msg = "Cannot re-index a dask array without a Reindexer" - raise ValueError(msg) - off_axis_size = cur_reindexers[0].idx.shape[0] - - result[k] = concat_arrays( - [ - el - if not_missing(el) - else missing_element( - n, - axis=concat_axis, - els=els, - fill_value=fill_value, - off_axis_size=off_axis_size, - ) - for el, n in zip(els, ns, strict=True) - ], - cur_reindexers, - axis=concat_axis, - index=index, - fill_value=fill_value if any_missing else None, - force_lazy=force_lazy, - ) - return result - - def concat_pairwise_mapping( mappings: Collection[Mapping], shapes: Collection[int], join_keys=intersect_keys ): @@ -1583,10 +1588,13 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 for more. aligned_axis_key_join How to join keys on the *concatenation axis* itself: columns of `obs`/`var`, - and keys of `obsm`/`obsp` (or `varm`/`varp` when concatenating along `axis="var"`). - Use "outer" to take the union of these keys, "inner" to take the intersection. - Defaults to `None`, in which case `join` is used for both the off-axis index - alignment and the on-axis key join (the historical behaviour). + keys of `obsm`/`obsp` (or `varm`/`varp` when concatenating along `axis="var"`), + and keys of `layers`. Use "outer" to take the union of these keys, "inner" + to take the intersection. The off-axis content of each value (e.g. the var + index of an obsm DataFrame, or the gene axis of a layer) still follows + `join`. Defaults to `None`, in which case `join` is used for both the + off-axis index alignment and the on-axis key join (the historical + behaviour). merge How elements not aligned to the axis being concatenated along are selected. Currently implemented strategies include: @@ -1903,7 +1911,12 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value) - # Helper bindings for off-axis-shaped objects (layers): keep using `join`. + # `join` controls the *off-axis* content alignment shared by X, layers, + # and obsm-style values within each shared key. `aligned_join` controls + # which keys appear on the on-axis side (obs/var columns, obsm/obsp + # keys, layers keys). The two settings are independent; when + # `aligned_axis_key_join=None`, `aligned_join == join` and behaviour + # reduces to the historical single-knob path. if join == "inner": concat_aligned_mapping = inner_concat_aligned_mapping elif join == "outer": @@ -1925,34 +1938,30 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 msg = f"{aligned_join=} should have been validated" raise AssertionError(msg) + # Layers: `aligned_join` selects which layer keys appear; `reindexers` + # carries the off-axis alignment from X so each kept layer is aligned to + # the same alt-axis as X. + layer_mappings = [a.layers for a in adatas] layers = concat_aligned_mapping( - [a.layers for a in adatas], axis=axis, reindexers=reindexers + layer_mappings, + axis=axis, + reindexers=reindexers, + keys=aligned_join_keys(layer_mappings), ) - # obsm/varm: aligned_axis_key_join controls which keys appear, while the - # content alignment (e.g. DataFrame columns within a shared key) follows - # `join`. When the two settings agree we can use the existing helpers; - # when they diverge we use a split-join helper. + # obsm/varm: aligned_axis_key_join controls which keys appear; the content + # alignment within a shared key follows `join`. The pre-computed + # `aligned_join_keys` selects the on-axis key set; the inner/outer helper + # selected above performs the off-axis content alignment. obsm_mappings = [getattr(a, f"{axis_name}m") for a in adatas] - if aligned_join == join: - concat_mapping = concat_aligned_mapping( - obsm_mappings, - axis=axis, - concat_axis=0, - index=concat_indices, - force_lazy=force_lazy, - ) - else: - concat_mapping = _concat_aligned_mapping_split_join( - obsm_mappings, - key_join=aligned_join, - content_join=join, - fill_value=fill_value, - axis=axis, - concat_axis=0, - index=concat_indices, - force_lazy=force_lazy, - ) + concat_mapping = concat_aligned_mapping( + obsm_mappings, + axis=axis, + concat_axis=0, + index=concat_indices, + force_lazy=force_lazy, + keys=aligned_join_keys(obsm_mappings), + ) if pairwise: concat_pairwise = concat_pairwise_mapping( mappings=[getattr(a, f"{axis_name}p") for a in adatas], diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 072c36020..4e724e80a 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1973,39 +1973,101 @@ def test_aligned_axis_key_join_inner_with_outer_join(axis_name): assert keys == ["shared_m"] -def test_aligned_axis_key_join_does_not_affect_layers(): - """Layers key-joining follows `join`, not `aligned_axis_key_join`.""" +def test_aligned_axis_key_join_layer_keys_unioned_with_inner_content(): + """`aligned_axis_key_join="outer"` + `join="inner"` unions layer keys + while aligning each layer's off-axis (var) to the inner intersection. + + Mirrors the obsm contract: which keys appear is on-axis (controlled by + `aligned_axis_key_join`); how each kept value aligns along the alt-axis + is off-axis (controlled by `join`). + """ a = AnnData( - X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + X=np.ones((2, 3), dtype=np.float64), obs=pd.DataFrame(index=["s1", "s2"]), - var=pd.DataFrame(index=["v1", "v2"]), + var=pd.DataFrame(index=["v1", "v2", "v3"]), layers={ - "shared_layer": np.ones((2, 2)), - "only_a_layer": np.zeros((2, 2)), + "shared_layer": np.full((2, 3), 1.0), + "only_a_layer": np.full((2, 3), 7.0), }, ) b = AnnData( - X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + X=np.ones((2, 2), dtype=np.float64), obs=pd.DataFrame(index=["s3", "s4"]), var=pd.DataFrame(index=["v1", "v2"]), layers={ - "shared_layer": 2 * np.ones((2, 2)), - "only_b_layer": np.ones((2, 2)), + "shared_layer": np.full((2, 2), 2.0), + "only_b_layer": np.full((2, 2), 9.0), }, ) - # join="inner" + aligned="outer": layers stay intersection. res = concat([a, b], join="inner", aligned_axis_key_join="outer") - assert sorted(res.layers.keys()) == ["shared_layer"] - # join="outer" + aligned="inner": layers go union. - res = concat([a, b], join="outer", aligned_axis_key_join="inner") + # Outer key join: every layer name appears. assert sorted(res.layers.keys()) == [ "only_a_layer", "only_b_layer", "shared_layer", ] + # Inner content join: alt-axis (var) is the intersection {v1, v2}. + n_total_cells = 4 + n_inner_genes = 2 + assert res.shape == (n_total_cells, n_inner_genes) + for k in ("shared_layer", "only_a_layer", "only_b_layer"): + assert res.layers[k].shape == (n_total_cells, n_inner_genes), ( + f"layer {k!r} should be aligned to inner alt-axis, " + f"got shape {res.layers[k].shape}" + ) + + # Spot-check content. shared_layer is present in both; values 1.0 from a + # and 2.0 from b stack into the inner gene set. + np.testing.assert_array_equal( + np.asarray(res.layers["shared_layer"]), + np.array([[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]), + ) + # only_a_layer is filled (with the missing-element default) for b's rows. + only_a = np.asarray(res.layers["only_a_layer"]) + np.testing.assert_array_equal(only_a[:2], np.array([[7.0, 7.0], [7.0, 7.0]])) + # only_b_layer is filled for a's rows; b's rows carry 9.0. + only_b = np.asarray(res.layers["only_b_layer"]) + np.testing.assert_array_equal(only_b[2:], np.array([[9.0, 9.0], [9.0, 9.0]])) + + +def test_aligned_axis_key_join_layer_keys_intersected_with_outer_content(): + """`aligned_axis_key_join="inner"` + `join="outer"` intersects layer + keys while aligning each kept layer's off-axis (var) to the outer + union. + """ + a = AnnData( + X=np.ones((2, 3), dtype=np.float64), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2", "v3"]), + layers={ + "shared_layer": np.full((2, 3), 1.0), + "only_a_layer": np.full((2, 3), 7.0), + }, + ) + b = AnnData( + X=np.ones((2, 2), dtype=np.float64), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + layers={ + "shared_layer": np.full((2, 2), 2.0), + "only_b_layer": np.full((2, 2), 9.0), + }, + ) + + res = concat([a, b], join="outer", aligned_axis_key_join="inner") + + # Inner key join: only the layer present in every input survives. + assert sorted(res.layers.keys()) == ["shared_layer"] + + # Outer content join: alt-axis (var) is the union {v1, v2, v3}. + n_total_cells = 4 + n_outer_genes = 3 + assert res.shape == (n_total_cells, n_outer_genes) + assert res.layers["shared_layer"].shape == (n_total_cells, n_outer_genes) + def test_aligned_axis_key_join_does_not_affect_alt_axis_mappings(): """When concatenating along obs, varm/varp follow `merge`, not `aligned_axis_key_join`.""" From f4ea166b1ebe319b805adba936961d82802d5e86 Mon Sep 17 00:00:00 2001 From: Ekin-Kahraman Date: Tue, 5 May 2026 21:49:53 +0100 Subject: [PATCH 5/7] docs(release-notes): add 2416.feat.md fragment Captures the user-facing addition of `aligned_axis_key_join` for towncrier-style aggregation. Mirrors the format of the other `*.feat.md` fragments under docs/release-notes/. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/release-notes/2416.feat.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/2416.feat.md diff --git a/docs/release-notes/2416.feat.md b/docs/release-notes/2416.feat.md new file mode 100644 index 000000000..68cbf55d7 --- /dev/null +++ b/docs/release-notes/2416.feat.md @@ -0,0 +1 @@ +Add `aligned_axis_key_join` to {func}`anndata.concat` for controlling on-axis annotation and aligned-mapping key joins (obs/var columns, obsm/varm and obsp/varp keys, layers keys) independently of the off-axis index alignment controlled by `join` {user}`Ekin-Kahraman` From dac79dd7d3838a0f41a41789e6265771a120c4bc Mon Sep 17 00:00:00 2001 From: Ekin-Kahraman Date: Mon, 11 May 2026 15:55:23 +0100 Subject: [PATCH 6/7] test(concat): update layer-key assertions for X-as-None layer post #1707 #1707 (feat!: Unify X and layers) moved X into the layers mapping under the `None` key. Two aligned_axis_key_join layer-key tests asserted `sorted(res.layers.keys())` against named keys only, which now fails with TypeError when None is present. Switched to set comparison and explicitly included the `None` (X) key in both the outer-union and inner-intersection expectations. --- tests/test_concatenate.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 55ed74e01..502c95864 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -2018,12 +2018,14 @@ def test_aligned_axis_key_join_layer_keys_unioned_with_inner_content(): res = concat([a, b], join="inner", aligned_axis_key_join="outer") - # Outer key join: every layer name appears. - assert sorted(res.layers.keys()) == [ + # Outer key join: every layer name appears. Post #1707, X is stored + # under the `None` layer key, so it shows up here too. + assert set(res.layers.keys()) == { + None, "only_a_layer", "only_b_layer", "shared_layer", - ] + } # Inner content join: alt-axis (var) is the intersection {v1, v2}. n_total_cells = 4 @@ -2075,8 +2077,10 @@ def test_aligned_axis_key_join_layer_keys_intersected_with_outer_content(): res = concat([a, b], join="outer", aligned_axis_key_join="inner") - # Inner key join: only the layer present in every input survives. - assert sorted(res.layers.keys()) == ["shared_layer"] + # Inner key join: only the layer keys present in every input survive. + # Post #1707, X is stored under the `None` layer key and is present in + # both inputs, so it survives the intersection alongside "shared_layer". + assert set(res.layers.keys()) == {None, "shared_layer"} # Outer content join: alt-axis (var) is the union {v1, v2, v3}. n_total_cells = 4 From c47a97f69f80983f7bc6418bd2086eac1da12c5b Mon Sep 17 00:00:00 2001 From: Ekin-Kahraman Date: Fri, 15 May 2026 21:34:55 +0100 Subject: [PATCH 7/7] refactor(concat): make aligned mapping key joins explicit --- src/anndata/_core/merge.py | 53 ++++--------------- .../multi_files/_anncollection.py | 5 +- 2 files changed, 12 insertions(+), 46 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index d4f6aa0bc..a882348f1 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -10,7 +10,7 @@ from functools import partial, reduce, singledispatch from itertools import repeat from operator import and_, or_, sub -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, Literal, cast, get_args import numpy as np import pandas as pd @@ -18,6 +18,7 @@ from scipy import sparse from anndata._core.file_backing import to_memory +from anndata._types import Join_T from anndata._warnings import ExperimentalFeatureWarning from ..compat import ( @@ -35,6 +36,8 @@ from .index import _subset, make_slice from .xarray import Dataset2D +JOIN_OPTIONS = get_args(Join_T.__value__) + if TYPE_CHECKING: from collections.abc import Collection, Generator, Iterable, Sequence from typing import Any @@ -42,8 +45,6 @@ from numpy.typing import NDArray from pandas.api.extensions import ExtensionDtype - from anndata._types import Join_T - from ..compat import XDataArray from ..types import SupportsArrayApi @@ -951,25 +952,17 @@ def concat_arrays( # noqa: PLR0911, PLR0912 def inner_concat_aligned_mapping( # noqa: PLR0913 mappings, *, + keys, reindexers=None, index=None, axis=0, concat_axis=None, force_lazy: bool = False, - keys=None, fill_value=None, ): - """Inner-content concat of aligned mappings. - - By default iterates ``intersect_keys(mappings)``. Pass ``keys`` to override - the iterated key set (e.g. ``union_keys(mappings)`` for an outer key join - paired with inner content alignment); missing entries are then filled - with ``fill_value``. - """ + """Inner-content concat of aligned mappings over an explicit key set.""" if concat_axis is None: concat_axis = axis - if keys is None: - keys = intersect_keys(mappings) result = {} ns = [m.parent.shape[axis] for m in mappings] @@ -1188,24 +1181,17 @@ def missing_element( def outer_concat_aligned_mapping( # noqa: PLR0913 mappings, *, + keys, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None, force_lazy: bool = False, - keys=None, ): - """Outer-content concat of aligned mappings. - - By default iterates ``union_keys(mappings)``. Pass ``keys`` to override - the iterated key set (e.g. ``intersect_keys(mappings)`` for an inner key - join paired with outer content alignment). - """ + """Outer-content concat of aligned mappings over an explicit key set.""" if concat_axis is None: concat_axis = axis - if keys is None: - keys = union_keys(mappings) result = {} ns = [m.parent.shape[axis] for m in mappings] @@ -1778,10 +1764,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 merge = resolve_merge_strategy(merge) uns_merge = resolve_merge_strategy(uns_merge) - if aligned_axis_key_join is not None and aligned_axis_key_join not in ( - "inner", - "outer", - ): + if aligned_axis_key_join is not None and aligned_axis_key_join not in JOIN_OPTIONS: msg = ( f"`aligned_axis_key_join` must be one of 'inner', 'outer', or None, " f"got {aligned_axis_key_join!r}" @@ -1906,14 +1889,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 xr.merge(annotations_with_only_dask, join=join, compat="override") ) - # `join` controls the *off-axis* content alignment shared by layers - # and obsm-style values within each shared key. `aligned_join` controls - # which keys appear on the on-axis side (obs/var columns, obsm/obsp - # keys, layers keys). The two settings are independent; when - # `aligned_axis_key_join=None`, `aligned_join == join` and behaviour - # reduces to the historical single-knob path. Post #1707, X is a - # layer entry rather than a separate field, so the layers mapping - # implicitly carries X through the same `aligned_join` path. if join == "inner": concat_aligned_mapping = inner_concat_aligned_mapping elif join == "outer": @@ -1924,9 +1899,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 msg = f"{join=} should have been validated above by pd.concat" raise AssertionError(msg) - # Pairwise key joining (obsp/varp) is purely about which keys appear in - # the result; the per-key block-diagonal does not have a content-alignment - # axis, so a single setting is sufficient here. if aligned_join == "inner": aligned_join_keys = intersect_keys elif aligned_join == "outer": @@ -1935,9 +1907,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 msg = f"{aligned_join=} should have been validated" raise AssertionError(msg) - # Layers: `aligned_join` selects which layer keys appear; `reindexers` - # carries the off-axis alignment from X so each kept layer is aligned to - # the same alt-axis as X. layer_mappings = [a.layers for a in adatas] layers = concat_aligned_mapping( layer_mappings, @@ -1946,10 +1915,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 keys=aligned_join_keys(layer_mappings), ) - # obsm/varm: aligned_axis_key_join controls which keys appear; the content - # alignment within a shared key follows `join`. The pre-computed - # `aligned_join_keys` selects the on-axis key set; the inner/outer helper - # selected above performs the off-axis content alignment. obsm_mappings = [getattr(a, f"{axis_name}m") for a in adatas] concat_mapping = concat_aligned_mapping( obsm_mappings, diff --git a/src/anndata/experimental/multi_files/_anncollection.py b/src/anndata/experimental/multi_files/_anncollection.py index b1ed27b78..b76208b9d 100644 --- a/src/anndata/experimental/multi_files/_anncollection.py +++ b/src/anndata/experimental/multi_files/_anncollection.py @@ -14,7 +14,7 @@ from ..._core.aligned_mapping import AxisArrays from ..._core.anndata import AnnData from ..._core.index import _normalize_index, _normalize_indices -from ..._core.merge import concat_arrays, inner_concat_aligned_mapping +from ..._core.merge import concat_arrays, inner_concat_aligned_mapping, intersect_keys from ..._core.sparse_dataset import BaseCompressedSparseDataset from ..._core.views import _resolve_idx from ...compat import old_positionals @@ -768,8 +768,9 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915 if join_obsm == "inner": view_attrs.remove("obsm") self._attrs.append("obsm") + obsm_mappings = [a.obsm for a in adatas] self._obsm = inner_concat_aligned_mapping( - [a.obsm for a in adatas], index=self.obs_names + obsm_mappings, keys=intersect_keys(obsm_mappings), index=self.obs_names ) self._obsm = ( AxisArrays(self, axis=0, store={}) if self._obsm == {} else self._obsm