diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..2e8b124 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,25 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + groups: + dev-dependencies: + patterns: + - "pytest*" + - "mypy" + - "ruff" + - "pre-commit" + - "nbstripout" + docs-dependencies: + patterns: + - "mkdocs*" + - "pooch" + - "netcdf4" + - "jupyter" + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/dependabot-auto-merge.yml b/.github/workflows/dependabot-auto-merge.yml new file mode 100644 index 0000000..3a701d5 --- /dev/null +++ b/.github/workflows/dependabot-auto-merge.yml @@ -0,0 +1,24 @@ +name: Dependabot auto-merge +on: pull_request + +permissions: + contents: write + pull-requests: write + +jobs: + dependabot: + runs-on: ubuntu-latest + if: github.actor == 'dependabot[bot]' + steps: + - name: Dependabot metadata + id: metadata + uses: dependabot/fetch-metadata@v2 + with: + github-token: "${{ secrets.GITHUB_TOKEN }}" + + - name: Auto-merge minor and patch updates + if: steps.metadata.outputs.update-type == 'version-update:semver-patch' || steps.metadata.outputs.update-type == 'version-update:semver-minor' + run: gh pr merge --auto --squash "$PR_URL" + env: + PR_URL: ${{ github.event.pull_request.html_url }} + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/docs/examples/advanced.ipynb b/docs/examples/advanced.ipynb index 1b23f46..71bda7c 100644 --- a/docs/examples/advanced.ipynb +++ b/docs/examples/advanced.ipynb @@ -6,7 +6,7 @@ "source": [ "# Advanced Usage\n", "\n", - "This notebook covers advanced patterns and customization options." + "This notebook covers advanced Plotly customization. All kwargs are passed directly to [Plotly Express](https://plotly.com/python/plotly-express/), and figures can be modified using the full [Plotly Graph Objects API](https://plotly.com/python/graph-objects/)." ] }, { @@ -15,22 +15,13 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import pandas as pd\n", + "import plotly.express as px\n", + "import plotly.graph_objects as go\n", "import xarray as xr\n", "\n", "from xarray_plotly import config, xpx\n", "\n", - "config.notebook() # Configure Plotly for notebook rendering" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Working with xarray Attributes\n", - "\n", - "xarray_plotly automatically uses metadata from xarray attributes for labels:" + "config.notebook()" ] }, { @@ -39,37 +30,37 @@ "metadata": {}, "outputs": [], "source": [ - "da = xr.DataArray(\n", - " np.random.randn(30, 3).cumsum(axis=0) + 15,\n", - " dims=[\"time\", \"station\"],\n", - " coords={\n", - " \"time\": pd.date_range(\"2024-01-01\", periods=30, freq=\"D\"),\n", - " \"station\": [\"Alpine\", \"Coastal\", \"Urban\"],\n", - " },\n", - " name=\"temperature\",\n", - " attrs={\n", - " \"long_name\": \"Air Temperature\",\n", - " \"units\": \"°C\",\n", - " \"standard_name\": \"air_temperature\",\n", - " },\n", - ")\n", + "# Load sample data\n", + "df_stocks = px.data.stocks().set_index(\"date\")\n", + "df_stocks.index = df_stocks.index.astype(\"datetime64[ns]\")\n", "\n", - "# Add coordinate attributes\n", - "da.coords[\"time\"].attrs = {\"long_name\": \"Time\", \"units\": \"days\"}\n", - "da.coords[\"station\"].attrs = {\"long_name\": \"Measurement Station\"}\n", + "stocks = xr.DataArray(\n", + " df_stocks.values,\n", + " dims=[\"date\", \"company\"],\n", + " coords={\"date\": df_stocks.index, \"company\": df_stocks.columns.tolist()},\n", + " name=\"price\",\n", + ")\n", "\n", - "# Labels are automatically extracted from attrs\n", - "fig = xpx(da).line(title=\"Temperature with Auto-Labels\")\n", - "fig" + "df_gap = px.data.gapminder()\n", + "countries = [\"United States\", \"China\", \"Germany\", \"Brazil\", \"Nigeria\"]\n", + "df_life = df_gap[df_gap[\"country\"].isin(countries)].pivot(\n", + " index=\"year\", columns=\"country\", values=\"lifeExp\"\n", + ")\n", + "life_exp = xr.DataArray(\n", + " df_life.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": df_life.index, \"country\": df_life.columns.tolist()},\n", + " name=\"life_exp\",\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Configuring Label Behavior\n", + "## Passing Kwargs to Plotly Express\n", "\n", - "Use `config.set_options()` to control how labels are extracted from attributes:" + "All keyword arguments are passed directly to [Plotly Express functions](https://plotly.com/python-api-reference/plotly.express.html):" ] }, { @@ -78,72 +69,47 @@ "metadata": {}, "outputs": [], "source": [ - "# Disable units in labels\n", - "with config.set_options(label_include_units=False):\n", - " fig = xpx(da).line(title=\"Without Units in Labels\")\n", + "# All px.line kwargs work: template, labels, colors, etc.\n", + "fig = xpx(stocks).line(\n", + " title=\"Stock Performance\",\n", + " template=\"plotly_white\",\n", + " labels={\"price\": \"Normalized Price\", \"date\": \"Date\", \"company\": \"Ticker\"},\n", + " color_discrete_sequence=px.colors.qualitative.Set2,\n", + ")\n", "fig" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Overriding Labels\n", - "\n", - "You can override the automatic labels:" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da).line(\n", - " labels={\n", - " \"temperature\": \"Temp (°C)\",\n", - " \"time\": \"Date\",\n", - " \"station\": \"Location\",\n", - " },\n", - " title=\"Custom Labels\",\n", + "# px.imshow kwargs: colorscale, midpoint, aspect\n", + "# See: https://plotly.com/python/imshow/\n", + "life_change = life_exp - life_exp.isel(year=0)\n", + "\n", + "fig = xpx(life_change).imshow(\n", + " color_continuous_scale=\"RdBu\",\n", + " color_continuous_midpoint=0,\n", + " aspect=\"auto\",\n", + " title=\"Life Expectancy Change Since 1952\",\n", ")\n", "fig" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Advanced Dimension Assignment\n", - "\n", - "### Complex Slot Assignments" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "np.random.seed(42)\n", - "\n", - "da_complex = xr.DataArray(\n", - " np.random.randn(20, 3, 2, 2),\n", - " dims=[\"time\", \"city\", \"scenario\", \"model\"],\n", - " coords={\n", - " \"time\": pd.date_range(\"2024-01-01\", periods=20),\n", - " \"city\": [\"NYC\", \"LA\", \"Chicago\"],\n", - " \"scenario\": [\"SSP2\", \"SSP5\"],\n", - " \"model\": [\"GCM-A\", \"GCM-B\"],\n", - " },\n", - " name=\"projection\",\n", - ")\n", - "\n", - "# Use line_dash for one dimension, color for another\n", - "fig = xpx(da_complex.sel(city=\"NYC\")).line(\n", - " color=\"scenario\",\n", - " line_dash=\"model\",\n", - " title=\"Multiple Visual Encodings\",\n", + "# px.bar kwargs: barmode, text_auto\n", + "# See: https://plotly.com/python/bar-charts/\n", + "fig = xpx(life_exp.sel(year=[1952, 1982, 2007])).bar(\n", + " barmode=\"group\",\n", + " text_auto=\".1f\",\n", + " title=\"Life Expectancy\",\n", ")\n", "fig" ] @@ -152,23 +118,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Reducing Dimensions Before Plotting\n", + "## Modifying Figures After Creation\n", "\n", - "When you have more dimensions than slots, reduce them first:" + "All methods return a Plotly `Figure`. Use the [Figure API](https://plotly.com/python/creating-and-updating-figures/) to customize:" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# Average over model dimension\n", - "fig = xpx(da_complex.mean(\"model\")).line(\n", - " facet_col=\"city\",\n", - " title=\"Ensemble Mean by City\",\n", - ")\n", - "fig" + "### update_layout\n", + "\n", + "Modify [layout properties](https://plotly.com/python/reference/layout/): legend, margins, fonts, etc." ] }, { @@ -177,10 +138,18 @@ "metadata": {}, "outputs": [], "source": [ - "# Select a specific slice\n", - "fig = xpx(da_complex.sel(scenario=\"SSP5\", model=\"GCM-A\")).line(\n", - " facet_col=\"city\",\n", - " title=\"SSP5 / GCM-A Projections\",\n", + "fig = xpx(stocks).line()\n", + "\n", + "fig.update_layout(\n", + " title={\"text\": \"Stock Prices\", \"x\": 0.5, \"font\": {\"size\": 24}},\n", + " legend={\n", + " \"orientation\": \"h\",\n", + " \"yanchor\": \"bottom\",\n", + " \"y\": 1.02,\n", + " \"xanchor\": \"center\",\n", + " \"x\": 0.5,\n", + " },\n", + " margin={\"t\": 80},\n", ")\n", "fig" ] @@ -189,14 +158,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Custom Styling" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Themes" + "### update_traces\n", + "\n", + "Modify [trace properties](https://plotly.com/python/reference/): line width, markers, opacity, etc." ] }, { @@ -205,34 +169,21 @@ "metadata": {}, "outputs": [], "source": [ - "da_simple = da.sel(station=\"Urban\")\n", - "\n", - "fig = xpx(da_simple).line(\n", - " template=\"plotly_dark\",\n", - " title=\"Dark Theme\",\n", - ")\n", + "fig = xpx(stocks).line()\n", + "fig.update_traces(line={\"width\": 3})\n", + "fig.update_layout(title=\"Thicker Lines\")\n", "fig" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Custom Colors" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "import plotly.express as px\n", - "\n", - "fig = xpx(da).line(\n", - " color_discrete_sequence=px.colors.qualitative.Set2,\n", - " title=\"Custom Color Palette\",\n", - ")\n", + "fig = xpx(stocks).scatter()\n", + "fig.update_traces(marker={\"size\": 12, \"opacity\": 0.7, \"line\": {\"width\": 1, \"color\": \"white\"}})\n", + "fig.update_layout(title=\"Custom Markers\")\n", "fig" ] }, @@ -240,7 +191,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Heatmap Colorscales" + "### update_xaxes / update_yaxes\n", + "\n", + "Modify [axis properties](https://plotly.com/python/axes/): range slider, tick format, etc." ] }, { @@ -249,18 +202,9 @@ "metadata": {}, "outputs": [], "source": [ - "da_2d = xr.DataArray(\n", - " np.random.randn(20, 30),\n", - " dims=[\"lat\", \"lon\"],\n", - " name=\"anomaly\",\n", - ")\n", - "\n", - "# Diverging colorscale centered at zero\n", - "fig = xpx(da_2d).imshow(\n", - " color_continuous_scale=\"RdBu_r\",\n", - " color_continuous_midpoint=0,\n", - " title=\"Diverging Colorscale\",\n", - ")\n", + "fig = xpx(stocks).line(title=\"With Range Slider\")\n", + "fig.update_xaxes(rangeslider_visible=True)\n", + "fig.update_yaxes(tickformat=\".0%\")\n", "fig" ] }, @@ -268,9 +212,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Post-Creation Customization\n", + "### Adding Shapes and Annotations\n", "\n", - "All plots return Plotly `Figure` objects that can be extensively customized:" + "Add reference lines, [shapes](https://plotly.com/python/shapes/), and [annotations](https://plotly.com/python/text-and-annotations/):" ] }, { @@ -279,23 +223,20 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da).line()\n", + "fig = xpx(stocks).line(title=\"With Reference Lines and Annotations\")\n", "\n", - "# Add horizontal reference line\n", - "fig.add_hline(y=15, line_dash=\"dash\", line_color=\"gray\", annotation_text=\"Reference\")\n", + "fig.add_hline(y=1.0, line_dash=\"dash\", line_color=\"gray\", annotation_text=\"Baseline\")\n", + "fig.add_vline(x=\"2018-10-01\", line_dash=\"dot\", line_color=\"red\")\n", "\n", - "# Update layout\n", - "fig.update_layout(\n", - " title=\"Temperature with Reference Line\",\n", - " legend={\n", - " \"orientation\": \"h\",\n", - " \"yanchor\": \"bottom\",\n", - " \"y\": 1.02,\n", - " \"xanchor\": \"right\",\n", - " \"x\": 1,\n", - " },\n", + "fig.add_annotation(\n", + " x=\"2018-10-01\",\n", + " y=1.4,\n", + " text=\"Market correction\",\n", + " showarrow=True,\n", + " arrowhead=2,\n", + " ax=40,\n", + " ay=-40,\n", ")\n", - "\n", "fig" ] }, @@ -303,7 +244,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Modifying Traces" + "### Adding Traces with Graph Objects\n", + "\n", + "Add custom traces using [Plotly Graph Objects](https://plotly.com/python/graph-objects/):" ] }, { @@ -312,12 +255,21 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da).line()\n", + "fig = xpx(stocks.sel(company=\"GOOG\")).line(title=\"GOOG with Moving Average\")\n", "\n", - "# Make all lines thicker\n", - "fig.update_traces(line_width=3)\n", + "# Add moving average as a new trace\n", + "goog = stocks.sel(company=\"GOOG\")\n", + "ma_20 = goog.rolling(date=20, center=True).mean()\n", "\n", - "fig.update_layout(title=\"Thicker Lines\")\n", + "fig.add_trace(\n", + " go.Scatter(\n", + " x=ma_20.coords[\"date\"].values,\n", + " y=ma_20.values,\n", + " mode=\"lines\",\n", + " name=\"20-day MA\",\n", + " line={\"dash\": \"dash\", \"color\": \"red\", \"width\": 2},\n", + " )\n", + ")\n", "fig" ] }, @@ -325,27 +277,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Exporting Figures" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Interactive HTML\n", + "## Exporting Figures\n", "\n", - "```python\n", - "fig.write_html(\"interactive_plot.html\")\n", - "```\n", - "\n", - "### Static Images\n", - "\n", - "Requires `kaleido`: `pip install kaleido`\n", + "See [static image export](https://plotly.com/python/static-image-export/) and [HTML export](https://plotly.com/python/interactive-html-export/).\n", "\n", "```python\n", - "fig.write_image(\"plot.png\", scale=2) # High resolution\n", - "fig.write_image(\"plot.svg\") # Vector format\n", - "fig.write_image(\"plot.pdf\") # PDF\n", + "# Interactive HTML\n", + "fig.write_html(\"plot.html\")\n", + "\n", + "# Static images (requires: pip install kaleido)\n", + "fig.write_image(\"plot.png\", scale=2)\n", + "fig.write_image(\"plot.svg\")\n", + "fig.write_image(\"plot.pdf\")\n", "```" ] }, @@ -353,59 +296,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Integration Examples" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### With xarray operations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Rolling mean\n", - "da_smooth = da.rolling(time=7, center=True).mean()\n", + "## More Resources\n", "\n", - "fig = xpx(da_smooth).line(\n", - " title=\"7-Day Rolling Mean\",\n", - ")\n", - "fig" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Groupby operations\n", - "da_monthly = xr.DataArray(\n", - " np.random.randn(365, 3).cumsum(axis=0),\n", - " dims=[\"time\", \"category\"],\n", - " coords={\n", - " \"time\": pd.date_range(\"2024-01-01\", periods=365),\n", - " \"category\": [\"A\", \"B\", \"C\"],\n", - " },\n", - " name=\"value\",\n", - ")\n", - "\n", - "monthly_mean = da_monthly.groupby(\"time.month\").mean()\n", - "\n", - "fig = xpx(monthly_mean).line(\n", - " title=\"Monthly Climatology\",\n", - ")\n", - "fig.update_xaxes(\n", - " tickmode=\"array\",\n", - " tickvals=list(range(1, 13)),\n", - " ticktext=[\"Jan\", \"Feb\", \"Mar\", \"Apr\", \"May\", \"Jun\", \"Jul\", \"Aug\", \"Sep\", \"Oct\", \"Nov\", \"Dec\"],\n", - ")\n", - "fig" + "- [Plotly Express API](https://plotly.com/python-api-reference/plotly.express.html)\n", + "- [Figure Reference](https://plotly.com/python/reference/)\n", + "- [Plotly Tutorials](https://plotly.com/python/)" ] } ], diff --git a/docs/examples/plot-types.ipynb b/docs/examples/plot-types.ipynb index 94ab5b7..1f068a7 100644 --- a/docs/examples/plot-types.ipynb +++ b/docs/examples/plot-types.ipynb @@ -15,8 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import pandas as pd\n", + "import plotly.express as px\n", "import xarray as xr\n", "\n", "from xarray_plotly import config, xpx\n", @@ -28,9 +27,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Sample Data\n", + "## Load Sample Data\n", "\n", - "Let's create some sample data to work with:" + "We'll use plotly's built-in datasets converted to xarray:\n", + "- **stocks**: Tech company stock prices over time\n", + "- **gapminder**: Country statistics (life expectancy, GDP, population) by year" ] }, { @@ -39,40 +40,62 @@ "metadata": {}, "outputs": [], "source": [ - "np.random.seed(42)\n", + "# Stock prices: 2D (date, company)\n", + "df_stocks = px.data.stocks().set_index(\"date\")\n", + "df_stocks.index = df_stocks.index.astype(\"datetime64[ns]\")\n", "\n", - "# Time series data\n", - "da_ts = xr.DataArray(\n", - " np.random.randn(30, 3).cumsum(axis=0),\n", - " dims=[\"time\", \"category\"],\n", - " coords={\n", - " \"time\": pd.date_range(\"2024-01-01\", periods=30),\n", - " \"category\": [\"A\", \"B\", \"C\"],\n", - " },\n", - " name=\"value\",\n", + "stocks = xr.DataArray(\n", + " df_stocks.values,\n", + " dims=[\"date\", \"company\"],\n", + " coords={\"date\": df_stocks.index, \"company\": df_stocks.columns.tolist()},\n", + " name=\"price\",\n", + ")\n", + "print(f\"stocks: {dict(stocks.sizes)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Gapminder: pivot to create multi-dimensional arrays\n", + "df_gap = px.data.gapminder()\n", + "\n", + "# Life expectancy: 2D (year, country) - select a few countries\n", + "countries = [\"United States\", \"China\", \"Germany\", \"Brazil\", \"Nigeria\"]\n", + "df_life = df_gap[df_gap[\"country\"].isin(countries)].pivot(\n", + " index=\"year\", columns=\"country\", values=\"lifeExp\"\n", ")\n", "\n", - "# 2D grid data\n", - "da_2d = xr.DataArray(\n", - " np.random.rand(20, 30),\n", - " dims=[\"lat\", \"lon\"],\n", - " coords={\n", - " \"lat\": np.linspace(-90, 90, 20),\n", - " \"lon\": np.linspace(-180, 180, 30),\n", - " },\n", - " name=\"temperature\",\n", + "life_exp = xr.DataArray(\n", + " df_life.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": df_life.index, \"country\": df_life.columns.tolist()},\n", + " name=\"life_expectancy\",\n", + " attrs={\"units\": \"years\"},\n", ")\n", + "print(f\"life_exp: {dict(life_exp.sizes)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# GDP per capita by continent and year (aggregated)\n", + "df_continent = df_gap.groupby([\"continent\", \"year\"])[\"gdpPercap\"].mean().reset_index()\n", + "df_gdp = df_continent.pivot(index=\"year\", columns=\"continent\", values=\"gdpPercap\")\n", "\n", - "# Categorical data\n", - "da_cat = xr.DataArray(\n", - " np.random.rand(4, 3) * 100,\n", - " dims=[\"product\", \"region\"],\n", - " coords={\n", - " \"product\": [\"Widget\", \"Gadget\", \"Gizmo\", \"Thingamajig\"],\n", - " \"region\": [\"North\", \"South\", \"West\"],\n", - " },\n", - " name=\"sales\",\n", - ")" + "gdp = xr.DataArray(\n", + " df_gdp.values,\n", + " dims=[\"year\", \"continent\"],\n", + " coords={\"year\": df_gdp.index, \"continent\": df_gdp.columns.tolist()},\n", + " name=\"gdp_per_capita\",\n", + " attrs={\"units\": \"USD\"},\n", + ")\n", + "print(f\"gdp: {dict(gdp.sizes)}\")" ] }, { @@ -90,24 +113,17 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da_ts).line(title=\"Line Plot\")\n", + "fig = xpx(stocks).line(title=\"Stock Prices Over Time\")\n", "fig" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### With markers" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da_ts).line(markers=True, title=\"Line Plot with Markers\")\n", + "fig = xpx(life_exp).line(title=\"Life Expectancy by Country\", markers=True)\n", "fig" ] }, @@ -126,7 +142,8 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da_cat).bar(title=\"Stacked Bar Chart\")\n", + "# GDP by continent for selected years\n", + "fig = xpx(gdp.sel(year=[1952, 1977, 2007])).bar(title=\"GDP per Capita by Continent\")\n", "fig" ] }, @@ -143,7 +160,7 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da_cat).bar(barmode=\"group\", title=\"Grouped Bar Chart\")\n", + "fig = xpx(gdp.sel(year=[1952, 1977, 2007])).bar(barmode=\"group\", title=\"GDP per Capita (Grouped)\")\n", "fig" ] }, @@ -162,18 +179,19 @@ "metadata": {}, "outputs": [], "source": [ - "# Use absolute values for stacking to make sense\n", - "da_positive = xr.DataArray(\n", - " np.abs(np.random.randn(30, 3)) * 10,\n", - " dims=[\"time\", \"source\"],\n", - " coords={\n", - " \"time\": pd.date_range(\"2024-01-01\", periods=30),\n", - " \"source\": [\"Solar\", \"Wind\", \"Hydro\"],\n", - " },\n", - " name=\"energy\",\n", + "# Population by continent over time\n", + "df_pop = df_gap.groupby([\"continent\", \"year\"])[\"pop\"].sum().reset_index()\n", + "df_pop_pivot = df_pop.pivot(index=\"year\", columns=\"continent\", values=\"pop\")\n", + "\n", + "population = xr.DataArray(\n", + " df_pop_pivot.values / 1e9, # Convert to billions\n", + " dims=[\"year\", \"continent\"],\n", + " coords={\"year\": df_pop_pivot.index, \"continent\": df_pop_pivot.columns.tolist()},\n", + " name=\"population\",\n", + " attrs={\"units\": \"billions\"},\n", ")\n", "\n", - "fig = xpx(da_positive).area(title=\"Stacked Area Chart\")\n", + "fig = xpx(population).area(title=\"World Population by Continent\")\n", "fig" ] }, @@ -192,7 +210,7 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da_ts).scatter(title=\"Scatter Plot\")\n", + "fig = xpx(stocks).scatter(title=\"Stock Prices Scatter\")\n", "fig" ] }, @@ -200,9 +218,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Dimension vs Dimension (Geographic style)\n", + "## Box Plot\n", "\n", - "You can plot one dimension against another, with values shown as color:" + "Best for showing distributions:" ] }, { @@ -211,34 +229,19 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da_2d).scatter(x=\"lon\", y=\"lat\", color=\"value\", title=\"Lat/Lon Scatter\")\n", + "# Stock price distributions by company\n", + "fig = xpx(stocks).box(title=\"Stock Price Distributions\")\n", "fig" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Box Plot\n", - "\n", - "Best for showing distributions:" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# Create data with more samples\n", - "da_dist = xr.DataArray(\n", - " np.random.randn(100, 4) + np.array([0, 1, 2, 3]),\n", - " dims=[\"sample\", \"group\"],\n", - " coords={\"group\": [\"Control\", \"Treatment A\", \"Treatment B\", \"Treatment C\"]},\n", - " name=\"response\",\n", - ")\n", - "\n", - "fig = xpx(da_dist).box(title=\"Box Plot\")\n", + "# Life expectancy distributions by country\n", + "fig = xpx(life_exp).box(title=\"Life Expectancy Distribution by Country\")\n", "fig" ] }, @@ -257,7 +260,7 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da_2d).imshow(title=\"Heatmap\")\n", + "fig = xpx(life_exp).imshow(title=\"Life Expectancy Heatmap\")\n", "fig" ] }, @@ -274,9 +277,9 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da_2d).imshow(\n", + "fig = xpx(gdp).imshow(\n", " color_continuous_scale=\"Viridis\",\n", - " title=\"Heatmap with Viridis colorscale\",\n", + " title=\"GDP per Capita by Continent and Year\",\n", ")\n", "fig" ] @@ -296,22 +299,27 @@ "metadata": {}, "outputs": [], "source": [ - "# 3D data for faceting\n", - "da_3d = xr.DataArray(\n", - " np.random.randn(30, 3, 2).cumsum(axis=0),\n", - " dims=[\"time\", \"city\", \"scenario\"],\n", - " coords={\n", - " \"time\": pd.date_range(\"2024-01-01\", periods=30),\n", - " \"city\": [\"NYC\", \"LA\", \"Chicago\"],\n", - " \"scenario\": [\"Low\", \"High\"],\n", - " },\n", - " name=\"value\",\n", + "# Create 3D data: life expectancy by year, country, and metric\n", + "# We'll add GDP as another \"metric\" dimension\n", + "df_metrics = df_gap[df_gap[\"country\"].isin(countries)].pivot(\n", + " index=\"year\", columns=\"country\", values=\"gdpPercap\"\n", + ")\n", + "gdp_countries = xr.DataArray(\n", + " df_metrics.values,\n", + " dims=[\"year\", \"country\"],\n", + " coords={\"year\": df_metrics.index, \"country\": df_metrics.columns.tolist()},\n", + " name=\"gdp_per_capita\",\n", + ")\n", + "\n", + "# Combine into 3D array\n", + "combined = xr.concat(\n", + " [life_exp, gdp_countries / 1000], # Scale GDP to thousands\n", + " dim=xr.Variable(\"metric\", [\"Life Exp (years)\", \"GDP (thousands USD)\"]),\n", ")\n", "\n", - "fig = xpx(da_3d).line(\n", - " facet_col=\"city\",\n", - " facet_row=\"scenario\",\n", - " title=\"Faceted Line Plot\",\n", + "fig = xpx(combined).line(\n", + " facet_col=\"metric\",\n", + " title=\"Country Comparison: Life Expectancy and GDP\",\n", ")\n", "fig" ] @@ -331,35 +339,28 @@ "metadata": {}, "outputs": [], "source": [ - "# Create monthly data\n", - "da_monthly = xr.DataArray(\n", - " np.random.rand(12, 4) * 100,\n", - " dims=[\"month\", \"product\"],\n", - " coords={\n", - " \"month\": [\n", - " \"Jan\",\n", - " \"Feb\",\n", - " \"Mar\",\n", - " \"Apr\",\n", - " \"May\",\n", - " \"Jun\",\n", - " \"Jul\",\n", - " \"Aug\",\n", - " \"Sep\",\n", - " \"Oct\",\n", - " \"Nov\",\n", - " \"Dec\",\n", - " ],\n", - " \"product\": [\"A\", \"B\", \"C\", \"D\"],\n", - " },\n", - " name=\"sales\",\n", + "# Animated bar chart of GDP by continent over time\n", + "fig = xpx(gdp).bar(\n", + " x=\"continent\",\n", + " animation_frame=\"year\",\n", + " title=\"GDP per Capita by Continent (Animated)\",\n", + " range_y=[0, 35000],\n", ")\n", - "\n", - "fig = xpx(da_monthly).bar(\n", - " x=\"product\",\n", - " animation_frame=\"month\",\n", - " title=\"Monthly Sales (Animated)\",\n", - " range_y=[0, 120],\n", + "fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Animated line showing life expectancy evolution\n", + "fig = xpx(life_exp).bar(\n", + " x=\"country\",\n", + " animation_frame=\"year\",\n", + " title=\"Life Expectancy by Country (Animated)\",\n", + " range_y=[0, 85],\n", ")\n", "fig" ] diff --git a/docs/getting-started.ipynb b/docs/getting-started.ipynb index cd80907..11c0bfa 100644 --- a/docs/getting-started.ipynb +++ b/docs/getting-started.ipynb @@ -41,8 +41,7 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import pandas as pd\n", + "import plotly.express as px\n", "import xarray as xr\n", "\n", "from xarray_plotly import config, xpx\n", @@ -54,9 +53,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Create Sample Data\n", + "## Load Sample Data\n", "\n", - "Let's create a DataArray with multiple dimensions:" + "We'll use plotly's built-in stock price data and convert it to an xarray DataArray:" ] }, { @@ -65,22 +64,20 @@ "metadata": {}, "outputs": [], "source": [ - "# Create sample climate data\n", - "np.random.seed(42)\n", - "\n", - "da = xr.DataArray(\n", - " np.random.randn(50, 3, 2).cumsum(axis=0), # Random walk\n", - " dims=[\"time\", \"city\", \"scenario\"],\n", - " coords={\n", - " \"time\": pd.date_range(\"2020-01-01\", periods=50, freq=\"D\"),\n", - " \"city\": [\"New York\", \"Los Angeles\", \"Chicago\"],\n", - " \"scenario\": [\"baseline\", \"warming\"],\n", - " },\n", - " name=\"temperature\",\n", - " attrs={\"long_name\": \"Temperature Anomaly\", \"units\": \"°C\"},\n", + "# Load stock prices from plotly\n", + "df = px.data.stocks()\n", + "df = df.set_index(\"date\")\n", + "df.index = df.index.astype(\"datetime64[ns]\")\n", + "\n", + "# Convert to xarray DataArray\n", + "stocks = xr.DataArray(\n", + " df.values,\n", + " dims=[\"date\", \"company\"],\n", + " coords={\"date\": df.index, \"company\": df.columns.tolist()},\n", + " name=\"price\",\n", + " attrs={\"long_name\": \"Stock Price\", \"units\": \"normalized\"},\n", ")\n", - "\n", - "da.to_dataframe()" + "stocks" ] }, { @@ -98,8 +95,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Dimensions auto-assign: time→x, city→color, scenario→facet_col\n", - "fig = xpx(da).line()\n", + "# Dimensions auto-assign: date->x, company->color\n", + "fig = xpx(stocks).line()\n", "fig" ] }, @@ -124,9 +121,8 @@ "\n", "| Dimension | Slot |\n", "|-----------|------|\n", - "| time (1st) | x-axis |\n", - "| city (2nd) | color |\n", - "| scenario (3rd) | facet_col |\n", + "| date (1st) | x-axis |\n", + "| company (2nd) | color |\n", "\n", "You can override this with explicit assignments:" ] @@ -137,8 +133,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Put scenario on color, city on facets\n", - "fig = xpx(da).line(color=\"scenario\", facet_col=\"city\")\n", + "# Put company on x-axis, date on color (just first few dates)\n", + "fig = xpx(stocks.isel(date=[0, 25, 50, 75, 100])).bar(x=\"company\", color=\"date\")\n", "fig" ] }, @@ -157,8 +153,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Skip color, so city goes to line_dash instead\n", - "fig = xpx(da.sel(scenario=\"baseline\")).line(color=None)\n", + "# Skip color, so company goes to line_dash instead\n", + "fig = xpx(stocks.sel(company=[\"GOOG\", \"AAPL\", \"MSFT\"])).line(color=None)\n", "fig" ] }, @@ -177,12 +173,12 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da).line()\n", + "fig = xpx(stocks).line()\n", "\n", "fig.update_layout(\n", - " title=\"Temperature Anomaly Projections\",\n", + " title=\"Tech Stock Performance (2018-2019)\",\n", " template=\"plotly_white\",\n", - " legend_title_text=\"City\",\n", + " legend_title_text=\"Company\",\n", ")\n", "\n", "fig" @@ -201,9 +197,9 @@ "metadata": {}, "outputs": [], "source": [ - "fig = xpx(da).line(\n", - " title=\"Temperature Trends\",\n", - " color_discrete_sequence=[\"#E63946\", \"#457B9D\", \"#2A9D8F\"],\n", + "fig = xpx(stocks).line(\n", + " title=\"Stock Prices\",\n", + " color_discrete_sequence=[\"#E63946\", \"#457B9D\", \"#2A9D8F\", \"#E9C46A\", \"#F4A261\", \"#264653\"],\n", " template=\"simple_white\",\n", ")\n", "fig" diff --git a/pyproject.toml b/pyproject.toml index 5f7f927..145b399 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,19 +32,20 @@ Repository = "https://github.com/FBumann/xarray_plotly" [project.optional-dependencies] dev = [ - "pytest>=7.0", - "pytest-cov>=4.0", - "mypy>=1.0", - "ruff>=0.4", - "pre-commit>=3.0", - "nbstripout>=0.6", + "pytest==8.3.5", + "pytest-cov==6.0.0", + "mypy==1.14.1", + "ruff==0.9.2", + "pre-commit==4.0.1", + "nbstripout==0.8.1", ] docs = [ - "mkdocs>=1.5", - "mkdocs-material>=9.0", - "mkdocstrings[python]>=0.24", - "mkdocs-jupyter>=0.24", - "mkdocs-plotly-plugin>=0.1", + "mkdocs==1.6.1", + "mkdocs-material==9.5.49", + "mkdocstrings[python]==0.27.0", + "mkdocs-jupyter==0.25.1", + "mkdocs-plotly-plugin==0.1.3", + "jupyter==1.1.1", ] [project.entry-points."xarray.backends"]