Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def render_shapes(
-----
- Empty geometries will be removed at the time of plotting.
- An `outline_width` of 0.0 leads to no border being plotted.
- When passing a color-like to 'color', this has precedence over the potential existence as a column name.
- If ``color`` is a string that is both a matplotlib color name and a column name in the
element or an annotating table, a ``ValueError`` is raised. Disambiguate by passing
a hex string (e.g. ``"#ffa500"``) or an RGB(A) tuple, or by renaming the column.

Returns
-------
Expand Down
38 changes: 37 additions & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2280,6 +2280,41 @@ def _validate_show_parameters(
)


def _check_color_column_collision(
sdata: SpatialData,
elements: list[str],
color: str,
element_type: str,
) -> None:
"""Raise if ``color`` is a color-like string that also names a column in the element or its tables."""
matches: list[str] = []
for el in elements:
if element_type in {"shapes", "points"}:
try:
el_cols = sdata[el].columns
except (KeyError, AttributeError):
el_cols = ()
if color in el_cols:
matches.append(f"element '{el}'")
continue
try:
tables = get_element_annotators(sdata, el)
except (KeyError, ValueError):
tables = set()
for t in tables:
adata = sdata[t]
if color in adata.obs.columns or color in adata.var_names:
matches.append(f"table '{t}' (annotating '{el}')")
break
if matches:
locations = ", ".join(matches)
raise ValueError(
f"`color={color!r}` is ambiguous: it is a valid matplotlib color name AND a column "
f"name in {locations}. Disambiguate by either passing an unambiguous color form "
f"(hex string like '#ffa500' or an RGB(A) tuple), or by renaming the column."
)


def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]:
colorbar = param_dict.get("colorbar", "auto")
if colorbar not in {True, False, None, "auto"}:
Expand Down Expand Up @@ -2330,7 +2365,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
if not isinstance(color, str | tuple | list):
raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.")
if _is_color_like(color):
logger.info("Value for parameter 'color' appears to be a color, using it as such.")
if isinstance(color, str):
_check_color_column_collision(param_dict["sdata"], param_dict["element"], color, element_type)
param_dict["col_for_color"] = None
param_dict["color"] = Color(color)
if param_dict["color"].alpha_is_user_defined():
Expand Down
40 changes: 40 additions & 0 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import geopandas as gpd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import scanpy as sc
import xarray as xr
from anndata import AnnData
from shapely.geometry import Point
from spatialdata import SpatialData
from spatialdata.models import PointsModel, ShapesModel, TableModel

import spatialdata_plot
from spatialdata_plot.pl.render_params import Color
Expand Down Expand Up @@ -380,3 +384,39 @@ def test_exact_match_selects_that_scale(self):
multiscale = self._make_multiscale((3, 500, 500), [2, 2])
result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=2.5, height=2.5)
assert result.sizes["x"] == 250


def test_color_column_collision_on_element_columns_raises():
# regression test for #619, element-column path: points with an "orange" column + color="orange".
points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0], "orange": [0.1, 0.2, 0.3]}))
sdata = SpatialData(points={"pts": points})

with pytest.raises(ValueError, match=r"color='orange'.*ambiguous.*element 'pts'"):
sdata.pl.render_points("pts", color="orange")

sdata.pl.render_points("pts", color="#ffa500")
sdata.pl.render_points("pts", color=(1.0, 0.65, 0.0))


def test_color_column_collision_on_annotating_table_raises():
# regression test for #619, table path: element has no "orange" column but its annotating table does.
shapes = ShapesModel.parse(gpd.GeoDataFrame({"geometry": [Point(i, 0) for i in range(3)], "radius": [0.5] * 3}))
obs = pd.DataFrame(
{
"region": pd.Categorical(["s"] * 3),
"instance_id": list(range(3)),
"orange": pd.Categorical(["A", "B", "A"]),
}
)
table = TableModel.parse(
AnnData(X=np.zeros((3, 1)), obs=obs),
region="s",
region_key="region",
instance_key="instance_id",
)
sdata = SpatialData(shapes={"s": shapes}, tables={"t": table})

with pytest.raises(ValueError, match=r"color='orange'.*ambiguous.*table 't'"):
sdata.pl.render_shapes("s", color="orange")

sdata.pl.render_shapes("s", color="#ffa500")
Loading