Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,52 @@ def test_box_all_variables(self) -> None:
"""Test box plot with all variables."""
fig = self.ds.plotly.box()
assert isinstance(fig, go.Figure)


class TestImshowBounds:
"""Tests for imshow global bounds and robust mode."""

def test_imshow_global_bounds(self) -> None:
"""Test that imshow uses global min/max by default."""
da = xr.DataArray(
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 100]]]),
dims=["time", "y", "x"],
)
fig = da.plotly.imshow(animation_frame="time")
# Check coloraxis for zmin/zmax (plotly stores them there)
coloraxis = fig.layout.coloraxis
assert coloraxis.cmin == 1.0
assert coloraxis.cmax == 100.0

def test_imshow_robust_bounds(self) -> None:
"""Test that robust=True uses percentile-based bounds."""
# Create data with outlier
data = np.random.rand(10, 20) * 100
data[0, 0] = 10000 # extreme outlier
da = xr.DataArray(data, dims=["y", "x"])

fig = da.plotly.imshow(robust=True)
# With robust=True, cmax should be much less than the outlier
coloraxis = fig.layout.coloraxis
assert coloraxis.cmax < 10000
assert coloraxis.cmax < 200 # Should be around 98th percentile (~98)

def test_imshow_user_zmin_zmax_override(self) -> None:
"""Test that user-provided zmin/zmax overrides auto bounds."""
da = xr.DataArray(np.random.rand(10, 20) * 100, dims=["y", "x"])
fig = da.plotly.imshow(zmin=0, zmax=50)
coloraxis = fig.layout.coloraxis
assert coloraxis.cmin == 0
assert coloraxis.cmax == 50

def test_imshow_animation_consistent_bounds(self) -> None:
"""Test that animation frames have consistent color bounds."""
da = xr.DataArray(
np.array([[[0, 10], [20, 30]], [[40, 50], [60, 70]]]),
dims=["time", "y", "x"],
)
fig = da.plotly.imshow(animation_frame="time")
# All frames should use global min (0) and max (70)
coloraxis = fig.layout.coloraxis
assert coloraxis.cmin == 0.0
assert coloraxis.cmax == 70.0
4 changes: 4 additions & 0 deletions xarray_plotly/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def imshow(
y: SlotValue = auto,
facet_col: SlotValue = auto,
animation_frame: SlotValue = auto,
robust: bool = False,
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive heatmap image.
Expand All @@ -261,7 +262,9 @@ def imshow(
y: Dimension for y-axis (rows). Default: first dimension.
facet_col: Dimension for subplot columns. Default: third dimension.
animation_frame: Dimension for animation. Default: fourth dimension.
robust: If True, use 2nd/98th percentiles for color bounds (handles outliers).
**px_kwargs: Additional arguments passed to `plotly.express.imshow()`.
Use `zmin` and `zmax` to manually set color scale bounds.

Returns:
Interactive Plotly Figure.
Expand All @@ -272,6 +275,7 @@ def imshow(
y=y,
facet_col=facet_col,
animation_frame=animation_frame,
robust=robust,
**px_kwargs,
)

Expand Down
20 changes: 20 additions & 0 deletions xarray_plotly/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import TYPE_CHECKING, Any

import numpy as np
import plotly.express as px

from xarray_plotly.common import (
Expand Down Expand Up @@ -398,6 +399,7 @@ def imshow(
y: SlotValue = auto,
facet_col: SlotValue = auto,
animation_frame: SlotValue = auto,
robust: bool = False,
**px_kwargs: Any,
) -> go.Figure:
"""
Expand All @@ -418,8 +420,12 @@ def imshow(
Dimension for subplot columns. Default: third dimension.
animation_frame
Dimension for animation. Default: fourth dimension.
robust
If True, compute color bounds using 2nd and 98th percentiles
for robustness against outliers. Default: False.
**px_kwargs
Additional arguments passed to `plotly.express.imshow()`.
Use `zmin` and `zmax` to manually set color scale bounds.

Returns
-------
Expand All @@ -440,6 +446,20 @@ def imshow(
]
plot_data = darray.transpose(*transpose_order) if transpose_order else darray

# Compute global color bounds if not provided
if "zmin" not in px_kwargs or "zmax" not in px_kwargs:
values = plot_data.values
if robust:
# Use percentiles for outlier robustness
zmin = float(np.nanpercentile(values, 2))
zmax = float(np.nanpercentile(values, 98))
else:
# Use global min/max across all data
zmin = float(np.nanmin(values))
zmax = float(np.nanmax(values))
px_kwargs.setdefault("zmin", zmin)
px_kwargs.setdefault("zmax", zmax)

return px.imshow(
plot_data,
facet_col=slots.get("facet_col"),
Expand Down