_aymcZRMO}cIxcz4{-D|8kW=X;#2i6j|?bC|z{2?vE2h=1d53U}zynLm{~
z4}S*dH9LEI6nQVmjdy)~*C!|E*Yz+(?=gT7LVX{ynf}7^jiRLTGkrpK7HhKfz0aJi
z8FY~ZcJSSk*RrD180*45dUkLHTLp@J+3a6#o540lGS-J3F6%UleF?D%tj&?@XS1hZ
zBkgW%)qsO@8J|Z_Vgo%>Z(*@OVj%m)eGtCFY;&lcWO2tU?`F5aDb8f>
z<#*iAW`pioEcui3A7&Gw45zcTb5=dV_QKZ=?CDKg9%ZRup(p#t%%sOyCR|IYY}uR+
zkF$5dSPwR*^4MAy3&wh~|FSz@Ws9JY{aD0-#I0;D6f%ij@zF1DvZoNAgEhO8iozW4BJw^a3j?oL_=bN)-UuUiYOou-
z(r@yrvPphe8METW=0Oz1dz%ok^hpGW01+SpM1Tko0V2@Q2^de+;4}!=n;!#?K~#H%
zK$`~QZf^akf(aIN{VU>K`iW=jLJk^2&>jbl-+c9hFjybPZ$5+yfj4YS)_3mp6hXY@{7j}!dx9?uCx5+(v@?Iirv?V+-pZ*zD_bS;jM9(yUaB9cjP^*j7YUkBbt8?2EsC=f?q-Uzhny8D
zuBDp%VlygtHJi-IbrzdUX1yo0qzWDCJAno|cJ|HQ`m=iDU1($8>v3y^_Z0bkkQ?vV
z*>4lFUV#e!F@xRM7~UcSQMrJ|#sVRtbRs|mhyW2F0z`la5CJ04XbAAV{%}{wIhmqb
zIkI-qQ^QySIF>Q+mUSGvC^IYbqRg=wnYr1c$4`Kl=)wQxd*K)6d-BVj_#U=B82+L)
ze_5r_HO*PV_av06`U@=f+4OV3m>q-!{Zo@0WZ+>%bEZAEFkH+@Xx)b138R8p=2}hO#=a^~HYc
z`Ce=-1|(4bABk3ZV~kb)q>oiD9cz`_o@s7e#aQlv%pjWKmBxcM+bfm&_SI5^U#}xr*RS@YQq{
zbbK@i{%C!Is;in$A0~Ve&Cv(zqGZ7}6CUTZQdlig+#TI1&Cyq8>Ex3oV_g4aBLn(d
znuG7p&qMo6S4nxP(^XzkGDS9*vDiZO#zYSKtd9ijyf_@;2vp`diZ@5^Cj=r!=|q4C
z5CI}U1c(3;AOb|7;S#WV^WcfT9C>%8=D;3T>QB6$WKf@{ynP!!7#Wv=jx5%1~E4VQoS8$_j426hn3`aaL9mh9N
zHbyX{zA-`}WgLP)eH>~)*%(HU`o;i(l+pi3eOwWavawZ!)Hk+(kTR|qM}1r|jxr7)
zAYWq}2bmvNl%u|}CxcAK73P?3Y_K5HamBfjE?1nRY%CX%`o@Y7DI3c@q--qWkTNdR
zKzqh244IBgEHE8cP@s%!Bv4*fD!w5PT)~d%xPl#JT)~bqu3$$QSFoduE7(!S73?U-
zE}hxOI+h3THe4CD|6~ER_U#E@`J7i9h0}7gh|9T(~4i9W%z{IWu@B)nd!lKFx
zz(s~lHStw=&Cx_;U@^*;o5on01K!@zgVZqI7coZH9JoV52Vuj|sD$F>@{*K$oTWJU
zraAa88}iUZS2qUX9O&FM=dYz0dw&{(a*k#q8`IN`lNf_^@Ozgcw6na{VkrcU;AxKL
zQaK!yjd46j3mF?JhZ1gnRzzry7@5+^V%%q(1BU=L2j6;v2k&Yp%VwoGHmEty6?*J6
z_f=Or3p}pr&RtiKe=ngl=tFJ2|MuklH|X}K!V>`^Km>>Y5g-CYfCvx)B0vO)z#ox7
zZASx!#Ne=>poha{7>i>@I2!jyWK4A?0z`la5CI}U1c(3;AOekrz>LWSd2vlLVvar5
z?d}M+DLv|~>cn8}`!-$or1sE+G?eK`1c(3;AOb{y2oM1xKm>@u??<4vvjf>QJMjB8
z9-Wy85CI}U1c(3;AOb{y2%KU9=3TsUBcL3GQxOMPs+H97wY3KUbw3h?&5_q?&5{|xQiFcxQiF+
z<1St(<1SvPkGpuGjJtTDKJMa$GVbDq`nZc1%D9Uc>f9~s*+QVJEFdcXC!gSol
z3uWBJ3uWB@3iBIVTgeDx7b__n`&CKV*mO$D#?DewHa3WovauPIl&?H*Xm)K!3ntTO
afuu 0
+ else:
+ fig, ax = plotting.with_matplotlib(data, mode=mode)
+ assert fig is not None and ax is not None
+
+
+@pytest.mark.parametrize(
+ 'engine,data_type', [(e, dt) for e in ['plotly', 'matplotlib'] for dt in ['dataset', 'dataframe', 'series']]
+)
+def test_pie_plots(engine, data_type):
+ """Test pie charts with all data types, including automatic summing."""
+ time = pd.date_range('2020-01-01', periods=5, freq='h')
+
+ # Single-value data
+ single_data = {
+ 'dataset': xr.Dataset({'A': xr.DataArray(10), 'B': xr.DataArray(20), 'C': xr.DataArray(30)}),
+ 'dataframe': pd.DataFrame({'A': [10], 'B': [20], 'C': [30]}),
+ 'series': pd.Series({'A': 10, 'B': 20, 'C': 30}),
+ }[data_type]
+
+ # Multi-dimensional data (for summing test)
+ multi_data = {
+ 'dataset': xr.Dataset(
+ {'A': (['time'], [1, 2, 3, 4, 5]), 'B': (['time'], [5, 5, 5, 5, 5])}, coords={'time': time}
+ ),
+ 'dataframe': pd.DataFrame({'A': [1, 2, 3, 4, 5], 'B': [5, 5, 5, 5, 5]}, index=time),
+ 'series': pd.Series([1, 2, 3, 4, 5], index=time, name='A'),
+ }[data_type]
+
+ for data in [single_data, multi_data]:
+ if engine == 'plotly':
+ fig = plotting.dual_pie_with_plotly(data, data)
+ assert fig is not None and len(fig.data) >= 2
+ if data is multi_data and data_type != 'series':
+ assert sum(fig.data[0].values) == pytest.approx(40)
+ else:
+ fig, axes = plotting.dual_pie_with_matplotlib(data, data)
+ assert fig is not None and len(axes) == 2
diff --git a/tests/test_results_plots.py b/tests/test_results_plots.py
index 35a219e31..a656f7c44 100644
--- a/tests/test_results_plots.py
+++ b/tests/test_results_plots.py
@@ -28,7 +28,7 @@ def plotting_engine(request):
@pytest.fixture(
params=[
- 'viridis', # Test string colormap
+ 'turbo', # Test string colormap
['#ff0000', '#00ff00', '#0000ff', '#ffff00', '#ff00ff', '#00ffff'], # Test color list
{
'Boiler(Q_th)|flow_rate': '#ff0000',
@@ -48,18 +48,29 @@ def test_results_plots(flow_system, plotting_engine, show, save, color_spec):
results['Boiler'].plot_node_balance(engine=plotting_engine, save=save, show=show, colors=color_spec)
- results.plot_heatmap(
- 'Speicher(Q_th_load)|flow_rate',
- heatmap_timeframes='D',
- heatmap_timesteps_per_frame='h',
- color_map='viridis', # Note: heatmap only accepts string colormap
- save=show,
- show=save,
- engine=plotting_engine,
- )
+ # Matplotlib doesn't support faceting/animation, so disable them for matplotlib engine
+ heatmap_kwargs = {
+ 'reshape_time': ('D', 'h'),
+ 'colors': 'turbo', # Note: heatmap only accepts string colormap
+ 'save': save,
+ 'show': show,
+ 'engine': plotting_engine,
+ }
+ if plotting_engine == 'matplotlib':
+ heatmap_kwargs['facet_by'] = None
+ heatmap_kwargs['animate_by'] = None
+
+ results.plot_heatmap('Speicher(Q_th_load)|flow_rate', **heatmap_kwargs)
results['Speicher'].plot_node_balance_pie(engine=plotting_engine, save=save, show=show, colors=color_spec)
- results['Speicher'].plot_charge_state(engine=plotting_engine)
+
+ # Matplotlib doesn't support faceting/animation for plot_charge_state, and 'area' mode
+ charge_state_kwargs = {'engine': plotting_engine}
+ if plotting_engine == 'matplotlib':
+ charge_state_kwargs['facet_by'] = None
+ charge_state_kwargs['animate_by'] = None
+ charge_state_kwargs['mode'] = 'stacked_bar' # 'area' not supported by matplotlib
+ results['Speicher'].plot_charge_state(**charge_state_kwargs)
plt.close('all')
From 03b12028bf5c47df81ebf45dd6ac424fee4f8c9f Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 13:52:16 +0100
Subject: [PATCH 04/27] Merge main intop feature/402-feature-silent-framework
---
examples/04_Scenarios/scenario_example.py | 139 +-
flixopt/components.py | 32 +-
flixopt/config.py | 63 +-
flixopt/elements.py | 59 +-
flixopt/plotting.py | 2134 +++++++++++----------
flixopt/results.py | 1383 ++++++++++---
6 files changed, 2549 insertions(+), 1261 deletions(-)
diff --git a/examples/04_Scenarios/scenario_example.py b/examples/04_Scenarios/scenario_example.py
index f06760603..d258d4142 100644
--- a/examples/04_Scenarios/scenario_example.py
+++ b/examples/04_Scenarios/scenario_example.py
@@ -8,20 +8,80 @@
import flixopt as fx
if __name__ == '__main__':
- # Create datetime array starting from '2020-01-01' for the given time period
- timesteps = pd.date_range('2020-01-01', periods=9, freq='h')
+ # Create datetime array starting from '2020-01-01' for one week
+ timesteps = pd.date_range('2020-01-01', periods=24 * 7, freq='h')
scenarios = pd.Index(['Base Case', 'High Demand'])
periods = pd.Index([2020, 2021, 2022])
# --- Create Time Series Data ---
- # Heat demand profile (e.g., kW) over time and corresponding power prices
- heat_demand_per_h = pd.DataFrame(
- {'Base Case': [30, 0, 90, 110, 110, 20, 20, 20, 20], 'High Demand': [30, 0, 100, 118, 125, 20, 20, 20, 20]},
- index=timesteps,
+ # Realistic daily patterns: morning/evening peaks, night/midday lows
+ np.random.seed(42)
+ n_hours = len(timesteps)
+
+ # Heat demand: 24-hour patterns (kW) for Base Case and High Demand scenarios
+ base_daily_pattern = np.array(
+ [22, 20, 18, 18, 20, 25, 40, 70, 95, 110, 85, 65, 60, 58, 62, 68, 75, 88, 105, 125, 130, 122, 95, 35]
+ )
+ high_daily_pattern = np.array(
+ [28, 25, 22, 22, 24, 30, 52, 88, 118, 135, 105, 80, 75, 72, 75, 82, 92, 108, 128, 148, 155, 145, 115, 48]
+ )
+
+ # Tile and add variation
+ base_demand = np.tile(base_daily_pattern, n_hours // 24 + 1)[:n_hours] * (
+ 1 + np.random.uniform(-0.05, 0.05, n_hours)
)
- power_prices = np.array([0.08, 0.09, 0.10])
+ high_demand = np.tile(high_daily_pattern, n_hours // 24 + 1)[:n_hours] * (
+ 1 + np.random.uniform(-0.07, 0.07, n_hours)
+ )
+
+ heat_demand_per_h = pd.DataFrame({'Base Case': base_demand, 'High Demand': high_demand}, index=timesteps)
+
+ # Power prices: hourly factors (night low, peak high) and period escalation (2020-2022)
+ hourly_price_factors = np.array(
+ [
+ 0.70,
+ 0.65,
+ 0.62,
+ 0.60,
+ 0.62,
+ 0.70,
+ 0.95,
+ 1.15,
+ 1.30,
+ 1.25,
+ 1.10,
+ 1.00,
+ 0.95,
+ 0.90,
+ 0.88,
+ 0.92,
+ 1.00,
+ 1.10,
+ 1.25,
+ 1.40,
+ 1.35,
+ 1.20,
+ 0.95,
+ 0.80,
+ ]
+ )
+ period_base_prices = np.array([0.075, 0.095, 0.135]) # β¬/kWh for 2020, 2021, 2022
+
+ price_series = np.zeros((n_hours, 3))
+ for period_idx, base_price in enumerate(period_base_prices):
+ price_series[:, period_idx] = (
+ np.tile(hourly_price_factors, n_hours // 24 + 1)[:n_hours]
+ * base_price
+ * (1 + np.random.uniform(-0.03, 0.03, n_hours))
+ )
- flow_system = fx.FlowSystem(timesteps=timesteps, periods=periods, scenarios=scenarios, weights=np.array([0.5, 0.6]))
+ power_prices = price_series.mean(axis=0)
+
+ # Scenario weights: probability of each scenario occurring
+ # Base Case: 60% probability, High Demand: 40% probability
+ scenario_weights = np.array([0.6, 0.4])
+
+ flow_system = fx.FlowSystem(timesteps=timesteps, periods=periods, scenarios=scenarios, weights=scenario_weights)
# --- Define Energy Buses ---
# These represent nodes, where the used medias are balanced (electricity, heat, and gas)
@@ -35,22 +95,24 @@
description='Kosten',
is_standard=True, # standard effect: no explicit value needed for costs
is_objective=True, # Minimizing costs as the optimization objective
- share_from_temporal={'CO2': 0.2},
+ share_from_temporal={'CO2': 0.2}, # Carbon price: 0.2 β¬/kg CO2 (e.g., carbon tax)
)
- # CO2 emissions effect with an associated cost impact
+ # CO2 emissions effect with constraint
+ # Maximum of 1000 kg CO2/hour represents a regulatory or voluntary emissions limit
CO2 = fx.Effect(
label='CO2',
unit='kg',
description='CO2_e-Emissionen',
- maximum_per_hour=1000, # Max CO2 emissions per hour
+ maximum_per_hour=1000, # Regulatory emissions limit: 1000 kg CO2/hour
)
# --- Define Flow System Components ---
# Boiler: Converts fuel (gas) into thermal energy (heat)
+ # Modern condensing gas boiler with realistic efficiency
boiler = fx.linear_converters.Boiler(
label='Boiler',
- eta=0.5,
+ eta=0.92, # Realistic efficiency for modern condensing gas boiler (92%)
Q_th=fx.Flow(
label='Q_th',
bus='FernwΓ€rme',
@@ -63,27 +125,28 @@
)
# Combined Heat and Power (CHP): Generates both electricity and heat from fuel
+ # Modern CHP unit with realistic efficiencies (total efficiency ~88%)
chp = fx.linear_converters.CHP(
label='CHP',
- eta_th=0.5,
- eta_el=0.4,
+ eta_th=0.48, # Realistic thermal efficiency (48%)
+ eta_el=0.40, # Realistic electrical efficiency (40%)
P_el=fx.Flow('P_el', bus='Strom', size=60, relative_minimum=5 / 60, on_off_parameters=fx.OnOffParameters()),
Q_th=fx.Flow('Q_th', bus='FernwΓ€rme'),
Q_fu=fx.Flow('Q_fu', bus='Gas'),
)
- # Storage: Energy storage system with charging and discharging capabilities
+ # Storage: Thermal energy storage system with charging and discharging capabilities
+ # Realistic thermal storage parameters (e.g., insulated hot water tank)
storage = fx.Storage(
label='Storage',
charging=fx.Flow('Q_th_load', bus='FernwΓ€rme', size=1000),
discharging=fx.Flow('Q_th_unload', bus='FernwΓ€rme', size=1000),
capacity_in_flow_hours=fx.InvestParameters(effects_of_investment=20, fixed_size=30, mandatory=True),
initial_charge_state=0, # Initial storage state: empty
- relative_maximum_charge_state=np.array([80, 70, 80, 80, 80, 80, 80, 80, 80]) * 0.01,
- relative_maximum_final_charge_state=0.8,
- eta_charge=0.9,
- eta_discharge=1, # Efficiency factors for charging/discharging
- relative_loss_per_hour=0.08, # 8% loss per hour. Absolute loss depends on current charge state
+ relative_maximum_final_charge_state=np.array([0.8, 0.5, 0.1]),
+ eta_charge=0.95, # Realistic charging efficiency (~95%)
+ eta_discharge=0.98, # Realistic discharging efficiency (~98%)
+ relative_loss_per_hour=np.array([0.008, 0.015]), # Realistic thermal losses: 0.8-1.5% per hour
prevent_simultaneous_charge_and_discharge=True, # Prevent charging and discharging at the same time
)
@@ -94,10 +157,22 @@
)
# Gas Source: Gas tariff source with associated costs and CO2 emissions
+ # Realistic gas prices varying by period (reflecting 2020-2022 energy crisis)
+ # 2020: 0.04 β¬/kWh, 2021: 0.06 β¬/kWh, 2022: 0.11 β¬/kWh
+ gas_prices_per_period = np.array([0.04, 0.06, 0.11])
+
+ # CO2 emissions factor for natural gas: ~0.202 kg CO2/kWh (realistic value)
+ gas_co2_emissions = 0.202
+
gas_source = fx.Source(
label='Gastarif',
outputs=[
- fx.Flow(label='Q_Gas', bus='Gas', size=1000, effects_per_flow_hour={costs.label: 0.04, CO2.label: 0.3})
+ fx.Flow(
+ label='Q_Gas',
+ bus='Gas',
+ size=1000,
+ effects_per_flow_hour={costs.label: gas_prices_per_period, CO2.label: gas_co2_emissions},
+ )
],
)
@@ -121,24 +196,26 @@
# --- Solve the Calculation and Save Results ---
calculation.solve(fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=30))
+ calculation.results.setup_colors(
+ {
+ 'CHP': 'red',
+ 'Greys': ['Gastarif', 'Einspeisung', 'Heat Demand'],
+ 'Storage': 'blue',
+ 'Boiler': 'orange',
+ }
+ )
+
calculation.results.plot_heatmap('CHP(Q_th)|flow_rate')
# --- Analyze Results ---
- calculation.results['FernwΓ€rme'].plot_node_balance_pie()
- calculation.results['FernwΓ€rme'].plot_node_balance(style='stacked_bar')
- calculation.results['Storage'].plot_node_balance()
+ calculation.results['FernwΓ€rme'].plot_node_balance(mode='stacked_bar')
calculation.results.plot_heatmap('CHP(Q_th)|flow_rate')
+ calculation.results['Storage'].plot_charge_state()
+ calculation.results['FernwΓ€rme'].plot_node_balance_pie(select={'period': 2020, 'scenario': 'Base Case'})
# Convert the results for the storage component to a dataframe and display
df = calculation.results['Storage'].node_balance_with_charge_state()
print(df)
- # Plot charge state using matplotlib
- fig, ax = calculation.results['Storage'].plot_charge_state(engine='matplotlib')
- # Customize the plot further if needed
- ax.set_title('Storage Charge State Over Time')
- # Or save the figure
- # fig.savefig('storage_charge_state.png')
-
# Save results to file for later usage
calculation.results.to_file()
diff --git a/flixopt/components.py b/flixopt/components.py
index c40e6af88..8f89378ae 100644
--- a/flixopt/components.py
+++ b/flixopt/components.py
@@ -11,6 +11,7 @@
import numpy as np
import xarray as xr
+from . import io as fx_io
from .core import PeriodicDataUser, PlausibilityError, TemporalData, TemporalDataUser
from .elements import Component, ComponentModel, Flow
from .features import InvestmentModel, PiecewiseModel
@@ -528,6 +529,15 @@ def _plausibility_checks(self) -> None:
f'{self.discharging.size.minimum_size=}, {self.discharging.size.maximum_size=}.'
)
+ def __repr__(self) -> str:
+ """Return string representation."""
+ # Use build_repr_from_init directly to exclude charging and discharging
+ return fx_io.build_repr_from_init(
+ self,
+ excluded_params={'self', 'label', 'charging', 'discharging', 'kwargs'},
+ skip_default_size=True,
+ ) + fx_io.format_flow_details(self)
+
@register_class_for_io
class Transmission(Component):
@@ -1304,16 +1314,18 @@ def __init__(
prevent_simultaneous_flow_rates: bool = False,
**kwargs,
):
- """
- Initialize a Sink (consumes flow from the system).
-
- Supports legacy `sink=` keyword for backward compatibility (deprecated): if `sink` is provided it is used as the single input flow and a DeprecationWarning is issued; specifying both `inputs` and `sink` raises ValueError.
-
- Parameters:
- label (str): Unique element label.
- inputs (list[Flow], optional): Input flows for the sink.
- meta_data (dict, optional): Arbitrary metadata attached to the element.
- prevent_simultaneous_flow_rates (bool, optional): If True, prevents simultaneous nonzero flow rates across the element's inputs by wiring that restriction into the base Component setup.
+ """Initialize a Sink (consumes flow from the system).
+
+ Supports legacy `sink=` keyword for backward compatibility (deprecated): if `sink` is provided
+ it is used as the single input flow and a DeprecationWarning is issued; specifying both
+ `inputs` and `sink` raises ValueError.
+
+ Args:
+ label: Unique element label.
+ inputs: Input flows for the sink.
+ meta_data: Arbitrary metadata attached to the element.
+ prevent_simultaneous_flow_rates: If True, prevents simultaneous nonzero flow rates
+ across the element's inputs by wiring that restriction into the base Component setup.
Note:
The deprecated `sink` kwarg is accepted for compatibility but will be removed in future releases.
diff --git a/flixopt/config.py b/flixopt/config.py
index a7549a3ec..670f86da2 100644
--- a/flixopt/config.py
+++ b/flixopt/config.py
@@ -8,7 +8,6 @@
from types import MappingProxyType
from typing import Literal
-import yaml
from rich.console import Console
from rich.logging import RichHandler
from rich.style import Style
@@ -54,6 +53,16 @@
'big_binary_bound': 100_000,
}
),
+ 'plotting': MappingProxyType(
+ {
+ 'default_show': True,
+ 'default_engine': 'plotly',
+ 'default_dpi': 300,
+ 'default_facet_cols': 3,
+ 'default_sequential_colorscale': 'turbo',
+ 'default_qualitative_colorscale': 'plotly',
+ }
+ ),
}
)
@@ -185,6 +194,42 @@ class Modeling:
epsilon: float = _DEFAULTS['modeling']['epsilon']
big_binary_bound: int = _DEFAULTS['modeling']['big_binary_bound']
+ class Plotting:
+ """Plotting configuration.
+
+ Configure backends via environment variables:
+ - Matplotlib: Set `MPLBACKEND` environment variable (e.g., 'Agg', 'TkAgg')
+ - Plotly: Set `PLOTLY_RENDERER` or use `plotly.io.renderers.default`
+
+ Attributes:
+ default_show: Default value for the `show` parameter in plot methods.
+ default_engine: Default plotting engine.
+ default_dpi: Default DPI for saved plots.
+ default_facet_cols: Default number of columns for faceted plots.
+ default_sequential_colorscale: Default colorscale for heatmaps and continuous data.
+ default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts).
+
+ Examples:
+ ```python
+ # Set consistent theming
+ CONFIG.Plotting.plotly_template = 'plotly_dark'
+ CONFIG.apply()
+
+ # Configure default export and color settings
+ CONFIG.Plotting.default_dpi = 600
+ CONFIG.Plotting.default_sequential_colorscale = 'plasma'
+ CONFIG.Plotting.default_qualitative_colorscale = 'Dark24'
+ CONFIG.apply()
+ ```
+ """
+
+ default_show: bool = _DEFAULTS['plotting']['default_show']
+ default_engine: Literal['plotly', 'matplotlib'] = _DEFAULTS['plotting']['default_engine']
+ default_dpi: int = _DEFAULTS['plotting']['default_dpi']
+ default_facet_cols: int = _DEFAULTS['plotting']['default_facet_cols']
+ default_sequential_colorscale: str = _DEFAULTS['plotting']['default_sequential_colorscale']
+ default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale']
+
config_name: str = _DEFAULTS['config_name']
@classmethod
@@ -253,13 +298,15 @@ def load_from_file(cls, config_file: str | Path):
Raises:
FileNotFoundError: If the config file does not exist.
"""
+ # Import here to avoid circular import
+ from . import io as fx_io
+
config_path = Path(config_file)
if not config_path.exists():
raise FileNotFoundError(f'Config file not found: {config_file}')
- with config_path.open() as file:
- config_dict = yaml.safe_load(file) or {}
- cls._apply_config_dict(config_dict)
+ config_dict = fx_io.load_yaml(config_path)
+ cls._apply_config_dict(config_dict)
cls.apply()
@@ -319,6 +366,14 @@ def to_dict(cls) -> dict:
'epsilon': cls.Modeling.epsilon,
'big_binary_bound': cls.Modeling.big_binary_bound,
},
+ 'plotting': {
+ 'default_show': cls.Plotting.default_show,
+ 'default_engine': cls.Plotting.default_engine,
+ 'default_dpi': cls.Plotting.default_dpi,
+ 'default_facet_cols': cls.Plotting.default_facet_cols,
+ 'default_sequential_colorscale': cls.Plotting.default_sequential_colorscale,
+ 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale,
+ },
}
diff --git a/flixopt/elements.py b/flixopt/elements.py
index 25e399811..2a9a2cf4f 100644
--- a/flixopt/elements.py
+++ b/flixopt/elements.py
@@ -11,6 +11,7 @@
import numpy as np
import xarray as xr
+from . import io as fx_io
from .config import CONFIG
from .core import PlausibilityError, Scalar, TemporalData, TemporalDataUser
from .features import InvestmentModel, OnOffModel
@@ -86,10 +87,12 @@ def __init__(
super().__init__(label, meta_data=meta_data)
self.inputs: list[Flow] = inputs or []
self.outputs: list[Flow] = outputs or []
- self._check_unique_flow_labels()
self.on_off_parameters = on_off_parameters
self.prevent_simultaneous_flows: list[Flow] = prevent_simultaneous_flows or []
+ self._check_unique_flow_labels()
+ self._connect_flows()
+
self.flows: dict[str, Flow] = {flow.label: flow for flow in self.inputs + self.outputs}
def create_model(self, model: FlowSystemModel) -> ComponentModel:
@@ -115,6 +118,48 @@ def _check_unique_flow_labels(self):
def _plausibility_checks(self) -> None:
self._check_unique_flow_labels()
+ def _connect_flows(self):
+ # Inputs
+ for flow in self.inputs:
+ if flow.component not in ('UnknownComponent', self.label_full):
+ raise ValueError(
+ f'Flow "{flow.label}" already assigned to component "{flow.component}". '
+ f'Cannot attach to "{self.label_full}".'
+ )
+ flow.component = self.label_full
+ flow.is_input_in_component = True
+ # Outputs
+ for flow in self.outputs:
+ if flow.component not in ('UnknownComponent', self.label_full):
+ raise ValueError(
+ f'Flow "{flow.label}" already assigned to component "{flow.component}". '
+ f'Cannot attach to "{self.label_full}".'
+ )
+ flow.component = self.label_full
+ flow.is_input_in_component = False
+
+ # Validate prevent_simultaneous_flows: only allow local flows
+ if self.prevent_simultaneous_flows:
+ # Deduplicate while preserving order
+ seen = set()
+ self.prevent_simultaneous_flows = [
+ f for f in self.prevent_simultaneous_flows if id(f) not in seen and not seen.add(id(f))
+ ]
+ local = set(self.inputs + self.outputs)
+ foreign = [f for f in self.prevent_simultaneous_flows if f not in local]
+ if foreign:
+ names = ', '.join(f.label_full for f in foreign)
+ raise ValueError(
+ f'prevent_simultaneous_flows for "{self.label_full}" must reference its own flows. '
+ f'Foreign flows detected: {names}'
+ )
+
+ def __repr__(self) -> str:
+ """Return string representation with flow information."""
+ return fx_io.build_repr_from_init(
+ self, excluded_params={'self', 'label', 'inputs', 'outputs', 'kwargs'}, skip_default_size=True
+ ) + fx_io.format_flow_details(self)
+
@register_class_for_io
class Bus(Element):
@@ -207,11 +252,19 @@ def _plausibility_checks(self) -> None:
logger.warning(
f'In Bus {self.label_full}, the excess_penalty_per_flow_hour is 0. Use "None" or a value > 0.'
)
+ if len(self.inputs) == 0 and len(self.outputs) == 0:
+ raise ValueError(
+ f'Bus "{self.label_full}" has no Flows connected to it. Please remove it from the FlowSystem'
+ )
@property
def with_excess(self) -> bool:
return False if self.excess_penalty_per_flow_hour is None else True
+ def __repr__(self) -> str:
+ """Return string representation."""
+ return super().__repr__() + fx_io.format_flow_details(self)
+
@register_class_for_io
class Connection:
@@ -489,6 +542,10 @@ def size_is_fixed(self) -> bool:
# Wenn kein InvestParameters existiert --> True; Wenn Investparameter, den Wert davon nehmen
return False if (isinstance(self.size, InvestParameters) and self.size.fixed_size is None) else True
+ def _format_invest_params(self, params: InvestParameters) -> str:
+ """Format InvestParameters for display."""
+ return f'size: {params.format_for_repr()}'
+
class FlowModel(ElementModel):
element: Flow # Type hint
diff --git a/flixopt/plotting.py b/flixopt/plotting.py
index 356f013c0..045cf7e99 100644
--- a/flixopt/plotting.py
+++ b/flixopt/plotting.py
@@ -39,14 +39,17 @@
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline
-from plotly.exceptions import PlotlyError
+import xarray as xr
+
+from .color_processing import process_colors
+from .config import CONFIG
if TYPE_CHECKING:
import pyvis
logger = logging.getLogger('flixopt')
-# Define the colors for the 'portland' colormap in matplotlib
+# Define the colors for the 'portland' colorscale in matplotlib
_portland_colors = [
[12 / 255, 51 / 255, 131 / 255], # Dark blue
[10 / 255, 136 / 255, 186 / 255], # Light blue
@@ -55,7 +58,7 @@
[217 / 255, 30 / 255, 30 / 255], # Red
]
-# Check if the colormap already exists before registering it
+# Check if the colorscale already exists before registering it
if hasattr(plt, 'colormaps'): # Matplotlib >= 3.7
registry = plt.colormaps
if 'portland' not in registry:
@@ -70,9 +73,9 @@
Color specifications can take several forms to accommodate different use cases:
-**Named Colormaps** (str):
- - Standard colormaps: 'viridis', 'plasma', 'cividis', 'tab10', 'Set1'
- - Energy-focused: 'portland' (custom flixopt colormap for energy systems)
+**Named colorscales** (str):
+ - Standard colorscales: 'turbo', 'plasma', 'cividis', 'tab10', 'Set1'
+ - Energy-focused: 'portland' (custom flixopt colorscale for energy systems)
- Backend-specific maps available in Plotly and Matplotlib
**Color Lists** (list[str]):
@@ -87,8 +90,8 @@
Examples:
```python
- # Named colormap
- colors = 'viridis' # Automatic color generation
+ # Named colorscale
+ colors = 'turbo' # Automatic color generation
# Explicit color list
colors = ['red', 'blue', 'green', '#FFD700']
@@ -111,7 +114,7 @@
References:
- HTML Color Names: https://htmlcolorcodes.com/color-names/
- - Matplotlib Colormaps: https://matplotlib.org/stable/tutorials/colors/colormaps.html
+ - Matplotlib colorscales: https://matplotlib.org/stable/tutorials/colors/colorscales.html
- Plotly Built-in Colorscales: https://plotly.com/python/builtin-colorscales/
"""
@@ -119,432 +122,520 @@
"""Identifier for the plotting engine to use."""
-class ColorProcessor:
- """Intelligent color management system for consistent multi-backend visualization.
+def _ensure_dataset(data: xr.Dataset | pd.DataFrame | pd.Series) -> xr.Dataset:
+ """Convert DataFrame or Series to Dataset if needed."""
+ if isinstance(data, xr.Dataset):
+ return data
+ elif isinstance(data, pd.DataFrame):
+ # Convert DataFrame to Dataset
+ return data.to_xarray()
+ elif isinstance(data, pd.Series):
+ # Convert Series to DataFrame first, then to Dataset
+ return data.to_frame().to_xarray()
+ else:
+ raise TypeError(f'Data must be xr.Dataset, pd.DataFrame, or pd.Series, got {type(data).__name__}')
- This class provides unified color processing across Plotly and Matplotlib backends,
- ensuring consistent visual appearance regardless of the plotting engine used.
- It handles color palette generation, named colormap translation, and intelligent
- color cycling for complex datasets with many categories.
- Key Features:
- **Backend Agnostic**: Automatic color format conversion between engines
- **Palette Management**: Support for named colormaps, custom palettes, and color lists
- **Intelligent Cycling**: Smart color assignment for datasets with many categories
- **Fallback Handling**: Graceful degradation when requested colormaps are unavailable
- **Energy System Colors**: Built-in palettes optimized for energy system visualization
+def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None:
+ """Validate dataset for plotting (checks for empty data, non-numeric types, etc.)."""
+ # Check for empty data
+ if not allow_empty and len(data.data_vars) == 0:
+ raise ValueError('Empty Dataset provided (no variables). Cannot create plot.')
+
+ # Check if dataset has any data (xarray uses nbytes for total size)
+ if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True:
+ if not allow_empty and len(data.data_vars) > 0:
+ raise ValueError('Dataset has zero size. Cannot create plot.')
+ if len(data.data_vars) == 0:
+ return # Empty dataset, nothing to validate
+ return
+
+ # Check for non-numeric data types
+ for var in data.data_vars:
+ dtype = data[var].dtype
+ if not np.issubdtype(dtype, np.number):
+ raise TypeError(
+ f"Variable '{var}' has non-numeric dtype '{dtype}'. "
+ f'Plotting requires numeric data types (int, float, etc.).'
+ )
- Color Input Types:
- - **Named Colormaps**: 'viridis', 'plasma', 'portland', 'tab10', etc.
- - **Color Lists**: ['red', 'blue', 'green'] or ['#FF0000', '#0000FF', '#00FF00']
- - **Label Dictionaries**: {'Generator': 'red', 'Storage': 'blue', 'Load': 'green'}
+ # Warn about NaN/Inf values
+ for var in data.data_vars:
+ if np.isnan(data[var].values).any():
+ logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.")
+ if np.isinf(data[var].values).any():
+ logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.")
- Examples:
- Basic color processing:
- ```python
- # Initialize for Plotly backend
- processor = ColorProcessor(engine='plotly', default_colormap='viridis')
+def with_plotly(
+ data: xr.Dataset | pd.DataFrame | pd.Series,
+ mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
+ colors: ColorType | None = None,
+ title: str = '',
+ ylabel: str = '',
+ xlabel: str = '',
+ facet_by: str | list[str] | None = None,
+ animate_by: str | None = None,
+ facet_cols: int | None = None,
+ shared_yaxes: bool = True,
+ shared_xaxes: bool = True,
+ **px_kwargs: Any,
+) -> go.Figure:
+ """
+ Plot data with Plotly using facets (subplots) and/or animation for multidimensional data.
- # Process different color specifications
- colors = processor.process_colors('plasma', ['Gen1', 'Gen2', 'Storage'])
- colors = processor.process_colors(['red', 'blue', 'green'], ['A', 'B', 'C'])
- colors = processor.process_colors({'Wind': 'skyblue', 'Solar': 'gold'}, ['Wind', 'Solar', 'Gas'])
+ Uses Plotly Express for convenient faceting and animation with automatic styling.
- # Switch to Matplotlib
- processor = ColorProcessor(engine='matplotlib')
- mpl_colors = processor.process_colors('tab10', component_labels)
- ```
+ Args:
+ data: An xarray Dataset, pandas DataFrame, or pandas Series to plot.
+ mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines,
+ 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts.
+ colors: Color specification (colorscale, list, or dict mapping labels to colors).
+ title: The main title of the plot.
+ ylabel: The label for the y-axis.
+ xlabel: The label for the x-axis.
+ facet_by: Dimension(s) to create facets for. Creates a subplot grid.
+ Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col).
+ If the dimension doesn't exist in the data, it will be silently ignored.
+ animate_by: Dimension to animate over. Creates animation frames.
+ If the dimension doesn't exist in the data, it will be silently ignored.
+ facet_cols: Number of columns in the facet grid (used when facet_by is single dimension).
+ shared_yaxes: Whether subplots share y-axes.
+ shared_xaxes: Whether subplots share x-axes.
+ **px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function
+ (px.bar, px.line, px.area). These override default arguments if provided.
+ Examples: range_x=[0, 100], range_y=[0, 50], category_orders={...}, line_shape='linear'
- Energy system visualization:
+ Returns:
+ A Plotly figure object containing the faceted/animated plot. You can further customize
+ the returned figure using Plotly's methods (e.g., fig.update_traces(), fig.update_layout()).
+
+ Examples:
+ Simple plot:
```python
- # Specialized energy system palette
- energy_colors = {
- 'Natural_Gas': '#8B4513', # Brown
- 'Electricity': '#FFD700', # Gold
- 'Heat': '#FF4500', # Red-orange
- 'Cooling': '#87CEEB', # Sky blue
- 'Hydrogen': '#E6E6FA', # Lavender
- 'Battery': '#32CD32', # Lime green
- }
+ fig = with_plotly(dataset, mode='area', title='Energy Mix')
+ ```
+
+ Facet by scenario:
- processor = ColorProcessor('plotly')
- flow_colors = processor.process_colors(energy_colors, flow_labels)
+ ```python
+ fig = with_plotly(dataset, facet_by='scenario', facet_cols=2)
```
- Args:
- engine: Plotting backend ('plotly' or 'matplotlib'). Determines output color format.
- default_colormap: Fallback colormap when requested palettes are unavailable.
- Common options: 'viridis', 'plasma', 'tab10', 'portland'.
+ Animate by period:
- """
+ ```python
+ fig = with_plotly(dataset, animate_by='period')
+ ```
- def __init__(self, engine: PlottingEngine = 'plotly', default_colormap: str = 'viridis'):
- """Initialize the color processor with specified backend and defaults."""
- if engine not in ['plotly', 'matplotlib']:
- raise TypeError(f'engine must be "plotly" or "matplotlib", but is {engine}')
- self.engine = engine
- self.default_colormap = default_colormap
-
- def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list[Any]:
- """
- Generate colors from a named colormap.
-
- Args:
- colormap_name: Name of the colormap
- num_colors: Number of colors to generate
-
- Returns:
- list of colors in the format appropriate for the engine
- """
- if self.engine == 'plotly':
- try:
- colorscale = px.colors.get_colorscale(colormap_name)
- except PlotlyError as e:
- logger.error(f"Colorscale '{colormap_name}' not found in Plotly. Using {self.default_colormap}: {e}")
- colorscale = px.colors.get_colorscale(self.default_colormap)
-
- # Generate evenly spaced points
- color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0]
- return px.colors.sample_colorscale(colorscale, color_points)
-
- else: # matplotlib
- try:
- cmap = plt.get_cmap(colormap_name, num_colors)
- except ValueError as e:
- logger.error(f"Colormap '{colormap_name}' not found in Matplotlib. Using {self.default_colormap}: {e}")
- cmap = plt.get_cmap(self.default_colormap, num_colors)
-
- return [cmap(i) for i in range(num_colors)]
-
- def _handle_color_list(self, colors: list[str], num_labels: int) -> list[str]:
- """
- Handle a list of colors, cycling if necessary.
-
- Args:
- colors: list of color strings
- num_labels: Number of labels that need colors
-
- Returns:
- list of colors matching the number of labels
- """
- if len(colors) == 0:
- logger.error(f'Empty color list provided. Using {self.default_colormap} instead.')
- return self._generate_colors_from_colormap(self.default_colormap, num_labels)
-
- if len(colors) < num_labels:
- logger.warning(
- f'Not enough colors provided ({len(colors)}) for all labels ({num_labels}). Colors will cycle.'
- )
- # Cycle through the colors
- color_iter = itertools.cycle(colors)
- return [next(color_iter) for _ in range(num_labels)]
- else:
- # Trim if necessary
- if len(colors) > num_labels:
- logger.warning(
- f'More colors provided ({len(colors)}) than labels ({num_labels}). Extra colors will be ignored.'
- )
- return colors[:num_labels]
-
- def _handle_color_dict(self, colors: dict[str, str], labels: list[str]) -> list[str]:
- """
- Handle a dictionary mapping labels to colors.
-
- Args:
- colors: Dictionary mapping labels to colors
- labels: list of labels that need colors
-
- Returns:
- list of colors in the same order as labels
- """
- if len(colors) == 0:
- logger.warning(f'Empty color dictionary provided. Using {self.default_colormap} instead.')
- return self._generate_colors_from_colormap(self.default_colormap, len(labels))
-
- # Find missing labels
- missing_labels = sorted(set(labels) - set(colors.keys()))
- if missing_labels:
- logger.warning(
- f'Some labels have no color specified: {missing_labels}. Using {self.default_colormap} for these.'
- )
+ Facet and animate:
- # Generate colors for missing labels
- missing_colors = self._generate_colors_from_colormap(self.default_colormap, len(missing_labels))
+ ```python
+ fig = with_plotly(dataset, facet_by='scenario', animate_by='period')
+ ```
- # Create a copy to avoid modifying the original
- colors_copy = colors.copy()
- for i, label in enumerate(missing_labels):
- colors_copy[label] = missing_colors[i]
- else:
- colors_copy = colors
-
- # Create color list in the same order as labels
- return [colors_copy[label] for label in labels]
-
- def process_colors(
- self,
- colors: ColorType,
- labels: list[str],
- return_mapping: bool = False,
- ) -> list[Any] | dict[str, Any]:
- """
- Process colors for the specified labels.
-
- Args:
- colors: Color specification (colormap name, list of colors, or label-to-color mapping)
- labels: list of data labels that need colors assigned
- return_mapping: If True, returns a dictionary mapping labels to colors;
- if False, returns a list of colors in the same order as labels
-
- Returns:
- Either a list of colors or a dictionary mapping labels to colors
- """
- if len(labels) == 0:
- logger.error('No labels provided for color assignment.')
- return {} if return_mapping else []
-
- # Process based on type of colors input
- if isinstance(colors, str):
- color_list = self._generate_colors_from_colormap(colors, len(labels))
- elif isinstance(colors, list):
- color_list = self._handle_color_list(colors, len(labels))
- elif isinstance(colors, dict):
- color_list = self._handle_color_dict(colors, labels)
- else:
- logger.error(
- f'Unsupported color specification type: {type(colors)}. Using {self.default_colormap} instead.'
- )
- color_list = self._generate_colors_from_colormap(self.default_colormap, len(labels))
+ Customize with Plotly Express kwargs:
- # Return either a list or a mapping
- if return_mapping:
- return {label: color_list[i] for i, label in enumerate(labels)}
- else:
- return color_list
+ ```python
+ fig = with_plotly(dataset, range_y=[0, 100], line_shape='linear')
+ ```
+ Further customize the returned figure:
-def with_plotly(
- data: pd.DataFrame,
- style: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
- colors: ColorType = 'viridis',
- title: str = '',
- ylabel: str = '',
- xlabel: str = 'Time in h',
- fig: go.Figure | None = None,
-) -> go.Figure:
+ ```python
+ fig = with_plotly(dataset, mode='line')
+ fig.update_traces(line={'width': 5, 'dash': 'dot'})
+ fig.update_layout(template='plotly_dark', width=1200, height=600)
+ ```
"""
- Plot a DataFrame with Plotly, using either stacked bars or stepped lines.
+ if colors is None:
+ colors = CONFIG.Plotting.default_qualitative_colorscale
- Args:
- data: A DataFrame containing the data to plot, where the index represents time (e.g., hours),
- and each column represents a separate data series.
- style: The plotting style. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines,
- or 'area' for stacked area charts.
- colors: Color specification, can be:
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
- title: The title of the plot.
- ylabel: The label for the y-axis.
- xlabel: The label for the x-axis.
- fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
+ if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
+ raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}")
- Returns:
- A Plotly figure object containing the generated plot.
- """
- if style not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
- raise ValueError(f"'style' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {style!r}")
- if data.empty:
- return go.Figure()
+ # Apply CONFIG defaults if not explicitly set
+ if facet_cols is None:
+ facet_cols = CONFIG.Plotting.default_facet_cols
- processed_colors = ColorProcessor(engine='plotly').process_colors(colors, list(data.columns))
+ # Ensure data is a Dataset and validate it
+ data = _ensure_dataset(data)
+ _validate_plotting_data(data, allow_empty=True)
- fig = fig if fig is not None else go.Figure()
+ # Handle empty data
+ if len(data.data_vars) == 0:
+ logger.error('with_plotly() got an empty Dataset.')
+ return go.Figure()
- if style == 'stacked_bar':
- for i, column in enumerate(data.columns):
- fig.add_trace(
- go.Bar(
- x=data.index,
- y=data[column],
- name=column,
- marker=dict(
- color=processed_colors[i], line=dict(width=0, color='rgba(0,0,0,0)')
- ), # Transparent line with 0 width
- )
- )
+ # Handle all-scalar datasets (where all variables have no dimensions)
+ # This occurs when all variables are scalar values with dims=()
+ if all(len(data[var].dims) == 0 for var in data.data_vars):
+ # Create a simple DataFrame with variable names as x-axis
+ variables = list(data.data_vars.keys())
+ values = [float(data[var].values) for var in data.data_vars]
- fig.update_layout(
- barmode='relative',
- bargap=0, # No space between bars
- bargroupgap=0, # No space between grouped bars
- )
- if style == 'grouped_bar':
- for i, column in enumerate(data.columns):
- fig.add_trace(go.Bar(x=data.index, y=data[column], name=column, marker=dict(color=processed_colors[i])))
-
- fig.update_layout(
- barmode='group',
- bargap=0.2, # No space between bars
- bargroupgap=0, # space between grouped bars
+ # Resolve colors
+ color_discrete_map = process_colors(
+ colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
)
- elif style == 'line':
- for i, column in enumerate(data.columns):
- fig.add_trace(
- go.Scatter(
- x=data.index,
- y=data[column],
- mode='lines',
- name=column,
- line=dict(shape='hv', color=processed_colors[i]),
- )
+ marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables]
+
+ # Create simple plot based on mode using go (not px) for better color control
+ if mode in ('stacked_bar', 'grouped_bar'):
+ fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)])
+ elif mode == 'line':
+ fig = go.Figure(
+ data=[
+ go.Scatter(
+ x=variables,
+ y=values,
+ mode='lines+markers',
+ marker=dict(color=marker_colors, size=8),
+ line=dict(color='lightgray'),
+ )
+ ]
)
- elif style == 'area':
- data = data.copy()
- data[(data > -1e-5) & (data < 1e-5)] = 0 # Preventing issues with plotting
- # Split columns into positive, negative, and mixed categories
- positive_columns = list(data.columns[(data >= 0).where(~np.isnan(data), True).all()])
- negative_columns = list(data.columns[(data <= 0).where(~np.isnan(data), True).all()])
- negative_columns = [column for column in negative_columns if column not in positive_columns]
- mixed_columns = list(set(data.columns) - set(positive_columns + negative_columns))
-
- if mixed_columns:
- logger.error(
- f'Data for plotting stacked lines contains columns with both positive and negative values:'
- f' {mixed_columns}. These can not be stacked, and are printed as simple lines'
+ elif mode == 'area':
+ fig = go.Figure(
+ data=[
+ go.Scatter(
+ x=variables,
+ y=values,
+ fill='tozeroy',
+ marker=dict(color=marker_colors, size=8),
+ line=dict(color='lightgray'),
+ )
+ ]
)
-
- # Get color mapping for all columns
- colors_stacked = {column: processed_colors[i] for i, column in enumerate(data.columns)}
-
- for column in positive_columns + negative_columns:
- fig.add_trace(
- go.Scatter(
- x=data.index,
- y=data[column],
- mode='lines',
- name=column,
- line=dict(shape='hv', color=colors_stacked[column]),
- fill='tonexty',
- stackgroup='pos' if column in positive_columns else 'neg',
+ else:
+ raise ValueError('"mode" must be one of "stacked_bar", "grouped_bar", "line", "area"')
+
+ fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False)
+ return fig
+
+ # Convert Dataset to long-form DataFrame for Plotly Express
+ # Structure: time, variable, value, scenario, period, ... (all dims as columns)
+ dim_names = list(data.dims)
+ df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value')
+
+ # Validate facet_by and animate_by dimensions exist in the data
+ available_dims = [col for col in df_long.columns if col not in ['variable', 'value']]
+
+ # Check facet_by dimensions
+ if facet_by is not None:
+ if isinstance(facet_by, str):
+ if facet_by not in available_dims:
+ logger.debug(
+ f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
+ f'Ignoring facet_by parameter.'
)
- )
-
- for column in mixed_columns:
- fig.add_trace(
- go.Scatter(
- x=data.index,
- y=data[column],
- mode='lines',
- name=column,
- line=dict(shape='hv', color=colors_stacked[column], dash='dash'),
+ facet_by = None
+ elif isinstance(facet_by, list):
+ # Filter out dimensions that don't exist
+ missing_dims = [dim for dim in facet_by if dim not in available_dims]
+ facet_by = [dim for dim in facet_by if dim in available_dims]
+ if missing_dims:
+ logger.debug(
+ f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
+ f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
)
- )
+ if len(facet_by) == 0:
+ facet_by = None
+
+ # Check animate_by dimension
+ if animate_by is not None and animate_by not in available_dims:
+ logger.debug(
+ f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
+ f'Ignoring animate_by parameter.'
+ )
+ animate_by = None
+
+ # Setup faceting parameters for Plotly Express
+ facet_row = None
+ facet_col = None
+ if facet_by:
+ if isinstance(facet_by, str):
+ # Single facet dimension - use facet_col with facet_col_wrap
+ facet_col = facet_by
+ elif len(facet_by) == 1:
+ facet_col = facet_by[0]
+ elif len(facet_by) == 2:
+ # Two facet dimensions - use facet_row and facet_col
+ facet_row = facet_by[0]
+ facet_col = facet_by[1]
+ else:
+ raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}')
- # Update layout for better aesthetics
- fig.update_layout(
- title=title,
- yaxis=dict(
- title=ylabel,
- showgrid=True, # Enable grid lines on the y-axis
- gridcolor='lightgrey', # Customize grid line color
- gridwidth=0.5, # Customize grid line width
- ),
- xaxis=dict(
- title=xlabel,
- showgrid=True, # Enable grid lines on the x-axis
- gridcolor='lightgrey', # Customize grid line color
- gridwidth=0.5, # Customize grid line width
- ),
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
- font=dict(size=14), # Increase font size for better readability
+ # Process colors
+ all_vars = df_long['variable'].unique().tolist()
+ color_discrete_map = process_colors(
+ colors, all_vars, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
)
+ # Determine which dimension to use for x-axis
+ # Collect dimensions used for faceting and animation
+ used_dims = set()
+ if facet_row:
+ used_dims.add(facet_row)
+ if facet_col:
+ used_dims.add(facet_col)
+ if animate_by:
+ used_dims.add(animate_by)
+
+ # Find available dimensions for x-axis (not used for faceting/animation)
+ x_candidates = [d for d in available_dims if d not in used_dims]
+
+ # Use 'time' if available, otherwise use the first available dimension
+ if 'time' in x_candidates:
+ x_dim = 'time'
+ elif len(x_candidates) > 0:
+ x_dim = x_candidates[0]
+ else:
+ # Fallback: use the first dimension (shouldn't happen in normal cases)
+ x_dim = available_dims[0] if available_dims else 'time'
+
+ # Create plot using Plotly Express based on mode
+ common_args = {
+ 'data_frame': df_long,
+ 'x': x_dim,
+ 'y': 'value',
+ 'color': 'variable',
+ 'facet_row': facet_row,
+ 'facet_col': facet_col,
+ 'animation_frame': animate_by,
+ 'color_discrete_map': color_discrete_map,
+ 'title': title,
+ 'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''},
+ }
+
+ # Add facet_col_wrap for single facet dimension
+ if facet_col and not facet_row:
+ common_args['facet_col_wrap'] = facet_cols
+
+ # Add mode-specific defaults (before px_kwargs so they can be overridden)
+ if mode in ('line', 'area'):
+ common_args['line_shape'] = 'hv' # Stepped lines by default
+
+ # Allow callers to pass any px.* keyword args (e.g., category_orders, range_x/y, line_shape)
+ # These will override the defaults set above
+ if px_kwargs:
+ common_args.update(px_kwargs)
+
+ if mode == 'stacked_bar':
+ fig = px.bar(**common_args)
+ fig.update_traces(marker_line_width=0)
+ fig.update_layout(barmode='relative', bargap=0, bargroupgap=0)
+ elif mode == 'grouped_bar':
+ fig = px.bar(**common_args)
+ fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0)
+ elif mode == 'line':
+ fig = px.line(**common_args)
+ elif mode == 'area':
+ # Use Plotly Express to create the area plot (preserves animation, legends, faceting)
+ fig = px.area(**common_args)
+
+ # Classify each variable based on its values
+ variable_classification = {}
+ for var in all_vars:
+ var_data = df_long[df_long['variable'] == var]['value']
+ var_data_clean = var_data[(var_data < -1e-5) | (var_data > 1e-5)]
+
+ if len(var_data_clean) == 0:
+ variable_classification[var] = 'zero'
+ else:
+ has_pos, has_neg = (var_data_clean > 0).any(), (var_data_clean < 0).any()
+ variable_classification[var] = (
+ 'mixed' if has_pos and has_neg else ('negative' if has_neg else 'positive')
+ )
+
+ # Log warning for mixed variables
+ mixed_vars = [v for v, c in variable_classification.items() if c == 'mixed']
+ if mixed_vars:
+ logger.warning(f'Variables with both positive and negative values: {mixed_vars}. Plotted as dashed lines.')
+
+ all_traces = list(fig.data)
+ for frame in fig.frames:
+ all_traces.extend(frame.data)
+
+ for trace in all_traces:
+ cls = variable_classification.get(trace.name, None)
+ # Only stack positive and negative, not mixed or zero
+ trace.stackgroup = cls if cls in ('positive', 'negative') else None
+
+ if cls in ('positive', 'negative'):
+ # Stacked area: add opacity to avoid hiding layers, remove line border
+ if hasattr(trace, 'line') and trace.line.color:
+ trace.fillcolor = trace.line.color
+ trace.line.width = 0
+ elif cls == 'mixed':
+ # Mixed variables: show as dashed line, not stacked
+ if hasattr(trace, 'line'):
+ trace.line.width = 2
+ trace.line.dash = 'dash'
+ if hasattr(trace, 'fill'):
+ trace.fill = None
+
+ # Update axes to share if requested (Plotly Express already handles this, but we can customize)
+ if not shared_yaxes:
+ fig.update_yaxes(matches=None)
+ if not shared_xaxes:
+ fig.update_xaxes(matches=None)
+
return fig
def with_matplotlib(
- data: pd.DataFrame,
- style: Literal['stacked_bar', 'line'] = 'stacked_bar',
- colors: ColorType = 'viridis',
+ data: xr.Dataset | pd.DataFrame | pd.Series,
+ mode: Literal['stacked_bar', 'line'] = 'stacked_bar',
+ colors: ColorType | None = None,
title: str = '',
ylabel: str = '',
xlabel: str = 'Time in h',
figsize: tuple[int, int] = (12, 6),
- fig: plt.Figure | None = None,
- ax: plt.Axes | None = None,
+ plot_kwargs: dict[str, Any] | None = None,
) -> tuple[plt.Figure, plt.Axes]:
"""
- Plot a DataFrame with Matplotlib using stacked bars or stepped lines.
+ Plot data with Matplotlib using stacked bars or stepped lines.
Args:
- data: A DataFrame containing the data to plot. The index should represent time (e.g., hours),
- and each column represents a separate data series.
- style: Plotting style. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
- colors: Color specification, can be:
- - A string with a colormap name (e.g., 'viridis', 'plasma')
+ data: An xarray Dataset, pandas DataFrame, or pandas Series to plot. After conversion to DataFrame,
+ the index represents time and each column represents a separate data series (variables).
+ mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines.
+ colors: Color specification. Can be:
+ - A colorscale name (e.g., 'turbo', 'plasma')
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
+ - A dict mapping column names to colors (e.g., {'Column1': '#ff0000'})
title: The title of the plot.
ylabel: The ylabel of the plot.
xlabel: The xlabel of the plot.
- figsize: Specify the size of the figure
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
- ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
+ figsize: Specify the size of the figure (width, height) in inches.
+ plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls.
+ Use this to customize plot properties (e.g., linewidth, alpha, edgecolor).
Returns:
A tuple containing the Matplotlib figure and axes objects used for the plot.
Notes:
- - If `style` is 'stacked_bar', bars are stacked for both positive and negative values.
+ - If `mode` is 'stacked_bar', bars are stacked for both positive and negative values.
Negative values are stacked separately without extra labels in the legend.
- - If `style` is 'line', stepped lines are drawn for each data series.
+ - If `mode` is 'line', stepped lines are drawn for each data series.
"""
- if style not in ('stacked_bar', 'line'):
- raise ValueError(f"'style' must be one of {{'stacked_bar','line'}} for matplotlib, got {style!r}")
+ if colors is None:
+ colors = CONFIG.Plotting.default_qualitative_colorscale
- if fig is None or ax is None:
- fig, ax = plt.subplots(figsize=figsize)
+ if mode not in ('stacked_bar', 'line'):
+ raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}")
+
+ # Ensure data is a Dataset and validate it
+ data = _ensure_dataset(data)
+ _validate_plotting_data(data, allow_empty=True)
+
+ # Create new figure and axes
+ fig, ax = plt.subplots(figsize=figsize)
+
+ # Initialize plot_kwargs if not provided
+ if plot_kwargs is None:
+ plot_kwargs = {}
+
+ # Handle all-scalar datasets (where all variables have no dimensions)
+ # This occurs when all variables are scalar values with dims=()
+ if all(len(data[var].dims) == 0 for var in data.data_vars):
+ # Create simple bar/line plot with variable names as x-axis
+ variables = list(data.data_vars.keys())
+ values = [float(data[var].values) for var in data.data_vars]
+
+ # Resolve colors
+ color_discrete_map = process_colors(
+ colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
+ )
+ colors_list = [color_discrete_map.get(var, '#808080') for var in variables]
+
+ # Create plot based on mode
+ if mode == 'stacked_bar':
+ ax.bar(variables, values, color=colors_list, **plot_kwargs)
+ elif mode == 'line':
+ ax.plot(
+ variables,
+ values,
+ marker='o',
+ color=colors_list[0] if len(set(colors_list)) == 1 else None,
+ **plot_kwargs,
+ )
+ # If different colors, plot each point separately
+ if len(set(colors_list)) > 1:
+ ax.clear()
+ for i, (var, val) in enumerate(zip(variables, values, strict=False)):
+ ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs)
+ ax.set_xticks(range(len(variables)))
+ ax.set_xticklabels(variables)
+
+ ax.set_xlabel(xlabel, ha='center')
+ ax.set_ylabel(ylabel, va='center')
+ ax.set_title(title)
+ ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y')
+ fig.tight_layout()
+
+ return fig, ax
+
+ # Resolve colors first (includes validation)
+ color_discrete_map = process_colors(
+ colors, list(data.data_vars), default_colorscale=CONFIG.Plotting.default_qualitative_colorscale
+ )
- processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, list(data.columns))
+ # Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form)
+ df = data.to_dataframe()
- if style == 'stacked_bar':
- cumulative_positive = np.zeros(len(data))
- cumulative_negative = np.zeros(len(data))
- width = data.index.to_series().diff().dropna().min() # Minimum time difference
+ # Get colors in column order
+ processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns]
- for i, column in enumerate(data.columns):
- positive_values = np.clip(data[column], 0, None) # Keep only positive values
- negative_values = np.clip(data[column], None, 0) # Keep only negative values
+ if mode == 'stacked_bar':
+ cumulative_positive = np.zeros(len(df))
+ cumulative_negative = np.zeros(len(df))
+
+ # Robust bar width: handle datetime-like, numeric, and single-point indexes
+ if len(df.index) > 1:
+ delta = pd.Index(df.index).to_series().diff().dropna().min()
+ if hasattr(delta, 'total_seconds'): # datetime-like
+ width = delta.total_seconds() / 86400.0 # Matplotlib date units = days
+ else:
+ width = float(delta)
+ else:
+ width = 0.8 # reasonable default for a single bar
+
+ for i, column in enumerate(df.columns):
+ # Fill NaNs to avoid breaking stacking math
+ series = df[column].fillna(0)
+ positive_values = np.clip(series, 0, None) # Keep only positive values
+ negative_values = np.clip(series, None, 0) # Keep only negative values
# Plot positive bars
ax.bar(
- data.index,
+ df.index,
positive_values,
bottom=cumulative_positive,
color=processed_colors[i],
label=column,
width=width,
align='center',
+ **plot_kwargs,
)
cumulative_positive += positive_values.values
# Plot negative bars
ax.bar(
- data.index,
+ df.index,
negative_values,
bottom=cumulative_negative,
color=processed_colors[i],
label='', # No label for negative bars
width=width,
align='center',
+ **plot_kwargs,
)
cumulative_negative += negative_values.values
- elif style == 'line':
- for i, column in enumerate(data.columns):
- ax.step(data.index, data[column], where='post', color=processed_colors[i], label=column)
+ elif mode == 'line':
+ for i, column in enumerate(df.columns):
+ ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs)
# Aesthetics
ax.set_xlabel(xlabel, ha='center')
@@ -562,213 +653,110 @@ def with_matplotlib(
return fig, ax
-def heat_map_matplotlib(
- data: pd.DataFrame,
- color_map: str = 'viridis',
- title: str = '',
- xlabel: str = 'Period',
- ylabel: str = 'Step',
- figsize: tuple[float, float] = (12, 6),
-) -> tuple[plt.Figure, plt.Axes]:
+def reshape_data_for_heatmap(
+ data: xr.DataArray,
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
+ | Literal['auto']
+ | None = 'auto',
+ facet_by: str | list[str] | None = None,
+ animate_by: str | None = None,
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
+) -> xr.DataArray:
"""
- Plots a DataFrame as a heatmap using Matplotlib. The columns of the DataFrame will be displayed on the x-axis,
- the index will be displayed on the y-axis, and the values will represent the 'heat' intensity in the plot.
+ Reshape data for heatmap visualization, handling time dimension intelligently.
- Args:
- data: A DataFrame containing the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
- The values in the DataFrame will be represented as colors in the heatmap.
- color_map: The colormap to use for the heatmap. Default is 'viridis'. Matplotlib supports various colormaps like 'plasma', 'inferno', 'cividis', etc.
- title: The title of the plot.
- xlabel: The label for the x-axis.
- ylabel: The label for the y-axis.
- figsize: The size of the figure to create. Default is (12, 6), which results in a width of 12 inches and a height of 6 inches.
-
- Returns:
- A tuple containing the Matplotlib `Figure` and `Axes` objects. The `Figure` contains the overall plot, while the `Axes` is the area
- where the heatmap is drawn. These can be used for further customization or saving the plot to a file.
-
- Notes:
- - The y-axis is flipped so that the first row of the DataFrame is displayed at the top of the plot.
- - The color scale is normalized based on the minimum and maximum values in the DataFrame.
- - The x-axis labels (periods) are placed at the top of the plot.
- - The colorbar is added horizontally at the bottom of the plot, with a label.
- """
-
- # Get the min and max values for color normalization
- color_bar_min, color_bar_max = data.min().min(), data.max().max()
-
- # Create the heatmap plot
- fig, ax = plt.subplots(figsize=figsize)
- ax.pcolormesh(data.values, cmap=color_map, shading='auto')
- ax.invert_yaxis() # Flip the y-axis to start at the top
+ This function decides whether to reshape the 'time' dimension based on the reshape_time parameter:
+ - 'auto': Automatically reshapes if only 'time' dimension would remain for heatmap
+ - Tuple: Explicitly reshapes time with specified parameters
+ - None: No reshaping (returns data as-is)
- # Adjust ticks and labels for x and y axes
- ax.set_xticks(np.arange(len(data.columns)) + 0.5)
- ax.set_xticklabels(data.columns, ha='center')
- ax.set_yticks(np.arange(len(data.index)) + 0.5)
- ax.set_yticklabels(data.index, va='center')
-
- # Add labels to the axes
- ax.set_xlabel(xlabel, ha='center')
- ax.set_ylabel(ylabel, va='center')
- ax.set_title(title)
-
- # Position x-axis labels at the top
- ax.xaxis.set_label_position('top')
- ax.xaxis.set_ticks_position('top')
-
- # Add the colorbar
- sm1 = plt.cm.ScalarMappable(cmap=color_map, norm=plt.Normalize(vmin=color_bar_min, vmax=color_bar_max))
- sm1.set_array([])
- fig.colorbar(sm1, ax=ax, pad=0.12, aspect=15, fraction=0.2, orientation='horizontal')
-
- fig.tight_layout()
-
- return fig, ax
-
-
-def heat_map_plotly(
- data: pd.DataFrame,
- color_map: str = 'viridis',
- title: str = '',
- xlabel: str = 'Period',
- ylabel: str = 'Step',
- categorical_labels: bool = True,
-) -> go.Figure:
- """
- Plots a DataFrame as a heatmap using Plotly. The columns of the DataFrame will be mapped to the x-axis,
- and the index will be displayed on the y-axis. The values in the DataFrame will represent the 'heat' in the plot.
-
- Args:
- data: A DataFrame with the data to be visualized. The index will be used for the y-axis, and columns will be used for the x-axis.
- The values in the DataFrame will be represented as colors in the heatmap.
- color_map: The color scale to use for the heatmap. Default is 'viridis'. Plotly supports various color scales like 'Cividis', 'Inferno', etc.
- title: The title of the heatmap. Default is an empty string.
- xlabel: The label for the x-axis. Default is 'Period'.
- ylabel: The label for the y-axis. Default is 'Step'.
- categorical_labels: If True, the x and y axes are treated as categorical data (i.e., the index and columns will not be interpreted as continuous data).
- Default is True. If False, the axes are treated as continuous, which may be useful for time series or numeric data.
-
- Returns:
- A Plotly figure object containing the heatmap. This can be further customized and saved
- or displayed using `fig.show()`.
-
- Notes:
- The color bar is automatically scaled to the minimum and maximum values in the data.
- The y-axis is reversed to display the first row at the top.
- """
-
- color_bar_min, color_bar_max = data.min().min(), data.max().max() # Min and max values for color scaling
- # Define the figure
- fig = go.Figure(
- data=go.Heatmap(
- z=data.values,
- x=data.columns,
- y=data.index,
- colorscale=color_map,
- zmin=color_bar_min,
- zmax=color_bar_max,
- colorbar=dict(
- title=dict(text='Color Bar Label', side='right'),
- orientation='h',
- xref='container',
- yref='container',
- len=0.8, # Color bar length relative to plot
- x=0.5,
- y=0.1,
- ),
- )
- )
-
- # Set axis labels and style
- fig.update_layout(
- title=title,
- xaxis=dict(title=xlabel, side='top', type='category' if categorical_labels else None),
- yaxis=dict(title=ylabel, autorange='reversed', type='category' if categorical_labels else None),
- )
-
- return fig
-
-
-def reshape_to_2d(data_1d: np.ndarray, nr_of_steps_per_column: int) -> np.ndarray:
- """
- Reshapes a 1D numpy array into a 2D array suitable for plotting as a colormap.
-
- The reshaped array will have the number of rows corresponding to the steps per column
- (e.g., 24 hours per day) and columns representing time periods (e.g., days or months).
+ All non-time dimensions are preserved during reshaping.
Args:
- data_1d: A 1D numpy array with the data to reshape.
- nr_of_steps_per_column: The number of steps (rows) per column in the resulting 2D array. For example,
- this could be 24 (for hours) or 31 (for days in a month).
+ data: DataArray to reshape for heatmap visualization.
+ reshape_time: Reshaping configuration:
+ - 'auto' (default): Auto-reshape if needed based on facet_by/animate_by
+ - Tuple (timeframes, timesteps_per_frame): Explicit time reshaping
+ - None: No reshaping
+ facet_by: Dimension(s) used for faceting (used in 'auto' decision).
+ animate_by: Dimension used for animation (used in 'auto' decision).
+ fill: Method to fill missing values: 'ffill' or 'bfill'. Default is 'ffill'.
Returns:
- The reshaped 2D array. Each internal array corresponds to one column, with the specified number of steps.
- Each column might represents a time period (e.g., day, month, etc.).
- """
+ Reshaped DataArray. If time reshaping is applied, 'time' dimension is replaced
+ by 'timestep' and 'timeframe'. All other dimensions are preserved.
- # Step 1: Ensure the input is a 1D array.
- if data_1d.ndim != 1:
- raise ValueError('Input must be a 1D array')
-
- # Step 2: Convert data to float type to allow NaN padding
- if data_1d.dtype != np.float64:
- data_1d = data_1d.astype(np.float64)
-
- # Step 3: Calculate the number of columns required
- total_steps = len(data_1d)
- cols = len(data_1d) // nr_of_steps_per_column # Base number of columns
-
- # If there's a remainder, add an extra column to hold the remaining values
- if total_steps % nr_of_steps_per_column != 0:
- cols += 1
+ Examples:
+ Auto-reshaping:
- # Step 4: Pad the 1D data to match the required number of rows and columns
- padded_data = np.pad(
- data_1d, (0, cols * nr_of_steps_per_column - total_steps), mode='constant', constant_values=np.nan
- )
+ ```python
+ # Will auto-reshape because only 'time' remains after faceting/animation
+ data = reshape_data_for_heatmap(data, reshape_time='auto', facet_by='scenario', animate_by='period')
+ ```
- # Step 5: Reshape the padded data into a 2D array
- data_2d = padded_data.reshape(cols, nr_of_steps_per_column)
+ Explicit reshaping:
- return data_2d.T
+ ```python
+ # Explicitly reshape to daily pattern
+ data = reshape_data_for_heatmap(data, reshape_time=('D', 'h'))
+ ```
+ No reshaping:
-def heat_map_data_from_df(
- df: pd.DataFrame,
- periods: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'],
- steps_per_period: Literal['W', 'D', 'h', '15min', 'min'],
- fill: Literal['ffill', 'bfill'] | None = None,
-) -> pd.DataFrame:
+ ```python
+ # Keep data as-is
+ data = reshape_data_for_heatmap(data, reshape_time=None)
+ ```
"""
- Reshapes a DataFrame with a DateTime index into a 2D array for heatmap plotting,
- based on a specified sample rate.
- Only specific combinations of `periods` and `steps_per_period` are supported; invalid combinations raise an assertion.
-
- Args:
- df: A DataFrame with a DateTime index containing the data to reshape.
- periods: The time interval of each period (columns of the heatmap),
- such as 'YS' (year start), 'W' (weekly), 'D' (daily), 'h' (hourly) etc.
- steps_per_period: The time interval within each period (rows in the heatmap),
- such as 'YS' (year start), 'W' (weekly), 'D' (daily), 'h' (hourly) etc.
- fill: Method to fill missing values: 'ffill' for forward fill or 'bfill' for backward fill.
+ # If no time dimension, return data as-is
+ if 'time' not in data.dims:
+ return data
+
+ # Handle None (disabled) - return data as-is
+ if reshape_time is None:
+ return data
+
+ # Determine timeframes and timesteps_per_frame based on reshape_time parameter
+ if reshape_time == 'auto':
+ # Check if we need automatic time reshaping
+ facet_dims_used = []
+ if facet_by:
+ facet_dims_used = [facet_by] if isinstance(facet_by, str) else list(facet_by)
+ if animate_by:
+ facet_dims_used.append(animate_by)
+
+ # Get dimensions that would remain for heatmap
+ potential_heatmap_dims = [dim for dim in data.dims if dim not in facet_dims_used]
+
+ # Auto-reshape if only 'time' dimension remains
+ if len(potential_heatmap_dims) == 1 and potential_heatmap_dims[0] == 'time':
+ logger.debug(
+ "Auto-applying time reshaping: Only 'time' dimension remains after faceting/animation. "
+ "Using default timeframes='D' and timesteps_per_frame='h'. "
+ "To customize, use reshape_time=('D', 'h') or disable with reshape_time=None."
+ )
+ timeframes, timesteps_per_frame = 'D', 'h'
+ else:
+ # No reshaping needed
+ return data
+ elif isinstance(reshape_time, tuple):
+ # Explicit reshaping
+ timeframes, timesteps_per_frame = reshape_time
+ else:
+ raise ValueError(f"reshape_time must be 'auto', a tuple like ('D', 'h'), or None. Got: {reshape_time}")
- Returns:
- A DataFrame suitable for heatmap plotting, with rows representing steps within each period
- and columns representing each period.
- """
- assert pd.api.types.is_datetime64_any_dtype(df.index), (
- 'The index of the DataFrame must be datetime to transform it properly for a heatmap plot'
- )
+ # Validate that time is datetime
+ if not np.issubdtype(data.coords['time'].dtype, np.datetime64):
+ raise ValueError(f'Time dimension must be datetime-based, got {data.coords["time"].dtype}')
- # Define formats for different combinations of `periods` and `steps_per_period`
+ # Define formats for different combinations
formats = {
('YS', 'W'): ('%Y', '%W'),
('YS', 'D'): ('%Y', '%j'), # day of year
('YS', 'h'): ('%Y', '%j %H:00'),
('MS', 'D'): ('%Y-%m', '%d'), # day of month
('MS', 'h'): ('%Y-%m', '%d %H:00'),
- ('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week (with prefix for proper sorting)
+ ('W', 'D'): ('%Y-w%W', '%w_%A'), # week and day of week
('W', 'h'): ('%Y-w%W', '%w_%A %H:00'),
('D', 'h'): ('%Y-%m-%d', '%H:00'), # Day and hour
('D', '15min'): ('%Y-%m-%d', '%H:%M'), # Day and minute
@@ -776,43 +764,64 @@ def heat_map_data_from_df(
('h', 'min'): ('%Y-%m-%d %H:00', '%M'), # minute of hour
}
- if df.empty:
- raise ValueError('DataFrame is empty.')
- diffs = df.index.to_series().diff().dropna()
- minimum_time_diff_in_min = diffs.min().total_seconds() / 60
- time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
- if time_intervals[steps_per_period] > minimum_time_diff_in_min:
- logger.error(
- f'To compute the heatmap, the data was aggregated from {minimum_time_diff_in_min:.2f} min to '
- f'{time_intervals[steps_per_period]:.2f} min. Mean values are displayed.'
- )
-
- # Select the format based on the `periods` and `steps_per_period` combination
- format_pair = (periods, steps_per_period)
+ format_pair = (timeframes, timesteps_per_frame)
if format_pair not in formats:
raise ValueError(f'{format_pair} is not a valid format. Choose from {list(formats.keys())}')
period_format, step_format = formats[format_pair]
- df = df.sort_index() # Ensure DataFrame is sorted by time index
+ # Check if resampling is needed
+ if data.sizes['time'] > 1:
+ # Use NumPy for more efficient timedelta computation
+ time_values = data.coords['time'].values # Already numpy datetime64[ns]
+ # Calculate differences and convert to minutes
+ time_diffs = np.diff(time_values).astype('timedelta64[s]').astype(float) / 60.0
+ if time_diffs.size > 0:
+ min_time_diff_min = np.nanmin(time_diffs)
+ time_intervals = {'min': 1, '15min': 15, 'h': 60, 'D': 24 * 60, 'W': 7 * 24 * 60}
+ if time_intervals[timesteps_per_frame] > min_time_diff_min:
+ logger.warning(
+ f'Resampling data from {min_time_diff_min:.2f} min to '
+ f'{time_intervals[timesteps_per_frame]:.2f} min. Mean values are displayed.'
+ )
- resampled_data = df.resample(steps_per_period).mean() # Resample and fill any gaps with NaN
+ # Resample along time dimension
+ resampled = data.resample(time=timesteps_per_frame).mean()
- if fill == 'ffill': # Apply fill method if specified
- resampled_data = resampled_data.ffill()
+ # Apply fill if specified
+ if fill == 'ffill':
+ resampled = resampled.ffill(dim='time')
elif fill == 'bfill':
- resampled_data = resampled_data.bfill()
+ resampled = resampled.bfill(dim='time')
+
+ # Create period and step labels
+ time_values = pd.to_datetime(resampled.coords['time'].values)
+ period_labels = time_values.strftime(period_format)
+ step_labels = time_values.strftime(step_format)
+
+ # Handle special case for weekly day format
+ if '%w_%A' in step_format:
+ step_labels = pd.Series(step_labels).replace('0_Sunday', '7_Sunday').values
+
+ # Add period and step as coordinates
+ resampled = resampled.assign_coords(
+ {
+ 'timeframe': ('time', period_labels),
+ 'timestep': ('time', step_labels),
+ }
+ )
- resampled_data['period'] = resampled_data.index.strftime(period_format)
- resampled_data['step'] = resampled_data.index.strftime(step_format)
- if '%w_%A' in step_format: # Shift index of strings to ensure proper sorting
- resampled_data['step'] = resampled_data['step'].apply(
- lambda x: x.replace('0_Sunday', '7_Sunday') if '0_Sunday' in x else x
- )
+ # Convert to multi-index and unstack
+ resampled = resampled.set_index(time=['timeframe', 'timestep'])
+ result = resampled.unstack('time')
+
+ # Ensure timestep and timeframe come first in dimension order
+ # Get other dimensions
+ other_dims = [d for d in result.dims if d not in ['timestep', 'timeframe']]
- # Pivot the table so periods are columns and steps are indices
- df_pivoted = resampled_data.pivot(columns='period', index='step', values=df.columns[0])
+ # Reorder: timestep, timeframe, then other dimensions
+ result = result.transpose('timestep', 'timeframe', *other_dims)
- return df_pivoted
+ return result
def plot_network(
@@ -899,518 +908,653 @@ def plot_network(
)
-def pie_with_plotly(
- data: pd.DataFrame,
- colors: ColorType = 'viridis',
- title: str = '',
- legend_title: str = '',
- hole: float = 0.0,
- fig: go.Figure | None = None,
-) -> go.Figure:
+def preprocess_data_for_pie(
+ data: xr.Dataset | pd.DataFrame | pd.Series,
+ lower_percentage_threshold: float = 5.0,
+) -> pd.Series:
"""
- Create a pie chart with Plotly to visualize the proportion of values in a DataFrame.
+ Preprocess data for pie chart display.
+
+ Groups items that are individually below the threshold percentage into an "Other" category.
+ Converts various input types to a pandas Series for uniform handling.
Args:
- data: A DataFrame containing the data to plot. If multiple rows exist,
- they will be summed unless a specific index value is passed.
- colors: Color specification, can be:
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
- title: The title of the plot.
- legend_title: The title for the legend.
- hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
- fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
+ data: Input data (xarray Dataset, DataFrame, or Series)
+ lower_percentage_threshold: Percentage threshold - items below this are grouped into "Other"
Returns:
- A Plotly figure object containing the generated pie chart.
+ Processed pandas Series with small items grouped into "Other"
+ """
+ # Convert to Series
+ if isinstance(data, xr.Dataset):
+ # Sum all dimensions for each variable to get total values
+ values = {}
+ for var in data.data_vars:
+ var_data = data[var]
+ if len(var_data.dims) > 0:
+ total_value = float(var_data.sum().item())
+ else:
+ total_value = float(var_data.item())
- Notes:
- - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning.
- - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category
- for better readability.
- - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
+ # Handle negative values
+ if total_value < 0:
+ logger.warning(f'Negative value for {var}: {total_value}. Using absolute value.')
+ total_value = abs(total_value)
- """
- if data.empty:
- logger.error('Empty DataFrame provided for pie chart. Returning empty figure.')
- return go.Figure()
+ values[var] = total_value
- # Create a copy to avoid modifying the original DataFrame
- data_copy = data.copy()
+ series = pd.Series(values)
- # Check if any negative values and warn
- if (data_copy < 0).any().any():
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
- data_copy = data_copy.abs()
+ elif isinstance(data, pd.DataFrame):
+ # Sum across all columns if DataFrame
+ series = data.sum(axis=0)
+ # Handle negative values
+ negative_mask = series < 0
+ if negative_mask.any():
+ logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
+ series = series.abs()
- # If data has multiple rows, sum them to get total for each column
- if len(data_copy) > 1:
- data_sum = data_copy.sum()
- else:
- data_sum = data_copy.iloc[0]
+ else: # pd.Series
+ series = data.copy()
+ # Handle negative values
+ negative_mask = series < 0
+ if negative_mask.any():
+ logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.')
+ series = series.abs()
- # Get labels (column names) and values
- labels = data_sum.index.tolist()
- values = data_sum.values.tolist()
+ # Only keep positive values
+ series = series[series > 0]
- # Apply color mapping using the unified color processor
- processed_colors = ColorProcessor(engine='plotly').process_colors(colors, labels)
+ if series.empty or lower_percentage_threshold <= 0:
+ return series
- # Create figure if not provided
- fig = fig if fig is not None else go.Figure()
+ # Calculate percentages
+ total = series.sum()
+ percentages = (series / total) * 100
- # Add pie trace
- fig.add_trace(
- go.Pie(
- labels=labels,
- values=values,
- hole=hole,
- marker=dict(colors=processed_colors),
- textinfo='percent+label+value',
- textposition='inside',
- insidetextorientation='radial',
+ # Find items below and above threshold
+ below_threshold = series[percentages < lower_percentage_threshold]
+ above_threshold = series[percentages >= lower_percentage_threshold]
+
+ # Only group if there are at least 2 items below threshold
+ if len(below_threshold) > 1:
+ # Create new series with items above threshold + "Other"
+ result = above_threshold.copy()
+ result['Other'] = below_threshold.sum()
+ return result
+
+ return series
+
+
+def dual_pie_with_plotly(
+ data_left: xr.Dataset | pd.DataFrame | pd.Series,
+ data_right: xr.Dataset | pd.DataFrame | pd.Series,
+ colors: ColorType | None = None,
+ title: str = '',
+ subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
+ legend_title: str = '',
+ hole: float = 0.2,
+ lower_percentage_group: float = 5.0,
+ text_info: str = 'percent+label',
+ text_position: str = 'inside',
+ hover_template: str = '%{label}: %{value} (%{percent})',
+) -> go.Figure:
+ """
+ Create two pie charts side by side with Plotly.
+
+ Args:
+ data_left: Data for the left pie chart. Variables are summed across all dimensions.
+ data_right: Data for the right pie chart. Variables are summed across all dimensions.
+ colors: Color specification (colorscale name, list of colors, or dict mapping)
+ title: The main title of the plot.
+ subtitles: Tuple containing the subtitles for (left, right) charts.
+ legend_title: The title for the legend.
+ hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
+ lower_percentage_group: Group segments whose cumulative share is below this percentage (0β100) into "Other".
+ hover_template: Template for hover text. Use %{label}, %{value}, %{percent}.
+ text_info: What to show on pie segments: 'label', 'percent', 'value', 'label+percent',
+ 'label+value', 'percent+value', 'label+percent+value', or 'none'.
+ text_position: Position of text: 'inside', 'outside', 'auto', or 'none'.
+
+ Returns:
+ Plotly Figure object
+ """
+ if colors is None:
+ colors = CONFIG.Plotting.default_qualitative_colorscale
+
+ # Preprocess data to Series
+ left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
+ right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
+
+ # Extract labels and values
+ left_labels = left_series.index.tolist()
+ left_values = left_series.values.tolist()
+
+ right_labels = right_series.index.tolist()
+ right_values = right_series.values.tolist()
+
+ # Get all unique labels for consistent coloring
+ all_labels = sorted(set(left_labels) | set(right_labels))
+
+ # Create color map
+ color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
+
+ # Create figure
+ fig = go.Figure()
+
+ # Add left pie
+ if left_labels:
+ fig.add_trace(
+ go.Pie(
+ labels=left_labels,
+ values=left_values,
+ name=subtitles[0],
+ marker=dict(colors=[color_map.get(label, '#636EFA') for label in left_labels]),
+ hole=hole,
+ textinfo=text_info,
+ textposition=text_position,
+ hovertemplate=hover_template,
+ domain=dict(x=[0, 0.48]),
+ )
+ )
+
+ # Add right pie
+ if right_labels:
+ fig.add_trace(
+ go.Pie(
+ labels=right_labels,
+ values=right_values,
+ name=subtitles[1],
+ marker=dict(colors=[color_map.get(label, '#636EFA') for label in right_labels]),
+ hole=hole,
+ textinfo=text_info,
+ textposition=text_position,
+ hovertemplate=hover_template,
+ domain=dict(x=[0.52, 1]),
+ )
)
- )
- # Update layout for better aesthetics
+ # Update layout
fig.update_layout(
title=title,
legend_title=legend_title,
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
- font=dict(size=14), # Increase font size for better readability
+ margin=dict(t=80, b=50, l=30, r=30),
)
return fig
-def pie_with_matplotlib(
- data: pd.DataFrame,
- colors: ColorType = 'viridis',
+def dual_pie_with_matplotlib(
+ data_left: xr.Dataset | pd.DataFrame | pd.Series,
+ data_right: xr.Dataset | pd.DataFrame | pd.Series,
+ colors: ColorType | None = None,
title: str = '',
- legend_title: str = 'Categories',
- hole: float = 0.0,
- figsize: tuple[int, int] = (10, 8),
- fig: plt.Figure | None = None,
- ax: plt.Axes | None = None,
-) -> tuple[plt.Figure, plt.Axes]:
+ subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
+ legend_title: str = '',
+ hole: float = 0.2,
+ lower_percentage_group: float = 5.0,
+ figsize: tuple[int, int] = (14, 7),
+) -> tuple[plt.Figure, list[plt.Axes]]:
"""
- Create a pie chart with Matplotlib to visualize the proportion of values in a DataFrame.
+ Create two pie charts side by side with Matplotlib.
Args:
- data: A DataFrame containing the data to plot. If multiple rows exist,
- they will be summed unless a specific index value is passed.
- colors: Color specification, can be:
- - A string with a colormap name (e.g., 'viridis', 'plasma')
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
- - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
- title: The title of the plot.
+ data_left: Data for the left pie chart.
+ data_right: Data for the right pie chart.
+ colors: Color specification (colorscale name, list of colors, or dict mapping)
+ title: The main title of the plot.
+ subtitles: Tuple containing the subtitles for (left, right) charts.
legend_title: The title for the legend.
- hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0).
+ hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
+ lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category.
figsize: The size of the figure (width, height) in inches.
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
- ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created.
Returns:
- A tuple containing the Matplotlib figure and axes objects used for the plot.
+ Tuple of (Figure, list of Axes)
+ """
+ if colors is None:
+ colors = CONFIG.Plotting.default_qualitative_colorscale
- Notes:
- - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning.
- - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category
- for better readability.
- - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing.
+ # Preprocess data to Series
+ left_series = preprocess_data_for_pie(data_left, lower_percentage_group)
+ right_series = preprocess_data_for_pie(data_right, lower_percentage_group)
- """
- if data.empty:
- logger.error('Empty DataFrame provided for pie chart. Returning empty figure.')
- if fig is None or ax is None:
- fig, ax = plt.subplots(figsize=figsize)
- return fig, ax
+ # Extract labels and values
+ left_labels = left_series.index.tolist()
+ left_values = left_series.values.tolist()
- # Create a copy to avoid modifying the original DataFrame
- data_copy = data.copy()
+ right_labels = right_series.index.tolist()
+ right_values = right_series.values.tolist()
- # Check if any negative values and warn
- if (data_copy < 0).any().any():
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
- data_copy = data_copy.abs()
+ # Get all unique labels for consistent coloring
+ all_labels = sorted(set(left_labels) | set(right_labels))
- # If data has multiple rows, sum them to get total for each column
- if len(data_copy) > 1:
- data_sum = data_copy.sum()
- else:
- data_sum = data_copy.iloc[0]
+ # Create color map (process_colors always returns a dict)
+ color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale)
- # Get labels (column names) and values
- labels = data_sum.index.tolist()
- values = data_sum.values.tolist()
+ # Create figure
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
- # Apply color mapping using the unified color processor
- processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, labels)
+ def draw_pie(ax, labels, values, subtitle):
+ """Draw a single pie chart."""
+ if not labels:
+ ax.set_title(subtitle)
+ ax.axis('off')
+ return
- # Create figure and axis if not provided
- if fig is None or ax is None:
- fig, ax = plt.subplots(figsize=figsize)
+ chart_colors = [color_map[label] for label in labels]
- # Draw the pie chart
- wedges, texts, autotexts = ax.pie(
- values,
- labels=labels,
- colors=processed_colors,
- autopct='%1.1f%%',
- startangle=90,
- shadow=False,
- wedgeprops=dict(width=0.5) if hole > 0 else None, # Set width for donut
- )
+ # Draw pie
+ wedges, texts, autotexts = ax.pie(
+ values,
+ labels=labels,
+ colors=chart_colors,
+ autopct='%1.1f%%',
+ startangle=90,
+ wedgeprops=dict(width=1 - hole) if hole > 0 else None,
+ )
+
+ # Style text
+ for autotext in autotexts:
+ autotext.set_fontsize(10)
+ autotext.set_color('white')
+ autotext.set_weight('bold')
+
+ ax.set_aspect('equal')
+ ax.set_title(subtitle, fontsize=14, pad=20)
- # Adjust the wedgeprops to make donut hole size consistent with plotly
- # For matplotlib, the hole size is determined by the wedge width
- # Convert hole parameter to wedge width
- if hole > 0:
- # Adjust hole size to match plotly's hole parameter
- # In matplotlib, wedge width is relative to the radius (which is 1)
- # For plotly, hole is a fraction of the radius
- wedge_width = 1 - hole
- for wedge in wedges:
- wedge.set_width(wedge_width)
-
- # Customize the appearance
- # Make autopct text more visible
- for autotext in autotexts:
- autotext.set_fontsize(10)
- autotext.set_color('white')
-
- # Set aspect ratio to be equal to ensure a circular pie
- ax.set_aspect('equal')
-
- # Add title
+ # Draw both pies
+ draw_pie(axes[0], left_labels, left_values, subtitles[0])
+ draw_pie(axes[1], right_labels, right_values, subtitles[1])
+
+ # Add main title
if title:
- ax.set_title(title, fontsize=16)
+ fig.suptitle(title, fontsize=16, y=0.98)
- # Create a legend if there are many segments
- if len(labels) > 6:
- ax.legend(wedges, labels, title=legend_title, loc='center left', bbox_to_anchor=(1, 0, 0.5, 1))
+ # Create unified legend
+ if left_labels or right_labels:
+ handles = [
+ plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[label], markersize=10)
+ for label in all_labels
+ ]
+
+ fig.legend(
+ handles=handles,
+ labels=all_labels,
+ title=legend_title,
+ loc='lower center',
+ bbox_to_anchor=(0.5, -0.02),
+ ncol=min(len(all_labels), 5),
+ )
+
+ fig.subplots_adjust(bottom=0.15)
- # Apply tight layout
fig.tight_layout()
- return fig, ax
+ return fig, axes
-def dual_pie_with_plotly(
- data_left: pd.Series,
- data_right: pd.Series,
- colors: ColorType = 'viridis',
+def heatmap_with_plotly(
+ data: xr.DataArray,
+ colors: ColorType | None = None,
title: str = '',
- subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
- legend_title: str = '',
- hole: float = 0.2,
- lower_percentage_group: float = 5.0,
- hover_template: str = '%{label}: %{value} (%{percent})',
- text_info: str = 'percent+label',
- text_position: str = 'inside',
+ facet_by: str | list[str] | None = None,
+ animate_by: str | None = None,
+ facet_cols: int | None = None,
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
+ | Literal['auto']
+ | None = 'auto',
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
+ **imshow_kwargs: Any,
) -> go.Figure:
"""
- Create two pie charts side by side with Plotly, with consistent coloring across both charts.
+ Plot a heatmap visualization using Plotly's imshow with faceting and animation support.
+
+ This function creates heatmap visualizations from xarray DataArrays, supporting
+ multi-dimensional data through faceting (subplots) and animation. It automatically
+ handles dimension reduction and data reshaping for optimal heatmap display.
+
+ Automatic Time Reshaping:
+ If only the 'time' dimension remains after faceting/animation (making the data 1D),
+ the function automatically reshapes time into a 2D format using default values
+ (timeframes='D', timesteps_per_frame='h'). This creates a daily pattern heatmap
+ showing hours vs days.
Args:
- data_left: Series for the left pie chart.
- data_right: Series for the right pie chart.
- colors: Color specification, can be:
- - A string with a colorscale name (e.g., 'viridis', 'plasma')
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
- - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
- title: The main title of the plot.
- subtitles: Tuple containing the subtitles for (left, right) charts.
- legend_title: The title for the legend.
- hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
- lower_percentage_group: Group segments whose cumulative share is below this percentage (0β100) into "Other".
- hover_template: Template for hover text. Use %{label}, %{value}, %{percent}.
- text_info: What to show on pie segments: 'label', 'percent', 'value', 'label+percent',
- 'label+value', 'percent+value', 'label+percent+value', or 'none'.
- text_position: Position of text: 'inside', 'outside', 'auto', or 'none'.
+ data: An xarray DataArray containing the data to visualize. Should have at least
+ 2 dimensions, or a 'time' dimension that can be reshaped into 2D.
+ colors: Color specification (colorscale name, list, or dict). Common options:
+ 'turbo', 'plasma', 'RdBu', 'portland'.
+ title: The main title of the heatmap.
+ facet_by: Dimension to create facets for. Creates a subplot grid.
+ Can be a single dimension name or list (only first dimension used).
+ Note: px.imshow only supports single-dimension faceting.
+ If the dimension doesn't exist in the data, it will be silently ignored.
+ animate_by: Dimension to animate over. Creates animation frames.
+ If the dimension doesn't exist in the data, it will be silently ignored.
+ facet_cols: Number of columns in the facet grid (used with facet_by).
+ reshape_time: Time reshaping configuration:
+ - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension remains
+ - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
+ - None: Disable time reshaping (will error if only 1D time data)
+ fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
+ **imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow.
+ Common options include:
+ - aspect: 'auto', 'equal', or a number for aspect ratio
+ - zmin, zmax: Minimum and maximum values for color scale
+ - labels: Dict to customize axis labels
Returns:
- A Plotly figure object containing the generated dual pie chart.
- """
- from plotly.subplots import make_subplots
+ A Plotly figure object containing the heatmap visualization.
- # Check for empty data
- if data_left.empty and data_right.empty:
- logger.error('Both datasets are empty. Returning empty figure.')
- return go.Figure()
+ Examples:
+ Simple heatmap:
- # Create a subplot figure
- fig = make_subplots(
- rows=1, cols=2, specs=[[{'type': 'pie'}, {'type': 'pie'}]], subplot_titles=subtitles, horizontal_spacing=0.05
- )
+ ```python
+ fig = heatmap_with_plotly(data_array, colors='RdBu', title='Temperature Map')
+ ```
- # Process series to handle negative values and apply minimum percentage threshold
- def preprocess_series(series: pd.Series):
- """
- Preprocess a series for pie chart display by handling negative values
- and grouping the smallest parts together if they collectively represent
- less than the specified percentage threshold.
+ Facet by scenario:
- Args:
- series: The series to preprocess
+ ```python
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', facet_cols=2)
+ ```
- Returns:
- A preprocessed pandas Series
- """
- # Handle negative values
- if (series < 0).any():
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
- series = series.abs()
+ Animate by period:
- # Remove zeros
- series = series[series > 0]
+ ```python
+ fig = heatmap_with_plotly(data_array, animate_by='period')
+ ```
- # Apply minimum percentage threshold if needed
- if lower_percentage_group and not series.empty:
- total = series.sum()
- if total > 0:
- # Sort series by value (ascending)
- sorted_series = series.sort_values()
+ Automatic time reshaping (when only time dimension remains):
- # Calculate cumulative percentage contribution
- cumulative_percent = (sorted_series.cumsum() / total) * 100
+ ```python
+ # Data with dims ['time', 'scenario', 'period']
+ # After faceting and animation, only 'time' remains -> auto-reshapes to (timestep, timeframe)
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period')
+ ```
- # Find entries that collectively make up less than lower_percentage_group
- to_group = cumulative_percent <= lower_percentage_group
+ Explicit time reshaping:
- if to_group.sum() > 1:
- # Create "Other" category for the smallest values that together are < threshold
- other_sum = sorted_series[to_group].sum()
+ ```python
+ fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D'))
+ ```
+ """
+ if colors is None:
+ colors = CONFIG.Plotting.default_sequential_colorscale
- # Keep only values that aren't in the "Other" group
- result_series = series[~series.index.isin(sorted_series[to_group].index)]
+ # Apply CONFIG defaults if not explicitly set
+ if facet_cols is None:
+ facet_cols = CONFIG.Plotting.default_facet_cols
- # Add the "Other" category if it has a value
- if other_sum > 0:
- result_series['Other'] = other_sum
+ # Handle empty data
+ if data.size == 0:
+ return go.Figure()
- return result_series
+ # Apply time reshaping using the new unified function
+ data = reshape_data_for_heatmap(
+ data, reshape_time=reshape_time, facet_by=facet_by, animate_by=animate_by, fill=fill
+ )
- return series
+ # Get available dimensions
+ available_dims = list(data.dims)
- data_left_processed = preprocess_series(data_left)
- data_right_processed = preprocess_series(data_right)
+ # Validate and filter facet_by dimensions
+ if facet_by is not None:
+ if isinstance(facet_by, str):
+ if facet_by not in available_dims:
+ logger.debug(
+ f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
+ f'Ignoring facet_by parameter.'
+ )
+ facet_by = None
+ elif isinstance(facet_by, list):
+ missing_dims = [dim for dim in facet_by if dim not in available_dims]
+ facet_by = [dim for dim in facet_by if dim in available_dims]
+ if missing_dims:
+ logger.debug(
+ f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
+ f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
+ )
+ if len(facet_by) == 0:
+ facet_by = None
+
+ # Validate animate_by dimension
+ if animate_by is not None and animate_by not in available_dims:
+ logger.debug(
+ f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
+ f'Ignoring animate_by parameter.'
+ )
+ animate_by = None
- # Get unique set of all labels for consistent coloring
- all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index))
+ # Determine which dimensions are used for faceting/animation
+ facet_dims = []
+ if facet_by:
+ facet_dims = [facet_by] if isinstance(facet_by, str) else facet_by
+ if animate_by:
+ facet_dims.append(animate_by)
- # Get consistent color mapping for both charts using our unified function
- color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True)
+ # Get remaining dimensions for the heatmap itself
+ heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
- # Function to create a pie trace with consistently mapped colors
- def create_pie_trace(data_series, side):
- if data_series.empty:
- return None
+ if len(heatmap_dims) < 2:
+ # Handle single-dimension case by adding variable name as a dimension
+ if len(heatmap_dims) == 1:
+ # Get the variable name, or use a default
+ var_name = data.name if data.name else 'value'
- labels = data_series.index.tolist()
- values = data_series.values.tolist()
- trace_colors = [color_map[label] for label in labels]
+ # Expand the DataArray by adding a new dimension with the variable name
+ data = data.expand_dims({'variable': [var_name]})
- return go.Pie(
- labels=labels,
- values=values,
- name=side,
- marker=dict(colors=trace_colors),
- hole=hole,
- textinfo=text_info,
- textposition=text_position,
- insidetextorientation='radial',
- hovertemplate=hover_template,
- sort=True, # Sort values by default (largest first)
- )
+ # Update available dimensions
+ available_dims = list(data.dims)
+ heatmap_dims = [dim for dim in available_dims if dim not in facet_dims]
- # Add left pie if data exists
- left_trace = create_pie_trace(data_left_processed, subtitles[0])
- if left_trace:
- left_trace.domain = dict(x=[0, 0.48])
- fig.add_trace(left_trace, row=1, col=1)
+ logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}')
+ else:
+ # No dimensions at all - cannot create a heatmap
+ logger.error(
+ f'Heatmap requires at least 1 dimension. '
+ f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}'
+ )
+ return go.Figure()
+
+ # Setup faceting parameters for Plotly Express
+ # Note: px.imshow only supports facet_col, not facet_row
+ facet_col_param = None
+ if facet_by:
+ if isinstance(facet_by, str):
+ facet_col_param = facet_by
+ elif len(facet_by) == 1:
+ facet_col_param = facet_by[0]
+ elif len(facet_by) >= 2:
+ # px.imshow doesn't support facet_row, so we can only facet by one dimension
+ # Use the first dimension and warn about the rest
+ facet_col_param = facet_by[0]
+ logger.warning(
+ f'px.imshow only supports faceting by a single dimension. '
+ f'Using {facet_by[0]} for faceting. Dimensions {facet_by[1:]} will be ignored. '
+ f'Consider using animate_by for additional dimensions.'
+ )
- # Add right pie if data exists
- right_trace = create_pie_trace(data_right_processed, subtitles[1])
- if right_trace:
- right_trace.domain = dict(x=[0.52, 1])
- fig.add_trace(right_trace, row=1, col=2)
+ # Create the imshow plot - px.imshow can work directly with xarray DataArrays
+ common_args = {
+ 'img': data,
+ 'color_continuous_scale': colors,
+ 'title': title,
+ }
- # Update layout
- fig.update_layout(
- title=title,
- legend_title=legend_title,
- plot_bgcolor='rgba(0,0,0,0)', # Transparent background
- paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
- font=dict(size=14),
- margin=dict(t=80, b=50, l=30, r=30),
- )
+ # Add faceting if specified
+ if facet_col_param:
+ common_args['facet_col'] = facet_col_param
+ if facet_cols:
+ common_args['facet_col_wrap'] = facet_cols
+
+ # Add animation if specified
+ if animate_by:
+ common_args['animation_frame'] = animate_by
+
+ # Merge in additional imshow kwargs
+ common_args.update(imshow_kwargs)
+
+ try:
+ fig = px.imshow(**common_args)
+ except Exception as e:
+ logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.')
+ # Fallback: create a simple heatmap without faceting
+ fallback_args = {
+ 'img': data.values,
+ 'color_continuous_scale': colors,
+ 'title': title,
+ }
+ fallback_args.update(imshow_kwargs)
+ fig = px.imshow(**fallback_args)
return fig
-def dual_pie_with_matplotlib(
- data_left: pd.Series,
- data_right: pd.Series,
- colors: ColorType = 'viridis',
+def heatmap_with_matplotlib(
+ data: xr.DataArray,
+ colors: ColorType | None = None,
title: str = '',
- subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'),
- legend_title: str = '',
- hole: float = 0.2,
- lower_percentage_group: float = 5.0,
- figsize: tuple[int, int] = (14, 7),
- fig: plt.Figure | None = None,
- axes: list[plt.Axes] | None = None,
-) -> tuple[plt.Figure, list[plt.Axes]]:
+ figsize: tuple[float, float] = (12, 6),
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
+ | Literal['auto']
+ | None = 'auto',
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
+ vmin: float | None = None,
+ vmax: float | None = None,
+ imshow_kwargs: dict[str, Any] | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ **kwargs: Any,
+) -> tuple[plt.Figure, plt.Axes]:
"""
- Create two pie charts side by side with Matplotlib, with consistent coloring across both charts.
- Leverages the existing pie_with_matplotlib function.
+ Plot a heatmap visualization using Matplotlib's imshow.
+
+ This function creates a basic 2D heatmap from an xarray DataArray using matplotlib's
+ imshow function. For multi-dimensional data, only the first two dimensions are used.
Args:
- data_left: Series for the left pie chart.
- data_right: Series for the right pie chart.
- colors: Color specification, can be:
- - A string with a colormap name (e.g., 'viridis', 'plasma')
- - A list of color strings (e.g., ['#ff0000', '#00ff00'])
- - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'})
- title: The main title of the plot.
- subtitles: Tuple containing the subtitles for (left, right) charts.
- legend_title: The title for the legend.
- hole: Size of the hole in the center for creating donut charts (0.0 to 1.0).
- lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category.
+ data: An xarray DataArray containing the data to visualize. Should have at least
+ 2 dimensions. If more than 2 dimensions exist, additional dimensions will
+ be reduced by taking the first slice.
+ colors: Color specification. Should be a colorscale name (e.g., 'turbo', 'RdBu').
+ title: The title of the heatmap.
figsize: The size of the figure (width, height) in inches.
- fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created.
- axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created.
+ reshape_time: Time reshaping configuration:
+ - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension
+ - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours)
+ - None: Disable time reshaping
+ fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'.
+ vmin: Minimum value for color scale. If None, uses data minimum.
+ vmax: Maximum value for color scale. If None, uses data maximum.
+ imshow_kwargs: Optional dict of parameters to pass to ax.imshow().
+ Use this to customize image properties (e.g., interpolation, aspect).
+ cbar_kwargs: Optional dict of parameters to pass to plt.colorbar().
+ Use this to customize colorbar properties (e.g., orientation, label).
+ **kwargs: Additional keyword arguments passed to ax.imshow().
+ Common options include:
+ - interpolation: 'nearest', 'bilinear', 'bicubic', etc.
+ - alpha: Transparency level (0-1)
+ - extent: [left, right, bottom, top] for axis limits
Returns:
- A tuple containing the Matplotlib figure and list of axes objects used for the plot.
- """
- # Check for empty data
- if data_left.empty and data_right.empty:
- logger.error('Both datasets are empty. Returning empty figure.')
- if fig is None:
- fig, axes = plt.subplots(1, 2, figsize=figsize)
- return fig, axes
-
- # Create figure and axes if not provided
- if fig is None or axes is None:
- fig, axes = plt.subplots(1, 2, figsize=figsize)
-
- # Process series to handle negative values and apply minimum percentage threshold
- def preprocess_series(series: pd.Series):
- """
- Preprocess a series for pie chart display by handling negative values
- and grouping the smallest parts together if they collectively represent
- less than the specified percentage threshold.
- """
- # Handle negative values
- if (series < 0).any():
- logger.error('Negative values detected in data. Using absolute values for pie chart.')
- series = series.abs()
-
- # Remove zeros
- series = series[series > 0]
-
- # Apply minimum percentage threshold if needed
- if lower_percentage_group and not series.empty:
- total = series.sum()
- if total > 0:
- # Sort series by value (ascending)
- sorted_series = series.sort_values()
-
- # Calculate cumulative percentage contribution
- cumulative_percent = (sorted_series.cumsum() / total) * 100
-
- # Find entries that collectively make up less than lower_percentage_group
- to_group = cumulative_percent <= lower_percentage_group
+ A tuple containing the Matplotlib figure and axes objects used for the plot.
- if to_group.sum() > 1:
- # Create "Other" category for the smallest values that together are < threshold
- other_sum = sorted_series[to_group].sum()
+ Notes:
+ - Matplotlib backend doesn't support faceting or animation. Use plotly engine for those features.
+ - The y-axis is automatically inverted to display data with origin at top-left.
+ - A colorbar is added to show the value scale.
- # Keep only values that aren't in the "Other" group
- result_series = series[~series.index.isin(sorted_series[to_group].index)]
+ Examples:
+ ```python
+ fig, ax = heatmap_with_matplotlib(data_array, colors='RdBu', title='Temperature')
+ plt.savefig('heatmap.png')
+ ```
- # Add the "Other" category if it has a value
- if other_sum > 0:
- result_series['Other'] = other_sum
+ Time reshaping:
- return result_series
+ ```python
+ fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h'))
+ ```
+ """
+ if colors is None:
+ colors = CONFIG.Plotting.default_sequential_colorscale
- return series
+ # Initialize kwargs if not provided
+ if imshow_kwargs is None:
+ imshow_kwargs = {}
+ if cbar_kwargs is None:
+ cbar_kwargs = {}
- # Preprocess data
- data_left_processed = preprocess_series(data_left)
- data_right_processed = preprocess_series(data_right)
+ # Merge any additional kwargs into imshow_kwargs
+ # This allows users to pass imshow options directly
+ imshow_kwargs.update(kwargs)
- # Convert Series to DataFrames for pie_with_matplotlib
- df_left = pd.DataFrame(data_left_processed).T if not data_left_processed.empty else pd.DataFrame()
- df_right = pd.DataFrame(data_right_processed).T if not data_right_processed.empty else pd.DataFrame()
+ # Handle empty data
+ if data.size == 0:
+ fig, ax = plt.subplots(figsize=figsize)
+ return fig, ax
- # Get unique set of all labels for consistent coloring
- all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index))
+ # Apply time reshaping using the new unified function
+ # Matplotlib doesn't support faceting/animation, so we pass None for those
+ data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill)
- # Get consistent color mapping for both charts using our unified function
- color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True)
+ # Handle single-dimension case by adding variable name as a dimension
+ if isinstance(data, xr.DataArray) and len(data.dims) == 1:
+ var_name = data.name if data.name else 'value'
+ data = data.expand_dims({'variable': [var_name]})
+ logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}')
- # Configure colors for each DataFrame based on the consistent mapping
- left_colors = [color_map[col] for col in df_left.columns] if not df_left.empty else []
- right_colors = [color_map[col] for col in df_right.columns] if not df_right.empty else []
+ # Create figure and axes
+ fig, ax = plt.subplots(figsize=figsize)
- # Create left pie chart
- if not df_left.empty:
- pie_with_matplotlib(data=df_left, colors=left_colors, title=subtitles[0], hole=hole, fig=fig, ax=axes[0])
- else:
- axes[0].set_title(subtitles[0])
- axes[0].axis('off')
+ # Extract data values
+ # If data has more than 2 dimensions, we need to reduce it
+ if isinstance(data, xr.DataArray):
+ # Get the first 2 dimensions
+ dims = list(data.dims)
+ if len(dims) > 2:
+ logger.warning(
+ f'Data has {len(dims)} dimensions: {dims}. '
+ f'Only the first 2 will be used for the heatmap. '
+ f'Use the plotly engine for faceting/animation support.'
+ )
+ # Select only the first 2 dimensions by taking first slice of others
+ selection = {dim: 0 for dim in dims[2:]}
+ data = data.isel(selection)
- # Create right pie chart
- if not df_right.empty:
- pie_with_matplotlib(data=df_right, colors=right_colors, title=subtitles[1], hole=hole, fig=fig, ax=axes[1])
+ values = data.values
+ x_labels = data.dims[1] if len(data.dims) > 1 else 'x'
+ y_labels = data.dims[0] if len(data.dims) > 0 else 'y'
else:
- axes[1].set_title(subtitles[1])
- axes[1].axis('off')
-
- # Add main title
- if title:
- fig.suptitle(title, fontsize=16, y=0.98)
+ values = data
+ x_labels = 'x'
+ y_labels = 'y'
+
+ # Create the heatmap using imshow with user customizations
+ imshow_defaults = {'cmap': colors, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax}
+ imshow_defaults.update(imshow_kwargs) # User kwargs override defaults
+ im = ax.imshow(values, **imshow_defaults)
+
+ # Add colorbar with user customizations
+ cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05}
+ cbar_defaults.update(cbar_kwargs) # User kwargs override defaults
+ cbar = plt.colorbar(im, **cbar_defaults)
+
+ # Set colorbar label if not overridden by user
+ if 'label' not in cbar_kwargs:
+ cbar.set_label('Value')
+
+ # Set labels and title
+ ax.set_xlabel(str(x_labels).capitalize())
+ ax.set_ylabel(str(y_labels).capitalize())
+ ax.set_title(title)
- # Adjust layout
+ # Apply tight layout
fig.tight_layout()
- # Create a unified legend if both charts have data
- if not df_left.empty and not df_right.empty:
- # Remove individual legends
- for ax in axes:
- if ax.get_legend():
- ax.get_legend().remove()
-
- # Create handles for the unified legend
- handles = []
- labels_for_legend = []
-
- for label in all_labels:
- color = color_map[label]
- patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=label)
- handles.append(patch)
- labels_for_legend.append(label)
-
- # Add unified legend
- fig.legend(
- handles=handles,
- labels=labels_for_legend,
- title=legend_title,
- loc='lower center',
- bbox_to_anchor=(0.5, 0),
- ncol=min(len(all_labels), 5), # Limit columns to 5 for readability
- )
-
- # Add padding at the bottom for the legend
- fig.subplots_adjust(bottom=0.2)
-
- return fig, axes
+ return fig, ax
def export_figure(
@@ -1418,8 +1562,9 @@ def export_figure(
default_path: pathlib.Path,
default_filetype: str | None = None,
user_path: pathlib.Path | None = None,
- show: bool = True,
+ show: bool | None = None,
save: bool = False,
+ dpi: int | None = None,
) -> go.Figure | tuple[plt.Figure, plt.Axes]:
"""
Export a figure to a file and or show it.
@@ -1429,13 +1574,21 @@ def export_figure(
default_path: The default file path if no user filename is provided.
default_filetype: The default filetype if the path doesnt end with a filetype.
user_path: An optional user-specified file path.
- show: Whether to display the figure (default: True).
+ show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None).
save: Whether to save the figure (default: False).
+ dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi.
Raises:
ValueError: If no default filetype is provided and the path doesn't specify a filetype.
TypeError: If the figure type is not supported.
"""
+ # Apply CONFIG defaults if not explicitly set
+ if show is None:
+ show = CONFIG.Plotting.default_show
+
+ if dpi is None:
+ dpi = CONFIG.Plotting.default_dpi
+
filename = user_path or default_path
filename = filename.with_name(filename.name.replace('|', '__'))
if filename.suffix == '':
@@ -1450,25 +1603,17 @@ def export_figure(
filename = filename.with_suffix('.html')
try:
- is_test_env = 'PYTEST_CURRENT_TEST' in os.environ
-
- if is_test_env:
- # Test environment: never open browser, only save if requested
- if save:
- fig.write_html(str(filename))
- # Ignore show flag in tests
- else:
- # Production environment: respect show and save flags
- if save and show:
- # Save and auto-open in browser
- plotly.offline.plot(fig, filename=str(filename))
- elif save and not show:
- # Save without opening
- fig.write_html(str(filename))
- elif show and not save:
- # Show interactively without saving
- fig.show()
- # If neither save nor show: do nothing
+ # Respect show and save flags (tests should set CONFIG.Plotting.default_show=False)
+ if save and show:
+ # Save and auto-open in browser
+ plotly.offline.plot(fig, filename=str(filename))
+ elif save and not show:
+ # Save without opening
+ fig.write_html(str(filename))
+ elif show and not save:
+ # Show interactively without saving
+ fig.show()
+ # If neither save nor show: do nothing
finally:
# Cleanup to prevent socket warnings
if hasattr(fig, '_renderer'):
@@ -1479,16 +1624,15 @@ def export_figure(
elif isinstance(figure_like, tuple):
fig, ax = figure_like
if show:
- # Only show if using interactive backend and not in test environment
+ # Only show if using interactive backend (tests should set CONFIG.Plotting.default_show=False)
backend = matplotlib.get_backend().lower()
is_interactive = backend not in {'agg', 'pdf', 'ps', 'svg', 'template'}
- is_test_env = 'PYTEST_CURRENT_TEST' in os.environ
- if is_interactive and not is_test_env:
+ if is_interactive:
plt.show()
if save:
- fig.savefig(str(filename), dpi=300)
+ fig.savefig(str(filename), dpi=dpi)
plt.close(fig) # Close figure to free memory
return fig, ax
diff --git a/flixopt/results.py b/flixopt/results.py
index 2e951af70..26eaf9d5d 100644
--- a/flixopt/results.py
+++ b/flixopt/results.py
@@ -1,7 +1,7 @@
from __future__ import annotations
+import copy
import datetime
-import json
import logging
import pathlib
import warnings
@@ -10,16 +10,18 @@
import linopy
import numpy as np
import pandas as pd
-import plotly
import xarray as xr
-import yaml
from . import io as fx_io
from . import plotting
+from .color_processing import process_colors
+from .config import CONFIG
from .flow_system import FlowSystem
+from .structure import CompositeContainerMixin, ElementContainer, ResultsContainer
if TYPE_CHECKING:
import matplotlib.pyplot as plt
+ import plotly
import pyvis
from .calculation import Calculation, SegmentedCalculation
@@ -29,13 +31,30 @@
logger = logging.getLogger('flixopt')
+def load_mapping_from_file(path: pathlib.Path) -> dict[str, str | list[str]]:
+ """Load color mapping from JSON or YAML file.
+
+ Tries loader based on file suffix first, with fallback to the other format.
+
+ Args:
+ path: Path to config file (.json or .yaml/.yml)
+
+ Returns:
+ Dictionary mapping components to colors or colorscales to component lists
+
+ Raises:
+ ValueError: If file cannot be loaded as JSON or YAML
+ """
+ return fx_io.load_config_file(path)
+
+
class _FlowSystemRestorationError(Exception):
"""Exception raised when a FlowSystem cannot be restored from dataset."""
pass
-class CalculationResults:
+class CalculationResults(CompositeContainerMixin['ComponentResults | BusResults | EffectResults | FlowResults']):
"""Comprehensive container for optimization calculation results and analysis tools.
This class provides unified access to all optimization results including flow rates,
@@ -107,6 +126,20 @@ class CalculationResults:
).mean()
```
+ Configure automatic color management for plots:
+
+ ```python
+ # Dict-based configuration:
+ results.setup_colors({'Solar*': 'Oranges', 'Wind*': 'Blues', 'Battery': 'green'})
+
+ # All plots automatically use configured colors (colors=None is the default)
+ results['ElectricityBus'].plot_node_balance()
+ results['Battery'].plot_charge_state()
+
+ # Override when needed
+ results['ElectricityBus'].plot_node_balance(colors='turbo') # Ignores setup
+ ```
+
Design Patterns:
**Factory Methods**: Use `from_file()` and `from_calculation()` for creation or access directly from `Calculation.results`
**Dictionary Access**: Use `results[element_label]` for element-specific results
@@ -137,8 +170,7 @@ def from_file(cls, folder: str | pathlib.Path, name: str) -> CalculationResults:
except Exception as e:
logger.critical(f'Could not load the linopy model "{name}" from file ("{paths.linopy_model}"): {e}')
- with open(paths.summary, encoding='utf-8') as f:
- summary = yaml.load(f, Loader=yaml.FullLoader)
+ summary = fx_io.load_yaml(paths.summary)
return cls(
solution=fx_io.load_dataset_from_netcdf(paths.solution),
@@ -195,8 +227,8 @@ def __init__(
if 'flow_system' in kwargs and flow_system_data is None:
flow_system_data = kwargs.pop('flow_system')
warnings.warn(
- "The 'flow_system' parameter is deprecated. Use 'flow_system_data' instead."
- "Acess is now by '.flow_system_data', while '.flow_system' returns the restored FlowSystem.",
+ "The 'flow_system' parameter is deprecated. Use 'flow_system_data' instead. "
+ "Access is now via '.flow_system_data', while '.flow_system' returns the restored FlowSystem.",
DeprecationWarning,
stacklevel=2,
)
@@ -207,13 +239,18 @@ def __init__(
self.name = name
self.model = model
self.folder = pathlib.Path(folder) if folder is not None else pathlib.Path.cwd() / 'results'
- self.components = {
+
+ # Create ResultsContainers for better access patterns
+ components_dict = {
label: ComponentResults(self, **infos) for label, infos in self.solution.attrs['Components'].items()
}
+ self.components = ResultsContainer(elements=components_dict, element_type_name='component results')
- self.buses = {label: BusResults(self, **infos) for label, infos in self.solution.attrs['Buses'].items()}
+ buses_dict = {label: BusResults(self, **infos) for label, infos in self.solution.attrs['Buses'].items()}
+ self.buses = ResultsContainer(elements=buses_dict, element_type_name='bus results')
- self.effects = {label: EffectResults(self, **infos) for label, infos in self.solution.attrs['Effects'].items()}
+ effects_dict = {label: EffectResults(self, **infos) for label, infos in self.solution.attrs['Effects'].items()}
+ self.effects = ResultsContainer(elements=effects_dict, element_type_name='effect results')
if 'Flows' not in self.solution.attrs:
warnings.warn(
@@ -221,15 +258,19 @@ def __init__(
'is not availlable. We recommend to evaluate your results with a version <2.2.0.',
stacklevel=2,
)
- self.flows = {}
+ flows_dict = {}
+ self._has_flow_data = False
else:
- self.flows = {
+ flows_dict = {
label: FlowResults(self, **infos) for label, infos in self.solution.attrs.get('Flows', {}).items()
}
+ self._has_flow_data = True
+ self.flows = ResultsContainer(elements=flows_dict, element_type_name='flow results')
self.timesteps_extra = self.solution.indexes['time']
self.hours_per_timestep = FlowSystem.calculate_hours_per_timestep(self.timesteps_extra)
self.scenarios = self.solution.indexes['scenario'] if 'scenario' in self.solution.indexes else None
+ self.periods = self.solution.indexes['period'] if 'period' in self.solution.indexes else None
self._effect_share_factors = None
self._flow_system = None
@@ -239,16 +280,24 @@ def __init__(
self._sizes = None
self._effects_per_component = None
- def __getitem__(self, key: str) -> ComponentResults | BusResults | EffectResults:
- if key in self.components:
- return self.components[key]
- if key in self.buses:
- return self.buses[key]
- if key in self.effects:
- return self.effects[key]
- if key in self.flows:
- return self.flows[key]
- raise KeyError(f'No element with label {key} found.')
+ self.colors: dict[str, str] = {}
+
+ def _get_container_groups(self) -> dict[str, ResultsContainer]:
+ """Return ordered container groups for CompositeContainerMixin."""
+ return {
+ 'Components': self.components,
+ 'Buses': self.buses,
+ 'Effects': self.effects,
+ 'Flows': self.flows,
+ }
+
+ def __repr__(self) -> str:
+ """Return grouped representation of all results."""
+ r = fx_io.format_title_with_underline(self.__class__.__name__, '=')
+ r += f'Name: "{self.name}"\nFolder: {self.folder}\n'
+ # Add grouped container view
+ r += '\n' + self._format_grouped_containers()
+ return r
@property
def storages(self) -> list[ComponentResults]:
@@ -305,6 +354,131 @@ def flow_system(self) -> FlowSystem:
logger.level = old_level
return self._flow_system
+ def setup_colors(
+ self,
+ config: dict[str, str | list[str]] | str | pathlib.Path | None = None,
+ default_colorscale: str | None = None,
+ ) -> dict[str, str]:
+ """
+ Setup colors for all variables across all elements. Overwrites existing ones.
+
+ Args:
+ config: Configuration for color assignment. Can be:
+ - dict: Maps components to colors/colorscales:
+ * 'component1': 'red' # Single component to single color
+ * 'component1': '#FF0000' # Single component to hex color
+ - OR maps colorscales to multiple components:
+ * 'colorscale_name': ['component1', 'component2'] # Colorscale across components
+ - str: Path to a JSON/YAML config file or a colorscale name to apply to all
+ - Path: Path to a JSON/YAML config file
+ - None: Use default_colorscale for all components
+ default_colorscale: Default colorscale for unconfigured components (default: 'turbo')
+
+ Examples:
+ setup_colors({
+ # Direct component-to-color mappings
+ 'Boiler1': '#FF0000',
+ 'CHP': 'darkred',
+ # Colorscale for multiple components
+ 'Oranges': ['Solar1', 'Solar2'],
+ 'Blues': ['Wind1', 'Wind2'],
+ 'Greens': ['Battery1', 'Battery2', 'Battery3'],
+ })
+
+ Returns:
+ Complete variable-to-color mapping dictionary
+ """
+
+ def get_all_variable_names(comp: str) -> list[str]:
+ """Collect all variables from the component, including flows and flow_hours."""
+ comp_object = self.components[comp]
+ var_names = [comp] + list(comp_object._variable_names)
+ for flow in comp_object.flows:
+ var_names.extend([flow, f'{flow}|flow_hours'])
+ return var_names
+
+ # Set default colorscale if not provided
+ if default_colorscale is None:
+ default_colorscale = CONFIG.Plotting.default_qualitative_colorscale
+
+ # Handle different config input types
+ if config is None:
+ # Apply default colorscale to all components
+ config_dict = {}
+ elif isinstance(config, (str, pathlib.Path)):
+ # Try to load from file first
+ config_path = pathlib.Path(config)
+ if config_path.exists():
+ # Load config from file using helper
+ config_dict = load_mapping_from_file(config_path)
+ else:
+ # Treat as colorscale name to apply to all components
+ all_components = list(self.components.keys())
+ config_dict = {config: all_components}
+ elif isinstance(config, dict):
+ config_dict = config
+ else:
+ raise TypeError(f'config must be dict, str, Path, or None, got {type(config)}')
+
+ # Step 1: Build component-to-color mapping
+ component_colors: dict[str, str] = {}
+
+ # Track which components are configured
+ configured_components = set()
+
+ # Process each configuration entry
+ for key, value in config_dict.items():
+ # Check if value is a list (colorscale -> [components])
+ # or a string (component -> color OR colorscale -> [components])
+
+ if isinstance(value, list):
+ # key is colorscale, value is list of components
+ # Format: 'Blues': ['Wind1', 'Wind2']
+ components = value
+ colorscale_name = key
+
+ # Validate components exist
+ for component in components:
+ if component not in self.components:
+ raise ValueError(f"Component '{component}' not found")
+
+ configured_components.update(components)
+
+ # Use process_colors to get one color per component from the colorscale
+ colors_for_components = process_colors(colorscale_name, components)
+ component_colors.update(colors_for_components)
+
+ elif isinstance(value, str):
+ # Check if key is an existing component
+ if key in self.components:
+ # Format: 'CHP': 'red' (component -> color)
+ component, color = key, value
+
+ configured_components.add(component)
+ component_colors[component] = color
+ else:
+ raise ValueError(f"Component '{key}' not found")
+ else:
+ raise TypeError(f'Config value must be str or list, got {type(value)}')
+
+ # Step 2: Assign colors to remaining unconfigured components
+ remaining_components = list(set(self.components.keys()) - configured_components)
+ if remaining_components:
+ # Use default colorscale to assign one color per remaining component
+ default_colors = process_colors(default_colorscale, remaining_components)
+ component_colors.update(default_colors)
+
+ # Step 3: Build variable-to-color mapping
+ # Clear existing colors to avoid stale keys
+ self.colors = {}
+ # Each component's variables all get the same color as the component
+ for component, color in component_colors.items():
+ variable_names = get_all_variable_names(component)
+ for var_name in variable_names:
+ self.colors[var_name] = color
+
+ return self.colors
+
def filter_solution(
self,
variable_dims: Literal['scalar', 'time', 'scenario', 'timeonly', 'scenarioonly'] | None = None,
@@ -388,6 +562,8 @@ def flow_rates(
To recombine filtered dataarrays, use `xr.concat` with dim 'flow':
>>>xr.concat([results.flow_rates(start='FernwΓ€rme'), results.flow_rates(end='FernwΓ€rme')], dim='flow')
"""
+ if not self._has_flow_data:
+ raise ValueError('Flow data is not available in this results object (pre-v2.2.0).')
if self._flow_rates is None:
self._flow_rates = self._assign_flow_coords(
xr.concat(
@@ -449,6 +625,8 @@ def sizes(
>>>xr.concat([results.sizes(start='FernwΓ€rme'), results.sizes(end='FernwΓ€rme')], dim='flow')
"""
+ if not self._has_flow_data:
+ raise ValueError('Flow data is not available in this results object (pre-v2.2.0).')
if self._sizes is None:
self._sizes = self._assign_flow_coords(
xr.concat(
@@ -461,11 +639,12 @@ def sizes(
def _assign_flow_coords(self, da: xr.DataArray):
# Add start and end coordinates
+ flows_list = list(self.flows.values())
da = da.assign_coords(
{
- 'start': ('flow', [flow.start for flow in self.flows.values()]),
- 'end': ('flow', [flow.end for flow in self.flows.values()]),
- 'component': ('flow', [flow.component for flow in self.flows.values()]),
+ 'start': ('flow', [flow.start for flow in flows_list]),
+ 'end': ('flow', [flow.end for flow in flows_list]),
+ 'component': ('flow', [flow.component for flow in flows_list]),
}
)
@@ -584,8 +763,6 @@ def _compute_effect_total(
temporal = temporal.sum('time')
if periodic.isnull().all():
return temporal.rename(f'{element}->{effect}')
- if 'time' in temporal.indexes:
- temporal = temporal.sum('time')
return periodic + temporal
total = xr.DataArray(0)
@@ -619,6 +796,30 @@ def _compute_effect_total(
total = xr.DataArray(np.nan)
return total.rename(f'{element}->{effect}({mode})')
+ def _create_template_for_mode(self, mode: Literal['temporal', 'periodic', 'total']) -> xr.DataArray:
+ """Create a template DataArray with the correct dimensions for a given mode.
+
+ Args:
+ mode: The calculation mode ('temporal', 'periodic', or 'total').
+
+ Returns:
+ A DataArray filled with NaN, with dimensions appropriate for the mode.
+ """
+ coords = {}
+ if mode == 'temporal':
+ coords['time'] = self.timesteps_extra
+ if self.periods is not None:
+ coords['period'] = self.periods
+ if self.scenarios is not None:
+ coords['scenario'] = self.scenarios
+
+ # Create template with appropriate shape
+ if coords:
+ shape = tuple(len(coords[dim]) for dim in coords)
+ return xr.DataArray(np.full(shape, np.nan, dtype=float), coords=coords, dims=list(coords.keys()))
+ else:
+ return xr.DataArray(np.nan)
+
def _create_effects_dataset(self, mode: Literal['temporal', 'periodic', 'total']) -> xr.Dataset:
"""Creates a dataset containing effect totals for all components (including their flows).
The dataset does contain the direct as well as the indirect effects of each component.
@@ -629,32 +830,23 @@ def _create_effects_dataset(self, mode: Literal['temporal', 'periodic', 'total']
Returns:
An xarray Dataset with components as dimension and effects as variables.
"""
+ # Create template with correct dimensions for this mode
+ template = self._create_template_for_mode(mode)
+
ds = xr.Dataset()
all_arrays = {}
- template = None # Template is needed to determine the dimensions of the arrays. This handles the case of no shares for an effect
-
components_list = list(self.components)
- # First pass: collect arrays and find template
+ # Collect arrays for all effects and components
for effect in self.effects:
effect_arrays = []
for component in components_list:
da = self._compute_effect_total(element=component, effect=effect, mode=mode, include_flows=True)
effect_arrays.append(da)
- if template is None and (da.dims or not da.isnull().all()):
- template = da
-
all_arrays[effect] = effect_arrays
- # Ensure we have a template
- if template is None:
- raise ValueError(
- f"No template with proper dimensions found for mode '{mode}'. "
- f'All computed arrays are scalars, which indicates a data issue.'
- )
-
- # Second pass: process all effects (guaranteed to include all)
+ # Process all effects: expand scalar NaN arrays to match template dimensions
for effect in self.effects:
dataarrays = all_arrays[effect]
component_arrays = []
@@ -687,68 +879,145 @@ def _create_effects_dataset(self, mode: Literal['temporal', 'periodic', 'total']
def plot_heatmap(
self,
- variable_name: str,
- heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D',
- heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h',
- color_map: str = 'portland',
+ variable_name: str | list[str],
save: bool | pathlib.Path = False,
- show: bool = True,
+ show: bool | None = None,
+ colors: plotting.ColorType | None = None,
engine: plotting.PlottingEngine = 'plotly',
+ select: dict[FlowSystemDimensions, Any] | None = None,
+ facet_by: str | list[str] | None = 'scenario',
+ animate_by: str | None = 'period',
+ facet_cols: int | None = None,
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
+ | Literal['auto']
+ | None = 'auto',
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
+ # Deprecated parameters (kept for backwards compatibility)
indexer: dict[FlowSystemDimensions, Any] | None = None,
+ heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None,
+ heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None,
+ color_map: str | None = None,
+ **plot_kwargs: Any,
) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]:
"""
- Plots a heatmap of the solution of a variable.
+ Plots a heatmap visualization of a variable using imshow or time-based reshaping.
+
+ Supports multiple visualization features that can be combined:
+ - **Multi-variable**: Plot multiple variables on a single heatmap (creates 'variable' dimension)
+ - **Time reshaping**: Converts 'time' dimension into 2D (e.g., hours vs days)
+ - **Faceting**: Creates subplots for different dimension values
+ - **Animation**: Animates through dimension values (Plotly only)
Args:
- variable_name: The name of the variable to plot.
- heatmap_timeframes: The timeframes to use for the heatmap.
- heatmap_timesteps_per_frame: The timesteps per frame to use for the heatmap.
- color_map: The color map to use for the heatmap.
+ variable_name: The name of the variable to plot, or a list of variable names.
+ When a list is provided, variables are combined into a single DataArray
+ with a new 'variable' dimension.
save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location.
show: Whether to show the plot or not.
+ colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options.
engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'.
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
- If None, uses first value for each dimension.
- If empty dict {}, uses all values.
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
+ Applied BEFORE faceting/animation/reshaping.
+ facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str)
+ or list of dimensions. Each unique value combination creates a subplot. Ignored if not found.
+ animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through
+ dimension values. Only one dimension can be animated. Ignored if not found.
+ facet_cols: Number of columns in the facet grid layout (default: 3).
+ reshape_time: Time reshaping configuration (default: 'auto'):
+ - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains
+ - Tuple: Explicit reshaping, e.g. ('D', 'h') for days vs hours,
+ ('MS', 'D') for months vs days, ('W', 'h') for weeks vs hours
+ - None: Disable auto-reshaping (will error if only 1D time data)
+ Supported timeframes: 'YS', 'MS', 'W', 'D', 'h', '15min', 'min'
+ fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill).
+ Default is 'ffill'.
+ **plot_kwargs: Additional plotting customization options.
+ Common options:
+
+ - **dpi** (int): Export resolution for saved plots. Default: 300.
+
+ For heatmaps specifically:
+
+ - **vmin** (float): Minimum value for color scale (both engines).
+ - **vmax** (float): Maximum value for color scale (both engines).
+
+ For Matplotlib heatmaps:
+
+ - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow (e.g., interpolation, aspect).
+ - **cbar_kwargs** (dict): Additional kwargs for colorbar customization.
Examples:
- Basic usage (uses first scenario, first period, all time):
+ Direct imshow mode (default):
+
+ >>> results.plot_heatmap('Battery|charge_state', select={'scenario': 'base'})
+
+ Facet by scenario:
- >>> results.plot_heatmap('Battery|charge_state')
+ >>> results.plot_heatmap('Boiler(Qth)|flow_rate', facet_by='scenario', facet_cols=2)
- Select specific scenario and period:
+ Animate by period:
- >>> results.plot_heatmap('Boiler(Qth)|flow_rate', indexer={'scenario': 'base', 'period': 2024})
+ >>> results.plot_heatmap('Boiler(Qth)|flow_rate', select={'scenario': 'base'}, animate_by='period')
- Time filtering (summer months only):
+ Time reshape mode - daily patterns:
+
+ >>> results.plot_heatmap('Boiler(Qth)|flow_rate', select={'scenario': 'base'}, reshape_time=('D', 'h'))
+
+ Combined: time reshaping with faceting and animation:
>>> results.plot_heatmap(
- ... 'Boiler(Qth)|flow_rate',
- ... indexer={
- ... 'scenario': 'base',
- ... 'time': results.solution.time[results.solution.time.dt.month.isin([6, 7, 8])],
- ... },
+ ... 'Boiler(Qth)|flow_rate', facet_by='scenario', animate_by='period', reshape_time=('D', 'h')
... )
- Save to specific location:
+ Multi-variable heatmap (variables as one axis):
>>> results.plot_heatmap(
- ... 'Boiler(Qth)|flow_rate', indexer={'scenario': 'base'}, save='path/to/my_heatmap.html'
+ ... ['Boiler(Q_th)|flow_rate', 'CHP(Q_th)|flow_rate', 'HeatStorage|charge_state'],
+ ... select={'scenario': 'base', 'period': 1},
+ ... reshape_time=None,
... )
- """
- dataarray = self.solution[variable_name]
+ Multi-variable with time reshaping:
+
+ >>> results.plot_heatmap(
+ ... ['Boiler(Q_th)|flow_rate', 'CHP(Q_th)|flow_rate'],
+ ... facet_by='scenario',
+ ... animate_by='period',
+ ... reshape_time=('D', 'h'),
+ ... )
+
+ High-resolution export with custom color range:
+
+ >>> results.plot_heatmap('Battery|charge_state', save=True, dpi=600, vmin=0, vmax=100)
+
+ Matplotlib heatmap with custom imshow settings:
+
+ >>> results.plot_heatmap(
+ ... 'Boiler(Q_th)|flow_rate',
+ ... engine='matplotlib',
+ ... imshow_kwargs={'interpolation': 'bilinear', 'aspect': 'auto'},
+ ... )
+ """
+ # Delegate to module-level plot_heatmap function
return plot_heatmap(
- dataarray=dataarray,
- name=variable_name,
+ data=self.solution[variable_name],
+ name=variable_name if isinstance(variable_name, str) else None,
folder=self.folder,
- heatmap_timeframes=heatmap_timeframes,
- heatmap_timesteps_per_frame=heatmap_timesteps_per_frame,
- color_map=color_map,
+ colors=colors,
save=save,
show=show,
engine=engine,
+ select=select,
+ facet_by=facet_by,
+ animate_by=animate_by,
+ facet_cols=facet_cols,
+ reshape_time=reshape_time,
+ fill=fill,
indexer=indexer,
+ heatmap_timeframes=heatmap_timeframes,
+ heatmap_timesteps_per_frame=heatmap_timesteps_per_frame,
+ color_map=color_map,
+ **plot_kwargs,
)
def plot_network(
@@ -805,14 +1074,13 @@ def to_file(
fx_io.save_dataset_to_netcdf(self.solution, paths.solution, compression=compression)
fx_io.save_dataset_to_netcdf(self.flow_system_data, paths.flow_system, compression=compression)
- with open(paths.summary, 'w', encoding='utf-8') as f:
- yaml.dump(self.summary, f, allow_unicode=True, sort_keys=False, indent=4, width=1000)
+ fx_io.save_yaml(self.summary, paths.summary, compact_numeric_lists=True)
if save_linopy_model:
if self.model is None:
logger.critical('No model in the CalculationResults. Saving the model is not possible.')
else:
- self.model.to_netcdf(paths.linopy_model, engine='h5netcdf')
+ self.model.to_netcdf(paths.linopy_model, engine='netcdf4')
if document_model:
if self.model is None:
@@ -856,6 +1124,14 @@ def constraints(self) -> linopy.Constraints:
raise ValueError('The linopy model is not available.')
return self._calculation_results.model.constraints[self._constraint_names]
+ def __repr__(self) -> str:
+ """Return string representation with element info and dataset preview."""
+ class_name = self.__class__.__name__
+ header = f'{class_name}: "{self.label}"'
+ sol = self.solution.copy(deep=False)
+ sol.attrs = {}
+ return f'{header}\n{"-" * len(header)}\n{repr(sol)}'
+
def filter_solution(
self,
variable_dims: Literal['scalar', 'time', 'scenario', 'timeonly', 'scenarioonly'] | None = None,
@@ -917,54 +1193,182 @@ def __init__(
def plot_node_balance(
self,
save: bool | pathlib.Path = False,
- show: bool = True,
- colors: plotting.ColorType = 'viridis',
+ show: bool | None = None,
+ colors: plotting.ColorType | None = None,
engine: plotting.PlottingEngine = 'plotly',
- indexer: dict[FlowSystemDimensions, Any] | None = None,
- mode: Literal['flow_rate', 'flow_hours'] = 'flow_rate',
- style: Literal['area', 'stacked_bar', 'line'] = 'stacked_bar',
+ select: dict[FlowSystemDimensions, Any] | None = None,
+ unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate',
+ mode: Literal['area', 'stacked_bar', 'line'] = 'stacked_bar',
drop_suffix: bool = True,
+ facet_by: str | list[str] | None = 'scenario',
+ animate_by: str | None = 'period',
+ facet_cols: int | None = None,
+ # Deprecated parameter (kept for backwards compatibility)
+ indexer: dict[FlowSystemDimensions, Any] | None = None,
+ **plot_kwargs: Any,
) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]:
"""
- Plots the node balance of the Component or Bus.
+ Plots the node balance of the Component or Bus with optional faceting and animation.
+
Args:
save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location.
show: Whether to show the plot or not.
colors: The colors to use for the plot. See `flixopt.plotting.ColorType` for options.
engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'.
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
- If None, uses first value for each dimension (except time).
- If empty dict {}, uses all values.
- style: The style to use for the dataset. Can be 'flow_rate' or 'flow_hours'.
+ select: Optional data selection dict. Supports:
+ - Single values: {'scenario': 'base', 'period': 2024}
+ - Multiple values: {'scenario': ['base', 'high', 'renewable']}
+ - Slices: {'time': slice('2024-01', '2024-06')}
+ - Index arrays: {'time': time_array}
+ Note: Applied BEFORE faceting/animation.
+ unit_type: The unit type to use for the dataset. Can be 'flow_rate' or 'flow_hours'.
- 'flow_rate': Returns the flow_rates of the Node.
- 'flow_hours': Returns the flow_hours of the Node. [flow_hours(t) = flow_rate(t) * dt(t)]. Renames suffixes to |flow_hours.
+ mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, or 'area' for stacked area charts.
drop_suffix: Whether to drop the suffix from the variable names.
+ facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str)
+ or list of dimensions. Each unique value combination creates a subplot. Ignored if not found.
+ Example: 'scenario' creates one subplot per scenario.
+ Example: ['scenario', 'period'] creates a grid of subplots for each scenario-period combination.
+ animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through
+ dimension values. Only one dimension can be animated. Ignored if not found.
+ facet_cols: Number of columns in the facet grid layout (default: 3).
+ **plot_kwargs: Additional plotting customization options passed to underlying plotting functions.
+
+ Common options:
+
+ - **dpi** (int): Export resolution in dots per inch. Default: 300.
+
+ **For Plotly engine** (`engine='plotly'`):
+
+ - Any Plotly Express parameter for px.bar()/px.line()/px.area()
+ Example: `range_y=[0, 100]`, `line_shape='linear'`
+
+ **For Matplotlib engine** (`engine='matplotlib'`):
+
+ - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`.
+ Example: `plot_kwargs={'linewidth': 3, 'alpha': 0.7, 'edgecolor': 'black'}`
+
+ See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib`
+ for complete parameter reference.
+
+ Note: For Plotly, you can further customize the returned figure using `fig.update_traces()`
+ and `fig.update_layout()` after calling this method.
+
+ Examples:
+ Basic plot (current behavior):
+
+ >>> results['Boiler'].plot_node_balance()
+
+ Facet by scenario:
+
+ >>> results['Boiler'].plot_node_balance(facet_by='scenario', facet_cols=2)
+
+ Animate by period:
+
+ >>> results['Boiler'].plot_node_balance(animate_by='period')
+
+ Facet by scenario AND animate by period:
+
+ >>> results['Boiler'].plot_node_balance(facet_by='scenario', animate_by='period')
+
+ Select single scenario, then facet by period:
+
+ >>> results['Boiler'].plot_node_balance(select={'scenario': 'base'}, facet_by='period')
+
+ Select multiple scenarios and facet by them:
+
+ >>> results['Boiler'].plot_node_balance(
+ ... select={'scenario': ['base', 'high', 'renewable']}, facet_by='scenario'
+ ... )
+
+ Time range selection (summer months only):
+
+ >>> results['Boiler'].plot_node_balance(select={'time': slice('2024-06', '2024-08')}, facet_by='scenario')
+
+ High-resolution export for publication:
+
+ >>> results['Boiler'].plot_node_balance(engine='matplotlib', save='figure.png', dpi=600)
+
+ Plotly Express customization (e.g., set y-axis range):
+
+ >>> results['Boiler'].plot_node_balance(range_y=[0, 100])
+
+ Custom matplotlib appearance:
+
+ >>> results['Boiler'].plot_node_balance(engine='matplotlib', plot_kwargs={'linewidth': 3, 'alpha': 0.7})
+
+ Further customize Plotly figure after creation:
+
+ >>> fig = results['Boiler'].plot_node_balance(mode='line', show=False)
+ >>> fig.update_traces(line={'width': 5, 'dash': 'dot'})
+ >>> fig.update_layout(template='plotly_dark', width=1200, height=600)
+ >>> fig.show()
"""
- ds = self.node_balance(with_last_timestep=True, mode=mode, drop_suffix=drop_suffix, indexer=indexer)
+ # Handle deprecated indexer parameter
+ if indexer is not None:
+ # Check for conflict with new parameter
+ if select is not None:
+ raise ValueError(
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ select = indexer
+
+ if engine not in {'plotly', 'matplotlib'}:
+ raise ValueError(f'Engine "{engine}" not supported. Use one of ["plotly", "matplotlib"]')
- ds, suffix_parts = _apply_indexer_to_data(ds, indexer, drop=True)
+ # Extract dpi for export_figure
+ dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi
+
+ # Don't pass select/indexer to node_balance - we'll apply it afterwards
+ ds = self.node_balance(with_last_timestep=False, unit_type=unit_type, drop_suffix=drop_suffix)
+
+ ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True)
+
+ # Matplotlib requires only 'time' dimension; check for extras after selection
+ if engine == 'matplotlib':
+ extra_dims = [d for d in ds.dims if d != 'time']
+ if extra_dims:
+ raise ValueError(
+ f'Matplotlib engine only supports a single time axis, but found extra dimensions: {extra_dims}. '
+ f'Please use select={{...}} to reduce dimensions or switch to engine="plotly" for faceting/animation.'
+ )
suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
- title = f'{self.label} (flow rates){suffix}' if mode == 'flow_rate' else f'{self.label} (flow hours){suffix}'
+ title = (
+ f'{self.label} (flow rates){suffix}' if unit_type == 'flow_rate' else f'{self.label} (flow hours){suffix}'
+ )
if engine == 'plotly':
figure_like = plotting.with_plotly(
- ds.to_dataframe(),
- colors=colors,
- style=style,
+ ds,
+ facet_by=facet_by,
+ animate_by=animate_by,
+ colors=colors if colors is not None else self._calculation_results.colors,
+ mode=mode,
title=title,
+ facet_cols=facet_cols,
+ xlabel='Time in h',
+ **plot_kwargs,
)
default_filetype = '.html'
- elif engine == 'matplotlib':
+ else:
figure_like = plotting.with_matplotlib(
- ds.to_dataframe(),
- colors=colors,
- style=style,
+ ds,
+ colors=colors if colors is not None else self._calculation_results.colors,
+ mode=mode,
title=title,
+ **plot_kwargs,
)
default_filetype = '.png'
- else:
- raise ValueError(f'Engine "{engine}" not supported. Use "plotly" or "matplotlib"')
return plotting.export_figure(
figure_like=figure_like,
@@ -973,19 +1377,31 @@ def plot_node_balance(
user_path=None if isinstance(save, bool) else pathlib.Path(save),
show=show,
save=True if save else False,
+ dpi=dpi,
)
def plot_node_balance_pie(
self,
lower_percentage_group: float = 5,
- colors: plotting.ColorType = 'viridis',
+ colors: plotting.ColorType | None = None,
text_info: str = 'percent+label+value',
save: bool | pathlib.Path = False,
- show: bool = True,
+ show: bool | None = None,
engine: plotting.PlottingEngine = 'plotly',
+ select: dict[FlowSystemDimensions, Any] | None = None,
+ # Deprecated parameter (kept for backwards compatibility)
indexer: dict[FlowSystemDimensions, Any] | None = None,
+ **plot_kwargs: Any,
) -> plotly.graph_objs.Figure | tuple[plt.Figure, list[plt.Axes]]:
"""Plot pie chart of flow hours distribution.
+
+ Note:
+ Pie charts require scalar data (no extra dimensions beyond time).
+ If your data has dimensions like 'scenario' or 'period', either:
+
+ - Use `select` to choose specific values: `select={'scenario': 'base', 'period': 2024}`
+ - Let auto-selection choose the first value (a warning will be logged)
+
Args:
lower_percentage_group: Percentage threshold for "Others" grouping.
colors: Color scheme. Also see plotly.
@@ -993,10 +1409,57 @@ def plot_node_balance_pie(
save: Whether to save plot.
show: Whether to display plot.
engine: Plotting engine ('plotly' or 'matplotlib').
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
- If None, uses first value for each dimension.
- If empty dict {}, uses all values.
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
+ Use this to select specific scenario/period before creating the pie chart.
+ **plot_kwargs: Additional plotting customization options.
+
+ Common options:
+
+ - **dpi** (int): Export resolution in dots per inch. Default: 300.
+ - **hover_template** (str): Hover text template (Plotly only).
+ Example: `hover_template='%{label}: %{value} (%{percent})'`
+ - **text_position** (str): Text position ('inside', 'outside', 'auto').
+ - **hole** (float): Size of donut hole (0.0 to 1.0).
+
+ See :func:`flixopt.plotting.dual_pie_with_plotly` for complete reference.
+
+ Examples:
+ Basic usage (auto-selects first scenario/period if present):
+
+ >>> results['Bus'].plot_node_balance_pie()
+
+ Explicitly select a scenario and period:
+
+ >>> results['Bus'].plot_node_balance_pie(select={'scenario': 'high_demand', 'period': 2030})
+
+ Create a donut chart with custom hover text:
+
+ >>> results['Bus'].plot_node_balance_pie(hole=0.4, hover_template='%{label}: %{value:.2f} (%{percent})')
+
+ High-resolution export:
+
+ >>> results['Bus'].plot_node_balance_pie(save='figure.png', dpi=600)
"""
+ # Handle deprecated indexer parameter
+ if indexer is not None:
+ # Check for conflict with new parameter
+ if select is not None:
+ raise ValueError(
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ select = indexer
+
+ # Extract dpi for export_figure
+ dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi
+
inputs = sanitize_dataset(
ds=self.solution[self.inputs] * self._calculation_results.hours_per_timestep,
threshold=1e-5,
@@ -1012,25 +1475,58 @@ def plot_node_balance_pie(
drop_suffix='|',
)
- inputs, suffix_parts = _apply_indexer_to_data(inputs, indexer, drop=True)
- outputs, suffix_parts = _apply_indexer_to_data(outputs, indexer, drop=True)
- suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
-
- title = f'{self.label} (total flow hours){suffix}'
+ inputs, suffix_parts_in = _apply_selection_to_data(inputs, select=select, drop=True)
+ outputs, suffix_parts_out = _apply_selection_to_data(outputs, select=select, drop=True)
+ suffix_parts = suffix_parts_in + suffix_parts_out
+ # Sum over time dimension
inputs = inputs.sum('time')
outputs = outputs.sum('time')
+ # Auto-select first value for any remaining dimensions (scenario, period, etc.)
+ # Pie charts need scalar data, so we automatically reduce extra dimensions
+ extra_dims_inputs = [dim for dim in inputs.dims if dim != 'time']
+ extra_dims_outputs = [dim for dim in outputs.dims if dim != 'time']
+ extra_dims = sorted(set(extra_dims_inputs + extra_dims_outputs))
+
+ if extra_dims:
+ auto_select = {}
+ for dim in extra_dims:
+ # Get first value of this dimension
+ if dim in inputs.coords:
+ first_val = inputs.coords[dim].values[0]
+ elif dim in outputs.coords:
+ first_val = outputs.coords[dim].values[0]
+ else:
+ continue
+ auto_select[dim] = first_val
+ logger.info(
+ f'Pie chart auto-selected {dim}={first_val} (first value). '
+ f'Use select={{"{dim}": value}} to choose a different value.'
+ )
+
+ # Apply auto-selection only for coords present in each dataset
+ inputs = inputs.sel({k: v for k, v in auto_select.items() if k in inputs.coords})
+ outputs = outputs.sel({k: v for k, v in auto_select.items() if k in outputs.coords})
+
+ # Update suffix with auto-selected values
+ auto_suffix_parts = [f'{dim}={val}' for dim, val in auto_select.items()]
+ suffix_parts.extend(auto_suffix_parts)
+
+ suffix = '--' + '-'.join(sorted(set(suffix_parts))) if suffix_parts else ''
+ title = f'{self.label} (total flow hours){suffix}'
+
if engine == 'plotly':
figure_like = plotting.dual_pie_with_plotly(
- data_left=inputs.to_pandas(),
- data_right=outputs.to_pandas(),
- colors=colors,
+ data_left=inputs,
+ data_right=outputs,
+ colors=colors if colors is not None else self._calculation_results.colors,
title=title,
text_info=text_info,
subtitles=('Inputs', 'Outputs'),
legend_title='Flows',
lower_percentage_group=lower_percentage_group,
+ **plot_kwargs,
)
default_filetype = '.html'
elif engine == 'matplotlib':
@@ -1038,11 +1534,12 @@ def plot_node_balance_pie(
figure_like = plotting.dual_pie_with_matplotlib(
data_left=inputs.to_pandas(),
data_right=outputs.to_pandas(),
- colors=colors,
+ colors=colors if colors is not None else self._calculation_results.colors,
title=title,
subtitles=('Inputs', 'Outputs'),
legend_title='Flows',
lower_percentage_group=lower_percentage_group,
+ **plot_kwargs,
)
default_filetype = '.png'
else:
@@ -1055,6 +1552,7 @@ def plot_node_balance_pie(
user_path=None if isinstance(save, bool) else pathlib.Path(save),
show=show,
save=True if save else False,
+ dpi=dpi,
)
def node_balance(
@@ -1063,8 +1561,10 @@ def node_balance(
negate_outputs: bool = False,
threshold: float | None = 1e-5,
with_last_timestep: bool = False,
- mode: Literal['flow_rate', 'flow_hours'] = 'flow_rate',
+ unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate',
drop_suffix: bool = False,
+ select: dict[FlowSystemDimensions, Any] | None = None,
+ # Deprecated parameter (kept for backwards compatibility)
indexer: dict[FlowSystemDimensions, Any] | None = None,
) -> xr.Dataset:
"""
@@ -1074,14 +1574,29 @@ def node_balance(
negate_outputs: Whether to negate the output flow_rates of the Node.
threshold: The threshold for small values. Variables with all values below the threshold are dropped.
with_last_timestep: Whether to include the last timestep in the dataset.
- mode: The mode to use for the dataset. Can be 'flow_rate' or 'flow_hours'.
+ unit_type: The unit type to use for the dataset. Can be 'flow_rate' or 'flow_hours'.
- 'flow_rate': Returns the flow_rates of the Node.
- 'flow_hours': Returns the flow_hours of the Node. [flow_hours(t) = flow_rate(t) * dt(t)]. Renames suffixes to |flow_hours.
drop_suffix: Whether to drop the suffix from the variable names.
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
- If None, uses first value for each dimension.
- If empty dict {}, uses all values.
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
"""
+ # Handle deprecated indexer parameter
+ if indexer is not None:
+ # Check for conflict with new parameter
+ if select is not None:
+ raise ValueError(
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ select = indexer
+
ds = self.solution[self.inputs + self.outputs]
ds = sanitize_dataset(
@@ -1100,9 +1615,9 @@ def node_balance(
drop_suffix='|' if drop_suffix else None,
)
- ds, _ = _apply_indexer_to_data(ds, indexer, drop=True)
+ ds, _ = _apply_selection_to_data(ds, select=select, drop=True)
- if mode == 'flow_hours':
+ if unit_type == 'flow_hours':
ds = ds * self._calculation_results.hours_per_timestep
ds = ds.rename_vars({var: var.replace('flow_rate', 'flow_hours') for var in ds.data_vars})
@@ -1134,75 +1649,221 @@ def charge_state(self) -> xr.DataArray:
def plot_charge_state(
self,
save: bool | pathlib.Path = False,
- show: bool = True,
- colors: plotting.ColorType = 'viridis',
+ show: bool | None = None,
+ colors: plotting.ColorType | None = None,
engine: plotting.PlottingEngine = 'plotly',
- style: Literal['area', 'stacked_bar', 'line'] = 'stacked_bar',
+ mode: Literal['area', 'stacked_bar', 'line'] = 'area',
+ select: dict[FlowSystemDimensions, Any] | None = None,
+ facet_by: str | list[str] | None = 'scenario',
+ animate_by: str | None = 'period',
+ facet_cols: int | None = None,
+ # Deprecated parameter (kept for backwards compatibility)
indexer: dict[FlowSystemDimensions, Any] | None = None,
+ **plot_kwargs: Any,
) -> plotly.graph_objs.Figure:
- """Plot storage charge state over time, combined with the node balance.
+ """Plot storage charge state over time, combined with the node balance with optional faceting and animation.
Args:
save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location.
show: Whether to show the plot or not.
colors: Color scheme. Also see plotly.
engine: Plotting engine to use. Only 'plotly' is implemented atm.
- style: The colors to use for the plot. See `flixopt.plotting.ColorType` for options.
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
- If None, uses first value for each dimension.
- If empty dict {}, uses all values.
+ mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, or 'area' for stacked area charts.
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
+ Applied BEFORE faceting/animation.
+ facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str)
+ or list of dimensions. Each unique value combination creates a subplot. Ignored if not found.
+ animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through
+ dimension values. Only one dimension can be animated. Ignored if not found.
+ facet_cols: Number of columns in the facet grid layout (default: 3).
+ **plot_kwargs: Additional plotting customization options passed to underlying plotting functions.
+
+ Common options:
+
+ - **dpi** (int): Export resolution in dots per inch. Default: 300.
+
+ **For Plotly engine:**
+
+ - Any Plotly Express parameter for px.bar()/px.line()/px.area()
+ Example: `range_y=[0, 100]`, `line_shape='linear'`
+
+ **For Matplotlib engine:**
+
+ - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`.
+
+ See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib`
+ for complete parameter reference.
+
+ Note: For Plotly, you can further customize the returned figure using `fig.update_traces()`
+ and `fig.update_layout()` after calling this method.
Raises:
ValueError: If component is not a storage.
+
+ Examples:
+ Basic plot:
+
+ >>> results['Storage'].plot_charge_state()
+
+ Facet by scenario:
+
+ >>> results['Storage'].plot_charge_state(facet_by='scenario', facet_cols=2)
+
+ Animate by period:
+
+ >>> results['Storage'].plot_charge_state(animate_by='period')
+
+ Facet by scenario AND animate by period:
+
+ >>> results['Storage'].plot_charge_state(facet_by='scenario', animate_by='period')
+
+ Custom layout after creation:
+
+ >>> fig = results['Storage'].plot_charge_state(show=False)
+ >>> fig.update_layout(template='plotly_dark', height=800)
+ >>> fig.show()
+
+ High-resolution export:
+
+ >>> results['Storage'].plot_charge_state(save='storage.png', dpi=600)
"""
+ # Handle deprecated indexer parameter
+ if indexer is not None:
+ # Check for conflict with new parameter
+ if select is not None:
+ raise ValueError(
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'indexer' parameter is deprecated and will be removed in a future version. Use 'select' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ select = indexer
+
+ # Extract dpi for export_figure
+ dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi
+
+ # Extract charge state line color (for overlay customization)
+ overlay_color = plot_kwargs.pop('charge_state_line_color', 'black')
+
if not self.is_storage:
raise ValueError(f'Cant plot charge_state. "{self.label}" is not a storage')
- ds = self.node_balance(with_last_timestep=True, indexer=indexer)
- charge_state = self.charge_state
+ # Get node balance and charge state
+ ds = self.node_balance(with_last_timestep=True).fillna(0)
+ charge_state_da = self.charge_state
- ds, suffix_parts = _apply_indexer_to_data(ds, indexer, drop=True)
- charge_state, suffix_parts = _apply_indexer_to_data(charge_state, indexer, drop=True)
+ # Apply select filtering
+ ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True)
+ charge_state_da, _ = _apply_selection_to_data(charge_state_da, select=select, drop=True)
suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
title = f'Operation Balance of {self.label}{suffix}'
if engine == 'plotly':
- fig = plotting.with_plotly(
- ds.to_dataframe(),
- colors=colors,
- style=style,
+ # Plot flows (node balance) with the specified mode
+ figure_like = plotting.with_plotly(
+ ds,
+ facet_by=facet_by,
+ animate_by=animate_by,
+ colors=colors if colors is not None else self._calculation_results.colors,
+ mode=mode,
title=title,
+ facet_cols=facet_cols,
+ xlabel='Time in h',
+ **plot_kwargs,
)
- # TODO: Use colors for charge state?
-
- charge_state = charge_state.to_dataframe()
- fig.add_trace(
- plotly.graph_objs.Scatter(
- x=charge_state.index, y=charge_state.values.flatten(), mode='lines', name=self._charge_state
- )
+ # Prepare charge_state as Dataset for plotting
+ charge_state_ds = xr.Dataset({self._charge_state: charge_state_da})
+
+ # Plot charge_state with mode='line' to get Scatter traces
+ charge_state_fig = plotting.with_plotly(
+ charge_state_ds,
+ facet_by=facet_by,
+ animate_by=animate_by,
+ colors=colors if colors is not None else self._calculation_results.colors,
+ mode='line', # Always line for charge_state
+ title='', # No title needed for this temp figure
+ facet_cols=facet_cols,
+ xlabel='Time in h',
+ **plot_kwargs,
)
+
+ # Add charge_state traces to the main figure
+ # This preserves subplot assignments and animation frames
+ for trace in charge_state_fig.data:
+ trace.line.width = 2 # Make charge_state line more prominent
+ trace.line.shape = 'linear' # Smooth line for charge state (not stepped like flows)
+ trace.line.color = overlay_color
+ figure_like.add_trace(trace)
+
+ # Also add traces from animation frames if they exist
+ # Both figures use the same animate_by parameter, so they should have matching frames
+ if hasattr(charge_state_fig, 'frames') and charge_state_fig.frames:
+ # Add charge_state traces to each frame
+ for i, frame in enumerate(charge_state_fig.frames):
+ if i < len(figure_like.frames):
+ for trace in frame.data:
+ trace.line.width = 2
+ trace.line.shape = 'linear' # Smooth line for charge state
+ trace.line.color = overlay_color
+ figure_like.frames[i].data = figure_like.frames[i].data + (trace,)
+
+ default_filetype = '.html'
elif engine == 'matplotlib':
+ # Matplotlib requires only 'time' dimension; check for extras after selection
+ extra_dims = [d for d in ds.dims if d != 'time']
+ if extra_dims:
+ raise ValueError(
+ f'Matplotlib engine only supports a single time axis, but found extra dimensions: {extra_dims}. '
+ f'Please use select={{...}} to reduce dimensions or switch to engine="plotly" for faceting/animation.'
+ )
+ # For matplotlib, plot flows (node balance), then add charge_state as line
fig, ax = plotting.with_matplotlib(
- ds.to_dataframe(),
- colors=colors,
- style=style,
+ ds,
+ colors=colors if colors is not None else self._calculation_results.colors,
+ mode=mode,
title=title,
+ **plot_kwargs,
)
- charge_state = charge_state.to_dataframe()
- ax.plot(charge_state.index, charge_state.values.flatten(), label=self._charge_state)
+ # Add charge_state as a line overlay
+ charge_state_df = charge_state_da.to_dataframe()
+ ax.plot(
+ charge_state_df.index,
+ charge_state_df.values.flatten(),
+ label=self._charge_state,
+ linewidth=2,
+ color=overlay_color,
+ )
+ # Recreate legend with the same styling as with_matplotlib
+ handles, labels = ax.get_legend_handles_labels()
+ ax.legend(
+ handles,
+ labels,
+ loc='upper center',
+ bbox_to_anchor=(0.5, -0.15),
+ ncol=5,
+ frameon=False,
+ )
fig.tight_layout()
- fig = fig, ax
+
+ figure_like = fig, ax
+ default_filetype = '.png'
return plotting.export_figure(
- fig,
+ figure_like=figure_like,
default_path=self._calculation_results.folder / title,
- default_filetype='.html',
+ default_filetype=default_filetype,
user_path=None if isinstance(save, bool) else pathlib.Path(save),
show=show,
save=True if save else False,
+ dpi=dpi,
)
def node_balance_with_charge_state(
@@ -1412,8 +2073,7 @@ def from_file(cls, folder: str | pathlib.Path, name: str) -> SegmentedCalculatio
folder = pathlib.Path(folder)
path = folder / name
logger.info(f'loading calculation "{name}" from file ("{path.with_suffix(".nc4")}")')
- with open(path.with_suffix('.json'), encoding='utf-8') as f:
- meta_data = json.load(f)
+ meta_data = fx_io.load_json(path.with_suffix('.json'))
return cls(
[CalculationResults.from_file(folder, sub_name) for sub_name in meta_data['sub_calculations']],
all_timesteps=pd.DatetimeIndex(
@@ -1441,6 +2101,7 @@ def __init__(
self.name = name
self.folder = pathlib.Path(folder) if folder is not None else pathlib.Path.cwd() / 'results'
self.hours_per_timestep = FlowSystem.calculate_hours_per_timestep(self.all_timesteps)
+ self._colors = {}
@property
def meta_data(self) -> dict[str, int | list[str]]:
@@ -1455,6 +2116,64 @@ def meta_data(self) -> dict[str, int | list[str]]:
def segment_names(self) -> list[str]:
return [segment.name for segment in self.segment_results]
+ @property
+ def colors(self) -> dict[str, str]:
+ return self._colors
+
+ @colors.setter
+ def colors(self, colors: dict[str, str]):
+ """Applies colors to all segments"""
+ self._colors = colors
+ for segment in self.segment_results:
+ segment.colors = copy.deepcopy(colors)
+
+ def setup_colors(
+ self,
+ config: dict[str, str | list[str]] | str | pathlib.Path | None = None,
+ default_colorscale: str | None = None,
+ ) -> dict[str, str]:
+ """
+ Setup colors for all variables across all segment results.
+
+ This method applies the same color configuration to all segments, ensuring
+ consistent visualization across the entire segmented calculation. The color
+ mapping is propagated to each segment's CalculationResults instance.
+
+ Args:
+ config: Configuration for color assignment. Can be:
+ - dict: Maps components to colors/colorscales:
+ * 'component1': 'red' # Single component to single color
+ * 'component1': '#FF0000' # Single component to hex color
+ - OR maps colorscales to multiple components:
+ * 'colorscale_name': ['component1', 'component2'] # Colorscale across components
+ - str: Path to a JSON/YAML config file or a colorscale name to apply to all
+ - Path: Path to a JSON/YAML config file
+ - None: Use default_colorscale for all components
+ default_colorscale: Default colorscale for unconfigured components (default: 'turbo')
+
+ Examples:
+ ```python
+ # Apply colors to all segments
+ segmented_results.setup_colors(
+ {
+ 'CHP': 'red',
+ 'Blues': ['Storage1', 'Storage2'],
+ 'Oranges': ['Solar1', 'Solar2'],
+ }
+ )
+
+ # Use a single colorscale for all components in all segments
+ segmented_results.setup_colors('portland')
+ ```
+
+ Returns:
+ Complete variable-to-color mapping dictionary from the first segment
+ (all segments will have the same mapping)
+ """
+ self.colors = self.segment_results[0].setup_colors(config=config, default_colorscale=default_colorscale)
+
+ return self.colors
+
def solution_without_overlap(self, variable_name: str) -> xr.DataArray:
"""Get variable solution removing segment overlaps.
@@ -1473,37 +2192,108 @@ def solution_without_overlap(self, variable_name: str) -> xr.DataArray:
def plot_heatmap(
self,
variable_name: str,
- heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D',
- heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h',
- color_map: str = 'portland',
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
+ | Literal['auto']
+ | None = 'auto',
+ colors: plotting.ColorType | None = None,
save: bool | pathlib.Path = False,
- show: bool = True,
+ show: bool | None = None,
engine: plotting.PlottingEngine = 'plotly',
+ facet_by: str | list[str] | None = None,
+ animate_by: str | None = None,
+ facet_cols: int | None = None,
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
+ # Deprecated parameters (kept for backwards compatibility)
+ heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None,
+ heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None,
+ color_map: str | None = None,
+ **plot_kwargs: Any,
) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]:
"""Plot heatmap of variable solution across segments.
Args:
variable_name: Variable to plot.
- heatmap_timeframes: Time aggregation level.
- heatmap_timesteps_per_frame: Timesteps per frame.
- color_map: Color scheme. Also see plotly.
+ reshape_time: Time reshaping configuration (default: 'auto'):
+ - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains
+ - Tuple like ('D', 'h'): Explicit reshaping (days vs hours)
+ - None: Disable time reshaping
+ colors: Color scheme. See plotting.ColorType for options.
save: Whether to save plot.
show: Whether to display plot.
engine: Plotting engine.
+ facet_by: Dimension(s) to create facets (subplots) for.
+ animate_by: Dimension to animate over (Plotly only).
+ facet_cols: Number of columns in the facet grid layout.
+ fill: Method to fill missing values: 'ffill' or 'bfill'.
+ heatmap_timeframes: (Deprecated) Use reshape_time instead.
+ heatmap_timesteps_per_frame: (Deprecated) Use reshape_time instead.
+ color_map: (Deprecated) Use colors instead.
+ **plot_kwargs: Additional plotting customization options.
+ Common options:
+
+ - **dpi** (int): Export resolution for saved plots. Default: 300.
+ - **vmin** (float): Minimum value for color scale.
+ - **vmax** (float): Maximum value for color scale.
+
+ For Matplotlib heatmaps:
+
+ - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow.
+ - **cbar_kwargs** (dict): Additional kwargs for colorbar customization.
Returns:
Figure object.
"""
+ # Handle deprecated parameters
+ if heatmap_timeframes is not None or heatmap_timesteps_per_frame is not None:
+ # Check for conflict with new parameter
+ if reshape_time != 'auto': # Check if user explicitly set reshape_time
+ raise ValueError(
+ "Cannot use both deprecated parameters 'heatmap_timeframes'/'heatmap_timesteps_per_frame' "
+ "and new parameter 'reshape_time'. Use only 'reshape_time'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'heatmap_timeframes' and 'heatmap_timesteps_per_frame' parameters are deprecated. "
+ "Use 'reshape_time=(timeframes, timesteps_per_frame)' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ # Override reshape_time if old parameters provided
+ if heatmap_timeframes is not None and heatmap_timesteps_per_frame is not None:
+ reshape_time = (heatmap_timeframes, heatmap_timesteps_per_frame)
+
+ if color_map is not None:
+ # Check for conflict with new parameter
+ if colors is not None: # Check if user explicitly set colors
+ raise ValueError(
+ "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'color_map' parameter is deprecated. Use 'colors' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ colors = color_map
+
return plot_heatmap(
- dataarray=self.solution_without_overlap(variable_name),
+ data=self.solution_without_overlap(variable_name),
name=variable_name,
folder=self.folder,
- heatmap_timeframes=heatmap_timeframes,
- heatmap_timesteps_per_frame=heatmap_timesteps_per_frame,
- color_map=color_map,
+ reshape_time=reshape_time,
+ colors=colors,
save=save,
show=show,
engine=engine,
+ facet_by=facet_by,
+ animate_by=animate_by,
+ facet_cols=facet_cols,
+ fill=fill,
+ **plot_kwargs,
)
def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = None, compression: int = 5):
@@ -1527,69 +2317,227 @@ def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = N
for segment in self.segment_results:
segment.to_file(folder=folder, name=segment.name, compression=compression)
- with open(path.with_suffix('.json'), 'w', encoding='utf-8') as f:
- json.dump(self.meta_data, f, indent=4, ensure_ascii=False)
+ fx_io.save_json(self.meta_data, path.with_suffix('.json'))
logger.info(f'Saved calculation "{name}" to {path}')
def plot_heatmap(
- dataarray: xr.DataArray,
- name: str,
- folder: pathlib.Path,
- heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] = 'D',
- heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] = 'h',
- color_map: str = 'portland',
+ data: xr.DataArray | xr.Dataset,
+ name: str | None = None,
+ folder: pathlib.Path | None = None,
+ colors: plotting.ColorType | None = None,
save: bool | pathlib.Path = False,
- show: bool = True,
+ show: bool | None = None,
engine: plotting.PlottingEngine = 'plotly',
+ select: dict[str, Any] | None = None,
+ facet_by: str | list[str] | None = None,
+ animate_by: str | None = None,
+ facet_cols: int | None = None,
+ reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']]
+ | Literal['auto']
+ | None = 'auto',
+ fill: Literal['ffill', 'bfill'] | None = 'ffill',
+ # Deprecated parameters (kept for backwards compatibility)
indexer: dict[str, Any] | None = None,
+ heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None,
+ heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None,
+ color_map: str | None = None,
+ **plot_kwargs: Any,
):
- """Plot heatmap of time series data.
+ """Plot heatmap visualization with support for multi-variable, faceting, and animation.
+
+ This function provides a standalone interface to the heatmap plotting capabilities,
+ supporting the same modern features as CalculationResults.plot_heatmap().
Args:
- dataarray: Data to plot.
- name: Variable name for title.
- folder: Save folder.
- heatmap_timeframes: Time aggregation level.
- heatmap_timesteps_per_frame: Timesteps per frame.
- color_map: Color scheme. Also see plotly.
- save: Whether to save plot.
- show: Whether to display plot.
- engine: Plotting engine.
- indexer: Optional selection dict, e.g., {'scenario': 'base', 'period': 2024}.
- If None, uses first value for each dimension.
- If empty dict {}, uses all values.
+ data: Data to plot. Can be a single DataArray or an xarray Dataset.
+ When a Dataset is provided, all data variables are combined along a new 'variable' dimension.
+ name: Optional name for the title. If not provided, uses the DataArray name or
+ generates a default title for Datasets.
+ folder: Save folder for the plot. Defaults to current directory if not provided.
+ colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options.
+ save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location.
+ show: Whether to show the plot or not.
+ engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'.
+ select: Optional data selection dict. Supports single values, lists, slices, and index arrays.
+ facet_by: Dimension(s) to create facets (subplots) for. Can be a single dimension name (str)
+ or list of dimensions. Each unique value combination creates a subplot.
+ animate_by: Dimension to animate over (Plotly only). Creates animation frames.
+ facet_cols: Number of columns in the facet grid layout (default: 3).
+ reshape_time: Time reshaping configuration (default: 'auto'):
+ - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains
+ - Tuple: Explicit reshaping, e.g. ('D', 'h') for days vs hours
+ - None: Disable auto-reshaping
+ fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill).
+ Default is 'ffill'.
+
+ Examples:
+ Single DataArray with time reshaping:
+
+ >>> plot_heatmap(data, name='Temperature', folder=Path('.'), reshape_time=('D', 'h'))
+
+ Dataset with multiple variables (facet by variable):
+
+ >>> dataset = xr.Dataset({'Boiler': data1, 'CHP': data2, 'Storage': data3})
+ >>> plot_heatmap(
+ ... dataset,
+ ... folder=Path('.'),
+ ... facet_by='variable',
+ ... reshape_time=('D', 'h'),
+ ... )
+
+ Dataset with animation by variable:
+
+ >>> plot_heatmap(dataset, animate_by='variable', reshape_time=('D', 'h'))
"""
- dataarray, suffix_parts = _apply_indexer_to_data(dataarray, indexer, drop=True)
+ # Handle deprecated heatmap time parameters
+ if heatmap_timeframes is not None or heatmap_timesteps_per_frame is not None:
+ # Check for conflict with new parameter
+ if reshape_time != 'auto': # User explicitly set reshape_time
+ raise ValueError(
+ "Cannot use both deprecated parameters 'heatmap_timeframes'/'heatmap_timesteps_per_frame' "
+ "and new parameter 'reshape_time'. Use only 'reshape_time'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'heatmap_timeframes' and 'heatmap_timesteps_per_frame' parameters are deprecated. "
+ "Use 'reshape_time=(timeframes, timesteps_per_frame)' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ # Override reshape_time if both old parameters provided
+ if heatmap_timeframes is not None and heatmap_timesteps_per_frame is not None:
+ reshape_time = (heatmap_timeframes, heatmap_timesteps_per_frame)
+
+ # Handle deprecated color_map parameter
+ if color_map is not None:
+ if colors is not None: # User explicitly set colors
+ raise ValueError(
+ "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'color_map' parameter is deprecated. Use 'colors' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ colors = color_map
+
+ # Handle deprecated indexer parameter
+ if indexer is not None:
+ # Check for conflict with new parameter
+ if select is not None: # User explicitly set select
+ raise ValueError(
+ "Cannot use both deprecated parameter 'indexer' and new parameter 'select'. Use only 'select'."
+ )
+
+ import warnings
+
+ warnings.warn(
+ "The 'indexer' parameter is deprecated. Use 'select' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ select = indexer
+
+ # Convert Dataset to DataArray with 'variable' dimension
+ if isinstance(data, xr.Dataset):
+ # Extract all data variables from the Dataset
+ variable_names = list(data.data_vars)
+ dataarrays = [data[var] for var in variable_names]
+
+ # Combine into single DataArray with 'variable' dimension
+ data = xr.concat(dataarrays, dim='variable')
+ data = data.assign_coords(variable=variable_names)
+
+ # Use Dataset variable names for title if name not provided
+ if name is None:
+ title_name = f'Heatmap of {len(variable_names)} variables'
+ else:
+ title_name = name
+ else:
+ # Single DataArray
+ if name is None:
+ title_name = data.name if data.name else 'Heatmap'
+ else:
+ title_name = name
+
+ # Apply select filtering
+ data, suffix_parts = _apply_selection_to_data(data, select=select, drop=True)
suffix = '--' + '-'.join(suffix_parts) if suffix_parts else ''
- name = name if not suffix_parts else name + suffix
- heatmap_data = plotting.heat_map_data_from_df(
- dataarray.to_dataframe(name), heatmap_timeframes, heatmap_timesteps_per_frame, 'ffill'
- )
+ # Matplotlib heatmaps require at most 2D data
+ # Time dimension will be reshaped to 2D (timeframe Γ timestep), so can't have other dims alongside it
+ if engine == 'matplotlib':
+ dims = list(data.dims)
+
+ # If 'time' dimension exists and will be reshaped, we can't have any other dimensions
+ if 'time' in dims and len(dims) > 1 and reshape_time is not None:
+ extra_dims = [d for d in dims if d != 'time']
+ raise ValueError(
+ f'Matplotlib heatmaps with time reshaping cannot have additional dimensions. '
+ f'Found extra dimensions: {extra_dims}. '
+ f'Use select={{...}} to reduce to time only, use "reshape_time=None" or switch to engine="plotly" or use for multi-dimensional support.'
+ )
+ # If no 'time' dimension (already reshaped or different data), allow at most 2 dimensions
+ elif 'time' not in dims and len(dims) > 2:
+ raise ValueError(
+ f'Matplotlib heatmaps support at most 2 dimensions, but data has {len(dims)}: {dims}. '
+ f'Use select={{...}} to reduce dimensions or switch to engine="plotly".'
+ )
+
+ # Build title
+ title = f'{title_name}{suffix}'
+ if isinstance(reshape_time, tuple):
+ timeframes, timesteps_per_frame = reshape_time
+ title += f' ({timeframes} vs {timesteps_per_frame})'
- xlabel, ylabel = f'timeframe [{heatmap_timeframes}]', f'timesteps [{heatmap_timesteps_per_frame}]'
+ # Extract dpi before passing to plotting functions
+ dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi
+ # Plot with appropriate engine
if engine == 'plotly':
- figure_like = plotting.heat_map_plotly(
- heatmap_data, title=name, color_map=color_map, xlabel=xlabel, ylabel=ylabel
+ figure_like = plotting.heatmap_with_plotly(
+ data=data,
+ facet_by=facet_by,
+ animate_by=animate_by,
+ colors=colors,
+ title=title,
+ facet_cols=facet_cols,
+ reshape_time=reshape_time,
+ fill=fill,
+ **plot_kwargs,
)
default_filetype = '.html'
elif engine == 'matplotlib':
- figure_like = plotting.heat_map_matplotlib(
- heatmap_data, title=name, color_map=color_map, xlabel=xlabel, ylabel=ylabel
+ figure_like = plotting.heatmap_with_matplotlib(
+ data=data,
+ colors=colors,
+ title=title,
+ reshape_time=reshape_time,
+ fill=fill,
+ **plot_kwargs,
)
default_filetype = '.png'
else:
raise ValueError(f'Engine "{engine}" not supported. Use "plotly" or "matplotlib"')
+ # Set default folder if not provided
+ if folder is None:
+ folder = pathlib.Path('.')
+
return plotting.export_figure(
figure_like=figure_like,
- default_path=folder / f'{name} ({heatmap_timeframes}-{heatmap_timesteps_per_frame})',
+ default_path=folder / title,
default_filetype=default_filetype,
user_path=None if isinstance(save, bool) else pathlib.Path(save),
show=show,
save=True if save else False,
+ dpi=dpi,
)
@@ -1787,8 +2735,13 @@ def apply_filter(array, coord_name: str, coord_values: Any | list[Any]):
if coord_name not in array.coords:
raise AttributeError(f"Missing required coordinate '{coord_name}'")
- # Convert single value to list
- val_list = [coord_values] if isinstance(coord_values, str) else coord_values
+ # Normalize to list for sequence-like inputs (excluding strings)
+ if isinstance(coord_values, str):
+ val_list = [coord_values]
+ elif isinstance(coord_values, (list, tuple, np.ndarray, pd.Index)):
+ val_list = list(coord_values)
+ else:
+ val_list = [coord_values]
# Verify coord_values exist
available = set(array[coord_name].values)
@@ -1798,7 +2751,7 @@ def apply_filter(array, coord_name: str, coord_values: Any | list[Any]):
# Apply filter
return array.where(
- array[coord_name].isin(val_list) if isinstance(coord_values, list) else array[coord_name] == coord_values,
+ array[coord_name].isin(val_list) if len(val_list) > 1 else array[coord_name] == val_list[0],
drop=True,
)
@@ -1817,36 +2770,26 @@ def apply_filter(array, coord_name: str, coord_values: Any | list[Any]):
return da
-def _apply_indexer_to_data(
- data: xr.DataArray | xr.Dataset, indexer: dict[str, Any] | None = None, drop=False
+def _apply_selection_to_data(
+ data: xr.DataArray | xr.Dataset,
+ select: dict[str, Any] | None = None,
+ drop=False,
) -> tuple[xr.DataArray | xr.Dataset, list[str]]:
"""
- Apply indexer selection or auto-select first values for non-time dimensions.
+ Apply selection to data.
Args:
data: xarray Dataset or DataArray
- indexer: Optional selection dict
- If None, uses first value for each dimension (except time).
- If empty dict {}, uses all values.
+ select: Optional selection dict
+ drop: Whether to drop dimensions after selection
Returns:
Tuple of (selected_data, selection_string)
"""
selection_string = []
- if indexer is not None:
- # User provided indexer
- data = data.sel(indexer, drop=drop)
- selection_string.extend(f'{v}[{k}]' for k, v in indexer.items())
- else:
- # Auto-select first value for each dimension except 'time'
- selection = {}
- for dim in data.dims:
- if dim != 'time' and dim in data.coords:
- first_value = data.coords[dim].values[0]
- selection[dim] = first_value
- selection_string.append(f'{first_value}[{dim}]')
- if selection:
- data = data.sel(selection, drop=drop)
+ if select:
+ data = data.sel(select, drop=drop)
+ selection_string.extend(f'{dim}={val}' for dim, val in select.items())
return data, selection_string
From fc42bc2fd39230b09ebbdbf43ae94eefb15ded52 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 14:03:05 +0100
Subject: [PATCH 05/27] Add extra log_to_console option to solvers.py
---
flixopt/solvers.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/flixopt/solvers.py b/flixopt/solvers.py
index 410d69434..36f993f95 100644
--- a/flixopt/solvers.py
+++ b/flixopt/solvers.py
@@ -19,12 +19,14 @@ class _Solver:
Args:
mip_gap: Acceptable relative optimality gap in [0.0, 1.0].
time_limit_seconds: Time limit in seconds.
+ log_to_console: If False, no output to console.
extra_options: Additional solver options merged into `options`.
"""
name: ClassVar[str]
mip_gap: float
time_limit_seconds: int
+ log_to_console: bool = True
extra_options: dict[str, Any] = field(default_factory=dict)
@property
@@ -45,6 +47,7 @@ class GurobiSolver(_Solver):
Args:
mip_gap: Acceptable relative optimality gap in [0.0, 1.0]; mapped to Gurobi `MIPGap`.
time_limit_seconds: Time limit in seconds; mapped to Gurobi `TimeLimit`.
+ log_to_console: If False, no output to console.
extra_options: Additional solver options merged into `options`.
"""
@@ -55,6 +58,7 @@ def _options(self) -> dict[str, Any]:
return {
'MIPGap': self.mip_gap,
'TimeLimit': self.time_limit_seconds,
+ 'LogToConsole': 1 if self.log_to_console else 0,
}
@@ -65,6 +69,7 @@ class HighsSolver(_Solver):
Attributes:
mip_gap: Acceptable relative optimality gap in [0.0, 1.0]; mapped to HiGHS `mip_rel_gap`.
time_limit_seconds: Time limit in seconds; mapped to HiGHS `time_limit`.
+ log_to_console: If False, no output to console.
extra_options: Additional solver options merged into `options`.
threads (int | None): Number of threads to use. If None, HiGHS chooses.
"""
@@ -78,4 +83,5 @@ def _options(self) -> dict[str, Any]:
'mip_rel_gap': self.mip_gap,
'time_limit': self.time_limit_seconds,
'threads': self.threads,
+ 'log_to_console': self.log_to_console,
}
From d3bcdc271acc7ec1d177fbafce9ab2ec842040d0 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 14:06:51 +0100
Subject: [PATCH 06/27] Add extra log_to_console option to solvers.py
---
flixopt/calculation.py | 16 +++++++++-------
flixopt/io.py | 26 --------------------------
2 files changed, 9 insertions(+), 33 deletions(-)
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index 1dab78e57..d07a23793 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -10,6 +10,7 @@
from __future__ import annotations
+import copy
import logging
import math
import pathlib
@@ -612,13 +613,14 @@ def do_modeling_and_solve(
f'Following InvestmentModels were found: {invest_elements}'
)
- # Redirect solver stdout to null to avoid cluttering the output
- with fx_io.suppress_output():
- calculation.solve(
- solver,
- log_file=pathlib.Path(log_file) if log_file is not None else self.folder / f'{self.name}.log',
- log_main_results=log_main_results,
- )
+ solver_silent = copy.copy(solver)
+ solver_silent.log_to_console = False
+
+ calculation.solve(
+ solver_silent,
+ log_file=pathlib.Path(log_file) if log_file is not None else self.folder / f'{self.name}.log',
+ log_main_results=log_main_results,
+ )
progress_bar.close()
diff --git a/flixopt/io.py b/flixopt/io.py
index fa4ef4ebf..7f832ed0e 100644
--- a/flixopt/io.py
+++ b/flixopt/io.py
@@ -3,11 +3,8 @@
import inspect
import json
import logging
-import os
import pathlib
import re
-import sys
-from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
@@ -934,26 +931,3 @@ def build_metadata_info(parts: list[str], prefix: str = ' | ') -> str:
return ''
info = ' | '.join(parts)
return prefix + info if prefix else info
-
-
-@contextmanager
-def suppress_output():
- """Redirect both Python and C-level stdout/stderr to os.devnull."""
- with open(os.devnull, 'w') as devnull:
- # Save original file descriptors
- old_stdout_fd = os.dup(1)
- old_stderr_fd = os.dup(2)
- try:
- # Flush any pending text
- sys.stdout.flush()
- sys.stderr.flush()
- # Redirect low-level fds to devnull
- os.dup2(devnull.fileno(), 1)
- os.dup2(devnull.fileno(), 2)
- yield
- finally:
- # Restore fds
- os.dup2(old_stdout_fd, 1)
- os.dup2(old_stderr_fd, 2)
- os.close(old_stdout_fd)
- os.close(old_stderr_fd)
From 7931580360a20e0686f4b711ded8e165bff1d031 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 14:11:10 +0100
Subject: [PATCH 07/27] Add extra log_to_console option config.py
---
flixopt/config.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/flixopt/config.py b/flixopt/config.py
index 670f86da2..83bdbe66f 100644
--- a/flixopt/config.py
+++ b/flixopt/config.py
@@ -28,6 +28,7 @@
'file': None,
'rich': False,
'console': False,
+ 'solver_to_console': True,
'max_file_size': 10_485_760, # 10MB
'backup_count': 5,
'date_format': '%Y-%m-%d %H:%M:%S',
@@ -104,6 +105,7 @@ class Logging:
file: Log file path for file logging.
console: Enable console output.
rich: Use Rich library for enhanced output.
+ solver_to_console: Enable solver output to console.
max_file_size: Max file size before rotation.
backup_count: Number of backup files to keep.
date_format: Date/time format string.
@@ -135,6 +137,7 @@ class Logging:
file: str | None = _DEFAULTS['logging']['file']
rich: bool = _DEFAULTS['logging']['rich']
console: bool | Literal['stdout', 'stderr'] = _DEFAULTS['logging']['console']
+ solver_to_console: bool = _DEFAULTS['logging']['solver_to_console']
max_file_size: int = _DEFAULTS['logging']['max_file_size']
backup_count: int = _DEFAULTS['logging']['backup_count']
date_format: str = _DEFAULTS['logging']['date_format']
@@ -346,6 +349,7 @@ def to_dict(cls) -> dict:
'file': cls.Logging.file,
'rich': cls.Logging.rich,
'console': cls.Logging.console,
+ 'solver_to_console': cls.Logging.solver_to_console,
'max_file_size': cls.Logging.max_file_size,
'backup_count': cls.Logging.backup_count,
'date_format': cls.Logging.date_format,
From 168ec617c6860d07f80a5a0db85afc276cd52155 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 14:13:27 +0100
Subject: [PATCH 08/27] Add to tests
---
tests/test_config.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/tests/test_config.py b/tests/test_config.py
index 60ed80555..ae3304188 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -28,6 +28,7 @@ def test_config_defaults(self):
assert CONFIG.Logging.file is None
assert CONFIG.Logging.rich is False
assert CONFIG.Logging.console is False
+ assert CONFIG.Logging.solver_to_console is True
assert CONFIG.Modeling.big == 10_000_000
assert CONFIG.Modeling.epsilon == 1e-5
assert CONFIG.Modeling.big_binary_bound == 100_000
@@ -104,6 +105,7 @@ def test_config_to_dict(self):
assert config_dict['logging']['console'] is True
assert config_dict['logging']['file'] is None
assert config_dict['logging']['rich'] is False
+ assert config_dict['logging']['solver_to_console'] is True
assert 'modeling' in config_dict
assert config_dict['modeling']['big'] == 10_000_000
@@ -423,6 +425,7 @@ def test_config_reset(self):
CONFIG.Logging.console = False
CONFIG.Logging.rich = True
CONFIG.Logging.file = '/tmp/test.log'
+ CONFIG.Logging.solver_to_console = False
CONFIG.Modeling.big = 99999999
CONFIG.Modeling.epsilon = 1e-8
CONFIG.Modeling.big_binary_bound = 500000
@@ -436,6 +439,7 @@ def test_config_reset(self):
assert CONFIG.Logging.console is False
assert CONFIG.Logging.rich is False
assert CONFIG.Logging.file is None
+ assert CONFIG.Logging.solver_to_console is True
assert CONFIG.Modeling.big == 10_000_000
assert CONFIG.Modeling.epsilon == 1e-5
assert CONFIG.Modeling.big_binary_bound == 100_000
@@ -457,6 +461,7 @@ def test_reset_matches_class_defaults(self):
CONFIG.Logging.file = '/tmp/test.log'
CONFIG.Logging.rich = True
CONFIG.Logging.console = True
+ CONFIG.Logging.solver_to_console = False
CONFIG.Modeling.big = 999999
CONFIG.Modeling.epsilon = 1e-10
CONFIG.Modeling.big_binary_bound = 999999
@@ -464,6 +469,7 @@ def test_reset_matches_class_defaults(self):
# Verify values are actually different from defaults
assert CONFIG.Logging.level != _DEFAULTS['logging']['level']
+ assert CONFIG.Logging.solver_to_console != _DEFAULTS['logging']['solver_to_console']
assert CONFIG.Modeling.big != _DEFAULTS['modeling']['big']
# Now reset
@@ -474,6 +480,7 @@ def test_reset_matches_class_defaults(self):
assert CONFIG.Logging.file == _DEFAULTS['logging']['file']
assert CONFIG.Logging.rich == _DEFAULTS['logging']['rich']
assert CONFIG.Logging.console == _DEFAULTS['logging']['console']
+ assert CONFIG.Logging.solver_to_console == _DEFAULTS['logging']['solver_to_console']
assert CONFIG.Modeling.big == _DEFAULTS['modeling']['big']
assert CONFIG.Modeling.epsilon == _DEFAULTS['modeling']['epsilon']
assert CONFIG.Modeling.big_binary_bound == _DEFAULTS['modeling']['big_binary_bound']
From 20602f9c0cb9c813adb6b74812f398f02303b488 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 14:27:28 +0100
Subject: [PATCH 09/27] Use default from console to say if logging to console
(gurobipy still has some issues...)
---
flixopt/solvers.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/flixopt/solvers.py b/flixopt/solvers.py
index 36f993f95..7d083eef4 100644
--- a/flixopt/solvers.py
+++ b/flixopt/solvers.py
@@ -8,6 +8,8 @@
from dataclasses import dataclass, field
from typing import Any, ClassVar
+from flixopt.config import CONFIG
+
logger = logging.getLogger('flixopt')
@@ -26,7 +28,7 @@ class _Solver:
name: ClassVar[str]
mip_gap: float
time_limit_seconds: int
- log_to_console: bool = True
+ log_to_console: bool = field(default_factory=lambda: CONFIG.Logging.solver_to_console)
extra_options: dict[str, Any] = field(default_factory=dict)
@property
From 95b921770be11e0d61705d788f014a851b733a79 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 14:33:35 +0100
Subject: [PATCH 10/27] Add rounding duration of solve
---
flixopt/calculation.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index d07a23793..28dccb5ab 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -238,7 +238,7 @@ def solve(
**solver.options,
)
self.durations['solving'] = round(timeit.default_timer() - t_start, 2)
- logger.info(f'Model solved with {solver.name} in {self.durations["solving"]} seconds.')
+ logger.info(f'Model solved with {solver.name} in {self.durations["solving"]:.2f} seconds.')
logger.info(f'Model status after solve: {self.model.status}')
if self.model.status == 'warning':
@@ -628,7 +628,7 @@ def do_modeling_and_solve(
for key, value in calc.durations.items():
self.durations[key] += value
- logger.info(f'Model solved with {solver.name} in {self.durations["solving"]} seconds.')
+ logger.info(f'Model solved with {solver.name} in {self.durations["solving"]:.2f} seconds.')
self.results = SegmentedCalculationResults.from_calculation(self)
From 677f534e65c1ab656653824f79dc0c9f7218fd55 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 14:34:58 +0100
Subject: [PATCH 11/27] Use contextmanager to entirely supress output in
SegmentedCalculation
---
flixopt/calculation.py | 19 ++++++++++---------
flixopt/io.py | 24 ++++++++++++++++++++++++
2 files changed, 34 insertions(+), 9 deletions(-)
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index 28dccb5ab..a55e453d1 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -573,7 +573,10 @@ def _create_sub_calculations(self):
)
def do_modeling_and_solve(
- self, solver: _Solver, log_file: pathlib.Path | None = None, log_main_results: bool = False
+ self,
+ solver: _Solver,
+ log_file: pathlib.Path | None = None,
+ log_main_results: bool = False,
) -> SegmentedCalculation:
logger.info(f'{"":#^80}')
logger.info(f'{" Segmented Solving ":#^80}')
@@ -613,14 +616,12 @@ def do_modeling_and_solve(
f'Following InvestmentModels were found: {invest_elements}'
)
- solver_silent = copy.copy(solver)
- solver_silent.log_to_console = False
-
- calculation.solve(
- solver_silent,
- log_file=pathlib.Path(log_file) if log_file is not None else self.folder / f'{self.name}.log',
- log_main_results=log_main_results,
- )
+ with fx_io.suppress_output():
+ calculation.solve(
+ solver,
+ log_file=pathlib.Path(log_file) if log_file is not None else self.folder / f'{self.name}.log',
+ log_main_results=log_main_results,
+ )
progress_bar.close()
diff --git a/flixopt/io.py b/flixopt/io.py
index 7f832ed0e..6a5544d7b 100644
--- a/flixopt/io.py
+++ b/flixopt/io.py
@@ -3,8 +3,11 @@
import inspect
import json
import logging
+import os
import pathlib
import re
+import sys
+from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
@@ -931,3 +934,24 @@ def build_metadata_info(parts: list[str], prefix: str = ' | ') -> str:
return ''
info = ' | '.join(parts)
return prefix + info if prefix else info
+
+
+@contextmanager
+def suppress_output():
+ """Suppress all console output including C-level output from Gurobi."""
+ old_stdout_fd = os.dup(1)
+ old_stderr_fd = os.dup(2)
+
+ try:
+ devnull_fd = os.open(os.devnull, os.O_WRONLY)
+ sys.stdout.flush()
+ sys.stderr.flush()
+ os.dup2(devnull_fd, 1)
+ os.dup2(devnull_fd, 2)
+ yield
+ finally:
+ os.dup2(old_stdout_fd, 1)
+ os.dup2(old_stderr_fd, 2)
+ os.close(devnull_fd)
+ os.close(old_stdout_fd)
+ os.close(old_stderr_fd)
From 767d8eca40fd7d89695b02f7bb115d4daed2789b Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 14:36:40 +0100
Subject: [PATCH 12/27] Improve suppress_output()
---
flixopt/io.py | 47 +++++++++++++++++++++++++++++++++++++++++------
1 file changed, 41 insertions(+), 6 deletions(-)
diff --git a/flixopt/io.py b/flixopt/io.py
index 6a5544d7b..c5f839ed9 100644
--- a/flixopt/io.py
+++ b/flixopt/io.py
@@ -938,20 +938,55 @@ def build_metadata_info(parts: list[str], prefix: str = ' | ') -> str:
@contextmanager
def suppress_output():
- """Suppress all console output including C-level output from Gurobi."""
+ """
+ Suppress all console output including C-level output from solvers.
+
+ WARNING: Not thread-safe. Modifies global file descriptors.
+ Use only with sequential execution or multiprocessing.
+ """
+ # Save original file descriptors
old_stdout_fd = os.dup(1)
old_stderr_fd = os.dup(2)
+ devnull_fd = None
try:
+ # Open devnull
devnull_fd = os.open(os.devnull, os.O_WRONLY)
+
+ # Flush Python buffers before redirecting
sys.stdout.flush()
sys.stderr.flush()
+
+ # Redirect file descriptors to devnull
os.dup2(devnull_fd, 1)
os.dup2(devnull_fd, 2)
+
yield
+
finally:
- os.dup2(old_stdout_fd, 1)
- os.dup2(old_stderr_fd, 2)
- os.close(devnull_fd)
- os.close(old_stdout_fd)
- os.close(old_stderr_fd)
+ # Restore original file descriptors with nested try blocks
+ # to ensure all cleanup happens even if one step fails
+ try:
+ # Flush any buffered output in the redirected streams
+ sys.stdout.flush()
+ sys.stderr.flush()
+ except (OSError, ValueError):
+ pass # Stream might be closed or invalid
+
+ try:
+ os.dup2(old_stdout_fd, 1)
+ except OSError:
+ pass # Failed to restore stdout, continue cleanup
+
+ try:
+ os.dup2(old_stderr_fd, 2)
+ except OSError:
+ pass # Failed to restore stderr, continue cleanup
+
+ # Close all file descriptors
+ for fd in [devnull_fd, old_stdout_fd, old_stderr_fd]:
+ if fd is not None:
+ try:
+ os.close(fd)
+ except OSError:
+ pass # FD already closed or invalid
From faf4267e83fe6cf3bff8ffe3c24ae49aee902886 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 15:06:22 +0100
Subject: [PATCH 13/27] More options in config.py
---
flixopt/calculation.py | 4 +-
flixopt/config.py | 59 +++++++++++++++--
flixopt/solvers.py | 12 ++--
tests/test_config.py | 144 +++++++++++++++++++++++++++++++++++++++--
4 files changed, 200 insertions(+), 19 deletions(-)
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index a55e453d1..1728725b8 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -228,7 +228,7 @@ def fix_sizes(self, ds: xr.Dataset, decimal_rounding: int | None = 5) -> FullCal
return self
def solve(
- self, solver: _Solver, log_file: pathlib.Path | None = None, log_main_results: bool = True
+ self, solver: _Solver, log_file: pathlib.Path | None = None, log_main_results: bool | None = None
) -> FullCalculation:
t_start = timeit.default_timer()
@@ -253,7 +253,7 @@ def solve(
)
# Log the formatted output
- if log_main_results:
+ if log_main_results if log_main_results is not None else CONFIG.Solving.log_main_results:
logger.info(
f'{" Main Results ":#^80}\n'
+ yaml.dump(
diff --git a/flixopt/config.py b/flixopt/config.py
index 83bdbe66f..1507621fd 100644
--- a/flixopt/config.py
+++ b/flixopt/config.py
@@ -28,7 +28,6 @@
'file': None,
'rich': False,
'console': False,
- 'solver_to_console': True,
'max_file_size': 10_485_760, # 10MB
'backup_count': 5,
'date_format': '%Y-%m-%d %H:%M:%S',
@@ -64,6 +63,14 @@
'default_qualitative_colorscale': 'plotly',
}
),
+ 'solving': MappingProxyType(
+ {
+ 'mip_gap': 0.01,
+ 'time_limit_seconds': 300,
+ 'log_to_console': True,
+ 'log_main_results': True,
+ }
+ ),
}
)
@@ -76,6 +83,8 @@ class CONFIG:
Attributes:
Logging: Logging configuration.
Modeling: Optimization modeling parameters.
+ Solving: Solver configuration and default parameters.
+ Plotting: Plotting configuration.
config_name: Configuration name.
Examples:
@@ -92,6 +101,9 @@ class CONFIG:
level: DEBUG
console: true
file: app.log
+ solving:
+ mip_gap: 0.001
+ time_limit_seconds: 600
```
"""
@@ -105,7 +117,6 @@ class Logging:
file: Log file path for file logging.
console: Enable console output.
rich: Use Rich library for enhanced output.
- solver_to_console: Enable solver output to console.
max_file_size: Max file size before rotation.
backup_count: Number of backup files to keep.
date_format: Date/time format string.
@@ -137,7 +148,6 @@ class Logging:
file: str | None = _DEFAULTS['logging']['file']
rich: bool = _DEFAULTS['logging']['rich']
console: bool | Literal['stdout', 'stderr'] = _DEFAULTS['logging']['console']
- solver_to_console: bool = _DEFAULTS['logging']['solver_to_console']
max_file_size: int = _DEFAULTS['logging']['max_file_size']
backup_count: int = _DEFAULTS['logging']['backup_count']
date_format: str = _DEFAULTS['logging']['date_format']
@@ -197,6 +207,30 @@ class Modeling:
epsilon: float = _DEFAULTS['modeling']['epsilon']
big_binary_bound: int = _DEFAULTS['modeling']['big_binary_bound']
+ class Solving:
+ """Solver configuration and default parameters.
+
+ Attributes:
+ mip_gap: Default MIP gap tolerance for solver convergence.
+ time_limit_seconds: Default time limit in seconds for solver runs.
+ log_to_console: Whether solver should output to console.
+ log_main_results: Whether to log main results after solving.
+
+ Examples:
+ ```python
+ # Set tighter convergence and longer timeout
+ CONFIG.Solving.mip_gap = 0.001
+ CONFIG.Solving.time_limit_seconds = 600
+ CONFIG.Solving.log_to_console = False
+ CONFIG.apply()
+ ```
+ """
+
+ mip_gap: float = _DEFAULTS['solving']['mip_gap']
+ time_limit_seconds: int = _DEFAULTS['solving']['time_limit_seconds']
+ log_to_console: bool = _DEFAULTS['solving']['log_to_console']
+ log_main_results: bool = _DEFAULTS['solving']['log_main_results']
+
class Plotting:
"""Plotting configuration.
@@ -249,6 +283,12 @@ def reset(cls):
for key, value in _DEFAULTS['modeling'].items():
setattr(cls.Modeling, key, value)
+ for key, value in _DEFAULTS['solving'].items():
+ setattr(cls.Solving, key, value)
+
+ for key, value in _DEFAULTS['plotting'].items():
+ setattr(cls.Plotting, key, value)
+
cls.config_name = _DEFAULTS['config_name']
cls.apply()
@@ -332,6 +372,12 @@ def _apply_config_dict(cls, config_dict: dict):
elif key == 'modeling' and isinstance(value, dict):
for nested_key, nested_value in value.items():
setattr(cls.Modeling, nested_key, nested_value)
+ elif key == 'solving' and isinstance(value, dict):
+ for nested_key, nested_value in value.items():
+ setattr(cls.Solving, nested_key, nested_value)
+ elif key == 'plotting' and isinstance(value, dict):
+ for nested_key, nested_value in value.items():
+ setattr(cls.Plotting, nested_key, nested_value)
elif hasattr(cls, key):
setattr(cls, key, value)
@@ -349,7 +395,6 @@ def to_dict(cls) -> dict:
'file': cls.Logging.file,
'rich': cls.Logging.rich,
'console': cls.Logging.console,
- 'solver_to_console': cls.Logging.solver_to_console,
'max_file_size': cls.Logging.max_file_size,
'backup_count': cls.Logging.backup_count,
'date_format': cls.Logging.date_format,
@@ -370,6 +415,12 @@ def to_dict(cls) -> dict:
'epsilon': cls.Modeling.epsilon,
'big_binary_bound': cls.Modeling.big_binary_bound,
},
+ 'solving': {
+ 'mip_gap': cls.Solving.mip_gap,
+ 'time_limit_seconds': cls.Solving.time_limit_seconds,
+ 'log_to_console': cls.Solving.log_to_console,
+ 'log_main_results': cls.Solving.log_main_results,
+ },
'plotting': {
'default_show': cls.Plotting.default_show,
'default_engine': cls.Plotting.default_engine,
diff --git a/flixopt/solvers.py b/flixopt/solvers.py
index 7d083eef4..e5db61192 100644
--- a/flixopt/solvers.py
+++ b/flixopt/solvers.py
@@ -19,16 +19,16 @@ class _Solver:
Abstract base class for solvers.
Args:
- mip_gap: Acceptable relative optimality gap in [0.0, 1.0].
- time_limit_seconds: Time limit in seconds.
- log_to_console: If False, no output to console.
+ mip_gap: Acceptable relative optimality gap in [0.0, 1.0]. Defaults to CONFIG.Solving.mip_gap.
+ time_limit_seconds: Time limit in seconds. Defaults to CONFIG.Solving.time_limit_seconds.
+ log_to_console: If False, no output to console. Defaults to CONFIG.Solving.log_to_console.
extra_options: Additional solver options merged into `options`.
"""
name: ClassVar[str]
- mip_gap: float
- time_limit_seconds: int
- log_to_console: bool = field(default_factory=lambda: CONFIG.Logging.solver_to_console)
+ mip_gap: float = field(default_factory=lambda: CONFIG.Solving.mip_gap)
+ time_limit_seconds: int = field(default_factory=lambda: CONFIG.Solving.time_limit_seconds)
+ log_to_console: bool = field(default_factory=lambda: CONFIG.Solving.log_to_console)
extra_options: dict[str, Any] = field(default_factory=dict)
@property
diff --git a/tests/test_config.py b/tests/test_config.py
index ae3304188..a78330eb4 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -28,10 +28,13 @@ def test_config_defaults(self):
assert CONFIG.Logging.file is None
assert CONFIG.Logging.rich is False
assert CONFIG.Logging.console is False
- assert CONFIG.Logging.solver_to_console is True
assert CONFIG.Modeling.big == 10_000_000
assert CONFIG.Modeling.epsilon == 1e-5
assert CONFIG.Modeling.big_binary_bound == 100_000
+ assert CONFIG.Solving.mip_gap == 0.01
+ assert CONFIG.Solving.time_limit_seconds == 300
+ assert CONFIG.Solving.log_to_console is True
+ assert CONFIG.Solving.log_main_results is True
assert CONFIG.config_name == 'flixopt'
def test_module_initialization(self):
@@ -105,9 +108,13 @@ def test_config_to_dict(self):
assert config_dict['logging']['console'] is True
assert config_dict['logging']['file'] is None
assert config_dict['logging']['rich'] is False
- assert config_dict['logging']['solver_to_console'] is True
assert 'modeling' in config_dict
assert config_dict['modeling']['big'] == 10_000_000
+ assert 'solving' in config_dict
+ assert config_dict['solving']['mip_gap'] == 0.01
+ assert config_dict['solving']['time_limit_seconds'] == 300
+ assert config_dict['solving']['log_to_console'] is True
+ assert config_dict['solving']['log_main_results'] is True
def test_config_load_from_file(self, tmp_path):
"""Test loading configuration from YAML file."""
@@ -121,6 +128,10 @@ def test_config_load_from_file(self, tmp_path):
modeling:
big: 20000000
epsilon: 1e-6
+solving:
+ mip_gap: 0.001
+ time_limit_seconds: 600
+ log_main_results: false
"""
config_file.write_text(config_content)
@@ -132,6 +143,9 @@ def test_config_load_from_file(self, tmp_path):
assert CONFIG.Modeling.big == 20000000
# YAML may load epsilon as string, so convert for comparison
assert float(CONFIG.Modeling.epsilon) == 1e-6
+ assert CONFIG.Solving.mip_gap == 0.001
+ assert CONFIG.Solving.time_limit_seconds == 600
+ assert CONFIG.Solving.log_main_results is False
def test_config_load_from_file_not_found(self):
"""Test that loading from non-existent file raises error."""
@@ -266,6 +280,10 @@ def test_custom_config_yaml_complete(self, tmp_path):
big: 50000000
epsilon: 1e-4
big_binary_bound: 200000
+solving:
+ mip_gap: 0.005
+ time_limit_seconds: 900
+ log_main_results: false
"""
config_file.write_text(config_content)
@@ -280,6 +298,9 @@ def test_custom_config_yaml_complete(self, tmp_path):
assert CONFIG.Modeling.big == 50000000
assert float(CONFIG.Modeling.epsilon) == 1e-4
assert CONFIG.Modeling.big_binary_bound == 200000
+ assert CONFIG.Solving.mip_gap == 0.005
+ assert CONFIG.Solving.time_limit_seconds == 900
+ assert CONFIG.Solving.log_main_results is False
# Verify logging was applied
logger = logging.getLogger('flixopt')
@@ -425,10 +446,13 @@ def test_config_reset(self):
CONFIG.Logging.console = False
CONFIG.Logging.rich = True
CONFIG.Logging.file = '/tmp/test.log'
- CONFIG.Logging.solver_to_console = False
CONFIG.Modeling.big = 99999999
CONFIG.Modeling.epsilon = 1e-8
CONFIG.Modeling.big_binary_bound = 500000
+ CONFIG.Solving.mip_gap = 0.0001
+ CONFIG.Solving.time_limit_seconds = 1800
+ CONFIG.Solving.log_to_console = False
+ CONFIG.Solving.log_main_results = False
CONFIG.config_name = 'test_config'
# Reset should restore all defaults
@@ -439,10 +463,13 @@ def test_config_reset(self):
assert CONFIG.Logging.console is False
assert CONFIG.Logging.rich is False
assert CONFIG.Logging.file is None
- assert CONFIG.Logging.solver_to_console is True
assert CONFIG.Modeling.big == 10_000_000
assert CONFIG.Modeling.epsilon == 1e-5
assert CONFIG.Modeling.big_binary_bound == 100_000
+ assert CONFIG.Solving.mip_gap == 0.01
+ assert CONFIG.Solving.time_limit_seconds == 300
+ assert CONFIG.Solving.log_to_console is True
+ assert CONFIG.Solving.log_main_results is True
assert CONFIG.config_name == 'flixopt'
# Verify logging was also reset
@@ -461,16 +488,20 @@ def test_reset_matches_class_defaults(self):
CONFIG.Logging.file = '/tmp/test.log'
CONFIG.Logging.rich = True
CONFIG.Logging.console = True
- CONFIG.Logging.solver_to_console = False
CONFIG.Modeling.big = 999999
CONFIG.Modeling.epsilon = 1e-10
CONFIG.Modeling.big_binary_bound = 999999
+ CONFIG.Solving.mip_gap = 0.0001
+ CONFIG.Solving.time_limit_seconds = 9999
+ CONFIG.Solving.log_to_console = False
+ CONFIG.Solving.log_main_results = False
CONFIG.config_name = 'modified'
# Verify values are actually different from defaults
assert CONFIG.Logging.level != _DEFAULTS['logging']['level']
- assert CONFIG.Logging.solver_to_console != _DEFAULTS['logging']['solver_to_console']
assert CONFIG.Modeling.big != _DEFAULTS['modeling']['big']
+ assert CONFIG.Solving.mip_gap != _DEFAULTS['solving']['mip_gap']
+ assert CONFIG.Solving.log_to_console != _DEFAULTS['solving']['log_to_console']
# Now reset
CONFIG.reset()
@@ -480,8 +511,107 @@ def test_reset_matches_class_defaults(self):
assert CONFIG.Logging.file == _DEFAULTS['logging']['file']
assert CONFIG.Logging.rich == _DEFAULTS['logging']['rich']
assert CONFIG.Logging.console == _DEFAULTS['logging']['console']
- assert CONFIG.Logging.solver_to_console == _DEFAULTS['logging']['solver_to_console']
assert CONFIG.Modeling.big == _DEFAULTS['modeling']['big']
assert CONFIG.Modeling.epsilon == _DEFAULTS['modeling']['epsilon']
assert CONFIG.Modeling.big_binary_bound == _DEFAULTS['modeling']['big_binary_bound']
+ assert CONFIG.Solving.mip_gap == _DEFAULTS['solving']['mip_gap']
+ assert CONFIG.Solving.time_limit_seconds == _DEFAULTS['solving']['time_limit_seconds']
+ assert CONFIG.Solving.log_to_console == _DEFAULTS['solving']['log_to_console']
+ assert CONFIG.Solving.log_main_results == _DEFAULTS['solving']['log_main_results']
assert CONFIG.config_name == _DEFAULTS['config_name']
+
+ def test_solving_config_defaults(self):
+ """Test that CONFIG.Solving has correct default values."""
+ assert CONFIG.Solving.mip_gap == 0.01
+ assert CONFIG.Solving.time_limit_seconds == 300
+ assert CONFIG.Solving.log_to_console is True
+ assert CONFIG.Solving.log_main_results is True
+
+ def test_solving_config_modification(self):
+ """Test that CONFIG.Solving attributes can be modified."""
+ # Modify solving config
+ CONFIG.Solving.mip_gap = 0.005
+ CONFIG.Solving.time_limit_seconds = 600
+ CONFIG.Solving.log_main_results = False
+ CONFIG.apply()
+
+ # Verify modifications
+ assert CONFIG.Solving.mip_gap == 0.005
+ assert CONFIG.Solving.time_limit_seconds == 600
+ assert CONFIG.Solving.log_main_results is False
+
+ def test_solving_config_integration_with_solvers(self):
+ """Test that solvers use CONFIG.Solving defaults."""
+ from flixopt import solvers
+
+ # Test with default config
+ CONFIG.reset()
+ solver1 = solvers.HighsSolver()
+ assert solver1.mip_gap == CONFIG.Solving.mip_gap
+ assert solver1.time_limit_seconds == CONFIG.Solving.time_limit_seconds
+
+ # Modify config and create new solver
+ CONFIG.Solving.mip_gap = 0.002
+ CONFIG.Solving.time_limit_seconds = 900
+ CONFIG.apply()
+
+ solver2 = solvers.GurobiSolver()
+ assert solver2.mip_gap == 0.002
+ assert solver2.time_limit_seconds == 900
+
+ # Explicit values should override config
+ solver3 = solvers.HighsSolver(mip_gap=0.1, time_limit_seconds=60)
+ assert solver3.mip_gap == 0.1
+ assert solver3.time_limit_seconds == 60
+
+ def test_solving_config_yaml_loading(self, tmp_path):
+ """Test loading solving config from YAML file."""
+ config_file = tmp_path / 'solving_config.yaml'
+ config_content = """
+solving:
+ mip_gap: 0.0001
+ time_limit_seconds: 1200
+ log_main_results: false
+"""
+ config_file.write_text(config_content)
+
+ CONFIG.load_from_file(config_file)
+
+ assert CONFIG.Solving.mip_gap == 0.0001
+ assert CONFIG.Solving.time_limit_seconds == 1200
+ assert CONFIG.Solving.log_main_results is False
+
+ def test_solving_config_in_to_dict(self):
+ """Test that CONFIG.Solving is included in to_dict()."""
+ CONFIG.Solving.mip_gap = 0.003
+ CONFIG.Solving.time_limit_seconds = 450
+ CONFIG.Solving.log_main_results = False
+
+ config_dict = CONFIG.to_dict()
+
+ assert 'solving' in config_dict
+ assert config_dict['solving']['mip_gap'] == 0.003
+ assert config_dict['solving']['time_limit_seconds'] == 450
+ assert config_dict['solving']['log_main_results'] is False
+
+ def test_solving_config_persistence(self):
+ """Test that Solving config is independent of other configs."""
+ # Set custom solving values
+ CONFIG.Solving.mip_gap = 0.007
+ CONFIG.Solving.time_limit_seconds = 750
+
+ # Change and apply logging config
+ CONFIG.Logging.console = True
+ CONFIG.apply()
+
+ # Solving values should be unchanged
+ assert CONFIG.Solving.mip_gap == 0.007
+ assert CONFIG.Solving.time_limit_seconds == 750
+
+ # Change modeling config
+ CONFIG.Modeling.big = 99999999
+ CONFIG.apply()
+
+ # Solving values should still be unchanged
+ assert CONFIG.Solving.mip_gap == 0.007
+ assert CONFIG.Solving.time_limit_seconds == 750
From 69ffb13d61338775d88b76e17e5d369d30a4a74c Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 15:42:15 +0100
Subject: [PATCH 14/27] Update CHANGELOG.md
---
CHANGELOG.md | 11 ++++++++-
flixopt/config.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 70 insertions(+), 1 deletion(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8d81c9a0b..d28ad16d1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -51,12 +51,21 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp
## [Unreleased] - ????-??-??
-**Summary**:
+**Summary**: Enhanced solver configuration with new CONFIG.Solving section for centralized solver parameter management.
If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOpt/flixOpt/releases/tag/v3.0.0) and [Migration Guide](https://flixopt.github.io/flixopt/latest/user-guide/migration-guide-v3/).
### β¨ Added
+**Solver configuration:**
+- **New `CONFIG.Solving` configuration section** for centralized solver parameter management:
+ - `mip_gap`: Default MIP gap tolerance for solver convergence (default: 0.01)
+ - `time_limit_seconds`: Default time limit in seconds for solver runs (default: 300)
+ - `log_to_console`: Whether solver should output to console (default: True)
+ - `log_main_results`: Whether to log main results after solving (default: True)
+- Solvers (`HighsSolver`, `GurobiSolver`) now use `CONFIG.Solving` defaults for parameters, allowing global configuration
+- Solver parameters can still be explicitly overridden when creating solver instances
+
### π₯ Breaking Changes
### β»οΈ Changed
diff --git a/flixopt/config.py b/flixopt/config.py
index 1507621fd..a74740efb 100644
--- a/flixopt/config.py
+++ b/flixopt/config.py
@@ -431,6 +431,66 @@ def to_dict(cls) -> dict:
},
}
+ @classmethod
+ def silent(cls) -> type[CONFIG]:
+ """Configure for silent operation.
+
+ Disables console logging, solver output, and result logging
+ for clean production runs. Does not show plots. Automatically calls apply().
+ """
+ cls.Logging.console = False
+ cls.Plotting.default_show = False
+ cls.Logging.file = None
+ cls.Solving.log_to_console = False
+ cls.Solving.log_main_results = False
+ cls.apply()
+ return cls
+
+ @classmethod
+ def debug(cls) -> type[CONFIG]:
+ """Configure for debug mode with verbose output.
+
+ Enables console logging at DEBUG level and all solver output for
+ troubleshooting. Automatically calls apply().
+ """
+ cls.Logging.console = True
+ cls.Logging.level = 'DEBUG'
+ cls.Solving.log_to_console = True
+ cls.Solving.log_main_results = True
+ cls.apply()
+ return cls
+
+ @classmethod
+ def exploring(cls) -> type[CONFIG]:
+ """Configure for exploring flixopt
+
+ Enables console logging at INFO level and all solver output.
+ Also enables browser plotting for plotly with showing plots per default
+ """
+ cls.Logging.console = True
+ cls.Logging.level = 'INFO'
+ cls.Solving.log_to_console = True
+ cls.Solving.log_main_results = True
+ cls.browser_plotting()
+ cls.apply()
+ return cls
+
+ @classmethod
+ def browser_plotting(cls) -> type[CONFIG]:
+ """Configure for interactive usage with plotly to open plots in browser.
+
+ Sets plotly.io.renderers.default = 'browser'. Useful for running examples
+ and viewing interactive plots. Does NOT modify CONFIG.Plotting settings.
+ """
+ cls.Plotting.default_show = True
+ cls.apply()
+
+ import plotly.io as pio
+
+ pio.renderers.default = 'browser'
+
+ return cls
+
class MultilineFormatter(logging.Formatter):
"""Formatter that handles multi-line messages with consistent prefixes.
From 2fbbd3a072604c30583c44dc2c2310e861f19edb Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 16:14:59 +0100
Subject: [PATCH 15/27] Use new Config options in examples
---
examples/00_Minmal/minimal_example.py | 3 +--
examples/01_Simple/simple_example.py | 7 +++----
examples/02_Complex/complex_example.py | 5 ++---
examples/02_Complex/complex_example_results.py | 5 ++---
examples/03_Calculation_types/example_calculation_types.py | 5 ++---
examples/04_Scenarios/scenario_example.py | 5 +++--
.../05_Two-stage-optimization/two_stage_optimization.py | 7 +++----
7 files changed, 16 insertions(+), 21 deletions(-)
diff --git a/examples/00_Minmal/minimal_example.py b/examples/00_Minmal/minimal_example.py
index 6a0ed3831..92e6801b2 100644
--- a/examples/00_Minmal/minimal_example.py
+++ b/examples/00_Minmal/minimal_example.py
@@ -9,8 +9,7 @@
import flixopt as fx
if __name__ == '__main__':
- fx.CONFIG.Logging.console = True
- fx.CONFIG.apply()
+ fx.CONFIG.silent()
flow_system = fx.FlowSystem(pd.date_range('2020-01-01', periods=3, freq='h'))
flow_system.add_elements(
diff --git a/examples/01_Simple/simple_example.py b/examples/01_Simple/simple_example.py
index 6b62d6712..fd5a3d9b7 100644
--- a/examples/01_Simple/simple_example.py
+++ b/examples/01_Simple/simple_example.py
@@ -8,9 +8,8 @@
import flixopt as fx
if __name__ == '__main__':
- # Enable console logging
- fx.CONFIG.Logging.console = True
- fx.CONFIG.apply()
+ fx.CONFIG.exploring()
+
# --- Create Time Series Data ---
# Heat demand profile (e.g., kW) over time and corresponding power prices
heat_demand_per_h = np.array([30, 0, 90, 110, 110, 20, 20, 20, 20])
@@ -101,7 +100,7 @@
flow_system.add_elements(costs, CO2, boiler, storage, chp, heat_sink, gas_source, power_sink)
# Visualize the flow system for validation purposes
- flow_system.plot_network(show=True)
+ flow_system.plot_network()
# --- Define and Run Calculation ---
# Create a calculation object to model the Flow System
diff --git a/examples/02_Complex/complex_example.py b/examples/02_Complex/complex_example.py
index 805cb08f6..b8ef76a03 100644
--- a/examples/02_Complex/complex_example.py
+++ b/examples/02_Complex/complex_example.py
@@ -9,9 +9,8 @@
import flixopt as fx
if __name__ == '__main__':
- # Enable console logging
- fx.CONFIG.Logging.console = True
- fx.CONFIG.apply()
+ fx.CONFIG.exploring()
+
# --- Experiment Options ---
# Configure options for testing various parameters and behaviors
check_penalty = False
diff --git a/examples/02_Complex/complex_example_results.py b/examples/02_Complex/complex_example_results.py
index 96d06dd04..edc2f7a1d 100644
--- a/examples/02_Complex/complex_example_results.py
+++ b/examples/02_Complex/complex_example_results.py
@@ -5,9 +5,8 @@
import flixopt as fx
if __name__ == '__main__':
- # Enable console logging
- fx.CONFIG.Logging.console = True
- fx.CONFIG.apply()
+ fx.CONFIG.exploring()
+
# --- Load Results ---
try:
results = fx.results.CalculationResults.from_file('results', 'complex example')
diff --git a/examples/03_Calculation_types/example_calculation_types.py b/examples/03_Calculation_types/example_calculation_types.py
index c5df50034..210747db9 100644
--- a/examples/03_Calculation_types/example_calculation_types.py
+++ b/examples/03_Calculation_types/example_calculation_types.py
@@ -11,9 +11,8 @@
import flixopt as fx
if __name__ == '__main__':
- # Enable console logging
- fx.CONFIG.Logging.console = True
- fx.CONFIG.apply()
+ fx.CONFIG.exploring()
+
# Calculation Types
full, segmented, aggregated = True, True, True
diff --git a/examples/04_Scenarios/scenario_example.py b/examples/04_Scenarios/scenario_example.py
index d258d4142..bf4f24617 100644
--- a/examples/04_Scenarios/scenario_example.py
+++ b/examples/04_Scenarios/scenario_example.py
@@ -8,6 +8,8 @@
import flixopt as fx
if __name__ == '__main__':
+ fx.CONFIG.exploring()
+
# Create datetime array starting from '2020-01-01' for one week
timesteps = pd.date_range('2020-01-01', periods=24 * 7, freq='h')
scenarios = pd.Index(['Base Case', 'High Demand'])
@@ -186,7 +188,7 @@
flow_system.add_elements(costs, CO2, boiler, storage, chp, heat_sink, gas_source, power_sink)
# Visualize the flow system for validation purposes
- flow_system.plot_network(show=True)
+ flow_system.plot_network()
# --- Define and Run Calculation ---
# Create a calculation object to model the Flow System
@@ -215,7 +217,6 @@
# Convert the results for the storage component to a dataframe and display
df = calculation.results['Storage'].node_balance_with_charge_state()
- print(df)
# Save results to file for later usage
calculation.results.to_file()
diff --git a/examples/05_Two-stage-optimization/two_stage_optimization.py b/examples/05_Two-stage-optimization/two_stage_optimization.py
index dde3ae069..b2be58cbe 100644
--- a/examples/05_Two-stage-optimization/two_stage_optimization.py
+++ b/examples/05_Two-stage-optimization/two_stage_optimization.py
@@ -7,7 +7,6 @@
While the final optimum might differ from the global optimum, the solving will be much faster.
"""
-import logging
import pathlib
import timeit
@@ -16,9 +15,9 @@
import flixopt as fx
-logger = logging.getLogger('flixopt')
-
if __name__ == '__main__':
+ fx.CONFIG.exploring()
+
# Data Import
data_import = pd.read_csv(
pathlib.Path(__file__).parent.parent / 'resources' / 'Zeitreihen2020.csv', index_col=0
@@ -136,7 +135,7 @@
timer_dispatch = timeit.default_timer() - start
if (calculation_dispatch.results.sizes().round(5) == calculation_sizing.results.sizes().round(5)).all().item():
- logger.info('Sizes were correctly equalized')
+ print('Sizes were correctly equalized')
else:
raise RuntimeError('Sizes were not correctly equalized')
From 209cdfd9155a69b274d0a4c638a706de67a7065e Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 16:20:37 +0100
Subject: [PATCH 16/27] Sett plotting backend in CI directly, overwriting all
configs
---
.github/workflows/python-app.yaml | 2 ++
tests/conftest.py | 23 -----------------------
2 files changed, 2 insertions(+), 23 deletions(-)
diff --git a/.github/workflows/python-app.yaml b/.github/workflows/python-app.yaml
index f4dbc28c5..66ceceab4 100644
--- a/.github/workflows/python-app.yaml
+++ b/.github/workflows/python-app.yaml
@@ -24,6 +24,8 @@ concurrency:
env:
PYTHON_VERSION: "3.11"
+ MPLBACKEND: Agg # Non-interactive matplotlib backend for CI/testing
+ PLOTLY_RENDERER: json # Headless plotly renderer for CI/testing
jobs:
lint:
diff --git a/tests/conftest.py b/tests/conftest.py
index bd940b843..50c58e1ab 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -828,26 +828,3 @@ def cleanup_figures():
import matplotlib.pyplot as plt
plt.close('all')
-
-
-@pytest.fixture(scope='session', autouse=True)
-def set_test_environment():
- """
- Configure plotting for test environment.
-
- This fixture runs once per test session to:
- - Set matplotlib to use non-interactive 'Agg' backend
- - Set plotly to use non-interactive 'json' renderer
- - Prevent GUI windows from opening during tests
- """
- import matplotlib
-
- matplotlib.use('Agg') # Use non-interactive backend
-
- import plotly.io as pio
-
- pio.renderers.default = 'json' # Use non-interactive renderer
-
- fx.CONFIG.Plotting.default_show = False
-
- yield
From f529d9b68e704d39994a00b21ab407f8d14c6ec9 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 16:47:24 +0100
Subject: [PATCH 17/27] Fixed tqdm progress bar to respect CONFIG.silent()
---
flixopt/calculation.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index 1728725b8..37ea7a3db 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -589,6 +589,7 @@ def do_modeling_and_solve(
desc='Solving segments',
unit='segment',
file=sys.stdout, # Force tqdm to write to stdout instead of stderr
+ disable=not CONFIG.Solving.log_to_console, # Respect silent configuration
)
for i, calculation in progress_bar:
From 3ea3881557d999acb397832633e8e92dc8f31175 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 16:47:38 +0100
Subject: [PATCH 18/27] Replaced print() with framework logger
(examples/05_Two-stage-optimization/two_stage_optimization.py
---
examples/05_Two-stage-optimization/two_stage_optimization.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/examples/05_Two-stage-optimization/two_stage_optimization.py b/examples/05_Two-stage-optimization/two_stage_optimization.py
index b2be58cbe..9647e803c 100644
--- a/examples/05_Two-stage-optimization/two_stage_optimization.py
+++ b/examples/05_Two-stage-optimization/two_stage_optimization.py
@@ -7,6 +7,7 @@
While the final optimum might differ from the global optimum, the solving will be much faster.
"""
+import logging
import pathlib
import timeit
@@ -15,6 +16,8 @@
import flixopt as fx
+logger = logging.getLogger('flixopt')
+
if __name__ == '__main__':
fx.CONFIG.exploring()
@@ -135,7 +138,7 @@
timer_dispatch = timeit.default_timer() - start
if (calculation_dispatch.results.sizes().round(5) == calculation_sizing.results.sizes().round(5)).all().item():
- print('Sizes were correctly equalized')
+ logger.info('Sizes were correctly equalized')
else:
raise RuntimeError('Sizes were not correctly equalized')
From 284e3a525b0fb0ee1352736faa760fdbc5b1142f Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Thu, 30 Oct 2025 16:47:58 +0100
Subject: [PATCH 19/27] Added comprehensive tests for suppress_output()
---
tests/test_io.py | 107 +++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 107 insertions(+)
diff --git a/tests/test_io.py b/tests/test_io.py
index dbbc4cc72..83ac4251b 100644
--- a/tests/test_io.py
+++ b/tests/test_io.py
@@ -80,5 +80,112 @@ def test_flow_system_io(flow_system):
flow_system.__str__()
+def test_suppress_output_file_descriptors(tmp_path):
+ """Test that suppress_output() redirects file descriptors to /dev/null."""
+ import os
+ import sys
+
+ from flixopt.io import suppress_output
+
+ # Create temporary files to capture output
+ test_file = tmp_path / 'test_output.txt'
+
+ # Test that FD 1 (stdout) is redirected during suppression
+ with open(test_file, 'w') as f:
+ original_stdout_fd = os.dup(1) # Save original stdout FD
+ try:
+ # Redirect FD 1 to our test file
+ os.dup2(f.fileno(), 1)
+ os.write(1, b'before suppression\n')
+
+ with suppress_output():
+ # Inside suppress_output, writes should go to /dev/null, not our file
+ os.write(1, b'during suppression\n')
+
+ # After suppress_output, writes should go to our file again
+ os.write(1, b'after suppression\n')
+ finally:
+ # Restore original stdout
+ os.dup2(original_stdout_fd, 1)
+ os.close(original_stdout_fd)
+
+ # Read the file and verify content
+ content = test_file.read_text()
+ assert 'before suppression' in content
+ assert 'during suppression' not in content # This should NOT be in the file
+ assert 'after suppression' in content
+
+
+def test_suppress_output_python_level():
+ """Test that Python-level stdout/stderr continue to work after suppress_output()."""
+ import io
+ import sys
+
+ from flixopt.io import suppress_output
+
+ # Create a StringIO to capture Python-level output
+ captured_output = io.StringIO()
+
+ # After suppress_output exits, Python streams should be functional
+ with suppress_output():
+ pass # Just enter and exit the context
+
+ # Redirect sys.stdout to our StringIO
+ old_stdout = sys.stdout
+ try:
+ sys.stdout = captured_output
+ print('test message')
+ finally:
+ sys.stdout = old_stdout
+
+ # Verify Python-level stdout works
+ assert 'test message' in captured_output.getvalue()
+
+
+def test_suppress_output_exception_handling():
+ """Test that suppress_output() properly restores streams even on exception."""
+ import sys
+
+ from flixopt.io import suppress_output
+
+ # Save original file descriptors
+ original_stdout_fd = sys.stdout.fileno()
+ original_stderr_fd = sys.stderr.fileno()
+
+ try:
+ with suppress_output():
+ raise ValueError('Test exception')
+ except ValueError:
+ pass
+
+ # Verify streams are restored after exception
+ assert sys.stdout.fileno() == original_stdout_fd
+ assert sys.stderr.fileno() == original_stderr_fd
+
+ # Verify we can still write to stdout/stderr
+ sys.stdout.write('test after exception\n')
+ sys.stdout.flush()
+
+
+def test_suppress_output_c_level():
+ """Test that suppress_output() suppresses C-level output (file descriptor level)."""
+ import os
+ import sys
+
+ from flixopt.io import suppress_output
+
+ # This test verifies that even low-level C writes are suppressed
+ # by writing directly to file descriptor 1 (stdout)
+ with suppress_output():
+ # Try to write directly to FD 1 (stdout) - should be suppressed
+ os.write(1, b'C-level stdout write\n')
+ # Try to write directly to FD 2 (stderr) - should be suppressed
+ os.write(2, b'C-level stderr write\n')
+
+ # After exiting context, ensure streams work
+ sys.stdout.write('After C-level test\n')
+ sys.stdout.flush()
+
+
if __name__ == '__main__':
pytest.main(['-v', '--disable-warnings'])
From 8f613bc58386b290bfdf44e4fa61ca705e2cf2fb Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Fri, 31 Oct 2025 16:25:28 +0100
Subject: [PATCH 20/27] Remove unused import
---
flixopt/calculation.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index 37ea7a3db..ff6780bb2 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -10,7 +10,6 @@
from __future__ import annotations
-import copy
import logging
import math
import pathlib
From 2bd25bc7cb4d6c43f2df80b6e81b22e51f16f763 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Fri, 31 Oct 2025 16:26:22 +0100
Subject: [PATCH 21/27] Ensure progress bar cleanup on exceptions.
---
flixopt/calculation.py | 63 +++++++++++++++++++++---------------------
1 file changed, 32 insertions(+), 31 deletions(-)
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index ff6780bb2..b643dbebf 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -591,39 +591,40 @@ def do_modeling_and_solve(
disable=not CONFIG.Solving.log_to_console, # Respect silent configuration
)
- for i, calculation in progress_bar:
- # Update progress bar description with current segment info
- progress_bar.set_description(
- f'Solving ({calculation.flow_system.timesteps[0]} -> {calculation.flow_system.timesteps[-1]})'
- )
-
- if i > 0 and self.nr_of_previous_values > 0:
- self._transfer_start_values(i)
-
- calculation.do_modeling()
-
- # Warn about Investments, but only in fist run
- if i == 0:
- invest_elements = [
- model.label_full
- for component in calculation.flow_system.components.values()
- for model in component.submodel.all_submodels
- if isinstance(model, InvestmentModel)
- ]
- if invest_elements:
- logger.critical(
- f'Investments are not supported in Segmented Calculation! '
- f'Following InvestmentModels were found: {invest_elements}'
- )
-
- with fx_io.suppress_output():
- calculation.solve(
- solver,
- log_file=pathlib.Path(log_file) if log_file is not None else self.folder / f'{self.name}.log',
- log_main_results=log_main_results,
+ try:
+ for i, calculation in progress_bar:
+ # Update progress bar description with current segment info
+ progress_bar.set_description(
+ f'Solving ({calculation.flow_system.timesteps[0]} -> {calculation.flow_system.timesteps[-1]})'
)
- progress_bar.close()
+ if i > 0 and self.nr_of_previous_values > 0:
+ self._transfer_start_values(i)
+
+ calculation.do_modeling()
+
+ # Warn about Investments, but only in fist run
+ if i == 0:
+ invest_elements = [
+ model.label_full
+ for component in calculation.flow_system.components.values()
+ for model in component.submodel.all_submodels
+ if isinstance(model, InvestmentModel)
+ ]
+ if invest_elements:
+ logger.critical(
+ f'Investments are not supported in Segmented Calculation! '
+ f'Following InvestmentModels were found: {invest_elements}'
+ )
+
+ with fx_io.suppress_output():
+ calculation.solve(
+ solver,
+ log_file=pathlib.Path(log_file) if log_file is not None else self.folder / f'{self.name}.log',
+ log_main_results=log_main_results,
+ )
+ finally:
+ progress_bar.close()
for calc in self.sub_calculations:
for key, value in calc.durations.items():
From 6d6f15efb884d3056f80cddbcea9860ea43e3dd0 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Fri, 31 Oct 2025 16:28:24 +0100
Subject: [PATCH 22/27] Add test
---
tests/test_io.py | 36 ++++++++++++++++++++++++++++++++++++
1 file changed, 36 insertions(+)
diff --git a/tests/test_io.py b/tests/test_io.py
index 83ac4251b..6d225734e 100644
--- a/tests/test_io.py
+++ b/tests/test_io.py
@@ -187,5 +187,41 @@ def test_suppress_output_c_level():
sys.stdout.flush()
+def test_tqdm_cleanup_on_exception():
+ """Test that tqdm progress bar is properly cleaned up even when exceptions occur.
+
+ This test verifies the pattern used in SegmentedCalculation where a try/finally
+ block ensures progress_bar.close() is called even if an exception occurs.
+ """
+ from tqdm import tqdm
+
+ # Create a progress bar (disabled to avoid output during tests)
+ items = enumerate(range(5))
+ progress_bar = tqdm(items, total=5, desc='Test progress', disable=True)
+
+ # Track whether cleanup was called
+ cleanup_called = False
+ exception_raised = False
+
+ try:
+ try:
+ for idx, _ in progress_bar:
+ if idx == 2:
+ raise ValueError('Test exception')
+ finally:
+ # This should always execute, even with exception
+ progress_bar.close()
+ cleanup_called = True
+ except ValueError:
+ exception_raised = True
+
+ # Verify both that the exception was raised AND cleanup happened
+ assert exception_raised, 'Test exception should have been raised'
+ assert cleanup_called, 'Cleanup should have been called even with exception'
+
+ # Verify that close() is idempotent - calling it again should not raise
+ progress_bar.close() # Should not raise even if already closed
+
+
if __name__ == '__main__':
pytest.main(['-v', '--disable-warnings'])
From 3ad25a0c80977ce7f19771ce73913f9f2ca68283 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Fri, 31 Oct 2025 16:49:56 +0100
Subject: [PATCH 23/27] Split method in SegmentedCalculation for better
distinction if show or not show solver output
---
CHANGELOG.md | 1 +
flixopt/calculation.py | 126 +++++++++++++++++++++++++++--------------
2 files changed, 86 insertions(+), 41 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d28ad16d1..2e4912d48 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -67,6 +67,7 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp
- Solver parameters can still be explicitly overridden when creating solver instances
### π₯ Breaking Changes
+- Individual solver output is now hidden in **SegmentedCalculation**. To return t the prior behaviour, set `show_individual_solves=True` in `do_modeling_and_solve()`.
### β»οΈ Changed
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index b643dbebf..c0b2a54c6 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -571,60 +571,104 @@ def _create_sub_calculations(self):
f'({timesteps_of_segment[0]} -> {timesteps_of_segment[-1]}):'
)
+ def _solve_single_segment(
+ self,
+ i: int,
+ calculation: FullCalculation,
+ solver: _Solver,
+ log_file: pathlib.Path | None,
+ log_main_results: bool,
+ suppress_output: bool,
+ ) -> None:
+ """Solve a single segment calculation."""
+ if i > 0 and self.nr_of_previous_values > 0:
+ self._transfer_start_values(i)
+
+ calculation.do_modeling()
+
+ # Warn about Investments, but only in first run
+ if i == 0:
+ invest_elements = [
+ model.label_full
+ for component in calculation.flow_system.components.values()
+ for model in component.submodel.all_submodels
+ if isinstance(model, InvestmentModel)
+ ]
+ if invest_elements:
+ logger.critical(
+ f'Investments are not supported in Segmented Calculation! '
+ f'Following InvestmentModels were found: {invest_elements}'
+ )
+
+ log_path = pathlib.Path(log_file) if log_file is not None else self.folder / f'{self.name}.log'
+
+ if suppress_output:
+ with fx_io.suppress_output():
+ calculation.solve(solver, log_file=log_path, log_main_results=log_main_results)
+ else:
+ calculation.solve(solver, log_file=log_path, log_main_results=log_main_results)
+
def do_modeling_and_solve(
self,
solver: _Solver,
log_file: pathlib.Path | None = None,
log_main_results: bool = False,
+ show_individual_solves: bool = False,
) -> SegmentedCalculation:
+ """Model and solve all segments of the segmented calculation.
+
+ This method creates sub-calculations for each time segment, then iteratively
+ models and solves each segment. It supports two output modes: a progress bar
+ for compact output, or detailed individual solve information.
+
+ Args:
+ solver: The solver instance to use for optimization (e.g., Gurobi, HiGHS).
+ log_file: Optional path to the solver log file. If None, defaults to
+ folder/name.log.
+ log_main_results: Whether to log main results (objective, effects, etc.)
+ after each segment solve. Defaults to False.
+ show_individual_solves: If True, shows detailed output for each segment
+ solve with logger messages. If False (default), shows a compact progress
+ bar with suppressed solver output for cleaner display.
+
+ Returns:
+ Self, for method chaining.
+
+ Note:
+ The method automatically transfers all start values between segments to ensure
+ continuity of storage states and flow rates across segment boundaries.
+ """
logger.info(f'{"":#^80}')
logger.info(f'{" Segmented Solving ":#^80}')
self._create_sub_calculations()
- # Create tqdm progress bar with custom format that prints to stdout
- progress_bar = tqdm(
- enumerate(self.sub_calculations),
- total=len(self.sub_calculations),
- desc='Solving segments',
- unit='segment',
- file=sys.stdout, # Force tqdm to write to stdout instead of stderr
- disable=not CONFIG.Solving.log_to_console, # Respect silent configuration
- )
-
- try:
- for i, calculation in progress_bar:
- # Update progress bar description with current segment info
- progress_bar.set_description(
- f'Solving ({calculation.flow_system.timesteps[0]} -> {calculation.flow_system.timesteps[-1]})'
+ if show_individual_solves:
+ # Path 1: Show individual solves with detailed output
+ for i, calculation in enumerate(self.sub_calculations):
+ logger.info(
+ f'Solving segment {i + 1}/{len(self.sub_calculations)}: '
+ f'{calculation.flow_system.timesteps[0]} -> {calculation.flow_system.timesteps[-1]}'
)
+ self._solve_single_segment(i, calculation, solver, log_file, log_main_results, suppress_output=False)
+ else:
+ # Path 2: Show only progress bar with suppressed output
+ progress_bar = tqdm(
+ enumerate(self.sub_calculations),
+ total=len(self.sub_calculations),
+ desc='Solving segments',
+ unit='segment',
+ file=sys.stdout,
+ disable=not CONFIG.Solving.log_to_console,
+ )
- if i > 0 and self.nr_of_previous_values > 0:
- self._transfer_start_values(i)
-
- calculation.do_modeling()
-
- # Warn about Investments, but only in fist run
- if i == 0:
- invest_elements = [
- model.label_full
- for component in calculation.flow_system.components.values()
- for model in component.submodel.all_submodels
- if isinstance(model, InvestmentModel)
- ]
- if invest_elements:
- logger.critical(
- f'Investments are not supported in Segmented Calculation! '
- f'Following InvestmentModels were found: {invest_elements}'
- )
-
- with fx_io.suppress_output():
- calculation.solve(
- solver,
- log_file=pathlib.Path(log_file) if log_file is not None else self.folder / f'{self.name}.log',
- log_main_results=log_main_results,
+ try:
+ for i, calculation in progress_bar:
+ progress_bar.set_description(
+ f'Solving ({calculation.flow_system.timesteps[0]} -> {calculation.flow_system.timesteps[-1]})'
)
- finally:
- progress_bar.close()
+ self._solve_single_segment(i, calculation, solver, log_file, log_main_results, suppress_output=True)
+ finally:
+ progress_bar.close()
for calc in self.sub_calculations:
for key, value in calc.durations.items():
From 691d95c3f8105ce08fd33d9c68f4705975465d40 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Fri, 31 Oct 2025 17:17:23 +0100
Subject: [PATCH 24/27] USe config show in exmaples
---
examples/02_Complex/complex_example_results.py | 2 +-
examples/03_Calculation_types/example_calculation_types.py | 2 +-
flixopt/calculation.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/examples/02_Complex/complex_example_results.py b/examples/02_Complex/complex_example_results.py
index edc2f7a1d..96191c4d8 100644
--- a/examples/02_Complex/complex_example_results.py
+++ b/examples/02_Complex/complex_example_results.py
@@ -18,7 +18,7 @@
) from e
# --- Basic overview ---
- results.plot_network(show=True)
+ results.plot_network()
results['FernwΓ€rme'].plot_node_balance()
# --- Detailed Plots ---
diff --git a/examples/03_Calculation_types/example_calculation_types.py b/examples/03_Calculation_types/example_calculation_types.py
index 210747db9..e339c1c24 100644
--- a/examples/03_Calculation_types/example_calculation_types.py
+++ b/examples/03_Calculation_types/example_calculation_types.py
@@ -164,7 +164,7 @@
a_kwk,
a_speicher,
)
- flow_system.plot_network(controls=False, show=True)
+ flow_system.plot_network()
# Calculations
calculations: list[fx.FullCalculation | fx.AggregatedCalculation | fx.SegmentedCalculation] = []
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index c0b2a54c6..875b3967b 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -370,7 +370,7 @@ def _perform_aggregation(self):
)
self.aggregation.cluster()
- self.aggregation.plot(show=True, save=self.folder / 'aggregation.html')
+ self.aggregation.plot(show=CONFIG.Plotting.default_show, save=self.folder / 'aggregation.html')
if self.aggregation_parameters.aggregate_data_and_fix_non_binary_vars:
ds = self.flow_system.to_dataset()
for name, series in self.aggregation.aggregated_data.items():
From 8a504ef77ed16dfa121683b4b27f78a9e421870c Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Fri, 31 Oct 2025 17:39:32 +0100
Subject: [PATCH 25/27] USe config show in results.plot_network()
---
flixopt/results.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/flixopt/results.py b/flixopt/results.py
index 26eaf9d5d..c02e5b769 100644
--- a/flixopt/results.py
+++ b/flixopt/results.py
@@ -1029,14 +1029,14 @@ def plot_network(
]
) = True,
path: pathlib.Path | None = None,
- show: bool = False,
+ show: bool | None = None,
) -> pyvis.network.Network | None:
"""Plot interactive network visualization of the system.
Args:
controls: Enable/disable interactive controls.
path: Save path for network HTML.
- show: Whether to display the plot.
+ show: Whether to display the plot. If None, uses CONFIG.Plotting.default_show.
"""
if path is None:
path = self.folder / f'{self.name}--network.html'
From 59b125a925871b5bc23a463b2d8d0d686be9cff1 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Fri, 31 Oct 2025 17:39:40 +0100
Subject: [PATCH 26/27] Improve readabailty of code
---
flixopt/calculation.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/flixopt/calculation.py b/flixopt/calculation.py
index 875b3967b..feb077dcf 100644
--- a/flixopt/calculation.py
+++ b/flixopt/calculation.py
@@ -252,7 +252,8 @@ def solve(
)
# Log the formatted output
- if log_main_results if log_main_results is not None else CONFIG.Solving.log_main_results:
+ should_log = log_main_results if log_main_results is not None else CONFIG.Solving.log_main_results
+ if should_log:
logger.info(
f'{" Main Results ":#^80}\n'
+ yaml.dump(
From f3f54c94f30bea00e724f459b044791b8d4f4530 Mon Sep 17 00:00:00 2001
From: FBumann <117816358+FBumann@users.noreply.github.com>
Date: Sat, 1 Nov 2025 13:48:27 +0100
Subject: [PATCH 27/27] Typo
---
CHANGELOG.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2e4912d48..befccf890 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -67,7 +67,7 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp
- Solver parameters can still be explicitly overridden when creating solver instances
### π₯ Breaking Changes
-- Individual solver output is now hidden in **SegmentedCalculation**. To return t the prior behaviour, set `show_individual_solves=True` in `do_modeling_and_solve()`.
+- Individual solver output is now hidden in **SegmentedCalculation**. To return to the prior behaviour, set `show_individual_solves=True` in `do_modeling_and_solve()`.
### β»οΈ Changed