Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 89 additions & 3 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@
class TestXpxFunction:
"""Tests for the xpx() function."""

def test_xpx_returns_accessor(self) -> None:
"""Test that xpx() returns a DataArrayPlotlyAccessor."""
def test_xpx_returns_dataarray_accessor(self) -> None:
"""Test that xpx() returns a DataArrayPlotlyAccessor for DataArray."""
da = xr.DataArray(np.random.rand(10), dims=["time"])
accessor = xpx(da)
assert hasattr(accessor, "line")
assert hasattr(accessor, "bar")
assert hasattr(accessor, "scatter")
assert hasattr(accessor, "imshow")

def test_xpx_equivalent_to_accessor(self) -> None:
def test_xpx_returns_dataset_accessor(self) -> None:
"""Test that xpx() returns a DatasetPlotlyAccessor for Dataset."""
ds = xr.Dataset({"temp": (["time"], np.random.rand(10))})
accessor = xpx(ds)
assert hasattr(accessor, "line")
assert hasattr(accessor, "bar")
assert hasattr(accessor, "scatter")
# Dataset accessor should not have imshow
assert not hasattr(accessor, "imshow")

def test_xpx_dataarray_equivalent_to_accessor(self) -> None:
"""Test that xpx(da).line() works the same as da.plotly.line()."""
da = xr.DataArray(
np.random.rand(10, 3),
Expand All @@ -36,6 +47,19 @@ def test_xpx_equivalent_to_accessor(self) -> None:
assert isinstance(fig1, go.Figure)
assert isinstance(fig2, go.Figure)

def test_xpx_dataset_equivalent_to_accessor(self) -> None:
"""Test that xpx(ds).line() works the same as ds.plotly.line()."""
ds = xr.Dataset(
{
"temperature": (["time", "city"], np.random.rand(10, 3)),
"humidity": (["time", "city"], np.random.rand(10, 3)),
}
)
fig1 = xpx(ds).line()
fig2 = ds.plotly.line()
assert isinstance(fig1, go.Figure)
assert isinstance(fig2, go.Figure)


class TestDataArrayPxplot:
"""Tests for DataArray.plotly accessor."""
Expand Down Expand Up @@ -206,3 +230,65 @@ def test_value_label_from_attrs(self) -> None:
"""Test that value labels are extracted from attributes."""
fig = self.da.plotly.line()
assert isinstance(fig, go.Figure)


class TestDatasetPlotlyAccessor:
"""Tests for Dataset.plotly accessor."""

@pytest.fixture(autouse=True)
def setup(self) -> None:
"""Set up test data."""
self.ds = xr.Dataset(
{
"temperature": (["time", "city"], np.random.rand(10, 3)),
"humidity": (["time", "city"], np.random.rand(10, 3)),
},
coords={
"time": pd.date_range("2020", periods=10),
"city": ["NYC", "LA", "Chicago"],
},
)

def test_accessor_exists(self) -> None:
"""Test that plotly accessor is available on Dataset."""
assert hasattr(self.ds, "plotly")
assert hasattr(self.ds.plotly, "line")
assert hasattr(self.ds.plotly, "bar")
assert hasattr(self.ds.plotly, "area")
assert hasattr(self.ds.plotly, "scatter")
assert hasattr(self.ds.plotly, "box")

def test_line_all_variables(self) -> None:
"""Test line plot with all variables."""
fig = self.ds.plotly.line()
assert isinstance(fig, go.Figure)

def test_line_single_variable(self) -> None:
"""Test line plot with single variable."""
fig = self.ds.plotly.line(var="temperature")
assert isinstance(fig, go.Figure)

def test_line_variable_as_facet(self) -> None:
"""Test line plot with variable as facet."""
fig = self.ds.plotly.line(facet_col="variable")
assert isinstance(fig, go.Figure)

def test_bar_all_variables(self) -> None:
"""Test bar plot with all variables."""
fig = self.ds.plotly.bar()
assert isinstance(fig, go.Figure)

def test_area_all_variables(self) -> None:
"""Test area plot with all variables."""
fig = self.ds.plotly.area()
assert isinstance(fig, go.Figure)

def test_scatter_all_variables(self) -> None:
"""Test scatter plot with all variables."""
fig = self.ds.plotly.scatter()
assert isinstance(fig, go.Figure)

def test_box_all_variables(self) -> None:
"""Test box plot with all variables."""
fig = self.ds.plotly.box()
assert isinstance(fig, go.Figure)
45 changes: 36 additions & 9 deletions xarray_plotly/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Interactive Plotly Express plotting for xarray.

This package provides a `plotly` accessor for xarray DataArray objects,
This package provides a `plotly` accessor for xarray DataArray and Dataset objects,
enabling interactive visualization with Plotly Express.

Features:
- **Interactive plots**: Zoom, pan, hover, toggle traces
- **Automatic dimension assignment**: Dimensions fill slots (x, color, facet) by position
- **Multiple plot types**: line, bar, area, scatter, box, imshow
- **Dataset support**: Plot all variables at once with "variable" dimension
- **Faceting and animation**: Built-in subplot grids and animated plots
- **Customizable**: Returns Plotly Figure objects for further modification

Expand All @@ -15,11 +16,13 @@

import xarray_plotly
fig = da.plotly.line()
fig = ds.plotly.line() # Dataset: all variables

Function style (recommended for IDE completion)::

from xarray_plotly import xpx
fig = xpx(da).line()
fig = xpx(ds).line() # Dataset: all variables

Example:
```python
Expand All @@ -34,48 +37,72 @@
fig = xpx(da).line() # Auto: time->x, city->color, scenario->facet_col
fig = xpx(da).line(x="time", color="scenario") # Explicit
fig = xpx(da).line(color=None) # Skip slot

# Dataset: plot all variables (accessor or xpx)
ds = xr.Dataset({"temp": da, "precip": da})
fig = xpx(ds).line() # "variable" dimension for color
fig = xpx(ds).line(facet_col="variable") # Facet by variable
```
"""

from importlib.metadata import version
from typing import overload

from xarray import DataArray, register_dataarray_accessor
from xarray import DataArray, Dataset, register_dataarray_accessor, register_dataset_accessor

from xarray_plotly import config
from xarray_plotly.accessor import DataArrayPlotlyAccessor
from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor
from xarray_plotly.common import SLOT_ORDERS, auto

__all__ = [
"SLOT_ORDERS",
"DataArrayPlotlyAccessor",
"DatasetPlotlyAccessor",
"auto",
"config",
"xpx",
]


def xpx(da: DataArray) -> DataArrayPlotlyAccessor:
"""Get the plotly accessor for a DataArray with full IDE code completion.
@overload
def xpx(data: DataArray) -> DataArrayPlotlyAccessor: ...


@overload
def xpx(data: Dataset) -> DatasetPlotlyAccessor: ...


This is an alternative to `da.plotly` that provides proper type hints
def xpx(data: DataArray | Dataset) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor:
"""Get the plotly accessor for a DataArray or Dataset with full IDE code completion.

This is an alternative to `da.plotly` / `ds.plotly` that provides proper type hints
and code completion in IDEs.

Args:
da: The DataArray to plot.
data: The DataArray or Dataset to plot.

Returns:
The accessor with plotting methods (line, bar, area, scatter, box, imshow).

Example:
```python
from xarray_plotly import xpx

# DataArray
fig = xpx(da).line() # Full code completion works here

# Dataset
fig = xpx(ds).line() # Plots all variables
fig = xpx(ds).line(var="temperature") # Single variable
```
"""
return DataArrayPlotlyAccessor(da)
if isinstance(data, Dataset):
return DatasetPlotlyAccessor(data)
return DataArrayPlotlyAccessor(data)


__version__ = version("xarray_plotly")

# Register the accessor
# Register the accessors
register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor)
register_dataset_accessor("plotly")(DatasetPlotlyAccessor)
Loading