diff --git a/linopy/common.py b/linopy/common.py index 09f67355..4499ab12 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -9,7 +9,7 @@ import operator import os -from collections.abc import Callable, Generator, Hashable, Iterable, Sequence +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence from functools import partial, reduce, wraps from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload @@ -210,25 +210,116 @@ def numpy_to_dataarray( return DataArray(arr, coords=coords, dims=dims, **kwargs) -def as_dataarray( +def _coords_to_mapping( + coords: CoordsLike, +) -> Mapping[Hashable, Any]: + """ + Normalize coords to a mapping of ``{dim_name: index_values}``. + + Handles both dict-like coords and sequence coords (e.g. a list of + ``pd.Index`` or ``pd.RangeIndex``). + """ + if is_dict_like(coords): + return coords + + result: dict[Hashable, Any] = {} + for idx in coords: + if isinstance(idx, pd.Index): + name = idx.name + elif isinstance(idx, DataArray): + name = idx.dims[0] if idx.dims else None + else: + name = None + if name is None: + msg = ( + "Cannot determine dimension name from coords sequence element " + f"{idx!r}. Use a dict or ensure each element has a .name attribute." + ) + raise ValueError(msg) + result[name] = idx + return result + + +def _validate_dataarray_coords( + da: DataArray, + coords: CoordsLike, +) -> DataArray: + """ + Validate that a DataArray's coordinates match the expected coords. + + For shared dimensions, coordinates must match exactly. Extra dimensions + in the DataArray (not in coords) raise ``ValueError``. Missing dimensions + (in coords but not in the DataArray) are broadcast via ``expand_dims``. + + Parameters + ---------- + da : DataArray + The input DataArray to validate. + coords : CoordsLike + The expected coordinates. + + Returns + ------- + DataArray + The validated (and possibly broadcast) DataArray. + + Raises + ------ + ValueError + If the DataArray has extra dimensions or mismatched coordinates. + """ + coords_map = _coords_to_mapping(coords) + + extra_dims = set(da.dims) - set(coords_map) + if extra_dims: + raise ValueError( + f"DataArray has dimensions not present in coords: {extra_dims}" + ) + + for dim in da.dims: + if dim not in coords_map: + continue + expected = pd.Index(coords_map[dim]) + actual = pd.Index(da.coords[dim].values) + if not expected.equals(actual): + raise ValueError( + f"Coordinates for dimension '{dim}' do not match: " + f"expected {expected.tolist()}, got {actual.tolist()}" + ) + + # Broadcast to dimensions present in coords but not in the DataArray + missing_dims = set(coords_map) - set(da.dims) + for dim in missing_dims: + da = da.expand_dims({dim: coords_map[dim]}) + + return da + + +def ensure_dataarray( arr: Any, coords: CoordsLike | None = None, dims: DimsLike | None = None, **kwargs: Any, ) -> DataArray: """ - Convert an object to a DataArray. + Convert an object to a DataArray without coordinate validation. + + DataArray inputs pass through unchanged. For all other types, ``coords`` + and ``dims`` are forwarded to the DataArray constructor. Parameters ---------- arr: The input object. coords (Union[dict, list, None]): - The coordinates for the DataArray. If None, default coordinates will be used. + The coordinates for the DataArray. If None, default coordinates + will be used. dims (Union[list, None]): - The dimensions for the DataArray. If None, the dimensions will be automatically generated. + The dimensions for the DataArray. If None, the dimensions will + be automatically generated. **kwargs: - Additional keyword arguments to be passed to the DataArray constructor. + Additional keyword arguments to be passed to the DataArray + constructor. Returns ------- @@ -245,7 +336,6 @@ def as_dataarray( arr = DataArray(float(arr), coords=coords, dims=dims, **kwargs) elif isinstance(arr, int | float | str | bool | list): arr = DataArray(arr, coords=coords, dims=dims, **kwargs) - elif not isinstance(arr, DataArray): supported_types = [ np.number, @@ -267,6 +357,56 @@ def as_dataarray( return arr +def as_dataarray( + arr: Any, + coords: CoordsLike | None = None, + dims: DimsLike | None = None, + **kwargs: Any, +) -> DataArray: + """ + Convert an object to a DataArray with coordinate validation. + + When ``coords`` is provided the result is guaranteed to have exactly + those dimensions and coordinate values. For DataArray inputs this + means the array is validated (shared dims must match, extra dims are + rejected, missing dims are broadcast). For other input types the + coords are forwarded to the DataArray constructor. + + Parameters + ---------- + arr: + The input object. + coords (Union[dict, list, None]): + The coordinates for the DataArray. If None, default coordinates + will be used. + dims (Union[list, None]): + The dimensions for the DataArray. If None, the dimensions will + be automatically generated. + **kwargs: + Additional keyword arguments to be passed to the DataArray + constructor. + + Returns + ------- + DataArray: + The converted DataArray. + """ + is_input_dataarray = isinstance(arr, DataArray) + arr = ensure_dataarray(arr, coords=coords, dims=dims, **kwargs) + + if is_input_dataarray and coords is not None: + # Normalize plain sequence coords to a mapping using dims or arr.dims + if not is_dict_like(coords): + if dims is not None: + dim_names = list(dims) if isinstance(dims, Iterable) else [dims] + else: + dim_names = list(arr.dims) + coords = dict(zip(dim_names, coords)) + arr = _validate_dataarray_coords(arr, coords) + + return arr + + def broadcast_mask(mask: DataArray, labels: DataArray) -> DataArray: """ Broadcast a boolean mask to match the shape of labels. diff --git a/linopy/expressions.py b/linopy/expressions.py index d2ae9022..fb5c0065 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -51,6 +51,7 @@ check_common_keys_values, check_has_nulls, check_has_nulls_polars, + ensure_dataarray, fill_missing_coords, filter_nulls_polars, forward_as_properties, @@ -597,7 +598,7 @@ def _add_constant( # so that missing data does not silently propagate through arithmetic. if np.isscalar(other) and join is None: return self.assign(const=self.const.fillna(0) + other) - da = as_dataarray(other, coords=self.coords, dims=self.coord_dims) + da = ensure_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, da, needs_data_reindex = self._align_constant( da, fill_value=0, join=join ) @@ -626,7 +627,7 @@ def _apply_constant_op( - factor (other) is filled with fill_value (0 for mul, 1 for div) - coeffs and const are filled with 0 (additive identity) """ - factor = as_dataarray(other, coords=self.coords, dims=self.coord_dims) + factor = ensure_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, factor, needs_data_reindex = self._align_constant( factor, fill_value=fill_value, join=join ) @@ -1142,7 +1143,7 @@ def to_constraint( ) if isinstance(rhs, SUPPORTED_CONSTANT_TYPES): - rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims) + rhs = ensure_dataarray(rhs, coords=self.coords, dims=self.coord_dims) extra_dims = set(rhs.dims) - set(self.coord_dims) if extra_dims: @@ -1705,7 +1706,7 @@ def __matmul__( Matrix multiplication with other, similar to xarray dot. """ if not isinstance(other, LinearExpression | variables.Variable): - other = as_dataarray(other, coords=self.coords, dims=self.coord_dims) + other = ensure_dataarray(other, coords=self.coords, dims=self.coord_dims) common_dims = list(set(self.coord_dims).intersection(other.dims)) return (self * other).sum(dim=common_dims) @@ -2191,7 +2192,7 @@ def __matmul__( "Higher order non-linear expressions are not yet supported." ) - other = as_dataarray(other, coords=self.coords, dims=self.coord_dims) + other = ensure_dataarray(other, coords=self.coords, dims=self.coord_dims) common_dims = list(set(self.coord_dims).intersection(other.dims)) return (self * other).sum(dim=common_dims) diff --git a/linopy/model.py b/linopy/model.py index 54334411..6074535e 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -30,6 +30,7 @@ assign_multiindex_safe, best_int, broadcast_mask, + ensure_dataarray, maybe_replace_signs, replace_by_map, set_int_index, @@ -615,7 +616,9 @@ def add_variables( self._check_valid_dim_names(data) if mask is not None: - mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool) + mask = ensure_dataarray(mask, coords=data.coords, dims=data.dims).astype( + bool + ) mask = broadcast_mask(mask, data.labels) # Auto-mask based on NaN in bounds (use numpy for speed) @@ -777,14 +780,14 @@ def add_constraints( name = f"con{self._connameCounter}" self._connameCounter += 1 if sign is not None: - sign = maybe_replace_signs(as_dataarray(sign)) + sign = maybe_replace_signs(ensure_dataarray(sign)) # Capture original RHS for auto-masking before constraint creation # (NaN values in RHS are lost during constraint creation) # Use numpy for speed instead of xarray's notnull() original_rhs_mask = None if self.auto_mask and rhs is not None: - rhs_da = as_dataarray(rhs) + rhs_da = ensure_dataarray(rhs) original_rhs_mask = (rhs_da.coords, rhs_da.dims, ~np.isnan(rhs_da.values)) if isinstance(lhs, LinearExpression): @@ -837,14 +840,16 @@ def add_constraints( mask = ( rhs_mask if mask is None - else (as_dataarray(mask).astype(bool) & rhs_mask) + else (ensure_dataarray(mask).astype(bool) & rhs_mask) ) data["labels"] = -1 (data,) = xr.broadcast(data, exclude=[TERM_DIM]) if mask is not None: - mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool) + mask = ensure_dataarray(mask, coords=data.coords, dims=data.dims).astype( + bool + ) mask = broadcast_mask(mask, data.labels) # Auto-mask based on null expressions or NaN RHS (use numpy for speed) diff --git a/linopy/variables.py b/linopy/variables.py index 4332a037..8f3be95b 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -33,10 +33,10 @@ from linopy.common import ( LabelPositionIndex, LocIndexer, - as_dataarray, assign_multiindex_safe, check_has_nulls, check_has_nulls_polars, + ensure_dataarray, filter_nulls_polars, format_string_as_variable_name, generate_indices_for_printout, @@ -321,7 +321,7 @@ def to_linexpr( linopy.LinearExpression Linear expression with the variables and coefficients. """ - coefficient = as_dataarray(coefficient, coords=self.coords, dims=self.dims) + coefficient = ensure_dataarray(coefficient, coords=self.coords, dims=self.dims) coefficient = coefficient.reindex_like(self.labels, fill_value=0) coefficient = coefficient.fillna(0) ds = Dataset({"coeffs": coefficient, "vars": self.labels}).expand_dims( diff --git a/test/test_common.py b/test/test_common.py index f1190024..32654772 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -414,6 +414,64 @@ def test_as_dataarray_with_unsupported_type() -> None: as_dataarray(lambda x: 1, dims=["dim1"], coords=[["a"]]) +def test_as_dataarray_dataarray_matching_coords() -> None: + da = DataArray([1, 2, 3], dims=["x"], coords={"x": [0, 1, 2]}) + result = as_dataarray(da, coords={"x": [0, 1, 2]}) + assert_equal(result, da) + + +def test_as_dataarray_dataarray_mismatched_coords() -> None: + da = DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 11, 12]}) + with pytest.raises(ValueError, match="do not match"): + as_dataarray(da, coords={"x": [0, 1, 2]}) + + +def test_as_dataarray_dataarray_extra_dims() -> None: + da = DataArray([[1, 2], [3, 4]], dims=["x", "y"]) + with pytest.raises(ValueError, match="not present in coords"): + as_dataarray(da, coords={"x": [0, 1]}) + + +def test_as_dataarray_dataarray_broadcast_missing_dims() -> None: + da = DataArray([1, 2], dims=["x"], coords={"x": [0, 1]}) + result = as_dataarray(da, coords={"x": [0, 1], "y": [10, 20, 30]}) + assert result.dims == ("y", "x") # expand_dims prepends + assert result.shape == (3, 2) + # Values should be broadcast + assert (result.sel(y=10) == da).all() + + +def test_as_dataarray_dataarray_sequence_coords() -> None: + time = pd.RangeIndex(3, name="time") + da = DataArray([1, 2, 3], dims=["time"], coords={"time": [0, 1, 2]}) + result = as_dataarray(da, coords=[time]) + assert_equal(result, da) + + +def test_as_dataarray_dataarray_sequence_coords_mismatch() -> None: + time = pd.RangeIndex(5, name="time") + da = DataArray([1, 2, 3], dims=["time"], coords={"time": [0, 1, 2]}) + with pytest.raises(ValueError, match="do not match"): + as_dataarray(da, coords=[time]) + + +def test_add_variables_dataarray_bounds_validated() -> None: + m = Model() + time = pd.RangeIndex(3, name="time") + lower = DataArray([0, 0, 0], dims=["time"], coords={"time": [10, 11, 12]}) + with pytest.raises(ValueError, match="do not match"): + m.add_variables(lower=lower, coords=[time], name="x") + + +def test_add_variables_dataarray_bounds_broadcast() -> None: + m = Model() + time = pd.RangeIndex(3, name="time") + space = pd.Index(["a", "b"], name="space") + lower = DataArray([0, 0, 0], dims=["time"], coords={"time": [0, 1, 2]}) + v = m.add_variables(lower=lower, coords=[time, space], name="x") + assert v.shape == (3, 2) + + def test_best_int() -> None: # Test for int8 assert best_int(127) == np.int8