diff --git a/CHANGELOG.md b/CHANGELOG.md index ca9600f04..8eb16a4c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,17 +54,39 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp ### ✨ Added - Support for plotting kwargs in `results.py`, passed to plotly express and matplotlib. +- **Color management system**: New `color_processing.py` module with `process_colors()` function for unified color handling across plotting backends + - Supports flexible color inputs: colorscale names (e.g., 'turbo', 'plasma'), color lists, and label-to-color dictionaries + - Automatic fallback handling when requested colorscales are unavailable + - Seamless integration with both Plotly and Matplotlib colorscales + - Automatic rgba→hex color conversion for Matplotlib compatibility +- **Component color grouping**: Added `setup_colors()` method to `CalculationResults` and `SegmentedCalculationResults` to create color mappings with similar colors for all variables of a component + - Allows grouping components by custom colorscales: `{'CHP': 'red', 'Greys': ['Gastarif', 'Einspeisung'], 'Storage': 'blue'}` + - Colors are automatically assigned using default colorscale if not specified + - For segmented calculations, colors are propagated to all segments for consistent visualization + - Explicit `colors` arguments in plot methods override configured colors (when provided) +- **Plotting configuration**: New `CONFIG.Plotting` section with extensive customization options: + - `default_show`: Control default visibility of plots + - `default_engine`: Choose between 'plotly' or 'matplotlib' + - `default_dpi`: Configure resolution for saved plots (with matplotlib) + - `default_facet_cols`: Set default columns for faceted plots + - `default_sequential_colorscale`: Default for heatmaps and continuous data (default: 'turbo') + - `default_qualitative_colorscale`: Default for categorical plots (default: 'plotly') ### 💥 Breaking Changes ### ♻️ Changed - **Template integration**: Plotly templates now fully control plot styling without hardcoded overrides - **Dataset first plotting**: Underlying plotting methods in `plotting.py` now use `xr.Dataset` as the main datatype. DataFrames are automatically converted via `_ensure_dataset()`. Both DataFrames and Datasets can be passed to plotting functions without code changes. +- **Color terminology**: Standardized terminology from "colormap" to "colorscale" throughout the codebase for consistency with Plotly conventions +- **Default colorscales**: Changed default sequential colorscale from 'viridis' to 'turbo' for better perceptual uniformity; qualitative colorscale now defaults to 'plotly' +- **Aggregation plotting**: `Aggregation.plot()` now respects `CONFIG.Plotting.default_qualitative_colorscale` and uses `process_colors()` for consistent color handling ### 🗑️ Deprecated ### 🔥 Removed -- Removed `plotting.pie_with_plotly()` method as it was not used +- Removed `plotting.pie_with_plotly()` method as it was not used +- Removed `ColorProcessor` class - replaced by simpler `process_colors()` function +- Removed `resolve_colors()` helper function - color resolution now handled directly by `process_colors()` ### 🐛 Fixed - Improved error messages for `engine='matplotlib'` with multidimensional data @@ -76,9 +98,15 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp ### 📝 Docs - Moved `linked_periods` into correct section of the docstring (was in deprecated params) +- Updated terminology in docstrings from "colormap" to "colorscale" for consistency +- Enhanced examples to demonstrate `setup_colors()` usage: + - `simple_example.py`: Shows automatic color assignment and optional custom configuration + - `scenario_example.py`: Demonstrates component grouping with custom colorscales ### 👷 Development - Fixed concurrency issue in CI +- **Code architecture**: Extracted color processing logic into dedicated `color_processing.py` module for better separation of concerns +- Refactored from class-based (`ColorProcessor`) to function-based color handling for simpler API and reduced complexity ### 🚧 Known Issues diff --git a/examples/01_Simple/simple_example.py b/examples/01_Simple/simple_example.py index 906c24622..6b62d6712 100644 --- a/examples/01_Simple/simple_example.py +++ b/examples/01_Simple/simple_example.py @@ -112,6 +112,9 @@ calculation.solve(fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=30)) # --- Analyze Results --- + # Colors are automatically assigned using default colormap + # Optional: Configure custom colors with + calculation.results.setup_colors() calculation.results['Fernwärme'].plot_node_balance_pie() calculation.results['Fernwärme'].plot_node_balance() calculation.results['Storage'].plot_charge_state() diff --git a/examples/04_Scenarios/scenario_example.py b/examples/04_Scenarios/scenario_example.py index 834e55782..d258d4142 100644 --- a/examples/04_Scenarios/scenario_example.py +++ b/examples/04_Scenarios/scenario_example.py @@ -196,6 +196,15 @@ # --- Solve the Calculation and Save Results --- calculation.solve(fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=30)) + calculation.results.setup_colors( + { + 'CHP': 'red', + 'Greys': ['Gastarif', 'Einspeisung', 'Heat Demand'], + 'Storage': 'blue', + 'Boiler': 'orange', + } + ) + calculation.results.plot_heatmap('CHP(Q_th)|flow_rate') # --- Analyze Results --- diff --git a/flixopt/aggregation.py b/flixopt/aggregation.py index 53770e140..cd0fdde3c 100644 --- a/flixopt/aggregation.py +++ b/flixopt/aggregation.py @@ -20,7 +20,9 @@ except ImportError: TSAM_AVAILABLE = False +from .color_processing import process_colors from .components import Storage +from .config import CONFIG from .structure import ( FlowSystemModel, Submodel, @@ -141,7 +143,7 @@ def describe_clusters(self) -> str: def use_extreme_periods(self): return self.time_series_for_high_peaks or self.time_series_for_low_peaks - def plot(self, colormap: str = 'viridis', show: bool = True, save: pathlib.Path | None = None) -> go.Figure: + def plot(self, colormap: str | None = None, show: bool = True, save: pathlib.Path | None = None) -> go.Figure: from . import plotting df_org = self.original_data.copy().rename( @@ -150,10 +152,13 @@ def plot(self, colormap: str = 'viridis', show: bool = True, save: pathlib.Path df_agg = self.aggregated_data.copy().rename( columns={col: f'Aggregated - {col}' for col in self.aggregated_data.columns} ) - fig = plotting.with_plotly(df_org.to_xarray(), 'line', colors=colormap, xlabel='Time in h') + colors = list( + process_colors(colormap or CONFIG.Plotting.default_qualitative_colorscale, list(df_org.columns)).values() + ) + fig = plotting.with_plotly(df_org.to_xarray(), 'line', colors=colors, xlabel='Time in h') for trace in fig.data: trace.update(dict(line=dict(dash='dash'))) - fig2 = plotting.with_plotly(df_agg.to_xarray(), 'line', colors=colormap, xlabel='Time in h') + fig2 = plotting.with_plotly(df_agg.to_xarray(), 'line', colors=colors, xlabel='Time in h') for trace in fig2.data: fig.add_trace(trace) diff --git a/flixopt/color_processing.py b/flixopt/color_processing.py new file mode 100644 index 000000000..2959acc82 --- /dev/null +++ b/flixopt/color_processing.py @@ -0,0 +1,261 @@ +"""Simplified color handling for visualization. + +This module provides clean color processing that transforms various input formats +into a label-to-color mapping dictionary, without needing to know about the plotting engine. +""" + +from __future__ import annotations + +import logging + +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import plotly.express as px +from plotly.exceptions import PlotlyError + +logger = logging.getLogger('flixopt') + + +def _rgb_string_to_hex(color: str) -> str: + """Convert Plotly RGB/RGBA string format to hex. + + Args: + color: Color in format 'rgb(R, G, B)', 'rgba(R, G, B, A)' or already in hex + + Returns: + Color in hex format '#RRGGBB' + """ + color = color.strip() + + # If already hex, return as-is + if color.startswith('#'): + return color + + # Try to parse rgb() or rgba() + try: + if color.startswith('rgb('): + # Extract RGB values from 'rgb(R, G, B)' format + rgb_str = color[4:-1] # Remove 'rgb(' and ')' + elif color.startswith('rgba('): + # Extract RGBA values from 'rgba(R, G, B, A)' format + rgb_str = color[5:-1] # Remove 'rgba(' and ')' + else: + return color + + # Split on commas and parse first three components + components = rgb_str.split(',') + if len(components) < 3: + return color + + # Parse and clamp the first three components + r = max(0, min(255, int(round(float(components[0].strip()))))) + g = max(0, min(255, int(round(float(components[1].strip()))))) + b = max(0, min(255, int(round(float(components[2].strip()))))) + + return f'#{r:02x}{g:02x}{b:02x}' + except (ValueError, IndexError): + # If parsing fails, return original + return color + + +def process_colors( + colors: None | str | list[str] | dict[str, str], + labels: list[str], + default_colorscale: str = 'turbo', +) -> dict[str, str]: + """Process color input and return a label-to-color mapping. + + This function takes flexible color input and always returns a dictionary + mapping each label to a specific color string. The plotting engine can then + use this mapping as needed. + + Args: + colors: Color specification in one of four formats: + - None: Use the default colorscale + - str: Name of a colorscale (e.g., 'turbo', 'plasma', 'Set1', 'portland') + - list[str]: List of color strings (hex, named colors, etc.) + - dict[str, str]: Direct label-to-color mapping + labels: List of labels that need colors assigned + default_colorscale: Fallback colorscale name if requested scale not found + + Returns: + Dictionary mapping each label to a color string + + Examples: + >>> # Using None - applies default colorscale + >>> process_colors(None, ['A', 'B', 'C']) + {'A': '#0d0887', 'B': '#7e03a8', 'C': '#cc4778'} + + >>> # Using a colorscale name + >>> process_colors('plasma', ['A', 'B', 'C']) + {'A': '#0d0887', 'B': '#7e03a8', 'C': '#cc4778'} + + >>> # Using a list of colors + >>> process_colors(['red', 'blue', 'green'], ['A', 'B', 'C']) + {'A': 'red', 'B': 'blue', 'C': 'green'} + + >>> # Using a pre-made mapping + >>> process_colors({'A': 'red', 'B': 'blue'}, ['A', 'B', 'C']) + {'A': 'red', 'B': 'blue', 'C': '#0d0887'} # C gets color from default scale + """ + if not labels: + return {} + + # Case 1: Already a mapping dictionary + if isinstance(colors, dict): + return _fill_missing_colors(colors, labels, default_colorscale) + + # Case 2: None or colorscale name (string) + if colors is None or isinstance(colors, str): + colorscale_name = colors if colors is not None else default_colorscale + color_list = _get_colors_from_scale(colorscale_name, len(labels), default_colorscale) + return dict(zip(labels, color_list, strict=False)) + + # Case 3: List of colors + if isinstance(colors, list): + if len(colors) == 0: + logger.warning(f'Empty color list provided. Using {default_colorscale} instead.') + color_list = _get_colors_from_scale(default_colorscale, len(labels), default_colorscale) + return dict(zip(labels, color_list, strict=False)) + + if len(colors) < len(labels): + logger.debug( + f'Not enough colors provided ({len(colors)}) for all labels ({len(labels)}). Colors will cycle.' + ) + + # Cycle through colors if we don't have enough + return {label: colors[i % len(colors)] for i, label in enumerate(labels)} + + raise TypeError(f'colors must be None, str, list, or dict, got {type(colors)}') + + +def _fill_missing_colors( + color_mapping: dict[str, str], + labels: list[str], + default_colorscale: str, +) -> dict[str, str]: + """Fill in missing labels in a color mapping using a colorscale. + + Args: + color_mapping: Partial label-to-color mapping + labels: All labels that need colors + default_colorscale: Colorscale to use for missing labels + + Returns: + Complete label-to-color mapping + """ + missing_labels = [label for label in labels if label not in color_mapping] + + if not missing_labels: + return color_mapping.copy() + + # Log warning about missing labels + logger.debug(f'Labels missing colors: {missing_labels}. Using {default_colorscale} for these.') + + # Get colors for missing labels + missing_colors = _get_colors_from_scale(default_colorscale, len(missing_labels), default_colorscale) + + # Combine existing and new colors + result = color_mapping.copy() + result.update(dict(zip(missing_labels, missing_colors, strict=False))) + return result + + +def _get_colors_from_scale( + colorscale_name: str, + num_colors: int, + fallback_scale: str, +) -> list[str]: + """Extract a list of colors from a named colorscale. + + Tries to get colors from the named scale (Plotly first, then Matplotlib), + falls back to the fallback scale if not found. + + Args: + colorscale_name: Name of the colorscale to try + num_colors: Number of colors needed + fallback_scale: Fallback colorscale name if first fails + + Returns: + List of color strings (hex format) + """ + # Try to get the requested colorscale + colors = _try_get_colorscale(colorscale_name, num_colors) + + if colors is not None: + return colors + + # Fallback to default + logger.warning(f"Colorscale '{colorscale_name}' not found. Using '{fallback_scale}' instead.") + + colors = _try_get_colorscale(fallback_scale, num_colors) + + if colors is not None: + return colors + + # Ultimate fallback: just use basic colors + logger.warning(f"Fallback colorscale '{fallback_scale}' also not found. Using basic colors.") + basic_colors = [ + '#1f77b4', + '#ff7f0e', + '#2ca02c', + '#d62728', + '#9467bd', + '#8c564b', + '#e377c2', + '#7f7f7f', + '#bcbd22', + '#17becf', + ] + return [basic_colors[i % len(basic_colors)] for i in range(num_colors)] + + +def _try_get_colorscale(colorscale_name: str, num_colors: int) -> list[str] | None: + """Try to get colors from Plotly or Matplotlib colorscales. + + Tries Plotly colorscales first (both qualitative and sequential), + then falls back to Matplotlib colorscales. + + Args: + colorscale_name: Name of the colorscale + num_colors: Number of colors needed + + Returns: + List of color strings (hex format) if successful, None if colorscale not found + """ + # First try Plotly qualitative (discrete) color sequences + colorscale_title = colorscale_name.title() + if hasattr(px.colors.qualitative, colorscale_title): + color_list = getattr(px.colors.qualitative, colorscale_title) + # Convert to hex format for matplotlib compatibility + return [_rgb_string_to_hex(color_list[i % len(color_list)]) for i in range(num_colors)] + + # Then try Plotly sequential/continuous colorscales + try: + colorscale = px.colors.get_colorscale(colorscale_name) + # Sample evenly from the colorscale + if num_colors == 1: + sample_points = [0.5] + else: + sample_points = [i / (num_colors - 1) for i in range(num_colors)] + colors = px.colors.sample_colorscale(colorscale, sample_points) + # Convert to hex format for matplotlib compatibility + return [_rgb_string_to_hex(c) for c in colors] + except (PlotlyError, ValueError): + pass + + # Finally try Matplotlib colorscales + try: + cmap = plt.get_cmap(colorscale_name) + + # Sample evenly from the colorscale + if num_colors == 1: + colors = [cmap(0.5)] + else: + colors = [cmap(i / (num_colors - 1)) for i in range(num_colors)] + + # Convert RGBA tuples to hex strings + return [mcolors.rgb2hex(color[:3]) for color in colors] + + except (ValueError, KeyError): + return None diff --git a/flixopt/config.py b/flixopt/config.py index a7549a3ec..b7162e55f 100644 --- a/flixopt/config.py +++ b/flixopt/config.py @@ -54,6 +54,16 @@ 'big_binary_bound': 100_000, } ), + 'plotting': MappingProxyType( + { + 'default_show': True, + 'default_engine': 'plotly', + 'default_dpi': 300, + 'default_facet_cols': 3, + 'default_sequential_colorscale': 'turbo', + 'default_qualitative_colorscale': 'plotly', + } + ), } ) @@ -185,6 +195,42 @@ class Modeling: epsilon: float = _DEFAULTS['modeling']['epsilon'] big_binary_bound: int = _DEFAULTS['modeling']['big_binary_bound'] + class Plotting: + """Plotting configuration. + + Configure backends via environment variables: + - Matplotlib: Set `MPLBACKEND` environment variable (e.g., 'Agg', 'TkAgg') + - Plotly: Set `PLOTLY_RENDERER` or use `plotly.io.renderers.default` + + Attributes: + default_show: Default value for the `show` parameter in plot methods. + default_engine: Default plotting engine. + default_dpi: Default DPI for saved plots. + default_facet_cols: Default number of columns for faceted plots. + default_sequential_colorscale: Default colorscale for heatmaps and continuous data. + default_qualitative_colorscale: Default colormap for categorical plots (bar/line/area charts). + + Examples: + ```python + # Set consistent theming + CONFIG.Plotting.plotly_template = 'plotly_dark' + CONFIG.apply() + + # Configure default export and color settings + CONFIG.Plotting.default_dpi = 600 + CONFIG.Plotting.default_sequential_colorscale = 'plasma' + CONFIG.Plotting.default_qualitative_colorscale = 'Dark24' + CONFIG.apply() + ``` + """ + + default_show: bool = _DEFAULTS['plotting']['default_show'] + default_engine: Literal['plotly', 'matplotlib'] = _DEFAULTS['plotting']['default_engine'] + default_dpi: int = _DEFAULTS['plotting']['default_dpi'] + default_facet_cols: int = _DEFAULTS['plotting']['default_facet_cols'] + default_sequential_colorscale: str = _DEFAULTS['plotting']['default_sequential_colorscale'] + default_qualitative_colorscale: str = _DEFAULTS['plotting']['default_qualitative_colorscale'] + config_name: str = _DEFAULTS['config_name'] @classmethod @@ -319,6 +365,14 @@ def to_dict(cls) -> dict: 'epsilon': cls.Modeling.epsilon, 'big_binary_bound': cls.Modeling.big_binary_bound, }, + 'plotting': { + 'default_show': cls.Plotting.default_show, + 'default_engine': cls.Plotting.default_engine, + 'default_dpi': cls.Plotting.default_dpi, + 'default_facet_cols': cls.Plotting.default_facet_cols, + 'default_sequential_colorscale': cls.Plotting.default_sequential_colorscale, + 'default_qualitative_colorscale': cls.Plotting.default_qualitative_colorscale, + }, } diff --git a/flixopt/flow_system.py b/flixopt/flow_system.py index ad43c183b..fd0f6a98d 100644 --- a/flixopt/flow_system.py +++ b/flixopt/flow_system.py @@ -4,7 +4,6 @@ from __future__ import annotations -import json import logging import warnings from typing import TYPE_CHECKING, Any, Literal, Optional @@ -13,6 +12,7 @@ import pandas as pd import xarray as xr +from .config import CONFIG from .core import ( ConversionError, DataConverter, @@ -484,7 +484,7 @@ def plot_network( | list[ Literal['nodes', 'edges', 'layout', 'interaction', 'manipulation', 'physics', 'selection', 'renderer'] ] = True, - show: bool = False, + show: bool | None = None, ) -> pyvis.network.Network | None: """ Visualizes the network structure of a FlowSystem using PyVis, saving it as an interactive HTML file. @@ -514,7 +514,9 @@ def plot_network( from . import plotting node_infos, edge_infos = self.network_infos() - return plotting.plot_network(node_infos, edge_infos, path, controls, show) + return plotting.plot_network( + node_infos, edge_infos, path, controls, show if show is not None else CONFIG.Plotting.default_show + ) def start_network_app(self): """Visualizes the network structure of a FlowSystem using Dash, Cytoscape, and networkx. diff --git a/flixopt/plotting.py b/flixopt/plotting.py index a024c97fc..045cf7e99 100644 --- a/flixopt/plotting.py +++ b/flixopt/plotting.py @@ -40,8 +40,8 @@ import plotly.graph_objects as go import plotly.offline import xarray as xr -from plotly.exceptions import PlotlyError +from .color_processing import process_colors from .config import CONFIG if TYPE_CHECKING: @@ -49,7 +49,7 @@ logger = logging.getLogger('flixopt') -# Define the colors for the 'portland' colormap in matplotlib +# Define the colors for the 'portland' colorscale in matplotlib _portland_colors = [ [12 / 255, 51 / 255, 131 / 255], # Dark blue [10 / 255, 136 / 255, 186 / 255], # Light blue @@ -58,7 +58,7 @@ [217 / 255, 30 / 255, 30 / 255], # Red ] -# Check if the colormap already exists before registering it +# Check if the colorscale already exists before registering it if hasattr(plt, 'colormaps'): # Matplotlib >= 3.7 registry = plt.colormaps if 'portland' not in registry: @@ -73,9 +73,9 @@ Color specifications can take several forms to accommodate different use cases: -**Named Colormaps** (str): - - Standard colormaps: 'viridis', 'plasma', 'cividis', 'tab10', 'Set1' - - Energy-focused: 'portland' (custom flixopt colormap for energy systems) +**Named colorscales** (str): + - Standard colorscales: 'turbo', 'plasma', 'cividis', 'tab10', 'Set1' + - Energy-focused: 'portland' (custom flixopt colorscale for energy systems) - Backend-specific maps available in Plotly and Matplotlib **Color Lists** (list[str]): @@ -90,8 +90,8 @@ Examples: ```python - # Named colormap - colors = 'viridis' # Automatic color generation + # Named colorscale + colors = 'turbo' # Automatic color generation # Explicit color list colors = ['red', 'blue', 'green', '#FFD700'] @@ -114,7 +114,7 @@ References: - HTML Color Names: https://htmlcolorcodes.com/color-names/ - - Matplotlib Colormaps: https://matplotlib.org/stable/tutorials/colors/colormaps.html + - Matplotlib colorscales: https://matplotlib.org/stable/tutorials/colors/colorscales.html - Plotly Built-in Colorscales: https://plotly.com/python/builtin-colorscales/ """ @@ -122,212 +122,6 @@ """Identifier for the plotting engine to use.""" -class ColorProcessor: - """Intelligent color management system for consistent multi-backend visualization. - - This class provides unified color processing across Plotly and Matplotlib backends, - ensuring consistent visual appearance regardless of the plotting engine used. - It handles color palette generation, named colormap translation, and intelligent - color cycling for complex datasets with many categories. - - Key Features: - **Backend Agnostic**: Automatic color format conversion between engines - **Palette Management**: Support for named colormaps, custom palettes, and color lists - **Intelligent Cycling**: Smart color assignment for datasets with many categories - **Fallback Handling**: Graceful degradation when requested colormaps are unavailable - **Energy System Colors**: Built-in palettes optimized for energy system visualization - - Color Input Types: - - **Named Colormaps**: 'viridis', 'plasma', 'portland', 'tab10', etc. - - **Color Lists**: ['red', 'blue', 'green'] or ['#FF0000', '#0000FF', '#00FF00'] - - **Label Dictionaries**: {'Generator': 'red', 'Storage': 'blue', 'Load': 'green'} - - Examples: - Basic color processing: - - ```python - # Initialize for Plotly backend - processor = ColorProcessor(engine='plotly', default_colormap='viridis') - - # Process different color specifications - colors = processor.process_colors('plasma', ['Gen1', 'Gen2', 'Storage']) - colors = processor.process_colors(['red', 'blue', 'green'], ['A', 'B', 'C']) - colors = processor.process_colors({'Wind': 'skyblue', 'Solar': 'gold'}, ['Wind', 'Solar', 'Gas']) - - # Switch to Matplotlib - processor = ColorProcessor(engine='matplotlib') - mpl_colors = processor.process_colors('tab10', component_labels) - ``` - - Energy system visualization: - - ```python - # Specialized energy system palette - energy_colors = { - 'Natural_Gas': '#8B4513', # Brown - 'Electricity': '#FFD700', # Gold - 'Heat': '#FF4500', # Red-orange - 'Cooling': '#87CEEB', # Sky blue - 'Hydrogen': '#E6E6FA', # Lavender - 'Battery': '#32CD32', # Lime green - } - - processor = ColorProcessor('plotly') - flow_colors = processor.process_colors(energy_colors, flow_labels) - ``` - - Args: - engine: Plotting backend ('plotly' or 'matplotlib'). Determines output color format. - default_colormap: Fallback colormap when requested palettes are unavailable. - Common options: 'viridis', 'plasma', 'tab10', 'portland'. - - """ - - def __init__(self, engine: PlottingEngine = 'plotly', default_colormap: str = 'viridis'): - """Initialize the color processor with specified backend and defaults.""" - if engine not in ['plotly', 'matplotlib']: - raise TypeError(f'engine must be "plotly" or "matplotlib", but is {engine}') - self.engine = engine - self.default_colormap = default_colormap - - def _generate_colors_from_colormap(self, colormap_name: str, num_colors: int) -> list[Any]: - """ - Generate colors from a named colormap. - - Args: - colormap_name: Name of the colormap - num_colors: Number of colors to generate - - Returns: - list of colors in the format appropriate for the engine - """ - if self.engine == 'plotly': - try: - colorscale = px.colors.get_colorscale(colormap_name) - except PlotlyError as e: - logger.error(f"Colorscale '{colormap_name}' not found in Plotly. Using {self.default_colormap}: {e}") - colorscale = px.colors.get_colorscale(self.default_colormap) - - # Generate evenly spaced points - color_points = [i / (num_colors - 1) for i in range(num_colors)] if num_colors > 1 else [0] - return px.colors.sample_colorscale(colorscale, color_points) - - else: # matplotlib - try: - cmap = plt.get_cmap(colormap_name, num_colors) - except ValueError as e: - logger.error(f"Colormap '{colormap_name}' not found in Matplotlib. Using {self.default_colormap}: {e}") - cmap = plt.get_cmap(self.default_colormap, num_colors) - - return [cmap(i) for i in range(num_colors)] - - def _handle_color_list(self, colors: list[str], num_labels: int) -> list[str]: - """ - Handle a list of colors, cycling if necessary. - - Args: - colors: list of color strings - num_labels: Number of labels that need colors - - Returns: - list of colors matching the number of labels - """ - if len(colors) == 0: - logger.error(f'Empty color list provided. Using {self.default_colormap} instead.') - return self._generate_colors_from_colormap(self.default_colormap, num_labels) - - if len(colors) < num_labels: - logger.warning( - f'Not enough colors provided ({len(colors)}) for all labels ({num_labels}). Colors will cycle.' - ) - # Cycle through the colors - color_iter = itertools.cycle(colors) - return [next(color_iter) for _ in range(num_labels)] - else: - # Trim if necessary - if len(colors) > num_labels: - logger.warning( - f'More colors provided ({len(colors)}) than labels ({num_labels}). Extra colors will be ignored.' - ) - return colors[:num_labels] - - def _handle_color_dict(self, colors: dict[str, str], labels: list[str]) -> list[str]: - """ - Handle a dictionary mapping labels to colors. - - Args: - colors: Dictionary mapping labels to colors - labels: list of labels that need colors - - Returns: - list of colors in the same order as labels - """ - if len(colors) == 0: - logger.warning(f'Empty color dictionary provided. Using {self.default_colormap} instead.') - return self._generate_colors_from_colormap(self.default_colormap, len(labels)) - - # Find missing labels - missing_labels = sorted(set(labels) - set(colors.keys())) - if missing_labels: - logger.warning( - f'Some labels have no color specified: {missing_labels}. Using {self.default_colormap} for these.' - ) - - # Generate colors for missing labels - missing_colors = self._generate_colors_from_colormap(self.default_colormap, len(missing_labels)) - - # Create a copy to avoid modifying the original - colors_copy = colors.copy() - for i, label in enumerate(missing_labels): - colors_copy[label] = missing_colors[i] - else: - colors_copy = colors - - # Create color list in the same order as labels - return [colors_copy[label] for label in labels] - - def process_colors( - self, - colors: ColorType, - labels: list[str], - return_mapping: bool = False, - ) -> list[Any] | dict[str, Any]: - """ - Process colors for the specified labels. - - Args: - colors: Color specification (colormap name, list of colors, or label-to-color mapping) - labels: list of data labels that need colors assigned - return_mapping: If True, returns a dictionary mapping labels to colors; - if False, returns a list of colors in the same order as labels - - Returns: - Either a list of colors or a dictionary mapping labels to colors - """ - if len(labels) == 0: - logger.error('No labels provided for color assignment.') - return {} if return_mapping else [] - - # Process based on type of colors input - if isinstance(colors, str): - color_list = self._generate_colors_from_colormap(colors, len(labels)) - elif isinstance(colors, list): - color_list = self._handle_color_list(colors, len(labels)) - elif isinstance(colors, dict): - color_list = self._handle_color_dict(colors, labels) - else: - logger.error( - f'Unsupported color specification type: {type(colors)}. Using {self.default_colormap} instead.' - ) - color_list = self._generate_colors_from_colormap(self.default_colormap, len(labels)) - - # Return either a list or a mapping - if return_mapping: - return {label: color_list[i] for i, label in enumerate(labels)} - else: - return color_list - - def _ensure_dataset(data: xr.Dataset | pd.DataFrame | pd.Series) -> xr.Dataset: """Convert DataFrame or Series to Dataset if needed.""" if isinstance(data, xr.Dataset): @@ -373,31 +167,10 @@ def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.") -def resolve_colors( - data: xr.Dataset, - colors: ColorType, - engine: PlottingEngine = 'plotly', -) -> dict[str, str]: - """Resolve colors parameter to a dict mapping variable names to colors.""" - # Get variable names from Dataset (always strings and unique) - labels = list(data.data_vars.keys()) - - # If explicit dict provided, use it directly - if isinstance(colors, dict): - return colors - - # If string or list, use ColorProcessor (traditional behavior) - if isinstance(colors, (str, list)): - processor = ColorProcessor(engine=engine) - return processor.process_colors(colors, labels, return_mapping=True) - - raise TypeError(f'Wrong type passed to resolve_colors(): {type(colors)}') - - def with_plotly( data: xr.Dataset | pd.DataFrame | pd.Series, mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar', - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', ylabel: str = '', xlabel: str = '', @@ -417,7 +190,7 @@ def with_plotly( data: An xarray Dataset, pandas DataFrame, or pandas Series to plot. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines, 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts. - colors: Color specification (colormap, list, or dict mapping labels to colors). + colors: Color specification (colorscale, list, or dict mapping labels to colors). title: The main title of the plot. ylabel: The label for the y-axis. xlabel: The label for the x-axis. @@ -476,9 +249,16 @@ def with_plotly( fig.update_layout(template='plotly_dark', width=1200, height=600) ``` """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}") + # Apply CONFIG defaults if not explicitly set + if facet_cols is None: + facet_cols = CONFIG.Plotting.default_facet_cols + # Ensure data is a Dataset and validate it data = _ensure_dataset(data) _validate_plotting_data(data, allow_empty=True) @@ -496,7 +276,9 @@ def with_plotly( values = [float(data[var].values) for var in data.data_vars] # Resolve colors - color_discrete_map = resolve_colors(data, colors, engine='plotly') + color_discrete_map = process_colors( + colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables] # Create simple plot based on mode using go (not px) for better color control @@ -587,8 +369,9 @@ def with_plotly( # Process colors all_vars = df_long['variable'].unique().tolist() - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars) - color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)} + color_discrete_map = process_colors( + colors, all_vars, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) # Determine which dimension to use for x-axis # Collect dimensions used for faceting and animation @@ -705,7 +488,7 @@ def with_plotly( def with_matplotlib( data: xr.Dataset | pd.DataFrame | pd.Series, mode: Literal['stacked_bar', 'line'] = 'stacked_bar', - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', ylabel: str = '', xlabel: str = 'Time in h', @@ -720,7 +503,7 @@ def with_matplotlib( the index represents time and each column represents a separate data series (variables). mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines. colors: Color specification. Can be: - - A colormap name (e.g., 'turbo', 'plasma') + - A colorscale name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - A dict mapping column names to colors (e.g., {'Column1': '#ff0000'}) title: The title of the plot. @@ -738,6 +521,9 @@ def with_matplotlib( Negative values are stacked separately without extra labels in the legend. - If `mode` is 'line', stepped lines are drawn for each data series. """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + if mode not in ('stacked_bar', 'line'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}") @@ -760,7 +546,9 @@ def with_matplotlib( values = [float(data[var].values) for var in data.data_vars] # Resolve colors - color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + color_discrete_map = process_colors( + colors, variables, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) colors_list = [color_discrete_map.get(var, '#808080') for var in variables] # Create plot based on mode @@ -791,7 +579,9 @@ def with_matplotlib( return fig, ax # Resolve colors first (includes validation) - color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + color_discrete_map = process_colors( + colors, list(data.data_vars), default_colorscale=CONFIG.Plotting.default_qualitative_colorscale + ) # Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form) df = data.to_dataframe() @@ -1199,7 +989,7 @@ def preprocess_data_for_pie( def dual_pie_with_plotly( data_left: xr.Dataset | pd.DataFrame | pd.Series, data_right: xr.Dataset | pd.DataFrame | pd.Series, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), legend_title: str = '', @@ -1229,6 +1019,9 @@ def dual_pie_with_plotly( Returns: Plotly Figure object """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + # Preprocess data to Series left_series = preprocess_data_for_pie(data_left, lower_percentage_group) right_series = preprocess_data_for_pie(data_right, lower_percentage_group) @@ -1244,7 +1037,7 @@ def dual_pie_with_plotly( all_labels = sorted(set(left_labels) | set(right_labels)) # Create color map - color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True) + color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) # Create figure fig = go.Figure() @@ -1294,7 +1087,7 @@ def dual_pie_with_plotly( def dual_pie_with_matplotlib( data_left: xr.Dataset | pd.DataFrame | pd.Series, data_right: xr.Dataset | pd.DataFrame | pd.Series, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), legend_title: str = '', @@ -1308,7 +1101,7 @@ def dual_pie_with_matplotlib( Args: data_left: Data for the left pie chart. data_right: Data for the right pie chart. - colors: Color specification (colormap name, list of colors, or dict mapping) + colors: Color specification (colorscale name, list of colors, or dict mapping) title: The main title of the plot. subtitles: Tuple containing the subtitles for (left, right) charts. legend_title: The title for the legend. @@ -1319,6 +1112,9 @@ def dual_pie_with_matplotlib( Returns: Tuple of (Figure, list of Axes) """ + if colors is None: + colors = CONFIG.Plotting.default_qualitative_colorscale + # Preprocess data to Series left_series = preprocess_data_for_pie(data_left, lower_percentage_group) right_series = preprocess_data_for_pie(data_right, lower_percentage_group) @@ -1333,8 +1129,8 @@ def dual_pie_with_matplotlib( # Get all unique labels for consistent coloring all_labels = sorted(set(left_labels) | set(right_labels)) - # Create color map - color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True) + # Create color map (process_colors always returns a dict) + color_map = process_colors(colors, all_labels, default_colorscale=CONFIG.Plotting.default_qualitative_colorscale) # Create figure fig, axes = plt.subplots(1, 2, figsize=figsize) @@ -1400,11 +1196,11 @@ def draw_pie(ax, labels, values, subtitle): def heatmap_with_plotly( data: xr.DataArray, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', @@ -1427,8 +1223,8 @@ def heatmap_with_plotly( Args: data: An xarray DataArray containing the data to visualize. Should have at least 2 dimensions, or a 'time' dimension that can be reshaped into 2D. - colors: Color specification (colormap name, list, or dict). Common options: - 'viridis', 'plasma', 'RdBu', 'portland'. + colors: Color specification (colorscale name, list, or dict). Common options: + 'turbo', 'plasma', 'RdBu', 'portland'. title: The main title of the heatmap. facet_by: Dimension to create facets for. Creates a subplot grid. Can be a single dimension name or list (only first dimension used). @@ -1484,6 +1280,13 @@ def heatmap_with_plotly( fig = heatmap_with_plotly(data_array, facet_by='scenario', animate_by='period', reshape_time=('W', 'D')) ``` """ + if colors is None: + colors = CONFIG.Plotting.default_sequential_colorscale + + # Apply CONFIG defaults if not explicitly set + if facet_cols is None: + facet_cols = CONFIG.Plotting.default_facet_cols + # Handle empty data if data.size == 0: return go.Figure() @@ -1577,7 +1380,7 @@ def heatmap_with_plotly( # Create the imshow plot - px.imshow can work directly with xarray DataArrays common_args = { 'img': data, - 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis', + 'color_continuous_scale': colors, 'title': title, } @@ -1601,7 +1404,7 @@ def heatmap_with_plotly( # Fallback: create a simple heatmap without faceting fallback_args = { 'img': data.values, - 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis', + 'color_continuous_scale': colors, 'title': title, } fallback_args.update(imshow_kwargs) @@ -1612,7 +1415,7 @@ def heatmap_with_plotly( def heatmap_with_matplotlib( data: xr.DataArray, - colors: ColorType = 'viridis', + colors: ColorType | None = None, title: str = '', figsize: tuple[float, float] = (12, 6), reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] @@ -1635,7 +1438,7 @@ def heatmap_with_matplotlib( data: An xarray DataArray containing the data to visualize. Should have at least 2 dimensions. If more than 2 dimensions exist, additional dimensions will be reduced by taking the first slice. - colors: Color specification. Should be a colormap name (e.g., 'turbo', 'RdBu'). + colors: Color specification. Should be a colorscale name (e.g., 'turbo', 'RdBu'). title: The title of the heatmap. figsize: The size of the figure (width, height) in inches. reshape_time: Time reshaping configuration: @@ -1675,6 +1478,9 @@ def heatmap_with_matplotlib( fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h')) ``` """ + if colors is None: + colors = CONFIG.Plotting.default_sequential_colorscale + # Initialize kwargs if not provided if imshow_kwargs is None: imshow_kwargs = {} @@ -1726,11 +1532,8 @@ def heatmap_with_matplotlib( x_labels = 'x' y_labels = 'y' - # Process colormap - cmap = colors if isinstance(colors, str) else 'viridis' - # Create the heatmap using imshow with user customizations - imshow_defaults = {'cmap': cmap, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax} + imshow_defaults = {'cmap': colors, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax} imshow_defaults.update(imshow_kwargs) # User kwargs override defaults im = ax.imshow(values, **imshow_defaults) @@ -1759,9 +1562,9 @@ def export_figure( default_path: pathlib.Path, default_filetype: str | None = None, user_path: pathlib.Path | None = None, - show: bool = True, + show: bool | None = None, save: bool = False, - dpi: int = 300, + dpi: int | None = None, ) -> go.Figure | tuple[plt.Figure, plt.Axes]: """ Export a figure to a file and or show it. @@ -1771,14 +1574,21 @@ def export_figure( default_path: The default file path if no user filename is provided. default_filetype: The default filetype if the path doesnt end with a filetype. user_path: An optional user-specified file path. - show: Whether to display the figure (default: True). + show: Whether to display the figure. If None, uses CONFIG.Plotting.default_show (default: None). save: Whether to save the figure (default: False). - dpi: DPI (dots per inch) for saving Matplotlib figures. If None, Matplotlib rcParams are used. + dpi: DPI (dots per inch) for saving Matplotlib figures. If None, uses CONFIG.Plotting.default_dpi. Raises: ValueError: If no default filetype is provided and the path doesn't specify a filetype. TypeError: If the figure type is not supported. """ + # Apply CONFIG defaults if not explicitly set + if show is None: + show = CONFIG.Plotting.default_show + + if dpi is None: + dpi = CONFIG.Plotting.default_dpi + filename = user_path or default_path filename = filename.with_name(filename.name.replace('|', '__')) if filename.suffix == '': @@ -1793,25 +1603,17 @@ def export_figure( filename = filename.with_suffix('.html') try: - is_test_env = 'PYTEST_CURRENT_TEST' in os.environ - - if is_test_env: - # Test environment: never open browser, only save if requested - if save: - fig.write_html(str(filename)) - # Ignore show flag in tests - else: - # Production environment: respect show and save flags - if save and show: - # Save and auto-open in browser - plotly.offline.plot(fig, filename=str(filename)) - elif save and not show: - # Save without opening - fig.write_html(str(filename)) - elif show and not save: - # Show interactively without saving - fig.show() - # If neither save nor show: do nothing + # Respect show and save flags (tests should set CONFIG.Plotting.default_show=False) + if save and show: + # Save and auto-open in browser + plotly.offline.plot(fig, filename=str(filename)) + elif save and not show: + # Save without opening + fig.write_html(str(filename)) + elif show and not save: + # Show interactively without saving + fig.show() + # If neither save nor show: do nothing finally: # Cleanup to prevent socket warnings if hasattr(fig, '_renderer'): @@ -1822,12 +1624,11 @@ def export_figure( elif isinstance(figure_like, tuple): fig, ax = figure_like if show: - # Only show if using interactive backend and not in test environment + # Only show if using interactive backend (tests should set CONFIG.Plotting.default_show=False) backend = matplotlib.get_backend().lower() is_interactive = backend not in {'agg', 'pdf', 'ps', 'svg', 'template'} - is_test_env = 'PYTEST_CURRENT_TEST' in os.environ - if is_interactive and not is_test_env: + if is_interactive: plt.show() if save: diff --git a/flixopt/results.py b/flixopt/results.py index 576ff9ec1..847ee5a7f 100644 --- a/flixopt/results.py +++ b/flixopt/results.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import datetime import json import logging @@ -15,6 +16,8 @@ from . import io as fx_io from . import plotting +from .color_processing import process_colors +from .config import CONFIG from .flow_system import FlowSystem if TYPE_CHECKING: @@ -29,6 +32,57 @@ logger = logging.getLogger('flixopt') +def load_mapping_from_file(path: pathlib.Path) -> dict[str, str | list[str]]: + """Load color mapping from JSON or YAML file. + + Tries loader based on file suffix first, with fallback to the other format. + + Args: + path: Path to config file (.json or .yaml/.yml) + + Returns: + Dictionary mapping components to colors or colorscales to component lists + + Raises: + ValueError: If file cannot be loaded as JSON or YAML + """ + suffix = path.suffix.lower() + + if suffix == '.json': + # Try JSON first, fallback to YAML + try: + with open(path) as f: + return json.load(f) + except json.JSONDecodeError: + try: + with open(path) as f: + return yaml.safe_load(f) + except Exception: + raise ValueError(f'Could not load config from {path}') from None + elif suffix in {'.yaml', '.yml'}: + # Try YAML first, fallback to JSON + try: + with open(path) as f: + return yaml.safe_load(f) + except yaml.YAMLError: + try: + with open(path) as f: + return json.load(f) + except Exception: + raise ValueError(f'Could not load config from {path}') from None + else: + # Unknown extension, try both starting with JSON + try: + with open(path) as f: + return json.load(f) + except json.JSONDecodeError: + try: + with open(path) as f: + return yaml.safe_load(f) + except Exception: + raise ValueError(f'Could not load config from {path}') from None + + class _FlowSystemRestorationError(Exception): """Exception raised when a FlowSystem cannot be restored from dataset.""" @@ -254,6 +308,8 @@ def __init__( self._sizes = None self._effects_per_component = None + self.colors: dict[str, str] = {} + def __getitem__(self, key: str) -> ComponentResults | BusResults | EffectResults: if key in self.components: return self.components[key] @@ -320,6 +376,131 @@ def flow_system(self) -> FlowSystem: logger.level = old_level return self._flow_system + def setup_colors( + self, + config: dict[str, str | list[str]] | str | pathlib.Path | None = None, + default_colorscale: str | None = None, + ) -> dict[str, str]: + """ + Setup colors for all variables across all elements. Overwrites existing ones. + + Args: + config: Configuration for color assignment. Can be: + - dict: Maps components to colors/colorscales: + * 'component1': 'red' # Single component to single color + * 'component1': '#FF0000' # Single component to hex color + - OR maps colorscales to multiple components: + * 'colorscale_name': ['component1', 'component2'] # Colorscale across components + - str: Path to a JSON/YAML config file or a colorscale name to apply to all + - Path: Path to a JSON/YAML config file + - None: Use default_colorscale for all components + default_colorscale: Default colorscale for unconfigured components (default: 'turbo') + + Examples: + setup_colors({ + # Direct component-to-color mappings + 'Boiler1': '#FF0000', + 'CHP': 'darkred', + # Colorscale for multiple components + 'Oranges': ['Solar1', 'Solar2'], + 'Blues': ['Wind1', 'Wind2'], + 'Greens': ['Battery1', 'Battery2', 'Battery3'], + }) + + Returns: + Complete variable-to-color mapping dictionary + """ + + def get_all_variable_names(comp: str) -> list[str]: + """Collect all variables from the component, including flows and flow_hours.""" + comp_object = self.components[comp] + var_names = [comp] + list(comp_object._variable_names) + for flow in comp_object.flows: + var_names.extend([flow, f'{flow}|flow_hours']) + return var_names + + # Set default colorscale if not provided + if default_colorscale is None: + default_colorscale = CONFIG.Plotting.default_qualitative_colorscale + + # Handle different config input types + if config is None: + # Apply default colorscale to all components + config_dict = {} + elif isinstance(config, (str, pathlib.Path)): + # Try to load from file first + config_path = pathlib.Path(config) + if config_path.exists(): + # Load config from file using helper + config_dict = load_mapping_from_file(config_path) + else: + # Treat as colorscale name to apply to all components + all_components = list(self.components.keys()) + config_dict = {config: all_components} + elif isinstance(config, dict): + config_dict = config + else: + raise TypeError(f'config must be dict, str, Path, or None, got {type(config)}') + + # Step 1: Build component-to-color mapping + component_colors: dict[str, str] = {} + + # Track which components are configured + configured_components = set() + + # Process each configuration entry + for key, value in config_dict.items(): + # Check if value is a list (colorscale -> [components]) + # or a string (component -> color OR colorscale -> [components]) + + if isinstance(value, list): + # key is colorscale, value is list of components + # Format: 'Blues': ['Wind1', 'Wind2'] + components = value + colorscale_name = key + + # Validate components exist + for component in components: + if component not in self.components: + raise ValueError(f"Component '{component}' not found") + + configured_components.update(components) + + # Use process_colors to get one color per component from the colorscale + colors_for_components = process_colors(colorscale_name, components) + component_colors.update(colors_for_components) + + elif isinstance(value, str): + # Check if key is an existing component + if key in self.components: + # Format: 'CHP': 'red' (component -> color) + component, color = key, value + + configured_components.add(component) + component_colors[component] = color + else: + raise ValueError(f"Component '{key}' not found") + else: + raise TypeError(f'Config value must be str or list, got {type(value)}') + + # Step 2: Assign colors to remaining unconfigured components + remaining_components = list(set(self.components.keys()) - configured_components) + if remaining_components: + # Use default colorscale to assign one color per remaining component + default_colors = process_colors(default_colorscale, remaining_components) + component_colors.update(default_colors) + + # Step 3: Build variable-to-color mapping + # Clear existing colors to avoid stale keys + self.colors = {} + # Each component's variables all get the same color as the component + for component, color in component_colors.items(): + variable_names = get_all_variable_names(component) + for var_name in variable_names: + self.colors[var_name] = color + + return self.colors + def filter_solution( self, variable_dims: Literal['scalar', 'time', 'scenario', 'timeonly', 'scenarioonly'] | None = None, @@ -719,13 +900,13 @@ def plot_heatmap( self, variable_name: str | list[str], save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', @@ -1024,8 +1205,8 @@ def __init__( def plot_node_balance( self, save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, unit_type: Literal['flow_rate', 'flow_hours'] = 'flow_rate', @@ -1033,7 +1214,7 @@ def plot_node_balance( drop_suffix: bool = True, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, **plot_kwargs: Any, @@ -1183,7 +1364,7 @@ def plot_node_balance( ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=colors if colors is not None else self._calculation_results.colors, mode=mode, title=title, facet_cols=facet_cols, @@ -1194,7 +1375,7 @@ def plot_node_balance( else: figure_like = plotting.with_matplotlib( ds, - colors=colors, + colors=colors if colors is not None else self._calculation_results.colors, mode=mode, title=title, **plot_kwargs, @@ -1214,10 +1395,10 @@ def plot_node_balance( def plot_node_balance_pie( self, lower_percentage_group: float = 5, - colors: plotting.ColorType = 'viridis', + colors: plotting.ColorType | None = None, text_info: str = 'percent+label+value', save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[FlowSystemDimensions, Any] | None = None, # Deprecated parameter (kept for backwards compatibility) @@ -1351,7 +1532,7 @@ def plot_node_balance_pie( figure_like = plotting.dual_pie_with_plotly( data_left=inputs, data_right=outputs, - colors=colors, + colors=colors if colors is not None else self._calculation_results.colors, title=title, text_info=text_info, subtitles=('Inputs', 'Outputs'), @@ -1365,7 +1546,7 @@ def plot_node_balance_pie( figure_like = plotting.dual_pie_with_matplotlib( data_left=inputs.to_pandas(), data_right=outputs.to_pandas(), - colors=colors, + colors=colors if colors is not None else self._calculation_results.colors, title=title, subtitles=('Inputs', 'Outputs'), legend_title='Flows', @@ -1480,14 +1661,14 @@ def charge_state(self) -> xr.DataArray: def plot_charge_state( self, save: bool | pathlib.Path = False, - show: bool = True, - colors: plotting.ColorType = 'viridis', + show: bool | None = None, + colors: plotting.ColorType | None = None, engine: plotting.PlottingEngine = 'plotly', mode: Literal['area', 'stacked_bar', 'line'] = 'area', select: dict[FlowSystemDimensions, Any] | None = None, facet_by: str | list[str] | None = 'scenario', animate_by: str | None = 'period', - facet_cols: int = 3, + facet_cols: int | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, **plot_kwargs: Any, @@ -1601,7 +1782,7 @@ def plot_charge_state( ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=colors if colors is not None else self._calculation_results.colors, mode=mode, title=title, facet_cols=facet_cols, @@ -1617,7 +1798,7 @@ def plot_charge_state( charge_state_ds, facet_by=facet_by, animate_by=animate_by, - colors=colors, + colors=colors if colors is not None else self._calculation_results.colors, mode='line', # Always line for charge_state title='', # No title needed for this temp figure facet_cols=facet_cols, @@ -1657,7 +1838,7 @@ def plot_charge_state( # For matplotlib, plot flows (node balance), then add charge_state as line fig, ax = plotting.with_matplotlib( ds, - colors=colors, + colors=colors if colors is not None else self._calculation_results.colors, mode=mode, title=title, **plot_kwargs, @@ -1933,6 +2114,7 @@ def __init__( self.name = name self.folder = pathlib.Path(folder) if folder is not None else pathlib.Path.cwd() / 'results' self.hours_per_timestep = FlowSystem.calculate_hours_per_timestep(self.all_timesteps) + self._colors = {} @property def meta_data(self) -> dict[str, int | list[str]]: @@ -1947,6 +2129,64 @@ def meta_data(self) -> dict[str, int | list[str]]: def segment_names(self) -> list[str]: return [segment.name for segment in self.segment_results] + @property + def colors(self) -> dict[str, str]: + return self._colors + + @colors.setter + def colors(self, colors: dict[str, str]): + """Applies colors to all segments""" + self._colors = colors + for segment in self.segment_results: + segment.colors = copy.deepcopy(colors) + + def setup_colors( + self, + config: dict[str, str | list[str]] | str | pathlib.Path | None = None, + default_colorscale: str | None = None, + ) -> dict[str, str]: + """ + Setup colors for all variables across all segment results. + + This method applies the same color configuration to all segments, ensuring + consistent visualization across the entire segmented calculation. The color + mapping is propagated to each segment's CalculationResults instance. + + Args: + config: Configuration for color assignment. Can be: + - dict: Maps components to colors/colorscales: + * 'component1': 'red' # Single component to single color + * 'component1': '#FF0000' # Single component to hex color + - OR maps colorscales to multiple components: + * 'colorscale_name': ['component1', 'component2'] # Colorscale across components + - str: Path to a JSON/YAML config file or a colorscale name to apply to all + - Path: Path to a JSON/YAML config file + - None: Use default_colorscale for all components + default_colorscale: Default colorscale for unconfigured components (default: 'turbo') + + Examples: + ```python + # Apply colors to all segments + segmented_results.setup_colors( + { + 'CHP': 'red', + 'Blues': ['Storage1', 'Storage2'], + 'Oranges': ['Solar1', 'Solar2'], + } + ) + + # Use a single colorscale for all components in all segments + segmented_results.setup_colors('portland') + ``` + + Returns: + Complete variable-to-color mapping dictionary from the first segment + (all segments will have the same mapping) + """ + self.colors = self.segment_results[0].setup_colors(config=config, default_colorscale=default_colorscale) + + return self.colors + def solution_without_overlap(self, variable_name: str) -> xr.DataArray: """Get variable solution removing segment overlaps. @@ -1968,13 +2208,13 @@ def plot_heatmap( reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', - colors: str = 'portland', + colors: plotting.ColorType | None = None, save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, fill: Literal['ffill', 'bfill'] | None = 'ffill', # Deprecated parameters (kept for backwards compatibility) heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, @@ -2039,7 +2279,7 @@ def plot_heatmap( if color_map is not None: # Check for conflict with new parameter - if colors != 'portland': # Check if user explicitly set colors + if colors is not None: # Check if user explicitly set colors raise ValueError( "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'." ) @@ -2099,14 +2339,14 @@ def plot_heatmap( data: xr.DataArray | xr.Dataset, name: str | None = None, folder: pathlib.Path | None = None, - colors: plotting.ColorType = 'viridis', + colors: plotting.ColorType | None = None, save: bool | pathlib.Path = False, - show: bool = True, + show: bool | None = None, engine: plotting.PlottingEngine = 'plotly', select: dict[str, Any] | None = None, facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', @@ -2187,8 +2427,7 @@ def plot_heatmap( # Handle deprecated color_map parameter if color_map is not None: - # Check for conflict with new parameter - if colors != 'viridis': # User explicitly set colors + if colors is not None: # User explicitly set colors raise ValueError( "Cannot use both deprecated parameter 'color_map' and new parameter 'colors'. Use only 'colors'." ) diff --git a/tests/conftest.py b/tests/conftest.py index ac5255562..bd940b843 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -848,4 +848,6 @@ def set_test_environment(): pio.renderers.default = 'json' # Use non-interactive renderer + fx.CONFIG.Plotting.default_show = False + yield diff --git a/tests/test_results_plots.py b/tests/test_results_plots.py index 1fd6cf7f5..a656f7c44 100644 --- a/tests/test_results_plots.py +++ b/tests/test_results_plots.py @@ -28,7 +28,7 @@ def plotting_engine(request): @pytest.fixture( params=[ - 'viridis', # Test string colormap + 'turbo', # Test string colormap ['#ff0000', '#00ff00', '#0000ff', '#ffff00', '#ff00ff', '#00ffff'], # Test color list { 'Boiler(Q_th)|flow_rate': '#ff0000', @@ -51,7 +51,7 @@ def test_results_plots(flow_system, plotting_engine, show, save, color_spec): # Matplotlib doesn't support faceting/animation, so disable them for matplotlib engine heatmap_kwargs = { 'reshape_time': ('D', 'h'), - 'colors': 'viridis', # Note: heatmap only accepts string colormap + 'colors': 'turbo', # Note: heatmap only accepts string colormap 'save': save, 'show': show, 'engine': plotting_engine,