Skip to content

Commit

Permalink
unfixed full implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
melonora committed Feb 24, 2024
1 parent 406fb1b commit 7e76133
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 70 deletions.
76 changes: 39 additions & 37 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import OrderedDict
from pathlib import Path
from typing import Any, Union
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -48,11 +49,10 @@
_prepare_cmap_norm,
_prepare_params_plot,
_set_outline,
_update_element_table_mapping_colors,
_update_element_table_mapping_points_shapes_colors,
_validate_colors_element_table_mapping_points_shapes,
_validate_render_params,
_validate_show_parameters,
save_fig,
save_fig, _update_element_table_mapping_label_colors,
)
from spatialdata_plot.pp.utils import _verify_plotting_tree

Expand Down Expand Up @@ -776,24 +776,26 @@ def show(
wanted_elements = []

for cmd, params in render_cmds:
# We create a copy here as the wanted elements can change from one cs to another.
params_copy = deepcopy(params)
if cmd == "render_images" and has_images:
wants_images = True
wanted_images = params.elements if params.elements is not None else list(sdata.images.keys())
wanted_images = params_copy.elements if params_copy.elements is not None else list(sdata.images.keys())
wanted_images_on_this_cs = [
image
for image in wanted_images
if cs in set(get_transformation(sdata.images[image], get_all=True).keys())
]
wanted_elements.extend(wanted_images_on_this_cs)
if wanted_images_on_this_cs:
rasterize = (params.scale is None) or (
isinstance(params.scale, str)
and params.scale != "full"
rasterize = (params_copy.scale is None) or (
isinstance(params_copy.scale, str)
and params_copy.scale != "full"
and (dpi is not None or figsize is not None)
)
_render_images(
sdata=sdata,
render_params=params,
render_params=params_copy,
coordinate_system=cs,
ax=ax,
fig_params=fig_params,
Expand All @@ -804,23 +806,23 @@ def show(

elif cmd == "render_shapes" and has_shapes:
wants_shapes = True
wanted_shapes = params.elements if params.elements is not None else list(sdata.shapes.keys())
wanted_shapes = params_copy.elements if params_copy.elements is not None else list(sdata.shapes.keys())
wanted_shapes_on_this_cs = [
shape
for shape in wanted_shapes
if cs in set(get_transformation(sdata.shapes[shape], get_all=True).keys())
]
if wanted_shapes_on_this_cs:
params = _create_initial_element_table_mapping(sdata, params, wanted_shapes_on_this_cs)
params = _update_element_table_mapping_points_shapes_colors(
sdata, params, wanted_shapes_on_this_cs
params_copy = _create_initial_element_table_mapping(sdata, params_copy, wanted_shapes_on_this_cs)
params_copy = _validate_colors_element_table_mapping_points_shapes(
sdata, params_copy, wanted_shapes_on_this_cs
)

wanted_elements.extend(wanted_shapes_on_this_cs)
if wanted_shapes_on_this_cs:
_render_shapes(
sdata=sdata,
render_params=params,
render_params=params_copy,
coordinate_system=cs,
ax=ax,
fig_params=fig_params,
Expand All @@ -830,24 +832,24 @@ def show(

elif cmd == "render_points" and has_points:
wants_points = True
wanted_points = params.elements if params.elements is not None else list(sdata.points.keys())
wanted_points = params_copy.elements if params_copy.elements is not None else list(sdata.points.keys())
wanted_points_on_this_cs = [
point
for point in wanted_points
if cs in set(get_transformation(sdata.points[point], get_all=True).keys())
]

if wanted_points_on_this_cs:
params = _create_initial_element_table_mapping(sdata, params, wanted_points_on_this_cs)
params = _update_element_table_mapping_points_shapes_colors(
sdata, params, wanted_points_on_this_cs
params_copy = _create_initial_element_table_mapping(sdata, params_copy, wanted_points_on_this_cs)
params_copy = _validate_colors_element_table_mapping_points_shapes(
sdata, params_copy, wanted_points_on_this_cs
)

wanted_elements.extend(wanted_points_on_this_cs)
if wanted_points_on_this_cs:
_render_points(
sdata=sdata,
render_params=params,
render_params=params_copy,
coordinate_system=cs,
ax=ax,
fig_params=fig_params,
Expand All @@ -857,7 +859,7 @@ def show(

elif cmd == "render_labels" and has_labels:
wants_labels = True
wanted_labels = params.elements if params.elements is not None else list(sdata.labels.keys())
wanted_labels = params_copy.elements if params_copy.elements is not None else list(sdata.labels.keys())
wanted_labels_on_this_cs = [
label
for label in wanted_labels
Expand All @@ -867,36 +869,36 @@ def show(

if wanted_labels_on_this_cs:
# Create element to table mapping and check whether specified color columns are in tables.
params = _create_initial_element_table_mapping(sdata, params, wanted_labels_on_this_cs)
params = _update_element_table_mapping_colors(sdata, params, wanted_labels_on_this_cs)

if isinstance(params.color, list):
element_table_mapping = params.element_table_mapping
params.color = (
[params.color[0] if value is not None else None for value in element_table_mapping.values()]
if len(params.color) == 1
else params.color
params_copy = _create_initial_element_table_mapping(sdata, params_copy, wanted_labels_on_this_cs)
params_copy = _update_element_table_mapping_label_colors(sdata, params_copy, wanted_labels_on_this_cs)

if isinstance(params_copy.color, list):
element_table_mapping = params_copy.element_table_mapping
params_copy.color = (
[params_copy.color[0] if value is not None else None for value in element_table_mapping.values()]
if len(params_copy.color) == 1
else params_copy.color
)
for index, table in enumerate(params.element_table_mapping.values()):
for index, table in enumerate(params_copy.element_table_mapping.values()):
if table is None:
continue
colors = sc.get.obs_df(sdata[table], params.color[index])
if isinstance(colors, pd.CategoricalDtype):
colors = sc.get.obs_df(sdata[table], params_copy.color[index])
if isinstance(colors.dtype, pd.CategoricalDtype):
_maybe_set_colors(
source=sdata[table],
target=sdata[table],
key=params.color[index],
palette=params.palette,
key=params_copy.color[index],
palette=params_copy.palette,
)

rasterize = (params.scale is None) or (
isinstance(params.scale, str)
and params.scale != "full"
rasterize = (params_copy.scale is None) or (
isinstance(params_copy.scale, str)
and params_copy.scale != "full"
and (dpi is not None or figsize is not None)
)
_render_labels(
sdata=sdata,
render_params=params,
render_params=params_copy,
coordinate_system=cs,
ax=ax,
fig_params=fig_params,
Expand Down
44 changes: 23 additions & 21 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _render_shapes(
legend_params: LegendParams,
) -> None:
elements = render_params.elements
table_name = render_params.element_table_mapping
element_table_mapping = render_params.element_table_mapping

if render_params.groups is not None:
if isinstance(render_params.groups, str):
Expand All @@ -70,18 +70,19 @@ def _render_shapes(

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_table=sdata.get(table_name) is not None,
filter_table=any(value is not None for value in element_table_mapping.values()),
)

if elements is None:
elements = list(sdata_filt.shapes.keys())

for e in elements:
for index, e in enumerate(elements):
shapes = sdata.shapes[e]
n_shapes = sum(len(s) for s in shapes)

if sdata.get(table_name) is None:
table = AnnData(None, obs=pd.DataFrame(index=pd.Index(np.arange(n_shapes), dtype=str)))
table_name = element_table_mapping.get(e)
if table_name is None:
table = None
else:
_, region_key, _ = get_table_keys(sdata[table_name])
table = sdata[table_name][sdata[table_name].obs[region_key].isin([e])]
Expand All @@ -91,10 +92,10 @@ def _render_shapes(
sdata=sdata_filt,
element=sdata_filt.shapes[e],
element_name=e,
value_to_plot=render_params.col_for_color,
value_to_plot=render_params.col_for_color[index],
groups=render_params.groups,
palette=render_params.palette,
na_color=render_params.color or render_params.cmap_params.na_color,
na_color=render_params.color[index] or render_params.cmap_params.na_color[index],
cmap_params=render_params.cmap_params,
table_name=table_name,
)
Expand Down Expand Up @@ -158,18 +159,18 @@ def _render_shapes(
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
):
# necessary in case different shapes elements are annotated with one table
if color_source_vector is not None and render_params.col_for_color is not None:
if color_source_vector is not None and render_params.col_for_color[index] is not None:
color_source_vector = color_source_vector.remove_unused_categories()

# False if user specified color-like with 'color' parameter
colorbar = False if render_params.col_for_color is None else legend_params.colorbar
colorbar = False if render_params.col_for_color[index] is None else legend_params.colorbar

_ = _decorate_axs(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=render_params.col_for_color,
value_to_plot=render_params.col_for_color[index],
color_source_vector=color_source_vector,
palette=palette,
alpha=render_params.fill_alpha,
Expand All @@ -195,19 +196,20 @@ def _render_points(
legend_params: LegendParams,
) -> None:
elements = render_params.elements
table_name = render_params.element_table_mapping
element_table_mapping = render_params.element_table_mapping

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_table=sdata.get(table_name) is not None,
filter_table=any(value is not None for value in element_table_mapping.values()),
)

if elements is None:
elements = list(sdata_filt.points.keys())

for e in elements:
for index, e in enumerate(elements):
points = sdata.points[e]
col_for_color = render_params.col_for_color
col_for_color = render_params.col_for_color[index]
table_name = element_table_mapping.get(e)

coords = ["x", "y"]
if col_for_color is not None:
Expand All @@ -231,29 +233,29 @@ def _render_points(
points = dask.dataframe.from_pandas(points, npartitions=1)
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})

if render_params.col_for_color is not None:
cols = sc.get.obs_df(adata, render_params.col_for_color)
if col_for_color is not None:
cols = sc.get.obs_df(adata, col_for_color)
# maybe set color based on type
if is_categorical_dtype(cols):
if isinstance(cols.dtype, pd.CategoricalDtype):
_maybe_set_colors(
source=adata,
target=adata,
key=render_params.col_for_color,
key=col_for_color,
palette=render_params.palette,
)

# when user specified a single color, we overwrite na with it
default_color = (
render_params.color
if render_params.col_for_color is None and render_params.color is not None
render_params.color[index]
if col_for_color is None and render_params.color[index] is not None
else render_params.cmap_params.na_color
)

color_source_vector, color_vector, _ = _set_color_source_vec(
sdata=sdata_filt,
element=points,
element_name=e,
value_to_plot=render_params.col_for_color,
value_to_plot=render_params.col_for_color[index],
groups=render_params.groups,
palette=render_params.palette,
na_color=default_color,
Expand Down
Loading

0 comments on commit 7e76133

Please sign in to comment.