diff --git a/docs/examples/fast_bar.ipynb b/docs/examples/fast_bar.ipynb new file mode 100644 index 0000000..0f4d2df --- /dev/null +++ b/docs/examples/fast_bar.ipynb @@ -0,0 +1,297 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Fast Bar Charts\n", + "\n", + "The `fast_bar()` method creates bar-like visualizations using stacked areas. This renders much faster than actual bar charts for large datasets because it uses a single polygon per trace instead of individual rectangles." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "\n", + "from xarray_plotly import config, xpx\n", + "\n", + "config.notebook()" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Basic Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# Quarterly revenue data by product and region\n", + "np.random.seed(42)\n", + "da = xr.DataArray(\n", + " np.random.rand(4, 3, 2) * 100 + 50,\n", + " dims=[\"quarter\", \"product\", \"region\"],\n", + " coords={\n", + " \"quarter\": [\"Q1\", \"Q2\", \"Q3\", \"Q4\"],\n", + " \"product\": [\"Widgets\", \"Gadgets\", \"Gizmos\"],\n", + " \"region\": [\"North\", \"South\"],\n", + " },\n", + " name=\"revenue\",\n", + ")\n", + "\n", + "xpx(da).fast_bar()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "# Comparison with regular bar()\n", + "xpx(da).bar()" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## With Faceting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "xpx(da).fast_bar(facet_col=\"region\")" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## With Animation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# Multi-year data for animation\n", + "np.random.seed(123)\n", + "da_anim = xr.DataArray(\n", + " np.random.rand(4, 3, 5) * 100 + 20,\n", + " dims=[\"quarter\", \"product\", \"year\"],\n", + " coords={\n", + " \"quarter\": [\"Q1\", \"Q2\", \"Q3\", \"Q4\"],\n", + " \"product\": [\"Widgets\", \"Gadgets\", \"Gizmos\"],\n", + " \"year\": [2020, 2021, 2022, 2023, 2024],\n", + " },\n", + " name=\"revenue\",\n", + ")\n", + "\n", + "xpx(da_anim).fast_bar(animation_frame=\"year\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Faceting + Animation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# 4D data: quarter x product x region x year\n", + "np.random.seed(456)\n", + "da_4d = xr.DataArray(\n", + " np.random.rand(4, 3, 2, 4) * 80 + 30,\n", + " dims=[\"quarter\", \"product\", \"region\", \"year\"],\n", + " coords={\n", + " \"quarter\": [\"Q1\", \"Q2\", \"Q3\", \"Q4\"],\n", + " \"product\": [\"Widgets\", \"Gadgets\", \"Gizmos\"],\n", + " \"region\": [\"North\", \"South\"],\n", + " \"year\": [2021, 2022, 2023, 2024],\n", + " },\n", + " name=\"revenue\",\n", + ")\n", + "\n", + "xpx(da_4d).fast_bar(facet_col=\"region\", animation_frame=\"year\")" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## Positive and Negative Values\n", + "\n", + "`fast_bar()` classifies each trace by its values:\n", + "- **Purely positive** → stacks upward\n", + "- **Purely negative** → stacks downward\n", + "- **Mixed signs** → warning + dashed line (use `bar()` instead)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "# Profit (positive) and Loss (negative) - stacks correctly\n", + "np.random.seed(789)\n", + "da_split = xr.DataArray(\n", + " np.column_stack(\n", + " [\n", + " np.random.rand(6) * 80 + 20, # Revenue: positive\n", + " -np.random.rand(6) * 50 - 10, # Costs: negative\n", + " ]\n", + " ),\n", + " dims=[\"month\", \"category\"],\n", + " coords={\n", + " \"month\": [\"Jan\", \"Feb\", \"Mar\", \"Apr\", \"May\", \"Jun\"],\n", + " \"category\": [\"Revenue\", \"Costs\"],\n", + " },\n", + " name=\"financials\",\n", + ")\n", + "\n", + "xpx(da_split).fast_bar()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "# With animation - sign classification is consistent across frames\n", + "np.random.seed(321)\n", + "da_split_anim = xr.DataArray(\n", + " np.stack(\n", + " [\n", + " np.column_stack([np.random.rand(6) * 80 + 20, -np.random.rand(6) * 50 - 10])\n", + " for _ in range(4)\n", + " ],\n", + " axis=-1,\n", + " ),\n", + " dims=[\"month\", \"category\", \"year\"],\n", + " coords={\n", + " \"month\": [\"Jan\", \"Feb\", \"Mar\", \"Apr\", \"May\", \"Jun\"],\n", + " \"category\": [\"Revenue\", \"Costs\"],\n", + " \"year\": [2021, 2022, 2023, 2024],\n", + " },\n", + " name=\"financials\",\n", + ")\n", + "\n", + "xpx(da_split_anim).fast_bar(animation_frame=\"year\")" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "## Mixed Sign Warning\n", + "\n", + "When a trace has both positive and negative values, `fast_bar()` shows a warning and displays it as a dashed line:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "# Both columns have mixed signs - triggers warning\n", + "da_mixed = xr.DataArray(\n", + " np.array(\n", + " [\n", + " [50, -30],\n", + " [-40, 60],\n", + " [30, -50],\n", + " [-20, 40],\n", + " ]\n", + " ),\n", + " dims=[\"month\", \"category\"],\n", + " coords={\n", + " \"month\": [\"Jan\", \"Feb\", \"Mar\", \"Apr\"],\n", + " \"category\": [\"A\", \"B\"],\n", + " },\n", + ")\n", + "\n", + "# This will show a warning\n", + "xpx(da_mixed).fast_bar()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# For mixed data, use bar() instead\n", + "xpx(da_mixed).bar()" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "## When to Use\n", + "\n", + "| Method | Use when... |\n", + "|--------|-------------|\n", + "| `fast_bar()` | Large datasets, animations, performance matters, data is same-sign per trace |\n", + "| `bar()` | Need grouped bars, pattern fills, or have mixed +/- values per trace |\n", + "| `area()` | Want smooth continuous fills |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mkdocs.yml b/mkdocs.yml index 1a1f0f1..5a38157 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -73,4 +73,7 @@ nav: - Dimensions & Facets: examples/dimensions.ipynb - Plotly Express Options: examples/kwargs.ipynb - Figure Customization: examples/figure.ipynb + - Combining Figures: examples/combining.ipynb + - Figure Manipulation: examples/manipulation.ipynb + - Fast Bar Charts: examples/fast_bar.ipynb - API Reference: api.md diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 9786112..a8f23d2 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -146,6 +146,66 @@ def test_area_returns_figure(self) -> None: fig = self.da_2d.plotly.area() assert isinstance(fig, go.Figure) + def test_fast_bar_returns_figure(self) -> None: + """Test that fast_bar() returns a Plotly Figure.""" + fig = self.da_2d.plotly.fast_bar() + assert isinstance(fig, go.Figure) + + def test_fast_bar_trace_styling(self) -> None: + """Test that fast_bar applies correct trace styling.""" + fig = self.da_2d.plotly.fast_bar() + for trace in fig.data: + assert trace.line.width == 0 + assert trace.line.shape == "hv" + assert trace.fillcolor is not None + + def test_fast_bar_animation_frames(self) -> None: + """Test that fast_bar styling applies to animation frames.""" + da = xr.DataArray( + np.random.rand(5, 3, 4), + dims=["time", "city", "year"], + ) + fig = da.plotly.fast_bar(animation_frame="year") + assert len(fig.frames) > 0 + for frame in fig.frames: + for trace in frame.data: + assert trace.line.width == 0 + assert trace.line.shape == "hv" + assert trace.fillcolor is not None + + def test_fast_bar_mixed_signs_dashed(self) -> None: + """Test that fast_bar shows mixed-sign traces as dashed lines.""" + da = xr.DataArray( + np.array([[50, -30], [-40, 60]]), # Both columns have mixed signs + dims=["time", "category"], + ) + fig = da.plotly.fast_bar() + # Mixed traces should have no stacking and dashed lines + for trace in fig.data: + assert trace.stackgroup is None + assert trace.line.dash == "dash" + + def test_fast_bar_separate_sign_columns(self) -> None: + """Test that fast_bar uses separate stackgroups when columns have different signs.""" + da = xr.DataArray( + np.array([[50, -30], [60, -40]]), # Column 0 positive, column 1 negative + dims=["time", "category"], + ) + fig = da.plotly.fast_bar() + stackgroups = {trace.stackgroup for trace in fig.data} + assert "positive" in stackgroups + assert "negative" in stackgroups + + def test_fast_bar_same_sign_stacks(self) -> None: + """Test that fast_bar uses stacking for same-sign data.""" + da = xr.DataArray( + np.random.rand(5, 3) * 100, + dims=["time", "category"], + ) + fig = da.plotly.fast_bar() + for trace in fig.data: + assert trace.stackgroup is not None + def test_scatter_returns_figure(self) -> None: """Test that scatter() returns a Plotly Figure.""" fig = self.da_2d.plotly.scatter() diff --git a/xarray_plotly/accessor.py b/xarray_plotly/accessor.py index ff45d26..41f60b4 100644 --- a/xarray_plotly/accessor.py +++ b/xarray_plotly/accessor.py @@ -34,7 +34,7 @@ class DataArrayPlotlyAccessor: ``` """ - __all__: ClassVar = ["line", "bar", "area", "scatter", "box", "imshow", "pie"] + __all__: ClassVar = ["line", "bar", "fast_bar", "area", "scatter", "box", "imshow", "pie"] def __init__(self, darray: DataArray) -> None: self._da = darray @@ -160,6 +160,41 @@ def area( **px_kwargs, ) + def fast_bar( + self, + *, + x: SlotValue = auto, + color: SlotValue = auto, + facet_col: SlotValue = auto, + facet_row: SlotValue = auto, + animation_frame: SlotValue = auto, + **px_kwargs: Any, + ) -> go.Figure: + """Create a bar-like chart using stacked areas for better performance. + + Slot order: x -> color -> facet_col -> facet_row -> animation_frame + + Args: + x: Dimension for x-axis. Default: first dimension. + color: Dimension for color/stacking. Default: second dimension. + facet_col: Dimension for subplot columns. Default: third dimension. + facet_row: Dimension for subplot rows. Default: fourth dimension. + animation_frame: Dimension for animation. Default: fifth dimension. + **px_kwargs: Additional arguments passed to `plotly.express.area()`. + + Returns: + Interactive Plotly Figure. + """ + return plotting.fast_bar( + self._da, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **px_kwargs, + ) + def scatter( self, *, @@ -349,7 +384,7 @@ class DatasetPlotlyAccessor: ``` """ - __all__: ClassVar = ["line", "bar", "area", "scatter", "box", "pie"] + __all__: ClassVar = ["line", "bar", "fast_bar", "area", "scatter", "box", "pie"] def __init__(self, dataset: Dataset) -> None: self._ds = dataset @@ -501,6 +536,42 @@ def area( **px_kwargs, ) + def fast_bar( + self, + var: str | None = None, + *, + x: SlotValue = auto, + color: SlotValue = auto, + facet_col: SlotValue = auto, + facet_row: SlotValue = auto, + animation_frame: SlotValue = auto, + **px_kwargs: Any, + ) -> go.Figure: + """Create a bar-like chart using stacked areas for better performance. + + Args: + var: Variable to plot. If None, plots all variables with "variable" dimension. + x: Dimension for x-axis. + color: Dimension for color/stacking. + facet_col: Dimension for subplot columns. + facet_row: Dimension for subplot rows. + animation_frame: Dimension for animation. + **px_kwargs: Additional arguments passed to `plotly.express.area()`. + + Returns: + Interactive Plotly Figure. + """ + da = self._get_dataarray(var) + return plotting.fast_bar( + da, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **px_kwargs, + ) + def scatter( self, var: str | None = None, diff --git a/xarray_plotly/config.py b/xarray_plotly/config.py index e28cab0..0f704a3 100644 --- a/xarray_plotly/config.py +++ b/xarray_plotly/config.py @@ -26,6 +26,7 @@ "animation_frame", ), "bar": ("x", "color", "pattern_shape", "facet_col", "facet_row", "animation_frame"), + "fast_bar": ("x", "color", "facet_col", "facet_row", "animation_frame"), "area": ( "x", "color", diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index 638de94..69acf6c 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -4,6 +4,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any import numpy as np @@ -167,6 +168,169 @@ def bar( ) +def _classify_trace_sign(y_values: np.ndarray) -> str: + """Classify a trace as 'positive', 'negative', or 'mixed' based on its values.""" + y_arr = np.asarray(y_values) + y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)] + if len(y_clean) == 0: + return "zero" + has_pos = bool(np.any(y_clean > 0)) + has_neg = bool(np.any(y_clean < 0)) + if has_pos and has_neg: + return "mixed" + elif has_neg: + return "negative" + elif has_pos: + return "positive" + return "zero" + + +def _style_traces_as_bars(fig: go.Figure) -> None: + """Style area chart traces to look like bar charts with proper pos/neg stacking. + + Classifies each trace (by name) across all data and animation frames, + then assigns stackgroups: positive traces stack upward, negative stack downward. + """ + # Collect all traces (main + animation frames) + all_traces = list(fig.data) + for frame in fig.frames: + all_traces.extend(frame.data) + + # Classify each trace name by aggregating sign info across all occurrences + sign_flags: dict[str, dict[str, bool]] = {} + for trace in all_traces: + if trace.name not in sign_flags: + sign_flags[trace.name] = {"has_pos": False, "has_neg": False} + if trace.y is not None and len(trace.y) > 0: + y_arr = np.asarray(trace.y) + y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)] + if len(y_clean) > 0: + if np.any(y_clean > 0): + sign_flags[trace.name]["has_pos"] = True + if np.any(y_clean < 0): + sign_flags[trace.name]["has_neg"] = True + + # Build classification map + class_map: dict[str, str] = {} + mixed_traces: list[str] = [] + for name, flags in sign_flags.items(): + if flags["has_pos"] and flags["has_neg"]: + class_map[name] = "mixed" + mixed_traces.append(name) + elif flags["has_neg"]: + class_map[name] = "negative" + elif flags["has_pos"]: + class_map[name] = "positive" + else: + class_map[name] = "zero" + + # Warn about mixed traces + if mixed_traces: + warnings.warn( + f"fast_bar: traces {mixed_traces} have mixed positive/negative values " + "and cannot be stacked. They are shown as dashed lines. " + "Consider using bar() for proper stacking of mixed data.", + UserWarning, + stacklevel=3, + ) + + # Apply styling to all traces + for trace in all_traces: + color = trace.line.color + cls = class_map.get(trace.name, "positive") + + if cls in ("positive", "negative"): + trace.stackgroup = cls + trace.fillcolor = color + trace.line = {"width": 0, "color": color, "shape": "hv"} + elif cls == "mixed": + # Mixed: no stacking, show as dashed line + trace.stackgroup = None + trace.fill = None + trace.line = {"width": 2, "color": color, "shape": "hv", "dash": "dash"} + else: # zero + trace.stackgroup = None + trace.fill = None + trace.line = {"width": 0, "color": color, "shape": "hv"} + + +def fast_bar( + darray: DataArray, + *, + x: SlotValue = auto, + color: SlotValue = auto, + facet_col: SlotValue = auto, + facet_row: SlotValue = auto, + animation_frame: SlotValue = auto, + **px_kwargs: Any, +) -> go.Figure: + """ + Create a bar-like chart using stacked areas for better performance. + + Uses `px.area` with stepped lines and no outline to create a bar-like + appearance. Renders faster than `bar()` for large datasets because it + uses a single polygon per trace instead of individual rectangles. + + The y-axis shows DataArray values. Dimensions fill slots in order: + x -> color -> facet_col -> facet_row -> animation_frame + + Traces are classified by their values: purely positive traces stack upward, + purely negative traces stack downward. Traces with mixed signs are shown + as dashed lines without stacking. + + Parameters + ---------- + darray + The DataArray to plot. + x + Dimension for x-axis. Default: first dimension. + color + Dimension for color/stacking. Default: second dimension. + facet_col + Dimension for subplot columns. Default: third dimension. + facet_row + Dimension for subplot rows. Default: fourth dimension. + animation_frame + Dimension for animation. Default: fifth dimension. + **px_kwargs + Additional arguments passed to `plotly.express.area()`. + + Returns + ------- + plotly.graph_objects.Figure + """ + slots = assign_slots( + list(darray.dims), + "fast_bar", + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + ) + + df = to_dataframe(darray) + value_col = get_value_col(darray) + labels = {**build_labels(darray, slots, value_col), **px_kwargs.pop("labels", {})} + + fig = px.area( + df, + x=slots.get("x"), + y=value_col, + color=slots.get("color"), + facet_col=slots.get("facet_col"), + facet_row=slots.get("facet_row"), + animation_frame=slots.get("animation_frame"), + line_shape="hv", + labels=labels, + **px_kwargs, + ) + + _style_traces_as_bars(fig) + + return fig + + def area( darray: DataArray, *,