From 2c6461b061de5e7b2057adc1ee1876dc21071f1d Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 1 Jan 2026 21:30:17 +0100 Subject: [PATCH 01/62] Add dataset plot accessor --- flixopt/__init__.py | 3 + flixopt/dataset_plot_accessor.py | 446 +++++++++++++++++++++++++++++++ 2 files changed, 449 insertions(+) create mode 100644 flixopt/dataset_plot_accessor.py diff --git a/flixopt/__init__.py b/flixopt/__init__.py index d1a63a9c5..3cf219c38 100644 --- a/flixopt/__init__.py +++ b/flixopt/__init__.py @@ -14,6 +14,9 @@ # Import commonly used classes and functions from . import clustering, linear_converters, plotting, results, solvers + +# Register xr.Dataset.fxplot accessor (import triggers registration via decorator) +from . import dataset_plot_accessor as _ # noqa: F401 from .carrier import Carrier, CarrierContainer from .components import ( LinearConverter, diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py new file mode 100644 index 000000000..35f8db47c --- /dev/null +++ b/flixopt/dataset_plot_accessor.py @@ -0,0 +1,446 @@ +"""Dataset plot accessor for xarray Datasets. + +Provides convenient plotting methods for any xr.Dataset via the .fxplot accessor. +This is globally registered and available on all xr.Dataset objects when flixopt is imported. + +Example: + >>> import flixopt + >>> import xarray as xr + >>> ds = xr.Dataset({'temp': (['time', 'location'], data)}) + >>> ds.fxplot.line() # Line plot of all variables + >>> ds.fxplot.stacked_bar() # Stacked bar chart + >>> ds.fxplot.heatmap('temp') # Heatmap of specific variable +""" + +from __future__ import annotations + +from typing import Any, Literal + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import xarray as xr + +from .color_processing import ColorType, process_colors +from .config import CONFIG + + +def _resolve_auto_facets( + ds: xr.Dataset, + facet_col: str | Literal['auto'] | None, + facet_row: str | Literal['auto'] | None, + animation_frame: str | Literal['auto'] | None = None, +) -> tuple[str | None, str | None, str | None]: + """Resolve 'auto' facet/animation dimensions based on available data dimensions. + + When 'auto' is specified, extra dimensions are assigned to slots based on: + - CONFIG.Plotting.extra_dim_priority: Order of dimensions (default: cluster -> period -> scenario) + - CONFIG.Plotting.dim_slot_priority: Order of slots (default: facet_col -> facet_row -> animation_frame) + + Args: + ds: Dataset to check for available dimensions. + facet_col: Dimension name, 'auto', or None. + facet_row: Dimension name, 'auto', or None. + animation_frame: Dimension name, 'auto', or None. + + Returns: + Tuple of (resolved_facet_col, resolved_facet_row, resolved_animation_frame). + Each is either a valid dimension name or None. + """ + # Get available extra dimensions with size > 1, sorted by priority + available = {d for d in ds.dims if ds.sizes[d] > 1} + extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] + used: set[str] = set() + + # Map slot names to their input values + slots = { + 'facet_col': facet_col, + 'facet_row': facet_row, + 'animation_frame': animation_frame, + } + results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} + + # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used + for slot_name, value in slots.items(): + if value is not None and value != 'auto': + if value in available and value not in used: + used.add(value) + results[slot_name] = value + + # Second pass: resolve 'auto' slots in dim_slot_priority order + dim_iter = iter(d for d in extra_dims if d not in used) + for slot_name in CONFIG.Plotting.dim_slot_priority: + if slots.get(slot_name) == 'auto': + next_dim = next(dim_iter, None) + if next_dim: + used.add(next_dim) + results[slot_name] = next_dim + + return results['facet_col'], results['facet_row'], results['animation_frame'] + + +def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: + """Convert xarray Dataset to long-form DataFrame for plotly express.""" + if not ds.data_vars: + return pd.DataFrame() + if all(ds[var].ndim == 0 for var in ds.data_vars): + rows = [{var_name: var, value_name: float(ds[var].values)} for var in ds.data_vars] + return pd.DataFrame(rows) + df = ds.to_dataframe().reset_index() + # Only use coordinates that are actually present as columns after reset_index + coord_cols = [c for c in ds.coords.keys() if c in df.columns] + return df.melt(id_vars=coord_cols, var_name=var_name, value_name=value_name) + + +@xr.register_dataset_accessor('fxplot') +class DatasetPlotAccessor: + """Plot accessor for any xr.Dataset. Access via ``dataset.fxplot``. + + Provides convenient plotting methods that automatically handle multi-dimensional + data through faceting and animation. All methods return a Plotly Figure. + + This accessor is globally registered when flixopt is imported and works on + any xr.Dataset. + + Examples: + Basic usage:: + + import flixopt + import xarray as xr + + ds = xr.Dataset({'A': (['time'], [1, 2, 3]), 'B': (['time'], [3, 2, 1])}) + ds.fxplot.stacked_bar() + ds.fxplot.line() + ds.fxplot.area() + + With faceting:: + + ds.fxplot.stacked_bar(facet_col='scenario') + ds.fxplot.line(facet_col='period', animation_frame='scenario') + + Heatmap:: + + ds.fxplot.heatmap('temperature') + """ + + def __init__(self, xarray_obj: xr.Dataset) -> None: + """Initialize the accessor with an xr.Dataset object.""" + self._ds = xarray_obj + + def bar( + self, + *, + colors: ColorType | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + **px_kwargs: Any, + ) -> go.Figure: + """Create a grouped bar chart from the dataset. + + Args: + colors: Color specification (colorscale name, color list, or dict mapping). + title: Plot title. + facet_col: Dimension for column facets. 'auto' uses CONFIG priority. + facet_row: Dimension for row facets. 'auto' uses CONFIG priority. + animation_frame: Dimension for animation slider. + facet_cols: Number of columns in facet grid wrap. + **px_kwargs: Additional arguments passed to plotly.express.bar. + + Returns: + Plotly Figure. + """ + df = _dataset_to_long_df(self._ds) + if df.empty: + return go.Figure() + + x_col = 'time' if 'time' in df.columns else df.columns[0] + variables = df['variable'].unique().tolist() + color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame + ) + + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + fig_kwargs: dict[str, Any] = { + 'data_frame': df, + 'x': x_col, + 'y': 'value', + 'color': 'variable', + 'color_discrete_map': color_map, + 'title': title, + 'barmode': 'group', + **px_kwargs, + } + + if actual_facet_col: + fig_kwargs['facet_col'] = actual_facet_col + if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + fig_kwargs['facet_col_wrap'] = facet_col_wrap + if actual_facet_row: + fig_kwargs['facet_row'] = actual_facet_row + if actual_anim: + fig_kwargs['animation_frame'] = actual_anim + + return px.bar(**fig_kwargs) + + def stacked_bar( + self, + *, + colors: ColorType | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + **px_kwargs: Any, + ) -> go.Figure: + """Create a stacked bar chart from the dataset. + + Variables in the dataset become stacked segments. Positive and negative + values are stacked separately. + + Args: + colors: Color specification (colorscale name, color list, or dict mapping). + title: Plot title. + facet_col: Dimension for column facets. 'auto' uses CONFIG priority. + facet_row: Dimension for row facets. + animation_frame: Dimension for animation slider. + facet_cols: Number of columns in facet grid wrap. + **px_kwargs: Additional arguments passed to plotly.express.bar. + + Returns: + Plotly Figure. + """ + df = _dataset_to_long_df(self._ds) + if df.empty: + return go.Figure() + + x_col = 'time' if 'time' in df.columns else df.columns[0] + variables = df['variable'].unique().tolist() + color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame + ) + + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + fig_kwargs: dict[str, Any] = { + 'data_frame': df, + 'x': x_col, + 'y': 'value', + 'color': 'variable', + 'color_discrete_map': color_map, + 'title': title, + **px_kwargs, + } + + if actual_facet_col: + fig_kwargs['facet_col'] = actual_facet_col + if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + fig_kwargs['facet_col_wrap'] = facet_col_wrap + if actual_facet_row: + fig_kwargs['facet_row'] = actual_facet_row + if actual_anim: + fig_kwargs['animation_frame'] = actual_anim + + fig = px.bar(**fig_kwargs) + fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) + fig.update_traces(marker_line_width=0) + return fig + + def line( + self, + *, + colors: ColorType | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + line_shape: str = 'hv', + **px_kwargs: Any, + ) -> go.Figure: + """Create a line chart from the dataset. + + Each variable in the dataset becomes a separate line. + + Args: + colors: Color specification (colorscale name, color list, or dict mapping). + title: Plot title. + facet_col: Dimension for column facets. 'auto' uses CONFIG priority. + facet_row: Dimension for row facets. + animation_frame: Dimension for animation slider. + facet_cols: Number of columns in facet grid wrap. + line_shape: Line interpolation ('linear', 'hv', 'vh', 'hvh', 'vhv', 'spline'). + Default 'hv' for stepped lines. + **px_kwargs: Additional arguments passed to plotly.express.line. + + Returns: + Plotly Figure. + """ + df = _dataset_to_long_df(self._ds) + if df.empty: + return go.Figure() + + x_col = 'time' if 'time' in df.columns else df.columns[0] + variables = df['variable'].unique().tolist() + color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame + ) + + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + fig_kwargs: dict[str, Any] = { + 'data_frame': df, + 'x': x_col, + 'y': 'value', + 'color': 'variable', + 'color_discrete_map': color_map, + 'title': title, + 'line_shape': line_shape, + **px_kwargs, + } + + if actual_facet_col: + fig_kwargs['facet_col'] = actual_facet_col + if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + fig_kwargs['facet_col_wrap'] = facet_col_wrap + if actual_facet_row: + fig_kwargs['facet_row'] = actual_facet_row + if actual_anim: + fig_kwargs['animation_frame'] = actual_anim + + return px.line(**fig_kwargs) + + def area( + self, + *, + colors: ColorType | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + line_shape: str = 'hv', + **px_kwargs: Any, + ) -> go.Figure: + """Create a stacked area chart from the dataset. + + Args: + colors: Color specification (colorscale name, color list, or dict mapping). + title: Plot title. + facet_col: Dimension for column facets. 'auto' uses CONFIG priority. + facet_row: Dimension for row facets. + animation_frame: Dimension for animation slider. + facet_cols: Number of columns in facet grid wrap. + line_shape: Line interpolation. Default 'hv' for stepped. + **px_kwargs: Additional arguments passed to plotly.express.area. + + Returns: + Plotly Figure. + """ + df = _dataset_to_long_df(self._ds) + if df.empty: + return go.Figure() + + x_col = 'time' if 'time' in df.columns else df.columns[0] + variables = df['variable'].unique().tolist() + color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame + ) + + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + fig_kwargs: dict[str, Any] = { + 'data_frame': df, + 'x': x_col, + 'y': 'value', + 'color': 'variable', + 'color_discrete_map': color_map, + 'title': title, + 'line_shape': line_shape, + **px_kwargs, + } + + if actual_facet_col: + fig_kwargs['facet_col'] = actual_facet_col + if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + fig_kwargs['facet_col_wrap'] = facet_col_wrap + if actual_facet_row: + fig_kwargs['facet_row'] = actual_facet_row + if actual_anim: + fig_kwargs['animation_frame'] = actual_anim + + return px.area(**fig_kwargs) + + def heatmap( + self, + variable: str | None = None, + *, + colors: str | list[str] | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + **imshow_kwargs: Any, + ) -> go.Figure: + """Create a heatmap visualization. + + If the dataset has multiple variables, select one with the `variable` parameter. + If only one variable exists, it is used automatically. + + Args: + variable: Variable name to plot. Required if dataset has multiple variables. + If None and dataset has one variable, that variable is used. + colors: Colorscale name or list of colors. + title: Plot title. + facet_col: Dimension for column facets. + animation_frame: Dimension for animation slider. + facet_cols: Number of columns in facet grid wrap. + **imshow_kwargs: Additional arguments passed to plotly.express.imshow. + + Returns: + Plotly Figure. + """ + # Select single variable + if variable is None: + if len(self._ds.data_vars) == 1: + variable = list(self._ds.data_vars)[0] + else: + raise ValueError( + f'Dataset has {len(self._ds.data_vars)} variables. ' + f"Please specify which variable to plot with variable='name'." + ) + + da = self._ds[variable] + + if da.size == 0: + return go.Figure() + + colors = colors or CONFIG.Plotting.default_sequential_colorscale + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + + actual_facet_col, _, actual_anim = _resolve_auto_facets(self._ds, facet_col, None, animation_frame) + + imshow_args: dict[str, Any] = { + 'img': da, + 'color_continuous_scale': colors, + 'title': title or variable, + **imshow_kwargs, + } + + if actual_facet_col and actual_facet_col in da.dims: + imshow_args['facet_col'] = actual_facet_col + if facet_col_wrap < da.sizes[actual_facet_col]: + imshow_args['facet_col_wrap'] = facet_col_wrap + + if actual_anim and actual_anim in da.dims: + imshow_args['animation_frame'] = actual_anim + + return px.imshow(**imshow_args) From f574c0ce3deba6315af0b66f7a47793c997e91e8 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 1 Jan 2026 21:33:35 +0100 Subject: [PATCH 02/62] Add fxplot acessor showcase --- docs/notebooks/fxplot_accessor_demo.ipynb | 323 ++++++++++++++++++++++ 1 file changed, 323 insertions(+) create mode 100644 docs/notebooks/fxplot_accessor_demo.ipynb diff --git a/docs/notebooks/fxplot_accessor_demo.ipynb b/docs/notebooks/fxplot_accessor_demo.ipynb new file mode 100644 index 000000000..d4d7b69b6 --- /dev/null +++ b/docs/notebooks/fxplot_accessor_demo.ipynb @@ -0,0 +1,323 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataset Plot Accessor Demo (`.fxplot`)\n", + "\n", + "This notebook demonstrates the new `.fxplot` accessor for `xr.Dataset` objects.\n", + "It provides convenient Plotly Express plotting methods with smart auto-faceting and coloring." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Sample Data\n", + "\n", + "Let's create a multi-dimensional dataset to demonstrate the plotting capabilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simple time-series dataset\n", + "np.random.seed(42)\n", + "time = pd.date_range('2024-01-01', periods=24, freq='h')\n", + "\n", + "ds_simple = xr.Dataset(\n", + " {\n", + " 'Solar': (['time'], np.maximum(0, np.sin(np.linspace(0, 2 * np.pi, 24)) * 50 + np.random.randn(24) * 5)),\n", + " 'Wind': (['time'], np.abs(np.random.randn(24) * 20 + 30)),\n", + " 'Demand': (['time'], np.abs(np.sin(np.linspace(0, 2 * np.pi, 24) + 1) * 40 + 50 + np.random.randn(24) * 5)),\n", + " },\n", + " coords={'time': time},\n", + ")\n", + "\n", + "ds_simple" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Line Plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds_simple.fxplot.line(title='Energy Generation & Demand')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Stacked Bar Chart" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds_simple[['Solar', 'Wind']].fxplot.stacked_bar(title='Renewable Generation')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Area Chart" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds_simple[['Solar', 'Wind']].fxplot.area(title='Stacked Area - Generation')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Grouped Bar Chart" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds_simple.fxplot.bar(title='Grouped Bar Chart')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Heatmap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create 2D data for heatmap\n", + "ds_heatmap = xr.Dataset(\n", + " {\n", + " 'temperature': (['day', 'hour'], np.random.randn(7, 24) * 5 + 20),\n", + " },\n", + " coords={\n", + " 'day': pd.date_range('2024-01-01', periods=7, freq='D'),\n", + " 'hour': range(24),\n", + " },\n", + ")\n", + "\n", + "ds_heatmap.fxplot.heatmap('temperature', title='Temperature Heatmap')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-Dimensional Data with Auto-Faceting\n", + "\n", + "The accessor automatically handles extra dimensions by assigning them to facets or animation based on CONFIG priority." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Dataset with scenario dimension\n", + "ds_scenarios = xr.Dataset(\n", + " {\n", + " 'Solar': (\n", + " ['time', 'scenario'],\n", + " np.column_stack(\n", + " [\n", + " np.maximum(0, np.sin(np.linspace(0, 2 * np.pi, 24)) * 50),\n", + " np.maximum(0, np.sin(np.linspace(0, 2 * np.pi, 24)) * 70), # High scenario\n", + " ]\n", + " ),\n", + " ),\n", + " 'Wind': (\n", + " ['time', 'scenario'],\n", + " np.column_stack(\n", + " [\n", + " np.abs(np.random.randn(24) * 20 + 30),\n", + " np.abs(np.random.randn(24) * 25 + 40), # High scenario\n", + " ]\n", + " ),\n", + " ),\n", + " },\n", + " coords={\n", + " 'time': time,\n", + " 'scenario': ['base', 'high'],\n", + " },\n", + ")\n", + "\n", + "ds_scenarios" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Auto-faceting assigns 'scenario' to facet_col\n", + "ds_scenarios.fxplot.line(title='Generation by Scenario (Auto-Faceted)')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Explicit faceting\n", + "ds_scenarios.fxplot.stacked_bar(facet_col='scenario', title='Stacked Bar by Scenario')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Animation Support" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use animation instead of faceting\n", + "ds_scenarios.fxplot.area(facet_col=None, animation_frame='scenario', title='Animated by Scenario')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Colors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Using a colorscale name\n", + "ds_simple.fxplot.line(colors='viridis', title='With Viridis Colorscale')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Using explicit color mapping\n", + "ds_simple.fxplot.stacked_bar(\n", + " colors={'Solar': 'gold', 'Wind': 'skyblue', 'Demand': 'salmon'}, title='With Custom Colors'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chaining with Plotly Methods\n", + "\n", + "Since all methods return `go.Figure`, you can chain Plotly's update methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " ds_simple.fxplot.line(title='Customized Plot')\n", + " .update_layout(xaxis_title='Time of Day', yaxis_title='Power (MW)', legend_title='Source', template='plotly_white')\n", + " .update_traces(line_width=2)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-filtering with xarray\n", + "\n", + "Filter data using xarray methods before plotting." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select specific time range\n", + "ds_simple.sel(time=slice('2024-01-01 06:00', '2024-01-01 18:00')).fxplot.line(title='Daytime Only')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select specific variables\n", + "ds_simple[['Solar', 'Wind']].fxplot.area(title='Renewables Only')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From edf89ef7271d59167d670d4cbe631566752c83e4 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 1 Jan 2026 21:50:16 +0100 Subject: [PATCH 03/62] The internal plot accessors now leverage the shared .fxplot implementation, reducing code duplication while maintaining the same functionality (data preparation, color resolution from components, PlotResult wrapping). --- flixopt/dataset_plot_accessor.py | 173 +++++++++++++++++++++++++++++++ flixopt/statistics_accessor.py | 129 ++--------------------- 2 files changed, 179 insertions(+), 123 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 35f8db47c..9f836c56c 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -444,3 +444,176 @@ def heatmap( imshow_args['animation_frame'] = actual_anim return px.imshow(**imshow_args) + + +@xr.register_dataarray_accessor('fxplot') +class DataArrayPlotAccessor: + """Plot accessor for any xr.DataArray. Access via ``dataarray.fxplot``. + + Provides convenient plotting methods. For bar/stacked_bar/line/area, + the DataArray is converted to a Dataset first. For heatmap, it works + directly with the DataArray. + + Examples: + Basic usage:: + + import flixopt + import xarray as xr + + da = xr.DataArray([1, 2, 3], dims=['time'], name='temperature') + da.fxplot.line() + da.fxplot.heatmap() + """ + + def __init__(self, xarray_obj: xr.DataArray) -> None: + """Initialize the accessor with an xr.DataArray object.""" + self._da = xarray_obj + + def _to_dataset(self) -> xr.Dataset: + """Convert DataArray to Dataset for plotting.""" + name = self._da.name or 'value' + return self._da.to_dataset(name=name) + + def bar( + self, + *, + colors: ColorType | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + **px_kwargs: Any, + ) -> go.Figure: + """Create a grouped bar chart. See DatasetPlotAccessor.bar for details.""" + return self._to_dataset().fxplot.bar( + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + facet_cols=facet_cols, + **px_kwargs, + ) + + def stacked_bar( + self, + *, + colors: ColorType | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + **px_kwargs: Any, + ) -> go.Figure: + """Create a stacked bar chart. See DatasetPlotAccessor.stacked_bar for details.""" + return self._to_dataset().fxplot.stacked_bar( + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + facet_cols=facet_cols, + **px_kwargs, + ) + + def line( + self, + *, + colors: ColorType | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + line_shape: str = 'hv', + **px_kwargs: Any, + ) -> go.Figure: + """Create a line chart. See DatasetPlotAccessor.line for details.""" + return self._to_dataset().fxplot.line( + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + facet_cols=facet_cols, + line_shape=line_shape, + **px_kwargs, + ) + + def area( + self, + *, + colors: ColorType | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + line_shape: str = 'hv', + **px_kwargs: Any, + ) -> go.Figure: + """Create a stacked area chart. See DatasetPlotAccessor.area for details.""" + return self._to_dataset().fxplot.area( + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + facet_cols=facet_cols, + line_shape=line_shape, + **px_kwargs, + ) + + def heatmap( + self, + *, + colors: str | list[str] | None = None, + title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + **imshow_kwargs: Any, + ) -> go.Figure: + """Create a heatmap visualization directly from the DataArray. + + Args: + colors: Colorscale name or list of colors. + title: Plot title. + facet_col: Dimension for column facets. + animation_frame: Dimension for animation slider. + facet_cols: Number of columns in facet grid wrap. + **imshow_kwargs: Additional arguments passed to plotly.express.imshow. + + Returns: + Plotly Figure. + """ + da = self._da + + if da.size == 0: + return go.Figure() + + colors = colors or CONFIG.Plotting.default_sequential_colorscale + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + + # Use Dataset for facet resolution + ds_for_resolution = da.to_dataset(name='_temp') + actual_facet_col, _, actual_anim = _resolve_auto_facets(ds_for_resolution, facet_col, None, animation_frame) + + imshow_args: dict[str, Any] = { + 'img': da, + 'color_continuous_scale': colors, + 'title': title or (da.name if da.name else ''), + **imshow_kwargs, + } + + if actual_facet_col and actual_facet_col in da.dims: + imshow_args['facet_col'] = actual_facet_col + if facet_col_wrap < da.sizes[actual_facet_col]: + imshow_args['facet_col_wrap'] = facet_col_wrap + + if actual_anim and actual_anim in da.dims: + imshow_args['animation_frame'] = actual_anim + + return px.imshow(**imshow_args) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index bee26a0e2..e01880f76 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -124,54 +124,6 @@ def _reshape_time_for_heatmap( return result.transpose('timestep', 'timeframe', *other_dims) -def _heatmap_figure( - data: xr.DataArray, - colors: str | list[str] | None = None, - title: str = '', - facet_col: str | None = None, - animation_frame: str | None = None, - facet_col_wrap: int | None = None, - **imshow_kwargs: Any, -) -> go.Figure: - """Create heatmap figure using px.imshow. - - Args: - data: DataArray with 2-4 dimensions. First two are heatmap axes. - colors: Colorscale name (str) or list of colors. Dicts are not supported - for heatmaps as color_continuous_scale requires a colorscale specification. - title: Plot title. - facet_col: Dimension for subplot columns. - animation_frame: Dimension for animation slider. - facet_col_wrap: Max columns before wrapping. - **imshow_kwargs: Additional args for px.imshow. - - Returns: - Plotly Figure. - """ - if data.size == 0: - return go.Figure() - - colors = colors or CONFIG.Plotting.default_sequential_colorscale - facet_col_wrap = facet_col_wrap or CONFIG.Plotting.default_facet_cols - - imshow_args: dict[str, Any] = { - 'img': data, - 'color_continuous_scale': colors, - 'title': title, - **imshow_kwargs, - } - - if facet_col and facet_col in data.dims: - imshow_args['facet_col'] = facet_col - if facet_col_wrap < data.sizes[facet_col]: - imshow_args['facet_col_wrap'] = facet_col_wrap - - if animation_frame and animation_frame in data.dims: - imshow_args['animation_frame'] = animation_frame - - return px.imshow(**imshow_args) - - # --- Helper functions --- @@ -308,69 +260,6 @@ def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str return df.melt(id_vars=coord_cols, var_name=var_name, value_name=value_name) -def _create_stacked_bar( - ds: xr.Dataset, - colors: ColorType, - title: str, - facet_col: str | None, - facet_row: str | None, - animation_frame: str | None = None, - **plotly_kwargs: Any, -) -> go.Figure: - """Create a stacked bar chart from xarray Dataset.""" - df = _dataset_to_long_df(ds) - if df.empty: - return go.Figure() - x_col = 'time' if 'time' in df.columns else df.columns[0] - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) - fig = px.bar( - df, - x=x_col, - y='value', - color='variable', - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - color_discrete_map=color_map, - title=title, - **plotly_kwargs, - ) - fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) - fig.update_traces(marker_line_width=0) - return fig - - -def _create_line( - ds: xr.Dataset, - colors: ColorType, - title: str, - facet_col: str | None, - facet_row: str | None, - animation_frame: str | None = None, - **plotly_kwargs: Any, -) -> go.Figure: - """Create a line chart from xarray Dataset.""" - df = _dataset_to_long_df(ds) - if df.empty: - return go.Figure() - x_col = 'time' if 'time' in df.columns else df.columns[0] - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) - return px.line( - df, - x=x_col, - y='value', - color='variable', - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - color_discrete_map=color_map, - title=title, - **plotly_kwargs, - ) - - # --- Statistics Accessor (data only) --- @@ -1507,8 +1396,7 @@ def balance( first_var = next(iter(ds.data_vars)) unit_label = ds[first_var].attrs.get('unit', '') - fig = _create_stacked_bar( - ds, + fig = ds.fxplot.stacked_bar( colors=colors, title=f'{node} [{unit_label}]' if unit_label else node, facet_col=actual_facet_col, @@ -1632,8 +1520,7 @@ def carrier_balance( first_var = next(iter(ds.data_vars)) unit_label = ds[first_var].attrs.get('unit', '') - fig = _create_stacked_bar( - ds, + fig = ds.fxplot.stacked_bar( colors=colors, title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', facet_col=actual_facet_col, @@ -1766,8 +1653,7 @@ def heatmap( if has_multiple_vars: da = da.rename('') - fig = _heatmap_figure( - da, + fig = da.fxplot.heatmap( colors=colors, facet_col=actual_facet, animation_frame=actual_animation, @@ -1861,8 +1747,7 @@ def flows( first_var = next(iter(ds.data_vars)) unit_label = ds[first_var].attrs.get('unit', '') - fig = _create_line( - ds, + fig = ds.fxplot.line( colors=colors, title=f'Flows [{unit_label}]' if unit_label else 'Flows', facet_col=actual_facet_col, @@ -2038,8 +1923,7 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: first_var = next(iter(ds.data_vars)) unit_label = ds[first_var].attrs.get('unit', '') - fig = _create_line( - result_ds, + fig = result_ds.fxplot.line( colors=colors, title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', facet_col=actual_facet_col, @@ -2258,8 +2142,7 @@ def charge_states( ds, facet_col, facet_row, animation_frame ) - fig = _create_line( - ds, + fig = ds.fxplot.line( colors=colors, title='Storage Charge States', facet_col=actual_facet_col, From 2b7aa63f8380b7fff0dc4e4002c42d97866010f3 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 1 Jan 2026 21:54:55 +0100 Subject: [PATCH 04/62] Fix notebook --- docs/notebooks/fxplot_accessor_demo.ipynb | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/notebooks/fxplot_accessor_demo.ipynb b/docs/notebooks/fxplot_accessor_demo.ipynb index d4d7b69b6..71f9b8245 100644 --- a/docs/notebooks/fxplot_accessor_demo.ipynb +++ b/docs/notebooks/fxplot_accessor_demo.ipynb @@ -18,7 +18,11 @@ "source": [ "import numpy as np\n", "import pandas as pd\n", - "import xarray as xr" + "import xarray as xr\n", + "\n", + "import flixopt as fx\n", + "\n", + "fx.__version__" ] }, { From 8d500093665f808dbc1271595318a19509f70782 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 09:56:29 +0100 Subject: [PATCH 05/62] 1. xlabel/ylabel parameters - Added to bar(), stacked_bar(), line(), area(), and duration_curve() methods in both DatasetPlotAccessor and DataArrayPlotAccessor 2. scatter() method - Plots two variables against each other with x and y parameters 3. pie() method - Creates pie charts from aggregated (scalar) dataset values, e.g. ds.sum('time').fxplot.pie() 4. duration_curve() method - Sorts values along the time dimension in descending order, with optional normalize parameter for percentage x-axis 5. CONFIG.Plotting.default_line_shape - New config option (default 'hv') that controls the default line shape for line(), area(), and duration_curve() methods --- flixopt/config.py | 3 + flixopt/dataset_plot_accessor.py | 260 ++++++++++++++++++++++++++++++- 2 files changed, 255 insertions(+), 8 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index 7e7c784cb..ad5db2897 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -163,6 +163,7 @@ def format(self, record): 'default_facet_cols': 3, 'default_sequential_colorscale': 'turbo', 'default_qualitative_colorscale': 'plotly', + 'default_line_shape': 'hv', 'extra_dim_priority': ('cluster', 'period', 'scenario'), 'dim_slot_priority': ('facet_col', 'facet_row', 'animation_frame'), } @@ -585,6 +586,7 @@ class Plotting: default_facet_cols: int = _DEFAULTS['plotting']['default_facet_cols'] default_sequential_colorscale: str = _DEFAULTS['plotting']['default_sequential_colorscale'] default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale'] + default_line_shape: str = _DEFAULTS['plotting']['default_line_shape'] extra_dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['extra_dim_priority'] dim_slot_priority: tuple[str, ...] = _DEFAULTS['plotting']['dim_slot_priority'] @@ -687,6 +689,7 @@ def to_dict(cls) -> dict: 'default_facet_cols': cls.Plotting.default_facet_cols, 'default_sequential_colorscale': cls.Plotting.default_sequential_colorscale, 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale, + 'default_line_shape': cls.Plotting.default_line_shape, 'extra_dim_priority': cls.Plotting.extra_dim_priority, 'dim_slot_priority': cls.Plotting.dim_slot_priority, }, diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 9f836c56c..877f8598b 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -132,6 +132,8 @@ def bar( *, colors: ColorType | None = None, title: str = '', + xlabel: str = '', + ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = None, animation_frame: str | Literal['auto'] | None = None, @@ -143,6 +145,8 @@ def bar( Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. + xlabel: X-axis label. + ylabel: Y-axis label. facet_col: Dimension for column facets. 'auto' uses CONFIG priority. facet_row: Dimension for row facets. 'auto' uses CONFIG priority. animation_frame: Dimension for animation slider. @@ -175,6 +179,10 @@ def bar( 'barmode': 'group', **px_kwargs, } + if xlabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} + if ylabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} if actual_facet_col: fig_kwargs['facet_col'] = actual_facet_col @@ -192,6 +200,8 @@ def stacked_bar( *, colors: ColorType | None = None, title: str = '', + xlabel: str = '', + ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = None, animation_frame: str | Literal['auto'] | None = None, @@ -206,6 +216,8 @@ def stacked_bar( Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. + xlabel: X-axis label. + ylabel: Y-axis label. facet_col: Dimension for column facets. 'auto' uses CONFIG priority. facet_row: Dimension for row facets. animation_frame: Dimension for animation slider. @@ -237,6 +249,10 @@ def stacked_bar( 'title': title, **px_kwargs, } + if xlabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} + if ylabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} if actual_facet_col: fig_kwargs['facet_col'] = actual_facet_col @@ -257,11 +273,13 @@ def line( *, colors: ColorType | None = None, title: str = '', + xlabel: str = '', + ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = None, animation_frame: str | Literal['auto'] | None = None, facet_cols: int | None = None, - line_shape: str = 'hv', + line_shape: str | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a line chart from the dataset. @@ -271,12 +289,14 @@ def line( Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. + xlabel: X-axis label. + ylabel: Y-axis label. facet_col: Dimension for column facets. 'auto' uses CONFIG priority. facet_row: Dimension for row facets. animation_frame: Dimension for animation slider. facet_cols: Number of columns in facet grid wrap. line_shape: Line interpolation ('linear', 'hv', 'vh', 'hvh', 'vhv', 'spline'). - Default 'hv' for stepped lines. + Default from CONFIG.Plotting.default_line_shape. **px_kwargs: Additional arguments passed to plotly.express.line. Returns: @@ -302,9 +322,13 @@ def line( 'color': 'variable', 'color_discrete_map': color_map, 'title': title, - 'line_shape': line_shape, + 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, **px_kwargs, } + if xlabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} + if ylabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} if actual_facet_col: fig_kwargs['facet_col'] = actual_facet_col @@ -322,11 +346,13 @@ def area( *, colors: ColorType | None = None, title: str = '', + xlabel: str = '', + ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = None, animation_frame: str | Literal['auto'] | None = None, facet_cols: int | None = None, - line_shape: str = 'hv', + line_shape: str | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked area chart from the dataset. @@ -334,11 +360,13 @@ def area( Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. + xlabel: X-axis label. + ylabel: Y-axis label. facet_col: Dimension for column facets. 'auto' uses CONFIG priority. facet_row: Dimension for row facets. animation_frame: Dimension for animation slider. facet_cols: Number of columns in facet grid wrap. - line_shape: Line interpolation. Default 'hv' for stepped. + line_shape: Line interpolation. Default from CONFIG.Plotting.default_line_shape. **px_kwargs: Additional arguments passed to plotly.express.area. Returns: @@ -364,9 +392,13 @@ def area( 'color': 'variable', 'color_discrete_map': color_map, 'title': title, - 'line_shape': line_shape, + 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, **px_kwargs, } + if xlabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} + if ylabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} if actual_facet_col: fig_kwargs['facet_col'] = actual_facet_col @@ -445,6 +477,202 @@ def heatmap( return px.imshow(**imshow_args) + def scatter( + self, + x: str, + y: str, + *, + colors: ColorType | None = None, + title: str = '', + xlabel: str = '', + ylabel: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + **px_kwargs: Any, + ) -> go.Figure: + """Create a scatter plot from two variables in the dataset. + + Args: + x: Variable name for x-axis. + y: Variable name for y-axis. + colors: Color specification (colorscale name, color list, or dict mapping). + title: Plot title. + xlabel: X-axis label. + ylabel: Y-axis label. + facet_col: Dimension for column facets. 'auto' uses CONFIG priority. + facet_row: Dimension for row facets. + animation_frame: Dimension for animation slider. + facet_cols: Number of columns in facet grid wrap. + **px_kwargs: Additional arguments passed to plotly.express.scatter. + + Returns: + Plotly Figure. + """ + if x not in self._ds.data_vars: + raise ValueError(f"Variable '{x}' not found in dataset. Available: {list(self._ds.data_vars)}") + if y not in self._ds.data_vars: + raise ValueError(f"Variable '{y}' not found in dataset. Available: {list(self._ds.data_vars)}") + + df = self._ds[[x, y]].to_dataframe().reset_index() + if df.empty: + return go.Figure() + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame + ) + + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + fig_kwargs: dict[str, Any] = { + 'data_frame': df, + 'x': x, + 'y': y, + 'title': title, + **px_kwargs, + } + if xlabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x: xlabel} + if ylabel: + fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), y: ylabel} + + if actual_facet_col: + fig_kwargs['facet_col'] = actual_facet_col + if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + fig_kwargs['facet_col_wrap'] = facet_col_wrap + if actual_facet_row: + fig_kwargs['facet_row'] = actual_facet_row + if actual_anim: + fig_kwargs['animation_frame'] = actual_anim + + return px.scatter(**fig_kwargs) + + def pie( + self, + *, + colors: ColorType | None = None, + title: str = '', + **px_kwargs: Any, + ) -> go.Figure: + """Create a pie chart from aggregated dataset values. + + The dataset should be reduced to scalar values per variable (e.g., via .sum()). + Each variable becomes a slice of the pie. + + Args: + colors: Color specification (colorscale name, color list, or dict mapping). + title: Plot title. + **px_kwargs: Additional arguments passed to plotly.express.pie. + + Returns: + Plotly Figure. + + Example: + >>> ds.sum('time').fxplot.pie() # Sum over time, then pie chart + """ + # Check that all variables are scalar + non_scalar = [v for v in self._ds.data_vars if self._ds[v].ndim > 0] + if non_scalar: + raise ValueError( + f'Pie chart requires scalar values per variable. ' + f'Non-scalar variables: {non_scalar}. ' + f"Try reducing first: ds.sum('time').fxplot.pie()" + ) + + names = list(self._ds.data_vars) + values = [float(self._ds[v].values) for v in names] + df = pd.DataFrame({'variable': names, 'value': values}) + + color_map = process_colors(colors, names, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + + return px.pie( + df, + names='variable', + values='value', + title=title, + color='variable', + color_discrete_map=color_map, + **px_kwargs, + ) + + def duration_curve( + self, + *, + colors: ColorType | None = None, + title: str = '', + xlabel: str = '', + ylabel: str = '', + normalize: bool = True, + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = None, + facet_cols: int | None = None, + line_shape: str | None = None, + **px_kwargs: Any, + ) -> go.Figure: + """Create a duration curve (sorted values) from the dataset. + + Values are sorted in descending order along the 'time' dimension. + The x-axis shows duration (percentage or timesteps). + + Args: + colors: Color specification (colorscale name, color list, or dict mapping). + title: Plot title. + xlabel: X-axis label. Default 'Duration [%]' or 'Timesteps'. + ylabel: Y-axis label. + normalize: If True, x-axis shows percentage (0-100). If False, shows timestep index. + facet_col: Dimension for column facets. 'auto' uses CONFIG priority. + facet_row: Dimension for row facets. + animation_frame: Dimension for animation slider. + facet_cols: Number of columns in facet grid wrap. + line_shape: Line interpolation. Default from CONFIG.Plotting.default_line_shape. + **px_kwargs: Additional arguments passed to plotly.express.line. + + Returns: + Plotly Figure. + """ + import numpy as np + + if 'time' not in self._ds.dims: + raise ValueError("Duration curve requires a 'time' dimension.") + + # Sort each variable along time dimension (descending) + sorted_ds = self._ds.copy() + for var in sorted_ds.data_vars: + da = sorted_ds[var] + # Sort along time axis + sorted_values = np.sort(da.values, axis=da.dims.index('time'))[::-1] + sorted_ds[var] = (da.dims, sorted_values) + + # Replace time coordinate with duration + n_timesteps = sorted_ds.sizes['time'] + if normalize: + duration_coord = np.linspace(0, 100, n_timesteps) + sorted_ds = sorted_ds.assign_coords({'time': duration_coord}) + sorted_ds = sorted_ds.rename({'time': 'duration_pct'}) + default_xlabel = 'Duration [%]' + else: + duration_coord = np.arange(n_timesteps) + sorted_ds = sorted_ds.assign_coords({'time': duration_coord}) + sorted_ds = sorted_ds.rename({'time': 'duration'}) + default_xlabel = 'Timesteps' + + # Use line plot + fig = sorted_ds.fxplot.line( + colors=colors, + title=title or 'Duration Curve', + xlabel=xlabel or default_xlabel, + ylabel=ylabel, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + facet_cols=facet_cols, + line_shape=line_shape, + **px_kwargs, + ) + + return fig + @xr.register_dataarray_accessor('fxplot') class DataArrayPlotAccessor: @@ -479,6 +707,8 @@ def bar( *, colors: ColorType | None = None, title: str = '', + xlabel: str = '', + ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = None, animation_frame: str | Literal['auto'] | None = None, @@ -489,6 +719,8 @@ def bar( return self._to_dataset().fxplot.bar( colors=colors, title=title, + xlabel=xlabel, + ylabel=ylabel, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, @@ -501,6 +733,8 @@ def stacked_bar( *, colors: ColorType | None = None, title: str = '', + xlabel: str = '', + ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = None, animation_frame: str | Literal['auto'] | None = None, @@ -511,6 +745,8 @@ def stacked_bar( return self._to_dataset().fxplot.stacked_bar( colors=colors, title=title, + xlabel=xlabel, + ylabel=ylabel, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, @@ -523,17 +759,21 @@ def line( *, colors: ColorType | None = None, title: str = '', + xlabel: str = '', + ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = None, animation_frame: str | Literal['auto'] | None = None, facet_cols: int | None = None, - line_shape: str = 'hv', + line_shape: str | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a line chart. See DatasetPlotAccessor.line for details.""" return self._to_dataset().fxplot.line( colors=colors, title=title, + xlabel=xlabel, + ylabel=ylabel, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, @@ -547,17 +787,21 @@ def area( *, colors: ColorType | None = None, title: str = '', + xlabel: str = '', + ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = None, animation_frame: str | Literal['auto'] | None = None, facet_cols: int | None = None, - line_shape: str = 'hv', + line_shape: str | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked area chart. See DatasetPlotAccessor.area for details.""" return self._to_dataset().fxplot.area( colors=colors, title=title, + xlabel=xlabel, + ylabel=ylabel, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, From 1e70c78d9144df621b2d9c1b551bf055417b7e7f Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:24:47 +0100 Subject: [PATCH 06/62] Fix faceting of pie --- flixopt/dataset_plot_accessor.py | 72 +++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 19 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 877f8598b..112d15412 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -552,16 +552,22 @@ def pie( *, colors: ColorType | None = None, title: str = '', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = None, + facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a pie chart from aggregated dataset values. - The dataset should be reduced to scalar values per variable (e.g., via .sum()). - Each variable becomes a slice of the pie. + The dataset should be reduced so each variable has at most one remaining + dimension (for faceting). For scalar values, a single pie is shown. Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. + facet_col: Dimension for column facets. 'auto' uses CONFIG priority. + facet_row: Dimension for row facets. + facet_cols: Number of columns in facet grid wrap. **px_kwargs: Additional arguments passed to plotly.express.pie. Returns: @@ -569,31 +575,59 @@ def pie( Example: >>> ds.sum('time').fxplot.pie() # Sum over time, then pie chart + >>> ds.sum('time').fxplot.pie(facet_col='scenario') # Pie per scenario """ - # Check that all variables are scalar - non_scalar = [v for v in self._ds.data_vars if self._ds[v].ndim > 0] - if non_scalar: + # Check dimensionality - allow at most 1D for faceting + max_ndim = max((self._ds[v].ndim for v in self._ds.data_vars), default=0) + if max_ndim > 1: raise ValueError( - f'Pie chart requires scalar values per variable. ' - f'Non-scalar variables: {non_scalar}. ' - f"Try reducing first: ds.sum('time').fxplot.pie()" + 'Pie chart requires at most 1D data per variable (for faceting). ' + "Try reducing first: ds.sum('time').fxplot.pie()" ) names = list(self._ds.data_vars) - values = [float(self._ds[v].values) for v in names] - df = pd.DataFrame({'variable': names, 'value': values}) - color_map = process_colors(colors, names, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) - return px.pie( - df, - names='variable', - values='value', - title=title, - color='variable', - color_discrete_map=color_map, + # Scalar case - single pie + if max_ndim == 0: + values = [float(self._ds[v].values) for v in names] + df = pd.DataFrame({'variable': names, 'value': values}) + return px.pie( + df, + names='variable', + values='value', + title=title, + color='variable', + color_discrete_map=color_map, + **px_kwargs, + ) + + # 1D case - faceted pies + df = _dataset_to_long_df(self._ds) + if df.empty: + return go.Figure() + + actual_facet_col, actual_facet_row, _ = _resolve_auto_facets(self._ds, facet_col, facet_row, None) + + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + fig_kwargs: dict[str, Any] = { + 'data_frame': df, + 'names': 'variable', + 'values': 'value', + 'title': title, + 'color': 'variable', + 'color_discrete_map': color_map, **px_kwargs, - ) + } + + if actual_facet_col: + fig_kwargs['facet_col'] = actual_facet_col + if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + fig_kwargs['facet_col_wrap'] = facet_col_wrap + if actual_facet_row: + fig_kwargs['facet_row'] = actual_facet_row + + return px.pie(**fig_kwargs) def duration_curve( self, From 7be17d02e786ca4b8a5080857e6043935586c1cc Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:35:16 +0100 Subject: [PATCH 07/62] Improve auto dim handling --- flixopt/dataset_plot_accessor.py | 46 ++++++++++++++++---------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 112d15412..276081dee 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -135,8 +135,8 @@ def bar( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: @@ -203,8 +203,8 @@ def stacked_bar( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: @@ -276,8 +276,8 @@ def line( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, **px_kwargs: Any, @@ -349,8 +349,8 @@ def area( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, **px_kwargs: Any, @@ -418,7 +418,7 @@ def heatmap( colors: str | list[str] | None = None, title: str = '', facet_col: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **imshow_kwargs: Any, ) -> go.Figure: @@ -487,8 +487,8 @@ def scatter( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: @@ -553,7 +553,7 @@ def pie( colors: ColorType | None = None, title: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: @@ -638,8 +638,8 @@ def duration_curve( ylabel: str = '', normalize: bool = True, facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, **px_kwargs: Any, @@ -744,8 +744,8 @@ def bar( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: @@ -770,8 +770,8 @@ def stacked_bar( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: @@ -796,8 +796,8 @@ def line( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, **px_kwargs: Any, @@ -824,8 +824,8 @@ def area( xlabel: str = '', ylabel: str = '', facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = None, - animation_frame: str | Literal['auto'] | None = None, + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, **px_kwargs: Any, @@ -850,7 +850,7 @@ def heatmap( colors: str | list[str] | None = None, title: str = '', facet_col: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **imshow_kwargs: Any, ) -> go.Figure: From da72bb85a893433225c8fac10282f1e94fd873aa Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:39:45 +0100 Subject: [PATCH 08/62] Improve notebook --- docs/notebooks/fxplot_accessor_demo.ipynb | 268 +++++++++++++++++++--- 1 file changed, 238 insertions(+), 30 deletions(-) diff --git a/docs/notebooks/fxplot_accessor_demo.ipynb b/docs/notebooks/fxplot_accessor_demo.ipynb index 71f9b8245..934b819cb 100644 --- a/docs/notebooks/fxplot_accessor_demo.ipynb +++ b/docs/notebooks/fxplot_accessor_demo.ipynb @@ -151,9 +151,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Multi-Dimensional Data with Auto-Faceting\n", + "## Automatic Faceting & Animation\n", "\n", - "The accessor automatically handles extra dimensions by assigning them to facets or animation based on CONFIG priority." + "Extra dimensions are **automatically** assigned to `facet_col`, `facet_row`, and `animation_frame` based on CONFIG priority. Just call the plot method - no configuration needed!" ] }, { @@ -162,35 +162,20 @@ "metadata": {}, "outputs": [], "source": [ - "# Dataset with scenario dimension\n", - "ds_scenarios = xr.Dataset(\n", + "# Dataset with scenario AND period dimensions\n", + "ds_multi = xr.Dataset(\n", " {\n", - " 'Solar': (\n", - " ['time', 'scenario'],\n", - " np.column_stack(\n", - " [\n", - " np.maximum(0, np.sin(np.linspace(0, 2 * np.pi, 24)) * 50),\n", - " np.maximum(0, np.sin(np.linspace(0, 2 * np.pi, 24)) * 70), # High scenario\n", - " ]\n", - " ),\n", - " ),\n", - " 'Wind': (\n", - " ['time', 'scenario'],\n", - " np.column_stack(\n", - " [\n", - " np.abs(np.random.randn(24) * 20 + 30),\n", - " np.abs(np.random.randn(24) * 25 + 40), # High scenario\n", - " ]\n", - " ),\n", - " ),\n", + " 'Solar': (['time', 'scenario', 'period'], np.random.rand(24, 2, 3) * 50),\n", + " 'Wind': (['time', 'scenario', 'period'], np.random.rand(24, 2, 3) * 40 + 20),\n", " },\n", " coords={\n", " 'time': time,\n", " 'scenario': ['base', 'high'],\n", + " 'period': ['winter', 'spring', 'summer'],\n", " },\n", ")\n", "\n", - "ds_scenarios" + "ds_multi" ] }, { @@ -199,8 +184,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Auto-faceting assigns 'scenario' to facet_col\n", - "ds_scenarios.fxplot.line(title='Generation by Scenario (Auto-Faceted)')" + "# Just call .line() - dimensions are auto-assigned to facet_col, facet_row, animation_frame\n", + "ds_multi.fxplot.line(title='Auto-Faceted: Just Works!')" ] }, { @@ -209,15 +194,17 @@ "metadata": {}, "outputs": [], "source": [ - "# Explicit faceting\n", - "ds_scenarios.fxplot.stacked_bar(facet_col='scenario', title='Stacked Bar by Scenario')" + "# Same for stacked bar - auto-assigns period to facet_col, scenario to animation\n", + "ds_multi.fxplot.stacked_bar(title='Stacked Bar: Also Just Works!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Animation Support" + "## Customizing Facets & Animation\n", + "\n", + "Override auto-assignment when needed. Use `None` to disable a slot entirely." ] }, { @@ -226,8 +213,28 @@ "metadata": {}, "outputs": [], "source": [ - "# Use animation instead of faceting\n", - "ds_scenarios.fxplot.area(facet_col=None, animation_frame='scenario', title='Animated by Scenario')" + "# Swap: put scenario in facet_col, period in animation\n", + "ds_multi.fxplot.line(facet_col='scenario', animation_frame='period', title='Swapped: Scenario in Columns')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use both row and column facets - no animation\n", + "ds_multi.fxplot.area(facet_col='scenario', facet_row='period', animation_frame=None, title='Grid: Period × Scenario')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Or reduce dimensions with .sel() for a simpler plot\n", + "ds_multi.sel(scenario='base', period='summer').fxplot.line(title='Single Slice: No Faceting Needed')" ] }, { @@ -309,6 +316,207 @@ "# Select specific variables\n", "ds_simple[['Solar', 'Wind']].fxplot.area(title='Renewables Only')" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DataArray Accessor\n", + "\n", + "The `.fxplot` accessor also works on `xr.DataArray` objects directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a DataArray\n", + "da = xr.DataArray(\n", + " np.random.randn(24, 7) * 5 + 20,\n", + " dims=['time', 'day'],\n", + " coords={\n", + " 'time': pd.date_range('2024-01-01', periods=24, freq='h'),\n", + " 'day': pd.date_range('2024-01-01', periods=7, freq='D'),\n", + " },\n", + " name='temperature',\n", + ")\n", + "\n", + "# Heatmap directly from DataArray\n", + "da.fxplot.heatmap(title='DataArray Heatmap')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Line plot from DataArray (converts to Dataset internally)\n", + "da_1d = xr.DataArray(\n", + " np.sin(np.linspace(0, 4 * np.pi, 100)) * 50,\n", + " dims=['time'],\n", + " coords={'time': pd.date_range('2024-01-01', periods=100, freq='h')},\n", + " name='signal',\n", + ")\n", + "da_1d.fxplot.line(title='DataArray Line Plot')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Axis Labels\n", + "\n", + "Use `xlabel` and `ylabel` parameters to customize axis labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds_simple.fxplot.line(title='Generation with Custom Axis Labels', xlabel='Time of Day', ylabel='Power [MW]')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scatter Plot\n", + "\n", + "Plot two variables against each other." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Basic scatter plot\n", + "ds_simple.fxplot.scatter(\n", + " x='Solar', y='Demand', title='Solar vs Demand Correlation', xlabel='Solar Generation [MW]', ylabel='Demand [MW]'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Scatter with faceting by period, for one scenario\n", + "ds_multi.sel(scenario='high').fxplot.scatter(\n", + " x='Solar', y='Wind', facet_col='period', title='Solar vs Wind by Period (High Scenario)'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pie Chart\n", + "\n", + "Aggregate data to at most 1D per variable. Scalar data creates a single pie; 1D data creates faceted pies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Single pie from scalar values (sum over time)\n", + "ds_simple[['Solar', 'Wind']].sum('time').fxplot.pie(\n", + " title='Total Generation by Source', colors={'Solar': 'gold', 'Wind': 'skyblue'}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Faceted pie - one pie per period (for one scenario)\n", + "ds_multi.sum('time').fxplot.pie(\n", + " facet_col='period',\n", + " title='Generation by Source per Period (Base Scenario)',\n", + " colors={'Solar': 'gold', 'Wind': 'skyblue'},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Duration Curve\n", + "\n", + "Sort values along the time dimension to create a duration curve. Useful for analyzing capacity utilization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Duration curve with normalized x-axis (percentage)\n", + "ds_simple.fxplot.duration_curve(title='Generation Duration Curves')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Duration curve with absolute timesteps on x-axis\n", + "ds_simple.fxplot.duration_curve(normalize=False, title='Duration Curves (Timesteps)')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Duration curve with faceting by period (for one scenario)\n", + "ds_multi.sel(scenario='base').fxplot.duration_curve(facet_col='period', title='Duration Curves by Period')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Line Shape Configuration\n", + "\n", + "The default line shape is controlled by `CONFIG.Plotting.default_line_shape` (default: `'hv'` for step plots).\n", + "Override per-plot with the `line_shape` parameter. Options: `'linear'`, `'hv'`, `'vh'`, `'hvh'`, `'vhv'`, `'spline'`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Default step plot (hv)\n", + "ds_simple[['Solar']].fxplot.line(title='Default Step Plot (hv)')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Override to linear interpolation\n", + "ds_simple[['Solar']].fxplot.line(line_shape='linear', title='Linear Interpolation')" + ] } ], "metadata": { From 31bfb85d957061cec2ab3c1c4fbb39de894a0559 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:44:35 +0100 Subject: [PATCH 09/62] Fix pie plot --- flixopt/dataset_plot_accessor.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 276081dee..ed235074b 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -554,19 +554,21 @@ def pie( title: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a pie chart from aggregated dataset values. - The dataset should be reduced so each variable has at most one remaining - dimension (for faceting). For scalar values, a single pie is shown. + Extra dimensions are auto-assigned to facet_col, facet_row, and animation_frame. + For scalar values, a single pie is shown. Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. facet_col: Dimension for column facets. 'auto' uses CONFIG priority. - facet_row: Dimension for row facets. + facet_row: Dimension for row facets. 'auto' uses CONFIG priority. + animation_frame: Dimension for animation slider. 'auto' uses CONFIG priority. facet_cols: Number of columns in facet grid wrap. **px_kwargs: Additional arguments passed to plotly.express.pie. @@ -577,13 +579,7 @@ def pie( >>> ds.sum('time').fxplot.pie() # Sum over time, then pie chart >>> ds.sum('time').fxplot.pie(facet_col='scenario') # Pie per scenario """ - # Check dimensionality - allow at most 1D for faceting max_ndim = max((self._ds[v].ndim for v in self._ds.data_vars), default=0) - if max_ndim > 1: - raise ValueError( - 'Pie chart requires at most 1D data per variable (for faceting). ' - "Try reducing first: ds.sum('time').fxplot.pie()" - ) names = list(self._ds.data_vars) color_map = process_colors(colors, names, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) @@ -602,12 +598,14 @@ def pie( **px_kwargs, ) - # 1D case - faceted pies + # Multi-dimensional case - faceted/animated pies df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - actual_facet_col, actual_facet_row, _ = _resolve_auto_facets(self._ds, facet_col, facet_row, None) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -626,6 +624,8 @@ def pie( fig_kwargs['facet_col_wrap'] = facet_col_wrap if actual_facet_row: fig_kwargs['facet_row'] = actual_facet_row + if actual_anim: + fig_kwargs['animation_frame'] = actual_anim return px.pie(**fig_kwargs) From 450ec0e7a7db8e8909907066f4b2e8e60c4d9d42 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 11:36:02 +0100 Subject: [PATCH 10/62] Logic order changed: 1. X-axis is now determined first using CONFIG.Plotting.x_dim_priority 2. Facets are resolved from remaining dimensions (x-axis excluded) x_dim_priority expanded: x_dim_priority = ('time', 'duration', 'duration_pct', 'period', 'scenario', 'cluster') - Time-like dims first, then common grouping dims as fallback - variable stays excluded (it's used for color, not x-axis) _get_x_dim() refactored: - Now takes dims: list[str] instead of a DataFrame - More versatile - works with any list of dimension names --- flixopt/config.py | 5 + flixopt/dataset_plot_accessor.py | 165 ++++++++++++++++++------------- 2 files changed, 102 insertions(+), 68 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index ad5db2897..87d16615a 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -166,6 +166,7 @@ def format(self, record): 'default_line_shape': 'hv', 'extra_dim_priority': ('cluster', 'period', 'scenario'), 'dim_slot_priority': ('facet_col', 'facet_row', 'animation_frame'), + 'x_dim_priority': ('time', 'duration', 'duration_pct', 'period', 'scenario', 'cluster'), } ), 'solving': MappingProxyType( @@ -565,6 +566,8 @@ class Plotting: Default: ('cluster', 'period', 'scenario'). dim_slot_priority: Order of slots to fill with extra dimensions. Default: ('facet_col', 'facet_row', 'animation_frame'). + x_dim_priority: Order of dimensions to prefer for x-axis when 'auto'. + Default: ('time', 'duration', 'duration_pct'). Examples: ```python @@ -589,6 +592,7 @@ class Plotting: default_line_shape: str = _DEFAULTS['plotting']['default_line_shape'] extra_dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['extra_dim_priority'] dim_slot_priority: tuple[str, ...] = _DEFAULTS['plotting']['dim_slot_priority'] + x_dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['x_dim_priority'] class Carriers: """Default carrier definitions for common energy types. @@ -692,6 +696,7 @@ def to_dict(cls) -> dict: 'default_line_shape': cls.Plotting.default_line_shape, 'extra_dim_priority': cls.Plotting.extra_dim_priority, 'dim_slot_priority': cls.Plotting.dim_slot_priority, + 'x_dim_priority': cls.Plotting.x_dim_priority, }, } diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index ed235074b..34253327b 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -25,11 +25,34 @@ from .config import CONFIG +def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str: + """Determine the x-axis dimension from available dimensions. + + Args: + dims: List of available dimension names. + x: Explicit dimension name, 'auto' to use priority list, or None. + + Returns: + Dimension name to use for x-axis. + """ + if x and x != 'auto': + return x + + # Check priority list first + for dim in CONFIG.Plotting.x_dim_priority: + if dim in dims: + return dim + + # Fallback to first available dimension + return dims[0] if dims else '' + + def _resolve_auto_facets( ds: xr.Dataset, facet_col: str | Literal['auto'] | None, facet_row: str | Literal['auto'] | None, animation_frame: str | Literal['auto'] | None = None, + exclude_dims: set[str] | None = None, ) -> tuple[str | None, str | None, str | None]: """Resolve 'auto' facet/animation dimensions based on available data dimensions. @@ -42,13 +65,15 @@ def _resolve_auto_facets( facet_col: Dimension name, 'auto', or None. facet_row: Dimension name, 'auto', or None. animation_frame: Dimension name, 'auto', or None. + exclude_dims: Dimensions to exclude (e.g., x-axis dimension). Returns: Tuple of (resolved_facet_col, resolved_facet_row, resolved_animation_frame). Each is either a valid dimension name or None. """ - # Get available extra dimensions with size > 1, sorted by priority - available = {d for d in ds.dims if ds.sizes[d] > 1} + # Get available extra dimensions with size > 1, excluding specified dims + exclude = exclude_dims or set() + available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] used: set[str] = set() @@ -130,6 +155,7 @@ def __init__(self, xarray_obj: xr.Dataset) -> None: def bar( self, *, + x: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -143,6 +169,7 @@ def bar( """Create a grouped bar chart from the dataset. Args: + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -156,18 +183,20 @@ def bar( Returns: Plotly Figure. """ + # Determine x-axis first, then resolve facets from remaining dims + dims = list(self._ds.dims) + x_col = _get_x_dim(dims, x) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + ) + df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - x_col = 'time' if 'time' in df.columns else df.columns[0] variables = df['variable'].unique().tolist() color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame - ) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, @@ -198,6 +227,7 @@ def bar( def stacked_bar( self, *, + x: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -214,6 +244,7 @@ def stacked_bar( values are stacked separately. Args: + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -227,18 +258,20 @@ def stacked_bar( Returns: Plotly Figure. """ + # Determine x-axis first, then resolve facets from remaining dims + dims = list(self._ds.dims) + x_col = _get_x_dim(dims, x) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + ) + df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - x_col = 'time' if 'time' in df.columns else df.columns[0] variables = df['variable'].unique().tolist() color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame - ) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, @@ -271,6 +304,7 @@ def stacked_bar( def line( self, *, + x: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -287,6 +321,7 @@ def line( Each variable in the dataset becomes a separate line. Args: + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -302,18 +337,20 @@ def line( Returns: Plotly Figure. """ + # Determine x-axis first, then resolve facets from remaining dims + dims = list(self._ds.dims) + x_col = _get_x_dim(dims, x) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + ) + df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - x_col = 'time' if 'time' in df.columns else df.columns[0] variables = df['variable'].unique().tolist() color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame - ) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, @@ -344,6 +381,7 @@ def line( def area( self, *, + x: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -358,6 +396,7 @@ def area( """Create a stacked area chart from the dataset. Args: + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -372,18 +411,20 @@ def area( Returns: Plotly Figure. """ + # Determine x-axis first, then resolve facets from remaining dims + dims = list(self._ds.dims) + x_col = _get_x_dim(dims, x) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + ) + df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - x_col = 'time' if 'time' in df.columns else df.columns[0] variables = df['variable'].unique().tolist() color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame - ) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, @@ -629,41 +670,37 @@ def pie( return px.pie(**fig_kwargs) - def duration_curve( - self, - *, - colors: ColorType | None = None, - title: str = '', - xlabel: str = '', - ylabel: str = '', - normalize: bool = True, - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - facet_cols: int | None = None, - line_shape: str | None = None, - **px_kwargs: Any, - ) -> go.Figure: - """Create a duration curve (sorted values) from the dataset. + +@xr.register_dataset_accessor('fxstats') +class DatasetStatsAccessor: + """Statistics/transformation accessor for any xr.Dataset. Access via ``dataset.fxstats``. + + Provides data transformation methods that return new datasets. + Chain with ``.fxplot`` for visualization. + + Examples: + Duration curve:: + + ds.fxstats.to_duration_curve().fxplot.line() + """ + + def __init__(self, xarray_obj: xr.Dataset) -> None: + self._ds = xarray_obj + + def to_duration_curve(self, *, normalize: bool = True) -> xr.Dataset: + """Transform dataset to duration curve format (sorted values). Values are sorted in descending order along the 'time' dimension. - The x-axis shows duration (percentage or timesteps). + The time coordinate is replaced with duration (percentage or index). Args: - colors: Color specification (colorscale name, color list, or dict mapping). - title: Plot title. - xlabel: X-axis label. Default 'Duration [%]' or 'Timesteps'. - ylabel: Y-axis label. normalize: If True, x-axis shows percentage (0-100). If False, shows timestep index. - facet_col: Dimension for column facets. 'auto' uses CONFIG priority. - facet_row: Dimension for row facets. - animation_frame: Dimension for animation slider. - facet_cols: Number of columns in facet grid wrap. - line_shape: Line interpolation. Default from CONFIG.Plotting.default_line_shape. - **px_kwargs: Additional arguments passed to plotly.express.line. Returns: - Plotly Figure. + Transformed xr.Dataset with duration coordinate instead of time. + + Example: + >>> ds.fxstats.to_duration_curve().fxplot.line(title='Duration Curve') """ import numpy as np @@ -674,7 +711,7 @@ def duration_curve( sorted_ds = self._ds.copy() for var in sorted_ds.data_vars: da = sorted_ds[var] - # Sort along time axis + # Sort along time axis (descending) sorted_values = np.sort(da.values, axis=da.dims.index('time'))[::-1] sorted_ds[var] = (da.dims, sorted_values) @@ -684,28 +721,12 @@ def duration_curve( duration_coord = np.linspace(0, 100, n_timesteps) sorted_ds = sorted_ds.assign_coords({'time': duration_coord}) sorted_ds = sorted_ds.rename({'time': 'duration_pct'}) - default_xlabel = 'Duration [%]' else: duration_coord = np.arange(n_timesteps) sorted_ds = sorted_ds.assign_coords({'time': duration_coord}) sorted_ds = sorted_ds.rename({'time': 'duration'}) - default_xlabel = 'Timesteps' - - # Use line plot - fig = sorted_ds.fxplot.line( - colors=colors, - title=title or 'Duration Curve', - xlabel=xlabel or default_xlabel, - ylabel=ylabel, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - facet_cols=facet_cols, - line_shape=line_shape, - **px_kwargs, - ) - return fig + return sorted_ds @xr.register_dataarray_accessor('fxplot') @@ -739,6 +760,7 @@ def _to_dataset(self) -> xr.Dataset: def bar( self, *, + x: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -751,6 +773,7 @@ def bar( ) -> go.Figure: """Create a grouped bar chart. See DatasetPlotAccessor.bar for details.""" return self._to_dataset().fxplot.bar( + x=x, colors=colors, title=title, xlabel=xlabel, @@ -765,6 +788,7 @@ def bar( def stacked_bar( self, *, + x: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -777,6 +801,7 @@ def stacked_bar( ) -> go.Figure: """Create a stacked bar chart. See DatasetPlotAccessor.stacked_bar for details.""" return self._to_dataset().fxplot.stacked_bar( + x=x, colors=colors, title=title, xlabel=xlabel, @@ -791,6 +816,7 @@ def stacked_bar( def line( self, *, + x: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -804,6 +830,7 @@ def line( ) -> go.Figure: """Create a line chart. See DatasetPlotAccessor.line for details.""" return self._to_dataset().fxplot.line( + x=x, colors=colors, title=title, xlabel=xlabel, @@ -819,6 +846,7 @@ def line( def area( self, *, + x: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -832,6 +860,7 @@ def area( ) -> go.Figure: """Create a stacked area chart. See DatasetPlotAccessor.area for details.""" return self._to_dataset().fxplot.area( + x=x, colors=colors, title=title, xlabel=xlabel, From 29752381a1501f41e02777fe08f2440273db7970 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 12:24:26 +0100 Subject: [PATCH 11/62] Add x parameter and x_dim_priority config to fxplot - Add `x` parameter to bar/stacked_bar/line/area for explicit x-axis control - Add CONFIG.Plotting.x_dim_priority for auto x-axis selection order - X-axis determined first, facets from remaining dimensions - Refactor _get_x_column -> _get_x_dim (takes dim list, not DataFrame) - Support scalar data (no dims) by using 'variable' as x-axis --- docs/notebooks/fxplot_accessor_demo.ipynb | 41 +++++++++++++++++------ flixopt/dataset_plot_accessor.py | 6 ++-- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/docs/notebooks/fxplot_accessor_demo.ipynb b/docs/notebooks/fxplot_accessor_demo.ipynb index 934b819cb..bcf065cc8 100644 --- a/docs/notebooks/fxplot_accessor_demo.ipynb +++ b/docs/notebooks/fxplot_accessor_demo.ipynb @@ -198,6 +198,16 @@ "ds_multi.fxplot.stacked_bar(title='Stacked Bar: Also Just Works!')" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Same for stacked bar - auto-assigns period to facet_col, scenario to animation\n", + "ds_multi.sum('time').fxplot.stacked_bar(title='Stacked Bar: Also Just Works!', x='variable', colors=None)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -224,7 +234,9 @@ "outputs": [], "source": [ "# Use both row and column facets - no animation\n", - "ds_multi.fxplot.area(facet_col='scenario', facet_row='period', animation_frame=None, title='Grid: Period × Scenario')" + "ds_multi.sum('time').fxplot.area(\n", + " facet_col='scenario', facet_row='period', animation_frame=None, title='Grid: Period × Scenario'\n", + ")" ] }, { @@ -441,10 +453,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Faceted pie - one pie per period (for one scenario)\n", + "# Faceted pie - auto-assigns scenario and period to facets\n", "ds_multi.sum('time').fxplot.pie(\n", - " facet_col='period',\n", - " title='Generation by Source per Period (Base Scenario)',\n", + " title='Generation by Source (Scenario × Period)',\n", " colors={'Solar': 'gold', 'Wind': 'skyblue'},\n", ")" ] @@ -455,7 +466,7 @@ "source": [ "## Duration Curve\n", "\n", - "Sort values along the time dimension to create a duration curve. Useful for analyzing capacity utilization." + "Use `.fxstats.to_duration_curve()` to transform data, then `.fxplot.line()` to plot. Clean separation of transformation and plotting." ] }, { @@ -465,7 +476,7 @@ "outputs": [], "source": [ "# Duration curve with normalized x-axis (percentage)\n", - "ds_simple.fxplot.duration_curve(title='Generation Duration Curves')" + "ds_simple.fxstats.to_duration_curve().fxplot.line(title='Duration Curves', xlabel='Duration [%]')" ] }, { @@ -474,8 +485,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Duration curve with absolute timesteps on x-axis\n", - "ds_simple.fxplot.duration_curve(normalize=False, title='Duration Curves (Timesteps)')" + "# Duration curve with absolute timesteps\n", + "ds_simple.fxstats.to_duration_curve(normalize=False).fxplot.line(title='Duration Curves', xlabel='Timesteps')" ] }, { @@ -484,8 +495,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Duration curve with faceting by period (for one scenario)\n", - "ds_multi.sel(scenario='base').fxplot.duration_curve(facet_col='period', title='Duration Curves by Period')" + "# Duration curve with auto-faceting - works seamlessly!\n", + "ds_multi.fxstats.to_duration_curve().fxplot.line(title='Duration Curves (Auto-Faceted)', xlabel='Duration [%]')" ] }, { @@ -526,8 +537,16 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.11.0" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" } }, "nbformat": 4, diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 34253327b..acc8a0d44 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -33,7 +33,7 @@ def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str x: Explicit dimension name, 'auto' to use priority list, or None. Returns: - Dimension name to use for x-axis. + Dimension name to use for x-axis. Returns 'variable' for scalar data. """ if x and x != 'auto': return x @@ -43,8 +43,8 @@ def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str if dim in dims: return dim - # Fallback to first available dimension - return dims[0] if dims else '' + # Fallback to first available dimension, or 'variable' for scalar data + return dims[0] if dims else 'variable' def _resolve_auto_facets( From 21350db6d56fd572f302e69d6d8b0ccf7744a1ee Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 12:32:10 +0100 Subject: [PATCH 12/62] Add x parameter and smart dimension handling to fxplot - Add `x` parameter to bar/stacked_bar/line/area for explicit x-axis control - Add CONFIG.Plotting.x_dim_priority for auto x-axis selection Default: ('time', 'duration', 'duration_pct', 'period', 'scenario', 'cluster') - X-axis determined first, facets resolved from remaining dimensions - Refactor _get_x_column -> _get_x_dim (takes dim list, more versatile) - Support scalar data (no dims) by using 'variable' as x-axis - Skip color='variable' when x='variable' to avoid double encoding - Fix _dataset_to_long_df to use dims (not just coords) as id_vars --- flixopt/dataset_plot_accessor.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index acc8a0d44..ba8bbe6d8 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -112,9 +112,9 @@ def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str rows = [{var_name: var, value_name: float(ds[var].values)} for var in ds.data_vars] return pd.DataFrame(rows) df = ds.to_dataframe().reset_index() - # Only use coordinates that are actually present as columns after reset_index - coord_cols = [c for c in ds.coords.keys() if c in df.columns] - return df.melt(id_vars=coord_cols, var_name=var_name, value_name=value_name) + # Use dims (not just coords) as id_vars - dims without coords become integer indices + id_cols = [c for c in ds.dims if c in df.columns] + return df.melt(id_vars=id_cols, var_name=var_name, value_name=value_name) @xr.register_dataset_accessor('fxplot') @@ -202,12 +202,14 @@ def bar( 'data_frame': df, 'x': x_col, 'y': 'value', - 'color': 'variable', - 'color_discrete_map': color_map, 'title': title, 'barmode': 'group', **px_kwargs, } + # Only color by variable if it's not already on x-axis + if x_col != 'variable': + fig_kwargs['color'] = 'variable' + fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} if ylabel: @@ -277,11 +279,13 @@ def stacked_bar( 'data_frame': df, 'x': x_col, 'y': 'value', - 'color': 'variable', - 'color_discrete_map': color_map, 'title': title, **px_kwargs, } + # Only color by variable if it's not already on x-axis + if x_col != 'variable': + fig_kwargs['color'] = 'variable' + fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} if ylabel: @@ -356,12 +360,14 @@ def line( 'data_frame': df, 'x': x_col, 'y': 'value', - 'color': 'variable', - 'color_discrete_map': color_map, 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, **px_kwargs, } + # Only color by variable if it's not already on x-axis + if x_col != 'variable': + fig_kwargs['color'] = 'variable' + fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} if ylabel: @@ -430,12 +436,14 @@ def area( 'data_frame': df, 'x': x_col, 'y': 'value', - 'color': 'variable', - 'color_discrete_map': color_map, 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, **px_kwargs, } + # Only color by variable if it's not already on x-axis + if x_col != 'variable': + fig_kwargs['color'] = 'variable' + fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} if ylabel: From d1f1a39b7f59ba1e6ddcc6dc88aa0685c8bc40ec Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 12:55:44 +0100 Subject: [PATCH 13/62] Add x parameter and smart dimension handling to fxplot - Add `x` parameter to bar/stacked_bar/line/area for explicit x-axis control - Add CONFIG.Plotting.x_dim_priority for auto x-axis selection Default: ('time', 'duration', 'duration_pct', 'period', 'scenario', 'cluster') - X-axis determined first, facets resolved from remaining dimensions - Refactor _get_x_column -> _get_x_dim (takes dim list, more versatile) - Support scalar data (no dims) by using 'variable' as x-axis - Skip color='variable' when x='variable' to avoid double encoding - Fix _dataset_to_long_df to use dims (not just coords) as id_vars - Ensure px_kwargs properly overrides all defaults (color, facets, etc.) --- flixopt/dataset_plot_accessor.py | 60 +++++++++++++++----------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index ba8bbe6d8..cb9c7ac72 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -204,27 +204,26 @@ def bar( 'y': 'value', 'title': title, 'barmode': 'group', - **px_kwargs, } - # Only color by variable if it's not already on x-axis - if x_col != 'variable': + # Only color by variable if it's not already on x-axis (and user didn't override) + if x_col != 'variable' and 'color' not in px_kwargs: fig_kwargs['color'] = 'variable' fig_kwargs['color_discrete_map'] = color_map if xlabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} + fig_kwargs['labels'] = {x_col: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col: + if actual_facet_col and 'facet_col' not in px_kwargs: fig_kwargs['facet_col'] = actual_facet_col if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: + if actual_facet_row and 'facet_row' not in px_kwargs: fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: + if actual_anim and 'animation_frame' not in px_kwargs: fig_kwargs['animation_frame'] = actual_anim - return px.bar(**fig_kwargs) + return px.bar(**{**fig_kwargs, **px_kwargs}) def stacked_bar( self, @@ -280,27 +279,26 @@ def stacked_bar( 'x': x_col, 'y': 'value', 'title': title, - **px_kwargs, } - # Only color by variable if it's not already on x-axis - if x_col != 'variable': + # Only color by variable if it's not already on x-axis (and user didn't override) + if x_col != 'variable' and 'color' not in px_kwargs: fig_kwargs['color'] = 'variable' fig_kwargs['color_discrete_map'] = color_map if xlabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} + fig_kwargs['labels'] = {x_col: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col: + if actual_facet_col and 'facet_col' not in px_kwargs: fig_kwargs['facet_col'] = actual_facet_col if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: + if actual_facet_row and 'facet_row' not in px_kwargs: fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: + if actual_anim and 'animation_frame' not in px_kwargs: fig_kwargs['animation_frame'] = actual_anim - fig = px.bar(**fig_kwargs) + fig = px.bar(**{**fig_kwargs, **px_kwargs}) fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) fig.update_traces(marker_line_width=0) return fig @@ -362,27 +360,26 @@ def line( 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, - **px_kwargs, } - # Only color by variable if it's not already on x-axis - if x_col != 'variable': + # Only color by variable if it's not already on x-axis (and user didn't override) + if x_col != 'variable' and 'color' not in px_kwargs: fig_kwargs['color'] = 'variable' fig_kwargs['color_discrete_map'] = color_map if xlabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} + fig_kwargs['labels'] = {x_col: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col: + if actual_facet_col and 'facet_col' not in px_kwargs: fig_kwargs['facet_col'] = actual_facet_col if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: + if actual_facet_row and 'facet_row' not in px_kwargs: fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: + if actual_anim and 'animation_frame' not in px_kwargs: fig_kwargs['animation_frame'] = actual_anim - return px.line(**fig_kwargs) + return px.line(**{**fig_kwargs, **px_kwargs}) def area( self, @@ -438,27 +435,26 @@ def area( 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, - **px_kwargs, } - # Only color by variable if it's not already on x-axis - if x_col != 'variable': + # Only color by variable if it's not already on x-axis (and user didn't override) + if x_col != 'variable' and 'color' not in px_kwargs: fig_kwargs['color'] = 'variable' fig_kwargs['color_discrete_map'] = color_map if xlabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), x_col: xlabel} + fig_kwargs['labels'] = {x_col: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col: + if actual_facet_col and 'facet_col' not in px_kwargs: fig_kwargs['facet_col'] = actual_facet_col if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: + if actual_facet_row and 'facet_row' not in px_kwargs: fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: + if actual_anim and 'animation_frame' not in px_kwargs: fig_kwargs['animation_frame'] = actual_anim - return px.area(**fig_kwargs) + return px.area(**{**fig_kwargs, **px_kwargs}) def heatmap( self, From 9d40c82d1299dd33895a268f6a46e8c328932387 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 14:41:21 +0100 Subject: [PATCH 14/62] Improve documentation --- .../recipes/plotting-custom-data.md | 129 +++--------------- flixopt/dataset_plot_accessor.py | 44 +----- mkdocs.yml | 1 + 3 files changed, 22 insertions(+), 152 deletions(-) diff --git a/docs/user-guide/recipes/plotting-custom-data.md b/docs/user-guide/recipes/plotting-custom-data.md index 3c539e6ce..8c19931f3 100644 --- a/docs/user-guide/recipes/plotting-custom-data.md +++ b/docs/user-guide/recipes/plotting-custom-data.md @@ -1,125 +1,30 @@ # Plotting Custom Data -The plot accessor (`flow_system.statistics.plot`) is designed for visualizing optimization results using element labels. If you want to create faceted plots with your own custom data (not from a FlowSystem), you can use Plotly Express directly with xarray data. +While the plot accessor (`flow_system.statistics.plot`) is designed for optimization results, you often need to plot custom xarray data. The `.fxplot` accessor provides the same convenience for any `xr.Dataset` or `xr.DataArray`. -## Faceted Plots with Custom xarray Data - -The key is converting your xarray Dataset to a long-form DataFrame that Plotly Express expects: +## Quick Example ```python +import flixopt as fx import xarray as xr -import pandas as pd -import plotly.express as px -# Your custom xarray Dataset -my_data = xr.Dataset({ - 'Solar': (['time', 'scenario'], solar_values), - 'Wind': (['time', 'scenario'], wind_values), - 'Demand': (['time', 'scenario'], demand_values), -}, coords={ - 'time': timestamps, - 'scenario': ['Base', 'High RE', 'Low Demand'] +ds = xr.Dataset({ + 'Solar': (['time'], solar_values), + 'Wind': (['time'], wind_values), }) -# Convert to long-form DataFrame for Plotly Express -df = ( - my_data - .to_dataframe() - .reset_index() - .melt( - id_vars=['time', 'scenario'], # Keep as columns - var_name='variable', - value_name='value' - ) -) - -# Faceted stacked bar chart -fig = px.bar( - df, - x='time', - y='value', - color='variable', - facet_col='scenario', - barmode='relative', - title='Energy Balance by Scenario' -) -fig.show() - -# Faceted line plot -fig = px.line( - df, - x='time', - y='value', - color='variable', - facet_col='scenario' -) -fig.show() - -# Faceted area chart -fig = px.area( - df, - x='time', - y='value', - color='variable', - facet_col='scenario' -) -fig.show() -``` - -## Common Plotly Express Faceting Options - -| Parameter | Description | -|-----------|-------------| -| `facet_col` | Dimension for column subplots | -| `facet_row` | Dimension for row subplots | -| `animation_frame` | Dimension for animation slider | -| `facet_col_wrap` | Number of columns before wrapping | - -```python -# Row and column facets -fig = px.line(df, x='time', y='value', color='variable', - facet_col='scenario', facet_row='region') - -# Animation over time periods -fig = px.bar(df, x='variable', y='value', color='variable', - animation_frame='period', barmode='group') - -# Wrap columns -fig = px.line(df, x='time', y='value', color='variable', - facet_col='scenario', facet_col_wrap=2) +# Plot directly - no conversion needed! +ds.fxplot.line(title='Energy Generation') +ds.fxplot.stacked_bar(title='Stacked Generation') ``` -## Heatmaps with Custom Data - -For heatmaps, you can pass 2D arrays directly to `px.imshow`: - -```python -import plotly.express as px - -# 2D data (e.g., days × hours) -heatmap_data = my_data['Solar'].sel(scenario='Base').values.reshape(365, 24) +## Full Documentation -fig = px.imshow( - heatmap_data, - labels={'x': 'Hour', 'y': 'Day', 'color': 'Power [kW]'}, - aspect='auto', - color_continuous_scale='portland' -) -fig.show() - -# Faceted heatmaps using subplots -from plotly.subplots import make_subplots -import plotly.graph_objects as go - -scenarios = ['Base', 'High RE'] -fig = make_subplots(rows=1, cols=len(scenarios), subplot_titles=scenarios) - -for i, scenario in enumerate(scenarios, 1): - data = my_data['Solar'].sel(scenario=scenario).values.reshape(365, 24) - fig.add_trace(go.Heatmap(z=data, colorscale='portland'), row=1, col=i) - -fig.update_layout(title='Solar Output by Scenario') -fig.show() -``` +For comprehensive documentation with interactive examples, see the [Custom Data Plotting](../../notebooks/fxplot_accessor_demo.ipynb) notebook which covers: -This approach gives you full control over your visualizations while leveraging Plotly's powerful faceting capabilities. +- All available plot methods (line, bar, stacked_bar, area, scatter, heatmap, pie) +- Automatic x-axis selection and faceting +- Custom colors and axis labels +- Duration curves with `.fxstats.to_duration_curve()` +- Configuration options +- Combining with xarray operations diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index cb9c7ac72..70f7990c1 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -1,16 +1,4 @@ -"""Dataset plot accessor for xarray Datasets. - -Provides convenient plotting methods for any xr.Dataset via the .fxplot accessor. -This is globally registered and available on all xr.Dataset objects when flixopt is imported. - -Example: - >>> import flixopt - >>> import xarray as xr - >>> ds = xr.Dataset({'temp': (['time', 'location'], data)}) - >>> ds.fxplot.line() # Line plot of all variables - >>> ds.fxplot.stacked_bar() # Stacked bar chart - >>> ds.fxplot.heatmap('temp') # Heatmap of specific variable -""" +"""Xarray accessors for plotting (``.fxplot``) and statistics (``.fxstats``).""" from __future__ import annotations @@ -26,15 +14,7 @@ def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str: - """Determine the x-axis dimension from available dimensions. - - Args: - dims: List of available dimension names. - x: Explicit dimension name, 'auto' to use priority list, or None. - - Returns: - Dimension name to use for x-axis. Returns 'variable' for scalar data. - """ + """Select x-axis dim from priority list, or 'variable' for scalar data.""" if x and x != 'auto': return x @@ -54,23 +34,7 @@ def _resolve_auto_facets( animation_frame: str | Literal['auto'] | None = None, exclude_dims: set[str] | None = None, ) -> tuple[str | None, str | None, str | None]: - """Resolve 'auto' facet/animation dimensions based on available data dimensions. - - When 'auto' is specified, extra dimensions are assigned to slots based on: - - CONFIG.Plotting.extra_dim_priority: Order of dimensions (default: cluster -> period -> scenario) - - CONFIG.Plotting.dim_slot_priority: Order of slots (default: facet_col -> facet_row -> animation_frame) - - Args: - ds: Dataset to check for available dimensions. - facet_col: Dimension name, 'auto', or None. - facet_row: Dimension name, 'auto', or None. - animation_frame: Dimension name, 'auto', or None. - exclude_dims: Dimensions to exclude (e.g., x-axis dimension). - - Returns: - Tuple of (resolved_facet_col, resolved_facet_row, resolved_animation_frame). - Each is either a valid dimension name or None. - """ + """Assign 'auto' facet slots from available dims using CONFIG priority lists.""" # Get available extra dimensions with size > 1, excluding specified dims exclude = exclude_dims or set() available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} @@ -105,7 +69,7 @@ def _resolve_auto_facets( def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: - """Convert xarray Dataset to long-form DataFrame for plotly express.""" + """Convert Dataset to long-form DataFrame for Plotly Express.""" if not ds.data_vars: return pd.DataFrame() if all(ds[var].ndim == 0 for var in ds.data_vars): diff --git a/mkdocs.yml b/mkdocs.yml index 493937983..ca94a6302 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -73,6 +73,7 @@ nav: - Rolling Horizon: notebooks/08b-rolling-horizon.ipynb - Results: - Plotting: notebooks/09-plotting-and-data-access.ipynb + - Custom Data Plotting: notebooks/fxplot_accessor_demo.ipynb - API Reference: api-reference/ From bd314e03ed1625e9d4f68bae9be5fa36665445bd Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Fri, 2 Jan 2026 15:06:57 +0100 Subject: [PATCH 15/62] Fix notebook in docs --- docs/notebooks/fxplot_accessor_demo.ipynb | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/notebooks/fxplot_accessor_demo.ipynb b/docs/notebooks/fxplot_accessor_demo.ipynb index bcf065cc8..db8684d82 100644 --- a/docs/notebooks/fxplot_accessor_demo.ipynb +++ b/docs/notebooks/fxplot_accessor_demo.ipynb @@ -25,6 +25,17 @@ "fx.__version__" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.io as pio\n", + "\n", + "pio.renderers.default = 'notebook_connected'" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -53,7 +64,7 @@ " coords={'time': time},\n", ")\n", "\n", - "ds_simple" + "ds_simple.to_dataframe().head()" ] }, { @@ -175,7 +186,7 @@ " },\n", ")\n", "\n", - "ds_multi" + "ds_multi.to_dataframe().head()" ] }, { From 22702e0550b0d327338172d45323581b09b5c113 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 12:50:47 +0100 Subject: [PATCH 16/62] 1. heatmap kwarg merge order - Now uses **{**imshow_args, **imshow_kwargs} so user can override 2. scatter unused colors - Removed the unused parameter 3. to_duration_curve sorting - Changed [::-1] to np.flip(..., axis=time_axis) for correct multi-dimensional handling 4. DataArrayPlotAccessor.heatmap - Same kwarg merge fix --- flixopt/dataset_plot_accessor.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 70f7990c1..fc38f730b 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -473,7 +473,6 @@ def heatmap( 'img': da, 'color_continuous_scale': colors, 'title': title or variable, - **imshow_kwargs, } if actual_facet_col and actual_facet_col in da.dims: @@ -484,14 +483,13 @@ def heatmap( if actual_anim and actual_anim in da.dims: imshow_args['animation_frame'] = actual_anim - return px.imshow(**imshow_args) + return px.imshow(**{**imshow_args, **imshow_kwargs}) def scatter( self, x: str, y: str, *, - colors: ColorType | None = None, title: str = '', xlabel: str = '', ylabel: str = '', @@ -506,7 +504,6 @@ def scatter( Args: x: Variable name for x-axis. y: Variable name for y-axis. - colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. ylabel: Y-axis label. @@ -679,8 +676,9 @@ def to_duration_curve(self, *, normalize: bool = True) -> xr.Dataset: sorted_ds = self._ds.copy() for var in sorted_ds.data_vars: da = sorted_ds[var] - # Sort along time axis (descending) - sorted_values = np.sort(da.values, axis=da.dims.index('time'))[::-1] + time_axis = da.dims.index('time') + # Sort along time axis (descending) - use flip for correct axis + sorted_values = np.flip(np.sort(da.values, axis=time_axis), axis=time_axis) sorted_ds[var] = (da.dims, sorted_values) # Replace time coordinate with duration @@ -880,7 +878,6 @@ def heatmap( 'img': da, 'color_continuous_scale': colors, 'title': title or (da.name if da.name else ''), - **imshow_kwargs, } if actual_facet_col and actual_facet_col in da.dims: @@ -891,4 +888,4 @@ def heatmap( if actual_anim and actual_anim in da.dims: imshow_args['animation_frame'] = actual_anim - return px.imshow(**imshow_args) + return px.imshow(**{**imshow_args, **imshow_kwargs}) From ed33706bd57e9926f22f8639a842b80589e107d2 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 13:59:07 +0100 Subject: [PATCH 17/62] Improve docstrings --- flixopt/config.py | 10 +++------- flixopt/statistics_accessor.py | 4 ++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index 87d16615a..454f8ad3e 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -563,11 +563,8 @@ class Plotting: default_sequential_colorscale: Default colorscale for heatmaps and continuous data. default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts). extra_dim_priority: Order of extra dimensions when auto-assigning to slots. - Default: ('cluster', 'period', 'scenario'). dim_slot_priority: Order of slots to fill with extra dimensions. - Default: ('facet_col', 'facet_row', 'animation_frame'). x_dim_priority: Order of dimensions to prefer for x-axis when 'auto'. - Default: ('time', 'duration', 'duration_pct'). Examples: ```python @@ -576,10 +573,9 @@ class Plotting: CONFIG.Plotting.default_sequential_colorscale = 'plasma' CONFIG.Plotting.default_qualitative_colorscale = 'Dark24' - # Customize dimension handling - # With 2 extra dims (period, scenario): period → facet_col, scenario → facet_row - CONFIG.Plotting.extra_dim_priority = ('cluster', 'period', 'scenario') - CONFIG.Plotting.dim_slot_priority = ('facet_col', 'facet_row', 'animation_frame') + # Customize dimension handling for faceting + CONFIG.Plotting.extra_dim_priority = ('scenario', 'period', 'cluster') + CONFIG.Plotting.dim_slot_priority = ('facet_row', 'facet_col', 'animation_frame') ``` """ diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index e01880f76..1cbcbcd7d 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -189,8 +189,8 @@ def _resolve_auto_facets( """Resolve 'auto' facet/animation dimensions based on available data dimensions. When 'auto' is specified, extra dimensions are assigned to slots based on: - - CONFIG.Plotting.extra_dim_priority: Order of dimensions (default: cluster → period → scenario) - - CONFIG.Plotting.dim_slot_priority: Order of slots (default: facet_col → facet_row → animation_frame) + - CONFIG.Plotting.extra_dim_priority: Order of dimensions to assign. + - CONFIG.Plotting.dim_slot_priority: Order of slots to fill. Args: ds: Dataset to check for available dimensions. From 0c7965f357caa198ec2efff385b9b6e7a8344821 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 19:25:01 +0100 Subject: [PATCH 18/62] Update notebooks to not do file operations --- docs/notebooks/08a-aggregation.ipynb | 24 ++- docs/notebooks/08b-rolling-horizon.ipynb | 16 +- docs/notebooks/08c-clustering.ipynb | 15 +- .../08c2-clustering-storage-modes.ipynb | 15 +- .../08d-clustering-multiperiod.ipynb | 13 +- docs/notebooks/08e-clustering-internals.ipynb | 13 +- .../09-plotting-and-data-access.ipynb | 175 ++++++++---------- 7 files changed, 108 insertions(+), 163 deletions(-) diff --git a/docs/notebooks/08a-aggregation.ipynb b/docs/notebooks/08a-aggregation.ipynb index 6d0260539..410cd1715 100644 --- a/docs/notebooks/08a-aggregation.ipynb +++ b/docs/notebooks/08a-aggregation.ipynb @@ -59,21 +59,13 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", + "from data.generate_example_systems import create_district_heating_system\n", "\n", - "# Generate example data if not present (for local development)\n", - "data_file = Path('data/district_heating_system.nc4')\n", - "if not data_file.exists():\n", - " from data.generate_example_systems import create_district_heating_system\n", - "\n", - " fs = create_district_heating_system()\n", - " fs.to_netcdf(data_file)\n", - "\n", - "# Load the district heating system (real data from Zeitreihen2020.csv)\n", - "flow_system = fx.FlowSystem.from_netcdf(data_file)\n", + "flow_system = create_district_heating_system()\n", + "flow_system.connect_and_transform()\n", "\n", "timesteps = flow_system.timesteps\n", - "print(f'Loaded FlowSystem: {len(timesteps)} timesteps ({len(timesteps) / 96:.0f} days at 15-min resolution)')\n", + "print(f'FlowSystem: {len(timesteps)} timesteps ({len(timesteps) / 96:.0f} days at 15-min resolution)')\n", "print(f'Components: {list(flow_system.components.keys())}')" ] }, @@ -397,7 +389,13 @@ ] } ], - "metadata": {}, + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, "nbformat": 4, "nbformat_minor": 5 } diff --git a/docs/notebooks/08b-rolling-horizon.ipynb b/docs/notebooks/08b-rolling-horizon.ipynb index e43da8f2c..5032588fe 100644 --- a/docs/notebooks/08b-rolling-horizon.ipynb +++ b/docs/notebooks/08b-rolling-horizon.ipynb @@ -63,21 +63,13 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", + "from data.generate_example_systems import create_operational_system\n", "\n", - "# Generate example data if not present (for local development)\n", - "data_file = Path('data/operational_system.nc4')\n", - "if not data_file.exists():\n", - " from data.generate_example_systems import create_operational_system\n", - "\n", - " fs = create_operational_system()\n", - " fs.to_netcdf(data_file)\n", - "\n", - "# Load the operational system (real data from Zeitreihen2020.csv, two weeks)\n", - "flow_system = fx.FlowSystem.from_netcdf(data_file)\n", + "flow_system = create_operational_system()\n", + "flow_system.connect_and_transform()\n", "\n", "timesteps = flow_system.timesteps\n", - "print(f'Loaded FlowSystem: {len(timesteps)} timesteps ({len(timesteps) / 96:.0f} days at 15-min resolution)')\n", + "print(f'FlowSystem: {len(timesteps)} timesteps ({len(timesteps) / 96:.0f} days at 15-min resolution)')\n", "print(f'Components: {list(flow_system.components.keys())}')" ] }, diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index cf5b53b53..acd30ea94 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -27,7 +27,6 @@ "outputs": [], "source": [ "import timeit\n", - "from pathlib import Path\n", "\n", "import numpy as np\n", "import pandas as pd\n", @@ -56,19 +55,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Generate example data if not present\n", - "data_file = Path('data/district_heating_system.nc4')\n", - "if not data_file.exists():\n", - " from data.generate_example_systems import create_district_heating_system\n", + "from data.generate_example_systems import create_district_heating_system\n", "\n", - " fs = create_district_heating_system()\n", - " fs.to_netcdf(data_file)\n", - "\n", - "# Load the district heating system\n", - "flow_system = fx.FlowSystem.from_netcdf(data_file)\n", + "flow_system = create_district_heating_system()\n", + "flow_system.connect_and_transform()\n", "\n", "timesteps = flow_system.timesteps\n", - "print(f'Loaded FlowSystem: {len(timesteps)} timesteps ({len(timesteps) / 96:.0f} days at 15-min resolution)')\n", + "print(f'FlowSystem: {len(timesteps)} timesteps ({len(timesteps) / 96:.0f} days at 15-min resolution)')\n", "print(f'Components: {list(flow_system.components.keys())}')" ] }, diff --git a/docs/notebooks/08c2-clustering-storage-modes.ipynb b/docs/notebooks/08c2-clustering-storage-modes.ipynb index 163cf8729..c99d25dbd 100644 --- a/docs/notebooks/08c2-clustering-storage-modes.ipynb +++ b/docs/notebooks/08c2-clustering-storage-modes.ipynb @@ -27,7 +27,6 @@ "outputs": [], "source": [ "import timeit\n", - "from pathlib import Path\n", "\n", "import numpy as np\n", "import pandas as pd\n", @@ -61,19 +60,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Generate example data if not present\n", - "data_file = Path('data/seasonal_storage_system.nc4')\n", - "if not data_file.exists():\n", - " from data.generate_example_systems import create_seasonal_storage_system\n", + "from data.generate_example_systems import create_seasonal_storage_system\n", "\n", - " fs = create_seasonal_storage_system()\n", - " fs.to_netcdf(data_file)\n", - "\n", - "# Load the seasonal storage system\n", - "flow_system = fx.FlowSystem.from_netcdf(data_file)\n", + "flow_system = create_seasonal_storage_system()\n", + "flow_system.connect_and_transform()\n", "\n", "timesteps = flow_system.timesteps\n", - "print(f'Loaded FlowSystem: {len(timesteps)} timesteps ({len(timesteps) / 24:.0f} days)')\n", + "print(f'FlowSystem: {len(timesteps)} timesteps ({len(timesteps) / 24:.0f} days)')\n", "print(f'Components: {list(flow_system.components.keys())}')" ] }, diff --git a/docs/notebooks/08d-clustering-multiperiod.ipynb b/docs/notebooks/08d-clustering-multiperiod.ipynb index 84ff468ea..31c47ff38 100644 --- a/docs/notebooks/08d-clustering-multiperiod.ipynb +++ b/docs/notebooks/08d-clustering-multiperiod.ipynb @@ -28,7 +28,6 @@ "outputs": [], "source": [ "import timeit\n", - "from pathlib import Path\n", "\n", "import numpy as np\n", "import pandas as pd\n", @@ -62,16 +61,10 @@ "metadata": {}, "outputs": [], "source": [ - "# Generate example data if not present\n", - "data_file = Path('data/multiperiod_system.nc4')\n", - "if not data_file.exists():\n", - " from data.generate_example_systems import create_multiperiod_system\n", + "from data.generate_example_systems import create_multiperiod_system\n", "\n", - " fs = create_multiperiod_system()\n", - " fs.to_netcdf(data_file)\n", - "\n", - "# Load the multi-period system\n", - "flow_system = fx.FlowSystem.from_netcdf(data_file)\n", + "flow_system = create_multiperiod_system()\n", + "flow_system.connect_and_transform()\n", "\n", "print(f'Timesteps: {len(flow_system.timesteps)} ({len(flow_system.timesteps) // 24} days)')\n", "print(f'Periods: {list(flow_system.periods.values)}')\n", diff --git a/docs/notebooks/08e-clustering-internals.ipynb b/docs/notebooks/08e-clustering-internals.ipynb index a0ac80ca7..066ec749c 100644 --- a/docs/notebooks/08e-clustering-internals.ipynb +++ b/docs/notebooks/08e-clustering-internals.ipynb @@ -26,21 +26,14 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", + "from data.generate_example_systems import create_district_heating_system\n", "\n", "import flixopt as fx\n", "\n", "fx.CONFIG.notebook()\n", "\n", - "# Load the district heating system\n", - "data_file = Path('data/district_heating_system.nc4')\n", - "if not data_file.exists():\n", - " from data.generate_example_systems import create_district_heating_system\n", - "\n", - " fs = create_district_heating_system()\n", - " fs.to_netcdf(data_file)\n", - "\n", - "flow_system = fx.FlowSystem.from_netcdf(data_file)" + "flow_system = create_district_heating_system()\n", + "flow_system.connect_and_transform()" ] }, { diff --git a/docs/notebooks/09-plotting-and-data-access.ipynb b/docs/notebooks/09-plotting-and-data-access.ipynb index a4803adf4..39fa788da 100644 --- a/docs/notebooks/09-plotting-and-data-access.ipynb +++ b/docs/notebooks/09-plotting-and-data-access.ipynb @@ -11,7 +11,6 @@ "\n", "This notebook covers:\n", "\n", - "- Loading saved FlowSystems from NetCDF files\n", "- Accessing data (flow rates, sizes, effects, charge states)\n", "- Time series plots (balance, flows, storage)\n", "- Aggregated plots (sizes, effects, duration curves)\n", @@ -36,7 +35,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", + "from data.generate_example_systems import create_complex_system, create_multiperiod_system, create_simple_system\n", "\n", "import flixopt as fx\n", "\n", @@ -48,9 +47,9 @@ "id": "3", "metadata": {}, "source": [ - "## Generate Example Data\n", + "## Generate Example Systems\n", "\n", - "First, run the script that generates three example FlowSystems with solutions:" + "First, create three example FlowSystems with solutions:" ] }, { @@ -60,35 +59,19 @@ "metadata": {}, "outputs": [], "source": [ - "# Run the generation script (only needed once, or to regenerate)\n", - "!python data/generate_example_systems.py > /dev/null 2>&1" - ] - }, - { - "cell_type": "markdown", - "id": "5", - "metadata": {}, - "source": [ - "## 1. Loading Saved FlowSystems\n", + "# Create and optimize the example systems\n", + "solver = fx.solvers.HighsSolver(mip_gap=0.01, log_to_console=False)\n", "\n", - "FlowSystems can be saved to and loaded from NetCDF files, preserving the full structure and solution:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [ - "DATA_DIR = Path('data')\n", + "simple = create_simple_system()\n", + "simple.optimize(solver)\n", + "\n", + "complex_sys = create_complex_system()\n", + "complex_sys.optimize(solver)\n", "\n", - "# Load the three example systems\n", - "simple = fx.FlowSystem.from_netcdf(DATA_DIR / 'simple_system.nc4')\n", - "complex_sys = fx.FlowSystem.from_netcdf(DATA_DIR / 'complex_system.nc4')\n", - "multiperiod = fx.FlowSystem.from_netcdf(DATA_DIR / 'multiperiod_system.nc4')\n", + "multiperiod = create_multiperiod_system()\n", + "multiperiod.optimize(solver)\n", "\n", - "print('Loaded systems:')\n", + "print('Created systems:')\n", "print(f' simple: {len(simple.components)} components, {len(simple.buses)} buses')\n", "print(f' complex_sys: {len(complex_sys.components)} components, {len(complex_sys.buses)} buses')\n", "print(f' multiperiod: {len(multiperiod.components)} components, dims={dict(multiperiod.solution.sizes)}')" @@ -96,7 +79,7 @@ }, { "cell_type": "markdown", - "id": "7", + "id": "5", "metadata": {}, "source": [ "## 2. Quick Overview: Balance Plot\n", @@ -107,7 +90,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -117,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "9", + "id": "7", "metadata": {}, "source": [ "### Accessing Plot Data\n", @@ -128,7 +111,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -142,7 +125,7 @@ }, { "cell_type": "markdown", - "id": "11", + "id": "9", "metadata": {}, "source": [ "### Energy Totals\n", @@ -153,7 +136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -167,7 +150,7 @@ }, { "cell_type": "markdown", - "id": "13", + "id": "11", "metadata": {}, "source": [ "## 3. Time Series Plots" @@ -175,7 +158,7 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "12", "metadata": {}, "source": [ "### 3.1 Balance Plot\n", @@ -186,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -196,7 +179,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "14", "metadata": {}, "source": [ "### 3.2 Carrier Balance\n", @@ -207,7 +190,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -217,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -226,7 +209,7 @@ }, { "cell_type": "markdown", - "id": "19", + "id": "17", "metadata": {}, "source": [ "### 3.3 Flow Rates\n", @@ -237,7 +220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -258,7 +241,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "20", "metadata": {}, "source": [ "### 3.4 Storage Plot\n", @@ -269,7 +252,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -278,7 +261,7 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "22", "metadata": {}, "source": [ "### 3.5 Charge States Plot\n", @@ -289,7 +272,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -298,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "26", + "id": "24", "metadata": {}, "source": [ "## 4. Aggregated Plots" @@ -306,7 +289,7 @@ }, { "cell_type": "markdown", - "id": "27", + "id": "25", "metadata": {}, "source": [ "### 4.1 Sizes Plot\n", @@ -317,7 +300,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -326,7 +309,7 @@ }, { "cell_type": "markdown", - "id": "29", + "id": "27", "metadata": {}, "source": [ "### 4.2 Effects Plot\n", @@ -337,7 +320,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +330,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -358,7 +341,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -367,7 +350,7 @@ }, { "cell_type": "markdown", - "id": "33", + "id": "31", "metadata": {}, "source": [ "### 4.3 Duration Curve\n", @@ -378,7 +361,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "32", "metadata": {}, "outputs": [], "source": [ @@ -388,7 +371,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -398,7 +381,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "34", "metadata": {}, "source": [ "## 5. Heatmaps\n", @@ -409,7 +392,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "35", "metadata": {}, "outputs": [], "source": [ @@ -420,7 +403,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -431,7 +414,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "37", "metadata": {}, "outputs": [], "source": [ @@ -441,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "38", "metadata": {}, "source": [ "## 6. Sankey Diagrams\n", @@ -451,7 +434,7 @@ }, { "cell_type": "markdown", - "id": "41", + "id": "39", "metadata": {}, "source": [ "### 6.1 Flow Sankey\n", @@ -462,7 +445,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -472,7 +455,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "41", "metadata": {}, "outputs": [], "source": [ @@ -482,7 +465,7 @@ }, { "cell_type": "markdown", - "id": "44", + "id": "42", "metadata": {}, "source": [ "### 6.2 Sizes Sankey\n", @@ -493,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "45", + "id": "43", "metadata": {}, "outputs": [], "source": [ @@ -502,7 +485,7 @@ }, { "cell_type": "markdown", - "id": "46", + "id": "44", "metadata": {}, "source": [ "### 6.3 Peak Flow Sankey\n", @@ -513,7 +496,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47", + "id": "45", "metadata": {}, "outputs": [], "source": [ @@ -522,7 +505,7 @@ }, { "cell_type": "markdown", - "id": "48", + "id": "46", "metadata": {}, "source": [ "### 6.4 Effects Sankey\n", @@ -533,7 +516,7 @@ { "cell_type": "code", "execution_count": null, - "id": "49", + "id": "47", "metadata": {}, "outputs": [], "source": [ @@ -543,7 +526,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50", + "id": "48", "metadata": {}, "outputs": [], "source": [ @@ -553,7 +536,7 @@ }, { "cell_type": "markdown", - "id": "51", + "id": "49", "metadata": {}, "source": [ "### 6.5 Filtering with `select`\n", @@ -564,7 +547,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52", + "id": "50", "metadata": {}, "outputs": [], "source": [ @@ -574,7 +557,7 @@ }, { "cell_type": "markdown", - "id": "53", + "id": "51", "metadata": {}, "source": [ "## 7. Topology Visualization\n", @@ -584,7 +567,7 @@ }, { "cell_type": "markdown", - "id": "54", + "id": "52", "metadata": {}, "source": [ "### 7.1 Topology Plot\n", @@ -595,7 +578,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55", + "id": "53", "metadata": {}, "outputs": [], "source": [ @@ -605,7 +588,7 @@ { "cell_type": "code", "execution_count": null, - "id": "56", + "id": "54", "metadata": {}, "outputs": [], "source": [ @@ -614,7 +597,7 @@ }, { "cell_type": "markdown", - "id": "57", + "id": "55", "metadata": {}, "source": [ "### 7.2 Topology Info\n", @@ -625,7 +608,7 @@ { "cell_type": "code", "execution_count": null, - "id": "58", + "id": "56", "metadata": {}, "outputs": [], "source": [ @@ -642,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "59", + "id": "57", "metadata": {}, "source": [ "## 8. Multi-Period/Scenario Data\n", @@ -653,7 +636,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60", + "id": "58", "metadata": {}, "outputs": [], "source": [ @@ -666,7 +649,7 @@ { "cell_type": "code", "execution_count": null, - "id": "61", + "id": "59", "metadata": {}, "outputs": [], "source": [ @@ -677,7 +660,7 @@ { "cell_type": "code", "execution_count": null, - "id": "62", + "id": "60", "metadata": {}, "outputs": [], "source": [ @@ -688,7 +671,7 @@ { "cell_type": "code", "execution_count": null, - "id": "63", + "id": "61", "metadata": {}, "outputs": [], "source": [ @@ -698,7 +681,7 @@ }, { "cell_type": "markdown", - "id": "64", + "id": "62", "metadata": {}, "source": [ "## 9. Color Customization\n", @@ -709,7 +692,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65", + "id": "63", "metadata": {}, "outputs": [], "source": [ @@ -720,7 +703,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66", + "id": "64", "metadata": {}, "outputs": [], "source": [ @@ -731,7 +714,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67", + "id": "65", "metadata": {}, "outputs": [], "source": [ @@ -749,7 +732,7 @@ }, { "cell_type": "markdown", - "id": "68", + "id": "66", "metadata": {}, "source": [ "## 10. Exporting Results\n", @@ -760,7 +743,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69", + "id": "67", "metadata": {}, "outputs": [], "source": [ @@ -775,7 +758,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70", + "id": "68", "metadata": {}, "outputs": [], "source": [ @@ -787,7 +770,7 @@ { "cell_type": "code", "execution_count": null, - "id": "71", + "id": "69", "metadata": {}, "outputs": [], "source": [ @@ -800,7 +783,7 @@ }, { "cell_type": "markdown", - "id": "72", + "id": "70", "metadata": {}, "source": [ "## Summary\n", From 3d257abfe684db398643dce8636321189821d79e Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 19:41:09 +0100 Subject: [PATCH 19/62] Add Comparison class --- flixopt/__init__.py | 2 + flixopt/comparison.py | 1038 +++++++++++++++++++++++++++++++++++++++++ flixopt/config.py | 2 +- 3 files changed, 1041 insertions(+), 1 deletion(-) create mode 100644 flixopt/comparison.py diff --git a/flixopt/__init__.py b/flixopt/__init__.py index 3cf219c38..64e08d7ac 100644 --- a/flixopt/__init__.py +++ b/flixopt/__init__.py @@ -18,6 +18,7 @@ # Register xr.Dataset.fxplot accessor (import triggers registration via decorator) from . import dataset_plot_accessor as _ # noqa: F401 from .carrier import Carrier, CarrierContainer +from .comparison import Comparison from .components import ( LinearConverter, Sink, @@ -41,6 +42,7 @@ 'CONFIG', 'Carrier', 'CarrierContainer', + 'Comparison', 'Flow', 'Bus', 'Effect', diff --git a/flixopt/comparison.py b/flixopt/comparison.py new file mode 100644 index 000000000..bde698a44 --- /dev/null +++ b/flixopt/comparison.py @@ -0,0 +1,1038 @@ +"""Compare multiple FlowSystems side-by-side.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +import xarray as xr + +from .config import CONFIG +from .plot_result import PlotResult + +if TYPE_CHECKING: + from .flow_system import FlowSystem + +__all__ = ['Comparison'] + +# Type aliases (matching statistics_accessor.py) +SelectType = dict[str, Any] +FilterType = str | list[str] +ColorType = str | list[str] | dict[str, str] | None + + +class Comparison: + """Compare multiple FlowSystems side-by-side. + + Combines solutions and statistics from multiple FlowSystems into unified + xarray Datasets with a 'case' dimension. The existing plotting infrastructure + automatically handles faceting by the 'case' dimension. + + Args: + flow_systems: List of FlowSystems to compare. + names: Optional names for each case. If None, uses FlowSystem.name. + + Examples: + ```python + # Compare two systems (uses FlowSystem.name by default) + comp = fx.Comparison([fs_base, fs_modified]) + + # Or with custom names + comp = fx.Comparison([fs_base, fs_modified], names=['baseline', 'modified']) + + # Side-by-side plots (auto-facets by 'case') + comp.statistics.plot.balance('Heat') + comp.statistics.flow_rates.fxplot.line() + + # Access combined data + comp.solution # xr.Dataset with 'case' dimension + comp.statistics.flow_rates # xr.Dataset with 'case' dimension + + # Compute differences relative to first case + comp.diff() # Returns xr.Dataset of differences + comp.diff('baseline') # Or specify reference by name + ``` + """ + + def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = None) -> None: + if len(flow_systems) < 2: + raise ValueError('Comparison requires at least 2 FlowSystems') + + self._systems = flow_systems + self._names = names or [fs.name for fs in flow_systems] + + if len(self._names) != len(self._systems): + raise ValueError( + f'Number of names ({len(self._names)}) must match number of FlowSystems ({len(self._systems)})' + ) + + if len(set(self._names)) != len(self._names): + raise ValueError(f'Case names must be unique, got: {self._names}') + + # Caches + self._solution: xr.Dataset | None = None + self._statistics: ComparisonStatistics | None = None + + @property + def names(self) -> list[str]: + """Case names for each FlowSystem.""" + return self._names + + @property + def solution(self) -> xr.Dataset: + """Combined solution Dataset with 'case' dimension.""" + if self._solution is None: + datasets = [] + for fs, name in zip(self._systems, self._names, strict=True): + if fs.solution is None: + raise RuntimeError(f"FlowSystem '{fs.name}' has no solution. Run optimize() first.") + ds = fs.solution.expand_dims(case=[name]) + datasets.append(ds) + self._solution = xr.concat(datasets, dim='case') + return self._solution + + @property + def statistics(self) -> ComparisonStatistics: + """Combined statistics accessor with 'case' dimension.""" + if self._statistics is None: + self._statistics = ComparisonStatistics(self) + return self._statistics + + def diff(self, reference: str | int = 0) -> xr.Dataset: + """Compute differences relative to a reference case. + + Args: + reference: Reference case name or index (default: 0, first case). + + Returns: + Dataset with differences (each case minus reference). + """ + if isinstance(reference, str): + if reference not in self._names: + raise ValueError(f"Reference '{reference}' not found. Available: {self._names}") + ref_idx = self._names.index(reference) + else: + ref_idx = reference + + ref_data = self.solution.isel(case=ref_idx) + return self.solution - ref_data + + +class ComparisonStatistics: + """Combined statistics accessor for comparing FlowSystems. + + Mirrors StatisticsAccessor properties, concatenating data with a 'case' dimension. + Access via ``Comparison.statistics``. + """ + + def __init__(self, comparison: Comparison) -> None: + self._comp = comparison + # Caches for dataset properties + self._flow_rates: xr.Dataset | None = None + self._flow_hours: xr.Dataset | None = None + self._flow_sizes: xr.Dataset | None = None + self._storage_sizes: xr.Dataset | None = None + self._sizes: xr.Dataset | None = None + self._charge_states: xr.Dataset | None = None + self._temporal_effects: xr.Dataset | None = None + self._periodic_effects: xr.Dataset | None = None + self._total_effects: xr.Dataset | None = None + # Caches for dict properties + self._carrier_colors: dict[str, str] | None = None + self._component_colors: dict[str, str] | None = None + self._bus_colors: dict[str, str] | None = None + self._carrier_units: dict[str, str] | None = None + self._effect_units: dict[str, str] | None = None + # Plot accessor + self._plot: ComparisonStatisticsPlot | None = None + + def _concat_property(self, prop_name: str) -> xr.Dataset: + """Concatenate a statistics property across all cases.""" + datasets = [] + for fs, name in zip(self._comp._systems, self._comp._names, strict=True): + ds = getattr(fs.statistics, prop_name) + datasets.append(ds.expand_dims(case=[name])) + return xr.concat(datasets, dim='case') + + def _merge_dict_property(self, prop_name: str) -> dict[str, str]: + """Merge a dict property from all cases (later cases override).""" + result: dict[str, str] = {} + for fs in self._comp._systems: + result.update(getattr(fs.statistics, prop_name)) + return result + + @property + def flow_rates(self) -> xr.Dataset: + """Combined flow rates with 'case' dimension.""" + if self._flow_rates is None: + self._flow_rates = self._concat_property('flow_rates') + return self._flow_rates + + @property + def flow_hours(self) -> xr.Dataset: + """Combined flow hours (energy) with 'case' dimension.""" + if self._flow_hours is None: + self._flow_hours = self._concat_property('flow_hours') + return self._flow_hours + + @property + def flow_sizes(self) -> xr.Dataset: + """Combined flow investment sizes with 'case' dimension.""" + if self._flow_sizes is None: + self._flow_sizes = self._concat_property('flow_sizes') + return self._flow_sizes + + @property + def storage_sizes(self) -> xr.Dataset: + """Combined storage capacity sizes with 'case' dimension.""" + if self._storage_sizes is None: + self._storage_sizes = self._concat_property('storage_sizes') + return self._storage_sizes + + @property + def sizes(self) -> xr.Dataset: + """Combined sizes (flow + storage) with 'case' dimension.""" + if self._sizes is None: + self._sizes = self._concat_property('sizes') + return self._sizes + + @property + def charge_states(self) -> xr.Dataset: + """Combined storage charge states with 'case' dimension.""" + if self._charge_states is None: + self._charge_states = self._concat_property('charge_states') + return self._charge_states + + @property + def temporal_effects(self) -> xr.Dataset: + """Combined temporal effects with 'case' dimension.""" + if self._temporal_effects is None: + self._temporal_effects = self._concat_property('temporal_effects') + return self._temporal_effects + + @property + def periodic_effects(self) -> xr.Dataset: + """Combined periodic effects with 'case' dimension.""" + if self._periodic_effects is None: + self._periodic_effects = self._concat_property('periodic_effects') + return self._periodic_effects + + @property + def total_effects(self) -> xr.Dataset: + """Combined total effects with 'case' dimension.""" + if self._total_effects is None: + self._total_effects = self._concat_property('total_effects') + return self._total_effects + + @property + def carrier_colors(self) -> dict[str, str]: + """Merged carrier colors from all cases.""" + if self._carrier_colors is None: + self._carrier_colors = self._merge_dict_property('carrier_colors') + return self._carrier_colors + + @property + def component_colors(self) -> dict[str, str]: + """Merged component colors from all cases.""" + if self._component_colors is None: + self._component_colors = self._merge_dict_property('component_colors') + return self._component_colors + + @property + def bus_colors(self) -> dict[str, str]: + """Merged bus colors from all cases.""" + if self._bus_colors is None: + self._bus_colors = self._merge_dict_property('bus_colors') + return self._bus_colors + + @property + def carrier_units(self) -> dict[str, str]: + """Merged carrier units from all cases.""" + if self._carrier_units is None: + self._carrier_units = self._merge_dict_property('carrier_units') + return self._carrier_units + + @property + def effect_units(self) -> dict[str, str]: + """Merged effect units from all cases.""" + if self._effect_units is None: + self._effect_units = self._merge_dict_property('effect_units') + return self._effect_units + + @property + def plot(self) -> ComparisonStatisticsPlot: + """Access plot methods for comparison statistics.""" + if self._plot is None: + self._plot = ComparisonStatisticsPlot(self) + return self._plot + + +class ComparisonStatisticsPlot: + """Plot accessor for comparison statistics. + + Mirrors StatisticsPlotAccessor methods, operating on combined data + from multiple FlowSystems. The 'case' dimension is automatically + used for faceting. + """ + + def __init__(self, statistics: ComparisonStatistics) -> None: + self._stats = statistics + self._comp = statistics._comp + + def _get_first_stats_plot(self): + """Get StatisticsPlotAccessor from first FlowSystem for delegation.""" + return self._comp._systems[0].statistics.plot + + def balance( + self, + node: str, + *, + select: SelectType | None = None, + include: FilterType | None = None, + exclude: FilterType | None = None, + unit: Literal['flow_rate', 'flow_hours'] = 'flow_rate', + colors: ColorType = None, + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot node balance comparison across cases. + + See StatisticsPlotAccessor.balance for full documentation. + The 'case' dimension is automatically used for faceting. + """ + from .statistics_accessor import _apply_selection, _filter_by_pattern, _resolve_auto_facets + + # Get flow labels from first system (assumes same topology) + fs = self._comp._systems[0] + if node in fs.buses: + element = fs.buses[node] + elif node in fs.components: + element = fs.components[node] + else: + raise KeyError(f"'{node}' not found in buses or components") + + input_labels = [f.label_full for f in element.inputs] + output_labels = [f.label_full for f in element.outputs] + all_labels = input_labels + output_labels + filtered_labels = _filter_by_pattern(all_labels, include, exclude) + + if not filtered_labels: + import plotly.graph_objects as go + + return PlotResult(data=xr.Dataset(), figure=go.Figure()) + + # Get combined data + if unit == 'flow_rate': + ds = self._stats.flow_rates[[lbl for lbl in filtered_labels if lbl in self._stats.flow_rates]] + else: + ds = self._stats.flow_hours[[lbl for lbl in filtered_labels if lbl in self._stats.flow_hours]] + + # Negate inputs + for label in input_labels: + if label in ds: + ds[label] = -ds[label] + + ds = _apply_selection(ds, select) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + ds, facet_col, facet_row, animation_frame + ) + + # Get unit label + unit_label = '' + if ds.data_vars: + first_var = next(iter(ds.data_vars)) + unit_label = ds[first_var].attrs.get('unit', '') + + fig = ds.fxplot.stacked_bar( + colors=colors, + title=f'{node} [{unit_label}]' if unit_label else node, + facet_col=actual_facet_col, + facet_row=actual_facet_row, + animation_frame=actual_anim, + **plotly_kwargs, + ) + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + return PlotResult(data=ds, figure=fig) + + def carrier_balance( + self, + carrier: str, + *, + select: SelectType | None = None, + include: FilterType | None = None, + exclude: FilterType | None = None, + unit: Literal['flow_rate', 'flow_hours'] = 'flow_rate', + colors: ColorType = None, + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot carrier balance comparison across cases. + + See StatisticsPlotAccessor.carrier_balance for full documentation. + """ + from .statistics_accessor import _apply_selection, _filter_by_pattern, _resolve_auto_facets + + carrier = carrier.lower() + fs = self._comp._systems[0] + + carrier_buses = [bus for bus in fs.buses.values() if bus.carrier == carrier] + if not carrier_buses: + raise KeyError(f"No buses found with carrier '{carrier}'") + + input_labels: list[str] = [] + output_labels: list[str] = [] + for bus in carrier_buses: + for flow in bus.inputs: + input_labels.append(flow.label_full) + for flow in bus.outputs: + output_labels.append(flow.label_full) + + all_labels = input_labels + output_labels + filtered_labels = _filter_by_pattern(all_labels, include, exclude) + + if not filtered_labels: + import plotly.graph_objects as go + + return PlotResult(data=xr.Dataset(), figure=go.Figure()) + + if unit == 'flow_rate': + ds = self._stats.flow_rates[[lbl for lbl in filtered_labels if lbl in self._stats.flow_rates]] + else: + ds = self._stats.flow_hours[[lbl for lbl in filtered_labels if lbl in self._stats.flow_hours]] + + for label in output_labels: + if label in ds: + ds[label] = -ds[label] + + ds = _apply_selection(ds, select) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + ds, facet_col, facet_row, animation_frame + ) + + unit_label = '' + if ds.data_vars: + first_var = next(iter(ds.data_vars)) + unit_label = ds[first_var].attrs.get('unit', '') + + fig = ds.fxplot.stacked_bar( + colors=colors, + title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', + facet_col=actual_facet_col, + facet_row=actual_facet_row, + animation_frame=actual_anim, + **plotly_kwargs, + ) + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + return PlotResult(data=ds, figure=fig) + + def flows( + self, + *, + start: str | list[str] | None = None, + end: str | list[str] | None = None, + component: str | list[str] | None = None, + select: SelectType | None = None, + unit: Literal['flow_rate', 'flow_hours'] = 'flow_rate', + colors: ColorType = None, + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot flows comparison across cases. + + See StatisticsPlotAccessor.flows for full documentation. + """ + from .statistics_accessor import _apply_selection, _resolve_auto_facets + + ds = self._stats.flow_rates if unit == 'flow_rate' else self._stats.flow_hours + fs = self._comp._systems[0] + + if start is not None or end is not None or component is not None: + matching_labels = [] + starts = [start] if isinstance(start, str) else (start or []) + ends = [end] if isinstance(end, str) else (end or []) + components = [component] if isinstance(component, str) else (component or []) + + for flow in fs.flows.values(): + bus_label = flow.bus + comp_label = flow.component + + if flow.is_input_in_component: + if starts and bus_label not in starts: + continue + if ends and comp_label not in ends: + continue + else: + if starts and comp_label not in starts: + continue + if ends and bus_label not in ends: + continue + + if components and comp_label not in components: + continue + matching_labels.append(flow.label_full) + + ds = ds[[lbl for lbl in matching_labels if lbl in ds]] + + ds = _apply_selection(ds, select) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + ds, facet_col, facet_row, animation_frame + ) + + unit_label = '' + if ds.data_vars: + first_var = next(iter(ds.data_vars)) + unit_label = ds[first_var].attrs.get('unit', '') + + fig = ds.fxplot.line( + colors=colors, + title=f'Flows [{unit_label}]' if unit_label else 'Flows', + facet_col=actual_facet_col, + facet_row=actual_facet_row, + animation_frame=actual_anim, + **plotly_kwargs, + ) + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + return PlotResult(data=ds, figure=fig) + + def sizes( + self, + *, + max_size: float | None = 1e6, + select: SelectType | None = None, + colors: ColorType = None, + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot investment sizes comparison across cases. + + See StatisticsPlotAccessor.sizes for full documentation. + """ + import plotly.express as px + + from .color_processing import process_colors + from .statistics_accessor import _apply_selection, _dataset_to_long_df, _resolve_auto_facets + + ds = self._stats.sizes + ds = _apply_selection(ds, select) + + if max_size is not None and ds.data_vars: + valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] + ds = ds[valid_labels] + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + ds, facet_col, facet_row, animation_frame + ) + + df = _dataset_to_long_df(ds) + if df.empty: + import plotly.graph_objects as go + + fig = go.Figure() + else: + variables = df['variable'].unique().tolist() + color_map = process_colors(colors, variables) + fig = px.bar( + df, + x='variable', + y='value', + color='variable', + facet_col=actual_facet_col, + facet_row=actual_facet_row, + animation_frame=actual_anim, + color_discrete_map=color_map, + title='Investment Sizes', + labels={'variable': 'Flow', 'value': 'Size'}, + **plotly_kwargs, + ) + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + return PlotResult(data=ds, figure=fig) + + def duration_curve( + self, + variables: str | list[str], + *, + select: SelectType | None = None, + normalize: bool = False, + colors: ColorType = None, + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot duration curves comparison across cases. + + See StatisticsPlotAccessor.duration_curve for full documentation. + """ + import numpy as np + + from .statistics_accessor import _apply_selection, _resolve_auto_facets + + if isinstance(variables, str): + variables = [variables] + + flow_rates = self._stats.flow_rates + solution = self._comp.solution + + normalized_vars = [] + for var in variables: + if var.endswith('|flow_rate'): + var = var[: -len('|flow_rate')] + normalized_vars.append(var) + + ds_parts = [] + for var in normalized_vars: + if var in flow_rates: + ds_parts.append(flow_rates[[var]]) + elif var in solution: + ds_parts.append(solution[[var]]) + else: + flow_rate_var = f'{var}|flow_rate' + if flow_rate_var in solution: + ds_parts.append(solution[[flow_rate_var]].rename({flow_rate_var: var})) + else: + raise KeyError(f"Variable '{var}' not found in flow_rates or solution") + + ds = xr.merge(ds_parts) + ds = _apply_selection(ds, select) + + if 'time' not in ds.dims: + raise ValueError('Duration curve requires time dimension') + + def sort_descending(arr: np.ndarray) -> np.ndarray: + return np.sort(arr)[::-1] + + result_ds = xr.apply_ufunc( + sort_descending, + ds, + input_core_dims=[['time']], + output_core_dims=[['time']], + vectorize=True, + ) + + duration_name = 'duration_pct' if normalize else 'duration' + result_ds = result_ds.rename({'time': duration_name}) + + n_timesteps = result_ds.sizes[duration_name] + duration_coord = np.linspace(0, 100, n_timesteps) if normalize else np.arange(n_timesteps) + result_ds = result_ds.assign_coords({duration_name: duration_coord}) + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + result_ds, facet_col, facet_row, animation_frame + ) + + unit_label = '' + if ds.data_vars: + first_var = next(iter(ds.data_vars)) + unit_label = ds[first_var].attrs.get('unit', '') + + fig = result_ds.fxplot.line( + colors=colors, + title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', + facet_col=actual_facet_col, + facet_row=actual_facet_row, + animation_frame=actual_anim, + **plotly_kwargs, + ) + + x_label = 'Duration [%]' if normalize else 'Timesteps' + fig.update_xaxes(title_text=x_label) + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + return PlotResult(data=result_ds, figure=fig) + + def effects( + self, + aspect: Literal['total', 'temporal', 'periodic'] = 'total', + *, + effect: str | None = None, + by: Literal['component', 'contributor', 'time'] | None = None, + select: SelectType | None = None, + colors: ColorType = None, + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot effects comparison across cases. + + See StatisticsPlotAccessor.effects for full documentation. + """ + import plotly.express as px + + from .color_processing import process_colors + from .statistics_accessor import _apply_selection, _resolve_auto_facets + + if aspect == 'total': + effects_ds = self._stats.total_effects + elif aspect == 'temporal': + effects_ds = self._stats.temporal_effects + elif aspect == 'periodic': + effects_ds = self._stats.periodic_effects + else: + raise ValueError(f"Aspect '{aspect}' not valid. Choose from 'total', 'temporal', 'periodic'.") + + available_effects = list(effects_ds.data_vars) + + if effect is not None: + if effect not in available_effects: + raise ValueError(f"Effect '{effect}' not found. Available: {available_effects}") + effects_to_plot = [effect] + else: + effects_to_plot = available_effects + + effect_arrays = [] + for eff in effects_to_plot: + da = effects_ds[eff] + if by == 'contributor': + effect_arrays.append(da.expand_dims(effect=[eff])) + else: + da_grouped = da.groupby('component').sum() + effect_arrays.append(da_grouped.expand_dims(effect=[eff])) + + combined = xr.concat(effect_arrays, dim='effect') + combined = _apply_selection(combined.to_dataset(name='value'), select)['value'] + + if by is None: + if 'time' in combined.dims: + combined = combined.sum(dim='time') + if 'component' in combined.dims: + combined = combined.sum(dim='component') + if 'contributor' in combined.dims: + combined = combined.sum(dim='contributor') + x_col = 'effect' + color_col = 'effect' + elif by == 'component': + if 'time' in combined.dims: + combined = combined.sum(dim='time') + x_col = 'component' + color_col = 'effect' if len(effects_to_plot) > 1 else 'component' + elif by == 'contributor': + if 'time' in combined.dims: + combined = combined.sum(dim='time') + x_col = 'contributor' + color_col = 'effect' if len(effects_to_plot) > 1 else 'contributor' + elif by == 'time': + if 'time' not in combined.dims: + raise ValueError(f"Cannot plot by 'time' for aspect '{aspect}' - no time dimension.") + if 'component' in combined.dims: + combined = combined.sum(dim='component') + if 'contributor' in combined.dims: + combined = combined.sum(dim='contributor') + x_col = 'time' + color_col = 'effect' if len(effects_to_plot) > 1 else None + else: + raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + combined.to_dataset(name='value'), facet_col, facet_row, animation_frame + ) + + df = combined.to_dataframe(name='value').reset_index() + + if color_col and color_col in df.columns: + color_items = df[color_col].unique().tolist() + color_map = process_colors(colors, color_items) + else: + color_map = None + + effect_label = effect if effect else 'Effects' + if effect and effect in effects_ds: + unit_label = effects_ds[effect].attrs.get('unit', '') + title = f'{effect_label} [{unit_label}]' if unit_label else effect_label + else: + title = effect_label + title = f'{title} ({aspect})' if by is None else f'{title} ({aspect}) by {by}' + + fig = px.bar( + df, + x=x_col, + y='value', + color=color_col, + color_discrete_map=color_map, + facet_col=actual_facet_col, + facet_row=actual_facet_row, + animation_frame=actual_anim, + title=title, + **plotly_kwargs, + ) + fig.update_layout(bargap=0, bargroupgap=0) + fig.update_traces(marker_line_width=0) + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + return PlotResult(data=combined.to_dataset(name=aspect), figure=fig) + + def charge_states( + self, + storages: str | list[str] | None = None, + *, + select: SelectType | None = None, + colors: ColorType = None, + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot charge states comparison across cases. + + See StatisticsPlotAccessor.charge_states for full documentation. + """ + from .statistics_accessor import _apply_selection, _resolve_auto_facets + + ds = self._stats.charge_states + + if storages is not None: + if isinstance(storages, str): + storages = [storages] + ds = ds[[s for s in storages if s in ds]] + + ds = _apply_selection(ds, select) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + ds, facet_col, facet_row, animation_frame + ) + + fig = ds.fxplot.line( + colors=colors, + title='Storage Charge States', + facet_col=actual_facet_col, + facet_row=actual_facet_row, + animation_frame=actual_anim, + **plotly_kwargs, + ) + fig.update_yaxes(title_text='Charge State') + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + return PlotResult(data=ds, figure=fig) + + def heatmap( + self, + variables: str | list[str], + *, + select: SelectType | None = None, + reshape: tuple[str, str] | Literal['auto'] | None = 'auto', + colors: str | list[str] | None = None, + facet_col: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = None, + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot heatmap comparison across cases. + + See StatisticsPlotAccessor.heatmap for full documentation. + """ + import pandas as pd + + from .statistics_accessor import _apply_selection, _reshape_time_for_heatmap, _resolve_auto_facets + + solution = self._comp.solution + + if isinstance(variables, str): + variables = [variables] + + # Resolve flow labels + resolved_variables = [] + for var in variables: + if var in solution: + resolved_variables.append(var) + elif '|' not in var: + flow_rate_var = f'{var}|flow_rate' + charge_state_var = f'{var}|charge_state' + if flow_rate_var in solution: + resolved_variables.append(flow_rate_var) + elif charge_state_var in solution: + resolved_variables.append(charge_state_var) + else: + resolved_variables.append(var) + else: + resolved_variables.append(var) + + ds = solution[resolved_variables] + ds = _apply_selection(ds, select) + + variable_names = list(ds.data_vars) + dataarrays = [ds[var] for var in variable_names] + da = xr.concat(dataarrays, dim=pd.Index(variable_names, name='variable')) + + is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 + has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 + + if has_multiple_vars: + actual_facet = 'variable' + _, _, actual_animation = _resolve_auto_facets(da.to_dataset(name='value'), None, None, animation_frame) + if actual_animation == 'variable': + actual_animation = None + else: + actual_facet, _, actual_animation = _resolve_auto_facets( + da.to_dataset(name='value'), facet_col, None, animation_frame + ) + + if is_clustered and (reshape == 'auto' or reshape is None): + heatmap_dims = ['time', 'cluster'] + elif reshape and reshape != 'auto' and 'time' in da.dims: + da = _reshape_time_for_heatmap(da, reshape) + heatmap_dims = ['timestep', 'timeframe'] + elif reshape == 'auto' and 'time' in da.dims and not is_clustered: + da = _reshape_time_for_heatmap(da, ('D', 'h')) + heatmap_dims = ['timestep', 'timeframe'] + elif has_multiple_vars: + heatmap_dims = ['variable', 'time'] + actual_facet, _, actual_animation = _resolve_auto_facets( + da.to_dataset(name='value'), facet_col, None, animation_frame + ) + else: + available_dims = [d for d in da.dims if da.sizes[d] > 1] + if len(available_dims) >= 2: + heatmap_dims = available_dims[:2] + elif 'time' in da.dims: + heatmap_dims = ['time'] + else: + heatmap_dims = list(da.dims)[:1] + + keep_dims = set(heatmap_dims) | {d for d in [actual_facet, actual_animation] if d is not None} + for dim in [d for d in da.dims if d not in keep_dims]: + da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) + + dim_order = heatmap_dims + [d for d in [actual_facet, actual_animation] if d] + da = da.transpose(*dim_order) + + if has_multiple_vars: + da = da.rename('') + + fig = da.fxplot.heatmap( + colors=colors, + facet_col=actual_facet, + animation_frame=actual_animation, + **plotly_kwargs, + ) + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + reshaped_ds = da.to_dataset(name='value') if isinstance(da, xr.DataArray) else da + return PlotResult(data=reshaped_ds, figure=fig) + + def storage( + self, + storage: str, + *, + select: SelectType | None = None, + unit: Literal['flow_rate', 'flow_hours'] = 'flow_rate', + colors: ColorType = None, + charge_state_color: str = 'black', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', + show: bool | None = None, + **plotly_kwargs: Any, + ) -> PlotResult: + """Plot storage operation comparison across cases. + + See StatisticsPlotAccessor.storage for full documentation. + """ + # Delegate to first system's plot method for now (complex subplot logic) + # This is a simplification - for full support would need to reimplement + from .statistics_accessor import _apply_selection, _resolve_auto_facets + + fs = self._comp._systems[0] + if storage not in fs.components: + raise KeyError(f"Storage '{storage}' not found in components") + + from .components import Storage as StorageComponent + + comp = fs.components[storage] + if not isinstance(comp, StorageComponent): + raise ValueError(f"'{storage}' is not a Storage component") + + # Get combined data + input_labels = [f.label_full for f in comp.inputs] + output_labels = [f.label_full for f in comp.outputs] + all_labels = input_labels + output_labels + + if unit == 'flow_rate': + ds = self._stats.flow_rates[[lbl for lbl in all_labels if lbl in self._stats.flow_rates]] + else: + ds = self._stats.flow_hours[[lbl for lbl in all_labels if lbl in self._stats.flow_hours]] + + for label in input_labels: + if label in ds: + ds[label] = -ds[label] + + charge_ds = self._stats.charge_states[[storage]] if storage in self._stats.charge_states else None + + ds = _apply_selection(ds, select) + if charge_ds is not None: + charge_ds = _apply_selection(charge_ds, select) + + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + ds, facet_col, facet_row, animation_frame + ) + + # Create stacked bar for flows + fig = ds.fxplot.stacked_bar( + colors=colors, + title=f'{storage} Operation', + facet_col=actual_facet_col, + facet_row=actual_facet_row, + animation_frame=actual_anim, + **plotly_kwargs, + ) + + if show is None: + show = CONFIG.Plotting.default_show + if show: + fig.show() + + # Combine data + if charge_ds is not None: + combined_ds = xr.merge([ds, charge_ds]) + else: + combined_ds = ds + + return PlotResult(data=combined_ds, figure=fig) diff --git a/flixopt/config.py b/flixopt/config.py index 454f8ad3e..afa525713 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -164,7 +164,7 @@ def format(self, record): 'default_sequential_colorscale': 'turbo', 'default_qualitative_colorscale': 'plotly', 'default_line_shape': 'hv', - 'extra_dim_priority': ('cluster', 'period', 'scenario'), + 'extra_dim_priority': ('case', 'cluster', 'period', 'scenario'), 'dim_slot_priority': ('facet_col', 'facet_row', 'animation_frame'), 'x_dim_priority': ('time', 'duration', 'duration_pct', 'period', 'scenario', 'cluster'), } From 64cb85772894c516d2355a5b6a5e9a86ae25757a Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 19:42:47 +0100 Subject: [PATCH 20/62] Add Release notes --- CHANGELOG.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20f2de7d7..5731b0726 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,33 @@ Until here --> ### ✨ Added +**FlowSystem Comparison**: New `Comparison` class for comparing multiple FlowSystems side-by-side: + +```python +# Compare systems (uses FlowSystem.name by default) +comp = fx.Comparison([fs_base, fs_modified]) + +# Or with custom names +comp = fx.Comparison([fs1, fs2, fs3], names=['baseline', 'low_cost', 'high_eff']) + +# Side-by-side plots (auto-facets by 'case' dimension) +comp.statistics.plot.balance('Heat') +comp.statistics.flow_rates.fxplot.line() + +# Access combined data with 'case' dimension +comp.solution # xr.Dataset +comp.statistics.flow_rates # xr.Dataset + +# Compute differences relative to a reference case +comp.diff() # vs first case +comp.diff('baseline') # vs named case +``` + +- Concatenates solutions and statistics from multiple FlowSystems with a `'case'` dimension +- Mirrors all `StatisticsAccessor` properties (`flow_rates`, `flow_hours`, `sizes`, `charge_states`, `temporal_effects`, `periodic_effects`, `total_effects`) +- Mirrors all `StatisticsPlotAccessor` methods (`balance`, `carrier_balance`, `flows`, `sizes`, `duration_curve`, `effects`, `charge_states`, `heatmap`, `storage`) +- Existing plotting infrastructure automatically handles faceting by `'case'` + **Time-Series Clustering**: Reduce large time series to representative typical periods for faster investment optimization, then expand results back to full resolution. ```python From 3e3c6173a2c69ead5763ce9419309fe6b6eb0207 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 20:01:00 +0100 Subject: [PATCH 21/62] Add Comparison class to all Notebooks --- .../04-operational-constraints.ipynb | 37 +++++++++++++++--- docs/notebooks/05-multi-carrier-system.ipynb | 39 +++++++++++++++++-- docs/notebooks/07-scenarios-and-periods.ipynb | 1 + docs/notebooks/08a-aggregation.ipynb | 30 ++++++-------- docs/notebooks/08b-rolling-horizon.ipynb | 39 ++++++++----------- docs/notebooks/08c-clustering.ipynb | 25 +++--------- .../08c2-clustering-storage-modes.ipynb | 31 +++++++++++++-- .../08d-clustering-multiperiod.ipynb | 24 +++++++++--- 8 files changed, 150 insertions(+), 76 deletions(-) diff --git a/docs/notebooks/04-operational-constraints.ipynb b/docs/notebooks/04-operational-constraints.ipynb index fbb611d1c..8e171ff7a 100644 --- a/docs/notebooks/04-operational-constraints.ipynb +++ b/docs/notebooks/04-operational-constraints.ipynb @@ -126,7 +126,7 @@ "metadata": {}, "outputs": [], "source": [ - "flow_system = fx.FlowSystem(timesteps)\n", + "flow_system = fx.FlowSystem(timesteps, name='Constrained')\n", "\n", "# Define and register custom carriers\n", "flow_system.add_carriers(\n", @@ -347,7 +347,7 @@ "outputs": [], "source": [ "# Build unconstrained system\n", - "fs_unconstrained = fx.FlowSystem(timesteps)\n", + "fs_unconstrained = fx.FlowSystem(timesteps, name='Unconstrained')\n", "fs_unconstrained.add_carriers(\n", " fx.Carrier('gas', '#3498db', 'kW'),\n", " fx.Carrier('steam', '#87CEEB', 'kW_th', 'Process steam'),\n", @@ -385,6 +385,27 @@ "cell_type": "markdown", "id": "24", "metadata": {}, + "source": [ + "### Side-by-Side Comparison\n", + "\n", + "Use the `Comparison` class to visualize both systems together:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "comp = fx.Comparison([fs_unconstrained, flow_system])\n", + "comp.statistics.plot.effects()" + ] + }, + { + "cell_type": "markdown", + "id": "26", + "metadata": {}, "source": [ "### Energy Flow Sankey\n", "\n", @@ -394,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -403,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "26", + "id": "28", "metadata": {}, "source": [ "## Key Concepts\n", @@ -455,7 +476,13 @@ ] } ], - "metadata": {}, + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, "nbformat": 4, "nbformat_minor": 5 } diff --git a/docs/notebooks/05-multi-carrier-system.ipynb b/docs/notebooks/05-multi-carrier-system.ipynb index a1a9543fa..d4b36b79c 100644 --- a/docs/notebooks/05-multi-carrier-system.ipynb +++ b/docs/notebooks/05-multi-carrier-system.ipynb @@ -181,7 +181,7 @@ "metadata": {}, "outputs": [], "source": [ - "flow_system = fx.FlowSystem(timesteps)\n", + "flow_system = fx.FlowSystem(timesteps, name='With CHP')\n", "flow_system.add_carriers(\n", " fx.Carrier('gas', '#3498db', 'kW'),\n", " fx.Carrier('electricity', '#f1c40f', 'kW'),\n", @@ -415,7 +415,7 @@ "outputs": [], "source": [ "# Build system without CHP\n", - "fs_no_chp = fx.FlowSystem(timesteps)\n", + "fs_no_chp = fx.FlowSystem(timesteps, name='No CHP')\n", "fs_no_chp.add_carriers(\n", " fx.Carrier('gas', '#3498db', 'kW'),\n", " fx.Carrier('electricity', '#f1c40f', 'kW'),\n", @@ -468,6 +468,37 @@ "cell_type": "markdown", "id": "24", "metadata": {}, + "source": [ + "### Side-by-Side Comparison\n", + "\n", + "Use the `Comparison` class to visualize both systems together:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "comp = fx.Comparison([fs_no_chp, flow_system])\n", + "comp.statistics.plot.balance('Electricity')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "comp.statistics.plot.balance('Heat')" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": {}, "source": [ "### Energy Flow Sankey\n", "\n", @@ -477,7 +508,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -486,7 +517,7 @@ }, { "cell_type": "markdown", - "id": "26", + "id": "29", "metadata": {}, "source": [ "## Key Concepts\n", diff --git a/docs/notebooks/07-scenarios-and-periods.ipynb b/docs/notebooks/07-scenarios-and-periods.ipynb index db74afefb..f3cb90c61 100644 --- a/docs/notebooks/07-scenarios-and-periods.ipynb +++ b/docs/notebooks/07-scenarios-and-periods.ipynb @@ -215,6 +215,7 @@ " periods=periods,\n", " scenarios=scenarios,\n", " scenario_weights=scenario_weights,\n", + " name='Both Scenarios',\n", ")\n", "flow_system.add_carriers(\n", " fx.Carrier('gas', '#3498db', 'kW'),\n", diff --git a/docs/notebooks/08a-aggregation.ipynb b/docs/notebooks/08a-aggregation.ipynb index 410cd1715..7c2633aa6 100644 --- a/docs/notebooks/08a-aggregation.ipynb +++ b/docs/notebooks/08a-aggregation.ipynb @@ -172,6 +172,7 @@ "# Stage 2: Dispatch at full resolution with fixed sizes\n", "start = timeit.default_timer()\n", "fs_dispatch = flow_system.transform.fix_sizes(fs_sizing.statistics.sizes)\n", + "fs_dispatch.name = 'Two-Stage'\n", "fs_dispatch.optimize(solver)\n", "time_stage2 = timeit.default_timer() - start\n", "\n", @@ -199,6 +200,7 @@ "source": [ "start = timeit.default_timer()\n", "fs_full = flow_system.copy()\n", + "fs_full.name = 'Full Optimization'\n", "fs_full.optimize(solver)\n", "time_full = timeit.default_timer() - start\n", "\n", @@ -271,7 +273,9 @@ "id": "16", "metadata": {}, "source": [ - "## Visual Comparison: Heat Balance" + "markdown## Visual Comparison: Heat Balance\n", + "\n", + "Compare the full optimization with the two-stage approach side-by-side:" ] }, { @@ -281,24 +285,14 @@ "metadata": {}, "outputs": [], "source": [ - "# Full optimization heat balance\n", - "fs_full.statistics.plot.balance('Heat')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": {}, - "outputs": [], - "source": [ - "# Two-stage optimization heat balance\n", - "fs_dispatch.statistics.plot.balance('Heat')" + "# Side-by-side comparison of full optimization vs two-stage\n", + "comp = fx.Comparison([fs_full, fs_dispatch])\n", + "comp.statistics.plot.balance('Heat')" ] }, { "cell_type": "markdown", - "id": "19", + "id": "18", "metadata": {}, "source": [ "### Energy Flow Sankey (Full Optimization)\n", @@ -309,7 +303,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -318,7 +312,7 @@ }, { "cell_type": "markdown", - "id": "21", + "id": "20", "metadata": {}, "source": [ "## When to Use Each Technique\n", @@ -358,7 +352,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "21", "metadata": {}, "source": [ "## Summary\n", diff --git a/docs/notebooks/08b-rolling-horizon.ipynb b/docs/notebooks/08b-rolling-horizon.ipynb index 5032588fe..0de5e1657 100644 --- a/docs/notebooks/08b-rolling-horizon.ipynb +++ b/docs/notebooks/08b-rolling-horizon.ipynb @@ -94,6 +94,7 @@ "\n", "start = timeit.default_timer()\n", "fs_full = flow_system.copy()\n", + "fs_full.name = 'Full Optimization'\n", "fs_full.optimize(solver)\n", "time_full = timeit.default_timer() - start\n", "\n", @@ -133,6 +134,7 @@ "source": [ "start = timeit.default_timer()\n", "fs_rolling = flow_system.copy()\n", + "fs_rolling.name = 'Rolling Horizon'\n", "segments = fs_rolling.optimize.rolling_horizon(\n", " solver,\n", " horizon=192, # 2-day segments (192 timesteps at 15-min resolution)\n", @@ -179,7 +181,9 @@ "id": "11", "metadata": {}, "source": [ - "## Visualize: Heat Balance Comparison" + "markdown## Visualize: Heat Balance Comparison\n", + "\n", + "Use the `Comparison` class to view both methods side-by-side:" ] }, { @@ -189,22 +193,13 @@ "metadata": {}, "outputs": [], "source": [ - "fs_full.statistics.plot.balance('Heat').figure.update_layout(title='Heat Balance (Full)')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": {}, - "outputs": [], - "source": [ - "fs_rolling.statistics.plot.balance('Heat').figure.update_layout(title='Heat Balance (Rolling)')" + "comp = fx.Comparison([fs_full, fs_rolling])\n", + "comp.statistics.plot.balance('Heat')" ] }, { "cell_type": "markdown", - "id": "14", + "id": "13", "metadata": {}, "source": [ "## Storage State Continuity\n", @@ -215,7 +210,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -239,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "15", "metadata": {}, "source": [ "## Inspect Individual Segments\n", @@ -250,7 +245,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -263,7 +258,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "17", "metadata": {}, "source": [ "## Visualize Segment Overlaps\n", @@ -274,7 +269,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -291,7 +286,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -303,7 +298,7 @@ }, { "cell_type": "markdown", - "id": "21", + "id": "20", "metadata": {}, "source": [ "## When to Use Rolling Horizon\n", @@ -324,7 +319,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "21", "metadata": {}, "source": [ "## API Reference\n", @@ -348,7 +343,7 @@ }, { "cell_type": "markdown", - "id": "23", + "id": "22", "metadata": {}, "source": [ "## Summary\n", diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index acd30ea94..435338005 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -108,6 +108,7 @@ "\n", "start = timeit.default_timer()\n", "fs_full = flow_system.copy()\n", + "fs_full.name = 'Full Optimization'\n", "fs_full.optimize(solver)\n", "time_full = timeit.default_timer() - start\n", "\n", @@ -265,6 +266,7 @@ "start = timeit.default_timer()\n", "\n", "fs_dispatch = flow_system.transform.fix_sizes(sizes_with_margin)\n", + "fs_dispatch.name = 'Two-Stage'\n", "fs_dispatch.optimize(solver)\n", "\n", "time_dispatch = timeit.default_timer() - start\n", @@ -355,6 +357,7 @@ "source": [ "# Expand the clustered solution to full resolution\n", "fs_expanded = fs_clustered.transform.expand_solution()\n", + "fs_expanded.name = 'Expanded from Clustering'\n", "\n", "print(f'Expanded: {len(fs_clustered.timesteps)} → {len(fs_expanded.timesteps)} timesteps')\n", "print(f'Cost: {fs_expanded.solution[\"costs\"].item():,.0f} €')" @@ -367,25 +370,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Compare heat balance: Full vs Expanded\n", - "fig = make_subplots(rows=2, cols=1, shared_xaxes=True, subplot_titles=['Full Optimization', 'Expanded from Clustering'])\n", - "\n", - "# Full\n", - "for var in ['CHP(Q_th)', 'Boiler(Q_th)']:\n", - " values = fs_full.solution[f'{var}|flow_rate'].values\n", - " fig.add_trace(go.Scatter(x=fs_full.timesteps, y=values, name=var, legendgroup=var, showlegend=True), row=1, col=1)\n", - "\n", - "# Expanded\n", - "for var in ['CHP(Q_th)', 'Boiler(Q_th)']:\n", - " values = fs_expanded.solution[f'{var}|flow_rate'].values\n", - " fig.add_trace(\n", - " go.Scatter(x=fs_expanded.timesteps, y=values, name=var, legendgroup=var, showlegend=False), row=2, col=1\n", - " )\n", - "\n", - "fig.update_layout(height=500, title='Heat Production Comparison')\n", - "fig.update_yaxes(title_text='MW', row=1, col=1)\n", - "fig.update_yaxes(title_text='MW', row=2, col=1)\n", - "fig.show()" + "# Compare heat balance: Full vs Expanded using Comparison class\n", + "comp = fx.Comparison([fs_full, fs_expanded])\n", + "comp.statistics.plot.balance('Heat')" ] }, { diff --git a/docs/notebooks/08c2-clustering-storage-modes.ipynb b/docs/notebooks/08c2-clustering-storage-modes.ipynb index c99d25dbd..2c01b16eb 100644 --- a/docs/notebooks/08c2-clustering-storage-modes.ipynb +++ b/docs/notebooks/08c2-clustering-storage-modes.ipynb @@ -138,6 +138,7 @@ "\n", "start = timeit.default_timer()\n", "fs_full = flow_system.copy()\n", + "fs_full.name = 'Full Optimization'\n", "fs_full.optimize(solver)\n", "time_full = timeit.default_timer() - start\n", "\n", @@ -273,7 +274,9 @@ "# Expand clustered solutions to full resolution\n", "expanded_systems = {}\n", "for mode in storage_modes:\n", - " expanded_systems[mode] = clustered_systems[mode].transform.expand_solution()" + " fs_expanded = clustered_systems[mode].transform.expand_solution()\n", + " fs_expanded.name = f'Mode: {mode}'\n", + " expanded_systems[mode] = fs_expanded" ] }, { @@ -312,6 +315,28 @@ "cell_type": "markdown", "id": "14", "metadata": {}, + "source": [ + "### Side-by-Side Comparison\n", + "\n", + "Use the `Comparison` class to compare the full optimization with the recommended mode:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "# Compare full optimization with the recommended intercluster_cyclic mode\n", + "comp = fx.Comparison([fs_full, expanded_systems['intercluster_cyclic']])\n", + "comp.statistics.plot.balance('Heat')" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, "source": [ "## Interpretation\n", "\n", @@ -339,7 +364,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "17", "metadata": {}, "source": [ "## When to Use Each Mode\n", @@ -374,7 +399,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "18", "metadata": {}, "source": [ "## Summary\n", diff --git a/docs/notebooks/08d-clustering-multiperiod.ipynb b/docs/notebooks/08d-clustering-multiperiod.ipynb index 31c47ff38..0932f4754 100644 --- a/docs/notebooks/08d-clustering-multiperiod.ipynb +++ b/docs/notebooks/08d-clustering-multiperiod.ipynb @@ -141,6 +141,7 @@ "\n", "start = timeit.default_timer()\n", "fs_full = flow_system.copy()\n", + "fs_full.name = 'Full Optimization'\n", "fs_full.optimize(solver)\n", "time_full = timeit.default_timer() - start\n", "\n", @@ -346,6 +347,7 @@ "start = timeit.default_timer()\n", "\n", "fs_dispatch = flow_system.transform.fix_sizes(sizes_with_margin)\n", + "fs_dispatch.name = 'Two-Stage'\n", "fs_dispatch.optimize(solver)\n", "\n", "time_dispatch = timeit.default_timer() - start\n", @@ -435,9 +437,21 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "25", "metadata": {}, + "outputs": [], + "source": [ + "# Side-by-side comparison using the Comparison class\n", + "comp = fx.Comparison([fs_full, fs_dispatch])\n", + "comp.statistics.plot.balance('Heat')" + ] + }, + { + "cell_type": "markdown", + "id": "26", + "metadata": {}, "source": [ "## Expand Clustered Solution to Full Resolution\n", "\n", @@ -447,7 +461,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -461,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -471,7 +485,7 @@ }, { "cell_type": "markdown", - "id": "28", + "id": "29", "metadata": {}, "source": [ "## Key Considerations for Multi-Period Clustering\n", @@ -505,7 +519,7 @@ }, { "cell_type": "markdown", - "id": "29", + "id": "30", "metadata": {}, "source": [ "## Summary\n", From ff75419dab13605800a156a81143432058fae4d6 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 20:19:27 +0100 Subject: [PATCH 22/62] Update comparison.py and add documentation --- docs/user-guide/results/index.md | 146 +++++++++ flixopt/comparison.py | 488 ++++++++++--------------------- 2 files changed, 304 insertions(+), 330 deletions(-) diff --git a/docs/user-guide/results/index.md b/docs/user-guide/results/index.md index a9b40f7f9..0137197ea 100644 --- a/docs/user-guide/results/index.md +++ b/docs/user-guide/results/index.md @@ -277,6 +277,152 @@ flow_system.statistics.plot.heatmap('Boiler(Q_th)|flow_rate') flow_system.to_netcdf('results/optimized_system.nc') ``` +## Comparing Multiple Systems + +Use the [`Comparison`][flixopt.comparison.Comparison] class to analyze and visualize multiple FlowSystems side-by-side. This is useful for: + +- Comparing different design alternatives (with/without CHP, different storage sizes) +- Analyzing optimization method trade-offs (full vs. two-stage, different aggregation levels) +- Sensitivity analysis (different scenarios, parameter variations) + +### Basic Usage + +```python +import flixopt as fx + +# Optimize two system variants +fs_baseline = create_system() +fs_baseline.name = 'Baseline' +fs_baseline.optimize(solver) + +fs_with_storage = create_system_with_storage() +fs_with_storage.name = 'With Storage' +fs_with_storage.optimize(solver) + +# Create comparison +comp = fx.Comparison([fs_baseline, fs_with_storage]) + +# Side-by-side balance plots (auto-faceted by 'case' dimension) +comp.statistics.plot.balance('Heat') + +# Access combined data with 'case' dimension +comp.statistics.flow_rates # xr.Dataset with dims: (time, case) +comp.solution # Combined solution dataset +``` + +### Requirements + +All FlowSystems in a comparison must have **matching dimensions** (time, period, scenario, etc.). If dimensions differ, use `.transform.sel()` to align them first: + +```python +# Systems with different scenarios +fs_both = flow_system # Has 'Mild Winter' and 'Harsh Winter' scenarios +fs_mild = flow_system.transform.sel(scenario='Mild Winter') # Single scenario + +# Cannot compare directly - dimension mismatch! +# fx.Comparison([fs_both, fs_mild]) # Raises ValueError + +# Instead, select matching dimensions +fs_both_mild = fs_both.transform.sel(scenario='Mild Winter') +comp = fx.Comparison([fs_both_mild, fs_mild]) # Works! +``` + +### Available Properties + +The `Comparison.statistics` accessor mirrors all `StatisticsAccessor` properties, returning combined datasets with an added `'case'` dimension: + +| Property | Description | +|----------|-------------| +| `flow_rates` | All flow rate variables | +| `flow_hours` | Flow hours (energy) | +| `sizes` | Component sizes | +| `storage_sizes` | Storage capacities | +| `charge_states` | Storage charge states | +| `temporal_effects` | Effects per timestep | +| `periodic_effects` | Investment effects | +| `total_effects` | Combined effects | + +### Available Plot Methods + +All standard plot methods work on the comparison, with the `'case'` dimension automatically used for faceting: + +```python +comp = fx.Comparison([fs_baseline, fs_modified]) + +# Balance plots - faceted by case +comp.statistics.plot.balance('Heat') +comp.statistics.plot.balance('Electricity', mode='area') + +# Flow plots +comp.statistics.plot.flows(component='CHP') + +# Effect breakdowns +comp.statistics.plot.effects() + +# Heatmaps +comp.statistics.plot.heatmap('Boiler(Q_th)') + +# Duration curves +comp.statistics.plot.duration_curve('CHP(Q_th)') + +# Storage plots +comp.statistics.plot.storage('Battery') +``` + +### Computing Differences + +Use the `diff()` method to compute differences relative to a reference case: + +```python +# Differences relative to first case (default) +differences = comp.diff() + +# Differences relative to specific case +differences = comp.diff(reference='Baseline') +differences = comp.diff(reference=0) # By index + +# Analyze differences +print(differences['costs']) # Cost difference per case +``` + +### Naming Systems + +System names come from `FlowSystem.name` by default. Override with the `names` parameter: + +```python +# Using FlowSystem.name (default) +fs1.name = 'Scenario A' +fs2.name = 'Scenario B' +comp = fx.Comparison([fs1, fs2]) + +# Or override explicitly +comp = fx.Comparison([fs1, fs2], names=['Base Case', 'Alternative']) +``` + +### Example: Comparing Optimization Methods + +```python +# Full optimization +fs_full = flow_system.copy() +fs_full.name = 'Full Optimization' +fs_full.optimize(solver) + +# Two-stage optimization +fs_sizing = flow_system.transform.resample('4h') +fs_sizing.optimize(solver) +fs_dispatch = flow_system.transform.fix_sizes(fs_sizing.statistics.sizes) +fs_dispatch.name = 'Two-Stage' +fs_dispatch.optimize(solver) + +# Compare results +comp = fx.Comparison([fs_full, fs_dispatch]) +comp.statistics.plot.balance('Heat') + +# Check cost difference +diff = comp.diff() +print(f"Cost difference: {diff['costs'].sel(case='Two-Stage').item():.0f} €") +``` + ## Next Steps - [Plotting Results](../results-plotting.md) - Detailed plotting documentation diff --git a/flixopt/comparison.py b/flixopt/comparison.py index bde698a44..899048bbb 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -27,10 +27,18 @@ class Comparison: xarray Datasets with a 'case' dimension. The existing plotting infrastructure automatically handles faceting by the 'case' dimension. + All FlowSystems must have matching dimensions (time, period, scenario, etc.). + Use `flow_system.transform.sel()` to align dimensions before comparing. + Args: - flow_systems: List of FlowSystems to compare. + flow_systems: List of FlowSystems to compare. All must be optimized + and have matching dimensions. names: Optional names for each case. If None, uses FlowSystem.name. + Raises: + ValueError: If FlowSystems have mismatched dimensions. + RuntimeError: If any FlowSystem has no solution. + Examples: ```python # Compare two systems (uses FlowSystem.name by default) @@ -50,6 +58,12 @@ class Comparison: # Compute differences relative to first case comp.diff() # Returns xr.Dataset of differences comp.diff('baseline') # Or specify reference by name + + # For systems with different dimensions, align first: + fs_both = ... # Has scenario dimension + fs_mild = fs_both.transform.sel(scenario='Mild') # Select one scenario + fs_other = ... # Also select to match + comp = fx.Comparison([fs_mild, fs_other]) # Now dimensions match ``` """ @@ -68,10 +82,37 @@ def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = Non if len(set(self._names)) != len(self._names): raise ValueError(f'Case names must be unique, got: {self._names}') + # Validate all FlowSystems have solutions + for fs in flow_systems: + if fs.solution is None: + raise RuntimeError(f"FlowSystem '{fs.name}' has no solution. Run optimize() first.") + + # Validate matching dimensions across all FlowSystems + self._validate_matching_dimensions() + # Caches self._solution: xr.Dataset | None = None self._statistics: ComparisonStatistics | None = None + def _validate_matching_dimensions(self) -> None: + """Validate that all FlowSystems have matching dimensions.""" + reference = self._systems[0] + ref_dims = set(reference.solution.dims) + ref_name = self._names[0] + + for fs, name in zip(self._systems[1:], self._names[1:], strict=True): + fs_dims = set(fs.solution.dims) + if fs_dims != ref_dims: + missing = ref_dims - fs_dims + extra = fs_dims - ref_dims + msg_parts = [f"Dimension mismatch between '{ref_name}' and '{name}'."] + if missing: + msg_parts.append(f"Missing in '{name}': {missing}.") + if extra: + msg_parts.append(f"Extra in '{name}': {extra}.") + msg_parts.append('Use .transform.sel() to align dimensions before comparing.') + raise ValueError(' '.join(msg_parts)) + @property def names(self) -> list[str]: """Case names for each FlowSystem.""" @@ -83,11 +124,9 @@ def solution(self) -> xr.Dataset: if self._solution is None: datasets = [] for fs, name in zip(self._systems, self._names, strict=True): - if fs.solution is None: - raise RuntimeError(f"FlowSystem '{fs.name}' has no solution. Run optimize() first.") ds = fs.solution.expand_dims(case=[name]) datasets.append(ds) - self._solution = xr.concat(datasets, dim='case') + self._solution = xr.concat(datasets, dim='case', join='outer', fill_value=float('nan')) return self._solution @property @@ -151,7 +190,7 @@ def _concat_property(self, prop_name: str) -> xr.Dataset: for fs, name in zip(self._comp._systems, self._comp._names, strict=True): ds = getattr(fs.statistics, prop_name) datasets.append(ds.expand_dims(case=[name])) - return xr.concat(datasets, dim='case') + return xr.concat(datasets, dim='case', join='outer', fill_value=float('nan')) def _merge_dict_property(self, prop_name: str) -> dict[str, str]: """Merge a dict property from all cases (later cases override).""" @@ -278,9 +317,29 @@ def __init__(self, statistics: ComparisonStatistics) -> None: self._stats = statistics self._comp = statistics._comp - def _get_first_stats_plot(self): - """Get StatisticsPlotAccessor from first FlowSystem for delegation.""" - return self._comp._systems[0].statistics.plot + def _concat_plot_data(self, method_name: str, *args, **kwargs) -> xr.Dataset: + """Call a plot method on each system and concatenate the resulting data. + + This ensures all data variables from all systems are included, + even if topologies differ between systems. + """ + # Disable show for individual calls, we'll handle it after combining + kwargs['show'] = False + datasets = [] + for fs, name in zip(self._comp._systems, self._comp._names, strict=True): + try: + plot_method = getattr(fs.statistics.plot, method_name) + result = plot_method(*args, **kwargs) + ds = result.data.expand_dims(case=[name]) + datasets.append(ds) + except (KeyError, ValueError): + # Node/element might not exist in this system - skip it + continue + + if not datasets: + return xr.Dataset() + + return xr.concat(datasets, dim='case', join='outer', fill_value=float('nan')) def balance( self, @@ -302,39 +361,16 @@ def balance( See StatisticsPlotAccessor.balance for full documentation. The 'case' dimension is automatically used for faceting. """ - from .statistics_accessor import _apply_selection, _filter_by_pattern, _resolve_auto_facets - - # Get flow labels from first system (assumes same topology) - fs = self._comp._systems[0] - if node in fs.buses: - element = fs.buses[node] - elif node in fs.components: - element = fs.components[node] - else: - raise KeyError(f"'{node}' not found in buses or components") + from .statistics_accessor import _resolve_auto_facets - input_labels = [f.label_full for f in element.inputs] - output_labels = [f.label_full for f in element.outputs] - all_labels = input_labels + output_labels - filtered_labels = _filter_by_pattern(all_labels, include, exclude) + # Get combined data from all systems + ds = self._concat_plot_data('balance', node, select=select, include=include, exclude=exclude, unit=unit) - if not filtered_labels: + if not ds.data_vars: import plotly.graph_objects as go return PlotResult(data=xr.Dataset(), figure=go.Figure()) - # Get combined data - if unit == 'flow_rate': - ds = self._stats.flow_rates[[lbl for lbl in filtered_labels if lbl in self._stats.flow_rates]] - else: - ds = self._stats.flow_hours[[lbl for lbl in filtered_labels if lbl in self._stats.flow_hours]] - - # Negate inputs - for label in input_labels: - if label in ds: - ds[label] = -ds[label] - - ds = _apply_selection(ds, select) actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( ds, facet_col, facet_row, animation_frame ) @@ -380,41 +416,18 @@ def carrier_balance( See StatisticsPlotAccessor.carrier_balance for full documentation. """ - from .statistics_accessor import _apply_selection, _filter_by_pattern, _resolve_auto_facets - - carrier = carrier.lower() - fs = self._comp._systems[0] - - carrier_buses = [bus for bus in fs.buses.values() if bus.carrier == carrier] - if not carrier_buses: - raise KeyError(f"No buses found with carrier '{carrier}'") + from .statistics_accessor import _resolve_auto_facets - input_labels: list[str] = [] - output_labels: list[str] = [] - for bus in carrier_buses: - for flow in bus.inputs: - input_labels.append(flow.label_full) - for flow in bus.outputs: - output_labels.append(flow.label_full) - - all_labels = input_labels + output_labels - filtered_labels = _filter_by_pattern(all_labels, include, exclude) + # Get combined data from all systems + ds = self._concat_plot_data( + 'carrier_balance', carrier, select=select, include=include, exclude=exclude, unit=unit + ) - if not filtered_labels: + if not ds.data_vars: import plotly.graph_objects as go return PlotResult(data=xr.Dataset(), figure=go.Figure()) - if unit == 'flow_rate': - ds = self._stats.flow_rates[[lbl for lbl in filtered_labels if lbl in self._stats.flow_rates]] - else: - ds = self._stats.flow_hours[[lbl for lbl in filtered_labels if lbl in self._stats.flow_hours]] - - for label in output_labels: - if label in ds: - ds[label] = -ds[label] - - ds = _apply_selection(ds, select) actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( ds, facet_col, facet_row, animation_frame ) @@ -459,39 +472,16 @@ def flows( See StatisticsPlotAccessor.flows for full documentation. """ - from .statistics_accessor import _apply_selection, _resolve_auto_facets - - ds = self._stats.flow_rates if unit == 'flow_rate' else self._stats.flow_hours - fs = self._comp._systems[0] - - if start is not None or end is not None or component is not None: - matching_labels = [] - starts = [start] if isinstance(start, str) else (start or []) - ends = [end] if isinstance(end, str) else (end or []) - components = [component] if isinstance(component, str) else (component or []) - - for flow in fs.flows.values(): - bus_label = flow.bus - comp_label = flow.component - - if flow.is_input_in_component: - if starts and bus_label not in starts: - continue - if ends and comp_label not in ends: - continue - else: - if starts and comp_label not in starts: - continue - if ends and bus_label not in ends: - continue - - if components and comp_label not in components: - continue - matching_labels.append(flow.label_full) - - ds = ds[[lbl for lbl in matching_labels if lbl in ds]] - - ds = _apply_selection(ds, select) + from .statistics_accessor import _resolve_auto_facets + + # Get combined data from all systems + ds = self._concat_plot_data('flows', start=start, end=end, component=component, select=select, unit=unit) + + if not ds.data_vars: + import plotly.graph_objects as go + + return PlotResult(data=xr.Dataset(), figure=go.Figure()) + actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( ds, facet_col, facet_row, animation_frame ) @@ -536,14 +526,15 @@ def sizes( import plotly.express as px from .color_processing import process_colors - from .statistics_accessor import _apply_selection, _dataset_to_long_df, _resolve_auto_facets + from .statistics_accessor import _dataset_to_long_df, _resolve_auto_facets - ds = self._stats.sizes - ds = _apply_selection(ds, select) + # Get combined data from all systems + ds = self._concat_plot_data('sizes', max_size=max_size, select=select) - if max_size is not None and ds.data_vars: - valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] - ds = ds[valid_labels] + if not ds.data_vars: + import plotly.graph_objects as go + + return PlotResult(data=xr.Dataset(), figure=go.Figure()) actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( ds, facet_col, facet_row, animation_frame @@ -595,61 +586,18 @@ def duration_curve( See StatisticsPlotAccessor.duration_curve for full documentation. """ - import numpy as np - - from .statistics_accessor import _apply_selection, _resolve_auto_facets - - if isinstance(variables, str): - variables = [variables] - - flow_rates = self._stats.flow_rates - solution = self._comp.solution - - normalized_vars = [] - for var in variables: - if var.endswith('|flow_rate'): - var = var[: -len('|flow_rate')] - normalized_vars.append(var) - - ds_parts = [] - for var in normalized_vars: - if var in flow_rates: - ds_parts.append(flow_rates[[var]]) - elif var in solution: - ds_parts.append(solution[[var]]) - else: - flow_rate_var = f'{var}|flow_rate' - if flow_rate_var in solution: - ds_parts.append(solution[[flow_rate_var]].rename({flow_rate_var: var})) - else: - raise KeyError(f"Variable '{var}' not found in flow_rates or solution") - - ds = xr.merge(ds_parts) - ds = _apply_selection(ds, select) - - if 'time' not in ds.dims: - raise ValueError('Duration curve requires time dimension') - - def sort_descending(arr: np.ndarray) -> np.ndarray: - return np.sort(arr)[::-1] - - result_ds = xr.apply_ufunc( - sort_descending, - ds, - input_core_dims=[['time']], - output_core_dims=[['time']], - vectorize=True, - ) + from .statistics_accessor import _resolve_auto_facets + + # Get combined data from all systems + ds = self._concat_plot_data('duration_curve', variables, select=select, normalize=normalize) - duration_name = 'duration_pct' if normalize else 'duration' - result_ds = result_ds.rename({'time': duration_name}) + if not ds.data_vars: + import plotly.graph_objects as go - n_timesteps = result_ds.sizes[duration_name] - duration_coord = np.linspace(0, 100, n_timesteps) if normalize else np.arange(n_timesteps) - result_ds = result_ds.assign_coords({duration_name: duration_coord}) + return PlotResult(data=xr.Dataset(), figure=go.Figure()) actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - result_ds, facet_col, facet_row, animation_frame + ds, facet_col, facet_row, animation_frame ) unit_label = '' @@ -657,7 +605,7 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: first_var = next(iter(ds.data_vars)) unit_label = ds[first_var].attrs.get('unit', '') - fig = result_ds.fxplot.line( + fig = ds.fxplot.line( colors=colors, title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', facet_col=actual_facet_col, @@ -674,7 +622,7 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: if show: fig.show() - return PlotResult(data=result_ds, figure=fig) + return PlotResult(data=ds, figure=fig) def effects( self, @@ -697,68 +645,38 @@ def effects( import plotly.express as px from .color_processing import process_colors - from .statistics_accessor import _apply_selection, _resolve_auto_facets - - if aspect == 'total': - effects_ds = self._stats.total_effects - elif aspect == 'temporal': - effects_ds = self._stats.temporal_effects - elif aspect == 'periodic': - effects_ds = self._stats.periodic_effects - else: - raise ValueError(f"Aspect '{aspect}' not valid. Choose from 'total', 'temporal', 'periodic'.") + from .statistics_accessor import _resolve_auto_facets - available_effects = list(effects_ds.data_vars) + # Get combined data from all systems + ds = self._concat_plot_data('effects', aspect, effect=effect, by=by, select=select) - if effect is not None: - if effect not in available_effects: - raise ValueError(f"Effect '{effect}' not found. Available: {available_effects}") - effects_to_plot = [effect] - else: - effects_to_plot = available_effects + if not ds.data_vars: + import plotly.graph_objects as go - effect_arrays = [] - for eff in effects_to_plot: - da = effects_ds[eff] - if by == 'contributor': - effect_arrays.append(da.expand_dims(effect=[eff])) - else: - da_grouped = da.groupby('component').sum() - effect_arrays.append(da_grouped.expand_dims(effect=[eff])) + return PlotResult(data=xr.Dataset(), figure=go.Figure()) - combined = xr.concat(effect_arrays, dim='effect') - combined = _apply_selection(combined.to_dataset(name='value'), select)['value'] + # The underlying effects method returns a Dataset with a single data var named after aspect + # Convert back to DataArray for processing + combined = ds[aspect] if aspect in ds else next(iter(ds.data_vars)) + if isinstance(combined, xr.Dataset): + combined = combined[next(iter(combined.data_vars))] + # Determine x_col and color_col based on dimensions if by is None: - if 'time' in combined.dims: - combined = combined.sum(dim='time') - if 'component' in combined.dims: - combined = combined.sum(dim='component') - if 'contributor' in combined.dims: - combined = combined.sum(dim='contributor') x_col = 'effect' color_col = 'effect' elif by == 'component': - if 'time' in combined.dims: - combined = combined.sum(dim='time') x_col = 'component' - color_col = 'effect' if len(effects_to_plot) > 1 else 'component' + color_col = 'effect' if 'effect' in combined.dims and combined.sizes.get('effect', 1) > 1 else 'component' elif by == 'contributor': - if 'time' in combined.dims: - combined = combined.sum(dim='time') x_col = 'contributor' - color_col = 'effect' if len(effects_to_plot) > 1 else 'contributor' + color_col = 'effect' if 'effect' in combined.dims and combined.sizes.get('effect', 1) > 1 else 'contributor' elif by == 'time': - if 'time' not in combined.dims: - raise ValueError(f"Cannot plot by 'time' for aspect '{aspect}' - no time dimension.") - if 'component' in combined.dims: - combined = combined.sum(dim='component') - if 'contributor' in combined.dims: - combined = combined.sum(dim='contributor') x_col = 'time' - color_col = 'effect' if len(effects_to_plot) > 1 else None + color_col = 'effect' if 'effect' in combined.dims and combined.sizes.get('effect', 1) > 1 else None else: - raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") + x_col = 'effect' + color_col = 'effect' actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( combined.to_dataset(name='value'), facet_col, facet_row, animation_frame @@ -773,12 +691,7 @@ def effects( color_map = None effect_label = effect if effect else 'Effects' - if effect and effect in effects_ds: - unit_label = effects_ds[effect].attrs.get('unit', '') - title = f'{effect_label} [{unit_label}]' if unit_label else effect_label - else: - title = effect_label - title = f'{title} ({aspect})' if by is None else f'{title} ({aspect}) by {by}' + title = f'{effect_label} ({aspect})' if by is None else f'{effect_label} ({aspect}) by {by}' fig = px.bar( df, @@ -800,7 +713,7 @@ def effects( if show: fig.show() - return PlotResult(data=combined.to_dataset(name=aspect), figure=fig) + return PlotResult(data=ds, figure=fig) def charge_states( self, @@ -818,16 +731,16 @@ def charge_states( See StatisticsPlotAccessor.charge_states for full documentation. """ - from .statistics_accessor import _apply_selection, _resolve_auto_facets + from .statistics_accessor import _resolve_auto_facets - ds = self._stats.charge_states + # Get combined data from all systems + ds = self._concat_plot_data('charge_states', storages, select=select) - if storages is not None: - if isinstance(storages, str): - storages = [storages] - ds = ds[[s for s in storages if s in ds]] + if not ds.data_vars: + import plotly.graph_objects as go + + return PlotResult(data=xr.Dataset(), figure=go.Figure()) - ds = _apply_selection(ds, select) actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( ds, facet_col, facet_row, animation_frame ) @@ -865,87 +778,33 @@ def heatmap( See StatisticsPlotAccessor.heatmap for full documentation. """ - import pandas as pd - - from .statistics_accessor import _apply_selection, _reshape_time_for_heatmap, _resolve_auto_facets - - solution = self._comp.solution - - if isinstance(variables, str): - variables = [variables] - - # Resolve flow labels - resolved_variables = [] - for var in variables: - if var in solution: - resolved_variables.append(var) - elif '|' not in var: - flow_rate_var = f'{var}|flow_rate' - charge_state_var = f'{var}|charge_state' - if flow_rate_var in solution: - resolved_variables.append(flow_rate_var) - elif charge_state_var in solution: - resolved_variables.append(charge_state_var) - else: - resolved_variables.append(var) - else: - resolved_variables.append(var) - - ds = solution[resolved_variables] - ds = _apply_selection(ds, select) - - variable_names = list(ds.data_vars) - dataarrays = [ds[var] for var in variable_names] - da = xr.concat(dataarrays, dim=pd.Index(variable_names, name='variable')) - - is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 - has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - - if has_multiple_vars: - actual_facet = 'variable' - _, _, actual_animation = _resolve_auto_facets(da.to_dataset(name='value'), None, None, animation_frame) - if actual_animation == 'variable': - actual_animation = None - else: - actual_facet, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame - ) + from .statistics_accessor import _resolve_auto_facets - if is_clustered and (reshape == 'auto' or reshape is None): - heatmap_dims = ['time', 'cluster'] - elif reshape and reshape != 'auto' and 'time' in da.dims: - da = _reshape_time_for_heatmap(da, reshape) - heatmap_dims = ['timestep', 'timeframe'] - elif reshape == 'auto' and 'time' in da.dims and not is_clustered: - da = _reshape_time_for_heatmap(da, ('D', 'h')) - heatmap_dims = ['timestep', 'timeframe'] - elif has_multiple_vars: - heatmap_dims = ['variable', 'time'] - actual_facet, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame - ) - else: - available_dims = [d for d in da.dims if da.sizes[d] > 1] - if len(available_dims) >= 2: - heatmap_dims = available_dims[:2] - elif 'time' in da.dims: - heatmap_dims = ['time'] - else: - heatmap_dims = list(da.dims)[:1] + # Get combined data from all systems + ds = self._concat_plot_data('heatmap', variables, select=select, reshape=reshape) - keep_dims = set(heatmap_dims) | {d for d in [actual_facet, actual_animation] if d is not None} - for dim in [d for d in da.dims if d not in keep_dims]: - da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) + if not ds.data_vars: + import plotly.graph_objects as go - dim_order = heatmap_dims + [d for d in [actual_facet, actual_animation] if d] - da = da.transpose(*dim_order) + return PlotResult(data=xr.Dataset(), figure=go.Figure()) + + # Convert to DataArray for heatmap plotting + if len(ds.data_vars) == 1: + da = ds[next(iter(ds.data_vars))] + else: + import pandas as pd - if has_multiple_vars: - da = da.rename('') + variable_names = list(ds.data_vars) + dataarrays = [ds[var] for var in variable_names] + da = xr.concat(dataarrays, dim=pd.Index(variable_names, name='variable')) + + actual_facet_col, _, actual_animation = _resolve_auto_facets( + da.to_dataset(name='value'), facet_col, None, animation_frame + ) fig = da.fxplot.heatmap( colors=colors, - facet_col=actual_facet, + facet_col=actual_facet_col, animation_frame=actual_animation, **plotly_kwargs, ) @@ -955,8 +814,7 @@ def heatmap( if show: fig.show() - reshaped_ds = da.to_dataset(name='value') if isinstance(da, xr.DataArray) else da - return PlotResult(data=reshaped_ds, figure=fig) + return PlotResult(data=ds, figure=fig) def storage( self, @@ -976,39 +834,15 @@ def storage( See StatisticsPlotAccessor.storage for full documentation. """ - # Delegate to first system's plot method for now (complex subplot logic) - # This is a simplification - for full support would need to reimplement - from .statistics_accessor import _apply_selection, _resolve_auto_facets - - fs = self._comp._systems[0] - if storage not in fs.components: - raise KeyError(f"Storage '{storage}' not found in components") - - from .components import Storage as StorageComponent + from .statistics_accessor import _resolve_auto_facets - comp = fs.components[storage] - if not isinstance(comp, StorageComponent): - raise ValueError(f"'{storage}' is not a Storage component") + # Get combined data from all systems + ds = self._concat_plot_data('storage', storage, select=select, unit=unit, charge_state_color=charge_state_color) - # Get combined data - input_labels = [f.label_full for f in comp.inputs] - output_labels = [f.label_full for f in comp.outputs] - all_labels = input_labels + output_labels - - if unit == 'flow_rate': - ds = self._stats.flow_rates[[lbl for lbl in all_labels if lbl in self._stats.flow_rates]] - else: - ds = self._stats.flow_hours[[lbl for lbl in all_labels if lbl in self._stats.flow_hours]] - - for label in input_labels: - if label in ds: - ds[label] = -ds[label] - - charge_ds = self._stats.charge_states[[storage]] if storage in self._stats.charge_states else None + if not ds.data_vars: + import plotly.graph_objects as go - ds = _apply_selection(ds, select) - if charge_ds is not None: - charge_ds = _apply_selection(charge_ds, select) + return PlotResult(data=xr.Dataset(), figure=go.Figure()) actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( ds, facet_col, facet_row, animation_frame @@ -1029,10 +863,4 @@ def storage( if show: fig.show() - # Combine data - if charge_ds is not None: - combined_ds = xr.merge([ds, charge_ds]) - else: - combined_ds = ds - - return PlotResult(data=combined_ds, figure=fig) + return PlotResult(data=ds, figure=fig) From c490f25ca7a0b167c0bfd9c269c2cfaf14bbdf91 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 21:27:09 +0100 Subject: [PATCH 23/62] =?UTF-8?q?=E2=8F=BA=20The=20class=20went=20from=20~?= =?UTF-8?q?560=20lines=20to=20~115=20lines.=20Key=20simplifications:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. __getattr__ - dynamically delegates any method to the underlying accessor 2. _wrap_plot_method - single method that handles all the data collection and concatenation 3. _recreate_figure - infers plot type from the original figure and recreates with combined data Tradeoffs: - Less explicit type hints on method signatures (but still works the same) - Infers plot type from original figure rather than hardcoding per method - Automatically supports any new methods added to StatisticsPlotAccessor in the future --- docs/user-guide/results/index.md | 8 +- flixopt/comparison.py | 627 +++++-------------------------- 2 files changed, 102 insertions(+), 533 deletions(-) diff --git a/docs/user-guide/results/index.md b/docs/user-guide/results/index.md index 0137197ea..500a64cd9 100644 --- a/docs/user-guide/results/index.md +++ b/docs/user-guide/results/index.md @@ -312,19 +312,23 @@ comp.solution # Combined solution dataset ### Requirements -All FlowSystems in a comparison must have **matching dimensions** (time, period, scenario, etc.). If dimensions differ, use `.transform.sel()` to align them first: +All FlowSystems must have **matching core dimensions** (`time`, `period`, `scenario`). Auxiliary dimensions like `cluster_boundary` are ignored. If core dimensions differ, use `.transform.sel()` to align them first: ```python # Systems with different scenarios fs_both = flow_system # Has 'Mild Winter' and 'Harsh Winter' scenarios fs_mild = flow_system.transform.sel(scenario='Mild Winter') # Single scenario -# Cannot compare directly - dimension mismatch! +# Cannot compare directly - scenario dimension mismatch! # fx.Comparison([fs_both, fs_mild]) # Raises ValueError # Instead, select matching dimensions fs_both_mild = fs_both.transform.sel(scenario='Mild Winter') comp = fx.Comparison([fs_both_mild, fs_mild]) # Works! + +# Auxiliary dimensions are OK (e.g., expanded clustered solutions) +fs_expanded = fs_clustered.transform.expand_solution() # Has cluster_boundary dim +comp = fx.Comparison([fs_full, fs_expanded]) # Works! cluster_boundary is ignored ``` ### Available Properties diff --git a/flixopt/comparison.py b/flixopt/comparison.py index 899048bbb..06819fda4 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any import xarray as xr @@ -94,18 +94,26 @@ def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = Non self._solution: xr.Dataset | None = None self._statistics: ComparisonStatistics | None = None + # Core dimensions that must match across FlowSystems + _CORE_DIMS = {'time', 'cluster', 'period', 'scenario'} + def _validate_matching_dimensions(self) -> None: - """Validate that all FlowSystems have matching dimensions.""" + """Validate that all FlowSystems have matching core dimensions. + + Only validates core dimensions (time, period, scenario). Auxiliary + dimensions like 'cluster_boundary' are ignored as they don't affect + the comparison logic. + """ reference = self._systems[0] - ref_dims = set(reference.solution.dims) + ref_core_dims = set(reference.solution.dims) & self._CORE_DIMS ref_name = self._names[0] for fs, name in zip(self._systems[1:], self._names[1:], strict=True): - fs_dims = set(fs.solution.dims) - if fs_dims != ref_dims: - missing = ref_dims - fs_dims - extra = fs_dims - ref_dims - msg_parts = [f"Dimension mismatch between '{ref_name}' and '{name}'."] + fs_core_dims = set(fs.solution.dims) & self._CORE_DIMS + if fs_core_dims != ref_core_dims: + missing = ref_core_dims - fs_core_dims + extra = fs_core_dims - ref_core_dims + msg_parts = [f"Core dimension mismatch between '{ref_name}' and '{name}'."] if missing: msg_parts.append(f"Missing in '{name}': {missing}.") if extra: @@ -308,559 +316,116 @@ def plot(self) -> ComparisonStatisticsPlot: class ComparisonStatisticsPlot: """Plot accessor for comparison statistics. - Mirrors StatisticsPlotAccessor methods, operating on combined data - from multiple FlowSystems. The 'case' dimension is automatically - used for faceting. + Dynamically wraps StatisticsPlotAccessor methods, combining data from all + FlowSystems with a 'case' dimension for faceting. """ def __init__(self, statistics: ComparisonStatistics) -> None: self._stats = statistics self._comp = statistics._comp - def _concat_plot_data(self, method_name: str, *args, **kwargs) -> xr.Dataset: - """Call a plot method on each system and concatenate the resulting data. + def __getattr__(self, name: str): + """Dynamically delegate any plot method to underlying systems.""" + if name.startswith('_'): + raise AttributeError(name) + # Check if method exists on underlying accessor + if not hasattr(self._comp._systems[0].statistics.plot, name): + raise AttributeError(name) + return lambda *args, **kwargs: self._wrap_plot_method(name, *args, **kwargs) + + def _wrap_plot_method(self, method_name: str, *args, show: bool | None = None, **kwargs) -> PlotResult: + """Call plot method on each system and combine results.""" + import plotly.graph_objects as go - This ensures all data variables from all systems are included, - even if topologies differ between systems. - """ - # Disable show for individual calls, we'll handle it after combining - kwargs['show'] = False datasets = [] - for fs, name in zip(self._comp._systems, self._comp._names, strict=True): + last_result = None + + for fs, case_name in zip(self._comp._systems, self._comp._names, strict=True): try: - plot_method = getattr(fs.statistics.plot, method_name) - result = plot_method(*args, **kwargs) - ds = result.data.expand_dims(case=[name]) - datasets.append(ds) + method = getattr(fs.statistics.plot, method_name) + result = method(*args, show=False, **kwargs) + datasets.append(result.data.expand_dims(case=[case_name])) + last_result = result except (KeyError, ValueError): - # Node/element might not exist in this system - skip it + # Element might not exist in this system continue if not datasets: - return xr.Dataset() - - return xr.concat(datasets, dim='case', join='outer', fill_value=float('nan')) - - def balance( - self, - node: str, - *, - select: SelectType | None = None, - include: FilterType | None = None, - exclude: FilterType | None = None, - unit: Literal['flow_rate', 'flow_hours'] = 'flow_rate', - colors: ColorType = None, - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot node balance comparison across cases. - - See StatisticsPlotAccessor.balance for full documentation. - The 'case' dimension is automatically used for faceting. - """ - from .statistics_accessor import _resolve_auto_facets - - # Get combined data from all systems - ds = self._concat_plot_data('balance', node, select=select, include=include, exclude=exclude, unit=unit) - - if not ds.data_vars: - import plotly.graph_objects as go - - return PlotResult(data=xr.Dataset(), figure=go.Figure()) - - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - - # Get unit label - unit_label = '' - if ds.data_vars: - first_var = next(iter(ds.data_vars)) - unit_label = ds[first_var].attrs.get('unit', '') - - fig = ds.fxplot.stacked_bar( - colors=colors, - title=f'{node} [{unit_label}]' if unit_label else node, - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - **plotly_kwargs, - ) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - fig.show() - - return PlotResult(data=ds, figure=fig) - - def carrier_balance( - self, - carrier: str, - *, - select: SelectType | None = None, - include: FilterType | None = None, - exclude: FilterType | None = None, - unit: Literal['flow_rate', 'flow_hours'] = 'flow_rate', - colors: ColorType = None, - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot carrier balance comparison across cases. - - See StatisticsPlotAccessor.carrier_balance for full documentation. - """ - from .statistics_accessor import _resolve_auto_facets - - # Get combined data from all systems - ds = self._concat_plot_data( - 'carrier_balance', carrier, select=select, include=include, exclude=exclude, unit=unit - ) - - if not ds.data_vars: - import plotly.graph_objects as go - return PlotResult(data=xr.Dataset(), figure=go.Figure()) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - - unit_label = '' - if ds.data_vars: - first_var = next(iter(ds.data_vars)) - unit_label = ds[first_var].attrs.get('unit', '') - - fig = ds.fxplot.stacked_bar( - colors=colors, - title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - **plotly_kwargs, - ) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - fig.show() - - return PlotResult(data=ds, figure=fig) - - def flows( - self, - *, - start: str | list[str] | None = None, - end: str | list[str] | None = None, - component: str | list[str] | None = None, - select: SelectType | None = None, - unit: Literal['flow_rate', 'flow_hours'] = 'flow_rate', - colors: ColorType = None, - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot flows comparison across cases. - - See StatisticsPlotAccessor.flows for full documentation. - """ - from .statistics_accessor import _resolve_auto_facets - - # Get combined data from all systems - ds = self._concat_plot_data('flows', start=start, end=end, component=component, select=select, unit=unit) - - if not ds.data_vars: - import plotly.graph_objects as go - - return PlotResult(data=xr.Dataset(), figure=go.Figure()) - - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - - unit_label = '' - if ds.data_vars: - first_var = next(iter(ds.data_vars)) - unit_label = ds[first_var].attrs.get('unit', '') - - fig = ds.fxplot.line( - colors=colors, - title=f'Flows [{unit_label}]' if unit_label else 'Flows', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - **plotly_kwargs, - ) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - fig.show() - - return PlotResult(data=ds, figure=fig) - - def sizes( - self, - *, - max_size: float | None = 1e6, - select: SelectType | None = None, - colors: ColorType = None, - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot investment sizes comparison across cases. - - See StatisticsPlotAccessor.sizes for full documentation. - """ - import plotly.express as px - - from .color_processing import process_colors - from .statistics_accessor import _dataset_to_long_df, _resolve_auto_facets - - # Get combined data from all systems - ds = self._concat_plot_data('sizes', max_size=max_size, select=select) - - if not ds.data_vars: - import plotly.graph_objects as go - - return PlotResult(data=xr.Dataset(), figure=go.Figure()) - - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - - df = _dataset_to_long_df(ds) - if df.empty: - import plotly.graph_objects as go - - fig = go.Figure() - else: - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables) - fig = px.bar( - df, - x='variable', - y='value', - color='variable', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - color_discrete_map=color_map, - title='Investment Sizes', - labels={'variable': 'Flow', 'value': 'Size'}, - **plotly_kwargs, - ) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - fig.show() + combined = xr.concat(datasets, dim='case', join='outer', fill_value=float('nan')) - return PlotResult(data=ds, figure=fig) - - def duration_curve( - self, - variables: str | list[str], - *, - select: SelectType | None = None, - normalize: bool = False, - colors: ColorType = None, - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot duration curves comparison across cases. - - See StatisticsPlotAccessor.duration_curve for full documentation. - """ + # Recreate figure with combined data from .statistics_accessor import _resolve_auto_facets - # Get combined data from all systems - ds = self._concat_plot_data('duration_curve', variables, select=select, normalize=normalize) + facet_col = kwargs.pop('facet_col', 'auto') + facet_row = kwargs.pop('facet_row', 'auto') + animation_frame = kwargs.pop('animation_frame', 'auto') + colors = kwargs.get('colors') - if not ds.data_vars: - import plotly.graph_objects as go - - return PlotResult(data=xr.Dataset(), figure=go.Figure()) + actual_col, actual_row, actual_anim = _resolve_auto_facets(combined, facet_col, facet_row, animation_frame) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - - unit_label = '' - if ds.data_vars: - first_var = next(iter(ds.data_vars)) - unit_label = ds[first_var].attrs.get('unit', '') - - fig = ds.fxplot.line( - colors=colors, - title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - **plotly_kwargs, - ) - - x_label = 'Duration [%]' if normalize else 'Timesteps' - fig.update_xaxes(title_text=x_label) + # Determine plot type from last successful result's figure + fig = self._recreate_figure(combined, last_result, colors, actual_col, actual_row, actual_anim, kwargs) if show is None: show = CONFIG.Plotting.default_show if show: fig.show() - return PlotResult(data=ds, figure=fig) - - def effects( - self, - aspect: Literal['total', 'temporal', 'periodic'] = 'total', - *, - effect: str | None = None, - by: Literal['component', 'contributor', 'time'] | None = None, - select: SelectType | None = None, - colors: ColorType = None, - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot effects comparison across cases. - - See StatisticsPlotAccessor.effects for full documentation. - """ - import plotly.express as px - - from .color_processing import process_colors - from .statistics_accessor import _resolve_auto_facets + return PlotResult(data=combined, figure=fig) - # Get combined data from all systems - ds = self._concat_plot_data('effects', aspect, effect=effect, by=by, select=select) + def _recreate_figure( + self, ds: xr.Dataset, last_result: PlotResult | None, colors, facet_col, facet_row, anim, kwargs + ): + """Recreate figure with combined data, inferring plot type from original.""" + import plotly.graph_objects as go if not ds.data_vars: - import plotly.graph_objects as go + return go.Figure() - return PlotResult(data=xr.Dataset(), figure=go.Figure()) - - # The underlying effects method returns a Dataset with a single data var named after aspect - # Convert back to DataArray for processing - combined = ds[aspect] if aspect in ds else next(iter(ds.data_vars)) - if isinstance(combined, xr.Dataset): - combined = combined[next(iter(combined.data_vars))] - - # Determine x_col and color_col based on dimensions - if by is None: - x_col = 'effect' - color_col = 'effect' - elif by == 'component': - x_col = 'component' - color_col = 'effect' if 'effect' in combined.dims and combined.sizes.get('effect', 1) > 1 else 'component' - elif by == 'contributor': - x_col = 'contributor' - color_col = 'effect' if 'effect' in combined.dims and combined.sizes.get('effect', 1) > 1 else 'contributor' - elif by == 'time': - x_col = 'time' - color_col = 'effect' if 'effect' in combined.dims and combined.sizes.get('effect', 1) > 1 else None + # Infer plot type from original figure traces + if last_result and last_result.figure.data: + trace_type = type(last_result.figure.data[0]).__name__ else: - x_col = 'effect' - color_col = 'effect' - - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - combined.to_dataset(name='value'), facet_col, facet_row, animation_frame - ) - - df = combined.to_dataframe(name='value').reset_index() - - if color_col and color_col in df.columns: - color_items = df[color_col].unique().tolist() - color_map = process_colors(colors, color_items) - else: - color_map = None - - effect_label = effect if effect else 'Effects' - title = f'{effect_label} ({aspect})' if by is None else f'{effect_label} ({aspect}) by {by}' - - fig = px.bar( - df, - x=x_col, - y='value', - color=color_col, - color_discrete_map=color_map, - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - title=title, - **plotly_kwargs, - ) - fig.update_layout(bargap=0, bargroupgap=0) - fig.update_traces(marker_line_width=0) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - fig.show() - - return PlotResult(data=ds, figure=fig) - - def charge_states( - self, - storages: str | list[str] | None = None, - *, - select: SelectType | None = None, - colors: ColorType = None, - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot charge states comparison across cases. - - See StatisticsPlotAccessor.charge_states for full documentation. - """ - from .statistics_accessor import _resolve_auto_facets - - # Get combined data from all systems - ds = self._concat_plot_data('charge_states', storages, select=select) - - if not ds.data_vars: - import plotly.graph_objects as go - - return PlotResult(data=xr.Dataset(), figure=go.Figure()) - - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - - fig = ds.fxplot.line( - colors=colors, - title='Storage Charge States', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - **plotly_kwargs, - ) - fig.update_yaxes(title_text='Charge State') - - if show is None: - show = CONFIG.Plotting.default_show - if show: - fig.show() - - return PlotResult(data=ds, figure=fig) - - def heatmap( - self, - variables: str | list[str], - *, - select: SelectType | None = None, - reshape: tuple[str, str] | Literal['auto'] | None = 'auto', - colors: str | list[str] | None = None, - facet_col: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = None, - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot heatmap comparison across cases. - - See StatisticsPlotAccessor.heatmap for full documentation. - """ - from .statistics_accessor import _resolve_auto_facets - - # Get combined data from all systems - ds = self._concat_plot_data('heatmap', variables, select=select, reshape=reshape) - - if not ds.data_vars: - import plotly.graph_objects as go - - return PlotResult(data=xr.Dataset(), figure=go.Figure()) - - # Convert to DataArray for heatmap plotting - if len(ds.data_vars) == 1: + trace_type = 'Bar' + + title = last_result.figure.layout.title.text if last_result else '' + + if trace_type in ('Bar',) and 'Scatter' not in str(last_result.figure.data): + # Check if it's a stacked bar (has barmode='relative') + barmode = getattr(last_result.figure.layout, 'barmode', None) if last_result else None + if barmode == 'relative': + return ds.fxplot.stacked_bar( + colors=colors, title=title, facet_col=facet_col, facet_row=facet_row, animation_frame=anim + ) + else: + # Regular bar - use px.bar via long-form data + from .color_processing import process_colors + from .statistics_accessor import _dataset_to_long_df + + df = _dataset_to_long_df(ds) + if df.empty: + return go.Figure() + import plotly.express as px + + color_map = process_colors(colors, df['variable'].unique().tolist()) + return px.bar( + df, + x='variable', + y='value', + color='variable', + facet_col=facet_col, + facet_row=facet_row, + animation_frame=anim, + color_discrete_map=color_map, + title=title, + ) + elif trace_type == 'Heatmap': da = ds[next(iter(ds.data_vars))] + return da.fxplot.heatmap(colors=colors, facet_col=facet_col, animation_frame=anim) else: - import pandas as pd - - variable_names = list(ds.data_vars) - dataarrays = [ds[var] for var in variable_names] - da = xr.concat(dataarrays, dim=pd.Index(variable_names, name='variable')) - - actual_facet_col, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame - ) - - fig = da.fxplot.heatmap( - colors=colors, - facet_col=actual_facet_col, - animation_frame=actual_animation, - **plotly_kwargs, - ) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - fig.show() - - return PlotResult(data=ds, figure=fig) - - def storage( - self, - storage: str, - *, - select: SelectType | None = None, - unit: Literal['flow_rate', 'flow_hours'] = 'flow_rate', - colors: ColorType = None, - charge_state_color: str = 'black', - facet_col: str | Literal['auto'] | None = 'auto', - facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', - show: bool | None = None, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot storage operation comparison across cases. - - See StatisticsPlotAccessor.storage for full documentation. - """ - from .statistics_accessor import _resolve_auto_facets - - # Get combined data from all systems - ds = self._concat_plot_data('storage', storage, select=select, unit=unit, charge_state_color=charge_state_color) - - if not ds.data_vars: - import plotly.graph_objects as go - - return PlotResult(data=xr.Dataset(), figure=go.Figure()) - - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - - # Create stacked bar for flows - fig = ds.fxplot.stacked_bar( - colors=colors, - title=f'{storage} Operation', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - **plotly_kwargs, - ) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - fig.show() - - return PlotResult(data=ds, figure=fig) + # Default to line plot + return ds.fxplot.line( + colors=colors, title=title, facet_col=facet_col, facet_row=facet_row, animation_frame=anim + ) From 64dc1f90242f72c1181efdee712c9ef0f0e5604e Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 21:46:49 +0100 Subject: [PATCH 24/62] =?UTF-8?q?=E2=8F=BA=20The=20class=20went=20from=20~?= =?UTF-8?q?560=20lines=20to=20~115=20lines.=20Key=20simplifications:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. __getattr__ - dynamically delegates any method to the underlying accessor 2. _wrap_plot_method - single method that handles all the data collection and concatenation 3. _recreate_figure - infers plot type from the original figure and recreates with combined data Tradeoffs: - Less explicit type hints on method signatures (but still works the same) - Infers plot type from original figure rather than hardcoding per method - Automatically supports any new methods added to StatisticsPlotAccessor in the future --- flixopt/comparison.py | 313 ++++++++++++++++++++++++++++++------------ 1 file changed, 225 insertions(+), 88 deletions(-) diff --git a/flixopt/comparison.py b/flixopt/comparison.py index 06819fda4..df6e0a3cd 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -316,116 +316,253 @@ def plot(self) -> ComparisonStatisticsPlot: class ComparisonStatisticsPlot: """Plot accessor for comparison statistics. - Dynamically wraps StatisticsPlotAccessor methods, combining data from all - FlowSystems with a 'case' dimension for faceting. + Wraps StatisticsPlotAccessor methods, combining data from all FlowSystems + with a 'case' dimension for faceting. """ def __init__(self, statistics: ComparisonStatistics) -> None: self._stats = statistics self._comp = statistics._comp - def __getattr__(self, name: str): - """Dynamically delegate any plot method to underlying systems.""" - if name.startswith('_'): - raise AttributeError(name) - # Check if method exists on underlying accessor - if not hasattr(self._comp._systems[0].statistics.plot, name): - raise AttributeError(name) - return lambda *args, **kwargs: self._wrap_plot_method(name, *args, **kwargs) - - def _wrap_plot_method(self, method_name: str, *args, show: bool | None = None, **kwargs) -> PlotResult: - """Call plot method on each system and combine results.""" - import plotly.graph_objects as go - + def _combine_data(self, method_name: str, *args, **kwargs) -> tuple[xr.Dataset, str]: + """Call plot method on each system and combine data. Returns (combined_data, title).""" datasets = [] - last_result = None + title = '' + kwargs['show'] = False for fs, case_name in zip(self._comp._systems, self._comp._names, strict=True): try: - method = getattr(fs.statistics.plot, method_name) - result = method(*args, show=False, **kwargs) + result = getattr(fs.statistics.plot, method_name)(*args, **kwargs) datasets.append(result.data.expand_dims(case=[case_name])) - last_result = result + title = result.figure.layout.title.text or title except (KeyError, ValueError): - # Element might not exist in this system continue if not datasets: - return PlotResult(data=xr.Dataset(), figure=go.Figure()) + return xr.Dataset(), '' - combined = xr.concat(datasets, dim='case', join='outer', fill_value=float('nan')) + return xr.concat(datasets, dim='case', join='outer', fill_value=float('nan')), title - # Recreate figure with combined data + def _resolve_facets(self, ds: xr.Dataset, facet_col='auto', facet_row='auto', animation_frame='auto'): + """Resolve auto facets.""" from .statistics_accessor import _resolve_auto_facets - facet_col = kwargs.pop('facet_col', 'auto') - facet_row = kwargs.pop('facet_row', 'auto') - animation_frame = kwargs.pop('animation_frame', 'auto') - colors = kwargs.get('colors') - - actual_col, actual_row, actual_anim = _resolve_auto_facets(combined, facet_col, facet_row, animation_frame) + return _resolve_auto_facets(ds, facet_col, facet_row, animation_frame) - # Determine plot type from last successful result's figure - fig = self._recreate_figure(combined, last_result, colors, actual_col, actual_row, actual_anim, kwargs) + def _finalize(self, ds: xr.Dataset, fig, show: bool | None) -> PlotResult: + """Handle show and return PlotResult.""" + import plotly.graph_objects as go if show is None: show = CONFIG.Plotting.default_show - if show: + if show and fig: fig.show() - - return PlotResult(data=combined, figure=fig) - - def _recreate_figure( - self, ds: xr.Dataset, last_result: PlotResult | None, colors, facet_col, facet_row, anim, kwargs - ): - """Recreate figure with combined data, inferring plot type from original.""" - import plotly.graph_objects as go - + return PlotResult(data=ds, figure=fig or go.Figure()) + + def balance( + self, + node: str, + *, + colors=None, + facet_col='auto', + facet_row='auto', + animation_frame='auto', + show: bool | None = None, + **kwargs, + ) -> PlotResult: + """Plot node balance comparison. See StatisticsPlotAccessor.balance.""" + ds, title = self._combine_data('balance', node, **kwargs) if not ds.data_vars: - return go.Figure() - - # Infer plot type from original figure traces - if last_result and last_result.figure.data: - trace_type = type(last_result.figure.data[0]).__name__ - else: - trace_type = 'Bar' - - title = last_result.figure.layout.title.text if last_result else '' - - if trace_type in ('Bar',) and 'Scatter' not in str(last_result.figure.data): - # Check if it's a stacked bar (has barmode='relative') - barmode = getattr(last_result.figure.layout, 'barmode', None) if last_result else None - if barmode == 'relative': - return ds.fxplot.stacked_bar( - colors=colors, title=title, facet_col=facet_col, facet_row=facet_row, animation_frame=anim - ) - else: - # Regular bar - use px.bar via long-form data - from .color_processing import process_colors - from .statistics_accessor import _dataset_to_long_df - - df = _dataset_to_long_df(ds) - if df.empty: - return go.Figure() - import plotly.express as px - - color_map = process_colors(colors, df['variable'].unique().tolist()) - return px.bar( - df, - x='variable', - y='value', - color='variable', - facet_col=facet_col, - facet_row=facet_row, - animation_frame=anim, - color_discrete_map=color_map, - title=title, - ) - elif trace_type == 'Heatmap': - da = ds[next(iter(ds.data_vars))] - return da.fxplot.heatmap(colors=colors, facet_col=facet_col, animation_frame=anim) - else: - # Default to line plot - return ds.fxplot.line( - colors=colors, title=title, facet_col=facet_col, facet_row=facet_row, animation_frame=anim - ) + return self._finalize(ds, None, show) + col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) + fig = ds.fxplot.stacked_bar(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + return self._finalize(ds, fig, show) + + def carrier_balance( + self, + carrier: str, + *, + colors=None, + facet_col='auto', + facet_row='auto', + animation_frame='auto', + show: bool | None = None, + **kwargs, + ) -> PlotResult: + """Plot carrier balance comparison. See StatisticsPlotAccessor.carrier_balance.""" + ds, title = self._combine_data('carrier_balance', carrier, **kwargs) + if not ds.data_vars: + return self._finalize(ds, None, show) + col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) + fig = ds.fxplot.stacked_bar(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + return self._finalize(ds, fig, show) + + def flows( + self, + *, + colors=None, + facet_col='auto', + facet_row='auto', + animation_frame='auto', + show: bool | None = None, + **kwargs, + ) -> PlotResult: + """Plot flows comparison. See StatisticsPlotAccessor.flows.""" + ds, title = self._combine_data('flows', **kwargs) + if not ds.data_vars: + return self._finalize(ds, None, show) + col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) + fig = ds.fxplot.line(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + return self._finalize(ds, fig, show) + + def storage( + self, + storage: str, + *, + colors=None, + facet_col='auto', + facet_row='auto', + animation_frame='auto', + show: bool | None = None, + **kwargs, + ) -> PlotResult: + """Plot storage operation comparison. See StatisticsPlotAccessor.storage.""" + ds, title = self._combine_data('storage', storage, **kwargs) + if not ds.data_vars: + return self._finalize(ds, None, show) + col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) + fig = ds.fxplot.stacked_bar(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + return self._finalize(ds, fig, show) + + def charge_states( + self, + storages=None, + *, + colors=None, + facet_col='auto', + facet_row='auto', + animation_frame='auto', + show: bool | None = None, + **kwargs, + ) -> PlotResult: + """Plot charge states comparison. See StatisticsPlotAccessor.charge_states.""" + ds, title = self._combine_data('charge_states', storages, **kwargs) + if not ds.data_vars: + return self._finalize(ds, None, show) + col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) + fig = ds.fxplot.line(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + fig.update_yaxes(title_text='Charge State') + return self._finalize(ds, fig, show) + + def duration_curve( + self, + variables, + *, + normalize: bool = False, + colors=None, + facet_col='auto', + facet_row='auto', + animation_frame='auto', + show: bool | None = None, + **kwargs, + ) -> PlotResult: + """Plot duration curves comparison. See StatisticsPlotAccessor.duration_curve.""" + ds, title = self._combine_data('duration_curve', variables, normalize=normalize, **kwargs) + if not ds.data_vars: + return self._finalize(ds, None, show) + col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) + fig = ds.fxplot.line(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + fig.update_xaxes(title_text='Duration [%]' if normalize else 'Timesteps') + return self._finalize(ds, fig, show) + + def sizes( + self, + *, + colors=None, + facet_col='auto', + facet_row='auto', + animation_frame='auto', + show: bool | None = None, + **kwargs, + ) -> PlotResult: + """Plot investment sizes comparison. See StatisticsPlotAccessor.sizes.""" + import plotly.express as px + + from .color_processing import process_colors + from .statistics_accessor import _dataset_to_long_df + + ds, title = self._combine_data('sizes', **kwargs) + if not ds.data_vars: + return self._finalize(ds, None, show) + col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) + df = _dataset_to_long_df(ds) + color_map = process_colors(colors, df['variable'].unique().tolist()) if not df.empty else None + fig = px.bar( + df, + x='variable', + y='value', + color='variable', + title=title, + facet_col=col, + facet_row=row, + animation_frame=anim, + color_discrete_map=color_map, + ) + return self._finalize(ds, fig, show) + + def effects( + self, + aspect='total', + *, + colors=None, + facet_col='auto', + facet_row='auto', + animation_frame='auto', + show: bool | None = None, + **kwargs, + ) -> PlotResult: + """Plot effects comparison. See StatisticsPlotAccessor.effects.""" + import plotly.express as px + + from .color_processing import process_colors + + ds, title = self._combine_data('effects', aspect, **kwargs) + if not ds.data_vars: + return self._finalize(ds, None, show) + col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) + + # Get the data array and convert to dataframe + da = ds[aspect] if aspect in ds else ds[next(iter(ds.data_vars))] + df = da.to_dataframe(name='value').reset_index() + + by = kwargs.get('by') + x_col = by if by else 'effect' + color_col = x_col if x_col in df.columns else None + color_map = process_colors(colors, df[color_col].unique().tolist()) if color_col else None + + fig = px.bar( + df, + x=x_col, + y='value', + color=color_col, + title=title, + facet_col=col, + facet_row=row, + animation_frame=anim, + color_discrete_map=color_map, + ) + fig.update_layout(bargap=0, bargroupgap=0) + fig.update_traces(marker_line_width=0) + return self._finalize(ds, fig, show) + + def heatmap( + self, variables, *, colors=None, facet_col='auto', animation_frame='auto', show: bool | None = None, **kwargs + ) -> PlotResult: + """Plot heatmap comparison. See StatisticsPlotAccessor.heatmap.""" + ds, _ = self._combine_data('heatmap', variables, **kwargs) + if not ds.data_vars: + return self._finalize(ds, None, show) + col, _, anim = self._resolve_facets(ds, facet_col, None, animation_frame) + da = ds[next(iter(ds.data_vars))] + fig = da.fxplot.heatmap(colors=colors, facet_col=col, animation_frame=anim) + return self._finalize(ds, fig, show) From d5d42593dffb66e54ab7aefdf5e574115e185f1c Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 21:51:21 +0100 Subject: [PATCH 25/62] Minor bugfix --- flixopt/comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flixopt/comparison.py b/flixopt/comparison.py index df6e0a3cd..9b4c69005 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -328,7 +328,7 @@ def _combine_data(self, method_name: str, *args, **kwargs) -> tuple[xr.Dataset, """Call plot method on each system and combine data. Returns (combined_data, title).""" datasets = [] title = '' - kwargs['show'] = False + kwargs = {**kwargs, 'show': False} # Don't mutate original for fs, case_name in zip(self._comp._systems, self._comp._names, strict=True): try: From 273b03c06d86d47bb243e0586209514d1f53bda0 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 21:54:54 +0100 Subject: [PATCH 26/62] Now all methods properly split kwargs and pass plotly_kwargs to the figure creation. The _DATA_KWARGS mapping defines which kwargs affect data processing - everything else passes through to plotly. --- flixopt/comparison.py | 77 +++++++++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 17 deletions(-) diff --git a/flixopt/comparison.py b/flixopt/comparison.py index 9b4c69005..c020cb01d 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -320,10 +320,30 @@ class ComparisonStatisticsPlot: with a 'case' dimension for faceting. """ + # Data-related kwargs for each method (everything else is plotly kwargs) + _DATA_KWARGS: dict[str, set[str]] = { + 'balance': {'select', 'include', 'exclude', 'unit'}, + 'carrier_balance': {'select', 'include', 'exclude', 'unit'}, + 'flows': {'start', 'end', 'component', 'select', 'unit'}, + 'storage': {'select', 'unit', 'charge_state_color'}, + 'charge_states': {'select'}, + 'duration_curve': {'select', 'normalize'}, + 'sizes': {'max_size', 'select'}, + 'effects': {'effect', 'by', 'select'}, + 'heatmap': {'select', 'reshape'}, + } + def __init__(self, statistics: ComparisonStatistics) -> None: self._stats = statistics self._comp = statistics._comp + def _split_kwargs(self, method_name: str, kwargs: dict) -> tuple[dict, dict]: + """Split kwargs into data kwargs and plotly kwargs.""" + data_keys = self._DATA_KWARGS.get(method_name, set()) + data_kwargs = {k: v for k, v in kwargs.items() if k in data_keys} + plotly_kwargs = {k: v for k, v in kwargs.items() if k not in data_keys} + return data_kwargs, plotly_kwargs + def _combine_data(self, method_name: str, *args, **kwargs) -> tuple[xr.Dataset, str]: """Call plot method on each system and combine data. Returns (combined_data, title).""" datasets = [] @@ -371,11 +391,14 @@ def balance( **kwargs, ) -> PlotResult: """Plot node balance comparison. See StatisticsPlotAccessor.balance.""" - ds, title = self._combine_data('balance', node, **kwargs) + data_kw, plotly_kw = self._split_kwargs('balance', kwargs) + ds, title = self._combine_data('balance', node, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) - fig = ds.fxplot.stacked_bar(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + fig = ds.fxplot.stacked_bar( + colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + ) return self._finalize(ds, fig, show) def carrier_balance( @@ -390,11 +413,14 @@ def carrier_balance( **kwargs, ) -> PlotResult: """Plot carrier balance comparison. See StatisticsPlotAccessor.carrier_balance.""" - ds, title = self._combine_data('carrier_balance', carrier, **kwargs) + data_kw, plotly_kw = self._split_kwargs('carrier_balance', kwargs) + ds, title = self._combine_data('carrier_balance', carrier, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) - fig = ds.fxplot.stacked_bar(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + fig = ds.fxplot.stacked_bar( + colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + ) return self._finalize(ds, fig, show) def flows( @@ -408,11 +434,14 @@ def flows( **kwargs, ) -> PlotResult: """Plot flows comparison. See StatisticsPlotAccessor.flows.""" - ds, title = self._combine_data('flows', **kwargs) + data_kw, plotly_kw = self._split_kwargs('flows', kwargs) + ds, title = self._combine_data('flows', **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) - fig = ds.fxplot.line(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + fig = ds.fxplot.line( + colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + ) return self._finalize(ds, fig, show) def storage( @@ -427,11 +456,14 @@ def storage( **kwargs, ) -> PlotResult: """Plot storage operation comparison. See StatisticsPlotAccessor.storage.""" - ds, title = self._combine_data('storage', storage, **kwargs) + data_kw, plotly_kw = self._split_kwargs('storage', kwargs) + ds, title = self._combine_data('storage', storage, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) - fig = ds.fxplot.stacked_bar(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + fig = ds.fxplot.stacked_bar( + colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + ) return self._finalize(ds, fig, show) def charge_states( @@ -446,11 +478,14 @@ def charge_states( **kwargs, ) -> PlotResult: """Plot charge states comparison. See StatisticsPlotAccessor.charge_states.""" - ds, title = self._combine_data('charge_states', storages, **kwargs) + data_kw, plotly_kw = self._split_kwargs('charge_states', kwargs) + ds, title = self._combine_data('charge_states', storages, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) - fig = ds.fxplot.line(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + fig = ds.fxplot.line( + colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + ) fig.update_yaxes(title_text='Charge State') return self._finalize(ds, fig, show) @@ -467,11 +502,14 @@ def duration_curve( **kwargs, ) -> PlotResult: """Plot duration curves comparison. See StatisticsPlotAccessor.duration_curve.""" - ds, title = self._combine_data('duration_curve', variables, normalize=normalize, **kwargs) + data_kw, plotly_kw = self._split_kwargs('duration_curve', kwargs) + ds, title = self._combine_data('duration_curve', variables, normalize=normalize, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) - fig = ds.fxplot.line(colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim) + fig = ds.fxplot.line( + colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + ) fig.update_xaxes(title_text='Duration [%]' if normalize else 'Timesteps') return self._finalize(ds, fig, show) @@ -491,7 +529,8 @@ def sizes( from .color_processing import process_colors from .statistics_accessor import _dataset_to_long_df - ds, title = self._combine_data('sizes', **kwargs) + data_kw, plotly_kw = self._split_kwargs('sizes', kwargs) + ds, title = self._combine_data('sizes', **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) @@ -507,6 +546,7 @@ def sizes( facet_row=row, animation_frame=anim, color_discrete_map=color_map, + **plotly_kw, ) return self._finalize(ds, fig, show) @@ -526,7 +566,8 @@ def effects( from .color_processing import process_colors - ds, title = self._combine_data('effects', aspect, **kwargs) + data_kw, plotly_kw = self._split_kwargs('effects', kwargs) + ds, title = self._combine_data('effects', aspect, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) @@ -535,7 +576,7 @@ def effects( da = ds[aspect] if aspect in ds else ds[next(iter(ds.data_vars))] df = da.to_dataframe(name='value').reset_index() - by = kwargs.get('by') + by = data_kw.get('by') x_col = by if by else 'effect' color_col = x_col if x_col in df.columns else None color_map = process_colors(colors, df[color_col].unique().tolist()) if color_col else None @@ -550,6 +591,7 @@ def effects( facet_row=row, animation_frame=anim, color_discrete_map=color_map, + **plotly_kw, ) fig.update_layout(bargap=0, bargroupgap=0) fig.update_traces(marker_line_width=0) @@ -559,10 +601,11 @@ def heatmap( self, variables, *, colors=None, facet_col='auto', animation_frame='auto', show: bool | None = None, **kwargs ) -> PlotResult: """Plot heatmap comparison. See StatisticsPlotAccessor.heatmap.""" - ds, _ = self._combine_data('heatmap', variables, **kwargs) + data_kw, plotly_kw = self._split_kwargs('heatmap', kwargs) + ds, _ = self._combine_data('heatmap', variables, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) col, _, anim = self._resolve_facets(ds, facet_col, None, animation_frame) da = ds[next(iter(ds.data_vars))] - fig = da.fxplot.heatmap(colors=colors, facet_col=col, animation_frame=anim) + fig = da.fxplot.heatmap(colors=colors, facet_col=col, animation_frame=anim, **plotly_kw) return self._finalize(ds, fig, show) From 9160b24580e6c5f59a908eb8a825dfd0830a083b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:11:54 +0100 Subject: [PATCH 27/62] Now all methods properly split kwargs and pass plotly_kwargs to the figure creation. The _DATA_KWARGS mapping defines which kwargs affect data processing - everything else passes through to plotly. --- CHANGELOG.md | 19 +- docs/notebooks/08e-clustering-internals.ipynb | 12 +- .../data/generate_example_systems.py | 9 + flixopt/__init__.py | 31 -- flixopt/clustering/base.py | 168 +++++++++-- flixopt/components.py | 6 +- flixopt/core.py | 2 +- flixopt/elements.py | 15 +- flixopt/features.py | 22 +- flixopt/flow_system.py | 274 ++++++++++++++---- flixopt/optimization.py | 9 +- flixopt/optimize_accessor.py | 17 +- flixopt/statistics_accessor.py | 2 +- flixopt/structure.py | 154 +++------- flixopt/transform_accessor.py | 82 ++++-- mkdocs.yml | 4 + pyproject.toml | 8 +- tests/deprecated/test_scenarios.py | 13 +- tests/test_cluster_reduce_expand.py | 15 +- tests/test_clustering/test_base.py | 1 - tests/test_clustering/test_integration.py | 132 ++++++--- tests/test_clustering_io.py | 241 +++++++++++++++ tests/test_io_conversion.py | 5 + tests/test_scenarios.py | 13 +- tests/test_sel_isel_single_selection.py | 193 ++++++++++++ 25 files changed, 1067 insertions(+), 380 deletions(-) create mode 100644 tests/test_clustering_io.py create mode 100644 tests/test_sel_isel_single_selection.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5731b0726..46ac8047b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,7 +53,7 @@ Until here --> ## [5.1.0] - Upcoming -**Summary**: Time-series clustering for faster optimization with configurable storage behavior across typical periods. +**Summary**: Time-series clustering for faster optimization with configurable storage behavior across typical periods. Improved weights API with always-normalized scenario weights. ### ✨ Added @@ -148,6 +148,23 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] Use `'cyclic'` for short-term storage like batteries or hot water tanks where only daily patterns matter. Use `'independent'` for quick estimates when storage behavior isn't critical. +### 💥 Breaking Changes + +- `FlowSystem.scenario_weights` are now always normalized to sum to 1 when set (including after `.sel()` subsetting) + +### ♻️ Changed + +- `FlowSystem.weights` returns `dict[str, xr.DataArray]` (unit weights instead of `1.0` float fallback) +- `FlowSystemDimensions` type now includes `'cluster'` + +### 🗑️ Deprecated + +- `normalize_weights` parameter in `create_model()`, `build_model()`, `optimize()` + +### 🐛 Fixed + +- `temporal_weight` and `sum_temporal()` now use consistent implementation + ### 👷 Development **New Test Suites for Clustering**: diff --git a/docs/notebooks/08e-clustering-internals.ipynb b/docs/notebooks/08e-clustering-internals.ipynb index 066ec749c..506a01ed9 100644 --- a/docs/notebooks/08e-clustering-internals.ipynb +++ b/docs/notebooks/08e-clustering-internals.ipynb @@ -287,17 +287,7 @@ ] } ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11" - } - }, + "metadata": {}, "nbformat": 4, "nbformat_minor": 5 } diff --git a/docs/notebooks/data/generate_example_systems.py b/docs/notebooks/data/generate_example_systems.py index ec645e8b2..c53322ef2 100644 --- a/docs/notebooks/data/generate_example_systems.py +++ b/docs/notebooks/data/generate_example_systems.py @@ -11,11 +11,13 @@ Run this script to regenerate the example data files. """ +import sys from pathlib import Path import numpy as np import pandas as pd +# Handle imports in different contexts (direct run, package import, mkdocs-jupyter) try: from .generate_realistic_profiles import ( ElectricityLoadGenerator, @@ -25,6 +27,13 @@ load_weather, ) except ImportError: + # Add data directory to path for mkdocs-jupyter context + try: + _data_dir = Path(__file__).parent + except NameError: + _data_dir = Path('docs/notebooks/data') + if str(_data_dir) not in sys.path: + sys.path.insert(0, str(_data_dir)) from generate_realistic_profiles import ( ElectricityLoadGenerator, GasPriceGenerator, diff --git a/flixopt/__init__.py b/flixopt/__init__.py index 64e08d7ac..6b226ea28 100644 --- a/flixopt/__init__.py +++ b/flixopt/__init__.py @@ -3,7 +3,6 @@ """ import logging -import warnings from importlib.metadata import PackageNotFoundError, version try: @@ -35,7 +34,6 @@ from .interface import InvestParameters, Piece, Piecewise, PiecewiseConversion, PiecewiseEffects, StatusParameters from .optimization import Optimization, SegmentedOptimization from .plot_result import PlotResult -from .structure import TimeSeriesWeights __all__ = [ 'TimeSeriesData', @@ -63,7 +61,6 @@ 'PiecewiseConversion', 'PiecewiseEffects', 'PlotResult', - 'TimeSeriesWeights', 'clustering', 'plotting', 'results', @@ -75,31 +72,3 @@ logger = logging.getLogger('flixopt') logger.setLevel(logging.WARNING) logger.addHandler(logging.NullHandler()) - -# === Runtime warning suppression for third-party libraries === -# These warnings are from dependencies and cannot be fixed by end users. -# They are suppressed at runtime to provide a cleaner user experience. -# These filters match the test configuration in pyproject.toml for consistency. - -# tsam: Time series aggregation library -# - UserWarning: Informational message about minimal value constraints during clustering. -warnings.filterwarnings( - 'ignore', - category=UserWarning, - message='.*minimal value.*exceeds.*', - module='tsam.timeseriesaggregation', # More specific if possible -) -# TODO: Might be able to fix it in flixopt? - -# linopy: Linear optimization library -# - UserWarning: Coordinate mismatch warnings that don't affect functionality and are expected. -warnings.filterwarnings( - 'ignore', category=UserWarning, message='Coordinates across variables not equal', module='linopy' -) -# - FutureWarning: join parameter default will change in future versions -warnings.filterwarnings( - 'ignore', - category=FutureWarning, - message="In a future version of xarray the default value for join will change from join='outer' to join='exact'", - module='linopy', -) diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index 2c442e3d5..4b31832e4 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -21,11 +21,11 @@ from typing import TYPE_CHECKING, Any import numpy as np +import pandas as pd import xarray as xr if TYPE_CHECKING: from ..color_processing import ColorType - from ..flow_system import FlowSystem from ..plot_result import PlotResult from ..statistics_accessor import SelectType @@ -98,6 +98,31 @@ def __repr__(self) -> str: f')' ) + def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: + """Create reference structure for serialization.""" + ref = {'__class__': self.__class__.__name__} + arrays = {} + + # Store DataArrays with references + arrays[str(self.cluster_order.name)] = self.cluster_order + ref['cluster_order'] = f':::{self.cluster_order.name}' + + arrays[str(self.cluster_occurrences.name)] = self.cluster_occurrences + ref['cluster_occurrences'] = f':::{self.cluster_occurrences.name}' + + # Store scalar values + if isinstance(self.n_clusters, xr.DataArray): + n_clusters_name = self.n_clusters.name or 'n_clusters' + self.n_clusters = self.n_clusters.rename(n_clusters_name) + arrays[n_clusters_name] = self.n_clusters + ref['n_clusters'] = f':::{n_clusters_name}' + else: + ref['n_clusters'] = int(self.n_clusters) + + ref['timesteps_per_cluster'] = self.timesteps_per_cluster + + return ref, arrays + @property def n_original_periods(self) -> int: """Number of original periods (before clustering).""" @@ -172,7 +197,7 @@ def get_cluster_weight_per_timestep(self) -> xr.DataArray: name='cluster_weight', ) - def plot(self, show: bool | None = None): + def plot(self, show: bool | None = None) -> PlotResult: """Plot cluster assignment visualization. Shows which cluster each original period belongs to, and the @@ -281,10 +306,9 @@ def __post_init__(self): self.timestep_mapping = self.timestep_mapping.rename('timestep_mapping') # Ensure representative_weights is a DataArray + # Can be (cluster, time) for 2D structure or (time,) for flat structure if not isinstance(self.representative_weights, xr.DataArray): - self.representative_weights = xr.DataArray( - self.representative_weights, dims=['time'], name='representative_weights' - ) + self.representative_weights = xr.DataArray(self.representative_weights, name='representative_weights') elif self.representative_weights.name is None: self.representative_weights = self.representative_weights.rename('representative_weights') @@ -304,6 +328,37 @@ def __repr__(self) -> str: f')' ) + def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: + """Create reference structure for serialization.""" + ref = {'__class__': self.__class__.__name__} + arrays = {} + + # Store DataArrays with references + arrays[str(self.timestep_mapping.name)] = self.timestep_mapping + ref['timestep_mapping'] = f':::{self.timestep_mapping.name}' + + arrays[str(self.representative_weights.name)] = self.representative_weights + ref['representative_weights'] = f':::{self.representative_weights.name}' + + # Store scalar values + if isinstance(self.n_representatives, xr.DataArray): + n_rep_name = self.n_representatives.name or 'n_representatives' + self.n_representatives = self.n_representatives.rename(n_rep_name) + arrays[n_rep_name] = self.n_representatives + ref['n_representatives'] = f':::{n_rep_name}' + else: + ref['n_representatives'] = int(self.n_representatives) + + # Store nested ClusterStructure if present + if self.cluster_structure is not None: + cs_ref, cs_arrays = self.cluster_structure._create_reference_structure() + ref['cluster_structure'] = cs_ref + arrays.update(cs_arrays) + + # Skip aggregated_data and original_data - not needed for serialization + + return ref, arrays + @property def n_original_timesteps(self) -> int: """Number of original timesteps (before aggregation).""" @@ -460,24 +515,50 @@ def validate(self) -> None: if max_idx >= n_rep: raise ValueError(f'timestep_mapping contains index {max_idx} but n_representatives is {n_rep}') - # Check weights length matches n_representatives - if len(self.representative_weights) != n_rep: - raise ValueError( - f'representative_weights has {len(self.representative_weights)} elements ' - f'but n_representatives is {n_rep}' - ) + # Check weights dimensions + # representative_weights should have (cluster,) dimension with n_clusters elements + # (plus optional period/scenario dimensions) + if self.cluster_structure is not None: + n_clusters = self.cluster_structure.n_clusters + if 'cluster' in self.representative_weights.dims: + weights_n_clusters = self.representative_weights.sizes['cluster'] + if weights_n_clusters != n_clusters: + raise ValueError( + f'representative_weights has {weights_n_clusters} clusters ' + f'but cluster_structure has {n_clusters}' + ) - # Check weights sum roughly equals original timesteps - weight_sum = float(self.representative_weights.sum().values) - n_original = self.n_original_timesteps - if abs(weight_sum - n_original) > 1e-6: - # Warning only - some aggregation methods may not preserve this exactly - import warnings + # Check weights sum roughly equals number of original periods + # (each weight is how many original periods that cluster represents) + # Sum should be checked per period/scenario slice, not across all dimensions + if self.cluster_structure is not None: + n_original_periods = self.cluster_structure.n_original_periods + # Sum over cluster dimension only (keep period/scenario if present) + weight_sum_per_slice = self.representative_weights.sum(dim='cluster') + # Check each slice + if weight_sum_per_slice.size == 1: + # Simple case: no period/scenario + weight_sum = float(weight_sum_per_slice.values) + if abs(weight_sum - n_original_periods) > 1e-6: + import warnings + + warnings.warn( + f'representative_weights sum ({weight_sum}) does not match ' + f'n_original_periods ({n_original_periods})', + stacklevel=2, + ) + else: + # Multi-dimensional: check each slice + for val in weight_sum_per_slice.values.flat: + if abs(float(val) - n_original_periods) > 1e-6: + import warnings - warnings.warn( - f'representative_weights sum ({weight_sum}) does not match n_original_timesteps ({n_original})', - stacklevel=2, - ) + warnings.warn( + f'representative_weights sum per slice ({float(val)}) does not match ' + f'n_original_periods ({n_original_periods})', + stacklevel=2, + ) + break # Only warn once class ClusteringPlotAccessor: @@ -911,7 +992,6 @@ class Clustering: Attributes: result: The ClusterResult from the aggregation backend. - original_flow_system: Reference to the FlowSystem before aggregation. backend_name: Name of the aggregation backend used (e.g., 'tsam', 'manual'). Example: @@ -923,9 +1003,23 @@ class Clustering: """ result: ClusterResult - original_flow_system: FlowSystem # FlowSystem - avoid circular import backend_name: str = 'unknown' + def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: + """Create reference structure for serialization.""" + ref = {'__class__': self.__class__.__name__} + arrays = {} + + # Store nested ClusterResult + result_ref, result_arrays = self.result._create_reference_structure() + ref['result'] = result_ref + arrays.update(result_arrays) + + # Store scalar values + ref['backend_name'] = self.backend_name + + return ref, arrays + def __repr__(self) -> str: cs = self.result.cluster_structure if cs is not None: @@ -1024,6 +1118,22 @@ def cluster_start_positions(self) -> np.ndarray: n_timesteps = self.n_clusters * self.timesteps_per_period return np.arange(0, n_timesteps, self.timesteps_per_period) + @property + def original_timesteps(self) -> pd.DatetimeIndex: + """Original timesteps before clustering. + + Derived from the 'original_time' coordinate of timestep_mapping. + + Raises: + KeyError: If 'original_time' coordinate is missing from timestep_mapping. + """ + if 'original_time' not in self.result.timestep_mapping.coords: + raise KeyError( + "timestep_mapping is missing 'original_time' coordinate. " + 'This may indicate corrupted or incompatible clustering results.' + ) + return pd.DatetimeIndex(self.result.timestep_mapping.coords['original_time'].values) + def create_cluster_structure_from_mapping( timestep_mapping: xr.DataArray, @@ -1073,3 +1183,15 @@ def create_cluster_structure_from_mapping( n_clusters=n_clusters, timesteps_per_cluster=timesteps_per_cluster, ) + + +def _register_clustering_classes(): + """Register clustering classes for IO. + + Called from flow_system.py after all imports are complete to avoid circular imports. + """ + from ..structure import CLASS_REGISTRY + + CLASS_REGISTRY['ClusterStructure'] = ClusterStructure + CLASS_REGISTRY['ClusterResult'] = ClusterResult + CLASS_REGISTRY['Clustering'] = Clustering diff --git a/flixopt/components.py b/flixopt/components.py index 0f2a6077e..390fc6f02 100644 --- a/flixopt/components.py +++ b/flixopt/components.py @@ -1257,10 +1257,12 @@ def _absolute_charge_state_bounds(self) -> tuple[xr.DataArray, xr.DataArray]: return -np.inf, np.inf elif isinstance(self.element.capacity_in_flow_hours, InvestParameters): cap_max = self.element.capacity_in_flow_hours.maximum_or_fixed_size * relative_upper_bound - return -cap_max, cap_max + # Adding 0.0 converts -0.0 to 0.0 (linopy LP writer bug workaround) + return -cap_max + 0.0, cap_max + 0.0 else: cap = self.element.capacity_in_flow_hours * relative_upper_bound - return -cap, cap + # Adding 0.0 converts -0.0 to 0.0 (linopy LP writer bug workaround) + return -cap + 0.0, cap + 0.0 def _do_modeling(self): """Create storage model with inter-cluster linking constraints. diff --git a/flixopt/core.py b/flixopt/core.py index fdcab029b..3d456fff1 100644 --- a/flixopt/core.py +++ b/flixopt/core.py @@ -15,7 +15,7 @@ logger = logging.getLogger('flixopt') -FlowSystemDimensions = Literal['time', 'period', 'scenario'] +FlowSystemDimensions = Literal['time', 'cluster', 'period', 'scenario'] """Possible dimensions of a FlowSystem.""" diff --git a/flixopt/elements.py b/flixopt/elements.py index ba2b72f80..0cee53738 100644 --- a/flixopt/elements.py +++ b/flixopt/elements.py @@ -677,14 +677,10 @@ def _do_modeling(self): self._constraint_flow_rate() # Total flow hours tracking (per period) - # Sum over all temporal dimensions (time, and cluster if present) - weighted_flow = self.flow_rate * self._model.aggregation_weight - # Get temporal_dims from aggregation_weight (not weighted_flow which has linopy's _term dim) - temporal_dims = [d for d in self._model.aggregation_weight.dims if d not in ('period', 'scenario')] ModelingPrimitives.expression_tracking_variable( model=self, name=f'{self.label_full}|total_flow_hours', - tracked_expression=weighted_flow.sum(temporal_dims), + tracked_expression=self._model.sum_temporal(self.flow_rate), bounds=( self.element.flow_hours_min if self.element.flow_hours_min is not None else 0, self.element.flow_hours_max if self.element.flow_hours_max is not None else None, @@ -841,9 +837,8 @@ def _create_bounds_for_load_factor(self): # Get the size (either from element or investment) size = self.investment.size if self.with_investment else self.element.size - # Sum over all temporal dimensions (time, and cluster if present) - temporal_dims = [d for d in self._model.aggregation_weight.dims if d not in ('period', 'scenario')] - total_hours = self._model.aggregation_weight.sum(temporal_dims) + # Total hours in the period (sum of temporal weights) + total_hours = self._model.temporal_weight.sum(self._model.temporal_dims) # Maximum load factor constraint if self.element.load_factor_max is not None: @@ -959,9 +954,7 @@ def _do_modeling(self): # Add virtual supply/demand to balance and penalty if needed if self.element.allows_imbalance: - imbalance_penalty = np.multiply( - self._model.aggregation_weight, self.element.imbalance_penalty_per_flow_hour - ) + imbalance_penalty = self.element.imbalance_penalty_per_flow_hour * self._model.timestep_duration self.virtual_supply = self.add_variables( lower=0, coords=self._model.get_coords(), short_name='virtual_supply' diff --git a/flixopt/features.py b/flixopt/features.py index e0a018a7f..289640ddd 100644 --- a/flixopt/features.py +++ b/flixopt/features.py @@ -196,22 +196,14 @@ def _do_modeling(self): inactive = self.add_variables(binary=True, short_name='inactive', coords=self._model.get_coords()) self.add_constraints(self.status + inactive == 1, short_name='complementary') - # 3. Total duration tracking using existing pattern - # Sum over all temporal dimensions (time, and cluster if present) - weighted_status = self.status * self._model.aggregation_weight - # Get temporal_dims from aggregation_weight (not weighted_status which has linopy's _term dim) - temporal_dims = [d for d in self._model.aggregation_weight.dims if d not in ('period', 'scenario')] - agg_weight_sum = self._model.aggregation_weight.sum(temporal_dims) + # 3. Total duration tracking + total_hours = self._model.temporal_weight.sum(self._model.temporal_dims) ModelingPrimitives.expression_tracking_variable( self, - tracked_expression=weighted_status.sum(temporal_dims), + tracked_expression=self._model.sum_temporal(self.status), bounds=( self.parameters.active_hours_min if self.parameters.active_hours_min is not None else 0, - self.parameters.active_hours_max - if self.parameters.active_hours_max is not None - else agg_weight_sum.max().item() - if hasattr(agg_weight_sum, 'max') - else agg_weight_sum, + self.parameters.active_hours_max if self.parameters.active_hours_max is not None else total_hours, ), short_name='active_hours', coords=['period', 'scenario'], @@ -631,10 +623,8 @@ def _do_modeling(self): # Add it to the total (cluster_weight handles cluster representation, defaults to 1.0) # Sum over all temporal dimensions (time, and cluster if present) - weighted_per_timestep = self.total_per_timestep * self._model.cluster_weight - # Get temporal_dims from total_per_timestep (linopy Variable) - its coords are the actual dims - temporal_dims = [d for d in self.total_per_timestep.dims if d not in ('period', 'scenario')] - self._eq_total.lhs -= weighted_per_timestep.sum(dim=temporal_dims) + weighted_per_timestep = self.total_per_timestep * self._model.weights.get('cluster', 1.0) + self._eq_total.lhs -= weighted_per_timestep.sum(dim=self._model.temporal_dims) def add_share( self, diff --git a/flixopt/flow_system.py b/flixopt/flow_system.py index c10a1defb..7e12029ed 100644 --- a/flixopt/flow_system.py +++ b/flixopt/flow_system.py @@ -40,11 +40,15 @@ from .clustering import Clustering from .solvers import _Solver - from .structure import TimeSeriesWeights from .types import Effect_TPS, Numeric_S, Numeric_TPS, NumericOrBool from .carrier import Carrier, CarrierContainer +# Register clustering classes for IO (deferred to avoid circular imports) +from .clustering.base import _register_clustering_classes + +_register_clustering_classes() + logger = logging.getLogger('flixopt') @@ -69,10 +73,10 @@ class FlowSystem(Interface, CompositeContainerMixin[Element]): scenario_weights: The weights of each scenario. If None, all scenarios have the same weight (normalized to 1). Period weights are always computed internally from the period index (like timestep_duration for time). The final `weights` array (accessible via `flow_system.model.objective_weights`) is computed as period_weights × normalized_scenario_weights, with normalization applied to the scenario weights by default. - cluster_weight: Weight for each timestep representing cluster representation count. - If None (default), all timesteps have weight 1.0. Used by cluster() to specify - how many original timesteps each cluster represents. Combined with timestep_duration - via aggregation_weight for proper time aggregation in clustered models. + cluster_weight: Weight for each cluster. + If None (default), all clusters have weight 1.0. Used by cluster() to specify + how many original timesteps each cluster represents. Multiply with timestep_duration + for proper time aggregation in clustered models. scenario_independent_sizes: Controls whether investment sizes are equalized across scenarios. - True: All sizes are shared/equalized across scenarios - False: All sizes are optimized separately per scenario @@ -201,10 +205,13 @@ def __init__( # Cluster weight for cluster() optimization (default 1.0) # Represents how many original timesteps each cluster represents # May have period/scenario dimensions if cluster() was used with those - self.cluster_weight = self.fit_to_model_coords( - 'cluster_weight', - np.ones(len(self.timesteps)) if cluster_weight is None else cluster_weight, - dims=['time', 'period', 'scenario'], # Gracefully ignores dims not present + self.cluster_weight: xr.DataArray | None = ( + self.fit_to_model_coords( + 'cluster_weight', + cluster_weight, + ) + if cluster_weight is not None + else None ) self.scenario_weights = scenario_weights # Use setter @@ -502,6 +509,9 @@ def _update_period_metadata( period index. This ensures period metadata stays synchronized with the actual periods after operations like selection. + When the period dimension is dropped (single value selected), this method + removes the scalar coordinate, period_weights DataArray, and cleans up attributes. + This is analogous to _update_time_metadata() for time-related metadata. Args: @@ -513,7 +523,16 @@ def _update_period_metadata( The same dataset with updated period-related attributes and data variables """ new_period_index = dataset.indexes.get('period') - if new_period_index is not None and len(new_period_index) >= 1: + + if new_period_index is None: + # Period dimension was dropped (single value selected) + if 'period' in dataset.coords: + dataset = dataset.drop_vars('period') + dataset = dataset.drop_vars(['period_weights'], errors='ignore') + dataset.attrs.pop('weight_of_last_period', None) + return dataset + + if len(new_period_index) >= 1: # Reuse stored weight_of_last_period when not explicitly overridden. # This is essential for single-period subsets where it cannot be inferred from intervals. if weight_of_last_period is None: @@ -542,6 +561,9 @@ def _update_scenario_metadata(cls, dataset: xr.Dataset) -> xr.Dataset: Recomputes or removes scenario weights. This ensures scenario metadata stays synchronized with the actual scenarios after operations like selection. + When the scenario dimension is dropped (single value selected), this method + removes the scalar coordinate, scenario_weights DataArray, and cleans up attributes. + This is analogous to _update_period_metadata() for time-related metadata. Args: @@ -551,7 +573,16 @@ def _update_scenario_metadata(cls, dataset: xr.Dataset) -> xr.Dataset: The same dataset with updated scenario-related attributes and data variables """ new_scenario_index = dataset.indexes.get('scenario') - if new_scenario_index is None or len(new_scenario_index) <= 1: + + if new_scenario_index is None: + # Scenario dimension was dropped (single value selected) + if 'scenario' in dataset.coords: + dataset = dataset.drop_vars('scenario') + dataset = dataset.drop_vars(['scenario_weights'], errors='ignore') + dataset.attrs.pop('scenario_weights', None) + return dataset + + if len(new_scenario_index) <= 1: dataset.attrs.pop('scenario_weights', None) return dataset @@ -645,13 +676,21 @@ def to_dataset(self, include_solution: bool = True) -> xr.Dataset: carriers_structure[name] = carrier_ref ds.attrs['carriers'] = json.dumps(carriers_structure) - # Include cluster info for clustered FlowSystems + # Include cluster info for clustered FlowSystems (old structure) if self.clusters is not None: ds.attrs['is_clustered'] = True ds.attrs['n_clusters'] = len(self.clusters) ds.attrs['timesteps_per_cluster'] = len(self.timesteps) ds.attrs['timestep_duration'] = float(self.timestep_duration.mean()) + # Serialize Clustering object if present (new structure) + if self.clustering is not None: + clustering_ref, clustering_arrays = self.clustering._create_reference_structure() + # Add clustering arrays with prefix + for name, arr in clustering_arrays.items(): + ds[f'clustering|{name}'] = arr + ds.attrs['clustering'] = json.dumps(clustering_ref) + # Add version info ds.attrs['flixopt_version'] = __version__ @@ -708,6 +747,11 @@ def from_dataset(cls, ds: xr.Dataset) -> FlowSystem: else None ) + # Resolve scenario_weights only if scenario dimension exists + scenario_weights = None + if ds.indexes.get('scenario') is not None and 'scenario_weights' in reference_structure: + scenario_weights = cls._resolve_dataarray_reference(reference_structure['scenario_weights'], arrays_dict) + # Create FlowSystem instance with constructor parameters flow_system = cls( timesteps=ds.indexes['time'], @@ -717,9 +761,7 @@ def from_dataset(cls, ds: xr.Dataset) -> FlowSystem: hours_of_last_timestep=reference_structure.get('hours_of_last_timestep'), hours_of_previous_timesteps=reference_structure.get('hours_of_previous_timesteps'), weight_of_last_period=reference_structure.get('weight_of_last_period'), - scenario_weights=cls._resolve_dataarray_reference(reference_structure['scenario_weights'], arrays_dict) - if 'scenario_weights' in reference_structure - else None, + scenario_weights=scenario_weights, cluster_weight=cluster_weight_for_constructor, scenario_independent_sizes=reference_structure.get('scenario_independent_sizes', True), scenario_independent_flow_rates=reference_structure.get('scenario_independent_flow_rates', False), @@ -765,6 +807,19 @@ def from_dataset(cls, ds: xr.Dataset) -> FlowSystem: carrier = cls._resolve_reference_structure(carrier_data, {}) flow_system._carriers.add(carrier) + # Restore Clustering object if present + if 'clustering' in reference_structure: + clustering_structure = json.loads(reference_structure['clustering']) + # Collect clustering arrays (prefixed with 'clustering|') + clustering_arrays = {} + for name, arr in ds.data_vars.items(): + if name.startswith('clustering|'): + # Remove 'clustering|' prefix (11 chars) + arr_name = name[11:] + clustering_arrays[arr_name] = arr + clustering = cls._resolve_reference_structure(clustering_structure, clustering_arrays) + flow_system.clustering = clustering + # Reconnect network to populate bus inputs/outputs (not stored in NetCDF). flow_system.connect_and_transform() @@ -1061,6 +1116,7 @@ def connect_and_transform(self): self._connect_network() self._register_missing_carriers() self._assign_element_colors() + for element in chain(self.components.values(), self.effects.values(), self.buses.values()): element.transform_data() @@ -1274,22 +1330,29 @@ def flow_carriers(self) -> dict[str, str]: return self._flow_carriers - def create_model(self, normalize_weights: bool = True) -> FlowSystemModel: + def create_model(self, normalize_weights: bool | None = None) -> FlowSystemModel: """ Create a linopy model from the FlowSystem. Args: - normalize_weights: Whether to automatically normalize the weights (periods and scenarios) to sum up to 1 when solving. + normalize_weights: Deprecated. Scenario weights are now always normalized in FlowSystem. """ + if normalize_weights is not None: + warnings.warn( + f'\n\nnormalize_weights parameter is deprecated and will be removed in {DEPRECATION_REMOVAL_VERSION}. ' + 'Scenario weights are now always normalized when set on FlowSystem.\n', + DeprecationWarning, + stacklevel=2, + ) if not self.connected_and_transformed: raise RuntimeError( 'FlowSystem is not connected_and_transformed. Call FlowSystem.connect_and_transform() first.' ) # System integrity was already validated in connect_and_transform() - self.model = FlowSystemModel(self, normalize_weights) + self.model = FlowSystemModel(self) return self.model - def build_model(self, normalize_weights: bool = True) -> FlowSystem: + def build_model(self, normalize_weights: bool | None = None) -> FlowSystem: """ Build the optimization model for this FlowSystem. @@ -1303,7 +1366,7 @@ def build_model(self, normalize_weights: bool = True) -> FlowSystem: before solving. Args: - normalize_weights: Whether to normalize scenario/period weights to sum to 1. + normalize_weights: Deprecated. Scenario weights are now always normalized in FlowSystem. Returns: Self, for method chaining. @@ -1313,8 +1376,15 @@ def build_model(self, normalize_weights: bool = True) -> FlowSystem: >>> print(flow_system.model.variables) # Inspect variables before solving >>> flow_system.solve(solver) """ + if normalize_weights is not None: + warnings.warn( + f'\n\nnormalize_weights parameter is deprecated and will be removed in {DEPRECATION_REMOVAL_VERSION}. ' + 'Scenario weights are now always normalized when set on FlowSystem.\n', + DeprecationWarning, + stacklevel=2, + ) self.connect_and_transform() - self.create_model(normalize_weights) + self.create_model() self.model.do_modeling() @@ -1862,27 +1932,85 @@ def storages(self) -> ElementContainer[Storage]: self._storages_cache = ElementContainer(storages, element_type_name='storages', truncate_repr=10) return self._storages_cache + @property + def dims(self) -> list[str]: + """Active dimension names. + + Returns: + List of active dimension names in order. + + Example: + >>> fs.dims + ['time'] # simple case + >>> fs_clustered.dims + ['cluster', 'time', 'period', 'scenario'] # full case + """ + result = [] + if self.clusters is not None: + result.append('cluster') + result.append('time') + if self.periods is not None: + result.append('period') + if self.scenarios is not None: + result.append('scenario') + return result + + @property + def indexes(self) -> dict[str, pd.Index]: + """Indexes for active dimensions. + + Returns: + Dict mapping dimension names to pandas Index objects. + + Example: + >>> fs.indexes['time'] + DatetimeIndex(['2024-01-01', ...], dtype='datetime64[ns]', name='time') + """ + result: dict[str, pd.Index] = {} + if self.clusters is not None: + result['cluster'] = self.clusters + result['time'] = self.timesteps + if self.periods is not None: + result['period'] = self.periods + if self.scenarios is not None: + result['scenario'] = self.scenarios + return result + + @property + def temporal_dims(self) -> list[str]: + """Temporal dimensions for summing over time. + + Returns ['time', 'cluster'] for clustered systems, ['time'] otherwise. + """ + if self.clusters is not None: + return ['time', 'cluster'] + return ['time'] + + @property + def temporal_weight(self) -> xr.DataArray: + """Combined temporal weight (timestep_duration × cluster_weight). + + Use for converting rates to totals before summing. + Note: cluster_weight is used even without a clusters dimension. + """ + # Use cluster_weight directly if set, otherwise check weights dict, fallback to 1.0 + cluster_weight = self.weights.get('cluster', self.cluster_weight if self.cluster_weight is not None else 1.0) + return self.weights['time'] * cluster_weight + @property def coords(self) -> dict[FlowSystemDimensions, pd.Index]: """Active coordinates for variable creation. + .. deprecated:: + Use :attr:`indexes` instead. + Returns a dict of dimension names to coordinate arrays. When clustered, includes 'cluster' dimension before 'time'. Returns: Dict mapping dimension names to coordinate arrays. """ - active_coords: dict[str, pd.Index] = {} - - if self.clusters is not None: - active_coords['cluster'] = self.clusters - active_coords['time'] = self.timesteps - - if self.periods is not None: - active_coords['period'] = self.periods - if self.scenarios is not None: - active_coords['scenario'] = self.scenarios - return active_coords + return self.indexes @property def _use_true_cluster_dims(self) -> bool: @@ -1928,14 +2056,15 @@ def scenario_weights(self) -> xr.DataArray | None: @scenario_weights.setter def scenario_weights(self, value: Numeric_S | None) -> None: """ - Set scenario weights. + Set scenario weights (always normalized to sum to 1). Args: - value: Scenario weights to set (will be converted to DataArray with 'scenario' dimension) - or None to clear weights. + value: Scenario weights to set (will be converted to DataArray with 'scenario' dimension + and normalized to sum to 1), or None to clear weights. Raises: ValueError: If value is not None and no scenarios are defined in the FlowSystem. + ValueError: If weights sum to zero (cannot normalize). """ if value is None: self._scenario_weights = None @@ -1947,48 +2076,65 @@ def scenario_weights(self, value: Numeric_S | None) -> None: 'Either define scenarios in FlowSystem(scenarios=...) or set scenario_weights to None.' ) - self._scenario_weights = self.fit_to_model_coords('scenario_weights', value, dims=['scenario']) + weights = self.fit_to_model_coords('scenario_weights', value, dims=['scenario']) - @property - def weights(self) -> TimeSeriesWeights: - """Unified weighting system for time series aggregation. + # Normalize to sum to 1 + norm = weights.sum('scenario') + if np.isclose(norm, 0.0).any(): + raise ValueError('scenario_weights sum to 0; cannot normalize.') + self._scenario_weights = weights / norm - Returns a TimeSeriesWeights object providing a clean, unified interface - for all weight types used in flixopt. This is the recommended way to - access weights for new code (PyPSA-inspired design). + def _unit_weight(self, dim: str) -> xr.DataArray: + """Create a unit weight DataArray (all 1.0) for a dimension.""" + index = self.indexes[dim] + return xr.DataArray( + np.ones(len(index), dtype=float), + coords={dim: index}, + dims=[dim], + name=f'{dim}_weight', + ) - The temporal weight combines timestep_duration and cluster_weight, - which is the proper weight for summing over time. + @property + def weights(self) -> dict[str, xr.DataArray]: + """Weights for active dimensions (unit weights if not explicitly set). Returns: - TimeSeriesWeights with temporal, period, and scenario weights. + Dict mapping dimension names to weight DataArrays. + Keys match :attr:`dims` and :attr:`indexes`. Example: - >>> weights = flow_system.weights - >>> weighted_total = (flow_rate * weights.temporal).sum('time') - >>> # Or use the convenience method: - >>> weighted_total = weights.sum_over_time(flow_rate) + >>> fs.weights['time'] # timestep durations + >>> fs.weights['cluster'] # cluster weights (unit if not set) """ - from .structure import TimeSeriesWeights + result: dict[str, xr.DataArray] = {'time': self.timestep_duration} + if self.clusters is not None: + result['cluster'] = self.cluster_weight if self.cluster_weight is not None else self._unit_weight('cluster') + if self.periods is not None: + result['period'] = self.period_weights if self.period_weights is not None else self._unit_weight('period') + if self.scenarios is not None: + result['scenario'] = ( + self.scenario_weights if self.scenario_weights is not None else self._unit_weight('scenario') + ) + return result - return TimeSeriesWeights( - temporal=self.timestep_duration * self.cluster_weight, - period=self.period_weights, - scenario=self._scenario_weights, - ) + def sum_temporal(self, data: xr.DataArray) -> xr.DataArray: + """Sum data over temporal dimensions with full temporal weighting. - @property - def aggregation_weight(self) -> xr.DataArray: - """Combined weight for time aggregation. + Applies both timestep_duration and cluster_weight, then sums over temporal dimensions. + Use this to convert rates to totals (e.g., flow_rate → total_energy). + + Args: + data: Data with time dimension (and optionally cluster). + Typically a rate (e.g., flow_rate in MW, status as 0/1). - Combines timestep_duration (physical duration) and cluster_weight (cluster representation). - Use this for proper time aggregation in clustered models. + Returns: + Data summed over temporal dims with full temporal weighting applied. - Note: - This is equivalent to `weights.temporal`. The unified TimeSeriesWeights - interface (via `flow_system.weights`) is recommended for new code. + Example: + >>> total_energy = fs.sum_temporal(flow_rate) # MW → MWh total + >>> active_hours = fs.sum_temporal(status) # count → hours """ - return self.timestep_duration * self.cluster_weight + return (data * self.temporal_weight).sum(self.temporal_dims) @property def is_clustered(self) -> bool: diff --git a/flixopt/optimization.py b/flixopt/optimization.py index 6a1a87ce1..0b567387f 100644 --- a/flixopt/optimization.py +++ b/flixopt/optimization.py @@ -82,7 +82,7 @@ def _initialize_optimization_common( name: str, flow_system: FlowSystem, folder: pathlib.Path | None = None, - normalize_weights: bool = True, + normalize_weights: bool | None = None, ) -> None: """ Shared initialization logic for all optimization types. @@ -95,7 +95,7 @@ def _initialize_optimization_common( name: Name of the optimization flow_system: FlowSystem to optimize folder: Directory for saving results - normalize_weights: Whether to normalize scenario weights + normalize_weights: Deprecated. Scenario weights are now always normalized in FlowSystem. """ obj.name = name @@ -106,7 +106,8 @@ def _initialize_optimization_common( ) flow_system = flow_system.copy() - obj.normalize_weights = normalize_weights + # normalize_weights is deprecated but kept for backwards compatibility + obj.normalize_weights = True # Always True now flow_system._used_in_optimization = True @@ -186,7 +187,7 @@ def do_modeling(self) -> Optimization: t_start = timeit.default_timer() self.flow_system.connect_and_transform() - self.model = self.flow_system.create_model(self.normalize_weights) + self.model = self.flow_system.create_model() self.model.do_modeling() self.durations['modeling'] = round(timeit.default_timer() - t_start, 2) diff --git a/flixopt/optimize_accessor.py b/flixopt/optimize_accessor.py index f88cdf982..7aee930a4 100644 --- a/flixopt/optimize_accessor.py +++ b/flixopt/optimize_accessor.py @@ -53,7 +53,7 @@ def __init__(self, flow_system: FlowSystem) -> None: """ self._fs = flow_system - def __call__(self, solver: _Solver, normalize_weights: bool = True) -> FlowSystem: + def __call__(self, solver: _Solver, normalize_weights: bool | None = None) -> FlowSystem: """ Build and solve the optimization model in one step. @@ -64,7 +64,7 @@ def __call__(self, solver: _Solver, normalize_weights: bool = True) -> FlowSyste Args: solver: The solver to use (e.g., HighsSolver, GurobiSolver). - normalize_weights: Whether to normalize scenario/period weights to sum to 1. + normalize_weights: Deprecated. Scenario weights are now always normalized in FlowSystem. Returns: The FlowSystem, for method chaining. @@ -85,7 +85,18 @@ def __call__(self, solver: _Solver, normalize_weights: bool = True) -> FlowSyste >>> solution = flow_system.optimize(solver).solution """ - self._fs.build_model(normalize_weights) + if normalize_weights is not None: + import warnings + + from .config import DEPRECATION_REMOVAL_VERSION + + warnings.warn( + f'\n\nnormalize_weights parameter is deprecated and will be removed in {DEPRECATION_REMOVAL_VERSION}. ' + 'Scenario weights are now always normalized when set on FlowSystem.\n', + DeprecationWarning, + stacklevel=2, + ) + self._fs.build_model() self._fs.solve(solver) return self._fs diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 1cbcbcd7d..382ed1bf0 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -736,7 +736,7 @@ def get_contributor_type(contributor: str) -> str: # For total mode, sum temporal over time (apply cluster_weight for proper weighting) # Sum over all temporal dimensions (time, and cluster if present) if mode == 'total' and current_mode == 'temporal' and 'time' in da.dims: - weighted = da * self._fs.cluster_weight + weighted = da * self._fs.weights.get('cluster', 1.0) temporal_dims = [d for d in weighted.dims if d not in ('period', 'scenario')] da = weighted.sum(temporal_dims) if share_total is None: diff --git a/flixopt/structure.py b/flixopt/structure.py index 7996565e8..4b7734199 100644 --- a/flixopt/structure.py +++ b/flixopt/structure.py @@ -43,92 +43,6 @@ CLASS_REGISTRY = {} -@dataclass -class TimeSeriesWeights: - """Unified weighting system for time series aggregation (PyPSA-inspired). - - This class provides a clean, unified interface for time series weights, - combining the various weight types used in flixopt into a single object. - - Attributes: - temporal: Combined weight for temporal operations (timestep_duration × cluster_weight). - Applied to all time-summing operations. dims: [time] or [time, period, scenario] - period: Weight for each period in multi-period optimization. - dims: [period] or None - scenario: Weight for each scenario in stochastic optimization. - dims: [scenario] or None - objective: Optional override weight for objective function calculations. - If None, uses temporal weight. dims: [time] or [time, period, scenario] - storage: Optional override weight for storage balance equations. - If None, uses temporal weight. dims: [time] or [time, period, scenario] - - Example: - >>> # Access via FlowSystem - >>> weights = flow_system.weights - >>> weighted_sum = (flow_rate * weights.temporal).sum('time') - >>> - >>> # With period/scenario weighting - >>> total = weighted_sum * weights.period * weights.scenario - - Note: - For backwards compatibility, the existing properties (cluster_weight, - timestep_duration, aggregation_weight) are still available on FlowSystem - and FlowSystemModel. - """ - - temporal: xr.DataArray - period: xr.DataArray | None = None - scenario: xr.DataArray | None = None - objective: xr.DataArray | None = None - storage: xr.DataArray | None = None - - def __post_init__(self): - """Validate weights.""" - if not isinstance(self.temporal, xr.DataArray): - raise TypeError('temporal must be an xarray DataArray') - if 'time' not in self.temporal.dims: - raise ValueError("temporal must have 'time' dimension") - - @property - def effective_objective(self) -> xr.DataArray: - """Get effective objective weight (override or temporal).""" - return self.objective if self.objective is not None else self.temporal - - @property - def effective_storage(self) -> xr.DataArray: - """Get effective storage weight (override or temporal).""" - return self.storage if self.storage is not None else self.temporal - - def sum_over_time(self, data: xr.DataArray) -> xr.DataArray: - """Sum data over time dimension with proper weighting. - - Args: - data: DataArray with 'time' dimension. - - Returns: - Data summed over time with temporal weighting applied. - """ - if 'time' not in data.dims: - return data - return (data * self.temporal).sum('time') - - def apply_period_scenario_weights(self, data: xr.DataArray) -> xr.DataArray: - """Apply period and scenario weights to data. - - Args: - data: DataArray, optionally with 'period' and/or 'scenario' dims. - - Returns: - Data with period and scenario weights applied. - """ - result = data - if self.period is not None and 'period' in data.dims: - result = result * self.period - if self.scenario is not None and 'scenario' in data.dims: - result = result * self.scenario - return result - - def register_class_for_io(cls): """Register a class for serialization/deserialization.""" name = cls.__name__ @@ -176,13 +90,11 @@ class FlowSystemModel(linopy.Model, SubmodelsMixin): Args: flow_system: The flow_system that is used to create the model. - normalize_weights: Whether to automatically normalize the weights to sum up to 1 when solving. """ - def __init__(self, flow_system: FlowSystem, normalize_weights: bool): + def __init__(self, flow_system: FlowSystem): super().__init__(force_dim_names=True) self.flow_system = flow_system - self.normalize_weights = normalize_weights self.effects: EffectCollectionModel | None = None self.submodels: Submodels = Submodels({}) @@ -314,53 +226,63 @@ def hours_of_previous_timesteps(self): return self.flow_system.hours_of_previous_timesteps @property - def cluster_weight(self) -> xr.DataArray: - """Cluster weight for cluster() optimization. + def dims(self) -> list[str]: + """Active dimension names.""" + return self.flow_system.dims - Represents how many original timesteps each cluster represents. - Default is 1.0 for all timesteps. + @property + def indexes(self) -> dict[str, pd.Index]: + """Indexes for active dimensions.""" + return self.flow_system.indexes + + @property + def weights(self) -> dict[str, xr.DataArray]: + """Weights for active dimensions (unit weights if not set). + + Scenario weights are always normalized (handled by FlowSystem). """ - return self.flow_system.cluster_weight + return self.flow_system.weights @property - def aggregation_weight(self) -> xr.DataArray: - """Combined weight for time aggregation. + def temporal_dims(self) -> list[str]: + """Temporal dimensions for summing over time. - Combines timestep_duration (physical duration) and cluster_weight (cluster representation). - Use this for proper time aggregation in clustered models. + Returns ['time', 'cluster'] for clustered systems, ['time'] otherwise. """ - return self.timestep_duration * self.cluster_weight + return self.flow_system.temporal_dims + + @property + def temporal_weight(self) -> xr.DataArray: + """Combined temporal weight (timestep_duration × cluster_weight).""" + return self.flow_system.temporal_weight + + def sum_temporal(self, data: xr.DataArray) -> xr.DataArray: + """Sum data over temporal dimensions with full temporal weighting. + + Example: + >>> total_energy = model.sum_temporal(flow_rate) + """ + return self.flow_system.sum_temporal(data) @property def scenario_weights(self) -> xr.DataArray: """ - Scenario weights of model. With optional normalization. + Scenario weights of model (always normalized, via FlowSystem). + + Returns unit weights if no scenarios defined or no explicit weights set. """ if self.flow_system.scenarios is None: return xr.DataArray(1) if self.flow_system.scenario_weights is None: - scenario_weights = xr.DataArray( - np.ones(self.flow_system.scenarios.size, dtype=float), - coords={'scenario': self.flow_system.scenarios}, - dims=['scenario'], - name='scenario_weights', - ) - else: - scenario_weights = self.flow_system.scenario_weights - - if not self.normalize_weights: - return scenario_weights + return self.flow_system._unit_weight('scenario') - norm = scenario_weights.sum('scenario') - if np.isclose(norm, 0.0).any(): - raise ValueError('FlowSystemModel.scenario_weights: weights sum to 0; cannot normalize.') - return scenario_weights / norm + return self.flow_system.scenario_weights @property def objective_weights(self) -> xr.DataArray: """ - Objective weights of model. With optional normalization of scenario weights. + Objective weights of model (period_weights × scenario_weights). """ period_weights = self.flow_system.effects.objective_effect.submodel.period_weights scenario_weights = self.scenario_weights diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 93a1e2247..3a13dbb63 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -8,6 +8,7 @@ from __future__ import annotations import logging +import warnings from collections import defaultdict from typing import TYPE_CHECKING, Any, Literal @@ -705,7 +706,10 @@ def cluster( addPeakMax=time_series_for_high_peaks or [], addPeakMin=time_series_for_low_peaks or [], ) - tsam_agg.createTypicalPeriods() + # Suppress tsam warning about minimal value constraints (informational, not actionable) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning, message='.*minimal value.*exceeds.*') + tsam_agg.createTypicalPeriods() tsam_results[key] = tsam_agg cluster_orders[key] = tsam_agg.clusterOrder @@ -801,7 +805,11 @@ def _build_cluster_weight_for_key(key: tuple) -> xr.DataArray: da = TimeSeriesData.from_dataarray(da.assign_attrs(original_da.attrs)) ds_new_vars[name] = da - ds_new = xr.Dataset(ds_new_vars, attrs=ds.attrs) + # Copy attrs but remove cluster_weight - the clustered FlowSystem gets its own + # cluster_weight set after from_dataset (original reference has wrong shape) + new_attrs = dict(ds.attrs) + new_attrs.pop('cluster_weight', None) + ds_new = xr.Dataset(ds_new_vars, attrs=new_attrs) ds_new.attrs['timesteps_per_cluster'] = timesteps_per_cluster ds_new.attrs['timestep_duration'] = dt ds_new.attrs['n_clusters'] = actual_n_clusters @@ -848,6 +856,9 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: timestep_mapping_slices = {} cluster_occurrences_slices = {} + # Use renamed timesteps as coordinates for multi-dimensional case + original_timesteps_coord = self._fs.timesteps.rename('original_time') + for p in periods: for s in scenarios: key = (p, s) @@ -855,7 +866,10 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: cluster_orders[key], dims=['original_period'], name='cluster_order' ) timestep_mapping_slices[key] = xr.DataArray( - _build_timestep_mapping_for_key(key), dims=['original_time'], name='timestep_mapping' + _build_timestep_mapping_for_key(key), + dims=['original_time'], + coords={'original_time': original_timesteps_coord}, + name='timestep_mapping', ) cluster_occurrences_slices[key] = xr.DataArray( _build_cluster_occurrences_for_key(key), dims=['cluster'], name='cluster_occurrences' @@ -874,8 +888,13 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: else: # Simple case: single (None, None) slice cluster_order_da = xr.DataArray(cluster_orders[first_key], dims=['original_period'], name='cluster_order') + # Use renamed timesteps as coordinates + original_timesteps_coord = self._fs.timesteps.rename('original_time') timestep_mapping_da = xr.DataArray( - _build_timestep_mapping_for_key(first_key), dims=['original_time'], name='timestep_mapping' + _build_timestep_mapping_for_key(first_key), + dims=['original_time'], + coords={'original_time': original_timesteps_coord}, + name='timestep_mapping', ) cluster_occurrences_da = xr.DataArray( _build_cluster_occurrences_for_key(first_key), dims=['cluster'], name='cluster_occurrences' @@ -888,16 +907,17 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: timesteps_per_cluster=timesteps_per_cluster, ) - # Create representative_weights in flat format for ClusterResult compatibility - # This repeats each cluster's weight for all timesteps within that cluster - def _build_flat_weights_for_key(key: tuple) -> xr.DataArray: + # Create representative_weights with (cluster,) dimension only + # Each cluster has one weight (same for all timesteps within it) + def _build_cluster_weights_for_key(key: tuple) -> xr.DataArray: occurrences = cluster_occurrences_all[key] - weights = np.repeat([occurrences.get(c, 1) for c in range(actual_n_clusters)], timesteps_per_cluster) - return xr.DataArray(weights, dims=['time'], name='representative_weights') + # Shape: (n_clusters,) - one weight per cluster + weights = np.array([occurrences.get(c, 1) for c in range(actual_n_clusters)]) + return xr.DataArray(weights, dims=['cluster'], name='representative_weights') - flat_weights_slices = {key: _build_flat_weights_for_key(key) for key in cluster_occurrences_all} + weights_slices = {key: _build_cluster_weights_for_key(key) for key in cluster_occurrences_all} representative_weights = self._combine_slices_to_dataarray_generic( - flat_weights_slices, ['time'], periods, scenarios, 'representative_weights' + weights_slices, ['cluster'], periods, scenarios, 'representative_weights' ) aggregation_result = ClusterResult( @@ -911,7 +931,6 @@ def _build_flat_weights_for_key(key: tuple) -> xr.DataArray: reduced_fs.clustering = Clustering( result=aggregation_result, - original_flow_system=self._fs, backend_name='tsam', ) @@ -1127,19 +1146,20 @@ def expand_solution(self) -> FlowSystem: raise ValueError('No cluster structure available for expansion.') timesteps_per_cluster = cluster_structure.timesteps_per_cluster - original_fs: FlowSystem = info.original_flow_system n_clusters = ( int(cluster_structure.n_clusters) if isinstance(cluster_structure.n_clusters, (int, np.integer)) else int(cluster_structure.n_clusters.values) ) - has_periods = original_fs.periods is not None - has_scenarios = original_fs.scenarios is not None - periods = list(original_fs.periods) if has_periods else [None] - scenarios = list(original_fs.scenarios) if has_scenarios else [None] + # Get original timesteps from clustering, but periods/scenarios from the FlowSystem + # (the clustered FlowSystem preserves the same periods/scenarios) + original_timesteps = info.original_timesteps + has_periods = self._fs.periods is not None + has_scenarios = self._fs.scenarios is not None - original_timesteps = original_fs.timesteps + periods = list(self._fs.periods) if has_periods else [None] + scenarios = list(self._fs.scenarios) if has_scenarios else [None] n_original_timesteps = len(original_timesteps) n_reduced_timesteps = n_clusters * timesteps_per_cluster @@ -1151,11 +1171,23 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: # 1. Expand FlowSystem data (with cluster_weight set to 1.0 for all timesteps) reduced_ds = self._fs.to_dataset(include_solution=False) - expanded_ds = xr.Dataset( - {name: expand_da(da) for name, da in reduced_ds.data_vars.items() if name != 'cluster_weight'}, - attrs=reduced_ds.attrs, - ) - expanded_ds.attrs['timestep_duration'] = original_fs.timestep_duration.values.tolist() + # Filter out cluster-related variables and copy attrs without clustering info + data_vars = { + name: expand_da(da) + for name, da in reduced_ds.data_vars.items() + if name != 'cluster_weight' and not name.startswith('clustering|') + } + attrs = { + k: v + for k, v in reduced_ds.attrs.items() + if k not in ('is_clustered', 'n_clusters', 'timesteps_per_cluster', 'clustering') + } + expanded_ds = xr.Dataset(data_vars, attrs=attrs) + # Compute timestep_duration from original timesteps + # Add extra timestep for duration calculation (assume same interval as last) + original_timesteps_extra = FlowSystem._create_timesteps_with_extra(original_timesteps, None) + timestep_duration = FlowSystem.calculate_timestep_duration(original_timesteps_extra) + expanded_ds.attrs['timestep_duration'] = timestep_duration.values.tolist() # Create cluster_weight with value 1.0 for all timesteps (no weighting needed for expanded) # Use _combine_slices_to_dataarray for consistent multi-dim handling @@ -1201,8 +1233,8 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps) # Apply self-discharge decay to SOC_boundary based on time within period - # Get the storage's relative_loss_per_hour from original flow system - storage = original_fs.storages[storage_name] + # Get the storage's relative_loss_per_hour from the clustered flow system + storage = self._fs.storages.get(storage_name) if storage is not None: # Time within period for each timestep (0, 1, 2, ..., timesteps_per_cluster-1, 0, 1, ...) time_within_period = np.arange(n_original_timesteps) % timesteps_per_cluster diff --git a/mkdocs.yml b/mkdocs.yml index ca94a6302..ab2e9309f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -234,6 +234,10 @@ plugins: allow_errors: false include_source: true include_requirejs: true + ignore: + - "notebooks/data/*.py" # Data generation scripts, not notebooks + execute_ignore: + - "notebooks/data/*.py" - plotly diff --git a/pyproject.toml b/pyproject.toml index 8c4749797..561f00f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,17 +209,11 @@ filterwarnings = [ "ignore:SegmentedResults is deprecated:DeprecationWarning:flixopt", "ignore:ClusteredOptimization is deprecated:DeprecationWarning:flixopt", - # === Treat flixopt warnings as errors (strict mode for our code) === + # === Treat most flixopt warnings as errors (strict mode for our code) === # This ensures we catch deprecations, future changes, and user warnings in our own code "error::DeprecationWarning:flixopt", "error::FutureWarning:flixopt", "error::UserWarning:flixopt", - - # === Third-party warnings (mirrored from __init__.py) === - "ignore:.*minimal value.*exceeds.*:UserWarning:tsam", - "ignore:Coordinates across variables not equal:UserWarning:linopy", - "ignore:.*join will change from join='outer' to join='exact'.*:FutureWarning:linopy", - "ignore:numpy\\.ndarray size changed:RuntimeWarning", "ignore:.*network visualization is still experimental.*:UserWarning:flixopt", ] diff --git a/tests/deprecated/test_scenarios.py b/tests/deprecated/test_scenarios.py index 65ea62d81..2699647ad 100644 --- a/tests/deprecated/test_scenarios.py +++ b/tests/deprecated/test_scenarios.py @@ -341,12 +341,14 @@ def test_scenarios_selection(flow_system_piecewise_conversion_scenarios): assert flow_system.scenarios.equals(flow_system_full.scenarios[0:2]) - np.testing.assert_allclose(flow_system.scenario_weights.values, flow_system_full.scenario_weights[0:2]) + # Scenario weights are always normalized - subset is re-normalized to sum to 1 + subset_weights = flow_system_full.scenario_weights[0:2] + expected_normalized = subset_weights / subset_weights.sum() + np.testing.assert_allclose(flow_system.scenario_weights.values, expected_normalized.values) - # Optimize using new API with normalize_weights=False + # Optimize using new API flow_system.optimize( fx.solvers.GurobiSolver(mip_gap=0.01, time_limit_seconds=60), - normalize_weights=False, ) # Penalty has same structure as other effects: 'Penalty' is the total, 'Penalty(temporal)' and 'Penalty(periodic)' are components @@ -769,7 +771,10 @@ def test_weights_selection(): # Verify weights are correctly sliced assert fs_subset.scenarios.equals(pd.Index(['base', 'high'], name='scenario')) - np.testing.assert_allclose(fs_subset.scenario_weights.values, custom_scenario_weights[[0, 2]]) + # Scenario weights are always normalized - subset is re-normalized to sum to 1 + subset_weights = np.array([0.3, 0.2]) # Original weights for selected scenarios + expected_normalized = subset_weights / subset_weights.sum() + np.testing.assert_allclose(fs_subset.scenario_weights.values, expected_normalized) # Verify weights are 1D with just scenario dimension (no period dimension) assert fs_subset.scenario_weights.dims == ('scenario',) diff --git a/tests/test_cluster_reduce_expand.py b/tests/test_cluster_reduce_expand.py index af2864563..7072fe22e 100644 --- a/tests/test_cluster_reduce_expand.py +++ b/tests/test_cluster_reduce_expand.py @@ -292,8 +292,9 @@ def test_cluster_with_scenarios(timesteps_8_days, scenarios_2): assert info is not None assert info.result.cluster_structure is not None assert info.result.cluster_structure.n_clusters == 2 - # Original FlowSystem had scenarios - assert info.original_flow_system.scenarios is not None + # Clustered FlowSystem preserves scenarios + assert fs_reduced.scenarios is not None + assert len(fs_reduced.scenarios) == 2 def test_cluster_and_expand_with_scenarios(solver_fixture, timesteps_8_days, scenarios_2): @@ -465,8 +466,8 @@ def test_storage_cluster_mode_intercluster_cyclic(self, solver_fixture, timestep assert 'cluster_boundary' in soc_boundary.dims # First and last SOC_boundary values should be equal (cyclic constraint) - first_soc = float(soc_boundary.isel(cluster_boundary=0).values) - last_soc = float(soc_boundary.isel(cluster_boundary=-1).values) + first_soc = soc_boundary.isel(cluster_boundary=0).item() + last_soc = soc_boundary.isel(cluster_boundary=-1).item() assert_allclose(first_soc, last_soc, rtol=1e-6) @@ -543,15 +544,15 @@ def test_expanded_charge_state_matches_manual_calculation(self, solver_fixture, # Manual verification for first few timesteps of first period p = 0 # First period cluster = int(cluster_order[p]) - soc_b = float(soc_boundary.isel(cluster_boundary=p).values) + soc_b = soc_boundary.isel(cluster_boundary=p).item() for t in [0, 5, 12, 23]: global_t = p * timesteps_per_cluster + t - delta_e = float(cs_clustered.isel(cluster=cluster, time=t).values) + delta_e = cs_clustered.isel(cluster=cluster, time=t).item() decay = (1 - loss_rate) ** t expected = soc_b * decay + delta_e expected_clipped = max(0.0, expected) - actual = float(cs_expanded.isel(time=global_t).values) + actual = cs_expanded.isel(time=global_t).item() assert_allclose( actual, diff --git a/tests/test_clustering/test_base.py b/tests/test_clustering/test_base.py index a6c4d8cc7..9c63f25f6 100644 --- a/tests/test_clustering/test_base.py +++ b/tests/test_clustering/test_base.py @@ -152,7 +152,6 @@ def test_creation(self): info = Clustering( result=result, - original_flow_system=None, # Would be FlowSystem in practice backend_name='tsam', ) diff --git a/tests/test_clustering/test_integration.py b/tests/test_clustering/test_integration.py index 587e39160..2d04a51c1 100644 --- a/tests/test_clustering/test_integration.py +++ b/tests/test_clustering/test_integration.py @@ -5,85 +5,121 @@ import pytest import xarray as xr -from flixopt import FlowSystem, TimeSeriesWeights +from flixopt import FlowSystem -class TestTimeSeriesWeights: - """Tests for TimeSeriesWeights class.""" +class TestWeights: + """Tests for FlowSystem.weights dict property.""" - def test_creation(self): - """Test TimeSeriesWeights creation.""" - temporal = xr.DataArray([1.0, 1.0, 1.0], dims=['time']) - weights = TimeSeriesWeights(temporal=temporal) + def test_weights_is_dict(self): + """Test weights returns a dict.""" + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) + weights = fs.weights + + assert isinstance(weights, dict) + assert 'time' in weights + + def test_time_weight(self): + """Test weights['time'] returns timestep_duration.""" + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) + weights = fs.weights + + # For hourly data, timestep_duration is 1.0 + assert float(weights['time'].mean()) == 1.0 + + def test_cluster_not_in_weights_when_non_clustered(self): + """Test weights doesn't have 'cluster' key for non-clustered systems.""" + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) + weights = fs.weights + + # Non-clustered: 'cluster' not in weights + assert 'cluster' not in weights - assert 'time' in weights.temporal.dims - assert float(weights.temporal.sum().values) == 3.0 + def test_temporal_dims_non_clustered(self): + """Test temporal_dims is ['time'] for non-clustered systems.""" + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) + + assert fs.temporal_dims == ['time'] - def test_invalid_no_time_dim(self): - """Test error when temporal has no time dimension.""" - temporal = xr.DataArray([1.0, 1.0], dims=['other']) + def test_temporal_weight(self): + """Test temporal_weight returns time * cluster.""" + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) - with pytest.raises(ValueError, match='time'): - TimeSeriesWeights(temporal=temporal) + expected = fs.weights['time'] * fs.weights.get('cluster', 1.0) + xr.testing.assert_equal(fs.temporal_weight, expected) - def test_sum_over_time(self): - """Test sum_over_time convenience method.""" - temporal = xr.DataArray([2.0, 3.0, 1.0], dims=['time'], coords={'time': [0, 1, 2]}) - weights = TimeSeriesWeights(temporal=temporal) + def test_sum_temporal(self): + """Test sum_temporal applies full temporal weighting (time * cluster) and sums.""" + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=3, freq='h')) - data = xr.DataArray([10.0, 20.0, 30.0], dims=['time'], coords={'time': [0, 1, 2]}) - result = weights.sum_over_time(data) + # Input is a rate (e.g., flow_rate in MW) + data = xr.DataArray([10.0, 20.0, 30.0], dims=['time'], coords={'time': fs.timesteps}) - # 10*2 + 20*3 + 30*1 = 20 + 60 + 30 = 110 - assert float(result.values) == 110.0 + result = fs.sum_temporal(data) - def test_effective_objective(self): - """Test effective_objective property.""" - temporal = xr.DataArray([1.0, 1.0], dims=['time']) - objective = xr.DataArray([2.0, 2.0], dims=['time']) + # For hourly non-clustered: temporal = time * cluster = 1.0 * 1.0 = 1.0 + # result = sum(data * temporal) = sum(data) = 60 + assert float(result.values) == 60.0 - # Without override - weights1 = TimeSeriesWeights(temporal=temporal) - assert np.array_equal(weights1.effective_objective.values, temporal.values) - # With override - weights2 = TimeSeriesWeights(temporal=temporal, objective=objective) - assert np.array_equal(weights2.effective_objective.values, objective.values) +class TestFlowSystemDimsIndexesWeights: + """Tests for FlowSystem.dims, .indexes, .weights properties.""" + def test_dims_property(self): + """Test that FlowSystem.dims returns active dimension names.""" + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) -class TestFlowSystemWeightsProperty: - """Tests for FlowSystem.weights property.""" + dims = fs.dims + assert dims == ['time'] - def test_weights_property_exists(self): - """Test that FlowSystem has weights property.""" + def test_indexes_property(self): + """Test that FlowSystem.indexes returns active indexes.""" fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) - weights = fs.weights - assert isinstance(weights, TimeSeriesWeights) + indexes = fs.indexes + assert isinstance(indexes, dict) + assert 'time' in indexes + assert len(indexes['time']) == 24 - def test_weights_temporal_equals_aggregation_weight(self): - """Test that weights.temporal equals aggregation_weight.""" + def test_weights_keys_match_dims(self): + """Test that weights.keys() is subset of dims (only 'time' for simple case).""" fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) - weights = fs.weights - aggregation_weight = fs.aggregation_weight + # For non-clustered, weights only has 'time' + assert set(fs.weights.keys()) == {'time'} - np.testing.assert_array_almost_equal(weights.temporal.values, aggregation_weight.values) + def test_temporal_weight_calculation(self): + """Test that temporal_weight = timestep_duration * cluster_weight.""" + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=24, freq='h')) + + expected = fs.timestep_duration * 1.0 # cluster is 1.0 for non-clustered + + np.testing.assert_array_almost_equal(fs.temporal_weight.values, expected.values) def test_weights_with_cluster_weight(self): - """Test weights property includes cluster_weight.""" + """Test weights property includes cluster_weight when provided.""" # Create FlowSystem with custom cluster_weight timesteps = pd.date_range('2024-01-01', periods=24, freq='h') - cluster_weight = np.array([2.0] * 12 + [1.0] * 12) # First 12h weighted 2x + cluster_weight = xr.DataArray( + np.array([2.0] * 12 + [1.0] * 12), + dims=['time'], + coords={'time': timesteps}, + ) fs = FlowSystem(timesteps=timesteps, cluster_weight=cluster_weight) weights = fs.weights - # temporal = timestep_duration * cluster_weight - # timestep_duration is 1h for all, so temporal = cluster_weight + # cluster weight should be in weights (FlowSystem has cluster_weight set) + # But note: 'cluster' only appears in weights if clusters dimension exists + # Since we didn't set clusters, 'cluster' won't be in weights + # The cluster_weight is applied via temporal_weight + assert 'cluster' not in weights # No cluster dimension + + # temporal_weight = timestep_duration * cluster_weight + # timestep_duration is 1h for all expected = 1.0 * cluster_weight - np.testing.assert_array_almost_equal(weights.temporal.values, expected) + np.testing.assert_array_almost_equal(fs.temporal_weight.values, expected.values) class TestClusterMethod: diff --git a/tests/test_clustering_io.py b/tests/test_clustering_io.py new file mode 100644 index 000000000..483cdc447 --- /dev/null +++ b/tests/test_clustering_io.py @@ -0,0 +1,241 @@ +"""Tests for clustering serialization and deserialization.""" + +import numpy as np +import pandas as pd +import pytest + +import flixopt as fx + + +@pytest.fixture +def simple_system_24h(): + """Create a simple flow system with 24 hourly timesteps.""" + timesteps = pd.date_range('2023-01-01', periods=24, freq='h') + + fs = fx.FlowSystem(timesteps) + fs.add_elements( + fx.Bus('heat'), + fx.Effect('costs', unit='EUR', description='costs', is_objective=True, is_standard=True), + ) + fs.add_elements( + fx.Sink('demand', inputs=[fx.Flow('in', bus='heat', fixed_relative_profile=np.ones(24), size=10)]), + fx.Source('source', outputs=[fx.Flow('out', bus='heat', size=50, effects_per_flow_hour={'costs': 0.05})]), + ) + return fs + + +@pytest.fixture +def simple_system_8_days(): + """Create a simple flow system with 8 days of hourly timesteps.""" + timesteps = pd.date_range('2023-01-01', periods=8 * 24, freq='h') + + # Create varying demand profile with different patterns for different days + # 4 "weekdays" with high demand, 4 "weekend" days with low demand + hourly_pattern = np.sin(np.linspace(0, 2 * np.pi, 24)) * 0.5 + 0.5 + weekday_profile = hourly_pattern * 1.5 # Higher demand + weekend_profile = hourly_pattern * 0.5 # Lower demand + demand_profile = np.concatenate( + [ + weekday_profile, + weekday_profile, + weekday_profile, + weekday_profile, + weekend_profile, + weekend_profile, + weekend_profile, + weekend_profile, + ] + ) + + fs = fx.FlowSystem(timesteps) + fs.add_elements( + fx.Bus('heat'), + fx.Effect('costs', unit='EUR', description='costs', is_objective=True, is_standard=True), + ) + fs.add_elements( + fx.Sink('demand', inputs=[fx.Flow('in', bus='heat', fixed_relative_profile=demand_profile, size=10)]), + fx.Source('source', outputs=[fx.Flow('out', bus='heat', size=50, effects_per_flow_hour={'costs': 0.05})]), + ) + return fs + + +class TestClusteringRoundtrip: + """Test that clustering survives dataset roundtrip.""" + + def test_clustering_to_dataset_has_clustering_attrs(self, simple_system_8_days): + """Clustered FlowSystem dataset should have clustering info.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + ds = fs_clustered.to_dataset(include_solution=False) + + # Check that clustering attrs are present + assert 'clustering' in ds.attrs + + # Check that clustering arrays are present with prefix + clustering_vars = [name for name in ds.data_vars if name.startswith('clustering|')] + assert len(clustering_vars) > 0 + + def test_clustering_roundtrip_preserves_clustering_object(self, simple_system_8_days): + """Clustering object should be restored after roundtrip.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + # Roundtrip + ds = fs_clustered.to_dataset(include_solution=False) + fs_restored = fx.FlowSystem.from_dataset(ds) + + # Clustering should be restored + assert fs_restored.clustering is not None + assert fs_restored.clustering.backend_name == 'tsam' + + def test_clustering_roundtrip_preserves_n_clusters(self, simple_system_8_days): + """Number of clusters should be preserved after roundtrip.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + ds = fs_clustered.to_dataset(include_solution=False) + fs_restored = fx.FlowSystem.from_dataset(ds) + + assert fs_restored.clustering.n_clusters == 2 + + def test_clustering_roundtrip_preserves_timesteps_per_cluster(self, simple_system_8_days): + """Timesteps per cluster should be preserved after roundtrip.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + ds = fs_clustered.to_dataset(include_solution=False) + fs_restored = fx.FlowSystem.from_dataset(ds) + + assert fs_restored.clustering.timesteps_per_cluster == 24 + + def test_clustering_roundtrip_preserves_original_timesteps(self, simple_system_8_days): + """Original timesteps should be preserved after roundtrip.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + original_timesteps = fs_clustered.clustering.original_timesteps + + ds = fs_clustered.to_dataset(include_solution=False) + fs_restored = fx.FlowSystem.from_dataset(ds) + + pd.testing.assert_index_equal(fs_restored.clustering.original_timesteps, original_timesteps) + + def test_clustering_roundtrip_preserves_timestep_mapping(self, simple_system_8_days): + """Timestep mapping should be preserved after roundtrip.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + original_mapping = fs_clustered.clustering.timestep_mapping.values.copy() + + ds = fs_clustered.to_dataset(include_solution=False) + fs_restored = fx.FlowSystem.from_dataset(ds) + + np.testing.assert_array_equal(fs_restored.clustering.timestep_mapping.values, original_mapping) + + +class TestClusteringWithSolutionRoundtrip: + """Test that clustering with solution survives roundtrip.""" + + def test_expand_solution_after_roundtrip(self, simple_system_8_days, solver_fixture): + """expand_solution should work after loading from dataset.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + # Solve + fs_clustered.optimize(solver_fixture) + + # Roundtrip + ds = fs_clustered.to_dataset(include_solution=True) + fs_restored = fx.FlowSystem.from_dataset(ds) + + # expand_solution should work + fs_expanded = fs_restored.transform.expand_solution() + + # Check expanded FlowSystem has correct number of timesteps + assert len(fs_expanded.timesteps) == 8 * 24 + + def test_expand_solution_after_netcdf_roundtrip(self, simple_system_8_days, tmp_path, solver_fixture): + """expand_solution should work after loading from NetCDF file.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + # Solve + fs_clustered.optimize(solver_fixture) + + # Save to NetCDF + nc_path = tmp_path / 'clustered.nc' + fs_clustered.to_netcdf(nc_path) + + # Load from NetCDF + fs_restored = fx.FlowSystem.from_netcdf(nc_path) + + # expand_solution should work + fs_expanded = fs_restored.transform.expand_solution() + + # Check expanded FlowSystem has correct number of timesteps + assert len(fs_expanded.timesteps) == 8 * 24 + + +class TestClusteringDerivedProperties: + """Test derived properties on Clustering object.""" + + def test_original_timesteps_property(self, simple_system_8_days): + """original_timesteps property should return correct DatetimeIndex.""" + fs = simple_system_8_days + original_timesteps = fs.timesteps + + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + # Check values are equal (name attribute may differ) + pd.testing.assert_index_equal( + fs_clustered.clustering.original_timesteps, + original_timesteps, + check_names=False, + ) + + def test_simple_system_has_no_periods_or_scenarios(self, simple_system_8_days): + """Clustered simple system should preserve that it has no periods/scenarios.""" + fs = simple_system_8_days + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + # FlowSystem without periods/scenarios should remain so after clustering + assert fs_clustered.periods is None + assert fs_clustered.scenarios is None + + +class TestClusteringWithScenarios: + """Test clustering IO with scenarios.""" + + @pytest.fixture + def system_with_scenarios(self): + """Create a flow system with scenarios.""" + timesteps = pd.date_range('2023-01-01', periods=4 * 24, freq='h') + scenarios = pd.Index(['Low', 'High'], name='scenario') + + # Create varying demand profile for clustering + demand_profile = np.tile(np.sin(np.linspace(0, 2 * np.pi, 24)) * 0.5 + 0.5, 4) + + fs = fx.FlowSystem(timesteps, scenarios=scenarios) + fs.add_elements( + fx.Bus('heat'), + fx.Effect('costs', unit='EUR', description='costs', is_objective=True, is_standard=True), + ) + fs.add_elements( + fx.Sink('demand', inputs=[fx.Flow('in', bus='heat', fixed_relative_profile=demand_profile, size=10)]), + fx.Source('source', outputs=[fx.Flow('out', bus='heat', size=50, effects_per_flow_hour={'costs': 0.05})]), + ) + return fs + + def test_clustering_roundtrip_preserves_scenarios(self, system_with_scenarios): + """Scenarios should be preserved after clustering and roundtrip.""" + fs = system_with_scenarios + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + ds = fs_clustered.to_dataset(include_solution=False) + fs_restored = fx.FlowSystem.from_dataset(ds) + + # Scenarios should be preserved in the FlowSystem itself + pd.testing.assert_index_equal( + fs_restored.scenarios, + pd.Index(['Low', 'High'], name='scenario'), + check_names=False, + ) diff --git a/tests/test_io_conversion.py b/tests/test_io_conversion.py index 33bda8c91..7775f987a 100644 --- a/tests/test_io_conversion.py +++ b/tests/test_io_conversion.py @@ -762,6 +762,11 @@ def test_v4_reoptimized_objective_matches_original(self, result_name): new_objective = float(fs.solution['objective'].item()) new_effect_total = float(fs.solution[objective_effect_label].sum().item()) + # Skip comparison for scenarios test case - scenario weights are now always normalized, + # which changes the objective value when loading old results with non-normalized weights + if result_name == '04_scenarios': + pytest.skip('Scenario weights are now always normalized - old results have different weights') + # Verify objective matches (within tolerance) assert new_objective == pytest.approx(old_objective, rel=1e-5, abs=1), ( f'Objective mismatch for {result_name}: new={new_objective}, old={old_objective}' diff --git a/tests/test_scenarios.py b/tests/test_scenarios.py index 65ea62d81..2699647ad 100644 --- a/tests/test_scenarios.py +++ b/tests/test_scenarios.py @@ -341,12 +341,14 @@ def test_scenarios_selection(flow_system_piecewise_conversion_scenarios): assert flow_system.scenarios.equals(flow_system_full.scenarios[0:2]) - np.testing.assert_allclose(flow_system.scenario_weights.values, flow_system_full.scenario_weights[0:2]) + # Scenario weights are always normalized - subset is re-normalized to sum to 1 + subset_weights = flow_system_full.scenario_weights[0:2] + expected_normalized = subset_weights / subset_weights.sum() + np.testing.assert_allclose(flow_system.scenario_weights.values, expected_normalized.values) - # Optimize using new API with normalize_weights=False + # Optimize using new API flow_system.optimize( fx.solvers.GurobiSolver(mip_gap=0.01, time_limit_seconds=60), - normalize_weights=False, ) # Penalty has same structure as other effects: 'Penalty' is the total, 'Penalty(temporal)' and 'Penalty(periodic)' are components @@ -769,7 +771,10 @@ def test_weights_selection(): # Verify weights are correctly sliced assert fs_subset.scenarios.equals(pd.Index(['base', 'high'], name='scenario')) - np.testing.assert_allclose(fs_subset.scenario_weights.values, custom_scenario_weights[[0, 2]]) + # Scenario weights are always normalized - subset is re-normalized to sum to 1 + subset_weights = np.array([0.3, 0.2]) # Original weights for selected scenarios + expected_normalized = subset_weights / subset_weights.sum() + np.testing.assert_allclose(fs_subset.scenario_weights.values, expected_normalized) # Verify weights are 1D with just scenario dimension (no period dimension) assert fs_subset.scenario_weights.dims == ('scenario',) diff --git a/tests/test_sel_isel_single_selection.py b/tests/test_sel_isel_single_selection.py new file mode 100644 index 000000000..4d84ced51 --- /dev/null +++ b/tests/test_sel_isel_single_selection.py @@ -0,0 +1,193 @@ +"""Tests for sel/isel with single period/scenario selection.""" + +import numpy as np +import pandas as pd +import pytest + +import flixopt as fx + + +@pytest.fixture +def fs_with_scenarios(): + """FlowSystem with scenarios for testing single selection.""" + timesteps = pd.date_range('2023-01-01', periods=24, freq='h') + scenarios = pd.Index(['A', 'B', 'C'], name='scenario') + scenario_weights = np.array([0.5, 0.3, 0.2]) + + fs = fx.FlowSystem(timesteps, scenarios=scenarios, scenario_weights=scenario_weights) + fs.add_elements( + fx.Bus('heat'), + fx.Effect('costs', unit='EUR', description='costs', is_objective=True, is_standard=True), + ) + fs.add_elements( + fx.Sink('demand', inputs=[fx.Flow('in', bus='heat', fixed_relative_profile=np.ones(24), size=10)]), + fx.Source('source', outputs=[fx.Flow('out', bus='heat', size=50, effects_per_flow_hour={'costs': 0.05})]), + ) + return fs + + +@pytest.fixture +def fs_with_periods(): + """FlowSystem with periods for testing single selection.""" + timesteps = pd.date_range('2023-01-01', periods=24, freq='h') + periods = pd.Index([2020, 2030, 2040], name='period') + + fs = fx.FlowSystem(timesteps, periods=periods, weight_of_last_period=10) + fs.add_elements( + fx.Bus('heat'), + fx.Effect('costs', unit='EUR', description='costs', is_objective=True, is_standard=True), + ) + fs.add_elements( + fx.Sink('demand', inputs=[fx.Flow('in', bus='heat', fixed_relative_profile=np.ones(24), size=10)]), + fx.Source('source', outputs=[fx.Flow('out', bus='heat', size=50, effects_per_flow_hour={'costs': 0.05})]), + ) + return fs + + +@pytest.fixture +def fs_with_periods_and_scenarios(): + """FlowSystem with both periods and scenarios.""" + timesteps = pd.date_range('2023-01-01', periods=24, freq='h') + periods = pd.Index([2020, 2030], name='period') + scenarios = pd.Index(['Low', 'High'], name='scenario') + + fs = fx.FlowSystem(timesteps, periods=periods, scenarios=scenarios, weight_of_last_period=10) + fs.add_elements( + fx.Bus('heat'), + fx.Effect('costs', unit='EUR', description='costs', is_objective=True, is_standard=True), + ) + fs.add_elements( + fx.Sink('demand', inputs=[fx.Flow('in', bus='heat', fixed_relative_profile=np.ones(24), size=10)]), + fx.Source('source', outputs=[fx.Flow('out', bus='heat', size=50, effects_per_flow_hour={'costs': 0.05})]), + ) + return fs + + +class TestIselSingleScenario: + """Test isel with single scenario selection.""" + + def test_isel_single_scenario_drops_dimension(self, fs_with_scenarios): + """Selecting a single scenario with isel should drop the scenario dimension.""" + fs_selected = fs_with_scenarios.transform.isel(scenario=0) + + assert fs_selected.scenarios is None + assert 'scenario' not in fs_selected.to_dataset().dims + + def test_isel_single_scenario_removes_scenario_weights(self, fs_with_scenarios): + """scenario_weights should be removed when scenario dimension is dropped.""" + fs_selected = fs_with_scenarios.transform.isel(scenario=0) + + ds = fs_selected.to_dataset() + assert 'scenario_weights' not in ds.data_vars + assert 'scenario_weights' not in ds.attrs + + def test_isel_single_scenario_preserves_time(self, fs_with_scenarios): + """Time dimension should be preserved.""" + fs_selected = fs_with_scenarios.transform.isel(scenario=0) + + assert len(fs_selected.timesteps) == 24 + + def test_isel_single_scenario_roundtrip(self, fs_with_scenarios): + """FlowSystem should survive to_dataset/from_dataset roundtrip after single selection.""" + fs_selected = fs_with_scenarios.transform.isel(scenario=0) + + ds = fs_selected.to_dataset() + fs_restored = fx.FlowSystem.from_dataset(ds) + + assert fs_restored.scenarios is None + assert len(fs_restored.timesteps) == 24 + + +class TestSelSingleScenario: + """Test sel with single scenario selection.""" + + def test_sel_single_scenario_drops_dimension(self, fs_with_scenarios): + """Selecting a single scenario with sel should drop the scenario dimension.""" + fs_selected = fs_with_scenarios.transform.sel(scenario='B') + + assert fs_selected.scenarios is None + + +class TestIselSinglePeriod: + """Test isel with single period selection.""" + + def test_isel_single_period_drops_dimension(self, fs_with_periods): + """Selecting a single period with isel should drop the period dimension.""" + fs_selected = fs_with_periods.transform.isel(period=0) + + assert fs_selected.periods is None + assert 'period' not in fs_selected.to_dataset().dims + + def test_isel_single_period_removes_period_weights(self, fs_with_periods): + """period_weights should be removed when period dimension is dropped.""" + fs_selected = fs_with_periods.transform.isel(period=0) + + ds = fs_selected.to_dataset() + assert 'period_weights' not in ds.data_vars + assert 'weight_of_last_period' not in ds.attrs + + def test_isel_single_period_roundtrip(self, fs_with_periods): + """FlowSystem should survive roundtrip after single period selection.""" + fs_selected = fs_with_periods.transform.isel(period=0) + + ds = fs_selected.to_dataset() + fs_restored = fx.FlowSystem.from_dataset(ds) + + assert fs_restored.periods is None + + +class TestSelSinglePeriod: + """Test sel with single period selection.""" + + def test_sel_single_period_drops_dimension(self, fs_with_periods): + """Selecting a single period with sel should drop the period dimension.""" + fs_selected = fs_with_periods.transform.sel(period=2030) + + assert fs_selected.periods is None + + +class TestMixedSelection: + """Test mixed selections (single + multiple).""" + + def test_single_period_multiple_scenarios(self, fs_with_periods_and_scenarios): + """Single period but multiple scenarios should only drop period.""" + fs_selected = fs_with_periods_and_scenarios.transform.isel(period=0) + + assert fs_selected.periods is None + assert fs_selected.scenarios is not None + assert len(fs_selected.scenarios) == 2 + + def test_multiple_periods_single_scenario(self, fs_with_periods_and_scenarios): + """Multiple periods but single scenario should only drop scenario.""" + fs_selected = fs_with_periods_and_scenarios.transform.isel(scenario=0) + + assert fs_selected.periods is not None + assert len(fs_selected.periods) == 2 + assert fs_selected.scenarios is None + + def test_single_period_single_scenario(self, fs_with_periods_and_scenarios): + """Single period and single scenario should drop both.""" + fs_selected = fs_with_periods_and_scenarios.transform.isel(period=0, scenario=0) + + assert fs_selected.periods is None + assert fs_selected.scenarios is None + + +class TestSliceSelection: + """Test that slice selection preserves dimensions.""" + + def test_slice_scenarios_preserves_dimension(self, fs_with_scenarios): + """Slice selection should preserve dimension even with 1 element.""" + # Select a slice that results in 2 elements + fs_selected = fs_with_scenarios.transform.isel(scenario=slice(0, 2)) + + assert fs_selected.scenarios is not None + assert len(fs_selected.scenarios) == 2 + + def test_list_selection_preserves_dimension(self, fs_with_scenarios): + """List selection should preserve dimension even with 1 element.""" + fs_selected = fs_with_scenarios.transform.isel(scenario=[0]) + + # List selection should preserve dimension + assert fs_selected.scenarios is not None + assert len(fs_selected.scenarios) == 1 From 85911e8e3a93fc0e044563fc0c7ed9d2d355e5f8 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:29:39 +0100 Subject: [PATCH 28/62] Improve documentation and improve CHANGELOG.md --- CHANGELOG.md | 70 ++++++++++++++++++++++++++++- docs/user-guide/results-plotting.md | 3 ++ mkdocs.yml | 7 ++- 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f60497d80..191a8c28c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,7 +53,7 @@ Until here --> ## [5.1.0] - Upcoming -**Summary**: Time-series clustering for faster optimization with configurable storage behavior across typical periods. Improved weights API with always-normalized scenario weights. +**Summary**: Major feature release introducing time-series clustering for faster optimization and the new `fxplot` accessor for universal xarray plotting. Includes configurable storage behavior across typical periods and improved weights API. ### ✨ Added @@ -121,6 +121,44 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] Use `'cyclic'` for short-term storage like batteries or hot water tanks where only daily patterns matter. Use `'independent'` for quick estimates when storage behavior isn't critical. +**FXPlot Accessor**: New global xarray accessors for universal plotting with automatic faceting and smart dimension handling. Works on any xarray Dataset, not just flixopt results. + +```python +import flixopt as fx # Registers accessors automatically + +# Plot any xarray Dataset with automatic faceting +dataset.fxplot.bar(x='component') +dataset.fxplot.area(x='time') +dataset.fxplot.heatmap(x='time', y='component') +dataset.fxplot.line(x='time', facet_col='scenario') + +# DataArray support +data_array.fxplot.line() + +# Statistics transformations +dataset.fxstats.to_duration_curve() +``` + +**Available Plot Methods**: + +| Method | Description | +|--------|-------------| +| `.fxplot.bar()` | Grouped bar charts | +| `.fxplot.stacked_bar()` | Stacked bar charts | +| `.fxplot.line()` | Line charts with faceting | +| `.fxplot.area()` | Stacked area charts | +| `.fxplot.heatmap()` | Heatmap visualizations | +| `.fxplot.scatter()` | Scatter plots | +| `.fxplot.pie()` | Pie charts with faceting | +| `.fxstats.to_duration_curve()` | Transform to duration curve format | + +**Key Features**: + +- **Auto-faceting**: Automatically assigns extra dimensions (period, scenario, cluster) to `facet_col`, `facet_row`, or `animation_frame` +- **Smart x-axis**: Intelligently selects x dimension based on priority (time > duration > period > scenario) +- **Universal**: Works on any xarray Dataset/DataArray, not limited to flixopt +- **Configurable**: Customize via `CONFIG.Plotting` (colorscales, facet columns, line shapes) + ### 💥 Breaking Changes - `FlowSystem.scenario_weights` are now always normalized to sum to 1 when set (including after `.sel()` subsetting) @@ -134,10 +172,35 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] - `normalize_weights` parameter in `create_model()`, `build_model()`, `optimize()` +**Topology method name simplifications** (old names still work with deprecation warnings, removal in v6.0.0): + +| Old (v5.0.0) | New (v5.1.0) | +|--------------|--------------| +| `topology.plot_network()` | `topology.plot()` | +| `topology.start_network_app()` | `topology.start_app()` | +| `topology.stop_network_app()` | `topology.stop_app()` | +| `topology.network_infos()` | `topology.infos()` | + +Note: `topology.plot()` now renders a Sankey diagram. The old PyVis visualization is available via `topology.plot_legacy()`. + ### 🐛 Fixed - `temporal_weight` and `sum_temporal()` now use consistent implementation +### 📝 Docs + +**New Documentation Pages:** + +- [Time-Series Clustering Guide](https://flixopt.github.io/flixopt/latest/user-guide/optimization/clustering/) - Comprehensive guide to clustering workflows + +**New Jupyter Notebooks:** + +- **08c-clustering.ipynb** - Introduction to time-series clustering +- **08c2-clustering-storage-modes.ipynb** - Comparison of all 4 storage cluster modes +- **08d-clustering-multiperiod.ipynb** - Clustering with periods and scenarios +- **08e-clustering-internals.ipynb** - Understanding clustering internals +- **fxplot_accessor_demo.ipynb** - Demo of the new fxplot accessor + ### 👷 Development **New Test Suites for Clustering**: @@ -147,6 +210,11 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] - `TestMultiPeriodClustering`: Tests for clustering with periods and scenarios dimensions - `TestPeakSelection`: Tests for `time_series_for_high_peaks` and `time_series_for_low_peaks` parameters +**New Test Suites for Other Features**: + +- `test_clustering_io.py` - Tests for clustering serialization roundtrip +- `test_sel_isel_single_selection.py` - Tests for transform selection methods + --- diff --git a/docs/user-guide/results-plotting.md b/docs/user-guide/results-plotting.md index 1ecd26aa1..28e3d2b2b 100644 --- a/docs/user-guide/results-plotting.md +++ b/docs/user-guide/results-plotting.md @@ -2,6 +2,9 @@ After solving an optimization, flixOpt provides a powerful plotting API to visualize and analyze your results. The API is designed to be intuitive and chainable, giving you quick access to common plots while still allowing deep customization. +!!! tip "Plotting Custom Data" + For plotting arbitrary xarray data (not just flixopt results), see the [Custom Data Plotting](recipes/plotting-custom-data.md) guide which covers the `.fxplot` accessor. + ## The Plot Accessor All plotting is accessed through the `statistics.plot` accessor on your FlowSystem: diff --git a/mkdocs.yml b/mkdocs.yml index ab2e9309f..6ac519130 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,8 +69,13 @@ nav: - Piecewise Effects: notebooks/06c-piecewise-effects.ipynb - Scaling: - Scenarios: notebooks/07-scenarios-and-periods.ipynb - - Clustering: notebooks/08a-aggregation.ipynb + - Resampling: notebooks/08a-aggregation.ipynb - Rolling Horizon: notebooks/08b-rolling-horizon.ipynb + - Clustering: + - Basics: notebooks/08c-clustering.ipynb + - Storage Modes: notebooks/08c2-clustering-storage-modes.ipynb + - Multi-Period: notebooks/08d-clustering-multiperiod.ipynb + - Internals: notebooks/08e-clustering-internals.ipynb - Results: - Plotting: notebooks/09-plotting-and-data-access.ipynb - Custom Data Plotting: notebooks/fxplot_accessor_demo.ipynb From 0b8440bd5ce0ace77287e3346f5e0e1a1d11981c Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:47:31 +0100 Subject: [PATCH 29/62] Fix core dims --- flixopt/comparison.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flixopt/comparison.py b/flixopt/comparison.py index c020cb01d..b615f9ae8 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -95,7 +95,8 @@ def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = Non self._statistics: ComparisonStatistics | None = None # Core dimensions that must match across FlowSystems - _CORE_DIMS = {'time', 'cluster', 'period', 'scenario'} + # Note: 'cluster' and 'cluster_boundary' are auxiliary dimensions from clustering + _CORE_DIMS = {'time', 'period', 'scenario'} def _validate_matching_dimensions(self) -> None: """Validate that all FlowSystems have matching core dimensions. From e4cd2701b5b69c09749cdc1d150faa6e6265fdd9 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:48:26 +0100 Subject: [PATCH 30/62] FIx CHangelog and change to v6.0.0 --- CHANGELOG.md | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 191a8c28c..9fe871469 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,9 +51,12 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp Until here --> -## [5.1.0] - Upcoming +## [6.0.0] - Upcoming -**Summary**: Major feature release introducing time-series clustering for faster optimization and the new `fxplot` accessor for universal xarray plotting. Includes configurable storage behavior across typical periods and improved weights API. +**Summary**: Major release introducing time-series clustering with storage inter-cluster linking, the new `fxplot` accessor for universal xarray plotting, and removal of deprecated v5.0 classes. Includes configurable storage behavior across typical periods and improved weights API. + +!!! warning "Breaking Changes" + This release removes `ClusteredOptimization` and `ClusteringParameters` which were deprecated in v5.0.0. Use `flow_system.transform.cluster()` instead. See [Migration](#migration-from-clusteredoptimization) below. ### ✨ Added @@ -183,6 +186,42 @@ dataset.fxstats.to_duration_curve() Note: `topology.plot()` now renders a Sankey diagram. The old PyVis visualization is available via `topology.plot_legacy()`. +### 🔥 Removed + +**Clustering classes removed** (deprecated in v5.0.0): + +- `ClusteredOptimization` class - Use `flow_system.transform.cluster()` then `optimize()` +- `ClusteringParameters` class - Parameters are now passed directly to `transform.cluster()` +- `flixopt/clustering.py` module - Restructured to `flixopt/clustering/` package with new classes + +#### Migration from ClusteredOptimization + +=== "v5.x (Old - No longer works)" + ```python + from flixopt import ClusteredOptimization, ClusteringParameters + + params = ClusteringParameters(hours_per_period=24, nr_of_periods=8) + calc = ClusteredOptimization('model', flow_system, params) + calc.do_modeling_and_solve(solver) + results = calc.results + ``` + +=== "v6.0.0 (New)" + ```python + # Cluster using transform accessor + fs_clustered = flow_system.transform.cluster( + n_clusters=8, # was: nr_of_periods + cluster_duration='1D', # was: hours_per_period=24 + ) + fs_clustered.optimize(solver) + + # Results on the clustered FlowSystem + costs = fs_clustered.solution['costs'].item() + + # Expand back to full resolution if needed + fs_expanded = fs_clustered.transform.expand_solution() + ``` + ### 🐛 Fixed - `temporal_weight` and `sum_temporal()` now use consistent implementation From c20f94f5cd9e83d384a6c83f0d2fef6b4ad48a95 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:50:17 +0100 Subject: [PATCH 31/62] FIx CHangelog and change to v6.0.0 --- CHANGELOG.md | 23 ++++++++++++++++++++--- flixopt/config.py | 2 +- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fe871469..1d99886c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -173,12 +173,29 @@ dataset.fxstats.to_duration_curve() ### 🗑️ Deprecated +The following items are deprecated and will be removed in **v7.0.0**: + +**Classes** (use FlowSystem methods instead): + +- `Optimization` class → Use `flow_system.optimize(solver)` +- `SegmentedOptimization` class → Use `flow_system.optimize.rolling_horizon()` +- `Results` class → Use `flow_system.solution` and `flow_system.statistics` +- `SegmentedResults` class → Use segment FlowSystems directly + +**FlowSystem methods** (use `transform` accessor instead): + +- `flow_system.sel()` → Use `flow_system.transform.sel()` +- `flow_system.isel()` → Use `flow_system.transform.isel()` +- `flow_system.resample()` → Use `flow_system.transform.resample()` + +**Parameters:** + - `normalize_weights` parameter in `create_model()`, `build_model()`, `optimize()` -**Topology method name simplifications** (old names still work with deprecation warnings, removal in v6.0.0): +**Topology method name simplifications** (old names still work with deprecation warnings, removal in v7.0.0): -| Old (v5.0.0) | New (v5.1.0) | -|--------------|--------------| +| Old (v5.x) | New (v6.0.0) | +|------------|--------------| | `topology.plot_network()` | `topology.plot()` | | `topology.start_network_app()` | `topology.start_app()` | | `topology.stop_network_app()` | `topology.stop_app()` | diff --git a/flixopt/config.py b/flixopt/config.py index 454f8ad3e..602652252 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -30,7 +30,7 @@ logging.addLevelName(SUCCESS_LEVEL, 'SUCCESS') # Deprecation removal version - update this when planning the next major version -DEPRECATION_REMOVAL_VERSION = '6.0.0' +DEPRECATION_REMOVAL_VERSION = '7.0.0' class MultilineFormatter(logging.Formatter): From c6c9a75ce506e270560e15bbb1d546e235c561dc Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sat, 3 Jan 2026 23:50:37 +0100 Subject: [PATCH 32/62] FIx CHangelog and change to v6.0.0 --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d99886c8..a09539d81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -182,11 +182,15 @@ The following items are deprecated and will be removed in **v7.0.0**: - `Results` class → Use `flow_system.solution` and `flow_system.statistics` - `SegmentedResults` class → Use segment FlowSystems directly -**FlowSystem methods** (use `transform` accessor instead): +**FlowSystem methods** (use `transform` or `topology` accessor instead): - `flow_system.sel()` → Use `flow_system.transform.sel()` - `flow_system.isel()` → Use `flow_system.transform.isel()` - `flow_system.resample()` → Use `flow_system.transform.resample()` +- `flow_system.plot_network()` → Use `flow_system.topology.plot()` +- `flow_system.start_network_app()` → Use `flow_system.topology.start_app()` +- `flow_system.stop_network_app()` → Use `flow_system.topology.stop_app()` +- `flow_system.network_infos()` → Use `flow_system.topology.infos()` **Parameters:** From 3d8e6008ed25bcfc1c521fdf6fa734f56c7ce240 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 11:49:26 +0100 Subject: [PATCH 33/62] Enhanced Clustering Control New Parameters Added to cluster() Method | Parameter | Type | Default | Purpose | |-------------------------|-------------------------------|----------------------|--------------------------------------------------------------------------------------------------------------------| | cluster_method | Literal[...] | 'k_means' | Clustering algorithm ('k_means', 'hierarchical', 'k_medoids', 'k_maxoids', 'averaging') | | representation_method | Literal[...] | 'meanRepresentation' | How clusters are represented ('meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation') | | extreme_period_method | Literal[...] | 'new_cluster_center' | How peaks are integrated ('None', 'append', 'new_cluster_center', 'replace_cluster_center') | | rescale_cluster_periods | bool | True | Rescale clusters to match original means | | random_state | int | None | None | Random seed for reproducibility | | predef_cluster_order | np.ndarray | list[int] | None | None | Manual clustering assignments | | **tsam_kwargs | Any | - | Pass-through for any tsam parameter | Clustering Quality Metrics Access via fs.clustering.metrics after clustering - returns a DataFrame with RMSE, MAE, and other accuracy indicators per time series. Files Modified 1. flixopt/transform_accessor.py - Updated cluster() signature and tsam call 2. flixopt/clustering/base.py - Added metrics field to Clustering class 3. tests/test_clustering/test_integration.py - Added tests for new parameters 4. docs/user-guide/optimization/clustering.md - Updated documentation --- docs/user-guide/optimization/clustering.md | 56 ++++++++++++++++++ flixopt/clustering/base.py | 2 + flixopt/transform_accessor.py | 46 +++++++++++++-- tests/test_clustering/test_integration.py | 68 ++++++++++++++++++++++ 4 files changed, 168 insertions(+), 4 deletions(-) diff --git a/docs/user-guide/optimization/clustering.md b/docs/user-guide/optimization/clustering.md index 7ec5faac1..aca1b3eaf 100644 --- a/docs/user-guide/optimization/clustering.md +++ b/docs/user-guide/optimization/clustering.md @@ -52,6 +52,10 @@ flow_rates = fs_expanded.solution['Boiler(Q_th)|flow_rate'] | `cluster_duration` | Duration of each cluster | `'1D'`, `'24h'`, or `24` (hours) | | `time_series_for_high_peaks` | Time series where peak clusters must be captured | `['HeatDemand(Q)|fixed_relative_profile']` | | `time_series_for_low_peaks` | Time series where minimum clusters must be captured | `['SolarGen(P)|fixed_relative_profile']` | +| `cluster_method` | Clustering algorithm | `'k_means'`, `'hierarchical'`, `'k_medoids'` | +| `representation_method` | How clusters are represented | `'meanRepresentation'`, `'medoidRepresentation'` | +| `random_state` | Random seed for reproducibility | `42` | +| `rescale_cluster_periods` | Rescale clusters to match original means | `True` (default) | ### Peak Selection @@ -68,6 +72,58 @@ fs_clustered = flow_system.transform.cluster( Without peak selection, the clustering algorithm might average out extreme days, leading to undersized equipment. +### Advanced Clustering Options + +Fine-tune the clustering algorithm with advanced parameters: + +```python +fs_clustered = flow_system.transform.cluster( + n_clusters=8, + cluster_duration='1D', + cluster_method='hierarchical', # Alternative to k_means + representation_method='medoidRepresentation', # Use actual periods, not averages + rescale_cluster_periods=True, # Match original time series means + random_state=42, # Reproducible results +) +``` + +**Available clustering algorithms** (`cluster_method`): + +| Method | Description | +|--------|-------------| +| `'k_means'` | Fast, good for most cases (default) | +| `'hierarchical'` | Produces consistent hierarchical groupings | +| `'k_medoids'` | Uses actual periods as representatives | +| `'k_maxoids'` | Maximizes representativeness | +| `'averaging'` | Simple averaging of similar periods | + +For advanced tsam parameters not exposed directly, use `**kwargs`: + +```python +# Pass any tsam.TimeSeriesAggregation parameter +fs_clustered = flow_system.transform.cluster( + n_clusters=8, + cluster_duration='1D', + sameMean=True, # Normalize all time series to same mean + sortValues=True, # Cluster by duration curves instead of shape +) +``` + +### Clustering Quality Metrics + +After clustering, access quality metrics to evaluate the aggregation accuracy: + +```python +fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') + +# Access clustering metrics +metrics = fs_clustered.clustering.metrics +print(metrics) + +# Metrics include RMSE, MAE per time series +# Use these to assess if more clusters are needed +``` + ## Storage Modes Storage behavior during clustering is controlled via the `cluster_mode` parameter: diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index 4b31832e4..34aba2ded 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -993,6 +993,7 @@ class Clustering: Attributes: result: The ClusterResult from the aggregation backend. backend_name: Name of the aggregation backend used (e.g., 'tsam', 'manual'). + metrics: Clustering quality metrics (RMSE, MAE, etc.) per time series. Example: >>> fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') @@ -1004,6 +1005,7 @@ class Clustering: result: ClusterResult backend_name: str = 'unknown' + metrics: pd.DataFrame | None = None def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: """Create reference structure for serialization.""" diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 3a13dbb63..010499f46 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -582,6 +582,17 @@ def cluster( weights: dict[str, float] | None = None, time_series_for_high_peaks: list[str] | None = None, time_series_for_low_peaks: list[str] | None = None, + cluster_method: Literal['k_means', 'k_medoids', 'hierarchical', 'k_maxoids', 'averaging'] = 'k_means', + representation_method: Literal[ + 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' + ] = 'meanRepresentation', + extreme_period_method: Literal[ + 'None', 'append', 'new_cluster_center', 'replace_cluster_center' + ] = 'new_cluster_center', + rescale_cluster_periods: bool = True, + random_state: int | None = None, + predef_cluster_order: np.ndarray | list[int] | None = None, + **tsam_kwargs: Any, ) -> FlowSystem: """ Create a FlowSystem with reduced timesteps using typical clusters. @@ -607,6 +618,24 @@ def cluster( time_series_for_high_peaks: Time series labels for explicitly selecting high-value clusters. **Recommended** for demand time series to capture peak demand days. time_series_for_low_peaks: Time series labels for explicitly selecting low-value clusters. + cluster_method: Clustering algorithm to use. Options: + ``'k_means'`` (default), ``'k_medoids'``, ``'hierarchical'``, + ``'k_maxoids'``, ``'averaging'``. + representation_method: How cluster representatives are computed. Options: + ``'meanRepresentation'`` (default), ``'medoidRepresentation'``, + ``'distributionAndMinMaxRepresentation'``. + extreme_period_method: How extreme periods (peaks) are integrated. Options: + ``'new_cluster_center'`` (default), ``'None'``, ``'append'``, + ``'replace_cluster_center'``. + rescale_cluster_periods: If True (default), rescale cluster periods so their + weighted mean matches the original time series mean. + random_state: Random seed for reproducible clustering results. If None, + results may vary between runs. + predef_cluster_order: Predefined cluster assignments for manual clustering. + Array of cluster indices (0 to n_clusters-1) for each original period. + If provided, clustering is skipped and these assignments are used directly. + **tsam_kwargs: Additional keyword arguments passed to + ``tsam.TimeSeriesAggregation``. See tsam documentation for all options. Returns: A new FlowSystem with reduced timesteps (only typical clusters). @@ -680,7 +709,10 @@ def cluster( tsam_results: dict[tuple, tsam.TimeSeriesAggregation] = {} cluster_orders: dict[tuple, np.ndarray] = {} cluster_occurrences_all: dict[tuple, dict] = {} - use_extreme_periods = bool(time_series_for_high_peaks or time_series_for_low_peaks) + + # Set random seed for reproducibility + if random_state is not None: + np.random.seed(random_state) for period_label in periods: for scenario_label in scenarios: @@ -700,11 +732,15 @@ def cluster( noTypicalPeriods=n_clusters, hoursPerPeriod=hours_per_cluster, resolution=dt, - clusterMethod='k_means', - extremePeriodMethod='new_cluster_center' if use_extreme_periods else 'None', + clusterMethod=cluster_method, + extremePeriodMethod=extreme_period_method, + representationMethod=representation_method, + rescaleClusterPeriods=rescale_cluster_periods, + predefClusterOrder=predef_cluster_order, weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, addPeakMax=time_series_for_high_peaks or [], addPeakMin=time_series_for_low_peaks or [], + **tsam_kwargs, ) # Suppress tsam warning about minimal value constraints (informational, not actionable) with warnings.catch_warnings(): @@ -715,9 +751,10 @@ def cluster( cluster_orders[key] = tsam_agg.clusterOrder cluster_occurrences_all[key] = tsam_agg.clusterPeriodNoOccur - # Use first result for structure + # Use first result for structure and metrics first_key = (periods[0], scenarios[0]) first_tsam = tsam_results[first_key] + clustering_metrics = first_tsam.accuracyIndicators() n_reduced_timesteps = len(first_tsam.typicalPeriods) actual_n_clusters = len(first_tsam.clusterPeriodNoOccur) @@ -932,6 +969,7 @@ def _build_cluster_weights_for_key(key: tuple) -> xr.DataArray: reduced_fs.clustering = Clustering( result=aggregation_result, backend_name='tsam', + metrics=clustering_metrics, ) return reduced_fs diff --git a/tests/test_clustering/test_integration.py b/tests/test_clustering/test_integration.py index 2d04a51c1..2bcd0b022 100644 --- a/tests/test_clustering/test_integration.py +++ b/tests/test_clustering/test_integration.py @@ -170,6 +170,74 @@ def test_cluster_reduces_timesteps(self): assert len(fs_clustered.timesteps) * len(fs_clustered.clusters) == 48 +class TestClusterAdvancedOptions: + """Tests for advanced clustering options.""" + + @pytest.fixture + def basic_flow_system(self): + """Create a basic FlowSystem for testing.""" + pytest.importorskip('tsam') + from flixopt import Bus, Flow, Sink, Source + from flixopt.core import TimeSeriesData + + n_hours = 168 # 7 days + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=n_hours, freq='h')) + + demand_data = np.sin(np.linspace(0, 14 * np.pi, n_hours)) + 2 + bus = Bus('electricity') + grid_flow = Flow('grid_in', bus='electricity', size=100) + demand_flow = Flow( + 'demand_out', bus='electricity', size=100, fixed_relative_profile=TimeSeriesData(demand_data / 100) + ) + source = Source('grid', outputs=[grid_flow]) + sink = Sink('demand', inputs=[demand_flow]) + fs.add_elements(source, sink, bus) + return fs + + def test_cluster_method_parameter(self, basic_flow_system): + """Test that cluster_method parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', cluster_method='hierarchical' + ) + assert len(fs_clustered.clusters) == 2 + + def test_random_state_reproducibility(self, basic_flow_system): + """Test that random_state produces reproducible results.""" + fs1 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', random_state=42) + fs2 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', random_state=42) + + # Same random state should produce identical cluster orders + xr.testing.assert_equal(fs1.clustering.cluster_order, fs2.clustering.cluster_order) + + def test_metrics_available(self, basic_flow_system): + """Test that clustering metrics are available after clustering.""" + fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') + + assert fs_clustered.clustering.metrics is not None + assert isinstance(fs_clustered.clustering.metrics, pd.DataFrame) + assert len(fs_clustered.clustering.metrics) > 0 + + def test_representation_method_parameter(self, basic_flow_system): + """Test that representation_method parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', representation_method='medoidRepresentation' + ) + assert len(fs_clustered.clusters) == 2 + + def test_rescale_cluster_periods_parameter(self, basic_flow_system): + """Test that rescale_cluster_periods parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', rescale_cluster_periods=False + ) + assert len(fs_clustered.clusters) == 2 + + def test_tsam_kwargs_passthrough(self, basic_flow_system): + """Test that additional kwargs are passed to tsam.""" + # sameMean is a valid tsam parameter + fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', sameMean=True) + assert len(fs_clustered.clusters) == 2 + + class TestClusteringModuleImports: """Tests for flixopt.clustering module imports.""" From 0abdb002036c896d0b4d85b5d72702d5869c7da7 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 12:17:57 +0100 Subject: [PATCH 34/62] =?UTF-8?q?=20=20Dimension=20renamed:=20original=5Fp?= =?UTF-8?q?eriod=20=E2=86=92=20original=5Fcluster=20=20=20Property=20renam?= =?UTF-8?q?ed:=20n=5Foriginal=5Fperiods=20=E2=86=92=20n=5Foriginal=5Fclust?= =?UTF-8?q?ers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/user-guide/optimization/clustering.md | 8 +-- flixopt/clustering/base.py | 54 ++++++++------- flixopt/clustering/intercluster_helpers.py | 6 +- flixopt/components.py | 45 ++++++------ flixopt/transform_accessor.py | 81 +++++++++++++++++----- tests/test_cluster_reduce_expand.py | 6 +- tests/test_clustering/test_base.py | 10 +-- tests/test_clustering/test_integration.py | 34 ++++++++- 8 files changed, 164 insertions(+), 80 deletions(-) diff --git a/docs/user-guide/optimization/clustering.md b/docs/user-guide/optimization/clustering.md index aca1b3eaf..793fbf8fe 100644 --- a/docs/user-guide/optimization/clustering.md +++ b/docs/user-guide/optimization/clustering.md @@ -116,12 +116,12 @@ After clustering, access quality metrics to evaluate the aggregation accuracy: ```python fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') -# Access clustering metrics +# Access clustering metrics (xr.Dataset) metrics = fs_clustered.clustering.metrics -print(metrics) +print(metrics) # Shows RMSE, MAE, etc. per time series -# Metrics include RMSE, MAE per time series -# Use these to assess if more clusters are needed +# Access specific metric +rmse = metrics['RMSE'] # xr.DataArray with dims [time_series, period?, scenario?] ``` ## Storage Modes diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index 34aba2ded..9c900593a 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -38,15 +38,15 @@ class ClusterStructure: which is needed for proper storage state-of-charge tracking across typical periods when using cluster(). - Note: "original_period" here refers to the original time chunks before - clustering (e.g., 365 original days), NOT the model's "period" dimension - (years/months). Each original time chunk gets assigned to a cluster. + Note: The "original_cluster" dimension indexes the original cluster-sized + time segments (e.g., 0..364 for 365 days), NOT the model's "period" dimension + (years). Each original segment gets assigned to a representative cluster. Attributes: - cluster_order: Maps each original time chunk index to its cluster ID. - dims: [original_period] for simple case, or - [original_period, period, scenario] for multi-period/scenario systems. - Values are cluster indices (0 to n_clusters-1). + cluster_order: Maps original cluster index → representative cluster ID. + dims: [original_cluster] for simple case, or + [original_cluster, period, scenario] for multi-period/scenario systems. + Values are cluster IDs (0 to n_clusters-1). cluster_occurrences: Count of how many original time chunks each cluster represents. dims: [cluster] for simple case, or [cluster, period, scenario] for multi-dim. n_clusters: Number of distinct clusters (typical periods). @@ -60,7 +60,7 @@ class ClusterStructure: - timesteps_per_cluster: 24 (for hourly data) For multi-scenario (e.g., 2 scenarios): - - cluster_order: shape (365, 2) with dims [original_period, scenario] + - cluster_order: shape (365, 2) with dims [original_cluster, scenario] - cluster_occurrences: shape (8, 2) with dims [cluster, scenario] """ @@ -73,7 +73,7 @@ def __post_init__(self): """Validate and ensure proper DataArray formatting.""" # Ensure cluster_order is a DataArray with proper dims if not isinstance(self.cluster_order, xr.DataArray): - self.cluster_order = xr.DataArray(self.cluster_order, dims=['original_period'], name='cluster_order') + self.cluster_order = xr.DataArray(self.cluster_order, dims=['original_cluster'], name='cluster_order') elif self.cluster_order.name is None: self.cluster_order = self.cluster_order.rename('cluster_order') @@ -92,7 +92,7 @@ def __repr__(self) -> str: occ = [int(self.cluster_occurrences.sel(cluster=c).values) for c in range(n_clusters)] return ( f'ClusterStructure(\n' - f' {self.n_original_periods} original periods → {n_clusters} clusters\n' + f' {self.n_original_clusters} original periods → {n_clusters} clusters\n' f' timesteps_per_cluster={self.timesteps_per_cluster}\n' f' occurrences={occ}\n' f')' @@ -124,9 +124,9 @@ def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: return ref, arrays @property - def n_original_periods(self) -> int: + def n_original_clusters(self) -> int: """Number of original periods (before clustering).""" - return len(self.cluster_order.coords['original_period']) + return len(self.cluster_order.coords['original_cluster']) @property def has_multi_dims(self) -> bool: @@ -236,7 +236,7 @@ def plot(self, show: bool | None = None) -> PlotResult: y=[1] * len(df), color='Cluster', color_continuous_scale='Viridis', - title=f'Cluster Assignment ({self.n_original_periods} periods → {n_clusters} clusters)', + title=f'Cluster Assignment ({self.n_original_clusters} periods → {n_clusters} clusters)', ) fig.update_layout(yaxis_visible=False, coloraxis_colorbar_title='Cluster') @@ -532,30 +532,30 @@ def validate(self) -> None: # (each weight is how many original periods that cluster represents) # Sum should be checked per period/scenario slice, not across all dimensions if self.cluster_structure is not None: - n_original_periods = self.cluster_structure.n_original_periods + n_original_clusters = self.cluster_structure.n_original_clusters # Sum over cluster dimension only (keep period/scenario if present) weight_sum_per_slice = self.representative_weights.sum(dim='cluster') # Check each slice if weight_sum_per_slice.size == 1: # Simple case: no period/scenario weight_sum = float(weight_sum_per_slice.values) - if abs(weight_sum - n_original_periods) > 1e-6: + if abs(weight_sum - n_original_clusters) > 1e-6: import warnings warnings.warn( f'representative_weights sum ({weight_sum}) does not match ' - f'n_original_periods ({n_original_periods})', + f'n_original_clusters ({n_original_clusters})', stacklevel=2, ) else: # Multi-dimensional: check each slice for val in weight_sum_per_slice.values.flat: - if abs(float(val) - n_original_periods) > 1e-6: + if abs(float(val) - n_original_clusters) > 1e-6: import warnings warnings.warn( f'representative_weights sum per slice ({float(val)}) does not match ' - f'n_original_periods ({n_original_periods})', + f'n_original_clusters ({n_original_clusters})', stacklevel=2, ) break # Only warn once @@ -993,7 +993,9 @@ class Clustering: Attributes: result: The ClusterResult from the aggregation backend. backend_name: Name of the aggregation backend used (e.g., 'tsam', 'manual'). - metrics: Clustering quality metrics (RMSE, MAE, etc.) per time series. + metrics: Clustering quality metrics (RMSE, MAE, etc.) as xr.Dataset. + Each metric (e.g., 'RMSE', 'MAE') is a DataArray with dims + ``[time_series, period?, scenario?]``. Example: >>> fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') @@ -1005,7 +1007,7 @@ class Clustering: result: ClusterResult backend_name: str = 'unknown' - metrics: pd.DataFrame | None = None + metrics: xr.Dataset | None = None def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: """Create reference structure for serialization.""" @@ -1028,7 +1030,7 @@ def __repr__(self) -> str: n_clusters = ( int(cs.n_clusters) if isinstance(cs.n_clusters, (int, np.integer)) else int(cs.n_clusters.values) ) - structure_info = f'{cs.n_original_periods} periods → {n_clusters} clusters' + structure_info = f'{cs.n_original_clusters} periods → {n_clusters} clusters' else: structure_info = 'no structure' return f'Clustering(\n backend={self.backend_name!r}\n {structure_info}\n)' @@ -1073,11 +1075,11 @@ def n_clusters(self) -> int: return int(n) if isinstance(n, (int, np.integer)) else int(n.values) @property - def n_original_periods(self) -> int: + def n_original_clusters(self) -> int: """Number of original periods (before clustering).""" if self.result.cluster_structure is None: raise ValueError('No cluster_structure available') - return self.result.cluster_structure.n_original_periods + return self.result.cluster_structure.n_original_clusters @property def timesteps_per_period(self) -> int: @@ -1154,17 +1156,17 @@ def create_cluster_structure_from_mapping( ClusterStructure derived from the mapping. """ n_original = len(timestep_mapping) - n_original_periods = n_original // timesteps_per_cluster + n_original_clusters = n_original // timesteps_per_cluster # Determine cluster order from the mapping # Each original period maps to the cluster of its first timestep cluster_order = [] - for p in range(n_original_periods): + for p in range(n_original_clusters): start_idx = p * timesteps_per_cluster cluster_idx = int(timestep_mapping.isel(original_time=start_idx).values) // timesteps_per_cluster cluster_order.append(cluster_idx) - cluster_order_da = xr.DataArray(cluster_order, dims=['original_period'], name='cluster_order') + cluster_order_da = xr.DataArray(cluster_order, dims=['original_cluster'], name='cluster_order') # Count occurrences of each cluster unique_clusters = np.unique(cluster_order) diff --git a/flixopt/clustering/intercluster_helpers.py b/flixopt/clustering/intercluster_helpers.py index d2a5eb9d3..a89a80862 100644 --- a/flixopt/clustering/intercluster_helpers.py +++ b/flixopt/clustering/intercluster_helpers.py @@ -132,7 +132,7 @@ def extract_capacity_bounds( def build_boundary_coords( - n_original_periods: int, + n_original_clusters: int, flow_system: FlowSystem, ) -> tuple[dict, list[str]]: """Build coordinates and dimensions for SOC_boundary variable. @@ -146,7 +146,7 @@ def build_boundary_coords( multi-period or stochastic optimizations. Args: - n_original_periods: Number of original (non-aggregated) time periods. + n_original_clusters: Number of original (non-aggregated) time periods. For example, if a year is clustered into 8 typical days but originally had 365 days, this would be 365. flow_system: The FlowSystem containing optional period/scenario dimensions. @@ -163,7 +163,7 @@ def build_boundary_coords( >>> coords['cluster_boundary'] array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) """ - n_boundaries = n_original_periods + 1 + n_boundaries = n_original_clusters + 1 coords = {'cluster_boundary': np.arange(n_boundaries)} dims = ['cluster_boundary'] diff --git a/flixopt/components.py b/flixopt/components.py index 390fc6f02..e962791d8 100644 --- a/flixopt/components.py +++ b/flixopt/components.py @@ -1195,7 +1195,7 @@ class InterclusterStorageModel(StorageModel): Variables Created ----------------- - ``SOC_boundary``: Absolute SOC at each original period boundary. - Shape: (n_original_periods + 1,) plus any period/scenario dimensions. + Shape: (n_original_clusters + 1,) plus any period/scenario dimensions. Constraints Created ------------------- @@ -1330,7 +1330,7 @@ def _add_intercluster_linking(self) -> None: else int(cluster_structure.n_clusters.values) ) timesteps_per_cluster = cluster_structure.timesteps_per_cluster - n_original_periods = cluster_structure.n_original_periods + n_original_clusters = cluster_structure.n_original_clusters cluster_order = cluster_structure.cluster_order # 1. Constrain ΔE = 0 at cluster starts @@ -1338,7 +1338,7 @@ def _add_intercluster_linking(self) -> None: # 2. Create SOC_boundary variable flow_system = self._model.flow_system - boundary_coords, boundary_dims = build_boundary_coords(n_original_periods, flow_system) + boundary_coords, boundary_dims = build_boundary_coords(n_original_clusters, flow_system) capacity_bounds = extract_capacity_bounds(self.element.capacity_in_flow_hours, boundary_coords, boundary_dims) soc_boundary = self.add_variables( @@ -1360,12 +1360,14 @@ def _add_intercluster_linking(self) -> None: delta_soc = self._compute_delta_soc(n_clusters, timesteps_per_cluster) # 5. Add linking constraints - self._add_linking_constraints(soc_boundary, delta_soc, cluster_order, n_original_periods, timesteps_per_cluster) + self._add_linking_constraints( + soc_boundary, delta_soc, cluster_order, n_original_clusters, timesteps_per_cluster + ) # 6. Add cyclic or initial constraint if self.element.cluster_mode == 'intercluster_cyclic': self.add_constraints( - soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_periods), + soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_clusters), short_name='cyclic', ) else: @@ -1375,7 +1377,8 @@ def _add_intercluster_linking(self) -> None: if isinstance(initial, str): # 'equals_final' means cyclic self.add_constraints( - soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_periods), + soc_boundary.isel(cluster_boundary=0) + == soc_boundary.isel(cluster_boundary=n_original_clusters), short_name='initial_SOC_boundary', ) else: @@ -1389,7 +1392,7 @@ def _add_intercluster_linking(self) -> None: soc_boundary, cluster_order, capacity_bounds.has_investment, - n_original_periods, + n_original_clusters, timesteps_per_cluster, ) @@ -1438,7 +1441,7 @@ def _add_linking_constraints( soc_boundary: xr.DataArray, delta_soc: xr.DataArray, cluster_order: xr.DataArray, - n_original_periods: int, + n_original_clusters: int, timesteps_per_cluster: int, ) -> None: """Add constraints linking consecutive SOC_boundary values. @@ -1455,17 +1458,17 @@ def _add_linking_constraints( soc_boundary: SOC_boundary variable. delta_soc: Net SOC change per cluster. cluster_order: Mapping from original periods to representative clusters. - n_original_periods: Number of original (non-clustered) periods. + n_original_clusters: Number of original (non-clustered) periods. timesteps_per_cluster: Number of timesteps in each cluster period. """ soc_after = soc_boundary.isel(cluster_boundary=slice(1, None)) soc_before = soc_boundary.isel(cluster_boundary=slice(None, -1)) # Rename for alignment - soc_after = soc_after.rename({'cluster_boundary': 'original_period'}) - soc_after = soc_after.assign_coords(original_period=np.arange(n_original_periods)) - soc_before = soc_before.rename({'cluster_boundary': 'original_period'}) - soc_before = soc_before.assign_coords(original_period=np.arange(n_original_periods)) + soc_after = soc_after.rename({'cluster_boundary': 'original_cluster'}) + soc_after = soc_after.assign_coords(original_cluster=np.arange(n_original_clusters)) + soc_before = soc_before.rename({'cluster_boundary': 'original_cluster'}) + soc_before = soc_before.assign_coords(original_cluster=np.arange(n_original_clusters)) # Get delta_soc for each original period using cluster_order delta_soc_ordered = delta_soc.isel(cluster=cluster_order) @@ -1484,7 +1487,7 @@ def _add_combined_bound_constraints( soc_boundary: xr.DataArray, cluster_order: xr.DataArray, has_investment: bool, - n_original_periods: int, + n_original_clusters: int, timesteps_per_cluster: int, ) -> None: """Add constraints ensuring actual SOC stays within bounds. @@ -1498,21 +1501,21 @@ def _add_combined_bound_constraints( middle, and end of each cluster. With 2D (cluster, time) structure, we simply select charge_state at a - given time offset, then reorder by cluster_order to get original_period order. + given time offset, then reorder by cluster_order to get original_cluster order. Args: soc_boundary: SOC_boundary variable. cluster_order: Mapping from original periods to clusters. has_investment: Whether the storage has investment sizing. - n_original_periods: Number of original periods. + n_original_clusters: Number of original periods. timesteps_per_cluster: Timesteps in each cluster. """ charge_state = self.charge_state # soc_d: SOC at start of each original period soc_d = soc_boundary.isel(cluster_boundary=slice(None, -1)) - soc_d = soc_d.rename({'cluster_boundary': 'original_period'}) - soc_d = soc_d.assign_coords(original_period=np.arange(n_original_periods)) + soc_d = soc_d.rename({'cluster_boundary': 'original_cluster'}) + soc_d = soc_d.assign_coords(original_cluster=np.arange(n_original_clusters)) # Get self-discharge rate for decay calculation # Keep as DataArray to respect per-period/scenario values @@ -1523,13 +1526,13 @@ def _add_combined_bound_constraints( for sample_name, offset in zip(['start', 'mid', 'end'], sample_offsets, strict=False): # With 2D structure: select time offset, then reorder by cluster_order cs_at_offset = charge_state.isel(time=offset) # Shape: (cluster, ...) - # Reorder to original_period order using cluster_order indexer + # Reorder to original_cluster order using cluster_order indexer cs_t = cs_at_offset.isel(cluster=cluster_order) # Suppress xarray warning about index loss - we immediately assign new coords anyway with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='.*does not create an index anymore.*') - cs_t = cs_t.rename({'cluster': 'original_period'}) - cs_t = cs_t.assign_coords(original_period=np.arange(n_original_periods)) + cs_t = cs_t.rename({'cluster': 'original_cluster'}) + cs_t = cs_t.assign_coords(original_cluster=np.arange(n_original_clusters)) # Apply decay factor (1-loss)^t to SOC_boundary per Eq. 9 decay_t = (1 - rel_loss) ** offset diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 010499f46..51fcb6f6f 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -591,7 +591,7 @@ def cluster( ] = 'new_cluster_center', rescale_cluster_periods: bool = True, random_state: int | None = None, - predef_cluster_order: np.ndarray | list[int] | None = None, + predef_cluster_order: xr.DataArray | np.ndarray | list[int] | None = None, **tsam_kwargs: Any, ) -> FlowSystem: """ @@ -634,6 +634,9 @@ def cluster( predef_cluster_order: Predefined cluster assignments for manual clustering. Array of cluster indices (0 to n_clusters-1) for each original period. If provided, clustering is skipped and these assignments are used directly. + For multi-dimensional FlowSystems, use an xr.DataArray with dims + ``[original_cluster, period?, scenario?]`` to specify different assignments + per period/scenario combination. **tsam_kwargs: Additional keyword arguments passed to ``tsam.TimeSeriesAggregation``. See tsam documentation for all options. @@ -714,6 +717,9 @@ def cluster( if random_state is not None: np.random.seed(random_state) + # Collect metrics per (period, scenario) slice + clustering_metrics_all: dict[tuple, pd.DataFrame] = {} + for period_label in periods: for scenario_label in scenarios: key = (period_label, scenario_label) @@ -725,6 +731,16 @@ def cluster( if selector: logger.info(f'Clustering {", ".join(f"{k}={v}" for k, v in selector.items())}...') + # Handle predef_cluster_order for multi-dimensional case + predef_order_slice = None + if predef_cluster_order is not None: + if isinstance(predef_cluster_order, xr.DataArray): + # Extract slice for this (period, scenario) combination + predef_order_slice = predef_cluster_order.sel(**selector, drop=True).values + else: + # Simple array/list - use directly + predef_order_slice = predef_cluster_order + # Use tsam directly clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) tsam_agg = tsam.TimeSeriesAggregation( @@ -736,7 +752,7 @@ def cluster( extremePeriodMethod=extreme_period_method, representationMethod=representation_method, rescaleClusterPeriods=rescale_cluster_periods, - predefClusterOrder=predef_cluster_order, + predefClusterOrder=predef_order_slice, weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, addPeakMax=time_series_for_high_peaks or [], addPeakMin=time_series_for_low_peaks or [], @@ -750,11 +766,44 @@ def cluster( tsam_results[key] = tsam_agg cluster_orders[key] = tsam_agg.clusterOrder cluster_occurrences_all[key] = tsam_agg.clusterPeriodNoOccur + clustering_metrics_all[key] = tsam_agg.accuracyIndicators() - # Use first result for structure and metrics + # Use first result for structure first_key = (periods[0], scenarios[0]) first_tsam = tsam_results[first_key] - clustering_metrics = first_tsam.accuracyIndicators() + + # Convert metrics to xr.Dataset with period/scenario dims if multi-dimensional + if len(clustering_metrics_all) == 1: + # Simple case: convert single DataFrame to Dataset + metrics_df = clustering_metrics_all[first_key] + clustering_metrics = xr.Dataset( + { + col: xr.DataArray( + metrics_df[col].values, dims=['time_series'], coords={'time_series': metrics_df.index} + ) + for col in metrics_df.columns + } + ) + else: + # Multi-dim case: combine metrics into Dataset with period/scenario dims + # First, get the metric columns from any DataFrame + sample_df = next(iter(clustering_metrics_all.values())) + metric_names = list(sample_df.columns) + time_series_names = list(sample_df.index) + + # Build DataArrays for each metric + data_vars = {} + for metric in metric_names: + # Shape: (time_series, period?, scenario?) + slices = {} + for (p, s), df in clustering_metrics_all.items(): + slices[(p, s)] = xr.DataArray(df[metric].values, dims=['time_series']) + + da = self._combine_slices_to_dataarray_generic(slices, ['time_series'], periods, scenarios, metric) + da = da.assign_coords(time_series=time_series_names) + data_vars[metric] = da + + clustering_metrics = xr.Dataset(data_vars) n_reduced_timesteps = len(first_tsam.typicalPeriods) actual_n_clusters = len(first_tsam.clusterPeriodNoOccur) @@ -888,7 +937,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: # Build multi-dimensional arrays if has_periods or has_scenarios: # Multi-dimensional case: build arrays for each (period, scenario) combination - # cluster_order: dims [original_period, period?, scenario?] + # cluster_order: dims [original_cluster, period?, scenario?] cluster_order_slices = {} timestep_mapping_slices = {} cluster_occurrences_slices = {} @@ -900,7 +949,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: for s in scenarios: key = (p, s) cluster_order_slices[key] = xr.DataArray( - cluster_orders[key], dims=['original_period'], name='cluster_order' + cluster_orders[key], dims=['original_cluster'], name='cluster_order' ) timestep_mapping_slices[key] = xr.DataArray( _build_timestep_mapping_for_key(key), @@ -914,7 +963,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: # Combine slices into multi-dimensional DataArrays cluster_order_da = self._combine_slices_to_dataarray_generic( - cluster_order_slices, ['original_period'], periods, scenarios, 'cluster_order' + cluster_order_slices, ['original_cluster'], periods, scenarios, 'cluster_order' ) timestep_mapping_da = self._combine_slices_to_dataarray_generic( timestep_mapping_slices, ['original_time'], periods, scenarios, 'timestep_mapping' @@ -924,7 +973,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: ) else: # Simple case: single (None, None) slice - cluster_order_da = xr.DataArray(cluster_orders[first_key], dims=['original_period'], name='cluster_order') + cluster_order_da = xr.DataArray(cluster_orders[first_key], dims=['original_cluster'], name='cluster_order') # Use renamed timesteps as coordinates original_timesteps_coord = self._fs.timesteps.rename('original_time') timestep_mapping_da = xr.DataArray( @@ -1034,7 +1083,7 @@ def _combine_slices_to_dataarray_generic( Args: slices: Dict mapping (period, scenario) tuples to DataArrays. - base_dims: Base dimensions of each slice (e.g., ['original_period'] or ['original_time']). + base_dims: Base dimensions of each slice (e.g., ['original_cluster'] or ['original_time']). periods: List of period labels ([None] if no periods dimension). scenarios: List of scenario labels ([None] if no scenarios dimension). name: Name for the resulting DataArray. @@ -1123,7 +1172,7 @@ def expand_solution(self) -> FlowSystem: disaggregates the FlowSystem by: 1. Expanding all time series data from typical clusters to full timesteps 2. Expanding the solution by mapping each typical cluster back to all - original segments it represents + original clusters it represents For FlowSystems with periods and/or scenarios, each (period, scenario) combination is expanded using its own cluster assignment. @@ -1159,7 +1208,7 @@ def expand_solution(self) -> FlowSystem: Note: The expanded FlowSystem repeats the typical cluster values for all - segments belonging to the same cluster. Both input data and solution + original clusters belonging to the same cluster. Both input data and solution are consistently expanded, so they match. This is an approximation - the actual dispatch at full resolution would differ due to intra-cluster variations in time series data. @@ -1261,12 +1310,12 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_charge_state = expanded_fs._solution[charge_state_name] # Map each original timestep to its original period index - original_period_indices = np.arange(n_original_timesteps) // timesteps_per_cluster + original_cluster_indices = np.arange(n_original_timesteps) // timesteps_per_cluster # Select SOC_boundary for each timestep (boundary[d] for period d) - # SOC_boundary has dim 'cluster_boundary', we select indices 0..n_original_periods-1 + # SOC_boundary has dim 'cluster_boundary', we select indices 0..n_original_clusters-1 soc_boundary_per_timestep = soc_boundary.isel( - cluster_boundary=xr.DataArray(original_period_indices, dims=['time']) + cluster_boundary=xr.DataArray(original_cluster_indices, dims=['time']) ) soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps) @@ -1293,14 +1342,14 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_fs._solution[charge_state_name] = combined_charge_state.assign_attrs(expanded_charge_state.attrs) n_combinations = len(periods) * len(scenarios) - n_original_segments = cluster_structure.n_original_periods + n_original_clusters = cluster_structure.n_original_clusters logger.info( f'Expanded FlowSystem from {n_reduced_timesteps} to {n_original_timesteps} timesteps ' f'({n_clusters} clusters' + ( f', {n_combinations} period/scenario combinations)' if n_combinations > 1 - else f' → {n_original_segments} original segments)' + else f' → {n_original_clusters} original clusters)' ) ) diff --git a/tests/test_cluster_reduce_expand.py b/tests/test_cluster_reduce_expand.py index 7072fe22e..b54eeb56e 100644 --- a/tests/test_cluster_reduce_expand.py +++ b/tests/test_cluster_reduce_expand.py @@ -449,9 +449,9 @@ def test_storage_cluster_mode_intercluster(self, solver_fixture, timesteps_8_day soc_boundary = fs_clustered.solution['Battery|SOC_boundary'] assert 'cluster_boundary' in soc_boundary.dims - # Number of boundaries = n_original_periods + 1 - n_original_periods = fs_clustered.clustering.result.cluster_structure.n_original_periods - assert soc_boundary.sizes['cluster_boundary'] == n_original_periods + 1 + # Number of boundaries = n_original_clusters + 1 + n_original_clusters = fs_clustered.clustering.result.cluster_structure.n_original_clusters + assert soc_boundary.sizes['cluster_boundary'] == n_original_clusters + 1 def test_storage_cluster_mode_intercluster_cyclic(self, solver_fixture, timesteps_8_days): """Storage with cluster_mode='intercluster_cyclic' - linked with yearly cycling.""" diff --git a/tests/test_clustering/test_base.py b/tests/test_clustering/test_base.py index 9c63f25f6..9cca4de81 100644 --- a/tests/test_clustering/test_base.py +++ b/tests/test_clustering/test_base.py @@ -17,7 +17,7 @@ class TestClusterStructure: def test_basic_creation(self): """Test basic ClusterStructure creation.""" - cluster_order = xr.DataArray([0, 1, 0, 1, 2, 0], dims=['original_period']) + cluster_order = xr.DataArray([0, 1, 0, 1, 2, 0], dims=['original_cluster']) cluster_occurrences = xr.DataArray([3, 2, 1], dims=['cluster']) structure = ClusterStructure( @@ -29,7 +29,7 @@ def test_basic_creation(self): assert structure.n_clusters == 3 assert structure.timesteps_per_cluster == 24 - assert structure.n_original_periods == 6 + assert structure.n_original_clusters == 6 def test_creation_from_numpy(self): """Test ClusterStructure creation from numpy arrays.""" @@ -42,12 +42,12 @@ def test_creation_from_numpy(self): assert isinstance(structure.cluster_order, xr.DataArray) assert isinstance(structure.cluster_occurrences, xr.DataArray) - assert structure.n_original_periods == 5 + assert structure.n_original_clusters == 5 def test_get_cluster_weight_per_timestep(self): """Test weight calculation per timestep.""" structure = ClusterStructure( - cluster_order=xr.DataArray([0, 1, 0], dims=['original_period']), + cluster_order=xr.DataArray([0, 1, 0], dims=['original_cluster']), cluster_occurrences=xr.DataArray([2, 1], dims=['cluster']), n_clusters=2, timesteps_per_cluster=4, @@ -136,7 +136,7 @@ def test_basic_creation(self): structure = create_cluster_structure_from_mapping(mapping, timesteps_per_cluster=4) assert structure.timesteps_per_cluster == 4 - assert structure.n_original_periods == 3 + assert structure.n_original_clusters == 3 class TestClustering: diff --git a/tests/test_clustering/test_integration.py b/tests/test_clustering/test_integration.py index 2bcd0b022..d6dd3d2e7 100644 --- a/tests/test_clustering/test_integration.py +++ b/tests/test_clustering/test_integration.py @@ -214,8 +214,9 @@ def test_metrics_available(self, basic_flow_system): fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') assert fs_clustered.clustering.metrics is not None - assert isinstance(fs_clustered.clustering.metrics, pd.DataFrame) - assert len(fs_clustered.clustering.metrics) > 0 + assert isinstance(fs_clustered.clustering.metrics, xr.Dataset) + assert 'time_series' in fs_clustered.clustering.metrics.dims + assert len(fs_clustered.clustering.metrics.data_vars) > 0 def test_representation_method_parameter(self, basic_flow_system): """Test that representation_method parameter works.""" @@ -237,6 +238,35 @@ def test_tsam_kwargs_passthrough(self, basic_flow_system): fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', sameMean=True) assert len(fs_clustered.clusters) == 2 + def test_metrics_with_periods(self): + """Test that metrics have period dimension for multi-period FlowSystems.""" + pytest.importorskip('tsam') + from flixopt import Bus, Flow, Sink, Source + from flixopt.core import TimeSeriesData + + n_hours = 168 # 7 days + fs = FlowSystem( + timesteps=pd.date_range('2024-01-01', periods=n_hours, freq='h'), + periods=pd.Index([2025, 2030], name='period'), + ) + + demand_data = np.sin(np.linspace(0, 14 * np.pi, n_hours)) + 2 + bus = Bus('electricity') + grid_flow = Flow('grid_in', bus='electricity', size=100) + demand_flow = Flow( + 'demand_out', bus='electricity', size=100, fixed_relative_profile=TimeSeriesData(demand_data / 100) + ) + source = Source('grid', outputs=[grid_flow]) + sink = Sink('demand', inputs=[demand_flow]) + fs.add_elements(source, sink, bus) + + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + # Metrics should have period dimension + assert fs_clustered.clustering.metrics is not None + assert 'period' in fs_clustered.clustering.metrics.dims + assert len(fs_clustered.clustering.metrics.period) == 2 + class TestClusteringModuleImports: """Tests for flixopt.clustering module imports.""" From 21f96c2d1d3cccb040a6b917658f18fab4d380fe Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 12:45:49 +0100 Subject: [PATCH 35/62] Problem: Expanded FlowSystem from clustering didn't have the extra timestep that regular FlowSystems have. Root Cause: In expand_solution(), the solution was only indexed by original_timesteps (n elements) instead of original_timesteps_extra (n+1 elements). Fix in flixopt/transform_accessor.py: 1. Reindex solution to timesteps_extra (line 1296-1298): - Added expanded_fs._solution.reindex(time=original_timesteps_extra) for consistency with non-expanded FlowSystems 2. Fill extra timestep for charge_state (lines 1300-1333): - Added special handling to properly fill the extra timestep for storage charge_state variables using the last cluster's extra timestep value 3. Updated intercluster storage handling (lines 1340-1388): - Modified to work with original_timesteps_extra instead of just original_timesteps - The extra timestep now correctly gets the final SOC boundary value with proper decay applied Tests updated in tests/test_cluster_reduce_expand.py: - Updated 4 assertions that check solution time coordinates to expect 193 (192 + 1 extra) instead of 192 --- flixopt/transform_accessor.py | 51 ++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 51fcb6f6f..b466d928a 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -1249,18 +1249,38 @@ def expand_solution(self) -> FlowSystem: scenarios = list(self._fs.scenarios) if has_scenarios else [None] n_original_timesteps = len(original_timesteps) n_reduced_timesteps = n_clusters * timesteps_per_cluster + n_original_clusters = cluster_structure.n_original_clusters # Expand function using ClusterResult.expand_data() - handles multi-dimensional cases - def expand_da(da: xr.DataArray) -> xr.DataArray: + # For charge_state with cluster dim, also includes the extra timestep + last_original_cluster_idx = (n_original_timesteps - 1) // timesteps_per_cluster + + def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: if 'time' not in da.dims: return da.copy() - return info.result.expand_data(da, original_time=original_timesteps) + expanded = info.result.expand_data(da, original_time=original_timesteps) + + # For charge_state with cluster dim, append the extra timestep value + if var_name.endswith('|charge_state') and 'cluster' in da.dims: + # Get extra timestep from last cluster using vectorized selection + cluster_order = cluster_structure.cluster_order # (n_original_clusters,) or with period/scenario + if cluster_order.ndim == 1: + last_cluster = int(cluster_order[last_original_cluster_idx]) + extra_val = da.isel(cluster=last_cluster, time=-1) + else: + # Multi-dimensional: select last cluster for each period/scenario slice + last_clusters = cluster_order.isel(original_cluster=last_original_cluster_idx) + extra_val = da.isel(cluster=last_clusters, time=-1) + extra_val = extra_val.expand_dims(time=[original_timesteps_extra[-1]]) + expanded = xr.concat([expanded, extra_val], dim='time') + + return expanded # 1. Expand FlowSystem data (with cluster_weight set to 1.0 for all timesteps) reduced_ds = self._fs.to_dataset(include_solution=False) # Filter out cluster-related variables and copy attrs without clustering info data_vars = { - name: expand_da(da) + name: expand_da(da, name) for name, da in reduced_ds.data_vars.items() if name != 'cluster_weight' and not name.startswith('clustering|') } @@ -1288,17 +1308,22 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_fs = FlowSystem.from_dataset(expanded_ds) # 2. Expand solution + # charge_state variables get their extra timestep via expand_da; others get NaN via reindex reduced_solution = self._fs.solution expanded_fs._solution = xr.Dataset( - {name: expand_da(da) for name, da in reduced_solution.data_vars.items()}, + {name: expand_da(da, name) for name, da in reduced_solution.data_vars.items()}, attrs=reduced_solution.attrs, ) + # Reindex to timesteps_extra for consistency with non-expanded FlowSystems + # (variables without extra timestep data will have NaN at the final timestep) + expanded_fs._solution = expanded_fs._solution.reindex(time=original_timesteps_extra) # 3. Combine charge_state with SOC_boundary for InterclusterStorageModel storages # For intercluster storages, charge_state is relative (ΔE) and can be negative. # Per Blanke et al. (2022) Eq. 9, actual SOC at time t in period d is: # SOC(t) = SOC_boundary[d] * (1 - loss)^t_within_period + charge_state(t) # where t_within_period is hours from period start (accounts for self-discharge decay). + n_original_timesteps_extra = len(original_timesteps_extra) soc_boundary_vars = [name for name in reduced_solution.data_vars if name.endswith('|SOC_boundary')] for soc_boundary_name in soc_boundary_vars: storage_name = soc_boundary_name.rsplit('|', 1)[0] @@ -1309,24 +1334,31 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: soc_boundary = reduced_solution[soc_boundary_name] expanded_charge_state = expanded_fs._solution[charge_state_name] - # Map each original timestep to its original period index - original_cluster_indices = np.arange(n_original_timesteps) // timesteps_per_cluster + # Map each original timestep (including extra) to its original period index + # The extra timestep belongs to the last period + original_cluster_indices = np.minimum( + np.arange(n_original_timesteps_extra) // timesteps_per_cluster, + n_original_clusters - 1, + ) # Select SOC_boundary for each timestep (boundary[d] for period d) # SOC_boundary has dim 'cluster_boundary', we select indices 0..n_original_clusters-1 soc_boundary_per_timestep = soc_boundary.isel( cluster_boundary=xr.DataArray(original_cluster_indices, dims=['time']) ) - soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps) + soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps_extra) # Apply self-discharge decay to SOC_boundary based on time within period # Get the storage's relative_loss_per_hour from the clustered flow system storage = self._fs.storages.get(storage_name) if storage is not None: # Time within period for each timestep (0, 1, 2, ..., timesteps_per_cluster-1, 0, 1, ...) - time_within_period = np.arange(n_original_timesteps) % timesteps_per_cluster + # The extra timestep is at index timesteps_per_cluster (one past the last within-cluster index) + time_within_period = np.arange(n_original_timesteps_extra) % timesteps_per_cluster + # The extra timestep gets the correct decay (timesteps_per_cluster) + time_within_period[-1] = timesteps_per_cluster time_within_period_da = xr.DataArray( - time_within_period, dims=['time'], coords={'time': original_timesteps} + time_within_period, dims=['time'], coords={'time': original_timesteps_extra} ) # Decay factor: (1 - loss)^t, using mean loss over time # Keep as DataArray to respect per-period/scenario values @@ -1342,7 +1374,6 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_fs._solution[charge_state_name] = combined_charge_state.assign_attrs(expanded_charge_state.attrs) n_combinations = len(periods) * len(scenarios) - n_original_clusters = cluster_structure.n_original_clusters logger.info( f'Expanded FlowSystem from {n_reduced_timesteps} to {n_original_timesteps} timesteps ' f'({n_clusters} clusters' From 8ffd18587a6940a69e9beee4740a7391fa739096 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 13:18:47 +0100 Subject: [PATCH 36/62] - 'variable' is treated as a special valid facet value (since it exists in the melted DataFrame from data_var names, not as a dimension) - When facet_row='variable' or facet_col='variable' is passed, it's passed through directly - In line(), when faceting by variable, it's not also used for color (avoids double encoding) --- flixopt/dataset_plot_accessor.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index fc38f730b..270293165 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -34,7 +34,11 @@ def _resolve_auto_facets( animation_frame: str | Literal['auto'] | None = None, exclude_dims: set[str] | None = None, ) -> tuple[str | None, str | None, str | None]: - """Assign 'auto' facet slots from available dims using CONFIG priority lists.""" + """Assign 'auto' facet slots from available dims using CONFIG priority lists. + + Special handling for 'variable': exists in melted DataFrame (from data_var names), + not as a dimension, so it's always valid as an explicit facet request. + """ # Get available extra dimensions with size > 1, excluding specified dims exclude = exclude_dims or set() available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} @@ -50,9 +54,10 @@ def _resolve_auto_facets( results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used + # 'variable' is special - exists in melted df from data_var names, not in ds.dims for slot_name, value in slots.items(): if value is not None and value != 'auto': - if value in available and value not in used: + if value == 'variable' or (value in available and value not in used): used.add(value) results[slot_name] = value @@ -325,8 +330,13 @@ def line( 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: + # Only color by variable if it's not used for faceting or x-axis (and user didn't override) + if ( + x_col != 'variable' + and actual_facet_col != 'variable' + and actual_facet_row != 'variable' + and 'color' not in px_kwargs + ): fig_kwargs['color'] = 'variable' fig_kwargs['color_discrete_map'] = color_map if xlabel: From 57c9cb1aace43618b79599bf173b08c57aae35c3 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 13:51:42 +0100 Subject: [PATCH 37/62] Add variable and color to auto resolving in fxplot --- flixopt/config.py | 6 +- flixopt/dataset_plot_accessor.py | 146 ++++++++++++++++++++----------- 2 files changed, 96 insertions(+), 56 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index 602652252..9793f9ba2 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -164,9 +164,9 @@ def format(self, record): 'default_sequential_colorscale': 'turbo', 'default_qualitative_colorscale': 'plotly', 'default_line_shape': 'hv', - 'extra_dim_priority': ('cluster', 'period', 'scenario'), - 'dim_slot_priority': ('facet_col', 'facet_row', 'animation_frame'), - 'x_dim_priority': ('time', 'duration', 'duration_pct', 'period', 'scenario', 'cluster'), + 'extra_dim_priority': ('variable', 'cluster', 'period', 'scenario'), + 'dim_slot_priority': ('color', 'facet_col', 'facet_row', 'animation_frame'), + 'x_dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'period', 'scenario', 'cluster'), } ), 'solving': MappingProxyType( diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 270293165..403a22a46 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -13,14 +13,25 @@ from .config import CONFIG -def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str: - """Select x-axis dim from priority list, or 'variable' for scalar data.""" +def _get_x_dim(dims: list[str], n_data_vars: int = 1, x: str | Literal['auto'] | None = 'auto') -> str: + """Select x-axis dim from priority list, or 'variable' for scalar data. + + Args: + dims: List of available dimensions. + n_data_vars: Number of data variables (for 'variable' availability). + x: Explicit x-axis choice or 'auto'. + """ if x and x != 'auto': return x + # 'variable' is available when there are multiple data_vars + available = set(dims) + if n_data_vars > 1: + available.add('variable') + # Check priority list first for dim in CONFIG.Plotting.x_dim_priority: - if dim in dims: + if dim in available: return dim # Fallback to first available dimension, or 'variable' for scalar data @@ -29,35 +40,47 @@ def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str def _resolve_auto_facets( ds: xr.Dataset, + color: str | Literal['auto'] | None, facet_col: str | Literal['auto'] | None, facet_row: str | Literal['auto'] | None, animation_frame: str | Literal['auto'] | None = None, exclude_dims: set[str] | None = None, -) -> tuple[str | None, str | None, str | None]: +) -> tuple[str | None, str | None, str | None, str | None]: """Assign 'auto' facet slots from available dims using CONFIG priority lists. - Special handling for 'variable': exists in melted DataFrame (from data_var names), - not as a dimension, so it's always valid as an explicit facet request. + 'variable' is treated like a dimension - available when len(data_vars) > 1. + It exists in the melted DataFrame from data_var names, not in ds.dims. + + Returns: + Tuple of (color, facet_col, facet_row, animation_frame). """ # Get available extra dimensions with size > 1, excluding specified dims exclude = exclude_dims or set() available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} + # 'variable' is available when there are multiple data_vars + if len(ds.data_vars) > 1: + available.add('variable') extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] used: set[str] = set() # Map slot names to their input values slots = { + 'color': color, 'facet_col': facet_col, 'facet_row': facet_row, 'animation_frame': animation_frame, } - results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} + results: dict[str, str | None] = { + 'color': None, + 'facet_col': None, + 'facet_row': None, + 'animation_frame': None, + } # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used - # 'variable' is special - exists in melted df from data_var names, not in ds.dims for slot_name, value in slots.items(): if value is not None and value != 'auto': - if value == 'variable' or (value in available and value not in used): + if value in available and value not in used: used.add(value) results[slot_name] = value @@ -70,7 +93,7 @@ def _resolve_auto_facets( used.add(next_dim) results[slot_name] = next_dim - return results['facet_col'], results['facet_row'], results['animation_frame'] + return results['color'], results['facet_col'], results['facet_row'], results['animation_frame'] def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: @@ -125,6 +148,7 @@ def bar( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -139,6 +163,8 @@ def bar( Args: x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -154,17 +180,20 @@ def bar( """ # Determine x-axis first, then resolve facets from remaining dims dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + x_col = _get_x_dim(dims, len(self._ds.data_vars), x) + actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} ) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + # Get color labels from the resolved color column + color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_map = process_colors( + colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -174,9 +203,8 @@ def bar( 'title': title, 'barmode': 'group', } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' + if actual_color and 'color' not in px_kwargs: + fig_kwargs['color'] = actual_color fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {x_col: xlabel} @@ -198,6 +226,7 @@ def stacked_bar( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -215,6 +244,8 @@ def stacked_bar( Args: x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -230,17 +261,20 @@ def stacked_bar( """ # Determine x-axis first, then resolve facets from remaining dims dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + x_col = _get_x_dim(dims, len(self._ds.data_vars), x) + actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} ) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + # Get color labels from the resolved color column + color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_map = process_colors( + colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -249,9 +283,8 @@ def stacked_bar( 'y': 'value', 'title': title, } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' + if actual_color and 'color' not in px_kwargs: + fig_kwargs['color'] = actual_color fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {x_col: xlabel} @@ -276,6 +309,7 @@ def line( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -293,6 +327,8 @@ def line( Args: x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -310,17 +346,20 @@ def line( """ # Determine x-axis first, then resolve facets from remaining dims dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + x_col = _get_x_dim(dims, len(self._ds.data_vars), x) + actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} ) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + # Get color labels from the resolved color column + color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_map = process_colors( + colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -330,14 +369,8 @@ def line( 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - # Only color by variable if it's not used for faceting or x-axis (and user didn't override) - if ( - x_col != 'variable' - and actual_facet_col != 'variable' - and actual_facet_row != 'variable' - and 'color' not in px_kwargs - ): - fig_kwargs['color'] = 'variable' + if actual_color and 'color' not in px_kwargs: + fig_kwargs['color'] = actual_color fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {x_col: xlabel} @@ -359,6 +392,7 @@ def area( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -374,6 +408,8 @@ def area( Args: x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -390,17 +426,20 @@ def area( """ # Determine x-axis first, then resolve facets from remaining dims dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + x_col = _get_x_dim(dims, len(self._ds.data_vars), x) + actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} ) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + # Get color labels from the resolved color column + color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_map = process_colors( + colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -410,9 +449,8 @@ def area( 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' + if actual_color and 'color' not in px_kwargs: + fig_kwargs['color'] = actual_color fig_kwargs['color_discrete_map'] = color_map if xlabel: fig_kwargs['labels'] = {x_col: xlabel} @@ -477,7 +515,7 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - actual_facet_col, _, actual_anim = _resolve_auto_facets(self._ds, facet_col, None, animation_frame) + _, actual_facet_col, _, actual_anim = _resolve_auto_facets(self._ds, None, facet_col, None, animation_frame) imshow_args: dict[str, Any] = { 'img': da, @@ -535,8 +573,8 @@ def scatter( if df.empty: return go.Figure() - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame + _, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, None, facet_col, facet_row, animation_frame ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols @@ -619,8 +657,8 @@ def pie( if df.empty: return go.Figure() - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame + _, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( + self._ds, None, facet_col, facet_row, animation_frame ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols @@ -882,7 +920,9 @@ def heatmap( # Use Dataset for facet resolution ds_for_resolution = da.to_dataset(name='_temp') - actual_facet_col, _, actual_anim = _resolve_auto_facets(ds_for_resolution, facet_col, None, animation_frame) + _, actual_facet_col, _, actual_anim = _resolve_auto_facets( + ds_for_resolution, None, facet_col, None, animation_frame + ) imshow_args: dict[str, Any] = { 'img': da, From df1fac1e30de08bea3eba6fa5368968aca6efb6c Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 13:53:58 +0100 Subject: [PATCH 38/62] Added 'variable' to both priority lists and updated the logic to treat it consistently: flixopt/config.py: 'extra_dim_priority': ('variable', 'cluster', 'period', 'scenario'), 'x_dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'period', 'scenario', 'cluster'), flixopt/dataset_plot_accessor.py: - _get_x_dim: Now takes n_data_vars parameter; 'variable' is available when > 1 - _resolve_auto_facets: 'variable' is available when len(data_vars) > 1 and respects exclude_dims Behavior: - 'variable' is treated like any other dimension in the priority system - Only available when there are multiple data_vars - Properly excluded when already used (e.g., for x-axis) --- flixopt/clustering/base.py | 181 ++++++++++++------------------- flixopt/dataset_plot_accessor.py | 4 +- 2 files changed, 69 insertions(+), 116 deletions(-) diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index 9c900593a..f914514b4 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -197,20 +197,20 @@ def get_cluster_weight_per_timestep(self) -> xr.DataArray: name='cluster_weight', ) - def plot(self, show: bool | None = None) -> PlotResult: + def plot(self, colors: str | list[str] | None = None, show: bool | None = None) -> PlotResult: """Plot cluster assignment visualization. Shows which cluster each original period belongs to, and the number of occurrences per cluster. Args: + colors: Colorscale name (str) or list of colors. + Defaults to CONFIG.Plotting.default_sequential_colorscale. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. Returns: PlotResult containing the figure and underlying data. """ - import plotly.express as px - from ..config import CONFIG from ..plot_result import PlotResult @@ -218,27 +218,24 @@ def plot(self, show: bool | None = None) -> PlotResult: int(self.n_clusters) if isinstance(self.n_clusters, (int, np.integer)) else int(self.n_clusters.values) ) - # Create DataFrame for plotting - import pandas as pd - cluster_order = self.get_cluster_order_for_slice() - df = pd.DataFrame( - { - 'Original Period': range(1, len(cluster_order) + 1), - 'Cluster': cluster_order, - } + + # Build DataArray for fxplot heatmap + cluster_da = xr.DataArray( + cluster_order.reshape(1, -1), + dims=['y', 'original_cluster'], + coords={'y': ['Cluster'], 'original_cluster': range(1, len(cluster_order) + 1)}, + name='cluster_assignment', ) - # Bar chart showing cluster assignment - fig = px.bar( - df, - x='Original Period', - y=[1] * len(df), - color='Cluster', - color_continuous_scale='Viridis', + # Use fxplot.heatmap for smart defaults + colorscale = colors or CONFIG.Plotting.default_sequential_colorscale + fig = cluster_da.fxplot.heatmap( + colors=colorscale, title=f'Cluster Assignment ({self.n_original_clusters} periods → {n_clusters} clusters)', ) - fig.update_layout(yaxis_visible=False, coloraxis_colorbar_title='Cluster') + fig.update_yaxes(showticklabels=False) + fig.update_coloraxes(colorbar_title='Cluster') # Build data for PlotResult data = xr.Dataset( @@ -585,8 +582,8 @@ def compare( *, select: SelectType | None = None, colors: ColorType | None = None, - facet_col: str | None = 'period', - facet_row: str | None = 'scenario', + facet_col: str | None = 'auto', + facet_row: str | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -600,8 +597,10 @@ def compare( or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). - facet_col: Dimension for subplot columns (default: 'period'). - facet_row: Dimension for subplot rows (default: 'scenario'). + facet_col: Dimension for subplot columns. 'auto' uses CONFIG priority. + Use 'variable' to create separate columns per variable. + facet_row: Dimension for subplot rows. 'auto' uses CONFIG priority. + Use 'variable' to create separate rows per variable. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. **plotly_kwargs: Additional arguments passed to plotly. @@ -610,9 +609,7 @@ def compare( PlotResult containing the comparison figure and underlying data. """ import pandas as pd - import plotly.express as px - from ..color_processing import process_colors from ..config import CONFIG from ..plot_result import PlotResult from ..statistics_accessor import _apply_selection @@ -626,7 +623,7 @@ def compare( resolved_variables = self._resolve_variables(variables) - # Build Dataset with 'representation' dimension for Original/Clustered + # Build Dataset with variables as data_vars data_vars = {} for var in resolved_variables: original = result.original_data[var] @@ -650,54 +647,34 @@ def compare( { var: xr.DataArray( [sorted_vars[(var, r)] for r in ['Original', 'Clustered']], - dims=['representation', 'rank'], - coords={'representation': ['Original', 'Clustered'], 'rank': range(n)}, + dims=['representation', 'duration'], + coords={'representation': ['Original', 'Clustered'], 'duration': range(n)}, ) for var in resolved_variables } ) - # Resolve facets (only for timeseries) - actual_facet_col = facet_col if kind == 'timeseries' and facet_col in ds.dims else None - actual_facet_row = facet_row if kind == 'timeseries' and facet_row in ds.dims else None - - # Convert to long-form DataFrame - df = ds.to_dataframe().reset_index() - coord_cols = [c for c in ds.coords.keys() if c in df.columns] - df = df.melt(id_vars=coord_cols, var_name='variable', value_name='value') - - variable_labels = df['variable'].unique().tolist() - color_map = process_colors(colors, variable_labels, CONFIG.Plotting.default_qualitative_colorscale) - - # Set x-axis and title based on kind - x_col = 'time' if kind == 'timeseries' else 'rank' + # Set title based on kind if kind == 'timeseries': title = ( 'Original vs Clustered' if len(resolved_variables) > 1 else f'Original vs Clustered: {resolved_variables[0]}' ) - labels = {} else: title = 'Duration Curve' if len(resolved_variables) > 1 else f'Duration Curve: {resolved_variables[0]}' - labels = {'rank': 'Hours (sorted)', 'value': 'Value'} - fig = px.line( - df, - x=x_col, - y='value', - color='variable', - line_dash='representation', - facet_col=actual_facet_col, - facet_row=actual_facet_row, + # Use fxplot for smart defaults with line_dash for representation + fig = ds.fxplot.line( + colors=colors, title=title, - labels=labels, - color_discrete_map=color_map, + facet_col=facet_col, + facet_row=facet_row, + line_dash='representation', **plotly_kwargs, ) - if actual_facet_row or actual_facet_col: - fig.update_yaxes(matches=None) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + fig.update_yaxes(matches=None) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) plot_result = PlotResult(data=ds, figure=fig) @@ -743,8 +720,8 @@ def heatmap( *, select: SelectType | None = None, colors: str | list[str] | None = None, - facet_col: str | None = 'period', - animation_frame: str | None = 'scenario', + facet_col: str | None = 'auto', + animation_frame: str | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -762,8 +739,8 @@ def heatmap( colors: Colorscale name (str) or list of colors for heatmap coloring. Dicts are not supported for heatmaps. Defaults to CONFIG.Plotting.default_sequential_colorscale. - facet_col: Dimension to facet on columns (default: 'period'). - animation_frame: Dimension for animation slider (default: 'scenario'). + facet_col: Dimension to facet on columns. 'auto' uses CONFIG priority. + animation_frame: Dimension for animation slider. 'auto' uses CONFIG priority. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. **plotly_kwargs: Additional arguments passed to plotly. @@ -773,7 +750,6 @@ def heatmap( The data has 'cluster' variable with time dimension, matching original timesteps. """ import pandas as pd - import plotly.express as px from ..config import CONFIG from ..plot_result import PlotResult @@ -833,34 +809,24 @@ def heatmap( else: cluster_da = cluster_slices[(None, None)] - # Resolve facet_col and animation_frame - only use if dimension exists - actual_facet_col = facet_col if facet_col and facet_col in cluster_da.dims else None - actual_animation = animation_frame if animation_frame and animation_frame in cluster_da.dims else None - # Add dummy y dimension for heatmap visualization (single row) heatmap_da = cluster_da.expand_dims('y', axis=-1) heatmap_da = heatmap_da.assign_coords(y=['Cluster']) + heatmap_da.name = 'cluster_assignment' - colorscale = colors or CONFIG.Plotting.default_sequential_colorscale - - # Use px.imshow with xr.DataArray - fig = px.imshow( - heatmap_da, - color_continuous_scale=colorscale, - facet_col=actual_facet_col, - animation_frame=actual_animation, + # Use fxplot.heatmap for smart defaults + fig = heatmap_da.fxplot.heatmap( + colors=colors, title='Cluster Assignments', - labels={'time': 'Time', 'color': 'Cluster'}, + facet_col=facet_col, + animation_frame=animation_frame, aspect='auto', **plotly_kwargs, ) - # Clean up facet labels - if actual_facet_col: - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) - - # Hide y-axis since it's just a single row + # Clean up: hide y-axis since it's just a single row fig.update_yaxes(showticklabels=False) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) # Data is exactly what we plotted (without dummy y dimension) cluster_da.name = 'cluster' @@ -880,21 +846,21 @@ def clusters( *, select: SelectType | None = None, colors: ColorType | None = None, - facet_col_wrap: int | None = None, + facet_cols: int | None = None, show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: """Plot each cluster's typical period profile. - Shows each cluster as a separate faceted subplot. Useful for - understanding what each cluster represents. + Shows each cluster as a separate faceted subplot with all variables + colored differently. Useful for understanding what each cluster represents. Args: variables: Variable(s) to plot. Can be a string, list of strings, or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). - facet_col_wrap: Max columns before wrapping facets. + facet_cols: Max columns before wrapping facets. Defaults to CONFIG.Plotting.default_facet_cols. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. @@ -903,10 +869,6 @@ def clusters( Returns: PlotResult containing the figure and underlying data. """ - import pandas as pd - import plotly.express as px - - from ..color_processing import process_colors from ..config import CONFIG from ..plot_result import PlotResult from ..statistics_accessor import _apply_selection @@ -929,45 +891,36 @@ def clusters( n_clusters = int(cs.n_clusters) if isinstance(cs.n_clusters, (int, np.integer)) else int(cs.n_clusters.values) timesteps_per_cluster = cs.timesteps_per_cluster - # Build long-form DataFrame with cluster labels including occurrence counts - rows = [] + # Build Dataset with cluster dimension, using labels with occurrence counts + cluster_labels = [ + f'Cluster {c} (×{int(cs.cluster_occurrences.sel(cluster=c).values)})' for c in range(n_clusters) + ] + data_vars = {} for var in resolved_variables: data = aggregated_data[var].values data_by_cluster = data.reshape(n_clusters, timesteps_per_cluster) data_vars[var] = xr.DataArray( data_by_cluster, - dims=['cluster', 'timestep'], - coords={'cluster': range(n_clusters), 'timestep': range(timesteps_per_cluster)}, + dims=['cluster', 'time'], + coords={'cluster': cluster_labels, 'time': range(timesteps_per_cluster)}, ) - for c in range(n_clusters): - occurrence = int(cs.cluster_occurrences.sel(cluster=c).values) - label = f'Cluster {c} (×{occurrence})' - for t in range(timesteps_per_cluster): - rows.append({'cluster': label, 'timestep': t, 'value': data_by_cluster[c, t], 'variable': var}) - df = pd.DataFrame(rows) - - cluster_labels = df['cluster'].unique().tolist() - color_map = process_colors(colors, cluster_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_col_wrap or CONFIG.Plotting.default_facet_cols + + ds = xr.Dataset(data_vars) title = 'Clusters' if len(resolved_variables) > 1 else f'Clusters: {resolved_variables[0]}' - fig = px.line( - df, - x='timestep', - y='value', - facet_col='cluster', - facet_row='variable' if len(resolved_variables) > 1 else None, - facet_col_wrap=facet_col_wrap if len(resolved_variables) == 1 else None, + # Use fxplot for smart defaults + fig = ds.fxplot.line( + colors=colors, title=title, - color_discrete_map=color_map, + facet_col='cluster', + facet_cols=facet_cols, **plotly_kwargs, ) - fig.update_layout(showlegend=False) - if len(resolved_variables) > 1: - fig.update_yaxes(matches=None) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + fig.update_yaxes(matches=None) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + # Include occurrences in result data data_vars['occurrences'] = cs.cluster_occurrences result_data = xr.Dataset(data_vars) plot_result = PlotResult(data=result_data, figure=fig) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 403a22a46..73b20b436 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -57,8 +57,8 @@ def _resolve_auto_facets( # Get available extra dimensions with size > 1, excluding specified dims exclude = exclude_dims or set() available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} - # 'variable' is available when there are multiple data_vars - if len(ds.data_vars) > 1: + # 'variable' is available when there are multiple data_vars (and not excluded) + if len(ds.data_vars) > 1 and 'variable' not in exclude: available.add('variable') extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] used: set[str] = set() From 829c34244b6edce8b446e4ece70cfa3c09e6385d Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 14:10:57 +0100 Subject: [PATCH 39/62] Improve plotting, especially for clustering --- docs/notebooks/08c-clustering.ipynb | 272 ++++++++++++++++++---------- flixopt/clustering/base.py | 26 ++- tests/test_cluster_reduce_expand.py | 8 +- 3 files changed, 202 insertions(+), 104 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index 0e9cda7b7..d07512cac 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -28,10 +28,8 @@ "source": [ "import timeit\n", "\n", - "import numpy as np\n", "import pandas as pd\n", - "import plotly.graph_objects as go\n", - "from plotly.subplots import make_subplots\n", + "import xarray as xr\n", "\n", "import flixopt as fx\n", "\n", @@ -73,18 +71,13 @@ "outputs": [], "source": [ "# Visualize input data\n", - "heat_demand = flow_system.components['HeatDemand'].inputs[0].fixed_relative_profile\n", - "electricity_price = flow_system.components['GridBuy'].outputs[0].effects_per_flow_hour['costs']\n", - "\n", - "fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1)\n", - "fig.add_trace(go.Scatter(x=timesteps, y=heat_demand.values, name='Heat Demand', line=dict(width=0.5)), row=1, col=1)\n", - "fig.add_trace(\n", - " go.Scatter(x=timesteps, y=electricity_price.values, name='Electricity Price', line=dict(width=0.5)), row=2, col=1\n", + "input_ds = xr.Dataset(\n", + " {\n", + " 'Heat Demand': flow_system.components['HeatDemand'].inputs[0].fixed_relative_profile,\n", + " 'Electricity Price': flow_system.components['GridBuy'].outputs[0].effects_per_flow_hour['costs'],\n", + " }\n", ")\n", - "fig.update_layout(height=400, title='One Month of Input Data')\n", - "fig.update_yaxes(title_text='Heat Demand [MW]', row=1, col=1)\n", - "fig.update_yaxes(title_text='El. Price [€/MWh]', row=2, col=1)\n", - "fig.show()" + "input_ds.fxplot.line(facet_row='variable', title='One Month of Input Data')" ] }, { @@ -154,11 +147,16 @@ " n_clusters=8, # 8 typical days\n", " cluster_duration='1D', # Daily clustering\n", " time_series_for_high_peaks=peak_series, # Capture peak demand day\n", + " random_state=42, # Reproducible results\n", ")\n", "\n", "time_clustering = timeit.default_timer() - start\n", - "print(f'Clustering time: {time_clustering:.1f} seconds')\n", - "print(f'Reduced: {len(flow_system.timesteps)} → {len(fs_clustered.timesteps)} timesteps')" + "\n", + "print(\n", + " f'Clustering: {len(flow_system.timesteps)} → {len(fs_clustered.timesteps) * len(fs_clustered.clusters)} timesteps'\n", + ")\n", + "print(f' Clusters: {len(fs_clustered.clusters)}')\n", + "print(f' Time: {time_clustering:.2f}s')" ] }, { @@ -188,7 +186,7 @@ "source": [ "## Understanding the Clustering\n", "\n", - "The clustering algorithm groups similar days together. Let's inspect the cluster structure:" + "The clustering algorithm groups similar days together. Access all metadata via `fs.clustering`:" ] }, { @@ -198,26 +196,110 @@ "metadata": {}, "outputs": [], "source": [ - "# Show clustering info\n", - "info = fs_clustered.clustering\n", - "cs = info.result.cluster_structure\n", - "print('Clustering Configuration:')\n", - "print(f' Number of typical periods: {cs.n_clusters}')\n", - "print(f' Timesteps per period: {cs.timesteps_per_cluster}')\n", - "print(f' Total reduced timesteps: {cs.n_clusters * cs.timesteps_per_cluster}')\n", - "print(f' Cluster order (first 10 days): {cs.cluster_order.values[:10]}...')\n", + "# Access clustering metadata directly\n", + "clustering = fs_clustered.clustering\n", + "clustering" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "# Key properties\n", + "print(f'Clusters: {clustering.n_clusters}')\n", + "print(f'Original segments (days): {clustering.n_original_clusters}')\n", + "print(f'Timesteps per cluster: {clustering.timesteps_per_cluster}')\n", + "print(f'\\nCluster occurrences: {clustering.occurrences.values}')\n", + "print(f'Cluster order: {clustering.cluster_order.values}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "# Quality metrics - how well do the clusters represent the original data?\n", + "# Lower RMSE/MAE = better representation\n", + "clustering.metrics.to_dataframe().style.format('{:.3f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# Visual comparison: original vs clustered time series\n", + "clustering.plot.compare()" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Advanced Clustering Options\n", + "\n", + "The `cluster()` method exposes many parameters for fine-tuning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# Try different clustering algorithms\n", + "fs_hierarchical = flow_system.transform.cluster(\n", + " n_clusters=8,\n", + " cluster_duration='1D',\n", + " cluster_method='hierarchical', # Alternative: 'k_means' (default), 'k_medoids', 'averaging'\n", + " random_state=42,\n", + ")\n", "\n", - "# Show how many times each cluster appears\n", - "cluster_order = cs.cluster_order.values\n", - "unique, counts = np.unique(cluster_order, return_counts=True)\n", - "print('\\nCluster occurrences:')\n", - "for cluster_id, count in zip(unique, counts, strict=False):\n", - " print(f' Cluster {cluster_id}: {count} days')" + "# Compare cluster assignments between algorithms\n", + "print('k_means clusters: ', fs_clustered.clustering.cluster_order.values)\n", + "print('hierarchical clusters:', fs_hierarchical.clustering.cluster_order.values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# Compare RMSE between algorithms\n", + "print('Quality comparison (RMSE for HeatDemand):')\n", + "print(\n", + " f' k_means: {float(fs_clustered.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", + ")\n", + "print(\n", + " f' hierarchical: {float(fs_hierarchical.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize cluster structure with heatmap\n", + "clustering.plot.heatmap()" ] }, { "cell_type": "markdown", - "id": "12", + "id": "19", "metadata": {}, "source": [ "## Method 3: Two-Stage Workflow (Recommended)\n", @@ -235,7 +317,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -256,7 +338,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -279,7 +361,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "22", "metadata": {}, "source": [ "## Compare Results" @@ -288,7 +370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -337,7 +419,7 @@ }, { "cell_type": "markdown", - "id": "17", + "id": "24", "metadata": {}, "source": [ "## Expand Solution to Full Resolution\n", @@ -349,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -363,34 +445,29 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "26", "metadata": {}, "outputs": [], "source": [ - "# Compare heat balance: Full vs Expanded\n", - "fig = make_subplots(rows=2, cols=1, shared_xaxes=True, subplot_titles=['Full Optimization', 'Expanded from Clustering'])\n", + "# Compare heat production: Full vs Expanded\n", + "heat_flows = ['CHP(Q_th)|flow_rate', 'Boiler(Q_th)|flow_rate']\n", "\n", - "# Full\n", - "for var in ['CHP(Q_th)', 'Boiler(Q_th)']:\n", - " values = fs_full.solution[f'{var}|flow_rate'].values\n", - " fig.add_trace(go.Scatter(x=fs_full.timesteps, y=values, name=var, legendgroup=var, showlegend=True), row=1, col=1)\n", - "\n", - "# Expanded\n", - "for var in ['CHP(Q_th)', 'Boiler(Q_th)']:\n", - " values = fs_expanded.solution[f'{var}|flow_rate'].values\n", - " fig.add_trace(\n", - " go.Scatter(x=fs_expanded.timesteps, y=values, name=var, legendgroup=var, showlegend=False), row=2, col=1\n", - " )\n", + "# Create comparison dataset\n", + "comparison_ds = xr.Dataset(\n", + " {\n", + " name.replace('|flow_rate', ''): xr.concat(\n", + " [fs_full.solution[name], fs_expanded.solution[name]], dim=pd.Index(['Full', 'Expanded'], name='method')\n", + " )\n", + " for name in heat_flows\n", + " }\n", + ")\n", "\n", - "fig.update_layout(height=500, title='Heat Production Comparison')\n", - "fig.update_yaxes(title_text='MW', row=1, col=1)\n", - "fig.update_yaxes(title_text='MW', row=2, col=1)\n", - "fig.show()" + "comparison_ds.fxplot.line(facet_col='variable', color='method', title='Heat Production Comparison')" ] }, { "cell_type": "markdown", - "id": "20", + "id": "27", "metadata": {}, "source": [ "## Visualize Clustered Heat Balance" @@ -399,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -409,33 +486,55 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "29", "metadata": {}, "outputs": [], "source": [ - "fs_expanded.statistics.plot.storage('Storage')" + "fs_expanded.statistics.plot.storage('Storage').data.to_dataframe()" ] }, { "cell_type": "markdown", - "id": "23", + "id": "30", "metadata": {}, "source": [ "## API Reference\n", "\n", "### `transform.cluster()` Parameters\n", "\n", - "| Parameter | Type | Description |\n", - "|-----------|------|-------------|\n", - "| `n_clusters` | `int` | Number of typical periods (e.g., 8 typical days) |\n", - "| `cluster_duration` | `str \\| float` | Duration per cluster ('1D', '24h') or hours |\n", - "| `weights` | `dict[str, float]` | Optional weights for time series in clustering |\n", - "| `time_series_for_high_peaks` | `list[str]` | **Essential**: Force inclusion of peak periods |\n", - "| `time_series_for_low_peaks` | `list[str]` | Force inclusion of minimum periods |\n", + "| Parameter | Type | Default | Description |\n", + "|-----------|------|---------|-------------|\n", + "| `n_clusters` | `int` | - | Number of typical periods (e.g., 8 typical days) |\n", + "| `cluster_duration` | `str \\| float` | - | Duration per cluster ('1D', '24h') or hours |\n", + "| `weights` | `dict[str, float]` | None | Optional weights for time series in clustering |\n", + "| `time_series_for_high_peaks` | `list[str]` | None | **Essential**: Force inclusion of peak periods |\n", + "| `time_series_for_low_peaks` | `list[str]` | None | Force inclusion of minimum periods |\n", + "| `cluster_method` | `str` | 'k_means' | Algorithm: 'k_means', 'hierarchical', 'k_medoids', 'k_maxoids', 'averaging' |\n", + "| `representation_method` | `str` | 'meanRepresentation' | 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' |\n", + "| `extreme_period_method` | `str` | 'new_cluster_center' | How peaks are integrated: 'None', 'append', 'new_cluster_center', 'replace_cluster_center' |\n", + "| `rescale_cluster_periods` | `bool` | True | Rescale clusters to match original means |\n", + "| `random_state` | `int` | None | Random seed for reproducibility |\n", + "| `predef_cluster_order` | `array` | None | Manual cluster assignments |\n", + "| `**tsam_kwargs` | - | - | Additional tsam parameters |\n", + "\n", + "### Clustering Object Properties\n", + "\n", + "After clustering, access metadata via `fs.clustering`:\n", + "\n", + "| Property | Description |\n", + "|----------|-------------|\n", + "| `n_clusters` | Number of clusters |\n", + "| `n_original_clusters` | Number of original time segments (e.g., 365 days) |\n", + "| `timesteps_per_cluster` | Timesteps in each cluster (e.g., 24 for daily) |\n", + "| `cluster_order` | xr.DataArray mapping original segment → cluster ID |\n", + "| `occurrences` | How many original segments each cluster represents |\n", + "| `metrics` | xr.Dataset with RMSE, MAE per time series |\n", + "| `plot.compare()` | Compare original vs clustered time series |\n", + "| `plot.heatmap()` | Visualize cluster structure |\n", "\n", "### Storage Behavior\n", "\n", - "Each `Storage` component has a `cluster_storage_mode` parameter that controls how it behaves during clustering:\n", + "Each `Storage` component has a `cluster_mode` parameter:\n", "\n", "| Mode | Description |\n", "|------|-------------|\n", @@ -444,37 +543,12 @@ "| `'cyclic'` | Each cluster is independent but cyclic (start = end) |\n", "| `'independent'` | Each cluster is independent, free start/end |\n", "\n", - "For a detailed comparison of storage modes, see [08c2-clustering-storage-modes](08c2-clustering-storage-modes.ipynb).\n", - "\n", - "### Peak Forcing Format\n", - "\n", - "```python\n", - "time_series_for_high_peaks = ['ComponentName(FlowName)|fixed_relative_profile']\n", - "```\n", - "\n", - "### Recommended Workflow\n", - "\n", - "```python\n", - "# Stage 1: Fast sizing\n", - "fs_sizing = flow_system.transform.cluster(\n", - " n_clusters=8,\n", - " cluster_duration='1D',\n", - " time_series_for_high_peaks=['Demand(Flow)|fixed_relative_profile'],\n", - ")\n", - "fs_sizing.optimize(solver)\n", - "\n", - "# Apply safety margin\n", - "sizes = {k: v.item() * 1.05 for k, v in fs_sizing.statistics.sizes.items()}\n", - "\n", - "# Stage 2: Accurate dispatch\n", - "fs_dispatch = flow_system.transform.fix_sizes(sizes)\n", - "fs_dispatch.optimize(solver)\n", - "```" + "For a detailed comparison of storage modes, see [08c2-clustering-storage-modes](08c2-clustering-storage-modes.ipynb)." ] }, { "cell_type": "markdown", - "id": "24", + "id": "31", "metadata": {}, "source": [ "## Summary\n", @@ -485,13 +559,17 @@ "- Apply **peak forcing** to capture extreme demand days\n", "- Use **two-stage optimization** for fast yet accurate investment decisions\n", "- **Expand solutions** back to full resolution with `expand_solution()`\n", + "- Access **clustering metadata** via `fs.clustering` (metrics, cluster_order, occurrences)\n", + "- Use **advanced options** like different algorithms and reproducible random states\n", "\n", "### Key Takeaways\n", "\n", "1. **Always use peak forcing** (`time_series_for_high_peaks`) for demand time series\n", "2. **Add safety margin** (5-10%) when fixing sizes from clustering\n", "3. **Two-stage is recommended**: clustering for sizing, full resolution for dispatch\n", - "4. **Storage handling** is configurable via `storage_mode`\n", + "4. **Storage handling** is configurable via `cluster_mode`\n", + "5. **Use `random_state`** for reproducible results\n", + "6. **Check metrics** to evaluate clustering quality\n", "\n", "### Next Steps\n", "\n", diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index f914514b4..ab9590aae 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -582,6 +582,8 @@ def compare( *, select: SelectType | None = None, colors: ColorType | None = None, + color: str | None = 'auto', + line_dash: str | None = 'representation', facet_col: str | None = 'auto', facet_row: str | None = 'auto', show: bool | None = None, @@ -597,6 +599,10 @@ def compare( or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). + color: Dimension for line colors. 'auto' uses CONFIG priority (typically 'variable'). + Use 'representation' to color by Original/Clustered instead of line_dash. + line_dash: Dimension for line dash styles. Defaults to 'representation'. + Set to None to disable line dash differentiation. facet_col: Dimension for subplot columns. 'auto' uses CONFIG priority. Use 'variable' to create separate columns per variable. facet_row: Dimension for subplot rows. 'auto' uses CONFIG priority. @@ -664,13 +670,20 @@ def compare( else: title = 'Duration Curve' if len(resolved_variables) > 1 else f'Duration Curve: {resolved_variables[0]}' - # Use fxplot for smart defaults with line_dash for representation + # Use fxplot for smart defaults + line_kwargs = {} + if line_dash is not None: + line_kwargs['line_dash'] = line_dash + if line_dash == 'representation': + line_kwargs['line_dash_map'] = {'Original': 'dot', 'Clustered': 'solid'} + fig = ds.fxplot.line( colors=colors, + color=color, title=title, facet_col=facet_col, facet_row=facet_row, - line_dash='representation', + **line_kwargs, **plotly_kwargs, ) fig.update_yaxes(matches=None) @@ -846,6 +859,8 @@ def clusters( *, select: SelectType | None = None, colors: ColorType | None = None, + color: str | None = 'auto', + facet_col: str | None = 'cluster', facet_cols: int | None = None, show: bool | None = None, **plotly_kwargs: Any, @@ -860,6 +875,10 @@ def clusters( or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). + color: Dimension for line colors. 'auto' uses CONFIG priority (typically 'variable'). + Use 'cluster' to color by cluster instead of faceting. + facet_col: Dimension for subplot columns. Defaults to 'cluster'. + Use 'variable' to facet by variable instead. facet_cols: Max columns before wrapping facets. Defaults to CONFIG.Plotting.default_facet_cols. show: Whether to display the figure. @@ -912,8 +931,9 @@ def clusters( # Use fxplot for smart defaults fig = ds.fxplot.line( colors=colors, + color=color, title=title, - facet_col='cluster', + facet_col=facet_col, facet_cols=facet_cols, **plotly_kwargs, ) diff --git a/tests/test_cluster_reduce_expand.py b/tests/test_cluster_reduce_expand.py index b54eeb56e..4059470ee 100644 --- a/tests/test_cluster_reduce_expand.py +++ b/tests/test_cluster_reduce_expand.py @@ -167,7 +167,7 @@ def test_expand_solution_enables_statistics_accessor(solver_fixture, timesteps_8 # These should work without errors flow_rates = fs_expanded.statistics.flow_rates assert 'Boiler(Q_th)' in flow_rates - assert len(flow_rates['Boiler(Q_th)'].coords['time']) == 192 + assert len(flow_rates['Boiler(Q_th)'].coords['time']) == 193 # 192 + 1 extra timestep flow_hours = fs_expanded.statistics.flow_hours assert 'Boiler(Q_th)' in flow_hours @@ -321,7 +321,7 @@ def test_cluster_and_expand_with_scenarios(solver_fixture, timesteps_8_days, sce flow_var = 'Boiler(Q_th)|flow_rate' assert flow_var in fs_expanded.solution assert 'scenario' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep def test_expand_solution_maps_scenarios_independently(solver_fixture, timesteps_8_days, scenarios_2): @@ -693,7 +693,7 @@ def test_expand_solution_with_periods(self, solver_fixture, timesteps_8_days, pe # Solution should have period dimension flow_var = 'Boiler(Q_th)|flow_rate' assert 'period' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep def test_cluster_with_periods_and_scenarios(self, solver_fixture, timesteps_8_days, periods_2, scenarios_2): """Clustering should work with both periods and scenarios.""" @@ -719,7 +719,7 @@ def test_cluster_with_periods_and_scenarios(self, solver_fixture, timesteps_8_da fs_expanded = fs_clustered.transform.expand_solution() assert 'period' in fs_expanded.solution[flow_var].dims assert 'scenario' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep # ==================== Peak Selection Tests ==================== From ed80f89893be2dcc1503ea808f22a1859f96a17b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 14:15:35 +0100 Subject: [PATCH 40/62] Drop cluster index when expanding --- flixopt/transform_accessor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index b466d928a..8ec350c95 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -1271,6 +1271,8 @@ def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: # Multi-dimensional: select last cluster for each period/scenario slice last_clusters = cluster_order.isel(original_cluster=last_original_cluster_idx) extra_val = da.isel(cluster=last_clusters, time=-1) + # Drop 'cluster' coord created by advanced indexing (non-dim coord from isel) + extra_val = extra_val.drop_vars('cluster', errors='ignore') extra_val = extra_val.expand_dims(time=[original_timesteps_extra[-1]]) expanded = xr.concat([expanded, extra_val], dim='time') From 3dc7eec2d238bedd9179d03854f32ef5fb58cef8 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 15:21:33 +0100 Subject: [PATCH 41/62] Fix storage expansion --- flixopt/transform_accessor.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 8ec350c95..43b15d440 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -1271,8 +1271,8 @@ def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: # Multi-dimensional: select last cluster for each period/scenario slice last_clusters = cluster_order.isel(original_cluster=last_original_cluster_idx) extra_val = da.isel(cluster=last_clusters, time=-1) - # Drop 'cluster' coord created by advanced indexing (non-dim coord from isel) - extra_val = extra_val.drop_vars('cluster', errors='ignore') + # Drop 'cluster'/'time' coords created by isel (kept as non-dim coords) + extra_val = extra_val.drop_vars(['cluster', 'time'], errors='ignore') extra_val = extra_val.expand_dims(time=[original_timesteps_extra[-1]]) expanded = xr.concat([expanded, extra_val], dim='time') @@ -1363,10 +1363,15 @@ def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: time_within_period, dims=['time'], coords={'time': original_timesteps_extra} ) # Decay factor: (1 - loss)^t, using mean loss over time - # Keep as DataArray to respect per-period/scenario values loss_value = storage.relative_loss_per_hour.mean('time') if (loss_value > 0).any(): decay_da = (1 - loss_value) ** time_within_period_da + if 'cluster' in decay_da.dims: + # Map each timestep to its cluster's decay value + cluster_per_timestep = cluster_structure.cluster_order.values[original_cluster_indices] + decay_da = decay_da.isel(cluster=xr.DataArray(cluster_per_timestep, dims=['time'])).drop_vars( + 'cluster', errors='ignore' + ) soc_boundary_per_timestep = soc_boundary_per_timestep * decay_da # Combine: actual_SOC = SOC_boundary * decay + charge_state @@ -1375,6 +1380,14 @@ def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: combined_charge_state = (expanded_charge_state + soc_boundary_per_timestep).clip(min=0) expanded_fs._solution[charge_state_name] = combined_charge_state.assign_attrs(expanded_charge_state.attrs) + # Remove SOC_boundary variables - they're cluster-specific and now incorporated into charge_state + for soc_boundary_name in soc_boundary_vars: + if soc_boundary_name in expanded_fs._solution: + del expanded_fs._solution[soc_boundary_name] + # Also drop the cluster_boundary coordinate (orphaned after removing SOC_boundary) + if 'cluster_boundary' in expanded_fs._solution.coords: + expanded_fs._solution = expanded_fs._solution.drop_vars('cluster_boundary') + n_combinations = len(periods) * len(scenarios) logger.info( f'Expanded FlowSystem from {n_reduced_timesteps} to {n_original_timesteps} timesteps ' From 6b0579f6bb24ac9b0e56c4d1f3b964f0017dc825 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 15:38:41 +0100 Subject: [PATCH 42/62] Improve clustering --- docs/notebooks/08c-clustering.ipynb | 72 +++++++++++++++++++++++------ flixopt/transform_accessor.py | 27 +++++------ 2 files changed, 73 insertions(+), 26 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index d07512cac..4d8b8a121 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -301,6 +301,50 @@ "cell_type": "markdown", "id": "19", "metadata": {}, + "source": [ + "### Manual Cluster Assignment\n", + "\n", + "When comparing design variants or performing sensitivity analysis, you often want to\n", + "use the **same cluster structure** across different FlowSystem configurations.\n", + "Use `predef_cluster_order` to ensure comparable results:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "# Save the cluster order from our optimized system\n", + "cluster_order = fs_clustered.clustering.cluster_order.values\n", + "print(f'Cluster order to reuse: {cluster_order}')\n", + "\n", + "# Now modify the FlowSystem (e.g., increase storage capacity limits)\n", + "flow_system_modified = flow_system.copy()\n", + "flow_system_modified.components['Storage'].capacity_in_flow_hours.maximum_size = 2000 # Larger storage option\n", + "\n", + "# Cluster with the SAME cluster structure for fair comparison\n", + "fs_modified_clustered = flow_system_modified.transform.cluster(\n", + " n_clusters=8,\n", + " cluster_duration='1D',\n", + " predef_cluster_order=cluster_order, # Reuse cluster assignments\n", + ")\n", + "\n", + "# Optimize the modified system\n", + "fs_modified_clustered.optimize(solver)\n", + "\n", + "print('\\nComparison (same cluster structure):')\n", + "print(f' Original storage size: {fs_clustered.statistics.sizes[\"Storage\"].item():.0f}')\n", + "print(f' Modified storage size: {fs_modified_clustered.statistics.sizes[\"Storage\"].item():.0f}')\n", + "print(f' Original cost: {fs_clustered.solution[\"costs\"].item():,.0f} €')\n", + "print(f' Modified cost: {fs_modified_clustered.solution[\"costs\"].item():,.0f} €')" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, "source": [ "## Method 3: Two-Stage Workflow (Recommended)\n", "\n", @@ -317,7 +361,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -338,7 +382,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -361,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "24", "metadata": {}, "source": [ "## Compare Results" @@ -370,7 +414,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -419,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "26", "metadata": {}, "source": [ "## Expand Solution to Full Resolution\n", @@ -431,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -445,7 +489,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -467,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "27", + "id": "29", "metadata": {}, "source": [ "## Visualize Clustered Heat Balance" @@ -476,7 +520,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -486,16 +530,16 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "31", "metadata": {}, "outputs": [], "source": [ - "fs_expanded.statistics.plot.storage('Storage').data.to_dataframe()" + "fs_expanded.statistics.plot.storage('Storage')" ] }, { "cell_type": "markdown", - "id": "30", + "id": "32", "metadata": {}, "source": [ "## API Reference\n", @@ -548,7 +592,7 @@ }, { "cell_type": "markdown", - "id": "31", + "id": "33", "metadata": {}, "source": [ "## Summary\n", @@ -561,6 +605,7 @@ "- **Expand solutions** back to full resolution with `expand_solution()`\n", "- Access **clustering metadata** via `fs.clustering` (metrics, cluster_order, occurrences)\n", "- Use **advanced options** like different algorithms and reproducible random states\n", + "- **Manually assign clusters** using `predef_cluster_order`\n", "\n", "### Key Takeaways\n", "\n", @@ -570,6 +615,7 @@ "4. **Storage handling** is configurable via `cluster_mode`\n", "5. **Use `random_state`** for reproducible results\n", "6. **Check metrics** to evaluate clustering quality\n", + "7. **Use `predef_cluster_order`** to reproduce or define custom cluster assignments\n", "\n", "### Next Steps\n", "\n", diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 43b15d440..e3a41a3ba 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -582,13 +582,11 @@ def cluster( weights: dict[str, float] | None = None, time_series_for_high_peaks: list[str] | None = None, time_series_for_low_peaks: list[str] | None = None, - cluster_method: Literal['k_means', 'k_medoids', 'hierarchical', 'k_maxoids', 'averaging'] = 'k_means', + cluster_method: Literal['k_means', 'k_medoids', 'hierarchical', 'k_maxoids', 'averaging'] = 'hierarchical', representation_method: Literal[ 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' - ] = 'meanRepresentation', - extreme_period_method: Literal[ - 'None', 'append', 'new_cluster_center', 'replace_cluster_center' - ] = 'new_cluster_center', + ] = 'medoidRepresentation', + extreme_period_method: Literal['append', 'new_cluster_center', 'replace_cluster_center'] | None = None, rescale_cluster_periods: bool = True, random_state: int | None = None, predef_cluster_order: xr.DataArray | np.ndarray | list[int] | None = None, @@ -602,7 +600,7 @@ def cluster( through time series aggregation using the tsam package. The method: - 1. Performs time series clustering using tsam (k-means) + 1. Performs time series clustering using tsam (hierarchical by default) 2. Extracts only the typical clusters (not all original timesteps) 3. Applies timestep weighting for accurate cost representation 4. Handles storage states between clusters based on each Storage's ``cluster_mode`` @@ -619,18 +617,19 @@ def cluster( clusters. **Recommended** for demand time series to capture peak demand days. time_series_for_low_peaks: Time series labels for explicitly selecting low-value clusters. cluster_method: Clustering algorithm to use. Options: - ``'k_means'`` (default), ``'k_medoids'``, ``'hierarchical'``, + ``'hierarchical'`` (default), ``'k_means'``, ``'k_medoids'``, ``'k_maxoids'``, ``'averaging'``. representation_method: How cluster representatives are computed. Options: - ``'meanRepresentation'`` (default), ``'medoidRepresentation'``, + ``'medoidRepresentation'`` (default), ``'meanRepresentation'``, ``'distributionAndMinMaxRepresentation'``. extreme_period_method: How extreme periods (peaks) are integrated. Options: - ``'new_cluster_center'`` (default), ``'None'``, ``'append'``, - ``'replace_cluster_center'``. + ``None`` (default, no special handling), ``'append'``, + ``'new_cluster_center'``, ``'replace_cluster_center'``. rescale_cluster_periods: If True (default), rescale cluster periods so their weighted mean matches the original time series mean. - random_state: Random seed for reproducible clustering results. If None, - results may vary between runs. + random_state: Random seed for reproducible clustering results. Only relevant + for non-deterministic methods like ``'k_means'``. The default + ``'hierarchical'`` method is deterministic. predef_cluster_order: Predefined cluster assignments for manual clustering. Array of cluster indices (0 to n_clusters-1) for each original period. If provided, clustering is skipped and these assignments are used directly. @@ -743,13 +742,15 @@ def cluster( # Use tsam directly clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) + # tsam expects 'None' as a string, not Python None + tsam_extreme_method = 'None' if extreme_period_method is None else extreme_period_method tsam_agg = tsam.TimeSeriesAggregation( df, noTypicalPeriods=n_clusters, hoursPerPeriod=hours_per_cluster, resolution=dt, clusterMethod=cluster_method, - extremePeriodMethod=extreme_period_method, + extremePeriodMethod=tsam_extreme_method, representationMethod=representation_method, rescaleClusterPeriods=rescale_cluster_periods, predefClusterOrder=predef_order_slice, From b2539d86a166c0473aedf46e95ca02df9a3d4286 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:45:41 +0100 Subject: [PATCH 43/62] fix scatter plot faceting --- flixopt/dataset_plot_accessor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 73b20b436..e2802cb04 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -590,13 +590,15 @@ def scatter( if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), y: ylabel} - if actual_facet_col: + # Only use facets if the column actually exists in the dataframe + # (scatter uses wide format, so 'variable' column doesn't exist) + if actual_facet_col and actual_facet_col in df.columns: fig_kwargs['facet_col'] = actual_facet_col if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: + if actual_facet_row and actual_facet_row in df.columns: fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: + if actual_anim and actual_anim in df.columns: fig_kwargs['animation_frame'] = actual_anim return px.scatter(**fig_kwargs) From e48ff177f177f3f7e4b6584a44f80abd739cc07f Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:48:56 +0100 Subject: [PATCH 44/62] =?UTF-8?q?=E2=8F=BA=20Fixed=20the=20documentation?= =?UTF-8?q?=20in=20the=20notebook:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Cell 32 (API Reference table): Updated defaults to 'hierarchical', 'medoidRepresentation', and None 2. Cell 16: Swapped the example to show k_means as the alternative (since hierarchical is now default) 3. Cell 17: Updated variable names to match 4. Cell 33 (Key Takeaways): Clarified that random_state is only needed for non-deterministic methods like 'k_means' The code review --- docs/notebooks/08c-clustering.ipynb | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index 4d8b8a121..35b91f6eb 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -257,16 +257,16 @@ "outputs": [], "source": [ "# Try different clustering algorithms\n", - "fs_hierarchical = flow_system.transform.cluster(\n", + "fs_kmeans = flow_system.transform.cluster(\n", " n_clusters=8,\n", " cluster_duration='1D',\n", - " cluster_method='hierarchical', # Alternative: 'k_means' (default), 'k_medoids', 'averaging'\n", + " cluster_method='k_means', # Alternative: 'hierarchical' (default), 'k_medoids', 'averaging'\n", " random_state=42,\n", ")\n", "\n", "# Compare cluster assignments between algorithms\n", - "print('k_means clusters: ', fs_clustered.clustering.cluster_order.values)\n", - "print('hierarchical clusters:', fs_hierarchical.clustering.cluster_order.values)" + "print('hierarchical clusters:', fs_clustered.clustering.cluster_order.values)\n", + "print('k_means clusters: ', fs_kmeans.clustering.cluster_order.values)" ] }, { @@ -279,10 +279,10 @@ "# Compare RMSE between algorithms\n", "print('Quality comparison (RMSE for HeatDemand):')\n", "print(\n", - " f' k_means: {float(fs_clustered.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", + " f' hierarchical: {float(fs_clustered.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", ")\n", "print(\n", - " f' hierarchical: {float(fs_hierarchical.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", + " f' k_means: {float(fs_kmeans.clustering.metrics[\"RMSE\"].sel(time_series=\"HeatDemand(Q_th)|fixed_relative_profile\")):.4f}'\n", ")" ] }, @@ -553,11 +553,11 @@ "| `weights` | `dict[str, float]` | None | Optional weights for time series in clustering |\n", "| `time_series_for_high_peaks` | `list[str]` | None | **Essential**: Force inclusion of peak periods |\n", "| `time_series_for_low_peaks` | `list[str]` | None | Force inclusion of minimum periods |\n", - "| `cluster_method` | `str` | 'k_means' | Algorithm: 'k_means', 'hierarchical', 'k_medoids', 'k_maxoids', 'averaging' |\n", - "| `representation_method` | `str` | 'meanRepresentation' | 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' |\n", - "| `extreme_period_method` | `str` | 'new_cluster_center' | How peaks are integrated: 'None', 'append', 'new_cluster_center', 'replace_cluster_center' |\n", + "| `cluster_method` | `str` | 'hierarchical' | Algorithm: 'hierarchical', 'k_means', 'k_medoids', 'k_maxoids', 'averaging' |\n", + "| `representation_method` | `str` | 'medoidRepresentation' | 'medoidRepresentation', 'meanRepresentation', 'distributionAndMinMaxRepresentation' |\n", + "| `extreme_period_method` | `str \\| None` | None | How peaks are integrated: None, 'append', 'new_cluster_center', 'replace_cluster_center' |\n", "| `rescale_cluster_periods` | `bool` | True | Rescale clusters to match original means |\n", - "| `random_state` | `int` | None | Random seed for reproducibility |\n", + "| `random_state` | `int` | None | Random seed for reproducibility (only needed for non-deterministic methods like 'k_means') |\n", "| `predef_cluster_order` | `array` | None | Manual cluster assignments |\n", "| `**tsam_kwargs` | - | - | Additional tsam parameters |\n", "\n", @@ -613,7 +613,7 @@ "2. **Add safety margin** (5-10%) when fixing sizes from clustering\n", "3. **Two-stage is recommended**: clustering for sizing, full resolution for dispatch\n", "4. **Storage handling** is configurable via `cluster_mode`\n", - "5. **Use `random_state`** for reproducible results\n", + "5. **Use `random_state`** for reproducible results when using non-deterministic methods like 'k_means' (the default 'hierarchical' is deterministic)\n", "6. **Check metrics** to evaluate clustering quality\n", "7. **Use `predef_cluster_order`** to reproduce or define custom cluster assignments\n", "\n", From 285e07b59b6550be0f45ec8e44d09ed4c26cf623 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:53:30 +0100 Subject: [PATCH 45/62] 1. Error handling for accuracyIndicators() - Added try/except with warning log and empty DataFrame fallback, plus handling empty DataFrames when building the metrics Dataset 2. Random state to tsam - Replaced global np.random.seed() with passing seed parameter directly to tsam's TimeSeriesAggregation 3. tsam_kwargs conflict validation - Added validation that raises ValueError if user tries to override explicit parameters via **tsam_kwargs (including seed) 4. predef_cluster_order validation - Added dimension validation for DataArray inputs, checking they match the FlowSystem's period/scenario structure 5. Out-of-bounds fix - Clamped last_original_cluster_idx to n_original_clusters - 1 to handle partial clusters at the end --- flixopt/transform_accessor.py | 104 +++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 26 deletions(-) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index e3a41a3ba..210725d7d 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -707,15 +707,46 @@ def cluster( ds = self._fs.to_dataset(include_solution=False) + # Validate tsam_kwargs doesn't override explicit parameters + reserved_tsam_keys = { + 'noTypicalPeriods', + 'hoursPerPeriod', + 'resolution', + 'clusterMethod', + 'extremePeriodMethod', + 'representationMethod', + 'rescaleClusterPeriods', + 'predefClusterOrder', + 'weightDict', + 'addPeakMax', + 'addPeakMin', + 'seed', # Controlled by random_state parameter + } + conflicts = reserved_tsam_keys & set(tsam_kwargs.keys()) + if conflicts: + raise ValueError( + f'Cannot override explicit parameters via tsam_kwargs: {conflicts}. ' + f'Use the corresponding cluster() parameters instead.' + ) + + # Validate predef_cluster_order dimensions if it's a DataArray + if isinstance(predef_cluster_order, xr.DataArray): + expected_dims = {'original_cluster'} + if has_periods: + expected_dims.add('period') + if has_scenarios: + expected_dims.add('scenario') + if set(predef_cluster_order.dims) != expected_dims: + raise ValueError( + f'predef_cluster_order dimensions {set(predef_cluster_order.dims)} ' + f'do not match expected {expected_dims} for this FlowSystem.' + ) + # Cluster each (period, scenario) combination using tsam directly tsam_results: dict[tuple, tsam.TimeSeriesAggregation] = {} cluster_orders: dict[tuple, np.ndarray] = {} cluster_occurrences_all: dict[tuple, dict] = {} - # Set random seed for reproducibility - if random_state is not None: - np.random.seed(random_state) - # Collect metrics per (period, scenario) slice clustering_metrics_all: dict[tuple, pd.DataFrame] = {} @@ -744,21 +775,24 @@ def cluster( clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) # tsam expects 'None' as a string, not Python None tsam_extreme_method = 'None' if extreme_period_method is None else extreme_period_method - tsam_agg = tsam.TimeSeriesAggregation( - df, - noTypicalPeriods=n_clusters, - hoursPerPeriod=hours_per_cluster, - resolution=dt, - clusterMethod=cluster_method, - extremePeriodMethod=tsam_extreme_method, - representationMethod=representation_method, - rescaleClusterPeriods=rescale_cluster_periods, - predefClusterOrder=predef_order_slice, - weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, - addPeakMax=time_series_for_high_peaks or [], - addPeakMin=time_series_for_low_peaks or [], - **tsam_kwargs, - ) + # Build tsam kwargs, including random_state if provided + tsam_init_kwargs: dict[str, Any] = { + 'noTypicalPeriods': n_clusters, + 'hoursPerPeriod': hours_per_cluster, + 'resolution': dt, + 'clusterMethod': cluster_method, + 'extremePeriodMethod': tsam_extreme_method, + 'representationMethod': representation_method, + 'rescaleClusterPeriods': rescale_cluster_periods, + 'predefClusterOrder': predef_order_slice, + 'weightDict': {name: w for name, w in clustering_weights.items() if name in df.columns}, + 'addPeakMax': time_series_for_high_peaks or [], + 'addPeakMin': time_series_for_low_peaks or [], + } + # Pass random_state to tsam instead of setting global np.random.seed() + if random_state is not None: + tsam_init_kwargs['seed'] = random_state + tsam_agg = tsam.TimeSeriesAggregation(df, **tsam_init_kwargs, **tsam_kwargs) # Suppress tsam warning about minimal value constraints (informational, not actionable) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=UserWarning, message='.*minimal value.*exceeds.*') @@ -767,16 +801,26 @@ def cluster( tsam_results[key] = tsam_agg cluster_orders[key] = tsam_agg.clusterOrder cluster_occurrences_all[key] = tsam_agg.clusterPeriodNoOccur - clustering_metrics_all[key] = tsam_agg.accuracyIndicators() + # Compute accuracy metrics with error handling + try: + clustering_metrics_all[key] = tsam_agg.accuracyIndicators() + except Exception as e: + logger.warning(f'Failed to compute clustering metrics for {key}: {e}') + clustering_metrics_all[key] = pd.DataFrame() # Use first result for structure first_key = (periods[0], scenarios[0]) first_tsam = tsam_results[first_key] # Convert metrics to xr.Dataset with period/scenario dims if multi-dimensional - if len(clustering_metrics_all) == 1: + # Filter out empty DataFrames (from failed accuracyIndicators calls) + non_empty_metrics = {k: v for k, v in clustering_metrics_all.items() if not v.empty} + if not non_empty_metrics: + # All metrics failed - create empty Dataset + clustering_metrics = xr.Dataset() + elif len(non_empty_metrics) == 1 or len(clustering_metrics_all) == 1: # Simple case: convert single DataFrame to Dataset - metrics_df = clustering_metrics_all[first_key] + metrics_df = non_empty_metrics.get(first_key) or next(iter(non_empty_metrics.values())) clustering_metrics = xr.Dataset( { col: xr.DataArray( @@ -787,8 +831,8 @@ def cluster( ) else: # Multi-dim case: combine metrics into Dataset with period/scenario dims - # First, get the metric columns from any DataFrame - sample_df = next(iter(clustering_metrics_all.values())) + # First, get the metric columns from any non-empty DataFrame + sample_df = next(iter(non_empty_metrics.values())) metric_names = list(sample_df.columns) time_series_names = list(sample_df.index) @@ -798,7 +842,11 @@ def cluster( # Shape: (time_series, period?, scenario?) slices = {} for (p, s), df in clustering_metrics_all.items(): - slices[(p, s)] = xr.DataArray(df[metric].values, dims=['time_series']) + if df.empty: + # Use NaN for failed metrics + slices[(p, s)] = xr.DataArray(np.full(len(time_series_names), np.nan), dims=['time_series']) + else: + slices[(p, s)] = xr.DataArray(df[metric].values, dims=['time_series']) da = self._combine_slices_to_dataarray_generic(slices, ['time_series'], periods, scenarios, metric) da = da.assign_coords(time_series=time_series_names) @@ -1254,7 +1302,11 @@ def expand_solution(self) -> FlowSystem: # Expand function using ClusterResult.expand_data() - handles multi-dimensional cases # For charge_state with cluster dim, also includes the extra timestep - last_original_cluster_idx = (n_original_timesteps - 1) // timesteps_per_cluster + # Clamp to valid bounds to handle partial clusters at the end + last_original_cluster_idx = min( + (n_original_timesteps - 1) // timesteps_per_cluster, + n_original_clusters - 1, + ) def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: if 'time' not in da.dims: From c126115dfa00df77ce78adca1755e796ab421ec5 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 20:43:47 +0100 Subject: [PATCH 46/62] 1. DataFrame truth ambiguity - Changed non_empty_metrics.get(first_key) or next(...) to explicit if metrics_df is None: check 2. removed random state --- docs/notebooks/08c-clustering.ipynb | 10 ++---- flixopt/transform_accessor.py | 42 ++++++++++------------- tests/test_clustering/test_integration.py | 10 +++--- 3 files changed, 26 insertions(+), 36 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index 35b91f6eb..9676b6992 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -147,7 +147,6 @@ " n_clusters=8, # 8 typical days\n", " cluster_duration='1D', # Daily clustering\n", " time_series_for_high_peaks=peak_series, # Capture peak demand day\n", - " random_state=42, # Reproducible results\n", ")\n", "\n", "time_clustering = timeit.default_timer() - start\n", @@ -261,7 +260,6 @@ " n_clusters=8,\n", " cluster_duration='1D',\n", " cluster_method='k_means', # Alternative: 'hierarchical' (default), 'k_medoids', 'averaging'\n", - " random_state=42,\n", ")\n", "\n", "# Compare cluster assignments between algorithms\n", @@ -557,7 +555,6 @@ "| `representation_method` | `str` | 'medoidRepresentation' | 'medoidRepresentation', 'meanRepresentation', 'distributionAndMinMaxRepresentation' |\n", "| `extreme_period_method` | `str \\| None` | None | How peaks are integrated: None, 'append', 'new_cluster_center', 'replace_cluster_center' |\n", "| `rescale_cluster_periods` | `bool` | True | Rescale clusters to match original means |\n", - "| `random_state` | `int` | None | Random seed for reproducibility (only needed for non-deterministic methods like 'k_means') |\n", "| `predef_cluster_order` | `array` | None | Manual cluster assignments |\n", "| `**tsam_kwargs` | - | - | Additional tsam parameters |\n", "\n", @@ -604,7 +601,7 @@ "- Use **two-stage optimization** for fast yet accurate investment decisions\n", "- **Expand solutions** back to full resolution with `expand_solution()`\n", "- Access **clustering metadata** via `fs.clustering` (metrics, cluster_order, occurrences)\n", - "- Use **advanced options** like different algorithms and reproducible random states\n", + "- Use **advanced options** like different algorithms\n", "- **Manually assign clusters** using `predef_cluster_order`\n", "\n", "### Key Takeaways\n", @@ -613,9 +610,8 @@ "2. **Add safety margin** (5-10%) when fixing sizes from clustering\n", "3. **Two-stage is recommended**: clustering for sizing, full resolution for dispatch\n", "4. **Storage handling** is configurable via `cluster_mode`\n", - "5. **Use `random_state`** for reproducible results when using non-deterministic methods like 'k_means' (the default 'hierarchical' is deterministic)\n", - "6. **Check metrics** to evaluate clustering quality\n", - "7. **Use `predef_cluster_order`** to reproduce or define custom cluster assignments\n", + "5. **Check metrics** to evaluate clustering quality\n", + "6. **Use `predef_cluster_order`** to reproduce or define custom cluster assignments\n", "\n", "### Next Steps\n", "\n", diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 210725d7d..6a5b51caa 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -588,7 +588,6 @@ def cluster( ] = 'medoidRepresentation', extreme_period_method: Literal['append', 'new_cluster_center', 'replace_cluster_center'] | None = None, rescale_cluster_periods: bool = True, - random_state: int | None = None, predef_cluster_order: xr.DataArray | np.ndarray | list[int] | None = None, **tsam_kwargs: Any, ) -> FlowSystem: @@ -627,9 +626,6 @@ def cluster( ``'new_cluster_center'``, ``'replace_cluster_center'``. rescale_cluster_periods: If True (default), rescale cluster periods so their weighted mean matches the original time series mean. - random_state: Random seed for reproducible clustering results. Only relevant - for non-deterministic methods like ``'k_means'``. The default - ``'hierarchical'`` method is deterministic. predef_cluster_order: Predefined cluster assignments for manual clustering. Array of cluster indices (0 to n_clusters-1) for each original period. If provided, clustering is skipped and these assignments are used directly. @@ -720,7 +716,6 @@ def cluster( 'weightDict', 'addPeakMax', 'addPeakMin', - 'seed', # Controlled by random_state parameter } conflicts = reserved_tsam_keys & set(tsam_kwargs.keys()) if conflicts: @@ -775,24 +770,21 @@ def cluster( clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) # tsam expects 'None' as a string, not Python None tsam_extreme_method = 'None' if extreme_period_method is None else extreme_period_method - # Build tsam kwargs, including random_state if provided - tsam_init_kwargs: dict[str, Any] = { - 'noTypicalPeriods': n_clusters, - 'hoursPerPeriod': hours_per_cluster, - 'resolution': dt, - 'clusterMethod': cluster_method, - 'extremePeriodMethod': tsam_extreme_method, - 'representationMethod': representation_method, - 'rescaleClusterPeriods': rescale_cluster_periods, - 'predefClusterOrder': predef_order_slice, - 'weightDict': {name: w for name, w in clustering_weights.items() if name in df.columns}, - 'addPeakMax': time_series_for_high_peaks or [], - 'addPeakMin': time_series_for_low_peaks or [], - } - # Pass random_state to tsam instead of setting global np.random.seed() - if random_state is not None: - tsam_init_kwargs['seed'] = random_state - tsam_agg = tsam.TimeSeriesAggregation(df, **tsam_init_kwargs, **tsam_kwargs) + tsam_agg = tsam.TimeSeriesAggregation( + df, + noTypicalPeriods=n_clusters, + hoursPerPeriod=hours_per_cluster, + resolution=dt, + clusterMethod=cluster_method, + extremePeriodMethod=tsam_extreme_method, + representationMethod=representation_method, + rescaleClusterPeriods=rescale_cluster_periods, + predefClusterOrder=predef_order_slice, + weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, + addPeakMax=time_series_for_high_peaks or [], + addPeakMin=time_series_for_low_peaks or [], + **tsam_kwargs, + ) # Suppress tsam warning about minimal value constraints (informational, not actionable) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=UserWarning, message='.*minimal value.*exceeds.*') @@ -820,7 +812,9 @@ def cluster( clustering_metrics = xr.Dataset() elif len(non_empty_metrics) == 1 or len(clustering_metrics_all) == 1: # Simple case: convert single DataFrame to Dataset - metrics_df = non_empty_metrics.get(first_key) or next(iter(non_empty_metrics.values())) + metrics_df = non_empty_metrics.get(first_key) + if metrics_df is None: + metrics_df = next(iter(non_empty_metrics.values())) clustering_metrics = xr.Dataset( { col: xr.DataArray( diff --git a/tests/test_clustering/test_integration.py b/tests/test_clustering/test_integration.py index d6dd3d2e7..16c638c95 100644 --- a/tests/test_clustering/test_integration.py +++ b/tests/test_clustering/test_integration.py @@ -201,12 +201,12 @@ def test_cluster_method_parameter(self, basic_flow_system): ) assert len(fs_clustered.clusters) == 2 - def test_random_state_reproducibility(self, basic_flow_system): - """Test that random_state produces reproducible results.""" - fs1 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', random_state=42) - fs2 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', random_state=42) + def test_hierarchical_is_deterministic(self, basic_flow_system): + """Test that hierarchical clustering (default) produces deterministic results.""" + fs1 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') + fs2 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') - # Same random state should produce identical cluster orders + # Hierarchical clustering should produce identical cluster orders xr.testing.assert_equal(fs1.clustering.cluster_order, fs2.clustering.cluster_order) def test_metrics_available(self, basic_flow_system): From 14887214f59bea1d60b66cafeebd7d884e307798 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 4 Jan 2026 21:51:45 +0100 Subject: [PATCH 47/62] Fix pie plot animation frame and add warnings for unassigned dims --- flixopt/dataset_plot_accessor.py | 37 ++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index e2802cb04..23c88bdb7 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from typing import Any, Literal import pandas as pd @@ -12,6 +13,8 @@ from .color_processing import ColorType, process_colors from .config import CONFIG +logger = logging.getLogger('flixopt') + def _get_x_dim(dims: list[str], n_data_vars: int = 1, x: str | Literal['auto'] | None = 'auto') -> str: """Select x-axis dim from priority list, or 'variable' for scalar data. @@ -93,6 +96,25 @@ def _resolve_auto_facets( used.add(next_dim) results[slot_name] = next_dim + # Warn if any dimensions were not assigned to any slot + # Only count slots that were available (passed as 'auto' or explicit dim, not None) + available_slot_count = sum(1 for v in slots.values() if v is not None) + unassigned = available - used + if unassigned: + if available_slot_count < 4: + # Some slots weren't available (e.g., pie doesn't support animation_frame) + unavailable_slots = [k for k, v in slots.items() if v is None] + logger.warning( + f'Dimensions {unassigned} not assigned to any plot dimension. ' + f'Not available for this plot type: {unavailable_slots}. ' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).' + ) + else: + logger.warning( + f'Dimensions {unassigned} not assigned to color/facet/animation. ' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).' + ) + return results['color'], results['facet_col'], results['facet_row'], results['animation_frame'] @@ -610,21 +632,22 @@ def pie( title: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a pie chart from aggregated dataset values. - Extra dimensions are auto-assigned to facet_col, facet_row, and animation_frame. + Extra dimensions are auto-assigned to facet_col and facet_row. For scalar values, a single pie is shown. + Note: + ``px.pie()`` does not support animation_frame, so only facets are available. + Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. facet_col: Dimension for column facets. 'auto' uses CONFIG priority. facet_row: Dimension for row facets. 'auto' uses CONFIG priority. - animation_frame: Dimension for animation slider. 'auto' uses CONFIG priority. facet_cols: Number of columns in facet grid wrap. **px_kwargs: Additional arguments passed to plotly.express.pie. @@ -654,14 +677,12 @@ def pie( **px_kwargs, ) - # Multi-dimensional case - faceted/animated pies + # Multi-dimensional case - faceted pies (px.pie doesn't support animation_frame) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - _, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, None, facet_col, facet_row, animation_frame - ) + _, actual_facet_col, actual_facet_row, _ = _resolve_auto_facets(self._ds, None, facet_col, facet_row, None) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -680,8 +701,6 @@ def pie( fig_kwargs['facet_col_wrap'] = facet_col_wrap if actual_facet_row: fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: - fig_kwargs['animation_frame'] = actual_anim return px.pie(**fig_kwargs) From e18966bc8db978444e841453520c27aeb94a4f31 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 11:43:40 +0100 Subject: [PATCH 48/62] Change logger warning to regular warning --- flixopt/dataset_plot_accessor.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 735393da8..9377aed2e 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -2,7 +2,7 @@ from __future__ import annotations -import logging +import warnings from typing import Any, Literal import pandas as pd @@ -13,8 +13,6 @@ from .color_processing import ColorType, process_colors from .config import CONFIG -logger = logging.getLogger('flixopt') - def _get_x_dim(dims: list[str], n_data_vars: int = 1, x: str | Literal['auto'] | None = 'auto') -> str: """Select x-axis dim from priority list, or 'variable' for scalar data. @@ -104,15 +102,17 @@ def _resolve_auto_facets( if available_slot_count < 4: # Some slots weren't available (e.g., pie doesn't support animation_frame) unavailable_slots = [k for k, v in slots.items() if v is None] - logger.warning( + warnings.warn( f'Dimensions {unassigned} not assigned to any plot dimension. ' f'Not available for this plot type: {unavailable_slots}. ' - f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).', + stacklevel=3, ) else: - logger.warning( + warnings.warn( f'Dimensions {unassigned} not assigned to color/facet/animation. ' - f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).', + stacklevel=3, ) return results['color'], results['facet_col'], results['facet_row'], results['animation_frame'] From 87ce35139c70e86ba5ecda8669c45566f545c1c6 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:04:37 +0100 Subject: [PATCH 49/62] =?UTF-8?q?=E2=8F=BA=20The=20centralized=20slot=20as?= =?UTF-8?q?signment=20system=20is=20now=20complete.=20Here's=20a=20summary?= =?UTF-8?q?=20of=20the=20changes=20made:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes Made 1. flixopt/config.py - Replaced three separate config attributes (extra_dim_priority, dim_slot_priority, x_dim_priority) with a single unified dim_priority tuple - Updated CONFIG.Plotting class docstring and attribute definitions - Updated to_dict() method to use the new attribute - The new priority order: ('time', 'duration', 'duration_pct', 'variable', 'cluster', 'period', 'scenario') 2. flixopt/dataset_plot_accessor.py - Created new assign_slots() function that centralizes all dimension-to-slot assignment logic - Fixed slot fill order: x → color → facet_col → facet_row → animation_frame - Updated all plot methods (bar, stacked_bar, line, area, heatmap, scatter, pie) to use assign_slots() - Removed old _get_x_dim() and _resolve_auto_facets() functions - Updated docstrings to reference dim_priority instead of x_dim_priority 3. flixopt/statistics_accessor.py - Updated _resolve_auto_facets() to use the new assign_slots() function internally - Added import for assign_slots from dataset_plot_accessor Key Design Decisions - Single priority list controls all auto-assignment - Slots are filled in fixed order based on availability - None means a slot is not available for that plot type - 'auto' triggers auto-assignment from priority list - Explicit string values override auto-assignment --- flixopt/config.py | 23 +-- flixopt/dataset_plot_accessor.py | 328 +++++++++++++++---------------- flixopt/statistics_accessor.py | 39 +--- 3 files changed, 171 insertions(+), 219 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index 9793f9ba2..fce943eb1 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -164,9 +164,7 @@ def format(self, record): 'default_sequential_colorscale': 'turbo', 'default_qualitative_colorscale': 'plotly', 'default_line_shape': 'hv', - 'extra_dim_priority': ('variable', 'cluster', 'period', 'scenario'), - 'dim_slot_priority': ('color', 'facet_col', 'facet_row', 'animation_frame'), - 'x_dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'period', 'scenario', 'cluster'), + 'dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'cluster', 'period', 'scenario'), } ), 'solving': MappingProxyType( @@ -562,9 +560,9 @@ class Plotting: default_facet_cols: Default number of columns for faceted plots. default_sequential_colorscale: Default colorscale for heatmaps and continuous data. default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts). - extra_dim_priority: Order of extra dimensions when auto-assigning to slots. - dim_slot_priority: Order of slots to fill with extra dimensions. - x_dim_priority: Order of dimensions to prefer for x-axis when 'auto'. + dim_priority: Priority order for assigning dimensions to plot slots (x, color, facet, etc.). + Dimensions are assigned to slots in order: x → y → color → facet_col → facet_row → animation_frame. + 'value' represents the y-axis values (from data_var names after melting). Examples: ```python @@ -573,9 +571,8 @@ class Plotting: CONFIG.Plotting.default_sequential_colorscale = 'plasma' CONFIG.Plotting.default_qualitative_colorscale = 'Dark24' - # Customize dimension handling for faceting - CONFIG.Plotting.extra_dim_priority = ('scenario', 'period', 'cluster') - CONFIG.Plotting.dim_slot_priority = ('facet_row', 'facet_col', 'animation_frame') + # Customize dimension priority for auto-assignment + CONFIG.Plotting.dim_priority = ('time', 'scenario', 'variable', 'period', 'cluster') ``` """ @@ -586,9 +583,7 @@ class Plotting: default_sequential_colorscale: str = _DEFAULTS['plotting']['default_sequential_colorscale'] default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale'] default_line_shape: str = _DEFAULTS['plotting']['default_line_shape'] - extra_dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['extra_dim_priority'] - dim_slot_priority: tuple[str, ...] = _DEFAULTS['plotting']['dim_slot_priority'] - x_dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['x_dim_priority'] + dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['dim_priority'] class Carriers: """Default carrier definitions for common energy types. @@ -690,9 +685,7 @@ def to_dict(cls) -> dict: 'default_sequential_colorscale': cls.Plotting.default_sequential_colorscale, 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale, 'default_line_shape': cls.Plotting.default_line_shape, - 'extra_dim_priority': cls.Plotting.extra_dim_priority, - 'dim_slot_priority': cls.Plotting.dim_slot_priority, - 'x_dim_priority': cls.Plotting.x_dim_priority, + 'dim_priority': cls.Plotting.dim_priority, }, } diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 9377aed2e..746227e45 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -14,94 +14,86 @@ from .config import CONFIG -def _get_x_dim(dims: list[str], n_data_vars: int = 1, x: str | Literal['auto'] | None = 'auto') -> str: - """Select x-axis dim from priority list, or 'variable' for scalar data. - - Args: - dims: List of available dimensions. - n_data_vars: Number of data variables (for 'variable' availability). - x: Explicit x-axis choice or 'auto'. - """ - if x and x != 'auto': - return x - - # 'variable' is available when there are multiple data_vars - available = set(dims) - if n_data_vars > 1: - available.add('variable') - - # Check priority list first - for dim in CONFIG.Plotting.x_dim_priority: - if dim in available: - return dim - - # Fallback to first available dimension, or 'variable' for scalar data - return dims[0] if dims else 'variable' - - -def _resolve_auto_facets( +def assign_slots( ds: xr.Dataset, - color: str | Literal['auto'] | None, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, - animation_frame: str | Literal['auto'] | None = None, - exclude_dims: set[str] | None = None, -) -> tuple[str | None, str | None, str | None, str | None]: - """Assign 'auto' facet slots from available dims using CONFIG priority lists. + *, + x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', +) -> dict[str, str | None]: + """Assign dimensions to plot slots using CONFIG.Plotting.dim_priority. + + Slot fill order: x → color → facet_col → facet_row → animation_frame. + Dimensions are assigned in priority order from CONFIG.Plotting.dim_priority. + + Slot values: + - 'auto': auto-assign from available dims using priority + - None: skip this slot (not available for this plot type) + - str: use this specific dimension + + 'variable' is treated as a dimension when len(data_vars) > 1. It represents + the data_var names column in the melted DataFrame. - 'variable' is treated like a dimension - available when len(data_vars) > 1. - It exists in the melted DataFrame from data_var names, not in ds.dims. + Args: + ds: Dataset to analyze for available dimensions. + x: X-axis dimension. 'auto' assigns first available from priority. + color: Color grouping dimension. + facet_col: Column faceting dimension. + facet_row: Row faceting dimension. + animation_frame: Animation slider dimension. Returns: - Tuple of (color, facet_col, facet_row, animation_frame). + Dict with keys 'x', 'color', 'facet_col', 'facet_row', 'animation_frame' + and values being assigned dimension names (or None if slot skipped/unfilled). """ - # Get available extra dimensions with size > 1, excluding specified dims - exclude = exclude_dims or set() - available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} - # 'variable' is available when there are multiple data_vars (and not excluded) - if len(ds.data_vars) > 1 and 'variable' not in exclude: + # Get available dimensions with size > 1 + available = {d for d in ds.dims if ds.sizes[d] > 1} + # 'variable' is available when there are multiple data_vars + if len(ds.data_vars) > 1: available.add('variable') - extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] - used: set[str] = set() - # Map slot names to their input values + # Get priority-ordered list of available dims + priority_dims = [d for d in CONFIG.Plotting.dim_priority if d in available] + # Add any available dims not in priority list (fallback) + priority_dims.extend(d for d in available if d not in priority_dims) + + # Slot specification in fill order slots = { + 'x': x, 'color': color, 'facet_col': facet_col, 'facet_row': facet_row, 'animation_frame': animation_frame, } - results: dict[str, str | None] = { - 'color': None, - 'facet_col': None, - 'facet_row': None, - 'animation_frame': None, - } + # Fixed fill order for 'auto' assignment + slot_order = ('x', 'color', 'facet_col', 'facet_row', 'animation_frame') + + results: dict[str, str | None] = {k: None for k in slot_order} + used: set[str] = set() # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used for slot_name, value in slots.items(): if value is not None and value != 'auto': - if value in available and value not in used: - used.add(value) - results[slot_name] = value - - # Second pass: resolve 'auto' slots in dim_slot_priority order - dim_iter = iter(d for d in extra_dims if d not in used) - for slot_name in CONFIG.Plotting.dim_slot_priority: - if slots.get(slot_name) == 'auto': + used.add(value) + results[slot_name] = value + + # Second pass: resolve 'auto' slots in fixed fill order + dim_iter = iter(d for d in priority_dims if d not in used) + for slot_name in slot_order: + if slots[slot_name] == 'auto': next_dim = next(dim_iter, None) if next_dim: used.add(next_dim) results[slot_name] = next_dim # Warn if any dimensions were not assigned to any slot - # Only count slots that were available (passed as 'auto' or explicit dim, not None) - available_slot_count = sum(1 for v in slots.values() if v is not None) unassigned = available - used if unassigned: - if available_slot_count < 4: - # Some slots weren't available (e.g., pie doesn't support animation_frame) - unavailable_slots = [k for k, v in slots.items() if v is None] + available_slots = [k for k, v in slots.items() if v is not None] + unavailable_slots = [k for k, v in slots.items() if v is None] + if unavailable_slots: warnings.warn( f'Dimensions {unassigned} not assigned to any plot dimension. ' f'Not available for this plot type: {unavailable_slots}. ' @@ -110,12 +102,12 @@ def _resolve_auto_facets( ) else: warnings.warn( - f'Dimensions {unassigned} not assigned to color/facet/animation. ' + f'Dimensions {unassigned} not assigned to any plot dimension ({available_slots}). ' f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).', stacklevel=3, ) - return results['color'], results['facet_col'], results['facet_row'], results['animation_frame'] + return results def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: @@ -184,7 +176,7 @@ def bar( """Create a grouped bar chart from the dataset. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). @@ -200,11 +192,8 @@ def bar( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, len(self._ds.data_vars), x) - actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) df = _dataset_to_long_df(self._ds) @@ -212,7 +201,7 @@ def bar( return go.Figure() # Get color labels from the resolved color column - color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] color_map = process_colors( colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale ) @@ -220,27 +209,27 @@ def bar( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'barmode': 'group', } - if actual_color and 'color' not in px_kwargs: - fig_kwargs['color'] = actual_color + if slots['color'] and 'color' not in px_kwargs: + fig_kwargs['color'] = slots['color'] fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} + if xlabel and slots['x']: + fig_kwargs['labels'] = {slots['x']: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and 'facet_col' not in px_kwargs: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and 'facet_row' not in px_kwargs: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and 'animation_frame' not in px_kwargs: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.bar(**{**fig_kwargs, **px_kwargs}) @@ -265,7 +254,7 @@ def stacked_bar( values are stacked separately. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). @@ -281,11 +270,8 @@ def stacked_bar( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, len(self._ds.data_vars), x) - actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) df = _dataset_to_long_df(self._ds) @@ -293,7 +279,7 @@ def stacked_bar( return go.Figure() # Get color labels from the resolved color column - color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] color_map = process_colors( colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale ) @@ -301,26 +287,26 @@ def stacked_bar( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, } - if actual_color and 'color' not in px_kwargs: - fig_kwargs['color'] = actual_color + if slots['color'] and 'color' not in px_kwargs: + fig_kwargs['color'] = slots['color'] fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} + if xlabel and slots['x']: + fig_kwargs['labels'] = {slots['x']: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and 'facet_col' not in px_kwargs: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and 'facet_row' not in px_kwargs: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and 'animation_frame' not in px_kwargs: + fig_kwargs['animation_frame'] = slots['animation_frame'] fig = px.bar(**{**fig_kwargs, **px_kwargs}) fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) @@ -348,7 +334,7 @@ def line( Each variable in the dataset becomes a separate line. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). @@ -366,11 +352,8 @@ def line( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, len(self._ds.data_vars), x) - actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) df = _dataset_to_long_df(self._ds) @@ -378,7 +361,7 @@ def line( return go.Figure() # Get color labels from the resolved color column - color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] color_map = process_colors( colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale ) @@ -386,27 +369,27 @@ def line( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - if actual_color and 'color' not in px_kwargs: - fig_kwargs['color'] = actual_color + if slots['color'] and 'color' not in px_kwargs: + fig_kwargs['color'] = slots['color'] fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} + if xlabel and slots['x']: + fig_kwargs['labels'] = {slots['x']: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and 'facet_col' not in px_kwargs: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and 'facet_row' not in px_kwargs: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and 'animation_frame' not in px_kwargs: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.line(**{**fig_kwargs, **px_kwargs}) @@ -429,7 +412,7 @@ def area( """Create a stacked area chart from the dataset. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). @@ -446,11 +429,8 @@ def area( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, len(self._ds.data_vars), x) - actual_color, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, color, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) df = _dataset_to_long_df(self._ds) @@ -458,7 +438,7 @@ def area( return go.Figure() # Get color labels from the resolved color column - color_labels = df[actual_color].unique().tolist() if actual_color and actual_color in df.columns else [] + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] color_map = process_colors( colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale ) @@ -466,27 +446,27 @@ def area( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, } - if actual_color and 'color' not in px_kwargs: - fig_kwargs['color'] = actual_color + if slots['color'] and 'color' not in px_kwargs: + fig_kwargs['color'] = slots['color'] fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} + if xlabel and slots['x']: + fig_kwargs['labels'] = {slots['x']: xlabel} if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and 'facet_col' not in px_kwargs: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and 'facet_row' not in px_kwargs: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and 'animation_frame' not in px_kwargs: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.area(**{**fig_kwargs, **px_kwargs}) @@ -537,7 +517,10 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - _, actual_facet_col, _, actual_anim = _resolve_auto_facets(self._ds, None, facet_col, None, animation_frame) + # Heatmap uses imshow - x/y come from array axes, color is continuous + slots = assign_slots( + self._ds, x=None, color=None, facet_col=facet_col, facet_row=None, animation_frame=animation_frame + ) imshow_args: dict[str, Any] = { 'img': da, @@ -545,13 +528,13 @@ def heatmap( 'title': title or variable, } - if actual_facet_col and actual_facet_col in da.dims: - imshow_args['facet_col'] = actual_facet_col - if facet_col_wrap < da.sizes[actual_facet_col]: + if slots['facet_col'] and slots['facet_col'] in da.dims: + imshow_args['facet_col'] = slots['facet_col'] + if facet_col_wrap < da.sizes[slots['facet_col']]: imshow_args['facet_col_wrap'] = facet_col_wrap - if actual_anim and actual_anim in da.dims: - imshow_args['animation_frame'] = actual_anim + if slots['animation_frame'] and slots['animation_frame'] in da.dims: + imshow_args['animation_frame'] = slots['animation_frame'] return px.imshow(**{**imshow_args, **imshow_kwargs}) @@ -595,8 +578,9 @@ def scatter( if df.empty: return go.Figure() - _, actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, None, facet_col, facet_row, animation_frame + # Scatter uses explicit x/y variable names, not dimensions + slots = assign_slots( + self._ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols @@ -614,14 +598,14 @@ def scatter( # Only use facets if the column actually exists in the dataframe # (scatter uses wide format, so 'variable' column doesn't exist) - if actual_facet_col and actual_facet_col in df.columns: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col'] and slots['facet_col'] in df.columns: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and actual_facet_row in df.columns: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and actual_anim in df.columns: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and slots['facet_row'] in df.columns: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and slots['animation_frame'] in df.columns: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.scatter(**fig_kwargs) @@ -682,8 +666,10 @@ def pie( if df.empty: return go.Figure() - # Note: px.pie doesn't support animation_frame - actual_facet_col, actual_facet_row, _ = _resolve_auto_facets(self._ds, facet_col, facet_row, None) + # Pie uses 'variable' for names and 'value' for values, no x/color/animation_frame + slots = assign_slots( + self._ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=None + ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols fig_kwargs: dict[str, Any] = { @@ -696,12 +682,12 @@ def pie( **px_kwargs, } - if actual_facet_col: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col']: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: - fig_kwargs['facet_row'] = actual_facet_row + if slots['facet_row']: + fig_kwargs['facet_row'] = slots['facet_row'] return px.pie(**fig_kwargs) @@ -940,10 +926,10 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - # Use Dataset for facet resolution + # Heatmap uses imshow - x/y come from array axes, color is continuous ds_for_resolution = da.to_dataset(name='_temp') - _, actual_facet_col, _, actual_anim = _resolve_auto_facets( - ds_for_resolution, None, facet_col, None, animation_frame + slots = assign_slots( + ds_for_resolution, x=None, color=None, facet_col=facet_col, facet_row=None, animation_frame=animation_frame ) imshow_args: dict[str, Any] = { @@ -952,12 +938,12 @@ def heatmap( 'title': title or (da.name if da.name else ''), } - if actual_facet_col and actual_facet_col in da.dims: - imshow_args['facet_col'] = actual_facet_col - if facet_col_wrap < da.sizes[actual_facet_col]: + if slots['facet_col'] and slots['facet_col'] in da.dims: + imshow_args['facet_col'] = slots['facet_col'] + if facet_col_wrap < da.sizes[slots['facet_col']]: imshow_args['facet_col_wrap'] = facet_col_wrap - if actual_anim and actual_anim in da.dims: - imshow_args['animation_frame'] = actual_anim + if slots['animation_frame'] and slots['animation_frame'] in da.dims: + imshow_args['animation_frame'] = slots['animation_frame'] return px.imshow(**{**imshow_args, **imshow_kwargs}) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 382ed1bf0..dc43287fc 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -31,6 +31,7 @@ from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG +from .dataset_plot_accessor import assign_slots from .plot_result import PlotResult if TYPE_CHECKING: @@ -188,9 +189,7 @@ def _resolve_auto_facets( ) -> tuple[str | None, str | None, str | None]: """Resolve 'auto' facet/animation dimensions based on available data dimensions. - When 'auto' is specified, extra dimensions are assigned to slots based on: - - CONFIG.Plotting.extra_dim_priority: Order of dimensions to assign. - - CONFIG.Plotting.dim_slot_priority: Order of slots to fill. + Uses assign_slots with x=None and color=None to only resolve facet/animation slots. Args: ds: Dataset to check for available dimensions. @@ -202,36 +201,10 @@ def _resolve_auto_facets( Tuple of (resolved_facet_col, resolved_facet_row, resolved_animation_frame). Each is either a valid dimension name or None. """ - # Get available extra dimensions with size > 1, sorted by priority - available = {d for d in ds.dims if ds.sizes[d] > 1} - extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] - used: set[str] = set() - - # Map slot names to their input values - slots = { - 'facet_col': facet_col, - 'facet_row': facet_row, - 'animation_frame': animation_frame, - } - results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} - - # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used - for slot_name, value in slots.items(): - if value is not None and value != 'auto': - if value in available and value not in used: - used.add(value) - results[slot_name] = value - - # Second pass: resolve 'auto' slots in dim_slot_priority order - dim_iter = iter(d for d in extra_dims if d not in used) - for slot_name in CONFIG.Plotting.dim_slot_priority: - if slots.get(slot_name) == 'auto': - next_dim = next(dim_iter, None) - if next_dim: - used.add(next_dim) - results[slot_name] = next_dim - - return results['facet_col'], results['facet_row'], results['animation_frame'] + slots = assign_slots( + ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + ) + return slots['facet_col'], slots['facet_row'], slots['animation_frame'] def _resolve_facets( From 947ccd9615f78c70cd38f336f9d715a84278c028 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:07:33 +0100 Subject: [PATCH 50/62] Add slot_order to config --- flixopt/config.py | 13 ++++++++++--- flixopt/dataset_plot_accessor.py | 8 ++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/flixopt/config.py b/flixopt/config.py index fce943eb1..3bc3d5ebf 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -165,6 +165,7 @@ def format(self, record): 'default_qualitative_colorscale': 'plotly', 'default_line_shape': 'hv', 'dim_priority': ('time', 'duration', 'duration_pct', 'variable', 'cluster', 'period', 'scenario'), + 'slot_priority': ('x', 'color', 'facet_col', 'facet_row', 'animation_frame'), } ), 'solving': MappingProxyType( @@ -560,9 +561,10 @@ class Plotting: default_facet_cols: Default number of columns for faceted plots. default_sequential_colorscale: Default colorscale for heatmaps and continuous data. default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts). - dim_priority: Priority order for assigning dimensions to plot slots (x, color, facet, etc.). - Dimensions are assigned to slots in order: x → y → color → facet_col → facet_row → animation_frame. - 'value' represents the y-axis values (from data_var names after melting). + dim_priority: Priority order for assigning dimensions to plot slots. + Dimensions are assigned to slots based on this order. + slot_priority: Order in which slots are filled during auto-assignment. + Default: x → color → facet_col → facet_row → animation_frame. Examples: ```python @@ -573,6 +575,9 @@ class Plotting: # Customize dimension priority for auto-assignment CONFIG.Plotting.dim_priority = ('time', 'scenario', 'variable', 'period', 'cluster') + + # Change slot fill order (e.g., prioritize facets over color) + CONFIG.Plotting.slot_priority = ('x', 'facet_col', 'facet_row', 'color', 'animation_frame') ``` """ @@ -584,6 +589,7 @@ class Plotting: default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale'] default_line_shape: str = _DEFAULTS['plotting']['default_line_shape'] dim_priority: tuple[str, ...] = _DEFAULTS['plotting']['dim_priority'] + slot_priority: tuple[str, ...] = _DEFAULTS['plotting']['slot_priority'] class Carriers: """Default carrier definitions for common energy types. @@ -686,6 +692,7 @@ def to_dict(cls) -> dict: 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale, 'default_line_shape': cls.Plotting.default_line_shape, 'dim_priority': cls.Plotting.dim_priority, + 'slot_priority': cls.Plotting.slot_priority, }, } diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 746227e45..ee3e82399 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -59,7 +59,7 @@ def assign_slots( # Add any available dims not in priority list (fallback) priority_dims.extend(d for d in available if d not in priority_dims) - # Slot specification in fill order + # Slot specification slots = { 'x': x, 'color': color, @@ -67,8 +67,8 @@ def assign_slots( 'facet_row': facet_row, 'animation_frame': animation_frame, } - # Fixed fill order for 'auto' assignment - slot_order = ('x', 'color', 'facet_col', 'facet_row', 'animation_frame') + # Slot fill order from config + slot_order = CONFIG.Plotting.slot_priority results: dict[str, str | None] = {k: None for k in slot_order} used: set[str] = set() @@ -79,7 +79,7 @@ def assign_slots( used.add(value) results[slot_name] = value - # Second pass: resolve 'auto' slots in fixed fill order + # Second pass: resolve 'auto' slots in config-defined fill order dim_iter = iter(d for d in priority_dims if d not in used) for slot_name in slot_order: if slots[slot_name] == 'auto': From b1336f6de867d8a22d780f0a4d353bfbc774fbab Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:20:21 +0100 Subject: [PATCH 51/62] Add new assign_slots() method --- flixopt/dataset_plot_accessor.py | 162 +++++++++------------- flixopt/statistics_accessor.py | 224 +++++++++++++++++-------------- 2 files changed, 184 insertions(+), 202 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index ee3e82399..9afce67dd 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -22,11 +22,11 @@ def assign_slots( facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', + exclude_dims: set[str] | None = None, ) -> dict[str, str | None]: """Assign dimensions to plot slots using CONFIG.Plotting.dim_priority. - Slot fill order: x → color → facet_col → facet_row → animation_frame. - Dimensions are assigned in priority order from CONFIG.Plotting.dim_priority. + Dimensions are assigned in priority order to slots based on CONFIG.Plotting.slot_priority. Slot values: - 'auto': auto-assign from available dims using priority @@ -43,15 +43,17 @@ def assign_slots( facet_col: Column faceting dimension. facet_row: Row faceting dimension. animation_frame: Animation slider dimension. + exclude_dims: Dimensions to exclude from auto-assignment (e.g., already used for x elsewhere). Returns: Dict with keys 'x', 'color', 'facet_col', 'facet_row', 'animation_frame' and values being assigned dimension names (or None if slot skipped/unfilled). """ - # Get available dimensions with size > 1 - available = {d for d in ds.dims if ds.sizes[d] > 1} - # 'variable' is available when there are multiple data_vars - if len(ds.data_vars) > 1: + # Get available dimensions with size > 1, excluding specified dims + exclude = exclude_dims or set() + available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} + # 'variable' is available when there are multiple data_vars (and not excluded) + if len(ds.data_vars) > 1 and 'variable' not in exclude: available.add('variable') # Get priority-ordered list of available dims @@ -110,6 +112,34 @@ def assign_slots( return results +def _build_fig_kwargs( + slots: dict[str, str | None], + ds_sizes: dict[str, int], + px_kwargs: dict[str, Any], + facet_cols: int | None = None, +) -> dict[str, Any]: + """Build plotly express kwargs from slot assignments. + + Adds facet/animation args only if slots are assigned and not overridden in px_kwargs. + Handles facet_col_wrap based on dimension size. + """ + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + result: dict[str, Any] = {} + + # Add facet/animation kwargs from slots (skip if None or already in px_kwargs) + for slot in ('color', 'facet_col', 'facet_row', 'animation_frame'): + if slots.get(slot) and slot not in px_kwargs: + result[slot] = slots[slot] + + # Add facet_col_wrap if facet_col is set and dimension is large enough + if result.get('facet_col'): + dim_size = ds_sizes.get(result['facet_col'], facet_col_wrap + 1) + if facet_col_wrap < dim_size: + result['facet_col_wrap'] = facet_col_wrap + + return result + + def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: """Convert Dataset to long-form DataFrame for Plotly Express.""" if not ds.data_vars: @@ -195,42 +225,24 @@ def bar( slots = assign_slots( self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - # Get color labels from the resolved color column color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] - color_map = process_colors( - colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale - ) + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, 'x': slots['x'], 'y': 'value', 'title': title, 'barmode': 'group', + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - if slots['color'] and 'color' not in px_kwargs: - fig_kwargs['color'] = slots['color'] - fig_kwargs['color_discrete_map'] = color_map - if xlabel and slots['x']: - fig_kwargs['labels'] = {slots['x']: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if slots['facet_col'] and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = slots['facet_col'] - if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if slots['facet_row'] and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = slots['facet_row'] - if slots['animation_frame'] and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = slots['animation_frame'] - return px.bar(**{**fig_kwargs, **px_kwargs}) def stacked_bar( @@ -273,41 +285,23 @@ def stacked_bar( slots = assign_slots( self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - # Get color labels from the resolved color column color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] - color_map = process_colors( - colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale - ) + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, 'x': slots['x'], 'y': 'value', 'title': title, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - if slots['color'] and 'color' not in px_kwargs: - fig_kwargs['color'] = slots['color'] - fig_kwargs['color_discrete_map'] = color_map - if xlabel and slots['x']: - fig_kwargs['labels'] = {slots['x']: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if slots['facet_col'] and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = slots['facet_col'] - if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if slots['facet_row'] and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = slots['facet_row'] - if slots['animation_frame'] and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = slots['animation_frame'] - fig = px.bar(**{**fig_kwargs, **px_kwargs}) fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) fig.update_traces(marker_line_width=0) @@ -355,42 +349,24 @@ def line( slots = assign_slots( self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - # Get color labels from the resolved color column color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] - color_map = process_colors( - colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale - ) + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - if slots['color'] and 'color' not in px_kwargs: - fig_kwargs['color'] = slots['color'] - fig_kwargs['color_discrete_map'] = color_map - if xlabel and slots['x']: - fig_kwargs['labels'] = {slots['x']: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if slots['facet_col'] and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = slots['facet_col'] - if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if slots['facet_row'] and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = slots['facet_row'] - if slots['animation_frame'] and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = slots['animation_frame'] - return px.line(**{**fig_kwargs, **px_kwargs}) def area( @@ -432,42 +408,24 @@ def area( slots = assign_slots( self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - # Get color labels from the resolved color column color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] - color_map = process_colors( - colors, color_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale - ) + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - if slots['color'] and 'color' not in px_kwargs: - fig_kwargs['color'] = slots['color'] - fig_kwargs['color_discrete_map'] = color_map - if xlabel and slots['x']: - fig_kwargs['labels'] = {slots['x']: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if slots['facet_col'] and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = slots['facet_col'] - if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if slots['facet_row'] and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = slots['facet_row'] - if slots['animation_frame'] and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = slots['animation_frame'] - return px.area(**{**fig_kwargs, **px_kwargs}) def heatmap( diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index dc43287fc..da6a859f9 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -181,43 +181,8 @@ def _filter_by_carrier(ds: xr.Dataset, carrier: str | list[str] | None) -> xr.Da return ds[matching_vars] if matching_vars else xr.Dataset() -def _resolve_auto_facets( - ds: xr.Dataset, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, - animation_frame: str | Literal['auto'] | None = None, -) -> tuple[str | None, str | None, str | None]: - """Resolve 'auto' facet/animation dimensions based on available data dimensions. - - Uses assign_slots with x=None and color=None to only resolve facet/animation slots. - - Args: - ds: Dataset to check for available dimensions. - facet_col: Dimension name, 'auto', or None. - facet_row: Dimension name, 'auto', or None. - animation_frame: Dimension name, 'auto', or None. - - Returns: - Tuple of (resolved_facet_col, resolved_facet_row, resolved_animation_frame). - Each is either a valid dimension name or None. - """ - slots = assign_slots( - ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame - ) - return slots['facet_col'], slots['facet_row'], slots['animation_frame'] - - -def _resolve_facets( - ds: xr.Dataset, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, -) -> tuple[str | None, str | None]: - """Resolve facet dimensions, returning None if not present in data. - - Legacy wrapper for _resolve_auto_facets for backward compatibility. - """ - resolved_col, resolved_row, _ = _resolve_auto_facets(ds, facet_col, facet_row, None) - return resolved_col, resolved_row +# Default dimensions to exclude from facet auto-assignment (typically x-axis dimensions) +_FACET_EXCLUDE_DIMS = {'time', 'duration', 'duration_pct'} def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: @@ -1355,8 +1320,14 @@ def balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Build color map from Element.color attributes if no colors specified @@ -1372,9 +1343,9 @@ def balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{node} [{unit_label}]' if unit_label else node, - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -1466,8 +1437,14 @@ def carrier_balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Use cached component colors for flows @@ -1496,9 +1473,9 @@ def carrier_balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -1570,17 +1547,22 @@ def heatmap( # Determine facet and animation from available dims has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - if has_multiple_vars: - actual_facet = 'variable' - # Resolve animation using auto logic, excluding 'variable' which is used for facet - _, _, actual_animation = _resolve_auto_facets(da.to_dataset(name='value'), None, None, animation_frame) - if actual_animation == 'variable': - actual_animation = None - else: - # Resolve facet and animation using auto logic - actual_facet, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame - ) + # Get slot assignments (heatmap only uses facet_col and animation_frame) + slots = assign_slots( + da.to_dataset(name='value'), + x=None, + color=None, + facet_col='variable' if has_multiple_vars else facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, + ) + resolved_facet = slots['facet_col'] + resolved_anim = slots['animation_frame'] + + # Don't use 'variable' for animation if it's used for facet + if resolved_anim == 'variable' and has_multiple_vars: + resolved_anim = None # Determine heatmap dimensions based on data structure if is_clustered and (reshape == 'auto' or reshape is None): @@ -1588,21 +1570,27 @@ def heatmap( heatmap_dims = ['time', 'cluster'] elif reshape and reshape != 'auto' and 'time' in da.dims: # Non-clustered with explicit reshape: reshape time to (day, hour) etc. - # Extra dims will be handled via facet/animation or dropped da = _reshape_time_for_heatmap(da, reshape) heatmap_dims = ['timestep', 'timeframe'] elif reshape == 'auto' and 'time' in da.dims and not is_clustered: # Auto mode for non-clustered: use default ('D', 'h') reshape - # Extra dims will be handled via facet/animation or dropped da = _reshape_time_for_heatmap(da, ('D', 'h')) heatmap_dims = ['timestep', 'timeframe'] elif has_multiple_vars: # Can't reshape but have multiple vars: use variable + time as heatmap axes heatmap_dims = ['variable', 'time'] - # variable is now a heatmap dim, use period/scenario for facet/animation - actual_facet, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame + # variable is now a heatmap dim, reassign facet + slots = assign_slots( + da.to_dataset(name='value'), + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS | {'variable'}, ) + resolved_facet = slots['facet_col'] + resolved_anim = slots['animation_frame'] else: # Fallback: use first two available dimensions available_dims = [d for d in da.dims if da.sizes[d] > 1] @@ -1614,12 +1602,12 @@ def heatmap( heatmap_dims = list(da.dims)[:1] # Keep only dims we need - keep_dims = set(heatmap_dims) | {d for d in [actual_facet, actual_animation] if d is not None} + keep_dims = set(heatmap_dims) | {d for d in [resolved_facet, resolved_anim] if d is not None} for dim in [d for d in da.dims if d not in keep_dims]: da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) # Transpose to expected order - dim_order = heatmap_dims + [d for d in [actual_facet, actual_animation] if d] + dim_order = heatmap_dims + [d for d in [resolved_facet, resolved_anim] if d] da = da.transpose(*dim_order) # Clear name for multiple variables (colorbar would show first var's name) @@ -1628,8 +1616,8 @@ def heatmap( fig = da.fxplot.heatmap( colors=colors, - facet_col=actual_facet, - animation_frame=actual_animation, + facet_col=resolved_facet, + animation_frame=resolved_anim, **plotly_kwargs, ) @@ -1710,8 +1698,14 @@ def flows( ds = ds[[lbl for lbl in matching_labels if lbl in ds]] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Get unit label from first data variable's attributes @@ -1723,9 +1717,9 @@ def flows( fig = ds.fxplot.line( colors=colors, title=f'Flows [{unit_label}]' if unit_label else 'Flows', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -1771,8 +1765,14 @@ def sizes( valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] ds = ds[valid_labels] - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) df = _dataset_to_long_df(ds) @@ -1786,9 +1786,9 @@ def sizes( x='variable', y='value', color='variable', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], color_discrete_map=color_map, title='Investment Sizes', labels={'variable': 'Flow', 'value': 'Size'}, @@ -1886,8 +1886,14 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: duration_coord = np.linspace(0, 100, n_timesteps) if normalize else np.arange(n_timesteps) result_ds = result_ds.assign_coords({duration_name: duration_coord}) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - result_ds, facet_col, facet_row, animation_frame + slots = assign_slots( + result_ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Get unit label from first data variable's attributes @@ -1899,9 +1905,9 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: fig = result_ds.fxplot.line( colors=colors, title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -2031,8 +2037,14 @@ def effects( raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") # Resolve facets - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - combined.to_dataset(name='value'), facet_col, facet_row, animation_frame + slots = assign_slots( + combined.to_dataset(name='value'), + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Convert to DataFrame for plotly express @@ -2060,9 +2072,9 @@ def effects( y='value', color=color_col, color_discrete_map=color_map, - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], title=title, **plotly_kwargs, ) @@ -2111,16 +2123,22 @@ def charge_states( ds = ds[[s for s in storages if s in ds]] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) fig = ds.fxplot.line( colors=colors, title='Storage Charge States', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) fig.update_yaxes(title_text='Charge State') @@ -2204,8 +2222,14 @@ def storage( # Apply selection ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame + slots = assign_slots( + ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=_FACET_EXCLUDE_DIMS, ) # Build color map @@ -2227,9 +2251,9 @@ def storage( x='time', y='value', color='variable', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], color_discrete_map=color_map, title=f'{storage} Operation ({unit})', **plotly_kwargs, @@ -2244,9 +2268,9 @@ def storage( charge_df, x='time', y='value', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=slots['facet_col'], + facet_row=slots['facet_row'], + animation_frame=slots['animation_frame'], ) # Get the primary y-axes from the bar figure to create matching secondary axes From 28bb631a9460645c737f6fe57a54a9ff63cff2af Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:33:14 +0100 Subject: [PATCH 52/62] Add new assign_slots() method --- flixopt/dataset_plot_accessor.py | 38 ++++++- flixopt/statistics_accessor.py | 171 ++++++++----------------------- 2 files changed, 75 insertions(+), 134 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 9afce67dd..6451037f5 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -201,6 +201,7 @@ def bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a grouped bar chart from the dataset. @@ -217,13 +218,20 @@ def bar( facet_row: Dimension for row facets. 'auto' uses CONFIG priority. animation_frame: Dimension for animation slider. facet_cols: Number of columns in facet grid wrap. + exclude_dims: Dimensions to exclude from auto-assignment. **px_kwargs: Additional arguments passed to plotly.express.bar. Returns: Plotly Figure. """ slots = assign_slots( - self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) df = _dataset_to_long_df(self._ds) if df.empty: @@ -258,6 +266,7 @@ def stacked_bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked bar chart from the dataset. @@ -283,7 +292,13 @@ def stacked_bar( Plotly Figure. """ slots = assign_slots( - self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) df = _dataset_to_long_df(self._ds) if df.empty: @@ -321,6 +336,7 @@ def line( animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a line chart from the dataset. @@ -347,7 +363,13 @@ def line( Plotly Figure. """ slots = assign_slots( - self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) df = _dataset_to_long_df(self._ds) if df.empty: @@ -383,6 +405,7 @@ def area( animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked area chart from the dataset. @@ -406,7 +429,13 @@ def area( Plotly Figure. """ slots = assign_slots( - self._ds, x=x, color=color, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) df = _dataset_to_long_df(self._ds) if df.empty: @@ -777,6 +806,7 @@ def stacked_bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked bar chart. See DatasetPlotAccessor.stacked_bar for details.""" diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index da6a859f9..26a4bf55c 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -31,7 +31,6 @@ from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG -from .dataset_plot_accessor import assign_slots from .plot_result import PlotResult if TYPE_CHECKING: @@ -181,10 +180,6 @@ def _filter_by_carrier(ds: xr.Dataset, carrier: str | list[str] | None) -> xr.Da return ds[matching_vars] if matching_vars else xr.Dataset() -# Default dimensions to exclude from facet auto-assignment (typically x-axis dimensions) -_FACET_EXCLUDE_DIMS = {'time', 'duration', 'duration_pct'} - - def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: """Convert xarray Dataset to long-form DataFrame for plotly express.""" if not ds.data_vars: @@ -1320,15 +1315,6 @@ def balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Build color map from Element.color attributes if no colors specified if colors is None: @@ -1343,9 +1329,9 @@ def balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{node} [{unit_label}]' if unit_label else node, - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1437,15 +1423,6 @@ def carrier_balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Use cached component colors for flows if colors is None: @@ -1473,9 +1450,9 @@ def carrier_balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1547,18 +1524,18 @@ def heatmap( # Determine facet and animation from available dims has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - # Get slot assignments (heatmap only uses facet_col and animation_frame) - slots = assign_slots( - da.to_dataset(name='value'), - x=None, - color=None, - facet_col='variable' if has_multiple_vars else facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) - resolved_facet = slots['facet_col'] - resolved_anim = slots['animation_frame'] + # For heatmap, facet defaults to 'variable' if multiple vars + # Resolve 'auto' to None for heatmap (no auto-faceting by time etc.) + if facet_col == 'auto': + resolved_facet = 'variable' if has_multiple_vars else None + else: + resolved_facet = facet_col + + # Resolve animation_frame - 'auto' means None for heatmap (no auto-animation) + if animation_frame == 'auto': + resolved_anim = None + else: + resolved_anim = animation_frame # Don't use 'variable' for animation if it's used for facet if resolved_anim == 'variable' and has_multiple_vars: @@ -1579,18 +1556,8 @@ def heatmap( elif has_multiple_vars: # Can't reshape but have multiple vars: use variable + time as heatmap axes heatmap_dims = ['variable', 'time'] - # variable is now a heatmap dim, reassign facet - slots = assign_slots( - da.to_dataset(name='value'), - x=None, - color=None, - facet_col=facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS | {'variable'}, - ) - resolved_facet = slots['facet_col'] - resolved_anim = slots['animation_frame'] + # variable is now a heatmap dim, use user's facet choice + resolved_facet = facet_col else: # Fallback: use first two available dimensions available_dims = [d for d in da.dims if da.sizes[d] > 1] @@ -1698,15 +1665,6 @@ def flows( ds = ds[[lbl for lbl in matching_labels if lbl in ds]] ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Get unit label from first data variable's attributes unit_label = '' @@ -1717,9 +1675,9 @@ def flows( fig = ds.fxplot.line( colors=colors, title=f'Flows [{unit_label}]' if unit_label else 'Flows', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1765,16 +1723,6 @@ def sizes( valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] ds = ds[valid_labels] - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) - df = _dataset_to_long_df(ds) if df.empty: fig = go.Figure() @@ -1786,9 +1734,9 @@ def sizes( x='variable', y='value', color='variable', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, color_discrete_map=color_map, title='Investment Sizes', labels={'variable': 'Flow', 'value': 'Size'}, @@ -1886,16 +1834,6 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: duration_coord = np.linspace(0, 100, n_timesteps) if normalize else np.arange(n_timesteps) result_ds = result_ds.assign_coords({duration_name: duration_coord}) - slots = assign_slots( - result_ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) - # Get unit label from first data variable's attributes unit_label = '' if ds.data_vars: @@ -1905,9 +1843,9 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: fig = result_ds.fxplot.line( colors=colors, title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -2037,15 +1975,6 @@ def effects( raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") # Resolve facets - slots = assign_slots( - combined.to_dataset(name='value'), - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Convert to DataFrame for plotly express df = combined.to_dataframe(name='value').reset_index() @@ -2072,9 +2001,9 @@ def effects( y='value', color=color_col, color_discrete_map=color_map, - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, title=title, **plotly_kwargs, ) @@ -2123,22 +2052,13 @@ def charge_states( ds = ds[[s for s in storages if s in ds]] ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) fig = ds.fxplot.line( colors=colors, title='Storage Charge States', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) fig.update_yaxes(title_text='Charge State') @@ -2222,15 +2142,6 @@ def storage( # Apply selection ds = _apply_selection(ds, select) - slots = assign_slots( - ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, - exclude_dims=_FACET_EXCLUDE_DIMS, - ) # Build color map flow_labels = [lbl for lbl in ds.data_vars if lbl != 'charge_state'] @@ -2251,9 +2162,9 @@ def storage( x='time', y='value', color='variable', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, color_discrete_map=color_map, title=f'{storage} Operation ({unit})', **plotly_kwargs, @@ -2268,9 +2179,9 @@ def storage( charge_df, x='time', y='value', - facet_col=slots['facet_col'], - facet_row=slots['facet_row'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, ) # Get the primary y-axes from the bar figure to create matching secondary axes From 4f8407a6353adc32195789a8393f65aefd1802a0 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:54:45 +0100 Subject: [PATCH 53/62] Fix heatmap and convert all to use fxplot --- flixopt/dataset_plot_accessor.py | 24 +++++- flixopt/statistics_accessor.py | 131 +++++++++++++------------------ 2 files changed, 73 insertions(+), 82 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 6451037f5..a310c94b8 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -504,9 +504,17 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - # Heatmap uses imshow - x/y come from array axes, color is continuous + # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap + # Exclude these from slot assignment + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() slots = assign_slots( - self._ds, x=None, color=None, facet_col=facet_col, facet_row=None, animation_frame=animation_frame + self._ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, ) imshow_args: dict[str, Any] = { @@ -914,10 +922,18 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - # Heatmap uses imshow - x/y come from array axes, color is continuous + # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap + # Exclude these from slot assignment + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() ds_for_resolution = da.to_dataset(name='_temp') slots = assign_slots( - ds_for_resolution, x=None, color=None, facet_col=facet_col, facet_row=None, animation_frame=animation_frame + ds_for_resolution, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, ) imshow_args: dict[str, Any] = { diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 26a4bf55c..1df6704ea 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -31,6 +31,7 @@ from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG +from .dataset_plot_accessor import assign_slots from .plot_result import PlotResult if TYPE_CHECKING: @@ -1520,28 +1521,9 @@ def heatmap( # Check if data is clustered (has cluster dimension with size > 1) is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 - - # Determine facet and animation from available dims has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - # For heatmap, facet defaults to 'variable' if multiple vars - # Resolve 'auto' to None for heatmap (no auto-faceting by time etc.) - if facet_col == 'auto': - resolved_facet = 'variable' if has_multiple_vars else None - else: - resolved_facet = facet_col - - # Resolve animation_frame - 'auto' means None for heatmap (no auto-animation) - if animation_frame == 'auto': - resolved_anim = None - else: - resolved_anim = animation_frame - - # Don't use 'variable' for animation if it's used for facet - if resolved_anim == 'variable' and has_multiple_vars: - resolved_anim = None - - # Determine heatmap dimensions based on data structure + # Apply time reshape if needed (creates timestep/timeframe dims) if is_clustered and (reshape == 'auto' or reshape is None): # Clustered data: use (time, cluster) as natural 2D heatmap axes heatmap_dims = ['time', 'cluster'] @@ -1556,26 +1538,34 @@ def heatmap( elif has_multiple_vars: # Can't reshape but have multiple vars: use variable + time as heatmap axes heatmap_dims = ['variable', 'time'] - # variable is now a heatmap dim, use user's facet choice - resolved_facet = facet_col else: # Fallback: use first two available dimensions available_dims = [d for d in da.dims if da.sizes[d] > 1] - if len(available_dims) >= 2: - heatmap_dims = available_dims[:2] - elif 'time' in da.dims: - heatmap_dims = ['time'] - else: - heatmap_dims = list(da.dims)[:1] + heatmap_dims = available_dims[:2] if len(available_dims) >= 2 else list(da.dims)[:2] + + # Resolve facet/animation using assign_slots, excluding heatmap dims + ds_temp = da.to_dataset(name='_temp') + slots = assign_slots( + ds_temp, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=set(heatmap_dims), + ) - # Keep only dims we need - keep_dims = set(heatmap_dims) | {d for d in [resolved_facet, resolved_anim] if d is not None} + # Keep only dims we need (heatmap axes + facet/animation) + keep_dims = set(heatmap_dims) | {d for d in [slots['facet_col'], slots['animation_frame']] if d} for dim in [d for d in da.dims if d not in keep_dims]: da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) - # Transpose to expected order - dim_order = heatmap_dims + [d for d in [resolved_facet, resolved_anim] if d] - da = da.transpose(*dim_order) + # Transpose to expected order (heatmap dims first) + dim_order = [d for d in heatmap_dims if d in da.dims] + [ + d for d in [slots['facet_col'], slots['animation_frame']] if d and d in da.dims + ] + if len(dim_order) == len(da.dims): + da = da.transpose(*dim_order) # Clear name for multiple variables (colorbar would show first var's name) if has_multiple_vars: @@ -1583,8 +1573,8 @@ def heatmap( fig = da.fxplot.heatmap( colors=colors, - facet_col=resolved_facet, - animation_frame=resolved_anim, + facet_col=slots['facet_col'], + animation_frame=slots['animation_frame'], **plotly_kwargs, ) @@ -1723,23 +1713,18 @@ def sizes( valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] ds = ds[valid_labels] - df = _dataset_to_long_df(ds) - if df.empty: + if not ds.data_vars: fig = go.Figure() else: - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables) - fig = px.bar( - df, + fig = ds.fxplot.bar( x='variable', - y='value', color='variable', + colors=colors, + title='Investment Sizes', + ylabel='Size', facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, - color_discrete_map=color_map, - title='Investment Sizes', - labels={'variable': 'Flow', 'value': 'Size'}, **plotly_kwargs, ) @@ -1974,11 +1959,14 @@ def effects( else: raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") - # Resolve facets - # Convert to DataFrame for plotly express df = combined.to_dataframe(name='value').reset_index() + # Resolve facet/animation: 'auto' means None for DataFrames (no dimension priority) + resolved_facet_col = None if facet_col == 'auto' else facet_col + resolved_facet_row = None if facet_row == 'auto' else facet_row + resolved_animation = None if animation_frame == 'auto' else animation_frame + # Build color map if color_col and color_col in df.columns: color_items = df[color_col].unique().tolist() @@ -2001,9 +1989,9 @@ def effects( y='value', color=color_col, color_discrete_map=color_map, - facet_col=facet_col, - facet_row=facet_row, - animation_frame=animation_frame, + facet_col=resolved_facet_col, + facet_row=resolved_facet_row, + animation_frame=resolved_animation, title=title, **plotly_kwargs, ) @@ -2143,57 +2131,45 @@ def storage( # Apply selection ds = _apply_selection(ds, select) - # Build color map + # Separate flow data from charge_state flow_labels = [lbl for lbl in ds.data_vars if lbl != 'charge_state'] + flow_ds = ds[flow_labels] + charge_da = ds['charge_state'] + + # Build color map for flows if colors is None: colors = self._get_color_map_for_balance(storage, flow_labels) - color_map = process_colors(colors, flow_labels) - color_map['charge_state'] = 'black' - # Convert to long-form DataFrame - df = _dataset_to_long_df(ds) - - # Create figure with facets using px.bar for flows - flow_df = df[df['variable'] != 'charge_state'] - charge_df = df[df['variable'] == 'charge_state'] - - fig = px.bar( - flow_df, + # Create stacked bar chart for flows using fxplot + fig = flow_ds.fxplot.stacked_bar( x='time', - y='value', color='variable', + colors=colors, + title=f'{storage} Operation ({unit})', facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, - color_discrete_map=color_map, - title=f'{storage} Operation ({unit})', **plotly_kwargs, ) - fig.update_layout(bargap=0, bargroupgap=0) - fig.update_traces(marker_line_width=0) # Add charge state as line on secondary y-axis - if not charge_df.empty: - # Create line figure with same facets to get matching trace structure - line_fig = px.line( - charge_df, + if charge_da.size > 0: + # Create line figure with same facets + line_fig = charge_da.fxplot.line( x='time', - y='value', + color=None, # Single line, no color grouping facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame, ) # Get the primary y-axes from the bar figure to create matching secondary axes - # px creates axes named: yaxis, yaxis2, yaxis3, etc. primary_yaxes = [key for key in fig.layout if key.startswith('yaxis')] # For each primary y-axis, create a secondary y-axis for i, primary_key in enumerate(sorted(primary_yaxes, key=lambda x: int(x[5:]) if x[5:] else 0)): - # Determine secondary axis name (y -> y2, y2 -> y3 pattern won't work) - # Instead use a consistent offset: yaxis -> yaxis10, yaxis2 -> yaxis11, etc. primary_num = primary_key[5:] if primary_key[5:] else '1' - secondary_num = int(primary_num) + 100 # Use high offset to avoid conflicts + secondary_num = int(primary_num) + 100 secondary_key = f'yaxis{secondary_num}' secondary_anchor = f'x{primary_num}' if primary_num != '1' else 'x' @@ -2207,14 +2183,13 @@ def storage( # Add line traces with correct axis assignments for i, trace in enumerate(line_fig.data): - # Map trace index to secondary y-axis primary_num = i + 1 if i > 0 else 1 secondary_yaxis = f'y{primary_num + 100}' trace.name = 'charge_state' trace.line = dict(color=charge_state_color, width=2) trace.yaxis = secondary_yaxis - trace.showlegend = i == 0 # Only show legend for first trace + trace.showlegend = i == 0 trace.legendgroup = 'charge_state' fig.add_trace(trace) From a4d46811c889b652f78b1ae1f1a95a0af96d1777 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:08:01 +0100 Subject: [PATCH 54/62] Fix heatmap --- flixopt/dataset_plot_accessor.py | 88 ++++++++++++++++++++------------ flixopt/statistics_accessor.py | 48 ++++++++--------- 2 files changed, 75 insertions(+), 61 deletions(-) diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index a310c94b8..47cb0564a 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -505,17 +505,24 @@ def heatmap( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap - # Exclude these from slot assignment - heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() - slots = assign_slots( - self._ds, - x=None, - color=None, - facet_col=facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=heatmap_axes, - ) + # Only call assign_slots if we need to resolve 'auto' values + if facet_col == 'auto' or animation_frame == 'auto': + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() + slots = assign_slots( + self._ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, + ) + resolved_facet = slots['facet_col'] + resolved_animation = slots['animation_frame'] + else: + # Values already resolved (or None), use directly without re-resolving + resolved_facet = facet_col + resolved_animation = animation_frame imshow_args: dict[str, Any] = { 'img': da, @@ -523,13 +530,17 @@ def heatmap( 'title': title or variable, } - if slots['facet_col'] and slots['facet_col'] in da.dims: - imshow_args['facet_col'] = slots['facet_col'] - if facet_col_wrap < da.sizes[slots['facet_col']]: + if resolved_facet and resolved_facet in da.dims: + imshow_args['facet_col'] = resolved_facet + if facet_col_wrap < da.sizes[resolved_facet]: imshow_args['facet_col_wrap'] = facet_col_wrap - if slots['animation_frame'] and slots['animation_frame'] in da.dims: - imshow_args['animation_frame'] = slots['animation_frame'] + if resolved_animation and resolved_animation in da.dims: + imshow_args['animation_frame'] = resolved_animation + + # Use binary_string=False to handle non-numeric coords (e.g., string labels) + if 'binary_string' not in imshow_kwargs: + imshow_args['binary_string'] = False return px.imshow(**{**imshow_args, **imshow_kwargs}) @@ -923,18 +934,25 @@ def heatmap( facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap - # Exclude these from slot assignment - heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() - ds_for_resolution = da.to_dataset(name='_temp') - slots = assign_slots( - ds_for_resolution, - x=None, - color=None, - facet_col=facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=heatmap_axes, - ) + # Only call assign_slots if we need to resolve 'auto' values + if facet_col == 'auto' or animation_frame == 'auto': + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() + ds_for_resolution = da.to_dataset(name='_temp') + slots = assign_slots( + ds_for_resolution, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, + ) + resolved_facet = slots['facet_col'] + resolved_animation = slots['animation_frame'] + else: + # Values already resolved (or None), use directly without re-resolving + resolved_facet = facet_col + resolved_animation = animation_frame imshow_args: dict[str, Any] = { 'img': da, @@ -942,12 +960,16 @@ def heatmap( 'title': title or (da.name if da.name else ''), } - if slots['facet_col'] and slots['facet_col'] in da.dims: - imshow_args['facet_col'] = slots['facet_col'] - if facet_col_wrap < da.sizes[slots['facet_col']]: + if resolved_facet and resolved_facet in da.dims: + imshow_args['facet_col'] = resolved_facet + if facet_col_wrap < da.sizes[resolved_facet]: imshow_args['facet_col_wrap'] = facet_col_wrap - if slots['animation_frame'] and slots['animation_frame'] in da.dims: - imshow_args['animation_frame'] = slots['animation_frame'] + if resolved_animation and resolved_animation in da.dims: + imshow_args['animation_frame'] = resolved_animation + + # Use binary_string=False to handle non-numeric coords (e.g., string labels) + if 'binary_string' not in imshow_kwargs: + imshow_args['binary_string'] = False return px.imshow(**{**imshow_args, **imshow_kwargs}) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 1df6704ea..2ac9060ac 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -31,7 +31,6 @@ from .color_processing import ColorType, hex_to_rgba, process_colors from .config import CONFIG -from .dataset_plot_accessor import assign_slots from .plot_result import PlotResult if TYPE_CHECKING: @@ -1472,7 +1471,7 @@ def heatmap( reshape: tuple[str, str] | Literal['auto'] | None = 'auto', colors: str | list[str] | None = None, facet_col: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -1523,6 +1522,11 @@ def heatmap( is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 + # Count extra dims (beyond time) - if too many, skip reshape to avoid dimension explosion + extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] + # Max dims for heatmap: 2 axes + facet_col + animation_frame = 4 + can_reshape = len(extra_dims) <= 2 # Leave room for facet and animation + # Apply time reshape if needed (creates timestep/timeframe dims) if is_clustered and (reshape == 'auto' or reshape is None): # Clustered data: use (time, cluster) as natural 2D heatmap axes @@ -1531,8 +1535,8 @@ def heatmap( # Non-clustered with explicit reshape: reshape time to (day, hour) etc. da = _reshape_time_for_heatmap(da, reshape) heatmap_dims = ['timestep', 'timeframe'] - elif reshape == 'auto' and 'time' in da.dims and not is_clustered: - # Auto mode for non-clustered: use default ('D', 'h') reshape + elif reshape == 'auto' and 'time' in da.dims and not is_clustered and can_reshape: + # Auto mode for non-clustered: use default ('D', 'h') reshape only if not too many dims da = _reshape_time_for_heatmap(da, ('D', 'h')) heatmap_dims = ['timestep', 'timeframe'] elif has_multiple_vars: @@ -1543,38 +1547,26 @@ def heatmap( available_dims = [d for d in da.dims if da.sizes[d] > 1] heatmap_dims = available_dims[:2] if len(available_dims) >= 2 else list(da.dims)[:2] - # Resolve facet/animation using assign_slots, excluding heatmap dims - ds_temp = da.to_dataset(name='_temp') - slots = assign_slots( - ds_temp, - x=None, - color=None, - facet_col=facet_col, - facet_row=None, - animation_frame=animation_frame, - exclude_dims=set(heatmap_dims), - ) - - # Keep only dims we need (heatmap axes + facet/animation) - keep_dims = set(heatmap_dims) | {d for d in [slots['facet_col'], slots['animation_frame']] if d} - for dim in [d for d in da.dims if d not in keep_dims]: - da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) + # Transpose so heatmap dims come first (px.imshow uses first 2 dims as y/x axes) + other_dims = [d for d in da.dims if d not in heatmap_dims] + dim_order = [d for d in heatmap_dims if d in da.dims] + other_dims + # Always transpose to ensure correct dim order (even if seemingly equal, xarray dim order matters) + da = da.transpose(*dim_order) - # Transpose to expected order (heatmap dims first) - dim_order = [d for d in heatmap_dims if d in da.dims] + [ - d for d in [slots['facet_col'], slots['animation_frame']] if d and d in da.dims - ] - if len(dim_order) == len(da.dims): - da = da.transpose(*dim_order) + # Squeeze single-element dims (except heatmap axes) to avoid 3D shape errors + for dim in list(da.dims): + if dim not in heatmap_dims and da.sizes[dim] == 1: + da = da.squeeze(dim, drop=True) # Clear name for multiple variables (colorbar would show first var's name) if has_multiple_vars: da = da.rename('') + # Let fxplot handle slot assignment for facet/animation fig = da.fxplot.heatmap( colors=colors, - facet_col=slots['facet_col'], - animation_frame=slots['animation_frame'], + facet_col=facet_col, + animation_frame=animation_frame, **plotly_kwargs, ) From ae5655dd4ceea31f8cc4b147aaf71b158d9846d5 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:14:06 +0100 Subject: [PATCH 55/62] Fix heatmap --- flixopt/statistics_accessor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 2ac9060ac..bbd61980f 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -1523,9 +1523,11 @@ def heatmap( has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 # Count extra dims (beyond time) - if too many, skip reshape to avoid dimension explosion + # Reshape adds 1 dim (time -> timestep + timeframe), so check available slots extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] - # Max dims for heatmap: 2 axes + facet_col + animation_frame = 4 - can_reshape = len(extra_dims) <= 2 # Leave room for facet and animation + # Count available slots: 'auto' means available, None/explicit means not available + available_slots = (1 if facet_col == 'auto' else 0) + (1 if animation_frame == 'auto' else 0) + can_reshape = len(extra_dims) <= available_slots # Apply time reshape if needed (creates timestep/timeframe dims) if is_clustered and (reshape == 'auto' or reshape is None): From 56b183810d7d16e7132f478b890e2343a81c8aeb Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:16:49 +0100 Subject: [PATCH 56/62] Fix heatmap --- flixopt/statistics_accessor.py | 122 +++++++++++++++------------------ 1 file changed, 57 insertions(+), 65 deletions(-) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index bbd61980f..09d75f145 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -127,6 +127,55 @@ def _reshape_time_for_heatmap( # --- Helper functions --- +def _prepare_for_heatmap( + da: xr.DataArray, + reshape: tuple[str, str] | Literal['auto'] | None, + facet_col: str | Literal['auto'] | None, + animation_frame: str | Literal['auto'] | None, +) -> xr.DataArray: + """Prepare DataArray for heatmap: determine axes, reshape if needed, transpose/squeeze.""" + is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 + has_time = 'time' in da.dims + has_multi_vars = da.sizes.get('variable', 1) > 1 + + # Determine heatmap axes and apply reshape if needed + if is_clustered and reshape in ('auto', None): + heatmap_dims = ['time', 'cluster'] + elif reshape and reshape != 'auto' and has_time: + da = _reshape_time_for_heatmap(da, reshape) + heatmap_dims = ['timestep', 'timeframe'] + elif reshape == 'auto' and has_time and not is_clustered: + # Check if we have room for extra dims after reshaping (adds 1 dim: time -> timestep + timeframe) + extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] + available_slots = (facet_col == 'auto') + (animation_frame == 'auto') + if len(extra_dims) <= available_slots: + da = _reshape_time_for_heatmap(da, ('D', 'h')) + heatmap_dims = ['timestep', 'timeframe'] + elif has_multi_vars: + heatmap_dims = ['variable', 'time'] + else: + heatmap_dims = [d for d in da.dims if da.sizes[d] > 1][:2] or list(da.dims)[:2] + elif has_multi_vars: + heatmap_dims = ['variable', 'time'] + else: + heatmap_dims = [d for d in da.dims if da.sizes[d] > 1][:2] or list(da.dims)[:2] + + # Transpose: heatmap dims first, then others + other_dims = [d for d in da.dims if d not in heatmap_dims] + da = da.transpose(*[d for d in heatmap_dims if d in da.dims], *other_dims) + + # Squeeze single-element dims (except heatmap axes) + for dim in list(da.dims): + if dim not in heatmap_dims and da.sizes[dim] == 1: + da = da.squeeze(dim, drop=True) + + # Clear name for multiple variables (colorbar would show first var's name) + if has_multi_vars: + da = da.rename('') + + return da + + def _filter_by_pattern( names: list[str], include: FilterType | None, @@ -1503,82 +1552,25 @@ def heatmap( PlotResult with processed data and figure. """ solution = self._stats._require_solution() - if isinstance(variables, str): variables = [variables] - # Resolve flow labels to variable names - resolved_variables = self._resolve_variable_names(variables, solution) + # Resolve, select, and stack into single DataArray + resolved = self._resolve_variable_names(variables, solution) + ds = _apply_selection(solution[resolved], select) + da = xr.concat([ds[v] for v in ds.data_vars], dim=pd.Index(list(ds.data_vars), name='variable')) - ds = solution[resolved_variables] - ds = _apply_selection(ds, select) + # Prepare for heatmap (reshape, transpose, squeeze) + da = _prepare_for_heatmap(da, reshape, facet_col, animation_frame) - # Stack variables into single DataArray - variable_names = list(ds.data_vars) - dataarrays = [ds[var] for var in variable_names] - da = xr.concat(dataarrays, dim=pd.Index(variable_names, name='variable')) - - # Check if data is clustered (has cluster dimension with size > 1) - is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 - has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - - # Count extra dims (beyond time) - if too many, skip reshape to avoid dimension explosion - # Reshape adds 1 dim (time -> timestep + timeframe), so check available slots - extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] - # Count available slots: 'auto' means available, None/explicit means not available - available_slots = (1 if facet_col == 'auto' else 0) + (1 if animation_frame == 'auto' else 0) - can_reshape = len(extra_dims) <= available_slots - - # Apply time reshape if needed (creates timestep/timeframe dims) - if is_clustered and (reshape == 'auto' or reshape is None): - # Clustered data: use (time, cluster) as natural 2D heatmap axes - heatmap_dims = ['time', 'cluster'] - elif reshape and reshape != 'auto' and 'time' in da.dims: - # Non-clustered with explicit reshape: reshape time to (day, hour) etc. - da = _reshape_time_for_heatmap(da, reshape) - heatmap_dims = ['timestep', 'timeframe'] - elif reshape == 'auto' and 'time' in da.dims and not is_clustered and can_reshape: - # Auto mode for non-clustered: use default ('D', 'h') reshape only if not too many dims - da = _reshape_time_for_heatmap(da, ('D', 'h')) - heatmap_dims = ['timestep', 'timeframe'] - elif has_multiple_vars: - # Can't reshape but have multiple vars: use variable + time as heatmap axes - heatmap_dims = ['variable', 'time'] - else: - # Fallback: use first two available dimensions - available_dims = [d for d in da.dims if da.sizes[d] > 1] - heatmap_dims = available_dims[:2] if len(available_dims) >= 2 else list(da.dims)[:2] - - # Transpose so heatmap dims come first (px.imshow uses first 2 dims as y/x axes) - other_dims = [d for d in da.dims if d not in heatmap_dims] - dim_order = [d for d in heatmap_dims if d in da.dims] + other_dims - # Always transpose to ensure correct dim order (even if seemingly equal, xarray dim order matters) - da = da.transpose(*dim_order) - - # Squeeze single-element dims (except heatmap axes) to avoid 3D shape errors - for dim in list(da.dims): - if dim not in heatmap_dims and da.sizes[dim] == 1: - da = da.squeeze(dim, drop=True) - - # Clear name for multiple variables (colorbar would show first var's name) - if has_multiple_vars: - da = da.rename('') - - # Let fxplot handle slot assignment for facet/animation - fig = da.fxplot.heatmap( - colors=colors, - facet_col=facet_col, - animation_frame=animation_frame, - **plotly_kwargs, - ) + fig = da.fxplot.heatmap(colors=colors, facet_col=facet_col, animation_frame=animation_frame, **plotly_kwargs) if show is None: show = CONFIG.Plotting.default_show if show: fig.show() - reshaped_ds = da.to_dataset(name='value') if isinstance(da, xr.DataArray) else da - return PlotResult(data=reshaped_ds, figure=fig) + return PlotResult(data=da.to_dataset(name='value'), figure=fig) def flows( self, From 56719e8d278b51d81261df82f5794d9e6b28cd05 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:25:21 +0100 Subject: [PATCH 57/62] Fix heatmap --- flixopt/statistics_accessor.py | 63 ++++++++++++++++------------------ 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 09d75f145..e3581e4e3 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -134,46 +134,41 @@ def _prepare_for_heatmap( animation_frame: str | Literal['auto'] | None, ) -> xr.DataArray: """Prepare DataArray for heatmap: determine axes, reshape if needed, transpose/squeeze.""" + + def finalize(da: xr.DataArray, heatmap_dims: list[str]) -> xr.DataArray: + """Transpose, squeeze, and clear name if needed.""" + other = [d for d in da.dims if d not in heatmap_dims] + da = da.transpose(*[d for d in heatmap_dims if d in da.dims], *other) + for dim in [d for d in da.dims if d not in heatmap_dims and da.sizes[d] == 1]: + da = da.squeeze(dim, drop=True) + return da.rename('') if da.sizes.get('variable', 1) > 1 else da + + def fallback_dims() -> list[str]: + """Default dims: (variable, time) if multi-var, else first 2 dims with size > 1.""" + if da.sizes.get('variable', 1) > 1: + return ['variable', 'time'] + dims = [d for d in da.dims if da.sizes[d] > 1][:2] + return dims if len(dims) >= 2 else list(da.dims)[:2] + is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 has_time = 'time' in da.dims - has_multi_vars = da.sizes.get('variable', 1) > 1 - # Determine heatmap axes and apply reshape if needed + # Clustered: use (time, cluster) as natural 2D if is_clustered and reshape in ('auto', None): - heatmap_dims = ['time', 'cluster'] - elif reshape and reshape != 'auto' and has_time: - da = _reshape_time_for_heatmap(da, reshape) - heatmap_dims = ['timestep', 'timeframe'] - elif reshape == 'auto' and has_time and not is_clustered: - # Check if we have room for extra dims after reshaping (adds 1 dim: time -> timestep + timeframe) - extra_dims = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] - available_slots = (facet_col == 'auto') + (animation_frame == 'auto') - if len(extra_dims) <= available_slots: - da = _reshape_time_for_heatmap(da, ('D', 'h')) - heatmap_dims = ['timestep', 'timeframe'] - elif has_multi_vars: - heatmap_dims = ['variable', 'time'] - else: - heatmap_dims = [d for d in da.dims if da.sizes[d] > 1][:2] or list(da.dims)[:2] - elif has_multi_vars: - heatmap_dims = ['variable', 'time'] - else: - heatmap_dims = [d for d in da.dims if da.sizes[d] > 1][:2] or list(da.dims)[:2] - - # Transpose: heatmap dims first, then others - other_dims = [d for d in da.dims if d not in heatmap_dims] - da = da.transpose(*[d for d in heatmap_dims if d in da.dims], *other_dims) - - # Squeeze single-element dims (except heatmap axes) - for dim in list(da.dims): - if dim not in heatmap_dims and da.sizes[dim] == 1: - da = da.squeeze(dim, drop=True) + return finalize(da, ['time', 'cluster']) + + # Explicit reshape: always apply + if reshape and reshape != 'auto' and has_time: + return finalize(_reshape_time_for_heatmap(da, reshape), ['timestep', 'timeframe']) - # Clear name for multiple variables (colorbar would show first var's name) - if has_multi_vars: - da = da.rename('') + # Auto reshape (non-clustered): apply only if extra dims fit in available slots + if reshape == 'auto' and has_time: + extra = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] + slots = (facet_col == 'auto') + (animation_frame == 'auto') + if len(extra) <= slots: + return finalize(_reshape_time_for_heatmap(da, ('D', 'h')), ['timestep', 'timeframe']) - return da + return finalize(da, fallback_dims()) def _filter_by_pattern( From df825de3eb2381c0bc2e131a721a7591e063469d Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:42:14 +0100 Subject: [PATCH 58/62] Merge remote-tracking branch 'origin/feature/tsam-params' into feature/comparison # Conflicts: # docs/notebooks/08c-clustering.ipynb # flixopt/config.py --- .github/workflows/docs.yaml | 50 ++- .github/workflows/release.yaml | 8 +- .github/workflows/tests.yaml | 8 +- CHANGELOG.md | 153 ++++++- CITATION.cff | 4 +- docs/notebooks/data/tutorial_data.py | 246 +++++++++++ docs/user-guide/optimization/clustering.md | 56 +++ docs/user-guide/results-plotting.md | 3 + flixopt/__init__.py | 2 +- flixopt/clustering/base.py | 255 ++++++----- flixopt/clustering/intercluster_helpers.py | 6 +- flixopt/comparison.py | 104 ++--- flixopt/components.py | 45 +- flixopt/dataset_plot_accessor.py | 464 ++++++++++++--------- flixopt/statistics_accessor.py | 342 +++++---------- flixopt/transform_accessor.py | 230 ++++++++-- mkdocs.yml | 9 +- pyproject.toml | 12 +- tests/test_cluster_reduce_expand.py | 14 +- tests/test_clustering/test_base.py | 10 +- tests/test_clustering/test_integration.py | 98 +++++ 21 files changed, 1415 insertions(+), 704 deletions(-) create mode 100644 docs/notebooks/data/tutorial_data.py diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index b6121b23b..7a9fb3e66 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -36,7 +36,7 @@ jobs: runs-on: ubuntu-24.04 timeout-minutes: 30 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 @@ -57,7 +57,31 @@ jobs: - name: Install dependencies run: uv pip install --system ".[docs,full]" + - name: Get notebook cache key + id: notebook-cache-key + run: | + # Hash notebooks + flixopt source code (sorted for stable cache keys) + HASH=$(find docs/notebooks -name '*.ipynb' | sort | xargs cat | cat - <(find flixopt -name '*.py' | sort | xargs cat) | sha256sum | cut -d' ' -f1) + echo "hash=$HASH" >> $GITHUB_OUTPUT + + - name: Cache executed notebooks + uses: actions/cache@v4 + id: notebook-cache + with: + path: docs/notebooks/*.ipynb + key: notebooks-${{ steps.notebook-cache-key.outputs.hash }} + + - name: Execute notebooks in parallel + if: steps.notebook-cache.outputs.cache-hit != 'true' + run: | + # Execute all notebooks in parallel (4 at a time) + # Run from notebooks directory so relative imports work + cd docs/notebooks && find . -name '*.ipynb' -print0 | \ + xargs -0 -P 4 -I {} jupyter execute --inplace {} + - name: Build docs + env: + MKDOCS_JUPYTER_EXECUTE: "false" run: mkdocs build --strict - uses: actions/upload-artifact@v4 @@ -74,7 +98,7 @@ jobs: permissions: contents: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 @@ -95,12 +119,34 @@ jobs: - name: Install dependencies run: uv pip install --system ".[docs,full]" + - name: Get notebook cache key + id: notebook-cache-key + run: | + # Hash notebooks + flixopt source code (sorted for stable cache keys) + HASH=$(find docs/notebooks -name '*.ipynb' | sort | xargs cat | cat - <(find flixopt -name '*.py' | sort | xargs cat) | sha256sum | cut -d' ' -f1) + echo "hash=$HASH" >> $GITHUB_OUTPUT + + - name: Cache executed notebooks + uses: actions/cache@v4 + id: notebook-cache + with: + path: docs/notebooks/*.ipynb + key: notebooks-${{ steps.notebook-cache-key.outputs.hash }} + + - name: Execute notebooks in parallel + if: steps.notebook-cache.outputs.cache-hit != 'true' + run: | + cd docs/notebooks && find . -name '*.ipynb' -print0 | \ + xargs -0 -P 4 -I {} jupyter execute --inplace {} + - name: Configure Git run: | git config user.name "github-actions[bot]" git config user.email "41898282+github-actions[bot]@users.noreply.github.com" - name: Deploy docs + env: + MKDOCS_JUPYTER_EXECUTE: "false" run: | VERSION=${{ inputs.version }} VERSION=${VERSION#v} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 2bf49bfc1..598540e57 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -16,7 +16,7 @@ jobs: outputs: prepared: ${{ steps.validate.outputs.prepared }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Validate commit message id: validate @@ -45,7 +45,7 @@ jobs: app-id: ${{ vars.RELEASE_BOT_APP_ID }} private-key: ${{ secrets.RELEASE_BOT_PRIVATE_KEY }} - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 ref: main @@ -88,7 +88,7 @@ jobs: if: needs.check-preparation.outputs.prepared == 'true' runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: astral-sh/setup-uv@v6 with: @@ -166,7 +166,7 @@ jobs: permissions: contents: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-python@v6 with: diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e8713993b..39534b461 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -22,7 +22,7 @@ jobs: lint: runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: astral-sh/setup-uv@v6 with: @@ -47,7 +47,7 @@ jobs: matrix: python-version: ['3.11', '3.12', '3.13', '3.14'] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: astral-sh/setup-uv@v6 with: @@ -71,7 +71,7 @@ jobs: # Only run on main branch or when called by release workflow (not on PRs) if: github.event_name != 'pull_request' steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: astral-sh/setup-uv@v6 with: @@ -92,7 +92,7 @@ jobs: runs-on: ubuntu-24.04 needs: lint steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: astral-sh/setup-uv@v6 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 46ac8047b..bad4e4d52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,9 +51,12 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp Until here --> -## [5.1.0] - Upcoming +## [6.0.0] - Upcoming -**Summary**: Time-series clustering for faster optimization with configurable storage behavior across typical periods. Improved weights API with always-normalized scenario weights. +**Summary**: Major release introducing time-series clustering with storage inter-cluster linking, the new `fxplot` accessor for universal xarray plotting, and removal of deprecated v5.0 classes. Includes configurable storage behavior across typical periods and improved weights API. + +!!! warning "Breaking Changes" + This release removes `ClusteredOptimization` and `ClusteringParameters` which were deprecated in v5.0.0. Use `flow_system.transform.cluster()` instead. See [Migration](#migration-from-clusteredoptimization) below. ### ✨ Added @@ -148,6 +151,44 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] Use `'cyclic'` for short-term storage like batteries or hot water tanks where only daily patterns matter. Use `'independent'` for quick estimates when storage behavior isn't critical. +**FXPlot Accessor**: New global xarray accessors for universal plotting with automatic faceting and smart dimension handling. Works on any xarray Dataset, not just flixopt results. + +```python +import flixopt as fx # Registers accessors automatically + +# Plot any xarray Dataset with automatic faceting +dataset.fxplot.bar(x='component') +dataset.fxplot.area(x='time') +dataset.fxplot.heatmap(x='time', y='component') +dataset.fxplot.line(x='time', facet_col='scenario') + +# DataArray support +data_array.fxplot.line() + +# Statistics transformations +dataset.fxstats.to_duration_curve() +``` + +**Available Plot Methods**: + +| Method | Description | +|--------|-------------| +| `.fxplot.bar()` | Grouped bar charts | +| `.fxplot.stacked_bar()` | Stacked bar charts | +| `.fxplot.line()` | Line charts with faceting | +| `.fxplot.area()` | Stacked area charts | +| `.fxplot.heatmap()` | Heatmap visualizations | +| `.fxplot.scatter()` | Scatter plots | +| `.fxplot.pie()` | Pie charts with faceting | +| `.fxstats.to_duration_curve()` | Transform to duration curve format | + +**Key Features**: + +- **Auto-faceting**: Automatically assigns extra dimensions (period, scenario, cluster) to `facet_col`, `facet_row`, or `animation_frame` +- **Smart x-axis**: Intelligently selects x dimension based on priority (time > duration > period > scenario) +- **Universal**: Works on any xarray Dataset/DataArray, not limited to flixopt +- **Configurable**: Customize via `CONFIG.Plotting` (colorscales, facet columns, line shapes) + ### 💥 Breaking Changes - `FlowSystem.scenario_weights` are now always normalized to sum to 1 when set (including after `.sel()` subsetting) @@ -159,12 +200,94 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] ### 🗑️ Deprecated +The following items are deprecated and will be removed in **v7.0.0**: + +**Classes** (use FlowSystem methods instead): + +- `Optimization` class → Use `flow_system.optimize(solver)` +- `SegmentedOptimization` class → Use `flow_system.optimize.rolling_horizon()` +- `Results` class → Use `flow_system.solution` and `flow_system.statistics` +- `SegmentedResults` class → Use segment FlowSystems directly + +**FlowSystem methods** (use `transform` or `topology` accessor instead): + +- `flow_system.sel()` → Use `flow_system.transform.sel()` +- `flow_system.isel()` → Use `flow_system.transform.isel()` +- `flow_system.resample()` → Use `flow_system.transform.resample()` +- `flow_system.plot_network()` → Use `flow_system.topology.plot()` +- `flow_system.start_network_app()` → Use `flow_system.topology.start_app()` +- `flow_system.stop_network_app()` → Use `flow_system.topology.stop_app()` +- `flow_system.network_infos()` → Use `flow_system.topology.infos()` + +**Parameters:** + - `normalize_weights` parameter in `create_model()`, `build_model()`, `optimize()` +**Topology method name simplifications** (old names still work with deprecation warnings, removal in v7.0.0): + +| Old (v5.x) | New (v6.0.0) | +|------------|--------------| +| `topology.plot_network()` | `topology.plot()` | +| `topology.start_network_app()` | `topology.start_app()` | +| `topology.stop_network_app()` | `topology.stop_app()` | +| `topology.network_infos()` | `topology.infos()` | + +Note: `topology.plot()` now renders a Sankey diagram. The old PyVis visualization is available via `topology.plot_legacy()`. + +### 🔥 Removed + +**Clustering classes removed** (deprecated in v5.0.0): + +- `ClusteredOptimization` class - Use `flow_system.transform.cluster()` then `optimize()` +- `ClusteringParameters` class - Parameters are now passed directly to `transform.cluster()` +- `flixopt/clustering.py` module - Restructured to `flixopt/clustering/` package with new classes + +#### Migration from ClusteredOptimization + +=== "v5.x (Old - No longer works)" + ```python + from flixopt import ClusteredOptimization, ClusteringParameters + + params = ClusteringParameters(hours_per_period=24, nr_of_periods=8) + calc = ClusteredOptimization('model', flow_system, params) + calc.do_modeling_and_solve(solver) + results = calc.results + ``` + +=== "v6.0.0 (New)" + ```python + # Cluster using transform accessor + fs_clustered = flow_system.transform.cluster( + n_clusters=8, # was: nr_of_periods + cluster_duration='1D', # was: hours_per_period=24 + ) + fs_clustered.optimize(solver) + + # Results on the clustered FlowSystem + costs = fs_clustered.solution['costs'].item() + + # Expand back to full resolution if needed + fs_expanded = fs_clustered.transform.expand_solution() + ``` + ### 🐛 Fixed - `temporal_weight` and `sum_temporal()` now use consistent implementation +### 📝 Docs + +**New Documentation Pages:** + +- [Time-Series Clustering Guide](https://flixopt.github.io/flixopt/latest/user-guide/optimization/clustering/) - Comprehensive guide to clustering workflows + +**New Jupyter Notebooks:** + +- **08c-clustering.ipynb** - Introduction to time-series clustering +- **08c2-clustering-storage-modes.ipynb** - Comparison of all 4 storage cluster modes +- **08d-clustering-multiperiod.ipynb** - Clustering with periods and scenarios +- **08e-clustering-internals.ipynb** - Understanding clustering internals +- **fxplot_accessor_demo.ipynb** - Demo of the new fxplot accessor + ### 👷 Development **New Test Suites for Clustering**: @@ -174,8 +297,34 @@ charge_state = fs_expanded.solution['SeasonalPit|charge_state'] - `TestMultiPeriodClustering`: Tests for clustering with periods and scenarios dimensions - `TestPeakSelection`: Tests for `time_series_for_high_peaks` and `time_series_for_low_peaks` parameters +**New Test Suites for Other Features**: + +- `test_clustering_io.py` - Tests for clustering serialization roundtrip +- `test_sel_isel_single_selection.py` - Tests for transform selection methods + --- +## [5.0.4] - 2026-01-05 + +**Summary**: Dependency updates. + +### 🐛 Fixed + +- Fixed netcdf dependency + +### 📦 Dependencies + +- Updated `mkdocs-material` to v9.7.1 +- Updated `mkdocstrings-python` to v1.19.0 +- Updated `ruff` to v0.14.10 +- Updated `pymdown-extensions` to v10.19.1 +- Updated `werkzeug` to v3.1.4 + +### 👷 Development + +- Updated `actions/checkout` action to v6 + +--- ## [5.0.3] - 2025-12-18 diff --git a/CITATION.cff b/CITATION.cff index 8da3f2727..1af00dcae 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -2,8 +2,8 @@ cff-version: 1.2.0 message: "If you use this software, please cite it as below and consider citing the related publication." type: software title: "flixopt" -version: 5.0.3 -date-released: 2025-12-18 +version: 5.0.4 +date-released: 2026-01-05 url: "https://github.com/flixOpt/flixopt" repository-code: "https://github.com/flixOpt/flixopt" license: MIT diff --git a/docs/notebooks/data/tutorial_data.py b/docs/notebooks/data/tutorial_data.py new file mode 100644 index 000000000..3b4997e0a --- /dev/null +++ b/docs/notebooks/data/tutorial_data.py @@ -0,0 +1,246 @@ +"""Generate tutorial data for notebooks 01-07. + +These functions return data (timesteps, profiles, prices) rather than full FlowSystems, +so notebooks can demonstrate building systems step by step. + +Usage: + from data.tutorial_data import get_quickstart_data, get_heat_system_data, ... +""" + +import numpy as np +import pandas as pd +import xarray as xr + + +def get_quickstart_data() -> dict: + """Data for 01-quickstart: minimal 4-hour example. + + Returns: + dict with: timesteps, heat_demand (xr.DataArray) + """ + timesteps = pd.date_range('2024-01-15 08:00', periods=4, freq='h') + heat_demand = xr.DataArray( + [30, 50, 45, 25], + dims=['time'], + coords={'time': timesteps}, + name='Heat Demand [kW]', + ) + return { + 'timesteps': timesteps, + 'heat_demand': heat_demand, + } + + +def get_heat_system_data() -> dict: + """Data for 02-heat-system: one week with storage. + + Returns: + dict with: timesteps, heat_demand, gas_price (arrays) + """ + timesteps = pd.date_range('2024-01-15', periods=168, freq='h') + hours = np.arange(168) + hour_of_day = hours % 24 + day_of_week = (hours // 24) % 7 + + # Office heat demand pattern + base_demand = np.where((hour_of_day >= 7) & (hour_of_day <= 18), 80, 30) + weekend_factor = np.where(day_of_week >= 5, 0.5, 1.0) + np.random.seed(42) + heat_demand = base_demand * weekend_factor + np.random.normal(0, 5, len(timesteps)) + heat_demand = np.clip(heat_demand, 20, 100) + + # Time-of-use gas prices + gas_price = np.where((hour_of_day >= 6) & (hour_of_day <= 22), 0.08, 0.05) + + return { + 'timesteps': timesteps, + 'heat_demand': heat_demand, + 'gas_price': gas_price, + } + + +def get_investment_data() -> dict: + """Data for 03-investment-optimization: solar pool heating. + + Returns: + dict with: timesteps, solar_profile, pool_demand, costs + """ + timesteps = pd.date_range('2024-07-15', periods=168, freq='h') + hours = np.arange(168) + hour_of_day = hours % 24 + + # Solar profile + solar_profile = np.maximum(0, np.sin((hour_of_day - 6) * np.pi / 12)) * 0.8 + solar_profile = np.where((hour_of_day >= 6) & (hour_of_day <= 20), solar_profile, 0) + np.random.seed(42) + solar_profile = solar_profile * np.random.uniform(0.6, 1.0, len(timesteps)) + + # Pool demand + pool_demand = np.where((hour_of_day >= 8) & (hour_of_day <= 22), 150, 50) + + return { + 'timesteps': timesteps, + 'solar_profile': solar_profile, + 'pool_demand': pool_demand, + 'gas_price': 0.12, + 'solar_cost_per_kw_week': 20 / 52, + 'tank_cost_per_kwh_week': 1.5 / 52, + } + + +def get_constraints_data() -> dict: + """Data for 04-operational-constraints: factory steam demand. + + Returns: + dict with: timesteps, steam_demand + """ + timesteps = pd.date_range('2024-03-11', periods=72, freq='h') + hours = np.arange(72) + hour_of_day = hours % 24 + + # Shift-based demand + steam_demand = np.select( + [ + (hour_of_day >= 6) & (hour_of_day < 14), + (hour_of_day >= 14) & (hour_of_day < 22), + ], + [400, 350], + default=80, + ).astype(float) + + np.random.seed(123) + steam_demand = steam_demand + np.random.normal(0, 20, len(steam_demand)) + steam_demand = np.clip(steam_demand, 50, 450) + + return { + 'timesteps': timesteps, + 'steam_demand': steam_demand, + } + + +def get_multicarrier_data() -> dict: + """Data for 05-multi-carrier-system: hospital CHP. + + Returns: + dict with: timesteps, electricity_demand, heat_demand, prices + """ + timesteps = pd.date_range('2024-02-05', periods=168, freq='h') + hours = np.arange(168) + hour_of_day = hours % 24 + + # Electricity demand + elec_base = 150 + elec_daily = 100 * np.sin((hour_of_day - 6) * np.pi / 12) + elec_daily = np.maximum(0, elec_daily) + electricity_demand = elec_base + elec_daily + + # Heat demand + heat_pattern = np.select( + [ + (hour_of_day >= 5) & (hour_of_day < 9), + (hour_of_day >= 9) & (hour_of_day < 17), + (hour_of_day >= 17) & (hour_of_day < 22), + ], + [350, 250, 300], + default=200, + ).astype(float) + + np.random.seed(456) + electricity_demand += np.random.normal(0, 15, len(timesteps)) + heat_demand = heat_pattern + np.random.normal(0, 20, len(timesteps)) + electricity_demand = np.clip(electricity_demand, 100, 300) + heat_demand = np.clip(heat_demand, 150, 400) + + # Prices + elec_buy_price = np.where((hour_of_day >= 7) & (hour_of_day <= 21), 0.35, 0.20) + + return { + 'timesteps': timesteps, + 'electricity_demand': electricity_demand, + 'heat_demand': heat_demand, + 'elec_buy_price': elec_buy_price, + 'elec_sell_price': 0.12, + 'gas_price': 0.05, + } + + +def get_time_varying_data() -> dict: + """Data for 06a-time-varying-parameters: heat pump with variable COP. + + Returns: + dict with: timesteps, outdoor_temp, heat_demand, cop + """ + timesteps = pd.date_range('2024-01-22', periods=168, freq='h') + hours = np.arange(168) + hour_of_day = hours % 24 + + # Outdoor temperature + temp_base = 2 + temp_amplitude = 5 + outdoor_temp = temp_base + temp_amplitude * np.sin((hour_of_day - 6) * np.pi / 12) + np.random.seed(789) + outdoor_temp = outdoor_temp + np.repeat(np.random.uniform(-3, 3, 7), 24) + + # Heat demand (inversely related to temperature) + heat_demand = 200 - 8 * outdoor_temp + heat_demand = np.clip(heat_demand, 100, 300) + + # COP calculation + t_supply = 45 + 273.15 + t_source = outdoor_temp + 273.15 + carnot_cop = t_supply / (t_supply - t_source) + cop = np.clip(0.45 * carnot_cop, 2.0, 5.0) + + return { + 'timesteps': timesteps, + 'outdoor_temp': outdoor_temp, + 'heat_demand': heat_demand, + 'cop': cop, + } + + +def get_scenarios_data() -> dict: + """Data for 07-scenarios-and-periods: multi-year planning. + + Returns: + dict with: timesteps, periods, scenarios, weights, heat_demand (DataFrame), prices + """ + timesteps = pd.date_range('2024-01-15', periods=168, freq='h') + periods = pd.Index([2024, 2025, 2026], name='period') + scenarios = pd.Index(['Mild Winter', 'Harsh Winter'], name='scenario') + scenario_weights = np.array([0.6, 0.4]) + + hours = np.arange(168) + hour_of_day = hours % 24 + + # Base pattern + daily_pattern = np.select( + [ + (hour_of_day >= 6) & (hour_of_day < 9), + (hour_of_day >= 9) & (hour_of_day < 17), + (hour_of_day >= 17) & (hour_of_day < 22), + ], + [180, 120, 160], + default=100, + ).astype(float) + + np.random.seed(42) + noise = np.random.normal(0, 10, len(timesteps)) + + mild_demand = np.clip(daily_pattern * 0.8 + noise, 60, 200) + harsh_demand = np.clip(daily_pattern * 1.3 + noise * 1.5, 100, 280) + + heat_demand = pd.DataFrame( + {'Mild Winter': mild_demand, 'Harsh Winter': harsh_demand}, + index=timesteps, + ) + + return { + 'timesteps': timesteps, + 'periods': periods, + 'scenarios': scenarios, + 'scenario_weights': scenario_weights, + 'heat_demand': heat_demand, + 'gas_prices': np.array([0.06, 0.08, 0.10]), + 'elec_prices': np.array([0.28, 0.34, 0.43]), + } diff --git a/docs/user-guide/optimization/clustering.md b/docs/user-guide/optimization/clustering.md index 7ec5faac1..793fbf8fe 100644 --- a/docs/user-guide/optimization/clustering.md +++ b/docs/user-guide/optimization/clustering.md @@ -52,6 +52,10 @@ flow_rates = fs_expanded.solution['Boiler(Q_th)|flow_rate'] | `cluster_duration` | Duration of each cluster | `'1D'`, `'24h'`, or `24` (hours) | | `time_series_for_high_peaks` | Time series where peak clusters must be captured | `['HeatDemand(Q)|fixed_relative_profile']` | | `time_series_for_low_peaks` | Time series where minimum clusters must be captured | `['SolarGen(P)|fixed_relative_profile']` | +| `cluster_method` | Clustering algorithm | `'k_means'`, `'hierarchical'`, `'k_medoids'` | +| `representation_method` | How clusters are represented | `'meanRepresentation'`, `'medoidRepresentation'` | +| `random_state` | Random seed for reproducibility | `42` | +| `rescale_cluster_periods` | Rescale clusters to match original means | `True` (default) | ### Peak Selection @@ -68,6 +72,58 @@ fs_clustered = flow_system.transform.cluster( Without peak selection, the clustering algorithm might average out extreme days, leading to undersized equipment. +### Advanced Clustering Options + +Fine-tune the clustering algorithm with advanced parameters: + +```python +fs_clustered = flow_system.transform.cluster( + n_clusters=8, + cluster_duration='1D', + cluster_method='hierarchical', # Alternative to k_means + representation_method='medoidRepresentation', # Use actual periods, not averages + rescale_cluster_periods=True, # Match original time series means + random_state=42, # Reproducible results +) +``` + +**Available clustering algorithms** (`cluster_method`): + +| Method | Description | +|--------|-------------| +| `'k_means'` | Fast, good for most cases (default) | +| `'hierarchical'` | Produces consistent hierarchical groupings | +| `'k_medoids'` | Uses actual periods as representatives | +| `'k_maxoids'` | Maximizes representativeness | +| `'averaging'` | Simple averaging of similar periods | + +For advanced tsam parameters not exposed directly, use `**kwargs`: + +```python +# Pass any tsam.TimeSeriesAggregation parameter +fs_clustered = flow_system.transform.cluster( + n_clusters=8, + cluster_duration='1D', + sameMean=True, # Normalize all time series to same mean + sortValues=True, # Cluster by duration curves instead of shape +) +``` + +### Clustering Quality Metrics + +After clustering, access quality metrics to evaluate the aggregation accuracy: + +```python +fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') + +# Access clustering metrics (xr.Dataset) +metrics = fs_clustered.clustering.metrics +print(metrics) # Shows RMSE, MAE, etc. per time series + +# Access specific metric +rmse = metrics['RMSE'] # xr.DataArray with dims [time_series, period?, scenario?] +``` + ## Storage Modes Storage behavior during clustering is controlled via the `cluster_mode` parameter: diff --git a/docs/user-guide/results-plotting.md b/docs/user-guide/results-plotting.md index 1ecd26aa1..28e3d2b2b 100644 --- a/docs/user-guide/results-plotting.md +++ b/docs/user-guide/results-plotting.md @@ -2,6 +2,9 @@ After solving an optimization, flixOpt provides a powerful plotting API to visualize and analyze your results. The API is designed to be intuitive and chainable, giving you quick access to common plots while still allowing deep customization. +!!! tip "Plotting Custom Data" + For plotting arbitrary xarray data (not just flixopt results), see the [Custom Data Plotting](recipes/plotting-custom-data.md) guide which covers the `.fxplot` accessor. + ## The Plot Accessor All plotting is accessed through the `statistics.plot` accessor on your FlowSystem: diff --git a/flixopt/__init__.py b/flixopt/__init__.py index 6b226ea28..b84b82a4f 100644 --- a/flixopt/__init__.py +++ b/flixopt/__init__.py @@ -68,7 +68,7 @@ 'solvers', ] -# Initialize logger with default configuration (silent: WARNING level, NullHandler) +# Initialize logger with default configuration (silent: WARNING level, NullHandler). logger = logging.getLogger('flixopt') logger.setLevel(logging.WARNING) logger.addHandler(logging.NullHandler()) diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index 4b31832e4..ab9590aae 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -38,15 +38,15 @@ class ClusterStructure: which is needed for proper storage state-of-charge tracking across typical periods when using cluster(). - Note: "original_period" here refers to the original time chunks before - clustering (e.g., 365 original days), NOT the model's "period" dimension - (years/months). Each original time chunk gets assigned to a cluster. + Note: The "original_cluster" dimension indexes the original cluster-sized + time segments (e.g., 0..364 for 365 days), NOT the model's "period" dimension + (years). Each original segment gets assigned to a representative cluster. Attributes: - cluster_order: Maps each original time chunk index to its cluster ID. - dims: [original_period] for simple case, or - [original_period, period, scenario] for multi-period/scenario systems. - Values are cluster indices (0 to n_clusters-1). + cluster_order: Maps original cluster index → representative cluster ID. + dims: [original_cluster] for simple case, or + [original_cluster, period, scenario] for multi-period/scenario systems. + Values are cluster IDs (0 to n_clusters-1). cluster_occurrences: Count of how many original time chunks each cluster represents. dims: [cluster] for simple case, or [cluster, period, scenario] for multi-dim. n_clusters: Number of distinct clusters (typical periods). @@ -60,7 +60,7 @@ class ClusterStructure: - timesteps_per_cluster: 24 (for hourly data) For multi-scenario (e.g., 2 scenarios): - - cluster_order: shape (365, 2) with dims [original_period, scenario] + - cluster_order: shape (365, 2) with dims [original_cluster, scenario] - cluster_occurrences: shape (8, 2) with dims [cluster, scenario] """ @@ -73,7 +73,7 @@ def __post_init__(self): """Validate and ensure proper DataArray formatting.""" # Ensure cluster_order is a DataArray with proper dims if not isinstance(self.cluster_order, xr.DataArray): - self.cluster_order = xr.DataArray(self.cluster_order, dims=['original_period'], name='cluster_order') + self.cluster_order = xr.DataArray(self.cluster_order, dims=['original_cluster'], name='cluster_order') elif self.cluster_order.name is None: self.cluster_order = self.cluster_order.rename('cluster_order') @@ -92,7 +92,7 @@ def __repr__(self) -> str: occ = [int(self.cluster_occurrences.sel(cluster=c).values) for c in range(n_clusters)] return ( f'ClusterStructure(\n' - f' {self.n_original_periods} original periods → {n_clusters} clusters\n' + f' {self.n_original_clusters} original periods → {n_clusters} clusters\n' f' timesteps_per_cluster={self.timesteps_per_cluster}\n' f' occurrences={occ}\n' f')' @@ -124,9 +124,9 @@ def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: return ref, arrays @property - def n_original_periods(self) -> int: + def n_original_clusters(self) -> int: """Number of original periods (before clustering).""" - return len(self.cluster_order.coords['original_period']) + return len(self.cluster_order.coords['original_cluster']) @property def has_multi_dims(self) -> bool: @@ -197,20 +197,20 @@ def get_cluster_weight_per_timestep(self) -> xr.DataArray: name='cluster_weight', ) - def plot(self, show: bool | None = None) -> PlotResult: + def plot(self, colors: str | list[str] | None = None, show: bool | None = None) -> PlotResult: """Plot cluster assignment visualization. Shows which cluster each original period belongs to, and the number of occurrences per cluster. Args: + colors: Colorscale name (str) or list of colors. + Defaults to CONFIG.Plotting.default_sequential_colorscale. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. Returns: PlotResult containing the figure and underlying data. """ - import plotly.express as px - from ..config import CONFIG from ..plot_result import PlotResult @@ -218,27 +218,24 @@ def plot(self, show: bool | None = None) -> PlotResult: int(self.n_clusters) if isinstance(self.n_clusters, (int, np.integer)) else int(self.n_clusters.values) ) - # Create DataFrame for plotting - import pandas as pd - cluster_order = self.get_cluster_order_for_slice() - df = pd.DataFrame( - { - 'Original Period': range(1, len(cluster_order) + 1), - 'Cluster': cluster_order, - } + + # Build DataArray for fxplot heatmap + cluster_da = xr.DataArray( + cluster_order.reshape(1, -1), + dims=['y', 'original_cluster'], + coords={'y': ['Cluster'], 'original_cluster': range(1, len(cluster_order) + 1)}, + name='cluster_assignment', ) - # Bar chart showing cluster assignment - fig = px.bar( - df, - x='Original Period', - y=[1] * len(df), - color='Cluster', - color_continuous_scale='Viridis', - title=f'Cluster Assignment ({self.n_original_periods} periods → {n_clusters} clusters)', + # Use fxplot.heatmap for smart defaults + colorscale = colors or CONFIG.Plotting.default_sequential_colorscale + fig = cluster_da.fxplot.heatmap( + colors=colorscale, + title=f'Cluster Assignment ({self.n_original_clusters} periods → {n_clusters} clusters)', ) - fig.update_layout(yaxis_visible=False, coloraxis_colorbar_title='Cluster') + fig.update_yaxes(showticklabels=False) + fig.update_coloraxes(colorbar_title='Cluster') # Build data for PlotResult data = xr.Dataset( @@ -532,30 +529,30 @@ def validate(self) -> None: # (each weight is how many original periods that cluster represents) # Sum should be checked per period/scenario slice, not across all dimensions if self.cluster_structure is not None: - n_original_periods = self.cluster_structure.n_original_periods + n_original_clusters = self.cluster_structure.n_original_clusters # Sum over cluster dimension only (keep period/scenario if present) weight_sum_per_slice = self.representative_weights.sum(dim='cluster') # Check each slice if weight_sum_per_slice.size == 1: # Simple case: no period/scenario weight_sum = float(weight_sum_per_slice.values) - if abs(weight_sum - n_original_periods) > 1e-6: + if abs(weight_sum - n_original_clusters) > 1e-6: import warnings warnings.warn( f'representative_weights sum ({weight_sum}) does not match ' - f'n_original_periods ({n_original_periods})', + f'n_original_clusters ({n_original_clusters})', stacklevel=2, ) else: # Multi-dimensional: check each slice for val in weight_sum_per_slice.values.flat: - if abs(float(val) - n_original_periods) > 1e-6: + if abs(float(val) - n_original_clusters) > 1e-6: import warnings warnings.warn( f'representative_weights sum per slice ({float(val)}) does not match ' - f'n_original_periods ({n_original_periods})', + f'n_original_clusters ({n_original_clusters})', stacklevel=2, ) break # Only warn once @@ -585,8 +582,10 @@ def compare( *, select: SelectType | None = None, colors: ColorType | None = None, - facet_col: str | None = 'period', - facet_row: str | None = 'scenario', + color: str | None = 'auto', + line_dash: str | None = 'representation', + facet_col: str | None = 'auto', + facet_row: str | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -600,8 +599,14 @@ def compare( or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). - facet_col: Dimension for subplot columns (default: 'period'). - facet_row: Dimension for subplot rows (default: 'scenario'). + color: Dimension for line colors. 'auto' uses CONFIG priority (typically 'variable'). + Use 'representation' to color by Original/Clustered instead of line_dash. + line_dash: Dimension for line dash styles. Defaults to 'representation'. + Set to None to disable line dash differentiation. + facet_col: Dimension for subplot columns. 'auto' uses CONFIG priority. + Use 'variable' to create separate columns per variable. + facet_row: Dimension for subplot rows. 'auto' uses CONFIG priority. + Use 'variable' to create separate rows per variable. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. **plotly_kwargs: Additional arguments passed to plotly. @@ -610,9 +615,7 @@ def compare( PlotResult containing the comparison figure and underlying data. """ import pandas as pd - import plotly.express as px - from ..color_processing import process_colors from ..config import CONFIG from ..plot_result import PlotResult from ..statistics_accessor import _apply_selection @@ -626,7 +629,7 @@ def compare( resolved_variables = self._resolve_variables(variables) - # Build Dataset with 'representation' dimension for Original/Clustered + # Build Dataset with variables as data_vars data_vars = {} for var in resolved_variables: original = result.original_data[var] @@ -650,54 +653,41 @@ def compare( { var: xr.DataArray( [sorted_vars[(var, r)] for r in ['Original', 'Clustered']], - dims=['representation', 'rank'], - coords={'representation': ['Original', 'Clustered'], 'rank': range(n)}, + dims=['representation', 'duration'], + coords={'representation': ['Original', 'Clustered'], 'duration': range(n)}, ) for var in resolved_variables } ) - # Resolve facets (only for timeseries) - actual_facet_col = facet_col if kind == 'timeseries' and facet_col in ds.dims else None - actual_facet_row = facet_row if kind == 'timeseries' and facet_row in ds.dims else None - - # Convert to long-form DataFrame - df = ds.to_dataframe().reset_index() - coord_cols = [c for c in ds.coords.keys() if c in df.columns] - df = df.melt(id_vars=coord_cols, var_name='variable', value_name='value') - - variable_labels = df['variable'].unique().tolist() - color_map = process_colors(colors, variable_labels, CONFIG.Plotting.default_qualitative_colorscale) - - # Set x-axis and title based on kind - x_col = 'time' if kind == 'timeseries' else 'rank' + # Set title based on kind if kind == 'timeseries': title = ( 'Original vs Clustered' if len(resolved_variables) > 1 else f'Original vs Clustered: {resolved_variables[0]}' ) - labels = {} else: title = 'Duration Curve' if len(resolved_variables) > 1 else f'Duration Curve: {resolved_variables[0]}' - labels = {'rank': 'Hours (sorted)', 'value': 'Value'} - - fig = px.line( - df, - x=x_col, - y='value', - color='variable', - line_dash='representation', - facet_col=actual_facet_col, - facet_row=actual_facet_row, + + # Use fxplot for smart defaults + line_kwargs = {} + if line_dash is not None: + line_kwargs['line_dash'] = line_dash + if line_dash == 'representation': + line_kwargs['line_dash_map'] = {'Original': 'dot', 'Clustered': 'solid'} + + fig = ds.fxplot.line( + colors=colors, + color=color, title=title, - labels=labels, - color_discrete_map=color_map, + facet_col=facet_col, + facet_row=facet_row, + **line_kwargs, **plotly_kwargs, ) - if actual_facet_row or actual_facet_col: - fig.update_yaxes(matches=None) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + fig.update_yaxes(matches=None) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) plot_result = PlotResult(data=ds, figure=fig) @@ -743,8 +733,8 @@ def heatmap( *, select: SelectType | None = None, colors: str | list[str] | None = None, - facet_col: str | None = 'period', - animation_frame: str | None = 'scenario', + facet_col: str | None = 'auto', + animation_frame: str | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -762,8 +752,8 @@ def heatmap( colors: Colorscale name (str) or list of colors for heatmap coloring. Dicts are not supported for heatmaps. Defaults to CONFIG.Plotting.default_sequential_colorscale. - facet_col: Dimension to facet on columns (default: 'period'). - animation_frame: Dimension for animation slider (default: 'scenario'). + facet_col: Dimension to facet on columns. 'auto' uses CONFIG priority. + animation_frame: Dimension for animation slider. 'auto' uses CONFIG priority. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. **plotly_kwargs: Additional arguments passed to plotly. @@ -773,7 +763,6 @@ def heatmap( The data has 'cluster' variable with time dimension, matching original timesteps. """ import pandas as pd - import plotly.express as px from ..config import CONFIG from ..plot_result import PlotResult @@ -833,34 +822,24 @@ def heatmap( else: cluster_da = cluster_slices[(None, None)] - # Resolve facet_col and animation_frame - only use if dimension exists - actual_facet_col = facet_col if facet_col and facet_col in cluster_da.dims else None - actual_animation = animation_frame if animation_frame and animation_frame in cluster_da.dims else None - # Add dummy y dimension for heatmap visualization (single row) heatmap_da = cluster_da.expand_dims('y', axis=-1) heatmap_da = heatmap_da.assign_coords(y=['Cluster']) + heatmap_da.name = 'cluster_assignment' - colorscale = colors or CONFIG.Plotting.default_sequential_colorscale - - # Use px.imshow with xr.DataArray - fig = px.imshow( - heatmap_da, - color_continuous_scale=colorscale, - facet_col=actual_facet_col, - animation_frame=actual_animation, + # Use fxplot.heatmap for smart defaults + fig = heatmap_da.fxplot.heatmap( + colors=colors, title='Cluster Assignments', - labels={'time': 'Time', 'color': 'Cluster'}, + facet_col=facet_col, + animation_frame=animation_frame, aspect='auto', **plotly_kwargs, ) - # Clean up facet labels - if actual_facet_col: - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) - - # Hide y-axis since it's just a single row + # Clean up: hide y-axis since it's just a single row fig.update_yaxes(showticklabels=False) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) # Data is exactly what we plotted (without dummy y dimension) cluster_da.name = 'cluster' @@ -880,21 +859,27 @@ def clusters( *, select: SelectType | None = None, colors: ColorType | None = None, - facet_col_wrap: int | None = None, + color: str | None = 'auto', + facet_col: str | None = 'cluster', + facet_cols: int | None = None, show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: """Plot each cluster's typical period profile. - Shows each cluster as a separate faceted subplot. Useful for - understanding what each cluster represents. + Shows each cluster as a separate faceted subplot with all variables + colored differently. Useful for understanding what each cluster represents. Args: variables: Variable(s) to plot. Can be a string, list of strings, or None to plot all time-varying variables. select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. colors: Color specification (colorscale name, color list, or label-to-color dict). - facet_col_wrap: Max columns before wrapping facets. + color: Dimension for line colors. 'auto' uses CONFIG priority (typically 'variable'). + Use 'cluster' to color by cluster instead of faceting. + facet_col: Dimension for subplot columns. Defaults to 'cluster'. + Use 'variable' to facet by variable instead. + facet_cols: Max columns before wrapping facets. Defaults to CONFIG.Plotting.default_facet_cols. show: Whether to display the figure. Defaults to CONFIG.Plotting.default_show. @@ -903,10 +888,6 @@ def clusters( Returns: PlotResult containing the figure and underlying data. """ - import pandas as pd - import plotly.express as px - - from ..color_processing import process_colors from ..config import CONFIG from ..plot_result import PlotResult from ..statistics_accessor import _apply_selection @@ -929,45 +910,37 @@ def clusters( n_clusters = int(cs.n_clusters) if isinstance(cs.n_clusters, (int, np.integer)) else int(cs.n_clusters.values) timesteps_per_cluster = cs.timesteps_per_cluster - # Build long-form DataFrame with cluster labels including occurrence counts - rows = [] + # Build Dataset with cluster dimension, using labels with occurrence counts + cluster_labels = [ + f'Cluster {c} (×{int(cs.cluster_occurrences.sel(cluster=c).values)})' for c in range(n_clusters) + ] + data_vars = {} for var in resolved_variables: data = aggregated_data[var].values data_by_cluster = data.reshape(n_clusters, timesteps_per_cluster) data_vars[var] = xr.DataArray( data_by_cluster, - dims=['cluster', 'timestep'], - coords={'cluster': range(n_clusters), 'timestep': range(timesteps_per_cluster)}, + dims=['cluster', 'time'], + coords={'cluster': cluster_labels, 'time': range(timesteps_per_cluster)}, ) - for c in range(n_clusters): - occurrence = int(cs.cluster_occurrences.sel(cluster=c).values) - label = f'Cluster {c} (×{occurrence})' - for t in range(timesteps_per_cluster): - rows.append({'cluster': label, 'timestep': t, 'value': data_by_cluster[c, t], 'variable': var}) - df = pd.DataFrame(rows) - - cluster_labels = df['cluster'].unique().tolist() - color_map = process_colors(colors, cluster_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_col_wrap or CONFIG.Plotting.default_facet_cols + + ds = xr.Dataset(data_vars) title = 'Clusters' if len(resolved_variables) > 1 else f'Clusters: {resolved_variables[0]}' - fig = px.line( - df, - x='timestep', - y='value', - facet_col='cluster', - facet_row='variable' if len(resolved_variables) > 1 else None, - facet_col_wrap=facet_col_wrap if len(resolved_variables) == 1 else None, + # Use fxplot for smart defaults + fig = ds.fxplot.line( + colors=colors, + color=color, title=title, - color_discrete_map=color_map, + facet_col=facet_col, + facet_cols=facet_cols, **plotly_kwargs, ) - fig.update_layout(showlegend=False) - if len(resolved_variables) > 1: - fig.update_yaxes(matches=None) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + fig.update_yaxes(matches=None) + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) + # Include occurrences in result data data_vars['occurrences'] = cs.cluster_occurrences result_data = xr.Dataset(data_vars) plot_result = PlotResult(data=result_data, figure=fig) @@ -993,6 +966,9 @@ class Clustering: Attributes: result: The ClusterResult from the aggregation backend. backend_name: Name of the aggregation backend used (e.g., 'tsam', 'manual'). + metrics: Clustering quality metrics (RMSE, MAE, etc.) as xr.Dataset. + Each metric (e.g., 'RMSE', 'MAE') is a DataArray with dims + ``[time_series, period?, scenario?]``. Example: >>> fs_clustered = flow_system.transform.cluster(n_clusters=8, cluster_duration='1D') @@ -1004,6 +980,7 @@ class Clustering: result: ClusterResult backend_name: str = 'unknown' + metrics: xr.Dataset | None = None def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: """Create reference structure for serialization.""" @@ -1026,7 +1003,7 @@ def __repr__(self) -> str: n_clusters = ( int(cs.n_clusters) if isinstance(cs.n_clusters, (int, np.integer)) else int(cs.n_clusters.values) ) - structure_info = f'{cs.n_original_periods} periods → {n_clusters} clusters' + structure_info = f'{cs.n_original_clusters} periods → {n_clusters} clusters' else: structure_info = 'no structure' return f'Clustering(\n backend={self.backend_name!r}\n {structure_info}\n)' @@ -1071,11 +1048,11 @@ def n_clusters(self) -> int: return int(n) if isinstance(n, (int, np.integer)) else int(n.values) @property - def n_original_periods(self) -> int: + def n_original_clusters(self) -> int: """Number of original periods (before clustering).""" if self.result.cluster_structure is None: raise ValueError('No cluster_structure available') - return self.result.cluster_structure.n_original_periods + return self.result.cluster_structure.n_original_clusters @property def timesteps_per_period(self) -> int: @@ -1152,17 +1129,17 @@ def create_cluster_structure_from_mapping( ClusterStructure derived from the mapping. """ n_original = len(timestep_mapping) - n_original_periods = n_original // timesteps_per_cluster + n_original_clusters = n_original // timesteps_per_cluster # Determine cluster order from the mapping # Each original period maps to the cluster of its first timestep cluster_order = [] - for p in range(n_original_periods): + for p in range(n_original_clusters): start_idx = p * timesteps_per_cluster cluster_idx = int(timestep_mapping.isel(original_time=start_idx).values) // timesteps_per_cluster cluster_order.append(cluster_idx) - cluster_order_da = xr.DataArray(cluster_order, dims=['original_period'], name='cluster_order') + cluster_order_da = xr.DataArray(cluster_order, dims=['original_cluster'], name='cluster_order') # Count occurrences of each cluster unique_clusters = np.unique(cluster_order) diff --git a/flixopt/clustering/intercluster_helpers.py b/flixopt/clustering/intercluster_helpers.py index d2a5eb9d3..a89a80862 100644 --- a/flixopt/clustering/intercluster_helpers.py +++ b/flixopt/clustering/intercluster_helpers.py @@ -132,7 +132,7 @@ def extract_capacity_bounds( def build_boundary_coords( - n_original_periods: int, + n_original_clusters: int, flow_system: FlowSystem, ) -> tuple[dict, list[str]]: """Build coordinates and dimensions for SOC_boundary variable. @@ -146,7 +146,7 @@ def build_boundary_coords( multi-period or stochastic optimizations. Args: - n_original_periods: Number of original (non-aggregated) time periods. + n_original_clusters: Number of original (non-aggregated) time periods. For example, if a year is clustered into 8 typical days but originally had 365 days, this would be 365. flow_system: The FlowSystem containing optional period/scenario dimensions. @@ -163,7 +163,7 @@ def build_boundary_coords( >>> coords['cluster_boundary'] array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) """ - n_boundaries = n_original_periods + 1 + n_boundaries = n_original_clusters + 1 coords = {'cluster_boundary': np.arange(n_boundaries)} dims = ['cluster_boundary'] diff --git a/flixopt/comparison.py b/flixopt/comparison.py index b615f9ae8..d7db72975 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -364,12 +364,6 @@ def _combine_data(self, method_name: str, *args, **kwargs) -> tuple[xr.Dataset, return xr.concat(datasets, dim='case', join='outer', fill_value=float('nan')), title - def _resolve_facets(self, ds: xr.Dataset, facet_col='auto', facet_row='auto', animation_frame='auto'): - """Resolve auto facets.""" - from .statistics_accessor import _resolve_auto_facets - - return _resolve_auto_facets(ds, facet_col, facet_row, animation_frame) - def _finalize(self, ds: xr.Dataset, fig, show: bool | None) -> PlotResult: """Handle show and return PlotResult.""" import plotly.graph_objects as go @@ -396,9 +390,13 @@ def balance( ds, title = self._combine_data('balance', node, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) fig = ds.fxplot.stacked_bar( - colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **plotly_kw, ) return self._finalize(ds, fig, show) @@ -418,9 +416,13 @@ def carrier_balance( ds, title = self._combine_data('carrier_balance', carrier, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) fig = ds.fxplot.stacked_bar( - colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **plotly_kw, ) return self._finalize(ds, fig, show) @@ -439,9 +441,13 @@ def flows( ds, title = self._combine_data('flows', **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) fig = ds.fxplot.line( - colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **plotly_kw, ) return self._finalize(ds, fig, show) @@ -461,9 +467,13 @@ def storage( ds, title = self._combine_data('storage', storage, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) fig = ds.fxplot.stacked_bar( - colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **plotly_kw, ) return self._finalize(ds, fig, show) @@ -483,9 +493,13 @@ def charge_states( ds, title = self._combine_data('charge_states', storages, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) fig = ds.fxplot.line( - colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **plotly_kw, ) fig.update_yaxes(title_text='Charge State') return self._finalize(ds, fig, show) @@ -507,9 +521,13 @@ def duration_curve( ds, title = self._combine_data('duration_curve', variables, normalize=normalize, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) fig = ds.fxplot.line( - colors=colors, title=title, facet_col=col, facet_row=row, animation_frame=anim, **plotly_kw + colors=colors, + title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + **plotly_kw, ) fig.update_xaxes(title_text='Duration [%]' if normalize else 'Timesteps') return self._finalize(ds, fig, show) @@ -525,28 +543,19 @@ def sizes( **kwargs, ) -> PlotResult: """Plot investment sizes comparison. See StatisticsPlotAccessor.sizes.""" - import plotly.express as px - - from .color_processing import process_colors - from .statistics_accessor import _dataset_to_long_df - data_kw, plotly_kw = self._split_kwargs('sizes', kwargs) ds, title = self._combine_data('sizes', **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) - df = _dataset_to_long_df(ds) - color_map = process_colors(colors, df['variable'].unique().tolist()) if not df.empty else None - fig = px.bar( - df, + fig = ds.fxplot.bar( x='variable', - y='value', color='variable', + colors=colors, title=title, - facet_col=col, - facet_row=row, - animation_frame=anim, - color_discrete_map=color_map, + ylabel='Size', + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kw, ) return self._finalize(ds, fig, show) @@ -563,35 +572,27 @@ def effects( **kwargs, ) -> PlotResult: """Plot effects comparison. See StatisticsPlotAccessor.effects.""" - import plotly.express as px - - from .color_processing import process_colors - data_kw, plotly_kw = self._split_kwargs('effects', kwargs) ds, title = self._combine_data('effects', aspect, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, row, anim = self._resolve_facets(ds, facet_col, facet_row, animation_frame) - # Get the data array and convert to dataframe + # Get the data array da = ds[aspect] if aspect in ds else ds[next(iter(ds.data_vars))] - df = da.to_dataframe(name='value').reset_index() by = data_kw.get('by') x_col = by if by else 'effect' - color_col = x_col if x_col in df.columns else None - color_map = process_colors(colors, df[color_col].unique().tolist()) if color_col else None - fig = px.bar( - df, + # Convert to Dataset along 'effect' dimension (each effect becomes a variable) + plot_ds = da.to_dataset(dim='effect') if 'effect' in da.dims else da.to_dataset(name=aspect) + fig = plot_ds.fxplot.bar( x=x_col, - y='value', - color=color_col, + color=x_col, + colors=colors, title=title, - facet_col=col, - facet_row=row, - animation_frame=anim, - color_discrete_map=color_map, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kw, ) fig.update_layout(bargap=0, bargroupgap=0) @@ -606,7 +607,6 @@ def heatmap( ds, _ = self._combine_data('heatmap', variables, **data_kw) if not ds.data_vars: return self._finalize(ds, None, show) - col, _, anim = self._resolve_facets(ds, facet_col, None, animation_frame) da = ds[next(iter(ds.data_vars))] - fig = da.fxplot.heatmap(colors=colors, facet_col=col, animation_frame=anim, **plotly_kw) + fig = da.fxplot.heatmap(colors=colors, facet_col=facet_col, animation_frame=animation_frame, **plotly_kw) return self._finalize(ds, fig, show) diff --git a/flixopt/components.py b/flixopt/components.py index 390fc6f02..e962791d8 100644 --- a/flixopt/components.py +++ b/flixopt/components.py @@ -1195,7 +1195,7 @@ class InterclusterStorageModel(StorageModel): Variables Created ----------------- - ``SOC_boundary``: Absolute SOC at each original period boundary. - Shape: (n_original_periods + 1,) plus any period/scenario dimensions. + Shape: (n_original_clusters + 1,) plus any period/scenario dimensions. Constraints Created ------------------- @@ -1330,7 +1330,7 @@ def _add_intercluster_linking(self) -> None: else int(cluster_structure.n_clusters.values) ) timesteps_per_cluster = cluster_structure.timesteps_per_cluster - n_original_periods = cluster_structure.n_original_periods + n_original_clusters = cluster_structure.n_original_clusters cluster_order = cluster_structure.cluster_order # 1. Constrain ΔE = 0 at cluster starts @@ -1338,7 +1338,7 @@ def _add_intercluster_linking(self) -> None: # 2. Create SOC_boundary variable flow_system = self._model.flow_system - boundary_coords, boundary_dims = build_boundary_coords(n_original_periods, flow_system) + boundary_coords, boundary_dims = build_boundary_coords(n_original_clusters, flow_system) capacity_bounds = extract_capacity_bounds(self.element.capacity_in_flow_hours, boundary_coords, boundary_dims) soc_boundary = self.add_variables( @@ -1360,12 +1360,14 @@ def _add_intercluster_linking(self) -> None: delta_soc = self._compute_delta_soc(n_clusters, timesteps_per_cluster) # 5. Add linking constraints - self._add_linking_constraints(soc_boundary, delta_soc, cluster_order, n_original_periods, timesteps_per_cluster) + self._add_linking_constraints( + soc_boundary, delta_soc, cluster_order, n_original_clusters, timesteps_per_cluster + ) # 6. Add cyclic or initial constraint if self.element.cluster_mode == 'intercluster_cyclic': self.add_constraints( - soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_periods), + soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_clusters), short_name='cyclic', ) else: @@ -1375,7 +1377,8 @@ def _add_intercluster_linking(self) -> None: if isinstance(initial, str): # 'equals_final' means cyclic self.add_constraints( - soc_boundary.isel(cluster_boundary=0) == soc_boundary.isel(cluster_boundary=n_original_periods), + soc_boundary.isel(cluster_boundary=0) + == soc_boundary.isel(cluster_boundary=n_original_clusters), short_name='initial_SOC_boundary', ) else: @@ -1389,7 +1392,7 @@ def _add_intercluster_linking(self) -> None: soc_boundary, cluster_order, capacity_bounds.has_investment, - n_original_periods, + n_original_clusters, timesteps_per_cluster, ) @@ -1438,7 +1441,7 @@ def _add_linking_constraints( soc_boundary: xr.DataArray, delta_soc: xr.DataArray, cluster_order: xr.DataArray, - n_original_periods: int, + n_original_clusters: int, timesteps_per_cluster: int, ) -> None: """Add constraints linking consecutive SOC_boundary values. @@ -1455,17 +1458,17 @@ def _add_linking_constraints( soc_boundary: SOC_boundary variable. delta_soc: Net SOC change per cluster. cluster_order: Mapping from original periods to representative clusters. - n_original_periods: Number of original (non-clustered) periods. + n_original_clusters: Number of original (non-clustered) periods. timesteps_per_cluster: Number of timesteps in each cluster period. """ soc_after = soc_boundary.isel(cluster_boundary=slice(1, None)) soc_before = soc_boundary.isel(cluster_boundary=slice(None, -1)) # Rename for alignment - soc_after = soc_after.rename({'cluster_boundary': 'original_period'}) - soc_after = soc_after.assign_coords(original_period=np.arange(n_original_periods)) - soc_before = soc_before.rename({'cluster_boundary': 'original_period'}) - soc_before = soc_before.assign_coords(original_period=np.arange(n_original_periods)) + soc_after = soc_after.rename({'cluster_boundary': 'original_cluster'}) + soc_after = soc_after.assign_coords(original_cluster=np.arange(n_original_clusters)) + soc_before = soc_before.rename({'cluster_boundary': 'original_cluster'}) + soc_before = soc_before.assign_coords(original_cluster=np.arange(n_original_clusters)) # Get delta_soc for each original period using cluster_order delta_soc_ordered = delta_soc.isel(cluster=cluster_order) @@ -1484,7 +1487,7 @@ def _add_combined_bound_constraints( soc_boundary: xr.DataArray, cluster_order: xr.DataArray, has_investment: bool, - n_original_periods: int, + n_original_clusters: int, timesteps_per_cluster: int, ) -> None: """Add constraints ensuring actual SOC stays within bounds. @@ -1498,21 +1501,21 @@ def _add_combined_bound_constraints( middle, and end of each cluster. With 2D (cluster, time) structure, we simply select charge_state at a - given time offset, then reorder by cluster_order to get original_period order. + given time offset, then reorder by cluster_order to get original_cluster order. Args: soc_boundary: SOC_boundary variable. cluster_order: Mapping from original periods to clusters. has_investment: Whether the storage has investment sizing. - n_original_periods: Number of original periods. + n_original_clusters: Number of original periods. timesteps_per_cluster: Timesteps in each cluster. """ charge_state = self.charge_state # soc_d: SOC at start of each original period soc_d = soc_boundary.isel(cluster_boundary=slice(None, -1)) - soc_d = soc_d.rename({'cluster_boundary': 'original_period'}) - soc_d = soc_d.assign_coords(original_period=np.arange(n_original_periods)) + soc_d = soc_d.rename({'cluster_boundary': 'original_cluster'}) + soc_d = soc_d.assign_coords(original_cluster=np.arange(n_original_clusters)) # Get self-discharge rate for decay calculation # Keep as DataArray to respect per-period/scenario values @@ -1523,13 +1526,13 @@ def _add_combined_bound_constraints( for sample_name, offset in zip(['start', 'mid', 'end'], sample_offsets, strict=False): # With 2D structure: select time offset, then reorder by cluster_order cs_at_offset = charge_state.isel(time=offset) # Shape: (cluster, ...) - # Reorder to original_period order using cluster_order indexer + # Reorder to original_cluster order using cluster_order indexer cs_t = cs_at_offset.isel(cluster=cluster_order) # Suppress xarray warning about index loss - we immediately assign new coords anyway with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='.*does not create an index anymore.*') - cs_t = cs_t.rename({'cluster': 'original_period'}) - cs_t = cs_t.assign_coords(original_period=np.arange(n_original_periods)) + cs_t = cs_t.rename({'cluster': 'original_cluster'}) + cs_t = cs_t.assign_coords(original_cluster=np.arange(n_original_clusters)) # Apply decay factor (1-loss)^t to SOC_boundary per Eq. 9 decay_t = (1 - rel_loss) ** offset diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index fc38f730b..47cb0564a 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from typing import Any, Literal import pandas as pd @@ -13,59 +14,130 @@ from .config import CONFIG -def _get_x_dim(dims: list[str], x: str | Literal['auto'] | None = 'auto') -> str: - """Select x-axis dim from priority list, or 'variable' for scalar data.""" - if x and x != 'auto': - return x - - # Check priority list first - for dim in CONFIG.Plotting.x_dim_priority: - if dim in dims: - return dim - - # Fallback to first available dimension, or 'variable' for scalar data - return dims[0] if dims else 'variable' - - -def _resolve_auto_facets( +def assign_slots( ds: xr.Dataset, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, - animation_frame: str | Literal['auto'] | None = None, + *, + x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', + facet_col: str | Literal['auto'] | None = 'auto', + facet_row: str | Literal['auto'] | None = 'auto', + animation_frame: str | Literal['auto'] | None = 'auto', exclude_dims: set[str] | None = None, -) -> tuple[str | None, str | None, str | None]: - """Assign 'auto' facet slots from available dims using CONFIG priority lists.""" - # Get available extra dimensions with size > 1, excluding specified dims +) -> dict[str, str | None]: + """Assign dimensions to plot slots using CONFIG.Plotting.dim_priority. + + Dimensions are assigned in priority order to slots based on CONFIG.Plotting.slot_priority. + + Slot values: + - 'auto': auto-assign from available dims using priority + - None: skip this slot (not available for this plot type) + - str: use this specific dimension + + 'variable' is treated as a dimension when len(data_vars) > 1. It represents + the data_var names column in the melted DataFrame. + + Args: + ds: Dataset to analyze for available dimensions. + x: X-axis dimension. 'auto' assigns first available from priority. + color: Color grouping dimension. + facet_col: Column faceting dimension. + facet_row: Row faceting dimension. + animation_frame: Animation slider dimension. + exclude_dims: Dimensions to exclude from auto-assignment (e.g., already used for x elsewhere). + + Returns: + Dict with keys 'x', 'color', 'facet_col', 'facet_row', 'animation_frame' + and values being assigned dimension names (or None if slot skipped/unfilled). + """ + # Get available dimensions with size > 1, excluding specified dims exclude = exclude_dims or set() available = {d for d in ds.dims if ds.sizes[d] > 1 and d not in exclude} - extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] - used: set[str] = set() + # 'variable' is available when there are multiple data_vars (and not excluded) + if len(ds.data_vars) > 1 and 'variable' not in exclude: + available.add('variable') - # Map slot names to their input values + # Get priority-ordered list of available dims + priority_dims = [d for d in CONFIG.Plotting.dim_priority if d in available] + # Add any available dims not in priority list (fallback) + priority_dims.extend(d for d in available if d not in priority_dims) + + # Slot specification slots = { + 'x': x, + 'color': color, 'facet_col': facet_col, 'facet_row': facet_row, 'animation_frame': animation_frame, } - results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} + # Slot fill order from config + slot_order = CONFIG.Plotting.slot_priority + + results: dict[str, str | None] = {k: None for k in slot_order} + used: set[str] = set() # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used for slot_name, value in slots.items(): if value is not None and value != 'auto': - if value in available and value not in used: - used.add(value) - results[slot_name] = value - - # Second pass: resolve 'auto' slots in dim_slot_priority order - dim_iter = iter(d for d in extra_dims if d not in used) - for slot_name in CONFIG.Plotting.dim_slot_priority: - if slots.get(slot_name) == 'auto': + used.add(value) + results[slot_name] = value + + # Second pass: resolve 'auto' slots in config-defined fill order + dim_iter = iter(d for d in priority_dims if d not in used) + for slot_name in slot_order: + if slots[slot_name] == 'auto': next_dim = next(dim_iter, None) if next_dim: used.add(next_dim) results[slot_name] = next_dim - return results['facet_col'], results['facet_row'], results['animation_frame'] + # Warn if any dimensions were not assigned to any slot + unassigned = available - used + if unassigned: + available_slots = [k for k, v in slots.items() if v is not None] + unavailable_slots = [k for k, v in slots.items() if v is None] + if unavailable_slots: + warnings.warn( + f'Dimensions {unassigned} not assigned to any plot dimension. ' + f'Not available for this plot type: {unavailable_slots}. ' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).', + stacklevel=3, + ) + else: + warnings.warn( + f'Dimensions {unassigned} not assigned to any plot dimension ({available_slots}). ' + f'Reduce dimensions before plotting (e.g., .sel(), .isel(), .mean()).', + stacklevel=3, + ) + + return results + + +def _build_fig_kwargs( + slots: dict[str, str | None], + ds_sizes: dict[str, int], + px_kwargs: dict[str, Any], + facet_cols: int | None = None, +) -> dict[str, Any]: + """Build plotly express kwargs from slot assignments. + + Adds facet/animation args only if slots are assigned and not overridden in px_kwargs. + Handles facet_col_wrap based on dimension size. + """ + facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols + result: dict[str, Any] = {} + + # Add facet/animation kwargs from slots (skip if None or already in px_kwargs) + for slot in ('color', 'facet_col', 'facet_row', 'animation_frame'): + if slots.get(slot) and slot not in px_kwargs: + result[slot] = slots[slot] + + # Add facet_col_wrap if facet_col is set and dimension is large enough + if result.get('facet_col'): + dim_size = ds_sizes.get(result['facet_col'], facet_col_wrap + 1) + if facet_col_wrap < dim_size: + result['facet_col_wrap'] = facet_col_wrap + + return result def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: @@ -120,6 +192,7 @@ def bar( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -128,12 +201,15 @@ def bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a grouped bar chart from the dataset. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -142,57 +218,46 @@ def bar( facet_row: Dimension for row facets. 'auto' uses CONFIG priority. animation_frame: Dimension for animation slider. facet_cols: Number of columns in facet grid wrap. + exclude_dims: Dimensions to exclude from auto-assignment. **px_kwargs: Additional arguments passed to plotly.express.bar. Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'barmode': 'group', + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' - fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim - return px.bar(**{**fig_kwargs, **px_kwargs}) def stacked_bar( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -201,6 +266,7 @@ def stacked_bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked bar chart from the dataset. @@ -209,7 +275,9 @@ def stacked_bar( values are stacked separately. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -223,45 +291,32 @@ def stacked_bar( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' - fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim - fig = px.bar(**{**fig_kwargs, **px_kwargs}) fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) fig.update_traces(marker_line_width=0) @@ -271,6 +326,7 @@ def line( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -280,6 +336,7 @@ def line( animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a line chart from the dataset. @@ -287,7 +344,9 @@ def line( Each variable in the dataset becomes a separate line. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -303,52 +362,40 @@ def line( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' - fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim - return px.line(**{**fig_kwargs, **px_kwargs}) def area( self, *, x: str | Literal['auto'] | None = 'auto', + color: str | Literal['auto'] | None = 'auto', colors: ColorType | None = None, title: str = '', xlabel: str = '', @@ -358,12 +405,15 @@ def area( animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, line_shape: str | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked area chart from the dataset. Args: - x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.x_dim_priority. + x: Dimension for x-axis. 'auto' uses CONFIG.Plotting.dim_priority. + color: Dimension for color grouping. 'auto' uses 'variable' (data_var names) + if available, otherwise uses CONFIG priority. colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. xlabel: X-axis label. @@ -378,46 +428,33 @@ def area( Returns: Plotly Figure. """ - # Determine x-axis first, then resolve facets from remaining dims - dims = list(self._ds.dims) - x_col = _get_x_dim(dims, x) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame, exclude_dims={x_col} + slots = assign_slots( + self._ds, + x=x, + color=color, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, + exclude_dims=exclude_dims, ) - df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) + color_labels = df[slots['color']].unique().tolist() if slots['color'] and slots['color'] in df.columns else [] + color_map = process_colors(colors, color_labels, CONFIG.Plotting.default_qualitative_colorscale) - facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - fig_kwargs: dict[str, Any] = { + labels = {**(({slots['x']: xlabel}) if xlabel and slots['x'] else {}), **({'value': ylabel} if ylabel else {})} + fig_kwargs = { 'data_frame': df, - 'x': x_col, + 'x': slots['x'], 'y': 'value', 'title': title, 'line_shape': line_shape or CONFIG.Plotting.default_line_shape, + 'color_discrete_map': color_map, + **({'labels': labels} if labels else {}), + **_build_fig_kwargs(slots, dict(self._ds.sizes), px_kwargs, facet_cols), } - # Only color by variable if it's not already on x-axis (and user didn't override) - if x_col != 'variable' and 'color' not in px_kwargs: - fig_kwargs['color'] = 'variable' - fig_kwargs['color_discrete_map'] = color_map - if xlabel: - fig_kwargs['labels'] = {x_col: xlabel} - if ylabel: - fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), 'value': ylabel} - - if actual_facet_col and 'facet_col' not in px_kwargs: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): - fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row and 'facet_row' not in px_kwargs: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim and 'animation_frame' not in px_kwargs: - fig_kwargs['animation_frame'] = actual_anim - return px.area(**{**fig_kwargs, **px_kwargs}) def heatmap( @@ -467,7 +504,25 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - actual_facet_col, _, actual_anim = _resolve_auto_facets(self._ds, facet_col, None, animation_frame) + # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap + # Only call assign_slots if we need to resolve 'auto' values + if facet_col == 'auto' or animation_frame == 'auto': + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() + slots = assign_slots( + self._ds, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, + ) + resolved_facet = slots['facet_col'] + resolved_animation = slots['animation_frame'] + else: + # Values already resolved (or None), use directly without re-resolving + resolved_facet = facet_col + resolved_animation = animation_frame imshow_args: dict[str, Any] = { 'img': da, @@ -475,13 +530,17 @@ def heatmap( 'title': title or variable, } - if actual_facet_col and actual_facet_col in da.dims: - imshow_args['facet_col'] = actual_facet_col - if facet_col_wrap < da.sizes[actual_facet_col]: + if resolved_facet and resolved_facet in da.dims: + imshow_args['facet_col'] = resolved_facet + if facet_col_wrap < da.sizes[resolved_facet]: imshow_args['facet_col_wrap'] = facet_col_wrap - if actual_anim and actual_anim in da.dims: - imshow_args['animation_frame'] = actual_anim + if resolved_animation and resolved_animation in da.dims: + imshow_args['animation_frame'] = resolved_animation + + # Use binary_string=False to handle non-numeric coords (e.g., string labels) + if 'binary_string' not in imshow_kwargs: + imshow_args['binary_string'] = False return px.imshow(**{**imshow_args, **imshow_kwargs}) @@ -525,8 +584,9 @@ def scatter( if df.empty: return go.Figure() - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame + # Scatter uses explicit x/y variable names, not dimensions + slots = assign_slots( + self._ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=animation_frame ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols @@ -542,14 +602,16 @@ def scatter( if ylabel: fig_kwargs['labels'] = {**fig_kwargs.get('labels', {}), y: ylabel} - if actual_facet_col: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + # Only use facets if the column actually exists in the dataframe + # (scatter uses wide format, so 'variable' column doesn't exist) + if slots['facet_col'] and slots['facet_col'] in df.columns: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row'] and slots['facet_row'] in df.columns: + fig_kwargs['facet_row'] = slots['facet_row'] + if slots['animation_frame'] and slots['animation_frame'] in df.columns: + fig_kwargs['animation_frame'] = slots['animation_frame'] return px.scatter(**fig_kwargs) @@ -560,21 +622,22 @@ def pie( title: str = '', facet_col: str | Literal['auto'] | None = 'auto', facet_row: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a pie chart from aggregated dataset values. - Extra dimensions are auto-assigned to facet_col, facet_row, and animation_frame. + Extra dimensions are auto-assigned to facet_col and facet_row. For scalar values, a single pie is shown. + Note: + ``px.pie()`` does not support animation_frame, so only facets are available. + Args: colors: Color specification (colorscale name, color list, or dict mapping). title: Plot title. facet_col: Dimension for column facets. 'auto' uses CONFIG priority. facet_row: Dimension for row facets. 'auto' uses CONFIG priority. - animation_frame: Dimension for animation slider. 'auto' uses CONFIG priority. facet_cols: Number of columns in facet grid wrap. **px_kwargs: Additional arguments passed to plotly.express.pie. @@ -604,13 +667,14 @@ def pie( **px_kwargs, ) - # Multi-dimensional case - faceted/animated pies + # Multi-dimensional case - faceted pies (px.pie doesn't support animation_frame) df = _dataset_to_long_df(self._ds) if df.empty: return go.Figure() - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - self._ds, facet_col, facet_row, animation_frame + # Pie uses 'variable' for names and 'value' for values, no x/color/animation_frame + slots = assign_slots( + self._ds, x=None, color=None, facet_col=facet_col, facet_row=facet_row, animation_frame=None ) facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols @@ -624,14 +688,12 @@ def pie( **px_kwargs, } - if actual_facet_col: - fig_kwargs['facet_col'] = actual_facet_col - if facet_col_wrap < self._ds.sizes.get(actual_facet_col, facet_col_wrap + 1): + if slots['facet_col']: + fig_kwargs['facet_col'] = slots['facet_col'] + if facet_col_wrap < self._ds.sizes.get(slots['facet_col'], facet_col_wrap + 1): fig_kwargs['facet_col_wrap'] = facet_col_wrap - if actual_facet_row: - fig_kwargs['facet_row'] = actual_facet_row - if actual_anim: - fig_kwargs['animation_frame'] = actual_anim + if slots['facet_row']: + fig_kwargs['facet_row'] = slots['facet_row'] return px.pie(**fig_kwargs) @@ -763,6 +825,7 @@ def stacked_bar( facet_row: str | Literal['auto'] | None = 'auto', animation_frame: str | Literal['auto'] | None = 'auto', facet_cols: int | None = None, + exclude_dims: set[str] | None = None, **px_kwargs: Any, ) -> go.Figure: """Create a stacked bar chart. See DatasetPlotAccessor.stacked_bar for details.""" @@ -870,9 +933,26 @@ def heatmap( colors = colors or CONFIG.Plotting.default_sequential_colorscale facet_col_wrap = facet_cols or CONFIG.Plotting.default_facet_cols - # Use Dataset for facet resolution - ds_for_resolution = da.to_dataset(name='_temp') - actual_facet_col, _, actual_anim = _resolve_auto_facets(ds_for_resolution, facet_col, None, animation_frame) + # Heatmap uses imshow - first 2 dims are the x/y axes of the heatmap + # Only call assign_slots if we need to resolve 'auto' values + if facet_col == 'auto' or animation_frame == 'auto': + heatmap_axes = set(list(da.dims)[:2]) if len(da.dims) >= 2 else set() + ds_for_resolution = da.to_dataset(name='_temp') + slots = assign_slots( + ds_for_resolution, + x=None, + color=None, + facet_col=facet_col, + facet_row=None, + animation_frame=animation_frame, + exclude_dims=heatmap_axes, + ) + resolved_facet = slots['facet_col'] + resolved_animation = slots['animation_frame'] + else: + # Values already resolved (or None), use directly without re-resolving + resolved_facet = facet_col + resolved_animation = animation_frame imshow_args: dict[str, Any] = { 'img': da, @@ -880,12 +960,16 @@ def heatmap( 'title': title or (da.name if da.name else ''), } - if actual_facet_col and actual_facet_col in da.dims: - imshow_args['facet_col'] = actual_facet_col - if facet_col_wrap < da.sizes[actual_facet_col]: + if resolved_facet and resolved_facet in da.dims: + imshow_args['facet_col'] = resolved_facet + if facet_col_wrap < da.sizes[resolved_facet]: imshow_args['facet_col_wrap'] = facet_col_wrap - if actual_anim and actual_anim in da.dims: - imshow_args['animation_frame'] = actual_anim + if resolved_animation and resolved_animation in da.dims: + imshow_args['animation_frame'] = resolved_animation + + # Use binary_string=False to handle non-numeric coords (e.g., string labels) + if 'binary_string' not in imshow_kwargs: + imshow_args['binary_string'] = False return px.imshow(**{**imshow_args, **imshow_kwargs}) diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index 382ed1bf0..e3581e4e3 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -127,6 +127,50 @@ def _reshape_time_for_heatmap( # --- Helper functions --- +def _prepare_for_heatmap( + da: xr.DataArray, + reshape: tuple[str, str] | Literal['auto'] | None, + facet_col: str | Literal['auto'] | None, + animation_frame: str | Literal['auto'] | None, +) -> xr.DataArray: + """Prepare DataArray for heatmap: determine axes, reshape if needed, transpose/squeeze.""" + + def finalize(da: xr.DataArray, heatmap_dims: list[str]) -> xr.DataArray: + """Transpose, squeeze, and clear name if needed.""" + other = [d for d in da.dims if d not in heatmap_dims] + da = da.transpose(*[d for d in heatmap_dims if d in da.dims], *other) + for dim in [d for d in da.dims if d not in heatmap_dims and da.sizes[d] == 1]: + da = da.squeeze(dim, drop=True) + return da.rename('') if da.sizes.get('variable', 1) > 1 else da + + def fallback_dims() -> list[str]: + """Default dims: (variable, time) if multi-var, else first 2 dims with size > 1.""" + if da.sizes.get('variable', 1) > 1: + return ['variable', 'time'] + dims = [d for d in da.dims if da.sizes[d] > 1][:2] + return dims if len(dims) >= 2 else list(da.dims)[:2] + + is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 + has_time = 'time' in da.dims + + # Clustered: use (time, cluster) as natural 2D + if is_clustered and reshape in ('auto', None): + return finalize(da, ['time', 'cluster']) + + # Explicit reshape: always apply + if reshape and reshape != 'auto' and has_time: + return finalize(_reshape_time_for_heatmap(da, reshape), ['timestep', 'timeframe']) + + # Auto reshape (non-clustered): apply only if extra dims fit in available slots + if reshape == 'auto' and has_time: + extra = [d for d in da.dims if d not in ('time', 'variable') and da.sizes[d] > 1] + slots = (facet_col == 'auto') + (animation_frame == 'auto') + if len(extra) <= slots: + return finalize(_reshape_time_for_heatmap(da, ('D', 'h')), ['timestep', 'timeframe']) + + return finalize(da, fallback_dims()) + + def _filter_by_pattern( names: list[str], include: FilterType | None, @@ -180,73 +224,6 @@ def _filter_by_carrier(ds: xr.Dataset, carrier: str | list[str] | None) -> xr.Da return ds[matching_vars] if matching_vars else xr.Dataset() -def _resolve_auto_facets( - ds: xr.Dataset, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, - animation_frame: str | Literal['auto'] | None = None, -) -> tuple[str | None, str | None, str | None]: - """Resolve 'auto' facet/animation dimensions based on available data dimensions. - - When 'auto' is specified, extra dimensions are assigned to slots based on: - - CONFIG.Plotting.extra_dim_priority: Order of dimensions to assign. - - CONFIG.Plotting.dim_slot_priority: Order of slots to fill. - - Args: - ds: Dataset to check for available dimensions. - facet_col: Dimension name, 'auto', or None. - facet_row: Dimension name, 'auto', or None. - animation_frame: Dimension name, 'auto', or None. - - Returns: - Tuple of (resolved_facet_col, resolved_facet_row, resolved_animation_frame). - Each is either a valid dimension name or None. - """ - # Get available extra dimensions with size > 1, sorted by priority - available = {d for d in ds.dims if ds.sizes[d] > 1} - extra_dims = [d for d in CONFIG.Plotting.extra_dim_priority if d in available] - used: set[str] = set() - - # Map slot names to their input values - slots = { - 'facet_col': facet_col, - 'facet_row': facet_row, - 'animation_frame': animation_frame, - } - results: dict[str, str | None] = {'facet_col': None, 'facet_row': None, 'animation_frame': None} - - # First pass: resolve explicit dimensions (not 'auto' or None) to mark them as used - for slot_name, value in slots.items(): - if value is not None and value != 'auto': - if value in available and value not in used: - used.add(value) - results[slot_name] = value - - # Second pass: resolve 'auto' slots in dim_slot_priority order - dim_iter = iter(d for d in extra_dims if d not in used) - for slot_name in CONFIG.Plotting.dim_slot_priority: - if slots.get(slot_name) == 'auto': - next_dim = next(dim_iter, None) - if next_dim: - used.add(next_dim) - results[slot_name] = next_dim - - return results['facet_col'], results['facet_row'], results['animation_frame'] - - -def _resolve_facets( - ds: xr.Dataset, - facet_col: str | Literal['auto'] | None, - facet_row: str | Literal['auto'] | None, -) -> tuple[str | None, str | None]: - """Resolve facet dimensions, returning None if not present in data. - - Legacy wrapper for _resolve_auto_facets for backward compatibility. - """ - resolved_col, resolved_row, _ = _resolve_auto_facets(ds, facet_col, facet_row, None) - return resolved_col, resolved_row - - def _dataset_to_long_df(ds: xr.Dataset, value_name: str = 'value', var_name: str = 'variable') -> pd.DataFrame: """Convert xarray Dataset to long-form DataFrame for plotly express.""" if not ds.data_vars: @@ -1382,9 +1359,6 @@ def balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) # Build color map from Element.color attributes if no colors specified if colors is None: @@ -1399,9 +1373,9 @@ def balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{node} [{unit_label}]' if unit_label else node, - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1493,9 +1467,6 @@ def carrier_balance( ds[label] = -ds[label] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) # Use cached component colors for flows if colors is None: @@ -1523,9 +1494,9 @@ def carrier_balance( fig = ds.fxplot.stacked_bar( colors=colors, title=f'{carrier.capitalize()} Balance [{unit_label}]' if unit_label else f'{carrier.capitalize()} Balance', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1544,7 +1515,7 @@ def heatmap( reshape: tuple[str, str] | Literal['auto'] | None = 'auto', colors: str | list[str] | None = None, facet_col: str | Literal['auto'] | None = 'auto', - animation_frame: str | Literal['auto'] | None = None, + animation_frame: str | Literal['auto'] | None = 'auto', show: bool | None = None, **plotly_kwargs: Any, ) -> PlotResult: @@ -1576,97 +1547,25 @@ def heatmap( PlotResult with processed data and figure. """ solution = self._stats._require_solution() - if isinstance(variables, str): variables = [variables] - # Resolve flow labels to variable names - resolved_variables = self._resolve_variable_names(variables, solution) - - ds = solution[resolved_variables] - ds = _apply_selection(ds, select) - - # Stack variables into single DataArray - variable_names = list(ds.data_vars) - dataarrays = [ds[var] for var in variable_names] - da = xr.concat(dataarrays, dim=pd.Index(variable_names, name='variable')) - - # Check if data is clustered (has cluster dimension with size > 1) - is_clustered = 'cluster' in da.dims and da.sizes['cluster'] > 1 - - # Determine facet and animation from available dims - has_multiple_vars = 'variable' in da.dims and da.sizes['variable'] > 1 - - if has_multiple_vars: - actual_facet = 'variable' - # Resolve animation using auto logic, excluding 'variable' which is used for facet - _, _, actual_animation = _resolve_auto_facets(da.to_dataset(name='value'), None, None, animation_frame) - if actual_animation == 'variable': - actual_animation = None - else: - # Resolve facet and animation using auto logic - actual_facet, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame - ) - - # Determine heatmap dimensions based on data structure - if is_clustered and (reshape == 'auto' or reshape is None): - # Clustered data: use (time, cluster) as natural 2D heatmap axes - heatmap_dims = ['time', 'cluster'] - elif reshape and reshape != 'auto' and 'time' in da.dims: - # Non-clustered with explicit reshape: reshape time to (day, hour) etc. - # Extra dims will be handled via facet/animation or dropped - da = _reshape_time_for_heatmap(da, reshape) - heatmap_dims = ['timestep', 'timeframe'] - elif reshape == 'auto' and 'time' in da.dims and not is_clustered: - # Auto mode for non-clustered: use default ('D', 'h') reshape - # Extra dims will be handled via facet/animation or dropped - da = _reshape_time_for_heatmap(da, ('D', 'h')) - heatmap_dims = ['timestep', 'timeframe'] - elif has_multiple_vars: - # Can't reshape but have multiple vars: use variable + time as heatmap axes - heatmap_dims = ['variable', 'time'] - # variable is now a heatmap dim, use period/scenario for facet/animation - actual_facet, _, actual_animation = _resolve_auto_facets( - da.to_dataset(name='value'), facet_col, None, animation_frame - ) - else: - # Fallback: use first two available dimensions - available_dims = [d for d in da.dims if da.sizes[d] > 1] - if len(available_dims) >= 2: - heatmap_dims = available_dims[:2] - elif 'time' in da.dims: - heatmap_dims = ['time'] - else: - heatmap_dims = list(da.dims)[:1] - - # Keep only dims we need - keep_dims = set(heatmap_dims) | {d for d in [actual_facet, actual_animation] if d is not None} - for dim in [d for d in da.dims if d not in keep_dims]: - da = da.isel({dim: 0}, drop=True) if da.sizes[dim] > 1 else da.squeeze(dim, drop=True) + # Resolve, select, and stack into single DataArray + resolved = self._resolve_variable_names(variables, solution) + ds = _apply_selection(solution[resolved], select) + da = xr.concat([ds[v] for v in ds.data_vars], dim=pd.Index(list(ds.data_vars), name='variable')) - # Transpose to expected order - dim_order = heatmap_dims + [d for d in [actual_facet, actual_animation] if d] - da = da.transpose(*dim_order) + # Prepare for heatmap (reshape, transpose, squeeze) + da = _prepare_for_heatmap(da, reshape, facet_col, animation_frame) - # Clear name for multiple variables (colorbar would show first var's name) - if has_multiple_vars: - da = da.rename('') - - fig = da.fxplot.heatmap( - colors=colors, - facet_col=actual_facet, - animation_frame=actual_animation, - **plotly_kwargs, - ) + fig = da.fxplot.heatmap(colors=colors, facet_col=facet_col, animation_frame=animation_frame, **plotly_kwargs) if show is None: show = CONFIG.Plotting.default_show if show: fig.show() - reshaped_ds = da.to_dataset(name='value') if isinstance(da, xr.DataArray) else da - return PlotResult(data=reshaped_ds, figure=fig) + return PlotResult(data=da.to_dataset(name='value'), figure=fig) def flows( self, @@ -1737,9 +1636,6 @@ def flows( ds = ds[[lbl for lbl in matching_labels if lbl in ds]] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) # Get unit label from first data variable's attributes unit_label = '' @@ -1750,9 +1646,9 @@ def flows( fig = ds.fxplot.line( colors=colors, title=f'Flows [{unit_label}]' if unit_label else 'Flows', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1798,27 +1694,18 @@ def sizes( valid_labels = [lbl for lbl in ds.data_vars if float(ds[lbl].max()) < max_size] ds = ds[valid_labels] - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - - df = _dataset_to_long_df(ds) - if df.empty: + if not ds.data_vars: fig = go.Figure() else: - variables = df['variable'].unique().tolist() - color_map = process_colors(colors, variables) - fig = px.bar( - df, + fig = ds.fxplot.bar( x='variable', - y='value', color='variable', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - color_discrete_map=color_map, + colors=colors, title='Investment Sizes', - labels={'variable': 'Flow', 'value': 'Size'}, + ylabel='Size', + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -1913,10 +1800,6 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: duration_coord = np.linspace(0, 100, n_timesteps) if normalize else np.arange(n_timesteps) result_ds = result_ds.assign_coords({duration_name: duration_coord}) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - result_ds, facet_col, facet_row, animation_frame - ) - # Get unit label from first data variable's attributes unit_label = '' if ds.data_vars: @@ -1926,9 +1809,9 @@ def sort_descending(arr: np.ndarray) -> np.ndarray: fig = result_ds.fxplot.line( colors=colors, title=f'Duration Curve [{unit_label}]' if unit_label else 'Duration Curve', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) @@ -2057,14 +1940,14 @@ def effects( else: raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") - # Resolve facets - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - combined.to_dataset(name='value'), facet_col, facet_row, animation_frame - ) - # Convert to DataFrame for plotly express df = combined.to_dataframe(name='value').reset_index() + # Resolve facet/animation: 'auto' means None for DataFrames (no dimension priority) + resolved_facet_col = None if facet_col == 'auto' else facet_col + resolved_facet_row = None if facet_row == 'auto' else facet_row + resolved_animation = None if animation_frame == 'auto' else animation_frame + # Build color map if color_col and color_col in df.columns: color_items = df[color_col].unique().tolist() @@ -2087,9 +1970,9 @@ def effects( y='value', color=color_col, color_discrete_map=color_map, - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=resolved_facet_col, + facet_row=resolved_facet_row, + animation_frame=resolved_animation, title=title, **plotly_kwargs, ) @@ -2138,16 +2021,13 @@ def charge_states( ds = ds[[s for s in storages if s in ds]] ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) fig = ds.fxplot.line( colors=colors, title='Storage Charge States', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) fig.update_yaxes(title_text='Charge State') @@ -2231,61 +2111,46 @@ def storage( # Apply selection ds = _apply_selection(ds, select) - actual_facet_col, actual_facet_row, actual_anim = _resolve_auto_facets( - ds, facet_col, facet_row, animation_frame - ) - # Build color map + # Separate flow data from charge_state flow_labels = [lbl for lbl in ds.data_vars if lbl != 'charge_state'] + flow_ds = ds[flow_labels] + charge_da = ds['charge_state'] + + # Build color map for flows if colors is None: colors = self._get_color_map_for_balance(storage, flow_labels) - color_map = process_colors(colors, flow_labels) - color_map['charge_state'] = 'black' - # Convert to long-form DataFrame - df = _dataset_to_long_df(ds) - - # Create figure with facets using px.bar for flows - flow_df = df[df['variable'] != 'charge_state'] - charge_df = df[df['variable'] == 'charge_state'] - - fig = px.bar( - flow_df, + # Create stacked bar chart for flows using fxplot + fig = flow_ds.fxplot.stacked_bar( x='time', - y='value', color='variable', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, - color_discrete_map=color_map, + colors=colors, title=f'{storage} Operation ({unit})', + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) - fig.update_layout(bargap=0, bargroupgap=0) - fig.update_traces(marker_line_width=0) # Add charge state as line on secondary y-axis - if not charge_df.empty: - # Create line figure with same facets to get matching trace structure - line_fig = px.line( - charge_df, + if charge_da.size > 0: + # Create line figure with same facets + line_fig = charge_da.fxplot.line( x='time', - y='value', - facet_col=actual_facet_col, - facet_row=actual_facet_row, - animation_frame=actual_anim, + color=None, # Single line, no color grouping + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, ) # Get the primary y-axes from the bar figure to create matching secondary axes - # px creates axes named: yaxis, yaxis2, yaxis3, etc. primary_yaxes = [key for key in fig.layout if key.startswith('yaxis')] # For each primary y-axis, create a secondary y-axis for i, primary_key in enumerate(sorted(primary_yaxes, key=lambda x: int(x[5:]) if x[5:] else 0)): - # Determine secondary axis name (y -> y2, y2 -> y3 pattern won't work) - # Instead use a consistent offset: yaxis -> yaxis10, yaxis2 -> yaxis11, etc. primary_num = primary_key[5:] if primary_key[5:] else '1' - secondary_num = int(primary_num) + 100 # Use high offset to avoid conflicts + secondary_num = int(primary_num) + 100 secondary_key = f'yaxis{secondary_num}' secondary_anchor = f'x{primary_num}' if primary_num != '1' else 'x' @@ -2299,14 +2164,13 @@ def storage( # Add line traces with correct axis assignments for i, trace in enumerate(line_fig.data): - # Map trace index to secondary y-axis primary_num = i + 1 if i > 0 else 1 secondary_yaxis = f'y{primary_num + 100}' trace.name = 'charge_state' trace.line = dict(color=charge_state_color, width=2) trace.yaxis = secondary_yaxis - trace.showlegend = i == 0 # Only show legend for first trace + trace.showlegend = i == 0 trace.legendgroup = 'charge_state' fig.add_trace(trace) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index 3a13dbb63..6a5b51caa 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -582,6 +582,14 @@ def cluster( weights: dict[str, float] | None = None, time_series_for_high_peaks: list[str] | None = None, time_series_for_low_peaks: list[str] | None = None, + cluster_method: Literal['k_means', 'k_medoids', 'hierarchical', 'k_maxoids', 'averaging'] = 'hierarchical', + representation_method: Literal[ + 'meanRepresentation', 'medoidRepresentation', 'distributionAndMinMaxRepresentation' + ] = 'medoidRepresentation', + extreme_period_method: Literal['append', 'new_cluster_center', 'replace_cluster_center'] | None = None, + rescale_cluster_periods: bool = True, + predef_cluster_order: xr.DataArray | np.ndarray | list[int] | None = None, + **tsam_kwargs: Any, ) -> FlowSystem: """ Create a FlowSystem with reduced timesteps using typical clusters. @@ -591,7 +599,7 @@ def cluster( through time series aggregation using the tsam package. The method: - 1. Performs time series clustering using tsam (k-means) + 1. Performs time series clustering using tsam (hierarchical by default) 2. Extracts only the typical clusters (not all original timesteps) 3. Applies timestep weighting for accurate cost representation 4. Handles storage states between clusters based on each Storage's ``cluster_mode`` @@ -607,6 +615,25 @@ def cluster( time_series_for_high_peaks: Time series labels for explicitly selecting high-value clusters. **Recommended** for demand time series to capture peak demand days. time_series_for_low_peaks: Time series labels for explicitly selecting low-value clusters. + cluster_method: Clustering algorithm to use. Options: + ``'hierarchical'`` (default), ``'k_means'``, ``'k_medoids'``, + ``'k_maxoids'``, ``'averaging'``. + representation_method: How cluster representatives are computed. Options: + ``'medoidRepresentation'`` (default), ``'meanRepresentation'``, + ``'distributionAndMinMaxRepresentation'``. + extreme_period_method: How extreme periods (peaks) are integrated. Options: + ``None`` (default, no special handling), ``'append'``, + ``'new_cluster_center'``, ``'replace_cluster_center'``. + rescale_cluster_periods: If True (default), rescale cluster periods so their + weighted mean matches the original time series mean. + predef_cluster_order: Predefined cluster assignments for manual clustering. + Array of cluster indices (0 to n_clusters-1) for each original period. + If provided, clustering is skipped and these assignments are used directly. + For multi-dimensional FlowSystems, use an xr.DataArray with dims + ``[original_cluster, period?, scenario?]`` to specify different assignments + per period/scenario combination. + **tsam_kwargs: Additional keyword arguments passed to + ``tsam.TimeSeriesAggregation``. See tsam documentation for all options. Returns: A new FlowSystem with reduced timesteps (only typical clusters). @@ -676,11 +703,47 @@ def cluster( ds = self._fs.to_dataset(include_solution=False) + # Validate tsam_kwargs doesn't override explicit parameters + reserved_tsam_keys = { + 'noTypicalPeriods', + 'hoursPerPeriod', + 'resolution', + 'clusterMethod', + 'extremePeriodMethod', + 'representationMethod', + 'rescaleClusterPeriods', + 'predefClusterOrder', + 'weightDict', + 'addPeakMax', + 'addPeakMin', + } + conflicts = reserved_tsam_keys & set(tsam_kwargs.keys()) + if conflicts: + raise ValueError( + f'Cannot override explicit parameters via tsam_kwargs: {conflicts}. ' + f'Use the corresponding cluster() parameters instead.' + ) + + # Validate predef_cluster_order dimensions if it's a DataArray + if isinstance(predef_cluster_order, xr.DataArray): + expected_dims = {'original_cluster'} + if has_periods: + expected_dims.add('period') + if has_scenarios: + expected_dims.add('scenario') + if set(predef_cluster_order.dims) != expected_dims: + raise ValueError( + f'predef_cluster_order dimensions {set(predef_cluster_order.dims)} ' + f'do not match expected {expected_dims} for this FlowSystem.' + ) + # Cluster each (period, scenario) combination using tsam directly tsam_results: dict[tuple, tsam.TimeSeriesAggregation] = {} cluster_orders: dict[tuple, np.ndarray] = {} cluster_occurrences_all: dict[tuple, dict] = {} - use_extreme_periods = bool(time_series_for_high_peaks or time_series_for_low_peaks) + + # Collect metrics per (period, scenario) slice + clustering_metrics_all: dict[tuple, pd.DataFrame] = {} for period_label in periods: for scenario_label in scenarios: @@ -693,18 +756,34 @@ def cluster( if selector: logger.info(f'Clustering {", ".join(f"{k}={v}" for k, v in selector.items())}...') + # Handle predef_cluster_order for multi-dimensional case + predef_order_slice = None + if predef_cluster_order is not None: + if isinstance(predef_cluster_order, xr.DataArray): + # Extract slice for this (period, scenario) combination + predef_order_slice = predef_cluster_order.sel(**selector, drop=True).values + else: + # Simple array/list - use directly + predef_order_slice = predef_cluster_order + # Use tsam directly clustering_weights = weights or self._calculate_clustering_weights(temporaly_changing_ds) + # tsam expects 'None' as a string, not Python None + tsam_extreme_method = 'None' if extreme_period_method is None else extreme_period_method tsam_agg = tsam.TimeSeriesAggregation( df, noTypicalPeriods=n_clusters, hoursPerPeriod=hours_per_cluster, resolution=dt, - clusterMethod='k_means', - extremePeriodMethod='new_cluster_center' if use_extreme_periods else 'None', + clusterMethod=cluster_method, + extremePeriodMethod=tsam_extreme_method, + representationMethod=representation_method, + rescaleClusterPeriods=rescale_cluster_periods, + predefClusterOrder=predef_order_slice, weightDict={name: w for name, w in clustering_weights.items() if name in df.columns}, addPeakMax=time_series_for_high_peaks or [], addPeakMin=time_series_for_low_peaks or [], + **tsam_kwargs, ) # Suppress tsam warning about minimal value constraints (informational, not actionable) with warnings.catch_warnings(): @@ -714,10 +793,60 @@ def cluster( tsam_results[key] = tsam_agg cluster_orders[key] = tsam_agg.clusterOrder cluster_occurrences_all[key] = tsam_agg.clusterPeriodNoOccur + # Compute accuracy metrics with error handling + try: + clustering_metrics_all[key] = tsam_agg.accuracyIndicators() + except Exception as e: + logger.warning(f'Failed to compute clustering metrics for {key}: {e}') + clustering_metrics_all[key] = pd.DataFrame() # Use first result for structure first_key = (periods[0], scenarios[0]) first_tsam = tsam_results[first_key] + + # Convert metrics to xr.Dataset with period/scenario dims if multi-dimensional + # Filter out empty DataFrames (from failed accuracyIndicators calls) + non_empty_metrics = {k: v for k, v in clustering_metrics_all.items() if not v.empty} + if not non_empty_metrics: + # All metrics failed - create empty Dataset + clustering_metrics = xr.Dataset() + elif len(non_empty_metrics) == 1 or len(clustering_metrics_all) == 1: + # Simple case: convert single DataFrame to Dataset + metrics_df = non_empty_metrics.get(first_key) + if metrics_df is None: + metrics_df = next(iter(non_empty_metrics.values())) + clustering_metrics = xr.Dataset( + { + col: xr.DataArray( + metrics_df[col].values, dims=['time_series'], coords={'time_series': metrics_df.index} + ) + for col in metrics_df.columns + } + ) + else: + # Multi-dim case: combine metrics into Dataset with period/scenario dims + # First, get the metric columns from any non-empty DataFrame + sample_df = next(iter(non_empty_metrics.values())) + metric_names = list(sample_df.columns) + time_series_names = list(sample_df.index) + + # Build DataArrays for each metric + data_vars = {} + for metric in metric_names: + # Shape: (time_series, period?, scenario?) + slices = {} + for (p, s), df in clustering_metrics_all.items(): + if df.empty: + # Use NaN for failed metrics + slices[(p, s)] = xr.DataArray(np.full(len(time_series_names), np.nan), dims=['time_series']) + else: + slices[(p, s)] = xr.DataArray(df[metric].values, dims=['time_series']) + + da = self._combine_slices_to_dataarray_generic(slices, ['time_series'], periods, scenarios, metric) + da = da.assign_coords(time_series=time_series_names) + data_vars[metric] = da + + clustering_metrics = xr.Dataset(data_vars) n_reduced_timesteps = len(first_tsam.typicalPeriods) actual_n_clusters = len(first_tsam.clusterPeriodNoOccur) @@ -851,7 +980,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: # Build multi-dimensional arrays if has_periods or has_scenarios: # Multi-dimensional case: build arrays for each (period, scenario) combination - # cluster_order: dims [original_period, period?, scenario?] + # cluster_order: dims [original_cluster, period?, scenario?] cluster_order_slices = {} timestep_mapping_slices = {} cluster_occurrences_slices = {} @@ -863,7 +992,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: for s in scenarios: key = (p, s) cluster_order_slices[key] = xr.DataArray( - cluster_orders[key], dims=['original_period'], name='cluster_order' + cluster_orders[key], dims=['original_cluster'], name='cluster_order' ) timestep_mapping_slices[key] = xr.DataArray( _build_timestep_mapping_for_key(key), @@ -877,7 +1006,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: # Combine slices into multi-dimensional DataArrays cluster_order_da = self._combine_slices_to_dataarray_generic( - cluster_order_slices, ['original_period'], periods, scenarios, 'cluster_order' + cluster_order_slices, ['original_cluster'], periods, scenarios, 'cluster_order' ) timestep_mapping_da = self._combine_slices_to_dataarray_generic( timestep_mapping_slices, ['original_time'], periods, scenarios, 'timestep_mapping' @@ -887,7 +1016,7 @@ def _build_cluster_occurrences_for_key(key: tuple) -> np.ndarray: ) else: # Simple case: single (None, None) slice - cluster_order_da = xr.DataArray(cluster_orders[first_key], dims=['original_period'], name='cluster_order') + cluster_order_da = xr.DataArray(cluster_orders[first_key], dims=['original_cluster'], name='cluster_order') # Use renamed timesteps as coordinates original_timesteps_coord = self._fs.timesteps.rename('original_time') timestep_mapping_da = xr.DataArray( @@ -932,6 +1061,7 @@ def _build_cluster_weights_for_key(key: tuple) -> xr.DataArray: reduced_fs.clustering = Clustering( result=aggregation_result, backend_name='tsam', + metrics=clustering_metrics, ) return reduced_fs @@ -996,7 +1126,7 @@ def _combine_slices_to_dataarray_generic( Args: slices: Dict mapping (period, scenario) tuples to DataArrays. - base_dims: Base dimensions of each slice (e.g., ['original_period'] or ['original_time']). + base_dims: Base dimensions of each slice (e.g., ['original_cluster'] or ['original_time']). periods: List of period labels ([None] if no periods dimension). scenarios: List of scenario labels ([None] if no scenarios dimension). name: Name for the resulting DataArray. @@ -1085,7 +1215,7 @@ def expand_solution(self) -> FlowSystem: disaggregates the FlowSystem by: 1. Expanding all time series data from typical clusters to full timesteps 2. Expanding the solution by mapping each typical cluster back to all - original segments it represents + original clusters it represents For FlowSystems with periods and/or scenarios, each (period, scenario) combination is expanded using its own cluster assignment. @@ -1121,7 +1251,7 @@ def expand_solution(self) -> FlowSystem: Note: The expanded FlowSystem repeats the typical cluster values for all - segments belonging to the same cluster. Both input data and solution + original clusters belonging to the same cluster. Both input data and solution are consistently expanded, so they match. This is an approximation - the actual dispatch at full resolution would differ due to intra-cluster variations in time series data. @@ -1162,18 +1292,44 @@ def expand_solution(self) -> FlowSystem: scenarios = list(self._fs.scenarios) if has_scenarios else [None] n_original_timesteps = len(original_timesteps) n_reduced_timesteps = n_clusters * timesteps_per_cluster + n_original_clusters = cluster_structure.n_original_clusters # Expand function using ClusterResult.expand_data() - handles multi-dimensional cases - def expand_da(da: xr.DataArray) -> xr.DataArray: + # For charge_state with cluster dim, also includes the extra timestep + # Clamp to valid bounds to handle partial clusters at the end + last_original_cluster_idx = min( + (n_original_timesteps - 1) // timesteps_per_cluster, + n_original_clusters - 1, + ) + + def expand_da(da: xr.DataArray, var_name: str = '') -> xr.DataArray: if 'time' not in da.dims: return da.copy() - return info.result.expand_data(da, original_time=original_timesteps) + expanded = info.result.expand_data(da, original_time=original_timesteps) + + # For charge_state with cluster dim, append the extra timestep value + if var_name.endswith('|charge_state') and 'cluster' in da.dims: + # Get extra timestep from last cluster using vectorized selection + cluster_order = cluster_structure.cluster_order # (n_original_clusters,) or with period/scenario + if cluster_order.ndim == 1: + last_cluster = int(cluster_order[last_original_cluster_idx]) + extra_val = da.isel(cluster=last_cluster, time=-1) + else: + # Multi-dimensional: select last cluster for each period/scenario slice + last_clusters = cluster_order.isel(original_cluster=last_original_cluster_idx) + extra_val = da.isel(cluster=last_clusters, time=-1) + # Drop 'cluster'/'time' coords created by isel (kept as non-dim coords) + extra_val = extra_val.drop_vars(['cluster', 'time'], errors='ignore') + extra_val = extra_val.expand_dims(time=[original_timesteps_extra[-1]]) + expanded = xr.concat([expanded, extra_val], dim='time') + + return expanded # 1. Expand FlowSystem data (with cluster_weight set to 1.0 for all timesteps) reduced_ds = self._fs.to_dataset(include_solution=False) # Filter out cluster-related variables and copy attrs without clustering info data_vars = { - name: expand_da(da) + name: expand_da(da, name) for name, da in reduced_ds.data_vars.items() if name != 'cluster_weight' and not name.startswith('clustering|') } @@ -1201,17 +1357,22 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: expanded_fs = FlowSystem.from_dataset(expanded_ds) # 2. Expand solution + # charge_state variables get their extra timestep via expand_da; others get NaN via reindex reduced_solution = self._fs.solution expanded_fs._solution = xr.Dataset( - {name: expand_da(da) for name, da in reduced_solution.data_vars.items()}, + {name: expand_da(da, name) for name, da in reduced_solution.data_vars.items()}, attrs=reduced_solution.attrs, ) + # Reindex to timesteps_extra for consistency with non-expanded FlowSystems + # (variables without extra timestep data will have NaN at the final timestep) + expanded_fs._solution = expanded_fs._solution.reindex(time=original_timesteps_extra) # 3. Combine charge_state with SOC_boundary for InterclusterStorageModel storages # For intercluster storages, charge_state is relative (ΔE) and can be negative. # Per Blanke et al. (2022) Eq. 9, actual SOC at time t in period d is: # SOC(t) = SOC_boundary[d] * (1 - loss)^t_within_period + charge_state(t) # where t_within_period is hours from period start (accounts for self-discharge decay). + n_original_timesteps_extra = len(original_timesteps_extra) soc_boundary_vars = [name for name in reduced_solution.data_vars if name.endswith('|SOC_boundary')] for soc_boundary_name in soc_boundary_vars: storage_name = soc_boundary_name.rsplit('|', 1)[0] @@ -1222,30 +1383,42 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: soc_boundary = reduced_solution[soc_boundary_name] expanded_charge_state = expanded_fs._solution[charge_state_name] - # Map each original timestep to its original period index - original_period_indices = np.arange(n_original_timesteps) // timesteps_per_cluster + # Map each original timestep (including extra) to its original period index + # The extra timestep belongs to the last period + original_cluster_indices = np.minimum( + np.arange(n_original_timesteps_extra) // timesteps_per_cluster, + n_original_clusters - 1, + ) # Select SOC_boundary for each timestep (boundary[d] for period d) - # SOC_boundary has dim 'cluster_boundary', we select indices 0..n_original_periods-1 + # SOC_boundary has dim 'cluster_boundary', we select indices 0..n_original_clusters-1 soc_boundary_per_timestep = soc_boundary.isel( - cluster_boundary=xr.DataArray(original_period_indices, dims=['time']) + cluster_boundary=xr.DataArray(original_cluster_indices, dims=['time']) ) - soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps) + soc_boundary_per_timestep = soc_boundary_per_timestep.assign_coords(time=original_timesteps_extra) # Apply self-discharge decay to SOC_boundary based on time within period # Get the storage's relative_loss_per_hour from the clustered flow system storage = self._fs.storages.get(storage_name) if storage is not None: # Time within period for each timestep (0, 1, 2, ..., timesteps_per_cluster-1, 0, 1, ...) - time_within_period = np.arange(n_original_timesteps) % timesteps_per_cluster + # The extra timestep is at index timesteps_per_cluster (one past the last within-cluster index) + time_within_period = np.arange(n_original_timesteps_extra) % timesteps_per_cluster + # The extra timestep gets the correct decay (timesteps_per_cluster) + time_within_period[-1] = timesteps_per_cluster time_within_period_da = xr.DataArray( - time_within_period, dims=['time'], coords={'time': original_timesteps} + time_within_period, dims=['time'], coords={'time': original_timesteps_extra} ) # Decay factor: (1 - loss)^t, using mean loss over time - # Keep as DataArray to respect per-period/scenario values loss_value = storage.relative_loss_per_hour.mean('time') if (loss_value > 0).any(): decay_da = (1 - loss_value) ** time_within_period_da + if 'cluster' in decay_da.dims: + # Map each timestep to its cluster's decay value + cluster_per_timestep = cluster_structure.cluster_order.values[original_cluster_indices] + decay_da = decay_da.isel(cluster=xr.DataArray(cluster_per_timestep, dims=['time'])).drop_vars( + 'cluster', errors='ignore' + ) soc_boundary_per_timestep = soc_boundary_per_timestep * decay_da # Combine: actual_SOC = SOC_boundary * decay + charge_state @@ -1254,15 +1427,22 @@ def expand_da(da: xr.DataArray) -> xr.DataArray: combined_charge_state = (expanded_charge_state + soc_boundary_per_timestep).clip(min=0) expanded_fs._solution[charge_state_name] = combined_charge_state.assign_attrs(expanded_charge_state.attrs) + # Remove SOC_boundary variables - they're cluster-specific and now incorporated into charge_state + for soc_boundary_name in soc_boundary_vars: + if soc_boundary_name in expanded_fs._solution: + del expanded_fs._solution[soc_boundary_name] + # Also drop the cluster_boundary coordinate (orphaned after removing SOC_boundary) + if 'cluster_boundary' in expanded_fs._solution.coords: + expanded_fs._solution = expanded_fs._solution.drop_vars('cluster_boundary') + n_combinations = len(periods) * len(scenarios) - n_original_segments = cluster_structure.n_original_periods logger.info( f'Expanded FlowSystem from {n_reduced_timesteps} to {n_original_timesteps} timesteps ' f'({n_clusters} clusters' + ( f', {n_combinations} period/scenario combinations)' if n_combinations > 1 - else f' → {n_original_segments} original segments)' + else f' → {n_original_clusters} original clusters)' ) ) diff --git a/mkdocs.yml b/mkdocs.yml index ab2e9309f..9eed96ad6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,8 +69,13 @@ nav: - Piecewise Effects: notebooks/06c-piecewise-effects.ipynb - Scaling: - Scenarios: notebooks/07-scenarios-and-periods.ipynb - - Clustering: notebooks/08a-aggregation.ipynb + - Aggregation: notebooks/08a-aggregation.ipynb - Rolling Horizon: notebooks/08b-rolling-horizon.ipynb + - Clustering: + - Introduction: notebooks/08c-clustering.ipynb + - Storage Modes: notebooks/08c2-clustering-storage-modes.ipynb + - Multi-Period: notebooks/08d-clustering-multiperiod.ipynb + - Internals: notebooks/08e-clustering-internals.ipynb - Results: - Plotting: notebooks/09-plotting-and-data-access.ipynb - Custom Data Plotting: notebooks/fxplot_accessor_demo.ipynb @@ -230,7 +235,7 @@ plugins: separator: '[\s\u200b\-_,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' - mkdocs-jupyter: - execute: true # Execute notebooks during build + execute: !ENV [MKDOCS_JUPYTER_EXECUTE, true] # CI pre-executes in parallel allow_errors: false include_source: true include_requirejs: true diff --git a/pyproject.toml b/pyproject.toml index 561f00f57..f80f83557 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "xarray >= 2024.2.0, < 2026.0", # CalVer: allow through next calendar year # Optimization and data handling "linopy >= 0.5.1, < 0.6", # Widened from patch pin to minor range - "netcdf4 >= 1.6.1, < 2", + "netcdf4 >= 1.6.1, < 1.7.4", # 1.7.4 missing wheels, revert to < 2 later # Utilities "pyyaml >= 6.0.0, < 7", "colorlog >= 6.8.0, < 7", @@ -79,7 +79,7 @@ dev = [ "pytest==8.4.2", "pytest-xdist==3.8.0", "nbformat==5.10.4", - "ruff==0.13.3", + "ruff==0.14.10", "pre-commit==4.3.0", "pyvis==0.3.2", "tsam==2.3.9", @@ -89,14 +89,14 @@ dev = [ "dash-cytoscape==1.0.2", "dash-daq==0.6.0", "networkx==3.0.0", - "werkzeug==3.0.0", + "werkzeug==3.1.4", ] # Documentation building docs = [ "mkdocs==1.6.1", - "mkdocs-material==9.6.23", - "mkdocstrings-python==1.18.2", + "mkdocs-material==9.7.1", + "mkdocstrings-python==1.19.0", "mkdocs-table-reader-plugin==3.1.0", "mkdocs-gen-files==0.5.0", "mkdocs-include-markdown-plugin==7.2.0", @@ -104,7 +104,7 @@ docs = [ "mkdocs-plotly-plugin==0.1.3", "mkdocs-jupyter==0.25.1", "markdown-include==0.8.1", - "pymdown-extensions==10.16.1", + "pymdown-extensions==10.19.1", "pygments==2.19.2", "mike==2.1.3", "mkdocs-git-revision-date-localized-plugin==1.5.0", diff --git a/tests/test_cluster_reduce_expand.py b/tests/test_cluster_reduce_expand.py index 7072fe22e..4059470ee 100644 --- a/tests/test_cluster_reduce_expand.py +++ b/tests/test_cluster_reduce_expand.py @@ -167,7 +167,7 @@ def test_expand_solution_enables_statistics_accessor(solver_fixture, timesteps_8 # These should work without errors flow_rates = fs_expanded.statistics.flow_rates assert 'Boiler(Q_th)' in flow_rates - assert len(flow_rates['Boiler(Q_th)'].coords['time']) == 192 + assert len(flow_rates['Boiler(Q_th)'].coords['time']) == 193 # 192 + 1 extra timestep flow_hours = fs_expanded.statistics.flow_hours assert 'Boiler(Q_th)' in flow_hours @@ -321,7 +321,7 @@ def test_cluster_and_expand_with_scenarios(solver_fixture, timesteps_8_days, sce flow_var = 'Boiler(Q_th)|flow_rate' assert flow_var in fs_expanded.solution assert 'scenario' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep def test_expand_solution_maps_scenarios_independently(solver_fixture, timesteps_8_days, scenarios_2): @@ -449,9 +449,9 @@ def test_storage_cluster_mode_intercluster(self, solver_fixture, timesteps_8_day soc_boundary = fs_clustered.solution['Battery|SOC_boundary'] assert 'cluster_boundary' in soc_boundary.dims - # Number of boundaries = n_original_periods + 1 - n_original_periods = fs_clustered.clustering.result.cluster_structure.n_original_periods - assert soc_boundary.sizes['cluster_boundary'] == n_original_periods + 1 + # Number of boundaries = n_original_clusters + 1 + n_original_clusters = fs_clustered.clustering.result.cluster_structure.n_original_clusters + assert soc_boundary.sizes['cluster_boundary'] == n_original_clusters + 1 def test_storage_cluster_mode_intercluster_cyclic(self, solver_fixture, timesteps_8_days): """Storage with cluster_mode='intercluster_cyclic' - linked with yearly cycling.""" @@ -693,7 +693,7 @@ def test_expand_solution_with_periods(self, solver_fixture, timesteps_8_days, pe # Solution should have period dimension flow_var = 'Boiler(Q_th)|flow_rate' assert 'period' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep def test_cluster_with_periods_and_scenarios(self, solver_fixture, timesteps_8_days, periods_2, scenarios_2): """Clustering should work with both periods and scenarios.""" @@ -719,7 +719,7 @@ def test_cluster_with_periods_and_scenarios(self, solver_fixture, timesteps_8_da fs_expanded = fs_clustered.transform.expand_solution() assert 'period' in fs_expanded.solution[flow_var].dims assert 'scenario' in fs_expanded.solution[flow_var].dims - assert len(fs_expanded.solution[flow_var].coords['time']) == 192 + assert len(fs_expanded.solution[flow_var].coords['time']) == 193 # 192 + 1 extra timestep # ==================== Peak Selection Tests ==================== diff --git a/tests/test_clustering/test_base.py b/tests/test_clustering/test_base.py index 9c63f25f6..9cca4de81 100644 --- a/tests/test_clustering/test_base.py +++ b/tests/test_clustering/test_base.py @@ -17,7 +17,7 @@ class TestClusterStructure: def test_basic_creation(self): """Test basic ClusterStructure creation.""" - cluster_order = xr.DataArray([0, 1, 0, 1, 2, 0], dims=['original_period']) + cluster_order = xr.DataArray([0, 1, 0, 1, 2, 0], dims=['original_cluster']) cluster_occurrences = xr.DataArray([3, 2, 1], dims=['cluster']) structure = ClusterStructure( @@ -29,7 +29,7 @@ def test_basic_creation(self): assert structure.n_clusters == 3 assert structure.timesteps_per_cluster == 24 - assert structure.n_original_periods == 6 + assert structure.n_original_clusters == 6 def test_creation_from_numpy(self): """Test ClusterStructure creation from numpy arrays.""" @@ -42,12 +42,12 @@ def test_creation_from_numpy(self): assert isinstance(structure.cluster_order, xr.DataArray) assert isinstance(structure.cluster_occurrences, xr.DataArray) - assert structure.n_original_periods == 5 + assert structure.n_original_clusters == 5 def test_get_cluster_weight_per_timestep(self): """Test weight calculation per timestep.""" structure = ClusterStructure( - cluster_order=xr.DataArray([0, 1, 0], dims=['original_period']), + cluster_order=xr.DataArray([0, 1, 0], dims=['original_cluster']), cluster_occurrences=xr.DataArray([2, 1], dims=['cluster']), n_clusters=2, timesteps_per_cluster=4, @@ -136,7 +136,7 @@ def test_basic_creation(self): structure = create_cluster_structure_from_mapping(mapping, timesteps_per_cluster=4) assert structure.timesteps_per_cluster == 4 - assert structure.n_original_periods == 3 + assert structure.n_original_clusters == 3 class TestClustering: diff --git a/tests/test_clustering/test_integration.py b/tests/test_clustering/test_integration.py index 2d04a51c1..16c638c95 100644 --- a/tests/test_clustering/test_integration.py +++ b/tests/test_clustering/test_integration.py @@ -170,6 +170,104 @@ def test_cluster_reduces_timesteps(self): assert len(fs_clustered.timesteps) * len(fs_clustered.clusters) == 48 +class TestClusterAdvancedOptions: + """Tests for advanced clustering options.""" + + @pytest.fixture + def basic_flow_system(self): + """Create a basic FlowSystem for testing.""" + pytest.importorskip('tsam') + from flixopt import Bus, Flow, Sink, Source + from flixopt.core import TimeSeriesData + + n_hours = 168 # 7 days + fs = FlowSystem(timesteps=pd.date_range('2024-01-01', periods=n_hours, freq='h')) + + demand_data = np.sin(np.linspace(0, 14 * np.pi, n_hours)) + 2 + bus = Bus('electricity') + grid_flow = Flow('grid_in', bus='electricity', size=100) + demand_flow = Flow( + 'demand_out', bus='electricity', size=100, fixed_relative_profile=TimeSeriesData(demand_data / 100) + ) + source = Source('grid', outputs=[grid_flow]) + sink = Sink('demand', inputs=[demand_flow]) + fs.add_elements(source, sink, bus) + return fs + + def test_cluster_method_parameter(self, basic_flow_system): + """Test that cluster_method parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', cluster_method='hierarchical' + ) + assert len(fs_clustered.clusters) == 2 + + def test_hierarchical_is_deterministic(self, basic_flow_system): + """Test that hierarchical clustering (default) produces deterministic results.""" + fs1 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') + fs2 = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') + + # Hierarchical clustering should produce identical cluster orders + xr.testing.assert_equal(fs1.clustering.cluster_order, fs2.clustering.cluster_order) + + def test_metrics_available(self, basic_flow_system): + """Test that clustering metrics are available after clustering.""" + fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D') + + assert fs_clustered.clustering.metrics is not None + assert isinstance(fs_clustered.clustering.metrics, xr.Dataset) + assert 'time_series' in fs_clustered.clustering.metrics.dims + assert len(fs_clustered.clustering.metrics.data_vars) > 0 + + def test_representation_method_parameter(self, basic_flow_system): + """Test that representation_method parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', representation_method='medoidRepresentation' + ) + assert len(fs_clustered.clusters) == 2 + + def test_rescale_cluster_periods_parameter(self, basic_flow_system): + """Test that rescale_cluster_periods parameter works.""" + fs_clustered = basic_flow_system.transform.cluster( + n_clusters=2, cluster_duration='1D', rescale_cluster_periods=False + ) + assert len(fs_clustered.clusters) == 2 + + def test_tsam_kwargs_passthrough(self, basic_flow_system): + """Test that additional kwargs are passed to tsam.""" + # sameMean is a valid tsam parameter + fs_clustered = basic_flow_system.transform.cluster(n_clusters=2, cluster_duration='1D', sameMean=True) + assert len(fs_clustered.clusters) == 2 + + def test_metrics_with_periods(self): + """Test that metrics have period dimension for multi-period FlowSystems.""" + pytest.importorskip('tsam') + from flixopt import Bus, Flow, Sink, Source + from flixopt.core import TimeSeriesData + + n_hours = 168 # 7 days + fs = FlowSystem( + timesteps=pd.date_range('2024-01-01', periods=n_hours, freq='h'), + periods=pd.Index([2025, 2030], name='period'), + ) + + demand_data = np.sin(np.linspace(0, 14 * np.pi, n_hours)) + 2 + bus = Bus('electricity') + grid_flow = Flow('grid_in', bus='electricity', size=100) + demand_flow = Flow( + 'demand_out', bus='electricity', size=100, fixed_relative_profile=TimeSeriesData(demand_data / 100) + ) + source = Source('grid', outputs=[grid_flow]) + sink = Sink('demand', inputs=[demand_flow]) + fs.add_elements(source, sink, bus) + + fs_clustered = fs.transform.cluster(n_clusters=2, cluster_duration='1D') + + # Metrics should have period dimension + assert fs_clustered.clustering.metrics is not None + assert 'period' in fs_clustered.clustering.metrics.dims + assert len(fs_clustered.clustering.metrics.period) == 2 + + class TestClusteringModuleImports: """Tests for flixopt.clustering module imports.""" From e613755fc201e0af114f17e9ecfaa1b385f8affd Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:54:03 +0100 Subject: [PATCH 59/62] comparison.py: 1. Removed _resolve_facets method - fxplot handles 'auto' resolution internally 2. Updated all methods to pass facet params directly to fxplot 3. sizes now uses ds.fxplot.bar() instead of px.bar 4. effects now uses ds.fxplot.bar() with proper column naming statistics_accessor.py: 1. Simplified effects method significantly: - Works directly with Dataset (no DataArray concat/conversion) - Uses dict.get for aspect selection - Cleaner aggregation logic - Returns Dataset with effects as data variables - Uses fxplot.bar instead of px.bar The code is now consistent - all plotting methods in both StatisticsPlotAccessor and ComparisonStatisticsPlot use fxplot for centralized dimension/slot handling. --- flixopt/comparison.py | 13 ++-- flixopt/statistics_accessor.py | 127 +++++++++++---------------------- 2 files changed, 45 insertions(+), 95 deletions(-) diff --git a/flixopt/comparison.py b/flixopt/comparison.py index d7db72975..5219642d2 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -577,17 +577,14 @@ def effects( if not ds.data_vars: return self._finalize(ds, None, show) - # Get the data array - da = ds[aspect] if aspect in ds else ds[next(iter(ds.data_vars))] - by = data_kw.get('by') - x_col = by if by else 'effect' + # After to_dataset(dim='effect'), effects become variables -> 'variable' column + x_col = by if by else 'variable' + color_col = 'variable' if len(ds.data_vars) > 1 else x_col - # Convert to Dataset along 'effect' dimension (each effect becomes a variable) - plot_ds = da.to_dataset(dim='effect') if 'effect' in da.dims else da.to_dataset(name=aspect) - fig = plot_ds.fxplot.bar( + fig = ds.fxplot.bar( x=x_col, - color=x_col, + color=color_col, colors=colors, title=title, facet_col=facet_col, diff --git a/flixopt/statistics_accessor.py b/flixopt/statistics_accessor.py index e3581e4e3..536c6beaf 100644 --- a/flixopt/statistics_accessor.py +++ b/flixopt/statistics_accessor.py @@ -25,7 +25,6 @@ import numpy as np import pandas as pd -import plotly.express as px import plotly.graph_objects as go import xarray as xr @@ -1867,113 +1866,67 @@ def effects( self._stats._require_solution() # Get the appropriate effects dataset based on aspect - if aspect == 'total': - effects_ds = self._stats.total_effects - elif aspect == 'temporal': - effects_ds = self._stats.temporal_effects - elif aspect == 'periodic': - effects_ds = self._stats.periodic_effects - else: + effects_ds = { + 'total': self._stats.total_effects, + 'temporal': self._stats.temporal_effects, + 'periodic': self._stats.periodic_effects, + }.get(aspect) + if effects_ds is None: raise ValueError(f"Aspect '{aspect}' not valid. Choose from 'total', 'temporal', 'periodic'.") - # Get available effects (data variables in the dataset) - available_effects = list(effects_ds.data_vars) - - # Filter to specific effect if requested + # Filter to specific effect(s) and apply selection if effect is not None: - if effect not in available_effects: - raise ValueError(f"Effect '{effect}' not found. Available: {available_effects}") - effects_to_plot = [effect] + if effect not in effects_ds: + raise ValueError(f"Effect '{effect}' not found. Available: {list(effects_ds.data_vars)}") + ds = effects_ds[[effect]] else: - effects_to_plot = available_effects - - # Build a combined DataArray with effect dimension - effect_arrays = [] - for eff in effects_to_plot: - da = effects_ds[eff] - if by == 'contributor': - # Keep individual contributors (flows) - no groupby - effect_arrays.append(da.expand_dims(effect=[eff])) - else: - # Group by component (sum over contributor within each component) - da_grouped = da.groupby('component').sum() - effect_arrays.append(da_grouped.expand_dims(effect=[eff])) + ds = effects_ds - combined = xr.concat(effect_arrays, dim='effect') + # Group by component (default) unless by='contributor' + if by != 'contributor' and 'contributor' in ds.dims: + ds = ds.groupby('component').sum() - # Apply selection - combined = _apply_selection(combined.to_dataset(name='value'), select)['value'] + ds = _apply_selection(ds, select) - # Group by the specified dimension + # Sum over dimensions based on 'by' parameter if by is None: - # Aggregate totals per effect - sum over all dimensions except effect - if 'time' in combined.dims: - combined = combined.sum(dim='time') - if 'component' in combined.dims: - combined = combined.sum(dim='component') - if 'contributor' in combined.dims: - combined = combined.sum(dim='contributor') - x_col = 'effect' - color_col = 'effect' + for dim in ['time', 'component', 'contributor']: + if dim in ds.dims: + ds = ds.sum(dim=dim) + x_col, color_col = 'variable', 'variable' elif by == 'component': - # Sum over time if present - if 'time' in combined.dims: - combined = combined.sum(dim='time') + if 'time' in ds.dims: + ds = ds.sum(dim='time') x_col = 'component' - color_col = 'effect' if len(effects_to_plot) > 1 else 'component' + color_col = 'variable' if len(ds.data_vars) > 1 else 'component' elif by == 'contributor': - # Sum over time if present - if 'time' in combined.dims: - combined = combined.sum(dim='time') + if 'time' in ds.dims: + ds = ds.sum(dim='time') x_col = 'contributor' - color_col = 'effect' if len(effects_to_plot) > 1 else 'contributor' + color_col = 'variable' if len(ds.data_vars) > 1 else 'contributor' elif by == 'time': - if 'time' not in combined.dims: + if 'time' not in ds.dims: raise ValueError(f"Cannot plot by 'time' for aspect '{aspect}' - no time dimension.") - # Sum over components or contributors - if 'component' in combined.dims: - combined = combined.sum(dim='component') - if 'contributor' in combined.dims: - combined = combined.sum(dim='contributor') + for dim in ['component', 'contributor']: + if dim in ds.dims: + ds = ds.sum(dim=dim) x_col = 'time' - color_col = 'effect' if len(effects_to_plot) > 1 else None + color_col = 'variable' if len(ds.data_vars) > 1 else None else: raise ValueError(f"'by' must be one of 'component', 'contributor', 'time', or None, got {by!r}") - # Convert to DataFrame for plotly express - df = combined.to_dataframe(name='value').reset_index() - - # Resolve facet/animation: 'auto' means None for DataFrames (no dimension priority) - resolved_facet_col = None if facet_col == 'auto' else facet_col - resolved_facet_row = None if facet_row == 'auto' else facet_row - resolved_animation = None if animation_frame == 'auto' else animation_frame - - # Build color map - if color_col and color_col in df.columns: - color_items = df[color_col].unique().tolist() - color_map = process_colors(colors, color_items) - else: - color_map = None + # Build title + effect_label = effect or 'Effects' + title = f'{effect_label} ({aspect})' if by is None else f'{effect_label} ({aspect}) by {by}' - # Build title with unit if single effect - effect_label = effect if effect else 'Effects' - if effect and effect in effects_ds: - unit_label = effects_ds[effect].attrs.get('unit', '') - title = f'{effect_label} [{unit_label}]' if unit_label else effect_label - else: - title = effect_label - title = f'{title} ({aspect})' if by is None else f'{title} ({aspect}) by {by}' - - fig = px.bar( - df, + fig = ds.fxplot.bar( x=x_col, - y='value', color=color_col, - color_discrete_map=color_map, - facet_col=resolved_facet_col, - facet_row=resolved_facet_row, - animation_frame=resolved_animation, + colors=colors, title=title, + facet_col=facet_col, + facet_row=facet_row, + animation_frame=animation_frame, **plotly_kwargs, ) fig.update_layout(bargap=0, bargroupgap=0) @@ -1984,7 +1937,7 @@ def effects( if show: fig.show() - return PlotResult(data=combined.to_dataset(name=aspect), figure=fig) + return PlotResult(data=ds, figure=fig) def charge_states( self, From c6da15f800eb1f2c459b593c1b73bfa10452595b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 13:57:51 +0100 Subject: [PATCH 60/62] Squeeze signleton dims in heatmap() --- flixopt/clustering/base.py | 5 +++++ flixopt/dataset_plot_accessor.py | 26 ++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index ab9590aae..0f154484b 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -827,6 +827,11 @@ def heatmap( heatmap_da = heatmap_da.assign_coords(y=['Cluster']) heatmap_da.name = 'cluster_assignment' + # Reorder dims so 'time' and 'y' are first (heatmap x/y axes) + # Other dims (period, scenario) will be used for faceting/animation + target_order = ['time', 'y'] + [d for d in heatmap_da.dims if d not in ('time', 'y')] + heatmap_da = heatmap_da.transpose(*target_order) + # Use fxplot.heatmap for smart defaults fig = heatmap_da.fxplot.heatmap( colors=colors, diff --git a/flixopt/dataset_plot_accessor.py b/flixopt/dataset_plot_accessor.py index 47cb0564a..6c833e652 100644 --- a/flixopt/dataset_plot_accessor.py +++ b/flixopt/dataset_plot_accessor.py @@ -525,7 +525,6 @@ def heatmap( resolved_animation = animation_frame imshow_args: dict[str, Any] = { - 'img': da, 'color_continuous_scale': colors, 'title': title or variable, } @@ -538,6 +537,18 @@ def heatmap( if resolved_animation and resolved_animation in da.dims: imshow_args['animation_frame'] = resolved_animation + # Squeeze singleton dimensions not used for faceting/animation + # px.imshow can't handle extra singleton dims in multi-dimensional data + dims_to_preserve = set(list(da.dims)[:2]) # First 2 dims are heatmap x/y axes + if resolved_facet and resolved_facet in da.dims: + dims_to_preserve.add(resolved_facet) + if resolved_animation and resolved_animation in da.dims: + dims_to_preserve.add(resolved_animation) + for dim in list(da.dims): + if dim not in dims_to_preserve and da.sizes[dim] == 1: + da = da.squeeze(dim) + imshow_args['img'] = da + # Use binary_string=False to handle non-numeric coords (e.g., string labels) if 'binary_string' not in imshow_kwargs: imshow_args['binary_string'] = False @@ -955,7 +966,6 @@ def heatmap( resolved_animation = animation_frame imshow_args: dict[str, Any] = { - 'img': da, 'color_continuous_scale': colors, 'title': title or (da.name if da.name else ''), } @@ -968,6 +978,18 @@ def heatmap( if resolved_animation and resolved_animation in da.dims: imshow_args['animation_frame'] = resolved_animation + # Squeeze singleton dimensions not used for faceting/animation + # px.imshow can't handle extra singleton dims in multi-dimensional data + dims_to_preserve = set(list(da.dims)[:2]) # First 2 dims are heatmap x/y axes + if resolved_facet and resolved_facet in da.dims: + dims_to_preserve.add(resolved_facet) + if resolved_animation and resolved_animation in da.dims: + dims_to_preserve.add(resolved_animation) + for dim in list(da.dims): + if dim not in dims_to_preserve and da.sizes[dim] == 1: + da = da.squeeze(dim) + imshow_args['img'] = da + # Use binary_string=False to handle non-numeric coords (e.g., string labels) if 'binary_string' not in imshow_kwargs: imshow_args['binary_string'] = False From e4c6510621bdb9d8ea4c37066f44e9499326f2d8 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 17:21:24 +0100 Subject: [PATCH 61/62] Replaced print statements with class repr --- docs/notebooks/08c-clustering.ipynb | 161 ++++++++++++++++++++++++---- 1 file changed, 139 insertions(+), 22 deletions(-) diff --git a/docs/notebooks/08c-clustering.ipynb b/docs/notebooks/08c-clustering.ipynb index 130d49a7f..6d85e60ba 100644 --- a/docs/notebooks/08c-clustering.ipynb +++ b/docs/notebooks/08c-clustering.ipynb @@ -143,6 +143,7 @@ " cluster_duration='1D', # Daily clustering\n", " time_series_for_high_peaks=peak_series, # Capture peak demand day\n", ")\n", + "fs_clustered.name = 'Clustered (8 days)'\n", "\n", "time_clustering = timeit.default_timer() - start" ] @@ -194,9 +195,127 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "13", "metadata": {}, + "outputs": [], + "source": [ + "# Quality metrics - how well do the clusters represent the original data?\n", + "# Lower RMSE/MAE = better representation\n", + "clustering.metrics.to_dataframe().style.format('{:.3f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# Visual comparison: original vs clustered time series\n", + "clustering.plot.compare()" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Advanced Clustering Options\n", + "\n", + "The `cluster()` method exposes many parameters for fine-tuning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# Try different clustering algorithms\n", + "fs_kmeans = flow_system.transform.cluster(\n", + " n_clusters=8,\n", + " cluster_duration='1D',\n", + " cluster_method='k_means', # Alternative: 'hierarchical' (default), 'k_medoids', 'averaging'\n", + ")\n", + "\n", + "fs_kmeans.clustering" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# Compare quality metrics between algorithms\n", + "pd.DataFrame(\n", + " {\n", + " 'hierarchical': fs_clustered.clustering.metrics.to_dataframe().iloc[0],\n", + " 'k_means': fs_kmeans.clustering.metrics.to_dataframe().iloc[0],\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize cluster structure with heatmap\n", + "clustering.plot.heatmap()" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "### Manual Cluster Assignment\n", + "\n", + "When comparing design variants or performing sensitivity analysis, you often want to\n", + "use the **same cluster structure** across different FlowSystem configurations.\n", + "Use `predef_cluster_order` to ensure comparable results:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "# Save the cluster order from our optimized system\n", + "cluster_order = fs_clustered.clustering.cluster_order.values\n", + "\n", + "# Now modify the FlowSystem (e.g., increase storage capacity limits)\n", + "flow_system_modified = flow_system.copy()\n", + "flow_system_modified.components['Storage'].capacity_in_flow_hours.maximum_size = 2000 # Larger storage option\n", + "\n", + "# Cluster with the SAME cluster structure for fair comparison\n", + "fs_modified_clustered = flow_system_modified.transform.cluster(\n", + " n_clusters=8,\n", + " cluster_duration='1D',\n", + " predef_cluster_order=cluster_order, # Reuse cluster assignments\n", + ")\n", + "fs_modified_clustered.name = 'Modified (larger storage limit)'\n", + "\n", + "# Optimize the modified system\n", + "fs_modified_clustered.optimize(solver)\n", + "\n", + "# Compare results using Comparison class\n", + "fx.Comparison([fs_clustered, fs_modified_clustered])" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, "source": [ "## Method 3: Two-Stage Workflow (Recommended)\n", "\n", @@ -213,7 +332,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -225,7 +344,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -244,7 +363,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "24", "metadata": {}, "source": [ "## Compare Results" @@ -253,7 +372,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -302,7 +421,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "26", "metadata": {}, "source": [ "## Expand Solution to Full Resolution\n", @@ -314,7 +433,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -325,31 +444,29 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "28", "metadata": {}, "outputs": [], "source": [ + "# Compare heat production: Full vs Expanded\n", + "heat_flows = ['CHP(Q_th)|flow_rate', 'Boiler(Q_th)|flow_rate']\n", + "\n", "# Create comparison dataset\n", "comparison_ds = xr.Dataset(\n", " {\n", - " 'CHP Heat': xr.concat(\n", - " [fs_full.solution['CHP(Q_th)|flow_rate'], fs_expanded.solution['CHP(Q_th)|flow_rate']],\n", - " dim=pd.Index(['Full', 'Expanded'], name='method'),\n", - " ),\n", - " 'Boiler Heat': xr.concat(\n", - " [fs_full.solution['Boiler(Q_th)|flow_rate'], fs_expanded.solution['Boiler(Q_th)|flow_rate']],\n", - " dim=pd.Index(['Full', 'Expanded'], name='method'),\n", - " ),\n", + " name.replace('|flow_rate', ''): xr.concat(\n", + " [fs_full.solution[name], fs_expanded.solution[name]], dim=pd.Index(['Full', 'Expanded'], name='method')\n", + " )\n", + " for name in heat_flows\n", " }\n", ")\n", "\n", - "# Compare heat production: Full vs Expanded\n", "comparison_ds.fxplot.line(facet_col='variable', color='method', title='Heat Production Comparison')" ] }, { "cell_type": "markdown", - "id": "21", + "id": "29", "metadata": {}, "source": [ "## Visualize Clustered Heat Balance" @@ -358,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -368,7 +485,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "31", "metadata": {}, "outputs": [], "source": [ @@ -377,7 +494,7 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "32", "metadata": {}, "source": [ "## API Reference\n", @@ -454,7 +571,7 @@ }, { "cell_type": "markdown", - "id": "25", + "id": "33", "metadata": {}, "source": [ "## Summary\n", From a53efa147935505ccee3ca55811b3575c212f58f Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 5 Jan 2026 17:51:43 +0100 Subject: [PATCH 62/62] 1. 08a-aggregation.ipynb cell 16: Removed corrupted markdown tag from markdown source 2. flixopt/comparison.py line 75: Added fallback for None names: # Before self._names = names or [fs.name for fs in flow_systems] # After self._names = names or [fs.name or f'System {i}' for i, fs in enumerate(flow_systems)] --- docs/notebooks/08a-aggregation.ipynb | 2 +- flixopt/comparison.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/notebooks/08a-aggregation.ipynb b/docs/notebooks/08a-aggregation.ipynb index 73d0feea3..e26c19223 100644 --- a/docs/notebooks/08a-aggregation.ipynb +++ b/docs/notebooks/08a-aggregation.ipynb @@ -273,7 +273,7 @@ "id": "16", "metadata": {}, "source": [ - "markdown## Visual Comparison: Heat Balance\n", + "## Visual Comparison: Heat Balance\n", "\n", "Compare the full optimization with the two-stage approach side-by-side:" ] diff --git a/flixopt/comparison.py b/flixopt/comparison.py index 5219642d2..63a00a0f1 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -72,7 +72,7 @@ def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = Non raise ValueError('Comparison requires at least 2 FlowSystems') self._systems = flow_systems - self._names = names or [fs.name for fs in flow_systems] + self._names = names or [fs.name or f'System {i}' for i, fs in enumerate(flow_systems)] if len(self._names) != len(self._systems): raise ValueError(