diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 64f8f67..ddef38c 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -15,15 +15,26 @@ class TestXpxFunction: """Tests for the xpx() function.""" - def test_xpx_returns_accessor(self) -> None: - """Test that xpx() returns a DataArrayPlotlyAccessor.""" + def test_xpx_returns_dataarray_accessor(self) -> None: + """Test that xpx() returns a DataArrayPlotlyAccessor for DataArray.""" da = xr.DataArray(np.random.rand(10), dims=["time"]) accessor = xpx(da) assert hasattr(accessor, "line") assert hasattr(accessor, "bar") assert hasattr(accessor, "scatter") + assert hasattr(accessor, "imshow") - def test_xpx_equivalent_to_accessor(self) -> None: + def test_xpx_returns_dataset_accessor(self) -> None: + """Test that xpx() returns a DatasetPlotlyAccessor for Dataset.""" + ds = xr.Dataset({"temp": (["time"], np.random.rand(10))}) + accessor = xpx(ds) + assert hasattr(accessor, "line") + assert hasattr(accessor, "bar") + assert hasattr(accessor, "scatter") + # Dataset accessor should not have imshow + assert not hasattr(accessor, "imshow") + + def test_xpx_dataarray_equivalent_to_accessor(self) -> None: """Test that xpx(da).line() works the same as da.plotly.line().""" da = xr.DataArray( np.random.rand(10, 3), @@ -36,6 +47,19 @@ def test_xpx_equivalent_to_accessor(self) -> None: assert isinstance(fig1, go.Figure) assert isinstance(fig2, go.Figure) + def test_xpx_dataset_equivalent_to_accessor(self) -> None: + """Test that xpx(ds).line() works the same as ds.plotly.line().""" + ds = xr.Dataset( + { + "temperature": (["time", "city"], np.random.rand(10, 3)), + "humidity": (["time", "city"], np.random.rand(10, 3)), + } + ) + fig1 = xpx(ds).line() + fig2 = ds.plotly.line() + assert isinstance(fig1, go.Figure) + assert isinstance(fig2, go.Figure) + class TestDataArrayPxplot: """Tests for DataArray.plotly accessor.""" @@ -206,3 +230,65 @@ def test_value_label_from_attrs(self) -> None: """Test that value labels are extracted from attributes.""" fig = self.da.plotly.line() assert isinstance(fig, go.Figure) + + +class TestDatasetPlotlyAccessor: + """Tests for Dataset.plotly accessor.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.ds = xr.Dataset( + { + "temperature": (["time", "city"], np.random.rand(10, 3)), + "humidity": (["time", "city"], np.random.rand(10, 3)), + }, + coords={ + "time": pd.date_range("2020", periods=10), + "city": ["NYC", "LA", "Chicago"], + }, + ) + + def test_accessor_exists(self) -> None: + """Test that plotly accessor is available on Dataset.""" + assert hasattr(self.ds, "plotly") + assert hasattr(self.ds.plotly, "line") + assert hasattr(self.ds.plotly, "bar") + assert hasattr(self.ds.plotly, "area") + assert hasattr(self.ds.plotly, "scatter") + assert hasattr(self.ds.plotly, "box") + + def test_line_all_variables(self) -> None: + """Test line plot with all variables.""" + fig = self.ds.plotly.line() + assert isinstance(fig, go.Figure) + + def test_line_single_variable(self) -> None: + """Test line plot with single variable.""" + fig = self.ds.plotly.line(var="temperature") + assert isinstance(fig, go.Figure) + + def test_line_variable_as_facet(self) -> None: + """Test line plot with variable as facet.""" + fig = self.ds.plotly.line(facet_col="variable") + assert isinstance(fig, go.Figure) + + def test_bar_all_variables(self) -> None: + """Test bar plot with all variables.""" + fig = self.ds.plotly.bar() + assert isinstance(fig, go.Figure) + + def test_area_all_variables(self) -> None: + """Test area plot with all variables.""" + fig = self.ds.plotly.area() + assert isinstance(fig, go.Figure) + + def test_scatter_all_variables(self) -> None: + """Test scatter plot with all variables.""" + fig = self.ds.plotly.scatter() + assert isinstance(fig, go.Figure) + + def test_box_all_variables(self) -> None: + """Test box plot with all variables.""" + fig = self.ds.plotly.box() + assert isinstance(fig, go.Figure) diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index 68b3393..7bc9539 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -1,12 +1,13 @@ """Interactive Plotly Express plotting for xarray. -This package provides a `plotly` accessor for xarray DataArray objects, +This package provides a `plotly` accessor for xarray DataArray and Dataset objects, enabling interactive visualization with Plotly Express. Features: - **Interactive plots**: Zoom, pan, hover, toggle traces - **Automatic dimension assignment**: Dimensions fill slots (x, color, facet) by position - **Multiple plot types**: line, bar, area, scatter, box, imshow + - **Dataset support**: Plot all variables at once with "variable" dimension - **Faceting and animation**: Built-in subplot grids and animated plots - **Customizable**: Returns Plotly Figure objects for further modification @@ -15,11 +16,13 @@ import xarray_plotly fig = da.plotly.line() + fig = ds.plotly.line() # Dataset: all variables Function style (recommended for IDE completion):: from xarray_plotly import xpx fig = xpx(da).line() + fig = xpx(ds).line() # Dataset: all variables Example: ```python @@ -34,34 +37,49 @@ fig = xpx(da).line() # Auto: time->x, city->color, scenario->facet_col fig = xpx(da).line(x="time", color="scenario") # Explicit fig = xpx(da).line(color=None) # Skip slot + + # Dataset: plot all variables (accessor or xpx) + ds = xr.Dataset({"temp": da, "precip": da}) + fig = xpx(ds).line() # "variable" dimension for color + fig = xpx(ds).line(facet_col="variable") # Facet by variable ``` """ from importlib.metadata import version +from typing import overload -from xarray import DataArray, register_dataarray_accessor +from xarray import DataArray, Dataset, register_dataarray_accessor, register_dataset_accessor from xarray_plotly import config -from xarray_plotly.accessor import DataArrayPlotlyAccessor +from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor from xarray_plotly.common import SLOT_ORDERS, auto __all__ = [ "SLOT_ORDERS", "DataArrayPlotlyAccessor", + "DatasetPlotlyAccessor", "auto", "config", "xpx", ] -def xpx(da: DataArray) -> DataArrayPlotlyAccessor: - """Get the plotly accessor for a DataArray with full IDE code completion. +@overload +def xpx(data: DataArray) -> DataArrayPlotlyAccessor: ... + + +@overload +def xpx(data: Dataset) -> DatasetPlotlyAccessor: ... + - This is an alternative to `da.plotly` that provides proper type hints +def xpx(data: DataArray | Dataset) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor: + """Get the plotly accessor for a DataArray or Dataset with full IDE code completion. + + This is an alternative to `da.plotly` / `ds.plotly` that provides proper type hints and code completion in IDEs. Args: - da: The DataArray to plot. + data: The DataArray or Dataset to plot. Returns: The accessor with plotting methods (line, bar, area, scatter, box, imshow). @@ -69,13 +87,22 @@ def xpx(da: DataArray) -> DataArrayPlotlyAccessor: Example: ```python from xarray_plotly import xpx + + # DataArray fig = xpx(da).line() # Full code completion works here + + # Dataset + fig = xpx(ds).line() # Plots all variables + fig = xpx(ds).line(var="temperature") # Single variable ``` """ - return DataArrayPlotlyAccessor(da) + if isinstance(data, Dataset): + return DatasetPlotlyAccessor(data) + return DataArrayPlotlyAccessor(data) __version__ = version("xarray_plotly") -# Register the accessor +# Register the accessors register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) +register_dataset_accessor("plotly")(DatasetPlotlyAccessor) diff --git a/xarray_plotly/accessor.py b/xarray_plotly/accessor.py index d8bccb8..3c2dc36 100644 --- a/xarray_plotly/accessor.py +++ b/xarray_plotly/accessor.py @@ -1,9 +1,9 @@ -"""Accessor classes for Plotly Express plotting on DataArray.""" +"""Accessor classes for Plotly Express plotting on DataArray and Dataset.""" from typing import Any, ClassVar import plotly.graph_objects as go -from xarray import DataArray +from xarray import DataArray, Dataset from xarray_plotly import plotting from xarray_plotly.common import SlotValue, auto @@ -273,3 +273,249 @@ def imshow( animation_frame=animation_frame, **px_kwargs, ) + + +class DatasetPlotlyAccessor: + """Plotly Express plotting accessor for xarray Dataset. + + Plot a single variable or all variables at once. When plotting all variables, + a "variable" dimension is created that can be assigned to any slot. + + Available methods: line, bar, area, scatter, box + + Args: + dataset: The Dataset to plot. + + Example: + ```python + import xarray as xr + import numpy as np + + ds = xr.Dataset({ + "temperature": (["time", "city"], np.random.rand(10, 3)), + "humidity": (["time", "city"], np.random.rand(10, 3)), + }) + + # Plot all variables - "variable" dimension auto-assigned to color + fig = ds.plotly.line() + + # Control where "variable" goes + fig = ds.plotly.line(facet_col="variable") + + # Plot single variable + fig = ds.plotly.line(var="temperature") + ``` + """ + + __all__: ClassVar = ["line", "bar", "area", "scatter", "box"] + + def __init__(self, dataset: Dataset) -> None: + self._ds = dataset + + def __dir__(self) -> list[str]: + """List available plot methods.""" + return list(self.__all__) + list(super().__dir__()) + + def _get_dataarray(self, var: str | None) -> DataArray: + """Get DataArray from Dataset, either single var or all via to_array().""" + if var is None: + return self._ds.to_array(dim="variable") + return self._ds[var] + + def line( + self, + var: str | None = None, + *, + x: SlotValue = auto, + color: SlotValue = auto, + line_dash: SlotValue = auto, + symbol: SlotValue = auto, + facet_col: SlotValue = auto, + facet_row: SlotValue = auto, + animation_frame: SlotValue = auto, + **px_kwargs: Any, + ) -> go.Figure: + """Create an interactive line plot. + + Args: + var: Variable to plot. If None, plots all variables with "variable" dimension. + x: Dimension for x-axis. + color: Dimension for color grouping. + line_dash: Dimension for line dash style. + symbol: Dimension for marker symbol. + facet_col: Dimension for subplot columns. + facet_row: Dimension for subplot rows. + animation_frame: Dimension for animation. + **px_kwargs: Additional arguments passed to `plotly.express.line()`. + + Returns: + Interactive Plotly Figure. + """ + da = self._get_dataarray(var) + return plotting.line( + da, + x=x, + color=color, + line_dash=line_dash, + symbol=symbol, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **px_kwargs, + ) + + def bar( + self, + var: str | None = None, + *, + x: SlotValue = auto, + color: SlotValue = auto, + pattern_shape: SlotValue = auto, + facet_col: SlotValue = auto, + facet_row: SlotValue = auto, + animation_frame: SlotValue = auto, + **px_kwargs: Any, + ) -> go.Figure: + """Create an interactive bar chart. + + Args: + var: Variable to plot. If None, plots all variables with "variable" dimension. + x: Dimension for x-axis. + color: Dimension for color grouping. + pattern_shape: Dimension for bar fill pattern. + facet_col: Dimension for subplot columns. + facet_row: Dimension for subplot rows. + animation_frame: Dimension for animation. + **px_kwargs: Additional arguments passed to `plotly.express.bar()`. + + Returns: + Interactive Plotly Figure. + """ + da = self._get_dataarray(var) + return plotting.bar( + da, + x=x, + color=color, + pattern_shape=pattern_shape, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **px_kwargs, + ) + + def area( + self, + var: str | None = None, + *, + x: SlotValue = auto, + color: SlotValue = auto, + pattern_shape: SlotValue = auto, + facet_col: SlotValue = auto, + facet_row: SlotValue = auto, + animation_frame: SlotValue = auto, + **px_kwargs: Any, + ) -> go.Figure: + """Create an interactive stacked area chart. + + Args: + var: Variable to plot. If None, plots all variables with "variable" dimension. + x: Dimension for x-axis. + color: Dimension for color/stacking. + pattern_shape: Dimension for fill pattern. + facet_col: Dimension for subplot columns. + facet_row: Dimension for subplot rows. + animation_frame: Dimension for animation. + **px_kwargs: Additional arguments passed to `plotly.express.area()`. + + Returns: + Interactive Plotly Figure. + """ + da = self._get_dataarray(var) + return plotting.area( + da, + x=x, + color=color, + pattern_shape=pattern_shape, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **px_kwargs, + ) + + def scatter( + self, + var: str | None = None, + *, + x: SlotValue = auto, + y: SlotValue | str = "value", + color: SlotValue = auto, + symbol: SlotValue = auto, + facet_col: SlotValue = auto, + facet_row: SlotValue = auto, + animation_frame: SlotValue = auto, + **px_kwargs: Any, + ) -> go.Figure: + """Create an interactive scatter plot. + + Args: + var: Variable to plot. If None, plots all variables with "variable" dimension. + x: Dimension for x-axis. + y: What to plot on y-axis. Default "value" uses DataArray values. + color: Dimension for color grouping, or "value" for DataArray values. + symbol: Dimension for marker symbol. + facet_col: Dimension for subplot columns. + facet_row: Dimension for subplot rows. + animation_frame: Dimension for animation. + **px_kwargs: Additional arguments passed to `plotly.express.scatter()`. + + Returns: + Interactive Plotly Figure. + """ + da = self._get_dataarray(var) + return plotting.scatter( + da, + x=x, + y=y, + color=color, + symbol=symbol, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **px_kwargs, + ) + + def box( + self, + var: str | None = None, + *, + x: SlotValue = auto, + color: SlotValue = None, + facet_col: SlotValue = None, + facet_row: SlotValue = None, + animation_frame: SlotValue = None, + **px_kwargs: Any, + ) -> go.Figure: + """Create an interactive box plot. + + Args: + var: Variable to plot. If None, plots all variables with "variable" dimension. + x: Dimension for x-axis categories. + color: Dimension for color grouping. + facet_col: Dimension for subplot columns. + facet_row: Dimension for subplot rows. + animation_frame: Dimension for animation. + **px_kwargs: Additional arguments passed to `plotly.express.box()`. + + Returns: + Interactive Plotly Figure. + """ + da = self._get_dataarray(var) + return plotting.box( + da, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **px_kwargs, + )