Skip to content

Commit 264bc92

Browse files
timtreisclaude
andauthored
Fix datashader failing with single-category Categorical (#539)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c0ab403 commit 264bc92

File tree

3 files changed

+79
-16
lines changed

3 files changed

+79
-16
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@
6262
_Normalize = Normalize | abc.Sequence[Normalize]
6363

6464

65+
def _build_color_key_from_categorical(color_vector: pd.Categorical | np.ndarray | object) -> list[str] | None:
66+
"""Build a datashader ``color_key`` list from a categorical color vector.
67+
68+
Returns ``None`` when *color_vector* is not a :class:`pd.Categorical` or
69+
has no categories. Hex colours are stripped of their alpha channel;
70+
named colours (e.g. ``"red"``) are passed through unchanged.
71+
"""
72+
if not isinstance(getattr(color_vector, "dtype", None), pd.CategoricalDtype):
73+
return None
74+
cat_values = color_vector.categories.values # type: ignore[union-attr]
75+
if len(cat_values) == 0:
76+
return None
77+
return [_hex_no_alpha(x) if isinstance(x, str) and x.startswith("#") else x for x in cat_values]
78+
79+
6580
def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]:
6681
"""Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
6782
layout: dict[str, object] = {}
@@ -301,7 +316,7 @@ def _render_shapes(
301316
# Render shapes with datashader
302317
color_by_categorical = col_for_color is not None and color_source_vector is not None
303318
aggregate_with_reduction = None
304-
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
319+
if col_for_color is not None:
305320
if color_by_categorical:
306321
agg = cvs.polygons(
307322
transformed_element,
@@ -356,11 +371,7 @@ def _render_shapes(
356371
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
357372
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
358373

359-
color_key = (
360-
[_hex_no_alpha(x) for x in color_vector.categories.values]
361-
if isinstance(color_vector.dtype, pd.CategoricalDtype) and (len(color_vector.categories.values) > 1)
362-
else None
363-
)
374+
color_key = _build_color_key_from_categorical(color_vector)
364375

365376
if color_by_categorical or col_for_color is None:
366377
ds_cmap = None
@@ -814,7 +825,7 @@ def _render_points(
814825
if color_by_categorical and transformed_element[col_for_color].values.dtype == object:
815826
transformed_element[col_for_color] = transformed_element[col_for_color].astype("category")
816827
aggregate_with_reduction = None
817-
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
828+
if col_for_color is not None:
818829
if color_by_categorical:
819830
agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count()))
820831
else:
@@ -851,15 +862,7 @@ def _render_points(
851862
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
852863
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
853864

854-
color_key: list[str] | None = (
855-
list(color_vector.categories.values)
856-
if isinstance(color_vector.dtype, pd.CategoricalDtype) and (len(color_vector.categories.values) > 1)
857-
else None
858-
)
859-
860-
# remove alpha from color if it's hex
861-
if color_key is not None and all(len(x) == 9 for x in color_key) and color_key[0][0] == "#":
862-
color_key = [x[:-2] for x in color_key]
865+
color_key = _build_color_key_from_categorical(color_vector)
863866
if isinstance(color_vector[0], str) and (
864867
color_vector is not None and all(len(x) == 9 for x in color_vector) and color_vector[0][0] == "#"
865868
):

tests/pl/test_render_points.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,37 @@ def test_datashader_colors_points_from_table_obs(sdata_blobs: SpatialData):
572572
method="datashader",
573573
size=5,
574574
).pl.show()
575+
576+
577+
def test_plot_datashader_single_category_points(sdata_blobs: SpatialData):
578+
"""Datashader with a single-category Categorical must not raise.
579+
580+
Regression test for https://github.com/scverse/spatialdata-plot/issues/483.
581+
Before the fix, color_key was None when there was only 1 category, but ds.by()
582+
still produced a 3D DataArray, causing datashader to raise:
583+
ValueError: Color key must be provided, with at least as many colors as
584+
there are categorical fields
585+
"""
586+
n_obs = len(sdata_blobs["blobs_points"])
587+
obs = pd.DataFrame(
588+
{
589+
"instance_id": np.arange(n_obs),
590+
"region": pd.Categorical(["blobs_points"] * n_obs),
591+
"foo": pd.Categorical(["only_cat"] * n_obs),
592+
}
593+
)
594+
table = TableModel.parse(
595+
adata=AnnData(get_standard_RNG().normal(size=(n_obs, 3)), obs=obs),
596+
region="blobs_points",
597+
region_key="region",
598+
instance_key="instance_id",
599+
)
600+
sdata_blobs["single_cat_table"] = table
601+
602+
sdata_blobs.pl.render_points(
603+
"blobs_points",
604+
color="foo",
605+
table_name="single_cat_table",
606+
method="datashader",
607+
size=5,
608+
).pl.show()

tests/pl/test_render_shapes.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,3 +973,29 @@ def test_plot_can_handle_mixed_numeric_and_color_data(sdata_blobs: SpatialData):
973973
# Mixed numeric / non-numeric values should raise a TypeError
974974
with pytest.raises(TypeError, match="contains both numeric and non-numeric values"):
975975
sdata_blobs.pl.render_shapes(element="blobs_circles", color="mixed_data", na_color="gray").pl.show()
976+
977+
978+
def test_plot_datashader_single_category(sdata_blobs: SpatialData):
979+
"""Datashader with a single-category Categorical must not raise.
980+
981+
Regression test for https://github.com/scverse/spatialdata-plot/issues/483.
982+
Before the fix, color_key was None when there was only 1 category, but ds.by()
983+
still produced a 3D DataArray, causing datashader to raise:
984+
ValueError: Color key must be provided, with at least as many colors as
985+
there are categorical fields
986+
"""
987+
n_obs = len(sdata_blobs["blobs_polygons"])
988+
adata = AnnData(get_standard_RNG().normal(size=(n_obs, 10)))
989+
adata.obs = pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["a", "b", "c"])
990+
adata.obs["category"] = pd.Categorical(["only_cat"] * n_obs)
991+
adata.obs["instance_id"] = list(range(n_obs))
992+
adata.obs["region"] = "blobs_polygons"
993+
table = TableModel.parse(
994+
adata=adata,
995+
region_key="region",
996+
instance_key="instance_id",
997+
region="blobs_polygons",
998+
)
999+
sdata_blobs["table"] = table
1000+
1001+
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="category", method="datashader").pl.show()

0 commit comments

Comments
 (0)