Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix table name and add unit test for colors passed via .uns #413

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def _set_color_source_vec(
color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series`

color_mapping = _get_categorical_color_mapping(
adata=sdata.table,
adata=sdata.tables[table_name] if table_name is not None else sdata.table,
cluster_key=value_to_plot,
color_source_vector=color_source_vector,
groups=groups,
Expand Down
Binary file modified tests/_images/Labels_can_annotate_labels_with_table_layer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 58 additions & 2 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import scanpy as sc
from anndata import AnnData
from spatial_image import to_spatial_image
from spatialdata import SpatialData, deepcopy, get_element_instances
from spatialdata import SpatialData, bounding_box_query, deepcopy, get_element_instances
from spatialdata.models import TableModel

import spatialdata_plot # noqa: F401
from tests.conftest import DPI, PlotTester, PlotTesterMeta

RNG = np.random.default_rng(seed=42)
sc.pl.set_rcParams_defaults()
sc.set_figure_params(dpi=DPI, color_map="viridis")
matplotlib.use("agg") # same as GitHub action runner
Expand Down Expand Up @@ -214,7 +213,63 @@ def test_plot_label_categorical_color(self, sdata_blobs: SpatialData):
self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels")
sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()

def test_plot_label_categorical_color_and_colors_in_uns(self, sdata_blobs: SpatialData):
self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels")
# purple, green, yellow
sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"]
# placeholder, otherwise "category_colors" will be ignored
sdata_blobs["other_table"].uns["category"] = "__value__"
sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()

def test_plot_label_categorical_color_and_colors_in_uns_query_uns_colors_removed(self, sdata_blobs: SpatialData):
self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels")
# purple, green, yellow
sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"]
# placeholder, otherwise "category_colors" will be ignored
sdata_blobs["other_table"].uns["category"] = "__value__"
sdata_blobs = bounding_box_query(
sdata_blobs,
axes=("y", "x"),
min_coordinate=[0, 0],
max_coordinate=[100, 100],
target_coordinate_system="global",
)
# we would expect colors purple and yellow for a and c, but we see default colors blue and orange,
# Reason: "category_colors" is removed by `.filter_by_coordinate_system` in
# `spatialdata_plot.pl.render._render_labels`.
# Why? Because `.bounding_box_query` removes "category_colors" that are not in the query,
# but restores original number of catergories in `.obs["category"]`, see https://github.com/scverse/anndata/issues/997,
# leading to mismatch and removal of "category_colors" by `.filter_by_coordinate_system`.
assert all(sdata_blobs["other_table"].obs["category"].unique() == ["a", "c"])
assert all(sdata_blobs["other_table"].uns["category_colors"] == ["#800080", "#FFFF00"])
# but due to https://github.com/scverse/anndata/issues/997:
assert all(sdata_blobs["other_table"].obs["category"].cat.categories == ["a", "b", "c"])
sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()

def test_plot_label_categorical_color_and_colors_in_uns_query_workaround(self, sdata_blobs: SpatialData):
self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels")
# purple, green, yellow
sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"]
# placeholder, otherwise "category_colors" will be ignored
sdata_blobs["other_table"].uns["category"] = "__value__"
sdata_blobs = bounding_box_query(
sdata_blobs,
axes=("y", "x"),
min_coordinate=[0, 0],
max_coordinate=[100, 100],
target_coordinate_system="global",
)
assert all(sdata_blobs["other_table"].obs["category"].unique() == ["a", "c"])
assert all(sdata_blobs["other_table"].uns["category_colors"] == ["#800080", "#FFFF00"])
# but due to https://github.com/scverse/anndata/issues/997:
assert all(sdata_blobs["other_table"].obs["category"].cat.categories == ["a", "b", "c"])
sdata_blobs["other_table"].obs["category"] = (
sdata_blobs["other_table"].obs["category"].cat.remove_unused_categories()
)
sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()

def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str):
RNG = np.random.default_rng(seed=42)
instances = get_element_instances(sdata_blobs[labels_name])
n_obs = len(instances)
adata = AnnData(
Expand All @@ -235,5 +290,6 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")

def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData):
RNG = np.random.default_rng(seed=42)
sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape)
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()
Loading