diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 38183fa8..e1d28d3b 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -293,7 +293,9 @@ def render_shapes( ----- - Empty geometries will be removed at the time of plotting. - An `outline_width` of 0.0 leads to no border being plotted. - - When passing a color-like to 'color', this has precedence over the potential existence as a column name. + - If ``color`` is a string that is both a matplotlib color name and a column name in the + element or an annotating table, a ``ValueError`` is raised. Disambiguate by passing + a hex string (e.g. ``"#ffa500"``) or an RGB(A) tuple, or by renaming the column. Returns ------- diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 4e872c8f..dd0daee6 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2280,6 +2280,41 @@ def _validate_show_parameters( ) +def _check_color_column_collision( + sdata: SpatialData, + elements: list[str], + color: str, + element_type: str, +) -> None: + """Raise if ``color`` is a color-like string that also names a column in the element or its tables.""" + matches: list[str] = [] + for el in elements: + if element_type in {"shapes", "points"}: + try: + el_cols = sdata[el].columns + except (KeyError, AttributeError): + el_cols = () + if color in el_cols: + matches.append(f"element '{el}'") + continue + try: + tables = get_element_annotators(sdata, el) + except (KeyError, ValueError): + tables = set() + for t in tables: + adata = sdata[t] + if color in adata.obs.columns or color in adata.var_names: + matches.append(f"table '{t}' (annotating '{el}')") + break + if matches: + locations = ", ".join(matches) + raise ValueError( + f"`color={color!r}` is ambiguous: it is a valid matplotlib color name AND a column " + f"name in {locations}. Disambiguate by either passing an unambiguous color form " + f"(hex string like '#ffa500' or an RGB(A) tuple), or by renaming the column." + ) + + def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]: colorbar = param_dict.get("colorbar", "auto") if colorbar not in {True, False, None, "auto"}: @@ -2330,7 +2365,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if not isinstance(color, str | tuple | list): raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.") if _is_color_like(color): - logger.info("Value for parameter 'color' appears to be a color, using it as such.") + if isinstance(color, str): + _check_color_column_collision(param_dict["sdata"], param_dict["element"], color, element_type) param_dict["col_for_color"] = None param_dict["color"] = Color(color) if param_dict["color"].alpha_is_user_defined(): diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index a648b3d8..467a5efe 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -1,3 +1,4 @@ +import geopandas as gpd import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -5,7 +6,10 @@ import pytest import scanpy as sc import xarray as xr +from anndata import AnnData +from shapely.geometry import Point from spatialdata import SpatialData +from spatialdata.models import PointsModel, ShapesModel, TableModel import spatialdata_plot from spatialdata_plot.pl.render_params import Color @@ -380,3 +384,39 @@ def test_exact_match_selects_that_scale(self): multiscale = self._make_multiscale((3, 500, 500), [2, 2]) result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=2.5, height=2.5) assert result.sizes["x"] == 250 + + +def test_color_column_collision_on_element_columns_raises(): + # regression test for #619, element-column path: points with an "orange" column + color="orange". + points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0], "orange": [0.1, 0.2, 0.3]})) + sdata = SpatialData(points={"pts": points}) + + with pytest.raises(ValueError, match=r"color='orange'.*ambiguous.*element 'pts'"): + sdata.pl.render_points("pts", color="orange") + + sdata.pl.render_points("pts", color="#ffa500") + sdata.pl.render_points("pts", color=(1.0, 0.65, 0.0)) + + +def test_color_column_collision_on_annotating_table_raises(): + # regression test for #619, table path: element has no "orange" column but its annotating table does. + shapes = ShapesModel.parse(gpd.GeoDataFrame({"geometry": [Point(i, 0) for i in range(3)], "radius": [0.5] * 3})) + obs = pd.DataFrame( + { + "region": pd.Categorical(["s"] * 3), + "instance_id": list(range(3)), + "orange": pd.Categorical(["A", "B", "A"]), + } + ) + table = TableModel.parse( + AnnData(X=np.zeros((3, 1)), obs=obs), + region="s", + region_key="region", + instance_key="instance_id", + ) + sdata = SpatialData(shapes={"s": shapes}, tables={"t": table}) + + with pytest.raises(ValueError, match=r"color='orange'.*ambiguous.*table 't'"): + sdata.pl.render_shapes("s", color="orange") + + sdata.pl.render_shapes("s", color="#ffa500")