|
62 | 62 | _Normalize = Normalize | abc.Sequence[Normalize] |
63 | 63 |
|
64 | 64 |
|
| 65 | +def _build_color_key_from_categorical(color_vector: pd.Categorical | np.ndarray | object) -> list[str] | None: |
| 66 | + """Build a datashader ``color_key`` list from a categorical color vector. |
| 67 | +
|
| 68 | + Returns ``None`` when *color_vector* is not a :class:`pd.Categorical` or |
| 69 | + has no categories. Hex colours are stripped of their alpha channel; |
| 70 | + named colours (e.g. ``"red"``) are passed through unchanged. |
| 71 | + """ |
| 72 | + if not isinstance(getattr(color_vector, "dtype", None), pd.CategoricalDtype): |
| 73 | + return None |
| 74 | + cat_values = color_vector.categories.values # type: ignore[union-attr] |
| 75 | + if len(cat_values) == 0: |
| 76 | + return None |
| 77 | + return [_hex_no_alpha(x) if isinstance(x, str) and x.startswith("#") else x for x in cat_values] |
| 78 | + |
| 79 | + |
65 | 80 | def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]: |
66 | 81 | """Split colorbar params into layout hints, Matplotlib kwargs, and label override.""" |
67 | 82 | layout: dict[str, object] = {} |
@@ -301,7 +316,7 @@ def _render_shapes( |
301 | 316 | # Render shapes with datashader |
302 | 317 | color_by_categorical = col_for_color is not None and color_source_vector is not None |
303 | 318 | aggregate_with_reduction = None |
304 | | - if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1): |
| 319 | + if col_for_color is not None: |
305 | 320 | if color_by_categorical: |
306 | 321 | agg = cvs.polygons( |
307 | 322 | transformed_element, |
@@ -356,11 +371,7 @@ def _render_shapes( |
356 | 371 | agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) |
357 | 372 | agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) |
358 | 373 |
|
359 | | - color_key = ( |
360 | | - [_hex_no_alpha(x) for x in color_vector.categories.values] |
361 | | - if isinstance(color_vector.dtype, pd.CategoricalDtype) and (len(color_vector.categories.values) > 1) |
362 | | - else None |
363 | | - ) |
| 374 | + color_key = _build_color_key_from_categorical(color_vector) |
364 | 375 |
|
365 | 376 | if color_by_categorical or col_for_color is None: |
366 | 377 | ds_cmap = None |
@@ -814,7 +825,7 @@ def _render_points( |
814 | 825 | if color_by_categorical and transformed_element[col_for_color].values.dtype == object: |
815 | 826 | transformed_element[col_for_color] = transformed_element[col_for_color].astype("category") |
816 | 827 | aggregate_with_reduction = None |
817 | | - if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1): |
| 828 | + if col_for_color is not None: |
818 | 829 | if color_by_categorical: |
819 | 830 | agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count())) |
820 | 831 | else: |
@@ -851,15 +862,7 @@ def _render_points( |
851 | 862 | agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) |
852 | 863 | agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) |
853 | 864 |
|
854 | | - color_key: list[str] | None = ( |
855 | | - list(color_vector.categories.values) |
856 | | - if isinstance(color_vector.dtype, pd.CategoricalDtype) and (len(color_vector.categories.values) > 1) |
857 | | - else None |
858 | | - ) |
859 | | - |
860 | | - # remove alpha from color if it's hex |
861 | | - if color_key is not None and all(len(x) == 9 for x in color_key) and color_key[0][0] == "#": |
862 | | - color_key = [x[:-2] for x in color_key] |
| 865 | + color_key = _build_color_key_from_categorical(color_vector) |
863 | 866 | if isinstance(color_vector[0], str) and ( |
864 | 867 | color_vector is not None and all(len(x) == 9 for x in color_vector) and color_vector[0][0] == "#" |
865 | 868 | ): |
|
0 commit comments