Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 147 additions & 7 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import operator
import os
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
from functools import partial, reduce, wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
Expand Down Expand Up @@ -210,25 +210,116 @@ def numpy_to_dataarray(
return DataArray(arr, coords=coords, dims=dims, **kwargs)


def as_dataarray(
def _coords_to_mapping(
coords: CoordsLike,
) -> Mapping[Hashable, Any]:
"""
Normalize coords to a mapping of ``{dim_name: index_values}``.

Handles both dict-like coords and sequence coords (e.g. a list of
``pd.Index`` or ``pd.RangeIndex``).
"""
if is_dict_like(coords):
return coords

result: dict[Hashable, Any] = {}
for idx in coords:
if isinstance(idx, pd.Index):
name = idx.name
elif isinstance(idx, DataArray):
name = idx.dims[0] if idx.dims else None
else:
name = None
if name is None:
msg = (
"Cannot determine dimension name from coords sequence element "
f"{idx!r}. Use a dict or ensure each element has a .name attribute."
)
raise ValueError(msg)
result[name] = idx
return result


def _validate_dataarray_coords(
da: DataArray,
coords: CoordsLike,
) -> DataArray:
"""
Validate that a DataArray's coordinates match the expected coords.

For shared dimensions, coordinates must match exactly. Extra dimensions
in the DataArray (not in coords) raise ``ValueError``. Missing dimensions
(in coords but not in the DataArray) are broadcast via ``expand_dims``.

Parameters
----------
da : DataArray
The input DataArray to validate.
coords : CoordsLike
The expected coordinates.

Returns
-------
DataArray
The validated (and possibly broadcast) DataArray.

Raises
------
ValueError
If the DataArray has extra dimensions or mismatched coordinates.
"""
coords_map = _coords_to_mapping(coords)

extra_dims = set(da.dims) - set(coords_map)
if extra_dims:
raise ValueError(
f"DataArray has dimensions not present in coords: {extra_dims}"
)

for dim in da.dims:
if dim not in coords_map:
continue
expected = pd.Index(coords_map[dim])
actual = pd.Index(da.coords[dim].values)
if not expected.equals(actual):
raise ValueError(
f"Coordinates for dimension '{dim}' do not match: "
f"expected {expected.tolist()}, got {actual.tolist()}"
)

# Broadcast to dimensions present in coords but not in the DataArray
missing_dims = set(coords_map) - set(da.dims)
for dim in missing_dims:
da = da.expand_dims({dim: coords_map[dim]})

return da


def ensure_dataarray(
arr: Any,
coords: CoordsLike | None = None,
dims: DimsLike | None = None,
**kwargs: Any,
) -> DataArray:
"""
Convert an object to a DataArray.
Convert an object to a DataArray without coordinate validation.

DataArray inputs pass through unchanged. For all other types, ``coords``
and ``dims`` are forwarded to the DataArray constructor.

Parameters
----------
arr:
The input object.
coords (Union[dict, list, None]):
The coordinates for the DataArray. If None, default coordinates will be used.
The coordinates for the DataArray. If None, default coordinates
will be used.
dims (Union[list, None]):
The dimensions for the DataArray. If None, the dimensions will be automatically generated.
The dimensions for the DataArray. If None, the dimensions will
be automatically generated.
**kwargs:
Additional keyword arguments to be passed to the DataArray constructor.
Additional keyword arguments to be passed to the DataArray
constructor.

Returns
-------
Expand All @@ -245,7 +336,6 @@ def as_dataarray(
arr = DataArray(float(arr), coords=coords, dims=dims, **kwargs)
elif isinstance(arr, int | float | str | bool | list):
arr = DataArray(arr, coords=coords, dims=dims, **kwargs)

elif not isinstance(arr, DataArray):
supported_types = [
np.number,
Expand All @@ -267,6 +357,56 @@ def as_dataarray(
return arr


def as_dataarray(
arr: Any,
coords: CoordsLike | None = None,
dims: DimsLike | None = None,
**kwargs: Any,
) -> DataArray:
"""
Convert an object to a DataArray with coordinate validation.

When ``coords`` is provided the result is guaranteed to have exactly
those dimensions and coordinate values. For DataArray inputs this
means the array is validated (shared dims must match, extra dims are
rejected, missing dims are broadcast). For other input types the
coords are forwarded to the DataArray constructor.

Parameters
----------
arr:
The input object.
coords (Union[dict, list, None]):
The coordinates for the DataArray. If None, default coordinates
will be used.
dims (Union[list, None]):
The dimensions for the DataArray. If None, the dimensions will
be automatically generated.
**kwargs:
Additional keyword arguments to be passed to the DataArray
constructor.

Returns
-------
DataArray:
The converted DataArray.
"""
is_input_dataarray = isinstance(arr, DataArray)
arr = ensure_dataarray(arr, coords=coords, dims=dims, **kwargs)

if is_input_dataarray and coords is not None:
# Normalize plain sequence coords to a mapping using dims or arr.dims
if not is_dict_like(coords):
if dims is not None:
dim_names = list(dims) if isinstance(dims, Iterable) else [dims]
else:
dim_names = list(arr.dims)
coords = dict(zip(dim_names, coords))
arr = _validate_dataarray_coords(arr, coords)

return arr


def broadcast_mask(mask: DataArray, labels: DataArray) -> DataArray:
"""
Broadcast a boolean mask to match the shape of labels.
Expand Down
11 changes: 6 additions & 5 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
check_common_keys_values,
check_has_nulls,
check_has_nulls_polars,
ensure_dataarray,
fill_missing_coords,
filter_nulls_polars,
forward_as_properties,
Expand Down Expand Up @@ -597,7 +598,7 @@ def _add_constant(
# so that missing data does not silently propagate through arithmetic.
if np.isscalar(other) and join is None:
return self.assign(const=self.const.fillna(0) + other)
da = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
da = ensure_dataarray(other, coords=self.coords, dims=self.coord_dims)
self_const, da, needs_data_reindex = self._align_constant(
da, fill_value=0, join=join
)
Expand Down Expand Up @@ -626,7 +627,7 @@ def _apply_constant_op(
- factor (other) is filled with fill_value (0 for mul, 1 for div)
- coeffs and const are filled with 0 (additive identity)
"""
factor = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
factor = ensure_dataarray(other, coords=self.coords, dims=self.coord_dims)
self_const, factor, needs_data_reindex = self._align_constant(
factor, fill_value=fill_value, join=join
)
Expand Down Expand Up @@ -1142,7 +1143,7 @@ def to_constraint(
)

if isinstance(rhs, SUPPORTED_CONSTANT_TYPES):
rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims)
rhs = ensure_dataarray(rhs, coords=self.coords, dims=self.coord_dims)

extra_dims = set(rhs.dims) - set(self.coord_dims)
if extra_dims:
Expand Down Expand Up @@ -1705,7 +1706,7 @@ def __matmul__(
Matrix multiplication with other, similar to xarray dot.
"""
if not isinstance(other, LinearExpression | variables.Variable):
other = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
other = ensure_dataarray(other, coords=self.coords, dims=self.coord_dims)

common_dims = list(set(self.coord_dims).intersection(other.dims))
return (self * other).sum(dim=common_dims)
Expand Down Expand Up @@ -2191,7 +2192,7 @@ def __matmul__(
"Higher order non-linear expressions are not yet supported."
)

other = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
other = ensure_dataarray(other, coords=self.coords, dims=self.coord_dims)
common_dims = list(set(self.coord_dims).intersection(other.dims))
return (self * other).sum(dim=common_dims)

Expand Down
15 changes: 10 additions & 5 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
assign_multiindex_safe,
best_int,
broadcast_mask,
ensure_dataarray,
maybe_replace_signs,
replace_by_map,
set_int_index,
Expand Down Expand Up @@ -615,7 +616,9 @@ def add_variables(
self._check_valid_dim_names(data)

if mask is not None:
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)
mask = ensure_dataarray(mask, coords=data.coords, dims=data.dims).astype(
bool
)
mask = broadcast_mask(mask, data.labels)

# Auto-mask based on NaN in bounds (use numpy for speed)
Expand Down Expand Up @@ -777,14 +780,14 @@ def add_constraints(
name = f"con{self._connameCounter}"
self._connameCounter += 1
if sign is not None:
sign = maybe_replace_signs(as_dataarray(sign))
sign = maybe_replace_signs(ensure_dataarray(sign))

# Capture original RHS for auto-masking before constraint creation
# (NaN values in RHS are lost during constraint creation)
# Use numpy for speed instead of xarray's notnull()
original_rhs_mask = None
if self.auto_mask and rhs is not None:
rhs_da = as_dataarray(rhs)
rhs_da = ensure_dataarray(rhs)
original_rhs_mask = (rhs_da.coords, rhs_da.dims, ~np.isnan(rhs_da.values))

if isinstance(lhs, LinearExpression):
Expand Down Expand Up @@ -837,14 +840,16 @@ def add_constraints(
mask = (
rhs_mask
if mask is None
else (as_dataarray(mask).astype(bool) & rhs_mask)
else (ensure_dataarray(mask).astype(bool) & rhs_mask)
)

data["labels"] = -1
(data,) = xr.broadcast(data, exclude=[TERM_DIM])

if mask is not None:
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)
mask = ensure_dataarray(mask, coords=data.coords, dims=data.dims).astype(
bool
)
mask = broadcast_mask(mask, data.labels)

# Auto-mask based on null expressions or NaN RHS (use numpy for speed)
Expand Down
4 changes: 2 additions & 2 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
from linopy.common import (
LabelPositionIndex,
LocIndexer,
as_dataarray,
assign_multiindex_safe,
check_has_nulls,
check_has_nulls_polars,
ensure_dataarray,
filter_nulls_polars,
format_string_as_variable_name,
generate_indices_for_printout,
Expand Down Expand Up @@ -321,7 +321,7 @@ def to_linexpr(
linopy.LinearExpression
Linear expression with the variables and coefficients.
"""
coefficient = as_dataarray(coefficient, coords=self.coords, dims=self.dims)
coefficient = ensure_dataarray(coefficient, coords=self.coords, dims=self.dims)
coefficient = coefficient.reindex_like(self.labels, fill_value=0)
coefficient = coefficient.fillna(0)
ds = Dataset({"coeffs": coefficient, "vars": self.labels}).expand_dims(
Expand Down
58 changes: 58 additions & 0 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,64 @@ def test_as_dataarray_with_unsupported_type() -> None:
as_dataarray(lambda x: 1, dims=["dim1"], coords=[["a"]])


def test_as_dataarray_dataarray_matching_coords() -> None:
da = DataArray([1, 2, 3], dims=["x"], coords={"x": [0, 1, 2]})
result = as_dataarray(da, coords={"x": [0, 1, 2]})
assert_equal(result, da)


def test_as_dataarray_dataarray_mismatched_coords() -> None:
da = DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 11, 12]})
with pytest.raises(ValueError, match="do not match"):
as_dataarray(da, coords={"x": [0, 1, 2]})


def test_as_dataarray_dataarray_extra_dims() -> None:
da = DataArray([[1, 2], [3, 4]], dims=["x", "y"])
with pytest.raises(ValueError, match="not present in coords"):
as_dataarray(da, coords={"x": [0, 1]})


def test_as_dataarray_dataarray_broadcast_missing_dims() -> None:
da = DataArray([1, 2], dims=["x"], coords={"x": [0, 1]})
result = as_dataarray(da, coords={"x": [0, 1], "y": [10, 20, 30]})
assert result.dims == ("y", "x") # expand_dims prepends
assert result.shape == (3, 2)
# Values should be broadcast
assert (result.sel(y=10) == da).all()


def test_as_dataarray_dataarray_sequence_coords() -> None:
time = pd.RangeIndex(3, name="time")
da = DataArray([1, 2, 3], dims=["time"], coords={"time": [0, 1, 2]})
result = as_dataarray(da, coords=[time])
assert_equal(result, da)


def test_as_dataarray_dataarray_sequence_coords_mismatch() -> None:
time = pd.RangeIndex(5, name="time")
da = DataArray([1, 2, 3], dims=["time"], coords={"time": [0, 1, 2]})
with pytest.raises(ValueError, match="do not match"):
as_dataarray(da, coords=[time])


def test_add_variables_dataarray_bounds_validated() -> None:
m = Model()
time = pd.RangeIndex(3, name="time")
lower = DataArray([0, 0, 0], dims=["time"], coords={"time": [10, 11, 12]})
with pytest.raises(ValueError, match="do not match"):
m.add_variables(lower=lower, coords=[time], name="x")


def test_add_variables_dataarray_bounds_broadcast() -> None:
m = Model()
time = pd.RangeIndex(3, name="time")
space = pd.Index(["a", "b"], name="space")
lower = DataArray([0, 0, 0], dims=["time"], coords={"time": [0, 1, 2]})
v = m.add_variables(lower=lower, coords=[time, space], name="x")
assert v.shape == (3, 2)


def test_best_int() -> None:
# Test for int8
assert best_int(127) == np.int8
Expand Down
Loading