Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions xarray_plotly/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DataArrayPlotlyAccessor:
```
"""

__all__: ClassVar = ["line", "bar", "area", "scatter", "box", "imshow"]
__all__: ClassVar = ["line", "bar", "area", "scatter", "box", "imshow", "pie"]

def __init__(self, darray: DataArray) -> None:
self._da = darray
Expand Down Expand Up @@ -274,6 +274,38 @@ def imshow(
**px_kwargs,
)

def pie(
self,
*,
names: SlotValue = auto,
color: SlotValue = None,
facet_col: SlotValue = auto,
facet_row: SlotValue = auto,
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive pie chart.
Slot order: names -> facet_col -> facet_row
Args:
names: Dimension for pie slice names/categories. Default: first dimension.
color: Dimension for color grouping. Default: None (uses names).
facet_col: Dimension for subplot columns. Default: second dimension.
facet_row: Dimension for subplot rows. Default: third dimension.
**px_kwargs: Additional arguments passed to `plotly.express.pie()`.
Returns:
Interactive Plotly Figure.
"""
return plotting.pie(
self._da,
names=names,
color=color,
facet_col=facet_col,
facet_row=facet_row,
**px_kwargs,
)


class DatasetPlotlyAccessor:
"""Plotly Express plotting accessor for xarray Dataset.
Expand Down Expand Up @@ -307,7 +339,7 @@ class DatasetPlotlyAccessor:
```
"""

__all__: ClassVar = ["line", "bar", "area", "scatter", "box"]
__all__: ClassVar = ["line", "bar", "area", "scatter", "box", "pie"]

def __init__(self, dataset: Dataset) -> None:
self._ds = dataset
Expand Down Expand Up @@ -519,3 +551,36 @@ def box(
animation_frame=animation_frame,
**px_kwargs,
)

def pie(
self,
var: str | None = None,
*,
names: SlotValue = auto,
color: SlotValue = None,
facet_col: SlotValue = auto,
facet_row: SlotValue = auto,
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive pie chart.
Args:
var: Variable to plot. If None, plots all variables with "variable" dimension.
names: Dimension for pie slice names/categories.
color: Dimension for color grouping.
facet_col: Dimension for subplot columns.
facet_row: Dimension for subplot rows.
**px_kwargs: Additional arguments passed to `plotly.express.pie()`.
Returns:
Interactive Plotly Figure.
"""
da = self._get_dataarray(var)
return plotting.pie(
da,
names=names,
color=color,
facet_col=facet_col,
facet_row=facet_row,
**px_kwargs,
)
1 change: 1 addition & 0 deletions xarray_plotly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
),
"imshow": ("y", "x", "facet_col", "animation_frame"),
"box": ("x", "color", "facet_col", "facet_row", "animation_frame"),
"pie": ("names", "facet_col", "facet_row"),
}


Expand Down
61 changes: 61 additions & 0 deletions xarray_plotly/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,64 @@ def imshow(
animation_frame=slots.get("animation_frame"),
**px_kwargs,
)


def pie(
darray: DataArray,
*,
names: SlotValue = auto,
color: SlotValue = None,
facet_col: SlotValue = auto,
facet_row: SlotValue = auto,
**px_kwargs: Any,
) -> go.Figure:
"""
Create an interactive pie chart from a DataArray.
The values are the DataArray values. Dimensions fill slots in order:
names -> facet_col -> facet_row
Parameters
----------
darray
The DataArray to plot.
names
Dimension for pie slice names/categories. Default: first dimension.
color
Dimension for color grouping. Default: None (uses names).
facet_col
Dimension for subplot columns. Default: second dimension.
facet_row
Dimension for subplot rows. Default: third dimension.
**px_kwargs
Additional arguments passed to `plotly.express.pie()`.
Returns
-------
plotly.graph_objects.Figure
"""
slots = assign_slots(
list(darray.dims),
"pie",
names=names,
facet_col=facet_col,
facet_row=facet_row,
)

df = to_dataframe(darray)
value_col = get_value_col(darray)
labels = {**build_labels(darray, slots, value_col), **px_kwargs.pop("labels", {})}

# Use names dimension for color if not explicitly set
color_col = color if color is not None else slots.get("names")

return px.pie(
df,
names=slots.get("names"),
values=value_col,
color=color_col,
facet_col=slots.get("facet_col"),
facet_row=slots.get("facet_row"),
labels=labels,
**px_kwargs,
)