From 8fdc7b0ed5c8a23992f0030048c894f22c504e7b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:36:34 +0100 Subject: [PATCH 1/4] Add new Colors and colors parameter to all methods --- xarray_plotly/__init__.py | 3 +- xarray_plotly/accessor.py | 47 +++++++++++++++++++++++- xarray_plotly/common.py | 75 +++++++++++++++++++++++++++++++++++++-- xarray_plotly/plotting.py | 71 ++++++++++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 5 deletions(-) diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index d377b4c..1b35062 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -52,7 +52,7 @@ from xarray_plotly import config from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor -from xarray_plotly.common import SLOT_ORDERS, auto +from xarray_plotly.common import SLOT_ORDERS, Colors, auto from xarray_plotly.figures import ( add_secondary_y, overlay, @@ -61,6 +61,7 @@ __all__ = [ "SLOT_ORDERS", + "Colors", "add_secondary_y", "auto", "config", diff --git a/xarray_plotly/accessor.py b/xarray_plotly/accessor.py index 41f60b4..eb3ddf0 100644 --- a/xarray_plotly/accessor.py +++ b/xarray_plotly/accessor.py @@ -6,7 +6,7 @@ from xarray import DataArray, Dataset from xarray_plotly import plotting -from xarray_plotly.common import SlotValue, auto +from xarray_plotly.common import Colors, SlotValue, auto from xarray_plotly.config import _options @@ -53,6 +53,7 @@ def line( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive line plot. @@ -67,6 +68,7 @@ def line( facet_col: Dimension for subplot columns. Default: fifth dimension. facet_row: Dimension for subplot rows. Default: sixth dimension. animation_frame: Dimension for animation. Default: seventh dimension. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.line()`. Returns: @@ -81,6 +83,7 @@ def line( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -93,6 +96,7 @@ def bar( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive bar chart. @@ -106,6 +110,7 @@ def bar( facet_col: Dimension for subplot columns. Default: fourth dimension. facet_row: Dimension for subplot rows. Default: fifth dimension. animation_frame: Dimension for animation. Default: sixth dimension. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.bar()`. Returns: @@ -119,6 +124,7 @@ def bar( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -131,6 +137,7 @@ def area( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive stacked area chart. @@ -144,6 +151,7 @@ def area( facet_col: Dimension for subplot columns. Default: fourth dimension. facet_row: Dimension for subplot rows. Default: fifth dimension. animation_frame: Dimension for animation. Default: sixth dimension. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.area()`. Returns: @@ -157,6 +165,7 @@ def area( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -168,6 +177,7 @@ def fast_bar( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create a bar-like chart using stacked areas for better performance. @@ -180,6 +190,7 @@ def fast_bar( facet_col: Dimension for subplot columns. Default: third dimension. facet_row: Dimension for subplot rows. Default: fourth dimension. animation_frame: Dimension for animation. Default: fifth dimension. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.area()`. Returns: @@ -192,6 +203,7 @@ def fast_bar( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -205,6 +217,7 @@ def scatter( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive scatter plot. @@ -223,6 +236,7 @@ def scatter( facet_col: Dimension for subplot columns. Default: fourth dimension. facet_row: Dimension for subplot rows. Default: fifth dimension. animation_frame: Dimension for animation. Default: sixth dimension. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.scatter()`. Returns: @@ -237,6 +251,7 @@ def scatter( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -248,6 +263,7 @@ def box( facet_col: SlotValue = None, facet_row: SlotValue = None, animation_frame: SlotValue = None, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive box plot. @@ -263,6 +279,7 @@ def box( facet_col: Dimension for subplot columns. Default: None (aggregated). facet_row: Dimension for subplot rows. Default: None (aggregated). animation_frame: Dimension for animation. Default: None (aggregated). + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.box()`. Returns: @@ -275,6 +292,7 @@ def box( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -286,6 +304,7 @@ def imshow( facet_col: SlotValue = auto, animation_frame: SlotValue = auto, robust: bool = False, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive heatmap image. @@ -303,6 +322,7 @@ def imshow( facet_col: Dimension for subplot columns. Default: third dimension. animation_frame: Dimension for animation. Default: fourth dimension. robust: If True, use 2nd/98th percentiles for color bounds (handles outliers). + colors: Color scale name (e.g., "Viridis", "RdBu"). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.imshow()`. Use `zmin` and `zmax` to manually set color scale bounds. @@ -316,6 +336,7 @@ def imshow( facet_col=facet_col, animation_frame=animation_frame, robust=robust, + colors=colors, **px_kwargs, ) @@ -326,6 +347,7 @@ def pie( color: SlotValue = None, facet_col: SlotValue = auto, facet_row: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive pie chart. @@ -337,6 +359,7 @@ def pie( color: Dimension for color grouping. Default: None (uses names). facet_col: Dimension for subplot columns. Default: second dimension. facet_row: Dimension for subplot rows. Default: third dimension. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.pie()`. Returns: @@ -348,6 +371,7 @@ def pie( color=color, facet_col=facet_col, facet_row=facet_row, + colors=colors, **px_kwargs, ) @@ -427,6 +451,7 @@ def line( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive line plot. @@ -440,6 +465,7 @@ def line( facet_col: Dimension for subplot columns. facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.line()`. Returns: @@ -455,6 +481,7 @@ def line( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -468,6 +495,7 @@ def bar( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive bar chart. @@ -480,6 +508,7 @@ def bar( facet_col: Dimension for subplot columns. facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.bar()`. Returns: @@ -494,6 +523,7 @@ def bar( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -507,6 +537,7 @@ def area( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive stacked area chart. @@ -519,6 +550,7 @@ def area( facet_col: Dimension for subplot columns. facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.area()`. Returns: @@ -533,6 +565,7 @@ def area( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -545,6 +578,7 @@ def fast_bar( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create a bar-like chart using stacked areas for better performance. @@ -556,6 +590,7 @@ def fast_bar( facet_col: Dimension for subplot columns. facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.area()`. Returns: @@ -569,6 +604,7 @@ def fast_bar( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -583,6 +619,7 @@ def scatter( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive scatter plot. @@ -596,6 +633,7 @@ def scatter( facet_col: Dimension for subplot columns. facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.scatter()`. Returns: @@ -611,6 +649,7 @@ def scatter( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -623,6 +662,7 @@ def box( facet_col: SlotValue = None, facet_row: SlotValue = None, animation_frame: SlotValue = None, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive box plot. @@ -634,6 +674,7 @@ def box( facet_col: Dimension for subplot columns. facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.box()`. Returns: @@ -647,6 +688,7 @@ def box( facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, + colors=colors, **px_kwargs, ) @@ -658,6 +700,7 @@ def pie( color: SlotValue = None, facet_col: SlotValue = auto, facet_row: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """Create an interactive pie chart. @@ -668,6 +711,7 @@ def pie( color: Dimension for color grouping. facet_col: Dimension for subplot columns. facet_row: Dimension for subplot rows. + colors: Color specification (scale name, list, or dict). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.pie()`. Returns: @@ -680,5 +724,6 @@ def pie( color=color, facet_col=facet_col, facet_row=facet_row, + colors=colors, **px_kwargs, ) diff --git a/xarray_plotly/common.py b/xarray_plotly/common.py index 898c980..743374a 100644 --- a/xarray_plotly/common.py +++ b/xarray_plotly/common.py @@ -2,13 +2,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import functools +import warnings +from collections.abc import Hashable, Mapping, Sequence +from typing import TYPE_CHECKING, Any + +import plotly.express as px from xarray_plotly.config import DEFAULT_SLOT_ORDERS, _options if TYPE_CHECKING: - from collections.abc import Hashable, Sequence - import pandas as pd from xarray import DataArray @@ -27,6 +30,15 @@ def __repr__(self) -> str: SlotValue = _AUTO | str | None """Type alias for slot values: auto, explicit dimension name, or None (skip).""" +Colors = str | Sequence[str] | Mapping[str, str] | None +"""Type alias for unified colors parameter. + +- str: Named color scale (e.g., "Viridis" for continuous, "D3" for discrete) +- Sequence[str]: List of colors for discrete sequence (e.g., ["red", "blue"]) +- Mapping[str, str]: Explicit mapping of values to colors (e.g., {"A": "red"}) +- None: Use Plotly defaults +""" + # Re-export for backward compatibility SLOT_ORDERS = DEFAULT_SLOT_ORDERS """Slot orders per plot type. @@ -227,3 +239,60 @@ def build_labels( labels[value_col] = get_label(darray, "value") return labels + + +@functools.cache +def _get_qualitative_scale_names() -> frozenset[str]: + """Get all named qualitative (discrete) color scales from Plotly.""" + return frozenset( + name + for name in dir(px.colors.qualitative) + if not name.startswith("_") and name[0].isupper() + ) + + +def resolve_colors(colors: Colors, px_kwargs: dict[str, Any]) -> dict[str, Any]: + """Map unified `colors` parameter to appropriate Plotly px_kwargs. + + Direct color_* kwargs take precedence and trigger a warning if + colors was also specified. + + Args: + colors: Unified color specification (str, list, dict, or None). + px_kwargs: Existing kwargs to pass to Plotly Express. + + Returns: + Updated px_kwargs with color parameters injected. + """ + if colors is None: + return px_kwargs + + # Check if any color_* kwarg is present - these take precedence + color_kwargs = [k for k in px_kwargs if k.startswith("color_")] + if color_kwargs: + warnings.warn( + f"`colors` parameter ignored because {color_kwargs[0]!r} " + f"was explicitly provided in px_kwargs.", + UserWarning, + stacklevel=3, + ) + return px_kwargs + + px_kwargs = px_kwargs.copy() + + if isinstance(colors, str): + # Check if it's a qualitative (discrete) palette name + if colors in _get_qualitative_scale_names(): + px_kwargs["color_discrete_sequence"] = getattr(px.colors.qualitative, colors) + else: + # Assume continuous scale + px_kwargs["color_continuous_scale"] = colors + elif isinstance(colors, Mapping): + px_kwargs["color_discrete_map"] = dict(colors) + elif isinstance(colors, Sequence): + px_kwargs["color_discrete_sequence"] = list(colors) + else: + msg = f"`colors` must be str, list, dict, or None, got {type(colors).__name__}" + raise TypeError(msg) + + return px_kwargs diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index b1e5085..2a78b64 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -11,12 +11,14 @@ import plotly.express as px from xarray_plotly.common import ( + Colors, SlotValue, assign_slots, auto, build_labels, get_label, get_value_col, + resolve_colors, to_dataframe, ) from xarray_plotly.figures import ( @@ -38,6 +40,7 @@ def line( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """ @@ -64,6 +67,13 @@ def line( Dimension for subplot rows. Default: sixth dimension. animation_frame Dimension for animation. Default: seventh dimension. + colors + Unified color specification. Can be: + - A named continuous scale (e.g., "Viridis") + - A named discrete palette (e.g., "D3", "Plotly") + - A list of colors (e.g., ["red", "blue", "green"]) + - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) + Explicit color_* kwargs in px_kwargs take precedence. **px_kwargs Additional arguments passed to `plotly.express.line()`. @@ -71,6 +81,7 @@ def line( ------- plotly.graph_objects.Figure """ + px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "line", @@ -111,6 +122,7 @@ def bar( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """ @@ -135,6 +147,13 @@ def bar( Dimension for subplot rows. Default: fifth dimension. animation_frame Dimension for animation. Default: sixth dimension. + colors + Unified color specification. Can be: + - A named continuous scale (e.g., "Viridis") + - A named discrete palette (e.g., "D3", "Plotly") + - A list of colors (e.g., ["red", "blue", "green"]) + - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) + Explicit color_* kwargs in px_kwargs take precedence. **px_kwargs Additional arguments passed to `plotly.express.bar()`. @@ -142,6 +161,7 @@ def bar( ------- plotly.graph_objects.Figure """ + px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "bar", @@ -263,6 +283,7 @@ def fast_bar( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """ @@ -293,6 +314,13 @@ def fast_bar( Dimension for subplot rows. Default: fourth dimension. animation_frame Dimension for animation. Default: fifth dimension. + colors + Unified color specification. Can be: + - A named continuous scale (e.g., "Viridis") + - A named discrete palette (e.g., "D3", "Plotly") + - A list of colors (e.g., ["red", "blue", "green"]) + - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) + Explicit color_* kwargs in px_kwargs take precedence. **px_kwargs Additional arguments passed to `plotly.express.area()`. @@ -300,6 +328,7 @@ def fast_bar( ------- plotly.graph_objects.Figure """ + px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "fast_bar", @@ -341,6 +370,7 @@ def area( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """ @@ -365,6 +395,13 @@ def area( Dimension for subplot rows. Default: fifth dimension. animation_frame Dimension for animation. Default: sixth dimension. + colors + Unified color specification. Can be: + - A named continuous scale (e.g., "Viridis") + - A named discrete palette (e.g., "D3", "Plotly") + - A list of colors (e.g., ["red", "blue", "green"]) + - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) + Explicit color_* kwargs in px_kwargs take precedence. **px_kwargs Additional arguments passed to `plotly.express.area()`. @@ -372,6 +409,7 @@ def area( ------- plotly.graph_objects.Figure """ + px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "area", @@ -409,6 +447,7 @@ def box( facet_col: SlotValue = None, facet_row: SlotValue = None, animation_frame: SlotValue = None, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """ @@ -433,6 +472,13 @@ def box( Dimension for subplot rows. Default: None (aggregated). animation_frame Dimension for animation. Default: None (aggregated). + colors + Unified color specification. Can be: + - A named continuous scale (e.g., "Viridis") + - A named discrete palette (e.g., "D3", "Plotly") + - A list of colors (e.g., ["red", "blue", "green"]) + - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) + Explicit color_* kwargs in px_kwargs take precedence. **px_kwargs Additional arguments passed to `plotly.express.box()`. @@ -440,6 +486,7 @@ def box( ------- plotly.graph_objects.Figure """ + px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "box", @@ -478,6 +525,7 @@ def scatter( facet_col: SlotValue = auto, facet_row: SlotValue = auto, animation_frame: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """ @@ -509,6 +557,13 @@ def scatter( Dimension for subplot rows. Default: fifth dimension. animation_frame Dimension for animation. Default: sixth dimension. + colors + Unified color specification. Can be: + - A named continuous scale (e.g., "Viridis") + - A named discrete palette (e.g., "D3", "Plotly") + - A list of colors (e.g., ["red", "blue", "green"]) + - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) + Explicit color_* kwargs in px_kwargs take precedence. **px_kwargs Additional arguments passed to `plotly.express.scatter()`. @@ -516,6 +571,7 @@ def scatter( ------- plotly.graph_objects.Figure """ + px_kwargs = resolve_colors(colors, px_kwargs) # If y is a dimension, exclude it from slot assignment y_is_dim = y != "value" and y in darray.dims dims_for_slots = [d for d in darray.dims if d != y] if y_is_dim else list(darray.dims) @@ -565,6 +621,7 @@ def imshow( facet_col: SlotValue = auto, animation_frame: SlotValue = auto, robust: bool = False, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """ @@ -596,6 +653,11 @@ def imshow( robust If True, compute color bounds using 2nd and 98th percentiles for robustness against outliers. Default: False (uses min/max). + colors + Unified color specification. For imshow, typically a named + continuous scale (e.g., "Viridis", "RdBu"). Lists and dicts + are not applicable for heatmaps. + Explicit color_continuous_scale in px_kwargs takes precedence. **px_kwargs Additional arguments passed to `plotly.express.imshow()`. Use `zmin` and `zmax` to manually set color scale bounds. @@ -604,6 +666,7 @@ def imshow( ------- plotly.graph_objects.Figure """ + px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "imshow", @@ -648,6 +711,7 @@ def pie( color: SlotValue = None, facet_col: SlotValue = auto, facet_row: SlotValue = auto, + colors: Colors = None, **px_kwargs: Any, ) -> go.Figure: """ @@ -668,6 +732,12 @@ def pie( Dimension for subplot columns. Default: second dimension. facet_row Dimension for subplot rows. Default: third dimension. + colors + Unified color specification. Can be: + - A named discrete palette (e.g., "D3", "Plotly") + - A list of colors (e.g., ["red", "blue", "green"]) + - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) + Explicit color_* kwargs in px_kwargs take precedence. **px_kwargs Additional arguments passed to `plotly.express.pie()`. @@ -675,6 +745,7 @@ def pie( ------- plotly.graph_objects.Figure """ + px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "pie", From 3ac4f49bb1d60f35619d5d99a129670fa5bbdffe Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:38:14 +0100 Subject: [PATCH 2/4] Add tests for new colors parameter --- tests/test_accessor.py | 124 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/tests/test_accessor.py b/tests/test_accessor.py index a8f23d2..d14d682 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -401,3 +401,127 @@ def test_imshow_animation_consistent_bounds(self) -> None: coloraxis = fig.layout.coloraxis assert coloraxis.cmin == 0.0 assert coloraxis.cmax == 70.0 + + +class TestColorsParameter: + """Tests for the unified colors parameter.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Create test DataArrays.""" + self.da = xr.DataArray( + np.random.rand(10, 3), + dims=["time", "city"], + coords={"city": ["A", "B", "C"]}, + ) + + def test_colors_list_sets_discrete_sequence(self) -> None: + """Test that a list of colors sets color_discrete_sequence.""" + fig = self.da.plotly.line(colors=["red", "blue", "green"]) + # Check that traces have the expected colors + assert len(fig.data) == 3 + assert fig.data[0].line.color == "red" + assert fig.data[1].line.color == "blue" + assert fig.data[2].line.color == "green" + + def test_colors_dict_sets_discrete_map(self) -> None: + """Test that a dict sets color_discrete_map.""" + fig = self.da.plotly.line(colors={"A": "red", "B": "blue", "C": "green"}) + # Traces should be colored according to the mapping + assert len(fig.data) == 3 + # Find traces by name and check their color + colors_by_name = {trace.name: trace.line.color for trace in fig.data} + assert colors_by_name["A"] == "red" + assert colors_by_name["B"] == "blue" + assert colors_by_name["C"] == "green" + + def test_colors_continuous_scale_string(self) -> None: + """Test that a continuous scale name sets color_continuous_scale.""" + da = xr.DataArray( + np.random.rand(50, 2), + dims=["point", "coord"], + coords={"coord": ["x", "y"]}, + ) + fig = da.plotly.scatter(y="coord", x="point", color="value", colors="Viridis") + # Plotly Express uses coloraxis in the layout for continuous scales + # Check that the colorscale was applied to the coloraxis + assert fig.layout.coloraxis.colorscale is not None + colorscale = fig.layout.coloraxis.colorscale + # Viridis should be in the colorscale definition + assert any("viridis" in str(c).lower() for c in colorscale) or len(colorscale) > 0 + + def test_colors_qualitative_palette_string(self) -> None: + """Test that a qualitative palette name sets color_discrete_sequence.""" + import plotly.express as px + + fig = self.da.plotly.line(colors="D3") + # D3 palette should be applied - check first trace color is from D3 + d3_colors = px.colors.qualitative.D3 + assert fig.data[0].line.color in d3_colors + + def test_colors_ignored_with_warning_when_px_kwargs_present(self) -> None: + """Test that colors is ignored with warning when color_* kwargs are present.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + fig = self.da.plotly.line( + colors="D3", color_discrete_sequence=["orange", "purple", "cyan"] + ) + # Should have raised a warning + assert len(w) == 1 + assert "colors" in str(w[0].message).lower() + assert "ignored" in str(w[0].message).lower() + # The explicit px_kwargs should take precedence + assert fig.data[0].line.color == "orange" + + def test_colors_none_uses_defaults(self) -> None: + """Test that colors=None uses Plotly defaults.""" + fig1 = self.da.plotly.line(colors=None) + fig2 = self.da.plotly.line() + # Both should produce the same result + assert fig1.data[0].line.color == fig2.data[0].line.color + + def test_colors_works_with_bar(self) -> None: + """Test colors parameter with bar chart.""" + fig = self.da.plotly.bar(colors=["#e41a1c", "#377eb8", "#4daf4a"]) + assert fig.data[0].marker.color == "#e41a1c" + + def test_colors_works_with_area(self) -> None: + """Test colors parameter with area chart.""" + fig = self.da.plotly.area(colors=["red", "green", "blue"]) + assert len(fig.data) == 3 + + def test_colors_works_with_scatter(self) -> None: + """Test colors parameter with scatter plot.""" + fig = self.da.plotly.scatter(colors=["red", "green", "blue"]) + assert len(fig.data) == 3 + + def test_colors_works_with_imshow(self) -> None: + """Test colors parameter with imshow (continuous scale).""" + da = xr.DataArray(np.random.rand(10, 10), dims=["y", "x"]) + fig = da.plotly.imshow(colors="RdBu") + # Plotly Express uses coloraxis in the layout for continuous scales + assert fig.layout.coloraxis.colorscale is not None + colorscale = fig.layout.coloraxis.colorscale + # RdBu should be in the colorscale definition + assert any("rdbu" in str(c).lower() for c in colorscale) or len(colorscale) > 0 + + def test_colors_works_with_pie(self) -> None: + """Test colors parameter with pie chart.""" + da = xr.DataArray([30, 40, 30], dims=["category"], coords={"category": ["A", "B", "C"]}) + fig = da.plotly.pie(colors={"A": "red", "B": "blue", "C": "green"}) + assert isinstance(fig, go.Figure) + + def test_colors_works_with_dataset(self) -> None: + """Test colors parameter works with Dataset accessor.""" + ds = xr.Dataset( + { + "temp": (["time"], np.random.rand(10)), + "precip": (["time"], np.random.rand(10)), + } + ) + fig = ds.plotly.line(colors=["red", "blue"]) + assert len(fig.data) == 2 + assert fig.data[0].line.color == "red" + assert fig.data[1].line.color == "blue" From 256c963116f524dc3fa887837a2f113fd7918f45 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:43:31 +0100 Subject: [PATCH 3/4] Update notebook to show off new colors parameter --- docs/examples/kwargs.ipynb | 65 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/docs/examples/kwargs.ipynb b/docs/examples/kwargs.ipynb index d28e100..e708ade 100644 --- a/docs/examples/kwargs.ipynb +++ b/docs/examples/kwargs.ipynb @@ -159,6 +159,71 @@ "xpx(change).imshow(color_continuous_scale=\"RdBu_r\", color_continuous_midpoint=0)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## colors (unified parameter)\n", + "\n", + "The `colors` parameter provides a simpler way to set colors without remembering the exact Plotly parameter name. It automatically maps to the correct parameter based on the input type:\n", + "\n", + "| Input | Maps To |\n", + "|-------|---------|\n", + "| `\"Viridis\"` (continuous scale name) | `color_continuous_scale` |\n", + "| `\"D3\"` (qualitative palette name) | `color_discrete_sequence` |\n", + "| `[\"red\", \"blue\"]` (list) | `color_discrete_sequence` |\n", + "| `{\"A\": \"red\"}` (dict) | `color_discrete_map` |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Named qualitative palette\n", + "xpx(stocks).line(colors=\"D3\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# List of custom colors\n", + "xpx(stocks).line(colors=[\"#E63946\", \"#457B9D\", \"#2A9D8F\", \"#E9C46A\", \"#F4A261\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Dict for explicit mapping\n", + "xpx(stocks).line(\n", + " colors={\n", + " \"GOOG\": \"red\",\n", + " \"AAPL\": \"blue\",\n", + " \"AMZN\": \"green\",\n", + " \"FB\": \"purple\",\n", + " \"NFLX\": \"orange\",\n", + " \"MSFT\": \"brown\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Continuous scale for heatmaps\n", + "xpx(stocks).imshow(colors=\"Plasma\")" + ] + }, { "cell_type": "markdown", "metadata": {}, From ed947c3dee4faa25c6ed49793ed7bf669b9d24f1 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:56:45 +0100 Subject: [PATCH 4/4] Improve test --- tests/test_accessor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_accessor.py b/tests/test_accessor.py index d14d682..23685d7 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -468,10 +468,11 @@ def test_colors_ignored_with_warning_when_px_kwargs_present(self) -> None: fig = self.da.plotly.line( colors="D3", color_discrete_sequence=["orange", "purple", "cyan"] ) - # Should have raised a warning - assert len(w) == 1 - assert "colors" in str(w[0].message).lower() - assert "ignored" in str(w[0].message).lower() + # Should have raised a warning about colors being ignored + assert any( + "colors" in str(m.message).lower() and "ignored" in str(m.message).lower() + for m in w + ), "Expected warning about 'colors' being 'ignored' not found" # The explicit px_kwargs should take precedence assert fig.data[0].line.color == "orange"