diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 2a51618d9..e9876c089 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -12,7 +12,7 @@ Thanks for your interest in contributing to FlixOpt! 🚀 2. **Install for Development** ```bash - pip install -e ".[full]" + pip install -e ".[full, dev, docs]" ``` 3. **Make Changes & Submit PR** diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b826569f..e8836e87d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,12 +54,21 @@ If upgrading from v2.x, see the [Migration Guide](https://flixopt.github.io/flix ### ✨ Added +- **Faceting and animation support for plots**: All plotting methods now support `facet_by` and `animate_by` parameters for creating subplot grids and animations with multidimensional data (scenarios, periods, etc.) +- **New `select` parameter**: Added to all plotting methods for flexible data selection using single values, lists, slices, and index arrays +- **Heatmap `fill` parameter**: Added `fill` parameter to heatmap plotting methods to control how missing values are filled after reshaping ('ffill' or 'bfill') +- **Dashed line styling**: Area plots now automatically style "mixed" variables (containing both positive and negative values) with dashed lines, while only stacking purely positive or negative variables ### 💥 Breaking Changes ### ♻️ Changed +- **Selection behavior**: Changed default selection behavior in plotting methods - no longer automatically selects first value for non-time dimensions. Use `select` parameter for explicit selection +- **Improved error messages**: Enhanced error messages when using matplotlib engine with multidimensional data, providing clearer guidance on dimension requirements +- Improved `scenario_example.py` +- Improved error handling in `plot_heatmap()` method for better dimension validation ### 🗑️ Deprecated +- **`indexer` parameter**: The `indexer` parameter in all plotting methods is deprecated in favor of the new `select` parameter with enhanced functionality ### 🔥 Removed @@ -75,6 +84,7 @@ If upgrading from v2.x, see the [Migration Guide](https://flixopt.github.io/flix - Improve docs visually with new Material theme and enhanced styling ### 👷 Development +- Renamed `_apply_indexer_to_data()` to `_apply_selection_to_data()` for consistency with new API ### 🚧 Known Issues diff --git a/examples/04_Scenarios/scenario_example.py b/examples/04_Scenarios/scenario_example.py index 6aa3c0c89..834e55782 100644 --- a/examples/04_Scenarios/scenario_example.py +++ b/examples/04_Scenarios/scenario_example.py @@ -8,20 +8,80 @@ import flixopt as fx if __name__ == '__main__': - # Create datetime array starting from '2020-01-01' for the given time period - timesteps = pd.date_range('2020-01-01', periods=9, freq='h') + # Create datetime array starting from '2020-01-01' for one week + timesteps = pd.date_range('2020-01-01', periods=24 * 7, freq='h') scenarios = pd.Index(['Base Case', 'High Demand']) periods = pd.Index([2020, 2021, 2022]) # --- Create Time Series Data --- - # Heat demand profile (e.g., kW) over time and corresponding power prices - heat_demand_per_h = pd.DataFrame( - {'Base Case': [30, 0, 90, 110, 110, 20, 20, 20, 20], 'High Demand': [30, 0, 100, 118, 125, 20, 20, 20, 20]}, - index=timesteps, + # Realistic daily patterns: morning/evening peaks, night/midday lows + np.random.seed(42) + n_hours = len(timesteps) + + # Heat demand: 24-hour patterns (kW) for Base Case and High Demand scenarios + base_daily_pattern = np.array( + [22, 20, 18, 18, 20, 25, 40, 70, 95, 110, 85, 65, 60, 58, 62, 68, 75, 88, 105, 125, 130, 122, 95, 35] + ) + high_daily_pattern = np.array( + [28, 25, 22, 22, 24, 30, 52, 88, 118, 135, 105, 80, 75, 72, 75, 82, 92, 108, 128, 148, 155, 145, 115, 48] + ) + + # Tile and add variation + base_demand = np.tile(base_daily_pattern, n_hours // 24 + 1)[:n_hours] * ( + 1 + np.random.uniform(-0.05, 0.05, n_hours) + ) + high_demand = np.tile(high_daily_pattern, n_hours // 24 + 1)[:n_hours] * ( + 1 + np.random.uniform(-0.07, 0.07, n_hours) + ) + + heat_demand_per_h = pd.DataFrame({'Base Case': base_demand, 'High Demand': high_demand}, index=timesteps) + + # Power prices: hourly factors (night low, peak high) and period escalation (2020-2022) + hourly_price_factors = np.array( + [ + 0.70, + 0.65, + 0.62, + 0.60, + 0.62, + 0.70, + 0.95, + 1.15, + 1.30, + 1.25, + 1.10, + 1.00, + 0.95, + 0.90, + 0.88, + 0.92, + 1.00, + 1.10, + 1.25, + 1.40, + 1.35, + 1.20, + 0.95, + 0.80, + ] ) - power_prices = np.array([0.08, 0.09, 0.10]) + period_base_prices = np.array([0.075, 0.095, 0.135]) # €/kWh for 2020, 2021, 2022 - flow_system = fx.FlowSystem(timesteps=timesteps, periods=periods, scenarios=scenarios, weights=np.array([0.5, 0.6])) + price_series = np.zeros((n_hours, 3)) + for period_idx, base_price in enumerate(period_base_prices): + price_series[:, period_idx] = ( + np.tile(hourly_price_factors, n_hours // 24 + 1)[:n_hours] + * base_price + * (1 + np.random.uniform(-0.03, 0.03, n_hours)) + ) + + power_prices = price_series.mean(axis=0) + + # Scenario weights: probability of each scenario occurring + # Base Case: 60% probability, High Demand: 40% probability + scenario_weights = np.array([0.6, 0.4]) + + flow_system = fx.FlowSystem(timesteps=timesteps, periods=periods, scenarios=scenarios, weights=scenario_weights) # --- Define Energy Buses --- # These represent nodes, where the used medias are balanced (electricity, heat, and gas) @@ -35,22 +95,24 @@ description='Kosten', is_standard=True, # standard effect: no explicit value needed for costs is_objective=True, # Minimizing costs as the optimization objective - share_from_temporal={'CO2': 0.2}, + share_from_temporal={'CO2': 0.2}, # Carbon price: 0.2 €/kg CO2 (e.g., carbon tax) ) - # CO2 emissions effect with an associated cost impact + # CO2 emissions effect with constraint + # Maximum of 1000 kg CO2/hour represents a regulatory or voluntary emissions limit CO2 = fx.Effect( label='CO2', unit='kg', description='CO2_e-Emissionen', - maximum_per_hour=1000, # Max CO2 emissions per hour + maximum_per_hour=1000, # Regulatory emissions limit: 1000 kg CO2/hour ) # --- Define Flow System Components --- # Boiler: Converts fuel (gas) into thermal energy (heat) + # Modern condensing gas boiler with realistic efficiency boiler = fx.linear_converters.Boiler( label='Boiler', - eta=0.5, + eta=0.92, # Realistic efficiency for modern condensing gas boiler (92%) Q_th=fx.Flow( label='Q_th', bus='Fernwärme', @@ -63,27 +125,28 @@ ) # Combined Heat and Power (CHP): Generates both electricity and heat from fuel + # Modern CHP unit with realistic efficiencies (total efficiency ~88%) chp = fx.linear_converters.CHP( label='CHP', - eta_th=0.5, - eta_el=0.4, + eta_th=0.48, # Realistic thermal efficiency (48%) + eta_el=0.40, # Realistic electrical efficiency (40%) P_el=fx.Flow('P_el', bus='Strom', size=60, relative_minimum=5 / 60, on_off_parameters=fx.OnOffParameters()), Q_th=fx.Flow('Q_th', bus='Fernwärme'), Q_fu=fx.Flow('Q_fu', bus='Gas'), ) - # Storage: Energy storage system with charging and discharging capabilities + # Storage: Thermal energy storage system with charging and discharging capabilities + # Realistic thermal storage parameters (e.g., insulated hot water tank) storage = fx.Storage( label='Storage', charging=fx.Flow('Q_th_load', bus='Fernwärme', size=1000), discharging=fx.Flow('Q_th_unload', bus='Fernwärme', size=1000), capacity_in_flow_hours=fx.InvestParameters(effects_of_investment=20, fixed_size=30, mandatory=True), initial_charge_state=0, # Initial storage state: empty - relative_maximum_charge_state=np.array([80, 70, 80, 80, 80, 80, 80, 80, 80]) * 0.01, - relative_maximum_final_charge_state=0.8, - eta_charge=0.9, - eta_discharge=1, # Efficiency factors for charging/discharging - relative_loss_per_hour=0.08, # 8% loss per hour. Absolute loss depends on current charge state + relative_maximum_final_charge_state=np.array([0.8, 0.5, 0.1]), + eta_charge=0.95, # Realistic charging efficiency (~95%) + eta_discharge=0.98, # Realistic discharging efficiency (~98%) + relative_loss_per_hour=np.array([0.008, 0.015]), # Realistic thermal losses: 0.8-1.5% per hour prevent_simultaneous_charge_and_discharge=True, # Prevent charging and discharging at the same time ) @@ -94,10 +157,22 @@ ) # Gas Source: Gas tariff source with associated costs and CO2 emissions + # Realistic gas prices varying by period (reflecting 2020-2022 energy crisis) + # 2020: 0.04 €/kWh, 2021: 0.06 €/kWh, 2022: 0.11 €/kWh + gas_prices_per_period = np.array([0.04, 0.06, 0.11]) + + # CO2 emissions factor for natural gas: ~0.202 kg CO2/kWh (realistic value) + gas_co2_emissions = 0.202 + gas_source = fx.Source( label='Gastarif', outputs=[ - fx.Flow(label='Q_Gas', bus='Gas', size=1000, effects_per_flow_hour={costs.label: 0.04, CO2.label: 0.3}) + fx.Flow( + label='Q_Gas', + bus='Gas', + size=1000, + effects_per_flow_hour={costs.label: gas_prices_per_period, CO2.label: gas_co2_emissions}, + ) ], ) @@ -124,21 +199,14 @@ calculation.results.plot_heatmap('CHP(Q_th)|flow_rate') # --- Analyze Results --- - calculation.results['Fernwärme'].plot_node_balance_pie() calculation.results['Fernwärme'].plot_node_balance(mode='stacked_bar') - calculation.results['Storage'].plot_node_balance() calculation.results.plot_heatmap('CHP(Q_th)|flow_rate') + calculation.results['Storage'].plot_charge_state() + calculation.results['Fernwärme'].plot_node_balance_pie(select={'period': 2020, 'scenario': 'Base Case'}) # Convert the results for the storage component to a dataframe and display df = calculation.results['Storage'].node_balance_with_charge_state() print(df) - # Plot charge state using matplotlib - fig, ax = calculation.results['Storage'].plot_charge_state(engine='matplotlib') - # Customize the plot further if needed - ax.set_title('Storage Charge State Over Time') - # Or save the figure - # fig.savefig('storage_charge_state.png') - # Save results to file for later usage calculation.results.to_file() diff --git a/flixopt/plotting.py b/flixopt/plotting.py index 218a8ab0e..bd1f3c2c4 100644 --- a/flixopt/plotting.py +++ b/flixopt/plotting.py @@ -39,6 +39,7 @@ import plotly.express as px import plotly.graph_objects as go import plotly.offline +import xarray as xr from plotly.exceptions import PlotlyError if TYPE_CHECKING: @@ -326,143 +327,269 @@ def process_colors( def with_plotly( - data: pd.DataFrame, + data: pd.DataFrame | xr.DataArray | xr.Dataset, mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar', colors: ColorType = 'viridis', title: str = '', ylabel: str = '', xlabel: str = 'Time in h', fig: go.Figure | None = None, + facet_by: str | list[str] | None = None, + animate_by: str | None = None, + facet_cols: int = 3, + shared_yaxes: bool = True, + shared_xaxes: bool = True, ) -> go.Figure: """ - Plot a DataFrame with Plotly, using either stacked bars or stepped lines. + Plot data with Plotly using facets (subplots) and/or animation for multidimensional data. + + Uses Plotly Express for convenient faceting and animation with automatic styling. + For simple plots without faceting, can optionally add to an existing figure. Args: - data: A DataFrame containing the data to plot, where the index represents time (e.g., hours), - and each column represents a separate data series. - mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, - or 'area' for stacked area charts. - colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) - title: The title of the plot. + data: A DataFrame or xarray DataArray/Dataset to plot. + mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines, + 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts. + colors: Color specification (colormap, list, or dict mapping labels to colors). + title: The main title of the plot. ylabel: The label for the y-axis. xlabel: The label for the x-axis. - fig: A Plotly figure object to plot on. If not provided, a new figure will be created. + fig: A Plotly figure object to plot on (only for simple plots without faceting). + If not provided, a new figure will be created. + facet_by: Dimension(s) to create facets for. Creates a subplot grid. + Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col). + If the dimension doesn't exist in the data, it will be silently ignored. + animate_by: Dimension to animate over. Creates animation frames. + If the dimension doesn't exist in the data, it will be silently ignored. + facet_cols: Number of columns in the facet grid (used when facet_by is single dimension). + shared_yaxes: Whether subplots share y-axes. + shared_xaxes: Whether subplots share x-axes. Returns: - A Plotly figure object containing the generated plot. + A Plotly figure object containing the faceted/animated plot. + + Examples: + Simple plot: + + ```python + fig = with_plotly(df, mode='area', title='Energy Mix') + ``` + + Facet by scenario: + + ```python + fig = with_plotly(ds, facet_by='scenario', facet_cols=2) + ``` + + Animate by period: + + ```python + fig = with_plotly(ds, animate_by='period') + ``` + + Facet and animate: + + ```python + fig = with_plotly(ds, facet_by='scenario', animate_by='period') + ``` """ if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}") - if data.empty: - return go.Figure() - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, list(data.columns)) - - fig = fig if fig is not None else go.Figure() + # Handle empty data + if isinstance(data, pd.DataFrame) and data.empty: + return go.Figure() + elif isinstance(data, xr.DataArray) and data.size == 0: + return go.Figure() + elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0: + return go.Figure() - if mode == 'stacked_bar': - for i, column in enumerate(data.columns): - fig.add_trace( - go.Bar( - x=data.index, - y=data[column], - name=column, - marker=dict( - color=processed_colors[i], line=dict(width=0, color='rgba(0,0,0,0)') - ), # Transparent line with 0 width + # Warn if fig parameter is used with faceting + if fig is not None and (facet_by is not None or animate_by is not None): + logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.') + fig = None + + # Convert xarray to long-form DataFrame for Plotly Express + if isinstance(data, (xr.DataArray, xr.Dataset)): + # Convert to long-form (tidy) DataFrame + # Structure: time, variable, value, scenario, period, ... (all dims as columns) + if isinstance(data, xr.Dataset): + # Stack all data variables into long format + df_long = data.to_dataframe().reset_index() + # Melt to get: time, scenario, period, ..., variable, value + id_vars = [dim for dim in data.dims] + value_vars = list(data.data_vars) + df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') + else: + # DataArray + df_long = data.to_dataframe().reset_index() + if data.name: + df_long = df_long.rename(columns={data.name: 'value'}) + else: + # Unnamed DataArray, find the value column + non_dim_cols = [col for col in df_long.columns if col not in data.dims] + if len(non_dim_cols) != 1: + raise ValueError( + f'Expected exactly one non-dimension column for unnamed DataArray, ' + f'but found {len(non_dim_cols)}: {non_dim_cols}' + ) + value_col = non_dim_cols[0] + df_long = df_long.rename(columns={value_col: 'value'}) + df_long['variable'] = data.name or 'data' + else: + # Already a DataFrame - convert to long format for Plotly Express + df_long = data.reset_index() + if 'time' not in df_long.columns: + # First column is probably time + df_long = df_long.rename(columns={df_long.columns[0]: 'time'}) + # Melt to long format + id_vars = [ + col + for col in df_long.columns + if col in ['time', 'scenario', 'period'] + or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else []) + ] + value_vars = [col for col in df_long.columns if col not in id_vars] + df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') + + # Validate facet_by and animate_by dimensions exist in the data + available_dims = [col for col in df_long.columns if col not in ['variable', 'value']] + + # Check facet_by dimensions + if facet_by is not None: + if isinstance(facet_by, str): + if facet_by not in available_dims: + logger.debug( + f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. " + f'Ignoring facet_by parameter.' ) - ) - - fig.update_layout( - barmode='relative', - bargap=0, # No space between bars - bargroupgap=0, # No space between grouped bars + facet_by = None + elif isinstance(facet_by, list): + # Filter out dimensions that don't exist + missing_dims = [dim for dim in facet_by if dim not in available_dims] + facet_by = [dim for dim in facet_by if dim in available_dims] + if missing_dims: + logger.debug( + f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. ' + f'Using only existing dimensions: {facet_by if facet_by else "none"}.' + ) + if len(facet_by) == 0: + facet_by = None + + # Check animate_by dimension + if animate_by is not None and animate_by not in available_dims: + logger.debug( + f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. " + f'Ignoring animate_by parameter.' ) - if mode == 'grouped_bar': - for i, column in enumerate(data.columns): - fig.add_trace(go.Bar(x=data.index, y=data[column], name=column, marker=dict(color=processed_colors[i]))) + animate_by = None + + # Setup faceting parameters for Plotly Express + facet_row = None + facet_col = None + if facet_by: + if isinstance(facet_by, str): + # Single facet dimension - use facet_col with facet_col_wrap + facet_col = facet_by + elif len(facet_by) == 1: + facet_col = facet_by[0] + elif len(facet_by) == 2: + # Two facet dimensions - use facet_row and facet_col + facet_row = facet_by[0] + facet_col = facet_by[1] + else: + raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}') + + # Process colors + all_vars = df_long['variable'].unique().tolist() + processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars) + color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)} + + # Create plot using Plotly Express based on mode + common_args = { + 'data_frame': df_long, + 'x': 'time', + 'y': 'value', + 'color': 'variable', + 'facet_row': facet_row, + 'facet_col': facet_col, + 'animation_frame': animate_by, + 'color_discrete_map': color_discrete_map, + 'title': title, + 'labels': {'value': ylabel, 'time': xlabel, 'variable': ''}, + } - fig.update_layout( - barmode='group', - bargap=0.2, # No space between bars - bargroupgap=0, # space between grouped bars - ) + # Add facet_col_wrap for single facet dimension + if facet_col and not facet_row: + common_args['facet_col_wrap'] = facet_cols + + if mode == 'stacked_bar': + fig = px.bar(**common_args) + fig.update_traces(marker_line_width=0) + fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) + elif mode == 'grouped_bar': + fig = px.bar(**common_args) + fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0) elif mode == 'line': - for i, column in enumerate(data.columns): - fig.add_trace( - go.Scatter( - x=data.index, - y=data[column], - mode='lines', - name=column, - line=dict(shape='hv', color=processed_colors[i]), - ) - ) + fig = px.line(**common_args, line_shape='hv') # Stepped lines elif mode == 'area': - data = data.copy() - data[(data > -1e-5) & (data < 1e-5)] = 0 # Preventing issues with plotting - # Split columns into positive, negative, and mixed categories - positive_columns = list(data.columns[(data >= 0).where(~np.isnan(data), True).all()]) - negative_columns = list(data.columns[(data <= 0).where(~np.isnan(data), True).all()]) - negative_columns = [column for column in negative_columns if column not in positive_columns] - mixed_columns = list(set(data.columns) - set(positive_columns + negative_columns)) - - if mixed_columns: - logger.error( - f'Data for plotting stacked lines contains columns with both positive and negative values:' - f' {mixed_columns}. These can not be stacked, and are printed as simple lines' - ) + # Use Plotly Express to create the area plot (preserves animation, legends, faceting) + fig = px.area(**common_args, line_shape='hv') - # Get color mapping for all columns - colors_stacked = {column: processed_colors[i] for i, column in enumerate(data.columns)} - - for column in positive_columns + negative_columns: - fig.add_trace( - go.Scatter( - x=data.index, - y=data[column], - mode='lines', - name=column, - line=dict(shape='hv', color=colors_stacked[column]), - fill='tonexty', - stackgroup='pos' if column in positive_columns else 'neg', - ) - ) + # Classify each variable based on its values + variable_classification = {} + for var in all_vars: + var_data = df_long[df_long['variable'] == var]['value'] + var_data_clean = var_data[(var_data < -1e-5) | (var_data > 1e-5)] - for column in mixed_columns: - fig.add_trace( - go.Scatter( - x=data.index, - y=data[column], - mode='lines', - name=column, - line=dict(shape='hv', color=colors_stacked[column], dash='dash'), + if len(var_data_clean) == 0: + variable_classification[var] = 'zero' + else: + has_pos, has_neg = (var_data_clean > 0).any(), (var_data_clean < 0).any() + variable_classification[var] = ( + 'mixed' if has_pos and has_neg else ('negative' if has_neg else 'positive') ) - ) - # Update layout for better aesthetics + # Log warning for mixed variables + mixed_vars = [v for v, c in variable_classification.items() if c == 'mixed'] + if mixed_vars: + logger.warning(f'Variables with both positive and negative values: {mixed_vars}. Plotted as dashed lines.') + + all_traces = list(fig.data) + for frame in fig.frames: + all_traces.extend(frame.data) + + for trace in all_traces: + cls = variable_classification.get(trace.name, None) + # Only stack positive and negative, not mixed or zero + trace.stackgroup = cls if cls in ('positive', 'negative') else None + + if cls in ('positive', 'negative'): + # Stacked area: add opacity to avoid hiding layers, remove line border + if hasattr(trace, 'line') and trace.line.color: + trace.fillcolor = trace.line.color + trace.line.width = 0 + elif cls == 'mixed': + # Mixed variables: show as dashed line, not stacked + if hasattr(trace, 'line'): + trace.line.width = 2 + trace.line.dash = 'dash' + if hasattr(trace, 'fill'): + trace.fill = None + + # Update layout with basic styling (Plotly Express handles sizing automatically) fig.update_layout( - title=title, - yaxis=dict( - title=ylabel, - showgrid=True, # Enable grid lines on the y-axis - gridcolor='lightgrey', # Customize grid line color - gridwidth=0.5, # Customize grid line width - ), - xaxis=dict( - title=xlabel, - showgrid=True, # Enable grid lines on the x-axis - gridcolor='lightgrey', # Customize grid line color - gridwidth=0.5, # Customize grid line width - ), - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), # Increase font size for better readability + plot_bgcolor='rgba(0,0,0,0)', + paper_bgcolor='rgba(0,0,0,0)', + font=dict(size=12), ) + # Update axes to share if requested (Plotly Express already handles this, but we can customize) + if not shared_yaxes: + fig.update_yaxes(matches=None) + if not shared_xaxes: + fig.update_xaxes(matches=None) + return fig @@ -562,213 +689,110 @@ def with_matplotlib( return fig, ax -def heat_map_matplotlib( - data: pd.DataFrame, - color_map: str = 'viridis', - title: str = '', - xlabel: str = 'Period', - ylabel: str = 'Step', - figsize: tuple[float, float] = (12, 6), -) -> tuple[plt.Figure, plt.Axes]: - """ - Plots a DataFrame as a heatmap using Matplotlib. The columns of the DataFrame will be displayed on the x-axis, - the index will be displayed on the y-axis, and the values will represent the 'heat' intensity in the plot. - - Args: - data: A DataFrame containing the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis. - The values in the DataFrame will be represented as colors in the heatmap. - color_map: The colormap to use for the heatmap. Default is 'viridis'. Matplotlib supports various colormaps like 'plasma', 'inferno', 'cividis', etc. - title: The title of the plot. - xlabel: The label for the x-axis. - ylabel: The label for the y-axis. - figsize: The size of the figure to create. Default is (12, 6), which results in a width of 12 inches and a height of 6 inches. - - Returns: - A tuple containing the Matplotlib `Figure` and `Axes` objects. The `Figure` contains the overall plot, while the `Axes` is the area - where the heatmap is drawn. These can be used for further customization or saving the plot to a file. - - Notes: - - The y-axis is flipped so that the first row of the DataFrame is displayed at the top of the plot. - - The color scale is normalized based on the minimum and maximum values in the DataFrame. - - The x-axis labels (periods) are placed at the top of the plot. - - The colorbar is added horizontally at the bottom of the plot, with a label. - """ - - # Get the min and max values for color normalization - color_bar_min, color_bar_max = data.min().min(), data.max().max() - - # Create the heatmap plot - fig, ax = plt.subplots(figsize=figsize) - ax.pcolormesh(data.values, cmap=color_map, shading='auto') - ax.invert_yaxis() # Flip the y-axis to start at the top - - # Adjust ticks and labels for x and y axes - ax.set_xticks(np.arange(len(data.columns)) + 0.5) - ax.set_xticklabels(data.columns, ha='center') - ax.set_yticks(np.arange(len(data.index)) + 0.5) - ax.set_yticklabels(data.index, va='center') - - # Add labels to the axes - ax.set_xlabel(xlabel, ha='center') - ax.set_ylabel(ylabel, va='center') - ax.set_title(title) - - # Position x-axis labels at the top - ax.xaxis.set_label_position('top') - ax.xaxis.set_ticks_position('top') - - # Add the colorbar - sm1 = plt.cm.ScalarMappable(cmap=color_map, norm=plt.Normalize(vmin=color_bar_min, vmax=color_bar_max)) - sm1.set_array([]) - fig.colorbar(sm1, ax=ax, pad=0.12, aspect=15, fraction=0.2, orientation='horizontal') - - fig.tight_layout() - - return fig, ax - - -def heat_map_plotly( - data: pd.DataFrame, - color_map: str = 'viridis', - title: str = '', - xlabel: str = 'Period', - ylabel: str = 'Step', - categorical_labels: bool = True, -) -> go.Figure: - """ - Plots a DataFrame as a heatmap using Plotly. The columns of the DataFrame will be mapped to the x-axis, - and the index will be displayed on the y-axis. The values in the DataFrame will represent the 'heat' in the plot. - - Args: - data: A DataFrame with the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis. - The values in the DataFrame will be represented as colors in the heatmap. - color_map: The color scale to use for the heatmap. Default is 'viridis'. Plotly supports various color scales like 'Cividis', 'Inferno', etc. - title: The title of the heatmap. Default is an empty string. - xlabel: The label for the x-axis. Default is 'Period'. - ylabel: The label for the y-axis. Default is 'Step'. - categorical_labels: If True, the x and y axes are treated as categorical data (i.e., the index and columns will not be interpreted as continuous data). - Default is True. If False, the axes are treated as continuous, which may be useful for time series or numeric data. - - Returns: - A Plotly figure object containing the heatmap. This can be further customized and saved - or displayed using `fig.show()`. - - Notes: - The color bar is automatically scaled to the minimum and maximum values in the data. - The y-axis is reversed to display the first row at the top. +def reshape_data_for_heatmap( + data: xr.DataArray, + reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] + | Literal['auto'] + | None = 'auto', + facet_by: str | list[str] | None = None, + animate_by: str | None = None, + fill: Literal['ffill', 'bfill'] | None = 'ffill', +) -> xr.DataArray: """ + Reshape data for heatmap visualization, handling time dimension intelligently. - color_bar_min, color_bar_max = data.min().min(), data.max().max() # Min and max values for color scaling - # Define the figure - fig = go.Figure( - data=go.Heatmap( - z=data.values, - x=data.columns, - y=data.index, - colorscale=color_map, - zmin=color_bar_min, - zmax=color_bar_max, - colorbar=dict( - title=dict(text='Color Bar Label', side='right'), - orientation='h', - xref='container', - yref='container', - len=0.8, # Color bar length relative to plot - x=0.5, - y=0.1, - ), - ) - ) - - # Set axis labels and style - fig.update_layout( - title=title, - xaxis=dict(title=xlabel, side='top', type='category' if categorical_labels else None), - yaxis=dict(title=ylabel, autorange='reversed', type='category' if categorical_labels else None), - ) - - return fig - - -def reshape_to_2d(data_1d: np.ndarray, nr_of_steps_per_column: int) -> np.ndarray: - """ - Reshapes a 1D numpy array into a 2D array suitable for plotting as a colormap. + This function decides whether to reshape the 'time' dimension based on the reshape_time parameter: + - 'auto': Automatically reshapes if only 'time' dimension would remain for heatmap + - Tuple: Explicitly reshapes time with specified parameters + - None: No reshaping (returns data as-is) - The reshaped array will have the number of rows corresponding to the steps per column - (e.g., 24 hours per day) and columns representing time periods (e.g., days or months). + All non-time dimensions are preserved during reshaping. Args: - data_1d: A 1D numpy array with the data to reshape. - nr_of_steps_per_column: The number of steps (rows) per column in the resulting 2D array. For example, - this could be 24 (for hours) or 31 (for days in a month). + data: DataArray to reshape for heatmap visualization. + reshape_time: Reshaping configuration: + - 'auto' (default): Auto-reshape if needed based on facet_by/animate_by + - Tuple (timeframes, timesteps_per_frame): Explicit time reshaping + - None: No reshaping + facet_by: Dimension(s) used for faceting (used in 'auto' decision). + animate_by: Dimension used for animation (used in 'auto' decision). + fill: Method to fill missing values: 'ffill' or 'bfill'. Default is 'ffill'. Returns: - The reshaped 2D array. Each internal array corresponds to one column, with the specified number of steps. - Each column might represents a time period (e.g., day, month, etc.). - """ - - # Step 1: Ensure the input is a 1D array. - if data_1d.ndim != 1: - raise ValueError('Input must be a 1D array') - - # Step 2: Convert data to float type to allow NaN padding - if data_1d.dtype != np.float64: - data_1d = data_1d.astype(np.float64) + Reshaped DataArray. If time reshaping is applied, 'time' dimension is replaced + by 'timestep' and 'timeframe'. All other dimensions are preserved. - # Step 3: Calculate the number of columns required - total_steps = len(data_1d) - cols = len(data_1d) // nr_of_steps_per_column # Base number of columns - - # If there's a remainder, add an extra column to hold the remaining values - if total_steps % nr_of_steps_per_column != 0: - cols += 1 + Examples: + Auto-reshaping: - # Step 4: Pad the 1D data to match the required number of rows and columns - padded_data = np.pad( - data_1d, (0, cols * nr_of_steps_per_column - total_steps), mode='constant', constant_values=np.nan - ) + ```python + # Will auto-reshape because only 'time' remains after faceting/animation + data = reshape_data_for_heatmap(data, reshape_time='auto', facet_by='scenario', animate_by='period') + ``` - # Step 5: Reshape the padded data into a 2D array - data_2d = padded_data.reshape(cols, nr_of_steps_per_column) + Explicit reshaping: - return data_2d.T + ```python + # Explicitly reshape to daily pattern + data = reshape_data_for_heatmap(data, reshape_time=('D', 'h')) + ``` + No reshaping: -def heat_map_data_from_df( - df: pd.DataFrame, - periods: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], - steps_per_period: Literal['W', 'D', 'h', '15min', 'min'], - fill: Literal['ffill', 'bfill'] | None = None, -) -> pd.DataFrame: + ```python + # Keep data as-is + data = reshape_data_for_heatmap(data, reshape_time=None) + ``` """ - Reshapes a DataFrame with a DateTime index into a 2D array for heatmap plotting, - based on a specified sample rate. - Only specific combinations of `periods` and `steps_per_period` are supported; invalid combinations raise an assertion. - - Args: - df: A DataFrame with a DateTime index containing the data to reshape. - periods: The time interval of each period (columns of the heatmap), - such as 'YS' (year start), 'W' (weekly), 'D' (daily), 'h' (hourly) etc. - steps_per_period: The time interval within each period (rows in the heatmap), - such as 'YS' (year start), 'W' (weekly), 'D' (daily), 'h' (hourly) etc. - fill: Method to fill missing values: 'ffill' for forward fill or 'bfill' for backward fill. + # If no time dimension, return data as-is + if 'time' not in data.dims: + return data + + # Handle None (disabled) - return data as-is + if reshape_time is None: + return data + + # Determine timeframes and timesteps_per_frame based on reshape_time parameter + if reshape_time == 'auto': + # Check if we need automatic time reshaping + facet_dims_used = [] + if facet_by: + facet_dims_used = [facet_by] if isinstance(facet_by, str) else list(facet_by) + if animate_by: + facet_dims_used.append(animate_by) + + # Get dimensions that would remain for heatmap + potential_heatmap_dims = [dim for dim in data.dims if dim not in facet_dims_used] + + # Auto-reshape if only 'time' dimension remains + if len(potential_heatmap_dims) == 1 and potential_heatmap_dims[0] == 'time': + logger.debug( + "Auto-applying time reshaping: Only 'time' dimension remains after faceting/animation. " + "Using default timeframes='D' and timesteps_per_frame='h'. " + "To customize, use reshape_time=('D', 'h') or disable with reshape_time=None." + ) + timeframes, timesteps_per_frame = 'D', 'h' + else: + # No reshaping needed + return data + elif isinstance(reshape_time, tuple): + # Explicit reshaping + timeframes, timesteps_per_frame = reshape_time + else: + raise ValueError(f"reshape_time must be 'auto', a tuple like ('D', 'h'), or None. Got: {reshape_time}") - Returns: - A DataFrame suitable for heatmap plotting, with rows representing steps within each period - and columns representing each period. - """ - assert pd.api.types.is_datetime64_any_dtype(df.index), ( - 'The index of the DataFrame must be datetime to transform it properly for a heatmap plot' - ) + # Validate that time is datetime + if not np.issubdtype(data.coords['time'].dtype, np.datetime64): + raise ValueError(f'Time dimension must be datetime-based, got {data.coords["time"].dtype}') - # Define formats for different combinations of `periods` and `steps_per_period` + # Define formats for different combinations formats = { ('YS', 'W'): ('%Y', '%W'), ('YS', 'D'): ('%Y', '%j'), # day of year ('YS', 'h'): ('%Y', '%j %H:00'), ('MS', 'D'): ('%Y-%m', '%d'), # day of month ('MS', 'h'): ('%Y-%m', '%d %H:00'), - ('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week (with prefix for proper sorting) + ('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week ('W', 'h'): ('%Y-w%W', '%w_%A %H:00'), ('D', 'h'): ('%Y-%m-%d', '%H:00'), # Day and hour ('D', '15min'): ('%Y-%m-%d', '%H:%M'), # Day and minute @@ -776,43 +800,64 @@ def heat_map_data_from_df( ('h', 'min'): ('%Y-%m-%d %H:00', '%M'), # minute of hour } - if df.empty: - raise ValueError('DataFrame is empty.') - diffs = df.index.to_series().diff().dropna() - minimum_time_diff_in_min = diffs.min().total_seconds() / 60 - time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60} - if time_intervals[steps_per_period] > minimum_time_diff_in_min: - logger.error( - f'To compute the heatmap, the data was aggregated from {minimum_time_diff_in_min:.2f} min to ' - f'{time_intervals[steps_per_period]:.2f} min. Mean values are displayed.' - ) - - # Select the format based on the `periods` and `steps_per_period` combination - format_pair = (periods, steps_per_period) + format_pair = (timeframes, timesteps_per_frame) if format_pair not in formats: raise ValueError(f'{format_pair} is not a valid format. Choose from {list(formats.keys())}') period_format, step_format = formats[format_pair] - df = df.sort_index() # Ensure DataFrame is sorted by time index + # Check if resampling is needed + if data.sizes['time'] > 1: + # Use NumPy for more efficient timedelta computation + time_values = data.coords['time'].values # Already numpy datetime64[ns] + # Calculate differences and convert to minutes + time_diffs = np.diff(time_values).astype('timedelta64[s]').astype(float) / 60.0 + if time_diffs.size > 0: + min_time_diff_min = np.nanmin(time_diffs) + time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60} + if time_intervals[timesteps_per_frame] > min_time_diff_min: + logger.warning( + f'Resampling data from {min_time_diff_min:.2f} min to ' + f'{time_intervals[timesteps_per_frame]:.2f} min. Mean values are displayed.' + ) - resampled_data = df.resample(steps_per_period).mean() # Resample and fill any gaps with NaN + # Resample along time dimension + resampled = data.resample(time=timesteps_per_frame).mean() - if fill == 'ffill': # Apply fill method if specified - resampled_data = resampled_data.ffill() + # Apply fill if specified + if fill == 'ffill': + resampled = resampled.ffill(dim='time') elif fill == 'bfill': - resampled_data = resampled_data.bfill() + resampled = resampled.bfill(dim='time') + + # Create period and step labels + time_values = pd.to_datetime(resampled.coords['time'].values) + period_labels = time_values.strftime(period_format) + step_labels = time_values.strftime(step_format) + + # Handle special case for weekly day format + if '%w_%A' in step_format: + step_labels = pd.Series(step_labels).replace('0_Sunday', '7_Sunday').values + + # Add period and step as coordinates + resampled = resampled.assign_coords( + { + 'timeframe': ('time', period_labels), + 'timestep': ('time', step_labels), + } + ) - resampled_data['period'] = resampled_data.index.strftime(period_format) - resampled_data['step'] = resampled_data.index.strftime(step_format) - if '%w_%A' in step_format: # Shift index of strings to ensure proper sorting - resampled_data['step'] = resampled_data['step'].apply( - lambda x: x.replace('0_Sunday', '7_Sunday') if '0_Sunday' in x else x - ) + # Convert to multi-index and unstack + resampled = resampled.set_index(time=['timeframe', 'timestep']) + result = resampled.unstack('time') - # Pivot the table so periods are columns and steps are indices - df_pivoted = resampled_data.pivot(columns='period', index='step', values=df.columns[0]) + # Ensure timestep and timeframe come first in dimension order + # Get other dimensions + other_dims = [d for d in result.dims if d not in ['timestep', 'timeframe']] - return df_pivoted + # Reorder: timestep, timeframe, then other dimensions + result = result.transpose('timestep', 'timeframe', *other_dims) + + return result def plot_network( @@ -1413,6 +1458,311 @@ def preprocess_series(series: pd.Series): return fig, axes +def heatmap_with_plotly( + data: xr.DataArray, + colors: ColorType = 'viridis', + title: str = '', + facet_by: str | list[str] | None = None, + animate_by: str | None = None, + facet_cols: int = 3, + reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] + | Literal['auto'] + | None = 'auto', + fill: Literal['ffill', 'bfill'] | None = 'ffill', +) -> go.Figure: + """ + Plot a heatmap visualization using Plotly's imshow with faceting and animation support. + + This function creates heatmap visualizations from xarray DataArrays, supporting + multi-dimensional data through faceting (subplots) and animation. It automatically + handles dimension reduction and data reshaping for optimal heatmap display. + + Automatic Time Reshaping: + If only the 'time' dimension remains after faceting/animation (making the data 1D), + the function automatically reshapes time into a 2D format using default values + (timeframes='D', timesteps_per_frame='h'). This creates a daily pattern heatmap + showing hours vs days. + + Args: + data: An xarray DataArray containing the data to visualize. Should have at least + 2 dimensions, or a 'time' dimension that can be reshaped into 2D. + colors: Color specification (colormap name, list, or dict). Common options: + 'viridis', 'plasma', 'RdBu', 'portland'. + title: The main title of the heatmap. + facet_by: Dimension to create facets for. Creates a subplot grid. + Can be a single dimension name or list (only first dimension used). + Note: px.imshow only supports single-dimension faceting. + If the dimension doesn't exist in the data, it will be silently ignored. + animate_by: Dimension to animate over. Creates animation frames. + If the dimension doesn't exist in the data, it will be silently ignored. + facet_cols: Number of columns in the facet grid (used with facet_by). + reshape_time: Time reshaping configuration: + - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension remains + - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) + - None: Disable time reshaping (will error if only 1D time data) + fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + + Returns: + A Plotly figure object containing the heatmap visualization. + + Examples: + Simple heatmap: + + ```python + fig = heatmap_with_plotly(data_array, colors='RdBu', title='Temperature Map') + ``` + + Facet by scenario: + + ```python + fig = heatmap_with_plotly(data_array, facet_by='scenario', facet_cols=2) + ``` + + Animate by period: + + ```python + fig = heatmap_with_plotly(data_array, animate_by='period') + ``` + + Automatic time reshaping (when only time dimension remains): + + ```python + # Data with dims ['time', 'scenario', 'period'] + # After faceting and animation, only 'time' remains -> auto-reshapes to (timestep, timeframe) + fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period') + ``` + + Explicit time reshaping: + + ```python + fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D')) + ``` + """ + # Handle empty data + if data.size == 0: + return go.Figure() + + # Apply time reshaping using the new unified function + data = reshape_data_for_heatmap( + data, reshape_time=reshape_time, facet_by=facet_by, animate_by=animate_by, fill=fill + ) + + # Get available dimensions + available_dims = list(data.dims) + + # Validate and filter facet_by dimensions + if facet_by is not None: + if isinstance(facet_by, str): + if facet_by not in available_dims: + logger.debug( + f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. " + f'Ignoring facet_by parameter.' + ) + facet_by = None + elif isinstance(facet_by, list): + missing_dims = [dim for dim in facet_by if dim not in available_dims] + facet_by = [dim for dim in facet_by if dim in available_dims] + if missing_dims: + logger.debug( + f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. ' + f'Using only existing dimensions: {facet_by if facet_by else "none"}.' + ) + if len(facet_by) == 0: + facet_by = None + + # Validate animate_by dimension + if animate_by is not None and animate_by not in available_dims: + logger.debug( + f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. " + f'Ignoring animate_by parameter.' + ) + animate_by = None + + # Determine which dimensions are used for faceting/animation + facet_dims = [] + if facet_by: + facet_dims = [facet_by] if isinstance(facet_by, str) else facet_by + if animate_by: + facet_dims.append(animate_by) + + # Get remaining dimensions for the heatmap itself + heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] + + if len(heatmap_dims) < 2: + # Need at least 2 dimensions for a heatmap + logger.error( + f'Heatmap requires at least 2 dimensions for rows and columns. ' + f'After faceting/animation, only {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' + ) + return go.Figure() + + # Setup faceting parameters for Plotly Express + # Note: px.imshow only supports facet_col, not facet_row + facet_col_param = None + if facet_by: + if isinstance(facet_by, str): + facet_col_param = facet_by + elif len(facet_by) == 1: + facet_col_param = facet_by[0] + elif len(facet_by) >= 2: + # px.imshow doesn't support facet_row, so we can only facet by one dimension + # Use the first dimension and warn about the rest + facet_col_param = facet_by[0] + logger.warning( + f'px.imshow only supports faceting by a single dimension. ' + f'Using {facet_by[0]} for faceting. Dimensions {facet_by[1:]} will be ignored. ' + f'Consider using animate_by for additional dimensions.' + ) + + # Create the imshow plot - px.imshow can work directly with xarray DataArrays + common_args = { + 'img': data, + 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis', + 'title': title, + } + + # Add faceting if specified + if facet_col_param: + common_args['facet_col'] = facet_col_param + if facet_cols: + common_args['facet_col_wrap'] = facet_cols + + # Add animation if specified + if animate_by: + common_args['animation_frame'] = animate_by + + try: + fig = px.imshow(**common_args) + except Exception as e: + logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.') + # Fallback: create a simple heatmap without faceting + fig = px.imshow( + data.values, + color_continuous_scale=colors if isinstance(colors, str) else 'viridis', + title=title, + ) + + # Update layout with basic styling + fig.update_layout( + plot_bgcolor='rgba(0,0,0,0)', + paper_bgcolor='rgba(0,0,0,0)', + font=dict(size=12), + ) + + return fig + + +def heatmap_with_matplotlib( + data: xr.DataArray, + colors: ColorType = 'viridis', + title: str = '', + figsize: tuple[float, float] = (12, 6), + fig: plt.Figure | None = None, + ax: plt.Axes | None = None, + reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] + | Literal['auto'] + | None = 'auto', + fill: Literal['ffill', 'bfill'] | None = 'ffill', +) -> tuple[plt.Figure, plt.Axes]: + """ + Plot a heatmap visualization using Matplotlib's imshow. + + This function creates a basic 2D heatmap from an xarray DataArray using matplotlib's + imshow function. For multi-dimensional data, only the first two dimensions are used. + + Args: + data: An xarray DataArray containing the data to visualize. Should have at least + 2 dimensions. If more than 2 dimensions exist, additional dimensions will + be reduced by taking the first slice. + colors: Color specification. Should be a colormap name (e.g., 'viridis', 'RdBu'). + title: The title of the heatmap. + figsize: The size of the figure (width, height) in inches. + fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. + ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. + reshape_time: Time reshaping configuration: + - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension + - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) + - None: Disable time reshaping + fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + + Returns: + A tuple containing the Matplotlib figure and axes objects used for the plot. + + Notes: + - Matplotlib backend doesn't support faceting or animation. Use plotly engine for those features. + - The y-axis is automatically inverted to display data with origin at top-left. + - A colorbar is added to show the value scale. + + Examples: + ```python + fig, ax = heatmap_with_matplotlib(data_array, colors='RdBu', title='Temperature') + plt.savefig('heatmap.png') + ``` + + Time reshaping: + + ```python + fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h')) + ``` + """ + # Handle empty data + if data.size == 0: + if fig is None or ax is None: + fig, ax = plt.subplots(figsize=figsize) + return fig, ax + + # Apply time reshaping using the new unified function + # Matplotlib doesn't support faceting/animation, so we pass None for those + data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill) + + # Create figure and axes if not provided + if fig is None or ax is None: + fig, ax = plt.subplots(figsize=figsize) + + # Extract data values + # If data has more than 2 dimensions, we need to reduce it + if isinstance(data, xr.DataArray): + # Get the first 2 dimensions + dims = list(data.dims) + if len(dims) > 2: + logger.warning( + f'Data has {len(dims)} dimensions: {dims}. ' + f'Only the first 2 will be used for the heatmap. ' + f'Use the plotly engine for faceting/animation support.' + ) + # Select only the first 2 dimensions by taking first slice of others + selection = {dim: 0 for dim in dims[2:]} + data = data.isel(selection) + + values = data.values + x_labels = data.dims[1] if len(data.dims) > 1 else 'x' + y_labels = data.dims[0] if len(data.dims) > 0 else 'y' + else: + values = data + x_labels = 'x' + y_labels = 'y' + + # Process colormap + cmap = colors if isinstance(colors, str) else 'viridis' + + # Create the heatmap using imshow + im = ax.imshow(values, cmap=cmap, aspect='auto', origin='upper') + + # Add colorbar + cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.1, aspect=15, fraction=0.05) + cbar.set_label('Value') + + # Set labels and title + ax.set_xlabel(str(x_labels).capitalize()) + ax.set_ylabel(str(y_labels).capitalize()) + ax.set_title(title) + + # Apply tight layout + fig.tight_layout() + + return fig, ax + + def export_figure( figure_like: go.Figure | tuple[plt.Figure, plt.Axes], default_path: pathlib.Path, diff --git a/flixopt/results.py b/flixopt/results.py index b55d48744..a58f0dc1e 100644 --- a/flixopt/results.py +++ b/flixopt/results.py @@ -10,7 +10,6 @@ import linopy import numpy as np import pandas as pd -import plotly import xarray as xr import yaml @@ -20,6 +19,7 @@ if TYPE_CHECKING: import matplotlib.pyplot as plt + import plotly import pyvis from .calculation import Calculation, SegmentedCalculation @@ -195,8 +195,8 @@ def __init__( if 'flow_system' in kwargs and flow_system_data is None: flow_system_data = kwargs.pop('flow_system') warnings.warn( - "The 'flow_system' parameter is deprecated. Use 'flow_system_data' instead." - "Acess is now by '.flow_system_data', while '.flow_system' returns the restored FlowSystem.", + "The 'flow_system' parameter is deprecated. Use 'flow_system_data' instead. " + "Access is now via '.flow_system_data', while '.flow_system' returns the restored FlowSystem.", DeprecationWarning, stacklevel=2, ) @@ -687,68 +687,117 @@ def _create_effects_dataset(self, mode: Literal['temporal', 'periodic', 'total'] def plot_heatmap( self, - variable_name: str, - heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D', - heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h', - color_map: str = 'portland', + variable_name: str | list[str], save: bool | pathlib.Path = False, show: bool = True, + colors: plotting.ColorType = 'viridis', engine: plotting.PlottingEngine = 'plotly', + select: dict[FlowSystemDimensions, Any] | None = None, + facet_by: str | list[str] | None = 'scenario', + animate_by: str | None = 'period', + facet_cols: int = 3, + reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] + | Literal['auto'] + | None = 'auto', + fill: Literal['ffill', 'bfill'] | None = 'ffill', + # Deprecated parameters (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, + heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, + color_map: str | None = None, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ - Plots a heatmap of the solution of a variable. + Plots a heatmap visualization of a variable using imshow or time-based reshaping. + + Supports multiple visualization features that can be combined: + - **Multi-variable**: Plot multiple variables on a single heatmap (creates 'variable' dimension) + - **Time reshaping**: Converts 'time' dimension into 2D (e.g., hours vs days) + - **Faceting**: Creates subplots for different dimension values + - **Animation**: Animates through dimension values (Plotly only) Args: - variable_name: The name of the variable to plot. - heatmap_timeframes: The timeframes to use for the heatmap. - heatmap_timesteps_per_frame: The timesteps per frame to use for the heatmap. - color_map: The color map to use for the heatmap. + variable_name: The name of the variable to plot, or a list of variable names. + When a list is provided, variables are combined into a single DataArray + with a new 'variable' dimension. save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. + colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options. engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. - indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}. - If None, uses first value for each dimension. - If empty dict {}, uses all values. + select: Optional data selection dict. Supports single values, lists, slices, and index arrays. + Applied BEFORE faceting/animation/reshaping. + facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str) + or list of dimensions. Each unique value combination creates a subplot. Ignored if not found. + animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through + dimension values. Only one dimension can be animated. Ignored if not found. + facet_cols: Number of columns in the facet grid layout (default: 3). + reshape_time: Time reshaping configuration (default: 'auto'): + - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains + - Tuple: Explicit reshaping, e.g. ('D', 'h') for days vs hours, + ('MS', 'D') for months vs days, ('W', 'h') for weeks vs hours + - None: Disable auto-reshaping (will error if only 1D time data) + Supported timeframes: 'YS', 'MS', 'W', 'D', 'h', '15min', 'min' + fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill). + Default is 'ffill'. Examples: - Basic usage (uses first scenario, first period, all time): + Direct imshow mode (default): + + >>> results.plot_heatmap('Battery|charge_state', select={'scenario': 'base'}) + + Facet by scenario: + + >>> results.plot_heatmap('Boiler(Qth)|flow_rate', facet_by='scenario', facet_cols=2) - >>> results.plot_heatmap('Battery|charge_state') + Animate by period: - Select specific scenario and period: + >>> results.plot_heatmap('Boiler(Qth)|flow_rate', select={'scenario': 'base'}, animate_by='period') - >>> results.plot_heatmap('Boiler(Qth)|flow_rate', indexer={'scenario': 'base', 'period': 2024}) + Time reshape mode - daily patterns: - Time filtering (summer months only): + >>> results.plot_heatmap('Boiler(Qth)|flow_rate', select={'scenario': 'base'}, reshape_time=('D', 'h')) + + Combined: time reshaping with faceting and animation: >>> results.plot_heatmap( - ... 'Boiler(Qth)|flow_rate', - ... indexer={ - ... 'scenario': 'base', - ... 'time': results.solution.time[results.solution.time.dt.month.isin([6, 7, 8])], - ... }, + ... 'Boiler(Qth)|flow_rate', facet_by='scenario', animate_by='period', reshape_time=('D', 'h') ... ) - Save to specific location: + Multi-variable heatmap (variables as one axis): >>> results.plot_heatmap( - ... 'Boiler(Qth)|flow_rate', indexer={'scenario': 'base'}, save='path/to/my_heatmap.html' + ... ['Boiler(Q_th)|flow_rate', 'CHP(Q_th)|flow_rate', 'HeatStorage|charge_state'], + ... select={'scenario': 'base', 'period': 1}, + ... reshape_time=None, ... ) - """ - dataarray = self.solution[variable_name] + Multi-variable with time reshaping: + + >>> results.plot_heatmap( + ... ['Boiler(Q_th)|flow_rate', 'CHP(Q_th)|flow_rate'], + ... facet_by='scenario', + ... animate_by='period', + ... reshape_time=('D', 'h'), + ... ) + """ + # Delegate to module-level plot_heatmap function return plot_heatmap( - dataarray=dataarray, - name=variable_name, + data=self.solution[variable_name], + name=variable_name if isinstance(variable_name, str) else None, folder=self.folder, - heatmap_timeframes=heatmap_timeframes, - heatmap_timesteps_per_frame=heatmap_timesteps_per_frame, - color_map=color_map, + colors=colors, save=save, show=show, engine=engine, + select=select, + facet_by=facet_by, + animate_by=animate_by, + facet_cols=facet_cols, + reshape_time=reshape_time, + fill=fill, indexer=indexer, + heatmap_timeframes=heatmap_timeframes, + heatmap_timesteps_per_frame=heatmap_timesteps_per_frame, + color_map=color_map, ) def plot_network( @@ -920,30 +969,107 @@ def plot_node_balance( show: bool = True, colors: plotting.ColorType = 'viridis', engine: plotting.PlottingEngine = 'plotly', - indexer: dict[FlowSystemDimensions, Any] | None = None, + select: dict[FlowSystemDimensions, Any] | None = None, unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate', mode: Literal['area', 'stacked_bar', 'line'] = 'stacked_bar', drop_suffix: bool = True, + facet_by: str | list[str] | None = 'scenario', + animate_by: str | None = 'period', + facet_cols: int = 3, + # Deprecated parameter (kept for backwards compatibility) + indexer: dict[FlowSystemDimensions, Any] | None = None, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ - Plots the node balance of the Component or Bus. + Plots the node balance of the Component or Bus with optional faceting and animation. + Args: save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. colors: The colors to use for the plot. See `flixopt.plotting.ColorType` for options. engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. - indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}. - If None, uses first value for each dimension (except time). - If empty dict {}, uses all values. + select: Optional data selection dict. Supports: + - Single values: {'scenario': 'base', 'period': 2024} + - Multiple values: {'scenario': ['base', 'high', 'renewable']} + - Slices: {'time': slice('2024-01', '2024-06')} + - Index arrays: {'time': time_array} + Note: Applied BEFORE faceting/animation. unit_type: The unit type to use for the dataset. Can be 'flow_rate' or 'flow_hours'. - 'flow_rate': Returns the flow_rates of the Node. - 'flow_hours': Returns the flow_hours of the Node. [flow_hours(t) = flow_rate(t) * dt(t)]. Renames suffixes to |flow_hours. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, or 'area' for stacked area charts. drop_suffix: Whether to drop the suffix from the variable names. + facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str) + or list of dimensions. Each unique value combination creates a subplot. Ignored if not found. + Example: 'scenario' creates one subplot per scenario. + Example: ['scenario', 'period'] creates a grid of subplots for each scenario-period combination. + animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through + dimension values. Only one dimension can be animated. Ignored if not found. + facet_cols: Number of columns in the facet grid layout (default: 3). + + Examples: + Basic plot (current behavior): + + >>> results['Boiler'].plot_node_balance() + + Facet by scenario: + + >>> results['Boiler'].plot_node_balance(facet_by='scenario', facet_cols=2) + + Animate by period: + + >>> results['Boiler'].plot_node_balance(animate_by='period') + + Facet by scenario AND animate by period: + + >>> results['Boiler'].plot_node_balance(facet_by='scenario', animate_by='period') + + Select single scenario, then facet by period: + + >>> results['Boiler'].plot_node_balance(select={'scenario': 'base'}, facet_by='period') + + Select multiple scenarios and facet by them: + + >>> results['Boiler'].plot_node_balance( + ... select={'scenario': ['base', 'high', 'renewable']}, facet_by='scenario' + ... ) + + Time range selection (summer months only): + + >>> results['Boiler'].plot_node_balance(select={'time': slice('2024-06', '2024-08')}, facet_by='scenario') """ - ds = self.node_balance(with_last_timestep=True, unit_type=unit_type, drop_suffix=drop_suffix, indexer=indexer) + # Handle deprecated indexer parameter + if indexer is not None: + # Check for conflict with new parameter + if select is not None: + raise ValueError( + "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'." + ) + + import warnings + + warnings.warn( + "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.", + DeprecationWarning, + stacklevel=2, + ) + select = indexer - ds, suffix_parts = _apply_indexer_to_data(ds, indexer, drop=True) + if engine not in {'plotly', 'matplotlib'}: + raise ValueError(f'Engine "{engine}" not supported. Use one of ["plotly", "matplotlib"]') + + # Don't pass select/indexer to node_balance - we'll apply it afterwards + ds = self.node_balance(with_last_timestep=True, unit_type=unit_type, drop_suffix=drop_suffix) + + ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True) + + # Matplotlib requires only 'time' dimension; check for extras after selection + if engine == 'matplotlib': + extra_dims = [d for d in ds.dims if d != 'time'] + if extra_dims: + raise ValueError( + f'Matplotlib engine only supports a single time axis, but found extra dimensions: {extra_dims}. ' + f'Please use select={{...}} to reduce dimensions or switch to engine="plotly" for faceting/animation.' + ) suffix = '--' + '-'.join(suffix_parts) if suffix_parts else '' title = ( @@ -952,13 +1078,16 @@ def plot_node_balance( if engine == 'plotly': figure_like = plotting.with_plotly( - ds.to_dataframe(), + ds, + facet_by=facet_by, + animate_by=animate_by, colors=colors, mode=mode, title=title, + facet_cols=facet_cols, ) default_filetype = '.html' - elif engine == 'matplotlib': + else: figure_like = plotting.with_matplotlib( ds.to_dataframe(), colors=colors, @@ -966,8 +1095,6 @@ def plot_node_balance( title=title, ) default_filetype = '.png' - else: - raise ValueError(f'Engine "{engine}" not supported. Use "plotly" or "matplotlib"') return plotting.export_figure( figure_like=figure_like, @@ -986,9 +1113,19 @@ def plot_node_balance_pie( save: bool | pathlib.Path = False, show: bool = True, engine: plotting.PlottingEngine = 'plotly', + select: dict[FlowSystemDimensions, Any] | None = None, + # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, list[plt.Axes]]: """Plot pie chart of flow hours distribution. + + Note: + Pie charts require scalar data (no extra dimensions beyond time). + If your data has dimensions like 'scenario' or 'period', either: + + - Use `select` to choose specific values: `select={'scenario': 'base', 'period': 2024}` + - Let auto-selection choose the first value (a warning will be logged) + Args: lower_percentage_group: Percentage threshold for "Others" grouping. colors: Color scheme. Also see plotly. @@ -996,10 +1133,35 @@ def plot_node_balance_pie( save: Whether to save plot. show: Whether to display plot. engine: Plotting engine ('plotly' or 'matplotlib'). - indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}. - If None, uses first value for each dimension. - If empty dict {}, uses all values. + select: Optional data selection dict. Supports single values, lists, slices, and index arrays. + Use this to select specific scenario/period before creating the pie chart. + + Examples: + Basic usage (auto-selects first scenario/period if present): + + >>> results['Bus'].plot_node_balance_pie() + + Explicitly select a scenario and period: + + >>> results['Bus'].plot_node_balance_pie(select={'scenario': 'high_demand', 'period': 2030}) """ + # Handle deprecated indexer parameter + if indexer is not None: + # Check for conflict with new parameter + if select is not None: + raise ValueError( + "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'." + ) + + import warnings + + warnings.warn( + "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.", + DeprecationWarning, + stacklevel=2, + ) + select = indexer + inputs = sanitize_dataset( ds=self.solution[self.inputs] * self._calculation_results.hours_per_timestep, threshold=1e-5, @@ -1015,15 +1177,46 @@ def plot_node_balance_pie( drop_suffix='|', ) - inputs, suffix_parts = _apply_indexer_to_data(inputs, indexer, drop=True) - outputs, suffix_parts = _apply_indexer_to_data(outputs, indexer, drop=True) - suffix = '--' + '-'.join(suffix_parts) if suffix_parts else '' - - title = f'{self.label} (total flow hours){suffix}' + inputs, suffix_parts = _apply_selection_to_data(inputs, select=select, drop=True) + outputs, suffix_parts = _apply_selection_to_data(outputs, select=select, drop=True) + # Sum over time dimension inputs = inputs.sum('time') outputs = outputs.sum('time') + # Auto-select first value for any remaining dimensions (scenario, period, etc.) + # Pie charts need scalar data, so we automatically reduce extra dimensions + extra_dims_inputs = [dim for dim in inputs.dims if dim != 'time'] + extra_dims_outputs = [dim for dim in outputs.dims if dim != 'time'] + extra_dims = list(set(extra_dims_inputs + extra_dims_outputs)) + + if extra_dims: + auto_select = {} + for dim in extra_dims: + # Get first value of this dimension + if dim in inputs.coords: + first_val = inputs.coords[dim].values[0] + elif dim in outputs.coords: + first_val = outputs.coords[dim].values[0] + else: + continue + auto_select[dim] = first_val + logger.info( + f'Pie chart auto-selected {dim}={first_val} (first value). ' + f'Use select={{"{dim}": value}} to choose a different value.' + ) + + # Apply auto-selection + inputs = inputs.sel(auto_select) + outputs = outputs.sel(auto_select) + + # Update suffix with auto-selected values + auto_suffix_parts = [f'{dim}={val}' for dim, val in auto_select.items()] + suffix_parts.extend(auto_suffix_parts) + + suffix = '--' + '-'.join(suffix_parts) if suffix_parts else '' + title = f'{self.label} (total flow hours){suffix}' + if engine == 'plotly': figure_like = plotting.dual_pie_with_plotly( data_left=inputs.to_pandas(), @@ -1068,6 +1261,8 @@ def node_balance( with_last_timestep: bool = False, unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate', drop_suffix: bool = False, + select: dict[FlowSystemDimensions, Any] | None = None, + # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, ) -> xr.Dataset: """ @@ -1081,10 +1276,25 @@ def node_balance( - 'flow_rate': Returns the flow_rates of the Node. - 'flow_hours': Returns the flow_hours of the Node. [flow_hours(t) = flow_rate(t) * dt(t)]. Renames suffixes to |flow_hours. drop_suffix: Whether to drop the suffix from the variable names. - indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}. - If None, uses first value for each dimension. - If empty dict {}, uses all values. + select: Optional data selection dict. Supports single values, lists, slices, and index arrays. """ + # Handle deprecated indexer parameter + if indexer is not None: + # Check for conflict with new parameter + if select is not None: + raise ValueError( + "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'." + ) + + import warnings + + warnings.warn( + "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.", + DeprecationWarning, + stacklevel=2, + ) + select = indexer + ds = self.solution[self.inputs + self.outputs] ds = sanitize_dataset( @@ -1103,7 +1313,7 @@ def node_balance( drop_suffix='|' if drop_suffix else None, ) - ds, _ = _apply_indexer_to_data(ds, indexer, drop=True) + ds, _ = _apply_selection_to_data(ds, select=select, drop=True) if unit_type == 'flow_hours': ds = ds * self._calculation_results.hours_per_timestep @@ -1140,10 +1350,15 @@ def plot_charge_state( show: bool = True, colors: plotting.ColorType = 'viridis', engine: plotting.PlottingEngine = 'plotly', - mode: Literal['area', 'stacked_bar', 'line'] = 'stacked_bar', + mode: Literal['area', 'stacked_bar', 'line'] = 'area', + select: dict[FlowSystemDimensions, Any] | None = None, + facet_by: str | list[str] | None = 'scenario', + animate_by: str | None = 'period', + facet_cols: int = 3, + # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, ) -> plotly.graph_objs.Figure: - """Plot storage charge state over time, combined with the node balance. + """Plot storage charge state over time, combined with the node balance with optional faceting and animation. Args: save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. @@ -1151,42 +1366,120 @@ def plot_charge_state( colors: Color scheme. Also see plotly. engine: Plotting engine to use. Only 'plotly' is implemented atm. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, or 'area' for stacked area charts. - indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}. - If None, uses first value for each dimension. - If empty dict {}, uses all values. + select: Optional data selection dict. Supports single values, lists, slices, and index arrays. + Applied BEFORE faceting/animation. + facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str) + or list of dimensions. Each unique value combination creates a subplot. Ignored if not found. + animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through + dimension values. Only one dimension can be animated. Ignored if not found. + facet_cols: Number of columns in the facet grid layout (default: 3). Raises: ValueError: If component is not a storage. + + Examples: + Basic plot: + + >>> results['Storage'].plot_charge_state() + + Facet by scenario: + + >>> results['Storage'].plot_charge_state(facet_by='scenario', facet_cols=2) + + Animate by period: + + >>> results['Storage'].plot_charge_state(animate_by='period') + + Facet by scenario AND animate by period: + + >>> results['Storage'].plot_charge_state(facet_by='scenario', animate_by='period') """ + # Handle deprecated indexer parameter + if indexer is not None: + # Check for conflict with new parameter + if select is not None: + raise ValueError( + "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'." + ) + + import warnings + + warnings.warn( + "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.", + DeprecationWarning, + stacklevel=2, + ) + select = indexer + if not self.is_storage: raise ValueError(f'Cant plot charge_state. "{self.label}" is not a storage') - ds = self.node_balance(with_last_timestep=True, indexer=indexer) - charge_state = self.charge_state + # Get node balance and charge state + ds = self.node_balance(with_last_timestep=True) + charge_state_da = self.charge_state - ds, suffix_parts = _apply_indexer_to_data(ds, indexer, drop=True) - charge_state, suffix_parts = _apply_indexer_to_data(charge_state, indexer, drop=True) + # Apply select filtering + ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True) + charge_state_da, _ = _apply_selection_to_data(charge_state_da, select=select, drop=True) suffix = '--' + '-'.join(suffix_parts) if suffix_parts else '' title = f'Operation Balance of {self.label}{suffix}' if engine == 'plotly': - fig = plotting.with_plotly( - ds.to_dataframe(), + # Plot flows (node balance) with the specified mode + figure_like = plotting.with_plotly( + ds, + facet_by=facet_by, + animate_by=animate_by, colors=colors, mode=mode, title=title, + facet_cols=facet_cols, ) - # TODO: Use colors for charge state? + # Create a dataset with just charge_state and plot it as lines + # This ensures proper handling of facets and animation + charge_state_ds = charge_state_da.to_dataset(name=self._charge_state) - charge_state = charge_state.to_dataframe() - fig.add_trace( - plotly.graph_objs.Scatter( - x=charge_state.index, y=charge_state.values.flatten(), mode='lines', name=self._charge_state - ) + # Plot charge_state with mode='line' to get Scatter traces + charge_state_fig = plotting.with_plotly( + charge_state_ds, + facet_by=facet_by, + animate_by=animate_by, + colors=colors, + mode='line', # Always line for charge_state + title='', # No title needed for this temp figure + facet_cols=facet_cols, ) + + # Add charge_state traces to the main figure + # This preserves subplot assignments and animation frames + for trace in charge_state_fig.data: + trace.line.width = 2 # Make charge_state line more prominent + trace.line.shape = 'linear' # Smooth line for charge state (not stepped like flows) + figure_like.add_trace(trace) + + # Also add traces from animation frames if they exist + # Both figures use the same animate_by parameter, so they should have matching frames + if hasattr(charge_state_fig, 'frames') and charge_state_fig.frames: + # Add charge_state traces to each frame + for i, frame in enumerate(charge_state_fig.frames): + if i < len(figure_like.frames): + for trace in frame.data: + trace.line.width = 2 + trace.line.shape = 'linear' # Smooth line for charge state + figure_like.frames[i].data = figure_like.frames[i].data + (trace,) + + default_filetype = '.html' elif engine == 'matplotlib': + # Matplotlib requires only 'time' dimension; check for extras after selection + extra_dims = [d for d in ds.dims if d != 'time'] + if extra_dims: + raise ValueError( + f'Matplotlib engine only supports a single time axis, but found extra dimensions: {extra_dims}. ' + f'Please use select={{...}} to reduce dimensions or switch to engine="plotly" for faceting/animation.' + ) + # For matplotlib, plot flows (node balance), then add charge_state as line fig, ax = plotting.with_matplotlib( ds.to_dataframe(), colors=colors, @@ -1194,15 +1487,25 @@ def plot_charge_state( title=title, ) - charge_state = charge_state.to_dataframe() - ax.plot(charge_state.index, charge_state.values.flatten(), label=self._charge_state) + # Add charge_state as a line overlay + charge_state_df = charge_state_da.to_dataframe() + ax.plot( + charge_state_df.index, + charge_state_df.values.flatten(), + label=self._charge_state, + linewidth=2, + color='black', + ) + ax.legend() fig.tight_layout() - fig = fig, ax + + figure_like = fig, ax + default_filetype = '.png' return plotting.export_figure( - fig, + figure_like=figure_like, default_path=self._calculation_results.folder / title, - default_filetype='.html', + default_filetype=default_filetype, user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, @@ -1476,37 +1779,95 @@ def solution_without_overlap(self, variable_name: str) -> xr.DataArray: def plot_heatmap( self, variable_name: str, - heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D', - heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h', - color_map: str = 'portland', + reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] + | Literal['auto'] + | None = 'auto', + colors: str = 'portland', save: bool | pathlib.Path = False, show: bool = True, engine: plotting.PlottingEngine = 'plotly', + facet_by: str | list[str] | None = None, + animate_by: str | None = None, + facet_cols: int = 3, + fill: Literal['ffill', 'bfill'] | None = 'ffill', + # Deprecated parameters (kept for backwards compatibility) + heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, + heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, + color_map: str | None = None, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """Plot heatmap of variable solution across segments. Args: variable_name: Variable to plot. - heatmap_timeframes: Time aggregation level. - heatmap_timesteps_per_frame: Timesteps per frame. - color_map: Color scheme. Also see plotly. + reshape_time: Time reshaping configuration (default: 'auto'): + - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains + - Tuple like ('D', 'h'): Explicit reshaping (days vs hours) + - None: Disable time reshaping + colors: Color scheme. See plotting.ColorType for options. save: Whether to save plot. show: Whether to display plot. engine: Plotting engine. + facet_by: Dimension(s) to create facets (subplots) for. + animate_by: Dimension to animate over (Plotly only). + facet_cols: Number of columns in the facet grid layout. + fill: Method to fill missing values: 'ffill' or 'bfill'. + heatmap_timeframes: (Deprecated) Use reshape_time instead. + heatmap_timesteps_per_frame: (Deprecated) Use reshape_time instead. + color_map: (Deprecated) Use colors instead. Returns: Figure object. """ + # Handle deprecated parameters + if heatmap_timeframes is not None or heatmap_timesteps_per_frame is not None: + # Check for conflict with new parameter + if reshape_time != 'auto': # Check if user explicitly set reshape_time + raise ValueError( + "Cannot use both deprecated parameters 'heatmap_timeframes'/'heatmap_timesteps_per_frame' " + "and new parameter 'reshape_time'. Use only 'reshape_time'." + ) + + import warnings + + warnings.warn( + "The 'heatmap_timeframes' and 'heatmap_timesteps_per_frame' parameters are deprecated. " + "Use 'reshape_time=(timeframes, timesteps_per_frame)' instead.", + DeprecationWarning, + stacklevel=2, + ) + # Override reshape_time if old parameters provided + if heatmap_timeframes is not None and heatmap_timesteps_per_frame is not None: + reshape_time = (heatmap_timeframes, heatmap_timesteps_per_frame) + + if color_map is not None: + # Check for conflict with new parameter + if colors != 'portland': # Check if user explicitly set colors + raise ValueError( + "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'." + ) + + import warnings + + warnings.warn( + "The 'color_map' parameter is deprecated. Use 'colors' instead.", + DeprecationWarning, + stacklevel=2, + ) + colors = color_map + return plot_heatmap( - dataarray=self.solution_without_overlap(variable_name), + data=self.solution_without_overlap(variable_name), name=variable_name, folder=self.folder, - heatmap_timeframes=heatmap_timeframes, - heatmap_timesteps_per_frame=heatmap_timesteps_per_frame, - color_map=color_map, + reshape_time=reshape_time, + colors=colors, save=save, show=show, engine=engine, + facet_by=facet_by, + animate_by=animate_by, + facet_cols=facet_cols, + fill=fill, ) def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = None, compression: int = 5): @@ -1536,59 +1897,212 @@ def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = N def plot_heatmap( - dataarray: xr.DataArray, - name: str, - folder: pathlib.Path, - heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D', - heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h', - color_map: str = 'portland', + data: xr.DataArray | xr.Dataset, + name: str | None = None, + folder: pathlib.Path | None = None, + colors: plotting.ColorType = 'viridis', save: bool | pathlib.Path = False, show: bool = True, engine: plotting.PlottingEngine = 'plotly', + select: dict[str, Any] | None = None, + facet_by: str | list[str] | None = None, + animate_by: str | None = None, + facet_cols: int = 3, + reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] + | Literal['auto'] + | None = 'auto', + fill: Literal['ffill', 'bfill'] | None = 'ffill', + # Deprecated parameters (kept for backwards compatibility) indexer: dict[str, Any] | None = None, + heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, + heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, + color_map: str | None = None, ): - """Plot heatmap of time series data. + """Plot heatmap visualization with support for multi-variable, faceting, and animation. + + This function provides a standalone interface to the heatmap plotting capabilities, + supporting the same modern features as CalculationResults.plot_heatmap(). Args: - dataarray: Data to plot. - name: Variable name for title. - folder: Save folder. - heatmap_timeframes: Time aggregation level. - heatmap_timesteps_per_frame: Timesteps per frame. - color_map: Color scheme. Also see plotly. - save: Whether to save plot. - show: Whether to display plot. - engine: Plotting engine. - indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}. - If None, uses first value for each dimension. - If empty dict {}, uses all values. + data: Data to plot. Can be a single DataArray or an xarray Dataset. + When a Dataset is provided, all data variables are combined along a new 'variable' dimension. + name: Optional name for the title. If not provided, uses the DataArray name or + generates a default title for Datasets. + folder: Save folder for the plot. Defaults to current directory if not provided. + colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options. + save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. + show: Whether to show the plot or not. + engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. + select: Optional data selection dict. Supports single values, lists, slices, and index arrays. + facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str) + or list of dimensions. Each unique value combination creates a subplot. + animate_by: Dimension to animate over (Plotly only). Creates animation frames. + facet_cols: Number of columns in the facet grid layout (default: 3). + reshape_time: Time reshaping configuration (default: 'auto'): + - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains + - Tuple: Explicit reshaping, e.g. ('D', 'h') for days vs hours + - None: Disable auto-reshaping + fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill). + Default is 'ffill'. + + Examples: + Single DataArray with time reshaping: + + >>> plot_heatmap(data, name='Temperature', folder=Path('.'), reshape_time=('D', 'h')) + + Dataset with multiple variables (facet by variable): + + >>> dataset = xr.Dataset({'Boiler': data1, 'CHP': data2, 'Storage': data3}) + >>> plot_heatmap( + ... dataset, + ... folder=Path('.'), + ... facet_by='variable', + ... reshape_time=('D', 'h'), + ... ) + + Dataset with animation by variable: + + >>> plot_heatmap(dataset, animate_by='variable', reshape_time=('D', 'h')) """ - dataarray, suffix_parts = _apply_indexer_to_data(dataarray, indexer, drop=True) + # Handle deprecated heatmap time parameters + if heatmap_timeframes is not None or heatmap_timesteps_per_frame is not None: + # Check for conflict with new parameter + if reshape_time != 'auto': # User explicitly set reshape_time + raise ValueError( + "Cannot use both deprecated parameters 'heatmap_timeframes'/'heatmap_timesteps_per_frame' " + "and new parameter 'reshape_time'. Use only 'reshape_time'." + ) + + import warnings + + warnings.warn( + "The 'heatmap_timeframes' and 'heatmap_timesteps_per_frame' parameters are deprecated. " + "Use 'reshape_time=(timeframes, timesteps_per_frame)' instead.", + DeprecationWarning, + stacklevel=2, + ) + # Override reshape_time if both old parameters provided + if heatmap_timeframes is not None and heatmap_timesteps_per_frame is not None: + reshape_time = (heatmap_timeframes, heatmap_timesteps_per_frame) + + # Handle deprecated color_map parameter + if color_map is not None: + # Check for conflict with new parameter + if colors != 'viridis': # User explicitly set colors + raise ValueError( + "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'." + ) + + import warnings + + warnings.warn( + "The 'color_map' parameter is deprecated. Use 'colors' instead.", + DeprecationWarning, + stacklevel=2, + ) + colors = color_map + + # Handle deprecated indexer parameter + if indexer is not None: + # Check for conflict with new parameter + if select is not None: # User explicitly set select + raise ValueError( + "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'." + ) + + import warnings + + warnings.warn( + "The 'indexer' parameter is deprecated. Use 'select' instead.", + DeprecationWarning, + stacklevel=2, + ) + select = indexer + + # Convert Dataset to DataArray with 'variable' dimension + if isinstance(data, xr.Dataset): + # Extract all data variables from the Dataset + variable_names = list(data.data_vars) + dataarrays = [data[var] for var in variable_names] + + # Combine into single DataArray with 'variable' dimension + data = xr.concat(dataarrays, dim='variable') + data = data.assign_coords(variable=variable_names) + + # Use Dataset variable names for title if name not provided + if name is None: + title_name = f'Heatmap of {len(variable_names)} variables' + else: + title_name = name + else: + # Single DataArray + if name is None: + title_name = data.name if data.name else 'Heatmap' + else: + title_name = name + + # Apply select filtering + data, suffix_parts = _apply_selection_to_data(data, select=select, drop=True) suffix = '--' + '-'.join(suffix_parts) if suffix_parts else '' - name = name if not suffix_parts else name + suffix - heatmap_data = plotting.heat_map_data_from_df( - dataarray.to_dataframe(name), heatmap_timeframes, heatmap_timesteps_per_frame, 'ffill' - ) + # Matplotlib heatmaps require at most 2D data + # Time dimension will be reshaped to 2D (timeframe × timestep), so can't have other dims alongside it + if engine == 'matplotlib': + dims = list(data.dims) - xlabel, ylabel = f'timeframe [{heatmap_timeframes}]', f'timesteps [{heatmap_timesteps_per_frame}]' + # If 'time' dimension exists and will be reshaped, we can't have any other dimensions + if 'time' in dims and len(dims) > 1 and reshape_time is not None: + extra_dims = [d for d in dims if d != 'time'] + raise ValueError( + f'Matplotlib heatmaps with time reshaping cannot have additional dimensions. ' + f'Found extra dimensions: {extra_dims}. ' + f'Use select={{...}} to reduce to time only, use "reshape_time=None" or switch to engine="plotly" or use for multi-dimensional support.' + ) + # If no 'time' dimension (already reshaped or different data), allow at most 2 dimensions + elif 'time' not in dims and len(dims) > 2: + raise ValueError( + f'Matplotlib heatmaps support at most 2 dimensions, but data has {len(dims)}: {dims}. ' + f'Use select={{...}} to reduce dimensions or switch to engine="plotly".' + ) + # Build title + title = f'{title_name}{suffix}' + if isinstance(reshape_time, tuple): + timeframes, timesteps_per_frame = reshape_time + title += f' ({timeframes} vs {timesteps_per_frame})' + + # Plot with appropriate engine if engine == 'plotly': - figure_like = plotting.heat_map_plotly( - heatmap_data, title=name, color_map=color_map, xlabel=xlabel, ylabel=ylabel + figure_like = plotting.heatmap_with_plotly( + data=data, + facet_by=facet_by, + animate_by=animate_by, + colors=colors, + title=title, + facet_cols=facet_cols, + reshape_time=reshape_time, + fill=fill, ) default_filetype = '.html' elif engine == 'matplotlib': - figure_like = plotting.heat_map_matplotlib( - heatmap_data, title=name, color_map=color_map, xlabel=xlabel, ylabel=ylabel + figure_like = plotting.heatmap_with_matplotlib( + data=data, + colors=colors, + title=title, + reshape_time=reshape_time, + fill=fill, ) default_filetype = '.png' else: raise ValueError(f'Engine "{engine}" not supported. Use "plotly" or "matplotlib"') + # Set default folder if not provided + if folder is None: + folder = pathlib.Path('.') + return plotting.export_figure( figure_like=figure_like, - default_path=folder / f'{name} ({heatmap_timeframes}-{heatmap_timesteps_per_frame})', + default_path=folder / title, default_filetype=default_filetype, user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, @@ -1790,8 +2304,13 @@ def apply_filter(array, coord_name: str, coord_values: Any | list[Any]): if coord_name not in array.coords: raise AttributeError(f"Missing required coordinate '{coord_name}'") - # Convert single value to list - val_list = [coord_values] if isinstance(coord_values, str) else coord_values + # Normalize to list for sequence-like inputs (excluding strings) + if isinstance(coord_values, str): + val_list = [coord_values] + elif isinstance(coord_values, (list, tuple, np.ndarray, pd.Index)): + val_list = list(coord_values) + else: + val_list = [coord_values] # Verify coord_values exist available = set(array[coord_name].values) @@ -1801,7 +2320,7 @@ def apply_filter(array, coord_name: str, coord_values: Any | list[Any]): # Apply filter return array.where( - array[coord_name].isin(val_list) if isinstance(coord_values, list) else array[coord_name] == coord_values, + array[coord_name].isin(val_list) if len(val_list) > 1 else array[coord_name] == val_list[0], drop=True, ) @@ -1820,36 +2339,26 @@ def apply_filter(array, coord_name: str, coord_values: Any | list[Any]): return da -def _apply_indexer_to_data( - data: xr.DataArray | xr.Dataset, indexer: dict[str, Any] | None = None, drop=False +def _apply_selection_to_data( + data: xr.DataArray | xr.Dataset, + select: dict[str, Any] | None = None, + drop=False, ) -> tuple[xr.DataArray | xr.Dataset, list[str]]: """ - Apply indexer selection or auto-select first values for non-time dimensions. + Apply selection to data. Args: data: xarray Dataset or DataArray - indexer: Optional selection dict - If None, uses first value for each dimension (except time). - If empty dict {}, uses all values. + select: Optional selection dict + drop: Whether to drop dimensions after selection Returns: Tuple of (selected_data, selection_string) """ selection_string = [] - if indexer is not None: - # User provided indexer - data = data.sel(indexer, drop=drop) - selection_string.extend(f'{v}[{k}]' for k, v in indexer.items()) - else: - # Auto-select first value for each dimension except 'time' - selection = {} - for dim in data.dims: - if dim != 'time' and dim in data.coords: - first_value = data.coords[dim].values[0] - selection[dim] = first_value - selection_string.append(f'{first_value}[{dim}]') - if selection: - data = data.sel(selection, drop=drop) + if select: + data = data.sel(select, drop=drop) + selection_string.extend(f'{dim}={val}' for dim, val in select.items()) return data, selection_string diff --git a/tests/ressources/Sim1--flow_system.nc4 b/tests/ressources/Sim1--flow_system.nc4 new file mode 100644 index 000000000..b56abf52d Binary files /dev/null and b/tests/ressources/Sim1--flow_system.nc4 differ diff --git a/tests/ressources/Sim1--solution.nc4 b/tests/ressources/Sim1--solution.nc4 new file mode 100644 index 000000000..1e2785d85 Binary files /dev/null and b/tests/ressources/Sim1--solution.nc4 differ diff --git a/tests/ressources/Sim1--summary.yaml b/tests/ressources/Sim1--summary.yaml new file mode 100644 index 000000000..dfa95c4ef --- /dev/null +++ b/tests/ressources/Sim1--summary.yaml @@ -0,0 +1,92 @@ +Name: Sim1 +Number of timesteps: 9 +Calculation Type: FullCalculation +Constraints: 1501 +Variables: 1654 +Main Results: + Objective: 82.55 + Penalty: 0.0 + Effects: + costs [€]: + temporal: + - - 61.88 + - 68.4 + - - 58.99 + - 65.51 + - - 56.1 + - 62.62 + periodic: + - - 20.0 + - 20.0 + - - 20.0 + - 20.0 + - - 20.0 + - 20.0 + total: + - - 81.88 + - 88.4 + - - 78.99 + - 85.51 + - - 76.1 + - 82.62 + CO2 [kg]: + temporal: + - - 255.09 + - 274.65 + - - 255.09 + - 274.65 + - - 255.09 + - 274.65 + periodic: + - - -0.0 + - -0.0 + - - -0.0 + - -0.0 + - - -0.0 + - -0.0 + total: + - - 255.09 + - 274.65 + - - 255.09 + - 274.65 + - - 255.09 + - 274.65 + Invest-Decisions: + Invested: + Storage: + - - 30.0 + - 30.0 + - - 30.0 + - 30.0 + - - 30.0 + - 30.0 + Not invested: {} + Buses with excess: [] +Durations: + modeling: 0.83 + solving: 0.45 + saving: 0.0 +Config: + config_name: flixopt + logging: + level: INFO + file: null + rich: false + console: false + max_file_size: 10485760 + backup_count: 5 + date_format: '%Y-%m-%d %H:%M:%S' + format: '%(message)s' + console_width: 120 + show_path: false + show_logger_name: false + colors: + DEBUG: "\e[90m" + INFO: "\e[0m" + WARNING: "\e[33m" + ERROR: "\e[31m" + CRITICAL: "\e[1m\e[31m" + modeling: + big: 10000000 + epsilon: 1.0e-05 + big_binary_bound: 100000 diff --git a/tests/test_heatmap_reshape.py b/tests/test_heatmap_reshape.py new file mode 100644 index 000000000..092adff4e --- /dev/null +++ b/tests/test_heatmap_reshape.py @@ -0,0 +1,91 @@ +"""Test reshape_data_for_heatmap() for common use cases.""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from flixopt.plotting import reshape_data_for_heatmap + +# Set random seed for reproducible tests +np.random.seed(42) + + +@pytest.fixture +def hourly_week_data(): + """Typical use case: hourly data for a week.""" + time = pd.date_range('2024-01-01', periods=168, freq='h') + data = np.random.rand(168) * 100 + return xr.DataArray(data, dims=['time'], coords={'time': time}, name='power') + + +def test_daily_hourly_pattern(): + """Most common use case: reshape hourly data into days × hours for daily patterns.""" + time = pd.date_range('2024-01-01', periods=72, freq='h') + data = np.random.rand(72) * 100 + da = xr.DataArray(data, dims=['time'], coords={'time': time}) + + result = reshape_data_for_heatmap(da, reshape_time=('D', 'h')) + + assert 'timeframe' in result.dims and 'timestep' in result.dims + assert result.sizes['timeframe'] == 3 # 3 days + assert result.sizes['timestep'] == 24 # 24 hours + + +def test_weekly_daily_pattern(hourly_week_data): + """Common use case: reshape hourly data into weeks × days.""" + result = reshape_data_for_heatmap(hourly_week_data, reshape_time=('W', 'D')) + + assert 'timeframe' in result.dims and 'timestep' in result.dims + # 168 hours = 7 days = 1 week + assert result.sizes['timeframe'] == 1 # 1 week + assert result.sizes['timestep'] == 7 # 7 days + + +def test_with_irregular_data(): + """Real-world use case: data with missing timestamps needs filling.""" + time = pd.date_range('2024-01-01', periods=100, freq='15min') + data = np.random.rand(100) + # Randomly drop 30% to simulate real data gaps + keep = np.sort(np.random.choice(100, 70, replace=False)) # Must be sorted + da = xr.DataArray(data[keep], dims=['time'], coords={'time': time[keep]}) + + result = reshape_data_for_heatmap(da, reshape_time=('h', 'min'), fill='ffill') + + assert 'timeframe' in result.dims and 'timestep' in result.dims + # 100 * 15min = 1500min = 25h; reshaped to hours × minutes + assert result.sizes['timeframe'] == 25 # 25 hours + assert result.sizes['timestep'] == 60 # 60 minutes per hour + # Should handle irregular data without errors + + +def test_multidimensional_scenarios(): + """Use case: data with scenarios/periods that need to be preserved.""" + time = pd.date_range('2024-01-01', periods=48, freq='h') + scenarios = ['base', 'high'] + data = np.random.rand(48, 2) * 100 + + da = xr.DataArray(data, dims=['time', 'scenario'], coords={'time': time, 'scenario': scenarios}, name='demand') + + result = reshape_data_for_heatmap(da, reshape_time=('D', 'h')) + + # Should preserve scenario dimension + assert 'scenario' in result.dims + assert result.sizes['scenario'] == 2 + # 48 hours = 2 days × 24 hours + assert result.sizes['timeframe'] == 2 # 2 days + assert result.sizes['timestep'] == 24 # 24 hours + + +def test_no_reshape_returns_unchanged(): + """Use case: when reshape_time=None, return data as-is.""" + time = pd.date_range('2024-01-01', periods=24, freq='h') + da = xr.DataArray(np.random.rand(24), dims=['time'], coords={'time': time}) + + result = reshape_data_for_heatmap(da, reshape_time=None) + + xr.testing.assert_equal(result, da) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_plots.py b/tests/test_plots.py deleted file mode 100644 index 61c26c510..000000000 --- a/tests/test_plots.py +++ /dev/null @@ -1,151 +0,0 @@ -""" -Manual test script for plots -""" - -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import pytest - -from flixopt import plotting - - -@pytest.mark.slow -class TestPlots(unittest.TestCase): - def setUp(self): - np.random.seed(72) - - def tearDown(self): - """Cleanup matplotlib and plotly resources""" - plt.close('all') - # Force garbage collection to cleanup any lingering resources - import gc - - gc.collect() - - @staticmethod - def get_sample_data( - nr_of_columns: int = 7, - nr_of_periods: int = 10, - time_steps_per_period: int = 24, - drop_fraction_of_indices: float | None = None, - only_pos_or_neg: bool = True, - column_prefix: str = '', - ): - columns = [f'Region {i + 1}{column_prefix}' for i in range(nr_of_columns)] # More realistic column labels - values_per_column = nr_of_periods * time_steps_per_period - if only_pos_or_neg: - positive_data = np.abs(np.random.rand(values_per_column, nr_of_columns) * 100) - negative_data = -np.abs(np.random.rand(values_per_column, nr_of_columns) * 100) - data = pd.DataFrame( - np.concatenate([positive_data, negative_data], axis=1), - columns=[f'Region {i + 1}' for i in range(nr_of_columns)] - + [f'Region {i + 1} Negative' for i in range(nr_of_columns)], - ) - else: - data = pd.DataFrame( - np.random.randn(values_per_column, nr_of_columns) * 50 + 20, columns=columns - ) # Random data with both positive and negative values - data.index = pd.date_range('2023-01-01', periods=values_per_column, freq='h') - - if drop_fraction_of_indices: - # Randomly drop a percentage of rows to create irregular intervals - drop_indices = np.random.choice(data.index, int(len(data) * drop_fraction_of_indices), replace=False) - data = data.drop(drop_indices) - return data - - def test_bar_plots(self): - data = self.get_sample_data(nr_of_columns=10, nr_of_periods=1, time_steps_per_period=24) - # Create plotly figure (json renderer doesn't need .show()) - _ = plotting.with_plotly(data, 'stacked_bar') - plotting.with_matplotlib(data, 'stacked_bar') - plt.savefig(f'test_plot_{self._testMethodName}.png', bbox_inches='tight') - plt.close('all') # Close all figures to prevent memory leaks - - data = self.get_sample_data( - nr_of_columns=10, nr_of_periods=5, time_steps_per_period=24, drop_fraction_of_indices=0.3 - ) - # Create plotly figure (json renderer doesn't need .show()) - _ = plotting.with_plotly(data, 'stacked_bar') - plotting.with_matplotlib(data, 'stacked_bar') - plt.savefig(f'test_plot_{self._testMethodName}.png', bbox_inches='tight') - plt.close('all') # Close all figures to prevent memory leaks - - def test_line_plots(self): - data = self.get_sample_data(nr_of_columns=10, nr_of_periods=1, time_steps_per_period=24) - _ = plotting.with_plotly(data, 'line') - plotting.with_matplotlib(data, 'line') - plt.savefig(f'test_plot_{self._testMethodName}.png', bbox_inches='tight') - plt.close('all') # Close all figures to prevent memory leaks - - data = self.get_sample_data( - nr_of_columns=10, nr_of_periods=5, time_steps_per_period=24, drop_fraction_of_indices=0.3 - ) - _ = plotting.with_plotly(data, 'line') - plotting.with_matplotlib(data, 'line') - plt.savefig(f'test_plot_{self._testMethodName}.png', bbox_inches='tight') - plt.close('all') # Close all figures to prevent memory leaks - - def test_stacked_line_plots(self): - data = self.get_sample_data(nr_of_columns=10, nr_of_periods=1, time_steps_per_period=24) - _ = plotting.with_plotly(data, 'area') - - data = self.get_sample_data( - nr_of_columns=10, nr_of_periods=5, time_steps_per_period=24, drop_fraction_of_indices=0.3 - ) - _ = plotting.with_plotly(data, 'area') - - def test_heat_map_plots(self): - # Generate single-column data with datetime index for heatmap - data = self.get_sample_data(nr_of_columns=1, nr_of_periods=10, time_steps_per_period=24, only_pos_or_neg=False) - - # Convert data for heatmap plotting using 'day' as period and 'hour' steps - heatmap_data = plotting.reshape_to_2d(data.iloc[:, 0].values.flatten(), 24) - # Plotting heatmaps with Plotly and Matplotlib - _ = plotting.heat_map_plotly(pd.DataFrame(heatmap_data)) - plotting.heat_map_matplotlib(pd.DataFrame(heatmap_data)) - plt.savefig(f'test_plot_{self._testMethodName}.png', bbox_inches='tight') - plt.close('all') # Close all figures to prevent memory leaks - - def test_heat_map_plots_resampling(self): - date_range = pd.date_range(start='2023-01-01', end='2023-03-21', freq='5min') - - # Generate random data for the DataFrame, simulating some metric (e.g., energy consumption, temperature) - data = np.random.rand(len(date_range)) - - # Create the DataFrame with a datetime index - df = pd.DataFrame(data, index=date_range, columns=['value']) - - # Randomly drop a percentage of rows to create irregular intervals - drop_fraction = 0.3 # Fraction of data points to drop (30% in this case) - drop_indices = np.random.choice(df.index, int(len(df) * drop_fraction), replace=False) - df_irregular = df.drop(drop_indices) - - # Generate single-column data with datetime index for heatmap - data = df_irregular - # Convert data for heatmap plotting using 'day' as period and 'hour' steps - heatmap_data = plotting.heat_map_data_from_df(data, 'MS', 'D') - _ = plotting.heat_map_plotly(heatmap_data) - plotting.heat_map_matplotlib(pd.DataFrame(heatmap_data)) - plt.savefig(f'test_plot_{self._testMethodName}.png', bbox_inches='tight') - plt.close('all') # Close all figures to prevent memory leaks - - heatmap_data = plotting.heat_map_data_from_df(data, 'W', 'h', fill='ffill') - # Plotting heatmaps with Plotly and Matplotlib - _ = plotting.heat_map_plotly(pd.DataFrame(heatmap_data)) - plotting.heat_map_matplotlib(pd.DataFrame(heatmap_data)) - plt.savefig(f'test_plot_{self._testMethodName}.png', bbox_inches='tight') - plt.close('all') # Close all figures to prevent memory leaks - - heatmap_data = plotting.heat_map_data_from_df(data, 'D', 'h', fill='ffill') - # Plotting heatmaps with Plotly and Matplotlib - _ = plotting.heat_map_plotly(pd.DataFrame(heatmap_data)) - plotting.heat_map_matplotlib(pd.DataFrame(heatmap_data)) - plt.savefig(f'test_plot_{self._testMethodName}.png', bbox_inches='tight') - plt.close('all') # Close all figures to prevent memory leaks - - -if __name__ == '__main__': - pytest.main(['-v', '--disable-warnings']) diff --git a/tests/test_results_plots.py b/tests/test_results_plots.py index 35a219e31..1fd6cf7f5 100644 --- a/tests/test_results_plots.py +++ b/tests/test_results_plots.py @@ -48,18 +48,29 @@ def test_results_plots(flow_system, plotting_engine, show, save, color_spec): results['Boiler'].plot_node_balance(engine=plotting_engine, save=save, show=show, colors=color_spec) - results.plot_heatmap( - 'Speicher(Q_th_load)|flow_rate', - heatmap_timeframes='D', - heatmap_timesteps_per_frame='h', - color_map='viridis', # Note: heatmap only accepts string colormap - save=show, - show=save, - engine=plotting_engine, - ) + # Matplotlib doesn't support faceting/animation, so disable them for matplotlib engine + heatmap_kwargs = { + 'reshape_time': ('D', 'h'), + 'colors': 'viridis', # Note: heatmap only accepts string colormap + 'save': save, + 'show': show, + 'engine': plotting_engine, + } + if plotting_engine == 'matplotlib': + heatmap_kwargs['facet_by'] = None + heatmap_kwargs['animate_by'] = None + + results.plot_heatmap('Speicher(Q_th_load)|flow_rate', **heatmap_kwargs) results['Speicher'].plot_node_balance_pie(engine=plotting_engine, save=save, show=show, colors=color_spec) - results['Speicher'].plot_charge_state(engine=plotting_engine) + + # Matplotlib doesn't support faceting/animation for plot_charge_state, and 'area' mode + charge_state_kwargs = {'engine': plotting_engine} + if plotting_engine == 'matplotlib': + charge_state_kwargs['facet_by'] = None + charge_state_kwargs['animate_by'] = None + charge_state_kwargs['mode'] = 'stacked_bar' # 'area' not supported by matplotlib + results['Speicher'].plot_charge_state(**charge_state_kwargs) plt.close('all')