diff --git a/docs/examples/combining.ipynb b/docs/examples/combining.ipynb new file mode 100644 index 0000000..6755c20 --- /dev/null +++ b/docs/examples/combining.ipynb @@ -0,0 +1,646 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Combining Figures\n", + "\n", + "xarray-plotly provides helper functions to combine multiple figures:\n", + "\n", + "- **`overlay`**: Overlay traces on the same axes\n", + "- **`add_secondary_y`**: Plot with two independent y-axes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.express as px\n", + "import xarray as xr\n", + "\n", + "from xarray_plotly import add_secondary_y, config, overlay, xpx\n", + "\n", + "config.notebook()" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Load Sample Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# Stock prices\n", + "df_stocks = px.data.stocks().set_index(\"date\")\n", + "df_stocks.index = df_stocks.index.astype(\"datetime64[ns]\")\n", + "\n", + "stocks = xr.DataArray(\n", + " df_stocks.values,\n", + " dims=[\"date\", \"company\"],\n", + " coords={\"date\": df_stocks.index, \"company\": df_stocks.columns.tolist()},\n", + " name=\"price\",\n", + ")\n", + "\n", + "# Gapminder data (subset: a few countries)\n", + "df_gap = px.data.gapminder()\n", + "countries = [\"United States\", \"China\", \"Germany\", \"Brazil\"]\n", + "df_gap = df_gap[df_gap[\"country\"].isin(countries)]\n", + "\n", + "# Convert to xarray\n", + "gap_pop = df_gap.pivot(index=\"year\", columns=\"country\", values=\"pop\")\n", + "gap_gdp = df_gap.pivot(index=\"year\", columns=\"country\", values=\"gdpPercap\")\n", + "gap_life = df_gap.pivot(index=\"year\", columns=\"country\", values=\"lifeExp\")\n", + "\n", + "population = xr.DataArray(\n", + " gap_pop.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": gap_pop.index.values, \"country\": gap_pop.columns.tolist()},\n", + " name=\"Population\",\n", + ")\n", + "\n", + "gdp_per_capita = xr.DataArray(\n", + " gap_gdp.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": gap_gdp.index.values, \"country\": gap_gdp.columns.tolist()},\n", + " name=\"GDP per Capita\",\n", + ")\n", + "\n", + "life_expectancy = xr.DataArray(\n", + " gap_life.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": gap_life.index.values, \"country\": gap_life.columns.tolist()},\n", + " name=\"Life Expectancy\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## overlay\n", + "\n", + "Overlay multiple figures on the same axes. Useful for showing data with a trend line, moving average, or different visualizations of related data." + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "### Stock Price with Moving Average" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# Select one company\n", + "goog = stocks.sel(company=\"GOOG\")\n", + "\n", + "# Calculate 20-day moving average\n", + "goog_ma = goog.rolling(date=20, center=True).mean()\n", + "goog_ma.name = \"20-day MA\"\n", + "\n", + "# Raw prices as scatter\n", + "price_fig = xpx(goog).scatter()\n", + "price_fig.update_traces(marker={\"size\": 4, \"opacity\": 0.5}, name=\"Daily Price\")\n", + "\n", + "# Moving average as line\n", + "ma_fig = xpx(goog_ma).line()\n", + "ma_fig.update_traces(line={\"color\": \"red\", \"width\": 3}, name=\"20-day MA\")\n", + "\n", + "combined = overlay(price_fig, ma_fig)\n", + "combined.update_layout(title=\"GOOG: Daily Price with Moving Average\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "### Multiple Companies with Moving Averages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# Select a few companies\n", + "subset = stocks.sel(company=[\"GOOG\", \"AAPL\", \"MSFT\"])\n", + "subset_ma = subset.rolling(date=20, center=True).mean()\n", + "\n", + "# Raw as scatter (faded)\n", + "raw_fig = xpx(subset).scatter()\n", + "raw_fig.update_traces(marker={\"size\": 3, \"opacity\": 0.3})\n", + "\n", + "# MA as lines (bold)\n", + "ma_fig = xpx(subset_ma).line()\n", + "ma_fig.update_traces(line={\"width\": 3})\n", + "\n", + "combined = overlay(raw_fig, ma_fig)\n", + "combined.update_layout(title=\"Tech Stocks: Raw Prices + Moving Averages\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "### With Facets\n", + "\n", + "`overlay` works with faceted figures as long as both have the same structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Faceted by company\n", + "raw_faceted = xpx(subset).scatter(facet_col=\"company\")\n", + "raw_faceted.update_traces(marker={\"size\": 3, \"opacity\": 0.4})\n", + "\n", + "ma_faceted = xpx(subset_ma).line(facet_col=\"company\")\n", + "ma_faceted.update_traces(line={\"color\": \"red\", \"width\": 2})\n", + "\n", + "combined = overlay(raw_faceted, ma_faceted)\n", + "combined.update_layout(title=\"Faceted: Price + Moving Average per Company\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "### With Animation\n", + "\n", + "Overlay animated figures - frames are merged correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "# Animate through countries, showing population bar + GDP line\n", + "# Both use the same animation dimension\n", + "pop_anim = xpx(population).bar(animation_frame=\"country\")\n", + "pop_anim.update_traces(marker={\"opacity\": 0.6})\n", + "\n", + "# Create a \"target\" line (e.g., some reference value)\n", + "pop_smooth = population.rolling(year=3, center=True).mean()\n", + "smooth_anim = xpx(pop_smooth).line(animation_frame=\"country\")\n", + "smooth_anim.update_traces(line={\"color\": \"red\", \"width\": 3})\n", + "\n", + "combined = overlay(pop_anim, smooth_anim)\n", + "combined.update_layout(title=\"Population: Raw + Smoothed (animated by country)\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "### Static Overlay on Animated Base\n", + "\n", + "A static figure can be overlaid on an animated one - the static traces appear in all frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# Animated population\n", + "pop_anim = xpx(population).bar(animation_frame=\"country\")\n", + "pop_anim.update_traces(marker={\"opacity\": 0.7})\n", + "\n", + "# Static reference line (global average across all countries)\n", + "global_avg = population.mean(dim=\"country\")\n", + "avg_fig = xpx(global_avg).line()\n", + "avg_fig.update_traces(line={\"color\": \"black\", \"width\": 2, \"dash\": \"dash\"}, name=\"Global Avg\")\n", + "\n", + "combined = overlay(pop_anim, avg_fig)\n", + "combined.update_layout(title=\"Population by Country vs Global Average\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## add_secondary_y\n", + "\n", + "Plot two variables with different scales using independent y-axes. Essential when values have different magnitudes." + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "### Population vs GDP per Capita" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# Select one country\n", + "us_pop = population.sel(country=\"United States\")\n", + "us_gdp = gdp_per_capita.sel(country=\"United States\")\n", + "\n", + "pop_fig = xpx(us_pop).bar()\n", + "pop_fig.update_traces(marker={\"color\": \"steelblue\", \"opacity\": 0.7}, name=\"Population\")\n", + "\n", + "gdp_fig = xpx(us_gdp).line()\n", + "gdp_fig.update_traces(line={\"color\": \"red\", \"width\": 3}, name=\"GDP per Capita\")\n", + "\n", + "combined = add_secondary_y(pop_fig, gdp_fig, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(\n", + " title=\"United States: Population vs GDP per Capita\",\n", + " yaxis_title=\"Population\",\n", + ")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "### Why Secondary Y-Axis Matters\n", + "\n", + "Without it, one variable dominates due to scale mismatch:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "# Same data on single y-axis - GDP looks flat because population is ~1e8, GDP is ~1e4\n", + "pop_fig = xpx(us_pop).line()\n", + "pop_fig.update_traces(name=\"Population\", line={\"color\": \"blue\"})\n", + "\n", + "gdp_fig = xpx(us_gdp).line()\n", + "gdp_fig.update_traces(name=\"GDP per Capita\", line={\"color\": \"red\"})\n", + "\n", + "bad = overlay(pop_fig, gdp_fig)\n", + "bad.update_layout(title=\"overlay: GDP invisible (scale mismatch)\")\n", + "bad" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "# With add_secondary_y - both variables visible\n", + "pop_fig = xpx(us_pop).line()\n", + "pop_fig.update_traces(name=\"Population\", line={\"color\": \"blue\", \"width\": 2})\n", + "\n", + "gdp_fig = xpx(us_gdp).line()\n", + "gdp_fig.update_traces(name=\"GDP per Capita\", line={\"color\": \"red\", \"width\": 2})\n", + "\n", + "good = add_secondary_y(pop_fig, gdp_fig, secondary_y_title=\"GDP per Capita ($)\")\n", + "good.update_layout(title=\"add_secondary_y: Both variables clearly visible\")\n", + "good" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "### With Animation\n", + "\n", + "`add_secondary_y` supports animated figures with matching frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "# Animate through countries\n", + "pop_anim = xpx(population).bar(animation_frame=\"country\")\n", + "pop_anim.update_traces(marker={\"color\": \"steelblue\", \"opacity\": 0.7})\n", + "\n", + "gdp_anim = xpx(gdp_per_capita).line(animation_frame=\"country\")\n", + "gdp_anim.update_traces(line={\"color\": \"red\", \"width\": 3})\n", + "\n", + "combined = add_secondary_y(pop_anim, gdp_anim, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(title=\"Population vs GDP (animated by country)\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "### Static Secondary on Animated Base\n", + "\n", + "A static secondary figure is replicated to all animation frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "# Animated population\n", + "pop_anim = xpx(population).bar(animation_frame=\"country\")\n", + "pop_anim.update_traces(marker={\"color\": \"steelblue\", \"opacity\": 0.7})\n", + "\n", + "# Static GDP reference (US only, shown in all frames)\n", + "us_gdp_static = xpx(us_gdp).line()\n", + "us_gdp_static.update_traces(\n", + " line={\"color\": \"red\", \"width\": 2, \"dash\": \"dash\"}, name=\"US GDP (reference)\"\n", + ")\n", + "\n", + "combined = add_secondary_y(pop_anim, us_gdp_static, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(title=\"Population (animated) vs US GDP (static reference)\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, + "source": [ + "### With Facets\n", + "\n", + "`add_secondary_y` works with faceted figures when both have the same facet structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "# Faceted by country - both figures must have same facet structure\n", + "pop_faceted = xpx(population).bar(facet_col=\"country\")\n", + "pop_faceted.update_traces(marker={\"color\": \"steelblue\", \"opacity\": 0.7})\n", + "\n", + "gdp_faceted = xpx(gdp_per_capita).line(facet_col=\"country\")\n", + "gdp_faceted.update_traces(line={\"color\": \"red\", \"width\": 3})\n", + "\n", + "combined = add_secondary_y(pop_faceted, gdp_faceted, secondary_y_title=\"GDP per Capita ($)\")\n", + "combined.update_layout(title=\"Population vs GDP per Capita (faceted by country)\")\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Limitations (with examples)\n", + "\n", + "Both functions validate inputs and raise clear errors when constraints are violated." + ] + }, + { + "cell_type": "markdown", + "id": "28", + "metadata": {}, + "source": [ + "### overlay: Mismatched Facet Structure\n", + "\n", + "Overlay cannot have subplots that don't exist in base." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [], + "source": [ + "# Base: no facets\n", + "base = xpx(stocks.sel(company=\"GOOG\")).line()\n", + "\n", + "# Overlay: has facets\n", + "overlay_fig = xpx(stocks.sel(company=[\"GOOG\", \"AAPL\"])).line(facet_col=\"company\")\n", + "\n", + "try:\n", + " overlay(base, overlay_fig)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "30", + "metadata": {}, + "source": [ + "### overlay: Animated Overlay on Static Base\n", + "\n", + "Cannot add an animated overlay to a static base figure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "# Base: static\n", + "static_base = xpx(population.sel(country=\"United States\")).line()\n", + "\n", + "# Overlay: animated\n", + "animated_overlay = xpx(population).line(animation_frame=\"country\")\n", + "\n", + "try:\n", + " overlay(static_base, animated_overlay)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "32", + "metadata": {}, + "source": [ + "### overlay: Mismatched Animation Frames\n", + "\n", + "Animation frame names must match exactly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [], + "source": [ + "# Different countries selected = different frame names\n", + "fig1 = xpx(population.sel(country=[\"United States\", \"China\"])).line(animation_frame=\"country\")\n", + "fig2 = xpx(population.sel(country=[\"Germany\", \"Brazil\"])).line(animation_frame=\"country\")\n", + "\n", + "try:\n", + " overlay(fig1, fig2)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "34", + "metadata": {}, + "source": [ + "### add_secondary_y: Mismatched Facet Structure\n", + "\n", + "Both figures must have the same facet structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "# Base with facets\n", + "pop_faceted = xpx(population).bar(facet_col=\"country\")\n", + "\n", + "# Secondary without facets (different structure)\n", + "gdp_single = xpx(gdp_per_capita.sel(country=\"United States\")).line()\n", + "\n", + "try:\n", + " add_secondary_y(pop_faceted, gdp_single)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "36", + "metadata": {}, + "source": [ + "### add_secondary_y: Animated Secondary on Static Base\n", + "\n", + "Cannot add animated secondary to static base." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37", + "metadata": {}, + "outputs": [], + "source": [ + "# Static base\n", + "static_pop = xpx(population.sel(country=\"United States\")).bar()\n", + "\n", + "# Animated secondary\n", + "animated_gdp = xpx(gdp_per_capita).line(animation_frame=\"country\")\n", + "\n", + "try:\n", + " add_secondary_y(static_pop, animated_gdp)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "38", + "metadata": {}, + "source": [ + "### add_secondary_y: Mismatched Animation Frames" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [], + "source": [ + "# Different countries = different frames\n", + "pop_some = xpx(population.sel(country=[\"United States\", \"China\"])).bar(animation_frame=\"country\")\n", + "gdp_other = xpx(gdp_per_capita.sel(country=[\"Germany\", \"Brazil\"])).line(animation_frame=\"country\")\n", + "\n", + "try:\n", + " add_secondary_y(pop_some, gdp_other)\n", + "except ValueError as e:\n", + " print(f\"ValueError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "40", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "| Function | Facets | Animation | Static + Animated |\n", + "|----------|--------|-----------|-------------------|\n", + "| `overlay` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n", + "| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/manipulation.ipynb b/docs/examples/manipulation.ipynb new file mode 100644 index 0000000..6b7f160 --- /dev/null +++ b/docs/examples/manipulation.ipynb @@ -0,0 +1,708 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Figure Manipulation\n", + "\n", + "What's easy, what's annoying, and how to work around it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.express as px\n", + "import plotly.graph_objects as go\n", + "import plotly.io as pio\n", + "\n", + "from xarray_plotly import overlay, update_traces\n", + "\n", + "pio.renderers.default = \"notebook_connected\"\n", + "\n", + "# Sample data\n", + "df = px.data.gapminder()\n", + "df_2007 = df.query(\"year == 2007\")\n", + "df_countries = df.query(\"country in ['United States', 'China', 'Germany', 'Brazil']\")" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "---\n", + "# Easy: Single Plots\n", + "\n", + "All standard manipulation methods work as expected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", color=\"continent\", size=\"pop\")\n", + "fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "# Layout\n", + "fig.update_layout(title=\"GDP vs Life Expectancy\", template=\"plotly_white\")\n", + "\n", + "# All traces\n", + "fig.update_traces(marker_opacity=0.7)\n", + "\n", + "# Specific traces\n", + "fig.update_traces(marker_line_width=2, selector={\"name\": \"Europe\"})\n", + "\n", + "# Axes\n", + "fig.update_xaxes(type=\"log\", title=\"GDP per Capita\")\n", + "fig.update_yaxes(range=[40, 90])\n", + "\n", + "# Annotations and shapes\n", + "fig.add_hline(y=df_2007[\"lifeExp\"].mean(), line_dash=\"dash\", line_color=\"gray\")\n", + "\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "---\n", + "# Easy: Faceted Plots\n", + "\n", + "`update_traces`, `update_xaxes`, `update_yaxes` all work across facets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", facet_col=\"country\")\n", + "fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# Update ALL traces across all facets\n", + "fig.update_traces(line_width=3)\n", + "\n", + "# Update ALL x-axes\n", + "fig.update_xaxes(showgrid=False)\n", + "\n", + "# Update ALL y-axes\n", + "fig.update_yaxes(showgrid=False, type=\"log\")\n", + "\n", + "# Clean up facet labels\n", + "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "### Targeting specific facets\n", + "\n", + "Use `row=` and `col=` (1-indexed) to target specific facets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.histogram(px.data.tips(), x=\"total_bill\", facet_row=\"sex\", facet_col=\"time\")\n", + "\n", + "# Target specific cell\n", + "fig.update_yaxes(title_text=\"Frequency\", row=1, col=1)\n", + "\n", + "# Target entire column\n", + "fig.update_xaxes(title_text=\"Bill ($)\", col=2)\n", + "\n", + "# Target entire row\n", + "fig.update_traces(marker_color=\"orange\", row=2)\n", + "\n", + "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "### Reference lines on facets\n", + "\n", + "`add_hline`/`add_vline` apply to all facets by default. Use `row=`/`col=` to target." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", + "fig.update_xaxes(type=\"log\")\n", + "\n", + "# Applies to ALL facets\n", + "fig.add_hline(y=70, line_dash=\"dash\", line_color=\"red\")\n", + "\n", + "# Specific facet only\n", + "fig.add_hline(y=50, line_dash=\"dot\", line_color=\"blue\", row=2, col=1)\n", + "\n", + "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "---\n", + "# Easy: Adding traces to faceted/animated figures\n", + "\n", + "Use `overlay` to add traces. It handles facets and animation frames automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "# Animated scatter\n", + "fig = px.scatter(\n", + " df_countries,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"country\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[40, 85],\n", + ")\n", + "\n", + "# Create a figure with reference marker\n", + "ref = go.Figure(\n", + " go.Scatter(\n", + " x=[10000],\n", + " y=[75],\n", + " mode=\"markers\",\n", + " marker={\"size\": 20, \"symbol\": \"star\", \"color\": \"gold\"},\n", + " name=\"Target\",\n", + " )\n", + ")\n", + "\n", + "# Overlay - trace appears in all animation frames\n", + "combined = overlay(fig, ref)\n", + "combined" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# Faceted plot\n", + "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", + "fig.update_xaxes(type=\"log\")\n", + "\n", + "# Add reference to first facet (default axes x, y)\n", + "ref1 = go.Figure(\n", + " go.Scatter(\n", + " x=[5000],\n", + " y=[70],\n", + " mode=\"markers\",\n", + " marker={\"size\": 15, \"symbol\": \"star\", \"color\": \"gold\"},\n", + " name=\"Target 1\",\n", + " )\n", + ")\n", + "\n", + "# Add reference to second facet (axes x2, y2)\n", + "ref2 = go.Figure(\n", + " go.Scatter(\n", + " x=[20000],\n", + " y=[80],\n", + " mode=\"markers\",\n", + " marker={\"size\": 15, \"symbol\": \"star\", \"color\": \"red\"},\n", + " name=\"Target 2\",\n", + " xaxis=\"x2\",\n", + " yaxis=\"y2\", # specify target facet\n", + " )\n", + ")\n", + "\n", + "combined = overlay(fig, ref1, ref2)\n", + "combined.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "combined" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "---\n", + "# Annoying: Facet axis names\n", + "\n", + "To target a specific facet with `add_shape`, `add_annotation`, or when adding traces via `overlay`, you need to know the axis name (`x2`, `y3`, etc.)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(df_2007, x=\"gdpPercap\", y=\"lifeExp\", facet_col=\"continent\", facet_col_wrap=3)\n", + "\n", + "# Inspect axis names\n", + "layout_dict = fig.layout.to_plotly_json()\n", + "print(\"X axes:\", sorted([k for k in layout_dict if k.startswith(\"xaxis\")]))\n", + "print(\"Y axes:\", sorted([k for k in layout_dict if k.startswith(\"yaxis\")]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# Check which trace uses which axis\n", + "for i, trace in enumerate(fig.data):\n", + " print(f\"Trace {i} ({trace.name}): xaxis={trace.xaxis or 'x'}, yaxis={trace.yaxis or 'y'}\")" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "**Tip:** For simple cases, use `add_hline`/`add_vline` with `row=`/`col=` instead of `add_shape` - it handles axis mapping internally." + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "---\n", + "# Annoying: Animation trace updates\n", + "\n", + "**This is the main pain point.** `update_traces()` does NOT update animation frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n", + "fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "# This only affects the INITIAL view, not the animation frames!\n", + "fig.update_traces(line_width=5, line_dash=\"dot\")\n", + "\n", + "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n", + "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "# When you play the animation, it reverts to the frame's original style\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "### Solution: `update_traces`\n", + "\n", + "xarray-plotly provides this helper to update both base traces and animation frames:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "# update_traces is imported from xarray_plotly\n", + "# It updates traces in both base figure and all animation frames" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"country\")\n", + "\n", + "update_traces(fig, line_width=4, line_dash=\"dot\")\n", + "\n", + "print(f\"Base trace line_width: {fig.data[0].line.width}\")\n", + "print(f\"Frame 0 trace line_width: {fig.frames[0].data[0].line.width}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": {}, + "source": [ + "### Selective updates with selector\n", + "\n", + "Use `selector` to target specific traces by name:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.line(df_countries, x=\"year\", y=\"gdpPercap\", color=\"country\", animation_frame=\"year\")\n", + "\n", + "# Update only one trace by name\n", + "update_traces(fig, selector={\"name\": \"Germany\"}, line_width=5, line_dash=\"dot\")\n", + "\n", + "# Update multiple traces\n", + "update_traces(fig, selector={\"name\": \"China\"}, line_color=\"red\", line_width=3)\n", + "\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "29", + "metadata": {}, + "source": [ + "### Works with facets + animation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "df_subset = df.query(\n", + " \"continent in ['Europe', 'Asia'] and country in ['Germany', 'France', 'China', 'Japan']\"\n", + ")\n", + "\n", + "fig = px.line(\n", + " df_subset,\n", + " x=\"year\",\n", + " y=\"gdpPercap\",\n", + " color=\"country\",\n", + " facet_col=\"continent\",\n", + " animation_frame=\"year\",\n", + ")\n", + "\n", + "update_traces(fig, line_width=3)\n", + "fig.for_each_annotation(lambda a: a.update(text=a.text.split(\"=\")[-1]))\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "31", + "metadata": {}, + "source": [ + "### What's affected\n", + "\n", + "Anything on **traces** needs the helper for animations:\n", + "\n", + "| Property | Facets | Animation |\n", + "|----------|--------|-----------|\n", + "| `line_width` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `line_dash` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `line_color` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `marker_size` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `marker_symbol` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "| `opacity` | ✅ `fig.update_traces()` | ❌ needs `update_traces()` |\n", + "\n", + "**Layout properties** (`update_layout`, `update_xaxes`, `update_yaxes`) work fine for animations." + ] + }, + { + "cell_type": "markdown", + "id": "32", + "metadata": {}, + "source": [ + "---\n", + "# Annoying: Animation speed\n", + "\n", + "The API to change animation speed is deeply nested." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " df,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"continent\",\n", + " size=\"pop\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[25, 90],\n", + ")\n", + "\n", + "# This is... not intuitive\n", + "fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = 100 # faster\n", + "fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = 50\n", + "\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "34", + "metadata": {}, + "source": [ + "### Workaround: Helper function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n", + " \"\"\"Set animation speed in milliseconds.\"\"\"\n", + " if fig.layout.updatemenus:\n", + " fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = frame_duration\n", + " fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = transition_duration\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " df,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"continent\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[25, 90],\n", + ")\n", + "\n", + "set_animation_speed(fig, frame_duration=200, transition_duration=100)\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "37", + "metadata": {}, + "source": [ + "---\n", + "# Annoying: Slider styling\n", + "\n", + "Verbose but straightforward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " df,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"continent\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[25, 90],\n", + ")\n", + "\n", + "fig.layout.sliders[0].currentvalue.prefix = \"Year: \"\n", + "fig.layout.sliders[0].currentvalue.font.size = 16\n", + "fig.layout.sliders[0].pad.t = 50 # padding from top\n", + "\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "39", + "metadata": {}, + "source": [ + "### Hide slider or play button" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " df,\n", + " x=\"gdpPercap\",\n", + " y=\"lifeExp\",\n", + " color=\"continent\",\n", + " animation_frame=\"year\",\n", + " log_x=True,\n", + " range_y=[25, 90],\n", + ")\n", + "\n", + "# Hide slider (keep play button)\n", + "fig.layout.sliders = []\n", + "\n", + "# Or hide play button (keep slider):\n", + "# fig.layout.updatemenus = []\n", + "\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "41", + "metadata": {}, + "source": [ + "---\n", + "# Summary\n", + "\n", + "### Provided by xarray-plotly\n", + "\n", + "```python\n", + "from xarray_plotly import overlay, update_traces\n", + "```\n", + "\n", + "### Local helper for animation speed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42", + "metadata": {}, + "outputs": [], + "source": [ + "def set_animation_speed(fig, frame_duration=500, transition_duration=300):\n", + " \"\"\"Set animation speed in milliseconds.\"\"\"\n", + " if fig.layout.updatemenus:\n", + " fig.layout.updatemenus[0].buttons[0].args[1][\"frame\"][\"duration\"] = frame_duration\n", + " fig.layout.updatemenus[0].buttons[0].args[1][\"transition\"][\"duration\"] = transition_duration\n", + " return fig" + ] + }, + { + "cell_type": "markdown", + "id": "43", + "metadata": {}, + "source": [ + "### Quick reference\n", + "\n", + "| Task | Facets | Animation | Solution |\n", + "|------|--------|-----------|----------|\n", + "| Update trace style | `fig.update_traces()` | `update_traces()` | xarray-plotly helper |\n", + "| Update axes | `update_xaxes()`/`update_yaxes()` | Same | ✅ Works |\n", + "| Update layout | `update_layout()` | Same | ✅ Works |\n", + "| Add reference line | `add_hline(row=, col=)` | `add_hline()` | ✅ Works |\n", + "| Add trace | `overlay()` | `overlay()` | ✅ Works |\n", + "| Add shape to specific facet | `add_shape(xref=\"x2\")` | Same | Need axis name |\n", + "| Change animation speed | N/A | `set_animation_speed()` | Local helper |\n", + "| Facet labels | `for_each_annotation()` | Same | ✅ Works |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_figures.py b/tests/test_figures.py new file mode 100644 index 0000000..85ab92c --- /dev/null +++ b/tests/test_figures.py @@ -0,0 +1,590 @@ +"""Tests for the figures module (overlay, add_secondary_y).""" + +from __future__ import annotations + +import copy + +import numpy as np +import plotly.graph_objects as go +import pytest +import xarray as xr + +from xarray_plotly import add_secondary_y, overlay, xpx + + +class TestOverlayBasic: + """Basic tests for overlay function.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_2d = xr.DataArray( + np.random.rand(10, 3), + dims=["time", "cat"], + coords={"time": np.arange(10), "cat": ["A", "B", "C"]}, + name="value", + ) + + def test_no_overlays_returns_copy(self) -> None: + """Test that no overlays returns a deep copy of base.""" + base = xpx(self.da_2d).line() + result = overlay(base) + + assert isinstance(result, go.Figure) + assert len(result.data) == len(base.data) + # Verify it's a copy, not the same object + assert result is not base + assert result.data[0] is not base.data[0] + + def test_combine_two_static_figures(self) -> None: + """Test combining two static figures.""" + area_fig = xpx(self.da_2d).area() + line_fig = xpx(self.da_2d).line() + + combined = overlay(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + expected_trace_count = len(area_fig.data) + len(line_fig.data) + assert len(combined.data) == expected_trace_count + + def test_preserves_base_layout(self) -> None: + """Test that base figure's layout is preserved.""" + area_fig = xpx(self.da_2d).area(title="My Area Plot") + line_fig = xpx(self.da_2d).line(title="My Line Plot") + + combined = overlay(area_fig, line_fig) + + assert combined.layout.title.text == "My Area Plot" + + def test_multiple_overlays(self) -> None: + """Test combining multiple overlays.""" + area_fig = xpx(self.da_2d).area() + line_fig = xpx(self.da_2d).line() + scatter_fig = xpx(self.da_2d).scatter() + + combined = overlay(area_fig, line_fig, scatter_fig) + + expected_count = len(area_fig.data) + len(line_fig.data) + len(scatter_fig.data) + assert len(combined.data) == expected_count + + def test_overlay_traces_added_in_order(self) -> None: + """Test that overlay traces are added after base traces.""" + # Create figures with distinguishable y values + da_1 = xr.DataArray([1, 2, 3], dims=["x"], name="first") + da_2 = xr.DataArray([10, 20, 30], dims=["x"], name="second") + + fig1 = xpx(da_1).line() + fig2 = xpx(da_2).line() + + combined = overlay(fig1, fig2) + + # First trace should have y values from fig1 + assert list(combined.data[0].y) == [1, 2, 3] + # Second trace should have y values from fig2 + assert list(combined.data[1].y) == [10, 20, 30] + + +class TestOverlayFacets: + """Tests for overlay with faceted figures.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_3d = xr.DataArray( + np.random.rand(10, 3, 2), + dims=["time", "cat", "facet"], + coords={ + "time": np.arange(10), + "cat": ["A", "B", "C"], + "facet": ["left", "right"], + }, + name="value", + ) + + def test_matching_facet_structures(self) -> None: + """Test combining figures with matching facet structures.""" + area_fig = xpx(self.da_3d).area(facet_col="facet") + line_fig = xpx(self.da_3d).line(facet_col="facet") + + combined = overlay(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + expected_count = len(area_fig.data) + len(line_fig.data) + assert len(combined.data) == expected_count + + def test_overlay_with_extra_subplots_raises(self) -> None: + """Test that overlay with extra subplots raises ValueError.""" + # Base without facets + base = xpx(self.da_3d.isel(facet=0)).line() + # Overlay with facets + overlay_fig = xpx(self.da_3d).line(facet_col="facet") + + with pytest.raises(ValueError, match="subplots not present in base"): + overlay(base, overlay_fig) + + def test_preserves_axis_references(self) -> None: + """Test that traces preserve their xaxis/yaxis references.""" + area_fig = xpx(self.da_3d).area(facet_col="facet") + line_fig = xpx(self.da_3d).line(facet_col="facet") + + combined = overlay(area_fig, line_fig) + + # Collect axis references from both original and combined + original_axes = set() + for trace in area_fig.data: + xaxis = getattr(trace, "xaxis", None) or "x" + yaxis = getattr(trace, "yaxis", None) or "y" + original_axes.add((xaxis, yaxis)) + + combined_axes = set() + for trace in combined.data: + xaxis = getattr(trace, "xaxis", None) or "x" + yaxis = getattr(trace, "yaxis", None) or "y" + combined_axes.add((xaxis, yaxis)) + + # Combined should have same axis structure + assert combined_axes == original_axes + + +class TestOverlayAnimation: + """Tests for overlay with animated figures.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_3d = xr.DataArray( + np.random.rand(10, 3, 4), + dims=["x", "cat", "time"], + coords={ + "x": np.arange(10), + "cat": ["A", "B", "C"], + "time": [0, 1, 2, 3], + }, + name="value", + ) + + def test_matching_frames_merged(self) -> None: + """Test that matching animation frames are merged correctly.""" + area_fig = xpx(self.da_3d).area(animation_frame="time") + line_fig = xpx(self.da_3d).line(animation_frame="time") + + combined = overlay(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + # Should have same number of frames + assert len(combined.frames) == len(area_fig.frames) + # Each frame should have more data + for i, frame in enumerate(combined.frames): + expected_data = len(area_fig.frames[i].data) + len(line_fig.frames[i].data) + assert len(frame.data) == expected_data + + def test_static_overlay_replicated_to_frames(self) -> None: + """Test that static overlay is replicated to all animation frames.""" + animated = xpx(self.da_3d).area(animation_frame="time") + static = xpx(self.da_3d.isel(time=0)).line() + + combined = overlay(animated, static) + + # Combined should have all frames from animated figure + assert len(combined.frames) == len(animated.frames) + + # Each frame should include the static traces + for frame in combined.frames: + # Frame data should include both animated and static traces + expected_count = len(animated.frames[0].data) + len(static.data) + assert len(frame.data) == expected_count + + def test_animated_overlay_on_static_base_raises(self) -> None: + """Test that animated overlay on static base raises ValueError.""" + static = xpx(self.da_3d.isel(time=0)).line() + animated = xpx(self.da_3d).area(animation_frame="time") + + with pytest.raises(ValueError, match="base figure does not"): + overlay(static, animated) + + def test_mismatched_frame_names_raises(self) -> None: + """Test that mismatched frame names raise ValueError.""" + da1 = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2]}, + ) + da2 = xr.DataArray( + np.random.rand(10, 4), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2, 3]}, + ) + + fig1 = xpx(da1).line(animation_frame="time") + fig2 = xpx(da2).line(animation_frame="time") + + with pytest.raises(ValueError, match="frame names don't match"): + overlay(fig1, fig2) + + def test_frame_names_preserved(self) -> None: + """Test that frame names are preserved in combined figure.""" + area_fig = xpx(self.da_3d).area(animation_frame="time") + line_fig = xpx(self.da_3d).line(animation_frame="time") + + combined = overlay(area_fig, line_fig) + + original_names = {frame.name for frame in area_fig.frames} + combined_names = {frame.name for frame in combined.frames} + assert original_names == combined_names + + +class TestOverlayFacetsAndAnimation: + """Tests for overlay with both facets and animation.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_4d = xr.DataArray( + np.random.rand(10, 3, 2, 4), + dims=["x", "cat", "facet", "time"], + coords={ + "x": np.arange(10), + "cat": ["A", "B", "C"], + "facet": ["left", "right"], + "time": [0, 1, 2, 3], + }, + name="value", + ) + + def test_facets_and_animation_combined(self) -> None: + """Test combining figures with both facets and animation.""" + area_fig = xpx(self.da_4d).area(facet_col="facet", animation_frame="time") + line_fig = xpx(self.da_4d).line(facet_col="facet", animation_frame="time") + + combined = overlay(area_fig, line_fig) + + assert isinstance(combined, go.Figure) + # Check trace count + expected_traces = len(area_fig.data) + len(line_fig.data) + assert len(combined.data) == expected_traces + # Check frame count + assert len(combined.frames) == len(area_fig.frames) + + def test_static_overlay_on_animated_faceted_base(self) -> None: + """Test static overlay replicated on animated faceted base.""" + animated = xpx(self.da_4d).area(facet_col="facet", animation_frame="time") + static = xpx(self.da_4d.isel(time=0)).line(facet_col="facet") + + combined = overlay(animated, static) + + # Should have same frames as animated + assert len(combined.frames) == len(animated.frames) + # Each frame should have combined trace count + for frame in combined.frames: + expected = len(animated.frames[0].data) + len(static.data) + assert len(frame.data) == expected + + +class TestOverlayDeepCopy: + """Tests to ensure overlay creates deep copies.""" + + def test_base_not_modified(self) -> None: + """Test that base figure is not modified.""" + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) + base = xpx(da).area() + original_trace_count = len(base.data) + original_title = copy.deepcopy(base.layout.title) + + overlay_fig = xpx(da).line() + _ = overlay(base, overlay_fig) + + # Base should be unchanged + assert len(base.data) == original_trace_count + assert base.layout.title == original_title + + def test_overlay_not_modified(self) -> None: + """Test that overlay figure is not modified.""" + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) + base = xpx(da).area() + overlay_fig = xpx(da).line() + original_trace_count = len(overlay_fig.data) + + _ = overlay(base, overlay_fig) + + # Overlay should be unchanged + assert len(overlay_fig.data) == original_trace_count + + def test_combined_traces_independent(self) -> None: + """Test that combined traces are independent of originals.""" + da = xr.DataArray(np.random.rand(10, 3), dims=["x", "cat"]) + base = xpx(da).area() + overlay_fig = xpx(da).line() + + combined = overlay(base, overlay_fig) + + # Modify combined figure + combined.data[0].name = "modified" + + # Originals should be unchanged + assert base.data[0].name != "modified" + + +class TestAddSecondaryYBasic: + """Basic tests for add_secondary_y function.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.temp = xr.DataArray( + [20, 22, 25, 23, 21], + dims=["time"], + coords={"time": [0, 1, 2, 3, 4]}, + name="Temperature", + ) + self.precip = xr.DataArray( + [0, 5, 12, 2, 8], + dims=["time"], + coords={"time": [0, 1, 2, 3, 4]}, + name="Precipitation", + ) + + def test_creates_secondary_y_axis(self) -> None: + """Test that secondary y-axis is created.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig) + + assert isinstance(combined, go.Figure) + assert combined.layout.yaxis2 is not None + assert combined.layout.yaxis2.side == "right" + assert combined.layout.yaxis2.overlaying == "y" + + def test_secondary_traces_use_y2(self) -> None: + """Test that secondary figure traces are assigned to y2.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig) + + # First trace (from temp_fig) should use default y + assert combined.data[0].yaxis is None or combined.data[0].yaxis == "y" + # Second trace (from precip_fig) should use y2 + assert combined.data[1].yaxis == "y2" + + def test_preserves_base_layout(self) -> None: + """Test that base figure's layout is preserved.""" + temp_fig = xpx(self.temp).line(title="My Temperature Plot") + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig) + + assert combined.layout.title.text == "My Temperature Plot" + + def test_total_trace_count(self) -> None: + """Test that all traces from both figures are included.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig) + + expected_count = len(temp_fig.data) + len(precip_fig.data) + assert len(combined.data) == expected_count + + def test_secondary_y_title_from_secondary_figure(self) -> None: + """Test that secondary y-axis title comes from secondary figure.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + # Plotly Express sets y-axis title based on the data + + combined = add_secondary_y(temp_fig, precip_fig) + + # The secondary y-axis title should be set + assert combined.layout.yaxis2.title is not None + + def test_custom_secondary_y_title(self) -> None: + """Test that custom secondary y-axis title can be provided.""" + temp_fig = xpx(self.temp).line() + precip_fig = xpx(self.precip).bar() + + combined = add_secondary_y(temp_fig, precip_fig, secondary_y_title="Rain (mm)") + + assert combined.layout.yaxis2.title.text == "Rain (mm)" + + +class TestAddSecondaryYFacets: + """Tests for add_secondary_y with faceted figures.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da = xr.DataArray( + np.random.rand(10, 3), + dims=["time", "facet"], + coords={"time": np.arange(10), "facet": ["A", "B", "C"]}, + name="value", + ) + # Different scale for secondary + self.da_secondary = xr.DataArray( + np.random.rand(10, 3) * 1000, + dims=["time", "facet"], + coords={"time": np.arange(10), "facet": ["A", "B", "C"]}, + name="large_value", + ) + + def test_matching_facets_works(self) -> None: + """Test that matching facet structures work.""" + base = xpx(self.da).line(facet_col="facet") + secondary = xpx(self.da_secondary).bar(facet_col="facet") + + combined = add_secondary_y(base, secondary) + + assert isinstance(combined, go.Figure) + expected_traces = len(base.data) + len(secondary.data) + assert len(combined.data) == expected_traces + + def test_facets_creates_multiple_secondary_axes(self) -> None: + """Test that secondary y-axes are created for each facet.""" + base = xpx(self.da).line(facet_col="facet") + secondary = xpx(self.da_secondary).bar(facet_col="facet") + + combined = add_secondary_y(base, secondary) + + # Should have yaxis2 (secondary for y), yaxis5 (secondary for y2), etc. + # Base has y, y2, y3, so secondary should be y4, y5, y6 + layout_json = combined.layout.to_plotly_json() + assert "yaxis4" in layout_json + assert layout_json["yaxis4"]["overlaying"] == "y" + assert layout_json["yaxis4"]["side"] == "right" + + def test_secondary_traces_remapped_to_correct_axes(self) -> None: + """Test that secondary traces use correct secondary y-axes.""" + base = xpx(self.da).line(facet_col="facet") + secondary = xpx(self.da_secondary).bar(facet_col="facet") + + combined = add_secondary_y(base, secondary) + + # Get secondary trace y-axes + secondary_trace_yaxes = {trace.yaxis for trace in combined.data[len(base.data) :]} + # Should be y4, y5, y6 (secondary axes) + assert secondary_trace_yaxes == {"y4", "y5", "y6"} + + def test_mismatched_facets_raises(self) -> None: + """Test that mismatched facet structures raise ValueError.""" + # Base with facets + base = xpx(self.da).line(facet_col="facet") + # Secondary without facets + secondary = xpx(self.da.isel(facet=0)).bar() + + with pytest.raises(ValueError, match="same facet structure"): + add_secondary_y(base, secondary) + + def test_mismatched_facets_reversed_raises(self) -> None: + """Test that mismatched facets raise (base without, secondary with).""" + # Base without facets + base = xpx(self.da.isel(facet=0)).line() + # Secondary with facets + secondary = xpx(self.da).bar(facet_col="facet") + + with pytest.raises(ValueError, match="same facet structure"): + add_secondary_y(base, secondary) + + def test_facets_with_custom_title(self) -> None: + """Test custom secondary y-axis title with facets.""" + base = xpx(self.da).line(facet_col="facet") + secondary = xpx(self.da_secondary).bar(facet_col="facet") + + combined = add_secondary_y(base, secondary, secondary_y_title="Custom Title") + + # Title should be on the first secondary axis + assert combined.layout.yaxis4.title.text == "Custom Title" + + +class TestAddSecondaryYAnimation: + """Tests for add_secondary_y with animated figures.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_2d = xr.DataArray( + np.random.rand(10, 4), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2, 3]}, + name="value", + ) + + def test_matching_animation_frames(self) -> None: + """Test add_secondary_y with matching animation frames.""" + fig1 = xpx(self.da_2d).line(animation_frame="time") + fig2 = xpx(self.da_2d).bar(animation_frame="time") + + combined = add_secondary_y(fig1, fig2) + + assert len(combined.frames) == len(fig1.frames) + # Verify frame names match + for orig, comb in zip(fig1.frames, combined.frames, strict=False): + assert orig.name == comb.name + + def test_static_secondary_on_animated_base(self) -> None: + """Test static secondary replicated to all animation frames.""" + animated = xpx(self.da_2d).line(animation_frame="time") + static = xpx(self.da_2d.isel(time=0)).bar() + + combined = add_secondary_y(animated, static) + + assert len(combined.frames) == len(animated.frames) + # Each frame should have traces from both figures + for frame in combined.frames: + expected = len(animated.frames[0].data) + len(static.data) + assert len(frame.data) == expected + + def test_animated_secondary_on_static_base_raises(self) -> None: + """Test that animated secondary on static base raises ValueError.""" + static = xpx(self.da_2d.isel(time=0)).line() + animated = xpx(self.da_2d).bar(animation_frame="time") + + with pytest.raises(ValueError, match="base figure does not"): + add_secondary_y(static, animated) + + def test_mismatched_animation_frames_raises(self) -> None: + """Test that mismatched animation frames raise ValueError.""" + da1 = xr.DataArray( + np.random.rand(10, 3), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2]}, + ) + da2 = xr.DataArray( + np.random.rand(10, 4), + dims=["x", "time"], + coords={"x": np.arange(10), "time": [0, 1, 2, 3]}, + ) + + fig1 = xpx(da1).line(animation_frame="time") + fig2 = xpx(da2).bar(animation_frame="time") + + with pytest.raises(ValueError, match="frame names don't match"): + add_secondary_y(fig1, fig2) + + +class TestAddSecondaryYDeepCopy: + """Tests to ensure add_secondary_y creates deep copies.""" + + def test_base_not_modified(self) -> None: + """Test that base figure is not modified.""" + da = xr.DataArray([1, 2, 3, 4, 5], dims=["x"]) + base = xpx(da).line() + original_trace_count = len(base.data) + + secondary = xpx(da).bar() + _ = add_secondary_y(base, secondary) + + assert len(base.data) == original_trace_count + # Base should not have yaxis2 (check via to_plotly_json) + assert "yaxis2" not in base.layout.to_plotly_json() + + def test_secondary_not_modified(self) -> None: + """Test that secondary figure is not modified.""" + da = xr.DataArray([1, 2, 3, 4, 5], dims=["x"]) + base = xpx(da).line() + secondary = xpx(da).bar() + original_yaxis = secondary.data[0].yaxis + + _ = add_secondary_y(base, secondary) + + # Secondary traces should still use original yaxis + assert secondary.data[0].yaxis == original_yaxis diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index 7bc9539..d377b4c 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -53,13 +53,19 @@ from xarray_plotly import config from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor from xarray_plotly.common import SLOT_ORDERS, auto +from xarray_plotly.figures import ( + add_secondary_y, + overlay, + update_traces, +) __all__ = [ "SLOT_ORDERS", - "DataArrayPlotlyAccessor", - "DatasetPlotlyAccessor", + "add_secondary_y", "auto", "config", + "overlay", + "update_traces", "xpx", ] diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py new file mode 100644 index 0000000..2e1f658 --- /dev/null +++ b/xarray_plotly/figures.py @@ -0,0 +1,432 @@ +""" +Helper functions for combining and manipulating Plotly figures. +""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import plotly.graph_objects as go + + +def _get_subplot_axes(fig: go.Figure) -> set[tuple[str, str]]: + """Extract (xaxis, yaxis) pairs from figure traces. + + Args: + fig: A Plotly figure. + + Returns: + Set of (xaxis, yaxis) tuples, e.g., {('x', 'y'), ('x2', 'y2')}. + """ + axes_pairs = set() + for trace in fig.data: + xaxis = getattr(trace, "xaxis", None) or "x" + yaxis = getattr(trace, "yaxis", None) or "y" + axes_pairs.add((xaxis, yaxis)) + return axes_pairs + + +def _validate_compatible_structure(base: go.Figure, overlay: go.Figure) -> None: + """Validate that overlay's subplot structure is compatible with base. + + Args: + base: The base figure. + overlay: The overlay figure to check. + + Raises: + ValueError: If overlay has subplots not present in base. + """ + base_axes = _get_subplot_axes(base) + overlay_axes = _get_subplot_axes(overlay) + + extra_axes = overlay_axes - base_axes + if extra_axes: + raise ValueError( + f"Overlay figure has subplots not present in base figure: {extra_axes}. " + "Ensure both figures have the same facet structure." + ) + + +def _validate_animation_compatibility(base: go.Figure, overlay: go.Figure) -> None: + """Validate animation frame compatibility between base and overlay. + + Args: + base: The base figure. + overlay: The overlay figure to check. + + Raises: + ValueError: If overlay has animation but base doesn't, or frame names don't match. + """ + base_has_frames = bool(base.frames) + overlay_has_frames = bool(overlay.frames) + + if overlay_has_frames and not base_has_frames: + raise ValueError( + "Overlay figure has animation frames but base figure does not. " + "Cannot add animated overlay to static base figure." + ) + + if base_has_frames and overlay_has_frames: + base_frame_names = {frame.name for frame in base.frames} + overlay_frame_names = {frame.name for frame in overlay.frames} + + if base_frame_names != overlay_frame_names: + missing_in_overlay = base_frame_names - overlay_frame_names + extra_in_overlay = overlay_frame_names - base_frame_names + msg = "Animation frame names don't match between base and overlay." + if missing_in_overlay: + msg += f" Missing in overlay: {missing_in_overlay}." + if extra_in_overlay: + msg += f" Extra in overlay: {extra_in_overlay}." + raise ValueError(msg) + + +def _merge_frames( + base: go.Figure, + overlays: list[go.Figure], + base_trace_count: int, + overlay_trace_counts: list[int], +) -> list: + """Merge animation frames from base and overlay figures. + + Args: + base: The base figure with animation frames. + overlays: List of overlay figures (may or may not have frames). + base_trace_count: Number of traces in the base figure. + overlay_trace_counts: Number of traces in each overlay figure. + + Returns: + List of merged frames. + """ + import plotly.graph_objects as go + + merged_frames = [] + + for base_frame in base.frames: + frame_name = base_frame.name + merged_data = list(base_frame.data) + + for overlay, _overlay_trace_count in zip(overlays, overlay_trace_counts, strict=False): + if overlay.frames: + # Find matching frame in overlay + overlay_frame = next((f for f in overlay.frames if f.name == frame_name), None) + if overlay_frame: + merged_data.extend(overlay_frame.data) + else: + # Static overlay: replicate traces to this frame + merged_data.extend(overlay.data) + + merged_frames.append( + go.Frame( + data=merged_data, + name=frame_name, + traces=list(range(base_trace_count + sum(overlay_trace_counts))), + ) + ) + + return merged_frames + + +def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure: + """Overlay multiple Plotly figures on the same axes. + + Creates a new figure with the base figure's layout, sliders, and buttons, + with all overlay traces added on top. Correctly handles faceted figures + and animation frames. + + Args: + base: The base figure whose layout is preserved. + *overlays: One or more figures to overlay on the base. + + Returns: + A new combined figure. + + Raises: + ValueError: If overlay has subplots not in base, animation frames don't match, + or overlay has animation but base doesn't. + + Example: + >>> import numpy as np + >>> import xarray as xr + >>> from xarray_plotly import xpx, overlay + >>> + >>> da = xr.DataArray(np.random.rand(10, 3), dims=["time", "cat"]) + >>> area_fig = xpx(da).area() + >>> line_fig = xpx(da).line() + >>> combined = overlay(area_fig, line_fig) + >>> + >>> # With animation + >>> da3d = xr.DataArray(np.random.rand(10, 3, 4), dims=["x", "cat", "time"]) + >>> area = xpx(da3d).area(animation_frame="time") + >>> line = xpx(da3d).line(animation_frame="time") + >>> combined = overlay(area, line) + """ + import plotly.graph_objects as go + + if not overlays: + # No overlays: return a deep copy of base + return copy.deepcopy(base) + + # Validate all overlays + for overlay in overlays: + _validate_compatible_structure(base, overlay) + _validate_animation_compatibility(base, overlay) + + # Create new figure with base's layout + combined = go.Figure(layout=copy.deepcopy(base.layout)) + + # Add all traces from base + for trace in base.data: + combined.add_trace(copy.deepcopy(trace)) + + # Add all traces from overlays + for overlay in overlays: + for trace in overlay.data: + combined.add_trace(copy.deepcopy(trace)) + + # Handle animation frames + if base.frames: + base_trace_count = len(base.data) + overlay_trace_counts = [len(overlay.data) for overlay in overlays] + merged_frames = _merge_frames(base, list(overlays), base_trace_count, overlay_trace_counts) + combined.frames = merged_frames + + return combined + + +def _build_secondary_y_mapping(base_axes: set[tuple[str, str]]) -> dict[str, str]: + """Build mapping from primary y-axes to secondary y-axes. + + Args: + base_axes: Set of (xaxis, yaxis) pairs from base figure. + + Returns: + Dict mapping primary yaxis names to secondary yaxis names. + E.g., {'y': 'y4', 'y2': 'y5', 'y3': 'y6'} + """ + primary_y_axes = sorted({yaxis for _, yaxis in base_axes}) + + # Find the highest existing yaxis number + max_y_num = 1 # 'y' is 1 + for yaxis in primary_y_axes: + num = 1 if yaxis == "y" else int(yaxis[1:]) + max_y_num = max(max_y_num, num) + + # Create mapping: primary_yaxis -> secondary_yaxis + y_mapping = {} + next_y_num = max_y_num + 1 + for yaxis in primary_y_axes: + y_mapping[yaxis] = f"y{next_y_num}" + next_y_num += 1 + + return y_mapping + + +def add_secondary_y( + base: go.Figure, + secondary: go.Figure, + *, + secondary_y_title: str | None = None, +) -> go.Figure: + """Add a secondary y-axis with traces from another figure. + + Creates a new figure with the base figure's layout and secondary y-axes + on the right side. All traces from the secondary figure are plotted against + the secondary y-axes. Supports faceted figures when both have matching + facet structure. + + Args: + base: The base figure (left y-axis). + secondary: The figure whose traces use the secondary y-axis (right). + secondary_y_title: Optional title for the secondary y-axis. + If not provided, uses the secondary figure's y-axis title. + + Returns: + A new figure with both primary and secondary y-axes. + + Raises: + ValueError: If facet structures don't match, or if animation + frames don't match. + + Example: + >>> import numpy as np + >>> import xarray as xr + >>> from xarray_plotly import xpx, add_secondary_y + >>> + >>> # Two variables with different scales + >>> temp = xr.DataArray([20, 22, 25, 23], dims=["time"], name="Temperature (°C)") + >>> precip = xr.DataArray([0, 5, 12, 2], dims=["time"], name="Precipitation (mm)") + >>> + >>> temp_fig = xpx(temp).line() + >>> precip_fig = xpx(precip).bar() + >>> combined = add_secondary_y(temp_fig, precip_fig) + >>> + >>> # With facets + >>> data = xr.DataArray(np.random.rand(10, 3), dims=["x", "facet"]) + >>> fig1 = xpx(data).line(facet_col="facet") + >>> fig2 = xpx(data * 100).bar(facet_col="facet") # Different scale + >>> combined = add_secondary_y(fig1, fig2) + """ + import plotly.graph_objects as go + + # Get axis pairs from both figures + base_axes = _get_subplot_axes(base) + secondary_axes = _get_subplot_axes(secondary) + + # Validate same facet structure + if base_axes != secondary_axes: + raise ValueError( + f"Base and secondary figures must have the same facet structure. " + f"Base has {base_axes}, secondary has {secondary_axes}." + ) + + # Validate animation compatibility + _validate_animation_compatibility(base, secondary) + + # Build mapping from primary y-axes to secondary y-axes + y_mapping = _build_secondary_y_mapping(base_axes) + + # Create new figure with base's layout + combined = go.Figure(layout=copy.deepcopy(base.layout)) + + # Add all traces from base (primary y-axis) + for trace in base.data: + combined.add_trace(copy.deepcopy(trace)) + + # Add all traces from secondary, remapped to secondary y-axes + for trace in secondary.data: + trace_copy = copy.deepcopy(trace) + original_yaxis = getattr(trace_copy, "yaxis", None) or "y" + trace_copy.yaxis = y_mapping[original_yaxis] + combined.add_trace(trace_copy) + + # Configure secondary y-axes + for primary_yaxis, secondary_yaxis in y_mapping.items(): + # Get title - only set on first secondary axis or use provided title + title = None + if secondary_y_title is not None: + # Only set title on the first secondary axis to avoid repetition + if primary_yaxis == "y": + title = secondary_y_title + elif primary_yaxis == "y" and secondary.layout.yaxis and secondary.layout.yaxis.title: + # Try to get from secondary's layout + title = secondary.layout.yaxis.title.text + + # Configure the secondary axis + axis_config = { + "title": title, + "overlaying": primary_yaxis, + "side": "right", + "anchor": "free" if primary_yaxis != "y" else None, + } + # Remove None values + axis_config = {k: v for k, v in axis_config.items() if v is not None} + + # Convert y2 -> yaxis2, y3 -> yaxis3, etc. for layout property name + layout_prop = "yaxis" if secondary_yaxis == "y" else f"yaxis{secondary_yaxis[1:]}" + combined.update_layout(**{layout_prop: axis_config}) + + # Handle animation frames + if base.frames: + merged_frames = _merge_secondary_y_frames(base, secondary, y_mapping) + combined.frames = merged_frames + + return combined + + +def _merge_secondary_y_frames( + base: go.Figure, + secondary: go.Figure, + y_mapping: dict[str, str], +) -> list: + """Merge animation frames for secondary y-axis combination. + + Args: + base: The base figure with animation frames. + secondary: The secondary figure (may or may not have frames). + y_mapping: Mapping from primary y-axis names to secondary y-axis names. + + Returns: + List of merged frames with secondary traces assigned to secondary y-axes. + """ + import plotly.graph_objects as go + + merged_frames = [] + base_trace_count = len(base.data) + secondary_trace_count = len(secondary.data) + + for base_frame in base.frames: + frame_name = base_frame.name + merged_data = list(base_frame.data) + + if secondary.frames: + # Find matching frame in secondary + secondary_frame = next((f for f in secondary.frames if f.name == frame_name), None) + if secondary_frame: + # Add secondary frame data with remapped y-axis + for trace_data in secondary_frame.data: + trace_copy = copy.deepcopy(trace_data) + original_yaxis = getattr(trace_copy, "yaxis", None) or "y" + trace_copy.yaxis = y_mapping.get(original_yaxis, original_yaxis) + merged_data.append(trace_copy) + else: + # Static secondary: replicate traces to this frame + for trace in secondary.data: + trace_copy = copy.deepcopy(trace) + original_yaxis = getattr(trace_copy, "yaxis", None) or "y" + trace_copy.yaxis = y_mapping.get(original_yaxis, original_yaxis) + merged_data.append(trace_copy) + + merged_frames.append( + go.Frame( + data=merged_data, + name=frame_name, + traces=list(range(base_trace_count + secondary_trace_count)), + ) + ) + + return merged_frames + + +def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.Figure: + """Update traces in both base figure and all animation frames. + + Plotly's `update_traces()` only updates the base figure, not animation frames. + This function updates both, ensuring trace styles persist during animation. + + Args: + fig: A Plotly figure, optionally with animation frames. + selector: Dict to match specific traces, e.g. ``{"name": "Germany"}``. + If None, updates all traces. + **kwargs: Trace properties to update, e.g. ``line_width=4``, ``line_dash="dot"``. + + Returns: + The modified figure (same object, mutated in place). + + Example: + >>> import plotly.express as px + >>> from xarray_plotly import update_traces + >>> + >>> df = px.data.gapminder() + >>> fig = px.line(df, x="year", y="gdpPercap", color="country", animation_frame="continent") + >>> + >>> # Update all traces + >>> update_traces(fig, line_width=3) + >>> + >>> # Update specific trace by name + >>> update_traces(fig, selector={"name": "Germany"}, line_width=5, line_dash="dot") + """ + fig.update_traces(selector=selector, **kwargs) + + for frame in fig.frames: + for trace in frame.data: + if selector is None: + trace.update(**kwargs) + else: + # Check if trace matches all selector criteria + if all(getattr(trace, k, None) == v for k, v in selector.items()): + trace.update(**kwargs) + + return fig