diff --git a/CHANGELOG.md b/CHANGELOG.md index 01165c249..4f8f8e128 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -301,6 +301,7 @@ fs.transform.cluster( - `FlowSystem.weights` returns `dict[str, xr.DataArray]` (unit weights instead of `1.0` float fallback) - `FlowSystemDimensions` type now includes `'cluster'` +- `statistics.plot.balance()`, `carrier_balance()`, and `storage()` now use `xarray_plotly.fast_bar()` internally (styled stacked areas for better performance) ### 🗑️ Deprecated diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 9b7aeb6b4..55259b0ba 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -27,6 +27,7 @@ import pandas as pd import plotly.graph_objects as go import xarray as xr +from xarray_plotly.figures import update_traces from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG @@ -146,108 +147,6 @@ def _reshape_time_for_heatmap( return result.transpose('timestep', 'timeframe', *other_dims) -def _iter_all_traces(fig: go.Figure): - """Iterate over all traces in a figure, including animation frames. - - Yields traces from fig.data first, then from each frame in fig.frames. - Useful for applying styling to all traces including those in animations. - - Args: - fig: Plotly Figure. - - Yields: - Each trace object from the figure. - """ - yield from fig.data - for frame in getattr(fig, 'frames', []) or []: - yield from frame.data - - -def _style_area_as_bar(fig: go.Figure) -> None: - """Style area chart traces to look like bar charts with proper pos/neg stacking. - - Iterates over all traces in fig.data and fig.frames (for animations), - setting stepped line shape, removing line borders, making fills opaque, - and assigning stackgroups based on whether values are positive or negative. - - Handles faceting + animation combinations by building color and classification - maps from trace names in the base figure. - - Args: - fig: Plotly Figure with area chart traces. - """ - import plotly.express as px - - default_colors = px.colors.qualitative.Plotly - - # Build color map from base figure traces - # trace.name -> color - color_map: dict[str, str] = {} - for i, trace in enumerate(fig.data): - if hasattr(trace, 'line') and trace.line and trace.line.color: - color_map[trace.name] = trace.line.color - else: - color_map[trace.name] = default_colors[i % len(default_colors)] - - # Classify traces by aggregating sign info across ALL traces (including animation frames) - # trace.name -> 'positive'|'negative'|'mixed'|'zero' - class_map: dict[str, str] = {} - sign_flags: dict[str, dict[str, bool]] = {} # trace.name -> {'has_pos': bool, 'has_neg': bool} - - for trace in _iter_all_traces(fig): - if trace.name not in sign_flags: - sign_flags[trace.name] = {'has_pos': False, 'has_neg': False} - - y_vals = trace.y - if y_vals is not None and len(y_vals) > 0: - y_arr = np.asarray(y_vals) - y_clean = y_arr[np.abs(y_arr) > 1e-9] - if len(y_clean) > 0: - if np.any(y_clean > 0): - sign_flags[trace.name]['has_pos'] = True - if np.any(y_clean < 0): - sign_flags[trace.name]['has_neg'] = True - - # Compute class_map from aggregated sign flags - for name, flags in sign_flags.items(): - has_pos, has_neg = flags['has_pos'], flags['has_neg'] - if has_pos and has_neg: - class_map[name] = 'mixed' - elif has_neg: - class_map[name] = 'negative' - elif has_pos: - class_map[name] = 'positive' - else: - class_map[name] = 'zero' - - def style_trace(trace: go.Scatter) -> None: - """Apply bar-like styling to a single trace.""" - # Look up color by trace name - color = color_map.get(trace.name, default_colors[0]) - - # Look up classification - cls = class_map.get(trace.name, 'positive') - - # Set stackgroup based on classification (positive and negative stack separately) - if cls in ('positive', 'negative'): - trace.stackgroup = cls - trace.fillcolor = color - trace.line = dict(width=0, color=color, shape='hv') - elif cls == 'mixed': - # Mixed: show as dashed line, no stacking - trace.stackgroup = None - trace.fill = None - trace.line = dict(width=2, color=color, shape='hv', dash='dash') - else: # zero - trace.stackgroup = None - trace.fill = None - trace.line = dict(width=0, color=color, shape='hv') - - # Style all traces (main + animation frames) - for trace in _iter_all_traces(fig): - style_trace(trace) - - def _apply_unified_hover(fig: go.Figure, unit: str = '', decimals: int = 1) -> None: """Apply unified hover mode with clean formatting to any Plotly figure. @@ -264,9 +163,8 @@ def _apply_unified_hover(fig: go.Figure, unit: str = '', decimals: int = 1) -> N unit_suffix = f' {unit}' if unit else '' hover_template = f'%{{fullData.name}}: %{{y:.{decimals}f}}{unit_suffix}' - # Apply to all traces (main + animation frames) - for trace in _iter_all_traces(fig): - trace.hovertemplate = hover_template + # Apply to all traces (main + animation frames) using xarray_plotly helper + update_traces(fig, hovertemplate=hover_template) # Layout settings for unified hover fig.update_layout(hovermode='x unified') @@ -1650,13 +1548,11 @@ def balance( unit_label = ds[first_var].attrs.get('unit', '') _apply_slot_defaults(plotly_kwargs, 'balance') - fig = ds.plotly.area( + fig = ds.plotly.fast_bar( title=f'{node} [{unit_label}]' if unit_label else node, - line_shape='hv', **color_kwargs, **plotly_kwargs, ) - _style_area_as_bar(fig) _apply_unified_hover(fig, unit=unit_label) if show is None: @@ -1775,13 +1671,11 @@ def carrier_balance( unit_label = ds[first_var].attrs.get('unit', '') _apply_slot_defaults(plotly_kwargs, 'carrier_balance') - fig = ds.plotly.area( + fig = ds.plotly.fast_bar( title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', - line_shape='hv', **color_kwargs, **plotly_kwargs, ) - _style_area_as_bar(fig) _apply_unified_hover(fig, unit=unit_label) if show is None: @@ -2380,13 +2274,11 @@ def storage( # Create stacked area chart for flows (styled as bar) _apply_slot_defaults(plotly_kwargs, 'storage') - fig = flow_ds.plotly.area( + fig = flow_ds.plotly.fast_bar( title=f'{storage} Operation [{unit_label}]' if unit_label else f'{storage} Operation', - line_shape='hv', **color_kwargs, **plotly_kwargs, ) - _style_area_as_bar(fig) _apply_unified_hover(fig, unit=unit_label) # Add charge state as line on secondary y-axis diff --git a/pyproject.toml b/pyproject.toml index 0ebb15d99..ac39fd48a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ # Visualization "matplotlib >= 3.5.2, < 4", "plotly >= 5.15.0, < 7", - "xarray_plotly >= 0.0.3, < 1", + "xarray_plotly >= 0.0.10, < 1", ] [project.optional-dependencies]