diff --git a/CHANGELOG.md b/CHANGELOG.md index dda77c01a..955dd6b69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,16 +53,28 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOpt/flixOpt/releases/tag/v3.0.0) and [Migration Guide](https://flixopt.github.io/flixopt/latest/user-guide/migration-guide-v3/). ### ✨ Added - -### 💥 Breaking Changes +- **Simplified color management**: Configure consistent plot colors with explicit component grouping + - Direct colors: `results.setup_colors({'Boiler1': '#FF0000', 'CHP': 'darkred'})` + - Grouped colors: `results.setup_colors({'oranges': ['Solar1', 'Solar2'], 'blues': ['Wind1', 'Wind2']})` + - Mixed approach: Combine direct and grouped colors in a single call + - File-based: `results.setup_colors('colors.yaml')` (YAML only) +- **Heatmap fill control**: Control missing value handling with `fill='ffill'` or `fill='bfill'` +- **New CONFIG options for plot styling** + - `CONFIG.Plotting.default_sequential_colorscale` - Falls back to template's sequential colorscale when `None` + - `CONFIG.Plotting.default_qualitative_colorscale` - Falls back to template's colorway when `None` + - `CONFIG.Plotting.default_show` defaults to `True` - set to None to prevent unwanted GUI windows ### ♻️ Changed +- **Template integration**: Plotly templates now fully control plot styling without hardcoded overrides + - Removed hardcoded `plot_bgcolor`, `paper_bgcolor`, and `font` settings from plotting functions + - Change template via `CONFIG.Plotting.plotly_template = 'plotly_dark'; CONFIG.apply()` +- Plotting methods now use `color_manager` by default if configured ### 🗑️ Deprecated -### 🔥 Removed - ### 🐛 Fixed +- Improved error messages for matplotlib with multidimensional data +- Better dimension validation in `plot_heatmap()` ### 🔒 Security diff --git a/examples/01_Simple/simple_example.py b/examples/01_Simple/simple_example.py index 906c24622..5b828b60c 100644 --- a/examples/01_Simple/simple_example.py +++ b/examples/01_Simple/simple_example.py @@ -112,6 +112,8 @@ calculation.solve(fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=30)) # --- Analyze Results --- + # Optional: Configure custom colors with + calculation.results.setup_colors({'CHP': 'red', 'Boiler': 'orange'}) calculation.results['Fernwärme'].plot_node_balance_pie() calculation.results['Fernwärme'].plot_node_balance() calculation.results['Storage'].plot_charge_state() diff --git a/examples/02_Complex/complex_example.py b/examples/02_Complex/complex_example.py index 805cb08f6..77623d14c 100644 --- a/examples/02_Complex/complex_example.py +++ b/examples/02_Complex/complex_example.py @@ -206,6 +206,8 @@ calculation.results.to_file() # But let's plot some results anyway + # Optional: Configure custom colors (dict is simplest): + calculation.results.setup_colors({'BHKW*': 'orange', 'Speicher': 'blue'}) calculation.results.plot_heatmap('BHKW2(Q_th)|flow_rate') calculation.results['BHKW2'].plot_node_balance() calculation.results['Speicher'].plot_charge_state() diff --git a/examples/02_Complex/complex_example_results.py b/examples/02_Complex/complex_example_results.py index 5020f71fe..b53b1ee7b 100644 --- a/examples/02_Complex/complex_example_results.py +++ b/examples/02_Complex/complex_example_results.py @@ -18,6 +18,9 @@ f'Original error: {e}' ) from e + # --- Configure Color Mapping for Consistent Plot Colors (Optional) --- + results.setup_colors({'Solar*': 'oranges', 'Wind*': 'blues'}) # Dict (simplest) + # --- Basic overview --- results.plot_network(show=True) results['Fernwärme'].plot_node_balance() @@ -25,8 +28,9 @@ # --- Detailed Plots --- # In depth plot for individual flow rates ('__' is used as the delimiter between Component and Flow results.plot_heatmap('Wärmelast(Q_th_Last)|flow_rate') - for flow_rate in results['BHKW2'].inputs + results['BHKW2'].outputs: - results.plot_heatmap(flow_rate) + for bus in results.buses.values(): + bus.plot_node_balance_pie() + bus.plot_node_balance() # --- Plotting internal variables manually --- results.plot_heatmap('BHKW2(Q_th)|on') diff --git a/examples/03_Calculation_types/example_calculation_types.py b/examples/03_Calculation_types/example_calculation_types.py index 8df8e742f..f71a8eed7 100644 --- a/examples/03_Calculation_types/example_calculation_types.py +++ b/examples/03_Calculation_types/example_calculation_types.py @@ -202,35 +202,35 @@ def get_solutions(calcs: list, variable: str) -> xr.Dataset: # --- Plotting for comparison --- fx.plotting.with_plotly( - get_solutions(calculations, 'Speicher|charge_state').to_dataframe(), + get_solutions(calculations, 'Speicher|charge_state'), mode='line', title='Charge State Comparison', ylabel='Charge state', ).write_html('results/Charge State.html') fx.plotting.with_plotly( - get_solutions(calculations, 'BHKW2(Q_th)|flow_rate').to_dataframe(), + get_solutions(calculations, 'BHKW2(Q_th)|flow_rate'), mode='line', title='BHKW2(Q_th) Flow Rate Comparison', ylabel='Flow rate', ).write_html('results/BHKW2 Thermal Power.html') fx.plotting.with_plotly( - get_solutions(calculations, 'costs(temporal)|per_timestep').to_dataframe(), + get_solutions(calculations, 'costs(temporal)|per_timestep'), mode='line', title='Operation Cost Comparison', ylabel='Costs [€]', ).write_html('results/Operation Costs.html') fx.plotting.with_plotly( - pd.DataFrame(get_solutions(calculations, 'costs(temporal)|per_timestep').to_dataframe().sum()).T, + get_solutions(calculations, 'costs(temporal)|per_timestep').sum('time'), mode='stacked_bar', title='Total Cost Comparison', ylabel='Costs [€]', ).update_layout(barmode='group').write_html('results/Total Costs.html') fx.plotting.with_plotly( - pd.DataFrame([calc.durations for calc in calculations], index=[calc.name for calc in calculations]), + pd.DataFrame([calc.durations for calc in calculations], index=[calc.name for calc in calculations]).to_xarray(), mode='stacked_bar', ).update_layout(title='Duration Comparison', xaxis_title='Calculation type', yaxis_title='Time (s)').write_html( 'results/Speed Comparison.html' diff --git a/examples/04_Scenarios/scenario_example.py b/examples/04_Scenarios/scenario_example.py index 834e55782..993349421 100644 --- a/examples/04_Scenarios/scenario_example.py +++ b/examples/04_Scenarios/scenario_example.py @@ -196,6 +196,8 @@ # --- Solve the Calculation and Save Results --- calculation.solve(fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=30)) + calculation.results.setup_colors() + calculation.results.plot_heatmap('CHP(Q_th)|flow_rate') # --- Analyze Results --- diff --git a/flixopt/aggregation.py b/flixopt/aggregation.py index 91ef618a9..18cac8013 100644 --- a/flixopt/aggregation.py +++ b/flixopt/aggregation.py @@ -21,6 +21,7 @@ TSAM_AVAILABLE = False from .components import Storage +from .config import CONFIG from .structure import ( FlowSystemModel, Submodel, @@ -141,7 +142,7 @@ def describe_clusters(self) -> str: def use_extreme_periods(self): return self.time_series_for_high_peaks or self.time_series_for_low_peaks - def plot(self, colormap: str = 'viridis', show: bool = True, save: pathlib.Path | None = None) -> go.Figure: + def plot(self, colormap: str | None = None, show: bool = True, save: pathlib.Path | None = None) -> go.Figure: from . import plotting df_org = self.original_data.copy().rename( @@ -150,10 +151,22 @@ def plot(self, colormap: str = 'viridis', show: bool = True, save: pathlib.Path df_agg = self.aggregated_data.copy().rename( columns={col: f'Aggregated - {col}' for col in self.aggregated_data.columns} ) - fig = plotting.with_plotly(df_org, 'line', colors=colormap) + fig = plotting.with_plotly( + df_org.to_xarray(), + 'line', + colors=colormap or CONFIG.Plotting.default_qualitative_colorscale, + xlabel='Time in h', + ) for trace in fig.data: trace.update(dict(line=dict(dash='dash'))) - fig = plotting.with_plotly(df_agg, 'line', colors=colormap, fig=fig) + fig2 = plotting.with_plotly( + df_agg.to_xarray(), + 'line', + colors=colormap or CONFIG.Plotting.default_qualitative_colorscale, + xlabel='Time in h', + ) + for trace in fig2.data: + fig.add_trace(trace) fig.update_layout( title='Original vs Aggregated Data (original = ---)', xaxis_title='Index', yaxis_title='Value' diff --git a/flixopt/config.py b/flixopt/config.py index a7549a3ec..957cd320e 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -54,6 +54,20 @@ 'big_binary_bound': 100_000, } ), + 'plotting': MappingProxyType( + { + 'plotly_template': 'plotly_white', + 'default_show': True, + 'default_save_path': None, + 'default_engine': 'plotly', + 'default_dpi': 300, + 'default_figure_width': None, + 'default_figure_height': None, + 'default_facet_cols': 3, + 'default_sequential_colorscale': 'turbo', + 'default_qualitative_colorscale': 'plotly', + } + ), } ) @@ -185,6 +199,67 @@ class Modeling: epsilon: float = _DEFAULTS['modeling']['epsilon'] big_binary_bound: int = _DEFAULTS['modeling']['big_binary_bound'] + class Plotting: + """Plotting configuration. + + Configure backends via environment variables: + - Matplotlib: Set `MPLBACKEND` environment variable (e.g., 'Agg', 'TkAgg') + - Plotly: Set `PLOTLY_RENDERER` or use `plotly.io.renderers.default` + + Attributes: + plotly_template: Plotly theme/template applied to all plots. + default_show: Default value for the `show` parameter in plot methods. + default_save_path: Default directory for saving plots. + default_engine: Default plotting engine. + default_dpi: Default DPI for saved plots. + default_figure_width: Default plot width in pixels. + default_figure_height: Default plot height in pixels. + default_facet_cols: Default number of columns for faceted plots. + default_sequential_colorscale: Default colorscale for heatmaps and continuous data. + default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts). + + Examples: + ```python + # Set consistent theming + CONFIG.Plotting.plotly_template = 'plotly_dark' + CONFIG.apply() + + # Configure default export and color settings + CONFIG.Plotting.default_dpi = 600 + CONFIG.Plotting.default_figure_width = 1200 + CONFIG.Plotting.default_figure_height = 800 + CONFIG.Plotting.default_sequential_colorscale = 'plasma' + CONFIG.Plotting.default_qualitative_colorscale = 'Dark24' + CONFIG.apply() + ``` + """ + + plotly_template: ( + Literal[ + 'plotly', + 'plotly_white', + 'plotly_dark', + 'ggplot2', + 'seaborn', + 'simple_white', + 'none', + 'gridon', + 'presentation', + 'xgridoff', + 'ygridoff', + ] + | None + ) = _DEFAULTS['plotting']['plotly_template'] + default_show: bool = _DEFAULTS['plotting']['default_show'] + default_save_path: str | None = _DEFAULTS['plotting']['default_save_path'] + default_engine: Literal['plotly', 'matplotlib'] = _DEFAULTS['plotting']['default_engine'] + default_dpi: int = _DEFAULTS['plotting']['default_dpi'] + default_figure_width: int | None = _DEFAULTS['plotting']['default_figure_width'] + default_figure_height: int | None = _DEFAULTS['plotting']['default_figure_height'] + default_facet_cols: int = _DEFAULTS['plotting']['default_facet_cols'] + default_sequential_colorscale: str = _DEFAULTS['plotting']['default_sequential_colorscale'] + default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale'] + config_name: str = _DEFAULTS['config_name'] @classmethod @@ -201,12 +276,15 @@ def reset(cls): for key, value in _DEFAULTS['modeling'].items(): setattr(cls.Modeling, key, value) + for key, value in _DEFAULTS['plotting'].items(): + setattr(cls.Plotting, key, value) + cls.config_name = _DEFAULTS['config_name'] cls.apply() @classmethod def apply(cls): - """Apply current configuration to logging system.""" + """Apply current configuration to logging and plotting systems.""" # Convert Colors class attributes to dict colors_dict = { 'DEBUG': cls.Logging.Colors.DEBUG, @@ -243,6 +321,11 @@ def apply(cls): colors=colors_dict, ) + # Apply plotting configuration + _apply_plotting_config( + plotly_template=cls.Plotting.plotly_template, + ) + @classmethod def load_from_file(cls, config_file: str | Path): """Load configuration from YAML file and apply it. @@ -282,6 +365,9 @@ def _apply_config_dict(cls, config_dict: dict): elif key == 'modeling' and isinstance(value, dict): for nested_key, nested_value in value.items(): setattr(cls.Modeling, nested_key, nested_value) + elif key == 'plotting' and isinstance(value, dict): + for nested_key, nested_value in value.items(): + setattr(cls.Plotting, nested_key, nested_value) elif hasattr(cls, key): setattr(cls, key, value) @@ -319,6 +405,18 @@ def to_dict(cls) -> dict: 'epsilon': cls.Modeling.epsilon, 'big_binary_bound': cls.Modeling.big_binary_bound, }, + 'plotting': { + 'plotly_template': cls.Plotting.plotly_template, + 'default_show': cls.Plotting.default_show, + 'default_save_path': cls.Plotting.default_save_path, + 'default_engine': cls.Plotting.default_engine, + 'default_dpi': cls.Plotting.default_dpi, + 'default_figure_width': cls.Plotting.default_figure_width, + 'default_figure_height': cls.Plotting.default_figure_height, + 'default_facet_cols': cls.Plotting.default_facet_cols, + 'default_sequential_colorscale': cls.Plotting.default_sequential_colorscale, + 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale, + }, } @@ -588,6 +686,43 @@ def _setup_logging( logger.addHandler(logging.NullHandler()) +def _apply_plotting_config( + plotly_template: Literal[ + 'plotly', + 'plotly_white', + 'plotly_dark', + 'ggplot2', + 'seaborn', + 'simple_white', + 'none', + 'gridon', + 'presentation', + 'xgridoff', + 'ygridoff', + ] + | None = 'plotly', +) -> None: + """Apply plotting configuration to plotly. + + Args: + plotly_template: Plotly template/theme to apply to all plots. + + Note: + Configure backends via environment variables: + - Matplotlib: Set MPLBACKEND environment variable before importing matplotlib + - Plotly: Set PLOTLY_RENDERER or use plotly.io.renderers.default directly + """ + # Configure Plotly template + try: + import plotly.io as pio + + if plotly_template is not None: + pio.templates.default = plotly_template + except ImportError: + # Plotly not installed, skip configuration + pass + + def change_logging_level(level_name: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']): """Change the logging level for the flixopt logger and all its handlers. diff --git a/flixopt/plotting.py b/flixopt/plotting.py index bd1f3c2c4..452a53efc 100644 --- a/flixopt/plotting.py +++ b/flixopt/plotting.py @@ -8,7 +8,7 @@ Key Features: **Dual Backend Support**: Seamless switching between Plotly and Matplotlib **Energy System Focus**: Specialized plots for power flows, storage states, emissions - **Color Management**: Intelligent color processing and palette management + **Color Management**: Intelligent color processing with ColorProcessor for flexible coloring **Export Capabilities**: High-quality export for reports and publications **Integration Ready**: Designed for use with CalculationResults and standalone analysis @@ -42,6 +42,8 @@ import xarray as xr from plotly.exceptions import PlotlyError +from .config import CONFIG + if TYPE_CHECKING: import pyvis @@ -72,7 +74,7 @@ Color specifications can take several forms to accommodate different use cases: **Named Colormaps** (str): - - Standard colormaps: 'viridis', 'plasma', 'cividis', 'tab10', 'Set1' + - Standard colormaps: 'turbo', 'plasma', 'cividis', 'tab10', 'Set1' - Energy-focused: 'portland' (custom flixopt colormap for energy systems) - Backend-specific maps available in Plotly and Matplotlib @@ -89,7 +91,7 @@ Examples: ```python # Named colormap - colors = 'viridis' # Automatic color generation + colors = 'turbo' # Automatic color generation # Explicit color list colors = ['red', 'blue', 'green', '#FFD700'] @@ -136,57 +138,75 @@ class ColorProcessor: **Energy System Colors**: Built-in palettes optimized for energy system visualization Color Input Types: - - **Named Colormaps**: 'viridis', 'plasma', 'portland', 'tab10', etc. + - **Named Colormaps**: 'turbo', 'plasma', 'portland', etc. - **Color Lists**: ['red', 'blue', 'green'] or ['#FF0000', '#0000FF', '#00FF00'] - **Label Dictionaries**: {'Generator': 'red', 'Storage': 'blue', 'Load': 'green'} - Examples: - Basic color processing: - + Example: ```python - # Initialize for Plotly backend - processor = ColorProcessor(engine='plotly', default_colormap='viridis') - - # Process different color specifications + processor = ColorProcessor(engine='plotly', default_colorscale='turbo') colors = processor.process_colors('plasma', ['Gen1', 'Gen2', 'Storage']) - colors = processor.process_colors(['red', 'blue', 'green'], ['A', 'B', 'C']) - colors = processor.process_colors({'Wind': 'skyblue', 'Solar': 'gold'}, ['Wind', 'Solar', 'Gas']) - - # Switch to Matplotlib - processor = ColorProcessor(engine='matplotlib') - mpl_colors = processor.process_colors('tab10', component_labels) - ``` - - Energy system visualization: - - ```python - # Specialized energy system palette - energy_colors = { - 'Natural_Gas': '#8B4513', # Brown - 'Electricity': '#FFD700', # Gold - 'Heat': '#FF4500', # Red-orange - 'Cooling': '#87CEEB', # Sky blue - 'Hydrogen': '#E6E6FA', # Lavender - 'Battery': '#32CD32', # Lime green - } - - processor = ColorProcessor('plotly') - flow_colors = processor.process_colors(energy_colors, flow_labels) ``` Args: engine: Plotting backend ('plotly' or 'matplotlib'). Determines output color format. - default_colormap: Fallback colormap when requested palettes are unavailable. - Common options: 'viridis', 'plasma', 'tab10', 'portland'. + default_colorscale: Fallback colormap when requested palettes are unavailable. + Common options: 'turbo', 'plasma', 'portland'. """ - def __init__(self, engine: PlottingEngine = 'plotly', default_colormap: str = 'viridis'): + def __init__(self, engine: PlottingEngine = 'plotly', default_colorscale: str | None = None): """Initialize the color processor with specified backend and defaults.""" if engine not in ['plotly', 'matplotlib']: raise TypeError(f'engine must be "plotly" or "matplotlib", but is {engine}') self.engine = engine - self.default_colormap = default_colormap + self.default_colorscale = ( + default_colorscale if default_colorscale is not None else CONFIG.Plotting.default_qualitative_colorscale + ) + + def _get_sequential_colorscale(self, colormap_name: str, num_colors: int) -> list[str] | None: + try: + colorscale = px.colors.get_colorscale(colormap_name) + # Generate evenly spaced points + color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0] + return px.colors.sample_colorscale(colorscale, color_points) + except PlotlyError: + return None + + def _get_plotly_colormap_robust(self, colormap_name: str, num_colors: int) -> list[str]: + # First try qualitative color sequences (Dark24, Plotly, Set1, etc.) + colormap_title = colormap_name.title() + if hasattr(px.colors.qualitative, colormap_title): + color_list = getattr(px.colors.qualitative, colormap_title) + # Cycle through colors if we need more than available + return [color_list[i % len(color_list)] for i in range(num_colors)] + + # Then try sequential/continuous colorscales (turbo, plasma, etc.) + colors = self._get_sequential_colorscale(colormap_name, num_colors) + if colors is not None: + return colors + + # Fallback to default_colorscale + logger.warning(f"Colormap '{colormap_name}' not found in Plotly. Trying default '{self.default_colorscale}'") + + # Try default as qualitative + default_title = self.default_colorscale.title() + if hasattr(px.colors.qualitative, default_title): + color_list = getattr(px.colors.qualitative, default_title) + return [color_list[i % len(color_list)] for i in range(num_colors)] + + # Try default as sequential + colors = self._get_sequential_colorscale(self.default_colorscale, num_colors) + if colors is not None: + return colors + + # Ultimate fallback: use built-in Plotly qualitative colormap + logger.warning( + f"Both '{colormap_name}' and default '{self.default_colorscale}' not found. " + f"Using hardcoded fallback 'Plotly' colormap" + ) + color_list = px.colors.qualitative.Plotly + return [color_list[i % len(color_list)] for i in range(num_colors)] def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list[Any]: """ @@ -200,22 +220,23 @@ def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list of colors in the format appropriate for the engine """ if self.engine == 'plotly': - try: - colorscale = px.colors.get_colorscale(colormap_name) - except PlotlyError as e: - logger.error(f"Colorscale '{colormap_name}' not found in Plotly. Using {self.default_colormap}: {e}") - colorscale = px.colors.get_colorscale(self.default_colormap) - - # Generate evenly spaced points - color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0] - return px.colors.sample_colorscale(colorscale, color_points) + return self._get_plotly_colormap_robust(colormap_name, num_colors) else: # matplotlib try: cmap = plt.get_cmap(colormap_name, num_colors) except ValueError as e: - logger.error(f"Colormap '{colormap_name}' not found in Matplotlib. Using {self.default_colormap}: {e}") - cmap = plt.get_cmap(self.default_colormap, num_colors) + logger.warning( + f"Colormap '{colormap_name}' not found in Matplotlib. Trying default '{self.default_colorscale}': {e}" + ) + try: + cmap = plt.get_cmap(self.default_colorscale, num_colors) + except ValueError: + logger.warning( + f"Default colormap '{self.default_colorscale}' also not found in Matplotlib. " + f"Using hardcoded fallback 'tab10'" + ) + cmap = plt.get_cmap('tab10', num_colors) return [cmap(i) for i in range(num_colors)] @@ -231,8 +252,8 @@ def _handle_color_list(self, colors: list[str], num_labels: int) -> list[str]: list of colors matching the number of labels """ if len(colors) == 0: - logger.error(f'Empty color list provided. Using {self.default_colormap} instead.') - return self._generate_colors_from_colormap(self.default_colormap, num_labels) + logger.error(f'Empty color list provided. Using {self.default_colorscale} instead.') + return self._generate_colors_from_colormap(self.default_colorscale, num_labels) if len(colors) < num_labels: logger.warning( @@ -261,18 +282,18 @@ def _handle_color_dict(self, colors: dict[str, str], labels: list[str]) -> list[ list of colors in the same order as labels """ if len(colors) == 0: - logger.warning(f'Empty color dictionary provided. Using {self.default_colormap} instead.') - return self._generate_colors_from_colormap(self.default_colormap, len(labels)) + logger.warning(f'Empty color dictionary provided. Using {self.default_colorscale} instead.') + return self._generate_colors_from_colormap(self.default_colorscale, len(labels)) # Find missing labels missing_labels = sorted(set(labels) - set(colors.keys())) if missing_labels: logger.warning( - f'Some labels have no color specified: {missing_labels}. Using {self.default_colormap} for these.' + f'Some labels have no color specified: {missing_labels}. Using {self.default_colorscale} for these.' ) # Generate colors for missing labels - missing_colors = self._generate_colors_from_colormap(self.default_colormap, len(missing_labels)) + missing_colors = self._generate_colors_from_colormap(self.default_colorscale, len(missing_labels)) # Create a copy to avoid modifying the original colors_copy = colors.copy() @@ -315,9 +336,9 @@ def process_colors( color_list = self._handle_color_dict(colors, labels) else: logger.error( - f'Unsupported color specification type: {type(colors)}. Using {self.default_colormap} instead.' + f'Unsupported color specification type: {type(colors)}. Using {self.default_colorscale} instead.' ) - color_list = self._generate_colors_from_colormap(self.default_colormap, len(labels)) + color_list = self._generate_colors_from_colormap(self.default_colorscale, len(labels)) # Return either a list or a mapping if return_mapping: @@ -326,19 +347,361 @@ def process_colors( return color_list +class ElementColorResolver: + """Resolve colors for flow system elements with pattern matching and colorscale support. + + Works with any element type (components, buses, flows) that has a _variable_names attribute. + Handles pattern matching, colorscale detection/sampling, and variable expansion. + + This centralizes all color resolution logic in the plotting module, keeping it separate + from the CalculationResults class. + + Example: + ```python + resolver = ElementColorResolver(results.components) + colors = resolver.resolve({'Solar*': 'Oranges', 'Wind*': 'Blues'}) + # Returns: {'Solar1(Bus)|flow_rate': '#ff8c00', 'Solar2(Bus)|flow_rate': '#ff7700', ...} + ``` + """ + + def __init__( + self, + elements: dict, + default_colorscale: str | None = None, + engine: PlottingEngine = 'plotly', + ): + """Initialize resolver. + + Args: + elements: Dict of element_label → element object (must have _variable_names attribute) + default_colorscale: Default colorscale for unmapped elements + engine: Plotting engine for ColorProcessor ('plotly' or 'matplotlib') + """ + self.elements = elements + self.processor = ColorProcessor( + engine=engine, + default_colorscale=default_colorscale or CONFIG.Plotting.default_qualitative_colorscale, + ) + + def resolve( + self, + config: dict[str, str | list[str]] | str | pathlib.Path | None = None, + reset: bool = True, + existing_colors: dict[str, str] | None = None, + ) -> dict[str, str]: + """Resolve config to variable→color dict. + + Args: + config: Color configuration: + - dict: Component/pattern to color/colorscale mapping + - str/Path: Path to YAML file + - None: Use default colorscale for all elements + reset: If True, reset all existing colors. If False, merge with existing at variable level. + existing_colors: Existing variable→color dict (for variable-level merging when reset=False) + + Returns: + dict[str, str]: Complete variable→color mapping + + Examples: + ```python + # Direct assignment + resolver.resolve({'Boiler1': 'red'}) + + # Pattern with color + resolver.resolve({'Solar*': 'orange'}) + + # Pattern with colorscale + resolver.resolve({'Solar*': 'Oranges'}) + + # Family grouping + resolver.resolve({'oranges': ['Solar1', 'Solar2']}) + + # Merge mode (preserves existing) + resolver.resolve({'NewComp': 'blue'}, reset=False, existing_colors=existing) + ``` + """ + # Load from file if needed + if isinstance(config, (str, pathlib.Path)): + config = load_yaml_config(config) + + # Resolve element colors (with pattern matching) + element_colors = self._resolve_element_colors(config) + + # Expand to variables + variable_colors = self._expand_to_variables(element_colors) + + # Variable-level merge: preserve existing variable colors not in new mapping + if not reset and existing_colors: + # Start with existing, then update with new + merged = existing_colors.copy() + merged.update(variable_colors) + return merged + + return variable_colors + + def _resolve_element_colors( + self, + config: dict[str, str | list[str]] | None, + ) -> dict[str, str]: + """Resolve config to element→color dict with pattern matching. + + Args: + config: Configuration dict or None + + Returns: + dict[str, str]: Element name → color mapping + """ + import fnmatch + + element_names = list(self.elements.keys()) + element_colors = {} + + # If no config, use default colorscale for all elements + if config is None: + colors = self.processor._generate_colors_from_colormap( + self.processor.default_colorscale, len(element_names) + ) + return dict(zip(element_names, colors, strict=False)) + + # Process config entries + for key, value in config.items(): + if isinstance(value, str): + # Check if key is a pattern or direct element name + if '*' in key or '?' in key: + # Pattern matching + matched = [e for e in element_names if fnmatch.fnmatch(e, key)] + if is_colorscale(value): + # Sample colorscale for matched elements + colors = self.processor._generate_colors_from_colormap(value, len(matched)) + element_colors.update(zip(matched, colors, strict=False)) + else: + # Apply same color to all matched elements + for elem in matched: + element_colors[elem] = value + else: + # Direct element→color assignment + element_colors[key] = value + + elif isinstance(value, list): + # Family grouping: colorscale → [elements] + colors = self.processor._generate_colors_from_colormap(key, len(value)) + element_colors.update(zip(value, colors, strict=False)) + + # Fill in missing elements with default colorscale + missing = [e for e in element_names if e not in element_colors] + if missing: + colors = self.processor._generate_colors_from_colormap(self.processor.default_colorscale, len(missing)) + element_colors.update(zip(missing, colors, strict=False)) + + return element_colors + + def _expand_to_variables(self, element_colors: dict[str, str]) -> dict[str, str]: + """Map element colors to all their variables. + + Args: + element_colors: Element name → color mapping + + Returns: + dict[str, str]: Variable name → color mapping + """ + variable_colors = {} + for element_name, color in element_colors.items(): + if element_name in self.elements: + # Access _variable_names from element object (ComponentResults, BusResults, etc.) + variable_colors[element_name] = color + for var in self.elements[element_name]._variable_names: + variable_colors[var] = color + return variable_colors + + +def load_yaml_config(path: str | pathlib.Path) -> dict[str, str | list[str]]: + """Load YAML color configuration file. + + Args: + path: Path to YAML file + + Returns: + dict: Color configuration + + Raises: + FileNotFoundError: If file doesn't exist + ValueError: If file is not valid YAML dict + + Example: + ```python + # colors.yaml: + # Boiler1: red + # Solar*: Oranges + # oranges: + # - Solar1 + # - Solar2 + + config = load_yaml_config('colors.yaml') + ``` + """ + import yaml + + path = pathlib.Path(path) + if not path.exists(): + raise FileNotFoundError(f'Color configuration file not found: {path}') + + with open(path, encoding='utf-8') as f: + config = yaml.safe_load(f) + + if not isinstance(config, dict): + raise ValueError(f'Invalid config file structure. Expected dict, got {type(config).__name__}') + + return config + + +def is_colorscale(name: str) -> bool: + """Check if string is a colorscale name vs a direct color. + + Args: + name: Color or colorscale name + + Returns: + bool: True if it's a colorscale, False if it's a direct color + + Examples: + ```python + is_colorscale('#FF0000') # False (hex color) + is_colorscale('red') # False (CSS color) + is_colorscale('Oranges') # True (Plotly colorscale) + is_colorscale('viridis') # True (matplotlib colormap) + ``` + """ + # Direct color patterns + if name.startswith('#') or name.startswith('rgb'): + return False + + # Check if it's a known CSS color (common colors) + common_colors = { + 'red', + 'blue', + 'green', + 'yellow', + 'orange', + 'purple', + 'pink', + 'brown', + 'black', + 'white', + 'gray', + 'grey', + 'cyan', + 'magenta', + 'lime', + 'navy', + 'teal', + 'aqua', + 'maroon', + 'olive', + 'silver', + 'gold', + 'indigo', + 'violet', + } + if name.lower() in common_colors: + return False + + # Check Plotly colorscales + try: + import plotly.express as px + + if hasattr(px.colors.qualitative, name.title()) or hasattr(px.colors.sequential, name.title()): + return True + except Exception: + pass + + # Check matplotlib colorscales + try: + import matplotlib.pyplot as plt + + return name in plt.colormaps() + except Exception: + return False + + +def _ensure_dataset(data: xr.Dataset | pd.DataFrame) -> xr.Dataset: + """Convert DataFrame to Dataset if needed.""" + if isinstance(data, xr.Dataset): + return data + elif isinstance(data, pd.DataFrame): + # Convert DataFrame to Dataset + return data.to_xarray() + else: + raise TypeError(f'Data must be xr.Dataset or pd.DataFrame, got {type(data).__name__}') + + +def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None: + """Validate dataset for plotting (checks for empty data, non-numeric types, etc.).""" + # Check for empty data + if not allow_empty and len(data.data_vars) == 0: + raise ValueError('Empty Dataset provided (no variables). Cannot create plot.') + + # Check if dataset has any data (xarray uses nbytes for total size) + if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True: + if not allow_empty and len(data.data_vars) > 0: + raise ValueError('Dataset has zero size. Cannot create plot.') + if len(data.data_vars) == 0: + return # Empty dataset, nothing to validate + return + + # Check for non-numeric data types + for var in data.data_vars: + dtype = data[var].dtype + if not np.issubdtype(dtype, np.number): + raise TypeError( + f"Variable '{var}' has non-numeric dtype '{dtype}'. " + f'Plotting requires numeric data types (int, float, etc.).' + ) + + # Warn about NaN/Inf values + for var in data.data_vars: + if data[var].isnull().any(): + logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.") + if np.isinf(data[var].values).any(): + logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.") + + +def resolve_colors( + data: xr.Dataset, + colors: ColorType, + engine: PlottingEngine = 'plotly', +) -> dict[str, str]: + """Resolve colors parameter to a dict mapping variable names to colors.""" + # Get variable names from Dataset (always strings and unique) + labels = list(data.data_vars.keys()) + + # If explicit dict provided, use it directly + if isinstance(colors, dict): + return colors + + # If string or list, use ColorProcessor (traditional behavior) + if isinstance(colors, (str, list)): + processor = ColorProcessor(engine=engine) + return processor.process_colors(colors, labels, return_mapping=True) + + raise TypeError(f'Wrong type passed to resolve_colors(): {type(colors)}') + + def with_plotly( - data: pd.DataFrame | xr.DataArray | xr.Dataset, + data: xr.Dataset | pd.DataFrame, mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar', - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', ylabel: str = '', - xlabel: str = 'Time in h', + xlabel: str = '', fig: go.Figure | None = None, facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, shared_yaxes: bool = True, shared_xaxes: bool = True, + trace_kwargs: dict[str, Any] | None = None, + layout_kwargs: dict[str, Any] | None = None, + **px_kwargs: Any, ) -> go.Figure: """ Plot data with Plotly using facets (subplots) and/or animation for multidimensional data. @@ -347,10 +710,13 @@ def with_plotly( For simple plots without faceting, can optionally add to an existing figure. Args: - data: A DataFrame or xarray DataArray/Dataset to plot. + data: An xarray Dataset to plot. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines, 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts. - colors: Color specification (colormap, list, or dict mapping labels to colors). + colors: Color specification. Can be: + - A colormap name (e.g., 'turbo', 'plasma') + - A list of color strings (e.g., ['#ff0000', '#00ff00']) + - A dict mapping labels to colors (e.g., {'Solar': '#FFD700'}) title: The main title of the plot. ylabel: The label for the y-axis. xlabel: The label for the x-axis. @@ -364,6 +730,12 @@ def with_plotly( facet_cols: Number of columns in the facet grid (used when facet_by is single dimension). shared_yaxes: Whether subplots share y-axes. shared_xaxes: Whether subplots share x-axes. + trace_kwargs: Optional dict of parameters to pass to fig.update_traces(). + Use this to customize trace properties (e.g., marker style, line width). + layout_kwargs: Optional dict of parameters to pass to fig.update_layout(). + Use this to customize layout properties (e.g., width, height, legend position). + **px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function + (px.bar, px.line, px.area). These override default arguments if provided. Returns: A Plotly figure object containing the faceted/animated plot. @@ -372,85 +744,104 @@ def with_plotly( Simple plot: ```python - fig = with_plotly(df, mode='area', title='Energy Mix') + fig = with_plotly(dataset, mode='area', title='Energy Mix') ``` Facet by scenario: ```python - fig = with_plotly(ds, facet_by='scenario', facet_cols=2) + fig = with_plotly(dataset, facet_by='scenario', facet_cols=2) ``` Animate by period: ```python - fig = with_plotly(ds, animate_by='period') + fig = with_plotly(dataset, animate_by='period') ``` Facet and animate: ```python - fig = with_plotly(ds, facet_by='scenario', animate_by='period') + fig = with_plotly(dataset, facet_by='scenario', animate_by='period') + ``` + + Custom color mapping: + + ```python + colors = {'Solar': 'orange', 'Wind': 'blue', 'Battery': 'green', 'Gas': 'red'} + fig = with_plotly(dataset, colors=colors, mode='area') ``` """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}") + # Apply CONFIG defaults if not explicitly set + if facet_cols is None: + facet_cols = CONFIG.Plotting.default_facet_cols + + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + # Handle empty data - if isinstance(data, pd.DataFrame) and data.empty: - return go.Figure() - elif isinstance(data, xr.DataArray) and data.size == 0: - return go.Figure() - elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0: + if len(data.data_vars) == 0: + logger.error('"with_plotly() got an empty Dataset.') return go.Figure() + # Handle all-scalar datasets (where all variables have no dimensions) + # This occurs when all variables are scalar values with dims=() + if all(len(data[var].dims) == 0 for var in data.data_vars): + # Create a simple DataFrame with variable names as x-axis + variables = list(data.data_vars.keys()) + values = [float(data[var].values) for var in data.data_vars] + + # Resolve colors + color_discrete_map = resolve_colors(data, colors, engine='plotly') + marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables] + + # Create simple plot based on mode using go (not px) for better color control + if mode in ('stacked_bar', 'grouped_bar'): + fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)]) + elif mode == 'line': + fig = go.Figure( + data=[ + go.Scatter( + x=variables, + y=values, + mode='lines+markers', + marker=dict(color=marker_colors, size=8), + line=dict(color='lightgray'), + ) + ] + ) + elif mode == 'area': + fig = go.Figure( + data=[ + go.Scatter( + x=variables, + y=values, + fill='tozeroy', + marker=dict(color=marker_colors, size=8), + line=dict(color='lightgray'), + ) + ] + ) + + fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False) + return fig + # Warn if fig parameter is used with faceting if fig is not None and (facet_by is not None or animate_by is not None): logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.') fig = None - # Convert xarray to long-form DataFrame for Plotly Express - if isinstance(data, (xr.DataArray, xr.Dataset)): - # Convert to long-form (tidy) DataFrame - # Structure: time, variable, value, scenario, period, ... (all dims as columns) - if isinstance(data, xr.Dataset): - # Stack all data variables into long format - df_long = data.to_dataframe().reset_index() - # Melt to get: time, scenario, period, ..., variable, value - id_vars = [dim for dim in data.dims] - value_vars = list(data.data_vars) - df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') - else: - # DataArray - df_long = data.to_dataframe().reset_index() - if data.name: - df_long = df_long.rename(columns={data.name: 'value'}) - else: - # Unnamed DataArray, find the value column - non_dim_cols = [col for col in df_long.columns if col not in data.dims] - if len(non_dim_cols) != 1: - raise ValueError( - f'Expected exactly one non-dimension column for unnamed DataArray, ' - f'but found {len(non_dim_cols)}: {non_dim_cols}' - ) - value_col = non_dim_cols[0] - df_long = df_long.rename(columns={value_col: 'value'}) - df_long['variable'] = data.name or 'data' - else: - # Already a DataFrame - convert to long format for Plotly Express - df_long = data.reset_index() - if 'time' not in df_long.columns: - # First column is probably time - df_long = df_long.rename(columns={df_long.columns[0]: 'time'}) - # Melt to long format - id_vars = [ - col - for col in df_long.columns - if col in ['time', 'scenario', 'period'] - or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else []) - ] - value_vars = [col for col in df_long.columns if col not in id_vars] - df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') + # Convert Dataset to long-form DataFrame for Plotly Express + # Structure: time, variable, value, scenario, period, ... (all dims as columns) + dim_names = list(data.dims) + df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value') # Validate facet_by and animate_by dimensions exist in the data available_dims = [col for col in df_long.columns if col not in ['variable', 'value']] @@ -500,15 +891,38 @@ def with_plotly( else: raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}') - # Process colors + # Process colors using resolve_colors (handles validation and all color types) + color_discrete_map = resolve_colors(data, colors, engine='plotly') + + # Get unique variable names for area plot processing all_vars = df_long['variable'].unique().tolist() - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars) - color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)} + + # Determine which dimension to use for x-axis + # Collect dimensions used for faceting and animation + used_dims = set() + if facet_row: + used_dims.add(facet_row) + if facet_col: + used_dims.add(facet_col) + if animate_by: + used_dims.add(animate_by) + + # Find available dimensions for x-axis (not used for faceting/animation) + x_candidates = [d for d in available_dims if d not in used_dims] + + # Use 'time' if available, otherwise use the first available dimension + if 'time' in x_candidates: + x_dim = 'time' + elif len(x_candidates) > 0: + x_dim = x_candidates[0] + else: + # Fallback: use the first dimension (shouldn't happen in normal cases) + x_dim = available_dims[0] if available_dims else 'time' # Create plot using Plotly Express based on mode common_args = { 'data_frame': df_long, - 'x': 'time', + 'x': x_dim, 'y': 'value', 'color': 'variable', 'facet_row': facet_row, @@ -516,13 +930,16 @@ def with_plotly( 'animation_frame': animate_by, 'color_discrete_map': color_discrete_map, 'title': title, - 'labels': {'value': ylabel, 'time': xlabel, 'variable': ''}, + 'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''}, } # Add facet_col_wrap for single facet dimension if facet_col and not facet_row: common_args['facet_col_wrap'] = facet_cols + # Apply user-provided Plotly Express kwargs (overrides defaults) + common_args.update(px_kwargs) + if mode == 'stacked_bar': fig = px.bar(**common_args) fig.update_traces(marker_line_width=0) @@ -577,50 +994,48 @@ def with_plotly( if hasattr(trace, 'fill'): trace.fill = None - # Update layout with basic styling (Plotly Express handles sizing automatically) - fig.update_layout( - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - font=dict(size=12), - ) - # Update axes to share if requested (Plotly Express already handles this, but we can customize) if not shared_yaxes: fig.update_yaxes(matches=None) if not shared_xaxes: fig.update_xaxes(matches=None) + # Apply user-provided trace and layout customizations + if trace_kwargs: + fig.update_traces(**trace_kwargs) + if layout_kwargs: + fig.update_layout(**layout_kwargs) + return fig def with_matplotlib( - data: pd.DataFrame, + data: xr.Dataset | pd.DataFrame, mode: Literal['stacked_bar', 'line'] = 'stacked_bar', - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', ylabel: str = '', xlabel: str = 'Time in h', figsize: tuple[int, int] = (12, 6), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, + plot_kwargs: dict[str, Any] | None = None, ) -> tuple[plt.Figure, plt.Axes]: """ - Plot a DataFrame with Matplotlib using stacked bars or stepped lines. + Plot data with Matplotlib using stacked bars or stepped lines. Args: - data: A DataFrame containing the data to plot. The index should represent time (e.g., hours), - and each column represents a separate data series. + data: An xarray Dataset to plot. After conversion to DataFrame, + the index represents time and each column represents a separate data series (variables). mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines. - colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') + colors: Color specification. Can be: + - A colormap name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A dict mapping column names to colors (e.g., {'Column1': '#ff0000'}) title: The title of the plot. ylabel: The ylabel of the plot. xlabel: The xlabel of the plot. - figsize: Specify the size of the figure - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. + figsize: Specify the size of the figure (width, height) in inches. + plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls. + Use this to customize plot properties (e.g., linewidth, alpha, edgecolor). Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. @@ -629,49 +1044,115 @@ def with_matplotlib( - If `mode` is 'stacked_bar', bars are stacked for both positive and negative values. Negative values are stacked separately without extra labels in the legend. - If `mode` is 'line', stepped lines are drawn for each data series. + + Examples: + Custom color mapping: + + ```python + colors = {'Solar': 'orange', 'Wind': 'blue', 'Coal': 'red'} + fig, ax = with_matplotlib(dataset, colors=colors, mode='line') + ``` """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + if mode not in ('stacked_bar', 'line'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}") - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + + # Create new figure and axes + fig, ax = plt.subplots(figsize=figsize) + + # Initialize plot_kwargs if not provided + if plot_kwargs is None: + plot_kwargs = {} + + # Handle all-scalar datasets (where all variables have no dimensions) + # This occurs when all variables are scalar values with dims=() + if all(len(data[var].dims) == 0 for var in data.data_vars): + # Create simple bar/line plot with variable names as x-axis + variables = list(data.data_vars.keys()) + values = [float(data[var].values) for var in data.data_vars] + + # Resolve colors + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + colors_list = [color_discrete_map.get(var, '#808080') for var in variables] + + # Create plot based on mode + if mode == 'stacked_bar': + ax.bar(variables, values, color=colors_list, **plot_kwargs) + elif mode == 'line': + ax.plot( + variables, + values, + marker='o', + color=colors_list[0] if len(set(colors_list)) == 1 else None, + **plot_kwargs, + ) + # If different colors, plot each point separately + if len(set(colors_list)) > 1: + ax.clear() + for i, (var, val) in enumerate(zip(variables, values, strict=False)): + ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs) + ax.set_xticks(range(len(variables))) + ax.set_xticklabels(variables) + + ax.set_xlabel(xlabel, ha='center') + ax.set_ylabel(ylabel, va='center') + ax.set_title(title) + ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y') + fig.tight_layout() - processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, list(data.columns)) + return fig, ax + + # Resolve colors first (includes validation) + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + + # Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form) + df = data.to_dataframe() + + # Get colors in column order + processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns] if mode == 'stacked_bar': - cumulative_positive = np.zeros(len(data)) - cumulative_negative = np.zeros(len(data)) - width = data.index.to_series().diff().dropna().min() # Minimum time difference + cumulative_positive = np.zeros(len(df)) + cumulative_negative = np.zeros(len(df)) + width = df.index.to_series().diff().dropna().min() # Minimum time difference - for i, column in enumerate(data.columns): - positive_values = np.clip(data[column], 0, None) # Keep only positive values - negative_values = np.clip(data[column], None, 0) # Keep only negative values + for i, column in enumerate(df.columns): + positive_values = np.clip(df[column], 0, None) # Keep only positive values + negative_values = np.clip(df[column], None, 0) # Keep only negative values # Plot positive bars ax.bar( - data.index, + df.index, positive_values, bottom=cumulative_positive, color=processed_colors[i], label=column, width=width, align='center', + **plot_kwargs, ) cumulative_positive += positive_values.values # Plot negative bars ax.bar( - data.index, + df.index, negative_values, bottom=cumulative_negative, color=processed_colors[i], label='', # No label for negative bars width=width, align='center', + **plot_kwargs, ) cumulative_negative += negative_values.values elif mode == 'line': - for i, column in enumerate(data.columns): - ax.step(data.index, data[column], where='post', color=processed_colors[i], label=column) + for i, column in enumerate(df.columns): + ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs) # Aesthetics ax.set_xlabel(xlabel, ha='center') @@ -945,62 +1426,93 @@ def plot_network( def pie_with_plotly( - data: pd.DataFrame, - colors: ColorType = 'viridis', + data: xr.Dataset | pd.DataFrame, + colors: ColorType | None = None, title: str = '', legend_title: str = '', hole: float = 0.0, fig: go.Figure | None = None, + hover_template: str = '%{label}: %{value} (%{percent})', + text_info: str = 'percent+label+value', + text_position: str = 'inside', ) -> go.Figure: """ - Create a pie chart with Plotly to visualize the proportion of values in a DataFrame. + Create a pie chart with Plotly to visualize the proportion of values in a Dataset. Args: - data: A DataFrame containing the data to plot. If multiple rows exist, - they will be summed unless a specific index value is passed. + data: An xarray Dataset containing the data to plot. All dimensions will be summed + to get the total for each variable. colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') + - A string with a colorscale name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A dictionary mapping variable names to colors (e.g., {'Solar': '#ff0000'}) title: The title of the plot. legend_title: The title for the legend. hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0). fig: A Plotly figure object to plot on. If not provided, a new figure will be created. + hover_template: Template for hover text. Use %{label}, %{value}, %{percent}. + text_info: What to show on pie segments: 'label', 'percent', 'value', 'label+percent', + 'label+value', 'percent+value', 'label+percent+value', or 'none'. + text_position: Position of text: 'inside', 'outside', 'auto', or 'none'. Returns: A Plotly figure object containing the generated pie chart. Notes: - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning. - - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category - for better readability. - - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing. + - All dimensions are summed to get total values for each variable. + - Scalar variables (with no dimensions) are used directly. + + Examples: + Simple pie chart: + + ```python + fig = pie_with_plotly(dataset, colors='turbo', title='Energy Mix') + ``` + + Custom color mapping: + ```python + colors = {'Solar': 'orange', 'Wind': 'blue', 'Coal': 'red'} + fig = pie_with_plotly(dataset, colors=colors, title='Renewable Energy') + ``` """ - if data.empty: - logger.error('Empty DataFrame provided for pie chart. Returning empty figure.') + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + + if len(data.data_vars) == 0: + logger.error('Empty Dataset provided for pie chart. Returning empty figure.') return go.Figure() - # Create a copy to avoid modifying the original DataFrame - data_copy = data.copy() + # Sum all dimensions for each variable to get total values + labels = [] + values = [] - # Check if any negative values and warn - if (data_copy < 0).any().any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - data_copy = data_copy.abs() + for var in data.data_vars: + var_data = data[var] - # If data has multiple rows, sum them to get total for each column - if len(data_copy) > 1: - data_sum = data_copy.sum() - else: - data_sum = data_copy.iloc[0] + # Sum across all dimensions to get total + if len(var_data.dims) > 0: + total_value = var_data.sum().item() + else: + # Scalar variable + total_value = var_data.item() - # Get labels (column names) and values - labels = data_sum.index.tolist() - values = data_sum.values.tolist() + # Check for negative values + if total_value < 0: + logger.warning(f'Negative value detected for {var}: {total_value}. Using absolute value.') + total_value = abs(total_value) - # Apply color mapping using the unified color processor - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, labels) + labels.append(str(var)) + values.append(total_value) + + # Use resolve_colors for consistent color handling + color_discrete_map = resolve_colors(data, colors, engine='plotly') + processed_colors = [color_discrete_map.get(label, '#636EFA') for label in labels] # Create figure if not provided fig = fig if fig is not None else go.Figure() @@ -1012,91 +1524,107 @@ def pie_with_plotly( values=values, hole=hole, marker=dict(colors=processed_colors), - textinfo='percent+label+value', - textposition='inside', + textinfo=text_info, + textposition=text_position, insidetextorientation='radial', + hovertemplate=hover_template, ) ) - # Update layout for better aesthetics + # Update layout with plot-specific properties fig.update_layout( title=title, legend_title=legend_title, - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), # Increase font size for better readability ) return fig def pie_with_matplotlib( - data: pd.DataFrame, - colors: ColorType = 'viridis', + data: xr.Dataset | pd.DataFrame, + colors: ColorType | None = None, title: str = '', legend_title: str = 'Categories', hole: float = 0.0, figsize: tuple[int, int] = (10, 8), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, ) -> tuple[plt.Figure, plt.Axes]: """ - Create a pie chart with Matplotlib to visualize the proportion of values in a DataFrame. + Create a pie chart with Matplotlib to visualize the proportion of values in a Dataset. Args: - data: A DataFrame containing the data to plot. If multiple rows exist, - they will be summed unless a specific index value is passed. + data: An xarray Dataset containing the data to plot. All dimensions will be summed + to get the total for each variable. colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') + - A string with a colormap name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A dictionary mapping variable names to colors (e.g., {'Solar': '#ff0000'}) title: The title of the plot. legend_title: The title for the legend. hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0). figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. Notes: - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning. - - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category - for better readability. - - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing. + - All dimensions are summed to get total values for each variable. + - Scalar variables (with no dimensions) are used directly. + + Examples: + Simple pie chart: + ```python + fig, ax = pie_with_matplotlib(dataset, colors='turbo', title='Energy Mix') + ``` + + Custom color mapping: + + ```python + colors = {'Solar': 'orange', 'Wind': 'blue', 'Coal': 'red'} + fig, ax = pie_with_matplotlib(dataset, colors=colors, title='Renewable Energy') + ``` """ - if data.empty: - logger.error('Empty DataFrame provided for pie chart. Returning empty figure.') - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + + if len(data.data_vars) == 0: + logger.error('Empty Dataset provided for pie chart. Returning empty figure.') + fig, ax = plt.subplots(figsize=figsize) return fig, ax - # Create a copy to avoid modifying the original DataFrame - data_copy = data.copy() + # Sum all dimensions for each variable to get total values + labels = [] + values = [] - # Check if any negative values and warn - if (data_copy < 0).any().any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - data_copy = data_copy.abs() + for var in data.data_vars: + var_data = data[var] - # If data has multiple rows, sum them to get total for each column - if len(data_copy) > 1: - data_sum = data_copy.sum() - else: - data_sum = data_copy.iloc[0] + # Sum across all dimensions to get total + if len(var_data.dims) > 0: + total_value = var_data.sum().item() + else: + # Scalar variable + total_value = var_data.item() - # Get labels (column names) and values - labels = data_sum.index.tolist() - values = data_sum.values.tolist() + # Check for negative values + if total_value < 0: + logger.warning(f'Negative value detected for {var}: {total_value}. Using absolute value.') + total_value = abs(total_value) - # Apply color mapping using the unified color processor - processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, labels) + labels.append(str(var)) + values.append(total_value) - # Create figure and axis if not provided - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Use resolve_colors for consistent color handling + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + processed_colors = [color_discrete_map.get(label, '#808080') for label in labels] + + # Create figure and axis + fig, ax = plt.subplots(figsize=figsize) # Draw the pie chart wedges, texts, autotexts = ax.pie( @@ -1144,9 +1672,9 @@ def pie_with_matplotlib( def dual_pie_with_plotly( - data_left: pd.Series, - data_right: pd.Series, - colors: ColorType = 'viridis', + data_left: xr.Dataset | pd.DataFrame, + data_right: xr.Dataset | pd.DataFrame, + colors: ColorType | None = None, title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), legend_title: str = '', @@ -1160,12 +1688,12 @@ def dual_pie_with_plotly( Create two pie charts side by side with Plotly, with consistent coloring across both charts. Args: - data_left: Series for the left pie chart. - data_right: Series for the right pie chart. + data_left: Dataset for the left pie chart. Variables are summed across all dimensions. + data_right: Dataset for the right pie chart. Variables are summed across all dimensions. colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') + - A string with a colorscale name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'}) + - A dictionary mapping variable names to colors (e.g., {'Solar': '#ff0000'}) title: The main title of the plot. subtitles: Tuple containing the subtitles for (left, right) charts. legend_title: The title for the legend. @@ -1179,10 +1707,19 @@ def dual_pie_with_plotly( Returns: A Plotly figure object containing the generated dual pie chart. """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + from plotly.subplots import make_subplots + # Ensure data is a Dataset and validate it + data_left = _ensure_dataset(data_left) + data_right = _ensure_dataset(data_right) + _validate_plotting_data(data_left, allow_empty=True) + _validate_plotting_data(data_right, allow_empty=True) + # Check for empty data - if data_left.empty and data_right.empty: + if len(data_left.data_vars) == 0 and len(data_right.data_vars) == 0: logger.error('Both datasets are empty. Returning empty figure.') return go.Figure() @@ -1191,71 +1728,52 @@ def dual_pie_with_plotly( rows=1, cols=2, specs=[[{'type': 'pie'}, {'type': 'pie'}]], subplot_titles=subtitles, horizontal_spacing=0.05 ) - # Process series to handle negative values and apply minimum percentage threshold - def preprocess_series(series: pd.Series): - """ - Preprocess a series for pie chart display by handling negative values - and grouping the smallest parts together if they collectively represent - less than the specified percentage threshold. - - Args: - series: The series to preprocess + # Helper function to extract labels and values from Dataset + def dataset_to_pie_data(dataset): + labels = [] + values = [] - Returns: - A preprocessed pandas Series - """ - # Handle negative values - if (series < 0).any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - series = series.abs() - - # Remove zeros - series = series[series > 0] + for var in dataset.data_vars: + var_data = dataset[var] - # Apply minimum percentage threshold if needed - if lower_percentage_group and not series.empty: - total = series.sum() - if total > 0: - # Sort series by value (ascending) - sorted_series = series.sort_values() - - # Calculate cumulative percentage contribution - cumulative_percent = (sorted_series.cumsum() / total) * 100 - - # Find entries that collectively make up less than lower_percentage_group - to_group = cumulative_percent <= lower_percentage_group - - if to_group.sum() > 1: - # Create "Other" category for the smallest values that together are < threshold - other_sum = sorted_series[to_group].sum() - - # Keep only values that aren't in the "Other" group - result_series = series[~series.index.isin(sorted_series[to_group].index)] + # Sum across all dimensions + if len(var_data.dims) > 0: + total_value = float(var_data.sum().values) + else: + total_value = float(var_data.values) - # Add the "Other" category if it has a value - if other_sum > 0: - result_series['Other'] = other_sum + # Handle negative values + if total_value < 0: + logger.warning(f'Negative value for {var}: {total_value}. Using absolute value.') + total_value = abs(total_value) - return result_series + # Only include if value > 0 + if total_value > 0: + labels.append(str(var)) + values.append(total_value) - return series + return labels, values - data_left_processed = preprocess_series(data_left) - data_right_processed = preprocess_series(data_right) + # Get data for left and right + left_labels, left_values = dataset_to_pie_data(data_left) + right_labels, right_values = dataset_to_pie_data(data_right) - # Get unique set of all labels for consistent coloring - all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index)) + # Get unique set of all labels for consistent coloring across both pies + # Merge both datasets for color resolution + combined_vars = list(set(data_left.data_vars) | set(data_right.data_vars)) + combined_ds = xr.Dataset( + {var: data_left[var] if var in data_left.data_vars else data_right[var] for var in combined_vars} + ) - # Get consistent color mapping for both charts using our unified function - color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True) + # Use resolve_colors for consistent color handling + color_discrete_map = resolve_colors(combined_ds, colors, engine='plotly') + color_map = {label: color_discrete_map.get(label, '#636EFA') for label in left_labels + right_labels} # Function to create a pie trace with consistently mapped colors - def create_pie_trace(data_series, side): - if data_series.empty: + def create_pie_trace(labels, values, side): + if not labels: return None - labels = data_series.index.tolist() - values = data_series.values.tolist() trace_colors = [color_map[label] for label in labels] return go.Pie( @@ -1272,24 +1790,21 @@ def create_pie_trace(data_series, side): ) # Add left pie if data exists - left_trace = create_pie_trace(data_left_processed, subtitles[0]) + left_trace = create_pie_trace(left_labels, left_values, subtitles[0]) if left_trace: left_trace.domain = dict(x=[0, 0.48]) fig.add_trace(left_trace, row=1, col=1) # Add right pie if data exists - right_trace = create_pie_trace(data_right_processed, subtitles[1]) + right_trace = create_pie_trace(right_labels, right_values, subtitles[1]) if right_trace: right_trace.domain = dict(x=[0.52, 1]) fig.add_trace(right_trace, row=1, col=2) - # Update layout + # Update layout with plot-specific properties fig.update_layout( title=title, legend_title=legend_title, - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), margin=dict(t=80, b=50, l=30, r=30), ) @@ -1299,25 +1814,22 @@ def create_pie_trace(data_series, side): def dual_pie_with_matplotlib( data_left: pd.Series, data_right: pd.Series, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), legend_title: str = '', hole: float = 0.2, lower_percentage_group: float = 5.0, figsize: tuple[int, int] = (14, 7), - fig: plt.Figure | None = None, - axes: list[plt.Axes] | None = None, ) -> tuple[plt.Figure, list[plt.Axes]]: """ Create two pie charts side by side with Matplotlib, with consistent coloring across both charts. - Leverages the existing pie_with_matplotlib function. Args: data_left: Series for the left pie chart. data_right: Series for the right pie chart. colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') + - A string with a colormap name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'}) title: The main title of the plot. @@ -1326,23 +1838,21 @@ def dual_pie_with_matplotlib( hole: Size of the hole in the center for creating donut charts (0.0 to 1.0). lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category. figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created. Returns: A tuple containing the Matplotlib figure and list of axes objects used for the plot. """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + + # Create figure and axes + fig, axes = plt.subplots(1, 2, figsize=figsize) + # Check for empty data if data_left.empty and data_right.empty: logger.error('Both datasets are empty. Returning empty figure.') - if fig is None: - fig, axes = plt.subplots(1, 2, figsize=figsize) return fig, axes - # Create figure and axes if not provided - if fig is None or axes is None: - fig, axes = plt.subplots(1, 2, figsize=figsize) - # Process series to handle negative values and apply minimum percentage threshold def preprocess_series(series: pd.Series): """ @@ -1404,19 +1914,49 @@ def preprocess_series(series: pd.Series): left_colors = [color_map[col] for col in df_left.columns] if not df_left.empty else [] right_colors = [color_map[col] for col in df_right.columns] if not df_right.empty else [] + # Helper function to draw pie chart on a specific axis + def draw_pie_on_axis(ax, data_series, colors_list, subtitle, hole_size): + """Draw a pie chart on a specific matplotlib axis.""" + if data_series.empty: + ax.set_title(subtitle) + ax.axis('off') + return + + labels = list(data_series.index) + values = list(data_series.values) + + # Draw the pie chart + wedges, texts, autotexts = ax.pie( + values, + labels=labels, + colors=colors_list, + autopct='%1.1f%%', + startangle=90, + shadow=False, + wedgeprops=dict(width=0.5) if hole_size > 0 else None, + ) + + # Adjust hole size + if hole_size > 0: + wedge_width = 1 - hole_size + for wedge in wedges: + wedge.set_width(wedge_width) + + # Customize text + for autotext in autotexts: + autotext.set_fontsize(10) + autotext.set_color('white') + + # Set aspect ratio and title + ax.set_aspect('equal') + if subtitle: + ax.set_title(subtitle, fontsize=14) + # Create left pie chart - if not df_left.empty: - pie_with_matplotlib(data=df_left, colors=left_colors, title=subtitles[0], hole=hole, fig=fig, ax=axes[0]) - else: - axes[0].set_title(subtitles[0]) - axes[0].axis('off') + draw_pie_on_axis(axes[0], data_left_processed, left_colors, subtitles[0], hole) # Create right pie chart - if not df_right.empty: - pie_with_matplotlib(data=df_right, colors=right_colors, title=subtitles[1], hole=hole, fig=fig, ax=axes[1]) - else: - axes[1].set_title(subtitles[1]) - axes[1].axis('off') + draw_pie_on_axis(axes[1], data_right_processed, right_colors, subtitles[1], hole) # Add main title if title: @@ -1460,15 +2000,16 @@ def preprocess_series(series: pd.Series): def heatmap_with_plotly( data: xr.DataArray, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', fill: Literal['ffill', 'bfill'] | None = 'ffill', + **imshow_kwargs: Any, ) -> go.Figure: """ Plot a heatmap visualization using Plotly's imshow with faceting and animation support. @@ -1487,7 +2028,7 @@ def heatmap_with_plotly( data: An xarray DataArray containing the data to visualize. Should have at least 2 dimensions, or a 'time' dimension that can be reshaped into 2D. colors: Color specification (colormap name, list, or dict). Common options: - 'viridis', 'plasma', 'RdBu', 'portland'. + 'turbo', 'plasma', 'RdBu', 'portland'. title: The main title of the heatmap. facet_by: Dimension to create facets for. Creates a subplot grid. Can be a single dimension name or list (only first dimension used). @@ -1501,6 +2042,11 @@ def heatmap_with_plotly( - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) - None: Disable time reshaping (will error if only 1D time data) fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + **imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow. + Common options include: + - aspect: 'auto', 'equal', or a number for aspect ratio + - zmin, zmax: Minimum and maximum values for color scale + - labels: Dict to customize axis labels Returns: A Plotly figure object containing the heatmap visualization. @@ -1538,6 +2084,13 @@ def heatmap_with_plotly( fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D')) ``` """ + if colors is None: + colors = CONFIG.Plotting.default_sequential_colorscale + + # Apply CONFIG defaults if not explicitly set + if facet_cols is None: + facet_cols = CONFIG.Plotting.default_facet_cols + # Handle empty data if data.size == 0: return go.Figure() @@ -1589,12 +2142,26 @@ def heatmap_with_plotly( heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] if len(heatmap_dims) < 2: - # Need at least 2 dimensions for a heatmap - logger.error( - f'Heatmap requires at least 2 dimensions for rows and columns. ' - f'After faceting/animation, only {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' - ) - return go.Figure() + # Handle single-dimension case by adding variable name as a dimension + if len(heatmap_dims) == 1: + # Get the variable name, or use a default + var_name = data.name if data.name else 'value' + + # Expand the DataArray by adding a new dimension with the variable name + data = data.expand_dims({'variable': [var_name]}) + + # Update available dimensions + available_dims = list(data.dims) + heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] + + logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}') + else: + # No dimensions at all - cannot create a heatmap + logger.error( + f'Heatmap requires at least 1 dimension. ' + f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' + ) + return go.Figure() # Setup faceting parameters for Plotly Express # Note: px.imshow only supports facet_col, not facet_row @@ -1617,7 +2184,7 @@ def heatmap_with_plotly( # Create the imshow plot - px.imshow can work directly with xarray DataArrays common_args = { 'img': data, - 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis', + 'color_continuous_scale': colors if isinstance(colors, str) else CONFIG.Plotting.default_sequential_colorscale, 'title': title, } @@ -1631,38 +2198,41 @@ def heatmap_with_plotly( if animate_by: common_args['animation_frame'] = animate_by + # Merge in additional imshow kwargs + common_args.update(imshow_kwargs) + try: fig = px.imshow(**common_args) except Exception as e: logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.') # Fallback: create a simple heatmap without faceting - fig = px.imshow( - data.values, - color_continuous_scale=colors if isinstance(colors, str) else 'viridis', - title=title, - ) - - # Update layout with basic styling - fig.update_layout( - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - font=dict(size=12), - ) + fallback_args = { + 'img': data.values, + 'color_continuous_scale': colors + if isinstance(colors, str) + else CONFIG.Plotting.default_sequential_colorscale, + 'title': title, + } + fallback_args.update(imshow_kwargs) + fig = px.imshow(**fallback_args) return fig def heatmap_with_matplotlib( data: xr.DataArray, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', figsize: tuple[float, float] = (12, 6), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', fill: Literal['ffill', 'bfill'] | None = 'ffill', + vmin: float | None = None, + vmax: float | None = None, + imshow_kwargs: dict[str, Any] | None = None, + cbar_kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes]: """ Plot a heatmap visualization using Matplotlib's imshow. @@ -1674,16 +2244,25 @@ def heatmap_with_matplotlib( data: An xarray DataArray containing the data to visualize. Should have at least 2 dimensions. If more than 2 dimensions exist, additional dimensions will be reduced by taking the first slice. - colors: Color specification. Should be a colormap name (e.g., 'viridis', 'RdBu'). + colors: Color specification. Should be a colormap name (e.g., 'turbo', 'RdBu'). title: The title of the heatmap. figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. reshape_time: Time reshaping configuration: - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) - None: Disable time reshaping fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + vmin: Minimum value for color scale. If None, uses data minimum. + vmax: Maximum value for color scale. If None, uses data maximum. + imshow_kwargs: Optional dict of parameters to pass to ax.imshow(). + Use this to customize image properties (e.g., interpolation, aspect). + cbar_kwargs: Optional dict of parameters to pass to plt.colorbar(). + Use this to customize colorbar properties (e.g., orientation, label). + **kwargs: Additional keyword arguments passed to ax.imshow(). + Common options include: + - interpolation: 'nearest', 'bilinear', 'bicubic', etc. + - alpha: Transparency level (0-1) + - extent: [left, right, bottom, top] for axis limits Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. @@ -1705,19 +2284,36 @@ def heatmap_with_matplotlib( fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h')) ``` """ + if colors is None: + colors = CONFIG.Plotting.default_sequential_colorscale + + # Initialize kwargs if not provided + if imshow_kwargs is None: + imshow_kwargs = {} + if cbar_kwargs is None: + cbar_kwargs = {} + + # Merge any additional kwargs into imshow_kwargs + # This allows users to pass imshow options directly + imshow_kwargs.update(kwargs) + # Handle empty data if data.size == 0: - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) return fig, ax # Apply time reshaping using the new unified function # Matplotlib doesn't support faceting/animation, so we pass None for those data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill) - # Create figure and axes if not provided - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Handle single-dimension case by adding variable name as a dimension + if isinstance(data, xr.DataArray) and len(data.dims) == 1: + var_name = data.name if data.name else 'value' + data = data.expand_dims({'variable': [var_name]}) + logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}') + + # Create figure and axes + fig, ax = plt.subplots(figsize=figsize) # Extract data values # If data has more than 2 dimensions, we need to reduce it @@ -1743,14 +2339,21 @@ def heatmap_with_matplotlib( y_labels = 'y' # Process colormap - cmap = colors if isinstance(colors, str) else 'viridis' + cmap = colors if isinstance(colors, str) else CONFIG.Plotting.default_sequential_colorscale + + # Create the heatmap using imshow with user customizations + imshow_defaults = {'cmap': cmap, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax} + imshow_defaults.update(imshow_kwargs) # User kwargs override defaults + im = ax.imshow(values, **imshow_defaults) - # Create the heatmap using imshow - im = ax.imshow(values, cmap=cmap, aspect='auto', origin='upper') + # Add colorbar with user customizations + cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05} + cbar_defaults.update(cbar_kwargs) # User kwargs override defaults + cbar = plt.colorbar(im, **cbar_defaults) - # Add colorbar - cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.1, aspect=15, fraction=0.05) - cbar.set_label('Value') + # Set colorbar label if not overridden by user + if 'label' not in cbar_kwargs: + cbar.set_label('Value') # Set labels and title ax.set_xlabel(str(x_labels).capitalize()) @@ -1768,8 +2371,9 @@ def export_figure( default_path: pathlib.Path, default_filetype: str | None = None, user_path: pathlib.Path | None = None, - show: bool = True, + show: bool | None = None, save: bool = False, + dpi: int | None = None, ) -> go.Figure | tuple[plt.Figure, plt.Axes]: """ Export a figure to a file and or show it. @@ -1779,13 +2383,21 @@ def export_figure( default_path: The default file path if no user filename is provided. default_filetype: The default filetype if the path doesnt end with a filetype. user_path: An optional user-specified file path. - show: Whether to display the figure (default: True). + show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None). save: Whether to save the figure (default: False). + dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi. Raises: ValueError: If no default filetype is provided and the path doesn't specify a filetype. TypeError: If the figure type is not supported. """ + # Apply CONFIG defaults if not explicitly set + if show is None: + show = CONFIG.Plotting.default_show + + if dpi is None: + dpi = CONFIG.Plotting.default_dpi + filename = user_path or default_path filename = filename.with_name(filename.name.replace('|', '__')) if filename.suffix == '': @@ -1795,30 +2407,32 @@ def export_figure( if isinstance(figure_like, plotly.graph_objs.Figure): fig = figure_like + + # Apply default dimensions if configured + layout_updates = {} + if CONFIG.Plotting.default_figure_width is not None: + layout_updates['width'] = CONFIG.Plotting.default_figure_width + if CONFIG.Plotting.default_figure_height is not None: + layout_updates['height'] = CONFIG.Plotting.default_figure_height + if layout_updates: + fig.update_layout(**layout_updates) + if filename.suffix != '.html': logger.warning(f'To save a Plotly figure, using .html. Adjusting suffix for {filename}') filename = filename.with_suffix('.html') try: - is_test_env = 'PYTEST_CURRENT_TEST' in os.environ - - if is_test_env: - # Test environment: never open browser, only save if requested - if save: - fig.write_html(str(filename)) - # Ignore show flag in tests - else: - # Production environment: respect show and save flags - if save and show: - # Save and auto-open in browser - plotly.offline.plot(fig, filename=str(filename)) - elif save and not show: - # Save without opening - fig.write_html(str(filename)) - elif show and not save: - # Show interactively without saving - fig.show() - # If neither save nor show: do nothing + # Respect show and save flags (tests should set CONFIG.Plotting.default_show=False) + if save and show: + # Save and auto-open in browser + plotly.offline.plot(fig, filename=str(filename)) + elif save and not show: + # Save without opening + fig.write_html(str(filename)) + elif show and not save: + # Show interactively without saving + fig.show() + # If neither save nor show: do nothing finally: # Cleanup to prevent socket warnings if hasattr(fig, '_renderer'): @@ -1829,16 +2443,15 @@ def export_figure( elif isinstance(figure_like, tuple): fig, ax = figure_like if show: - # Only show if using interactive backend and not in test environment + # Only show if using interactive backend (tests should set CONFIG.Plotting.default_show=False) backend = matplotlib.get_backend().lower() is_interactive = backend not in {'agg', 'pdf', 'ps', 'svg', 'template'} - is_test_env = 'PYTEST_CURRENT_TEST' in os.environ - if is_interactive and not is_test_env: + if is_interactive: plt.show() if save: - fig.savefig(str(filename), dpi=300) + fig.savefig(str(filename), dpi=dpi) plt.close(fig) # Close figure to free memory return fig, ax diff --git a/flixopt/results.py b/flixopt/results.py index 75f8f300e..80dcc16ea 100644 --- a/flixopt/results.py +++ b/flixopt/results.py @@ -15,6 +15,7 @@ from . import io as fx_io from . import plotting +from .config import CONFIG from .flow_system import FlowSystem if TYPE_CHECKING: @@ -69,6 +70,10 @@ class CalculationResults: effects: Dictionary mapping effect names to EffectResults objects timesteps_extra: Extended time index including boundary conditions hours_per_timestep: Duration of each timestep for proper energy calculations + colors: Optional dict mapping variable names to colors for automatic coloring in plots. + When set, all plotting methods automatically use these colors when colors=None + (the default). Use `setup_colors()` to configure colors, which returns this dict. + Set to None to disable automatic coloring. Examples: Load and analyze saved results: @@ -107,6 +112,20 @@ class CalculationResults: ).mean() ``` + Configure automatic color management for plots: + + ```python + # Dict-based configuration: + results.setup_colors({'Solar*': 'Oranges', 'Wind*': 'Blues', 'Battery': 'green'}) + + # All plots automatically use configured colors (colors=None is the default) + results['ElectricityBus'].plot_node_balance() + results['Battery'].plot_charge_state() + + # Override when needed + results['ElectricityBus'].plot_node_balance(colors='turbo') # Ignores setup + ``` + Design Patterns: **Factory Methods**: Use `from_file()` and `from_calculation()` for creation or access directly from `Calculation.results` **Dictionary Access**: Use `results[element_label]` for element-specific results @@ -240,6 +259,9 @@ def __init__( self._sizes = None self._effects_per_component = None + # Color dict for intelligent plot coloring - None by default, user configures explicitly + self.colors: dict[str, str] | None = None + def __getitem__(self, key: str) -> ComponentResults | BusResults | EffectResults: if key in self.components: return self.components[key] @@ -306,6 +328,104 @@ def flow_system(self) -> FlowSystem: logger.level = old_level return self._flow_system + def setup_colors( + self, + config: dict[str, str | list[str]] | str | pathlib.Path | None = None, + *, + default_colorscale: str | None = None, + reset: bool = True, + ) -> dict[str, str]: + """Configure colors for plotting. Returns variable→color dict. + + Supports multiple configuration styles: + - Direct assignment: {'Boiler1': 'red'} + - Pattern matching: {'Solar*': 'orange'} or {'Solar*': 'Oranges'} + - Family grouping: {'oranges': ['Solar1', 'Solar2']} + + Args: + config: Optional color configuration: + - dict: Component/pattern to color/colorscale mapping + - str/Path: Path to YAML file + - None: Use default colorscale for all components + default_colorscale: Default colorscale for unmapped components. + Defaults to CONFIG.Plotting.default_qualitative_colorscale + reset: If True, reset all existing colors before applying config. + If False, only update/add specified components (default: True) + + Returns: + dict[str, str]: Complete variable→color mapping + + Examples: + Direct color assignment: + + ```python + results.setup_colors({'Boiler1': 'red', 'CHP': 'darkred'}) + ``` + + Pattern matching with color: + + ```python + results.setup_colors({'Solar*': 'orange', 'Wind*': 'blue'}) + ``` + + Pattern matching with colorscale (generates shades): + + ```python + results.setup_colors({'Solar*': 'Oranges', 'Wind*': 'Blues'}) + ``` + + Family grouping (colorscale samples): + + ```python + results.setup_colors( + { + 'oranges': ['Solar1', 'Solar2'], + 'blues': ['Wind1', 'Wind2'], + } + ) + ``` + + Load from YAML file: + + ```python + # colors.yaml: + # Boiler1: red + # Solar*: Oranges + # oranges: + # - Solar1 + # - Solar2 + results.setup_colors('colors.yaml') + ``` + + Merge with existing colors: + + ```python + results.setup_colors({'Boiler1': 'red'}) + results.setup_colors({'CHP': 'blue'}, reset=False) # Keeps Boiler1 red + ``` + + Disable automatic coloring: + + ```python + results.colors = None # Plots use default colorscales + ``` + """ + # Create resolver and delegate + resolver = plotting.ElementColorResolver( + self.components, + default_colorscale=default_colorscale, + engine='plotly', + ) + + # Resolve colors (with variable-level merging if reset=False) + self.colors = resolver.resolve( + config=config, + reset=reset, + existing_colors=None if reset else self.colors, + ) + + return self.colors + def filter_solution( self, variable_dims: Literal['scalar', 'time', 'scenario', 'timeonly', 'scenarioonly'] | None = None, @@ -705,13 +825,13 @@ def plot_heatmap( self, variable_name: str | list[str], save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', @@ -721,6 +841,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ Plots a heatmap visualization of a variable using imshow or time-based reshaping. @@ -737,7 +858,8 @@ def plot_heatmap( with a new 'variable' dimension. save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. - colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options. + colors: Color scheme for the heatmap (default: None uses CONFIG.Plotting.default_sequential_colorscale). + See `flixopt.plotting.ColorType` for options. engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. select: Optional data selection dict. Supports single values, lists, slices, and index arrays. Applied BEFORE faceting/animation/reshaping. @@ -754,6 +876,20 @@ def plot_heatmap( Supported timeframes: 'YS', 'MS', 'W', 'D', 'h', '15min', 'min' fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill). Default is 'ffill'. + **plot_kwargs: Additional plotting customization options. + Common options: + + - **dpi** (int): Export resolution for saved plots. Default: 300. + + For heatmaps specifically: + + - **vmin** (float): Minimum value for color scale (both engines). + - **vmax** (float): Maximum value for color scale (both engines). + + For Matplotlib heatmaps: + + - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow (e.g., interpolation, aspect). + - **cbar_kwargs** (dict): Additional kwargs for colorbar customization. Examples: Direct imshow mode (default): @@ -794,6 +930,18 @@ def plot_heatmap( ... animate_by='period', ... reshape_time=('D', 'h'), ... ) + + High-resolution export with custom color range: + + >>> results.plot_heatmap('Battery|charge_state', save=True, dpi=600, vmin=0, vmax=100) + + Matplotlib heatmap with custom imshow settings: + + >>> results.plot_heatmap( + ... 'Boiler(Q_th)|flow_rate', + ... engine='matplotlib', + ... imshow_kwargs={'interpolation': 'bilinear', 'aspect': 'auto'}, + ... ) """ # Delegate to module-level plot_heatmap function return plot_heatmap( @@ -814,6 +962,7 @@ def plot_heatmap( heatmap_timeframes=heatmap_timeframes, heatmap_timesteps_per_frame=heatmap_timesteps_per_frame, color_map=color_map, + **plot_kwargs, ) def plot_network( @@ -982,8 +1131,8 @@ def __init__( def plot_node_balance( self, save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate', @@ -991,9 +1140,10 @@ def plot_node_balance( drop_suffix: bool = True, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ Plots the node balance of the Component or Bus with optional faceting and animation. @@ -1001,7 +1151,12 @@ def plot_node_balance( Args: save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. - colors: The colors to use for the plot. See `flixopt.plotting.ColorType` for options. + colors: The colors to use for the plot. Options: + - None (default): Use `self.colors` dict if configured, else fall back to CONFIG.Plotting.default_qualitative_colorscale + - Colormap name string (e.g., 'turbo', 'plasma') + - List of color strings + - Dict mapping variable names to colors + Use `results.setup_colors()` to configure automatic component-based coloring. engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. select: Optional data selection dict. Supports: - Single values: {'scenario': 'base', 'period': 2024} @@ -1021,6 +1176,27 @@ def plot_node_balance( animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through dimension values. Only one dimension can be animated. Ignored if not found. facet_cols: Number of columns in the facet grid layout (default: 3). + **plot_kwargs: Additional plotting customization options passed to underlying plotting functions. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + + **For Plotly engine** (`engine='plotly'`): + + - **trace_kwargs** (dict): Customize traces via `fig.update_traces()`. + Example: `trace_kwargs={'line': {'width': 5, 'dash': 'dot'}}` + - **layout_kwargs** (dict): Customize layout via `fig.update_layout()`. + Example: `layout_kwargs={'width': 1200, 'height': 600, 'template': 'plotly_dark'}` + - Any Plotly Express parameter for px.bar()/px.line()/px.area() + + **For Matplotlib engine** (`engine='matplotlib'`): + + - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`. + Example: `plot_kwargs={'linewidth': 3, 'alpha': 0.7, 'edgecolor': 'black'}` + + See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib` + for complete parameter reference. Examples: Basic plot (current behavior): @@ -1052,6 +1228,24 @@ def plot_node_balance( Time range selection (summer months only): >>> results['Boiler'].plot_node_balance(select={'time': slice('2024-06', '2024-08')}, facet_by='scenario') + + High-resolution export for publication: + + >>> results['Boiler'].plot_node_balance(engine='matplotlib', save='figure.png', dpi=600) + + Custom Plotly theme and layout: + + >>> results['Boiler'].plot_node_balance( + ... layout_kwargs={'template': 'plotly_dark', 'width': 1200, 'height': 600} + ... ) + + Custom line styling: + + >>> results['Boiler'].plot_node_balance(mode='line', trace_kwargs={'line': {'width': 5, 'dash': 'dot'}}) + + Custom matplotlib appearance: + + >>> results['Boiler'].plot_node_balance(engine='matplotlib', plot_kwargs={'linewidth': 3, 'alpha': 0.7}) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1073,11 +1267,18 @@ def plot_node_balance( if engine not in {'plotly', 'matplotlib'}: raise ValueError(f'Engine "{engine}" not supported. Use one of ["plotly", "matplotlib"]') + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + # Don't pass select/indexer to node_balance - we'll apply it afterwards - ds = self.node_balance(with_last_timestep=True, unit_type=unit_type, drop_suffix=drop_suffix) + ds = self.node_balance(with_last_timestep=False, unit_type=unit_type, drop_suffix=drop_suffix) ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True) + # Resolve colors: None -> colors dict if set -> CONFIG default -> explicit value + colors_to_use = colors or self._calculation_results.colors or CONFIG.Plotting.default_qualitative_colorscale + resolved_colors = plotting.resolve_colors(ds, colors_to_use, engine=engine) + # Matplotlib requires only 'time' dimension; check for extras after selection if engine == 'matplotlib': extra_dims = [d for d in ds.dims if d != 'time'] @@ -1097,18 +1298,21 @@ def plot_node_balance( ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=resolved_colors, mode=mode, title=title, facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) default_filetype = '.html' else: figure_like = plotting.with_matplotlib( - ds.to_dataframe(), - colors=colors, + ds, + colors=resolved_colors, mode=mode, title=title, + **plot_kwargs, ) default_filetype = '.png' @@ -1119,19 +1323,21 @@ def plot_node_balance( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def plot_node_balance_pie( self, lower_percentage_group: float = 5, - colors: plotting.ColorType = 'viridis', + colors: plotting.ColorType | None = None, text_info: str = 'percent+label+value', save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, list[plt.Axes]]: """Plot pie chart of flow hours distribution. @@ -1144,13 +1350,25 @@ def plot_node_balance_pie( Args: lower_percentage_group: Percentage threshold for "Others" grouping. - colors: Color scheme. Also see plotly. + colors: Color scheme (default: None uses colors dict if configured, + else falls back to CONFIG.Plotting.default_qualitative_colorscale). text_info: Information to display on pie slices. save: Whether to save plot. show: Whether to display plot. engine: Plotting engine ('plotly' or 'matplotlib'). select: Optional data selection dict. Supports single values, lists, slices, and index arrays. Use this to select specific scenario/period before creating the pie chart. + **plot_kwargs: Additional plotting customization options. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + - **hover_template** (str): Hover text template (Plotly only). + Example: `hover_template='%{label}: %{value} (%{percent})'` + - **text_position** (str): Text position ('inside', 'outside', 'auto'). + - **hole** (float): Size of donut hole (0.0 to 1.0). + + See :func:`flixopt.plotting.dual_pie_with_plotly` for complete reference. Examples: Basic usage (auto-selects first scenario/period if present): @@ -1160,6 +1378,14 @@ def plot_node_balance_pie( Explicitly select a scenario and period: >>> results['Bus'].plot_node_balance_pie(select={'scenario': 'high_demand', 'period': 2030}) + + Create a donut chart with custom hover text: + + >>> results['Bus'].plot_node_balance_pie(hole=0.4, hover_template='%{label}: %{value:.2f} (%{percent})') + + High-resolution export: + + >>> results['Bus'].plot_node_balance_pie(save='figure.png', dpi=600) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1178,6 +1404,9 @@ def plot_node_balance_pie( ) select = indexer + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + inputs = sanitize_dataset( ds=self.solution[self.inputs] * self._calculation_results.hours_per_timestep, threshold=1e-5, @@ -1233,16 +1462,24 @@ def plot_node_balance_pie( suffix = '--' + '-'.join(suffix_parts) if suffix_parts else '' title = f'{self.label} (total flow hours){suffix}' + # Combine inputs and outputs to resolve colors for all variables + combined_ds = xr.Dataset({**inputs.data_vars, **outputs.data_vars}) + + # Resolve colors: None -> colors dict if set -> CONFIG default -> explicit value + colors_to_use = colors or self._calculation_results.colors or CONFIG.Plotting.default_qualitative_colorscale + resolved_colors = plotting.resolve_colors(combined_ds, colors_to_use, engine=engine) + if engine == 'plotly': figure_like = plotting.dual_pie_with_plotly( - data_left=inputs.to_pandas(), - data_right=outputs.to_pandas(), - colors=colors, + data_left=inputs, + data_right=outputs, + colors=resolved_colors, title=title, text_info=text_info, subtitles=('Inputs', 'Outputs'), legend_title='Flows', lower_percentage_group=lower_percentage_group, + **plot_kwargs, ) default_filetype = '.html' elif engine == 'matplotlib': @@ -1250,11 +1487,12 @@ def plot_node_balance_pie( figure_like = plotting.dual_pie_with_matplotlib( data_left=inputs.to_pandas(), data_right=outputs.to_pandas(), - colors=colors, + colors=resolved_colors, title=title, subtitles=('Inputs', 'Outputs'), legend_title='Flows', lower_percentage_group=lower_percentage_group, + **plot_kwargs, ) default_filetype = '.png' else: @@ -1267,6 +1505,7 @@ def plot_node_balance_pie( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def node_balance( @@ -1363,23 +1602,25 @@ def charge_state(self) -> xr.DataArray: def plot_charge_state( self, save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', mode: Literal['area', 'stacked_bar', 'line'] = 'area', select: dict[FlowSystemDimensions, Any] | None = None, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure: """Plot storage charge state over time, combined with the node balance with optional faceting and animation. Args: save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. - colors: Color scheme. Also see plotly. + colors: Color scheme (default: None uses colors dict if configured, + else falls back to CONFIG.Plotting.default_qualitative_colorscale). engine: Plotting engine to use. Only 'plotly' is implemented atm. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, or 'area' for stacked area charts. select: Optional data selection dict. Supports single values, lists, slices, and index arrays. @@ -1389,6 +1630,24 @@ def plot_charge_state( animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through dimension values. Only one dimension can be animated. Ignored if not found. facet_cols: Number of columns in the facet grid layout (default: 3). + **plot_kwargs: Additional plotting customization options passed to underlying plotting functions. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + + **For Plotly engine:** + + - **trace_kwargs** (dict): Customize traces via `fig.update_traces()`. + - **layout_kwargs** (dict): Customize layout via `fig.update_layout()`. + - Any Plotly Express parameter for px.bar()/px.line()/px.area() + + **For Matplotlib engine:** + + - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`. + + See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib` + for complete parameter reference. Raises: ValueError: If component is not a storage. @@ -1409,6 +1668,14 @@ def plot_charge_state( Facet by scenario AND animate by period: >>> results['Storage'].plot_charge_state(facet_by='scenario', animate_by='period') + + Custom layout: + + >>> results['Storage'].plot_charge_state(layout_kwargs={'template': 'plotly_dark', 'height': 800}) + + High-resolution export: + + >>> results['Storage'].plot_charge_state(save='storage.png', dpi=600) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1427,11 +1694,14 @@ def plot_charge_state( ) select = indexer + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + if not self.is_storage: raise ValueError(f'Cant plot charge_state. "{self.label}" is not a storage') # Get node balance and charge state - ds = self.node_balance(with_last_timestep=True) + ds = self.node_balance(with_last_timestep=True).fillna(0) charge_state_da = self.charge_state # Apply select filtering @@ -1441,31 +1711,42 @@ def plot_charge_state( title = f'Operation Balance of {self.label}{suffix}' + # Combine flow balance and charge state for color resolution + # We need to include both in the color map for consistency + combined_ds = ds.assign({self._charge_state: charge_state_da}) + + # Resolve colors: None -> colors dict if set -> CONFIG default -> explicit value + colors_to_use = colors or self._calculation_results.colors or CONFIG.Plotting.default_qualitative_colorscale + resolved_colors = plotting.resolve_colors(combined_ds, colors_to_use, engine=engine) + if engine == 'plotly': # Plot flows (node balance) with the specified mode figure_like = plotting.with_plotly( ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=resolved_colors, mode=mode, title=title, facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) - # Create a dataset with just charge_state and plot it as lines - # This ensures proper handling of facets and animation - charge_state_ds = charge_state_da.to_dataset(name=self._charge_state) + # Prepare charge_state as Dataset for plotting + charge_state_ds = xr.Dataset({self._charge_state: charge_state_da}) # Plot charge_state with mode='line' to get Scatter traces charge_state_fig = plotting.with_plotly( charge_state_ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=resolved_colors, mode='line', # Always line for charge_state title='', # No title needed for this temp figure facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) # Add charge_state traces to the main figure @@ -1473,6 +1754,7 @@ def plot_charge_state( for trace in charge_state_fig.data: trace.line.width = 2 # Make charge_state line more prominent trace.line.shape = 'linear' # Smooth line for charge state (not stepped like flows) + trace.line.color = 'black' figure_like.add_trace(trace) # Also add traces from animation frames if they exist @@ -1497,10 +1779,11 @@ def plot_charge_state( ) # For matplotlib, plot flows (node balance), then add charge_state as line fig, ax = plotting.with_matplotlib( - ds.to_dataframe(), - colors=colors, + ds, + colors=resolved_colors, mode=mode, title=title, + **plot_kwargs, ) # Add charge_state as a line overlay @@ -1525,6 +1808,7 @@ def plot_charge_state( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def node_balance_with_charge_state( @@ -1635,6 +1919,19 @@ class SegmentedCalculationResults: - Flow rate transitions at segment boundaries - Aggregated results over the full time horizon + Attributes: + segment_results: List of CalculationResults for each segment + all_timesteps: Complete time index spanning all segments + timesteps_per_segment: Number of timesteps in each segment + overlap_timesteps: Number of overlapping timesteps between segments + name: Identifier for this segmented calculation + folder: Directory path for result storage and loading + hours_per_timestep: Duration of each timestep + colors: Optional dict mapping variable names to colors for automatic coloring in plots. + When set, it is automatically propagated to all segment results, ensuring + consistent coloring across segments. Use `setup_colors()` to configure + colors across all segments. + Examples: Load and analyze segmented results: @@ -1690,6 +1987,17 @@ class SegmentedCalculationResults: storage_continuity = results.check_storage_continuity('Battery') ``` + Configure color management for consistent plotting across segments: + + ```python + # Dict-based configuration: + results.setup_colors({'Solar*': 'Oranges', 'Wind*': 'Blues', 'Battery': 'green'}) + + # Colors automatically propagate to all segments + results.segment_results[0]['ElectricityBus'].plot_node_balance() + results.segment_results[1]['ElectricityBus'].plot_node_balance() # Same colors + ``` + Design Considerations: **Boundary Effects**: Monitor solution quality at segment interfaces where foresight is limited compared to full-horizon optimization. @@ -1764,6 +2072,9 @@ def __init__( self.folder = pathlib.Path(folder) if folder is not None else pathlib.Path.cwd() / 'results' self.hours_per_timestep = FlowSystem.calculate_hours_per_timestep(self.all_timesteps) + # Color dict for intelligent plot coloring - None by default, user configures explicitly + self.colors: dict[str, str] | None = None + @property def meta_data(self) -> dict[str, int | list[str]]: return { @@ -1777,6 +2088,66 @@ def meta_data(self) -> dict[str, int | list[str]]: def segment_names(self) -> list[str]: return [segment.name for segment in self.segment_results] + def setup_colors( + self, + config: dict[str, str | list[str]] | str | pathlib.Path | None = None, + *, + default_colorscale: str | None = None, + reset: bool = True, + ) -> dict[str, str]: + """Configure colors for all segments. Returns variable→color dict. + + Colors are set on the first segment and then propagated to all other + segments for consistent coloring across the entire segmented calculation. + + Args: + config: Optional color configuration: + - dict: Component/pattern to color/colorscale mapping + - str/Path: Path to YAML file + - None: Use default colorscale for all components + default_colorscale: Default colorscale for unmapped components. + Defaults to CONFIG.Plotting.default_qualitative_colorscale + reset: If True, reset all existing colors before applying config. + If False, only update/add specified components (default: True) + + Returns: + dict[str, str]: Complete variable→color mapping + + Examples: + Dict-based configuration: + + ```python + results.setup_colors( + { + 'Boiler1': 'red', + 'Solar*': 'Oranges', + 'oranges': ['Solar1', 'Solar2'], + } + ) + + # All segments use the same colors + results.segment_results[0]['ElectricityBus'].plot_node_balance() + results.segment_results[1]['ElectricityBus'].plot_node_balance() + ``` + + Load from file: + + ```python + results.setup_colors('colors.yaml') + ``` + """ + # Setup colors on first segment + self.segment_results[0].setup_colors(config, default_colorscale=default_colorscale, reset=reset) + + # Propagate to all other segments + for segment in self.segment_results[1:]: + segment.colors = self.segment_results[0].colors + + # Store reference + self.colors = self.segment_results[0].colors + + return self.colors + def solution_without_overlap(self, variable_name: str) -> xr.DataArray: """Get variable solution removing segment overlaps. @@ -1798,18 +2169,19 @@ def plot_heatmap( reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', - colors: str = 'portland', + colors: str | None = None, save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, fill: Literal['ffill', 'bfill'] | None = 'ffill', # Deprecated parameters (kept for backwards compatibility) heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """Plot heatmap of variable solution across segments. @@ -1819,7 +2191,8 @@ def plot_heatmap( - 'auto': Automatically applies ('D', 'h') when only 'time' dimension remains - Tuple like ('D', 'h'): Explicit reshaping (days vs hours) - None: Disable time reshaping - colors: Color scheme. See plotting.ColorType for options. + colors: Color scheme (default: None uses CONFIG.Plotting.default_sequential_colorscale). + See plotting.ColorType for options. save: Whether to save plot. show: Whether to display plot. engine: Plotting engine. @@ -1830,6 +2203,17 @@ def plot_heatmap( heatmap_timeframes: (Deprecated) Use reshape_time instead. heatmap_timesteps_per_frame: (Deprecated) Use reshape_time instead. color_map: (Deprecated) Use colors instead. + **plot_kwargs: Additional plotting customization options. + Common options: + + - **dpi** (int): Export resolution for saved plots. Default: 300. + - **vmin** (float): Minimum value for color scale. + - **vmax** (float): Maximum value for color scale. + + For Matplotlib heatmaps: + + - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow. + - **cbar_kwargs** (dict): Additional kwargs for colorbar customization. Returns: Figure object. @@ -1857,7 +2241,7 @@ def plot_heatmap( if color_map is not None: # Check for conflict with new parameter - if colors != 'portland': # Check if user explicitly set colors + if colors is not None: # Check if user explicitly set colors raise ValueError( "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'." ) @@ -1884,6 +2268,7 @@ def plot_heatmap( animate_by=animate_by, facet_cols=facet_cols, fill=fill, + **plot_kwargs, ) def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = None, compression: int = 5): @@ -1916,9 +2301,9 @@ def plot_heatmap( data: xr.DataArray | xr.Dataset, name: str | None = None, folder: pathlib.Path | None = None, - colors: plotting.ColorType = 'viridis', + colors: plotting.ColorType | None = None, save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[str, Any] | None = None, facet_by: str | list[str] | None = None, @@ -1933,6 +2318,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ): """Plot heatmap visualization with support for multi-variable, faceting, and animation. @@ -1945,7 +2331,8 @@ def plot_heatmap( name: Optional name for the title. If not provided, uses the DataArray name or generates a default title for Datasets. folder: Save folder for the plot. Defaults to current directory if not provided. - colors: Color scheme for the heatmap. See `flixopt.plotting.ColorType` for options. + colors: Color scheme for the heatmap (default: None uses CONFIG.Plotting.default_sequential_colorscale). + See `flixopt.plotting.ColorType` for options. save: Whether to save the plot or not. If a path is provided, the plot will be saved at that location. show: Whether to show the plot or not. engine: The engine to use for plotting. Can be either 'plotly' or 'matplotlib'. @@ -2004,7 +2391,7 @@ def plot_heatmap( # Handle deprecated color_map parameter if color_map is not None: # Check for conflict with new parameter - if colors != 'viridis': # User explicitly set colors + if colors is not None: # User explicitly set colors raise ValueError( "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'." ) @@ -2087,6 +2474,13 @@ def plot_heatmap( timeframes, timesteps_per_frame = reshape_time title += f' ({timeframes} vs {timesteps_per_frame})' + # Apply CONFIG default if colors is None + if colors is None: + colors = CONFIG.Plotting.default_sequential_colorscale + + # Extract dpi before passing to plotting functions + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + # Plot with appropriate engine if engine == 'plotly': figure_like = plotting.heatmap_with_plotly( @@ -2098,6 +2492,7 @@ def plot_heatmap( facet_cols=facet_cols, reshape_time=reshape_time, fill=fill, + **plot_kwargs, ) default_filetype = '.html' elif engine == 'matplotlib': @@ -2107,6 +2502,7 @@ def plot_heatmap( title=title, reshape_time=reshape_time, fill=fill, + **plot_kwargs, ) default_filetype = '.png' else: @@ -2123,6 +2519,7 @@ def plot_heatmap( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) diff --git a/tests/conftest.py b/tests/conftest.py index ac5255562..98929e467 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -838,6 +838,7 @@ def set_test_environment(): This fixture runs once per test session to: - Set matplotlib to use non-interactive 'Agg' backend - Set plotly to use non-interactive 'json' renderer + - Configure flixopt to not show plots by default - Prevent GUI windows from opening during tests """ import matplotlib @@ -848,4 +849,8 @@ def set_test_environment(): pio.renderers.default = 'json' # Use non-interactive renderer + # Configure flixopt to not show plots in tests + fx.CONFIG.Plotting.default_show = False + fx.CONFIG.apply() + yield diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py new file mode 100644 index 000000000..f59601dca --- /dev/null +++ b/tests/test_plotting_api.py @@ -0,0 +1,64 @@ +"""Smoke tests for plotting API robustness improvements.""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from flixopt import plotting + + +@pytest.fixture +def sample_dataset(): + """Create a sample xarray Dataset for testing.""" + time = np.arange(10) + data = xr.Dataset( + { + 'var1': (['time'], np.random.rand(10)), + 'var2': (['time'], np.random.rand(10)), + 'var3': (['time'], np.random.rand(10)), + }, + coords={'time': time}, + ) + return data + + +@pytest.fixture +def sample_dataframe(): + """Create a sample pandas DataFrame for testing.""" + time = np.arange(10) + df = pd.DataFrame({'var1': np.random.rand(10), 'var2': np.random.rand(10), 'var3': np.random.rand(10)}, index=time) + df.index.name = 'time' + return df + + +def test_kwargs_passthrough_plotly(sample_dataset): + """Test that backend-specific kwargs are passed through correctly.""" + fig = plotting.with_plotly( + sample_dataset, + mode='line', + trace_kwargs={'line': {'width': 5}}, + layout_kwargs={'width': 1200, 'height': 600}, + ) + assert fig.layout.width == 1200 + assert fig.layout.height == 600 + + +def test_dataframe_support_plotly(sample_dataframe): + """Test that DataFrames are accepted by plotting functions.""" + fig = plotting.with_plotly(sample_dataframe, mode='line') + assert fig is not None + + +def test_data_validation_non_numeric(): + """Test that validation catches non-numeric data.""" + data = xr.Dataset({'var1': (['time'], ['a', 'b', 'c'])}, coords={'time': [0, 1, 2]}) + + with pytest.raises(TypeError, match='non-numeric dtype'): + plotting.with_plotly(data) + + +def test_ensure_dataset_invalid_type(): + """Test that _ensure_dataset raises error for invalid types.""" + with pytest.raises(TypeError, match='must be xr.Dataset or pd.DataFrame'): + plotting._ensure_dataset([1, 2, 3]) diff --git a/tests/test_results_plots.py b/tests/test_results_plots.py index 1fd6cf7f5..a656f7c44 100644 --- a/tests/test_results_plots.py +++ b/tests/test_results_plots.py @@ -28,7 +28,7 @@ def plotting_engine(request): @pytest.fixture( params=[ - 'viridis', # Test string colormap + 'turbo', # Test string colormap ['#ff0000', '#00ff00', '#0000ff', '#ffff00', '#ff00ff', '#00ffff'], # Test color list { 'Boiler(Q_th)|flow_rate': '#ff0000', @@ -51,7 +51,7 @@ def test_results_plots(flow_system, plotting_engine, show, save, color_spec): # Matplotlib doesn't support faceting/animation, so disable them for matplotlib engine heatmap_kwargs = { 'reshape_time': ('D', 'h'), - 'colors': 'viridis', # Note: heatmap only accepts string colormap + 'colors': 'turbo', # Note: heatmap only accepts string colormap 'save': save, 'show': show, 'engine': plotting_engine,