diff --git a/CHANGELOG.md b/CHANGELOG.md index 976e6316b..5817377f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,27 +53,21 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOpt/flixOpt/releases/tag/v3.0.0) and [Migration Guide](https://flixopt.github.io/flixopt/latest/user-guide/migration-guide-v3/). ### ✨ Added - -### 💥 Breaking Changes +- **Simplified color management**: Configure consistent plot colors with explicit component grouping + - Direct colors: `results.setup_colors({'Boiler1': '#FF0000', 'CHP': 'darkred'})` + - Grouped colors: `results.setup_colors({'oranges': ['Solar1', 'Solar2'], 'blues': ['Wind1', 'Wind2']})` + - Mixed approach: Combine direct and grouped colors in a single call + - File-based: `results.setup_colors('colors.yaml')` (YAML only) +- **Heatmap fill control**: Control missing value handling with `fill='ffill'` or `fill='bfill'` ### ♻️ Changed +- Plotting methods now use `color_manager` by default if configured ### 🗑️ Deprecated -### 🔥 Removed - ### 🐛 Fixed - -### 🔒 Security - -### 📦 Dependencies - -### 📝 Docs - -### 👷 Development -- Fixed concurrency issue in CI - -### 🚧 Known Issues +- Improved error messages for matplotlib with multidimensional data +- Better dimension validation in `plot_heatmap()` --- diff --git a/examples/00_Minmal/minimal_example.py b/examples/00_Minmal/minimal_example.py index 81b7c2dba..aab2797be 100644 --- a/examples/00_Minmal/minimal_example.py +++ b/examples/00_Minmal/minimal_example.py @@ -11,6 +11,7 @@ if __name__ == '__main__': # Enable console logging fx.CONFIG.Logging.console = True + fx.CONFIG.Plotting.default_show = True fx.CONFIG.apply() # --- Define the Flow System, that will hold all elements, and the time steps you want to model --- timesteps = pd.date_range('2020-01-01', periods=3, freq='h') diff --git a/examples/01_Simple/simple_example.py b/examples/01_Simple/simple_example.py index 906c24622..36cbd9d7c 100644 --- a/examples/01_Simple/simple_example.py +++ b/examples/01_Simple/simple_example.py @@ -10,6 +10,7 @@ if __name__ == '__main__': # Enable console logging fx.CONFIG.Logging.console = True + fx.CONFIG.Plotting.default_show = True fx.CONFIG.apply() # --- Create Time Series Data --- # Heat demand profile (e.g., kW) over time and corresponding power prices @@ -45,11 +46,9 @@ # --- Define Flow System Components --- # Boiler: Converts fuel (gas) into thermal energy (heat) - boiler = fx.linear_converters.Boiler( + boiler = fx.Source( label='Boiler', - eta=0.5, - Q_th=fx.Flow(label='Q_th', bus='Fernwärme', size=50, relative_minimum=0.1, relative_maximum=1), - Q_fu=fx.Flow(label='Q_fu', bus='Gas'), + outputs=[fx.Flow(label=str(i), bus='Fernwärme', size=5) for i in range(10)], ) # Combined Heat and Power (CHP): Generates both electricity and heat from fuel @@ -112,6 +111,9 @@ calculation.solve(fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=30)) # --- Analyze Results --- + # Colors are automatically assigned using default colormap + # Optional: Configure custom colors with + calculation.results.setup_colors({'CHP': 'red'}) calculation.results['Fernwärme'].plot_node_balance_pie() calculation.results['Fernwärme'].plot_node_balance() calculation.results['Storage'].plot_charge_state() diff --git a/examples/02_Complex/complex_example.py b/examples/02_Complex/complex_example.py index 805cb08f6..7a4742284 100644 --- a/examples/02_Complex/complex_example.py +++ b/examples/02_Complex/complex_example.py @@ -205,8 +205,11 @@ # You can analyze results directly or save them to file and reload them later. calculation.results.to_file() - # But let's plot some results anyway - calculation.results.plot_heatmap('BHKW2(Q_th)|flow_rate') - calculation.results['BHKW2'].plot_node_balance() - calculation.results['Speicher'].plot_charge_state() - calculation.results['Fernwärme'].plot_node_balance_pie() + # Optional: Configure custom colors (dict is simplest): + calculation.results.setup_colors({'BHKW': 'orange', 'Speicher': 'green'}) + + # Plot results (colors are automatically assigned to components) + calculation.results.plot_heatmap('BHKW2(Q_th)|flow_rate') # Heatmap uses continuous colors (not ColorManager) + calculation.results['BHKW2'].plot_node_balance() # Uses ColorManager + calculation.results['Speicher'].plot_charge_state() # Uses ColorManager + calculation.results['Fernwärme'].plot_node_balance_pie() # Uses ColorManager diff --git a/examples/02_Complex/complex_example_results.py b/examples/02_Complex/complex_example_results.py index 5020f71fe..8eba4de50 100644 --- a/examples/02_Complex/complex_example_results.py +++ b/examples/02_Complex/complex_example_results.py @@ -18,6 +18,11 @@ f'Original error: {e}' ) from e + # --- Configure Color Mapping for Consistent Plot Colors (Optional) --- + results.setup_colors({'Solar*': 'oranges', 'Wind*': 'blues'}) # Dict (simplest) + # results.setup_colors('colors.yaml') # Or from file + # results.setup_colors().add_rule('Solar*', 'oranges') # Or programmatic + # --- Basic overview --- results.plot_network(show=True) results['Fernwärme'].plot_node_balance() @@ -25,8 +30,9 @@ # --- Detailed Plots --- # In depth plot for individual flow rates ('__' is used as the delimiter between Component and Flow results.plot_heatmap('Wärmelast(Q_th_Last)|flow_rate') - for flow_rate in results['BHKW2'].inputs + results['BHKW2'].outputs: - results.plot_heatmap(flow_rate) + for bus in results.buses.values(): + bus.plot_node_balance_pie() + bus.plot_node_balance() # --- Plotting internal variables manually --- results.plot_heatmap('BHKW2(Q_th)|on') diff --git a/examples/03_Calculation_types/example_calculation_types.py b/examples/03_Calculation_types/example_calculation_types.py index 8df8e742f..f71a8eed7 100644 --- a/examples/03_Calculation_types/example_calculation_types.py +++ b/examples/03_Calculation_types/example_calculation_types.py @@ -202,35 +202,35 @@ def get_solutions(calcs: list, variable: str) -> xr.Dataset: # --- Plotting for comparison --- fx.plotting.with_plotly( - get_solutions(calculations, 'Speicher|charge_state').to_dataframe(), + get_solutions(calculations, 'Speicher|charge_state'), mode='line', title='Charge State Comparison', ylabel='Charge state', ).write_html('results/Charge State.html') fx.plotting.with_plotly( - get_solutions(calculations, 'BHKW2(Q_th)|flow_rate').to_dataframe(), + get_solutions(calculations, 'BHKW2(Q_th)|flow_rate'), mode='line', title='BHKW2(Q_th) Flow Rate Comparison', ylabel='Flow rate', ).write_html('results/BHKW2 Thermal Power.html') fx.plotting.with_plotly( - get_solutions(calculations, 'costs(temporal)|per_timestep').to_dataframe(), + get_solutions(calculations, 'costs(temporal)|per_timestep'), mode='line', title='Operation Cost Comparison', ylabel='Costs [€]', ).write_html('results/Operation Costs.html') fx.plotting.with_plotly( - pd.DataFrame(get_solutions(calculations, 'costs(temporal)|per_timestep').to_dataframe().sum()).T, + get_solutions(calculations, 'costs(temporal)|per_timestep').sum('time'), mode='stacked_bar', title='Total Cost Comparison', ylabel='Costs [€]', ).update_layout(barmode='group').write_html('results/Total Costs.html') fx.plotting.with_plotly( - pd.DataFrame([calc.durations for calc in calculations], index=[calc.name for calc in calculations]), + pd.DataFrame([calc.durations for calc in calculations], index=[calc.name for calc in calculations]).to_xarray(), mode='stacked_bar', ).update_layout(title='Duration Comparison', xaxis_title='Calculation type', yaxis_title='Time (s)').write_html( 'results/Speed Comparison.html' diff --git a/examples/04_Scenarios/scenario_example.py b/examples/04_Scenarios/scenario_example.py index 834e55782..c366f5084 100644 --- a/examples/04_Scenarios/scenario_example.py +++ b/examples/04_Scenarios/scenario_example.py @@ -8,6 +8,8 @@ import flixopt as fx if __name__ == '__main__': + fx.CONFIG.Plotting.default_show = True + fx.CONFIG.apply() # Create datetime array starting from '2020-01-01' for one week timesteps = pd.date_range('2020-01-01', periods=24 * 7, freq='h') scenarios = pd.Index(['Base Case', 'High Demand']) @@ -196,6 +198,8 @@ # --- Solve the Calculation and Save Results --- calculation.solve(fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=30)) + calculation.results.setup_colors() + calculation.results.plot_heatmap('CHP(Q_th)|flow_rate') # --- Analyze Results --- diff --git a/flixopt/aggregation.py b/flixopt/aggregation.py index 91ef618a9..18cac8013 100644 --- a/flixopt/aggregation.py +++ b/flixopt/aggregation.py @@ -21,6 +21,7 @@ TSAM_AVAILABLE = False from .components import Storage +from .config import CONFIG from .structure import ( FlowSystemModel, Submodel, @@ -141,7 +142,7 @@ def describe_clusters(self) -> str: def use_extreme_periods(self): return self.time_series_for_high_peaks or self.time_series_for_low_peaks - def plot(self, colormap: str = 'viridis', show: bool = True, save: pathlib.Path | None = None) -> go.Figure: + def plot(self, colormap: str | None = None, show: bool = True, save: pathlib.Path | None = None) -> go.Figure: from . import plotting df_org = self.original_data.copy().rename( @@ -150,10 +151,22 @@ def plot(self, colormap: str = 'viridis', show: bool = True, save: pathlib.Path df_agg = self.aggregated_data.copy().rename( columns={col: f'Aggregated - {col}' for col in self.aggregated_data.columns} ) - fig = plotting.with_plotly(df_org, 'line', colors=colormap) + fig = plotting.with_plotly( + df_org.to_xarray(), + 'line', + colors=colormap or CONFIG.Plotting.default_qualitative_colorscale, + xlabel='Time in h', + ) for trace in fig.data: trace.update(dict(line=dict(dash='dash'))) - fig = plotting.with_plotly(df_agg, 'line', colors=colormap, fig=fig) + fig2 = plotting.with_plotly( + df_agg.to_xarray(), + 'line', + colors=colormap or CONFIG.Plotting.default_qualitative_colorscale, + xlabel='Time in h', + ) + for trace in fig2.data: + fig.add_trace(trace) fig.update_layout( title='Original vs Aggregated Data (original = ---)', xaxis_title='Index', yaxis_title='Value' diff --git a/flixopt/config.py b/flixopt/config.py index a7549a3ec..14adb176f 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import sys import warnings from logging.handlers import RotatingFileHandler @@ -54,6 +55,22 @@ 'big_binary_bound': 100_000, } ), + 'plotting': MappingProxyType( + { + 'plotly_renderer': 'browser', + 'plotly_template': 'plotly_white', + 'matplotlib_backend': None, + 'default_show': False, + 'default_save_path': None, + 'default_engine': 'plotly', + 'default_dpi': 300, + 'default_figure_width': None, + 'default_figure_height': None, + 'default_facet_cols': 3, + 'default_sequential_colorscale': 'turbo', + 'default_qualitative_colorscale': 'plotly', + } + ), } ) @@ -185,6 +202,86 @@ class Modeling: epsilon: float = _DEFAULTS['modeling']['epsilon'] big_binary_bound: int = _DEFAULTS['modeling']['big_binary_bound'] + class Plotting: + """Plotting configuration. + + Attributes: + plotly_renderer: Plotly renderer to use. + plotly_template: Plotly theme/template applied to all plots. + matplotlib_backend: Matplotlib backend to use. + default_show: Default value for the `show` parameter in plot methods. + default_save_path: Default directory for saving plots. + default_engine: Default plotting engine. + default_dpi: Default DPI for saved plots. + default_figure_width: Default plot width in pixels. + default_figure_height: Default plot height in pixels. + default_facet_cols: Default number of columns for faceted plots. + default_sequential_colorscale: Default colorscale for heatmaps and continuous data. + default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts). + + Examples: + ```python + # Set consistent theming and rendering + CONFIG.Plotting.plotly_renderer = 'browser' + CONFIG.Plotting.plotly_template = 'plotly_dark' + CONFIG.Plotting.matplotlib_backend = 'TkAgg' + CONFIG.apply() + + # Configure default export and color settings + CONFIG.Plotting.default_dpi = 600 + CONFIG.Plotting.default_figure_width = 1200 + CONFIG.Plotting.default_figure_height = 800 + CONFIG.Plotting.default_sequential_colorscale = 'plasma' + CONFIG.Plotting.default_qualitative_colorscale = 'Dark24' + CONFIG.apply() + ``` + """ + + plotly_renderer: ( + Literal['browser', 'notebook', 'svg', 'png', 'pdf', 'jpeg', 'json', 'plotly_mimetype'] | None + ) = _DEFAULTS['plotting']['plotly_renderer'] + plotly_template: ( + Literal[ + 'plotly', + 'plotly_white', + 'plotly_dark', + 'ggplot2', + 'seaborn', + 'simple_white', + 'none', + 'gridon', + 'presentation', + 'xgridoff', + 'ygridoff', + ] + | None + ) = _DEFAULTS['plotting']['plotly_template'] + matplotlib_backend: ( + Literal[ + 'TkAgg', + 'Qt5Agg', + 'QtAgg', + 'WXAgg', + 'Agg', + 'Cairo', + 'PDF', + 'PS', + 'SVG', + 'WebAgg', + 'module://backend_interagg', + ] + | None + ) = _DEFAULTS['plotting']['matplotlib_backend'] + default_show: bool = _DEFAULTS['plotting']['default_show'] + default_save_path: str | None = _DEFAULTS['plotting']['default_save_path'] + default_engine: Literal['plotly', 'matplotlib'] = _DEFAULTS['plotting']['default_engine'] + default_dpi: int = _DEFAULTS['plotting']['default_dpi'] + default_figure_width: int | None = _DEFAULTS['plotting']['default_figure_width'] + default_figure_height: int | None = _DEFAULTS['plotting']['default_figure_height'] + default_facet_cols: int = _DEFAULTS['plotting']['default_facet_cols'] + default_sequential_colorscale: str = _DEFAULTS['plotting']['default_sequential_colorscale'] + default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale'] + config_name: str = _DEFAULTS['config_name'] @classmethod @@ -201,12 +298,15 @@ def reset(cls): for key, value in _DEFAULTS['modeling'].items(): setattr(cls.Modeling, key, value) + for key, value in _DEFAULTS['plotting'].items(): + setattr(cls.Plotting, key, value) + cls.config_name = _DEFAULTS['config_name'] cls.apply() @classmethod def apply(cls): - """Apply current configuration to logging system.""" + """Apply current configuration to logging and plotting systems.""" # Convert Colors class attributes to dict colors_dict = { 'DEBUG': cls.Logging.Colors.DEBUG, @@ -243,6 +343,13 @@ def apply(cls): colors=colors_dict, ) + # Apply plotting configuration + _apply_plotting_config( + plotly_renderer=cls.Plotting.plotly_renderer, + plotly_template=cls.Plotting.plotly_template, + matplotlib_backend=cls.Plotting.matplotlib_backend, + ) + @classmethod def load_from_file(cls, config_file: str | Path): """Load configuration from YAML file and apply it. @@ -282,6 +389,9 @@ def _apply_config_dict(cls, config_dict: dict): elif key == 'modeling' and isinstance(value, dict): for nested_key, nested_value in value.items(): setattr(cls.Modeling, nested_key, nested_value) + elif key == 'plotting' and isinstance(value, dict): + for nested_key, nested_value in value.items(): + setattr(cls.Plotting, nested_key, nested_value) elif hasattr(cls, key): setattr(cls, key, value) @@ -319,6 +429,20 @@ def to_dict(cls) -> dict: 'epsilon': cls.Modeling.epsilon, 'big_binary_bound': cls.Modeling.big_binary_bound, }, + 'plotting': { + 'plotly_renderer': cls.Plotting.plotly_renderer, + 'plotly_template': cls.Plotting.plotly_template, + 'matplotlib_backend': cls.Plotting.matplotlib_backend, + 'default_show': cls.Plotting.default_show, + 'default_save_path': cls.Plotting.default_save_path, + 'default_engine': cls.Plotting.default_engine, + 'default_dpi': cls.Plotting.default_dpi, + 'default_figure_width': cls.Plotting.default_figure_width, + 'default_figure_height': cls.Plotting.default_figure_height, + 'default_facet_cols': cls.Plotting.default_facet_cols, + 'default_sequential_colorscale': cls.Plotting.default_sequential_colorscale, + 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale, + }, } @@ -588,6 +712,94 @@ def _setup_logging( logger.addHandler(logging.NullHandler()) +def _apply_plotting_config( + plotly_renderer: Literal['browser', 'notebook', 'svg', 'png', 'pdf', 'jpeg', 'json', 'plotly_mimetype'] + | None = 'browser', + plotly_template: Literal[ + 'plotly', + 'plotly_white', + 'plotly_dark', + 'ggplot2', + 'seaborn', + 'simple_white', + 'none', + 'gridon', + 'presentation', + 'xgridoff', + 'ygridoff', + ] + | None = 'plotly', + matplotlib_backend: Literal[ + 'TkAgg', 'Qt5Agg', 'QtAgg', 'WXAgg', 'Agg', 'Cairo', 'PDF', 'PS', 'SVG', 'WebAgg', 'module://backend_interagg' + ] + | None = None, +) -> None: + """Apply plotting configuration to plotly and matplotlib. + + Args: + plotly_renderer: Plotly renderer to use. + plotly_template: Plotly template/theme to apply to all plots. + matplotlib_backend: Matplotlib backend to use. If None, the existing backend is not changed. + """ + # Configure Plotly renderer and template + try: + import plotly.io as pio + + if plotly_renderer is not None: + pio.renderers.default = plotly_renderer + + if plotly_template is not None: + pio.templates.default = plotly_template + except ImportError: + # Plotly not installed, skip configuration + pass + + # Configure Matplotlib backend + if matplotlib_backend is not None: + try: + import matplotlib + + # Check if pyplot has been imported yet + pyplot_imported = 'matplotlib.pyplot' in sys.modules + + if not pyplot_imported: + # Safe path: Set environment variable before pyplot import + # This is the preferred method as it avoids runtime backend switching + os.environ['MPLBACKEND'] = matplotlib_backend + logger.debug(f"Set MPLBACKEND environment variable to '{matplotlib_backend}'") + else: + # pyplot is already imported - check if we need to switch + current_backend = matplotlib.get_backend() + + if current_backend == matplotlib_backend: + logger.debug(f"matplotlib backend already set to '{matplotlib_backend}'") + else: + # Need to switch backend - check if it's safe + import matplotlib.pyplot as plt + + if len(plt.get_fignums()) > 0: + logger.warning( + f"Cannot switch matplotlib backend from '{current_backend}' to '{matplotlib_backend}': " + f'There are {len(plt.get_fignums())} open figures. Close all figures before changing backend, ' + f'or set CONFIG.Plotting.matplotlib_backend before importing matplotlib.pyplot.' + ) + else: + # No open figures - attempt safe backend switch + try: + plt.switch_backend(matplotlib_backend) + logger.info( + f"Switched matplotlib backend from '{current_backend}' to '{matplotlib_backend}'" + ) + except Exception as e: + logger.warning( + f"Failed to switch matplotlib backend from '{current_backend}' to '{matplotlib_backend}': {e}. " + f'Set CONFIG.Plotting.matplotlib_backend before importing matplotlib.pyplot to avoid this issue.' + ) + except ImportError: + # Matplotlib not installed, skip configuration + pass + + def change_logging_level(level_name: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']): """Change the logging level for the flixopt logger and all its handlers. diff --git a/flixopt/plotting.py b/flixopt/plotting.py index bd1f3c2c4..3c112d926 100644 --- a/flixopt/plotting.py +++ b/flixopt/plotting.py @@ -8,7 +8,8 @@ Key Features: **Dual Backend Support**: Seamless switching between Plotly and Matplotlib **Energy System Focus**: Specialized plots for power flows, storage states, emissions - **Color Management**: Intelligent color processing and palette management + **Color Management**: Intelligent color processing with ColorProcessor and component-based + ComponentColorManager for stable, pattern-matched coloring **Export Capabilities**: High-quality export for reports and publications **Integration Ready**: Designed for use with CalculationResults and standalone analysis @@ -42,6 +43,8 @@ import xarray as xr from plotly.exceptions import PlotlyError +from .config import CONFIG + if TYPE_CHECKING: import pyvis @@ -72,7 +75,7 @@ Color specifications can take several forms to accommodate different use cases: **Named Colormaps** (str): - - Standard colormaps: 'viridis', 'plasma', 'cividis', 'tab10', 'Set1' + - Standard colormaps: 'turbo', 'plasma', 'cividis', 'tab10', 'Set1' - Energy-focused: 'portland' (custom flixopt colormap for energy systems) - Backend-specific maps available in Plotly and Matplotlib @@ -89,7 +92,7 @@ Examples: ```python # Named colormap - colors = 'viridis' # Automatic color generation + colors = 'turbo' # Automatic color generation # Explicit color list colors = ['red', 'blue', 'green', '#FFD700'] @@ -136,57 +139,75 @@ class ColorProcessor: **Energy System Colors**: Built-in palettes optimized for energy system visualization Color Input Types: - - **Named Colormaps**: 'viridis', 'plasma', 'portland', 'tab10', etc. + - **Named Colormaps**: 'turbo', 'plasma', 'portland', etc. - **Color Lists**: ['red', 'blue', 'green'] or ['#FF0000', '#0000FF', '#00FF00'] - **Label Dictionaries**: {'Generator': 'red', 'Storage': 'blue', 'Load': 'green'} - Examples: - Basic color processing: - + Example: ```python - # Initialize for Plotly backend - processor = ColorProcessor(engine='plotly', default_colormap='viridis') - - # Process different color specifications + processor = ColorProcessor(engine='plotly', default_colorscale='turbo') colors = processor.process_colors('plasma', ['Gen1', 'Gen2', 'Storage']) - colors = processor.process_colors(['red', 'blue', 'green'], ['A', 'B', 'C']) - colors = processor.process_colors({'Wind': 'skyblue', 'Solar': 'gold'}, ['Wind', 'Solar', 'Gas']) - - # Switch to Matplotlib - processor = ColorProcessor(engine='matplotlib') - mpl_colors = processor.process_colors('tab10', component_labels) - ``` - - Energy system visualization: - - ```python - # Specialized energy system palette - energy_colors = { - 'Natural_Gas': '#8B4513', # Brown - 'Electricity': '#FFD700', # Gold - 'Heat': '#FF4500', # Red-orange - 'Cooling': '#87CEEB', # Sky blue - 'Hydrogen': '#E6E6FA', # Lavender - 'Battery': '#32CD32', # Lime green - } - - processor = ColorProcessor('plotly') - flow_colors = processor.process_colors(energy_colors, flow_labels) ``` Args: engine: Plotting backend ('plotly' or 'matplotlib'). Determines output color format. - default_colormap: Fallback colormap when requested palettes are unavailable. - Common options: 'viridis', 'plasma', 'tab10', 'portland'. + default_colorscale: Fallback colormap when requested palettes are unavailable. + Common options: 'turbo', 'plasma', 'portland'. """ - def __init__(self, engine: PlottingEngine = 'plotly', default_colormap: str = 'viridis'): + def __init__(self, engine: PlottingEngine = 'plotly', default_colorscale: str | None = None): """Initialize the color processor with specified backend and defaults.""" if engine not in ['plotly', 'matplotlib']: raise TypeError(f'engine must be "plotly" or "matplotlib", but is {engine}') self.engine = engine - self.default_colormap = default_colormap + self.default_colorscale = ( + default_colorscale if default_colorscale is not None else CONFIG.Plotting.default_qualitative_colorscale + ) + + def _get_sequential_colorscale(self, colormap_name: str, num_colors: int) -> list[str] | None: + try: + colorscale = px.colors.get_colorscale(colormap_name) + # Generate evenly spaced points + color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0] + return px.colors.sample_colorscale(colorscale, color_points) + except PlotlyError: + return None + + def _get_plotly_colormap_robust(self, colormap_name: str, num_colors: int) -> list[str]: + # First try qualitative color sequences (Dark24, Plotly, Set1, etc.) + colormap_title = colormap_name.title() + if hasattr(px.colors.qualitative, colormap_title): + color_list = getattr(px.colors.qualitative, colormap_title) + # Cycle through colors if we need more than available + return [color_list[i % len(color_list)] for i in range(num_colors)] + + # Then try sequential/continuous colorscales (turbo, plasma, etc.) + colors = self._get_sequential_colorscale(colormap_name, num_colors) + if colors is not None: + return colors + + # Fallback to default_colorscale + logger.warning(f"Colormap '{colormap_name}' not found in Plotly. Trying default '{self.default_colorscale}'") + + # Try default as qualitative + default_title = self.default_colorscale.title() + if hasattr(px.colors.qualitative, default_title): + color_list = getattr(px.colors.qualitative, default_title) + return [color_list[i % len(color_list)] for i in range(num_colors)] + + # Try default as sequential + colors = self._get_sequential_colorscale(self.default_colorscale, num_colors) + if colors is not None: + return colors + + # Ultimate fallback: use built-in Plotly qualitative colormap + logger.warning( + f"Both '{colormap_name}' and default '{self.default_colorscale}' not found. " + f"Using hardcoded fallback 'Plotly' colormap" + ) + color_list = px.colors.qualitative.Plotly + return [color_list[i % len(color_list)] for i in range(num_colors)] def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list[Any]: """ @@ -200,22 +221,23 @@ def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list of colors in the format appropriate for the engine """ if self.engine == 'plotly': - try: - colorscale = px.colors.get_colorscale(colormap_name) - except PlotlyError as e: - logger.error(f"Colorscale '{colormap_name}' not found in Plotly. Using {self.default_colormap}: {e}") - colorscale = px.colors.get_colorscale(self.default_colormap) - - # Generate evenly spaced points - color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0] - return px.colors.sample_colorscale(colorscale, color_points) + return self._get_plotly_colormap_robust(colormap_name, num_colors) else: # matplotlib try: cmap = plt.get_cmap(colormap_name, num_colors) except ValueError as e: - logger.error(f"Colormap '{colormap_name}' not found in Matplotlib. Using {self.default_colormap}: {e}") - cmap = plt.get_cmap(self.default_colormap, num_colors) + logger.warning( + f"Colormap '{colormap_name}' not found in Matplotlib. Trying default '{self.default_colorscale}': {e}" + ) + try: + cmap = plt.get_cmap(self.default_colorscale, num_colors) + except ValueError: + logger.warning( + f"Default colormap '{self.default_colorscale}' also not found in Matplotlib. " + f"Using hardcoded fallback 'tab10'" + ) + cmap = plt.get_cmap('tab10', num_colors) return [cmap(i) for i in range(num_colors)] @@ -231,8 +253,8 @@ def _handle_color_list(self, colors: list[str], num_labels: int) -> list[str]: list of colors matching the number of labels """ if len(colors) == 0: - logger.error(f'Empty color list provided. Using {self.default_colormap} instead.') - return self._generate_colors_from_colormap(self.default_colormap, num_labels) + logger.error(f'Empty color list provided. Using {self.default_colorscale} instead.') + return self._generate_colors_from_colormap(self.default_colorscale, num_labels) if len(colors) < num_labels: logger.warning( @@ -261,18 +283,18 @@ def _handle_color_dict(self, colors: dict[str, str], labels: list[str]) -> list[ list of colors in the same order as labels """ if len(colors) == 0: - logger.warning(f'Empty color dictionary provided. Using {self.default_colormap} instead.') - return self._generate_colors_from_colormap(self.default_colormap, len(labels)) + logger.warning(f'Empty color dictionary provided. Using {self.default_colorscale} instead.') + return self._generate_colors_from_colormap(self.default_colorscale, len(labels)) # Find missing labels missing_labels = sorted(set(labels) - set(colors.keys())) if missing_labels: logger.warning( - f'Some labels have no color specified: {missing_labels}. Using {self.default_colormap} for these.' + f'Some labels have no color specified: {missing_labels}. Using {self.default_colorscale} for these.' ) # Generate colors for missing labels - missing_colors = self._generate_colors_from_colormap(self.default_colormap, len(missing_labels)) + missing_colors = self._generate_colors_from_colormap(self.default_colorscale, len(missing_labels)) # Create a copy to avoid modifying the original colors_copy = colors.copy() @@ -315,9 +337,9 @@ def process_colors( color_list = self._handle_color_dict(colors, labels) else: logger.error( - f'Unsupported color specification type: {type(colors)}. Using {self.default_colormap} instead.' + f'Unsupported color specification type: {type(colors)}. Using {self.default_colorscale} instead.' ) - color_list = self._generate_colors_from_colormap(self.default_colormap, len(labels)) + color_list = self._generate_colors_from_colormap(self.default_colorscale, len(labels)) # Return either a list or a mapping if return_mapping: @@ -326,19 +348,350 @@ def process_colors( return color_list +class ComponentColorManager: + """Manage consistent colors for flow system components. + + Assign direct colors or group components to get shades from colorscales. + Colorscale families: blues, greens, oranges, reds, purples, teals, greys, etc. + + Example: + ```python + manager = ComponentColorManager() + manager.configure( + { + 'Boiler1': '#FF0000', # Direct color + 'oranges': ['Solar1', 'Solar2'], # Group gets orange shades + } + ) + colors = manager.get_variable_colors(['Boiler1(Bus_A)|flow']) + ``` + """ + + # Class-level colorscale family defaults (Plotly sequential palettes, reversed) + # Reversed so darker colors come first when assigning to components + DEFAULT_FAMILIES = { + 'blues': px.colors.sequential.Blues[7:0:-1], + 'greens': px.colors.sequential.Greens[7:0:-1], + 'reds': px.colors.sequential.Reds[7:0:-1], + 'purples': px.colors.sequential.Purples[7:0:-1], + 'oranges': px.colors.sequential.Oranges[7:0:-1], + 'teals': px.colors.sequential.Teal[7:0:-1], + 'greys': px.colors.sequential.Greys[7:0:-1], + 'pinks': px.colors.sequential.Pinkyl[7:0:-1], + 'peach': px.colors.sequential.Peach[7:0:-1], + 'burg': px.colors.sequential.Burg[7:0:-1], + 'sunsetdark': px.colors.sequential.Sunsetdark[7:0:-1], + 'mint': px.colors.sequential.Mint[7:0:-1], + 'emrld': px.colors.sequential.Emrld[7:0:-1], + 'darkmint': px.colors.sequential.Darkmint[7:0:-1], + } + + def __init__( + self, + components: list[str] | None = None, + default_colorscale: str | None = None, + ) -> None: + """Initialize component color manager. + + Args: + components: Optional list of all component names. If not provided, + components will be discovered from configure() calls. + default_colorscale: Default colormap for ungrouped components. + If None, uses CONFIG.Plotting.default_qualitative_colorscale. + """ + self.components = sorted(set(components)) if components else [] + self.default_colorscale = default_colorscale or CONFIG.Plotting.default_qualitative_colorscale + self.color_families = self.DEFAULT_FAMILIES.copy() + + # Computed colors: {component_name: color} + self._component_colors: dict[str, str] = {} + + # Variable color cache for performance: {variable_name: color} + self._variable_cache: dict[str, str] = {} + + # Auto-assign default colors if components provided + if self.components: + self._assign_default_colors() + + def __repr__(self) -> str: + return ( + f'ComponentColorManager(components={len(self.components)}, ' + f'colors_configured={len(self._component_colors)}, ' + f"default_colorscale='{self.default_colorscale}')" + ) + + def __str__(self) -> str: + lines = [ + 'ComponentColorManager', + f' Components: {len(self.components)}', + ] + + # Show first few components as examples + if self.components: + sample = self.components[:5] + if len(self.components) > 5: + sample_str = ', '.join(sample) + f', ... ({len(self.components) - 5} more)' + else: + sample_str = ', '.join(sample) + lines.append(f' [{sample_str}]') + + lines.append(f' Colors configured: {len(self._component_colors)}') + if self._component_colors: + for comp, color in list(self._component_colors.items())[:3]: + lines.append(f' - {comp}: {color}') + if len(self._component_colors) > 3: + lines.append(f' ... and {len(self._component_colors) - 3} more') + + lines.append(f' Default colormap: {self.default_colorscale}') + + return '\n'.join(lines) + + @classmethod + def from_flow_system(cls, flow_system, **kwargs): + """Create ComponentColorManager from a FlowSystem.""" + from .flow_system import FlowSystem + + if not isinstance(flow_system, FlowSystem): + raise TypeError(f'Expected FlowSystem, got {type(flow_system).__name__}') + + # Extract component names + components = list(flow_system.components.keys()) + + return cls(components=components, **kwargs) + + def configure(self, config: dict[str, str | list[str]] | str | pathlib.Path) -> ComponentColorManager: + """Configure component colors from dict or YAML file. + + Args: + config: Dict with 'component': 'color' or 'colorscale': ['comp1', 'comp2'], + or path to YAML file with same format. + """ + # Load from file if path provided + if isinstance(config, (str, pathlib.Path)): + config = self._load_config_from_file(config) + + if not isinstance(config, dict): + raise TypeError(f'Config must be dict or file path, got {type(config).__name__}') + + # Process config: distinguish between direct colors and grouped colors + for key, value in config.items(): + if isinstance(value, str): + # Direct assignment: component → color + self._component_colors[key] = value + # Add to components list if not already there + if key not in self.components: + self.components.append(key) + self.components.sort() + + elif isinstance(value, list): + # Group assignment: colorscale → [components] + colorscale_name = key + components = value + + # Sample N colors from the colorscale + colors = self._sample_colors_from_colorscale(colorscale_name, len(components)) + + # Assign each component a color + for component, color in zip(components, colors, strict=False): + self._component_colors[component] = color + # Add to components list if not already there + if component not in self.components: + self.components.append(component) + self.components.sort() + + else: + raise TypeError( + f'Invalid config value type for key "{key}". ' + f'Expected str (color) or list[str] (components), got {type(value).__name__}' + ) + + # Clear cache since colors changed + self._variable_cache.clear() + + return self + + def get_color(self, component: str) -> str: + """Get color for a component (defaults to grey if unknown).""" + return self._component_colors.get(component, '#808080') + + def extract_component(self, variable: str) -> str: + """Extract component name from variable name (e.g., 'Boiler1(Bus_A)|flow' → 'Boiler1').""" + component, _ = self._extract_component_and_flow(variable) + return component + + def _extract_component_and_flow(self, variable: str) -> tuple[str, str | None]: + # Try "Component(Flow)|attribute" format + if '(' in variable and ')' in variable: + component = variable.split('(')[0] + flow = variable.split('(')[1].split(')')[0] + return component, flow + + # Try "Component|attribute" format (no flow) + if '|' in variable: + return variable.split('|')[0], None + + # Just the component name itself + return variable, None + + def get_variable_color(self, variable: str) -> str: + """Get color for a variable (extracts component automatically).""" + # Check cache first + if variable in self._variable_cache: + return self._variable_cache[variable] + + # Extract component name from variable + component = self.extract_component(variable) + + # Get color for component + color = self.get_color(component) + + # Cache and return + self._variable_cache[variable] = color + return color + + def get_variable_colors(self, variables: list[str]) -> dict[str, str]: + """Get colors for multiple variables (main API for plotting functions).""" + return {var: self.get_variable_color(var) for var in variables} + + def to_dict(self) -> dict[str, str]: + """Get complete component→color mapping.""" + return self._component_colors.copy() + + # ==================== INTERNAL METHODS ==================== + + def _assign_default_colors(self) -> None: + colors = self._sample_colors_from_colorscale(self.default_colorscale, len(self.components)) + + for component, color in zip(self.components, colors, strict=False): + self._component_colors[component] = color + + @staticmethod + def _load_config_from_file(file_path: str | pathlib.Path) -> dict[str, str | list[str]]: + """Load color configuration from YAML file.""" + file_path = pathlib.Path(file_path) + + if not file_path.exists(): + raise FileNotFoundError(f'Color configuration file not found: {file_path}') + + # Only support YAML + suffix = file_path.suffix.lower() + if suffix not in ['.yaml', '.yml']: + raise ValueError(f'Unsupported file format: {suffix}. Only YAML (.yaml, .yml) is supported.') + + try: + import yaml + except ImportError as e: + raise ImportError( + 'PyYAML is required to load YAML config files. Install it with: pip install pyyaml' + ) from e + + with open(file_path, encoding='utf-8') as f: + config = yaml.safe_load(f) + + # Validate config structure + if not isinstance(config, dict): + raise ValueError(f'Invalid config file structure. Expected dict, got {type(config).__name__}') + + return config + + def _sample_colors_from_colorscale(self, colorscale_name: str, num_colors: int) -> list[str]: + # Check custom families first (ComponentColorManager-specific feature) + if colorscale_name in self.color_families: + color_list = self.color_families[colorscale_name] + # Cycle through colors if needed + if len(color_list) >= num_colors: + return color_list[:num_colors] + else: + return [color_list[i % len(color_list)] for i in range(num_colors)] + + # Delegate everything else to ColorProcessor (handles qualitative, sequential, fallbacks, cycling) + processor = ColorProcessor(engine='plotly', default_colorscale=self.default_colorscale) + return processor._generate_colors_from_colormap(colorscale_name, num_colors) + + +def _ensure_dataset(data: xr.Dataset | pd.DataFrame) -> xr.Dataset: + """Convert DataFrame to Dataset if needed.""" + if isinstance(data, xr.Dataset): + return data + elif isinstance(data, pd.DataFrame): + # Convert DataFrame to Dataset + return data.to_xarray() + else: + raise TypeError(f'Data must be xr.Dataset or pd.DataFrame, got {type(data).__name__}') + + +def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None: + """Validate dataset for plotting (checks for empty data, non-numeric types, etc.).""" + # Check for empty data + if not allow_empty and len(data.data_vars) == 0: + raise ValueError('Empty Dataset provided (no variables). Cannot create plot.') + + # Check if dataset has any data (xarray uses nbytes for total size) + if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True: + if not allow_empty and len(data.data_vars) > 0: + raise ValueError('Dataset has zero size. Cannot create plot.') + if len(data.data_vars) == 0: + return # Empty dataset, nothing to validate + return + + # Check for non-numeric data types + for var in data.data_vars: + dtype = data[var].dtype + if not np.issubdtype(dtype, np.number): + raise TypeError( + f"Variable '{var}' has non-numeric dtype '{dtype}'. " + f'Plotting requires numeric data types (int, float, etc.).' + ) + + # Warn about NaN/Inf values + for var in data.data_vars: + if data[var].isnull().any(): + logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.") + if np.isinf(data[var].values).any(): + logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.") + + +def resolve_colors( + data: xr.Dataset, + colors: ColorType | ComponentColorManager, + engine: PlottingEngine = 'plotly', +) -> dict[str, str]: + """Resolve colors parameter to a dict mapping variable names to colors.""" + # Get variable names from Dataset (always strings and unique) + labels = list(data.data_vars.keys()) + + # If explicit dict provided, use it directly + if isinstance(colors, dict): + return colors + + # If string or list, use ColorProcessor (traditional behavior) + if isinstance(colors, (str, list)): + processor = ColorProcessor(engine=engine) + return processor.process_colors(colors, labels, return_mapping=True) + + if isinstance(colors, ComponentColorManager): + # Use color manager to resolve colors for variables + return colors.get_variable_colors(labels) + + raise TypeError(f'Wrong type passed to resolve_colors(): {type(colors)}') + + def with_plotly( - data: pd.DataFrame | xr.DataArray | xr.Dataset, + data: xr.Dataset | pd.DataFrame, mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar', - colors: ColorType = 'viridis', + colors: ColorType | ComponentColorManager | None = None, title: str = '', ylabel: str = '', - xlabel: str = 'Time in h', + xlabel: str = '', fig: go.Figure | None = None, facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, shared_yaxes: bool = True, shared_xaxes: bool = True, + trace_kwargs: dict[str, Any] | None = None, + layout_kwargs: dict[str, Any] | None = None, + **px_kwargs: Any, ) -> go.Figure: """ Plot data with Plotly using facets (subplots) and/or animation for multidimensional data. @@ -347,10 +700,14 @@ def with_plotly( For simple plots without faceting, can optionally add to an existing figure. Args: - data: A DataFrame or xarray DataArray/Dataset to plot. + data: An xarray Dataset to plot. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines, 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts. - colors: Color specification (colormap, list, or dict mapping labels to colors). + colors: Color specification. Can be: + - A colormap name (e.g., 'turbo', 'plasma') + - A list of color strings (e.g., ['#ff0000', '#00ff00']) + - A dict mapping labels to colors (e.g., {'Solar': '#FFD700'}) + - A ComponentColorManager instance for pattern-based color rules with component grouping title: The main title of the plot. ylabel: The label for the y-axis. xlabel: The label for the x-axis. @@ -364,6 +721,12 @@ def with_plotly( facet_cols: Number of columns in the facet grid (used when facet_by is single dimension). shared_yaxes: Whether subplots share y-axes. shared_xaxes: Whether subplots share x-axes. + trace_kwargs: Optional dict of parameters to pass to fig.update_traces(). + Use this to customize trace properties (e.g., marker style, line width). + layout_kwargs: Optional dict of parameters to pass to fig.update_layout(). + Use this to customize layout properties (e.g., width, height, legend position). + **px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function + (px.bar, px.line, px.area). These override default arguments if provided. Returns: A Plotly figure object containing the faceted/animated plot. @@ -372,85 +735,108 @@ def with_plotly( Simple plot: ```python - fig = with_plotly(df, mode='area', title='Energy Mix') + fig = with_plotly(dataset, mode='area', title='Energy Mix') ``` Facet by scenario: ```python - fig = with_plotly(ds, facet_by='scenario', facet_cols=2) + fig = with_plotly(dataset, facet_by='scenario', facet_cols=2) ``` Animate by period: ```python - fig = with_plotly(ds, animate_by='period') + fig = with_plotly(dataset, animate_by='period') ``` Facet and animate: ```python - fig = with_plotly(ds, facet_by='scenario', animate_by='period') + fig = with_plotly(dataset, facet_by='scenario', animate_by='period') + ``` + + Pattern-based colors with ComponentColorManager: + + ```python + manager = ComponentColorManager(['Solar', 'Wind', 'Battery', 'Gas']) + manager.add_grouping_rule('Solar', 'renewables', 'oranges', match_type='prefix') + manager.add_grouping_rule('Wind', 'renewables', 'blues', match_type='prefix') + manager.add_grouping_rule('Battery', 'storage', 'greens', match_type='contains') + manager.apply_colors() + fig = with_plotly(dataset, colors=manager, mode='area') ``` """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}") + # Apply CONFIG defaults if not explicitly set + if facet_cols is None: + facet_cols = CONFIG.Plotting.default_facet_cols + + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + # Handle empty data - if isinstance(data, pd.DataFrame) and data.empty: - return go.Figure() - elif isinstance(data, xr.DataArray) and data.size == 0: - return go.Figure() - elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0: + if len(data.data_vars) == 0: + logger.error('"with_plotly() got an empty Dataset.') return go.Figure() + # Handle all-scalar datasets (where all variables have no dimensions) + # This occurs when all variables are scalar values with dims=() + if all(len(data[var].dims) == 0 for var in data.data_vars): + # Create a simple DataFrame with variable names as x-axis + variables = list(data.data_vars.keys()) + values = [float(data[var].values) for var in data.data_vars] + + # Resolve colors + color_discrete_map = resolve_colors(data, colors, engine='plotly') + marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables] + + # Create simple plot based on mode using go (not px) for better color control + if mode in ('stacked_bar', 'grouped_bar'): + fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)]) + elif mode == 'line': + fig = go.Figure( + data=[ + go.Scatter( + x=variables, + y=values, + mode='lines+markers', + marker=dict(color=marker_colors, size=8), + line=dict(color='lightgray'), + ) + ] + ) + elif mode == 'area': + fig = go.Figure( + data=[ + go.Scatter( + x=variables, + y=values, + fill='tozeroy', + marker=dict(color=marker_colors, size=8), + line=dict(color='lightgray'), + ) + ] + ) + + fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False) + return fig + # Warn if fig parameter is used with faceting if fig is not None and (facet_by is not None or animate_by is not None): logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.') fig = None - # Convert xarray to long-form DataFrame for Plotly Express - if isinstance(data, (xr.DataArray, xr.Dataset)): - # Convert to long-form (tidy) DataFrame - # Structure: time, variable, value, scenario, period, ... (all dims as columns) - if isinstance(data, xr.Dataset): - # Stack all data variables into long format - df_long = data.to_dataframe().reset_index() - # Melt to get: time, scenario, period, ..., variable, value - id_vars = [dim for dim in data.dims] - value_vars = list(data.data_vars) - df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') - else: - # DataArray - df_long = data.to_dataframe().reset_index() - if data.name: - df_long = df_long.rename(columns={data.name: 'value'}) - else: - # Unnamed DataArray, find the value column - non_dim_cols = [col for col in df_long.columns if col not in data.dims] - if len(non_dim_cols) != 1: - raise ValueError( - f'Expected exactly one non-dimension column for unnamed DataArray, ' - f'but found {len(non_dim_cols)}: {non_dim_cols}' - ) - value_col = non_dim_cols[0] - df_long = df_long.rename(columns={value_col: 'value'}) - df_long['variable'] = data.name or 'data' - else: - # Already a DataFrame - convert to long format for Plotly Express - df_long = data.reset_index() - if 'time' not in df_long.columns: - # First column is probably time - df_long = df_long.rename(columns={df_long.columns[0]: 'time'}) - # Melt to long format - id_vars = [ - col - for col in df_long.columns - if col in ['time', 'scenario', 'period'] - or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else []) - ] - value_vars = [col for col in df_long.columns if col not in id_vars] - df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') + # Convert Dataset to long-form DataFrame for Plotly Express + # Structure: time, variable, value, scenario, period, ... (all dims as columns) + dim_names = list(data.dims) + df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value') # Validate facet_by and animate_by dimensions exist in the data available_dims = [col for col in df_long.columns if col not in ['variable', 'value']] @@ -500,15 +886,38 @@ def with_plotly( else: raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}') - # Process colors + # Process colors using resolve_colors (handles validation and all color types) + color_discrete_map = resolve_colors(data, colors, engine='plotly') + + # Get unique variable names for area plot processing all_vars = df_long['variable'].unique().tolist() - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars) - color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)} + + # Determine which dimension to use for x-axis + # Collect dimensions used for faceting and animation + used_dims = set() + if facet_row: + used_dims.add(facet_row) + if facet_col: + used_dims.add(facet_col) + if animate_by: + used_dims.add(animate_by) + + # Find available dimensions for x-axis (not used for faceting/animation) + x_candidates = [d for d in available_dims if d not in used_dims] + + # Use 'time' if available, otherwise use the first available dimension + if 'time' in x_candidates: + x_dim = 'time' + elif len(x_candidates) > 0: + x_dim = x_candidates[0] + else: + # Fallback: use the first dimension (shouldn't happen in normal cases) + x_dim = available_dims[0] if available_dims else 'time' # Create plot using Plotly Express based on mode common_args = { 'data_frame': df_long, - 'x': 'time', + 'x': x_dim, 'y': 'value', 'color': 'variable', 'facet_row': facet_row, @@ -516,13 +925,16 @@ def with_plotly( 'animation_frame': animate_by, 'color_discrete_map': color_discrete_map, 'title': title, - 'labels': {'value': ylabel, 'time': xlabel, 'variable': ''}, + 'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''}, } # Add facet_col_wrap for single facet dimension if facet_col and not facet_row: common_args['facet_col_wrap'] = facet_cols + # Apply user-provided Plotly Express kwargs (overrides defaults) + common_args.update(px_kwargs) + if mode == 'stacked_bar': fig = px.bar(**common_args) fig.update_traces(marker_line_width=0) @@ -577,50 +989,49 @@ def with_plotly( if hasattr(trace, 'fill'): trace.fill = None - # Update layout with basic styling (Plotly Express handles sizing automatically) - fig.update_layout( - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - font=dict(size=12), - ) - # Update axes to share if requested (Plotly Express already handles this, but we can customize) if not shared_yaxes: fig.update_yaxes(matches=None) if not shared_xaxes: fig.update_xaxes(matches=None) + # Apply user-provided trace and layout customizations + if trace_kwargs: + fig.update_traces(**trace_kwargs) + if layout_kwargs: + fig.update_layout(**layout_kwargs) + return fig def with_matplotlib( - data: pd.DataFrame, + data: xr.Dataset | pd.DataFrame, mode: Literal['stacked_bar', 'line'] = 'stacked_bar', - colors: ColorType = 'viridis', + colors: ColorType | ComponentColorManager | None = None, title: str = '', ylabel: str = '', xlabel: str = 'Time in h', figsize: tuple[int, int] = (12, 6), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, + plot_kwargs: dict[str, Any] | None = None, ) -> tuple[plt.Figure, plt.Axes]: """ - Plot a DataFrame with Matplotlib using stacked bars or stepped lines. + Plot data with Matplotlib using stacked bars or stepped lines. Args: - data: A DataFrame containing the data to plot. The index should represent time (e.g., hours), - and each column represents a separate data series. + data: An xarray Dataset to plot. After conversion to DataFrame, + the index represents time and each column represents a separate data series (variables). mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines. - colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') + colors: Color specification. Can be: + - A colormap name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A dict mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A ComponentColorManager instance for pattern-based color rules with grouping and sorting title: The title of the plot. ylabel: The ylabel of the plot. xlabel: The xlabel of the plot. - figsize: Specify the size of the figure - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. + figsize: Specify the size of the figure (width, height) in inches. + plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls. + Use this to customize plot properties (e.g., linewidth, alpha, edgecolor). Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. @@ -629,49 +1040,118 @@ def with_matplotlib( - If `mode` is 'stacked_bar', bars are stacked for both positive and negative values. Negative values are stacked separately without extra labels in the legend. - If `mode` is 'line', stepped lines are drawn for each data series. + + Examples: + With ComponentColorManager: + + ```python + manager = ComponentColorManager(['Solar', 'Wind', 'Coal']) + manager.add_grouping_rule('Solar', 'renewables', 'oranges', match_type='prefix') + manager.add_grouping_rule('Wind', 'renewables', 'blues', match_type='prefix') + manager.apply_colors() + fig, ax = with_matplotlib(dataset, colors=manager, mode='line') + ``` """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + if mode not in ('stacked_bar', 'line'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}") - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + + # Create new figure and axes + fig, ax = plt.subplots(figsize=figsize) + + # Initialize plot_kwargs if not provided + if plot_kwargs is None: + plot_kwargs = {} + + # Handle all-scalar datasets (where all variables have no dimensions) + # This occurs when all variables are scalar values with dims=() + if all(len(data[var].dims) == 0 for var in data.data_vars): + # Create simple bar/line plot with variable names as x-axis + variables = list(data.data_vars.keys()) + values = [float(data[var].values) for var in data.data_vars] + + # Resolve colors + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + colors_list = [color_discrete_map.get(var, '#808080') for var in variables] + + # Create plot based on mode + if mode == 'stacked_bar': + ax.bar(variables, values, color=colors_list, **plot_kwargs) + elif mode == 'line': + ax.plot( + variables, + values, + marker='o', + color=colors_list[0] if len(set(colors_list)) == 1 else None, + **plot_kwargs, + ) + # If different colors, plot each point separately + if len(set(colors_list)) > 1: + ax.clear() + for i, (var, val) in enumerate(zip(variables, values, strict=False)): + ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs) + ax.set_xticks(range(len(variables))) + ax.set_xticklabels(variables) + + ax.set_xlabel(xlabel, ha='center') + ax.set_ylabel(ylabel, va='center') + ax.set_title(title) + ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y') + fig.tight_layout() + + return fig, ax + + # Resolve colors first (includes validation) + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') - processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, list(data.columns)) + # Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form) + df = data.to_dataframe() + + # Get colors in column order + processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns] if mode == 'stacked_bar': - cumulative_positive = np.zeros(len(data)) - cumulative_negative = np.zeros(len(data)) - width = data.index.to_series().diff().dropna().min() # Minimum time difference + cumulative_positive = np.zeros(len(df)) + cumulative_negative = np.zeros(len(df)) + width = df.index.to_series().diff().dropna().min() # Minimum time difference - for i, column in enumerate(data.columns): - positive_values = np.clip(data[column], 0, None) # Keep only positive values - negative_values = np.clip(data[column], None, 0) # Keep only negative values + for i, column in enumerate(df.columns): + positive_values = np.clip(df[column], 0, None) # Keep only positive values + negative_values = np.clip(df[column], None, 0) # Keep only negative values # Plot positive bars ax.bar( - data.index, + df.index, positive_values, bottom=cumulative_positive, color=processed_colors[i], label=column, width=width, align='center', + **plot_kwargs, ) cumulative_positive += positive_values.values # Plot negative bars ax.bar( - data.index, + df.index, negative_values, bottom=cumulative_negative, color=processed_colors[i], label='', # No label for negative bars width=width, align='center', + **plot_kwargs, ) cumulative_negative += negative_values.values elif mode == 'line': - for i, column in enumerate(data.columns): - ax.step(data.index, data[column], where='post', color=processed_colors[i], label=column) + for i, column in enumerate(df.columns): + ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs) # Aesthetics ax.set_xlabel(xlabel, ha='center') @@ -945,62 +1425,97 @@ def plot_network( def pie_with_plotly( - data: pd.DataFrame, - colors: ColorType = 'viridis', + data: xr.Dataset | pd.DataFrame, + colors: ColorType | ComponentColorManager | None = None, title: str = '', legend_title: str = '', hole: float = 0.0, fig: go.Figure | None = None, + hover_template: str = '%{label}: %{value} (%{percent})', + text_info: str = 'percent+label+value', + text_position: str = 'inside', ) -> go.Figure: """ - Create a pie chart with Plotly to visualize the proportion of values in a DataFrame. + Create a pie chart with Plotly to visualize the proportion of values in a Dataset. Args: - data: A DataFrame containing the data to plot. If multiple rows exist, - they will be summed unless a specific index value is passed. + data: An xarray Dataset containing the data to plot. All dimensions will be summed + to get the total for each variable. colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') + - A string with a colorscale name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A dictionary mapping variable names to colors (e.g., {'Solar': '#ff0000'}) + - A ComponentColorManager instance for pattern-based color rules title: The title of the plot. legend_title: The title for the legend. hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0). fig: A Plotly figure object to plot on. If not provided, a new figure will be created. + hover_template: Template for hover text. Use %{label}, %{value}, %{percent}. + text_info: What to show on pie segments: 'label', 'percent', 'value', 'label+percent', + 'label+value', 'percent+value', 'label+percent+value', or 'none'. + text_position: Position of text: 'inside', 'outside', 'auto', or 'none'. Returns: A Plotly figure object containing the generated pie chart. Notes: - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning. - - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category - for better readability. - - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing. + - All dimensions are summed to get total values for each variable. + - Scalar variables (with no dimensions) are used directly. + + Examples: + Simple pie chart: + + ```python + fig = pie_with_plotly(dataset, colors='turbo', title='Energy Mix') + ``` + + With ComponentColorManager: + ```python + manager = ComponentColorManager(['Solar', 'Wind', 'Coal']) + manager.add_grouping_rule('Solar', 'renewables', 'oranges', match_type='prefix') + manager.add_grouping_rule('Wind', 'renewables', 'blues', match_type='prefix') + manager.apply_colors() + fig = pie_with_plotly(dataset, colors=manager, title='Renewable Energy') + ``` """ - if data.empty: - logger.error('Empty DataFrame provided for pie chart. Returning empty figure.') + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + + if len(data.data_vars) == 0: + logger.error('Empty Dataset provided for pie chart. Returning empty figure.') return go.Figure() - # Create a copy to avoid modifying the original DataFrame - data_copy = data.copy() + # Sum all dimensions for each variable to get total values + labels = [] + values = [] - # Check if any negative values and warn - if (data_copy < 0).any().any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - data_copy = data_copy.abs() + for var in data.data_vars: + var_data = data[var] - # If data has multiple rows, sum them to get total for each column - if len(data_copy) > 1: - data_sum = data_copy.sum() - else: - data_sum = data_copy.iloc[0] + # Sum across all dimensions to get total + if len(var_data.dims) > 0: + total_value = var_data.sum().item() + else: + # Scalar variable + total_value = var_data.item() + + # Check for negative values + if total_value < 0: + logger.warning(f'Negative value detected for {var}: {total_value}. Using absolute value.') + total_value = abs(total_value) - # Get labels (column names) and values - labels = data_sum.index.tolist() - values = data_sum.values.tolist() + labels.append(str(var)) + values.append(total_value) - # Apply color mapping using the unified color processor - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, labels) + # Use resolve_colors for consistent color handling + color_discrete_map = resolve_colors(data, colors, engine='plotly') + processed_colors = [color_discrete_map.get(label, '#636EFA') for label in labels] # Create figure if not provided fig = fig if fig is not None else go.Figure() @@ -1012,91 +1527,111 @@ def pie_with_plotly( values=values, hole=hole, marker=dict(colors=processed_colors), - textinfo='percent+label+value', - textposition='inside', + textinfo=text_info, + textposition=text_position, insidetextorientation='radial', + hovertemplate=hover_template, ) ) - # Update layout for better aesthetics + # Update layout with plot-specific properties fig.update_layout( title=title, legend_title=legend_title, - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), # Increase font size for better readability ) return fig def pie_with_matplotlib( - data: pd.DataFrame, - colors: ColorType = 'viridis', + data: xr.Dataset | pd.DataFrame, + colors: ColorType | ComponentColorManager | None = None, title: str = '', legend_title: str = 'Categories', hole: float = 0.0, figsize: tuple[int, int] = (10, 8), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, ) -> tuple[plt.Figure, plt.Axes]: """ - Create a pie chart with Matplotlib to visualize the proportion of values in a DataFrame. + Create a pie chart with Matplotlib to visualize the proportion of values in a Dataset. Args: - data: A DataFrame containing the data to plot. If multiple rows exist, - they will be summed unless a specific index value is passed. + data: An xarray Dataset containing the data to plot. All dimensions will be summed + to get the total for each variable. colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') + - A string with a colormap name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A dictionary mapping variable names to colors (e.g., {'Solar': '#ff0000'}) + - A ComponentColorManager instance for pattern-based color rules title: The title of the plot. legend_title: The title for the legend. hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0). figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. Notes: - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning. - - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category - for better readability. - - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing. + - All dimensions are summed to get total values for each variable. + - Scalar variables (with no dimensions) are used directly. + + Examples: + Simple pie chart: + ```python + fig, ax = pie_with_matplotlib(dataset, colors='turbo', title='Energy Mix') + ``` + + With ComponentColorManager: + + ```python + manager = ComponentColorManager(['Solar', 'Wind', 'Coal']) + manager.add_grouping_rule('Solar', 'renewables', 'oranges', match_type='prefix') + manager.add_grouping_rule('Wind', 'renewables', 'blues', match_type='prefix') + manager.apply_colors() + fig, ax = pie_with_matplotlib(dataset, colors=manager, title='Renewable Energy') + ``` """ - if data.empty: - logger.error('Empty DataFrame provided for pie chart. Returning empty figure.') - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + + if len(data.data_vars) == 0: + logger.error('Empty Dataset provided for pie chart. Returning empty figure.') + fig, ax = plt.subplots(figsize=figsize) return fig, ax - # Create a copy to avoid modifying the original DataFrame - data_copy = data.copy() + # Sum all dimensions for each variable to get total values + labels = [] + values = [] - # Check if any negative values and warn - if (data_copy < 0).any().any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - data_copy = data_copy.abs() + for var in data.data_vars: + var_data = data[var] - # If data has multiple rows, sum them to get total for each column - if len(data_copy) > 1: - data_sum = data_copy.sum() - else: - data_sum = data_copy.iloc[0] + # Sum across all dimensions to get total + if len(var_data.dims) > 0: + total_value = var_data.sum().item() + else: + # Scalar variable + total_value = var_data.item() - # Get labels (column names) and values - labels = data_sum.index.tolist() - values = data_sum.values.tolist() + # Check for negative values + if total_value < 0: + logger.warning(f'Negative value detected for {var}: {total_value}. Using absolute value.') + total_value = abs(total_value) - # Apply color mapping using the unified color processor - processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, labels) + labels.append(str(var)) + values.append(total_value) - # Create figure and axis if not provided - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Use resolve_colors for consistent color handling + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + processed_colors = [color_discrete_map.get(label, '#808080') for label in labels] + + # Create figure and axis + fig, ax = plt.subplots(figsize=figsize) # Draw the pie chart wedges, texts, autotexts = ax.pie( @@ -1144,9 +1679,9 @@ def pie_with_matplotlib( def dual_pie_with_plotly( - data_left: pd.Series, - data_right: pd.Series, - colors: ColorType = 'viridis', + data_left: xr.Dataset | pd.DataFrame, + data_right: xr.Dataset | pd.DataFrame, + colors: ColorType | ComponentColorManager | None = None, title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), legend_title: str = '', @@ -1160,12 +1695,13 @@ def dual_pie_with_plotly( Create two pie charts side by side with Plotly, with consistent coloring across both charts. Args: - data_left: Series for the left pie chart. - data_right: Series for the right pie chart. + data_left: Dataset for the left pie chart. Variables are summed across all dimensions. + data_right: Dataset for the right pie chart. Variables are summed across all dimensions. colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') + - A string with a colorscale name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'}) + - A dictionary mapping variable names to colors (e.g., {'Solar': '#ff0000'}) + - A ComponentColorManager instance for pattern-based color rules title: The main title of the plot. subtitles: Tuple containing the subtitles for (left, right) charts. legend_title: The title for the legend. @@ -1179,10 +1715,19 @@ def dual_pie_with_plotly( Returns: A Plotly figure object containing the generated dual pie chart. """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + from plotly.subplots import make_subplots + # Ensure data is a Dataset and validate it + data_left = _ensure_dataset(data_left) + data_right = _ensure_dataset(data_right) + _validate_plotting_data(data_left, allow_empty=True) + _validate_plotting_data(data_right, allow_empty=True) + # Check for empty data - if data_left.empty and data_right.empty: + if len(data_left.data_vars) == 0 and len(data_right.data_vars) == 0: logger.error('Both datasets are empty. Returning empty figure.') return go.Figure() @@ -1191,71 +1736,52 @@ def dual_pie_with_plotly( rows=1, cols=2, specs=[[{'type': 'pie'}, {'type': 'pie'}]], subplot_titles=subtitles, horizontal_spacing=0.05 ) - # Process series to handle negative values and apply minimum percentage threshold - def preprocess_series(series: pd.Series): - """ - Preprocess a series for pie chart display by handling negative values - and grouping the smallest parts together if they collectively represent - less than the specified percentage threshold. - - Args: - series: The series to preprocess - - Returns: - A preprocessed pandas Series - """ - # Handle negative values - if (series < 0).any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - series = series.abs() + # Helper function to extract labels and values from Dataset + def dataset_to_pie_data(dataset): + labels = [] + values = [] - # Remove zeros - series = series[series > 0] + for var in dataset.data_vars: + var_data = dataset[var] - # Apply minimum percentage threshold if needed - if lower_percentage_group and not series.empty: - total = series.sum() - if total > 0: - # Sort series by value (ascending) - sorted_series = series.sort_values() - - # Calculate cumulative percentage contribution - cumulative_percent = (sorted_series.cumsum() / total) * 100 - - # Find entries that collectively make up less than lower_percentage_group - to_group = cumulative_percent <= lower_percentage_group - - if to_group.sum() > 1: - # Create "Other" category for the smallest values that together are < threshold - other_sum = sorted_series[to_group].sum() - - # Keep only values that aren't in the "Other" group - result_series = series[~series.index.isin(sorted_series[to_group].index)] + # Sum across all dimensions + if len(var_data.dims) > 0: + total_value = float(var_data.sum().values) + else: + total_value = float(var_data.values) - # Add the "Other" category if it has a value - if other_sum > 0: - result_series['Other'] = other_sum + # Handle negative values + if total_value < 0: + logger.warning(f'Negative value for {var}: {total_value}. Using absolute value.') + total_value = abs(total_value) - return result_series + # Only include if value > 0 + if total_value > 0: + labels.append(str(var)) + values.append(total_value) - return series + return labels, values - data_left_processed = preprocess_series(data_left) - data_right_processed = preprocess_series(data_right) + # Get data for left and right + left_labels, left_values = dataset_to_pie_data(data_left) + right_labels, right_values = dataset_to_pie_data(data_right) - # Get unique set of all labels for consistent coloring - all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index)) + # Get unique set of all labels for consistent coloring across both pies + # Merge both datasets for color resolution + combined_vars = list(set(data_left.data_vars) | set(data_right.data_vars)) + combined_ds = xr.Dataset( + {var: data_left[var] if var in data_left.data_vars else data_right[var] for var in combined_vars} + ) - # Get consistent color mapping for both charts using our unified function - color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True) + # Use resolve_colors for consistent color handling + color_discrete_map = resolve_colors(combined_ds, colors, engine='plotly') + color_map = {label: color_discrete_map.get(label, '#636EFA') for label in left_labels + right_labels} # Function to create a pie trace with consistently mapped colors - def create_pie_trace(data_series, side): - if data_series.empty: + def create_pie_trace(labels, values, side): + if not labels: return None - labels = data_series.index.tolist() - values = data_series.values.tolist() trace_colors = [color_map[label] for label in labels] return go.Pie( @@ -1272,24 +1798,21 @@ def create_pie_trace(data_series, side): ) # Add left pie if data exists - left_trace = create_pie_trace(data_left_processed, subtitles[0]) + left_trace = create_pie_trace(left_labels, left_values, subtitles[0]) if left_trace: left_trace.domain = dict(x=[0, 0.48]) fig.add_trace(left_trace, row=1, col=1) # Add right pie if data exists - right_trace = create_pie_trace(data_right_processed, subtitles[1]) + right_trace = create_pie_trace(right_labels, right_values, subtitles[1]) if right_trace: right_trace.domain = dict(x=[0.52, 1]) fig.add_trace(right_trace, row=1, col=2) - # Update layout + # Update layout with plot-specific properties fig.update_layout( title=title, legend_title=legend_title, - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), margin=dict(t=80, b=50, l=30, r=30), ) @@ -1299,25 +1822,22 @@ def create_pie_trace(data_series, side): def dual_pie_with_matplotlib( data_left: pd.Series, data_right: pd.Series, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), legend_title: str = '', hole: float = 0.2, lower_percentage_group: float = 5.0, figsize: tuple[int, int] = (14, 7), - fig: plt.Figure | None = None, - axes: list[plt.Axes] | None = None, ) -> tuple[plt.Figure, list[plt.Axes]]: """ Create two pie charts side by side with Matplotlib, with consistent coloring across both charts. - Leverages the existing pie_with_matplotlib function. Args: data_left: Series for the left pie chart. data_right: Series for the right pie chart. colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') + - A string with a colormap name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'}) title: The main title of the plot. @@ -1326,23 +1846,21 @@ def dual_pie_with_matplotlib( hole: Size of the hole in the center for creating donut charts (0.0 to 1.0). lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category. figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created. Returns: A tuple containing the Matplotlib figure and list of axes objects used for the plot. """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + + # Create figure and axes + fig, axes = plt.subplots(1, 2, figsize=figsize) + # Check for empty data if data_left.empty and data_right.empty: logger.error('Both datasets are empty. Returning empty figure.') - if fig is None: - fig, axes = plt.subplots(1, 2, figsize=figsize) return fig, axes - # Create figure and axes if not provided - if fig is None or axes is None: - fig, axes = plt.subplots(1, 2, figsize=figsize) - # Process series to handle negative values and apply minimum percentage threshold def preprocess_series(series: pd.Series): """ @@ -1404,19 +1922,49 @@ def preprocess_series(series: pd.Series): left_colors = [color_map[col] for col in df_left.columns] if not df_left.empty else [] right_colors = [color_map[col] for col in df_right.columns] if not df_right.empty else [] + # Helper function to draw pie chart on a specific axis + def draw_pie_on_axis(ax, data_series, colors_list, subtitle, hole_size): + """Draw a pie chart on a specific matplotlib axis.""" + if data_series.empty: + ax.set_title(subtitle) + ax.axis('off') + return + + labels = list(data_series.index) + values = list(data_series.values) + + # Draw the pie chart + wedges, texts, autotexts = ax.pie( + values, + labels=labels, + colors=colors_list, + autopct='%1.1f%%', + startangle=90, + shadow=False, + wedgeprops=dict(width=0.5) if hole_size > 0 else None, + ) + + # Adjust hole size + if hole_size > 0: + wedge_width = 1 - hole_size + for wedge in wedges: + wedge.set_width(wedge_width) + + # Customize text + for autotext in autotexts: + autotext.set_fontsize(10) + autotext.set_color('white') + + # Set aspect ratio and title + ax.set_aspect('equal') + if subtitle: + ax.set_title(subtitle, fontsize=14) + # Create left pie chart - if not df_left.empty: - pie_with_matplotlib(data=df_left, colors=left_colors, title=subtitles[0], hole=hole, fig=fig, ax=axes[0]) - else: - axes[0].set_title(subtitles[0]) - axes[0].axis('off') + draw_pie_on_axis(axes[0], data_left_processed, left_colors, subtitles[0], hole) # Create right pie chart - if not df_right.empty: - pie_with_matplotlib(data=df_right, colors=right_colors, title=subtitles[1], hole=hole, fig=fig, ax=axes[1]) - else: - axes[1].set_title(subtitles[1]) - axes[1].axis('off') + draw_pie_on_axis(axes[1], data_right_processed, right_colors, subtitles[1], hole) # Add main title if title: @@ -1460,15 +2008,16 @@ def preprocess_series(series: pd.Series): def heatmap_with_plotly( data: xr.DataArray, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', fill: Literal['ffill', 'bfill'] | None = 'ffill', + **imshow_kwargs: Any, ) -> go.Figure: """ Plot a heatmap visualization using Plotly's imshow with faceting and animation support. @@ -1487,7 +2036,7 @@ def heatmap_with_plotly( data: An xarray DataArray containing the data to visualize. Should have at least 2 dimensions, or a 'time' dimension that can be reshaped into 2D. colors: Color specification (colormap name, list, or dict). Common options: - 'viridis', 'plasma', 'RdBu', 'portland'. + 'turbo', 'plasma', 'RdBu', 'portland'. title: The main title of the heatmap. facet_by: Dimension to create facets for. Creates a subplot grid. Can be a single dimension name or list (only first dimension used). @@ -1501,6 +2050,11 @@ def heatmap_with_plotly( - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) - None: Disable time reshaping (will error if only 1D time data) fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + **imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow. + Common options include: + - aspect: 'auto', 'equal', or a number for aspect ratio + - zmin, zmax: Minimum and maximum values for color scale + - labels: Dict to customize axis labels Returns: A Plotly figure object containing the heatmap visualization. @@ -1538,6 +2092,13 @@ def heatmap_with_plotly( fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D')) ``` """ + if colors is None: + colors = CONFIG.Plotting.default_sequential_colorscale + + # Apply CONFIG defaults if not explicitly set + if facet_cols is None: + facet_cols = CONFIG.Plotting.default_facet_cols + # Handle empty data if data.size == 0: return go.Figure() @@ -1589,12 +2150,26 @@ def heatmap_with_plotly( heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] if len(heatmap_dims) < 2: - # Need at least 2 dimensions for a heatmap - logger.error( - f'Heatmap requires at least 2 dimensions for rows and columns. ' - f'After faceting/animation, only {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' - ) - return go.Figure() + # Handle single-dimension case by adding variable name as a dimension + if len(heatmap_dims) == 1: + # Get the variable name, or use a default + var_name = data.name if data.name else 'value' + + # Expand the DataArray by adding a new dimension with the variable name + data = data.expand_dims({'variable': [var_name]}) + + # Update available dimensions + available_dims = list(data.dims) + heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] + + logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}') + else: + # No dimensions at all - cannot create a heatmap + logger.error( + f'Heatmap requires at least 1 dimension. ' + f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' + ) + return go.Figure() # Setup faceting parameters for Plotly Express # Note: px.imshow only supports facet_col, not facet_row @@ -1617,7 +2192,7 @@ def heatmap_with_plotly( # Create the imshow plot - px.imshow can work directly with xarray DataArrays common_args = { 'img': data, - 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis', + 'color_continuous_scale': colors if isinstance(colors, str) else CONFIG.Plotting.default_sequential_colorscale, 'title': title, } @@ -1631,38 +2206,41 @@ def heatmap_with_plotly( if animate_by: common_args['animation_frame'] = animate_by + # Merge in additional imshow kwargs + common_args.update(imshow_kwargs) + try: fig = px.imshow(**common_args) except Exception as e: logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.') # Fallback: create a simple heatmap without faceting - fig = px.imshow( - data.values, - color_continuous_scale=colors if isinstance(colors, str) else 'viridis', - title=title, - ) - - # Update layout with basic styling - fig.update_layout( - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - font=dict(size=12), - ) + fallback_args = { + 'img': data.values, + 'color_continuous_scale': colors + if isinstance(colors, str) + else CONFIG.Plotting.default_sequential_colorscale, + 'title': title, + } + fallback_args.update(imshow_kwargs) + fig = px.imshow(**fallback_args) return fig def heatmap_with_matplotlib( data: xr.DataArray, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', figsize: tuple[float, float] = (12, 6), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', fill: Literal['ffill', 'bfill'] | None = 'ffill', + vmin: float | None = None, + vmax: float | None = None, + imshow_kwargs: dict[str, Any] | None = None, + cbar_kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes]: """ Plot a heatmap visualization using Matplotlib's imshow. @@ -1674,16 +2252,25 @@ def heatmap_with_matplotlib( data: An xarray DataArray containing the data to visualize. Should have at least 2 dimensions. If more than 2 dimensions exist, additional dimensions will be reduced by taking the first slice. - colors: Color specification. Should be a colormap name (e.g., 'viridis', 'RdBu'). + colors: Color specification. Should be a colormap name (e.g., 'turbo', 'RdBu'). title: The title of the heatmap. figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. reshape_time: Time reshaping configuration: - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) - None: Disable time reshaping fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + vmin: Minimum value for color scale. If None, uses data minimum. + vmax: Maximum value for color scale. If None, uses data maximum. + imshow_kwargs: Optional dict of parameters to pass to ax.imshow(). + Use this to customize image properties (e.g., interpolation, aspect). + cbar_kwargs: Optional dict of parameters to pass to plt.colorbar(). + Use this to customize colorbar properties (e.g., orientation, label). + **kwargs: Additional keyword arguments passed to ax.imshow(). + Common options include: + - interpolation: 'nearest', 'bilinear', 'bicubic', etc. + - alpha: Transparency level (0-1) + - extent: [left, right, bottom, top] for axis limits Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. @@ -1705,19 +2292,36 @@ def heatmap_with_matplotlib( fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h')) ``` """ + if colors is None: + colors = CONFIG.Plotting.default_sequential_colorscale + + # Initialize kwargs if not provided + if imshow_kwargs is None: + imshow_kwargs = {} + if cbar_kwargs is None: + cbar_kwargs = {} + + # Merge any additional kwargs into imshow_kwargs + # This allows users to pass imshow options directly + imshow_kwargs.update(kwargs) + # Handle empty data if data.size == 0: - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) return fig, ax # Apply time reshaping using the new unified function # Matplotlib doesn't support faceting/animation, so we pass None for those data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill) - # Create figure and axes if not provided - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Handle single-dimension case by adding variable name as a dimension + if isinstance(data, xr.DataArray) and len(data.dims) == 1: + var_name = data.name if data.name else 'value' + data = data.expand_dims({'variable': [var_name]}) + logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}') + + # Create figure and axes + fig, ax = plt.subplots(figsize=figsize) # Extract data values # If data has more than 2 dimensions, we need to reduce it @@ -1743,14 +2347,21 @@ def heatmap_with_matplotlib( y_labels = 'y' # Process colormap - cmap = colors if isinstance(colors, str) else 'viridis' + cmap = colors if isinstance(colors, str) else CONFIG.Plotting.default_sequential_colorscale + + # Create the heatmap using imshow with user customizations + imshow_defaults = {'cmap': cmap, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax} + imshow_defaults.update(imshow_kwargs) # User kwargs override defaults + im = ax.imshow(values, **imshow_defaults) - # Create the heatmap using imshow - im = ax.imshow(values, cmap=cmap, aspect='auto', origin='upper') + # Add colorbar with user customizations + cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05} + cbar_defaults.update(cbar_kwargs) # User kwargs override defaults + cbar = plt.colorbar(im, **cbar_defaults) - # Add colorbar - cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.1, aspect=15, fraction=0.05) - cbar.set_label('Value') + # Set colorbar label if not overridden by user + if 'label' not in cbar_kwargs: + cbar.set_label('Value') # Set labels and title ax.set_xlabel(str(x_labels).capitalize()) @@ -1768,8 +2379,9 @@ def export_figure( default_path: pathlib.Path, default_filetype: str | None = None, user_path: pathlib.Path | None = None, - show: bool = True, + show: bool | None = None, save: bool = False, + dpi: int | None = None, ) -> go.Figure | tuple[plt.Figure, plt.Axes]: """ Export a figure to a file and or show it. @@ -1779,13 +2391,21 @@ def export_figure( default_path: The default file path if no user filename is provided. default_filetype: The default filetype if the path doesnt end with a filetype. user_path: An optional user-specified file path. - show: Whether to display the figure (default: True). + show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None). save: Whether to save the figure (default: False). + dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi. Raises: ValueError: If no default filetype is provided and the path doesn't specify a filetype. TypeError: If the figure type is not supported. """ + # Apply CONFIG defaults if not explicitly set + if show is None: + show = CONFIG.Plotting.default_show + + if dpi is None: + dpi = CONFIG.Plotting.default_dpi + filename = user_path or default_path filename = filename.with_name(filename.name.replace('|', '__')) if filename.suffix == '': @@ -1795,30 +2415,32 @@ def export_figure( if isinstance(figure_like, plotly.graph_objs.Figure): fig = figure_like + + # Apply default dimensions if configured + layout_updates = {} + if CONFIG.Plotting.default_figure_width is not None: + layout_updates['width'] = CONFIG.Plotting.default_figure_width + if CONFIG.Plotting.default_figure_height is not None: + layout_updates['height'] = CONFIG.Plotting.default_figure_height + if layout_updates: + fig.update_layout(**layout_updates) + if filename.suffix != '.html': logger.warning(f'To save a Plotly figure, using .html. Adjusting suffix for {filename}') filename = filename.with_suffix('.html') try: - is_test_env = 'PYTEST_CURRENT_TEST' in os.environ - - if is_test_env: - # Test environment: never open browser, only save if requested - if save: - fig.write_html(str(filename)) - # Ignore show flag in tests - else: - # Production environment: respect show and save flags - if save and show: - # Save and auto-open in browser - plotly.offline.plot(fig, filename=str(filename)) - elif save and not show: - # Save without opening - fig.write_html(str(filename)) - elif show and not save: - # Show interactively without saving - fig.show() - # If neither save nor show: do nothing + # Respect show and save flags (tests should set CONFIG.Plotting.default_show=False) + if save and show: + # Save and auto-open in browser + plotly.offline.plot(fig, filename=str(filename)) + elif save and not show: + # Save without opening + fig.write_html(str(filename)) + elif show and not save: + # Show interactively without saving + fig.show() + # If neither save nor show: do nothing finally: # Cleanup to prevent socket warnings if hasattr(fig, '_renderer'): @@ -1829,16 +2451,15 @@ def export_figure( elif isinstance(figure_like, tuple): fig, ax = figure_like if show: - # Only show if using interactive backend and not in test environment + # Only show if using interactive backend (tests should set CONFIG.Plotting.default_show=False) backend = matplotlib.get_backend().lower() is_interactive = backend not in {'agg', 'pdf', 'ps', 'svg', 'template'} - is_test_env = 'PYTEST_CURRENT_TEST' in os.environ - if is_interactive and not is_test_env: + if is_interactive: plt.show() if save: - fig.savefig(str(filename), dpi=300) + fig.savefig(str(filename), dpi=dpi) plt.close(fig) # Close figure to free memory return fig, ax diff --git a/flixopt/results.py b/flixopt/results.py index 75f8f300e..abc71457c 100644 --- a/flixopt/results.py +++ b/flixopt/results.py @@ -15,6 +15,7 @@ from . import io as fx_io from . import plotting +from .config import CONFIG from .flow_system import FlowSystem if TYPE_CHECKING: @@ -69,6 +70,10 @@ class CalculationResults: effects: Dictionary mapping effect names to EffectResults objects timesteps_extra: Extended time index including boundary conditions hours_per_timestep: Duration of each timestep for proper energy calculations + color_manager: Optional ComponentColorManager for automatic component-based coloring in plots. + When set, all plotting methods automatically use this manager when colors='auto' + (the default). Use `setup_colors()` to create and configure one, or assign + an existing manager directly. Set to None to disable automatic coloring. Examples: Load and analyze saved results: @@ -107,6 +112,23 @@ class CalculationResults: ).mean() ``` + Configure automatic color management for plots: + + ```python + # Dict-based configuration (simplest): + results.setup_colors({'Solar*': 'oranges', 'Wind*': 'blues', 'Battery': 'greens'}) + + # Or programmatically: + results.setup_colors().add_rule('Solar*', 'oranges').add_rule('Wind*', 'blues') + + # All plots automatically use configured colors (colors='auto' is the default) + results['ElectricityBus'].plot_node_balance() + results['Battery'].plot_charge_state() + + # Override when needed + results['ElectricityBus'].plot_node_balance(colors='turbo') # Ignores mapper + ``` + Design Patterns: **Factory Methods**: Use `from_file()` and `from_calculation()` for creation or access directly from `Calculation.results` **Dictionary Access**: Use `results[element_label]` for element-specific results @@ -240,6 +262,9 @@ def __init__( self._sizes = None self._effects_per_component = None + # Color manager for intelligent plot coloring - None by default, user configures explicitly + self.color_manager: plotting.ComponentColorManager | None = None + def __getitem__(self, key: str) -> ComponentResults | BusResults | EffectResults: if key in self.components: return self.components[key] @@ -306,6 +331,72 @@ def flow_system(self) -> FlowSystem: logger.level = old_level return self._flow_system + def setup_colors( + self, + config: dict[str, str | list[str]] | str | pathlib.Path | None = None, + default_colorscale: str | None = None, + ) -> plotting.ComponentColorManager: + """Initialize and return a ColorManager for configuring plot colors. + + Convenience method that creates a ComponentColorManager with all components + registered and assigns it to `self.color_manager`. Optionally load configuration + from a dict or file. + + Args: + config: Optional color configuration: + - dict: Mixed {component: color} or {colorscale: [components]} mapping + - str/Path: Path to YAML file + - None: Create empty manager for manual config (default) + default_colorscale: Optional default colorscale to use. Defaults to CONFIG.Plotting.default_default_qualitative_colorscale + + Returns: + ComponentColorManager instance ready for configuration. + + Examples: + Dict-based configuration (mixed direct + grouped): + + ```python + results.setup_colors( + { + # Direct colors + 'Boiler1': '#FF0000', + 'CHP': 'darkred', + # Grouped colors + 'oranges': ['Solar1', 'Solar2'], + 'blues': ['Wind1', 'Wind2'], + 'greens': ['Battery1', 'Battery2', 'Battery3'], + } + ) + results['ElectricityBus'].plot_node_balance() + ``` + + Load from YAML file: + + ```python + # colors.yaml contains: + # Boiler1: '#FF0000' + # oranges: + # - Solar1 + # - Solar2 + results.setup_colors('colors.yaml') + ``` + + Disable automatic coloring: + + ```python + results.color_manager = None # Plots use default colorscales + ``` + """ + self.color_manager = plotting.ComponentColorManager.from_flow_system( + self.flow_system, default_colorscale=default_colorscale + ) + + # Apply configuration if provided + if config is not None: + self.color_manager.configure(config) + + return self.color_manager + def filter_solution( self, variable_dims: Literal['scalar', 'time', 'scenario', 'timeonly', 'scenarioonly'] | None = None, @@ -705,13 +796,13 @@ def plot_heatmap( self, variable_name: str | list[str], save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', @@ -721,6 +812,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ Plots a heatmap visualization of a variable using imshow or time-based reshaping. @@ -737,7 +829,8 @@ def plot_heatmap( with a new 'variable' dimension. save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. - colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options. + colors: Color scheme for the heatmap (default: None uses CONFIG.Plotting.default_sequential_colorscale). + See `flixopt.plotting.ColorType` for options. engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. select: Optional data selection dict. Supports single values, lists, slices, and index arrays. Applied BEFORE faceting/animation/reshaping. @@ -754,6 +847,20 @@ def plot_heatmap( Supported timeframes: 'YS', 'MS', 'W', 'D', 'h', '15min', 'min' fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill). Default is 'ffill'. + **plot_kwargs: Additional plotting customization options. + Common options: + + - **dpi** (int): Export resolution for saved plots. Default: 300. + + For heatmaps specifically: + + - **vmin** (float): Minimum value for color scale (both engines). + - **vmax** (float): Maximum value for color scale (both engines). + + For Matplotlib heatmaps: + + - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow (e.g., interpolation, aspect). + - **cbar_kwargs** (dict): Additional kwargs for colorbar customization. Examples: Direct imshow mode (default): @@ -794,6 +901,18 @@ def plot_heatmap( ... animate_by='period', ... reshape_time=('D', 'h'), ... ) + + High-resolution export with custom color range: + + >>> results.plot_heatmap('Battery|charge_state', save=True, dpi=600, vmin=0, vmax=100) + + Matplotlib heatmap with custom imshow settings: + + >>> results.plot_heatmap( + ... 'Boiler(Q_th)|flow_rate', + ... engine='matplotlib', + ... imshow_kwargs={'interpolation': 'bilinear', 'aspect': 'auto'}, + ... ) """ # Delegate to module-level plot_heatmap function return plot_heatmap( @@ -814,6 +933,7 @@ def plot_heatmap( heatmap_timeframes=heatmap_timeframes, heatmap_timesteps_per_frame=heatmap_timesteps_per_frame, color_map=color_map, + **plot_kwargs, ) def plot_network( @@ -982,8 +1102,8 @@ def __init__( def plot_node_balance( self, save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate', @@ -991,9 +1111,10 @@ def plot_node_balance( drop_suffix: bool = True, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ Plots the node balance of the Component or Bus with optional faceting and animation. @@ -1001,7 +1122,12 @@ def plot_node_balance( Args: save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. - colors: The colors to use for the plot. See `flixopt.plotting.ColorType` for options. + colors: The colors to use for the plot. Options: + - None (default): Use `self.color_manager` if configured, else fall back to CONFIG.Plotting.default_qualitative_colorscale + - Colormap name string (e.g., 'turbo', 'plasma') + - List of color strings + - Dict mapping variable names to colors + Set `results.color_manager` to a `ComponentColorManager` for automatic component-based grouping. engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. select: Optional data selection dict. Supports: - Single values: {'scenario': 'base', 'period': 2024} @@ -1021,6 +1147,27 @@ def plot_node_balance( animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through dimension values. Only one dimension can be animated. Ignored if not found. facet_cols: Number of columns in the facet grid layout (default: 3). + **plot_kwargs: Additional plotting customization options passed to underlying plotting functions. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + + **For Plotly engine** (`engine='plotly'`): + + - **trace_kwargs** (dict): Customize traces via `fig.update_traces()`. + Example: `trace_kwargs={'line': {'width': 5, 'dash': 'dot'}}` + - **layout_kwargs** (dict): Customize layout via `fig.update_layout()`. + Example: `layout_kwargs={'width': 1200, 'height': 600, 'template': 'plotly_dark'}` + - Any Plotly Express parameter for px.bar()/px.line()/px.area() + + **For Matplotlib engine** (`engine='matplotlib'`): + + - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`. + Example: `plot_kwargs={'linewidth': 3, 'alpha': 0.7, 'edgecolor': 'black'}` + + See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib` + for complete parameter reference. Examples: Basic plot (current behavior): @@ -1052,6 +1199,24 @@ def plot_node_balance( Time range selection (summer months only): >>> results['Boiler'].plot_node_balance(select={'time': slice('2024-06', '2024-08')}, facet_by='scenario') + + High-resolution export for publication: + + >>> results['Boiler'].plot_node_balance(engine='matplotlib', save='figure.png', dpi=600) + + Custom Plotly theme and layout: + + >>> results['Boiler'].plot_node_balance( + ... layout_kwargs={'template': 'plotly_dark', 'width': 1200, 'height': 600} + ... ) + + Custom line styling: + + >>> results['Boiler'].plot_node_balance(mode='line', trace_kwargs={'line': {'width': 5, 'dash': 'dot'}}) + + Custom matplotlib appearance: + + >>> results['Boiler'].plot_node_balance(engine='matplotlib', plot_kwargs={'linewidth': 3, 'alpha': 0.7}) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1073,11 +1238,24 @@ def plot_node_balance( if engine not in {'plotly', 'matplotlib'}: raise ValueError(f'Engine "{engine}" not supported. Use one of ["plotly", "matplotlib"]') + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + # Don't pass select/indexer to node_balance - we'll apply it afterwards - ds = self.node_balance(with_last_timestep=True, unit_type=unit_type, drop_suffix=drop_suffix) + ds = self.node_balance(with_last_timestep=False, unit_type=unit_type, drop_suffix=drop_suffix) ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True) + # Resolve colors: None -> color_manager if set -> CONFIG default -> explicit value + colors_to_use = ( + self._calculation_results.color_manager + if colors is None and self._calculation_results.color_manager is not None + else CONFIG.Plotting.default_qualitative_colorscale + if colors is None + else colors + ) + resolved_colors = plotting.resolve_colors(ds, colors_to_use, engine=engine) + # Matplotlib requires only 'time' dimension; check for extras after selection if engine == 'matplotlib': extra_dims = [d for d in ds.dims if d != 'time'] @@ -1097,18 +1275,21 @@ def plot_node_balance( ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=resolved_colors, mode=mode, title=title, facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) default_filetype = '.html' else: figure_like = plotting.with_matplotlib( - ds.to_dataframe(), - colors=colors, + ds, + colors=resolved_colors, mode=mode, title=title, + **plot_kwargs, ) default_filetype = '.png' @@ -1119,19 +1300,21 @@ def plot_node_balance( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def plot_node_balance_pie( self, lower_percentage_group: float = 5, - colors: plotting.ColorType = 'viridis', + colors: plotting.ColorType | None = None, text_info: str = 'percent+label+value', save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, list[plt.Axes]]: """Plot pie chart of flow hours distribution. @@ -1144,13 +1327,25 @@ def plot_node_balance_pie( Args: lower_percentage_group: Percentage threshold for "Others" grouping. - colors: Color scheme. Also see plotly. + colors: Color scheme (default: None uses color_manager if configured, + else falls back to CONFIG.Plotting.default_qualitative_colorscale). text_info: Information to display on pie slices. save: Whether to save plot. show: Whether to display plot. engine: Plotting engine ('plotly' or 'matplotlib'). select: Optional data selection dict. Supports single values, lists, slices, and index arrays. Use this to select specific scenario/period before creating the pie chart. + **plot_kwargs: Additional plotting customization options. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + - **hover_template** (str): Hover text template (Plotly only). + Example: `hover_template='%{label}: %{value} (%{percent})'` + - **text_position** (str): Text position ('inside', 'outside', 'auto'). + - **hole** (float): Size of donut hole (0.0 to 1.0). + + See :func:`flixopt.plotting.dual_pie_with_plotly` for complete reference. Examples: Basic usage (auto-selects first scenario/period if present): @@ -1160,6 +1355,14 @@ def plot_node_balance_pie( Explicitly select a scenario and period: >>> results['Bus'].plot_node_balance_pie(select={'scenario': 'high_demand', 'period': 2030}) + + Create a donut chart with custom hover text: + + >>> results['Bus'].plot_node_balance_pie(hole=0.4, hover_template='%{label}: %{value:.2f} (%{percent})') + + High-resolution export: + + >>> results['Bus'].plot_node_balance_pie(save='figure.png', dpi=600) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1178,6 +1381,9 @@ def plot_node_balance_pie( ) select = indexer + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + inputs = sanitize_dataset( ds=self.solution[self.inputs] * self._calculation_results.hours_per_timestep, threshold=1e-5, @@ -1233,16 +1439,30 @@ def plot_node_balance_pie( suffix = '--' + '-'.join(suffix_parts) if suffix_parts else '' title = f'{self.label} (total flow hours){suffix}' + # Combine inputs and outputs to resolve colors for all variables + combined_ds = xr.Dataset({**inputs.data_vars, **outputs.data_vars}) + + # Resolve colors: None -> color_manager if set -> CONFIG default -> explicit value + colors_to_use = ( + self._calculation_results.color_manager + if colors is None and self._calculation_results.color_manager is not None + else CONFIG.Plotting.default_qualitative_colorscale + if colors is None + else colors + ) + resolved_colors = plotting.resolve_colors(combined_ds, colors_to_use, engine=engine) + if engine == 'plotly': figure_like = plotting.dual_pie_with_plotly( - data_left=inputs.to_pandas(), - data_right=outputs.to_pandas(), - colors=colors, + data_left=inputs, + data_right=outputs, + colors=resolved_colors, title=title, text_info=text_info, subtitles=('Inputs', 'Outputs'), legend_title='Flows', lower_percentage_group=lower_percentage_group, + **plot_kwargs, ) default_filetype = '.html' elif engine == 'matplotlib': @@ -1250,11 +1470,12 @@ def plot_node_balance_pie( figure_like = plotting.dual_pie_with_matplotlib( data_left=inputs.to_pandas(), data_right=outputs.to_pandas(), - colors=colors, + colors=resolved_colors, title=title, subtitles=('Inputs', 'Outputs'), legend_title='Flows', lower_percentage_group=lower_percentage_group, + **plot_kwargs, ) default_filetype = '.png' else: @@ -1267,6 +1488,7 @@ def plot_node_balance_pie( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def node_balance( @@ -1363,23 +1585,25 @@ def charge_state(self) -> xr.DataArray: def plot_charge_state( self, save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', mode: Literal['area', 'stacked_bar', 'line'] = 'area', select: dict[FlowSystemDimensions, Any] | None = None, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure: """Plot storage charge state over time, combined with the node balance with optional faceting and animation. Args: save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. - colors: Color scheme. Also see plotly. + colors: Color scheme (default: None uses color_manager if configured, + else falls back to CONFIG.Plotting.default_qualitative_colorscale). engine: Plotting engine to use. Only 'plotly' is implemented atm. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, or 'area' for stacked area charts. select: Optional data selection dict. Supports single values, lists, slices, and index arrays. @@ -1389,6 +1613,24 @@ def plot_charge_state( animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through dimension values. Only one dimension can be animated. Ignored if not found. facet_cols: Number of columns in the facet grid layout (default: 3). + **plot_kwargs: Additional plotting customization options passed to underlying plotting functions. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + + **For Plotly engine:** + + - **trace_kwargs** (dict): Customize traces via `fig.update_traces()`. + - **layout_kwargs** (dict): Customize layout via `fig.update_layout()`. + - Any Plotly Express parameter for px.bar()/px.line()/px.area() + + **For Matplotlib engine:** + + - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`. + + See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib` + for complete parameter reference. Raises: ValueError: If component is not a storage. @@ -1409,6 +1651,14 @@ def plot_charge_state( Facet by scenario AND animate by period: >>> results['Storage'].plot_charge_state(facet_by='scenario', animate_by='period') + + Custom layout: + + >>> results['Storage'].plot_charge_state(layout_kwargs={'template': 'plotly_dark', 'height': 800}) + + High-resolution export: + + >>> results['Storage'].plot_charge_state(save='storage.png', dpi=600) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1427,11 +1677,14 @@ def plot_charge_state( ) select = indexer + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + if not self.is_storage: raise ValueError(f'Cant plot charge_state. "{self.label}" is not a storage') # Get node balance and charge state - ds = self.node_balance(with_last_timestep=True) + ds = self.node_balance(with_last_timestep=True).fillna(0) charge_state_da = self.charge_state # Apply select filtering @@ -1441,31 +1694,48 @@ def plot_charge_state( title = f'Operation Balance of {self.label}{suffix}' + # Combine flow balance and charge state for color resolution + # We need to include both in the color map for consistency + combined_ds = ds.assign({self._charge_state: charge_state_da}) + + # Resolve colors: None -> color_manager if set -> CONFIG default -> explicit value + colors_to_use = ( + self._calculation_results.color_manager + if colors is None and self._calculation_results.color_manager is not None + else CONFIG.Plotting.default_qualitative_colorscale + if colors is None + else colors + ) + resolved_colors = plotting.resolve_colors(combined_ds, colors_to_use, engine=engine) + if engine == 'plotly': # Plot flows (node balance) with the specified mode figure_like = plotting.with_plotly( ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=resolved_colors, mode=mode, title=title, facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) - # Create a dataset with just charge_state and plot it as lines - # This ensures proper handling of facets and animation - charge_state_ds = charge_state_da.to_dataset(name=self._charge_state) + # Prepare charge_state as Dataset for plotting + charge_state_ds = xr.Dataset({self._charge_state: charge_state_da}) # Plot charge_state with mode='line' to get Scatter traces charge_state_fig = plotting.with_plotly( charge_state_ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=resolved_colors, mode='line', # Always line for charge_state title='', # No title needed for this temp figure facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) # Add charge_state traces to the main figure @@ -1473,6 +1743,7 @@ def plot_charge_state( for trace in charge_state_fig.data: trace.line.width = 2 # Make charge_state line more prominent trace.line.shape = 'linear' # Smooth line for charge state (not stepped like flows) + trace.line.color = 'black' figure_like.add_trace(trace) # Also add traces from animation frames if they exist @@ -1497,10 +1768,11 @@ def plot_charge_state( ) # For matplotlib, plot flows (node balance), then add charge_state as line fig, ax = plotting.with_matplotlib( - ds.to_dataframe(), - colors=colors, + ds, + colors=resolved_colors, mode=mode, title=title, + **plot_kwargs, ) # Add charge_state as a line overlay @@ -1525,6 +1797,7 @@ def plot_charge_state( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def node_balance_with_charge_state( @@ -1635,6 +1908,19 @@ class SegmentedCalculationResults: - Flow rate transitions at segment boundaries - Aggregated results over the full time horizon + Attributes: + segment_results: List of CalculationResults for each segment + all_timesteps: Complete time index spanning all segments + timesteps_per_segment: Number of timesteps in each segment + overlap_timesteps: Number of overlapping timesteps between segments + name: Identifier for this segmented calculation + folder: Directory path for result storage and loading + hours_per_timestep: Duration of each timestep + color_manager: Optional ComponentColorManager for automatic component-based coloring in plots. + When set, it is automatically propagated to all segment results, ensuring + consistent coloring across segments. Use `setup_colors()` to create + and configure one, or assign an existing manager directly. + Examples: Load and analyze segmented results: @@ -1690,6 +1976,17 @@ class SegmentedCalculationResults: storage_continuity = results.check_storage_continuity('Battery') ``` + Configure color management for consistent plotting across segments: + + ```python + # Dict-based configuration (simplest): + results.setup_colors({'Solar*': 'oranges', 'Wind*': 'blues', 'Battery': 'greens'}) + + # Colors automatically propagate to all segments + results.segment_results[0]['ElectricityBus'].plot_node_balance() + results.segment_results[1]['ElectricityBus'].plot_node_balance() # Same colors + ``` + Design Considerations: **Boundary Effects**: Monitor solution quality at segment interfaces where foresight is limited compared to full-horizon optimization. @@ -1764,6 +2061,9 @@ def __init__( self.folder = pathlib.Path(folder) if folder is not None else pathlib.Path.cwd() / 'results' self.hours_per_timestep = FlowSystem.calculate_hours_per_timestep(self.all_timesteps) + # Color manager for intelligent plot coloring - None by default, user configures explicitly + self.color_manager: plotting.ComponentColorManager | None = None + @property def meta_data(self) -> dict[str, int | list[str]]: return { @@ -1777,6 +2077,63 @@ def meta_data(self) -> dict[str, int | list[str]]: def segment_names(self) -> list[str]: return [segment.name for segment in self.segment_results] + def setup_colors( + self, + config: dict[str, str | list[str]] | str | pathlib.Path | None = None, + default_colorscale: str | None = None, + ) -> plotting.ComponentColorManager: + """Initialize and return a ColorManager that propagates to all segments. + + Convenience method that creates a ComponentColorManager with all components + registered and assigns it to `self.color_manager` and all segment results. + Optionally load configuration from a dict or file. + + Args: + config: Optional color configuration: + - dict: Mixed {component: color} or {colorscale: [components]} mapping + - str/Path: Path to YAML file + - None: Create empty manager for manual config (default) + default_colorscale: Optional default colorscale to use. Defaults to CONFIG.Plotting.default_default_qualitative_colorscale + + Returns: + ComponentColorManager instance ready for configuration (propagated to all segments). + + Examples: + Dict-based configuration (mixed direct + grouped): + + ```python + results.setup_colors( + { + 'Boiler1': '#FF0000', + 'oranges': ['Solar1', 'Solar2'], + 'blues': ['Wind1', 'Wind2'], + } + ) + + # All segments use the same colors + results.segment_results[0]['ElectricityBus'].plot_node_balance() + results.segment_results[1]['ElectricityBus'].plot_node_balance() + ``` + + Load from file: + + ```python + results.setup_colors('colors.yaml') + ``` + """ + self.color_manager = plotting.ComponentColorManager.from_flow_system( + self.flow_system, default_colorscale=default_colorscale + ) + # Propagate to all segment results for consistent coloring + for segment in self.segment_results: + segment.color_manager = self.color_manager + + # Apply configuration if provided + if config is not None: + self.color_manager.configure(config) + + return self.color_manager + def solution_without_overlap(self, variable_name: str) -> xr.DataArray: """Get variable solution removing segment overlaps. @@ -1798,18 +2155,19 @@ def plot_heatmap( reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', - colors: str = 'portland', + colors: str | None = None, save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, fill: Literal['ffill', 'bfill'] | None = 'ffill', # Deprecated parameters (kept for backwards compatibility) heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """Plot heatmap of variable solution across segments. @@ -1819,7 +2177,8 @@ def plot_heatmap( - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains - Tuple like ('D', 'h'): Explicit reshaping (days vs hours) - None: Disable time reshaping - colors: Color scheme. See plotting.ColorType for options. + colors: Color scheme (default: None uses CONFIG.Plotting.default_sequential_colorscale). + See plotting.ColorType for options. save: Whether to save plot. show: Whether to display plot. engine: Plotting engine. @@ -1830,6 +2189,17 @@ def plot_heatmap( heatmap_timeframes: (Deprecated) Use reshape_time instead. heatmap_timesteps_per_frame: (Deprecated) Use reshape_time instead. color_map: (Deprecated) Use colors instead. + **plot_kwargs: Additional plotting customization options. + Common options: + + - **dpi** (int): Export resolution for saved plots. Default: 300. + - **vmin** (float): Minimum value for color scale. + - **vmax** (float): Maximum value for color scale. + + For Matplotlib heatmaps: + + - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow. + - **cbar_kwargs** (dict): Additional kwargs for colorbar customization. Returns: Figure object. @@ -1857,7 +2227,7 @@ def plot_heatmap( if color_map is not None: # Check for conflict with new parameter - if colors != 'portland': # Check if user explicitly set colors + if colors is not None: # Check if user explicitly set colors raise ValueError( "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'." ) @@ -1884,6 +2254,7 @@ def plot_heatmap( animate_by=animate_by, facet_cols=facet_cols, fill=fill, + **plot_kwargs, ) def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = None, compression: int = 5): @@ -1916,9 +2287,9 @@ def plot_heatmap( data: xr.DataArray | xr.Dataset, name: str | None = None, folder: pathlib.Path | None = None, - colors: plotting.ColorType = 'viridis', + colors: plotting.ColorType | None = None, save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[str, Any] | None = None, facet_by: str | list[str] | None = None, @@ -1933,6 +2304,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ): """Plot heatmap visualization with support for multi-variable, faceting, and animation. @@ -1945,7 +2317,8 @@ def plot_heatmap( name: Optional name for the title. If not provided, uses the DataArray name or generates a default title for Datasets. folder: Save folder for the plot. Defaults to current directory if not provided. - colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options. + colors: Color scheme for the heatmap (default: None uses CONFIG.Plotting.default_sequential_colorscale). + See `flixopt.plotting.ColorType` for options. save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. @@ -2004,7 +2377,7 @@ def plot_heatmap( # Handle deprecated color_map parameter if color_map is not None: # Check for conflict with new parameter - if colors != 'viridis': # User explicitly set colors + if colors is not None: # User explicitly set colors raise ValueError( "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'." ) @@ -2087,6 +2460,13 @@ def plot_heatmap( timeframes, timesteps_per_frame = reshape_time title += f' ({timeframes} vs {timesteps_per_frame})' + # Apply CONFIG default if colors is None + if colors is None: + colors = CONFIG.Plotting.default_sequential_colorscale + + # Extract dpi before passing to plotting functions + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + # Plot with appropriate engine if engine == 'plotly': figure_like = plotting.heatmap_with_plotly( @@ -2098,6 +2478,7 @@ def plot_heatmap( facet_cols=facet_cols, reshape_time=reshape_time, fill=fill, + **plot_kwargs, ) default_filetype = '.html' elif engine == 'matplotlib': @@ -2107,6 +2488,7 @@ def plot_heatmap( title=title, reshape_time=reshape_time, fill=fill, + **plot_kwargs, ) default_filetype = '.png' else: @@ -2123,6 +2505,7 @@ def plot_heatmap( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) diff --git a/tests/conftest.py b/tests/conftest.py index ac5255562..98929e467 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -838,6 +838,7 @@ def set_test_environment(): This fixture runs once per test session to: - Set matplotlib to use non-interactive 'Agg' backend - Set plotly to use non-interactive 'json' renderer + - Configure flixopt to not show plots by default - Prevent GUI windows from opening during tests """ import matplotlib @@ -848,4 +849,8 @@ def set_test_environment(): pio.renderers.default = 'json' # Use non-interactive renderer + # Configure flixopt to not show plots in tests + fx.CONFIG.Plotting.default_show = False + fx.CONFIG.apply() + yield diff --git a/tests/test_component_color_manager.py b/tests/test_component_color_manager.py new file mode 100644 index 000000000..4d58cdc3c --- /dev/null +++ b/tests/test_component_color_manager.py @@ -0,0 +1,377 @@ +"""Tests for ComponentColorManager functionality.""" + +import numpy as np +import pytest +import xarray as xr + +from flixopt.plotting import ComponentColorManager, resolve_colors + + +class TestBasicFunctionality: + """Test basic ComponentColorManager functionality.""" + + def test_initialization_default(self): + """Test default initialization.""" + components = ['Solar_PV', 'Wind_Onshore', 'Coal_Plant'] + manager = ComponentColorManager(components) + + assert len(manager.components) == 3 + assert manager.default_colorscale == 'plotly' + assert 'Solar_PV' in manager.components + + def test_sorted_components(self): + """Test that components are sorted for stability.""" + components = ['C_Component', 'A_Component', 'B_Component'] + manager = ComponentColorManager(components) + + # Components should be sorted + assert manager.components == ['A_Component', 'B_Component', 'C_Component'] + + def test_default_color_assignment(self): + """Test that components get default colors on initialization.""" + components = ['Comp1', 'Comp2', 'Comp3'] + manager = ComponentColorManager(components) + + # Each component should have a color + for comp in components: + color = manager.get_color(comp) + assert color is not None + assert isinstance(color, str) + + def test_empty_initialization(self): + """Test initialization without components.""" + manager = ComponentColorManager() + assert len(manager.components) == 0 + + +class TestConfigureAPI: + """Test the configure() method with various inputs.""" + + def test_configure_direct_colors(self): + """Test direct color assignment (component → color).""" + manager = ComponentColorManager() + manager.configure({'Boiler1': '#FF0000', 'CHP': 'darkred', 'Storage': 'green'}) + + assert manager.get_color('Boiler1') == '#FF0000' + assert manager.get_color('CHP') == 'darkred' + assert manager.get_color('Storage') == 'green' + + def test_configure_grouped_colors(self): + """Test grouped color assignment (colorscale → list of components).""" + manager = ComponentColorManager() + manager.configure( + { + 'oranges': ['Solar1', 'Solar2'], + 'blues': ['Wind1', 'Wind2'], + } + ) + + # All should have colors + assert manager.get_color('Solar1') is not None + assert manager.get_color('Solar2') is not None + assert manager.get_color('Wind1') is not None + assert manager.get_color('Wind2') is not None + + # Solar components should have different shades + assert manager.get_color('Solar1') != manager.get_color('Solar2') + + # Wind components should have different shades + assert manager.get_color('Wind1') != manager.get_color('Wind2') + + def test_configure_mixed(self): + """Test mixed direct and grouped colors.""" + manager = ComponentColorManager() + manager.configure( + { + 'Boiler1': '#FF0000', + 'oranges': ['Solar1', 'Solar2'], + 'blues': ['Wind1', 'Wind2'], + } + ) + + # Direct color + assert manager.get_color('Boiler1') == '#FF0000' + + # Grouped colors + assert manager.get_color('Solar1') is not None + assert manager.get_color('Wind1') is not None + + def test_configure_updates_components_list(self): + """Test that configure() adds components to the list.""" + manager = ComponentColorManager() + assert len(manager.components) == 0 + + manager.configure({'Boiler1': '#FF0000', 'CHP': 'red'}) + + assert len(manager.components) == 2 + assert 'Boiler1' in manager.components + assert 'CHP' in manager.components + + +class TestColorFamilies: + """Test color family functionality.""" + + def test_default_families(self): + """Test that default families are available.""" + manager = ComponentColorManager([]) + + assert 'blues' in manager.color_families + assert 'oranges' in manager.color_families + assert 'greens' in manager.color_families + assert 'reds' in manager.color_families + + +class TestColorStability: + """Test color stability across different datasets.""" + + def test_same_component_same_color(self): + """Test that same component always gets same color.""" + manager = ComponentColorManager() + manager.configure( + { + 'oranges': ['Solar_PV'], + 'blues': ['Wind_Onshore'], + } + ) + + # Get colors multiple times + color1 = manager.get_color('Solar_PV') + color2 = manager.get_color('Solar_PV') + color3 = manager.get_color('Solar_PV') + + assert color1 == color2 == color3 + + def test_color_stability_with_different_datasets(self): + """Test that colors remain stable across different variable subsets.""" + manager = ComponentColorManager() + manager.configure( + { + 'oranges': ['Solar_PV'], + 'blues': ['Wind_Onshore'], + 'greys': ['Coal_Plant'], + 'reds': ['Gas_Plant'], + } + ) + + # Dataset 1: Only Solar and Wind + dataset1 = xr.Dataset( + { + 'Solar_PV(Bus)|flow_rate': (['time'], np.random.rand(10)), + 'Wind_Onshore(Bus)|flow_rate': (['time'], np.random.rand(10)), + }, + coords={'time': np.arange(10)}, + ) + + # Dataset 2: All components + dataset2 = xr.Dataset( + { + 'Solar_PV(Bus)|flow_rate': (['time'], np.random.rand(10)), + 'Wind_Onshore(Bus)|flow_rate': (['time'], np.random.rand(10)), + 'Coal_Plant(Bus)|flow_rate': (['time'], np.random.rand(10)), + 'Gas_Plant(Bus)|flow_rate': (['time'], np.random.rand(10)), + }, + coords={'time': np.arange(10)}, + ) + + colors1 = resolve_colors(dataset1, manager, engine='plotly') + colors2 = resolve_colors(dataset2, manager, engine='plotly') + + # Solar_PV and Wind_Onshore should have same colors in both datasets + assert colors1['Solar_PV(Bus)|flow_rate'] == colors2['Solar_PV(Bus)|flow_rate'] + assert colors1['Wind_Onshore(Bus)|flow_rate'] == colors2['Wind_Onshore(Bus)|flow_rate'] + + +class TestVariableExtraction: + """Test variable to component extraction.""" + + def test_extract_component_with_parentheses(self): + """Test extracting component from variable with parentheses.""" + manager = ComponentColorManager([]) + + variable = 'Solar_PV(ElectricityBus)|flow_rate' + component = manager.extract_component(variable) + + assert component == 'Solar_PV' + + def test_extract_component_with_pipe(self): + """Test extracting component from variable with pipe.""" + manager = ComponentColorManager([]) + + variable = 'Solar_PV|investment' + component = manager.extract_component(variable) + + assert component == 'Solar_PV' + + def test_extract_component_no_separators(self): + """Test extracting component from variable without separators.""" + manager = ComponentColorManager([]) + + variable = 'SimpleComponent' + component = manager.extract_component(variable) + + assert component == 'SimpleComponent' + + +class TestVariableColorResolution: + """Test getting colors for variables.""" + + def test_get_variable_color(self): + """Test getting color for a single variable.""" + manager = ComponentColorManager() + manager.configure({'oranges': ['Solar_PV']}) + + variable = 'Solar_PV(Bus)|flow_rate' + color = manager.get_variable_color(variable) + + assert color is not None + assert isinstance(color, str) + + def test_get_variable_colors_multiple(self): + """Test getting colors for multiple variables.""" + manager = ComponentColorManager() + manager.configure( + { + 'oranges': ['Solar_PV'], + 'blues': ['Wind_Onshore'], + 'greys': ['Coal_Plant'], + } + ) + + variables = ['Solar_PV(Bus)|flow_rate', 'Wind_Onshore(Bus)|flow_rate', 'Coal_Plant(Bus)|flow_rate'] + + colors = manager.get_variable_colors(variables) + + assert len(colors) == 3 + assert all(var in colors for var in variables) + assert all(isinstance(color, str) for color in colors.values()) + + def test_variable_extraction_in_color_resolution(self): + """Test that variable names are properly extracted to component names.""" + manager = ComponentColorManager() + manager.configure({'Solar_PV': '#FF0000'}) + + # Variable format with flow + variable_color = manager.get_variable_color('Solar_PV(Bus)|flow_rate') + component_color = manager.get_color('Solar_PV') + + # Should be the same color + assert variable_color == component_color + + +class TestIntegrationWithResolveColors: + """Test integration with resolve_colors function.""" + + def test_resolve_colors_with_manager(self): + """Test resolve_colors with ComponentColorManager.""" + manager = ComponentColorManager() + manager.configure( + { + 'oranges': ['Solar_PV'], + 'blues': ['Wind_Onshore'], + } + ) + + dataset = xr.Dataset( + { + 'Solar_PV(Bus)|flow_rate': (['time'], np.random.rand(10)), + 'Wind_Onshore(Bus)|flow_rate': (['time'], np.random.rand(10)), + }, + coords={'time': np.arange(10)}, + ) + + colors = resolve_colors(dataset, manager, engine='plotly') + + assert len(colors) == 2 + assert 'Solar_PV(Bus)|flow_rate' in colors + assert 'Wind_Onshore(Bus)|flow_rate' in colors + + def test_resolve_colors_with_dict(self): + """Test that resolve_colors still works with dict.""" + dataset = xr.Dataset( + {'var1': (['time'], np.random.rand(10)), 'var2': (['time'], np.random.rand(10))}, + coords={'time': np.arange(10)}, + ) + + color_dict = {'var1': '#FF0000', 'var2': '#00FF00'} + colors = resolve_colors(dataset, color_dict, engine='plotly') + + assert colors == color_dict + + +class TestMethodChaining: + """Test method chaining.""" + + def test_configure_returns_self(self): + """Test that configure() returns self for chaining.""" + manager = ComponentColorManager() + result = manager.configure({'Boiler': 'red'}) + + assert result is manager + + def test_chaining_with_initialization(self): + """Test method chaining with initialization.""" + # Test chaining configure() after __init__ + manager = ComponentColorManager(components=['Solar_PV', 'Wind_Onshore']) + manager.configure({'oranges': ['Solar_PV']}) + + assert len(manager.components) == 2 + assert manager.get_color('Solar_PV') is not None + + +class TestUnknownComponents: + """Test behavior with unknown components.""" + + def test_get_color_unknown_component(self): + """Test that unknown components get a default grey color.""" + manager = ComponentColorManager() + manager.configure({'Boiler': 'red'}) + + # Unknown component + color = manager.get_color('UnknownComponent') + + # Should return grey default + assert color == '#808080' + + def test_get_variable_color_unknown_component(self): + """Test that unknown components in variables get default color.""" + manager = ComponentColorManager() + manager.configure({'Boiler': 'red'}) + + # Unknown component + color = manager.get_variable_color('UnknownComponent(Bus)|flow') + + # Should return grey default + assert color == '#808080' + + +class TestColorCaching: + """Test that variable color caching works.""" + + def test_cache_is_used(self): + """Test that cache is used for repeated variable lookups.""" + manager = ComponentColorManager() + manager.configure({'Solar_PV': '#FF0000'}) + + # First call populates cache + color1 = manager.get_variable_color('Solar_PV(Bus)|flow_rate') + + # Second call should hit cache + color2 = manager.get_variable_color('Solar_PV(Bus)|flow_rate') + + assert color1 == color2 + assert 'Solar_PV(Bus)|flow_rate' in manager._variable_cache + + def test_cache_cleared_on_configure(self): + """Test that cache is cleared when colors are reconfigured.""" + manager = ComponentColorManager() + manager.configure({'Solar_PV': '#FF0000'}) + + # Populate cache + manager.get_variable_color('Solar_PV(Bus)|flow_rate') + assert len(manager._variable_cache) > 0 + + # Reconfigure + manager.configure({'Solar_PV': '#00FF00'}) + + # Cache should be cleared + assert len(manager._variable_cache) == 0 diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py new file mode 100644 index 000000000..f59601dca --- /dev/null +++ b/tests/test_plotting_api.py @@ -0,0 +1,64 @@ +"""Smoke tests for plotting API robustness improvements.""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from flixopt import plotting + + +@pytest.fixture +def sample_dataset(): + """Create a sample xarray Dataset for testing.""" + time = np.arange(10) + data = xr.Dataset( + { + 'var1': (['time'], np.random.rand(10)), + 'var2': (['time'], np.random.rand(10)), + 'var3': (['time'], np.random.rand(10)), + }, + coords={'time': time}, + ) + return data + + +@pytest.fixture +def sample_dataframe(): + """Create a sample pandas DataFrame for testing.""" + time = np.arange(10) + df = pd.DataFrame({'var1': np.random.rand(10), 'var2': np.random.rand(10), 'var3': np.random.rand(10)}, index=time) + df.index.name = 'time' + return df + + +def test_kwargs_passthrough_plotly(sample_dataset): + """Test that backend-specific kwargs are passed through correctly.""" + fig = plotting.with_plotly( + sample_dataset, + mode='line', + trace_kwargs={'line': {'width': 5}}, + layout_kwargs={'width': 1200, 'height': 600}, + ) + assert fig.layout.width == 1200 + assert fig.layout.height == 600 + + +def test_dataframe_support_plotly(sample_dataframe): + """Test that DataFrames are accepted by plotting functions.""" + fig = plotting.with_plotly(sample_dataframe, mode='line') + assert fig is not None + + +def test_data_validation_non_numeric(): + """Test that validation catches non-numeric data.""" + data = xr.Dataset({'var1': (['time'], ['a', 'b', 'c'])}, coords={'time': [0, 1, 2]}) + + with pytest.raises(TypeError, match='non-numeric dtype'): + plotting.with_plotly(data) + + +def test_ensure_dataset_invalid_type(): + """Test that _ensure_dataset raises error for invalid types.""" + with pytest.raises(TypeError, match='must be xr.Dataset or pd.DataFrame'): + plotting._ensure_dataset([1, 2, 3]) diff --git a/tests/test_results_plots.py b/tests/test_results_plots.py index 1fd6cf7f5..a656f7c44 100644 --- a/tests/test_results_plots.py +++ b/tests/test_results_plots.py @@ -28,7 +28,7 @@ def plotting_engine(request): @pytest.fixture( params=[ - 'viridis', # Test string colormap + 'turbo', # Test string colormap ['#ff0000', '#00ff00', '#0000ff', '#ffff00', '#ff00ff', '#00ffff'], # Test color list { 'Boiler(Q_th)|flow_rate': '#ff0000', @@ -51,7 +51,7 @@ def test_results_plots(flow_system, plotting_engine, show, save, color_spec): # Matplotlib doesn't support faceting/animation, so disable them for matplotlib engine heatmap_kwargs = { 'reshape_time': ('D', 'h'), - 'colors': 'viridis', # Note: heatmap only accepts string colormap + 'colors': 'turbo', # Note: heatmap only accepts string colormap 'save': save, 'show': show, 'engine': plotting_engine,