diff --git a/CHANGELOG.md b/CHANGELOG.md index dda77c01a..ca9600f04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,16 +53,22 @@ If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOp If upgrading from v2.x, see the [v3.0.0 release notes](https://github.com/flixOpt/flixOpt/releases/tag/v3.0.0) and [Migration Guide](https://flixopt.github.io/flixopt/latest/user-guide/migration-guide-v3/). ### ✨ Added +- Support for plotting kwargs in `results.py`, passed to plotly express and matplotlib. ### 💥 Breaking Changes ### ♻️ Changed +- **Template integration**: Plotly templates now fully control plot styling without hardcoded overrides +- **Dataset first plotting**: Underlying plotting methods in `plotting.py` now use `xr.Dataset` as the main datatype. DataFrames are automatically converted via `_ensure_dataset()`. Both DataFrames and Datasets can be passed to plotting functions without code changes. ### 🗑️ Deprecated ### 🔥 Removed +- Removed `plotting.pie_with_plotly()` method as it was not used ### 🐛 Fixed +- Improved error messages for `engine='matplotlib'` with multidimensional data +- Better dimension validation in `results.plot_heatmap()` ### 🔒 Security diff --git a/examples/02_Complex/complex_example_results.py b/examples/02_Complex/complex_example_results.py index 5020f71fe..96d06dd04 100644 --- a/examples/02_Complex/complex_example_results.py +++ b/examples/02_Complex/complex_example_results.py @@ -25,8 +25,9 @@ # --- Detailed Plots --- # In depth plot for individual flow rates ('__' is used as the delimiter between Component and Flow results.plot_heatmap('Wärmelast(Q_th_Last)|flow_rate') - for flow_rate in results['BHKW2'].inputs + results['BHKW2'].outputs: - results.plot_heatmap(flow_rate) + for bus in results.buses.values(): + bus.plot_node_balance_pie(show=False, save=f'results/{bus.label}--pie.html') + bus.plot_node_balance(show=False, save=f'results/{bus.label}--balance.html') # --- Plotting internal variables manually --- results.plot_heatmap('BHKW2(Q_th)|on') diff --git a/examples/03_Calculation_types/example_calculation_types.py b/examples/03_Calculation_types/example_calculation_types.py index 8df8e742f..c5df50034 100644 --- a/examples/03_Calculation_types/example_calculation_types.py +++ b/examples/03_Calculation_types/example_calculation_types.py @@ -202,35 +202,38 @@ def get_solutions(calcs: list, variable: str) -> xr.Dataset: # --- Plotting for comparison --- fx.plotting.with_plotly( - get_solutions(calculations, 'Speicher|charge_state').to_dataframe(), + get_solutions(calculations, 'Speicher|charge_state'), mode='line', title='Charge State Comparison', ylabel='Charge state', + xlabel='Time in h', ).write_html('results/Charge State.html') fx.plotting.with_plotly( - get_solutions(calculations, 'BHKW2(Q_th)|flow_rate').to_dataframe(), + get_solutions(calculations, 'BHKW2(Q_th)|flow_rate'), mode='line', title='BHKW2(Q_th) Flow Rate Comparison', ylabel='Flow rate', + xlabel='Time in h', ).write_html('results/BHKW2 Thermal Power.html') fx.plotting.with_plotly( - get_solutions(calculations, 'costs(temporal)|per_timestep').to_dataframe(), + get_solutions(calculations, 'costs(temporal)|per_timestep'), mode='line', title='Operation Cost Comparison', ylabel='Costs [€]', + xlabel='Time in h', ).write_html('results/Operation Costs.html') fx.plotting.with_plotly( - pd.DataFrame(get_solutions(calculations, 'costs(temporal)|per_timestep').to_dataframe().sum()).T, + get_solutions(calculations, 'costs(temporal)|per_timestep').sum('time'), mode='stacked_bar', title='Total Cost Comparison', ylabel='Costs [€]', ).update_layout(barmode='group').write_html('results/Total Costs.html') fx.plotting.with_plotly( - pd.DataFrame([calc.durations for calc in calculations], index=[calc.name for calc in calculations]), + pd.DataFrame([calc.durations for calc in calculations], index=[calc.name for calc in calculations]).to_xarray(), mode='stacked_bar', ).update_layout(title='Duration Comparison', xaxis_title='Calculation type', yaxis_title='Time (s)').write_html( 'results/Speed Comparison.html' diff --git a/flixopt/aggregation.py b/flixopt/aggregation.py index 91ef618a9..53770e140 100644 --- a/flixopt/aggregation.py +++ b/flixopt/aggregation.py @@ -150,13 +150,17 @@ def plot(self, colormap: str = 'viridis', show: bool = True, save: pathlib.Path df_agg = self.aggregated_data.copy().rename( columns={col: f'Aggregated - {col}' for col in self.aggregated_data.columns} ) - fig = plotting.with_plotly(df_org, 'line', colors=colormap) + fig = plotting.with_plotly(df_org.to_xarray(), 'line', colors=colormap, xlabel='Time in h') for trace in fig.data: trace.update(dict(line=dict(dash='dash'))) - fig = plotting.with_plotly(df_agg, 'line', colors=colormap, fig=fig) + fig2 = plotting.with_plotly(df_agg.to_xarray(), 'line', colors=colormap, xlabel='Time in h') + for trace in fig2.data: + fig.add_trace(trace) fig.update_layout( - title='Original vs Aggregated Data (original = ---)', xaxis_title='Index', yaxis_title='Value' + title='Original vs Aggregated Data (original = ---)', + xaxis_title='Time in h', + yaxis_title='Value', ) plotting.export_figure( diff --git a/flixopt/plotting.py b/flixopt/plotting.py index bd1f3c2c4..a024c97fc 100644 --- a/flixopt/plotting.py +++ b/flixopt/plotting.py @@ -42,6 +42,8 @@ import xarray as xr from plotly.exceptions import PlotlyError +from .config import CONFIG + if TYPE_CHECKING: import pyvis @@ -326,36 +328,99 @@ def process_colors( return color_list +def _ensure_dataset(data: xr.Dataset | pd.DataFrame | pd.Series) -> xr.Dataset: + """Convert DataFrame or Series to Dataset if needed.""" + if isinstance(data, xr.Dataset): + return data + elif isinstance(data, pd.DataFrame): + # Convert DataFrame to Dataset + return data.to_xarray() + elif isinstance(data, pd.Series): + # Convert Series to DataFrame first, then to Dataset + return data.to_frame().to_xarray() + else: + raise TypeError(f'Data must be xr.Dataset, pd.DataFrame, or pd.Series, got {type(data).__name__}') + + +def _validate_plotting_data(data: xr.Dataset, allow_empty: bool = False) -> None: + """Validate dataset for plotting (checks for empty data, non-numeric types, etc.).""" + # Check for empty data + if not allow_empty and len(data.data_vars) == 0: + raise ValueError('Empty Dataset provided (no variables). Cannot create plot.') + + # Check if dataset has any data (xarray uses nbytes for total size) + if all(data[var].size == 0 for var in data.data_vars) if len(data.data_vars) > 0 else True: + if not allow_empty and len(data.data_vars) > 0: + raise ValueError('Dataset has zero size. Cannot create plot.') + if len(data.data_vars) == 0: + return # Empty dataset, nothing to validate + return + + # Check for non-numeric data types + for var in data.data_vars: + dtype = data[var].dtype + if not np.issubdtype(dtype, np.number): + raise TypeError( + f"Variable '{var}' has non-numeric dtype '{dtype}'. " + f'Plotting requires numeric data types (int, float, etc.).' + ) + + # Warn about NaN/Inf values + for var in data.data_vars: + if np.isnan(data[var].values).any(): + logger.debug(f"Variable '{var}' contains NaN values which may affect visualization.") + if np.isinf(data[var].values).any(): + logger.debug(f"Variable '{var}' contains Inf values which may affect visualization.") + + +def resolve_colors( + data: xr.Dataset, + colors: ColorType, + engine: PlottingEngine = 'plotly', +) -> dict[str, str]: + """Resolve colors parameter to a dict mapping variable names to colors.""" + # Get variable names from Dataset (always strings and unique) + labels = list(data.data_vars.keys()) + + # If explicit dict provided, use it directly + if isinstance(colors, dict): + return colors + + # If string or list, use ColorProcessor (traditional behavior) + if isinstance(colors, (str, list)): + processor = ColorProcessor(engine=engine) + return processor.process_colors(colors, labels, return_mapping=True) + + raise TypeError(f'Wrong type passed to resolve_colors(): {type(colors)}') + + def with_plotly( - data: pd.DataFrame | xr.DataArray | xr.Dataset, + data: xr.Dataset | pd.DataFrame | pd.Series, mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar', colors: ColorType = 'viridis', title: str = '', ylabel: str = '', - xlabel: str = 'Time in h', - fig: go.Figure | None = None, + xlabel: str = '', facet_by: str | list[str] | None = None, animate_by: str | None = None, - facet_cols: int = 3, + facet_cols: int | None = None, shared_yaxes: bool = True, shared_xaxes: bool = True, + **px_kwargs: Any, ) -> go.Figure: """ Plot data with Plotly using facets (subplots) and/or animation for multidimensional data. Uses Plotly Express for convenient faceting and animation with automatic styling. - For simple plots without faceting, can optionally add to an existing figure. Args: - data: A DataFrame or xarray DataArray/Dataset to plot. + data: An xarray Dataset, pandas DataFrame, or pandas Series to plot. mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines, 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts. colors: Color specification (colormap, list, or dict mapping labels to colors). title: The main title of the plot. ylabel: The label for the y-axis. xlabel: The label for the x-axis. - fig: A Plotly figure object to plot on (only for simple plots without faceting). - If not provided, a new figure will be created. facet_by: Dimension(s) to create facets for. Creates a subplot grid. Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col). If the dimension doesn't exist in the data, it will be silently ignored. @@ -364,93 +429,113 @@ def with_plotly( facet_cols: Number of columns in the facet grid (used when facet_by is single dimension). shared_yaxes: Whether subplots share y-axes. shared_xaxes: Whether subplots share x-axes. + **px_kwargs: Additional keyword arguments passed to the underlying Plotly Express function + (px.bar, px.line, px.area). These override default arguments if provided. + Examples: range_x=[0, 100], range_y=[0, 50], category_orders={...}, line_shape='linear' Returns: - A Plotly figure object containing the faceted/animated plot. + A Plotly figure object containing the faceted/animated plot. You can further customize + the returned figure using Plotly's methods (e.g., fig.update_traces(), fig.update_layout()). Examples: Simple plot: ```python - fig = with_plotly(df, mode='area', title='Energy Mix') + fig = with_plotly(dataset, mode='area', title='Energy Mix') ``` Facet by scenario: ```python - fig = with_plotly(ds, facet_by='scenario', facet_cols=2) + fig = with_plotly(dataset, facet_by='scenario', facet_cols=2) ``` Animate by period: ```python - fig = with_plotly(ds, animate_by='period') + fig = with_plotly(dataset, animate_by='period') ``` Facet and animate: ```python - fig = with_plotly(ds, facet_by='scenario', animate_by='period') + fig = with_plotly(dataset, facet_by='scenario', animate_by='period') + ``` + + Customize with Plotly Express kwargs: + + ```python + fig = with_plotly(dataset, range_y=[0, 100], line_shape='linear') + ``` + + Further customize the returned figure: + + ```python + fig = with_plotly(dataset, mode='line') + fig.update_traces(line={'width': 5, 'dash': 'dot'}) + fig.update_layout(template='plotly_dark', width=1200, height=600) ``` """ if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}") + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + # Handle empty data - if isinstance(data, pd.DataFrame) and data.empty: - return go.Figure() - elif isinstance(data, xr.DataArray) and data.size == 0: - return go.Figure() - elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0: + if len(data.data_vars) == 0: + logger.error('with_plotly() got an empty Dataset.') return go.Figure() - # Warn if fig parameter is used with faceting - if fig is not None and (facet_by is not None or animate_by is not None): - logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.') - fig = None - - # Convert xarray to long-form DataFrame for Plotly Express - if isinstance(data, (xr.DataArray, xr.Dataset)): - # Convert to long-form (tidy) DataFrame - # Structure: time, variable, value, scenario, period, ... (all dims as columns) - if isinstance(data, xr.Dataset): - # Stack all data variables into long format - df_long = data.to_dataframe().reset_index() - # Melt to get: time, scenario, period, ..., variable, value - id_vars = [dim for dim in data.dims] - value_vars = list(data.data_vars) - df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') - else: - # DataArray - df_long = data.to_dataframe().reset_index() - if data.name: - df_long = df_long.rename(columns={data.name: 'value'}) - else: - # Unnamed DataArray, find the value column - non_dim_cols = [col for col in df_long.columns if col not in data.dims] - if len(non_dim_cols) != 1: - raise ValueError( - f'Expected exactly one non-dimension column for unnamed DataArray, ' - f'but found {len(non_dim_cols)}: {non_dim_cols}' + # Handle all-scalar datasets (where all variables have no dimensions) + # This occurs when all variables are scalar values with dims=() + if all(len(data[var].dims) == 0 for var in data.data_vars): + # Create a simple DataFrame with variable names as x-axis + variables = list(data.data_vars.keys()) + values = [float(data[var].values) for var in data.data_vars] + + # Resolve colors + color_discrete_map = resolve_colors(data, colors, engine='plotly') + marker_colors = [color_discrete_map.get(var, '#636EFA') for var in variables] + + # Create simple plot based on mode using go (not px) for better color control + if mode in ('stacked_bar', 'grouped_bar'): + fig = go.Figure(data=[go.Bar(x=variables, y=values, marker_color=marker_colors)]) + elif mode == 'line': + fig = go.Figure( + data=[ + go.Scatter( + x=variables, + y=values, + mode='lines+markers', + marker=dict(color=marker_colors, size=8), + line=dict(color='lightgray'), ) - value_col = non_dim_cols[0] - df_long = df_long.rename(columns={value_col: 'value'}) - df_long['variable'] = data.name or 'data' - else: - # Already a DataFrame - convert to long format for Plotly Express - df_long = data.reset_index() - if 'time' not in df_long.columns: - # First column is probably time - df_long = df_long.rename(columns={df_long.columns[0]: 'time'}) - # Melt to long format - id_vars = [ - col - for col in df_long.columns - if col in ['time', 'scenario', 'period'] - or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else []) - ] - value_vars = [col for col in df_long.columns if col not in id_vars] - df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') + ] + ) + elif mode == 'area': + fig = go.Figure( + data=[ + go.Scatter( + x=variables, + y=values, + fill='tozeroy', + marker=dict(color=marker_colors, size=8), + line=dict(color='lightgray'), + ) + ] + ) + else: + raise ValueError('"mode" must be one of "stacked_bar", "grouped_bar", "line", "area"') + + fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel, showlegend=False) + return fig + + # Convert Dataset to long-form DataFrame for Plotly Express + # Structure: time, variable, value, scenario, period, ... (all dims as columns) + dim_names = list(data.dims) + df_long = data.to_dataframe().reset_index().melt(id_vars=dim_names, var_name='variable', value_name='value') # Validate facet_by and animate_by dimensions exist in the data available_dims = [col for col in df_long.columns if col not in ['variable', 'value']] @@ -505,10 +590,32 @@ def with_plotly( processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars) color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=True)} + # Determine which dimension to use for x-axis + # Collect dimensions used for faceting and animation + used_dims = set() + if facet_row: + used_dims.add(facet_row) + if facet_col: + used_dims.add(facet_col) + if animate_by: + used_dims.add(animate_by) + + # Find available dimensions for x-axis (not used for faceting/animation) + x_candidates = [d for d in available_dims if d not in used_dims] + + # Use 'time' if available, otherwise use the first available dimension + if 'time' in x_candidates: + x_dim = 'time' + elif len(x_candidates) > 0: + x_dim = x_candidates[0] + else: + # Fallback: use the first dimension (shouldn't happen in normal cases) + x_dim = available_dims[0] if available_dims else 'time' + # Create plot using Plotly Express based on mode common_args = { 'data_frame': df_long, - 'x': 'time', + 'x': x_dim, 'y': 'value', 'color': 'variable', 'facet_row': facet_row, @@ -516,13 +623,22 @@ def with_plotly( 'animation_frame': animate_by, 'color_discrete_map': color_discrete_map, 'title': title, - 'labels': {'value': ylabel, 'time': xlabel, 'variable': ''}, + 'labels': {'value': ylabel, x_dim: xlabel, 'variable': ''}, } # Add facet_col_wrap for single facet dimension if facet_col and not facet_row: common_args['facet_col_wrap'] = facet_cols + # Add mode-specific defaults (before px_kwargs so they can be overridden) + if mode in ('line', 'area'): + common_args['line_shape'] = 'hv' # Stepped lines by default + + # Allow callers to pass any px.* keyword args (e.g., category_orders, range_x/y, line_shape) + # These will override the defaults set above + if px_kwargs: + common_args.update(px_kwargs) + if mode == 'stacked_bar': fig = px.bar(**common_args) fig.update_traces(marker_line_width=0) @@ -531,10 +647,10 @@ def with_plotly( fig = px.bar(**common_args) fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0) elif mode == 'line': - fig = px.line(**common_args, line_shape='hv') # Stepped lines + fig = px.line(**common_args) elif mode == 'area': # Use Plotly Express to create the area plot (preserves animation, legends, faceting) - fig = px.area(**common_args, line_shape='hv') + fig = px.area(**common_args) # Classify each variable based on its values variable_classification = {} @@ -577,13 +693,6 @@ def with_plotly( if hasattr(trace, 'fill'): trace.fill = None - # Update layout with basic styling (Plotly Express handles sizing automatically) - fig.update_layout( - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - font=dict(size=12), - ) - # Update axes to share if requested (Plotly Express already handles this, but we can customize) if not shared_yaxes: fig.update_yaxes(matches=None) @@ -594,33 +703,32 @@ def with_plotly( def with_matplotlib( - data: pd.DataFrame, + data: xr.Dataset | pd.DataFrame | pd.Series, mode: Literal['stacked_bar', 'line'] = 'stacked_bar', colors: ColorType = 'viridis', title: str = '', ylabel: str = '', xlabel: str = 'Time in h', figsize: tuple[int, int] = (12, 6), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, + plot_kwargs: dict[str, Any] | None = None, ) -> tuple[plt.Figure, plt.Axes]: """ - Plot a DataFrame with Matplotlib using stacked bars or stepped lines. + Plot data with Matplotlib using stacked bars or stepped lines. Args: - data: A DataFrame containing the data to plot. The index should represent time (e.g., hours), - and each column represents a separate data series. + data: An xarray Dataset, pandas DataFrame, or pandas Series to plot. After conversion to DataFrame, + the index represents time and each column represents a separate data series (variables). mode: Plotting mode. Use 'stacked_bar' for stacked bar charts or 'line' for stepped lines. - colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') + colors: Color specification. Can be: + - A colormap name (e.g., 'turbo', 'plasma') - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) + - A dict mapping column names to colors (e.g., {'Column1': '#ff0000'}) title: The title of the plot. ylabel: The ylabel of the plot. xlabel: The xlabel of the plot. - figsize: Specify the size of the figure - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. + figsize: Specify the size of the figure (width, height) in inches. + plot_kwargs: Optional dict of parameters to pass to ax.bar() or ax.step() plotting calls. + Use this to customize plot properties (e.g., linewidth, alpha, edgecolor). Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. @@ -633,45 +741,111 @@ def with_matplotlib( if mode not in ('stacked_bar', 'line'): raise ValueError(f"'mode' must be one of {{'stacked_bar','line'}} for matplotlib, got {mode!r}") - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Ensure data is a Dataset and validate it + data = _ensure_dataset(data) + _validate_plotting_data(data, allow_empty=True) + + # Create new figure and axes + fig, ax = plt.subplots(figsize=figsize) + + # Initialize plot_kwargs if not provided + if plot_kwargs is None: + plot_kwargs = {} + + # Handle all-scalar datasets (where all variables have no dimensions) + # This occurs when all variables are scalar values with dims=() + if all(len(data[var].dims) == 0 for var in data.data_vars): + # Create simple bar/line plot with variable names as x-axis + variables = list(data.data_vars.keys()) + values = [float(data[var].values) for var in data.data_vars] + + # Resolve colors + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') + colors_list = [color_discrete_map.get(var, '#808080') for var in variables] + + # Create plot based on mode + if mode == 'stacked_bar': + ax.bar(variables, values, color=colors_list, **plot_kwargs) + elif mode == 'line': + ax.plot( + variables, + values, + marker='o', + color=colors_list[0] if len(set(colors_list)) == 1 else None, + **plot_kwargs, + ) + # If different colors, plot each point separately + if len(set(colors_list)) > 1: + ax.clear() + for i, (var, val) in enumerate(zip(variables, values, strict=False)): + ax.plot([i], [val], marker='o', color=colors_list[i], label=var, **plot_kwargs) + ax.set_xticks(range(len(variables))) + ax.set_xticklabels(variables) + + ax.set_xlabel(xlabel, ha='center') + ax.set_ylabel(ylabel, va='center') + ax.set_title(title) + ax.grid(color='lightgrey', linestyle='-', linewidth=0.5, axis='y') + fig.tight_layout() + + return fig, ax + + # Resolve colors first (includes validation) + color_discrete_map = resolve_colors(data, colors, engine='matplotlib') - processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, list(data.columns)) + # Convert Dataset to DataFrame for matplotlib plotting (naturally wide-form) + df = data.to_dataframe() + + # Get colors in column order + processed_colors = [color_discrete_map.get(str(col), '#808080') for col in df.columns] if mode == 'stacked_bar': - cumulative_positive = np.zeros(len(data)) - cumulative_negative = np.zeros(len(data)) - width = data.index.to_series().diff().dropna().min() # Minimum time difference + cumulative_positive = np.zeros(len(df)) + cumulative_negative = np.zeros(len(df)) + + # Robust bar width: handle datetime-like, numeric, and single-point indexes + if len(df.index) > 1: + delta = pd.Index(df.index).to_series().diff().dropna().min() + if hasattr(delta, 'total_seconds'): # datetime-like + width = delta.total_seconds() / 86400.0 # Matplotlib date units = days + else: + width = float(delta) + else: + width = 0.8 # reasonable default for a single bar - for i, column in enumerate(data.columns): - positive_values = np.clip(data[column], 0, None) # Keep only positive values - negative_values = np.clip(data[column], None, 0) # Keep only negative values + for i, column in enumerate(df.columns): + # Fill NaNs to avoid breaking stacking math + series = df[column].fillna(0) + positive_values = np.clip(series, 0, None) # Keep only positive values + negative_values = np.clip(series, None, 0) # Keep only negative values # Plot positive bars ax.bar( - data.index, + df.index, positive_values, bottom=cumulative_positive, color=processed_colors[i], label=column, width=width, align='center', + **plot_kwargs, ) cumulative_positive += positive_values.values # Plot negative bars ax.bar( - data.index, + df.index, negative_values, bottom=cumulative_negative, color=processed_colors[i], label='', # No label for negative bars width=width, align='center', + **plot_kwargs, ) cumulative_negative += negative_values.values elif mode == 'line': - for i, column in enumerate(data.columns): - ax.step(data.index, data[column], where='post', color=processed_colors[i], label=column) + for i, column in enumerate(df.columns): + ax.step(df.index, df[column], where='post', color=processed_colors[i], label=column, **plot_kwargs) # Aesthetics ax.set_xlabel(xlabel, ha='center') @@ -944,228 +1118,104 @@ def plot_network( ) -def pie_with_plotly( - data: pd.DataFrame, - colors: ColorType = 'viridis', - title: str = '', - legend_title: str = '', - hole: float = 0.0, - fig: go.Figure | None = None, -) -> go.Figure: +def preprocess_data_for_pie( + data: xr.Dataset | pd.DataFrame | pd.Series, + lower_percentage_threshold: float = 5.0, +) -> pd.Series: """ - Create a pie chart with Plotly to visualize the proportion of values in a DataFrame. + Preprocess data for pie chart display. - Args: - data: A DataFrame containing the data to plot. If multiple rows exist, - they will be summed unless a specific index value is passed. - colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) - title: The title of the plot. - legend_title: The title for the legend. - hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0). - fig: A Plotly figure object to plot on. If not provided, a new figure will be created. - - Returns: - A Plotly figure object containing the generated pie chart. - - Notes: - - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning. - - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category - for better readability. - - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing. - - """ - if data.empty: - logger.error('Empty DataFrame provided for pie chart. Returning empty figure.') - return go.Figure() - - # Create a copy to avoid modifying the original DataFrame - data_copy = data.copy() - - # Check if any negative values and warn - if (data_copy < 0).any().any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - data_copy = data_copy.abs() - - # If data has multiple rows, sum them to get total for each column - if len(data_copy) > 1: - data_sum = data_copy.sum() - else: - data_sum = data_copy.iloc[0] - - # Get labels (column names) and values - labels = data_sum.index.tolist() - values = data_sum.values.tolist() - - # Apply color mapping using the unified color processor - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, labels) - - # Create figure if not provided - fig = fig if fig is not None else go.Figure() - - # Add pie trace - fig.add_trace( - go.Pie( - labels=labels, - values=values, - hole=hole, - marker=dict(colors=processed_colors), - textinfo='percent+label+value', - textposition='inside', - insidetextorientation='radial', - ) - ) - - # Update layout for better aesthetics - fig.update_layout( - title=title, - legend_title=legend_title, - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), # Increase font size for better readability - ) - - return fig - - -def pie_with_matplotlib( - data: pd.DataFrame, - colors: ColorType = 'viridis', - title: str = '', - legend_title: str = 'Categories', - hole: float = 0.0, - figsize: tuple[int, int] = (10, 8), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, -) -> tuple[plt.Figure, plt.Axes]: - """ - Create a pie chart with Matplotlib to visualize the proportion of values in a DataFrame. + Groups items that are individually below the threshold percentage into an "Other" category. + Converts various input types to a pandas Series for uniform handling. Args: - data: A DataFrame containing the data to plot. If multiple rows exist, - they will be summed unless a specific index value is passed. - colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) - title: The title of the plot. - legend_title: The title for the legend. - hole: Size of the hole in the center for creating a donut chart (0.0 to 1.0). - figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. + data: Input data (xarray Dataset, DataFrame, or Series) + lower_percentage_threshold: Percentage threshold - items below this are grouped into "Other" Returns: - A tuple containing the Matplotlib figure and axes objects used for the plot. - - Notes: - - Negative values are not appropriate for pie charts and will be converted to absolute values with a warning. - - If the data contains very small values (less than 1% of the total), they can be grouped into an "Other" category - for better readability. - - By default, the sum of all columns is used for the pie chart. For time series data, consider preprocessing. - + Processed pandas Series with small items grouped into "Other" """ - if data.empty: - logger.error('Empty DataFrame provided for pie chart. Returning empty figure.') - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) - return fig, ax + # Convert to Series + if isinstance(data, xr.Dataset): + # Sum all dimensions for each variable to get total values + values = {} + for var in data.data_vars: + var_data = data[var] + if len(var_data.dims) > 0: + total_value = float(var_data.sum().item()) + else: + total_value = float(var_data.item()) - # Create a copy to avoid modifying the original DataFrame - data_copy = data.copy() + # Handle negative values + if total_value < 0: + logger.warning(f'Negative value for {var}: {total_value}. Using absolute value.') + total_value = abs(total_value) - # Check if any negative values and warn - if (data_copy < 0).any().any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - data_copy = data_copy.abs() + values[var] = total_value - # If data has multiple rows, sum them to get total for each column - if len(data_copy) > 1: - data_sum = data_copy.sum() - else: - data_sum = data_copy.iloc[0] + series = pd.Series(values) - # Get labels (column names) and values - labels = data_sum.index.tolist() - values = data_sum.values.tolist() + elif isinstance(data, pd.DataFrame): + # Sum across all columns if DataFrame + series = data.sum(axis=0) + # Handle negative values + negative_mask = series < 0 + if negative_mask.any(): + logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.') + series = series.abs() - # Apply color mapping using the unified color processor - processed_colors = ColorProcessor(engine='matplotlib').process_colors(colors, labels) + else: # pd.Series + series = data.copy() + # Handle negative values + negative_mask = series < 0 + if negative_mask.any(): + logger.warning(f'Negative values found: {series[negative_mask].to_dict()}. Using absolute values.') + series = series.abs() - # Create figure and axis if not provided - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Only keep positive values + series = series[series > 0] - # Draw the pie chart - wedges, texts, autotexts = ax.pie( - values, - labels=labels, - colors=processed_colors, - autopct='%1.1f%%', - startangle=90, - shadow=False, - wedgeprops=dict(width=0.5) if hole > 0 else None, # Set width for donut - ) + if series.empty or lower_percentage_threshold <= 0: + return series - # Adjust the wedgeprops to make donut hole size consistent with plotly - # For matplotlib, the hole size is determined by the wedge width - # Convert hole parameter to wedge width - if hole > 0: - # Adjust hole size to match plotly's hole parameter - # In matplotlib, wedge width is relative to the radius (which is 1) - # For plotly, hole is a fraction of the radius - wedge_width = 1 - hole - for wedge in wedges: - wedge.set_width(wedge_width) - - # Customize the appearance - # Make autopct text more visible - for autotext in autotexts: - autotext.set_fontsize(10) - autotext.set_color('white') - - # Set aspect ratio to be equal to ensure a circular pie - ax.set_aspect('equal') - - # Add title - if title: - ax.set_title(title, fontsize=16) + # Calculate percentages + total = series.sum() + percentages = (series / total) * 100 - # Create a legend if there are many segments - if len(labels) > 6: - ax.legend(wedges, labels, title=legend_title, loc='center left', bbox_to_anchor=(1, 0, 0.5, 1)) + # Find items below and above threshold + below_threshold = series[percentages < lower_percentage_threshold] + above_threshold = series[percentages >= lower_percentage_threshold] - # Apply tight layout - fig.tight_layout() + # Only group if there are at least 2 items below threshold + if len(below_threshold) > 1: + # Create new series with items above threshold + "Other" + result = above_threshold.copy() + result['Other'] = below_threshold.sum() + return result - return fig, ax + return series def dual_pie_with_plotly( - data_left: pd.Series, - data_right: pd.Series, + data_left: xr.Dataset | pd.DataFrame | pd.Series, + data_right: xr.Dataset | pd.DataFrame | pd.Series, colors: ColorType = 'viridis', title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), legend_title: str = '', hole: float = 0.2, lower_percentage_group: float = 5.0, - hover_template: str = '%{label}: %{value} (%{percent})', text_info: str = 'percent+label', text_position: str = 'inside', + hover_template: str = '%{label}: %{value} (%{percent})', ) -> go.Figure: """ - Create two pie charts side by side with Plotly, with consistent coloring across both charts. + Create two pie charts side by side with Plotly. Args: - data_left: Series for the left pie chart. - data_right: Series for the right pie chart. - colors: Color specification, can be: - - A string with a colorscale name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'}) + data_left: Data for the left pie chart. Variables are summed across all dimensions. + data_right: Data for the right pie chart. Variables are summed across all dimensions. + colors: Color specification (colorscale name, list of colors, or dict mapping) title: The main title of the plot. subtitles: Tuple containing the subtitles for (left, right) charts. legend_title: The title for the legend. @@ -1177,119 +1227,64 @@ def dual_pie_with_plotly( text_position: Position of text: 'inside', 'outside', 'auto', or 'none'. Returns: - A Plotly figure object containing the generated dual pie chart. + Plotly Figure object """ - from plotly.subplots import make_subplots + # Preprocess data to Series + left_series = preprocess_data_for_pie(data_left, lower_percentage_group) + right_series = preprocess_data_for_pie(data_right, lower_percentage_group) - # Check for empty data - if data_left.empty and data_right.empty: - logger.error('Both datasets are empty. Returning empty figure.') - return go.Figure() + # Extract labels and values + left_labels = left_series.index.tolist() + left_values = left_series.values.tolist() - # Create a subplot figure - fig = make_subplots( - rows=1, cols=2, specs=[[{'type': 'pie'}, {'type': 'pie'}]], subplot_titles=subtitles, horizontal_spacing=0.05 - ) + right_labels = right_series.index.tolist() + right_values = right_series.values.tolist() - # Process series to handle negative values and apply minimum percentage threshold - def preprocess_series(series: pd.Series): - """ - Preprocess a series for pie chart display by handling negative values - and grouping the smallest parts together if they collectively represent - less than the specified percentage threshold. - - Args: - series: The series to preprocess - - Returns: - A preprocessed pandas Series - """ - # Handle negative values - if (series < 0).any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - series = series.abs() - - # Remove zeros - series = series[series > 0] - - # Apply minimum percentage threshold if needed - if lower_percentage_group and not series.empty: - total = series.sum() - if total > 0: - # Sort series by value (ascending) - sorted_series = series.sort_values() - - # Calculate cumulative percentage contribution - cumulative_percent = (sorted_series.cumsum() / total) * 100 - - # Find entries that collectively make up less than lower_percentage_group - to_group = cumulative_percent <= lower_percentage_group + # Get all unique labels for consistent coloring + all_labels = sorted(set(left_labels) | set(right_labels)) - if to_group.sum() > 1: - # Create "Other" category for the smallest values that together are < threshold - other_sum = sorted_series[to_group].sum() - - # Keep only values that aren't in the "Other" group - result_series = series[~series.index.isin(sorted_series[to_group].index)] - - # Add the "Other" category if it has a value - if other_sum > 0: - result_series['Other'] = other_sum - - return result_series - - return series - - data_left_processed = preprocess_series(data_left) - data_right_processed = preprocess_series(data_right) - - # Get unique set of all labels for consistent coloring - all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index)) - - # Get consistent color mapping for both charts using our unified function + # Create color map color_map = ColorProcessor(engine='plotly').process_colors(colors, all_labels, return_mapping=True) - # Function to create a pie trace with consistently mapped colors - def create_pie_trace(data_series, side): - if data_series.empty: - return None - - labels = data_series.index.tolist() - values = data_series.values.tolist() - trace_colors = [color_map[label] for label in labels] - - return go.Pie( - labels=labels, - values=values, - name=side, - marker=dict(colors=trace_colors), - hole=hole, - textinfo=text_info, - textposition=text_position, - insidetextorientation='radial', - hovertemplate=hover_template, - sort=True, # Sort values by default (largest first) + # Create figure + fig = go.Figure() + + # Add left pie + if left_labels: + fig.add_trace( + go.Pie( + labels=left_labels, + values=left_values, + name=subtitles[0], + marker=dict(colors=[color_map.get(label, '#636EFA') for label in left_labels]), + hole=hole, + textinfo=text_info, + textposition=text_position, + hovertemplate=hover_template, + domain=dict(x=[0, 0.48]), + ) ) - # Add left pie if data exists - left_trace = create_pie_trace(data_left_processed, subtitles[0]) - if left_trace: - left_trace.domain = dict(x=[0, 0.48]) - fig.add_trace(left_trace, row=1, col=1) - - # Add right pie if data exists - right_trace = create_pie_trace(data_right_processed, subtitles[1]) - if right_trace: - right_trace.domain = dict(x=[0.52, 1]) - fig.add_trace(right_trace, row=1, col=2) + # Add right pie + if right_labels: + fig.add_trace( + go.Pie( + labels=right_labels, + values=right_values, + name=subtitles[1], + marker=dict(colors=[color_map.get(label, '#636EFA') for label in right_labels]), + hole=hole, + textinfo=text_info, + textposition=text_position, + hovertemplate=hover_template, + domain=dict(x=[0.52, 1]), + ) + ) # Update layout fig.update_layout( title=title, legend_title=legend_title, - plot_bgcolor='rgba(0,0,0,0)', # Transparent background - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background - font=dict(size=14), margin=dict(t=80, b=50, l=30, r=30), ) @@ -1297,8 +1292,8 @@ def create_pie_trace(data_series, side): def dual_pie_with_matplotlib( - data_left: pd.Series, - data_right: pd.Series, + data_left: xr.Dataset | pd.DataFrame | pd.Series, + data_right: xr.Dataset | pd.DataFrame | pd.Series, colors: ColorType = 'viridis', title: str = '', subtitles: tuple[str, str] = ('Left Chart', 'Right Chart'), @@ -1306,154 +1301,99 @@ def dual_pie_with_matplotlib( hole: float = 0.2, lower_percentage_group: float = 5.0, figsize: tuple[int, int] = (14, 7), - fig: plt.Figure | None = None, - axes: list[plt.Axes] | None = None, ) -> tuple[plt.Figure, list[plt.Axes]]: """ - Create two pie charts side by side with Matplotlib, with consistent coloring across both charts. - Leverages the existing pie_with_matplotlib function. + Create two pie charts side by side with Matplotlib. Args: - data_left: Series for the left pie chart. - data_right: Series for the right pie chart. - colors: Color specification, can be: - - A string with a colormap name (e.g., 'viridis', 'plasma') - - A list of color strings (e.g., ['#ff0000', '#00ff00']) - - A dictionary mapping category names to colors (e.g., {'Category1': '#ff0000'}) + data_left: Data for the left pie chart. + data_right: Data for the right pie chart. + colors: Color specification (colormap name, list of colors, or dict mapping) title: The main title of the plot. subtitles: Tuple containing the subtitles for (left, right) charts. legend_title: The title for the legend. hole: Size of the hole in the center for creating donut charts (0.0 to 1.0). lower_percentage_group: Whether to group small segments (below percentage) into an "Other" category. figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - axes: A list of Matplotlib axes objects to plot on. If not provided, new axes will be created. Returns: - A tuple containing the Matplotlib figure and list of axes objects used for the plot. + Tuple of (Figure, list of Axes) """ - # Check for empty data - if data_left.empty and data_right.empty: - logger.error('Both datasets are empty. Returning empty figure.') - if fig is None: - fig, axes = plt.subplots(1, 2, figsize=figsize) - return fig, axes - - # Create figure and axes if not provided - if fig is None or axes is None: - fig, axes = plt.subplots(1, 2, figsize=figsize) - - # Process series to handle negative values and apply minimum percentage threshold - def preprocess_series(series: pd.Series): - """ - Preprocess a series for pie chart display by handling negative values - and grouping the smallest parts together if they collectively represent - less than the specified percentage threshold. - """ - # Handle negative values - if (series < 0).any(): - logger.error('Negative values detected in data. Using absolute values for pie chart.') - series = series.abs() - - # Remove zeros - series = series[series > 0] - - # Apply minimum percentage threshold if needed - if lower_percentage_group and not series.empty: - total = series.sum() - if total > 0: - # Sort series by value (ascending) - sorted_series = series.sort_values() - - # Calculate cumulative percentage contribution - cumulative_percent = (sorted_series.cumsum() / total) * 100 + # Preprocess data to Series + left_series = preprocess_data_for_pie(data_left, lower_percentage_group) + right_series = preprocess_data_for_pie(data_right, lower_percentage_group) - # Find entries that collectively make up less than lower_percentage_group - to_group = cumulative_percent <= lower_percentage_group + # Extract labels and values + left_labels = left_series.index.tolist() + left_values = left_series.values.tolist() - if to_group.sum() > 1: - # Create "Other" category for the smallest values that together are < threshold - other_sum = sorted_series[to_group].sum() + right_labels = right_series.index.tolist() + right_values = right_series.values.tolist() - # Keep only values that aren't in the "Other" group - result_series = series[~series.index.isin(sorted_series[to_group].index)] + # Get all unique labels for consistent coloring + all_labels = sorted(set(left_labels) | set(right_labels)) - # Add the "Other" category if it has a value - if other_sum > 0: - result_series['Other'] = other_sum - - return result_series - - return series + # Create color map + color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True) - # Preprocess data - data_left_processed = preprocess_series(data_left) - data_right_processed = preprocess_series(data_right) + # Create figure + fig, axes = plt.subplots(1, 2, figsize=figsize) - # Convert Series to DataFrames for pie_with_matplotlib - df_left = pd.DataFrame(data_left_processed).T if not data_left_processed.empty else pd.DataFrame() - df_right = pd.DataFrame(data_right_processed).T if not data_right_processed.empty else pd.DataFrame() + def draw_pie(ax, labels, values, subtitle): + """Draw a single pie chart.""" + if not labels: + ax.set_title(subtitle) + ax.axis('off') + return - # Get unique set of all labels for consistent coloring - all_labels = sorted(set(data_left_processed.index) | set(data_right_processed.index)) + chart_colors = [color_map[label] for label in labels] - # Get consistent color mapping for both charts using our unified function - color_map = ColorProcessor(engine='matplotlib').process_colors(colors, all_labels, return_mapping=True) + # Draw pie + wedges, texts, autotexts = ax.pie( + values, + labels=labels, + colors=chart_colors, + autopct='%1.1f%%', + startangle=90, + wedgeprops=dict(width=1 - hole) if hole > 0 else None, + ) - # Configure colors for each DataFrame based on the consistent mapping - left_colors = [color_map[col] for col in df_left.columns] if not df_left.empty else [] - right_colors = [color_map[col] for col in df_right.columns] if not df_right.empty else [] + # Style text + for autotext in autotexts: + autotext.set_fontsize(10) + autotext.set_color('white') + autotext.set_weight('bold') - # Create left pie chart - if not df_left.empty: - pie_with_matplotlib(data=df_left, colors=left_colors, title=subtitles[0], hole=hole, fig=fig, ax=axes[0]) - else: - axes[0].set_title(subtitles[0]) - axes[0].axis('off') + ax.set_aspect('equal') + ax.set_title(subtitle, fontsize=14, pad=20) - # Create right pie chart - if not df_right.empty: - pie_with_matplotlib(data=df_right, colors=right_colors, title=subtitles[1], hole=hole, fig=fig, ax=axes[1]) - else: - axes[1].set_title(subtitles[1]) - axes[1].axis('off') + # Draw both pies + draw_pie(axes[0], left_labels, left_values, subtitles[0]) + draw_pie(axes[1], right_labels, right_values, subtitles[1]) # Add main title if title: fig.suptitle(title, fontsize=16, y=0.98) - # Adjust layout - fig.tight_layout() - - # Create a unified legend if both charts have data - if not df_left.empty and not df_right.empty: - # Remove individual legends - for ax in axes: - if ax.get_legend(): - ax.get_legend().remove() - - # Create handles for the unified legend - handles = [] - labels_for_legend = [] - - for label in all_labels: - color = color_map[label] - patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=label) - handles.append(patch) - labels_for_legend.append(label) + # Create unified legend + if left_labels or right_labels: + handles = [ + plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[label], markersize=10) + for label in all_labels + ] - # Add unified legend fig.legend( handles=handles, - labels=labels_for_legend, + labels=all_labels, title=legend_title, loc='lower center', - bbox_to_anchor=(0.5, 0), - ncol=min(len(all_labels), 5), # Limit columns to 5 for readability + bbox_to_anchor=(0.5, -0.02), + ncol=min(len(all_labels), 5), ) - # Add padding at the bottom for the legend - fig.subplots_adjust(bottom=0.2) + fig.subplots_adjust(bottom=0.15) + + fig.tight_layout() return fig, axes @@ -1469,6 +1409,7 @@ def heatmap_with_plotly( | Literal['auto'] | None = 'auto', fill: Literal['ffill', 'bfill'] | None = 'ffill', + **imshow_kwargs: Any, ) -> go.Figure: """ Plot a heatmap visualization using Plotly's imshow with faceting and animation support. @@ -1501,6 +1442,11 @@ def heatmap_with_plotly( - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) - None: Disable time reshaping (will error if only 1D time data) fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + **imshow_kwargs: Additional keyword arguments to pass to plotly.express.imshow. + Common options include: + - aspect: 'auto', 'equal', or a number for aspect ratio + - zmin, zmax: Minimum and maximum values for color scale + - labels: Dict to customize axis labels Returns: A Plotly figure object containing the heatmap visualization. @@ -1589,12 +1535,26 @@ def heatmap_with_plotly( heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] if len(heatmap_dims) < 2: - # Need at least 2 dimensions for a heatmap - logger.error( - f'Heatmap requires at least 2 dimensions for rows and columns. ' - f'After faceting/animation, only {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' - ) - return go.Figure() + # Handle single-dimension case by adding variable name as a dimension + if len(heatmap_dims) == 1: + # Get the variable name, or use a default + var_name = data.name if data.name else 'value' + + # Expand the DataArray by adding a new dimension with the variable name + data = data.expand_dims({'variable': [var_name]}) + + # Update available dimensions + available_dims = list(data.dims) + heatmap_dims = [dim for dim in available_dims if dim not in facet_dims] + + logger.debug(f'Only 1 dimension remaining for heatmap. Added variable dimension: {var_name}') + else: + # No dimensions at all - cannot create a heatmap + logger.error( + f'Heatmap requires at least 1 dimension. ' + f'After faceting/animation, {len(heatmap_dims)} dimension(s) remain: {heatmap_dims}' + ) + return go.Figure() # Setup faceting parameters for Plotly Express # Note: px.imshow only supports facet_col, not facet_row @@ -1631,23 +1591,21 @@ def heatmap_with_plotly( if animate_by: common_args['animation_frame'] = animate_by + # Merge in additional imshow kwargs + common_args.update(imshow_kwargs) + try: fig = px.imshow(**common_args) except Exception as e: logger.error(f'Error creating imshow plot: {e}. Falling back to basic heatmap.') # Fallback: create a simple heatmap without faceting - fig = px.imshow( - data.values, - color_continuous_scale=colors if isinstance(colors, str) else 'viridis', - title=title, - ) - - # Update layout with basic styling - fig.update_layout( - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - font=dict(size=12), - ) + fallback_args = { + 'img': data.values, + 'color_continuous_scale': colors if isinstance(colors, str) else 'viridis', + 'title': title, + } + fallback_args.update(imshow_kwargs) + fig = px.imshow(**fallback_args) return fig @@ -1657,12 +1615,15 @@ def heatmap_with_matplotlib( colors: ColorType = 'viridis', title: str = '', figsize: tuple[float, float] = (12, 6), - fig: plt.Figure | None = None, - ax: plt.Axes | None = None, reshape_time: tuple[Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'], Literal['W', 'D', 'h', '15min', 'min']] | Literal['auto'] | None = 'auto', fill: Literal['ffill', 'bfill'] | None = 'ffill', + vmin: float | None = None, + vmax: float | None = None, + imshow_kwargs: dict[str, Any] | None = None, + cbar_kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes]: """ Plot a heatmap visualization using Matplotlib's imshow. @@ -1674,16 +1635,25 @@ def heatmap_with_matplotlib( data: An xarray DataArray containing the data to visualize. Should have at least 2 dimensions. If more than 2 dimensions exist, additional dimensions will be reduced by taking the first slice. - colors: Color specification. Should be a colormap name (e.g., 'viridis', 'RdBu'). + colors: Color specification. Should be a colormap name (e.g., 'turbo', 'RdBu'). title: The title of the heatmap. figsize: The size of the figure (width, height) in inches. - fig: A Matplotlib figure object to plot on. If not provided, a new figure will be created. - ax: A Matplotlib axes object to plot on. If not provided, a new axes will be created. reshape_time: Time reshaping configuration: - 'auto' (default): Automatically applies ('D', 'h') if only 'time' dimension - Tuple like ('D', 'h'): Explicit time reshaping (days vs hours) - None: Disable time reshaping fill: Method to fill missing values when reshaping time: 'ffill' or 'bfill'. Default is 'ffill'. + vmin: Minimum value for color scale. If None, uses data minimum. + vmax: Maximum value for color scale. If None, uses data maximum. + imshow_kwargs: Optional dict of parameters to pass to ax.imshow(). + Use this to customize image properties (e.g., interpolation, aspect). + cbar_kwargs: Optional dict of parameters to pass to plt.colorbar(). + Use this to customize colorbar properties (e.g., orientation, label). + **kwargs: Additional keyword arguments passed to ax.imshow(). + Common options include: + - interpolation: 'nearest', 'bilinear', 'bicubic', etc. + - alpha: Transparency level (0-1) + - extent: [left, right, bottom, top] for axis limits Returns: A tuple containing the Matplotlib figure and axes objects used for the plot. @@ -1705,19 +1675,33 @@ def heatmap_with_matplotlib( fig, ax = heatmap_with_matplotlib(data_array, reshape_time=('D', 'h')) ``` """ + # Initialize kwargs if not provided + if imshow_kwargs is None: + imshow_kwargs = {} + if cbar_kwargs is None: + cbar_kwargs = {} + + # Merge any additional kwargs into imshow_kwargs + # This allows users to pass imshow options directly + imshow_kwargs.update(kwargs) + # Handle empty data if data.size == 0: - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) return fig, ax # Apply time reshaping using the new unified function # Matplotlib doesn't support faceting/animation, so we pass None for those data = reshape_data_for_heatmap(data, reshape_time=reshape_time, facet_by=None, animate_by=None, fill=fill) - # Create figure and axes if not provided - if fig is None or ax is None: - fig, ax = plt.subplots(figsize=figsize) + # Handle single-dimension case by adding variable name as a dimension + if isinstance(data, xr.DataArray) and len(data.dims) == 1: + var_name = data.name if data.name else 'value' + data = data.expand_dims({'variable': [var_name]}) + logger.debug(f'Only 1 dimension in data. Added variable dimension: {var_name}') + + # Create figure and axes + fig, ax = plt.subplots(figsize=figsize) # Extract data values # If data has more than 2 dimensions, we need to reduce it @@ -1745,12 +1729,19 @@ def heatmap_with_matplotlib( # Process colormap cmap = colors if isinstance(colors, str) else 'viridis' - # Create the heatmap using imshow - im = ax.imshow(values, cmap=cmap, aspect='auto', origin='upper') + # Create the heatmap using imshow with user customizations + imshow_defaults = {'cmap': cmap, 'aspect': 'auto', 'origin': 'upper', 'vmin': vmin, 'vmax': vmax} + imshow_defaults.update(imshow_kwargs) # User kwargs override defaults + im = ax.imshow(values, **imshow_defaults) + + # Add colorbar with user customizations + cbar_defaults = {'ax': ax, 'orientation': 'horizontal', 'pad': 0.1, 'aspect': 15, 'fraction': 0.05} + cbar_defaults.update(cbar_kwargs) # User kwargs override defaults + cbar = plt.colorbar(im, **cbar_defaults) - # Add colorbar - cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.1, aspect=15, fraction=0.05) - cbar.set_label('Value') + # Set colorbar label if not overridden by user + if 'label' not in cbar_kwargs: + cbar.set_label('Value') # Set labels and title ax.set_xlabel(str(x_labels).capitalize()) @@ -1770,6 +1761,7 @@ def export_figure( user_path: pathlib.Path | None = None, show: bool = True, save: bool = False, + dpi: int = 300, ) -> go.Figure | tuple[plt.Figure, plt.Axes]: """ Export a figure to a file and or show it. @@ -1781,6 +1773,7 @@ def export_figure( user_path: An optional user-specified file path. show: Whether to display the figure (default: True). save: Whether to save the figure (default: False). + dpi: DPI (dots per inch) for saving Matplotlib figures. If None, Matplotlib rcParams are used. Raises: ValueError: If no default filetype is provided and the path doesn't specify a filetype. @@ -1838,7 +1831,7 @@ def export_figure( plt.show() if save: - fig.savefig(str(filename), dpi=300) + fig.savefig(str(filename), dpi=dpi) plt.close(fig) # Close figure to free memory return fig, ax diff --git a/flixopt/results.py b/flixopt/results.py index 75f8f300e..576ff9ec1 100644 --- a/flixopt/results.py +++ b/flixopt/results.py @@ -107,6 +107,20 @@ class CalculationResults: ).mean() ``` + Configure automatic color management for plots: + + ```python + # Dict-based configuration: + results.setup_colors({'Solar*': 'Oranges', 'Wind*': 'Blues', 'Battery': 'green'}) + + # All plots automatically use configured colors (colors=None is the default) + results['ElectricityBus'].plot_node_balance() + results['Battery'].plot_charge_state() + + # Override when needed + results['ElectricityBus'].plot_node_balance(colors='turbo') # Ignores setup + ``` + Design Patterns: **Factory Methods**: Use `from_file()` and `from_calculation()` for creation or access directly from `Calculation.results` **Dictionary Access**: Use `results[element_label]` for element-specific results @@ -721,6 +735,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ Plots a heatmap visualization of a variable using imshow or time-based reshaping. @@ -754,6 +769,20 @@ def plot_heatmap( Supported timeframes: 'YS', 'MS', 'W', 'D', 'h', '15min', 'min' fill: Method to fill missing values after reshape: 'ffill' (forward fill) or 'bfill' (backward fill). Default is 'ffill'. + **plot_kwargs: Additional plotting customization options. + Common options: + + - **dpi** (int): Export resolution for saved plots. Default: 300. + + For heatmaps specifically: + + - **vmin** (float): Minimum value for color scale (both engines). + - **vmax** (float): Maximum value for color scale (both engines). + + For Matplotlib heatmaps: + + - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow (e.g., interpolation, aspect). + - **cbar_kwargs** (dict): Additional kwargs for colorbar customization. Examples: Direct imshow mode (default): @@ -794,6 +823,18 @@ def plot_heatmap( ... animate_by='period', ... reshape_time=('D', 'h'), ... ) + + High-resolution export with custom color range: + + >>> results.plot_heatmap('Battery|charge_state', save=True, dpi=600, vmin=0, vmax=100) + + Matplotlib heatmap with custom imshow settings: + + >>> results.plot_heatmap( + ... 'Boiler(Q_th)|flow_rate', + ... engine='matplotlib', + ... imshow_kwargs={'interpolation': 'bilinear', 'aspect': 'auto'}, + ... ) """ # Delegate to module-level plot_heatmap function return plot_heatmap( @@ -814,6 +855,7 @@ def plot_heatmap( heatmap_timeframes=heatmap_timeframes, heatmap_timesteps_per_frame=heatmap_timesteps_per_frame, color_map=color_map, + **plot_kwargs, ) def plot_network( @@ -994,6 +1036,7 @@ def plot_node_balance( facet_cols: int = 3, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """ Plots the node balance of the Component or Bus with optional faceting and animation. @@ -1021,6 +1064,27 @@ def plot_node_balance( animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through dimension values. Only one dimension can be animated. Ignored if not found. facet_cols: Number of columns in the facet grid layout (default: 3). + **plot_kwargs: Additional plotting customization options passed to underlying plotting functions. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + + **For Plotly engine** (`engine='plotly'`): + + - Any Plotly Express parameter for px.bar()/px.line()/px.area() + Example: `range_y=[0, 100]`, `line_shape='linear'` + + **For Matplotlib engine** (`engine='matplotlib'`): + + - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`. + Example: `plot_kwargs={'linewidth': 3, 'alpha': 0.7, 'edgecolor': 'black'}` + + See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib` + for complete parameter reference. + + Note: For Plotly, you can further customize the returned figure using `fig.update_traces()` + and `fig.update_layout()` after calling this method. Examples: Basic plot (current behavior): @@ -1052,6 +1116,25 @@ def plot_node_balance( Time range selection (summer months only): >>> results['Boiler'].plot_node_balance(select={'time': slice('2024-06', '2024-08')}, facet_by='scenario') + + High-resolution export for publication: + + >>> results['Boiler'].plot_node_balance(engine='matplotlib', save='figure.png', dpi=600) + + Plotly Express customization (e.g., set y-axis range): + + >>> results['Boiler'].plot_node_balance(range_y=[0, 100]) + + Custom matplotlib appearance: + + >>> results['Boiler'].plot_node_balance(engine='matplotlib', plot_kwargs={'linewidth': 3, 'alpha': 0.7}) + + Further customize Plotly figure after creation: + + >>> fig = results['Boiler'].plot_node_balance(mode='line', show=False) + >>> fig.update_traces(line={'width': 5, 'dash': 'dot'}) + >>> fig.update_layout(template='plotly_dark', width=1200, height=600) + >>> fig.show() """ # Handle deprecated indexer parameter if indexer is not None: @@ -1073,8 +1156,11 @@ def plot_node_balance( if engine not in {'plotly', 'matplotlib'}: raise ValueError(f'Engine "{engine}" not supported. Use one of ["plotly", "matplotlib"]') + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + # Don't pass select/indexer to node_balance - we'll apply it afterwards - ds = self.node_balance(with_last_timestep=True, unit_type=unit_type, drop_suffix=drop_suffix) + ds = self.node_balance(with_last_timestep=False, unit_type=unit_type, drop_suffix=drop_suffix) ds, suffix_parts = _apply_selection_to_data(ds, select=select, drop=True) @@ -1101,14 +1187,17 @@ def plot_node_balance( mode=mode, title=title, facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) default_filetype = '.html' else: figure_like = plotting.with_matplotlib( - ds.to_dataframe(), + ds, colors=colors, mode=mode, title=title, + **plot_kwargs, ) default_filetype = '.png' @@ -1119,6 +1208,7 @@ def plot_node_balance( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def plot_node_balance_pie( @@ -1132,6 +1222,7 @@ def plot_node_balance_pie( select: dict[FlowSystemDimensions, Any] | None = None, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, list[plt.Axes]]: """Plot pie chart of flow hours distribution. @@ -1151,6 +1242,17 @@ def plot_node_balance_pie( engine: Plotting engine ('plotly' or 'matplotlib'). select: Optional data selection dict. Supports single values, lists, slices, and index arrays. Use this to select specific scenario/period before creating the pie chart. + **plot_kwargs: Additional plotting customization options. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + - **hover_template** (str): Hover text template (Plotly only). + Example: `hover_template='%{label}: %{value} (%{percent})'` + - **text_position** (str): Text position ('inside', 'outside', 'auto'). + - **hole** (float): Size of donut hole (0.0 to 1.0). + + See :func:`flixopt.plotting.dual_pie_with_plotly` for complete reference. Examples: Basic usage (auto-selects first scenario/period if present): @@ -1160,6 +1262,14 @@ def plot_node_balance_pie( Explicitly select a scenario and period: >>> results['Bus'].plot_node_balance_pie(select={'scenario': 'high_demand', 'period': 2030}) + + Create a donut chart with custom hover text: + + >>> results['Bus'].plot_node_balance_pie(hole=0.4, hover_template='%{label}: %{value:.2f} (%{percent})') + + High-resolution export: + + >>> results['Bus'].plot_node_balance_pie(save='figure.png', dpi=600) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1178,6 +1288,9 @@ def plot_node_balance_pie( ) select = indexer + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + inputs = sanitize_dataset( ds=self.solution[self.inputs] * self._calculation_results.hours_per_timestep, threshold=1e-5, @@ -1193,8 +1306,9 @@ def plot_node_balance_pie( drop_suffix='|', ) - inputs, suffix_parts = _apply_selection_to_data(inputs, select=select, drop=True) - outputs, suffix_parts = _apply_selection_to_data(outputs, select=select, drop=True) + inputs, suffix_parts_in = _apply_selection_to_data(inputs, select=select, drop=True) + outputs, suffix_parts_out = _apply_selection_to_data(outputs, select=select, drop=True) + suffix_parts = suffix_parts_in + suffix_parts_out # Sum over time dimension inputs = inputs.sum('time') @@ -1204,7 +1318,7 @@ def plot_node_balance_pie( # Pie charts need scalar data, so we automatically reduce extra dimensions extra_dims_inputs = [dim for dim in inputs.dims if dim != 'time'] extra_dims_outputs = [dim for dim in outputs.dims if dim != 'time'] - extra_dims = list(set(extra_dims_inputs + extra_dims_outputs)) + extra_dims = sorted(set(extra_dims_inputs + extra_dims_outputs)) if extra_dims: auto_select = {} @@ -1222,27 +1336,28 @@ def plot_node_balance_pie( f'Use select={{"{dim}": value}} to choose a different value.' ) - # Apply auto-selection - inputs = inputs.sel(auto_select) - outputs = outputs.sel(auto_select) + # Apply auto-selection only for coords present in each dataset + inputs = inputs.sel({k: v for k, v in auto_select.items() if k in inputs.coords}) + outputs = outputs.sel({k: v for k, v in auto_select.items() if k in outputs.coords}) # Update suffix with auto-selected values auto_suffix_parts = [f'{dim}={val}' for dim, val in auto_select.items()] suffix_parts.extend(auto_suffix_parts) - suffix = '--' + '-'.join(suffix_parts) if suffix_parts else '' + suffix = '--' + '-'.join(sorted(set(suffix_parts))) if suffix_parts else '' title = f'{self.label} (total flow hours){suffix}' if engine == 'plotly': figure_like = plotting.dual_pie_with_plotly( - data_left=inputs.to_pandas(), - data_right=outputs.to_pandas(), + data_left=inputs, + data_right=outputs, colors=colors, title=title, text_info=text_info, subtitles=('Inputs', 'Outputs'), legend_title='Flows', lower_percentage_group=lower_percentage_group, + **plot_kwargs, ) default_filetype = '.html' elif engine == 'matplotlib': @@ -1255,6 +1370,7 @@ def plot_node_balance_pie( subtitles=('Inputs', 'Outputs'), legend_title='Flows', lower_percentage_group=lower_percentage_group, + **plot_kwargs, ) default_filetype = '.png' else: @@ -1267,6 +1383,7 @@ def plot_node_balance_pie( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def node_balance( @@ -1373,6 +1490,7 @@ def plot_charge_state( facet_cols: int = 3, # Deprecated parameter (kept for backwards compatibility) indexer: dict[FlowSystemDimensions, Any] | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure: """Plot storage charge state over time, combined with the node balance with optional faceting and animation. @@ -1389,6 +1507,26 @@ def plot_charge_state( animate_by: Dimension to animate over (Plotly only). Creates animation frames that cycle through dimension values. Only one dimension can be animated. Ignored if not found. facet_cols: Number of columns in the facet grid layout (default: 3). + **plot_kwargs: Additional plotting customization options passed to underlying plotting functions. + + Common options: + + - **dpi** (int): Export resolution in dots per inch. Default: 300. + + **For Plotly engine:** + + - Any Plotly Express parameter for px.bar()/px.line()/px.area() + Example: `range_y=[0, 100]`, `line_shape='linear'` + + **For Matplotlib engine:** + + - **plot_kwargs** (dict): Customize plot via `ax.bar()` or `ax.step()`. + + See :func:`flixopt.plotting.with_plotly` and :func:`flixopt.plotting.with_matplotlib` + for complete parameter reference. + + Note: For Plotly, you can further customize the returned figure using `fig.update_traces()` + and `fig.update_layout()` after calling this method. Raises: ValueError: If component is not a storage. @@ -1409,6 +1547,16 @@ def plot_charge_state( Facet by scenario AND animate by period: >>> results['Storage'].plot_charge_state(facet_by='scenario', animate_by='period') + + Custom layout after creation: + + >>> fig = results['Storage'].plot_charge_state(show=False) + >>> fig.update_layout(template='plotly_dark', height=800) + >>> fig.show() + + High-resolution export: + + >>> results['Storage'].plot_charge_state(save='storage.png', dpi=600) """ # Handle deprecated indexer parameter if indexer is not None: @@ -1427,11 +1575,17 @@ def plot_charge_state( ) select = indexer + # Extract dpi for export_figure + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + + # Extract charge state line color (for overlay customization) + overlay_color = plot_kwargs.pop('charge_state_line_color', 'black') + if not self.is_storage: raise ValueError(f'Cant plot charge_state. "{self.label}" is not a storage') # Get node balance and charge state - ds = self.node_balance(with_last_timestep=True) + ds = self.node_balance(with_last_timestep=True).fillna(0) charge_state_da = self.charge_state # Apply select filtering @@ -1451,11 +1605,12 @@ def plot_charge_state( mode=mode, title=title, facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) - # Create a dataset with just charge_state and plot it as lines - # This ensures proper handling of facets and animation - charge_state_ds = charge_state_da.to_dataset(name=self._charge_state) + # Prepare charge_state as Dataset for plotting + charge_state_ds = xr.Dataset({self._charge_state: charge_state_da}) # Plot charge_state with mode='line' to get Scatter traces charge_state_fig = plotting.with_plotly( @@ -1466,6 +1621,8 @@ def plot_charge_state( mode='line', # Always line for charge_state title='', # No title needed for this temp figure facet_cols=facet_cols, + xlabel='Time in h', + **plot_kwargs, ) # Add charge_state traces to the main figure @@ -1473,6 +1630,7 @@ def plot_charge_state( for trace in charge_state_fig.data: trace.line.width = 2 # Make charge_state line more prominent trace.line.shape = 'linear' # Smooth line for charge state (not stepped like flows) + trace.line.color = overlay_color figure_like.add_trace(trace) # Also add traces from animation frames if they exist @@ -1484,6 +1642,7 @@ def plot_charge_state( for trace in frame.data: trace.line.width = 2 trace.line.shape = 'linear' # Smooth line for charge state + trace.line.color = overlay_color figure_like.frames[i].data = figure_like.frames[i].data + (trace,) default_filetype = '.html' @@ -1497,10 +1656,11 @@ def plot_charge_state( ) # For matplotlib, plot flows (node balance), then add charge_state as line fig, ax = plotting.with_matplotlib( - ds.to_dataframe(), + ds, colors=colors, mode=mode, title=title, + **plot_kwargs, ) # Add charge_state as a line overlay @@ -1510,9 +1670,18 @@ def plot_charge_state( charge_state_df.values.flatten(), label=self._charge_state, linewidth=2, - color='black', + color=overlay_color, + ) + # Recreate legend with the same styling as with_matplotlib + handles, labels = ax.get_legend_handles_labels() + ax.legend( + handles, + labels, + loc='upper center', + bbox_to_anchor=(0.5, -0.15), + ncol=5, + frameon=False, ) - ax.legend() fig.tight_layout() figure_like = fig, ax @@ -1525,6 +1694,7 @@ def plot_charge_state( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) def node_balance_with_charge_state( @@ -1810,6 +1980,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ) -> plotly.graph_objs.Figure | tuple[plt.Figure, plt.Axes]: """Plot heatmap of variable solution across segments. @@ -1830,6 +2001,17 @@ def plot_heatmap( heatmap_timeframes: (Deprecated) Use reshape_time instead. heatmap_timesteps_per_frame: (Deprecated) Use reshape_time instead. color_map: (Deprecated) Use colors instead. + **plot_kwargs: Additional plotting customization options. + Common options: + + - **dpi** (int): Export resolution for saved plots. Default: 300. + - **vmin** (float): Minimum value for color scale. + - **vmax** (float): Maximum value for color scale. + + For Matplotlib heatmaps: + + - **imshow_kwargs** (dict): Additional kwargs for matplotlib's imshow. + - **cbar_kwargs** (dict): Additional kwargs for colorbar customization. Returns: Figure object. @@ -1884,6 +2066,7 @@ def plot_heatmap( animate_by=animate_by, facet_cols=facet_cols, fill=fill, + **plot_kwargs, ) def to_file(self, folder: str | pathlib.Path | None = None, name: str | None = None, compression: int = 5): @@ -1933,6 +2116,7 @@ def plot_heatmap( heatmap_timeframes: Literal['YS', 'MS', 'W', 'D', 'h', '15min', 'min'] | None = None, heatmap_timesteps_per_frame: Literal['W', 'D', 'h', '15min', 'min'] | None = None, color_map: str | None = None, + **plot_kwargs: Any, ): """Plot heatmap visualization with support for multi-variable, faceting, and animation. @@ -2087,6 +2271,9 @@ def plot_heatmap( timeframes, timesteps_per_frame = reshape_time title += f' ({timeframes} vs {timesteps_per_frame})' + # Extract dpi before passing to plotting functions + dpi = plot_kwargs.pop('dpi', None) # None uses CONFIG.Plotting.default_dpi + # Plot with appropriate engine if engine == 'plotly': figure_like = plotting.heatmap_with_plotly( @@ -2098,6 +2285,7 @@ def plot_heatmap( facet_cols=facet_cols, reshape_time=reshape_time, fill=fill, + **plot_kwargs, ) default_filetype = '.html' elif engine == 'matplotlib': @@ -2107,6 +2295,7 @@ def plot_heatmap( title=title, reshape_time=reshape_time, fill=fill, + **plot_kwargs, ) default_filetype = '.png' else: @@ -2123,6 +2312,7 @@ def plot_heatmap( user_path=None if isinstance(save, bool) else pathlib.Path(save), show=show, save=True if save else False, + dpi=dpi, ) diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py new file mode 100644 index 000000000..141623cae --- /dev/null +++ b/tests/test_plotting_api.py @@ -0,0 +1,138 @@ +"""Smoke tests for plotting API robustness improvements.""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from flixopt import plotting + + +@pytest.fixture +def sample_dataset(): + """Create a sample xarray Dataset for testing.""" + rng = np.random.default_rng(0) + time = np.arange(10) + data = xr.Dataset( + { + 'var1': (['time'], rng.random(10)), + 'var2': (['time'], rng.random(10)), + 'var3': (['time'], rng.random(10)), + }, + coords={'time': time}, + ) + return data + + +@pytest.fixture +def sample_dataframe(): + """Create a sample pandas DataFrame for testing.""" + rng = np.random.default_rng(1) + time = np.arange(10) + df = pd.DataFrame({'var1': rng.random(10), 'var2': rng.random(10), 'var3': rng.random(10)}, index=time) + df.index.name = 'time' + return df + + +def test_kwargs_passthrough_plotly(sample_dataset): + """Test that px_kwargs are passed through and figure can be customized after creation.""" + # Test that px_kwargs are passed through + fig = plotting.with_plotly( + sample_dataset, + mode='line', + range_y=[0, 100], + ) + assert list(fig.layout.yaxis.range) == [0, 100] + + # Test that figure can be customized after creation + fig.update_traces(line={'width': 5}) + fig.update_layout(width=1200, height=600) + assert fig.layout.width == 1200 + assert fig.layout.height == 600 + assert all(getattr(t, 'line', None) and t.line.width == 5 for t in fig.data) + + +def test_dataframe_support_plotly(sample_dataframe): + """Test that DataFrames are accepted by plotting functions.""" + fig = plotting.with_plotly(sample_dataframe, mode='line') + assert fig is not None + + +def test_data_validation_non_numeric(): + """Test that validation catches non-numeric data.""" + data = xr.Dataset({'var1': (['time'], ['a', 'b', 'c'])}, coords={'time': [0, 1, 2]}) + + with pytest.raises(TypeError, match='non-?numeric'): + plotting.with_plotly(data) + + +def test_ensure_dataset_invalid_type(): + """Test that invalid types raise error via the public API.""" + with pytest.raises(TypeError, match='xr\\.Dataset|pd\\.DataFrame'): + plotting.with_plotly([1, 2, 3], mode='line') + + +@pytest.mark.parametrize( + 'engine,mode,data_type', + [ + *[ + (e, m, dt) + for e in ['plotly', 'matplotlib'] + for m in ['stacked_bar', 'line', 'area', 'grouped_bar'] + for dt in ['dataset', 'dataframe', 'series'] + if not (e == 'matplotlib' and m in ['area', 'grouped_bar']) + ], + ], +) +def test_all_data_types_and_modes(engine, mode, data_type): + """Test that Dataset, DataFrame, and Series work with all plotting modes.""" + time = pd.date_range('2020-01-01', periods=5, freq='h') + + data = { + 'dataset': xr.Dataset( + {'A': (['time'], [1, 2, 3, 4, 5]), 'B': (['time'], [5, 4, 3, 2, 1])}, coords={'time': time} + ), + 'dataframe': pd.DataFrame({'A': [1, 2, 3, 4, 5], 'B': [5, 4, 3, 2, 1]}, index=time), + 'series': pd.Series([1, 2, 3, 4, 5], index=time, name='A'), + }[data_type] + + if engine == 'plotly': + fig = plotting.with_plotly(data, mode=mode) + assert fig is not None and len(fig.data) > 0 + else: + fig, ax = plotting.with_matplotlib(data, mode=mode) + assert fig is not None and ax is not None + + +@pytest.mark.parametrize( + 'engine,data_type', [(e, dt) for e in ['plotly', 'matplotlib'] for dt in ['dataset', 'dataframe', 'series']] +) +def test_pie_plots(engine, data_type): + """Test pie charts with all data types, including automatic summing.""" + time = pd.date_range('2020-01-01', periods=5, freq='h') + + # Single-value data + single_data = { + 'dataset': xr.Dataset({'A': xr.DataArray(10), 'B': xr.DataArray(20), 'C': xr.DataArray(30)}), + 'dataframe': pd.DataFrame({'A': [10], 'B': [20], 'C': [30]}), + 'series': pd.Series({'A': 10, 'B': 20, 'C': 30}), + }[data_type] + + # Multi-dimensional data (for summing test) + multi_data = { + 'dataset': xr.Dataset( + {'A': (['time'], [1, 2, 3, 4, 5]), 'B': (['time'], [5, 5, 5, 5, 5])}, coords={'time': time} + ), + 'dataframe': pd.DataFrame({'A': [1, 2, 3, 4, 5], 'B': [5, 5, 5, 5, 5]}, index=time), + 'series': pd.Series([1, 2, 3, 4, 5], index=time, name='A'), + }[data_type] + + for data in [single_data, multi_data]: + if engine == 'plotly': + fig = plotting.dual_pie_with_plotly(data, data) + assert fig is not None and len(fig.data) >= 2 + if data is multi_data and data_type != 'series': + assert sum(fig.data[0].values) == pytest.approx(40) + else: + fig, axes = plotting.dual_pie_with_matplotlib(data, data) + assert fig is not None and len(axes) == 2