Skip to content

Commit 6338980

Browse files
timtreisclaude
andcommitted
Add col_for_color to labels, enabling literal color values like color='red'
Labels now use the same color/col_for_color split as shapes and points, so `render_labels(color="red")` is correctly recognized as a literal color instead of being treated as a column name. Fixes #470 and #478. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1554eb2 commit 6338980

5 files changed

Lines changed: 34 additions & 25 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,7 @@ def render_labels(
748748
sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams(
749749
element=element,
750750
color=param_values["color"],
751+
col_for_color=param_values["col_for_color"],
751752
groups=param_values["groups"],
752753
contour_px=param_values["contour_px"],
753754
cmap_params=cmap_params,
@@ -1121,14 +1122,13 @@ def _draw_colorbar(
11211122

11221123
if wanted_labels_on_this_cs:
11231124
table = params_copy.table_name
1124-
if table is not None:
1125-
assert isinstance(params_copy.color, str)
1126-
colors = sc.get.obs_df(sdata[table], [params_copy.color])
1127-
if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype):
1125+
if table is not None and params_copy.col_for_color is not None:
1126+
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
1127+
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
11281128
_maybe_set_colors(
11291129
source=sdata[table],
11301130
target=sdata[table],
1131-
key=params_copy.color,
1131+
key=params_copy.col_for_color,
11321132
palette=params_copy.palette,
11331133
)
11341134

src/spatialdata_plot/pl/render.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,7 +1262,7 @@ def _render_labels(
12621262
table_name = render_params.table_name
12631263
table_layer = render_params.table_layer
12641264
palette = render_params.palette
1265-
color = render_params.color
1265+
col_for_color = render_params.col_for_color
12661266
groups = render_params.groups
12671267
scale = render_params.scale
12681268

@@ -1311,23 +1311,25 @@ def _render_labels(
13111311

13121312
_, trans_data = _prepare_transformation(label, coordinate_system, ax)
13131313

1314+
na_color = render_params.color if render_params.color else render_params.cmap_params.na_color
13141315
color_source_vector, color_vector, categorical = _set_color_source_vec(
13151316
sdata=sdata_filt,
13161317
element=label,
13171318
element_name=element,
1318-
value_to_plot=color,
1319+
value_to_plot=col_for_color,
13191320
groups=groups,
13201321
palette=palette,
1321-
na_color=render_params.cmap_params.na_color,
1322+
na_color=na_color,
13221323
cmap_params=render_params.cmap_params,
13231324
table_name=table_name,
13241325
table_layer=table_layer,
1326+
render_type="labels",
13251327
coordinate_system=coordinate_system,
13261328
)
13271329

13281330
# rasterize could have removed labels from label
13291331
# only problematic if color is specified
1330-
if rasterize and color is not None:
1332+
if rasterize and col_for_color is not None:
13311333
labels_in_rasterized_image = np.unique(label.values)
13321334
mask = np.isin(instance_id, labels_in_rasterized_image)
13331335
instance_id = instance_id[mask]
@@ -1405,15 +1407,15 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
14051407
colorbar_requested = _should_request_colorbar(
14061408
render_params.colorbar,
14071409
has_mappable=cax is not None,
1408-
is_continuous=color is not None and color_source_vector is None and not categorical,
1410+
is_continuous=col_for_color is not None and color_source_vector is None and not categorical,
14091411
)
14101412

14111413
_ = _decorate_axs(
14121414
ax=ax,
14131415
cax=cax,
14141416
fig_params=fig_params,
14151417
adata=table,
1416-
value_to_plot=color,
1418+
value_to_plot=col_for_color,
14171419
color_source_vector=color_source_vector,
14181420
color_vector=color_vector,
14191421
palette=palette,
@@ -1429,7 +1431,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
14291431
colorbar_requests=colorbar_requests,
14301432
colorbar_label=_resolve_colorbar_label(
14311433
render_params.colorbar_params,
1432-
color if isinstance(color, str) else None,
1434+
col_for_color if isinstance(col_for_color, str) else None,
14331435
),
14341436
scalebar_dx=scalebar_params.scalebar_dx,
14351437
scalebar_units=scalebar_params.scalebar_units,

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ class LabelsRenderParams:
278278

279279
cmap_params: CmapParams
280280
element: str
281-
color: str | None = None
281+
color: Color | None = None
282+
col_for_color: str | None = None
282283
groups: str | list[str] | None = None
283284
contour_px: int | None = None
284285
outline: bool = False

src/spatialdata_plot/pl/utils.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ def _set_color_source_vec(
981981
alpha: float = 1.0,
982982
table_name: str | None = None,
983983
table_layer: str | None = None,
984-
render_type: Literal["points"] | None = None,
984+
render_type: Literal["points", "labels"] | None = None,
985985
coordinate_system: str | None = None,
986986
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
987987
if value_to_plot is None and element is not None:
@@ -1451,7 +1451,7 @@ def _get_categorical_color_mapping(
14511451
alpha: float = 1,
14521452
groups: list[str] | str | None = None,
14531453
palette: list[str] | str | None = None,
1454-
render_type: Literal["points"] | None = None,
1454+
render_type: Literal["points", "labels"] | None = None,
14551455
) -> Mapping[str, str]:
14561456
if not isinstance(color_source_vector, Categorical):
14571457
raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}")
@@ -2138,15 +2138,15 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
21382138
}:
21392139
if not isinstance(color, str | tuple | list):
21402140
raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.")
2141-
if element_type in {"shapes", "points"}:
2141+
if element_type in {"shapes", "points", "labels"}:
21422142
if _is_color_like(color):
21432143
logger.info("Value for parameter 'color' appears to be a color, using it as such.")
21442144
param_dict["col_for_color"] = None
21452145
param_dict["color"] = Color(color)
21462146
if param_dict["color"].alpha_is_user_defined():
21472147
if element_type == "points" and param_dict.get("alpha") is None:
21482148
param_dict["alpha"] = param_dict["color"].get_alpha_as_float()
2149-
elif element_type == "shapes" and param_dict.get("fill_alpha") is None:
2149+
elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None:
21502150
param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float()
21512151
else:
21522152
logger.info(
@@ -2158,7 +2158,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
21582158
param_dict["color"] = None
21592159
else:
21602160
raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.")
2161-
elif "color" in param_dict and element_type != "labels":
2161+
elif "color" in param_dict and element_type != "images":
21622162
param_dict["col_for_color"] = None
21632163

21642164
outline_width = param_dict.get("outline_width")
@@ -2455,15 +2455,18 @@ def _validate_label_render_params(
24552455
element_params[el]["table_layer"] = param_dict["table_layer"]
24562456

24572457
element_params[el]["table_name"] = None
2458-
element_params[el]["color"] = None
2459-
color = param_dict["color"]
2460-
if color is not None:
2461-
color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True)
2458+
element_params[el]["color"] = param_dict["color"] # literal Color or None
2459+
element_params[el]["col_for_color"] = None
2460+
if (col_for_color := param_dict["col_for_color"]) is not None:
2461+
col_for_color, table_name = _validate_col_for_column_table(
2462+
sdata, el, col_for_color, param_dict["table_name"], labels=True
2463+
)
24622464
element_params[el]["table_name"] = table_name
2463-
element_params[el]["color"] = color
2465+
element_params[el]["col_for_color"] = col_for_color
24642466

2465-
element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None
2466-
element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None
2467+
has_col = element_params[el]["col_for_color"] is not None
2468+
element_params[el]["palette"] = param_dict["palette"] if has_col else None
2469+
element_params[el]["groups"] = param_dict["groups"] if has_col else None
24672470
element_params[el]["colorbar"] = param_dict["colorbar"]
24682471
element_params[el]["colorbar_params"] = param_dict["colorbar_params"]
24692472

tests/pl/test_render_labels.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData):
8484
.pl.show()
8585
)
8686

87+
def test_plot_can_color_by_color_name(self, sdata_blobs: SpatialData):
88+
sdata_blobs.pl.render_labels("blobs_labels", color="red").pl.show()
89+
8790
def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: SpatialData):
8891
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()
8992

0 commit comments

Comments
 (0)