From 7e76133b58ae4466476eada0dc48ad452bbd23e1 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 24 Feb 2024 17:13:52 +0100 Subject: [PATCH] unfixed full implementation --- src/spatialdata_plot/pl/basic.py | 76 ++++++++++---------- src/spatialdata_plot/pl/render.py | 44 ++++++------ src/spatialdata_plot/pl/utils.py | 114 ++++++++++++++++++++++++++---- 3 files changed, 164 insertions(+), 70 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 29b54acc..1bf340b9 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -4,6 +4,7 @@ from collections import OrderedDict from pathlib import Path from typing import Any, Union +from copy import deepcopy import matplotlib.pyplot as plt import numpy as np @@ -48,11 +49,10 @@ _prepare_cmap_norm, _prepare_params_plot, _set_outline, - _update_element_table_mapping_colors, - _update_element_table_mapping_points_shapes_colors, + _validate_colors_element_table_mapping_points_shapes, _validate_render_params, _validate_show_parameters, - save_fig, + save_fig, _update_element_table_mapping_label_colors, ) from spatialdata_plot.pp.utils import _verify_plotting_tree @@ -776,9 +776,11 @@ def show( wanted_elements = [] for cmd, params in render_cmds: + # We create a copy here as the wanted elements can change from one cs to another. + params_copy = deepcopy(params) if cmd == "render_images" and has_images: wants_images = True - wanted_images = params.elements if params.elements is not None else list(sdata.images.keys()) + wanted_images = params_copy.elements if params_copy.elements is not None else list(sdata.images.keys()) wanted_images_on_this_cs = [ image for image in wanted_images @@ -786,14 +788,14 @@ def show( ] wanted_elements.extend(wanted_images_on_this_cs) if wanted_images_on_this_cs: - rasterize = (params.scale is None) or ( - isinstance(params.scale, str) - and params.scale != "full" + rasterize = (params_copy.scale is None) or ( + isinstance(params_copy.scale, str) + and params_copy.scale != "full" and (dpi is not None or figsize is not None) ) _render_images( sdata=sdata, - render_params=params, + render_params=params_copy, coordinate_system=cs, ax=ax, fig_params=fig_params, @@ -804,23 +806,23 @@ def show( elif cmd == "render_shapes" and has_shapes: wants_shapes = True - wanted_shapes = params.elements if params.elements is not None else list(sdata.shapes.keys()) + wanted_shapes = params_copy.elements if params_copy.elements is not None else list(sdata.shapes.keys()) wanted_shapes_on_this_cs = [ shape for shape in wanted_shapes if cs in set(get_transformation(sdata.shapes[shape], get_all=True).keys()) ] if wanted_shapes_on_this_cs: - params = _create_initial_element_table_mapping(sdata, params, wanted_shapes_on_this_cs) - params = _update_element_table_mapping_points_shapes_colors( - sdata, params, wanted_shapes_on_this_cs + params_copy = _create_initial_element_table_mapping(sdata, params_copy, wanted_shapes_on_this_cs) + params_copy = _validate_colors_element_table_mapping_points_shapes( + sdata, params_copy, wanted_shapes_on_this_cs ) wanted_elements.extend(wanted_shapes_on_this_cs) if wanted_shapes_on_this_cs: _render_shapes( sdata=sdata, - render_params=params, + render_params=params_copy, coordinate_system=cs, ax=ax, fig_params=fig_params, @@ -830,7 +832,7 @@ def show( elif cmd == "render_points" and has_points: wants_points = True - wanted_points = params.elements if params.elements is not None else list(sdata.points.keys()) + wanted_points = params_copy.elements if params_copy.elements is not None else list(sdata.points.keys()) wanted_points_on_this_cs = [ point for point in wanted_points @@ -838,16 +840,16 @@ def show( ] if wanted_points_on_this_cs: - params = _create_initial_element_table_mapping(sdata, params, wanted_points_on_this_cs) - params = _update_element_table_mapping_points_shapes_colors( - sdata, params, wanted_points_on_this_cs + params_copy = _create_initial_element_table_mapping(sdata, params_copy, wanted_points_on_this_cs) + params_copy = _validate_colors_element_table_mapping_points_shapes( + sdata, params_copy, wanted_points_on_this_cs ) wanted_elements.extend(wanted_points_on_this_cs) if wanted_points_on_this_cs: _render_points( sdata=sdata, - render_params=params, + render_params=params_copy, coordinate_system=cs, ax=ax, fig_params=fig_params, @@ -857,7 +859,7 @@ def show( elif cmd == "render_labels" and has_labels: wants_labels = True - wanted_labels = params.elements if params.elements is not None else list(sdata.labels.keys()) + wanted_labels = params_copy.elements if params_copy.elements is not None else list(sdata.labels.keys()) wanted_labels_on_this_cs = [ label for label in wanted_labels @@ -867,36 +869,36 @@ def show( if wanted_labels_on_this_cs: # Create element to table mapping and check whether specified color columns are in tables. - params = _create_initial_element_table_mapping(sdata, params, wanted_labels_on_this_cs) - params = _update_element_table_mapping_colors(sdata, params, wanted_labels_on_this_cs) - - if isinstance(params.color, list): - element_table_mapping = params.element_table_mapping - params.color = ( - [params.color[0] if value is not None else None for value in element_table_mapping.values()] - if len(params.color) == 1 - else params.color + params_copy = _create_initial_element_table_mapping(sdata, params_copy, wanted_labels_on_this_cs) + params_copy = _update_element_table_mapping_label_colors(sdata, params_copy, wanted_labels_on_this_cs) + + if isinstance(params_copy.color, list): + element_table_mapping = params_copy.element_table_mapping + params_copy.color = ( + [params_copy.color[0] if value is not None else None for value in element_table_mapping.values()] + if len(params_copy.color) == 1 + else params_copy.color ) - for index, table in enumerate(params.element_table_mapping.values()): + for index, table in enumerate(params_copy.element_table_mapping.values()): if table is None: continue - colors = sc.get.obs_df(sdata[table], params.color[index]) - if isinstance(colors, pd.CategoricalDtype): + colors = sc.get.obs_df(sdata[table], params_copy.color[index]) + if isinstance(colors.dtype, pd.CategoricalDtype): _maybe_set_colors( source=sdata[table], target=sdata[table], - key=params.color[index], - palette=params.palette, + key=params_copy.color[index], + palette=params_copy.palette, ) - rasterize = (params.scale is None) or ( - isinstance(params.scale, str) - and params.scale != "full" + rasterize = (params_copy.scale is None) or ( + isinstance(params_copy.scale, str) + and params_copy.scale != "full" and (dpi is not None or figsize is not None) ) _render_labels( sdata=sdata, - render_params=params, + render_params=params_copy, coordinate_system=cs, ax=ax, fig_params=fig_params, diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 65ef9586..04d4256d 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -60,7 +60,7 @@ def _render_shapes( legend_params: LegendParams, ) -> None: elements = render_params.elements - table_name = render_params.element_table_mapping + element_table_mapping = render_params.element_table_mapping if render_params.groups is not None: if isinstance(render_params.groups, str): @@ -70,18 +70,19 @@ def _render_shapes( sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, - filter_table=sdata.get(table_name) is not None, + filter_table=any(value is not None for value in element_table_mapping.values()), ) if elements is None: elements = list(sdata_filt.shapes.keys()) - for e in elements: + for index, e in enumerate(elements): shapes = sdata.shapes[e] n_shapes = sum(len(s) for s in shapes) - if sdata.get(table_name) is None: - table = AnnData(None, obs=pd.DataFrame(index=pd.Index(np.arange(n_shapes), dtype=str))) + table_name = element_table_mapping.get(e) + if table_name is None: + table = None else: _, region_key, _ = get_table_keys(sdata[table_name]) table = sdata[table_name][sdata[table_name].obs[region_key].isin([e])] @@ -91,10 +92,10 @@ def _render_shapes( sdata=sdata_filt, element=sdata_filt.shapes[e], element_name=e, - value_to_plot=render_params.col_for_color, + value_to_plot=render_params.col_for_color[index], groups=render_params.groups, palette=render_params.palette, - na_color=render_params.color or render_params.cmap_params.na_color, + na_color=render_params.color[index] or render_params.cmap_params.na_color[index], cmap_params=render_params.cmap_params, table_name=table_name, ) @@ -158,18 +159,18 @@ def _render_shapes( len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color) ): # necessary in case different shapes elements are annotated with one table - if color_source_vector is not None and render_params.col_for_color is not None: + if color_source_vector is not None and render_params.col_for_color[index] is not None: color_source_vector = color_source_vector.remove_unused_categories() # False if user specified color-like with 'color' parameter - colorbar = False if render_params.col_for_color is None else legend_params.colorbar + colorbar = False if render_params.col_for_color[index] is None else legend_params.colorbar _ = _decorate_axs( ax=ax, cax=cax, fig_params=fig_params, adata=table, - value_to_plot=render_params.col_for_color, + value_to_plot=render_params.col_for_color[index], color_source_vector=color_source_vector, palette=palette, alpha=render_params.fill_alpha, @@ -195,19 +196,20 @@ def _render_points( legend_params: LegendParams, ) -> None: elements = render_params.elements - table_name = render_params.element_table_mapping + element_table_mapping = render_params.element_table_mapping sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, - filter_table=sdata.get(table_name) is not None, + filter_table=any(value is not None for value in element_table_mapping.values()), ) if elements is None: elements = list(sdata_filt.points.keys()) - for e in elements: + for index, e in enumerate(elements): points = sdata.points[e] - col_for_color = render_params.col_for_color + col_for_color = render_params.col_for_color[index] + table_name = element_table_mapping.get(e) coords = ["x", "y"] if col_for_color is not None: @@ -231,21 +233,21 @@ def _render_points( points = dask.dataframe.from_pandas(points, npartitions=1) sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"}) - if render_params.col_for_color is not None: - cols = sc.get.obs_df(adata, render_params.col_for_color) + if col_for_color is not None: + cols = sc.get.obs_df(adata, col_for_color) # maybe set color based on type - if is_categorical_dtype(cols): + if isinstance(cols.dtype, pd.CategoricalDtype): _maybe_set_colors( source=adata, target=adata, - key=render_params.col_for_color, + key=col_for_color, palette=render_params.palette, ) # when user specified a single color, we overwrite na with it default_color = ( - render_params.color - if render_params.col_for_color is None and render_params.color is not None + render_params.color[index] + if col_for_color is None and render_params.color[index] is not None else render_params.cmap_params.na_color ) @@ -253,7 +255,7 @@ def _render_points( sdata=sdata_filt, element=points, element_name=e, - value_to_plot=render_params.col_for_color, + value_to_plot=render_params.col_for_color[index], groups=render_params.groups, palette=render_params.palette, na_color=default_color, diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index fb7addbd..16832807 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -386,7 +386,7 @@ def _set_outline( if outline_width == 0.0: outline = False if outline_width < 0.0: - logger.warning(f"Negative line widths are not allowed, changing {outline_width} to {(-1)*outline_width}") + logger.warning(f"Negative line widths are not allowed, changing {outline_width} to {(-1) * outline_width}") outline_width *= -1 # the default black and white colors can be changed using the contour_config parameter @@ -621,7 +621,7 @@ def _set_color_source_vec( color_source_vector = vals[value_to_plot] # numerical case, return early - if not is_categorical_dtype(color_source_vector): + if not isinstance(color_source_vector.dtype, pd.CategoricalDtype): if palette is not None: logger.warning( "Ignoring categorical palette which is given for a continuous variable. " @@ -666,7 +666,7 @@ def _map_color_seg( ) -> ArrayLike: cell_id = np.array(cell_id) - if is_categorical_dtype(color_vector): + if isinstance(color_vector.dtype, pd.CategoricalDtype): if isinstance(na_color, tuple) and len(na_color) == 4 and np.any(color_source_vector.isna()): cell_id[color_source_vector.isna()] = 0 val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1) @@ -773,9 +773,9 @@ def _decorate_axs( ax: Axes, cax: PatchCollection, fig_params: FigParams, - adata: AnnData, value_to_plot: str | None, color_source_vector: pd.Series[CategoricalDtype], + adata: AnnData | None = None, palette: ListedColormap | str | list[str] | None = None, alpha: float = 1.0, na_color: str | tuple[float, ...] = (0.0, 0.0, 0.0, 0.0), @@ -799,7 +799,7 @@ def _decorate_axs( path_effect = [] # Adding legends - if is_categorical_dtype(color_source_vector): + if isinstance(color_source_vector.dtype, pd.CategoricalDtype): # order of clusters should agree to palette order clusters = color_source_vector.unique() clusters = clusters[~clusters.isnull()] @@ -1002,7 +1002,7 @@ def _translate_image( def _convert_polygon_to_linestrings(polygon: Polygon) -> list[LineString]: b = polygon.boundary.coords - linestrings = [LineString(b[k : k + 2]) for k in range(len(b) - 1)] + linestrings = [LineString(b[k: k + 2]) for k in range(len(b) - 1)] return [list(ls.coords) for ls in linestrings] @@ -1093,11 +1093,11 @@ def _get_valid_cs( and any(e in elements for e in cs_mapping[cs]) or not elements and ( - (len(sdata.images.keys()) > 0 and render_images) - or (len(sdata.labels.keys()) > 0 and render_labels) - or (len(sdata.points.keys()) > 0 and render_points) - or (len(sdata.shapes.keys()) > 0 and render_shapes) - ) + (len(sdata.images.keys()) > 0 and render_images) + or (len(sdata.labels.keys()) > 0 and render_labels) + or (len(sdata.points.keys()) > 0 and render_points) + or (len(sdata.shapes.keys()) > 0 and render_shapes) + ) ): # not nice, but ruff wants it (SIM114) valid_cs.append(cs) else: @@ -1334,7 +1334,7 @@ def _create_initial_element_table_mapping( return params -def _update_element_table_mapping_colors(sdata, params, render_elements): +def _update_element_table_mapping_label_colors(sdata, params, render_elements): element_table_mapping = params.element_table_mapping if params.color is not None: params.color = [params.color] if isinstance(params.color, str) else params.color @@ -1370,6 +1370,60 @@ def _update_element_table_mapping_colors(sdata, params, render_elements): return params +def _validate_colors_element_table_mapping_points_shapes(sdata, params, render_elements): + element_table_mapping = params.element_table_mapping + if len(params.color) == 1: + color = params.color[0] + col_color = params.col_for_color[0] + # This means that we are dealing with colors that are color like + if color is not None: + params.color = [color] * len(render_elements) + params.col_for_color = [None] * len(render_elements) + else: + if col_color is not None: + params.color = [None] * len(render_elements) + params.col_for_color = [] + for element_name in render_elements: + for table_name in element_table_mapping[element_name].copy(): + if ( + col_color not in sdata[table_name].obs.columns + and col_color not in sdata[table_name].var_names + and col_color not in sdata[element_name].columns + ): + element_table_mapping[element_name].remove(table_name) + params.col_for_color.append(None) + else: + params.col_for_color.append(col_color) + else: + params.col_for_color = [None] * len(render_elements) + else: + assert len(params.color) == len(render_elements), f"The number of given colors and elements to render is not equal. Either provide one color or a list with one color for each element." + for index, color in enumerate(params.color): + if color is None: + element_name = render_elements[index] + col_color = params.col_for_color[index] + for table_name in element_table_mapping[element_name].copy(): + if ( + col_color not in sdata[table_name].obs.columns + and col_color not in sdata[table_name].var_names + and col_color not in sdata[element_name].columns + ): + element_table_mapping[element_name].remove(table_name) + for index, element_name in enumerate(render_elements): + # We only want one table value per element and only when there is a color column in the table + if params.col_for_color[index] is not None: + table_set = element_table_mapping[element_name] + if len(table_set) != 1: + raise ValueError(f"More than one table found with color column {params.col_for_color[index]}.") + element_table_mapping[element_name] = next(iter(table_set)) if len(table_set) != 0 else None + if element_table_mapping[element_name] is None: + warnings.warn(f"No table found with color column {params.col_for_color[index]} to render {element_name}") + else: + element_table_mapping[element_name] = None + params.element_table_mapping = element_table_mapping + return params + + def _validate_show_parameters( coordinate_systems: list[str] | str | None, legend_fontsize: int | float | _FontSize | None, @@ -1581,6 +1635,42 @@ def _validate_render_params( if not colors.is_color_like(outline_color): raise TypeError("Parameter 'outline_color' must be color-like.") + if element_type in ["points", "shapes"]: + if color is not None: + if not isinstance(color, list): + if colors.is_color_like(color): + logger.info("Value for parameter 'color' appears to be a color, using it as such.") + color = [color] + col_for_color = [None] + else: + if not isinstance(color, str): + raise TypeError( + "Parameter 'color' must be a string indicating which color " + + "in sdata.table to use for coloring the shapes." + ) + col_for_color = [color] + color = [None] + else: + col_for_color = [] + for index, c in enumerate(color): + if colors.is_color_like(c): + logger.info(f"Value `{c}` in list 'color' appears to be a color, using it as such.") + color[index] = c + col_for_color.append(None) + else: + if not isinstance(c, str): + raise TypeError( + f"Value `{c}` in list Parameter 'color' must be a string indicating which color " + + "in sdata.table to use for coloring the shapes or should be color-like." + ) + col_for_color.append(c) + color[index] = None + else: + color = [color] + col_for_color = [None] + params_dict["color"] = color + params_dict["col_for_color"] = col_for_color + if element_type == "points": if not isinstance(size, (float, int)): raise TypeError("Parameter 'size' must be numeric.")