From 3cb5f9bd9a51986e681403a532d3b7fa97870478 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:08:07 +0100 Subject: [PATCH 01/10] =?UTF-8?q?Files=20created/modified:=20=20=20?= =?UTF-8?q?=E2=94=8C=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=AC=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=90=20=20=20=E2=94=82=20=20=20=20=20=20=20=20=20=20=20=20?= =?UTF-8?q?=20File=20=20=20=20=20=20=20=20=20=20=20=20=20=20=E2=94=82=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=20=20=20=20=20Changes=20=20=20=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=20=20=E2=94=82=20=20=20=E2=94=9C?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=BC=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=A4?= =?UTF-8?q?=20=20=20=E2=94=82=20xarray=5Fplotly/figures.py=20=20=20=20=20?= =?UTF-8?q?=20=E2=94=82=20Renamed=20combine=5Ffigures=20=E2=86=92=20overla?= =?UTF-8?q?y=5Ffigures=20(with=20alias),=20added=20add=5Fsecondary=5Fy=20?= =?UTF-8?q?=E2=94=82=20=20=20=E2=94=9C=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=BC?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=A4=20=20=20=E2=94=82=20xarray=5Fplotly?= =?UTF-8?q?/=5F=5Finit=5F=5F.py=20=20=20=20=20=E2=94=82=20Exported=20overl?= =?UTF-8?q?ay=5Ffigures,=20add=5Fsecondary=5Fy,=20combine=5Ffigures=20=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=E2=94=82?= =?UTF-8?q?=20=20=20=E2=94=9C=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=BC=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=A4=20=20=20=E2=94=82=20tests/test=5Ffigures.py?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=E2=94=82=20Added=2016=20new=20tests?= =?UTF-8?q?=20for=20add=5Fsecondary=5Fy=20and=20alias=20verification=20(34?= =?UTF-8?q?=20total)=20=20=20=20=20=20=E2=94=82=20=20=20=E2=94=9C=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=BC=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=A4=20=20=20?= =?UTF-8?q?=E2=94=82=20docs/examples/combining.ipynb=20=E2=94=82=20New=20n?= =?UTF-8?q?otebook=20demonstrating=20both=20methods=20=20=20=20=20=20=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=20=20=E2=94=82=20=20=20=E2=94=94?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=B4=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80?= =?UTF-8?q?=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=80=E2=94=98?= =?UTF-8?q?=20=20=20API:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit from xarray_plotly import overlay_figures, add_secondary_y # Overlay traces on same axes combined = overlay_figures(area_fig, line_fig) # Dual y-axis (different scales) combined = add_secondary_y(temp_fig, precip_fig, secondary_y_title="Rain (mm)") Features: - overlay_figures: Supports facets, animation, multiple overlays - add_secondary_y: Supports animation, custom y-axis title - Both create deep copies (originals not modified) - Both validate compatibility and raise clear errors Test results: 99 tests passing --- docs/examples/combining.ipynb | 335 +++++++++++++++++++++ tests/test_figures.py | 552 ++++++++++++++++++++++++++++++++++ xarray_plotly/__init__.py | 4 + xarray_plotly/figures.py | 341 +++++++++++++++++++++ 4 files changed, 1232 insertions(+) create mode 100644 docs/examples/combining.ipynb create mode 100644 tests/test_figures.py create mode 100644 xarray_plotly/figures.py diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb new file mode 100644 index 0000000..1096b56 --- /dev/null +++ b/docs/examples/combining.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Combining Figures\n", + "\n", + "xarray-plotly provides helper functions to combine multiple figures:\n", + "\n", + "- **`overlay_figures`**: Overlay traces on the same axes\n", + "- **`add_secondary_y`**: Plot with two independent y-axes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "\n", + "from xarray_plotly import add_secondary_y, config, overlay_figures, xpx\n", + "\n", + "config.notebook()" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Sample Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# Time series with categories\n", + "np.random.seed(42)\n", + "time = np.arange(50)\n", + "\n", + "sales = xr.DataArray(\n", + " np.cumsum(np.random.randn(50, 3), axis=0) + 100,\n", + " dims=[\"day\", \"product\"],\n", + " coords={\"day\": time, \"product\": [\"Widget\", \"Gadget\", \"Gizmo\"]},\n", + " name=\"Sales\",\n", + ")\n", + "\n", + "# Two variables with different scales\n", + "temperature = xr.DataArray(\n", + " 20 + 10 * np.sin(time / 10) + np.random.randn(50),\n", + " dims=[\"day\"],\n", + " coords={\"day\": time},\n", + " name=\"Temperature\",\n", + " attrs={\"units\": \"°C\"},\n", + ")\n", + "\n", + "precipitation = xr.DataArray(\n", + " np.maximum(0, 5 + 10 * np.random.randn(50)),\n", + " dims=[\"day\"],\n", + " coords={\"day\": time},\n", + " name=\"Precipitation\",\n", + " attrs={\"units\": \"mm\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## overlay_figures\n", + "\n", + "Overlay multiple figures on the same axes. Useful for combining different plot types." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "# Area chart with line overlay\n", + "area_fig = xpx(sales).area()\n", + "line_fig = xpx(sales).line()\n", + "\n", + "# Update line style to make it visible on top of area\n", + "line_fig.update_traces(line={\"color\": \"black\", \"width\": 1})\n", + "\n", + "combined = overlay_figures(area_fig, line_fig)\n", + "combined.update_layout(title=\"Sales: Area with Line Overlay\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "### Multiple Overlays\n", + "\n", + "You can overlay more than two figures." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# Three different visualizations of the same data\n", + "area = xpx(sales).area()\n", + "line = xpx(sales).line()\n", + "scatter = xpx(sales).scatter()\n", + "\n", + "# Style them differently\n", + "line.update_traces(line={\"color\": \"black\", \"width\": 1, \"dash\": \"dot\"})\n", + "scatter.update_traces(marker={\"color\": \"black\", \"size\": 4})\n", + "\n", + "combined = overlay_figures(area, line, scatter)\n", + "combined.update_layout(title=\"Sales: Area + Line + Scatter\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "### With Facets\n", + "\n", + "`overlay_figures` works with faceted figures as long as both have the same structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "area_faceted = xpx(sales).area(facet_col=\"product\")\n", + "line_faceted = xpx(sales).line(facet_col=\"product\")\n", + "line_faceted.update_traces(line={\"color\": \"black\", \"width\": 2})\n", + "\n", + "combined = overlay_figures(area_faceted, line_faceted)\n", + "combined.update_layout(title=\"Faceted: Area + Line per Product\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "### With Animation\n", + "\n", + "Animation frames are merged correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "# Create animated data\n", + "animated_data = xr.DataArray(\n", + " np.random.rand(20, 3, 5).cumsum(axis=0),\n", + " dims=[\"x\", \"category\", \"frame\"],\n", + " coords={\n", + " \"x\": np.arange(20),\n", + " \"category\": [\"A\", \"B\", \"C\"],\n", + " \"frame\": [1, 2, 3, 4, 5],\n", + " },\n", + " name=\"Value\",\n", + ")\n", + "\n", + "area_anim = xpx(animated_data).area(animation_frame=\"frame\")\n", + "line_anim = xpx(animated_data).line(animation_frame=\"frame\")\n", + "line_anim.update_traces(line={\"color\": \"black\", \"width\": 2})\n", + "\n", + "combined = overlay_figures(area_anim, line_anim)\n", + "combined.update_layout(title=\"Animated: Area + Line\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## add_secondary_y\n", + "\n", + "Plot two variables with different scales using independent y-axes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "# Temperature on left y-axis, precipitation on right y-axis\n", + "temp_fig = xpx(temperature).line()\n", + "temp_fig.update_traces(line={\"color\": \"red\"})\n", + "\n", + "precip_fig = xpx(precipitation).bar()\n", + "precip_fig.update_traces(marker={\"color\": \"blue\", \"opacity\": 0.6})\n", + "\n", + "combined = add_secondary_y(temp_fig, precip_fig)\n", + "combined.update_layout(title=\"Weather: Temperature & Precipitation\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "### Custom Y-Axis Title\n", + "\n", + "Use `secondary_y_title` to customize the right y-axis label." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "temp_fig = xpx(temperature).line()\n", + "temp_fig.update_traces(line={\"color\": \"red\", \"width\": 2})\n", + "\n", + "precip_fig = xpx(precipitation).bar()\n", + "precip_fig.update_traces(marker={\"color\": \"steelblue\"})\n", + "\n", + "combined = add_secondary_y(\n", + " temp_fig,\n", + " precip_fig,\n", + " secondary_y_title=\"Rainfall (mm)\",\n", + ")\n", + "combined.update_layout(\n", + " title=\"Weather Data\",\n", + " yaxis_title=\"Temperature (°C)\",\n", + " legend={\"orientation\": \"h\", \"y\": 1.1},\n", + ")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "### With Animation\n", + "\n", + "`add_secondary_y` also supports animated figures." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# Create two animated variables with different scales\n", + "var1 = xr.DataArray(\n", + " np.random.rand(20, 5) * 100,\n", + " dims=[\"x\", \"frame\"],\n", + " coords={\"x\": np.arange(20), \"frame\": [1, 2, 3, 4, 5]},\n", + " name=\"Metric A\",\n", + ")\n", + "\n", + "var2 = xr.DataArray(\n", + " np.random.rand(20, 5) * 10,\n", + " dims=[\"x\", \"frame\"],\n", + " coords={\"x\": np.arange(20), \"frame\": [1, 2, 3, 4, 5]},\n", + " name=\"Metric B\",\n", + ")\n", + "\n", + "fig1 = xpx(var1).line(animation_frame=\"frame\")\n", + "fig2 = xpx(var2).bar(animation_frame=\"frame\")\n", + "fig2.update_traces(marker={\"opacity\": 0.5})\n", + "\n", + "combined = add_secondary_y(fig1, fig2)\n", + "combined.update_layout(title=\"Animated Dual Y-Axis\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "## Limitations\n", + "\n", + "### overlay_figures\n", + "- Overlay must have same or fewer subplots than base\n", + "- Animation frames must match (or overlay must be static)\n", + "\n", + "### add_secondary_y\n", + "- Does not support faceted figures (subplots)\n", + "- Animation frames must match (or secondary must be static)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_figures.py b/tests/test_figures.py new file mode 100644 index 0000000..903a2f8 --- /dev/null +++ b/tests/test_figures.py @@ -0,0 +1,552 @@ +"""Tests for the figures module (overlay_figures, add_secondary_y).""" + +from __future__ import annotations + +import copy + +import numpy as np +import plotly.graph_objects as go +import pytest +import xarray as xr + +from xarray_plotly import add_secondary_y, combine_figures, overlay_figures, xpx + + +class TestCombineFiguresBasic: + """Basic tests for combine_figures function.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_2d = xr.DataArray( + np.random.rand(10, 3), + dims=["time", "cat"], + coords={"time": np.arange(10), "cat": ["A", "B", "C"]}, + name="value", + ) + + def test_no_overlays_returns_copy(self) -> None: + """Test that no overlays returns a deep copy of base.""" + base = xpx(self.da_2d).line() + result = combine_figures(base) + + assert isinstance(result, go.Figure) + assert len(result.data) == len(base.data) + # Verify it's a copy, not the same object + assert result is not base + assert result.data[0] is not base.data[0] + + def test_combine_two_static_figures(self) -> None: + """Test combining two static figures.""" + area_fig = xpx(self.da_2d).area() + line_fig = xpx(self.da_2d).line() + + combined = combine_figures(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + expected_trace_count = len(area_fig.data) + len(line_fig.data) + assert len(combined.data) == expected_trace_count + + def test_preserves_base_layout(self) -> None: + """Test that base figure's layout is preserved.""" + area_fig = xpx(self.da_2d).area(title="My Area Plot") + line_fig = xpx(self.da_2d).line(title="My Line Plot") + + combined = combine_figures(area_fig, line_fig) + + assert combined.layout.title.text == "My Area Plot" + + def test_multiple_overlays(self) -> None: + """Test combining multiple overlays.""" + area_fig = xpx(self.da_2d).area() + line_fig = xpx(self.da_2d).line() + scatter_fig = xpx(self.da_2d).scatter() + + combined = combine_figures(area_fig, line_fig, scatter_fig) + + expected_count = len(area_fig.data) + len(line_fig.data) + len(scatter_fig.data) + assert len(combined.data) == expected_count + + def test_overlay_traces_added_in_order(self) -> None: + """Test that overlay traces are added after base traces.""" + # Create figures with distinguishable y values + da_1 = xr.DataArray([1, 2, 3], dims=["x"], name="first") + da_2 = xr.DataArray([10, 20, 30], dims=["x"], name="second") + + fig1 = xpx(da_1).line() + fig2 = xpx(da_2).line() + + combined = combine_figures(fig1, fig2) + + # First trace should have y values from fig1 + assert list(combined.data[0].y) == [1, 2, 3] + # Second trace should have y values from fig2 + assert list(combined.data[1].y) == [10, 20, 30] + + +class TestCombineFiguresFacets: + """Tests for combine_figures with faceted figures.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_3d = xr.DataArray( + np.random.rand(10, 3, 2), + dims=["time", "cat", "facet"], + coords={ + "time": np.arange(10), + "cat": ["A", "B", "C"], + "facet": ["left", "right"], + }, + name="value", + ) + + def test_matching_facet_structures(self) -> None: + """Test combining figures with matching facet structures.""" + area_fig = xpx(self.da_3d).area(facet_col="facet") + line_fig = xpx(self.da_3d).line(facet_col="facet") + + combined = combine_figures(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + expected_count = len(area_fig.data) + len(line_fig.data) + assert len(combined.data) == expected_count + + def test_overlay_with_extra_subplots_raises(self) -> None: + """Test that overlay with extra subplots raises ValueError.""" + # Base without facets + base = xpx(self.da_3d.isel(facet=0)).line() + # Overlay with facets + overlay = xpx(self.da_3d).line(facet_col="facet") + + with pytest.raises(ValueError, match="subplots not present in base"): + combine_figures(base, overlay) + + def test_preserves_axis_references(self) -> None: + """Test that traces preserve their xaxis/yaxis references.""" + area_fig = xpx(self.da_3d).area(facet_col="facet") + line_fig = xpx(self.da_3d).line(facet_col="facet") + + combined = combine_figures(area_fig, line_fig) + + # Collect axis references from both original and combined + original_axes = set() + for trace in area_fig.data: + xaxis = getattr(trace, "xaxis", None) or "x" + yaxis = getattr(trace, "yaxis", None) or "y" + original_axes.add((xaxis, yaxis)) + + combined_axes = set() + for trace in combined.data: + xaxis = getattr(trace, "xaxis", None) or "x" + yaxis = getattr(trace, "yaxis", None) or "y" + combined_axes.add((xaxis, yaxis)) + + # Combined should have same axis structure + assert combined_axes == original_axes + + +class TestCombineFiguresAnimation: + """Tests for combine_figures with animated figures.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_3d = xr.DataArray( + np.random.rand(10, 3, 4), + dims=["x", "cat", "time"], + coords={ + "x": np.arange(10), + "cat": ["A", "B", "C"], + "time": [0, 1, 2, 3], + }, + name="value", + ) + + def test_matching_frames_merged(self) -> None: + """Test that matching animation frames are merged correctly.""" + area_fig = xpx(self.da_3d).area(animation_frame="time") + line_fig = xpx(self.da_3d).line(animation_frame="time") + + combined = combine_figures(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + # Should have same number of frames + assert len(combined.frames) == len(area_fig.frames) + # Each frame should have more data + for i, frame in enumerate(combined.frames): + expected_data = len(area_fig.frames[i].data) + len(line_fig.frames[i].data) + assert len(frame.data) == expected_data + + def test_static_overlay_replicated_to_frames(self) -> None: + """Test that static overlay is replicated to all animation frames.""" + animated = xpx(self.da_3d).area(animation_frame="time") + static = xpx(self.da_3d.isel(time=0)).line() + + combined = combine_figures(animated, static) + + # Combined should have all frames from animated figure + assert len(combined.frames) == len(animated.frames) + + # Each frame should include the static traces + for frame in combined.frames: + # Frame data should include both animated and static traces + expected_count = len(animated.frames[0].data) + len(static.data) + assert len(frame.data) == expected_count + + def test_animated_overlay_on_static_base_raises(self) -> None: + """Test that animated overlay on static base raises ValueError.""" + static = xpx(self.da_3d.isel(time=0)).line() + animated = xpx(self.da_3d).area(animation_frame="time") + + with pytest.raises(ValueError, match="base figure does not"): + combine_figures(static, animated) + + def test_mismatched_frame_names_raises(self) -> None: + """Test that mismatched frame names raise ValueError.""" + da1 = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2]}, + ) + da2 = xr.DataArray( + np.random.rand(10, 4), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2, 3]}, + ) + + fig1 = xpx(da1).line(animation_frame="time") + fig2 = xpx(da2).line(animation_frame="time") + + with pytest.raises(ValueError, match="frame names don't match"): + combine_figures(fig1, fig2) + + def test_frame_names_preserved(self) -> None: + """Test that frame names are preserved in combined figure.""" + area_fig = xpx(self.da_3d).area(animation_frame="time") + line_fig = xpx(self.da_3d).line(animation_frame="time") + + combined = combine_figures(area_fig, line_fig) + + original_names = {frame.name for frame in area_fig.frames} + combined_names = {frame.name for frame in combined.frames} + assert original_names == combined_names + + +class TestCombineFiguresFacetsAndAnimation: + """Tests for combine_figures with both facets and animation.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_4d = xr.DataArray( + np.random.rand(10, 3, 2, 4), + dims=["x", "cat", "facet", "time"], + coords={ + "x": np.arange(10), + "cat": ["A", "B", "C"], + "facet": ["left", "right"], + "time": [0, 1, 2, 3], + }, + name="value", + ) + + def test_facets_and_animation_combined(self) -> None: + """Test combining figures with both facets and animation.""" + area_fig = xpx(self.da_4d).area(facet_col="facet", animation_frame="time") + line_fig = xpx(self.da_4d).line(facet_col="facet", animation_frame="time") + + combined = combine_figures(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + # Check trace count + expected_traces = len(area_fig.data) + len(line_fig.data) + assert len(combined.data) == expected_traces + # Check frame count + assert len(combined.frames) == len(area_fig.frames) + + def test_static_overlay_on_animated_faceted_base(self) -> None: + """Test static overlay replicated on animated faceted base.""" + animated = xpx(self.da_4d).area(facet_col="facet", animation_frame="time") + static = xpx(self.da_4d.isel(time=0)).line(facet_col="facet") + + combined = combine_figures(animated, static) + + # Should have same frames as animated + assert len(combined.frames) == len(animated.frames) + # Each frame should have combined trace count + for frame in combined.frames: + expected = len(animated.frames[0].data) + len(static.data) + assert len(frame.data) == expected + + +class TestCombineFiguresDeepCopy: + """Tests to ensure combine_figures creates deep copies.""" + + def test_base_not_modified(self) -> None: + """Test that base figure is not modified.""" + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) + base = xpx(da).area() + original_trace_count = len(base.data) + original_title = copy.deepcopy(base.layout.title) + + overlay = xpx(da).line() + _ = combine_figures(base, overlay) + + # Base should be unchanged + assert len(base.data) == original_trace_count + assert base.layout.title == original_title + + def test_overlay_not_modified(self) -> None: + """Test that overlay figure is not modified.""" + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) + base = xpx(da).area() + overlay = xpx(da).line() + original_trace_count = len(overlay.data) + + _ = combine_figures(base, overlay) + + # Overlay should be unchanged + assert len(overlay.data) == original_trace_count + + def test_combined_traces_independent(self) -> None: + """Test that combined traces are independent of originals.""" + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) + base = xpx(da).area() + overlay = xpx(da).line() + + combined = combine_figures(base, overlay) + + # Modify combined figure + combined.data[0].name = "modified" + + # Originals should be unchanged + assert base.data[0].name != "modified" + + +class TestOverlayFiguresAlias: + """Test that overlay_figures and combine_figures are equivalent.""" + + def test_overlay_figures_is_combine_figures(self) -> None: + """Test that overlay_figures is the same function as combine_figures.""" + assert overlay_figures is combine_figures + + def test_overlay_figures_works(self) -> None: + """Test that overlay_figures works correctly.""" + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) + area_fig = xpx(da).area() + line_fig = xpx(da).line() + + combined = overlay_figures(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + expected_count = len(area_fig.data) + len(line_fig.data) + assert len(combined.data) == expected_count + + +class TestAddSecondaryYBasic: + """Basic tests for add_secondary_y function.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.temp = xr.DataArray( + [20, 22, 25, 23, 21], + dims=["time"], + coords={"time": [0, 1, 2, 3, 4]}, + name="Temperature", + ) + self.precip = xr.DataArray( + [0, 5, 12, 2, 8], + dims=["time"], + coords={"time": [0, 1, 2, 3, 4]}, + name="Precipitation", + ) + + def test_creates_secondary_y_axis(self) -> None: + """Test that secondary y-axis is created.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig) + + assert isinstance(combined, go.Figure) + assert combined.layout.yaxis2 is not None + assert combined.layout.yaxis2.side == "right" + assert combined.layout.yaxis2.overlaying == "y" + + def test_secondary_traces_use_y2(self) -> None: + """Test that secondary figure traces are assigned to y2.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig) + + # First trace (from temp_fig) should use default y + assert combined.data[0].yaxis is None or combined.data[0].yaxis == "y" + # Second trace (from precip_fig) should use y2 + assert combined.data[1].yaxis == "y2" + + def test_preserves_base_layout(self) -> None: + """Test that base figure's layout is preserved.""" + temp_fig = xpx(self.temp).line(title="My Temperature Plot") + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig) + + assert combined.layout.title.text == "My Temperature Plot" + + def test_total_trace_count(self) -> None: + """Test that all traces from both figures are included.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig) + + expected_count = len(temp_fig.data) + len(precip_fig.data) + assert len(combined.data) == expected_count + + def test_secondary_y_title_from_secondary_figure(self) -> None: + """Test that secondary y-axis title comes from secondary figure.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + # Plotly Express sets y-axis title based on the data + + combined = add_secondary_y(temp_fig, precip_fig) + + # The secondary y-axis title should be set + assert combined.layout.yaxis2.title is not None + + def test_custom_secondary_y_title(self) -> None: + """Test that custom secondary y-axis title can be provided.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig, secondary_y_title="Rain (mm)") + + assert combined.layout.yaxis2.title.text == "Rain (mm)" + + +class TestAddSecondaryYFacetsError: + """Tests for add_secondary_y error handling with facets.""" + + def test_base_with_facets_raises(self) -> None: + """Test that base figure with facets raises ValueError.""" + da = xr.DataArray( + np.random.rand(10, 2), + dims=["time", "facet"], + coords={"time": np.arange(10), "facet": ["A", "B"]}, + ) + base = xpx(da).line(facet_col="facet") + secondary = xpx(da.isel(facet=0)).bar() + + with pytest.raises(ValueError, match="Base figure has facets"): + add_secondary_y(base, secondary) + + def test_secondary_with_facets_raises(self) -> None: + """Test that secondary figure with facets raises ValueError.""" + da = xr.DataArray( + np.random.rand(10, 2), + dims=["time", "facet"], + coords={"time": np.arange(10), "facet": ["A", "B"]}, + ) + base = xpx(da.isel(facet=0)).line() + secondary = xpx(da).bar(facet_col="facet") + + with pytest.raises(ValueError, match="Secondary figure has facets"): + add_secondary_y(base, secondary) + + +class TestAddSecondaryYAnimation: + """Tests for add_secondary_y with animated figures.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_2d = xr.DataArray( + np.random.rand(10, 4), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2, 3]}, + name="value", + ) + + def test_matching_animation_frames(self) -> None: + """Test add_secondary_y with matching animation frames.""" + fig1 = xpx(self.da_2d).line(animation_frame="time") + fig2 = xpx(self.da_2d).bar(animation_frame="time") + + combined = add_secondary_y(fig1, fig2) + + assert len(combined.frames) == len(fig1.frames) + # Verify frame names match + for orig, comb in zip(fig1.frames, combined.frames, strict=False): + assert orig.name == comb.name + + def test_static_secondary_on_animated_base(self) -> None: + """Test static secondary replicated to all animation frames.""" + animated = xpx(self.da_2d).line(animation_frame="time") + static = xpx(self.da_2d.isel(time=0)).bar() + + combined = add_secondary_y(animated, static) + + assert len(combined.frames) == len(animated.frames) + # Each frame should have traces from both figures + for frame in combined.frames: + expected = len(animated.frames[0].data) + len(static.data) + assert len(frame.data) == expected + + def test_animated_secondary_on_static_base_raises(self) -> None: + """Test that animated secondary on static base raises ValueError.""" + static = xpx(self.da_2d.isel(time=0)).line() + animated = xpx(self.da_2d).bar(animation_frame="time") + + with pytest.raises(ValueError, match="base figure does not"): + add_secondary_y(static, animated) + + def test_mismatched_animation_frames_raises(self) -> None: + """Test that mismatched animation frames raise ValueError.""" + da1 = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2]}, + ) + da2 = xr.DataArray( + np.random.rand(10, 4), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2, 3]}, + ) + + fig1 = xpx(da1).line(animation_frame="time") + fig2 = xpx(da2).bar(animation_frame="time") + + with pytest.raises(ValueError, match="frame names don't match"): + add_secondary_y(fig1, fig2) + + +class TestAddSecondaryYDeepCopy: + """Tests to ensure add_secondary_y creates deep copies.""" + + def test_base_not_modified(self) -> None: + """Test that base figure is not modified.""" + da = xr.DataArray([1, 2, 3, 4, 5], dims=["x"]) + base = xpx(da).line() + original_trace_count = len(base.data) + + secondary = xpx(da).bar() + _ = add_secondary_y(base, secondary) + + assert len(base.data) == original_trace_count + # Base should not have yaxis2 (check via to_plotly_json) + assert "yaxis2" not in base.layout.to_plotly_json() + + def test_secondary_not_modified(self) -> None: + """Test that secondary figure is not modified.""" + da = xr.DataArray([1, 2, 3, 4, 5], dims=["x"]) + base = xpx(da).line() + secondary = xpx(da).bar() + original_yaxis = secondary.data[0].yaxis + + _ = add_secondary_y(base, secondary) + + # Secondary traces should still use original yaxis + assert secondary.data[0].yaxis == original_yaxis diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index 7bc9539..cf254eb 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -53,13 +53,17 @@ from xarray_plotly import config from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor from xarray_plotly.common import SLOT_ORDERS, auto +from xarray_plotly.figures import add_secondary_y, combine_figures, overlay_figures __all__ = [ "SLOT_ORDERS", "DataArrayPlotlyAccessor", "DatasetPlotlyAccessor", + "add_secondary_y", "auto", + "combine_figures", "config", + "overlay_figures", "xpx", ] diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py new file mode 100644 index 0000000..f3edfc4 --- /dev/null +++ b/xarray_plotly/figures.py @@ -0,0 +1,341 @@ +""" +Helper functions for combining and manipulating Plotly figures. +""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import plotly.graph_objects as go + + +def _get_subplot_axes(fig: go.Figure) -> set[tuple[str, str]]: + """Extract (xaxis, yaxis) pairs from figure traces. + + Args: + fig: A Plotly figure. + + Returns: + Set of (xaxis, yaxis) tuples, e.g., {('x', 'y'), ('x2', 'y2')}. + """ + axes_pairs = set() + for trace in fig.data: + xaxis = getattr(trace, "xaxis", None) or "x" + yaxis = getattr(trace, "yaxis", None) or "y" + axes_pairs.add((xaxis, yaxis)) + return axes_pairs + + +def _validate_compatible_structure(base: go.Figure, overlay: go.Figure) -> None: + """Validate that overlay's subplot structure is compatible with base. + + Args: + base: The base figure. + overlay: The overlay figure to check. + + Raises: + ValueError: If overlay has subplots not present in base. + """ + base_axes = _get_subplot_axes(base) + overlay_axes = _get_subplot_axes(overlay) + + extra_axes = overlay_axes - base_axes + if extra_axes: + raise ValueError( + f"Overlay figure has subplots not present in base figure: {extra_axes}. " + "Ensure both figures have the same facet structure." + ) + + +def _validate_animation_compatibility(base: go.Figure, overlay: go.Figure) -> None: + """Validate animation frame compatibility between base and overlay. + + Args: + base: The base figure. + overlay: The overlay figure to check. + + Raises: + ValueError: If overlay has animation but base doesn't, or frame names don't match. + """ + base_has_frames = bool(base.frames) + overlay_has_frames = bool(overlay.frames) + + if overlay_has_frames and not base_has_frames: + raise ValueError( + "Overlay figure has animation frames but base figure does not. " + "Cannot add animated overlay to static base figure." + ) + + if base_has_frames and overlay_has_frames: + base_frame_names = {frame.name for frame in base.frames} + overlay_frame_names = {frame.name for frame in overlay.frames} + + if base_frame_names != overlay_frame_names: + missing_in_overlay = base_frame_names - overlay_frame_names + extra_in_overlay = overlay_frame_names - base_frame_names + msg = "Animation frame names don't match between base and overlay." + if missing_in_overlay: + msg += f" Missing in overlay: {missing_in_overlay}." + if extra_in_overlay: + msg += f" Extra in overlay: {extra_in_overlay}." + raise ValueError(msg) + + +def _merge_frames( + base: go.Figure, + overlays: list[go.Figure], + base_trace_count: int, + overlay_trace_counts: list[int], +) -> list: + """Merge animation frames from base and overlay figures. + + Args: + base: The base figure with animation frames. + overlays: List of overlay figures (may or may not have frames). + base_trace_count: Number of traces in the base figure. + overlay_trace_counts: Number of traces in each overlay figure. + + Returns: + List of merged frames. + """ + import plotly.graph_objects as go + + merged_frames = [] + + for base_frame in base.frames: + frame_name = base_frame.name + merged_data = list(base_frame.data) + + for overlay, _overlay_trace_count in zip(overlays, overlay_trace_counts, strict=False): + if overlay.frames: + # Find matching frame in overlay + overlay_frame = next((f for f in overlay.frames if f.name == frame_name), None) + if overlay_frame: + merged_data.extend(overlay_frame.data) + else: + # Static overlay: replicate traces to this frame + merged_data.extend(overlay.data) + + merged_frames.append( + go.Frame( + data=merged_data, + name=frame_name, + traces=list(range(base_trace_count + sum(overlay_trace_counts))), + ) + ) + + return merged_frames + + +def overlay_figures(base: go.Figure, *overlays: go.Figure) -> go.Figure: + """Overlay multiple Plotly figures on the same axes. + + Creates a new figure with the base figure's layout, sliders, and buttons, + with all overlay traces added on top. Correctly handles faceted figures + and animation frames. + + Args: + base: The base figure whose layout is preserved. + *overlays: One or more figures to overlay on the base. + + Returns: + A new combined figure. + + Raises: + ValueError: If overlay has subplots not in base, animation frames don't match, + or overlay has animation but base doesn't. + + Example: + >>> import numpy as np + >>> import xarray as xr + >>> from xarray_plotly import xpx, overlay_figures + >>> + >>> da = xr.DataArray(np.random.rand(10, 3), dims=["time", "cat"]) + >>> area_fig = xpx(da).area() + >>> line_fig = xpx(da).line() + >>> combined = overlay_figures(area_fig, line_fig) + >>> + >>> # With animation + >>> da3d = xr.DataArray(np.random.rand(10, 3, 4), dims=["x", "cat", "time"]) + >>> area = xpx(da3d).area(animation_frame="time") + >>> line = xpx(da3d).line(animation_frame="time") + >>> combined = overlay_figures(area, line) + """ + import plotly.graph_objects as go + + if not overlays: + # No overlays: return a deep copy of base + return copy.deepcopy(base) + + # Validate all overlays + for overlay in overlays: + _validate_compatible_structure(base, overlay) + _validate_animation_compatibility(base, overlay) + + # Create new figure with base's layout + combined = go.Figure(layout=copy.deepcopy(base.layout)) + + # Add all traces from base + for trace in base.data: + combined.add_trace(copy.deepcopy(trace)) + + # Add all traces from overlays + for overlay in overlays: + for trace in overlay.data: + combined.add_trace(copy.deepcopy(trace)) + + # Handle animation frames + if base.frames: + base_trace_count = len(base.data) + overlay_trace_counts = [len(overlay.data) for overlay in overlays] + merged_frames = _merge_frames(base, list(overlays), base_trace_count, overlay_trace_counts) + combined.frames = merged_frames + + return combined + + +# Backwards compatibility alias +combine_figures = overlay_figures + + +def add_secondary_y( + base: go.Figure, + secondary: go.Figure, + *, + secondary_y_title: str | None = None, +) -> go.Figure: + """Add a secondary y-axis with traces from another figure. + + Creates a new figure with the base figure's layout and a secondary y-axis + on the right side. All traces from the secondary figure are plotted against + the secondary y-axis. + + Args: + base: The base figure (left y-axis). + secondary: The figure whose traces use the secondary y-axis (right). + secondary_y_title: Optional title for the secondary y-axis. + If not provided, uses the secondary figure's y-axis title. + + Returns: + A new figure with both primary and secondary y-axes. + + Raises: + ValueError: If either figure has facets (subplots), or if animation + frames don't match. + + Example: + >>> import numpy as np + >>> import xarray as xr + >>> from xarray_plotly import xpx, add_secondary_y + >>> + >>> # Two variables with different scales + >>> temp = xr.DataArray([20, 22, 25, 23], dims=["time"], name="Temperature (°C)") + >>> precip = xr.DataArray([0, 5, 12, 2], dims=["time"], name="Precipitation (mm)") + >>> + >>> temp_fig = xpx(temp).line() + >>> precip_fig = xpx(precip).bar() + >>> combined = add_secondary_y(temp_fig, precip_fig) + """ + import plotly.graph_objects as go + + # Check for facets - not supported with secondary y + base_axes = _get_subplot_axes(base) + secondary_axes = _get_subplot_axes(secondary) + + if len(base_axes) > 1 or base_axes != {("x", "y")}: + raise ValueError( + "Base figure has facets (subplots). Secondary y-axis is not supported " + "with faceted figures." + ) + if len(secondary_axes) > 1 or secondary_axes != {("x", "y")}: + raise ValueError( + "Secondary figure has facets (subplots). Secondary y-axis is not supported " + "with faceted figures." + ) + + # Validate animation compatibility + _validate_animation_compatibility(base, secondary) + + # Create new figure with base's layout + combined = go.Figure(layout=copy.deepcopy(base.layout)) + + # Add all traces from base (primary y-axis) + for trace in base.data: + combined.add_trace(copy.deepcopy(trace)) + + # Add all traces from secondary, assigned to y2 + for trace in secondary.data: + trace_copy = copy.deepcopy(trace) + trace_copy.yaxis = "y2" + combined.add_trace(trace_copy) + + # Configure secondary y-axis + y2_title = secondary_y_title + if y2_title is None and secondary.layout.yaxis and secondary.layout.yaxis.title: + y2_title = secondary.layout.yaxis.title.text + + combined.update_layout( + yaxis2={ + "title": y2_title, + "overlaying": "y", + "side": "right", + }, + ) + + # Handle animation frames + if base.frames: + merged_frames = _merge_secondary_y_frames(base, secondary) + combined.frames = merged_frames + + return combined + + +def _merge_secondary_y_frames(base: go.Figure, secondary: go.Figure) -> list: + """Merge animation frames for secondary y-axis combination. + + Args: + base: The base figure with animation frames. + secondary: The secondary figure (may or may not have frames). + + Returns: + List of merged frames with secondary traces assigned to y2. + """ + import plotly.graph_objects as go + + merged_frames = [] + base_trace_count = len(base.data) + secondary_trace_count = len(secondary.data) + + for base_frame in base.frames: + frame_name = base_frame.name + merged_data = list(base_frame.data) + + if secondary.frames: + # Find matching frame in secondary + secondary_frame = next((f for f in secondary.frames if f.name == frame_name), None) + if secondary_frame: + # Add secondary frame data with y2 assignment + for trace_data in secondary_frame.data: + trace_copy = copy.deepcopy(trace_data) + if hasattr(trace_copy, "yaxis"): + trace_copy.yaxis = "y2" + merged_data.append(trace_copy) + else: + # Static secondary: replicate traces to this frame + for trace in secondary.data: + trace_copy = copy.deepcopy(trace) + if hasattr(trace_copy, "yaxis"): + trace_copy.yaxis = "y2" + merged_data.append(trace_copy) + + merged_frames.append( + go.Frame( + data=merged_data, + name=frame_name, + traces=list(range(base_trace_count + secondary_trace_count)), + ) + ) + + return merged_frames From 6c27ef6ee5359afe7984a53317e541f8778d554b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:16:27 +0100 Subject: [PATCH 02/10] Update notebook --- docs/examples/combining.ipynb | 275 ++++++++++++++++++---------------- 1 file changed, 145 insertions(+), 130 deletions(-) diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index 1096b56..1b8dff6 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -20,7 +20,7 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", + "import plotly.express as px\n", "import xarray as xr\n", "\n", "from xarray_plotly import add_secondary_y, config, overlay_figures, xpx\n", @@ -33,7 +33,7 @@ "id": "2", "metadata": {}, "source": [ - "## Sample Data" + "## Load Sample Data" ] }, { @@ -43,32 +43,46 @@ "metadata": {}, "outputs": [], "source": [ - "# Time series with categories\n", - "np.random.seed(42)\n", - "time = np.arange(50)\n", - "\n", - "sales = xr.DataArray(\n", - " np.cumsum(np.random.randn(50, 3), axis=0) + 100,\n", - " dims=[\"day\", \"product\"],\n", - " coords={\"day\": time, \"product\": [\"Widget\", \"Gadget\", \"Gizmo\"]},\n", - " name=\"Sales\",\n", + "# Stock prices\n", + "df_stocks = px.data.stocks().set_index(\"date\")\n", + "df_stocks.index = df_stocks.index.astype(\"datetime64[ns]\")\n", + "\n", + "stocks = xr.DataArray(\n", + " df_stocks.values,\n", + " dims=[\"date\", \"company\"],\n", + " coords={\"date\": df_stocks.index, \"company\": df_stocks.columns.tolist()},\n", + " name=\"price\",\n", + ")\n", + "\n", + "# Gapminder data (subset: a few countries)\n", + "df_gap = px.data.gapminder()\n", + "countries = [\"United States\", \"China\", \"Germany\", \"Brazil\"]\n", + "df_gap = df_gap[df_gap[\"country\"].isin(countries)]\n", + "\n", + "# Convert to xarray\n", + "gap_pop = df_gap.pivot(index=\"year\", columns=\"country\", values=\"pop\")\n", + "gap_gdp = df_gap.pivot(index=\"year\", columns=\"country\", values=\"gdpPercap\")\n", + "gap_life = df_gap.pivot(index=\"year\", columns=\"country\", values=\"lifeExp\")\n", + "\n", + "population = xr.DataArray(\n", + " gap_pop.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": gap_pop.index.values, \"country\": gap_pop.columns.tolist()},\n", + " name=\"Population\",\n", ")\n", "\n", - "# Two variables with different scales\n", - "temperature = xr.DataArray(\n", - " 20 + 10 * np.sin(time / 10) + np.random.randn(50),\n", - " dims=[\"day\"],\n", - " coords={\"day\": time},\n", - " name=\"Temperature\",\n", - " attrs={\"units\": \"°C\"},\n", + "gdp_per_capita = xr.DataArray(\n", + " gap_gdp.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": gap_gdp.index.values, \"country\": gap_gdp.columns.tolist()},\n", + " name=\"GDP per Capita\",\n", ")\n", "\n", - "precipitation = xr.DataArray(\n", - " np.maximum(0, 5 + 10 * np.random.randn(50)),\n", - " dims=[\"day\"],\n", - " coords={\"day\": time},\n", - " name=\"Precipitation\",\n", - " attrs={\"units\": \"mm\"},\n", + "life_expectancy = xr.DataArray(\n", + " gap_life.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": gap_life.index.values, \"country\": gap_life.columns.tolist()},\n", + " name=\"Life Expectancy\",\n", ")" ] }, @@ -79,62 +93,79 @@ "source": [ "## overlay_figures\n", "\n", - "Overlay multiple figures on the same axes. Useful for combining different plot types." + "Overlay multiple figures on the same axes. Useful for showing data with a trend line, moving average, or different visualizations of related data." + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "### Stock Price with Moving Average" ] }, { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "6", "metadata": {}, "outputs": [], "source": [ - "# Area chart with line overlay\n", - "area_fig = xpx(sales).area()\n", - "line_fig = xpx(sales).line()\n", + "# Select one company\n", + "goog = stocks.sel(company=\"GOOG\")\n", + "\n", + "# Calculate 20-day moving average\n", + "goog_ma = goog.rolling(date=20, center=True).mean()\n", + "goog_ma.name = \"20-day MA\"\n", + "\n", + "# Raw prices as scatter\n", + "price_fig = xpx(goog).scatter()\n", + "price_fig.update_traces(marker={\"size\": 4, \"opacity\": 0.5}, name=\"Daily Price\")\n", "\n", - "# Update line style to make it visible on top of area\n", - "line_fig.update_traces(line={\"color\": \"black\", \"width\": 1})\n", + "# Moving average as line\n", + "ma_fig = xpx(goog_ma).line()\n", + "ma_fig.update_traces(line={\"color\": \"red\", \"width\": 3}, name=\"20-day MA\")\n", "\n", - "combined = overlay_figures(area_fig, line_fig)\n", - "combined.update_layout(title=\"Sales: Area with Line Overlay\")\n", + "combined = overlay_figures(price_fig, ma_fig)\n", + "combined.update_layout(title=\"GOOG: Daily Price with Moving Average\")\n", "combined" ] }, { "cell_type": "markdown", - "id": "6", + "id": "7", "metadata": {}, "source": [ - "### Multiple Overlays\n", - "\n", - "You can overlay more than two figures." + "### Multiple Companies with Moving Averages" ] }, { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": {}, "outputs": [], "source": [ - "# Three different visualizations of the same data\n", - "area = xpx(sales).area()\n", - "line = xpx(sales).line()\n", - "scatter = xpx(sales).scatter()\n", + "# Select a few companies\n", + "subset = stocks.sel(company=[\"GOOG\", \"AAPL\", \"MSFT\"])\n", + "subset_ma = subset.rolling(date=20, center=True).mean()\n", + "\n", + "# Raw as scatter (faded)\n", + "raw_fig = xpx(subset).scatter()\n", + "raw_fig.update_traces(marker={\"size\": 3, \"opacity\": 0.3})\n", "\n", - "# Style them differently\n", - "line.update_traces(line={\"color\": \"black\", \"width\": 1, \"dash\": \"dot\"})\n", - "scatter.update_traces(marker={\"color\": \"black\", \"size\": 4})\n", + "# MA as lines (bold)\n", + "ma_fig = xpx(subset_ma).line()\n", + "ma_fig.update_traces(line={\"width\": 3})\n", "\n", - "combined = overlay_figures(area, line, scatter)\n", - "combined.update_layout(title=\"Sales: Area + Line + Scatter\")\n", + "combined = overlay_figures(raw_fig, ma_fig)\n", + "combined.update_layout(title=\"Tech Stocks: Raw Prices + Moving Averages\")\n", "combined" ] }, { "cell_type": "markdown", - "id": "8", + "id": "9", "metadata": {}, "source": [ "### With Facets\n", @@ -145,55 +176,30 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": {}, "outputs": [], "source": [ - "area_faceted = xpx(sales).area(facet_col=\"product\")\n", - "line_faceted = xpx(sales).line(facet_col=\"product\")\n", - "line_faceted.update_traces(line={\"color\": \"black\", \"width\": 2})\n", + "# Faceted by company\n", + "raw_faceted = xpx(subset).scatter(facet_col=\"company\")\n", + "raw_faceted.update_traces(marker={\"size\": 3, \"opacity\": 0.4})\n", "\n", - "combined = overlay_figures(area_faceted, line_faceted)\n", - "combined.update_layout(title=\"Faceted: Area + Line per Product\")\n", + "ma_faceted = xpx(subset_ma).line(facet_col=\"company\")\n", + "ma_faceted.update_traces(line={\"color\": \"red\", \"width\": 2})\n", + "\n", + "combined = overlay_figures(raw_faceted, ma_faceted)\n", + "combined.update_layout(title=\"Faceted: Price + Moving Average per Company\")\n", "combined" ] }, { "cell_type": "markdown", - "id": "10", - "metadata": {}, - "source": [ - "### With Animation\n", - "\n", - "Animation frames are merged correctly." - ] - }, - { - "cell_type": "code", - "execution_count": null, "id": "11", "metadata": {}, - "outputs": [], "source": [ - "# Create animated data\n", - "animated_data = xr.DataArray(\n", - " np.random.rand(20, 3, 5).cumsum(axis=0),\n", - " dims=[\"x\", \"category\", \"frame\"],\n", - " coords={\n", - " \"x\": np.arange(20),\n", - " \"category\": [\"A\", \"B\", \"C\"],\n", - " \"frame\": [1, 2, 3, 4, 5],\n", - " },\n", - " name=\"Value\",\n", - ")\n", - "\n", - "area_anim = xpx(animated_data).area(animation_frame=\"frame\")\n", - "line_anim = xpx(animated_data).line(animation_frame=\"frame\")\n", - "line_anim.update_traces(line={\"color\": \"black\", \"width\": 2})\n", + "## add_secondary_y\n", "\n", - "combined = overlay_figures(area_anim, line_anim)\n", - "combined.update_layout(title=\"Animated: Area + Line\")\n", - "combined" + "Plot two variables with different scales using independent y-axes. Essential when values have different magnitudes (e.g., population in millions vs GDP in thousands)." ] }, { @@ -201,9 +207,7 @@ "id": "12", "metadata": {}, "source": [ - "## add_secondary_y\n", - "\n", - "Plot two variables with different scales using independent y-axes." + "### Population vs GDP per Capita" ] }, { @@ -213,15 +217,21 @@ "metadata": {}, "outputs": [], "source": [ - "# Temperature on left y-axis, precipitation on right y-axis\n", - "temp_fig = xpx(temperature).line()\n", - "temp_fig.update_traces(line={\"color\": \"red\"})\n", + "# Select one country\n", + "us_pop = population.sel(country=\"United States\")\n", + "us_gdp = gdp_per_capita.sel(country=\"United States\")\n", + "\n", + "pop_fig = xpx(us_pop).bar()\n", + "pop_fig.update_traces(marker={\"color\": \"steelblue\", \"opacity\": 0.7}, name=\"Population\")\n", "\n", - "precip_fig = xpx(precipitation).bar()\n", - "precip_fig.update_traces(marker={\"color\": \"blue\", \"opacity\": 0.6})\n", + "gdp_fig = xpx(us_gdp).line()\n", + "gdp_fig.update_traces(line={\"color\": \"red\", \"width\": 3}, name=\"GDP per Capita\")\n", "\n", - "combined = add_secondary_y(temp_fig, precip_fig)\n", - "combined.update_layout(title=\"Weather: Temperature & Precipitation\")\n", + "combined = add_secondary_y(pop_fig, gdp_fig, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(\n", + " title=\"United States: Population vs GDP per Capita\",\n", + " yaxis_title=\"Population\",\n", + ")\n", "combined" ] }, @@ -230,9 +240,7 @@ "id": "14", "metadata": {}, "source": [ - "### Custom Y-Axis Title\n", - "\n", - "Use `secondary_y_title` to customize the right y-axis label." + "### Life Expectancy vs GDP" ] }, { @@ -242,20 +250,19 @@ "metadata": {}, "outputs": [], "source": [ - "temp_fig = xpx(temperature).line()\n", - "temp_fig.update_traces(line={\"color\": \"red\", \"width\": 2})\n", + "china_life = life_expectancy.sel(country=\"China\")\n", + "china_gdp = gdp_per_capita.sel(country=\"China\")\n", "\n", - "precip_fig = xpx(precipitation).bar()\n", - "precip_fig.update_traces(marker={\"color\": \"steelblue\"})\n", + "life_fig = xpx(china_life).line()\n", + "life_fig.update_traces(line={\"color\": \"green\", \"width\": 3}, name=\"Life Expectancy\")\n", "\n", - "combined = add_secondary_y(\n", - " temp_fig,\n", - " precip_fig,\n", - " secondary_y_title=\"Rainfall (mm)\",\n", - ")\n", + "gdp_fig = xpx(china_gdp).line()\n", + "gdp_fig.update_traces(line={\"color\": \"orange\", \"width\": 3, \"dash\": \"dash\"}, name=\"GDP per Capita\")\n", + "\n", + "combined = add_secondary_y(life_fig, gdp_fig, secondary_y_title=\"GDP per Capita ($)\")\n", "combined.update_layout(\n", - " title=\"Weather Data\",\n", - " yaxis_title=\"Temperature (°C)\",\n", + " title=\"China: Life Expectancy vs GDP per Capita\",\n", + " yaxis_title=\"Life Expectancy (years)\",\n", " legend={\"orientation\": \"h\", \"y\": 1.1},\n", ")\n", "combined" @@ -266,9 +273,9 @@ "id": "16", "metadata": {}, "source": [ - "### With Animation\n", + "### Comparing Scales: Why Secondary Y-Axis Matters\n", "\n", - "`add_secondary_y` also supports animated figures." + "Without secondary y-axis, one variable dominates:" ] }, { @@ -278,33 +285,41 @@ "metadata": {}, "outputs": [], "source": [ - "# Create two animated variables with different scales\n", - "var1 = xr.DataArray(\n", - " np.random.rand(20, 5) * 100,\n", - " dims=[\"x\", \"frame\"],\n", - " coords={\"x\": np.arange(20), \"frame\": [1, 2, 3, 4, 5]},\n", - " name=\"Metric A\",\n", - ")\n", + "# Same data on single y-axis (population dwarfs GDP)\n", + "pop_fig = xpx(us_pop).line()\n", + "pop_fig.update_traces(name=\"Population\")\n", "\n", - "var2 = xr.DataArray(\n", - " np.random.rand(20, 5) * 10,\n", - " dims=[\"x\", \"frame\"],\n", - " coords={\"x\": np.arange(20), \"frame\": [1, 2, 3, 4, 5]},\n", - " name=\"Metric B\",\n", - ")\n", + "gdp_fig = xpx(us_gdp).line()\n", + "gdp_fig.update_traces(name=\"GDP per Capita\")\n", "\n", - "fig1 = xpx(var1).line(animation_frame=\"frame\")\n", - "fig2 = xpx(var2).bar(animation_frame=\"frame\")\n", - "fig2.update_traces(marker={\"opacity\": 0.5})\n", + "# With overlay_figures (same y-axis) - GDP looks flat\n", + "bad = overlay_figures(pop_fig, gdp_fig)\n", + "bad.update_layout(title=\"Same Y-Axis: GDP appears flat (scale mismatch)\")\n", + "bad" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "# With add_secondary_y - both variables visible\n", + "pop_fig = xpx(us_pop).line()\n", + "pop_fig.update_traces(name=\"Population\", line={\"color\": \"blue\"})\n", "\n", - "combined = add_secondary_y(fig1, fig2)\n", - "combined.update_layout(title=\"Animated Dual Y-Axis\")\n", - "combined" + "gdp_fig = xpx(us_gdp).line()\n", + "gdp_fig.update_traces(name=\"GDP per Capita\", line={\"color\": \"red\"})\n", + "\n", + "good = add_secondary_y(pop_fig, gdp_fig, secondary_y_title=\"GDP ($)\")\n", + "good.update_layout(title=\"Dual Y-Axis: Both variables clearly visible\")\n", + "good" ] }, { "cell_type": "markdown", - "id": "18", + "id": "19", "metadata": {}, "source": [ "## Limitations\n", From 7f2cada94de586a346208375138b8719f531bd5b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Wed, 21 Jan 2026 22:07:46 +0100 Subject: [PATCH 03/10] Update notebook --- docs/examples/combining.ipynb | 378 +++++++++++++++++++++++++++++----- 1 file changed, 331 insertions(+), 47 deletions(-) diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index 1b8dff6..ab2983e 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -196,15 +196,78 @@ "cell_type": "markdown", "id": "11", "metadata": {}, + "source": [ + "### With Animation\n", + "\n", + "Overlay animated figures - frames are merged correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "# Animate through countries, showing population bar + GDP line\n", + "# Both use the same animation dimension\n", + "pop_anim = xpx(population).bar(animation_frame=\"country\")\n", + "pop_anim.update_traces(marker={\"opacity\": 0.6})\n", + "\n", + "# Create a \"target\" line (e.g., some reference value)\n", + "pop_smooth = population.rolling(year=3, center=True).mean()\n", + "smooth_anim = xpx(pop_smooth).line(animation_frame=\"country\")\n", + "smooth_anim.update_traces(line={\"color\": \"red\", \"width\": 3})\n", + "\n", + "combined = overlay_figures(pop_anim, smooth_anim)\n", + "combined.update_layout(title=\"Population: Raw + Smoothed (animated by country)\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "### Static Overlay on Animated Base\n", + "\n", + "A static figure can be overlaid on an animated one - the static traces appear in all frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# Animated population\n", + "pop_anim = xpx(population).bar(animation_frame=\"country\")\n", + "pop_anim.update_traces(marker={\"opacity\": 0.7})\n", + "\n", + "# Static reference line (global average across all countries)\n", + "global_avg = population.mean(dim=\"country\")\n", + "avg_fig = xpx(global_avg).line()\n", + "avg_fig.update_traces(line={\"color\": \"black\", \"width\": 2, \"dash\": \"dash\"}, name=\"Global Avg\")\n", + "\n", + "combined = overlay_figures(pop_anim, avg_fig)\n", + "combined.update_layout(title=\"Population by Country vs Global Average\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, "source": [ "## add_secondary_y\n", "\n", - "Plot two variables with different scales using independent y-axes. Essential when values have different magnitudes (e.g., population in millions vs GDP in thousands)." + "Plot two variables with different scales using independent y-axes. Essential when values have different magnitudes." ] }, { "cell_type": "markdown", - "id": "12", + "id": "16", "metadata": {}, "source": [ "### Population vs GDP per Capita" @@ -213,7 +276,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -237,100 +300,321 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "18", "metadata": {}, "source": [ - "### Life Expectancy vs GDP" + "### Why Secondary Y-Axis Matters\n", + "\n", + "Without it, one variable dominates due to scale mismatch:" ] }, { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "# Same data on single y-axis - GDP looks flat because population is ~1e8, GDP is ~1e4\n", + "pop_fig = xpx(us_pop).line()\n", + "pop_fig.update_traces(name=\"Population\", line={\"color\": \"blue\"})\n", + "\n", + "gdp_fig = xpx(us_gdp).line()\n", + "gdp_fig.update_traces(name=\"GDP per Capita\", line={\"color\": \"red\"})\n", + "\n", + "bad = overlay_figures(pop_fig, gdp_fig)\n", + "bad.update_layout(title=\"overlay_figures: GDP invisible (scale mismatch)\")\n", + "bad" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", "metadata": {}, "outputs": [], "source": [ - "china_life = life_expectancy.sel(country=\"China\")\n", - "china_gdp = gdp_per_capita.sel(country=\"China\")\n", + "# With add_secondary_y - both variables visible\n", + "pop_fig = xpx(us_pop).line()\n", + "pop_fig.update_traces(name=\"Population\", line={\"color\": \"blue\", \"width\": 2})\n", "\n", - "life_fig = xpx(china_life).line()\n", - "life_fig.update_traces(line={\"color\": \"green\", \"width\": 3}, name=\"Life Expectancy\")\n", + "gdp_fig = xpx(us_gdp).line()\n", + "gdp_fig.update_traces(name=\"GDP per Capita\", line={\"color\": \"red\", \"width\": 2})\n", "\n", - "gdp_fig = xpx(china_gdp).line()\n", - "gdp_fig.update_traces(line={\"color\": \"orange\", \"width\": 3, \"dash\": \"dash\"}, name=\"GDP per Capita\")\n", + "good = add_secondary_y(pop_fig, gdp_fig, secondary_y_title=\"GDP per Capita ($)\")\n", + "good.update_layout(title=\"add_secondary_y: Both variables clearly visible\")\n", + "good" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "### With Animation\n", "\n", - "combined = add_secondary_y(life_fig, gdp_fig, secondary_y_title=\"GDP per Capita ($)\")\n", - "combined.update_layout(\n", - " title=\"China: Life Expectancy vs GDP per Capita\",\n", - " yaxis_title=\"Life Expectancy (years)\",\n", - " legend={\"orientation\": \"h\", \"y\": 1.1},\n", + "`add_secondary_y` supports animated figures with matching frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "# Animate through countries\n", + "pop_anim = xpx(population).bar(animation_frame=\"country\")\n", + "pop_anim.update_traces(marker={\"color\": \"steelblue\", \"opacity\": 0.7})\n", + "\n", + "gdp_anim = xpx(gdp_per_capita).line(animation_frame=\"country\")\n", + "gdp_anim.update_traces(line={\"color\": \"red\", \"width\": 3})\n", + "\n", + "combined = add_secondary_y(pop_anim, gdp_anim, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(title=\"Population vs GDP (animated by country)\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "### Static Secondary on Animated Base\n", + "\n", + "A static secondary figure is replicated to all animation frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "# Animated population\n", + "pop_anim = xpx(population).bar(animation_frame=\"country\")\n", + "pop_anim.update_traces(marker={\"color\": \"steelblue\", \"opacity\": 0.7})\n", + "\n", + "# Static GDP reference (US only, shown in all frames)\n", + "us_gdp_static = xpx(us_gdp).line()\n", + "us_gdp_static.update_traces(\n", + " line={\"color\": \"red\", \"width\": 2, \"dash\": \"dash\"}, name=\"US GDP (reference)\"\n", ")\n", + "\n", + "combined = add_secondary_y(pop_anim, us_gdp_static, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(title=\"Population (animated) vs US GDP (static reference)\")\n", "combined" ] }, { "cell_type": "markdown", - "id": "16", + "id": "25", "metadata": {}, "source": [ - "### Comparing Scales: Why Secondary Y-Axis Matters\n", + "---\n", "\n", - "Without secondary y-axis, one variable dominates:" + "## Limitations (with examples)\n", + "\n", + "Both functions validate inputs and raise clear errors when constraints are violated." + ] + }, + { + "cell_type": "markdown", + "id": "26", + "metadata": {}, + "source": [ + "### overlay_figures: Mismatched Facet Structure\n", + "\n", + "Overlay cannot have subplots that don't exist in base." ] }, { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "27", "metadata": {}, "outputs": [], "source": [ - "# Same data on single y-axis (population dwarfs GDP)\n", - "pop_fig = xpx(us_pop).line()\n", - "pop_fig.update_traces(name=\"Population\")\n", + "# Base: no facets\n", + "base = xpx(stocks.sel(company=\"GOOG\")).line()\n", "\n", - "gdp_fig = xpx(us_gdp).line()\n", - "gdp_fig.update_traces(name=\"GDP per Capita\")\n", + "# Overlay: has facets\n", + "overlay = xpx(stocks.sel(company=[\"GOOG\", \"AAPL\"])).line(facet_col=\"company\")\n", "\n", - "# With overlay_figures (same y-axis) - GDP looks flat\n", - "bad = overlay_figures(pop_fig, gdp_fig)\n", - "bad.update_layout(title=\"Same Y-Axis: GDP appears flat (scale mismatch)\")\n", - "bad" + "try:\n", + " overlay_figures(base, overlay)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "28", + "metadata": {}, + "source": [ + "### overlay_figures: Animated Overlay on Static Base\n", + "\n", + "Cannot add an animated overlay to a static base figure." ] }, { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "29", "metadata": {}, "outputs": [], "source": [ - "# With add_secondary_y - both variables visible\n", - "pop_fig = xpx(us_pop).line()\n", - "pop_fig.update_traces(name=\"Population\", line={\"color\": \"blue\"})\n", + "# Base: static\n", + "static_base = xpx(population.sel(country=\"United States\")).line()\n", "\n", - "gdp_fig = xpx(us_gdp).line()\n", - "gdp_fig.update_traces(name=\"GDP per Capita\", line={\"color\": \"red\"})\n", + "# Overlay: animated\n", + "animated_overlay = xpx(population).line(animation_frame=\"country\")\n", "\n", - "good = add_secondary_y(pop_fig, gdp_fig, secondary_y_title=\"GDP ($)\")\n", - "good.update_layout(title=\"Dual Y-Axis: Both variables clearly visible\")\n", - "good" + "try:\n", + " overlay_figures(static_base, animated_overlay)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" ] }, { "cell_type": "markdown", - "id": "19", + "id": "30", + "metadata": {}, + "source": [ + "### overlay_figures: Mismatched Animation Frames\n", + "\n", + "Animation frame names must match exactly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "# Different countries selected = different frame names\n", + "fig1 = xpx(population.sel(country=[\"United States\", \"China\"])).line(animation_frame=\"country\")\n", + "fig2 = xpx(population.sel(country=[\"Germany\", \"Brazil\"])).line(animation_frame=\"country\")\n", + "\n", + "try:\n", + " overlay_figures(fig1, fig2)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "32", + "metadata": {}, + "source": [ + "### add_secondary_y: No Facet Support\n", + "\n", + "Secondary y-axis doesn't work with faceted figures." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [], + "source": [ + "# Base with facets\n", + "pop_faceted = xpx(population).bar(facet_col=\"country\")\n", + "\n", + "# Secondary (even without facets)\n", + "gdp_single = xpx(gdp_per_capita.sel(country=\"United States\")).line()\n", + "\n", + "try:\n", + " add_secondary_y(pop_faceted, gdp_single)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "# Secondary with facets (base is fine)\n", + "pop_single = xpx(population.sel(country=\"United States\")).bar()\n", + "gdp_faceted = xpx(gdp_per_capita).line(facet_col=\"country\")\n", + "\n", + "try:\n", + " add_secondary_y(pop_single, gdp_faceted)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "35", + "metadata": {}, + "source": [ + "### add_secondary_y: Animated Secondary on Static Base\n", + "\n", + "Cannot add animated secondary to static base." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", "metadata": {}, + "outputs": [], "source": [ - "## Limitations\n", + "# Static base\n", + "static_pop = xpx(population.sel(country=\"United States\")).bar()\n", "\n", - "### overlay_figures\n", - "- Overlay must have same or fewer subplots than base\n", - "- Animation frames must match (or overlay must be static)\n", + "# Animated secondary\n", + "animated_gdp = xpx(gdp_per_capita).line(animation_frame=\"country\")\n", + "\n", + "try:\n", + " add_secondary_y(static_pop, animated_gdp)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "37", + "metadata": {}, + "source": [ + "### add_secondary_y: Mismatched Animation Frames" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [], + "source": [ + "# Different countries = different frames\n", + "pop_some = xpx(population.sel(country=[\"United States\", \"China\"])).bar(animation_frame=\"country\")\n", + "gdp_other = xpx(gdp_per_capita.sel(country=[\"Germany\", \"Brazil\"])).line(animation_frame=\"country\")\n", + "\n", + "try:\n", + " add_secondary_y(pop_some, gdp_other)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "39", + "metadata": {}, + "source": [ + "## Summary\n", "\n", - "### add_secondary_y\n", - "- Does not support faceted figures (subplots)\n", - "- Animation frames must match (or secondary must be static)" + "| Function | Facets | Animation | Static + Animated |\n", + "|----------|--------|-----------|-------------------|\n", + "| `overlay_figures` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n", + "| `add_secondary_y` | No | Yes (frames must match) | Static secondary on animated base OK |" ] } ], From b27d0b6e1442f674f91494aa84291428e0b256ce Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 22 Jan 2026 08:07:45 +0100 Subject: [PATCH 04/10] =?UTF-8?q?=20=201.=20Added=20facet=20support=20to?= =?UTF-8?q?=20add=5Fsecondary=5Fy=20-=20The=20function=20now=20creates=20s?= =?UTF-8?q?econdary=20y-axes=20for=20each=20facet=20subplot=20(e.g.,=20y?= =?UTF-8?q?=E2=86=92y4,=20y2=E2=86=92y5,=20y3=E2=86=92y6)=20=20=202.=20Upd?= =?UTF-8?q?ated=20tests=20-=20Added=206=20new=20tests=20for=20faceted=20se?= =?UTF-8?q?condary=20y-axis:=20=20=20=20=20-=20test=5Fmatching=5Ffacets=5F?= =?UTF-8?q?works=20=20=20=20=20-=20test=5Ffacets=5Fcreates=5Fmultiple=5Fse?= =?UTF-8?q?condary=5Faxes=20=20=20=20=20-=20test=5Fsecondary=5Ftraces=5Fre?= =?UTF-8?q?mapped=5Fto=5Fcorrect=5Faxes=20=20=20=20=20-=20test=5Fmismatche?= =?UTF-8?q?d=5Ffacets=5Fraises=20=20=20=20=20-=20test=5Fmismatched=5Ffacet?= =?UTF-8?q?s=5Freversed=5Fraises=20=20=20=20=20-=20test=5Ffacets=5Fwith=5F?= =?UTF-8?q?custom=5Ftitle=20=20=203.=20Updated=20notebook=20(docs/examples?= =?UTF-8?q?/combining.ipynb):=20=20=20=20=20-=20Added=20new=20"With=20Face?= =?UTF-8?q?ts"=20section=20showing=20add=5Fsecondary=5Fy=20working=20with?= =?UTF-8?q?=20faceted=20figures=20=20=20=20=20-=20Changed=20"No=20Facet=20?= =?UTF-8?q?Support"=20limitation=20to=20"Mismatched=20Facet=20Structure"?= =?UTF-8?q?=20showing=20the=20error=20when=20structures=20don't=20match=20?= =?UTF-8?q?=20=20=20=20-=20Updated=20summary=20table:=20add=5Fsecondary=5F?= =?UTF-8?q?y=20now=20shows=20"Yes=20(must=20match)"=20for=20Facets=20=20?= =?UTF-8?q?=204.=20All=20103=20tests=20pass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/examples/combining.ipynb | 80 +++++++++++++---------- tests/test_figures.py | 98 ++++++++++++++++++++++------ xarray_plotly/figures.py | 119 ++++++++++++++++++++++++---------- 3 files changed, 210 insertions(+), 87 deletions(-) diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index ab2983e..ce5543e 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -411,6 +411,35 @@ "cell_type": "markdown", "id": "25", "metadata": {}, + "source": [ + "### With Facets\n", + "\n", + "`add_secondary_y` works with faceted figures when both have the same facet structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "# Faceted by country - both figures must have same facet structure\n", + "pop_faceted = xpx(population).bar(facet_col=\"country\")\n", + "pop_faceted.update_traces(marker={\"color\": \"steelblue\", \"opacity\": 0.7})\n", + "\n", + "gdp_faceted = xpx(gdp_per_capita).line(facet_col=\"country\")\n", + "gdp_faceted.update_traces(line={\"color\": \"red\", \"width\": 3})\n", + "\n", + "combined = add_secondary_y(pop_faceted, gdp_faceted, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(title=\"Population vs GDP per Capita (faceted by country)\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": {}, "source": [ "---\n", "\n", @@ -421,7 +450,7 @@ }, { "cell_type": "markdown", - "id": "26", + "id": "28", "metadata": {}, "source": [ "### overlay_figures: Mismatched Facet Structure\n", @@ -432,7 +461,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -450,7 +479,7 @@ }, { "cell_type": "markdown", - "id": "28", + "id": "30", "metadata": {}, "source": [ "### overlay_figures: Animated Overlay on Static Base\n", @@ -461,7 +490,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "31", "metadata": {}, "outputs": [], "source": [ @@ -479,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "32", "metadata": {}, "source": [ "### overlay_figures: Mismatched Animation Frames\n", @@ -490,7 +519,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -506,25 +535,25 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "34", "metadata": {}, "source": [ - "### add_secondary_y: No Facet Support\n", + "### add_secondary_y: Mismatched Facet Structure\n", "\n", - "Secondary y-axis doesn't work with faceted figures." + "Both figures must have the same facet structure." ] }, { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "35", "metadata": {}, "outputs": [], "source": [ "# Base with facets\n", "pop_faceted = xpx(population).bar(facet_col=\"country\")\n", "\n", - "# Secondary (even without facets)\n", + "# Secondary without facets (different structure)\n", "gdp_single = xpx(gdp_per_capita.sel(country=\"United States\")).line()\n", "\n", "try:\n", @@ -533,26 +562,9 @@ " print(f\"ValueError: {e}\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "34", - "metadata": {}, - "outputs": [], - "source": [ - "# Secondary with facets (base is fine)\n", - "pop_single = xpx(population.sel(country=\"United States\")).bar()\n", - "gdp_faceted = xpx(gdp_per_capita).line(facet_col=\"country\")\n", - "\n", - "try:\n", - " add_secondary_y(pop_single, gdp_faceted)\n", - "except ValueError as e:\n", - " print(f\"ValueError: {e}\")" - ] - }, { "cell_type": "markdown", - "id": "35", + "id": "36", "metadata": {}, "source": [ "### add_secondary_y: Animated Secondary on Static Base\n", @@ -563,7 +575,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36", + "id": "37", "metadata": {}, "outputs": [], "source": [ @@ -581,7 +593,7 @@ }, { "cell_type": "markdown", - "id": "37", + "id": "38", "metadata": {}, "source": [ "### add_secondary_y: Mismatched Animation Frames" @@ -590,7 +602,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "39", "metadata": {}, "outputs": [], "source": [ @@ -606,7 +618,7 @@ }, { "cell_type": "markdown", - "id": "39", + "id": "40", "metadata": {}, "source": [ "## Summary\n", @@ -614,7 +626,7 @@ "| Function | Facets | Animation | Static + Animated |\n", "|----------|--------|-----------|-------------------|\n", "| `overlay_figures` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n", - "| `add_secondary_y` | No | Yes (frames must match) | Static secondary on animated base OK |" + "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |" ] } ], diff --git a/tests/test_figures.py b/tests/test_figures.py index 903a2f8..befba25 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -427,35 +427,93 @@ def test_custom_secondary_y_title(self) -> None: assert combined.layout.yaxis2.title.text == "Rain (mm)" -class TestAddSecondaryYFacetsError: - """Tests for add_secondary_y error handling with facets.""" +class TestAddSecondaryYFacets: + """Tests for add_secondary_y with faceted figures.""" - def test_base_with_facets_raises(self) -> None: - """Test that base figure with facets raises ValueError.""" - da = xr.DataArray( - np.random.rand(10, 2), + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da = xr.DataArray( + np.random.rand(10, 3), + dims=["time", "facet"], + coords={"time": np.arange(10), "facet": ["A", "B", "C"]}, + name="value", + ) + # Different scale for secondary + self.da_secondary = xr.DataArray( + np.random.rand(10, 3) * 1000, dims=["time", "facet"], - coords={"time": np.arange(10), "facet": ["A", "B"]}, + coords={"time": np.arange(10), "facet": ["A", "B", "C"]}, + name="large_value", ) - base = xpx(da).line(facet_col="facet") - secondary = xpx(da.isel(facet=0)).bar() - with pytest.raises(ValueError, match="Base figure has facets"): + def test_matching_facets_works(self) -> None: + """Test that matching facet structures work.""" + base = xpx(self.da).line(facet_col="facet") + secondary = xpx(self.da_secondary).bar(facet_col="facet") + + combined = add_secondary_y(base, secondary) + + assert isinstance(combined, go.Figure) + expected_traces = len(base.data) + len(secondary.data) + assert len(combined.data) == expected_traces + + def test_facets_creates_multiple_secondary_axes(self) -> None: + """Test that secondary y-axes are created for each facet.""" + base = xpx(self.da).line(facet_col="facet") + secondary = xpx(self.da_secondary).bar(facet_col="facet") + + combined = add_secondary_y(base, secondary) + + # Should have yaxis2 (secondary for y), yaxis5 (secondary for y2), etc. + # Base has y, y2, y3, so secondary should be y4, y5, y6 + layout_json = combined.layout.to_plotly_json() + assert "yaxis4" in layout_json + assert layout_json["yaxis4"]["overlaying"] == "y" + assert layout_json["yaxis4"]["side"] == "right" + + def test_secondary_traces_remapped_to_correct_axes(self) -> None: + """Test that secondary traces use correct secondary y-axes.""" + base = xpx(self.da).line(facet_col="facet") + secondary = xpx(self.da_secondary).bar(facet_col="facet") + + combined = add_secondary_y(base, secondary) + + # Get secondary trace y-axes + secondary_trace_yaxes = {trace.yaxis for trace in combined.data[len(base.data) :]} + # Should be y4, y5, y6 (secondary axes) + assert secondary_trace_yaxes == {"y4", "y5", "y6"} + + def test_mismatched_facets_raises(self) -> None: + """Test that mismatched facet structures raise ValueError.""" + # Base with facets + base = xpx(self.da).line(facet_col="facet") + # Secondary without facets + secondary = xpx(self.da.isel(facet=0)).bar() + + with pytest.raises(ValueError, match="same facet structure"): add_secondary_y(base, secondary) - def test_secondary_with_facets_raises(self) -> None: - """Test that secondary figure with facets raises ValueError.""" - da = xr.DataArray( - np.random.rand(10, 2), - dims=["time", "facet"], - coords={"time": np.arange(10), "facet": ["A", "B"]}, - ) - base = xpx(da.isel(facet=0)).line() - secondary = xpx(da).bar(facet_col="facet") + def test_mismatched_facets_reversed_raises(self) -> None: + """Test that mismatched facets raise (base without, secondary with).""" + # Base without facets + base = xpx(self.da.isel(facet=0)).line() + # Secondary with facets + secondary = xpx(self.da).bar(facet_col="facet") - with pytest.raises(ValueError, match="Secondary figure has facets"): + with pytest.raises(ValueError, match="same facet structure"): add_secondary_y(base, secondary) + def test_facets_with_custom_title(self) -> None: + """Test custom secondary y-axis title with facets.""" + base = xpx(self.da).line(facet_col="facet") + secondary = xpx(self.da_secondary).bar(facet_col="facet") + + combined = add_secondary_y(base, secondary, secondary_y_title="Custom Title") + + # Title should be on the first secondary axis + assert combined.layout.yaxis4.title.text == "Custom Title" + class TestAddSecondaryYAnimation: """Tests for add_secondary_y with animated figures.""" diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index f3edfc4..82fccd7 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -200,6 +200,34 @@ def overlay_figures(base: go.Figure, *overlays: go.Figure) -> go.Figure: combine_figures = overlay_figures +def _build_secondary_y_mapping(base_axes: set[tuple[str, str]]) -> dict[str, str]: + """Build mapping from primary y-axes to secondary y-axes. + + Args: + base_axes: Set of (xaxis, yaxis) pairs from base figure. + + Returns: + Dict mapping primary yaxis names to secondary yaxis names. + E.g., {'y': 'y4', 'y2': 'y5', 'y3': 'y6'} + """ + primary_y_axes = sorted({yaxis for _, yaxis in base_axes}) + + # Find the highest existing yaxis number + max_y_num = 1 # 'y' is 1 + for yaxis in primary_y_axes: + num = 1 if yaxis == "y" else int(yaxis[1:]) + max_y_num = max(max_y_num, num) + + # Create mapping: primary_yaxis -> secondary_yaxis + y_mapping = {} + next_y_num = max_y_num + 1 + for yaxis in primary_y_axes: + y_mapping[yaxis] = f"y{next_y_num}" + next_y_num += 1 + + return y_mapping + + def add_secondary_y( base: go.Figure, secondary: go.Figure, @@ -208,9 +236,10 @@ def add_secondary_y( ) -> go.Figure: """Add a secondary y-axis with traces from another figure. - Creates a new figure with the base figure's layout and a secondary y-axis + Creates a new figure with the base figure's layout and secondary y-axes on the right side. All traces from the secondary figure are plotted against - the secondary y-axis. + the secondary y-axes. Supports faceted figures when both have matching + facet structure. Args: base: The base figure (left y-axis). @@ -222,7 +251,7 @@ def add_secondary_y( A new figure with both primary and secondary y-axes. Raises: - ValueError: If either figure has facets (subplots), or if animation + ValueError: If facet structures don't match, or if animation frames don't match. Example: @@ -237,27 +266,32 @@ def add_secondary_y( >>> temp_fig = xpx(temp).line() >>> precip_fig = xpx(precip).bar() >>> combined = add_secondary_y(temp_fig, precip_fig) + >>> + >>> # With facets + >>> data = xr.DataArray(np.random.rand(10, 3), dims=["x", "facet"]) + >>> fig1 = xpx(data).line(facet_col="facet") + >>> fig2 = xpx(data * 100).bar(facet_col="facet") # Different scale + >>> combined = add_secondary_y(fig1, fig2) """ import plotly.graph_objects as go - # Check for facets - not supported with secondary y + # Get axis pairs from both figures base_axes = _get_subplot_axes(base) secondary_axes = _get_subplot_axes(secondary) - if len(base_axes) > 1 or base_axes != {("x", "y")}: - raise ValueError( - "Base figure has facets (subplots). Secondary y-axis is not supported " - "with faceted figures." - ) - if len(secondary_axes) > 1 or secondary_axes != {("x", "y")}: + # Validate same facet structure + if base_axes != secondary_axes: raise ValueError( - "Secondary figure has facets (subplots). Secondary y-axis is not supported " - "with faceted figures." + f"Base and secondary figures must have the same facet structure. " + f"Base has {base_axes}, secondary has {secondary_axes}." ) # Validate animation compatibility _validate_animation_compatibility(base, secondary) + # Build mapping from primary y-axes to secondary y-axes + y_mapping = _build_secondary_y_mapping(base_axes) + # Create new figure with base's layout combined = go.Figure(layout=copy.deepcopy(base.layout)) @@ -265,42 +299,61 @@ def add_secondary_y( for trace in base.data: combined.add_trace(copy.deepcopy(trace)) - # Add all traces from secondary, assigned to y2 + # Add all traces from secondary, remapped to secondary y-axes for trace in secondary.data: trace_copy = copy.deepcopy(trace) - trace_copy.yaxis = "y2" + original_yaxis = getattr(trace_copy, "yaxis", None) or "y" + trace_copy.yaxis = y_mapping[original_yaxis] combined.add_trace(trace_copy) - # Configure secondary y-axis - y2_title = secondary_y_title - if y2_title is None and secondary.layout.yaxis and secondary.layout.yaxis.title: - y2_title = secondary.layout.yaxis.title.text - - combined.update_layout( - yaxis2={ - "title": y2_title, - "overlaying": "y", + # Configure secondary y-axes + for primary_yaxis, secondary_yaxis in y_mapping.items(): + # Get title - only set on first secondary axis or use provided title + title = None + if secondary_y_title is not None: + # Only set title on the first secondary axis to avoid repetition + if primary_yaxis == "y": + title = secondary_y_title + elif primary_yaxis == "y" and secondary.layout.yaxis and secondary.layout.yaxis.title: + # Try to get from secondary's layout + title = secondary.layout.yaxis.title.text + + # Configure the secondary axis + axis_config = { + "title": title, + "overlaying": primary_yaxis, "side": "right", - }, - ) + "anchor": "free" if primary_yaxis != "y" else None, + } + # Remove None values + axis_config = {k: v for k, v in axis_config.items() if v is not None} + + # Convert y2 -> yaxis2, y3 -> yaxis3, etc. for layout property name + layout_prop = "yaxis" if secondary_yaxis == "y" else f"yaxis{secondary_yaxis[1:]}" + combined.update_layout(**{layout_prop: axis_config}) # Handle animation frames if base.frames: - merged_frames = _merge_secondary_y_frames(base, secondary) + merged_frames = _merge_secondary_y_frames(base, secondary, y_mapping) combined.frames = merged_frames return combined -def _merge_secondary_y_frames(base: go.Figure, secondary: go.Figure) -> list: +def _merge_secondary_y_frames( + base: go.Figure, + secondary: go.Figure, + y_mapping: dict[str, str], +) -> list: """Merge animation frames for secondary y-axis combination. Args: base: The base figure with animation frames. secondary: The secondary figure (may or may not have frames). + y_mapping: Mapping from primary y-axis names to secondary y-axis names. Returns: - List of merged frames with secondary traces assigned to y2. + List of merged frames with secondary traces assigned to secondary y-axes. """ import plotly.graph_objects as go @@ -316,18 +369,18 @@ def _merge_secondary_y_frames(base: go.Figure, secondary: go.Figure) -> list: # Find matching frame in secondary secondary_frame = next((f for f in secondary.frames if f.name == frame_name), None) if secondary_frame: - # Add secondary frame data with y2 assignment + # Add secondary frame data with remapped y-axis for trace_data in secondary_frame.data: trace_copy = copy.deepcopy(trace_data) - if hasattr(trace_copy, "yaxis"): - trace_copy.yaxis = "y2" + original_yaxis = getattr(trace_copy, "yaxis", None) or "y" + trace_copy.yaxis = y_mapping.get(original_yaxis, original_yaxis) merged_data.append(trace_copy) else: # Static secondary: replicate traces to this frame for trace in secondary.data: trace_copy = copy.deepcopy(trace) - if hasattr(trace_copy, "yaxis"): - trace_copy.yaxis = "y2" + original_yaxis = getattr(trace_copy, "yaxis", None) or "y" + trace_copy.yaxis = y_mapping.get(original_yaxis, original_yaxis) merged_data.append(trace_copy) merged_frames.append( From afc29dec5749b183023f8908ab698be12a35f2e9 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 22 Jan 2026 08:10:05 +0100 Subject: [PATCH 05/10] Add notebook showing off manipulation options --- docs/examples/manipulation.ipynb | 587 +++++++++++++++++++++++++++++++ 1 file changed, 587 insertions(+) create mode 100644 docs/examples/manipulation.ipynb diff --git a/docs/examples/manipulation.ipynb b/docs/examples/manipulation.ipynb new file mode 100644 index 0000000..7f16aa1 --- /dev/null +++ b/docs/examples/manipulation.ipynb @@ -0,0 +1,587 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Figure Manipulation\n", + "\n", + "How to modify a [Plotly Figure](https://plotly.com/python/figure-structure/) after creation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.express as px\n", + "import plotly.io as pio\n", + "\n", + "pio.renderers.default = \"notebook_connected\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "df = px.data.gapminder().query(\"year == 2007\")\n", + "fig = px.scatter(\n", + " df, x=\"gdpPercap\", y=\"lifeExp\", color=\"continent\", size=\"pop\", hover_name=\"country\"\n", + ")\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## update_layout\n", + "\n", + "Modify [layout properties](https://plotly.com/python/reference/layout/): title, legend, axes, margins." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "fig.update_layout(\n", + " title={\"text\": \"GDP vs Life Expectancy (2007)\", \"x\": 0.5},\n", + " xaxis_title=\"GDP per Capita ($)\",\n", + " yaxis_title=\"Life Expectancy (years)\",\n", + ")\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## update_traces\n", + "\n", + "Modify [trace properties](https://plotly.com/python/reference/). Use `selector` to target specific traces." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "fig.update_traces(marker_opacity=0.8)\n", + "fig.update_traces(marker_line_width=2, selector={\"name\": \"Europe\"})\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## update_xaxes / update_yaxes\n", + "\n", + "Modify [axis properties](https://plotly.com/python/axes/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "fig.update_xaxes(type=\"log\", showgrid=True, gridcolor=\"lightgray\")\n", + "fig.update_yaxes(range=[40, 90])\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## add_hline / add_vline\n", + "\n", + "Add reference lines. See [shapes](https://plotly.com/python/shapes/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "fig.add_hline(y=df[\"lifeExp\"].mean(), line_dash=\"dash\", line_color=\"gray\", annotation_text=\"Mean\")\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## add_annotation\n", + "\n", + "Add text annotations. See [annotations](https://plotly.com/python/text-and-annotations/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "fig.add_annotation(x=4.5, y=82, text=\"Developed Nations\", showarrow=False, font_size=12)\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "## Direct Access\n", + "\n", + "Access traces via `fig.data` and layout via `fig.layout`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# Change first trace\n", + "fig.data[0].marker.symbol = \"diamond\"\n", + "\n", + "# Change layout directly\n", + "fig.layout.legend.title.text = \"Region\"\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Faceted Plots\n", + "\n", + "With facets, there are multiple axes: `xaxis`, `xaxis2`, etc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "df_time = px.data.gapminder().query(\"country in ['United States', 'China', 'Germany']\")\n", + "fig2 = px.line(df_time, x=\"year\", y=\"gdpPercap\", color=\"country\", facet_col=\"country\")\n", + "fig2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# Update all x-axes\n", + "fig2.update_xaxes(showgrid=False)\n", + "\n", + "# Update specific facet (col=2 is the middle one)\n", + "fig2.update_yaxes(type=\"log\", col=2)\n", + "\n", + "# Modify facet labels (stored as annotations)\n", + "fig2.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig2" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "### Facet Grid (rows and columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "df_facet = px.data.tips()\n", + "fig3 = px.histogram(df_facet, x=\"total_bill\", facet_row=\"sex\", facet_col=\"time\")\n", + "fig3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "# Target specific cell in the grid\n", + "fig3.update_xaxes(range=[0, 40], row=1, col=1) # top-left only\n", + "\n", + "# Update entire row\n", + "fig3.update_yaxes(title_text=\"Count\", row=2)\n", + "\n", + "# Update entire column\n", + "fig3.update_xaxes(title_text=\"Bill ($)\", col=2)\n", + "\n", + "fig3.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig3" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "### Direct Axis Access\n", + "\n", + "Access axes directly via `fig.layout.xaxis`, `fig.layout.xaxis2`, etc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "fig4 = px.scatter(df, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", + "\n", + "# See which axes exist\n", + "print(\"X axes:\", [k for k in fig4.layout.to_plotly_json() if k.startswith(\"xaxis\")])\n", + "print(\"Y axes:\", [k for k in fig4.layout.to_plotly_json() if k.startswith(\"yaxis\")])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "# Modify specific axis directly\n", + "fig4.layout.xaxis.type = \"log\"\n", + "fig4.layout.xaxis2.type = \"log\"\n", + "fig4.layout.yaxis.title.text = \"Life Exp\"\n", + "\n", + "fig4.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig4" + ] + }, + { + "cell_type": "markdown", + "id": "24", + "metadata": {}, + "source": [ + "### Shapes on Specific Facets\n", + "\n", + "Use `xref` and `yref` to target specific facet axes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "fig5 = px.scatter(df, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", + "fig5.update_xaxes(type=\"log\")\n", + "\n", + "# Add rectangle to first facet (x, y)\n", + "fig5.add_shape(\n", + " type=\"rect\",\n", + " x0=1000,\n", + " x1=10000,\n", + " y0=70,\n", + " y1=85,\n", + " fillcolor=\"lightblue\",\n", + " opacity=0.3,\n", + " line_width=0,\n", + " xref=\"x\",\n", + " yref=\"y\",\n", + ")\n", + "\n", + "# Add rectangle to second facet (x2, y2)\n", + "fig5.add_shape(\n", + " type=\"rect\",\n", + " x0=1000,\n", + " x1=10000,\n", + " y0=70,\n", + " y1=85,\n", + " fillcolor=\"lightgreen\",\n", + " opacity=0.3,\n", + " line_width=0,\n", + " xref=\"x2\",\n", + " yref=\"y2\",\n", + ")\n", + "\n", + "fig5.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig5" + ] + }, + { + "cell_type": "markdown", + "id": "26", + "metadata": {}, + "source": [ + "### Axis Matching\n", + "\n", + "Control whether facets share the same axis range." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27", + "metadata": {}, + "outputs": [], + "source": [ + "fig6 = px.histogram(df_facet, x=\"total_bill\", facet_col=\"day\")\n", + "\n", + "# Default: axes are matched (same range)\n", + "fig6.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig6" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [], + "source": [ + "# Make y-axes independent (each facet auto-scales)\n", + "fig6.update_yaxes(matches=None)\n", + "fig6" + ] + }, + { + "cell_type": "markdown", + "id": "29", + "metadata": {}, + "source": [ + "## Animated Plots\n", + "\n", + "Animations have frames, sliders, and play buttons." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "df_anim = px.data.gapminder()\n", + "fig7 = px.scatter(\n", + " df_anim,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " size=\"pop\",\n", + " color=\"continent\",\n", + " hover_name=\"country\",\n", + " animation_frame=\"year\",\n", + " animation_group=\"country\",\n", + " log_x=True,\n", + " range_y=[25, 90],\n", + ")\n", + "fig7" + ] + }, + { + "cell_type": "markdown", + "id": "31", + "metadata": {}, + "source": [ + "### Animation Speed\n", + "\n", + "Modify frame duration and transition time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [ + "fig7.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = 200 # ms per frame\n", + "fig7.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = 100 # transition time\n", + "fig7" + ] + }, + { + "cell_type": "markdown", + "id": "33", + "metadata": {}, + "source": [ + "### Slider Styling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "fig7.layout.sliders[0].currentvalue.prefix = \"Year: \"\n", + "fig7.layout.sliders[0].currentvalue.font.size = 16\n", + "fig7.layout.sliders[0].currentvalue.font.color = \"darkblue\"\n", + "fig7" + ] + }, + { + "cell_type": "markdown", + "id": "35", + "metadata": {}, + "source": [ + "### Play/Pause Button Styling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [], + "source": [ + "fig7.layout.updatemenus[0].bgcolor = \"lightgray\"\n", + "fig7.layout.updatemenus[0].font.color = \"black\"\n", + "fig7" + ] + }, + { + "cell_type": "markdown", + "id": "37", + "metadata": {}, + "source": [ + "### Modify Individual Frames\n", + "\n", + "Access frames via `fig.frames`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Number of frames: {len(fig7.frames)}\")\n", + "print(f\"Frame names: {[f.name for f in fig7.frames]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [], + "source": [ + "# Change layout for a specific frame (e.g., add title showing year)\n", + "for frame in fig7.frames:\n", + " frame.layout = {\"title\": f\"Gapminder {frame.name}\"}\n", + "fig7" + ] + }, + { + "cell_type": "markdown", + "id": "40", + "metadata": {}, + "source": [ + "### Hide Slider or Buttons" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41", + "metadata": {}, + "outputs": [], + "source": [ + "fig8 = px.scatter(\n", + " df_anim,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"continent\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[25, 90],\n", + ")\n", + "\n", + "# Hide the slider\n", + "fig8.layout.sliders = []\n", + "\n", + "# Or hide the play button instead:\n", + "# fig8.layout.updatemenus = []\n", + "\n", + "fig8" + ] + }, + { + "cell_type": "markdown", + "id": "42", + "metadata": {}, + "source": [ + "## Method Chaining\n", + "\n", + "All `update_*` methods return the figure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " px.scatter(df, x=\"gdpPercap\", y=\"lifeExp\", color=\"continent\")\n", + " .update_layout(title=\"Chained Example\")\n", + " .update_traces(marker_size=12)\n", + " .update_xaxes(type=\"log\")\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From acab4d3e70347666805cf67e167d786b980a9f54 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 22 Jan 2026 08:24:52 +0100 Subject: [PATCH 06/10] Updated the notebook to: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Use overlay_figures for adding traces - shown under "Easy: Adding traces to faceted/animated figures" 2. Keep just two helpers: - update_animation_traces() - the main pain point - set_animation_speed() - for the deeply nested API 3. Added facets + animation example showing the helper works for both So the final picture is: ┌─────────────────────────────────────────────────────┬──────────────────────────────────────────┐ │ What you want to do │ Solution │ ├─────────────────────────────────────────────────────┼──────────────────────────────────────────┤ │ Add traces to animated/faceted figures │ overlay_figures() ✅ already in library │ ├─────────────────────────────────────────────────────┼──────────────────────────────────────────┤ │ Update trace style (line_width, etc.) on animations │ update_animation_traces() - needs helper │ ├─────────────────────────────────────────────────────┼──────────────────────────────────────────┤ │ Change animation speed │ set_animation_speed() - needs helper │ ├─────────────────────────────────────────────────────┼──────────────────────────────────────────┤ │ Everything else │ Works out of the box │ └─────────────────────────────────────────────────────┴──────────────────────────────────────────┘ Should we add update_animation_traces() and set_animation_speed() to xarray_plotly.figures as proper exported functions? They're simple but solve real pain points. --- docs/examples/manipulation.ipynb | 578 ++++++++++++++++++------------- 1 file changed, 340 insertions(+), 238 deletions(-) diff --git a/docs/examples/manipulation.ipynb b/docs/examples/manipulation.ipynb index 7f16aa1..0024737 100644 --- a/docs/examples/manipulation.ipynb +++ b/docs/examples/manipulation.ipynb @@ -7,7 +7,7 @@ "source": [ "# Figure Manipulation\n", "\n", - "How to modify a [Plotly Figure](https://plotly.com/python/figure-structure/) after creation." + "What's easy, what's annoying, and how to work around it." ] }, { @@ -18,33 +18,39 @@ "outputs": [], "source": [ "import plotly.express as px\n", + "import plotly.graph_objects as go\n", "import plotly.io as pio\n", "\n", - "pio.renderers.default = \"notebook_connected\"" + "from xarray_plotly import overlay_figures\n", + "\n", + "pio.renderers.default = \"notebook_connected\"\n", + "\n", + "# Sample data\n", + "df = px.data.gapminder()\n", + "df_2007 = df.query(\"year == 2007\")\n", + "df_countries = df.query(\"country in ['United States', 'China', 'Germany', 'Brazil']\")" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "2", "metadata": {}, - "outputs": [], "source": [ - "df = px.data.gapminder().query(\"year == 2007\")\n", - "fig = px.scatter(\n", - " df, x=\"gdpPercap\", y=\"lifeExp\", color=\"continent\", size=\"pop\", hover_name=\"country\"\n", - ")\n", - "fig" + "---\n", + "# Easy: Single Plots\n", + "\n", + "All standard manipulation methods work as expected." ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "3", "metadata": {}, + "outputs": [], "source": [ - "## update_layout\n", - "\n", - "Modify [layout properties](https://plotly.com/python/reference/layout/): title, legend, axes, margins." + "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", color=\"continent\", size=\"pop\")\n", + "fig" ] }, { @@ -54,11 +60,22 @@ "metadata": {}, "outputs": [], "source": [ - "fig.update_layout(\n", - " title={\"text\": \"GDP vs Life Expectancy (2007)\", \"x\": 0.5},\n", - " xaxis_title=\"GDP per Capita ($)\",\n", - " yaxis_title=\"Life Expectancy (years)\",\n", - ")\n", + "# Layout\n", + "fig.update_layout(title=\"GDP vs Life Expectancy\", template=\"plotly_white\")\n", + "\n", + "# All traces\n", + "fig.update_traces(marker_opacity=0.7)\n", + "\n", + "# Specific traces\n", + "fig.update_traces(marker_line_width=2, selector={\"name\": \"Europe\"})\n", + "\n", + "# Axes\n", + "fig.update_xaxes(type=\"log\", title=\"GDP per Capita\")\n", + "fig.update_yaxes(range=[40, 90])\n", + "\n", + "# Annotations and shapes\n", + "fig.add_hline(y=df_2007[\"lifeExp\"].mean(), line_dash=\"dash\", line_color=\"gray\")\n", + "\n", "fig" ] }, @@ -67,9 +84,10 @@ "id": "5", "metadata": {}, "source": [ - "## update_traces\n", + "---\n", + "# Easy: Faceted Plots\n", "\n", - "Modify [trace properties](https://plotly.com/python/reference/). Use `selector` to target specific traces." + "`update_traces`, `update_xaxes`, `update_yaxes` all work across facets." ] }, { @@ -79,83 +97,137 @@ "metadata": {}, "outputs": [], "source": [ - "fig.update_traces(marker_opacity=0.8)\n", - "fig.update_traces(marker_line_width=2, selector={\"name\": \"Europe\"})\n", + "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", facet_col=\"country\")\n", "fig" ] }, - { - "cell_type": "markdown", - "id": "7", - "metadata": {}, - "source": [ - "## update_xaxes / update_yaxes\n", - "\n", - "Modify [axis properties](https://plotly.com/python/axes/)." - ] - }, { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "7", "metadata": {}, "outputs": [], "source": [ - "fig.update_xaxes(type=\"log\", showgrid=True, gridcolor=\"lightgray\")\n", - "fig.update_yaxes(range=[40, 90])\n", + "# Update ALL traces across all facets\n", + "fig.update_traces(line_width=3)\n", + "\n", + "# Update ALL x-axes\n", + "fig.update_xaxes(showgrid=False)\n", + "\n", + "# Update ALL y-axes\n", + "fig.update_yaxes(showgrid=False, type=\"log\")\n", + "\n", + "# Clean up facet labels\n", + "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "\n", "fig" ] }, { "cell_type": "markdown", - "id": "9", + "id": "8", "metadata": {}, "source": [ - "## add_hline / add_vline\n", + "### Targeting specific facets\n", "\n", - "Add reference lines. See [shapes](https://plotly.com/python/shapes/)." + "Use `row=` and `col=` (1-indexed) to target specific facets." ] }, { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "9", "metadata": {}, "outputs": [], "source": [ - "fig.add_hline(y=df[\"lifeExp\"].mean(), line_dash=\"dash\", line_color=\"gray\", annotation_text=\"Mean\")\n", + "fig = px.histogram(px.data.tips(), x=\"total_bill\", facet_row=\"sex\", facet_col=\"time\")\n", + "\n", + "# Target specific cell\n", + "fig.update_yaxes(title_text=\"Frequency\", row=1, col=1)\n", + "\n", + "# Target entire column\n", + "fig.update_xaxes(title_text=\"Bill ($)\", col=2)\n", + "\n", + "# Target entire row\n", + "fig.update_traces(marker_color=\"orange\", row=2)\n", + "\n", + "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", "fig" ] }, { "cell_type": "markdown", - "id": "11", + "id": "10", "metadata": {}, "source": [ - "## add_annotation\n", + "### Reference lines on facets\n", "\n", - "Add text annotations. See [annotations](https://plotly.com/python/text-and-annotations/)." + "`add_hline`/`add_vline` apply to all facets by default. Use `row=`/`col=` to target." ] }, { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "11", "metadata": {}, "outputs": [], "source": [ - "fig.add_annotation(x=4.5, y=82, text=\"Developed Nations\", showarrow=False, font_size=12)\n", + "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", + "fig.update_xaxes(type=\"log\")\n", + "\n", + "# Applies to ALL facets\n", + "fig.add_hline(y=70, line_dash=\"dash\", line_color=\"red\")\n", + "\n", + "# Specific facet only\n", + "fig.add_hline(y=50, line_dash=\"dot\", line_color=\"blue\", row=2, col=1)\n", + "\n", + "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", "fig" ] }, { "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "---\n", + "# Easy: Adding traces to faceted/animated figures\n", + "\n", + "Use `overlay_figures` to add traces. It handles facets and animation frames automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "13", "metadata": {}, + "outputs": [], "source": [ - "## Direct Access\n", + "# Animated scatter\n", + "fig = px.scatter(\n", + " df_countries,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"country\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[40, 85],\n", + ")\n", "\n", - "Access traces via `fig.data` and layout via `fig.layout`." + "# Create a figure with reference marker\n", + "ref = go.Figure(\n", + " go.Scatter(\n", + " x=[10000],\n", + " y=[75],\n", + " mode=\"markers\",\n", + " marker={\"size\": 20, \"symbol\": \"star\", \"color\": \"gold\"},\n", + " name=\"Target\",\n", + " )\n", + ")\n", + "\n", + "# Overlay - trace appears in all animation frames\n", + "combined = overlay_figures(fig, ref)\n", + "combined" ] }, { @@ -165,12 +237,37 @@ "metadata": {}, "outputs": [], "source": [ - "# Change first trace\n", - "fig.data[0].marker.symbol = \"diamond\"\n", + "# Faceted plot\n", + "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", + "fig.update_xaxes(type=\"log\")\n", "\n", - "# Change layout directly\n", - "fig.layout.legend.title.text = \"Region\"\n", - "fig" + "# Add reference to first facet (default axes x, y)\n", + "ref1 = go.Figure(\n", + " go.Scatter(\n", + " x=[5000],\n", + " y=[70],\n", + " mode=\"markers\",\n", + " marker={\"size\": 15, \"symbol\": \"star\", \"color\": \"gold\"},\n", + " name=\"Target 1\",\n", + " )\n", + ")\n", + "\n", + "# Add reference to second facet (axes x2, y2)\n", + "ref2 = go.Figure(\n", + " go.Scatter(\n", + " x=[20000],\n", + " y=[80],\n", + " mode=\"markers\",\n", + " marker={\"size\": 15, \"symbol\": \"star\", \"color\": \"red\"},\n", + " name=\"Target 2\",\n", + " xaxis=\"x2\",\n", + " yaxis=\"y2\", # specify target facet\n", + " )\n", + ")\n", + "\n", + "combined = overlay_figures(fig, ref1, ref2)\n", + "combined.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "combined" ] }, { @@ -178,9 +275,10 @@ "id": "15", "metadata": {}, "source": [ - "## Faceted Plots\n", + "---\n", + "# Annoying: Facet axis names\n", "\n", - "With facets, there are multiple axes: `xaxis`, `xaxis2`, etc." + "To target a specific facet with `add_shape`, `add_annotation`, or when adding traces via `overlay_figures`, you need to know the axis name (`x2`, `y3`, etc.)." ] }, { @@ -190,9 +288,12 @@ "metadata": {}, "outputs": [], "source": [ - "df_time = px.data.gapminder().query(\"country in ['United States', 'China', 'Germany']\")\n", - "fig2 = px.line(df_time, x=\"year\", y=\"gdpPercap\", color=\"country\", facet_col=\"country\")\n", - "fig2" + "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", + "\n", + "# Inspect axis names\n", + "layout_dict = fig.layout.to_plotly_json()\n", + "print(\"X axes:\", sorted([k for k in layout_dict if k.startswith(\"xaxis\")]))\n", + "print(\"Y axes:\", sorted([k for k in layout_dict if k.startswith(\"yaxis\")]))" ] }, { @@ -202,15 +303,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Update all x-axes\n", - "fig2.update_xaxes(showgrid=False)\n", - "\n", - "# Update specific facet (col=2 is the middle one)\n", - "fig2.update_yaxes(type=\"log\", col=2)\n", - "\n", - "# Modify facet labels (stored as annotations)\n", - "fig2.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", - "fig2" + "# Check which trace uses which axis\n", + "for i, trace in enumerate(fig.data):\n", + " print(f\"Trace {i} ({trace.name}): xaxis={trace.xaxis or 'x'}, yaxis={trace.yaxis or 'y'}\")" ] }, { @@ -218,19 +313,18 @@ "id": "18", "metadata": {}, "source": [ - "### Facet Grid (rows and columns)" + "**Tip:** For simple cases, use `add_hline`/`add_vline` with `row=`/`col=` instead of `add_shape` - it handles axis mapping internally." ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "19", "metadata": {}, - "outputs": [], "source": [ - "df_facet = px.data.tips()\n", - "fig3 = px.histogram(df_facet, x=\"total_bill\", facet_row=\"sex\", facet_col=\"time\")\n", - "fig3" + "---\n", + "# Annoying: Animation trace updates\n", + "\n", + "**This is the main pain point.** `update_traces()` does NOT update animation frames." ] }, { @@ -240,27 +334,22 @@ "metadata": {}, "outputs": [], "source": [ - "# Target specific cell in the grid\n", - "fig3.update_xaxes(range=[0, 40], row=1, col=1) # top-left only\n", - "\n", - "# Update entire row\n", - "fig3.update_yaxes(title_text=\"Count\", row=2)\n", - "\n", - "# Update entire column\n", - "fig3.update_xaxes(title_text=\"Bill ($)\", col=2)\n", - "\n", - "fig3.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", - "fig3" + "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n", + "fig" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "21", "metadata": {}, + "outputs": [], "source": [ - "### Direct Axis Access\n", + "# This only affects the INITIAL view, not the animation frames!\n", + "fig.update_traces(line_width=5, line_dash=\"dot\")\n", "\n", - "Access axes directly via `fig.layout.xaxis`, `fig.layout.xaxis2`, etc." + "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n", + "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")" ] }, { @@ -270,37 +359,35 @@ "metadata": {}, "outputs": [], "source": [ - "fig4 = px.scatter(df, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", - "\n", - "# See which axes exist\n", - "print(\"X axes:\", [k for k in fig4.layout.to_plotly_json() if k.startswith(\"xaxis\")])\n", - "print(\"Y axes:\", [k for k in fig4.layout.to_plotly_json() if k.startswith(\"yaxis\")])" + "# When you play the animation, it reverts to the frame's original style\n", + "fig" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "23", "metadata": {}, - "outputs": [], "source": [ - "# Modify specific axis directly\n", - "fig4.layout.xaxis.type = \"log\"\n", - "fig4.layout.xaxis2.type = \"log\"\n", - "fig4.layout.yaxis.title.text = \"Life Exp\"\n", - "\n", - "fig4.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", - "fig4" + "### Workaround: Update both base and frames" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "24", "metadata": {}, + "outputs": [], "source": [ - "### Shapes on Specific Facets\n", + "def update_animation_traces(fig, **kwargs):\n", + " \"\"\"Update traces in both base figure and all animation frames.\n", "\n", - "Use `xref` and `yref` to target specific facet axes." + " Works with faceted figures too - updates all traces across all facets and frames.\n", + " \"\"\"\n", + " fig.update_traces(**kwargs)\n", + " for frame in fig.frames:\n", + " for trace in frame.data:\n", + " trace.update(**kwargs)\n", + " return fig" ] }, { @@ -310,63 +397,30 @@ "metadata": {}, "outputs": [], "source": [ - "fig5 = px.scatter(df, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", - "fig5.update_xaxes(type=\"log\")\n", - "\n", - "# Add rectangle to first facet (x, y)\n", - "fig5.add_shape(\n", - " type=\"rect\",\n", - " x0=1000,\n", - " x1=10000,\n", - " y0=70,\n", - " y1=85,\n", - " fillcolor=\"lightblue\",\n", - " opacity=0.3,\n", - " line_width=0,\n", - " xref=\"x\",\n", - " yref=\"y\",\n", - ")\n", + "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n", "\n", - "# Add rectangle to second facet (x2, y2)\n", - "fig5.add_shape(\n", - " type=\"rect\",\n", - " x0=1000,\n", - " x1=10000,\n", - " y0=70,\n", - " y1=85,\n", - " fillcolor=\"lightgreen\",\n", - " opacity=0.3,\n", - " line_width=0,\n", - " xref=\"x2\",\n", - " yref=\"y2\",\n", - ")\n", + "update_animation_traces(fig, line_width=4, line_dash=\"dot\")\n", "\n", - "fig5.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", - "fig5" + "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n", + "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "26", "metadata": {}, + "outputs": [], "source": [ - "### Axis Matching\n", - "\n", - "Control whether facets share the same axis range." + "fig" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "27", "metadata": {}, - "outputs": [], "source": [ - "fig6 = px.histogram(df_facet, x=\"total_bill\", facet_col=\"day\")\n", - "\n", - "# Default: axes are matched (same range)\n", - "fig6.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", - "fig6" + "### Works with facets + animation" ] }, { @@ -376,9 +430,22 @@ "metadata": {}, "outputs": [], "source": [ - "# Make y-axes independent (each facet auto-scales)\n", - "fig6.update_yaxes(matches=None)\n", - "fig6" + "df_subset = df.query(\n", + " \"continent in ['Europe', 'Asia'] and country in ['Germany', 'France', 'China', 'Japan']\"\n", + ")\n", + "\n", + "fig = px.line(\n", + " df_subset,\n", + " x=\"year\",\n", + " y=\"gdpPercap\",\n", + " color=\"country\",\n", + " facet_col=\"continent\",\n", + " animation_frame=\"year\",\n", + ")\n", + "\n", + "update_animation_traces(fig, line_width=3)\n", + "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig" ] }, { @@ -386,62 +453,79 @@ "id": "29", "metadata": {}, "source": [ - "## Animated Plots\n", + "### What's affected\n", + "\n", + "Anything on **traces** needs the workaround for animations:\n", + "\n", + "| Property | Facets | Animation |\n", + "|----------|--------|-----------|\n", + "| `line_width` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", + "| `line_dash` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", + "| `line_color` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", + "| `marker_size` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", + "| `marker_symbol` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", + "| `opacity` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", "\n", - "Animations have frames, sliders, and play buttons." + "**Layout properties** (`update_layout`, `update_xaxes`, `update_yaxes`) work fine for animations." + ] + }, + { + "cell_type": "markdown", + "id": "30", + "metadata": {}, + "source": [ + "---\n", + "# Annoying: Animation speed\n", + "\n", + "The API to change animation speed is deeply nested." ] }, { "cell_type": "code", "execution_count": null, - "id": "30", + "id": "31", "metadata": {}, "outputs": [], "source": [ - "df_anim = px.data.gapminder()\n", - "fig7 = px.scatter(\n", - " df_anim,\n", + "fig = px.scatter(\n", + " df,\n", " x=\"gdpPercap\",\n", " y=\"lifeExp\",\n", - " size=\"pop\",\n", " color=\"continent\",\n", - " hover_name=\"country\",\n", + " size=\"pop\",\n", " animation_frame=\"year\",\n", - " animation_group=\"country\",\n", " log_x=True,\n", " range_y=[25, 90],\n", ")\n", - "fig7" + "\n", + "# This is... not intuitive\n", + "fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = 100 # faster\n", + "fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = 50\n", + "\n", + "fig" ] }, { "cell_type": "markdown", - "id": "31", + "id": "32", "metadata": {}, "source": [ - "### Animation Speed\n", - "\n", - "Modify frame duration and transition time." + "### Workaround: Helper function" ] }, { "cell_type": "code", "execution_count": null, - "id": "32", - "metadata": {}, - "outputs": [], - "source": [ - "fig7.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = 200 # ms per frame\n", - "fig7.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = 100 # transition time\n", - "fig7" - ] - }, - { - "cell_type": "markdown", "id": "33", "metadata": {}, + "outputs": [], "source": [ - "### Slider Styling" + "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n", + " \"\"\"Set animation speed in milliseconds.\"\"\"\n", + " if fig.layout.updatemenus:\n", + " fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = frame_duration\n", + " fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = transition_duration\n", + " return fig" ] }, { @@ -451,10 +535,18 @@ "metadata": {}, "outputs": [], "source": [ - "fig7.layout.sliders[0].currentvalue.prefix = \"Year: \"\n", - "fig7.layout.sliders[0].currentvalue.font.size = 16\n", - "fig7.layout.sliders[0].currentvalue.font.color = \"darkblue\"\n", - "fig7" + "fig = px.scatter(\n", + " df,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"continent\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[25, 90],\n", + ")\n", + "\n", + "set_animation_speed(fig, frame_duration=200, transition_duration=100)\n", + "fig" ] }, { @@ -462,7 +554,10 @@ "id": "35", "metadata": {}, "source": [ - "### Play/Pause Button Styling" + "---\n", + "# Annoying: Slider styling\n", + "\n", + "Verbose but straightforward." ] }, { @@ -472,9 +567,21 @@ "metadata": {}, "outputs": [], "source": [ - "fig7.layout.updatemenus[0].bgcolor = \"lightgray\"\n", - "fig7.layout.updatemenus[0].font.color = \"black\"\n", - "fig7" + "fig = px.scatter(\n", + " df,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"continent\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[25, 90],\n", + ")\n", + "\n", + "fig.layout.sliders[0].currentvalue.prefix = \"Year: \"\n", + "fig.layout.sliders[0].currentvalue.font.size = 16\n", + "fig.layout.sliders[0].pad.t = 50 # padding from top\n", + "\n", + "fig" ] }, { @@ -482,9 +589,7 @@ "id": "37", "metadata": {}, "source": [ - "### Modify Individual Frames\n", - "\n", - "Access frames via `fig.frames`." + "### Hide slider or play button" ] }, { @@ -494,40 +599,8 @@ "metadata": {}, "outputs": [], "source": [ - "print(f\"Number of frames: {len(fig7.frames)}\")\n", - "print(f\"Frame names: {[f.name for f in fig7.frames]}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "39", - "metadata": {}, - "outputs": [], - "source": [ - "# Change layout for a specific frame (e.g., add title showing year)\n", - "for frame in fig7.frames:\n", - " frame.layout = {\"title\": f\"Gapminder {frame.name}\"}\n", - "fig7" - ] - }, - { - "cell_type": "markdown", - "id": "40", - "metadata": {}, - "source": [ - "### Hide Slider or Buttons" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41", - "metadata": {}, - "outputs": [], - "source": [ - "fig8 = px.scatter(\n", - " df_anim,\n", + "fig = px.scatter(\n", + " df,\n", " x=\"gdpPercap\",\n", " y=\"lifeExp\",\n", " color=\"continent\",\n", @@ -536,38 +609,67 @@ " range_y=[25, 90],\n", ")\n", "\n", - "# Hide the slider\n", - "fig8.layout.sliders = []\n", + "# Hide slider (keep play button)\n", + "fig.layout.sliders = []\n", "\n", - "# Or hide the play button instead:\n", - "# fig8.layout.updatemenus = []\n", + "# Or hide play button (keep slider):\n", + "# fig.layout.updatemenus = []\n", "\n", - "fig8" + "fig" ] }, { "cell_type": "markdown", - "id": "42", + "id": "39", "metadata": {}, "source": [ - "## Method Chaining\n", + "---\n", + "# Summary\n", "\n", - "All `update_*` methods return the figure." + "### Helper functions" ] }, { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "40", "metadata": {}, "outputs": [], "source": [ - "(\n", - " px.scatter(df, x=\"gdpPercap\", y=\"lifeExp\", color=\"continent\")\n", - " .update_layout(title=\"Chained Example\")\n", - " .update_traces(marker_size=12)\n", - " .update_xaxes(type=\"log\")\n", - ")" + "def update_animation_traces(fig, **kwargs):\n", + " \"\"\"Update traces in both base figure and all animation frames.\"\"\"\n", + " fig.update_traces(**kwargs)\n", + " for frame in fig.frames:\n", + " for trace in frame.data:\n", + " trace.update(**kwargs)\n", + " return fig\n", + "\n", + "\n", + "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n", + " \"\"\"Set animation speed in milliseconds.\"\"\"\n", + " if fig.layout.updatemenus:\n", + " fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = frame_duration\n", + " fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = transition_duration\n", + " return fig" + ] + }, + { + "cell_type": "markdown", + "id": "41", + "metadata": {}, + "source": [ + "### Quick reference\n", + "\n", + "| Task | Facets | Animation | Solution |\n", + "|------|--------|-----------|----------|\n", + "| Update trace style | `update_traces()` | `update_animation_traces()` | Helper needed |\n", + "| Update axes | `update_xaxes()`/`update_yaxes()` | Same | ✅ Works |\n", + "| Update layout | `update_layout()` | Same | ✅ Works |\n", + "| Add reference line | `add_hline(row=, col=)` | `add_hline()` | ✅ Works |\n", + "| Add trace | `overlay_figures()` | `overlay_figures()` | ✅ Works |\n", + "| Add shape to specific facet | `add_shape(xref=\"x2\")` | Same | Need axis name |\n", + "| Change animation speed | N/A | `set_animation_speed()` | Helper needed |\n", + "| Facet labels | `for_each_annotation()` | Same | ✅ Works |" ] } ], From 130ea0352392a1f0c0a8d27d8d450a68399bc656 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:12:06 +0100 Subject: [PATCH 07/10] Final Public API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit from xarray_plotly import ( xpx, # Main entry point - accessor with IDE completion overlay, # Combine figures on same axes add_secondary_y, # Dual y-axis plots update_animation_traces, # Update traces in animation frames config, # Configuration settings ) Changes Summary ┌─────────────────────────┬──────────────────────────────────┐ │ Removed │ Reason │ ├─────────────────────────┼──────────────────────────────────┤ │ overlay_figures │ Renamed to overlay │ ├─────────────────────────┼──────────────────────────────────┤ │ combine_figures │ Alias removed │ ├─────────────────────────┼──────────────────────────────────┤ │ SLOT_ORDERS │ Implementation detail │ ├─────────────────────────┼──────────────────────────────────┤ │ DataArrayPlotlyAccessor │ Users use xpx() │ ├─────────────────────────┼──────────────────────────────────┤ │ DatasetPlotlyAccessor │ Users use xpx() │ ├─────────────────────────┼──────────────────────────────────┤ │ auto │ Rarely needed │ ├─────────────────────────┼──────────────────────────────────┤ │ set_animation_speed │ Kept as local helper in notebook │ └─────────────────────────┴──────────────────────────────────┘ ┌─────────────────────────┬─────────────────────────┐ │ Added │ Reason │ ├─────────────────────────┼─────────────────────────┤ │ update_animation_traces │ Was hidden, now exposed │ └─────────────────────────┴─────────────────────────┘ --- docs/examples/combining.ipynb | 38 ++++++------ docs/examples/manipulation.ipynb | 101 ++++++++++++++++++------------- tests/test_figures.py | 92 +++++++++++----------------- xarray_plotly/__init__.py | 12 ++-- xarray_plotly/figures.py | 54 ++++++++++++++--- 5 files changed, 168 insertions(+), 129 deletions(-) diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb index ce5543e..6755c20 100644 --- a/docs/examples/combining.ipynb +++ b/docs/examples/combining.ipynb @@ -9,7 +9,7 @@ "\n", "xarray-plotly provides helper functions to combine multiple figures:\n", "\n", - "- **`overlay_figures`**: Overlay traces on the same axes\n", + "- **`overlay`**: Overlay traces on the same axes\n", "- **`add_secondary_y`**: Plot with two independent y-axes" ] }, @@ -23,7 +23,7 @@ "import plotly.express as px\n", "import xarray as xr\n", "\n", - "from xarray_plotly import add_secondary_y, config, overlay_figures, xpx\n", + "from xarray_plotly import add_secondary_y, config, overlay, xpx\n", "\n", "config.notebook()" ] @@ -91,7 +91,7 @@ "id": "4", "metadata": {}, "source": [ - "## overlay_figures\n", + "## overlay\n", "\n", "Overlay multiple figures on the same axes. Useful for showing data with a trend line, moving average, or different visualizations of related data." ] @@ -126,7 +126,7 @@ "ma_fig = xpx(goog_ma).line()\n", "ma_fig.update_traces(line={\"color\": \"red\", \"width\": 3}, name=\"20-day MA\")\n", "\n", - "combined = overlay_figures(price_fig, ma_fig)\n", + "combined = overlay(price_fig, ma_fig)\n", "combined.update_layout(title=\"GOOG: Daily Price with Moving Average\")\n", "combined" ] @@ -158,7 +158,7 @@ "ma_fig = xpx(subset_ma).line()\n", "ma_fig.update_traces(line={\"width\": 3})\n", "\n", - "combined = overlay_figures(raw_fig, ma_fig)\n", + "combined = overlay(raw_fig, ma_fig)\n", "combined.update_layout(title=\"Tech Stocks: Raw Prices + Moving Averages\")\n", "combined" ] @@ -170,7 +170,7 @@ "source": [ "### With Facets\n", "\n", - "`overlay_figures` works with faceted figures as long as both have the same structure." + "`overlay` works with faceted figures as long as both have the same structure." ] }, { @@ -187,7 +187,7 @@ "ma_faceted = xpx(subset_ma).line(facet_col=\"company\")\n", "ma_faceted.update_traces(line={\"color\": \"red\", \"width\": 2})\n", "\n", - "combined = overlay_figures(raw_faceted, ma_faceted)\n", + "combined = overlay(raw_faceted, ma_faceted)\n", "combined.update_layout(title=\"Faceted: Price + Moving Average per Company\")\n", "combined" ] @@ -219,7 +219,7 @@ "smooth_anim = xpx(pop_smooth).line(animation_frame=\"country\")\n", "smooth_anim.update_traces(line={\"color\": \"red\", \"width\": 3})\n", "\n", - "combined = overlay_figures(pop_anim, smooth_anim)\n", + "combined = overlay(pop_anim, smooth_anim)\n", "combined.update_layout(title=\"Population: Raw + Smoothed (animated by country)\")\n", "combined" ] @@ -250,7 +250,7 @@ "avg_fig = xpx(global_avg).line()\n", "avg_fig.update_traces(line={\"color\": \"black\", \"width\": 2, \"dash\": \"dash\"}, name=\"Global Avg\")\n", "\n", - "combined = overlay_figures(pop_anim, avg_fig)\n", + "combined = overlay(pop_anim, avg_fig)\n", "combined.update_layout(title=\"Population by Country vs Global Average\")\n", "combined" ] @@ -322,8 +322,8 @@ "gdp_fig = xpx(us_gdp).line()\n", "gdp_fig.update_traces(name=\"GDP per Capita\", line={\"color\": \"red\"})\n", "\n", - "bad = overlay_figures(pop_fig, gdp_fig)\n", - "bad.update_layout(title=\"overlay_figures: GDP invisible (scale mismatch)\")\n", + "bad = overlay(pop_fig, gdp_fig)\n", + "bad.update_layout(title=\"overlay: GDP invisible (scale mismatch)\")\n", "bad" ] }, @@ -453,7 +453,7 @@ "id": "28", "metadata": {}, "source": [ - "### overlay_figures: Mismatched Facet Structure\n", + "### overlay: Mismatched Facet Structure\n", "\n", "Overlay cannot have subplots that don't exist in base." ] @@ -469,10 +469,10 @@ "base = xpx(stocks.sel(company=\"GOOG\")).line()\n", "\n", "# Overlay: has facets\n", - "overlay = xpx(stocks.sel(company=[\"GOOG\", \"AAPL\"])).line(facet_col=\"company\")\n", + "overlay_fig = xpx(stocks.sel(company=[\"GOOG\", \"AAPL\"])).line(facet_col=\"company\")\n", "\n", "try:\n", - " overlay_figures(base, overlay)\n", + " overlay(base, overlay_fig)\n", "except ValueError as e:\n", " print(f\"ValueError: {e}\")" ] @@ -482,7 +482,7 @@ "id": "30", "metadata": {}, "source": [ - "### overlay_figures: Animated Overlay on Static Base\n", + "### overlay: Animated Overlay on Static Base\n", "\n", "Cannot add an animated overlay to a static base figure." ] @@ -501,7 +501,7 @@ "animated_overlay = xpx(population).line(animation_frame=\"country\")\n", "\n", "try:\n", - " overlay_figures(static_base, animated_overlay)\n", + " overlay(static_base, animated_overlay)\n", "except ValueError as e:\n", " print(f\"ValueError: {e}\")" ] @@ -511,7 +511,7 @@ "id": "32", "metadata": {}, "source": [ - "### overlay_figures: Mismatched Animation Frames\n", + "### overlay: Mismatched Animation Frames\n", "\n", "Animation frame names must match exactly." ] @@ -528,7 +528,7 @@ "fig2 = xpx(population.sel(country=[\"Germany\", \"Brazil\"])).line(animation_frame=\"country\")\n", "\n", "try:\n", - " overlay_figures(fig1, fig2)\n", + " overlay(fig1, fig2)\n", "except ValueError as e:\n", " print(f\"ValueError: {e}\")" ] @@ -625,7 +625,7 @@ "\n", "| Function | Facets | Animation | Static + Animated |\n", "|----------|--------|-----------|-------------------|\n", - "| `overlay_figures` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n", + "| `overlay` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n", "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |" ] } diff --git a/docs/examples/manipulation.ipynb b/docs/examples/manipulation.ipynb index 0024737..69bc372 100644 --- a/docs/examples/manipulation.ipynb +++ b/docs/examples/manipulation.ipynb @@ -21,7 +21,7 @@ "import plotly.graph_objects as go\n", "import plotly.io as pio\n", "\n", - "from xarray_plotly import overlay_figures\n", + "from xarray_plotly import overlay, update_animation_traces\n", "\n", "pio.renderers.default = \"notebook_connected\"\n", "\n", @@ -193,7 +193,7 @@ "---\n", "# Easy: Adding traces to faceted/animated figures\n", "\n", - "Use `overlay_figures` to add traces. It handles facets and animation frames automatically." + "Use `overlay` to add traces. It handles facets and animation frames automatically." ] }, { @@ -226,7 +226,7 @@ ")\n", "\n", "# Overlay - trace appears in all animation frames\n", - "combined = overlay_figures(fig, ref)\n", + "combined = overlay(fig, ref)\n", "combined" ] }, @@ -265,7 +265,7 @@ " )\n", ")\n", "\n", - "combined = overlay_figures(fig, ref1, ref2)\n", + "combined = overlay(fig, ref1, ref2)\n", "combined.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", "combined" ] @@ -278,7 +278,7 @@ "---\n", "# Annoying: Facet axis names\n", "\n", - "To target a specific facet with `add_shape`, `add_annotation`, or when adding traces via `overlay_figures`, you need to know the axis name (`x2`, `y3`, etc.)." + "To target a specific facet with `add_shape`, `add_annotation`, or when adding traces via `overlay`, you need to know the axis name (`x2`, `y3`, etc.)." ] }, { @@ -368,7 +368,9 @@ "id": "23", "metadata": {}, "source": [ - "### Workaround: Update both base and frames" + "### Solution: `update_animation_traces`\n", + "\n", + "xarray-plotly provides this helper to update both base traces and animation frames:" ] }, { @@ -378,16 +380,8 @@ "metadata": {}, "outputs": [], "source": [ - "def update_animation_traces(fig, **kwargs):\n", - " \"\"\"Update traces in both base figure and all animation frames.\n", - "\n", - " Works with faceted figures too - updates all traces across all facets and frames.\n", - " \"\"\"\n", - " fig.update_traces(**kwargs)\n", - " for frame in fig.frames:\n", - " for trace in frame.data:\n", - " trace.update(**kwargs)\n", - " return fig" + "# update_animation_traces is imported from xarray_plotly\n", + "# It updates traces in both base figure and all animation frames" ] }, { @@ -420,7 +414,9 @@ "id": "27", "metadata": {}, "source": [ - "### Works with facets + animation" + "### Selective updates with selector\n", + "\n", + "Use `selector` to target specific traces by name:" ] }, { @@ -429,6 +425,32 @@ "id": "28", "metadata": {}, "outputs": [], + "source": [ + "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"year\")\n", + "\n", + "# Update only one trace by name\n", + "update_animation_traces(fig, selector={\"name\": \"Germany\"}, line_width=5, line_dash=\"dot\")\n", + "\n", + "# Update multiple traces\n", + "update_animation_traces(fig, selector={\"name\": \"China\"}, line_color=\"red\", line_width=3)\n", + "\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "29", + "metadata": {}, + "source": [ + "### Works with facets + animation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], "source": [ "df_subset = df.query(\n", " \"continent in ['Europe', 'Asia'] and country in ['Germany', 'France', 'China', 'Japan']\"\n", @@ -450,7 +472,7 @@ }, { "cell_type": "markdown", - "id": "29", + "id": "31", "metadata": {}, "source": [ "### What's affected\n", @@ -471,7 +493,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "32", "metadata": {}, "source": [ "---\n", @@ -483,7 +505,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -507,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "34", "metadata": {}, "source": [ "### Workaround: Helper function" @@ -516,7 +538,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "35", "metadata": {}, "outputs": [], "source": [ @@ -531,7 +553,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -551,7 +573,7 @@ }, { "cell_type": "markdown", - "id": "35", + "id": "37", "metadata": {}, "source": [ "---\n", @@ -563,7 +585,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -586,7 +608,7 @@ }, { "cell_type": "markdown", - "id": "37", + "id": "39", "metadata": {}, "source": [ "### Hide slider or play button" @@ -595,7 +617,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -620,31 +642,28 @@ }, { "cell_type": "markdown", - "id": "39", + "id": "41", "metadata": {}, "source": [ "---\n", "# Summary\n", "\n", - "### Helper functions" + "### Provided by xarray-plotly\n", + "\n", + "```python\n", + "from xarray_plotly import overlay, update_animation_traces\n", + "```\n", + "\n", + "### Local helper for animation speed" ] }, { "cell_type": "code", "execution_count": null, - "id": "40", + "id": "42", "metadata": {}, "outputs": [], "source": [ - "def update_animation_traces(fig, **kwargs):\n", - " \"\"\"Update traces in both base figure and all animation frames.\"\"\"\n", - " fig.update_traces(**kwargs)\n", - " for frame in fig.frames:\n", - " for trace in frame.data:\n", - " trace.update(**kwargs)\n", - " return fig\n", - "\n", - "\n", "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n", " \"\"\"Set animation speed in milliseconds.\"\"\"\n", " if fig.layout.updatemenus:\n", @@ -655,7 +674,7 @@ }, { "cell_type": "markdown", - "id": "41", + "id": "43", "metadata": {}, "source": [ "### Quick reference\n", @@ -666,7 +685,7 @@ "| Update axes | `update_xaxes()`/`update_yaxes()` | Same | ✅ Works |\n", "| Update layout | `update_layout()` | Same | ✅ Works |\n", "| Add reference line | `add_hline(row=, col=)` | `add_hline()` | ✅ Works |\n", - "| Add trace | `overlay_figures()` | `overlay_figures()` | ✅ Works |\n", + "| Add trace | `overlay()` | `overlay()` | ✅ Works |\n", "| Add shape to specific facet | `add_shape(xref=\"x2\")` | Same | Need axis name |\n", "| Change animation speed | N/A | `set_animation_speed()` | Helper needed |\n", "| Facet labels | `for_each_annotation()` | Same | ✅ Works |" diff --git a/tests/test_figures.py b/tests/test_figures.py index befba25..85ab92c 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -1,4 +1,4 @@ -"""Tests for the figures module (overlay_figures, add_secondary_y).""" +"""Tests for the figures module (overlay, add_secondary_y).""" from __future__ import annotations @@ -9,11 +9,11 @@ import pytest import xarray as xr -from xarray_plotly import add_secondary_y, combine_figures, overlay_figures, xpx +from xarray_plotly import add_secondary_y, overlay, xpx -class TestCombineFiguresBasic: - """Basic tests for combine_figures function.""" +class TestOverlayBasic: + """Basic tests for overlay function.""" @pytest.fixture(autouse=True) def setup(self) -> None: @@ -28,7 +28,7 @@ def setup(self) -> None: def test_no_overlays_returns_copy(self) -> None: """Test that no overlays returns a deep copy of base.""" base = xpx(self.da_2d).line() - result = combine_figures(base) + result = overlay(base) assert isinstance(result, go.Figure) assert len(result.data) == len(base.data) @@ -41,7 +41,7 @@ def test_combine_two_static_figures(self) -> None: area_fig = xpx(self.da_2d).area() line_fig = xpx(self.da_2d).line() - combined = combine_figures(area_fig, line_fig) + combined = overlay(area_fig, line_fig) assert isinstance(combined, go.Figure) expected_trace_count = len(area_fig.data) + len(line_fig.data) @@ -52,7 +52,7 @@ def test_preserves_base_layout(self) -> None: area_fig = xpx(self.da_2d).area(title="My Area Plot") line_fig = xpx(self.da_2d).line(title="My Line Plot") - combined = combine_figures(area_fig, line_fig) + combined = overlay(area_fig, line_fig) assert combined.layout.title.text == "My Area Plot" @@ -62,7 +62,7 @@ def test_multiple_overlays(self) -> None: line_fig = xpx(self.da_2d).line() scatter_fig = xpx(self.da_2d).scatter() - combined = combine_figures(area_fig, line_fig, scatter_fig) + combined = overlay(area_fig, line_fig, scatter_fig) expected_count = len(area_fig.data) + len(line_fig.data) + len(scatter_fig.data) assert len(combined.data) == expected_count @@ -76,7 +76,7 @@ def test_overlay_traces_added_in_order(self) -> None: fig1 = xpx(da_1).line() fig2 = xpx(da_2).line() - combined = combine_figures(fig1, fig2) + combined = overlay(fig1, fig2) # First trace should have y values from fig1 assert list(combined.data[0].y) == [1, 2, 3] @@ -84,8 +84,8 @@ def test_overlay_traces_added_in_order(self) -> None: assert list(combined.data[1].y) == [10, 20, 30] -class TestCombineFiguresFacets: - """Tests for combine_figures with faceted figures.""" +class TestOverlayFacets: + """Tests for overlay with faceted figures.""" @pytest.fixture(autouse=True) def setup(self) -> None: @@ -106,7 +106,7 @@ def test_matching_facet_structures(self) -> None: area_fig = xpx(self.da_3d).area(facet_col="facet") line_fig = xpx(self.da_3d).line(facet_col="facet") - combined = combine_figures(area_fig, line_fig) + combined = overlay(area_fig, line_fig) assert isinstance(combined, go.Figure) expected_count = len(area_fig.data) + len(line_fig.data) @@ -117,17 +117,17 @@ def test_overlay_with_extra_subplots_raises(self) -> None: # Base without facets base = xpx(self.da_3d.isel(facet=0)).line() # Overlay with facets - overlay = xpx(self.da_3d).line(facet_col="facet") + overlay_fig = xpx(self.da_3d).line(facet_col="facet") with pytest.raises(ValueError, match="subplots not present in base"): - combine_figures(base, overlay) + overlay(base, overlay_fig) def test_preserves_axis_references(self) -> None: """Test that traces preserve their xaxis/yaxis references.""" area_fig = xpx(self.da_3d).area(facet_col="facet") line_fig = xpx(self.da_3d).line(facet_col="facet") - combined = combine_figures(area_fig, line_fig) + combined = overlay(area_fig, line_fig) # Collect axis references from both original and combined original_axes = set() @@ -146,8 +146,8 @@ def test_preserves_axis_references(self) -> None: assert combined_axes == original_axes -class TestCombineFiguresAnimation: - """Tests for combine_figures with animated figures.""" +class TestOverlayAnimation: + """Tests for overlay with animated figures.""" @pytest.fixture(autouse=True) def setup(self) -> None: @@ -168,7 +168,7 @@ def test_matching_frames_merged(self) -> None: area_fig = xpx(self.da_3d).area(animation_frame="time") line_fig = xpx(self.da_3d).line(animation_frame="time") - combined = combine_figures(area_fig, line_fig) + combined = overlay(area_fig, line_fig) assert isinstance(combined, go.Figure) # Should have same number of frames @@ -183,7 +183,7 @@ def test_static_overlay_replicated_to_frames(self) -> None: animated = xpx(self.da_3d).area(animation_frame="time") static = xpx(self.da_3d.isel(time=0)).line() - combined = combine_figures(animated, static) + combined = overlay(animated, static) # Combined should have all frames from animated figure assert len(combined.frames) == len(animated.frames) @@ -200,7 +200,7 @@ def test_animated_overlay_on_static_base_raises(self) -> None: animated = xpx(self.da_3d).area(animation_frame="time") with pytest.raises(ValueError, match="base figure does not"): - combine_figures(static, animated) + overlay(static, animated) def test_mismatched_frame_names_raises(self) -> None: """Test that mismatched frame names raise ValueError.""" @@ -219,22 +219,22 @@ def test_mismatched_frame_names_raises(self) -> None: fig2 = xpx(da2).line(animation_frame="time") with pytest.raises(ValueError, match="frame names don't match"): - combine_figures(fig1, fig2) + overlay(fig1, fig2) def test_frame_names_preserved(self) -> None: """Test that frame names are preserved in combined figure.""" area_fig = xpx(self.da_3d).area(animation_frame="time") line_fig = xpx(self.da_3d).line(animation_frame="time") - combined = combine_figures(area_fig, line_fig) + combined = overlay(area_fig, line_fig) original_names = {frame.name for frame in area_fig.frames} combined_names = {frame.name for frame in combined.frames} assert original_names == combined_names -class TestCombineFiguresFacetsAndAnimation: - """Tests for combine_figures with both facets and animation.""" +class TestOverlayFacetsAndAnimation: + """Tests for overlay with both facets and animation.""" @pytest.fixture(autouse=True) def setup(self) -> None: @@ -256,7 +256,7 @@ def test_facets_and_animation_combined(self) -> None: area_fig = xpx(self.da_4d).area(facet_col="facet", animation_frame="time") line_fig = xpx(self.da_4d).line(facet_col="facet", animation_frame="time") - combined = combine_figures(area_fig, line_fig) + combined = overlay(area_fig, line_fig) assert isinstance(combined, go.Figure) # Check trace count @@ -270,7 +270,7 @@ def test_static_overlay_on_animated_faceted_base(self) -> None: animated = xpx(self.da_4d).area(facet_col="facet", animation_frame="time") static = xpx(self.da_4d.isel(time=0)).line(facet_col="facet") - combined = combine_figures(animated, static) + combined = overlay(animated, static) # Should have same frames as animated assert len(combined.frames) == len(animated.frames) @@ -280,8 +280,8 @@ def test_static_overlay_on_animated_faceted_base(self) -> None: assert len(frame.data) == expected -class TestCombineFiguresDeepCopy: - """Tests to ensure combine_figures creates deep copies.""" +class TestOverlayDeepCopy: + """Tests to ensure overlay creates deep copies.""" def test_base_not_modified(self) -> None: """Test that base figure is not modified.""" @@ -290,8 +290,8 @@ def test_base_not_modified(self) -> None: original_trace_count = len(base.data) original_title = copy.deepcopy(base.layout.title) - overlay = xpx(da).line() - _ = combine_figures(base, overlay) + overlay_fig = xpx(da).line() + _ = overlay(base, overlay_fig) # Base should be unchanged assert len(base.data) == original_trace_count @@ -301,21 +301,21 @@ def test_overlay_not_modified(self) -> None: """Test that overlay figure is not modified.""" da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) base = xpx(da).area() - overlay = xpx(da).line() - original_trace_count = len(overlay.data) + overlay_fig = xpx(da).line() + original_trace_count = len(overlay_fig.data) - _ = combine_figures(base, overlay) + _ = overlay(base, overlay_fig) # Overlay should be unchanged - assert len(overlay.data) == original_trace_count + assert len(overlay_fig.data) == original_trace_count def test_combined_traces_independent(self) -> None: """Test that combined traces are independent of originals.""" da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) base = xpx(da).area() - overlay = xpx(da).line() + overlay_fig = xpx(da).line() - combined = combine_figures(base, overlay) + combined = overlay(base, overlay_fig) # Modify combined figure combined.data[0].name = "modified" @@ -324,26 +324,6 @@ def test_combined_traces_independent(self) -> None: assert base.data[0].name != "modified" -class TestOverlayFiguresAlias: - """Test that overlay_figures and combine_figures are equivalent.""" - - def test_overlay_figures_is_combine_figures(self) -> None: - """Test that overlay_figures is the same function as combine_figures.""" - assert overlay_figures is combine_figures - - def test_overlay_figures_works(self) -> None: - """Test that overlay_figures works correctly.""" - da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) - area_fig = xpx(da).area() - line_fig = xpx(da).line() - - combined = overlay_figures(area_fig, line_fig) - - assert isinstance(combined, go.Figure) - expected_count = len(area_fig.data) + len(line_fig.data) - assert len(combined.data) == expected_count - - class TestAddSecondaryYBasic: """Basic tests for add_secondary_y function.""" diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index cf254eb..29778e6 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -53,17 +53,19 @@ from xarray_plotly import config from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor from xarray_plotly.common import SLOT_ORDERS, auto -from xarray_plotly.figures import add_secondary_y, combine_figures, overlay_figures +from xarray_plotly.figures import ( + add_secondary_y, + overlay, + update_animation_traces, +) __all__ = [ "SLOT_ORDERS", - "DataArrayPlotlyAccessor", - "DatasetPlotlyAccessor", "add_secondary_y", "auto", - "combine_figures", "config", - "overlay_figures", + "overlay", + "update_animation_traces", "xpx", ] diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 82fccd7..9e478da 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -129,7 +129,7 @@ def _merge_frames( return merged_frames -def overlay_figures(base: go.Figure, *overlays: go.Figure) -> go.Figure: +def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure: """Overlay multiple Plotly figures on the same axes. Creates a new figure with the base figure's layout, sliders, and buttons, @@ -150,18 +150,18 @@ def overlay_figures(base: go.Figure, *overlays: go.Figure) -> go.Figure: Example: >>> import numpy as np >>> import xarray as xr - >>> from xarray_plotly import xpx, overlay_figures + >>> from xarray_plotly import xpx, overlay >>> >>> da = xr.DataArray(np.random.rand(10, 3), dims=["time", "cat"]) >>> area_fig = xpx(da).area() >>> line_fig = xpx(da).line() - >>> combined = overlay_figures(area_fig, line_fig) + >>> combined = overlay(area_fig, line_fig) >>> >>> # With animation >>> da3d = xr.DataArray(np.random.rand(10, 3, 4), dims=["x", "cat", "time"]) >>> area = xpx(da3d).area(animation_frame="time") >>> line = xpx(da3d).line(animation_frame="time") - >>> combined = overlay_figures(area, line) + >>> combined = overlay(area, line) """ import plotly.graph_objects as go @@ -196,10 +196,6 @@ def overlay_figures(base: go.Figure, *overlays: go.Figure) -> go.Figure: return combined -# Backwards compatibility alias -combine_figures = overlay_figures - - def _build_secondary_y_mapping(base_axes: set[tuple[str, str]]) -> dict[str, str]: """Build mapping from primary y-axes to secondary y-axes. @@ -392,3 +388,45 @@ def _merge_secondary_y_frames( ) return merged_frames + + +def update_animation_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.Figure: + """Update traces in both base figure and all animation frames. + + Plotly's `update_traces()` only updates the base figure, not animation frames. + This function updates both, ensuring trace styles persist during animation. + + Args: + fig: A Plotly figure, optionally with animation frames. + selector: Dict to match specific traces, e.g. ``{"name": "Germany"}``. + If None, updates all traces. + **kwargs: Trace properties to update, e.g. ``line_width=4``, ``line_dash="dot"``. + + Returns: + The modified figure (same object, mutated in place). + + Example: + >>> import plotly.express as px + >>> from xarray_plotly import update_animation_traces + >>> + >>> df = px.data.gapminder() + >>> fig = px.line(df, x="year", y="gdpPercap", color="country", animation_frame="continent") + >>> + >>> # Update all traces + >>> update_animation_traces(fig, line_width=3) + >>> + >>> # Update specific trace by name + >>> update_animation_traces(fig, selector={"name": "Germany"}, line_width=5, line_dash="dot") + """ + fig.update_traces(selector=selector, **kwargs) + + for frame in fig.frames: + for trace in frame.data: + if selector is None: + trace.update(**kwargs) + else: + # Check if trace matches all selector criteria + if all(getattr(trace, k, None) == v for k, v in selector.items()): + trace.update(**kwargs) + + return fig From cef8d06c03a5f2571311e63a63143aa51523dbb3 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:14:07 +0100 Subject: [PATCH 08/10] Rename update_animation_traces to update_traces --- docs/examples/manipulation.ipynb | 34 ++++++++++++++++---------------- xarray_plotly/__init__.py | 4 ++-- xarray_plotly/figures.py | 8 ++++---- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/docs/examples/manipulation.ipynb b/docs/examples/manipulation.ipynb index 69bc372..6b7f160 100644 --- a/docs/examples/manipulation.ipynb +++ b/docs/examples/manipulation.ipynb @@ -21,7 +21,7 @@ "import plotly.graph_objects as go\n", "import plotly.io as pio\n", "\n", - "from xarray_plotly import overlay, update_animation_traces\n", + "from xarray_plotly import overlay, update_traces\n", "\n", "pio.renderers.default = \"notebook_connected\"\n", "\n", @@ -368,7 +368,7 @@ "id": "23", "metadata": {}, "source": [ - "### Solution: `update_animation_traces`\n", + "### Solution: `update_traces`\n", "\n", "xarray-plotly provides this helper to update both base traces and animation frames:" ] @@ -380,7 +380,7 @@ "metadata": {}, "outputs": [], "source": [ - "# update_animation_traces is imported from xarray_plotly\n", + "# update_traces is imported from xarray_plotly\n", "# It updates traces in both base figure and all animation frames" ] }, @@ -393,7 +393,7 @@ "source": [ "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n", "\n", - "update_animation_traces(fig, line_width=4, line_dash=\"dot\")\n", + "update_traces(fig, line_width=4, line_dash=\"dot\")\n", "\n", "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n", "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")" @@ -429,10 +429,10 @@ "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"year\")\n", "\n", "# Update only one trace by name\n", - "update_animation_traces(fig, selector={\"name\": \"Germany\"}, line_width=5, line_dash=\"dot\")\n", + "update_traces(fig, selector={\"name\": \"Germany\"}, line_width=5, line_dash=\"dot\")\n", "\n", "# Update multiple traces\n", - "update_animation_traces(fig, selector={\"name\": \"China\"}, line_color=\"red\", line_width=3)\n", + "update_traces(fig, selector={\"name\": \"China\"}, line_color=\"red\", line_width=3)\n", "\n", "fig" ] @@ -465,7 +465,7 @@ " animation_frame=\"year\",\n", ")\n", "\n", - "update_animation_traces(fig, line_width=3)\n", + "update_traces(fig, line_width=3)\n", "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", "fig" ] @@ -477,16 +477,16 @@ "source": [ "### What's affected\n", "\n", - "Anything on **traces** needs the workaround for animations:\n", + "Anything on **traces** needs the helper for animations:\n", "\n", "| Property | Facets | Animation |\n", "|----------|--------|-----------|\n", - "| `line_width` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", - "| `line_dash` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", - "| `line_color` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", - "| `marker_size` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", - "| `marker_symbol` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", - "| `opacity` | ✅ `update_traces()` | ❌ needs `update_animation_traces()` |\n", + "| `line_width` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `line_dash` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `line_color` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `marker_size` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `marker_symbol` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `opacity` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", "\n", "**Layout properties** (`update_layout`, `update_xaxes`, `update_yaxes`) work fine for animations." ] @@ -651,7 +651,7 @@ "### Provided by xarray-plotly\n", "\n", "```python\n", - "from xarray_plotly import overlay, update_animation_traces\n", + "from xarray_plotly import overlay, update_traces\n", "```\n", "\n", "### Local helper for animation speed" @@ -681,13 +681,13 @@ "\n", "| Task | Facets | Animation | Solution |\n", "|------|--------|-----------|----------|\n", - "| Update trace style | `update_traces()` | `update_animation_traces()` | Helper needed |\n", + "| Update trace style | `fig.update_traces()` | `update_traces()` | xarray-plotly helper |\n", "| Update axes | `update_xaxes()`/`update_yaxes()` | Same | ✅ Works |\n", "| Update layout | `update_layout()` | Same | ✅ Works |\n", "| Add reference line | `add_hline(row=, col=)` | `add_hline()` | ✅ Works |\n", "| Add trace | `overlay()` | `overlay()` | ✅ Works |\n", "| Add shape to specific facet | `add_shape(xref=\"x2\")` | Same | Need axis name |\n", - "| Change animation speed | N/A | `set_animation_speed()` | Helper needed |\n", + "| Change animation speed | N/A | `set_animation_speed()` | Local helper |\n", "| Facet labels | `for_each_annotation()` | Same | ✅ Works |" ] } diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index 29778e6..d377b4c 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -56,7 +56,7 @@ from xarray_plotly.figures import ( add_secondary_y, overlay, - update_animation_traces, + update_traces, ) __all__ = [ @@ -65,7 +65,7 @@ "auto", "config", "overlay", - "update_animation_traces", + "update_traces", "xpx", ] diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 9e478da..2e1f658 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -390,7 +390,7 @@ def _merge_secondary_y_frames( return merged_frames -def update_animation_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.Figure: +def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.Figure: """Update traces in both base figure and all animation frames. Plotly's `update_traces()` only updates the base figure, not animation frames. @@ -407,16 +407,16 @@ def update_animation_traces(fig: go.Figure, selector: dict | None = None, **kwar Example: >>> import plotly.express as px - >>> from xarray_plotly import update_animation_traces + >>> from xarray_plotly import update_traces >>> >>> df = px.data.gapminder() >>> fig = px.line(df, x="year", y="gdpPercap", color="country", animation_frame="continent") >>> >>> # Update all traces - >>> update_animation_traces(fig, line_width=3) + >>> update_traces(fig, line_width=3) >>> >>> # Update specific trace by name - >>> update_animation_traces(fig, selector={"name": "Germany"}, line_width=5, line_dash="dot") + >>> update_traces(fig, selector={"name": "Germany"}, line_width=5, line_dash="dot") """ fig.update_traces(selector=selector, **kwargs) From 9b428e1c2951fdeb6b2a21ba92250d33ee3e96e8 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:15:58 +0100 Subject: [PATCH 09/10] Make module callable --- xarray_plotly/__init__.py | 110 ++++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 45 deletions(-) diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index d377b4c..cafcddc 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -12,41 +12,43 @@ - **Customizable**: Returns Plotly Figure objects for further modification Usage: +Mos Recommended:: + + import xarray_plotly as xpx + + fig = xpx(da).line() # Create plots + combined = xpx.overlay(fig1, fig2) # Use helper functions + Accessor style:: import xarray_plotly fig = da.plotly.line() - fig = ds.plotly.line() # Dataset: all variables - - Function style (recommended for IDE completion):: - - from xarray_plotly import xpx - fig = xpx(da).line() - fig = xpx(ds).line() # Dataset: all variables Example: ```python import xarray as xr import numpy as np - from xarray_plotly import xpx + import xarray_plotly as xpx da = xr.DataArray( np.random.rand(10, 3, 2), dims=["time", "city", "scenario"], ) fig = xpx(da).line() # Auto: time->x, city->color, scenario->facet_col - fig = xpx(da).line(x="time", color="scenario") # Explicit - fig = xpx(da).line(color=None) # Skip slot - # Dataset: plot all variables (accessor or xpx) - ds = xr.Dataset({"temp": da, "precip": da}) - fig = xpx(ds).line() # "variable" dimension for color - fig = xpx(ds).line(facet_col="variable") # Facet by variable + # Combine figures + area = xpx(da).area() + line = xpx(da).line() + combined = xpx.overlay(area, line) ``` """ +from __future__ import annotations + +import sys +import types from importlib.metadata import version -from typing import overload +from typing import TYPE_CHECKING, overload from xarray import DataArray, Dataset, register_dataarray_accessor, register_dataset_accessor @@ -66,49 +68,67 @@ "config", "overlay", "update_traces", - "xpx", ] +__version__ = version("xarray_plotly") + +# Register the accessors +register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) +register_dataset_accessor("plotly")(DatasetPlotlyAccessor) + -@overload -def xpx(data: DataArray) -> DataArrayPlotlyAccessor: ... +class _CallableModule(types.ModuleType): + """A module that can be called as a function. + Enables the pattern:: -@overload -def xpx(data: Dataset) -> DatasetPlotlyAccessor: ... + import xarray_plotly as xpx + fig = xpx(da).line() # Call module as function + fig = xpx.overlay(a, b) # Access module attributes + """ + @overload + def __call__(self, data: DataArray) -> DataArrayPlotlyAccessor: ... -def xpx(data: DataArray | Dataset) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor: - """Get the plotly accessor for a DataArray or Dataset with full IDE code completion. + @overload + def __call__(self, data: Dataset) -> DatasetPlotlyAccessor: ... - This is an alternative to `da.plotly` / `ds.plotly` that provides proper type hints - and code completion in IDEs. + def __call__( + self, data: DataArray | Dataset + ) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor: + """Get the plotly accessor for a DataArray or Dataset. - Args: - data: The DataArray or Dataset to plot. + Args: + data: The DataArray or Dataset to plot. - Returns: - The accessor with plotting methods (line, bar, area, scatter, box, imshow). + Returns: + The accessor with plotting methods (line, bar, area, scatter, box, imshow, pie). - Example: - ```python - from xarray_plotly import xpx + Example: + ```python + import xarray_plotly as xpx - # DataArray - fig = xpx(da).line() # Full code completion works here + fig = xpx(da).line() + fig = xpx(ds).line(var="temperature") + ``` + """ + if isinstance(data, Dataset): + return DatasetPlotlyAccessor(data) + return DataArrayPlotlyAccessor(data) - # Dataset - fig = xpx(ds).line() # Plots all variables - fig = xpx(ds).line(var="temperature") # Single variable - ``` - """ - if isinstance(data, Dataset): - return DatasetPlotlyAccessor(data) - return DataArrayPlotlyAccessor(data) +# Make the module callable +sys.modules[__name__].__class__ = _CallableModule -__version__ = version("xarray_plotly") +# For type checking, expose the call signature +if TYPE_CHECKING: -# Register the accessors -register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) -register_dataset_accessor("plotly")(DatasetPlotlyAccessor) + @overload + def __call__(data: DataArray) -> DataArrayPlotlyAccessor: ... + + @overload + def __call__(data: Dataset) -> DatasetPlotlyAccessor: ... + + def __call__( + data: DataArray | Dataset, + ) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor: ... From 2f790a2d028294b5b4f9d69e3ed579d8f10ee040 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:44:14 +0100 Subject: [PATCH 10/10] Revert --- xarray_plotly/__init__.py | 110 ++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 65 deletions(-) diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index cafcddc..d377b4c 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -12,43 +12,41 @@ - **Customizable**: Returns Plotly Figure objects for further modification Usage: -Mos Recommended:: - - import xarray_plotly as xpx - - fig = xpx(da).line() # Create plots - combined = xpx.overlay(fig1, fig2) # Use helper functions - Accessor style:: import xarray_plotly fig = da.plotly.line() + fig = ds.plotly.line() # Dataset: all variables + + Function style (recommended for IDE completion):: + + from xarray_plotly import xpx + fig = xpx(da).line() + fig = xpx(ds).line() # Dataset: all variables Example: ```python import xarray as xr import numpy as np - import xarray_plotly as xpx + from xarray_plotly import xpx da = xr.DataArray( np.random.rand(10, 3, 2), dims=["time", "city", "scenario"], ) fig = xpx(da).line() # Auto: time->x, city->color, scenario->facet_col + fig = xpx(da).line(x="time", color="scenario") # Explicit + fig = xpx(da).line(color=None) # Skip slot - # Combine figures - area = xpx(da).area() - line = xpx(da).line() - combined = xpx.overlay(area, line) + # Dataset: plot all variables (accessor or xpx) + ds = xr.Dataset({"temp": da, "precip": da}) + fig = xpx(ds).line() # "variable" dimension for color + fig = xpx(ds).line(facet_col="variable") # Facet by variable ``` """ -from __future__ import annotations - -import sys -import types from importlib.metadata import version -from typing import TYPE_CHECKING, overload +from typing import overload from xarray import DataArray, Dataset, register_dataarray_accessor, register_dataset_accessor @@ -68,67 +66,49 @@ "config", "overlay", "update_traces", + "xpx", ] -__version__ = version("xarray_plotly") - -# Register the accessors -register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) -register_dataset_accessor("plotly")(DatasetPlotlyAccessor) - - -class _CallableModule(types.ModuleType): - """A module that can be called as a function. - Enables the pattern:: - - import xarray_plotly as xpx - fig = xpx(da).line() # Call module as function - fig = xpx.overlay(a, b) # Access module attributes - """ +@overload +def xpx(data: DataArray) -> DataArrayPlotlyAccessor: ... - @overload - def __call__(self, data: DataArray) -> DataArrayPlotlyAccessor: ... - @overload - def __call__(self, data: Dataset) -> DatasetPlotlyAccessor: ... +@overload +def xpx(data: Dataset) -> DatasetPlotlyAccessor: ... - def __call__( - self, data: DataArray | Dataset - ) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor: - """Get the plotly accessor for a DataArray or Dataset. - Args: - data: The DataArray or Dataset to plot. +def xpx(data: DataArray | Dataset) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor: + """Get the plotly accessor for a DataArray or Dataset with full IDE code completion. - Returns: - The accessor with plotting methods (line, bar, area, scatter, box, imshow, pie). + This is an alternative to `da.plotly` / `ds.plotly` that provides proper type hints + and code completion in IDEs. - Example: - ```python - import xarray_plotly as xpx + Args: + data: The DataArray or Dataset to plot. - fig = xpx(da).line() - fig = xpx(ds).line(var="temperature") - ``` - """ - if isinstance(data, Dataset): - return DatasetPlotlyAccessor(data) - return DataArrayPlotlyAccessor(data) + Returns: + The accessor with plotting methods (line, bar, area, scatter, box, imshow). + Example: + ```python + from xarray_plotly import xpx -# Make the module callable -sys.modules[__name__].__class__ = _CallableModule + # DataArray + fig = xpx(da).line() # Full code completion works here -# For type checking, expose the call signature -if TYPE_CHECKING: + # Dataset + fig = xpx(ds).line() # Plots all variables + fig = xpx(ds).line(var="temperature") # Single variable + ``` + """ + if isinstance(data, Dataset): + return DatasetPlotlyAccessor(data) + return DataArrayPlotlyAccessor(data) - @overload - def __call__(data: DataArray) -> DataArrayPlotlyAccessor: ... - @overload - def __call__(data: Dataset) -> DatasetPlotlyAccessor: ... +__version__ = version("xarray_plotly") - def __call__( - data: DataArray | Dataset, - ) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor: ... +# Register the accessors +register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) +register_dataset_accessor("plotly")(DatasetPlotlyAccessor)