diff --git a/CHANGELOG.md b/CHANGELOG.md
index 01165c249..4f8f8e128 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -301,6 +301,7 @@ fs.transform.cluster(
- `FlowSystem.weights` returns `dict[str, xr.DataArray]` (unit weights instead of `1.0` float fallback)
- `FlowSystemDimensions` type now includes `'cluster'`
+- `statistics.plot.balance()`, `carrier_balance()`, and `storage()` now use `xarray_plotly.fast_bar()` internally (styled stacked areas for better performance)
### 🗑️ Deprecated
diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py
index 9b7aeb6b4..55259b0ba 100644
--- a/flixopt/statistics_accessor.py
+++ b/flixopt/statistics_accessor.py
@@ -27,6 +27,7 @@
import pandas as pd
import plotly.graph_objects as go
import xarray as xr
+from xarray_plotly.figures import update_traces
from .color_processing import ColorType, hex_to_rgba, process_colors
from .config import CONFIG
@@ -146,108 +147,6 @@ def _reshape_time_for_heatmap(
return result.transpose('timestep', 'timeframe', *other_dims)
-def _iter_all_traces(fig: go.Figure):
- """Iterate over all traces in a figure, including animation frames.
-
- Yields traces from fig.data first, then from each frame in fig.frames.
- Useful for applying styling to all traces including those in animations.
-
- Args:
- fig: Plotly Figure.
-
- Yields:
- Each trace object from the figure.
- """
- yield from fig.data
- for frame in getattr(fig, 'frames', []) or []:
- yield from frame.data
-
-
-def _style_area_as_bar(fig: go.Figure) -> None:
- """Style area chart traces to look like bar charts with proper pos/neg stacking.
-
- Iterates over all traces in fig.data and fig.frames (for animations),
- setting stepped line shape, removing line borders, making fills opaque,
- and assigning stackgroups based on whether values are positive or negative.
-
- Handles faceting + animation combinations by building color and classification
- maps from trace names in the base figure.
-
- Args:
- fig: Plotly Figure with area chart traces.
- """
- import plotly.express as px
-
- default_colors = px.colors.qualitative.Plotly
-
- # Build color map from base figure traces
- # trace.name -> color
- color_map: dict[str, str] = {}
- for i, trace in enumerate(fig.data):
- if hasattr(trace, 'line') and trace.line and trace.line.color:
- color_map[trace.name] = trace.line.color
- else:
- color_map[trace.name] = default_colors[i % len(default_colors)]
-
- # Classify traces by aggregating sign info across ALL traces (including animation frames)
- # trace.name -> 'positive'|'negative'|'mixed'|'zero'
- class_map: dict[str, str] = {}
- sign_flags: dict[str, dict[str, bool]] = {} # trace.name -> {'has_pos': bool, 'has_neg': bool}
-
- for trace in _iter_all_traces(fig):
- if trace.name not in sign_flags:
- sign_flags[trace.name] = {'has_pos': False, 'has_neg': False}
-
- y_vals = trace.y
- if y_vals is not None and len(y_vals) > 0:
- y_arr = np.asarray(y_vals)
- y_clean = y_arr[np.abs(y_arr) > 1e-9]
- if len(y_clean) > 0:
- if np.any(y_clean > 0):
- sign_flags[trace.name]['has_pos'] = True
- if np.any(y_clean < 0):
- sign_flags[trace.name]['has_neg'] = True
-
- # Compute class_map from aggregated sign flags
- for name, flags in sign_flags.items():
- has_pos, has_neg = flags['has_pos'], flags['has_neg']
- if has_pos and has_neg:
- class_map[name] = 'mixed'
- elif has_neg:
- class_map[name] = 'negative'
- elif has_pos:
- class_map[name] = 'positive'
- else:
- class_map[name] = 'zero'
-
- def style_trace(trace: go.Scatter) -> None:
- """Apply bar-like styling to a single trace."""
- # Look up color by trace name
- color = color_map.get(trace.name, default_colors[0])
-
- # Look up classification
- cls = class_map.get(trace.name, 'positive')
-
- # Set stackgroup based on classification (positive and negative stack separately)
- if cls in ('positive', 'negative'):
- trace.stackgroup = cls
- trace.fillcolor = color
- trace.line = dict(width=0, color=color, shape='hv')
- elif cls == 'mixed':
- # Mixed: show as dashed line, no stacking
- trace.stackgroup = None
- trace.fill = None
- trace.line = dict(width=2, color=color, shape='hv', dash='dash')
- else: # zero
- trace.stackgroup = None
- trace.fill = None
- trace.line = dict(width=0, color=color, shape='hv')
-
- # Style all traces (main + animation frames)
- for trace in _iter_all_traces(fig):
- style_trace(trace)
-
-
def _apply_unified_hover(fig: go.Figure, unit: str = '', decimals: int = 1) -> None:
"""Apply unified hover mode with clean formatting to any Plotly figure.
@@ -264,9 +163,8 @@ def _apply_unified_hover(fig: go.Figure, unit: str = '', decimals: int = 1) -> N
unit_suffix = f' {unit}' if unit else ''
hover_template = f'%{{fullData.name}}: %{{y:.{decimals}f}}{unit_suffix}'
- # Apply to all traces (main + animation frames)
- for trace in _iter_all_traces(fig):
- trace.hovertemplate = hover_template
+ # Apply to all traces (main + animation frames) using xarray_plotly helper
+ update_traces(fig, hovertemplate=hover_template)
# Layout settings for unified hover
fig.update_layout(hovermode='x unified')
@@ -1650,13 +1548,11 @@ def balance(
unit_label = ds[first_var].attrs.get('unit', '')
_apply_slot_defaults(plotly_kwargs, 'balance')
- fig = ds.plotly.area(
+ fig = ds.plotly.fast_bar(
title=f'{node} [{unit_label}]' if unit_label else node,
- line_shape='hv',
**color_kwargs,
**plotly_kwargs,
)
- _style_area_as_bar(fig)
_apply_unified_hover(fig, unit=unit_label)
if show is None:
@@ -1775,13 +1671,11 @@ def carrier_balance(
unit_label = ds[first_var].attrs.get('unit', '')
_apply_slot_defaults(plotly_kwargs, 'carrier_balance')
- fig = ds.plotly.area(
+ fig = ds.plotly.fast_bar(
title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance',
- line_shape='hv',
**color_kwargs,
**plotly_kwargs,
)
- _style_area_as_bar(fig)
_apply_unified_hover(fig, unit=unit_label)
if show is None:
@@ -2380,13 +2274,11 @@ def storage(
# Create stacked area chart for flows (styled as bar)
_apply_slot_defaults(plotly_kwargs, 'storage')
- fig = flow_ds.plotly.area(
+ fig = flow_ds.plotly.fast_bar(
title=f'{storage} Operation [{unit_label}]' if unit_label else f'{storage} Operation',
- line_shape='hv',
**color_kwargs,
**plotly_kwargs,
)
- _style_area_as_bar(fig)
_apply_unified_hover(fig, unit=unit_label)
# Add charge state as line on secondary y-axis
diff --git a/pyproject.toml b/pyproject.toml
index 0ebb15d99..ac39fd48a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,7 +47,7 @@ dependencies = [
# Visualization
"matplotlib >= 3.5.2, < 4",
"plotly >= 5.15.0, < 7",
- "xarray_plotly >= 0.0.3, < 1",
+ "xarray_plotly >= 0.0.10, < 1",
]
[project.optional-dependencies]