Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 55 additions & 137 deletions flixopt/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
from ..statistics_accessor import SelectType


def _select_dims(da: xr.DataArray, period: str | None = None, scenario: str | None = None) -> xr.DataArray:
"""Select from DataArray by period/scenario if those dimensions exist."""
if 'period' in da.dims and period is not None:
da = da.sel(period=period)
if 'scenario' in da.dims and scenario is not None:
da = da.sel(scenario=scenario)
return da


@dataclass
class ClusterStructure:
"""Structure information for inter-cluster storage linking.
Expand Down Expand Up @@ -152,12 +161,7 @@ def get_cluster_order_for_slice(self, period: str | None = None, scenario: str |
Returns:
1D numpy array of cluster indices for the specified slice.
"""
order = self.cluster_order
if 'period' in order.dims and period is not None:
order = order.sel(period=period)
if 'scenario' in order.dims and scenario is not None:
order = order.sel(scenario=scenario)
return order.values.astype(int)
return _select_dims(self.cluster_order, period, scenario).values.astype(int)

def get_cluster_occurrences_for_slice(
self, period: str | None = None, scenario: str | None = None
Expand All @@ -171,12 +175,8 @@ def get_cluster_occurrences_for_slice(
Returns:
Dict mapping cluster ID to occurrence count.
"""
occurrences = self.cluster_occurrences
if 'period' in occurrences.dims and period is not None:
occurrences = occurrences.sel(period=period)
if 'scenario' in occurrences.dims and scenario is not None:
occurrences = occurrences.sel(scenario=scenario)
return {int(c): int(occurrences.sel(cluster=c).values) for c in occurrences.coords['cluster'].values}
occ = _select_dims(self.cluster_occurrences, period, scenario)
return {int(c): int(occ.sel(cluster=c).values) for c in occ.coords['cluster'].values}

def plot(self, colors: str | list[str] | None = None, show: bool | None = None) -> PlotResult:
"""Plot cluster assignment visualization.
Expand Down Expand Up @@ -372,12 +372,7 @@ def get_timestep_mapping_for_slice(self, period: str | None = None, scenario: st
Returns:
1D numpy array of representative timestep indices for the specified slice.
"""
mapping = self.timestep_mapping
if 'period' in mapping.dims and period is not None:
mapping = mapping.sel(period=period)
if 'scenario' in mapping.dims and scenario is not None:
mapping = mapping.sel(scenario=scenario)
return mapping.values.astype(int)
return _select_dims(self.timestep_mapping, period, scenario).values.astype(int)

def expand_data(self, aggregated: xr.DataArray, original_time: xr.DataArray | None = None) -> xr.DataArray:
"""Expand aggregated data back to original timesteps.
Expand All @@ -400,89 +395,52 @@ def expand_data(self, aggregated: xr.DataArray, original_time: xr.DataArray | No
>>> expanded = result.expand_data(aggregated_values)
>>> len(expanded.time) == len(original_timesteps) # True
"""
import pandas as pd

if original_time is None:
if self.original_data is None:
raise ValueError('original_time required when original_data is not available')
original_time = self.original_data.coords['time']

timestep_mapping = self.timestep_mapping
has_periods = 'period' in timestep_mapping.dims
has_scenarios = 'scenario' in timestep_mapping.dims
has_cluster_dim = 'cluster' in aggregated.dims
timesteps_per_cluster = self.cluster_structure.timesteps_per_cluster if has_cluster_dim else None

# Simple case: no period/scenario dimensions
if not has_periods and not has_scenarios:
mapping = timestep_mapping.values
def _expand_slice(mapping: np.ndarray, data: xr.DataArray) -> np.ndarray:
"""Expand a single slice using the mapping."""
if has_cluster_dim:
# 2D cluster structure: convert flat indices to (cluster, time_within)
# Use cluster_structure's timesteps_per_cluster, not aggregated.sizes['time']
# because the solution may include extra timesteps (timesteps_extra)
timesteps_per_cluster = self.cluster_structure.timesteps_per_cluster
cluster_ids = mapping // timesteps_per_cluster
time_within = mapping % timesteps_per_cluster
expanded_values = aggregated.values[cluster_ids, time_within]
else:
expanded_values = aggregated.values[mapping]
return xr.DataArray(
expanded_values,
coords={'time': original_time},
dims=['time'],
attrs=aggregated.attrs,
)
return data.values[cluster_ids, time_within]
return data.values[mapping]

# Multi-dimensional: expand each (period, scenario) slice and recombine
periods = list(timestep_mapping.coords['period'].values) if has_periods else [None]
scenarios = list(timestep_mapping.coords['scenario'].values) if has_scenarios else [None]

expanded_slices: dict[tuple, xr.DataArray] = {}
for p in periods:
for s in scenarios:
# Get mapping for this slice
mapping_slice = timestep_mapping
if p is not None:
mapping_slice = mapping_slice.sel(period=p)
if s is not None:
mapping_slice = mapping_slice.sel(scenario=s)
mapping = mapping_slice.values

# Select the data slice
selector = {}
if p is not None and 'period' in aggregated.dims:
selector['period'] = p
if s is not None and 'scenario' in aggregated.dims:
selector['scenario'] = s

slice_da = aggregated.sel(**selector, drop=True) if selector else aggregated

if has_cluster_dim:
# 2D cluster structure: convert flat indices to (cluster, time_within)
# Use cluster_structure's timesteps_per_cluster, not slice_da.sizes['time']
# because the solution may include extra timesteps (timesteps_extra)
timesteps_per_cluster = self.cluster_structure.timesteps_per_cluster
cluster_ids = mapping // timesteps_per_cluster
time_within = mapping % timesteps_per_cluster
expanded_values = slice_da.values[cluster_ids, time_within]
expanded = xr.DataArray(expanded_values, dims=['time'])
else:
expanded = slice_da.isel(time=xr.DataArray(mapping, dims=['time']))
expanded_slices[(p, s)] = expanded.assign_coords(time=original_time)

# Recombine slices using xr.concat
if has_periods and has_scenarios:
period_arrays = []
for p in periods:
scenario_arrays = [expanded_slices[(p, s)] for s in scenarios]
period_arrays.append(xr.concat(scenario_arrays, dim=pd.Index(scenarios, name='scenario')))
result = xr.concat(period_arrays, dim=pd.Index(periods, name='period'))
elif has_periods:
result = xr.concat([expanded_slices[(p, None)] for p in periods], dim=pd.Index(periods, name='period'))
else:
result = xr.concat(
[expanded_slices[(None, s)] for s in scenarios], dim=pd.Index(scenarios, name='scenario')
# Simple case: no period/scenario dimensions
extra_dims = [d for d in timestep_mapping.dims if d != 'original_time']
if not extra_dims:
expanded_values = _expand_slice(timestep_mapping.values, aggregated)
return xr.DataArray(expanded_values, coords={'time': original_time}, dims=['time'], attrs=aggregated.attrs)

# Multi-dimensional: expand each slice and recombine
dim_coords = {d: list(timestep_mapping.coords[d].values) for d in extra_dims}
expanded_slices = {}
for combo in np.ndindex(*[len(v) for v in dim_coords.values()]):
selector = {d: dim_coords[d][i] for d, i in zip(extra_dims, combo, strict=True)}
mapping = _select_dims(timestep_mapping, **selector).values
data_slice = (
_select_dims(aggregated, **selector) if any(d in aggregated.dims for d in selector) else aggregated
)
expanded_slices[tuple(selector.values())] = xr.DataArray(
_expand_slice(mapping, data_slice), coords={'time': original_time}, dims=['time']
)

# Concatenate iteratively along each extra dimension
result_arrays = expanded_slices
for dim in reversed(extra_dims):
dim_vals = dim_coords[dim]
grouped = {}
for key, arr in result_arrays.items():
rest_key = key[:-1] if len(key) > 1 else ()
grouped.setdefault(rest_key, []).append(arr)
result_arrays = {k: xr.concat(v, dim=pd.Index(dim_vals, name=dim)) for k, v in grouped.items()}
result = list(result_arrays.values())[0]
return result.transpose('time', ...).assign_attrs(aggregated.attrs)

def validate(self) -> None:
Expand Down Expand Up @@ -748,8 +706,6 @@ def heatmap(
PlotResult containing the heatmap figure and cluster assignment data.
The data has 'cluster' variable with time dimension, matching original timesteps.
"""
import pandas as pd

from ..config import CONFIG
from ..plot_result import PlotResult
from ..statistics_accessor import _apply_selection
Expand All @@ -760,63 +716,25 @@ def heatmap(
raise ValueError('No cluster structure available')

cluster_order_da = cs.cluster_order
timesteps_per_period = cs.timesteps_per_cluster
timesteps_per_cluster = cs.timesteps_per_cluster
original_time = result.original_data.coords['time'] if result.original_data is not None else None

# Apply selection if provided
if select:
cluster_order_da = _apply_selection(cluster_order_da.to_dataset(name='cluster'), select)['cluster']

# Check for multi-dimensional data
has_periods = 'period' in cluster_order_da.dims
has_scenarios = 'scenario' in cluster_order_da.dims

# Get dimension values
periods = list(cluster_order_da.coords['period'].values) if has_periods else [None]
scenarios = list(cluster_order_da.coords['scenario'].values) if has_scenarios else [None]

# Build cluster assignment per timestep for each (period, scenario) slice
cluster_slices: dict[tuple, xr.DataArray] = {}
for p in periods:
for s in scenarios:
cluster_order = cs.get_cluster_order_for_slice(period=p, scenario=s)
# Expand: each cluster repeated timesteps_per_period times
cluster_per_timestep = np.repeat(cluster_order, timesteps_per_period)
cluster_slices[(p, s)] = xr.DataArray(
cluster_per_timestep,
dims=['time'],
coords={'time': original_time} if original_time is not None else None,
)

# Combine slices into multi-dimensional DataArray
if has_periods and has_scenarios:
period_arrays = []
for p in periods:
scenario_arrays = [cluster_slices[(p, s)] for s in scenarios]
period_arrays.append(xr.concat(scenario_arrays, dim=pd.Index(scenarios, name='scenario')))
cluster_da = xr.concat(period_arrays, dim=pd.Index(periods, name='period'))
elif has_periods:
cluster_da = xr.concat(
[cluster_slices[(p, None)] for p in periods],
dim=pd.Index(periods, name='period'),
)
elif has_scenarios:
cluster_da = xr.concat(
[cluster_slices[(None, s)] for s in scenarios],
dim=pd.Index(scenarios, name='scenario'),
)
else:
cluster_da = cluster_slices[(None, None)]
# Expand cluster_order to per-timestep: repeat each value timesteps_per_cluster times
# Uses np.repeat along axis=0 (original_cluster dim)
extra_dims = [d for d in cluster_order_da.dims if d != 'original_cluster']
expanded_values = np.repeat(cluster_order_da.values, timesteps_per_cluster, axis=0)
coords = {'time': original_time} if original_time is not None else {}
coords.update({d: cluster_order_da.coords[d].values for d in extra_dims})
cluster_da = xr.DataArray(expanded_values, dims=['time'] + extra_dims, coords=coords)

# Add dummy y dimension for heatmap visualization (single row)
heatmap_da = cluster_da.expand_dims('y', axis=-1)
heatmap_da = heatmap_da.assign_coords(y=['Cluster'])
heatmap_da = cluster_da.expand_dims('y', axis=-1).assign_coords(y=['Cluster'])
heatmap_da.name = 'cluster_assignment'

# Reorder dims so 'time' and 'y' are first (heatmap x/y axes)
# Other dims (period, scenario) will be used for faceting/animation
target_order = ['time', 'y'] + [d for d in heatmap_da.dims if d not in ('time', 'y')]
heatmap_da = heatmap_da.transpose(*target_order)
heatmap_da = heatmap_da.transpose('time', 'y', ...)

# Use fxplot.heatmap for smart defaults
fig = heatmap_da.fxplot.heatmap(
Expand Down
61 changes: 35 additions & 26 deletions flixopt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def consecutive_duration_tracking(
maximum_duration: xr.DataArray | None = None,
duration_dim: str = 'time',
duration_per_step: int | float | xr.DataArray = None,
previous_duration: xr.DataArray = 0,
previous_duration: xr.DataArray | None = 0,
) -> tuple[dict[str, linopy.Variable], dict[str, linopy.Constraint]]:
"""Creates consecutive duration tracking for a binary state variable.

Expand All @@ -278,7 +278,8 @@ def consecutive_duration_tracking(
maximum_duration: Optional maximum consecutive duration (upper bound on duration variable)
duration_dim: Dimension name to track duration along (default 'time')
duration_per_step: Time increment per step in duration_dim
previous_duration: Initial duration value before first timestep (default 0)
previous_duration: Initial duration value before first timestep (default 0). If None,
no initial constraint is added (relaxed initial state).

Returns:
Tuple of (duration_variable, constraints_dict)
Expand All @@ -287,7 +288,8 @@ def consecutive_duration_tracking(
if not isinstance(model, Submodel):
raise ValueError('ModelingPrimitives.consecutive_duration_tracking() can only be used with a Submodel')

mega = duration_per_step.sum(duration_dim) + previous_duration # Big-M value
# Big-M value (use 0 for previous_duration if None)
mega = duration_per_step.sum(duration_dim) + (previous_duration if previous_duration is not None else 0)

# Duration variable
duration = model.add_variables(
Expand Down Expand Up @@ -320,11 +322,13 @@ def consecutive_duration_tracking(
)

# Initial condition: duration[0] = (duration_per_step[0] + previous_duration) * state[0]
constraints['initial'] = model.add_constraints(
duration.isel({duration_dim: 0})
== (duration_per_step.isel({duration_dim: 0}) + previous_duration) * state.isel({duration_dim: 0}),
name=f'{duration.name}|initial',
)
# Skipped if previous_duration is None (relaxed initial state)
if previous_duration is not None:
constraints['initial'] = model.add_constraints(
duration.isel({duration_dim: 0})
== (duration_per_step.isel({duration_dim: 0}) + previous_duration) * state.isel({duration_dim: 0}),
name=f'{duration.name}|initial',
)

# Minimum duration constraint if provided
if minimum_duration is not None:
Expand All @@ -335,17 +339,18 @@ def consecutive_duration_tracking(
name=f'{duration.name}|lb',
)

# Handle initial condition for minimum duration
prev = (
float(previous_duration)
if not isinstance(previous_duration, xr.DataArray)
else float(previous_duration.max().item())
)
min0 = float(minimum_duration.isel({duration_dim: 0}).max().item())
if prev > 0 and prev < min0:
constraints['initial_lb'] = model.add_constraints(
state.isel({duration_dim: 0}) == 1, name=f'{duration.name}|initial_lb'
# Handle initial condition for minimum duration (skip if previous_duration is None)
if previous_duration is not None:
prev = (
float(previous_duration)
if not isinstance(previous_duration, xr.DataArray)
else float(previous_duration.max().item())
)
min0 = float(minimum_duration.isel({duration_dim: 0}).max().item())
if prev > 0 and prev < min0:
constraints['initial_lb'] = model.add_constraints(
state.isel({duration_dim: 0}) == 1, name=f'{duration.name}|initial_lb'
)

variables = {'duration': duration}

Expand Down Expand Up @@ -578,9 +583,9 @@ def state_transition_bounds(
activate: linopy.Variable,
deactivate: linopy.Variable,
name: str,
previous_state: float | xr.DataArray = 0,
previous_state: float | xr.DataArray | None = 0,
coord: str = 'time',
) -> tuple[linopy.Constraint, linopy.Constraint, linopy.Constraint]:
) -> tuple[linopy.Constraint, linopy.Constraint | None, linopy.Constraint]:
"""Creates state transition constraints for binary state variables.

Tracks transitions between active (1) and inactive (0) states using
Expand All @@ -598,7 +603,8 @@ def state_transition_bounds(
activate: Binary variable for transitions from inactive to active (0→1)
deactivate: Binary variable for transitions from active to inactive (1→0)
name: Base name for constraints
previous_state: State value before first timestep (default 0)
previous_state: State value before first timestep (default 0). If None,
no initial constraint is added (relaxed initial state).
coord: Time dimension name (default 'time')

Returns:
Expand All @@ -614,11 +620,14 @@ def state_transition_bounds(
name=f'{name}|transition',
)

# Initial state transition for t = 0
initial = model.add_constraints(
activate.isel({coord: 0}) - deactivate.isel({coord: 0}) == state.isel({coord: 0}) - previous_state,
name=f'{name}|initial',
)
# Initial state transition for t = 0 (skipped if previous_state is None for relaxed initial state)
if previous_state is not None:
initial = model.add_constraints(
activate.isel({coord: 0}) - deactivate.isel({coord: 0}) == state.isel({coord: 0}) - previous_state,
name=f'{name}|initial',
)
else:
initial = None

# At most one transition per timestep (mutual exclusivity)
mutex = model.add_constraints(activate + deactivate <= 1, name=f'{name}|mutex')
Expand Down
Loading
Loading