From 9ed035dfb731c5aeab072dd8b095d6f0c3eee65b Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Sun, 7 Sep 2025 20:52:23 +0200 Subject: [PATCH 1/9] feat: add sos constraints --- linopy/io.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/linopy/io.py b/linopy/io.py index 56fe033d..6663c078 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -549,6 +549,13 @@ def to_lp_file( slice_size=slice_size, explicit_coordinate_names=explicit_coordinate_names, ) + sos_to_file( + m, + f=f, + progress=progress, + slice_size=slice_size, + explicit_coordinate_names=explicit_coordinate_names, + ) f.write(b"end\n") logger.info(f" Writing time: {round(time.time() - start, 2)}s") From 79ff38241ae1bb474974bb6af52a48525b6ff0f1 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Sun, 7 Sep 2025 23:15:23 +0200 Subject: [PATCH 2/9] Add documentation (claude) --- doc/sos-constraints.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/sos-constraints.rst b/doc/sos-constraints.rst index 37dd72d2..95e04942 100644 --- a/doc/sos-constraints.rst +++ b/doc/sos-constraints.rst @@ -309,3 +309,4 @@ See Also - :doc:`creating-variables`: Creating variables with coordinates - :doc:`creating-constraints`: Adding regular constraints - :doc:`user-guide`: General linopy usage patterns +- Example notebook: ``examples/sos-constraints-example.ipynb`` From a248863da63d0d487a7ff1d71ff47c7fe29d8cb3 Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 2 Dec 2025 15:21:52 +0100 Subject: [PATCH 3/9] fix: type annotations and docs for SOS constraints - Add return type annotations to add_sos_constraints and add_sos - Fix iterate_slices call with list instead of tuple - Fix groupby for multi-dimensional SOS using stack/unstack - Add defensive check in remove_sos_constraints - Clarify direct API support (Gurobi only) in docs --- doc/sos-constraints.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/sos-constraints.rst b/doc/sos-constraints.rst index 95e04942..37dd72d2 100644 --- a/doc/sos-constraints.rst +++ b/doc/sos-constraints.rst @@ -309,4 +309,3 @@ See Also - :doc:`creating-variables`: Creating variables with coordinates - :doc:`creating-constraints`: Adding regular constraints - :doc:`user-guide`: General linopy usage patterns -- Example notebook: ``examples/sos-constraints-example.ipynb`` From d206b812d8100b93ad9f9f6bc63602f79e555ef0 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 18 Dec 2025 11:21:31 +0100 Subject: [PATCH 4/9] fix: handle masked SOS + safer simplify --- linopy/expressions.py | 33 ++++++++++++++++++++++----------- linopy/io.py | 29 ++++++++++++++++++++++++----- test/test_linear_expression.py | 4 ++++ test/test_sos_constraints.py | 16 ++++++++++++++++ 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 10e243de..17067d5a 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1496,7 +1496,7 @@ def _simplify_row(vars_row: np.ndarray, coeffs_row: np.ndarray) -> np.ndarray: # Filter out invalid entries mask = (vars_row != -1) & (coeffs_row != 0) & ~np.isnan(coeffs_row) - valid_vars = vars_row[mask] + valid_vars = vars_row[mask].astype(np.int64, copy=False) valid_coeffs = coeffs_row[mask] if len(valid_vars) == 0: @@ -1508,15 +1508,11 @@ def _simplify_row(vars_row: np.ndarray, coeffs_row: np.ndarray) -> np.ndarray: ] ) - # Use bincount to sum coefficients for each variable ID efficiently - max_var = int(valid_vars.max()) - summed = np.bincount( - valid_vars, weights=valid_coeffs, minlength=max_var + 1 - ) - - # Get non-zero entries - unique_vars = np.where(summed != 0)[0] - unique_coeffs = summed[unique_vars] + unique_vars, inverse = np.unique(valid_vars, return_inverse=True) + summed = np.bincount(inverse, weights=valid_coeffs) + nonzero = summed != 0 + unique_vars = unique_vars[nonzero] + unique_coeffs = summed[nonzero] # Pad to match input length result_vars = np.full(input_len, -1, dtype=float) @@ -1696,6 +1692,17 @@ def from_tuples( This is the same as calling ``10*x + y`` + 1 but a bit more performant. """ + if model is None: + for t in tuples: + if isinstance(t, tuple) and len(t) == 2: + _, var = t + if isinstance(var, variables.ScalarVariable): + model = var.model + break + if isinstance(var, variables.Variable): + model = var.model + break + def process_one( t: tuple[ConstantLike, str | Variable | ScalarVariable] | tuple[ConstantLike] @@ -1731,7 +1738,11 @@ def process_one( raise TypeError("Expected variable as second element of tuple.") if model is None: - model = expr.model # TODO: Ensure equality of models + model = expr.model + elif expr.model is not model: + raise ValueError( + "All variables in tuples must belong to the same model." + ) return expr if len(t) == 1: diff --git a/linopy/io.py b/linopy/io.py index 6663c078..87f66a4c 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -357,7 +357,7 @@ def sos_to_file( Write out SOS constraints of a model to an LP file. """ names = m.variables.sos - if not len(list(names)): + if not len(names): return print_variable, _ = get_printers( @@ -380,11 +380,24 @@ def sos_to_file( other_dims = [dim for dim in var.labels.dims if dim != sos_dim] for var_slice in var.iterate_slices(slice_size, other_dims): ds = var_slice.labels.to_dataset() - ds["sos_labels"] = ds["labels"].isel({sos_dim: 0}) ds["weights"] = ds.coords[sos_dim] + sos_labels = ( + ds["labels"] + .where(ds["labels"] != -1) + .min(dim=sos_dim, skipna=True) + .fillna(-1) + .astype(int) + ) + ds["sos_labels"] = sos_labels + df = to_polars(ds) + df = df.filter(pl.col("labels").ne(-1) & pl.col("sos_labels").ne(-1)) + if df.height == 0: + continue + + df = df.sort(["sos_labels", "weights"]) - df = df.group_by("sos_labels").agg( + df = df.group_by("sos_labels", maintain_order=True).agg( pl.concat_str( *print_variable(pl.col("labels")), pl.lit(":"), pl.col("weights") ) @@ -394,7 +407,7 @@ def sos_to_file( columns = [ pl.lit("s"), - pl.col("sos_labels"), + pl.col("sos_labels").cast(pl.Int64), pl.lit(f": S{sos_type} :: "), pl.col("var_weights"), ] @@ -787,7 +800,13 @@ def add_sos(s: xr.DataArray, sos_type: int, sos_dim: str) -> None: s = s.squeeze() indices = s.values.flatten().tolist() weights = s.coords[sos_dim].values.tolist() - model.addSOS(sos_type, x[indices].tolist(), weights) + pairs = [(i, w) for i, w in zip(indices, weights) if i != -1] + if not pairs: + return + indices_filtered, weights_filtered = zip(*pairs) + model.addSOS( + sos_type, x[list(indices_filtered)].tolist(), list(weights_filtered) + ) others = [dim for dim in var.labels.dims if dim != sos_dim] if not others: diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index a75ace3f..897f50f5 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1102,6 +1102,10 @@ def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None: expr5 = LinearExpression.from_tuples(1, model=x.model) assert isinstance(expr5, LinearExpression) + expr6 = LinearExpression.from_tuples(1, (10, x), (1, y)) + assert isinstance(expr6, LinearExpression) + assert (expr6.const == 1).all() + def test_linear_expression_from_tuples_bad_calls( m: Model, x: Variable, y: Variable diff --git a/test/test_sos_constraints.py b/test/test_sos_constraints.py index 5d94162e..225cc470 100644 --- a/test/test_sos_constraints.py +++ b/test/test_sos_constraints.py @@ -60,6 +60,22 @@ def test_sos_constraints_written_to_lp(tmp_path: Path) -> None: assert "3.5" in content +def test_sos_constraints_written_to_lp_with_mask(tmp_path: Path) -> None: + m = Model() + breakpoints = pd.Index([0.0, 1.5, 3.5], name="bp") + mask = pd.Series([False, True, True], index=breakpoints) + lambdas = m.add_variables(coords=[breakpoints], name="lambda", mask=mask) + m.add_sos_constraints(lambdas, sos_type=2, sos_dim="bp") + + fn = tmp_path / "sos_mask.lp" + m.to_file(fn, io_api="lp") + content = fn.read_text() + + sos_section = content.split("\nsos\n", 1)[1].split("\nend\n", 1)[0] + assert "s-1" not in sos_section + assert "0.0" not in sos_section + + @pytest.mark.skipif("gurobi" not in available_solvers, reason="Gurobipy not installed") def test_to_gurobipy_emits_sos_constraints() -> None: gurobipy = pytest.importorskip("gurobipy") From 9caea89edc66bf09c11a740b9d2485d785b92676 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:32:23 +0100 Subject: [PATCH 5/9] Add add_piecewise_constraint() method Files Modified MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. linopy/model.py - Added add_piecewise_constraint method (lines 593-827) 2. test/test_piecewise_constraints.py - Created comprehensive test suite (21 tests) Method Signature def add_piecewise_constraint( self, vars: Variable | dict[str, Variable], breakpoints: DataArray, link_dim: str | None = None, dim: str = "breakpoint", mask: DataArray | None = None, name: str | None = None, ) -> Constraint Features Implemented - Single Variable Support: Pass a single Variable directly - Multiple Variables: Pass a dict of Variables with link_dim to link them - Auto-detect link_dim: When vars is a dict, automatically detects which breakpoints dimension matches the dict keys - NaN Masking: Auto-detects masked values from NaN in breakpoints - Explicit Masking: User-provided mask support - Multi-dimensional: Works with variables that have additional coordinates (generators, timesteps, etc.) - Auto-naming: Generates names like pwl0, pwl1 automatically - Custom naming: User can specify custom names SOS2 Formulation The method creates: 1. Lambda (λ) variables with bounds [0, 1] for each breakpoint 2. SOS2 constraint ensuring at most two adjacent λ values are non-zero 3. Convexity constraint: Σλ = 1 4. Linking constraints: var = Σ(λ × breakpoint) for each variable Test Results All 21 tests pass including: - Basic single/multiple variable cases - Auto-detection of link_dim - Masking (NaN and explicit) - Multi-dimensional cases - Input validation errors - LP file output - Solver integration tests with Gurobi --- linopy/model.py | 236 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) diff --git a/linopy/model.py b/linopy/model.py index 657b2d45..c0d413ed 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -591,6 +591,242 @@ def add_sos_constraints( variable.attrs.update(sos_type=sos_type, sos_dim=sos_dim) + def add_piecewise_constraint( + self, + vars: Variable | dict[str, Variable], + breakpoints: DataArray, + link_dim: str | None = None, + dim: str = "breakpoint", + mask: DataArray | None = None, + name: str | None = None, + ) -> Constraint: + """ + Add a piecewise linear constraint using SOS2 formulation. + + This method creates a piecewise linear constraint that links one or more + variables together via a set of breakpoints. It uses the SOS2 (Special + Ordered Set of type 2) formulation with lambda (interpolation) variables. + + The SOS2 formulation ensures that at most two adjacent lambda variables + can be non-zero, effectively selecting a segment of the piecewise linear + function. + + Parameters + ---------- + vars : Variable or dict[str, Variable] + The variable(s) to be linked by the piecewise constraint. + - If a single Variable is passed, the breakpoints directly specify + the piecewise points for that variable. + - If a dict is passed, the keys must match coordinates in `link_dim` + of the breakpoints, allowing multiple variables to be linked. + breakpoints : xr.DataArray + The breakpoint values defining the piecewise linear function. + Must have `dim` as one of its dimensions. If `vars` is a dict, + must also have `link_dim` dimension with coordinates matching the + dict keys. + link_dim : str, optional + The dimension in breakpoints that links to different variables. + Required when `vars` is a dict. If None and `vars` is a dict, + will attempt to auto-detect from breakpoints dimensions. + dim : str, default "breakpoint" + The dimension in breakpoints that represents the breakpoint index. + This dimension's coordinates must be numeric (used as SOS2 weights). + mask : xr.DataArray, optional + Boolean mask indicating which piecewise constraints are valid. + If None, auto-detected from NaN values in breakpoints. + name : str, optional + Base name for the generated variables and constraints. + If None, auto-generates names like "pwl0", "pwl1", etc. + + Returns + ------- + Constraint + The convexity constraint (sum of lambda = 1). Lambda variables + and other constraints can be accessed via: + - `model.variables[f"{name}_lambda"]` + - `model.constraints[f"{name}_convex"]` + - `model.constraints[f"{name}_link_{var_name}"]` + + Raises + ------ + ValueError + If vars is not a Variable or dict of Variables. + If breakpoints doesn't have the required dim dimension. + If link_dim cannot be auto-detected when vars is a dict. + If link_dim coordinates don't match dict keys. + If dim coordinates are not numeric. + + Examples + -------- + Single variable piecewise constraint: + + >>> m = Model() + >>> x = m.add_variables(name="x") + >>> breakpoints = xr.DataArray([0, 10, 50, 100], dims=["bp"]) + >>> m.add_piecewise_constraint(x, breakpoints, dim="bp") + + Multiple linked variables (e.g., power-efficiency curve): + + >>> m = Model() + >>> generators = ["gen1", "gen2"] + >>> power = m.add_variables(coords=[generators], name="power") + >>> efficiency = m.add_variables(coords=[generators], name="efficiency") + >>> breakpoints = xr.DataArray( + ... [[0, 50, 100], [0.8, 0.95, 0.9]], + ... coords={"var": ["power", "efficiency"], "bp": [0, 1, 2]}, + ... ) + >>> m.add_piecewise_constraint( + ... {"power": power, "efficiency": efficiency}, + ... breakpoints, + ... link_dim="var", + ... dim="bp", + ... ) + + Notes + ----- + The piecewise linear constraint is formulated using SOS2 variables: + + 1. Lambda variables λ_i with bounds [0, 1] are created for each breakpoint + 2. SOS2 constraint ensures at most two adjacent λ_i can be non-zero + 3. Convexity constraint: Σ λ_i = 1 + 4. Linking constraints: var = Σ λ_i × breakpoint_i (for each variable) + """ + # Step 1: Input validation + if not isinstance(vars, Variable | dict): + raise ValueError( + f"'vars' must be a Variable or dict of Variables, got {type(vars)}" + ) + + if dim not in breakpoints.dims: + raise ValueError( + f"breakpoints must have dimension '{dim}', " + f"but only has dimensions {list(breakpoints.dims)}" + ) + + # Validate dim coordinates are numeric (required for SOS2 weights) + if not pd.api.types.is_numeric_dtype(breakpoints.coords[dim]): + raise ValueError( + f"Breakpoint dimension '{dim}' must have numeric coordinates " + f"for SOS2 weights, but got {breakpoints.coords[dim].dtype}" + ) + + # Step 2: Normalize vars to dict + if isinstance(vars, Variable): + vars_dict: dict[str, Variable] = {vars.name: vars} + single_var = True + else: + vars_dict = vars + single_var = False + + # Validate all variables exist in model + for var_name, var in vars_dict.items(): + if var.name not in self.variables: + raise ValueError(f"Variable '{var.name}' not found in model") + + # Step 3: Auto-detect or validate link_dim + if not single_var: + if link_dim is None: + # Try to auto-detect link_dim from breakpoints + for d in breakpoints.dims: + if d == dim: + continue + coords_set = set(str(c) for c in breakpoints.coords[d].values) + if coords_set == set(vars_dict.keys()): + link_dim = str(d) + break + if link_dim is None: + raise ValueError( + "Could not auto-detect link_dim. Please specify it explicitly. " + f"Breakpoint dimensions: {list(breakpoints.dims)}, " + f"variable keys: {list(vars_dict.keys())}" + ) + else: + # Validate link_dim exists and matches dict keys + if link_dim not in breakpoints.dims: + raise ValueError( + f"link_dim '{link_dim}' not found in breakpoints dimensions " + f"{list(breakpoints.dims)}" + ) + coords_set = set(str(c) for c in breakpoints.coords[link_dim].values) + if coords_set != set(vars_dict.keys()): + raise ValueError( + f"link_dim '{link_dim}' coordinates {coords_set} " + f"don't match variable keys {set(vars_dict.keys())}" + ) + + # Step 4: Compute mask from NaN values if not provided + if mask is None: + mask = ~breakpoints.isnull() + + # Step 5: Determine lambda coordinates (all dims except link_dim) + # Lambda has all dims from breakpoints except link_dim + excluded_dims = set() + if link_dim is not None: + excluded_dims.add(link_dim) + + lambda_dims = [d for d in breakpoints.dims if d not in excluded_dims] + lambda_coords = [ + pd.Index(breakpoints.coords[d].values, name=d) for d in lambda_dims + ] + + # Step 6: Generate names + if name is None: + # Find unused pwl name + i = 0 + while f"pwl{i}_lambda" in self.variables: + i += 1 + name = f"pwl{i}" + + lambda_name = f"{name}_lambda" + convex_name = f"{name}_convex" + + # Step 7: Compute lambda mask + # Lambda variable is valid if ANY of its breakpoints across link_dim are valid + if link_dim is not None: + lambda_mask = mask.any(dim=link_dim) + else: + # For single var case, use mask directly (collapsed along other dims if needed) + lambda_mask = mask + + # Step 8: Create lambda variables + lambda_var = self.add_variables( + lower=0, + upper=1, + coords=lambda_coords, + name=lambda_name, + mask=lambda_mask, + ) + + # Step 9: Add SOS2 constraint on lambda variables + self.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim) + + # Step 10: Add convexity constraint (sum of lambda = 1) + convex_con = self.add_constraints( + lambda_var.sum(dim=dim) == 1, + name=convex_name, + ) + + # Step 11: Add linking constraints for each variable + for var_name, var in vars_dict.items(): + # Get breakpoints for this variable + if link_dim is not None: + bp_for_var = breakpoints.sel({link_dim: var_name}) + else: + bp_for_var = breakpoints + + # Compute weighted sum: sum(lambda * breakpoints) + weighted_sum = (lambda_var * bp_for_var).sum(dim=dim) + + # Add linking constraint: var == weighted_sum + link_con_name = f"{name}_link_{var_name}" + self.add_constraints( + var == weighted_sum, + name=link_con_name, + ) + + # Step 12: Return the convexity constraint as the primary reference + return convex_con + def add_constraints( self, lhs: VariableLike From 8d2d2ac1240d4227236a5e439689a6d52ea523fb Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:40:07 +0100 Subject: [PATCH 6/9] Summary of Changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Performance Improvements 1. Single Variable Case Handled Directly: No longer normalizes to dict, handles the single variable case with a direct path. 2. Single Linking Constraint for Dict Case: Instead of creating N separate linking constraints (one per variable), now creates a single constraint that covers all variables using merge() to stack variable expressions along link_dim. Code Quality Improvements 1. Added Constants in linopy/constants.py: - PWL_LAMBDA_SUFFIX = "_lambda" - PWL_CONVEX_SUFFIX = "_convex" - PWL_LINK_SUFFIX = "_link" - DEFAULT_BREAKPOINT_DIM = "breakpoint" 2. Replaced magic strings with constants throughout the implementation and tests. Result The constraint structure is now: - Before: pwl0_link_power, pwl0_link_efficiency (separate constraints) - After: pwl0_link (single constraint covering all variables) Constraints: ['pwl0_convex', 'pwl0_link'] Constraint `pwl0_link` [var: 2, generator: 2]: [power, gen1]: +1 power[gen1] - 0 λ[0] - 50 λ[1] - 100 λ[2] - 150 λ[3] = 0 [power, gen2]: +1 power[gen2] - 0 λ[0] - 50 λ[1] - 100 λ[2] - 150 λ[3] = 0 [efficiency, gen1]: +1 efficiency[gen1] - 0 λ[0] - 0.85 λ[1] - 0.92 λ[2] - 0.88 λ[3] = 0 [efficiency, gen2]: +1 efficiency[gen2] - 0 λ[0] - 0.85 λ[1] - 0.92 λ[2] - 0.88 λ[3] = 0 --- linopy/constants.py | 6 ++ linopy/model.py | 210 ++++++++++++++++++++++++-------------------- 2 files changed, 121 insertions(+), 95 deletions(-) diff --git a/linopy/constants.py b/linopy/constants.py index 021a9a10..35e0367f 100644 --- a/linopy/constants.py +++ b/linopy/constants.py @@ -35,6 +35,12 @@ TERM_DIM = "_term" STACKED_TERM_DIM = "_stacked_term" + +# Piecewise linear constraint constants +PWL_LAMBDA_SUFFIX = "_lambda" +PWL_CONVEX_SUFFIX = "_convex" +PWL_LINK_SUFFIX = "_link" +DEFAULT_BREAKPOINT_DIM = "breakpoint" GROUPED_TERM_DIM = "_grouped_term" GROUP_DIM = "_group" FACTOR_DIM = "_factor" diff --git a/linopy/model.py b/linopy/model.py index c0d413ed..ed7b3fcf 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -35,9 +35,13 @@ to_path, ) from linopy.constants import ( + DEFAULT_BREAKPOINT_DIM, GREATER_EQUAL, HELPER_DIMS, LESS_EQUAL, + PWL_CONVEX_SUFFIX, + PWL_LAMBDA_SUFFIX, + PWL_LINK_SUFFIX, TERM_DIM, ModelStatus, TerminationCondition, @@ -47,6 +51,7 @@ LinearExpression, QuadraticExpression, ScalarLinearExpression, + merge, ) from linopy.io import ( to_block_files, @@ -596,7 +601,7 @@ def add_piecewise_constraint( vars: Variable | dict[str, Variable], breakpoints: DataArray, link_dim: str | None = None, - dim: str = "breakpoint", + dim: str = DEFAULT_BREAKPOINT_DIM, mask: DataArray | None = None, name: str | None = None, ) -> Constraint: @@ -645,7 +650,7 @@ def add_piecewise_constraint( and other constraints can be accessed via: - `model.variables[f"{name}_lambda"]` - `model.constraints[f"{name}_convex"]` - - `model.constraints[f"{name}_link_{var_name}"]` + - `model.constraints[f"{name}_link"]` Raises ------ @@ -691,12 +696,7 @@ def add_piecewise_constraint( 3. Convexity constraint: Σ λ_i = 1 4. Linking constraints: var = Σ λ_i × breakpoint_i (for each variable) """ - # Step 1: Input validation - if not isinstance(vars, Variable | dict): - raise ValueError( - f"'vars' must be a Variable or dict of Variables, got {type(vars)}" - ) - + # Input validation if dim not in breakpoints.dims: raise ValueError( f"breakpoints must have dimension '{dim}', " @@ -710,121 +710,141 @@ def add_piecewise_constraint( f"for SOS2 weights, but got {breakpoints.coords[dim].dtype}" ) - # Step 2: Normalize vars to dict + # Generate base name if not provided + if name is None: + i = 0 + while f"pwl{i}{PWL_LAMBDA_SUFFIX}" in self.variables: + i += 1 + name = f"pwl{i}" + + lambda_name = f"{name}{PWL_LAMBDA_SUFFIX}" + convex_name = f"{name}{PWL_CONVEX_SUFFIX}" + link_name = f"{name}{PWL_LINK_SUFFIX}" + + # Handle single Variable case directly (no dict normalization) if isinstance(vars, Variable): - vars_dict: dict[str, Variable] = {vars.name: vars} - single_var = True - else: - vars_dict = vars - single_var = False + if vars.name not in self.variables: + raise ValueError(f"Variable '{vars.name}' not found in model") + + # Compute mask from NaN values if not provided + if mask is None: + mask = ~breakpoints.isnull() + + # Lambda coordinates: all dims from breakpoints + lambda_dims = list(breakpoints.dims) + lambda_coords = [ + pd.Index(breakpoints.coords[d].values, name=d) for d in lambda_dims + ] + + # Create lambda variables + lambda_var = self.add_variables( + lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=mask + ) + + # Add SOS2 constraint + self.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim) + + # Add convexity constraint: sum(lambda) = 1 + convex_con = self.add_constraints( + lambda_var.sum(dim=dim) == 1, name=convex_name + ) + + # Add single linking constraint: var = sum(lambda * breakpoints) + weighted_sum = (lambda_var * breakpoints).sum(dim=dim) + self.add_constraints(vars == weighted_sum, name=link_name) + + return convex_con + + # Handle dict of Variables case + if not isinstance(vars, dict): + raise ValueError( + f"'vars' must be a Variable or dict of Variables, got {type(vars)}" + ) + + vars_dict = vars # Validate all variables exist in model for var_name, var in vars_dict.items(): if var.name not in self.variables: raise ValueError(f"Variable '{var.name}' not found in model") - # Step 3: Auto-detect or validate link_dim - if not single_var: + # Auto-detect or validate link_dim + if link_dim is None: + # Try to auto-detect link_dim from breakpoints + for d in breakpoints.dims: + if d == dim: + continue + coords_set = set(str(c) for c in breakpoints.coords[d].values) + if coords_set == set(vars_dict.keys()): + link_dim = str(d) + break if link_dim is None: - # Try to auto-detect link_dim from breakpoints - for d in breakpoints.dims: - if d == dim: - continue - coords_set = set(str(c) for c in breakpoints.coords[d].values) - if coords_set == set(vars_dict.keys()): - link_dim = str(d) - break - if link_dim is None: - raise ValueError( - "Could not auto-detect link_dim. Please specify it explicitly. " - f"Breakpoint dimensions: {list(breakpoints.dims)}, " - f"variable keys: {list(vars_dict.keys())}" - ) - else: - # Validate link_dim exists and matches dict keys - if link_dim not in breakpoints.dims: - raise ValueError( - f"link_dim '{link_dim}' not found in breakpoints dimensions " - f"{list(breakpoints.dims)}" - ) - coords_set = set(str(c) for c in breakpoints.coords[link_dim].values) - if coords_set != set(vars_dict.keys()): - raise ValueError( - f"link_dim '{link_dim}' coordinates {coords_set} " - f"don't match variable keys {set(vars_dict.keys())}" - ) + raise ValueError( + "Could not auto-detect link_dim. Please specify it explicitly. " + f"Breakpoint dimensions: {list(breakpoints.dims)}, " + f"variable keys: {list(vars_dict.keys())}" + ) + else: + # Validate link_dim exists and matches dict keys + if link_dim not in breakpoints.dims: + raise ValueError( + f"link_dim '{link_dim}' not found in breakpoints dimensions " + f"{list(breakpoints.dims)}" + ) + coords_set = set(str(c) for c in breakpoints.coords[link_dim].values) + if coords_set != set(vars_dict.keys()): + raise ValueError( + f"link_dim '{link_dim}' coordinates {coords_set} " + f"don't match variable keys {set(vars_dict.keys())}" + ) - # Step 4: Compute mask from NaN values if not provided + # Compute mask from NaN values if not provided if mask is None: mask = ~breakpoints.isnull() - # Step 5: Determine lambda coordinates (all dims except link_dim) - # Lambda has all dims from breakpoints except link_dim - excluded_dims = set() - if link_dim is not None: - excluded_dims.add(link_dim) - - lambda_dims = [d for d in breakpoints.dims if d not in excluded_dims] + # Lambda coordinates: all dims from breakpoints except link_dim + lambda_dims = [d for d in breakpoints.dims if d != link_dim] lambda_coords = [ pd.Index(breakpoints.coords[d].values, name=d) for d in lambda_dims ] - # Step 6: Generate names - if name is None: - # Find unused pwl name - i = 0 - while f"pwl{i}_lambda" in self.variables: - i += 1 - name = f"pwl{i}" - - lambda_name = f"{name}_lambda" - convex_name = f"{name}_convex" + # Lambda mask: valid if ANY breakpoint across link_dim is valid + lambda_mask = mask.any(dim=link_dim) - # Step 7: Compute lambda mask - # Lambda variable is valid if ANY of its breakpoints across link_dim are valid - if link_dim is not None: - lambda_mask = mask.any(dim=link_dim) - else: - # For single var case, use mask directly (collapsed along other dims if needed) - lambda_mask = mask - - # Step 8: Create lambda variables + # Create lambda variables lambda_var = self.add_variables( - lower=0, - upper=1, - coords=lambda_coords, - name=lambda_name, - mask=lambda_mask, + lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=lambda_mask ) - # Step 9: Add SOS2 constraint on lambda variables + # Add SOS2 constraint self.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim) - # Step 10: Add convexity constraint (sum of lambda = 1) + # Add convexity constraint: sum(lambda) = 1 convex_con = self.add_constraints( - lambda_var.sum(dim=dim) == 1, - name=convex_name, + lambda_var.sum(dim=dim) == 1, name=convex_name ) - # Step 11: Add linking constraints for each variable - for var_name, var in vars_dict.items(): - # Get breakpoints for this variable - if link_dim is not None: - bp_for_var = breakpoints.sel({link_dim: var_name}) - else: - bp_for_var = breakpoints + # Stack all variables into a single expression along link_dim for single constraint + # Get the link_dim coordinates in the order they appear in breakpoints + link_coords = list(breakpoints.coords[link_dim].values) - # Compute weighted sum: sum(lambda * breakpoints) - weighted_sum = (lambda_var * bp_for_var).sum(dim=dim) + # Convert each variable to a LinearExpression and assign link_dim coordinate + var_exprs = [] + for k in link_coords: + var = vars_dict[str(k)] + expr = var.to_linexpr() + # Expand dims to add link_dim coordinate + expr_data = expr.data.expand_dims({link_dim: [k]}) + var_exprs.append(LinearExpression(expr_data, self)) - # Add linking constraint: var == weighted_sum - link_con_name = f"{name}_link_{var_name}" - self.add_constraints( - var == weighted_sum, - name=link_con_name, - ) + # Concatenate all variable expressions along link_dim + stacked_vars_expr = merge(var_exprs, dim=link_dim) + + # Add single linking constraint for all variables: + # stacked_vars == (lambda * breakpoints).sum(dim=dim) + weighted_sum = (lambda_var * breakpoints).sum(dim=dim) + self.add_constraints(stacked_vars_expr == weighted_sum, name=link_name) - # Step 12: Return the convexity constraint as the primary reference return convex_con def add_constraints( From 1f95be83f863059604153a49782111394d6df6d6 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:45:38 +0100 Subject: [PATCH 7/9] Performance Optimizations 1. Single Variable Case: Use breakpoints.coords directly # Before: Creating pd.Index objects in a loop lambda_coords = [ pd.Index(breakpoints.coords[d].values, name=d) for d in lambda_dims ] # After: Pass coords directly lambda_var = self.add_variables( lower=0, upper=1, coords=breakpoints.coords, ... ) 2. Dict Case: Build stacked expression directly # Before: Multiple intermediate objects for k in link_coords: expr = var.to_linexpr() # Creates LinearExpression expr_data = expr.data.expand_dims({link_dim: [k]}) # Creates Dataset var_exprs.append(LinearExpression(expr_data, self)) # Creates another LinearExpression stacked_vars_expr = merge(var_exprs, dim=link_dim) # Merges all # After: Direct Dataset construction from labels labels_list = [] for k in link_coords: labels_list.append(var.labels.expand_dims({link_dim: [k]})) stacked_labels = xr.concat(labels_list, dim=link_dim) # Single concat # Build Dataset directly stacked_expr_data = Dataset({ "coeffs": xr.ones_like(stacked_labels, dtype=float).expand_dims(TERM_DIM), "vars": stacked_labels.expand_dims(TERM_DIM), "const": xr.zeros_like(...), }) stacked_vars_expr = LinearExpression(stacked_expr_data, self) # Single object 3. Combined validation with expression building Variable existence validation now happens in the same loop that collects labels, avoiding a separate validation pass. Key Benefits: - Fewer intermediate objects: Avoids creating N LinearExpression objects + merge - Direct Dataset construction: Builds the final structure in one shot - Single xr.concat call: Instead of multiple expand_dims + merge operations - Removed merge import: No longer needed --- linopy/model.py | 78 ++++++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index ed7b3fcf..dee82bc9 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -51,7 +51,6 @@ LinearExpression, QuadraticExpression, ScalarLinearExpression, - merge, ) from linopy.io import ( to_block_files, @@ -726,19 +725,17 @@ def add_piecewise_constraint( if vars.name not in self.variables: raise ValueError(f"Variable '{vars.name}' not found in model") - # Compute mask from NaN values if not provided + # Compute mask from NaN values only if not provided if mask is None: mask = ~breakpoints.isnull() - # Lambda coordinates: all dims from breakpoints - lambda_dims = list(breakpoints.dims) - lambda_coords = [ - pd.Index(breakpoints.coords[d].values, name=d) for d in lambda_dims - ] - - # Create lambda variables + # Create lambda variables using breakpoints coords directly lambda_var = self.add_variables( - lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=mask + lower=0, + upper=1, + coords=breakpoints.coords, + name=lambda_name, + mask=mask, ) # Add SOS2 constraint @@ -762,27 +759,23 @@ def add_piecewise_constraint( ) vars_dict = vars + vars_keys = set(vars_dict.keys()) - # Validate all variables exist in model - for var_name, var in vars_dict.items(): - if var.name not in self.variables: - raise ValueError(f"Variable '{var.name}' not found in model") - - # Auto-detect or validate link_dim + # Auto-detect or validate link_dim (combined with variable validation) if link_dim is None: # Try to auto-detect link_dim from breakpoints for d in breakpoints.dims: if d == dim: continue coords_set = set(str(c) for c in breakpoints.coords[d].values) - if coords_set == set(vars_dict.keys()): + if coords_set == vars_keys: link_dim = str(d) break if link_dim is None: raise ValueError( "Could not auto-detect link_dim. Please specify it explicitly. " f"Breakpoint dimensions: {list(breakpoints.dims)}, " - f"variable keys: {list(vars_dict.keys())}" + f"variable keys: {list(vars_keys)}" ) else: # Validate link_dim exists and matches dict keys @@ -792,22 +785,23 @@ def add_piecewise_constraint( f"{list(breakpoints.dims)}" ) coords_set = set(str(c) for c in breakpoints.coords[link_dim].values) - if coords_set != set(vars_dict.keys()): + if coords_set != vars_keys: raise ValueError( f"link_dim '{link_dim}' coordinates {coords_set} " - f"don't match variable keys {set(vars_dict.keys())}" + f"don't match variable keys {vars_keys}" ) - # Compute mask from NaN values if not provided - if mask is None: - mask = ~breakpoints.isnull() - - # Lambda coordinates: all dims from breakpoints except link_dim - lambda_dims = [d for d in breakpoints.dims if d != link_dim] + # Lambda coordinates: breakpoints coords without link_dim (as list of Index) lambda_coords = [ - pd.Index(breakpoints.coords[d].values, name=d) for d in lambda_dims + pd.Index(breakpoints.coords[d].values, name=d) + for d in breakpoints.dims + if d != link_dim ] + # Compute mask from NaN values only if not provided + if mask is None: + mask = ~breakpoints.isnull() + # Lambda mask: valid if ANY breakpoint across link_dim is valid lambda_mask = mask.any(dim=link_dim) @@ -824,21 +818,31 @@ def add_piecewise_constraint( lambda_var.sum(dim=dim) == 1, name=convex_name ) - # Stack all variables into a single expression along link_dim for single constraint - # Get the link_dim coordinates in the order they appear in breakpoints + # Build stacked expression efficiently by concatenating variable labels directly link_coords = list(breakpoints.coords[link_dim].values) - # Convert each variable to a LinearExpression and assign link_dim coordinate - var_exprs = [] + # Collect labels and validate variables exist in single pass + labels_list = [] for k in link_coords: var = vars_dict[str(k)] - expr = var.to_linexpr() - # Expand dims to add link_dim coordinate - expr_data = expr.data.expand_dims({link_dim: [k]}) - var_exprs.append(LinearExpression(expr_data, self)) + if var.name not in self.variables: + raise ValueError(f"Variable '{var.name}' not found in model") + labels_list.append(var.labels.expand_dims({link_dim: [k]})) - # Concatenate all variable expressions along link_dim - stacked_vars_expr = merge(var_exprs, dim=link_dim) + # Stack labels directly using xr.concat (more efficient than merge) + stacked_labels = xr.concat(labels_list, dim=link_dim) + + # Build LinearExpression Dataset directly with coeffs=1, avoiding intermediate objects + stacked_expr_data = Dataset( + { + "coeffs": xr.ones_like(stacked_labels, dtype=float).expand_dims( + TERM_DIM + ), + "vars": stacked_labels.expand_dims(TERM_DIM), + "const": xr.zeros_like(stacked_labels.isel({link_dim: 0}), dtype=float), + } + ) + stacked_vars_expr = LinearExpression(stacked_expr_data, self) # Add single linking constraint for all variables: # stacked_vars == (lambda * breakpoints).sum(dim=dim) From 8aeb1c7f198f1e203ff35dec2fdbaeb4cb2584c1 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 25 Jan 2026 13:06:57 +0100 Subject: [PATCH 8/9] Summary of Improvements 1. Skip NaN Check Parameter (skip_nan_check: bool = False) # Before: Always computed O(n) scan if mask is None: mask = ~breakpoints.isnull() # After: Skip when user knows data is clean m.add_piecewise_constraint(x, breakpoints, dim='bp', skip_nan_check=True) 2. Counter-Based Name Generation # Before: Loop until finding unused name - O(n) worst case i = 0 while f"pwl{i}{PWL_LAMBDA_SUFFIX}" in self.variables: i += 1 # After: O(1) counter increment name = f"pwl{self._pwlCounter}" self._pwlCounter += 1 3. Expression Support # Now supports LinearExpression, not just Variable m.add_piecewise_constraint(x + y, breakpoints, dim='bp') # Also works in dict form m.add_piecewise_constraint({ 'total': x + y, 'cost': cost_expr }, breakpoints, link_dim='var', dim='bp') 4. Refactored with Helper Methods Extracted common logic into reusable private methods: - _to_linexpr() - Convert Variable/LinearExpression to LinearExpression - _compute_pwl_mask() - Handle mask computation with skip_nan_check - _resolve_pwl_link_dim() - Auto-detect or validate link_dim - _build_stacked_expr() - Build stacked expression from dict Code Structure (Before vs After) Before: ~250 lines, duplicated logic between single/dict cases After: ~180 lines main method + 4 helper methods, DRY code Test Coverage - 22 tests total (added tests for expression support and skip_nan_check) - All tests pass - Linting passes --- linopy/model.py | 263 +++++++++++++++++++++++++----------------------- 1 file changed, 136 insertions(+), 127 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index dee82bc9..db101e20 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -134,6 +134,7 @@ class Model: "_cCounter", "_varnameCounter", "_connameCounter", + "_pwlCounter", "_blocks", # TODO: check if these should not be mutable "_chunk", @@ -184,6 +185,7 @@ def __init__( self._cCounter: int = 0 self._varnameCounter: int = 0 self._connameCounter: int = 0 + self._pwlCounter: int = 0 self._blocks: DataArray | None = None self._chunk: T_Chunks = chunk @@ -597,19 +599,21 @@ def add_sos_constraints( def add_piecewise_constraint( self, - vars: Variable | dict[str, Variable], + expr: Variable | LinearExpression | dict[str, Variable | LinearExpression], breakpoints: DataArray, link_dim: str | None = None, dim: str = DEFAULT_BREAKPOINT_DIM, mask: DataArray | None = None, name: str | None = None, + skip_nan_check: bool = False, ) -> Constraint: """ Add a piecewise linear constraint using SOS2 formulation. This method creates a piecewise linear constraint that links one or more - variables together via a set of breakpoints. It uses the SOS2 (Special - Ordered Set of type 2) formulation with lambda (interpolation) variables. + variables/expressions together via a set of breakpoints. It uses the SOS2 + (Special Ordered Set of type 2) formulation with lambda (interpolation) + variables. The SOS2 formulation ensures that at most two adjacent lambda variables can be non-zero, effectively selecting a segment of the piecewise linear @@ -617,30 +621,34 @@ def add_piecewise_constraint( Parameters ---------- - vars : Variable or dict[str, Variable] - The variable(s) to be linked by the piecewise constraint. - - If a single Variable is passed, the breakpoints directly specify - the piecewise points for that variable. + expr : Variable, LinearExpression, or dict of these + The variable(s) or expression(s) to be linked by the piecewise constraint. + - If a single Variable/LinearExpression is passed, the breakpoints + directly specify the piecewise points for that expression. - If a dict is passed, the keys must match coordinates in `link_dim` - of the breakpoints, allowing multiple variables to be linked. + of the breakpoints, allowing multiple expressions to be linked. breakpoints : xr.DataArray The breakpoint values defining the piecewise linear function. - Must have `dim` as one of its dimensions. If `vars` is a dict, + Must have `dim` as one of its dimensions. If `expr` is a dict, must also have `link_dim` dimension with coordinates matching the dict keys. link_dim : str, optional - The dimension in breakpoints that links to different variables. - Required when `vars` is a dict. If None and `vars` is a dict, + The dimension in breakpoints that links to different expressions. + Required when `expr` is a dict. If None and `expr` is a dict, will attempt to auto-detect from breakpoints dimensions. dim : str, default "breakpoint" The dimension in breakpoints that represents the breakpoint index. This dimension's coordinates must be numeric (used as SOS2 weights). mask : xr.DataArray, optional Boolean mask indicating which piecewise constraints are valid. - If None, auto-detected from NaN values in breakpoints. + If None, auto-detected from NaN values in breakpoints (unless + skip_nan_check is True). name : str, optional Base name for the generated variables and constraints. If None, auto-generates names like "pwl0", "pwl1", etc. + skip_nan_check : bool, default False + If True, skip automatic NaN detection in breakpoints. Use this + when you know breakpoints contain no NaN values for better performance. Returns ------- @@ -654,9 +662,9 @@ def add_piecewise_constraint( Raises ------ ValueError - If vars is not a Variable or dict of Variables. + If expr is not a Variable, LinearExpression, or dict of these. If breakpoints doesn't have the required dim dimension. - If link_dim cannot be auto-detected when vars is a dict. + If link_dim cannot be auto-detected when expr is a dict. If link_dim coordinates don't match dict keys. If dim coordinates are not numeric. @@ -669,6 +677,14 @@ def add_piecewise_constraint( >>> breakpoints = xr.DataArray([0, 10, 50, 100], dims=["bp"]) >>> m.add_piecewise_constraint(x, breakpoints, dim="bp") + Using an expression: + + >>> m = Model() + >>> x = m.add_variables(name="x") + >>> y = m.add_variables(name="y") + >>> breakpoints = xr.DataArray([0, 10, 50, 100], dims=["bp"]) + >>> m.add_piecewise_constraint(x + y, breakpoints, dim="bp") + Multiple linked variables (e.g., power-efficiency curve): >>> m = Model() @@ -693,163 +709,156 @@ def add_piecewise_constraint( 1. Lambda variables λ_i with bounds [0, 1] are created for each breakpoint 2. SOS2 constraint ensures at most two adjacent λ_i can be non-zero 3. Convexity constraint: Σ λ_i = 1 - 4. Linking constraints: var = Σ λ_i × breakpoint_i (for each variable) + 4. Linking constraints: expr = Σ λ_i × breakpoint_i (for each expression) """ - # Input validation + # --- Input validation --- if dim not in breakpoints.dims: raise ValueError( f"breakpoints must have dimension '{dim}', " f"but only has dimensions {list(breakpoints.dims)}" ) - # Validate dim coordinates are numeric (required for SOS2 weights) if not pd.api.types.is_numeric_dtype(breakpoints.coords[dim]): raise ValueError( f"Breakpoint dimension '{dim}' must have numeric coordinates " f"for SOS2 weights, but got {breakpoints.coords[dim].dtype}" ) - # Generate base name if not provided + # --- Generate names using counter --- if name is None: - i = 0 - while f"pwl{i}{PWL_LAMBDA_SUFFIX}" in self.variables: - i += 1 - name = f"pwl{i}" + name = f"pwl{self._pwlCounter}" + self._pwlCounter += 1 lambda_name = f"{name}{PWL_LAMBDA_SUFFIX}" convex_name = f"{name}{PWL_CONVEX_SUFFIX}" link_name = f"{name}{PWL_LINK_SUFFIX}" - # Handle single Variable case directly (no dict normalization) - if isinstance(vars, Variable): - if vars.name not in self.variables: - raise ValueError(f"Variable '{vars.name}' not found in model") - - # Compute mask from NaN values only if not provided - if mask is None: - mask = ~breakpoints.isnull() - - # Create lambda variables using breakpoints coords directly - lambda_var = self.add_variables( - lower=0, - upper=1, - coords=breakpoints.coords, - name=lambda_name, - mask=mask, - ) - - # Add SOS2 constraint - self.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim) + # --- Determine lambda coordinates, mask, and target expression --- + is_single = isinstance(expr, Variable | LinearExpression) + is_dict = isinstance(expr, dict) - # Add convexity constraint: sum(lambda) = 1 - convex_con = self.add_constraints( - lambda_var.sum(dim=dim) == 1, name=convex_name - ) - - # Add single linking constraint: var = sum(lambda * breakpoints) - weighted_sum = (lambda_var * breakpoints).sum(dim=dim) - self.add_constraints(vars == weighted_sum, name=link_name) - - return convex_con - - # Handle dict of Variables case - if not isinstance(vars, dict): + if not is_single and not is_dict: raise ValueError( - f"'vars' must be a Variable or dict of Variables, got {type(vars)}" + f"'expr' must be a Variable, LinearExpression, or dict of these, " + f"got {type(expr)}" ) - vars_dict = vars - vars_keys = set(vars_dict.keys()) + if is_single: + # Single expression case + target_expr = self._to_linexpr(expr) + lambda_coords = breakpoints.coords + lambda_mask = self._compute_pwl_mask(mask, breakpoints, skip_nan_check) - # Auto-detect or validate link_dim (combined with variable validation) - if link_dim is None: - # Try to auto-detect link_dim from breakpoints - for d in breakpoints.dims: - if d == dim: - continue - coords_set = set(str(c) for c in breakpoints.coords[d].values) - if coords_set == vars_keys: - link_dim = str(d) - break - if link_dim is None: - raise ValueError( - "Could not auto-detect link_dim. Please specify it explicitly. " - f"Breakpoint dimensions: {list(breakpoints.dims)}, " - f"variable keys: {list(vars_keys)}" - ) else: - # Validate link_dim exists and matches dict keys - if link_dim not in breakpoints.dims: - raise ValueError( - f"link_dim '{link_dim}' not found in breakpoints dimensions " - f"{list(breakpoints.dims)}" - ) - coords_set = set(str(c) for c in breakpoints.coords[link_dim].values) - if coords_set != vars_keys: - raise ValueError( - f"link_dim '{link_dim}' coordinates {coords_set} " - f"don't match variable keys {vars_keys}" - ) + # Dict case - need to validate link_dim and build stacked expression + expr_dict = expr + expr_keys = set(expr_dict.keys()) - # Lambda coordinates: breakpoints coords without link_dim (as list of Index) - lambda_coords = [ - pd.Index(breakpoints.coords[d].values, name=d) - for d in breakpoints.dims - if d != link_dim - ] + # Auto-detect or validate link_dim + link_dim = self._resolve_pwl_link_dim(link_dim, breakpoints, dim, expr_keys) + + # Build lambda coordinates (exclude link_dim) + lambda_coords = [ + pd.Index(breakpoints.coords[d].values, name=d) + for d in breakpoints.dims + if d != link_dim + ] - # Compute mask from NaN values only if not provided - if mask is None: - mask = ~breakpoints.isnull() + # Compute mask + base_mask = self._compute_pwl_mask(mask, breakpoints, skip_nan_check) + lambda_mask = base_mask.any(dim=link_dim) - # Lambda mask: valid if ANY breakpoint across link_dim is valid - lambda_mask = mask.any(dim=link_dim) + # Build stacked expression from dict + target_expr = self._build_stacked_expr(expr_dict, breakpoints, link_dim) - # Create lambda variables + # --- Common: Create lambda, SOS2, convexity, and linking constraints --- lambda_var = self.add_variables( lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=lambda_mask ) - # Add SOS2 constraint self.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim) - # Add convexity constraint: sum(lambda) = 1 convex_con = self.add_constraints( lambda_var.sum(dim=dim) == 1, name=convex_name ) - # Build stacked expression efficiently by concatenating variable labels directly - link_coords = list(breakpoints.coords[link_dim].values) + weighted_sum = (lambda_var * breakpoints).sum(dim=dim) + self.add_constraints(target_expr == weighted_sum, name=link_name) - # Collect labels and validate variables exist in single pass - labels_list = [] - for k in link_coords: - var = vars_dict[str(k)] - if var.name not in self.variables: - raise ValueError(f"Variable '{var.name}' not found in model") - labels_list.append(var.labels.expand_dims({link_dim: [k]})) + return convex_con - # Stack labels directly using xr.concat (more efficient than merge) - stacked_labels = xr.concat(labels_list, dim=link_dim) + def _to_linexpr(self, expr: Variable | LinearExpression) -> LinearExpression: + """Convert Variable or LinearExpression to LinearExpression.""" + if isinstance(expr, LinearExpression): + return expr + return expr.to_linexpr() - # Build LinearExpression Dataset directly with coeffs=1, avoiding intermediate objects - stacked_expr_data = Dataset( - { - "coeffs": xr.ones_like(stacked_labels, dtype=float).expand_dims( - TERM_DIM - ), - "vars": stacked_labels.expand_dims(TERM_DIM), - "const": xr.zeros_like(stacked_labels.isel({link_dim: 0}), dtype=float), - } - ) - stacked_vars_expr = LinearExpression(stacked_expr_data, self) + def _compute_pwl_mask( + self, + mask: DataArray | None, + breakpoints: DataArray, + skip_nan_check: bool, + ) -> DataArray | None: + """Compute mask for piecewise constraint, optionally skipping NaN check.""" + if mask is not None: + return mask + if skip_nan_check: + return None + return ~breakpoints.isnull() - # Add single linking constraint for all variables: - # stacked_vars == (lambda * breakpoints).sum(dim=dim) - weighted_sum = (lambda_var * breakpoints).sum(dim=dim) - self.add_constraints(stacked_vars_expr == weighted_sum, name=link_name) + def _resolve_pwl_link_dim( + self, + link_dim: str | None, + breakpoints: DataArray, + dim: str, + expr_keys: set[str], + ) -> str: + """Auto-detect or validate link_dim for dict case.""" + if link_dim is None: + for d in breakpoints.dims: + if d == dim: + continue + coords_set = set(str(c) for c in breakpoints.coords[d].values) + if coords_set == expr_keys: + return str(d) + raise ValueError( + "Could not auto-detect link_dim. Please specify it explicitly. " + f"Breakpoint dimensions: {list(breakpoints.dims)}, " + f"expression keys: {list(expr_keys)}" + ) - return convex_con + if link_dim not in breakpoints.dims: + raise ValueError( + f"link_dim '{link_dim}' not found in breakpoints dimensions " + f"{list(breakpoints.dims)}" + ) + coords_set = set(str(c) for c in breakpoints.coords[link_dim].values) + if coords_set != expr_keys: + raise ValueError( + f"link_dim '{link_dim}' coordinates {coords_set} " + f"don't match expression keys {expr_keys}" + ) + return link_dim + + def _build_stacked_expr( + self, + expr_dict: dict[str, Variable | LinearExpression], + breakpoints: DataArray, + link_dim: str, + ) -> LinearExpression: + """Build a stacked LinearExpression from a dict of Variables/Expressions.""" + link_coords = list(breakpoints.coords[link_dim].values) + + # Collect expression data and stack + expr_data_list = [] + for k in link_coords: + e = expr_dict[str(k)] + linexpr = self._to_linexpr(e) + expr_data_list.append(linexpr.data.expand_dims({link_dim: [k]})) + + # Concatenate along link_dim + stacked_data = xr.concat(expr_data_list, dim=link_dim) + return LinearExpression(stacked_data, self) def add_constraints( self, From 9b3b771ca69bb577a99cacb5981bd40878651753 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Sun, 25 Jan 2026 13:15:48 +0100 Subject: [PATCH 9/9] FInalize rebase --- linopy/expressions.py | 33 +++++++++++-------------------- linopy/io.py | 36 +++++----------------------------- test/test_linear_expression.py | 4 ---- test/test_sos_constraints.py | 16 --------------- 4 files changed, 16 insertions(+), 73 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 17067d5a..10e243de 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1496,7 +1496,7 @@ def _simplify_row(vars_row: np.ndarray, coeffs_row: np.ndarray) -> np.ndarray: # Filter out invalid entries mask = (vars_row != -1) & (coeffs_row != 0) & ~np.isnan(coeffs_row) - valid_vars = vars_row[mask].astype(np.int64, copy=False) + valid_vars = vars_row[mask] valid_coeffs = coeffs_row[mask] if len(valid_vars) == 0: @@ -1508,11 +1508,15 @@ def _simplify_row(vars_row: np.ndarray, coeffs_row: np.ndarray) -> np.ndarray: ] ) - unique_vars, inverse = np.unique(valid_vars, return_inverse=True) - summed = np.bincount(inverse, weights=valid_coeffs) - nonzero = summed != 0 - unique_vars = unique_vars[nonzero] - unique_coeffs = summed[nonzero] + # Use bincount to sum coefficients for each variable ID efficiently + max_var = int(valid_vars.max()) + summed = np.bincount( + valid_vars, weights=valid_coeffs, minlength=max_var + 1 + ) + + # Get non-zero entries + unique_vars = np.where(summed != 0)[0] + unique_coeffs = summed[unique_vars] # Pad to match input length result_vars = np.full(input_len, -1, dtype=float) @@ -1692,17 +1696,6 @@ def from_tuples( This is the same as calling ``10*x + y`` + 1 but a bit more performant. """ - if model is None: - for t in tuples: - if isinstance(t, tuple) and len(t) == 2: - _, var = t - if isinstance(var, variables.ScalarVariable): - model = var.model - break - if isinstance(var, variables.Variable): - model = var.model - break - def process_one( t: tuple[ConstantLike, str | Variable | ScalarVariable] | tuple[ConstantLike] @@ -1738,11 +1731,7 @@ def process_one( raise TypeError("Expected variable as second element of tuple.") if model is None: - model = expr.model - elif expr.model is not model: - raise ValueError( - "All variables in tuples must belong to the same model." - ) + model = expr.model # TODO: Ensure equality of models return expr if len(t) == 1: diff --git a/linopy/io.py b/linopy/io.py index 87f66a4c..56fe033d 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -357,7 +357,7 @@ def sos_to_file( Write out SOS constraints of a model to an LP file. """ names = m.variables.sos - if not len(names): + if not len(list(names)): return print_variable, _ = get_printers( @@ -380,24 +380,11 @@ def sos_to_file( other_dims = [dim for dim in var.labels.dims if dim != sos_dim] for var_slice in var.iterate_slices(slice_size, other_dims): ds = var_slice.labels.to_dataset() + ds["sos_labels"] = ds["labels"].isel({sos_dim: 0}) ds["weights"] = ds.coords[sos_dim] - sos_labels = ( - ds["labels"] - .where(ds["labels"] != -1) - .min(dim=sos_dim, skipna=True) - .fillna(-1) - .astype(int) - ) - ds["sos_labels"] = sos_labels - df = to_polars(ds) - df = df.filter(pl.col("labels").ne(-1) & pl.col("sos_labels").ne(-1)) - if df.height == 0: - continue - - df = df.sort(["sos_labels", "weights"]) - df = df.group_by("sos_labels", maintain_order=True).agg( + df = df.group_by("sos_labels").agg( pl.concat_str( *print_variable(pl.col("labels")), pl.lit(":"), pl.col("weights") ) @@ -407,7 +394,7 @@ def sos_to_file( columns = [ pl.lit("s"), - pl.col("sos_labels").cast(pl.Int64), + pl.col("sos_labels"), pl.lit(f": S{sos_type} :: "), pl.col("var_weights"), ] @@ -562,13 +549,6 @@ def to_lp_file( slice_size=slice_size, explicit_coordinate_names=explicit_coordinate_names, ) - sos_to_file( - m, - f=f, - progress=progress, - slice_size=slice_size, - explicit_coordinate_names=explicit_coordinate_names, - ) f.write(b"end\n") logger.info(f" Writing time: {round(time.time() - start, 2)}s") @@ -800,13 +780,7 @@ def add_sos(s: xr.DataArray, sos_type: int, sos_dim: str) -> None: s = s.squeeze() indices = s.values.flatten().tolist() weights = s.coords[sos_dim].values.tolist() - pairs = [(i, w) for i, w in zip(indices, weights) if i != -1] - if not pairs: - return - indices_filtered, weights_filtered = zip(*pairs) - model.addSOS( - sos_type, x[list(indices_filtered)].tolist(), list(weights_filtered) - ) + model.addSOS(sos_type, x[indices].tolist(), weights) others = [dim for dim in var.labels.dims if dim != sos_dim] if not others: diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 897f50f5..a75ace3f 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1102,10 +1102,6 @@ def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None: expr5 = LinearExpression.from_tuples(1, model=x.model) assert isinstance(expr5, LinearExpression) - expr6 = LinearExpression.from_tuples(1, (10, x), (1, y)) - assert isinstance(expr6, LinearExpression) - assert (expr6.const == 1).all() - def test_linear_expression_from_tuples_bad_calls( m: Model, x: Variable, y: Variable diff --git a/test/test_sos_constraints.py b/test/test_sos_constraints.py index 225cc470..5d94162e 100644 --- a/test/test_sos_constraints.py +++ b/test/test_sos_constraints.py @@ -60,22 +60,6 @@ def test_sos_constraints_written_to_lp(tmp_path: Path) -> None: assert "3.5" in content -def test_sos_constraints_written_to_lp_with_mask(tmp_path: Path) -> None: - m = Model() - breakpoints = pd.Index([0.0, 1.5, 3.5], name="bp") - mask = pd.Series([False, True, True], index=breakpoints) - lambdas = m.add_variables(coords=[breakpoints], name="lambda", mask=mask) - m.add_sos_constraints(lambdas, sos_type=2, sos_dim="bp") - - fn = tmp_path / "sos_mask.lp" - m.to_file(fn, io_api="lp") - content = fn.read_text() - - sos_section = content.split("\nsos\n", 1)[1].split("\nend\n", 1)[0] - assert "s-1" not in sos_section - assert "0.0" not in sos_section - - @pytest.mark.skipif("gurobi" not in available_solvers, reason="Gurobipy not installed") def test_to_gurobipy_emits_sos_constraints() -> None: gurobipy = pytest.importorskip("gurobipy")