From b248ae2e1dc4d80c15d93b206cda3d97523de323 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 10 Jan 2026 15:24:06 +0100 Subject: [PATCH] 1. xarray_plotly/plotting.py: Added robust parameter to imshow() with global bounds computation 2. xarray_plotly/accessor.py: Added robust parameter to accessor method 3. tests/test_accessor.py: Added 4 tests for bounds behavior New behavior: - Default: Global min/max across all data (fixes animation consistency) - robust=True: Uses 2nd/98th percentile (handles outliers) - zmin/zmax: User override still works --- tests/test_accessor.py | 49 +++++++++++++++++++++++++++++++++++++++ xarray_plotly/accessor.py | 4 ++++ xarray_plotly/plotting.py | 20 ++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/tests/test_accessor.py b/tests/test_accessor.py index ddef38c..9786112 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -292,3 +292,52 @@ def test_box_all_variables(self) -> None: """Test box plot with all variables.""" fig = self.ds.plotly.box() assert isinstance(fig, go.Figure) + + +class TestImshowBounds: + """Tests for imshow global bounds and robust mode.""" + + def test_imshow_global_bounds(self) -> None: + """Test that imshow uses global min/max by default.""" + da = xr.DataArray( + np.array([[[1, 2], [3, 4]], [[5, 6], [7, 100]]]), + dims=["time", "y", "x"], + ) + fig = da.plotly.imshow(animation_frame="time") + # Check coloraxis for zmin/zmax (plotly stores them there) + coloraxis = fig.layout.coloraxis + assert coloraxis.cmin == 1.0 + assert coloraxis.cmax == 100.0 + + def test_imshow_robust_bounds(self) -> None: + """Test that robust=True uses percentile-based bounds.""" + # Create data with outlier + data = np.random.rand(10, 20) * 100 + data[0, 0] = 10000 # extreme outlier + da = xr.DataArray(data, dims=["y", "x"]) + + fig = da.plotly.imshow(robust=True) + # With robust=True, cmax should be much less than the outlier + coloraxis = fig.layout.coloraxis + assert coloraxis.cmax < 10000 + assert coloraxis.cmax < 200 # Should be around 98th percentile (~98) + + def test_imshow_user_zmin_zmax_override(self) -> None: + """Test that user-provided zmin/zmax overrides auto bounds.""" + da = xr.DataArray(np.random.rand(10, 20) * 100, dims=["y", "x"]) + fig = da.plotly.imshow(zmin=0, zmax=50) + coloraxis = fig.layout.coloraxis + assert coloraxis.cmin == 0 + assert coloraxis.cmax == 50 + + def test_imshow_animation_consistent_bounds(self) -> None: + """Test that animation frames have consistent color bounds.""" + da = xr.DataArray( + np.array([[[0, 10], [20, 30]], [[40, 50], [60, 70]]]), + dims=["time", "y", "x"], + ) + fig = da.plotly.imshow(animation_frame="time") + # All frames should use global min (0) and max (70) + coloraxis = fig.layout.coloraxis + assert coloraxis.cmin == 0.0 + assert coloraxis.cmax == 70.0 diff --git a/xarray_plotly/accessor.py b/xarray_plotly/accessor.py index ca66f59..95e1351 100644 --- a/xarray_plotly/accessor.py +++ b/xarray_plotly/accessor.py @@ -250,6 +250,7 @@ def imshow( y: SlotValue = auto, facet_col: SlotValue = auto, animation_frame: SlotValue = auto, + robust: bool = False, **px_kwargs: Any, ) -> go.Figure: """Create an interactive heatmap image. @@ -261,7 +262,9 @@ def imshow( y: Dimension for y-axis (rows). Default: first dimension. facet_col: Dimension for subplot columns. Default: third dimension. animation_frame: Dimension for animation. Default: fourth dimension. + robust: If True, use 2nd/98th percentiles for color bounds (handles outliers). **px_kwargs: Additional arguments passed to `plotly.express.imshow()`. + Use `zmin` and `zmax` to manually set color scale bounds. Returns: Interactive Plotly Figure. @@ -272,6 +275,7 @@ def imshow( y=y, facet_col=facet_col, animation_frame=animation_frame, + robust=robust, **px_kwargs, ) diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index a95e702..ed1142b 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any +import numpy as np import plotly.express as px from xarray_plotly.common import ( @@ -398,6 +399,7 @@ def imshow( y: SlotValue = auto, facet_col: SlotValue = auto, animation_frame: SlotValue = auto, + robust: bool = False, **px_kwargs: Any, ) -> go.Figure: """ @@ -418,8 +420,12 @@ def imshow( Dimension for subplot columns. Default: third dimension. animation_frame Dimension for animation. Default: fourth dimension. + robust + If True, compute color bounds using 2nd and 98th percentiles + for robustness against outliers. Default: False. **px_kwargs Additional arguments passed to `plotly.express.imshow()`. + Use `zmin` and `zmax` to manually set color scale bounds. Returns ------- @@ -440,6 +446,20 @@ def imshow( ] plot_data = darray.transpose(*transpose_order) if transpose_order else darray + # Compute global color bounds if not provided + if "zmin" not in px_kwargs or "zmax" not in px_kwargs: + values = plot_data.values + if robust: + # Use percentiles for outlier robustness + zmin = float(np.nanpercentile(values, 2)) + zmax = float(np.nanpercentile(values, 98)) + else: + # Use global min/max across all data + zmin = float(np.nanmin(values)) + zmax = float(np.nanmax(values)) + px_kwargs.setdefault("zmin", zmin) + px_kwargs.setdefault("zmax", zmax) + return px.imshow( plot_data, facet_col=slots.get("facet_col"),