diff --git a/flixopt/clustering/__init__.py b/flixopt/clustering/__init__.py index 9059fb682..9f12f2a02 100644 --- a/flixopt/clustering/__init__.py +++ b/flixopt/clustering/__init__.py @@ -6,7 +6,6 @@ Example usage: - # Cluster a FlowSystem to reduce timesteps from tsam import ExtremeConfig fs_clustered = flow_system.transform.cluster( @@ -15,21 +14,14 @@ extremes=ExtremeConfig(method='new_cluster', max_value=['Demand|fixed_relative_profile']), ) - # Access clustering structure (available before AND after IO) clustering = fs_clustered.clustering print(f'Number of clusters: {clustering.n_clusters}') - print(f'Clustering info: {clustering.clustering_result}') # tsam_xarray ClusteringResult + print(f'Clustering result: {clustering.clustering_result}') - # Access tsam_xarray AggregationResult for detailed analysis - # NOTE: Only available BEFORE saving/loading. Lost after IO. + # Access tsam_xarray AggregationResult (only before saving/loading) result = clustering.aggregation_result - result.cluster_representatives # DataArray with aggregated time series - result.accuracy # AccuracyMetrics (rmse, mae) - - # Save and load - structure preserved, AggregationResult access lost - fs_clustered.to_netcdf('system.nc') - # Use include_original_data=False for smaller files (~38% reduction) - fs_clustered.to_netcdf('system.nc', include_original_data=False) + result.cluster_representatives # DataArray + result.accuracy # AccuracyMetrics # Expand back to full resolution fs_expanded = fs_clustered.transform.expand() diff --git a/flixopt/clustering/base.py b/flixopt/clustering/base.py index ae3d689c4..b78bf89b3 100644 --- a/flixopt/clustering/base.py +++ b/flixopt/clustering/base.py @@ -11,42 +11,15 @@ import json from typing import TYPE_CHECKING, Any -import numpy as np import pandas as pd -import xarray as xr if TYPE_CHECKING: from pathlib import Path + import xarray as xr from tsam_xarray import AggregationResult as TsamXarrayAggregationResult from tsam_xarray import ClusteringResult - from ..color_processing import ColorType - from ..plot_result import PlotResult - from ..statistics_accessor import SelectType - -from ..statistics_accessor import _build_color_kwargs - - -def _apply_slot_defaults(plotly_kwargs: dict, defaults: dict[str, str | None]) -> None: - """Apply default slot assignments to plotly kwargs. - - Args: - plotly_kwargs: The kwargs dict to update (modified in place). - defaults: Default slot assignments. None values block slots. - """ - for slot, value in defaults.items(): - plotly_kwargs.setdefault(slot, value) - - -def _select_dims(da: xr.DataArray, period: Any = None, scenario: Any = None) -> xr.DataArray: - """Select from DataArray by period/scenario if those dimensions exist.""" - if 'period' in da.dims and period is not None: - da = da.sel(period=period) - if 'scenario' in da.dims and scenario is not None: - da = da.sel(scenario=scenario) - return da - class Clustering: """Clustering information for a FlowSystem. @@ -68,12 +41,6 @@ def __init__( self, clustering_result: ClusteringResult | dict | None = None, original_timesteps: pd.DatetimeIndex | list[str] | None = None, - original_data: xr.Dataset | None = None, - aggregated_data: xr.Dataset | None = None, - _metrics: xr.Dataset | None = None, - # These are for reconstruction from serialization - _original_data_refs: list[str] | None = None, - _metrics_refs: list[str] | None = None, # Internal: tsam_xarray AggregationResult for full data access _aggregation_result: TsamXarrayAggregationResult | None = None, # Internal: mapping from renamed dims back to originals (e.g., _period -> period) @@ -82,6 +49,8 @@ def __init__( # The IO resolver passes serialized dict keys as kwargs to __init__(). # Remove once all users have re-saved their netcdf files with the new format. results: Any = None, + # Legacy kwargs ignored (removed: original_data, aggregated_data, _metrics, refs) + **_ignored: Any, ): from tsam_xarray import ClusteringResult as ClusteringResultClass @@ -137,41 +106,6 @@ def __init__( if self._clustering_result.time_coords is None and len(self.original_timesteps) > 0: object.__setattr__(self._clustering_result, 'time_coords', self.original_timesteps) - self._metrics = _metrics - - # Handle reconstructed data from refs (list of DataArrays) - if _original_data_refs is not None and isinstance(_original_data_refs, list): - # These are resolved DataArrays from the structure resolver - if all(isinstance(da, xr.DataArray) for da in _original_data_refs): - # Rename 'original_time' back to 'time' and strip 'original_data|' prefix - data_vars = {} - for da in _original_data_refs: - if 'original_time' in da.dims: - da = da.rename({'original_time': 'time'}) - # Strip 'original_data|' prefix from name (added during serialization) - name = da.name - if name.startswith('original_data|'): - name = name[14:] # len('original_data|') = 14 - data_vars[name] = da.rename(name) - self.original_data = xr.Dataset(data_vars) - else: - self.original_data = original_data - else: - self.original_data = original_data - - self.aggregated_data = aggregated_data - - if _metrics_refs is not None and isinstance(_metrics_refs, list): - if all(isinstance(da, xr.DataArray) for da in _metrics_refs): - # Strip 'metrics|' prefix from name (added during serialization) - data_vars = {} - for da in _metrics_refs: - name = da.name - if name.startswith('metrics|'): - name = name[8:] # len('metrics|') = 8 - data_vars[name] = da.rename(name) - self._metrics = xr.Dataset(data_vars) - @staticmethod def _clustering_result_from_dict(d: dict) -> ClusteringResult: """Create ClusteringResult from serialized dict.""" @@ -308,7 +242,7 @@ def disaggregate(self, data: xr.DataArray) -> xr.DataArray: on the result. Args: - data: DataArray with ``(cluster, time)`` dims from the clustered FlowSystem. + data: DataArray with ``(cluster, time)`` or ``(cluster, segment)`` dims. Returns: DataArray with ``time`` dim restored to original timesteps. @@ -369,8 +303,8 @@ def from_json( ) -> Clustering: """Load a clustering from JSON. - The loaded Clustering has full apply() support because ClusteringResult - is fully preserved via serialization. + The loaded Clustering has full apply() and disaggregate() support + because ClusteringResult is fully preserved via serialization. Args: path: Path to the JSON file. @@ -399,61 +333,18 @@ def from_json( original_timesteps=original_timesteps, ) - def _create_reference_structure(self, include_original_data: bool = True) -> tuple[dict, dict[str, xr.DataArray]]: + def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: """Create serialization structure for to_dataset(). - Args: - include_original_data: Whether to include original_data in serialization. - Set to False for smaller files when plot.compare() isn't needed after IO. - Defaults to True. - Returns: Tuple of (reference_dict, arrays_dict). """ - arrays = {} - - # Collect original_data arrays - # Rename 'time' to 'original_time' to avoid conflict with clustered FlowSystem's time coord - original_data_refs = None - if include_original_data and self.original_data is not None: - original_data_refs = [] - # Use variables for faster access (avoids _construct_dataarray overhead) - variables = self.original_data.variables - for name in self.original_data.data_vars: - var = variables[name] - ref_name = f'original_data|{name}' - # Rename time dim to avoid xarray alignment issues - if 'time' in var.dims: - new_dims = tuple('original_time' if d == 'time' else d for d in var.dims) - arrays[ref_name] = xr.Variable(new_dims, var.values, attrs=var.attrs) - else: - arrays[ref_name] = var - original_data_refs.append(f':::{ref_name}') - - # NOTE: aggregated_data is NOT serialized - it's identical to the FlowSystem's - # main data arrays and would be redundant. After loading, aggregated_data is - # reconstructed from the FlowSystem's dataset. - - # Collect metrics arrays - metrics_refs = None - if self._metrics is not None: - metrics_refs = [] - # Use variables for faster access (avoids _construct_dataarray overhead) - metrics_vars = self._metrics.variables - for name in self._metrics.data_vars: - ref_name = f'metrics|{name}' - arrays[ref_name] = metrics_vars[name] - metrics_refs.append(f':::{ref_name}') - reference = { '__class__': 'Clustering', 'clustering_result': self._clustering_result.to_dict(), 'original_timesteps': [ts.isoformat() for ts in self.original_timesteps], - '_original_data_refs': original_data_refs, - '_metrics_refs': metrics_refs, } - - return reference, arrays + return reference, {} # ========================================================================== # Access to tsam_xarray AggregationResult @@ -485,19 +376,6 @@ def _require_full_data(self, operation: str) -> None: f'Use apply_clustering() to get full results.' ) - # ========================================================================== - # Visualization - # ========================================================================== - - @property - def plot(self) -> ClusteringPlotAccessor: - """Access plotting methods for clustering visualization. - - Returns: - ClusteringPlotAccessor with compare(), heatmap(), and clusters() methods. - """ - return ClusteringPlotAccessor(self) - def __repr__(self) -> str: return ( f'Clustering(\n' @@ -508,374 +386,6 @@ def __repr__(self) -> str: ) -class ClusteringPlotAccessor: - """Plot accessor for Clustering objects. - - Provides visualization methods for comparing original vs aggregated data - and understanding the clustering structure. - """ - - def __init__(self, clustering: Clustering): - self._clustering = clustering - - def compare( - self, - kind: str = 'timeseries', - variables: str | list[str] | None = None, - *, - select: SelectType | None = None, - colors: ColorType | None = None, - show: bool | None = None, - data_only: bool = False, - **plotly_kwargs: Any, - ) -> PlotResult: - """Compare original vs aggregated data. - - Args: - kind: Type of comparison plot. - - 'timeseries': Time series comparison (default) - - 'duration_curve': Sorted duration curve comparison - variables: Variable(s) to plot. Can be a string, list of strings, - or None to plot all time-varying variables. - select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. - colors: Color specification (colorscale name, color list, or label-to-color dict). - show: Whether to display the figure. - Defaults to CONFIG.Plotting.default_show. - data_only: If True, skip figure creation and return only data. - **plotly_kwargs: Additional arguments passed to plotly (e.g., color, line_dash, - facet_col, facet_row). Defaults: x='time'/'duration', color='variable', - line_dash='representation', symbol=None. - - Returns: - PlotResult containing the comparison figure and underlying data. - """ - import plotly.graph_objects as go - - from ..config import CONFIG - from ..plot_result import PlotResult - from ..statistics_accessor import _apply_selection - - if kind not in ('timeseries', 'duration_curve'): - raise ValueError(f"Unknown kind '{kind}'. Use 'timeseries' or 'duration_curve'.") - - clustering = self._clustering - if clustering.original_data is None or clustering.aggregated_data is None: - raise ValueError('No original/aggregated data available for comparison') - - resolved_variables = self._resolve_variables(variables) - - # Build Dataset with variables as data_vars - data_vars = {} - for var in resolved_variables: - original = clustering.original_data[var] - clustered = clustering.disaggregate(clustering.aggregated_data[var]) - if clustering.is_segmented: - clustered = clustered.ffill(dim='time') - combined = xr.concat([original, clustered], dim=pd.Index(['Original', 'Clustered'], name='representation')) - data_vars[var] = combined - ds = xr.Dataset(data_vars) - - ds = _apply_selection(ds, select) - - if kind == 'duration_curve': - sorted_vars = {} - # Use variables for faster access (avoids _construct_dataarray overhead) - variables = ds.variables - rep_values = ds.coords['representation'].values - rep_idx = {rep: i for i, rep in enumerate(rep_values)} - for var in ds.data_vars: - data = variables[var].values - for rep in rep_values: - # Direct numpy indexing instead of .sel() - values = np.sort(data[rep_idx[rep]].flatten())[::-1] - sorted_vars[(var, rep)] = values - # Get length from first sorted array - n = len(next(iter(sorted_vars.values()))) - ds = xr.Dataset( - { - var: xr.DataArray( - [sorted_vars[(var, r)] for r in ['Original', 'Clustered']], - dims=['representation', 'duration'], - coords={'representation': ['Original', 'Clustered'], 'duration': range(n)}, - ) - for var in resolved_variables - } - ) - - title = ( - ( - 'Original vs Clustered' - if len(resolved_variables) > 1 - else f'Original vs Clustered: {resolved_variables[0]}' - ) - if kind == 'timeseries' - else ('Duration Curve' if len(resolved_variables) > 1 else f'Duration Curve: {resolved_variables[0]}') - ) - - # Early return for data_only mode - if data_only: - return PlotResult(data=ds, figure=go.Figure()) - - # Apply slot defaults - defaults = { - 'x': 'duration' if kind == 'duration_curve' else 'time', - 'color': 'variable', - 'line_dash': 'representation', - 'line_dash_map': {'Original': 'dot', 'Clustered': 'solid'}, - 'symbol': None, # Block symbol slot - } - _apply_slot_defaults(plotly_kwargs, defaults) - - color_kwargs = _build_color_kwargs(colors, list(ds.data_vars)) - fig = ds.plotly.line( - title=title, - **color_kwargs, - **plotly_kwargs, - ) - fig.update_yaxes(matches=None) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) - - plot_result = PlotResult(data=ds, figure=fig) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - plot_result.show() - - return plot_result - - def _get_time_varying_variables(self) -> list[str]: - """Get list of time-varying variables from original data that also exist in aggregated data.""" - if self._clustering.original_data is None: - return [] - # Get variables that exist in both original and aggregated data - aggregated_vars = ( - set(self._clustering.aggregated_data.data_vars) - if self._clustering.aggregated_data is not None - else set(self._clustering.original_data.data_vars) - ) - return [ - name - for name in self._clustering.original_data.data_vars - if name in aggregated_vars - and 'time' in self._clustering.original_data[name].dims - and not np.isclose( - self._clustering.original_data[name].min(), - self._clustering.original_data[name].max(), - ) - ] - - def _resolve_variables(self, variables: str | list[str] | None) -> list[str]: - """Resolve variables parameter to a list of valid variable names.""" - time_vars = self._get_time_varying_variables() - if not time_vars: - raise ValueError('No time-varying variables found') - - if variables is None: - return time_vars - elif isinstance(variables, str): - if variables not in time_vars: - raise ValueError(f"Variable '{variables}' not found. Available: {time_vars}") - return [variables] - else: - invalid = [v for v in variables if v not in time_vars] - if invalid: - raise ValueError(f'Variables {invalid} not found. Available: {time_vars}') - return list(variables) - - def heatmap( - self, - *, - select: SelectType | None = None, - colors: str | list[str] | None = None, - show: bool | None = None, - data_only: bool = False, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot cluster assignments over time as a heatmap timeline. - - Shows which cluster each timestep belongs to as a horizontal color bar. - The x-axis is time, color indicates cluster assignment. This visualization - aligns with time series data, making it easy to correlate cluster - assignments with other plots. - - For multi-period/scenario data, uses faceting and/or animation. - - Args: - select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. - colors: Colorscale name (str) or list of colors for heatmap coloring. - Dicts are not supported for heatmaps. - Defaults to plotly template's sequential colorscale. - show: Whether to display the figure. - Defaults to CONFIG.Plotting.default_show. - data_only: If True, skip figure creation and return only data. - **plotly_kwargs: Additional arguments passed to plotly (e.g., facet_col, animation_frame). - - Returns: - PlotResult containing the heatmap figure and cluster assignment data. - The data has 'cluster' variable with time dimension, matching original timesteps. - """ - import plotly.graph_objects as go - - from ..config import CONFIG - from ..plot_result import PlotResult - from ..statistics_accessor import _apply_selection - - clustering = self._clustering - cluster_assignments = clustering.cluster_assignments - timesteps_per_cluster = clustering.timesteps_per_cluster - original_time = clustering.original_timesteps - - if select: - cluster_assignments = _apply_selection(cluster_assignments.to_dataset(name='cluster'), select)['cluster'] - - # Expand cluster_assignments to per-timestep - extra_dims = [d for d in cluster_assignments.dims if d != 'original_cluster'] - expanded_values = np.repeat(cluster_assignments.values, timesteps_per_cluster, axis=0) - - coords = {'time': original_time} - coords.update({d: cluster_assignments.coords[d].values for d in extra_dims}) - cluster_da = xr.DataArray(expanded_values, dims=['time'] + extra_dims, coords=coords) - cluster_da.name = 'cluster' - - # Early return for data_only mode - if data_only: - return PlotResult(data=xr.Dataset({'cluster': cluster_da}), figure=go.Figure()) - - heatmap_da = cluster_da.expand_dims('y', axis=-1).assign_coords(y=['Cluster']) - heatmap_da.name = 'cluster_assignment' - heatmap_da = heatmap_da.transpose('time', 'y', ...) - - # Use plotly.imshow for heatmap - # Only pass color_continuous_scale if explicitly provided (template handles default) - if colors is not None: - plotly_kwargs.setdefault('color_continuous_scale', colors) - fig = heatmap_da.plotly.imshow( - title='Cluster Assignments', - aspect='auto', - **plotly_kwargs, - ) - - fig.update_yaxes(showticklabels=False) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) - - # Data is exactly what we plotted (without dummy y dimension) - data = xr.Dataset({'cluster': cluster_da}) - plot_result = PlotResult(data=data, figure=fig) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - plot_result.show() - - return plot_result - - def clusters( - self, - variables: str | list[str] | None = None, - *, - select: SelectType | None = None, - colors: ColorType | None = None, - show: bool | None = None, - data_only: bool = False, - **plotly_kwargs: Any, - ) -> PlotResult: - """Plot each cluster's typical period profile. - - Shows each cluster as a separate faceted subplot with all variables - colored differently. Useful for understanding what each cluster represents. - - Args: - variables: Variable(s) to plot. Can be a string, list of strings, - or None to plot all time-varying variables. - select: xarray-style selection dict, e.g. {'scenario': 'Base Case'}. - colors: Color specification (colorscale name, color list, or label-to-color dict). - show: Whether to display the figure. - Defaults to CONFIG.Plotting.default_show. - data_only: If True, skip figure creation and return only data. - **plotly_kwargs: Additional arguments passed to plotly (e.g., color, facet_col, - facet_col_wrap). Defaults: x='time', color='variable', symbol=None. - - Returns: - PlotResult containing the figure and underlying data. - """ - import plotly.graph_objects as go - - from ..config import CONFIG - from ..plot_result import PlotResult - from ..statistics_accessor import _apply_selection - - clustering = self._clustering - if clustering.aggregated_data is None: - raise ValueError('No aggregated data available') - - aggregated_data = _apply_selection(clustering.aggregated_data, select) - resolved_variables = self._resolve_variables(variables) - - n_clusters = clustering.n_clusters - timesteps_per_cluster = clustering.timesteps_per_cluster - cluster_occurrences = clustering.cluster_occurrences - - # Build cluster labels - occ_extra_dims = [d for d in cluster_occurrences.dims if d != 'cluster'] - if occ_extra_dims: - cluster_labels = [f'Cluster {c}' for c in range(n_clusters)] - else: - cluster_labels = [ - f'Cluster {c} (×{int(cluster_occurrences.sel(cluster=c).values)})' for c in range(n_clusters) - ] - - data_vars = {} - for var in resolved_variables: - da = aggregated_data[var] - if 'cluster' in da.dims: - data_by_cluster = da.values - else: - data_by_cluster = da.values.reshape(n_clusters, timesteps_per_cluster) - data_vars[var] = xr.DataArray( - data_by_cluster, - dims=['cluster', 'time'], - coords={'cluster': cluster_labels, 'time': range(timesteps_per_cluster)}, - ) - - ds = xr.Dataset(data_vars) - - # Early return for data_only mode (include occurrences in result) - if data_only: - data_vars['occurrences'] = cluster_occurrences - return PlotResult(data=xr.Dataset(data_vars), figure=go.Figure()) - - title = 'Clusters' if len(resolved_variables) > 1 else f'Clusters: {resolved_variables[0]}' - - # Apply slot defaults - defaults = { - 'x': 'time', - 'color': 'variable', - 'symbol': None, # Block symbol slot - } - _apply_slot_defaults(plotly_kwargs, defaults) - - color_kwargs = _build_color_kwargs(colors, list(ds.data_vars)) - fig = ds.plotly.line( - title=title, - **color_kwargs, - **plotly_kwargs, - ) - fig.update_yaxes(matches=None) - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) - - data_vars['occurrences'] = cluster_occurrences - result_data = xr.Dataset(data_vars) - plot_result = PlotResult(data=result_data, figure=fig) - - if show is None: - show = CONFIG.Plotting.default_show - if show: - plot_result.show() - - return plot_result - - def _register_clustering_classes(): """Register clustering classes for IO.""" from ..structure import CLASS_REGISTRY diff --git a/flixopt/flow_system.py b/flixopt/flow_system.py index e838c5480..df71e2ff5 100644 --- a/flixopt/flow_system.py +++ b/flixopt/flow_system.py @@ -696,7 +696,7 @@ def _create_reference_structure(self) -> tuple[dict, dict[str, xr.DataArray]]: return reference_structure, all_extracted_arrays - def to_dataset(self, include_solution: bool = True, include_original_data: bool = True) -> xr.Dataset: + def to_dataset(self, include_solution: bool = True) -> xr.Dataset: """ Convert the FlowSystem to an xarray Dataset. Ensures FlowSystem is connected before serialization. @@ -714,10 +714,6 @@ def to_dataset(self, include_solution: bool = True, include_original_data: bool include_solution: Whether to include the optimization solution in the dataset. Defaults to True. Set to False to get only the FlowSystem structure without solution data (useful for copying or saving templates). - include_original_data: Whether to include clustering.original_data in the dataset. - Defaults to True. Set to False for smaller files (~38% reduction) when - clustering.plot.compare() isn't needed after loading. The core workflow - (optimize → expand) works without original_data. Returns: xr.Dataset: Dataset containing all DataArrays with structure in attributes @@ -734,7 +730,7 @@ def to_dataset(self, include_solution: bool = True, include_original_data: bool base_ds = super().to_dataset() # Add FlowSystem-specific data (solution, clustering, metadata) - return fx_io.flow_system_to_dataset(self, base_ds, include_solution, include_original_data) + return fx_io.flow_system_to_dataset(self, base_ds, include_solution) @classmethod def from_dataset(cls, ds: xr.Dataset) -> FlowSystem: @@ -766,7 +762,6 @@ def to_netcdf( path: str | pathlib.Path, compression: int = 5, overwrite: bool = False, - include_original_data: bool = True, ): """ Save the FlowSystem to a NetCDF file. @@ -779,9 +774,6 @@ def to_netcdf( path: The path to the netCDF file. Parent directories are created if they don't exist. compression: The compression level to use when saving the file (0-9). overwrite: If True, overwrite existing file. If False, raise error if file exists. - include_original_data: Whether to include clustering.original_data in the file. - Defaults to True. Set to False for smaller files (~38% reduction) when - clustering.plot.compare() isn't needed after loading. Raises: FileExistsError: If overwrite=False and file already exists. @@ -801,7 +793,7 @@ def to_netcdf( self.name = path.stem try: - ds = self.to_dataset(include_original_data=include_original_data) + ds = self.to_dataset() fx_io.save_dataset_to_netcdf(ds, path, compression=compression) logger.info(f'Saved FlowSystem to {path}') except Exception as e: diff --git a/flixopt/io.py b/flixopt/io.py index 592bdad79..20d302204 100644 --- a/flixopt/io.py +++ b/flixopt/io.py @@ -1858,13 +1858,6 @@ def _restore_clustering( clustering = fs_cls._resolve_reference_structure(clustering_structure, clustering_arrays) flow_system.clustering = clustering - # Reconstruct aggregated_data from FlowSystem's main data arrays - if clustering.aggregated_data is None and main_var_names: - from .core import drop_constant_arrays - - main_vars = {name: arrays_dict[name] for name in main_var_names} - clustering.aggregated_data = drop_constant_arrays(xr.Dataset(main_vars), dim='time') - # Restore cluster_weight from clustering's cluster_occurrences if hasattr(clustering, 'cluster_occurrences'): flow_system.cluster_weight = clustering.cluster_occurrences.rename('cluster_weight') @@ -1904,7 +1897,6 @@ def to_dataset( flow_system: FlowSystem, base_dataset: xr.Dataset, include_solution: bool = True, - include_original_data: bool = True, ) -> xr.Dataset: """Convert FlowSystem-specific data to dataset. @@ -1915,7 +1907,6 @@ def to_dataset( flow_system: The FlowSystem to serialize base_dataset: Dataset from parent class with basic structure include_solution: Whether to include optimization solution - include_original_data: Whether to include clustering.original_data Returns: Complete dataset with all FlowSystem data @@ -1931,7 +1922,7 @@ def to_dataset( ds = cls._add_carriers_to_dataset(ds, flow_system._carriers) # Add clustering - ds = cls._add_clustering_to_dataset(ds, flow_system.clustering, include_original_data) + ds = cls._add_clustering_to_dataset(ds, flow_system.clustering) # Add variable categories ds = cls._add_variable_categories_to_dataset(ds, flow_system._variable_categories) @@ -1996,17 +1987,13 @@ def _add_clustering_to_dataset( cls, ds: xr.Dataset, clustering: Any, - include_original_data: bool, ) -> xr.Dataset: """Add clustering object to dataset.""" if clustering is not None: - clustering_ref, clustering_arrays = clustering._create_reference_structure( - include_original_data=include_original_data - ) - # Add clustering arrays with prefix using batch assignment - # (individual ds[name] = arr assignments are slow) - prefixed_arrays = {f'{cls.CLUSTERING_PREFIX}{name}': arr for name, arr in clustering_arrays.items()} - ds = ds.assign(prefixed_arrays) + clustering_ref, clustering_arrays = clustering._create_reference_structure() + if clustering_arrays: + prefixed_arrays = {f'{cls.CLUSTERING_PREFIX}{name}': arr for name, arr in clustering_arrays.items()} + ds = ds.assign(prefixed_arrays) ds.attrs['clustering'] = json.dumps(clustering_ref, ensure_ascii=False) return ds @@ -2064,7 +2051,6 @@ def flow_system_to_dataset( flow_system: FlowSystem, base_dataset: xr.Dataset, include_solution: bool = True, - include_original_data: bool = True, ) -> xr.Dataset: """Convert FlowSystem-specific data to dataset. @@ -2075,7 +2061,6 @@ def flow_system_to_dataset( flow_system: The FlowSystem to serialize base_dataset: Dataset from parent class with basic structure include_solution: Whether to include optimization solution - include_original_data: Whether to include clustering.original_data Returns: Complete dataset with all FlowSystem data @@ -2083,4 +2068,4 @@ def flow_system_to_dataset( See Also: FlowSystemDatasetIO: Class containing the implementation """ - return FlowSystemDatasetIO.to_dataset(flow_system, base_dataset, include_solution, include_original_data) + return FlowSystemDatasetIO.to_dataset(flow_system, base_dataset, include_solution) diff --git a/flixopt/transform_accessor.py b/flixopt/transform_accessor.py index ebc2447d6..bcd0b23bb 100644 --- a/flixopt/transform_accessor.py +++ b/flixopt/transform_accessor.py @@ -82,11 +82,6 @@ def _unrename(self, da: xr.DataArray) -> xr.DataArray: renames = {k: v for k, v in self._unrename_map.items() if k in da.dims} return da.rename(renames) if renames else da - def _unrename_ds(self, ds: xr.Dataset) -> xr.Dataset: - """Rename tsam_xarray output dims back to original names in Dataset.""" - renames = {k: v for k, v in self._unrename_map.items() if k in ds.dims} - return ds.rename(renames) if renames else ds - def build_cluster_weights(self) -> xr.DataArray: """Build cluster_weight DataArray from aggregation result. @@ -142,32 +137,6 @@ def build_segment_durations(self) -> xr.DataArray: other_dims = [d for d in da.dims if d not in ('cluster', 'time')] return self._unrename(da.transpose('cluster', 'time', *other_dims).rename('timestep_duration')) - def build_metrics(self) -> xr.Dataset: - """Build clustering metrics Dataset from aggregation result. - - Returns: - Dataset with RMSE, MAE, RMSE_duration metrics. - """ - accuracy = self._agg_result.accuracy - try: - data_vars = {} - for metric_name, metric_da in [ - ('RMSE', accuracy.rmse), - ('MAE', accuracy.mae), - ('RMSE_duration', accuracy.rmse_duration), - ]: - # Rename the variable dimension to 'time_series' - known_metric_dims = {'period', 'scenario'} | set(self._unrename_map.keys()) - unknown_dims = [d for d in metric_da.dims if d not in known_metric_dims] - assert len(unknown_dims) == 1, f'Expected 1 variable dim in {metric_name}, got {unknown_dims}' - variable_dim = unknown_dims[0] - da = metric_da.rename({variable_dim: 'time_series'}) - data_vars[metric_name] = da - return self._unrename_ds(xr.Dataset(data_vars)) - except Exception as e: - logger.warning(f'Failed to compute clustering metrics: {e}') - return xr.Dataset() - def build_reduced_dataset(self, ds: xr.Dataset, typical_das: dict[str, xr.DataArray]) -> xr.Dataset: """Build the reduced dataset with (cluster, time) structure. @@ -235,13 +204,11 @@ def build(self, ds: xr.Dataset) -> FlowSystem: Reduced FlowSystem with clustering metadata attached. """ from .clustering import Clustering - from .core import drop_constant_arrays from .flow_system import FlowSystem # Build all components cluster_weight = self.build_cluster_weights() typical_das = self.build_typical_periods() - metrics = self.build_metrics() ds_new = self.build_reduced_dataset(ds, typical_das) # Add segment durations if segmented @@ -272,9 +239,6 @@ def build(self, ds: xr.Dataset) -> FlowSystem: # Create Clustering object with full AggregationResult access reduced_fs.clustering = Clustering( original_timesteps=self._fs.timesteps, - original_data=drop_constant_arrays(ds, dim='time'), - aggregated_data=drop_constant_arrays(ds_new, dim='time'), - _metrics=metrics if metrics.data_vars else None, _aggregation_result=self._agg_result, _unrename_map=self._unrename_map, ) diff --git a/tests/test_clustering/test_base.py b/tests/test_clustering/test_base.py index ed8831633..f69de4cdf 100644 --- a/tests/test_clustering/test_base.py +++ b/tests/test_clustering/test_base.py @@ -148,77 +148,3 @@ def test_multi_period_clustering(self, mock_cr_factory): assert clustering.n_clusters == 2 assert 'period' in clustering.cluster_occurrences.dims assert clustering.dim_names == ['period'] - - -class TestClusteringPlotAccessor: - """Tests for ClusteringPlotAccessor.""" - - @pytest.fixture - def clustering_with_data(self): - """Create Clustering with original and aggregated data.""" - - class MockClusteringResult: - n_clusters = 2 - n_original_periods = 3 - n_timesteps_per_period = 24 - cluster_assignments = (0, 1, 0) - period_duration = 24.0 - n_segments = None - segment_assignments = None - cluster_centers = (0, 1) - - cr_result = _make_clustering_result({(): MockClusteringResult()}, []) - original_timesteps = pd.date_range('2024-01-01', periods=72, freq='h') - - original_data = xr.Dataset( - { - 'col1': xr.DataArray(np.random.randn(72), dims=['time'], coords={'time': original_timesteps}), - } - ) - aggregated_data = xr.Dataset( - { - 'col1': xr.DataArray( - np.random.randn(2, 24), - dims=['cluster', 'time'], - coords={'cluster': [0, 1], 'time': pd.date_range('2000-01-01', periods=24, freq='h')}, - ), - } - ) - - return Clustering( - clustering_result=cr_result, - original_timesteps=original_timesteps, - original_data=original_data, - aggregated_data=aggregated_data, - ) - - def test_plot_accessor_exists(self, clustering_with_data): - """Test that plot accessor is available.""" - assert hasattr(clustering_with_data, 'plot') - assert hasattr(clustering_with_data.plot, 'compare') - assert hasattr(clustering_with_data.plot, 'heatmap') - assert hasattr(clustering_with_data.plot, 'clusters') - - def test_compare_requires_data(self): - """Test compare() raises when no data available.""" - - class MockClusteringResult: - n_clusters = 2 - n_original_periods = 2 - n_timesteps_per_period = 24 - cluster_assignments = (0, 1) - period_duration = 24.0 - n_segments = None - segment_assignments = None - cluster_centers = (0, 1) - - cr_result = _make_clustering_result({(): MockClusteringResult()}, []) - original_timesteps = pd.date_range('2024-01-01', periods=48, freq='h') - - clustering = Clustering( - clustering_result=cr_result, - original_timesteps=original_timesteps, - ) - - with pytest.raises(ValueError, match='No original/aggregated data'): - clustering.plot.compare() diff --git a/tests/test_clustering/test_clustering_io.py b/tests/test_clustering/test_clustering_io.py index 811fd3b7c..93769b167 100644 --- a/tests/test_clustering/test_clustering_io.py +++ b/tests/test_clustering/test_clustering_io.py @@ -70,13 +70,9 @@ def test_clustering_to_dataset_has_clustering_attrs(self, simple_system_8_days): ds = fs_clustered.to_dataset(include_solution=False) - # Check that clustering attrs are present + # Check that clustering attrs are present (serialized as JSON string) assert 'clustering' in ds.attrs - # Check that clustering arrays are present with prefix - clustering_vars = [name for name in ds.data_vars if name.startswith('clustering|')] - assert len(clustering_vars) > 0 - def test_clustering_roundtrip_preserves_clustering_object(self, simple_system_8_days): """Clustering object should be restored after roundtrip.""" from flixopt.clustering import Clustering diff --git a/tests/test_clustering/test_expansion_regression.py b/tests/test_clustering/test_expansion_regression.py new file mode 100644 index 000000000..1bce3b4e7 --- /dev/null +++ b/tests/test_clustering/test_expansion_regression.py @@ -0,0 +1,157 @@ +"""Regression tests for cluster → optimize → expand numerical equivalence. + +These tests verify that the expanded solution values match known reference +values, catching any changes in the clustering/expansion pipeline. +""" + +import numpy as np +import pandas as pd +import pytest + +import flixopt as fx + +tsam = pytest.importorskip('tsam') + + +@pytest.fixture +def system_with_storage(): + """System with storage (tests charge_state) and effects (tests segment totals).""" + ts = pd.date_range('2020-01-01', periods=192, freq='h') # 8 days + demand = np.sin(np.linspace(0, 16 * np.pi, 192)) * 10 + 15 + + fs = fx.FlowSystem(ts) + fs.add_elements( + fx.Bus('Heat'), + fx.Bus('Gas'), + fx.Effect('costs', '€', is_standard=True, is_objective=True), + fx.Sink('D', inputs=[fx.Flow('Q', bus='Heat', fixed_relative_profile=demand, size=1)]), + fx.Source('G', outputs=[fx.Flow('Gas', bus='Gas', effects_per_flow_hour=0.05)]), + fx.linear_converters.Boiler( + 'B', + thermal_efficiency=0.9, + fuel_flow=fx.Flow('Q_fu', bus='Gas'), + thermal_flow=fx.Flow('Q_th', bus='Heat'), + ), + fx.Storage( + 'S', + capacity_in_flow_hours=50, + initial_charge_state=0.5, + charging=fx.Flow('in', bus='Heat', size=10), + discharging=fx.Flow('out', bus='Heat', size=10), + ), + ) + return fs + + +class TestNonSegmentedExpansion: + """Test that non-segmented cluster → expand produces correct values.""" + + def test_expanded_objective_matches(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster(n_clusters=2, cluster_duration='1D') + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + assert fs_e.solution['objective'].item() == pytest.approx(160.0, abs=1e-6) + + def test_expanded_flow_rates(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster(n_clusters=2, cluster_duration='1D') + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + sol = fs_e.solution + assert float(np.nansum(sol['B(Q_th)|flow_rate'].values)) == pytest.approx(2880.0, abs=1e-6) + assert float(np.nansum(sol['D(Q)|flow_rate'].values)) == pytest.approx(2880.0, abs=1e-6) + assert float(np.nansum(sol['G(Gas)|flow_rate'].values)) == pytest.approx(3200.0, abs=1e-6) + + def test_expanded_costs(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster(n_clusters=2, cluster_duration='1D') + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + sol = fs_e.solution + assert float(np.nansum(sol['costs(temporal)|per_timestep'].values)) == pytest.approx(160.0, abs=1e-6) + assert float(np.nansum(sol['G(Gas)->costs(temporal)'].values)) == pytest.approx(160.0, abs=1e-6) + + def test_expanded_storage(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster(n_clusters=2, cluster_duration='1D') + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + sol = fs_e.solution + # Storage dispatch varies by solver — check charge_state is non-trivial + assert float(np.nansum(sol['S|charge_state'].values)) > 0 + # Net discharge should be ~0 (balanced storage) + assert float(np.nansum(sol['S|netto_discharge'].values)) == pytest.approx(0, abs=1e-4) + + def test_expanded_shapes(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster(n_clusters=2, cluster_duration='1D') + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + sol = fs_e.solution + # 192 original timesteps + 1 extra boundary = 193 + for name in sol.data_vars: + if 'time' in sol[name].dims: + assert sol[name].sizes['time'] == 193, f'{name} has wrong time size' + + +class TestSegmentedExpansion: + """Test that segmented cluster → expand produces correct values.""" + + def test_expanded_objective_matches(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster( + n_clusters=2, cluster_duration='1D', segments=tsam.SegmentConfig(n_segments=6) + ) + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + assert fs_e.solution['objective'].item() == pytest.approx(160.0, abs=1e-6) + + def test_expanded_flow_rates(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster( + n_clusters=2, cluster_duration='1D', segments=tsam.SegmentConfig(n_segments=6) + ) + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + sol = fs_e.solution + assert float(np.nansum(sol['B(Q_th)|flow_rate'].values)) == pytest.approx(2880.0, abs=1e-6) + assert float(np.nansum(sol['D(Q)|flow_rate'].values)) == pytest.approx(2880.0, abs=1e-6) + assert float(np.nansum(sol['G(Gas)|flow_rate'].values)) == pytest.approx(3200.0, abs=1e-6) + + def test_expanded_costs(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster( + n_clusters=2, cluster_duration='1D', segments=tsam.SegmentConfig(n_segments=6) + ) + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + sol = fs_e.solution + assert float(np.nansum(sol['costs(temporal)|per_timestep'].values)) == pytest.approx(160.0, abs=1e-6) + assert float(np.nansum(sol['G(Gas)->costs(temporal)'].values)) == pytest.approx(160.0, abs=1e-6) + + def test_expanded_shapes(self, system_with_storage, solver_fixture): + fs_c = system_with_storage.transform.cluster( + n_clusters=2, cluster_duration='1D', segments=tsam.SegmentConfig(n_segments=6) + ) + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + sol = fs_e.solution + for name in sol.data_vars: + if 'time' in sol[name].dims: + assert sol[name].sizes['time'] == 193, f'{name} has wrong time size' + + def test_no_nans_in_expanded_flow_rates(self, system_with_storage, solver_fixture): + """Segmented expansion must ffill — no NaNs in flow rates (except extra boundary).""" + fs_c = system_with_storage.transform.cluster( + n_clusters=2, cluster_duration='1D', segments=tsam.SegmentConfig(n_segments=6) + ) + fs_c.optimize(solver_fixture) + fs_e = fs_c.transform.expand() + + sol = fs_e.solution + for name in ['B(Q_th)|flow_rate', 'D(Q)|flow_rate', 'G(Gas)|flow_rate']: + # Exclude last timestep (extra boundary, may be NaN for non-state variables) + vals = sol[name].isel(time=slice(None, -1)) + assert not vals.isnull().any(), f'{name} has NaN values after expansion'