From c09f35ab366d99b6ed8b73f5d17e4df14e551cf3 Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Fri, 4 Oct 2024 15:09:41 +0200 Subject: [PATCH] Improved `concatenate()` for non-unique element names (#720) * accept Iterable in concatenate * concatenate: automatic non-unique names resolution * docs, changelog * add test for len 1 iterable (grst code review) --- CHANGELOG.md | 1 + src/spatialdata/_core/_utils.py | 15 ++- src/spatialdata/_core/concatenate.py | 109 +++++++++++++++--- .../operations/test_spatialdata_operations.py | 49 +++++++- 4 files changed, 154 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bfb99d232..26d8a5141 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning][]. - Added `shortest_path` parameter to `get_transformation_between_coordinate_systems` - Added `get_pyramid_levels()` utils API +- Improved ergonomics of `concatenate()` when element names are non-unique #720 ## [0.2.3] - 2024-09-25 diff --git a/src/spatialdata/_core/_utils.py b/src/spatialdata/_core/_utils.py index 1c22c802a..dd27e9c8d 100644 --- a/src/spatialdata/_core/_utils.py +++ b/src/spatialdata/_core/_utils.py @@ -1,22 +1,27 @@ +from collections.abc import Iterable + from spatialdata._core.spatialdata import SpatialData -def _find_common_table_keys(sdatas: list[SpatialData]) -> set[str]: +def _find_common_table_keys(sdatas: Iterable[SpatialData]) -> set[str]: """ Find table keys present in more than one SpatialData object. Parameters ---------- sdatas - A list of SpatialData objects. + An `Iterable` of SpatialData objects. Returns ------- A set of common keys that are present in the tables of more than one SpatialData object. """ - common_keys = set(sdatas[0].tables.keys()) + common_keys: set[str] = set() - for sdata in sdatas[1:]: - common_keys.intersection_update(sdata.tables.keys()) + for sdata in sdatas: + if len(common_keys) == 0: + common_keys = set(sdata.tables.keys()) + else: + common_keys.intersection_update(sdata.tables.keys()) return common_keys diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 8312d660e..b8548f741 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Iterable from copy import copy # Should probably go up at the top from itertools import chain from typing import Any @@ -11,7 +12,7 @@ from spatialdata._core._utils import _find_common_table_keys from spatialdata._core.spatialdata import SpatialData -from spatialdata.models import TableModel +from spatialdata.models import SpatialElement, TableModel, get_table_keys __all__ = [ "concatenate", @@ -73,10 +74,12 @@ def _concatenate_tables( def concatenate( - sdatas: list[SpatialData], + sdatas: Iterable[SpatialData] | dict[str, SpatialData], region_key: str | None = None, instance_key: str | None = None, concatenate_tables: bool = False, + obs_names_make_unique: bool = True, + modify_tables_inplace: bool = False, **kwargs: Any, ) -> SpatialData: """ @@ -85,36 +88,74 @@ def concatenate( Parameters ---------- sdatas - The spatial data objects to concatenate. + The spatial data objects to concatenate. The names of the elements across the `SpatialData` objects must be + unique. If they are not unique, you can pass a dictionary with the suffixes as keys and the spatial data objects + as values. This will rename the names of each `SpatialElement` to ensure uniqueness of names across + `SpatialData` objects. See more on the notes. region_key The key to use for the region column in the concatenated object. - If all region_keys are the same, the `region_key` is used. + If `None` and all region_keys are the same, the `region_key` is used. instance_key The key to use for the instance column in the concatenated object. + If `None` and all instance_keys are the same, the `instance_key` is used. concatenate_tables Whether to merge the tables in case of having the same element name. + obs_names_make_unique + Whether to make the `obs_names` unique by calling `AnnData.obs_names_make_unique()` on each table of the + concatenated object. If you passed a dictionary with the suffixes as keys and the `SpatialData` objects as + values and if `concatenate_tables` is `True`, the `obs_names` will be made unique by adding the corresponding + suffix instead. + modify_tables_inplace + Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables + will be copied before modification. Copying is enabled by default but can be disabled for performance reasons. kwargs See :func:`anndata.concat` for more details. Returns ------- The concatenated :class:`spatialdata.SpatialData` object. + + Notes + ----- + If you pass a dictionary with the suffixes as keys and the `SpatialData` objects as values, the names of each + `SpatialElement` will be renamed to ensure uniqueness of names across `SpatialData` objects by adding the + corresponding suffix. To ensure the matching between existing table annotations, the `region` metadata of each + table, and the values of the `region_key` column in each table, will be altered by adding the suffix. In addition, + the `obs_names` of each table will be altered (a suffix will be added). Finally, a suffix will be added to the name + of each table iff `rename_tables` is `False`. + + If you need more control in the renaming, please give us feedback, as we are still trying to find the right balance + between ergonomics and control. Also, you are welcome to copy and adjust the code of + `_fix_ensure_unique_element_names()` directly. """ + if not isinstance(sdatas, Iterable): + raise TypeError("`sdatas` must be a `Iterable`") + + if isinstance(sdatas, dict): + sdatas = _fix_ensure_unique_element_names( + sdatas, + rename_tables=not concatenate_tables, + rename_obs_names=obs_names_make_unique and concatenate_tables, + modify_tables_inplace=modify_tables_inplace, + ) + + ERROR_STR = ( + " must have unique names across the SpatialData objects to concatenate. Please pass a `dict[str, SpatialData]`" + " to `concatenate()` to address this (see docstring)." + ) + merged_images = {**{k: v for sdata in sdatas for k, v in sdata.images.items()}} if len(merged_images) != np.sum([len(sdata.images) for sdata in sdatas]): - raise KeyError("Images must have unique names across the SpatialData objects to concatenate") + raise KeyError("Images" + ERROR_STR) merged_labels = {**{k: v for sdata in sdatas for k, v in sdata.labels.items()}} if len(merged_labels) != np.sum([len(sdata.labels) for sdata in sdatas]): - raise KeyError("Labels must have unique names across the SpatialData objects to concatenate") + raise KeyError("Labels" + ERROR_STR) merged_points = {**{k: v for sdata in sdatas for k, v in sdata.points.items()}} if len(merged_points) != np.sum([len(sdata.points) for sdata in sdatas]): - raise KeyError("Points must have unique names across the SpatialData objects to concatenate") + raise KeyError("Points" + ERROR_STR) merged_shapes = {**{k: v for sdata in sdatas for k, v in sdata.shapes.items()}} if len(merged_shapes) != np.sum([len(sdata.shapes) for sdata in sdatas]): - raise KeyError("Shapes must have unique names across the SpatialData objects to concatenate") - - assert isinstance(sdatas, list), "sdatas must be a list" - assert len(sdatas) > 0, "sdatas must be a non-empty list" + raise KeyError("Shapes" + ERROR_STR) if not concatenate_tables: key_counts: dict[str, int] = defaultdict(int) @@ -124,8 +165,8 @@ def concatenate( if any(value > 1 for value in key_counts.values()): warn( - "Duplicate table names found. Tables will be added with integer suffix. Set concatenate_tables to True" - "if concatenation is wished for instead.", + "Duplicate table names found. Tables will be added with integer suffix. Set `concatenate_tables` to " + "`True` if concatenation is wished for instead.", UserWarning, stacklevel=2, ) @@ -147,13 +188,17 @@ def concatenate( else: merged_tables[k] = v - return SpatialData( + sdata = SpatialData( images=merged_images, labels=merged_labels, points=merged_points, shapes=merged_shapes, tables=merged_tables, ) + if obs_names_make_unique: + for table in sdata.tables.values(): + table.obs_names_make_unique() + return sdata def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list[str]) -> AnnData: @@ -162,3 +207,39 @@ def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list new_table = table[table.obs[region_key].isin(coordinate_systems)].copy() new_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_table.obs[region_key].unique().tolist() return new_table + + +def _fix_ensure_unique_element_names( + sdatas: dict[str, SpatialData], + rename_tables: bool, + rename_obs_names: bool, + modify_tables_inplace: bool, +) -> list[SpatialData]: + elements_by_sdata: list[dict[str, SpatialElement]] = [] + tables_by_sdata: list[dict[str, AnnData]] = [] + for suffix, sdata in sdatas.items(): + elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()} + elements_by_sdata.append(elements) + tables = {} + for name, table in sdata.tables.items(): + if not modify_tables_inplace: + table = table.copy() + + # fix the region_key column + region, region_key, _ = get_table_keys(table) + table.obs[region_key] = (table.obs[region_key].astype("str") + f"-{suffix}").astype("category") + table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = f"{region}-{suffix}" + + # fix the obs names + if rename_obs_names: + table.obs.index = table.obs.index.to_series().apply(lambda x, suffix=suffix: f"{x}-{suffix}") + + # fix the table name + new_name = f"{name}-{suffix}" if rename_tables else name + tables[new_name] = table + tables_by_sdata.append(tables) + sdatas_fixed = [] + for elements, tables in zip(elements_by_sdata, tables_by_sdata): + sdata = SpatialData.init_from_elements(elements, tables=tables) + sdatas_fixed.append(sdata) + return sdatas_fixed diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 8a59147f1..a0d7ea2cd 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -11,7 +11,14 @@ from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData from spatialdata.datasets import blobs -from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel, get_table_keys +from spatialdata.models import ( + Image2DModel, + Labels2DModel, + PointsModel, + ShapesModel, + TableModel, + get_table_keys, +) from spatialdata.testing import assert_elements_dict_are_identical, assert_spatial_data_objects_are_identical from spatialdata.transformations.operations import get_transformation, set_transformation from spatialdata.transformations.transformations import ( @@ -284,6 +291,46 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: assert len(list(concatenated.gen_elements())) == 3 +@pytest.mark.parametrize("concatenate_tables", [True, False]) +@pytest.mark.parametrize("obs_names_make_unique", [True, False]) +def test_concatenate_sdatas_from_iterable(concatenate_tables: bool, obs_names_make_unique: bool) -> None: + sdata0 = blobs() + sdata1 = blobs() + + sdatas = {"sample0": sdata0, "sample1": sdata1} + with pytest.raises(KeyError, match="Images must have unique names across the SpatialData objects"): + _ = concatenate( + sdatas.values(), concatenate_tables=concatenate_tables, obs_names_make_unique=obs_names_make_unique + ) + merged = concatenate(sdatas, obs_names_make_unique=obs_names_make_unique, concatenate_tables=concatenate_tables) + + if concatenate_tables: + assert len(merged.tables) == 1 + table = merged["table"] + if obs_names_make_unique: + assert table.obs_names[0] == "1-sample0" + assert table.obs_names[-1] == "30-sample1" + else: + assert table.obs_names[0] == "1" + else: + assert merged["table-sample0"].obs_names[0] == "1" + assert sdata0["table"].obs_names[0] == "1" + + +def test_concatenate_sdatas_single_item() -> None: + sdata = blobs() + + def _n_elements(sdata: SpatialData) -> int: + return len([0 for _, _, _ in sdata.gen_elements()]) + + n = _n_elements(sdata) + assert n == _n_elements(concatenate([sdata])) + assert n == _n_elements(concatenate({"sample": sdata}.values())) + c = concatenate({"sample": sdata}) + assert n == _n_elements(c) + assert "blobs_image-sample" in c.images + + def test_locate_spatial_element(full_sdata: SpatialData) -> None: assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d" im = full_sdata.images["image2d"]