diff --git a/linopy/model.py b/linopy/model.py index 48a8200b..431cbb17 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -135,7 +135,8 @@ def _validate_dataarray_bounds(arr: Any, coords: Any) -> Any: - Raises ``ValueError`` if the array has dimensions not in coords. - Raises ``ValueError`` if shared dimension coordinates don't match. - - Expands missing dimensions via ``expand_dims``. + - Expands missing dimensions via ``expand_dims`` and transposes the + result to the ``coords`` dimension order. """ if not isinstance(arr, DataArray): return arr @@ -171,10 +172,18 @@ def _validate_dataarray_bounds(arr: Any, coords: Any) -> Any: f"expected {expected_idx.tolist()}, got {actual_idx.tolist()}" ) - # Expand missing dimensions + # expand_dims prepends new dimensions and coordinates; restore coords order. expand = {k: v for k, v in expected.items() if k not in arr.dims} if expand: - arr = arr.expand_dims(expand) + arr = arr.expand_dims(expand).transpose(*expected) + target = [c for c in expected if c in arr.coords] + target += [c for c in arr.coords if c not in expected] + if list(arr.coords) != target: + arr = DataArray( + arr.variable, + coords={c: arr.coords[c] for c in target}, + name=arr.name, + ) return arr diff --git a/linopy/piecewise.py b/linopy/piecewise.py index ccc265a7..25a0ce17 100644 --- a/linopy/piecewise.py +++ b/linopy/piecewise.py @@ -1006,20 +1006,18 @@ def _broadcast_points( lin_exprs = [_to_linexpr(e) for e in exprs] - target_dims: set[str] = set() - for le in lin_exprs: - target_dims.update(str(d) for d in le.coord_dims) - - missing = target_dims - skip - {str(d) for d in points.dims} - if not missing: - return points + point_dims = {str(d) for d in points.dims} + # Iterate exprs/dims in order; a set would give a hash-dependent, + # run-varying expanded dimension order. expand_map: dict[str, list] = {} - for d in missing: - for le in lin_exprs: + for le in lin_exprs: + for dim in le.coord_dims: + d = str(dim) + if d in skip or d in point_dims or d in expand_map: + continue if d in le.coords: - expand_map[str(d)] = list(le.coords[d].values) - break + expand_map[d] = list(le.coords[d].values) if expand_map: points = points.expand_dims(expand_map) diff --git a/test/test_piecewise_constraints.py b/test/test_piecewise_constraints.py index c44af394..72b57265 100644 --- a/test/test_piecewise_constraints.py +++ b/test/test_piecewise_constraints.py @@ -1383,6 +1383,23 @@ def test_broadcast_over_extra_dims(self) -> None: assert "generator" in delta.dims assert "time" in delta.dims + def test_broadcast_points_dim_order_follows_exprs(self) -> None: + """Expanded dims follow the expression dim order, not set ordering.""" + import xarray as xr + + from linopy.piecewise import BREAKPOINT_DIM, _broadcast_points + + m = Model() + coords = [ + pd.Index(["v0", "v1"], name="alpha"), + pd.Index(["w0", "w1"], name="beta"), + pd.Index([0, 1], name="gamma"), + ] + x = m.add_variables(coords=coords, name="x") + points = xr.DataArray([0, 1, 2, 3], dims=[BREAKPOINT_DIM]) + out = _broadcast_points(points, 1 * x) + assert out.dims == ("alpha", "beta", "gamma", BREAKPOINT_DIM) + # =========================================================================== # NaN masking diff --git a/test/test_variable.py b/test/test_variable.py index b14b746e..d01d7fd8 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -446,6 +446,32 @@ def test_dataarray_broadcast_missing_dim(self, model: "Model") -> None: assert (var.data.lower.sel(space="a") == [1, 2, 3]).all() assert (var.data.lower.sel(space="b") == [1, 2, 3]).all() + @pytest.mark.parametrize( + "lower, upper", + [ + pytest.param(0, "da", id="scalar-lower+da-upper"), + pytest.param("da", 1, id="da-lower+scalar-upper"), + pytest.param("da", "da", id="da-lower+da-upper"), + ], + ) + def test_dataarray_broadcast_missing_dim_order( + self, model: "Model", lower: Any, upper: Any + ) -> None: + """Dimension order follows coords, not the type of the bounds (#706).""" + x = pd.Index(["a", "b", "c"], name="x") + y = pd.Index(["X", "Y"], name="y") + full = DataArray( + np.arange(6).reshape(3, 2), coords={"x": x, "y": y}, dims=["x", "y"] + ) + # bounds are DataArrays missing the 'y' dimension + da = full.sum("y") + lower = da if lower == "da" else lower + upper = da if upper == "da" else upper + var = model.add_variables(lower=lower, upper=upper, coords=[x, y], name="x") + assert var.dims == ("x", "y") + assert var.data.lower.dims == ("x", "y") + assert var.data.upper.dims == ("x", "y") + # -- Special coord formats --------------------------------------------- def test_multiindex_coords(self, model: "Model") -> None: